diff --git a/.gitattributes b/.gitattributes index 1e37d4b45bff8ede805ee81098714823320b012a..45722fd5fddf0aa96a7a50fcdf4497cc362cac65 100644 --- a/.gitattributes +++ b/.gitattributes @@ -34,3 +34,21 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text data/Processed[[:space:]]Data.pdf filter=lfs diff=lfs merge=lfs -text +llm/Lib/site-packages/altair/vegalite/v5/schema/__pycache__/channels.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +llm/Lib/site-packages/altair/vegalite/v5/schema/__pycache__/core.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +llm/Lib/site-packages/cassandra/cluster.cp311-win_amd64.pyd filter=lfs diff=lfs merge=lfs -text +llm/Lib/site-packages/ctransformers/lib/avx/ctransformers.dll filter=lfs diff=lfs merge=lfs -text +llm/Lib/site-packages/ctransformers/lib/avx/libctransformers.dylib filter=lfs diff=lfs merge=lfs -text +llm/Lib/site-packages/ctransformers/lib/avx/libctransformers.so filter=lfs diff=lfs merge=lfs -text +llm/Lib/site-packages/ctransformers/lib/avx2/ctransformers.dll filter=lfs diff=lfs merge=lfs -text +llm/Lib/site-packages/ctransformers/lib/avx2/libctransformers.dylib filter=lfs diff=lfs merge=lfs -text +llm/Lib/site-packages/ctransformers/lib/avx2/libctransformers.so filter=lfs diff=lfs merge=lfs -text +llm/Lib/site-packages/ctransformers/lib/basic/ctransformers.dll filter=lfs diff=lfs merge=lfs -text +llm/Lib/site-packages/ctransformers/lib/basic/libctransformers.dylib filter=lfs diff=lfs merge=lfs -text +llm/Lib/site-packages/ctransformers/lib/basic/libctransformers.so filter=lfs diff=lfs merge=lfs -text +llm/Lib/site-packages/ctransformers/lib/cuda/ctransformers.dll filter=lfs diff=lfs merge=lfs -text +llm/Lib/site-packages/ctransformers/lib/cuda/libctransformers.so filter=lfs diff=lfs merge=lfs -text +llm/Lib/site-packages/faiss_cpu.libs/flang-d38962844214aa9b06fc3989f9adae5b.dll filter=lfs diff=lfs merge=lfs -text +llm/Lib/site-packages/faiss_cpu.libs/openblas-1ba25ee8d70fa3c45ede15bdc95fbee3.dll filter=lfs diff=lfs merge=lfs -text +llm/Lib/site-packages/faiss/_swigfaiss_avx2.cp311-win_amd64.pyd filter=lfs diff=lfs merge=lfs -text +llm/Lib/site-packages/faiss/_swigfaiss.cp311-win_amd64.pyd filter=lfs diff=lfs merge=lfs -text diff --git a/llm/Include/site/python3.11/greenlet/greenlet.h b/llm/Include/site/python3.11/greenlet/greenlet.h new file mode 100644 index 0000000000000000000000000000000000000000..d02a16e43426fb1c1bb286f1cda463cb9b1185ad --- /dev/null +++ b/llm/Include/site/python3.11/greenlet/greenlet.h @@ -0,0 +1,164 @@ +/* -*- indent-tabs-mode: nil; tab-width: 4; -*- */ + +/* Greenlet object interface */ + +#ifndef Py_GREENLETOBJECT_H +#define Py_GREENLETOBJECT_H + + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/* This is deprecated and undocumented. It does not change. */ +#define GREENLET_VERSION "1.0.0" + +#ifndef GREENLET_MODULE +#define implementation_ptr_t void* +#endif + +typedef struct _greenlet { + PyObject_HEAD + PyObject* weakreflist; + PyObject* dict; + implementation_ptr_t pimpl; +} PyGreenlet; + +#define PyGreenlet_Check(op) (op && PyObject_TypeCheck(op, &PyGreenlet_Type)) + + +/* C API functions */ + +/* Total number of symbols that are exported */ +#define PyGreenlet_API_pointers 12 + +#define PyGreenlet_Type_NUM 0 +#define PyExc_GreenletError_NUM 1 +#define PyExc_GreenletExit_NUM 2 + +#define PyGreenlet_New_NUM 3 +#define PyGreenlet_GetCurrent_NUM 4 +#define PyGreenlet_Throw_NUM 5 +#define PyGreenlet_Switch_NUM 6 +#define PyGreenlet_SetParent_NUM 7 + +#define PyGreenlet_MAIN_NUM 8 +#define PyGreenlet_STARTED_NUM 9 +#define PyGreenlet_ACTIVE_NUM 10 +#define PyGreenlet_GET_PARENT_NUM 11 + +#ifndef GREENLET_MODULE +/* This section is used by modules that uses the greenlet C API */ +static void** _PyGreenlet_API = NULL; + +# define PyGreenlet_Type \ + (*(PyTypeObject*)_PyGreenlet_API[PyGreenlet_Type_NUM]) + +# define PyExc_GreenletError \ + ((PyObject*)_PyGreenlet_API[PyExc_GreenletError_NUM]) + +# define PyExc_GreenletExit \ + ((PyObject*)_PyGreenlet_API[PyExc_GreenletExit_NUM]) + +/* + * PyGreenlet_New(PyObject *args) + * + * greenlet.greenlet(run, parent=None) + */ +# define PyGreenlet_New \ + (*(PyGreenlet * (*)(PyObject * run, PyGreenlet * parent)) \ + _PyGreenlet_API[PyGreenlet_New_NUM]) + +/* + * PyGreenlet_GetCurrent(void) + * + * greenlet.getcurrent() + */ +# define PyGreenlet_GetCurrent \ + (*(PyGreenlet * (*)(void)) _PyGreenlet_API[PyGreenlet_GetCurrent_NUM]) + +/* + * PyGreenlet_Throw( + * PyGreenlet *greenlet, + * PyObject *typ, + * PyObject *val, + * PyObject *tb) + * + * g.throw(...) + */ +# define PyGreenlet_Throw \ + (*(PyObject * (*)(PyGreenlet * self, \ + PyObject * typ, \ + PyObject * val, \ + PyObject * tb)) \ + _PyGreenlet_API[PyGreenlet_Throw_NUM]) + +/* + * PyGreenlet_Switch(PyGreenlet *greenlet, PyObject *args) + * + * g.switch(*args, **kwargs) + */ +# define PyGreenlet_Switch \ + (*(PyObject * \ + (*)(PyGreenlet * greenlet, PyObject * args, PyObject * kwargs)) \ + _PyGreenlet_API[PyGreenlet_Switch_NUM]) + +/* + * PyGreenlet_SetParent(PyObject *greenlet, PyObject *new_parent) + * + * g.parent = new_parent + */ +# define PyGreenlet_SetParent \ + (*(int (*)(PyGreenlet * greenlet, PyGreenlet * nparent)) \ + _PyGreenlet_API[PyGreenlet_SetParent_NUM]) + +/* + * PyGreenlet_GetParent(PyObject* greenlet) + * + * return greenlet.parent; + * + * This could return NULL even if there is no exception active. + * If it does not return NULL, you are responsible for decrementing the + * reference count. + */ +# define PyGreenlet_GetParent \ + (*(PyGreenlet* (*)(PyGreenlet*)) \ + _PyGreenlet_API[PyGreenlet_GET_PARENT_NUM]) + +/* + * deprecated, undocumented alias. + */ +# define PyGreenlet_GET_PARENT PyGreenlet_GetParent + +# define PyGreenlet_MAIN \ + (*(int (*)(PyGreenlet*)) \ + _PyGreenlet_API[PyGreenlet_MAIN_NUM]) + +# define PyGreenlet_STARTED \ + (*(int (*)(PyGreenlet*)) \ + _PyGreenlet_API[PyGreenlet_STARTED_NUM]) + +# define PyGreenlet_ACTIVE \ + (*(int (*)(PyGreenlet*)) \ + _PyGreenlet_API[PyGreenlet_ACTIVE_NUM]) + + + + +/* Macro that imports greenlet and initializes C API */ +/* NOTE: This has actually moved to ``greenlet._greenlet._C_API``, but we + keep the older definition to be sure older code that might have a copy of + the header still works. */ +# define PyGreenlet_Import() \ + { \ + _PyGreenlet_API = (void**)PyCapsule_Import("greenlet._C_API", 0); \ + } + +#endif /* GREENLET_MODULE */ + +#ifdef __cplusplus +} +#endif +#endif /* !Py_GREENLETOBJECT_H */ diff --git a/llm/Lib/site-packages/GitPython-3.1.43.dist-info/AUTHORS b/llm/Lib/site-packages/GitPython-3.1.43.dist-info/AUTHORS new file mode 100644 index 0000000000000000000000000000000000000000..9311b39626f9cab901128b2442841a7218774cc0 --- /dev/null +++ b/llm/Lib/site-packages/GitPython-3.1.43.dist-info/AUTHORS @@ -0,0 +1,58 @@ +GitPython was originally written by Michael Trier. +GitPython 0.2 was partially (re)written by Sebastian Thiel, based on 0.1.6 and git-dulwich. + +Contributors are: + +-Michael Trier +-Alan Briolat +-Florian Apolloner +-David Aguilar +-Jelmer Vernooij +-Steve Frécinaux +-Kai Lautaportti +-Paul Sowden +-Sebastian Thiel +-Jonathan Chu +-Vincent Driessen +-Phil Elson +-Bernard `Guyzmo` Pratz +-Timothy B. Hartman +-Konstantin Popov +-Peter Jones +-Anson Mansfield +-Ken Odegard +-Alexis Horgix Chotard +-Piotr Babij +-Mikuláš Poul +-Charles Bouchard-Légaré +-Yaroslav Halchenko +-Tim Swast +-William Luc Ritchie +-David Host +-A. Jesse Jiryu Davis +-Steven Whitman +-Stefan Stancu +-César Izurieta +-Arthur Milchior +-Anil Khatri +-JJ Graham +-Ben Thayer +-Dries Kennes +-Pratik Anurag +-Harmon +-Liam Beguin +-Ram Rachum +-Alba Mendez +-Robert Westman +-Hugo van Kemenade +-Hiroki Tokunaga +-Julien Mauroy +-Patrick Gerard +-Luke Twist +-Joseph Hale +-Santos Gallegos +-Wenhan Zhu +-Eliah Kagan +-Ethan Lin + +Portions derived from other open source works and are clearly marked. diff --git a/llm/Lib/site-packages/GitPython-3.1.43.dist-info/INSTALLER b/llm/Lib/site-packages/GitPython-3.1.43.dist-info/INSTALLER new file mode 100644 index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68 --- /dev/null +++ b/llm/Lib/site-packages/GitPython-3.1.43.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/llm/Lib/site-packages/GitPython-3.1.43.dist-info/LICENSE b/llm/Lib/site-packages/GitPython-3.1.43.dist-info/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..ba8a219fe1f27c10b50df8cd4f26c0ab833bbbc8 --- /dev/null +++ b/llm/Lib/site-packages/GitPython-3.1.43.dist-info/LICENSE @@ -0,0 +1,29 @@ +Copyright (C) 2008, 2009 Michael Trier and contributors +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +* Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +* Neither the name of the GitPython project nor the names of +its contributors may be used to endorse or promote products derived +from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/llm/Lib/site-packages/GitPython-3.1.43.dist-info/METADATA b/llm/Lib/site-packages/GitPython-3.1.43.dist-info/METADATA new file mode 100644 index 0000000000000000000000000000000000000000..be628100532a16d615656573f53d4d69b67e4c5a --- /dev/null +++ b/llm/Lib/site-packages/GitPython-3.1.43.dist-info/METADATA @@ -0,0 +1,297 @@ +Metadata-Version: 2.1 +Name: GitPython +Version: 3.1.43 +Summary: GitPython is a Python library used to interact with Git repositories +Home-page: https://github.com/gitpython-developers/GitPython +Author: Sebastian Thiel, Michael Trier +Author-email: byronimo@gmail.com, mtrier@gmail.com +License: BSD-3-Clause +Classifier: Development Status :: 5 - Production/Stable +Classifier: Environment :: Console +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: BSD License +Classifier: Operating System :: OS Independent +Classifier: Operating System :: POSIX +Classifier: Operating System :: Microsoft :: Windows +Classifier: Operating System :: MacOS :: MacOS X +Classifier: Typing :: Typed +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Requires-Python: >=3.7 +Description-Content-Type: text/markdown +License-File: LICENSE +License-File: AUTHORS +Requires-Dist: gitdb <5,>=4.0.1 +Requires-Dist: typing-extensions >=3.7.4.3 ; python_version < "3.8" +Provides-Extra: doc +Requires-Dist: sphinx ==4.3.2 ; extra == 'doc' +Requires-Dist: sphinx-rtd-theme ; extra == 'doc' +Requires-Dist: sphinxcontrib-applehelp <=1.0.4,>=1.0.2 ; extra == 'doc' +Requires-Dist: sphinxcontrib-devhelp ==1.0.2 ; extra == 'doc' +Requires-Dist: sphinxcontrib-htmlhelp <=2.0.1,>=2.0.0 ; extra == 'doc' +Requires-Dist: sphinxcontrib-qthelp ==1.0.3 ; extra == 'doc' +Requires-Dist: sphinxcontrib-serializinghtml ==1.1.5 ; extra == 'doc' +Requires-Dist: sphinx-autodoc-typehints ; extra == 'doc' +Provides-Extra: test +Requires-Dist: coverage[toml] ; extra == 'test' +Requires-Dist: ddt !=1.4.3,>=1.1.1 ; extra == 'test' +Requires-Dist: mypy ; extra == 'test' +Requires-Dist: pre-commit ; extra == 'test' +Requires-Dist: pytest >=7.3.1 ; extra == 'test' +Requires-Dist: pytest-cov ; extra == 'test' +Requires-Dist: pytest-instafail ; extra == 'test' +Requires-Dist: pytest-mock ; extra == 'test' +Requires-Dist: pytest-sugar ; extra == 'test' +Requires-Dist: typing-extensions ; (python_version < "3.11") and extra == 'test' +Requires-Dist: mock ; (python_version < "3.8") and extra == 'test' + +![Python package](https://github.com/gitpython-developers/GitPython/workflows/Python%20package/badge.svg) +[![Documentation Status](https://readthedocs.org/projects/gitpython/badge/?version=stable)](https://readthedocs.org/projects/gitpython/?badge=stable) +[![Packaging status](https://repology.org/badge/tiny-repos/python:gitpython.svg)](https://repology.org/metapackage/python:gitpython/versions) + +## [Gitoxide](https://github.com/Byron/gitoxide): A peek into the future… + +I started working on GitPython in 2009, back in the days when Python was 'my thing' and I had great plans with it. +Of course, back in the days, I didn't really know what I was doing and this shows in many places. Somewhat similar to +Python this happens to be 'good enough', but at the same time is deeply flawed and broken beyond repair. + +By now, GitPython is widely used and I am sure there is a good reason for that, it's something to be proud of and happy about. +The community is maintaining the software and is keeping it relevant for which I am absolutely grateful. For the time to come I am happy to continue maintaining GitPython, remaining hopeful that one day it won't be needed anymore. + +More than 15 years after my first meeting with 'git' I am still in excited about it, and am happy to finally have the tools and +probably the skills to scratch that itch of mine: implement `git` in a way that makes tool creation a piece of cake for most. + +If you like the idea and want to learn more, please head over to [gitoxide](https://github.com/Byron/gitoxide), an +implementation of 'git' in [Rust](https://www.rust-lang.org). + +*(Please note that `gitoxide` is not currently available for use in Python, and that Rust is required.)* + +## GitPython + +GitPython is a python library used to interact with git repositories, high-level like git-porcelain, +or low-level like git-plumbing. + +It provides abstractions of git objects for easy access of repository data often backed by calling the `git` +command-line program. + +### DEVELOPMENT STATUS + +This project is in **maintenance mode**, which means that + +- …there will be no feature development, unless these are contributed +- …there will be no bug fixes, unless they are relevant to the safety of users, or contributed +- …issues will be responded to with waiting times of up to a month + +The project is open to contributions of all kinds, as well as new maintainers. + +### REQUIREMENTS + +GitPython needs the `git` executable to be installed on the system and available in your +`PATH` for most operations. If it is not in your `PATH`, you can help GitPython find it +by setting the `GIT_PYTHON_GIT_EXECUTABLE=` environment variable. + +- Git (1.7.x or newer) +- Python >= 3.7 + +The list of dependencies are listed in `./requirements.txt` and `./test-requirements.txt`. +The installer takes care of installing them for you. + +### INSTALL + +GitPython and its required package dependencies can be installed in any of the following ways, all of which should typically be done in a [virtual environment](https://docs.python.org/3/tutorial/venv.html). + +#### From PyPI + +To obtain and install a copy [from PyPI](https://pypi.org/project/GitPython/), run: + +```sh +pip install GitPython +``` + +(A distribution package can also be downloaded for manual installation at [the PyPI page](https://pypi.org/project/GitPython/).) + +#### From downloaded source code + +If you have downloaded the source code, run this from inside the unpacked `GitPython` directory: + +```sh +pip install . +``` + +#### By cloning the source code repository + +To clone the [the GitHub repository](https://github.com/gitpython-developers/GitPython) from source to work on the code, you can do it like so: + +```sh +git clone https://github.com/gitpython-developers/GitPython +cd GitPython +./init-tests-after-clone.sh +``` + +On Windows, `./init-tests-after-clone.sh` can be run in a Git Bash shell. + +If you are cloning [your own fork](https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/about-forks), then replace the above `git clone` command with one that gives the URL of your fork. Or use this [`gh`](https://cli.github.com/) command (assuming you have `gh` and your fork is called `GitPython`): + +```sh +gh repo clone GitPython +``` + +Having cloned the repo, create and activate your [virtual environment](https://docs.python.org/3/tutorial/venv.html). + +Then make an [editable install](https://pip.pypa.io/en/stable/topics/local-project-installs/#editable-installs): + +```sh +pip install -e ".[test]" +``` + +In the less common case that you do not want to install test dependencies, `pip install -e .` can be used instead. + +#### With editable *dependencies* (not preferred, and rarely needed) + +In rare cases, you may want to work on GitPython and one or both of its [gitdb](https://github.com/gitpython-developers/gitdb) and [smmap](https://github.com/gitpython-developers/smmap) dependencies at the same time, with changes in your local working copy of gitdb or smmap immediatley reflected in the behavior of your local working copy of GitPython. This can be done by making editable installations of those dependencies in the same virtual environment where you install GitPython. + +If you want to do that *and* you want the versions in GitPython's git submodules to be used, then pass `-e git/ext/gitdb` and/or `-e git/ext/gitdb/gitdb/ext/smmap` to `pip install`. This can be done in any order, and in separate `pip install` commands or the same one, so long as `-e` appears before *each* path. For example, you can install GitPython, gitdb, and smmap editably in the currently active virtual environment this way: + +```sh +pip install -e ".[test]" -e git/ext/gitdb -e git/ext/gitdb/gitdb/ext/smmap +``` + +The submodules must have been cloned for that to work, but that will already be the case if you have run `./init-tests-after-clone.sh`. You can use `pip list` to check which packages are installed editably and which are installed normally. + +To reiterate, this approach should only rarely be used. For most development it is preferable to allow the gitdb and smmap dependencices to be retrieved automatically from PyPI in their latest stable packaged versions. + +### Limitations + +#### Leakage of System Resources + +GitPython is not suited for long-running processes (like daemons) as it tends to +leak system resources. It was written in a time where destructors (as implemented +in the `__del__` method) still ran deterministically. + +In case you still want to use it in such a context, you will want to search the +codebase for `__del__` implementations and call these yourself when you see fit. + +Another way assure proper cleanup of resources is to factor out GitPython into a +separate process which can be dropped periodically. + +#### Windows support + +See [Issue #525](https://github.com/gitpython-developers/GitPython/issues/525). + +### RUNNING TESTS + +_Important_: Right after cloning this repository, please be sure to have executed +the `./init-tests-after-clone.sh` script in the repository root. Otherwise +you will encounter test failures. + +#### Install test dependencies + +Ensure testing libraries are installed. This is taken care of already if you installed with: + +```sh +pip install -e ".[test]" +``` + +If you had installed with a command like `pip install -e .` instead, you can still run +the above command to add the testing dependencies. + +#### Test commands + +To test, run: + +```sh +pytest +``` + +To lint, and apply some linting fixes as well as automatic code formatting, run: + +```sh +pre-commit run --all-files +``` + +This includes the linting and autoformatting done by Ruff, as well as some other checks. + +To typecheck, run: + +```sh +mypy +``` + +#### CI (and tox) + +Style and formatting checks, and running tests on all the different supported Python versions, will be performed: + +- Upon submitting a pull request. +- On each push, *if* you have a fork with GitHub Actions enabled. +- Locally, if you run [`tox`](https://tox.wiki/) (this skips any Python versions you don't have installed). + +#### Configuration files + +Specific tools are all configured in the `./pyproject.toml` file: + +- `pytest` (test runner) +- `coverage.py` (code coverage) +- `ruff` (linter and formatter) +- `mypy` (type checker) + +Orchestration tools: + +- Configuration for `pre-commit` is in the `./.pre-commit-config.yaml` file. +- Configuration for `tox` is in `./tox.ini`. +- Configuration for GitHub Actions (CI) is in files inside `./.github/workflows/`. + +### Contributions + +Please have a look at the [contributions file][contributing]. + +### INFRASTRUCTURE + +- [User Documentation](http://gitpython.readthedocs.org) +- [Questions and Answers](http://stackexchange.com/filters/167317/gitpython) +- Please post on Stack Overflow and use the `gitpython` tag +- [Issue Tracker](https://github.com/gitpython-developers/GitPython/issues) + - Post reproducible bugs and feature requests as a new issue. + Please be sure to provide the following information if posting bugs: + - GitPython version (e.g. `import git; git.__version__`) + - Python version (e.g. `python --version`) + - The encountered stack-trace, if applicable + - Enough information to allow reproducing the issue + +### How to make a new release + +1. Update/verify the **version** in the `VERSION` file. +2. Update/verify that the `doc/source/changes.rst` changelog file was updated. It should include a link to the forthcoming release page: `https://github.com/gitpython-developers/GitPython/releases/tag/` +3. Commit everything. +4. Run `git tag -s ` to tag the version in Git. +5. _Optionally_ create and activate a [virtual environment](https://packaging.python.org/en/latest/guides/installing-using-pip-and-virtual-environments/#creating-a-virtual-environment). (Then the next step can install `build` and `twine`.) +6. Run `make release`. +7. Go to [GitHub Releases](https://github.com/gitpython-developers/GitPython/releases) and publish a new one with the recently pushed tag. Generate the changelog. + +### Projects using GitPython + +- [PyDriller](https://github.com/ishepard/pydriller) +- [Kivy Designer](https://github.com/kivy/kivy-designer) +- [Prowl](https://github.com/nettitude/Prowl) +- [Python Taint](https://github.com/python-security/pyt) +- [Buster](https://github.com/axitkhurana/buster) +- [git-ftp](https://github.com/ezyang/git-ftp) +- [Git-Pandas](https://github.com/wdm0006/git-pandas) +- [PyGitUp](https://github.com/msiemens/PyGitUp) +- [PyJFuzz](https://github.com/mseclab/PyJFuzz) +- [Loki](https://github.com/Neo23x0/Loki) +- [Omniwallet](https://github.com/OmniLayer/omniwallet) +- [GitViper](https://github.com/BeayemX/GitViper) +- [Git Gud](https://github.com/bthayer2365/git-gud) + +### LICENSE + +[3-Clause BSD License](https://opensource.org/license/bsd-3-clause/), also known as the New BSD License. See the [LICENSE file][license]. + +[contributing]: https://github.com/gitpython-developers/GitPython/blob/main/CONTRIBUTING.md +[license]: https://github.com/gitpython-developers/GitPython/blob/main/LICENSE diff --git a/llm/Lib/site-packages/GitPython-3.1.43.dist-info/RECORD b/llm/Lib/site-packages/GitPython-3.1.43.dist-info/RECORD new file mode 100644 index 0000000000000000000000000000000000000000..8aee05f3847254112d30113a97df6626fbf1c78c --- /dev/null +++ b/llm/Lib/site-packages/GitPython-3.1.43.dist-info/RECORD @@ -0,0 +1,82 @@ +GitPython-3.1.43.dist-info/AUTHORS,sha256=h1TlPKfp05GA1eKQ15Yl4biR0C0FgivuGSeRA6Q1dz0,2286 +GitPython-3.1.43.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +GitPython-3.1.43.dist-info/LICENSE,sha256=hvyUwyGpr7wRUUcTURuv3tIl8lEA3MD3NQ6CvCMbi-s,1503 +GitPython-3.1.43.dist-info/METADATA,sha256=sAh3r1BMVw5_olGgDmpMS69zBpVr7UEOeRivNHKznfU,13376 +GitPython-3.1.43.dist-info/RECORD,, +GitPython-3.1.43.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92 +GitPython-3.1.43.dist-info/top_level.txt,sha256=0hzDuIp8obv624V3GmbqsagBWkk8ohtGU-Bc1PmTT0o,4 +git/__init__.py,sha256=w6fnS0QmwTfEFUSL6rfnpP0lUId2goSguZFOvVX3N3U,8899 +git/__pycache__/__init__.cpython-311.pyc,, +git/__pycache__/cmd.cpython-311.pyc,, +git/__pycache__/compat.cpython-311.pyc,, +git/__pycache__/config.cpython-311.pyc,, +git/__pycache__/db.cpython-311.pyc,, +git/__pycache__/diff.cpython-311.pyc,, +git/__pycache__/exc.cpython-311.pyc,, +git/__pycache__/remote.cpython-311.pyc,, +git/__pycache__/types.cpython-311.pyc,, +git/__pycache__/util.cpython-311.pyc,, +git/cmd.py,sha256=qd-gIHSk4mfsYjd9YA08cPyO8TMxaibTXAbFnHK71uc,67659 +git/compat.py,sha256=y1E6y6O2q5r8clSlr8ZNmuIWG9nmHuehQEsVsmBffs8,4526 +git/config.py,sha256=Ald8Xc-G9Shcgx3QCISyXTkL4a6nbc3qll-xUw4YdyY,34924 +git/db.py,sha256=vIW9uWSbqu99zbuU2ZDmOhVOv1UPTmxrnqiCtRHCfjE,2368 +git/diff.py,sha256=IE5aeHL7aP9yxBluYj06IX8nZjoJ_TOM3gG31-Evf_8,27058 +git/exc.py,sha256=Gc7g1pHpn8OmTse30NHmJVsBJ2CYH8LxaR8y8UA3lIM,7119 +git/index/__init__.py,sha256=i-Nqb8Lufp9aFbmxpQBORmmQnjEVVM1Pn58fsQkyGgQ,406 +git/index/__pycache__/__init__.cpython-311.pyc,, +git/index/__pycache__/base.cpython-311.pyc,, +git/index/__pycache__/fun.cpython-311.pyc,, +git/index/__pycache__/typ.cpython-311.pyc,, +git/index/__pycache__/util.cpython-311.pyc,, +git/index/base.py,sha256=A4q4cN_Ifxi8CsAR-7h4KsQ2d3JazBNFZ1ltbAKttgs,60734 +git/index/fun.py,sha256=37cA3DBC9vpAnSVu5TGA072SnoF5XZOkOukExwlejHs,16736 +git/index/typ.py,sha256=uuKNwitUw83FhVaLSwo4pY7PHDQudtZTLJrLGym4jcI,6570 +git/index/util.py,sha256=fULi7GPG-MvprKrRCD5c15GNdzku_1E38We0d97WB3A,3659 +git/objects/__init__.py,sha256=O6ZL_olX7e5-8iIbKviRPkVSJxN37WA-EC0q9d48U5Y,637 +git/objects/__pycache__/__init__.cpython-311.pyc,, +git/objects/__pycache__/base.cpython-311.pyc,, +git/objects/__pycache__/blob.cpython-311.pyc,, +git/objects/__pycache__/commit.cpython-311.pyc,, +git/objects/__pycache__/fun.cpython-311.pyc,, +git/objects/__pycache__/tag.cpython-311.pyc,, +git/objects/__pycache__/tree.cpython-311.pyc,, +git/objects/__pycache__/util.cpython-311.pyc,, +git/objects/base.py,sha256=0dqNkSRVH0mk0-7ZKIkGBK7iNYrzLTVxwQFUd6CagsE,10277 +git/objects/blob.py,sha256=zwwq0KfOMYeP5J2tW5CQatoLyeqFRlfkxP1Vwx1h07s,1215 +git/objects/commit.py,sha256=vLZNl1I9zp17Rpge7J66CvsryirEs90jyPTQzoP0JJs,30208 +git/objects/fun.py,sha256=B4jCqhAjm6Hl79GK58FPzW1H9K6Wc7Tx0rssyWmAcEE,8935 +git/objects/submodule/__init__.py,sha256=6xySp767LVz3UylWgUalntS_nGXRuVzXxDuFAv_Wc2c,303 +git/objects/submodule/__pycache__/__init__.cpython-311.pyc,, +git/objects/submodule/__pycache__/base.cpython-311.pyc,, +git/objects/submodule/__pycache__/root.cpython-311.pyc,, +git/objects/submodule/__pycache__/util.cpython-311.pyc,, +git/objects/submodule/base.py,sha256=MQ-2xV8JznGwy2hLQv1aeQNgAkhBhgc5tdtClFL3DmE,63901 +git/objects/submodule/root.py,sha256=5eTtYNHasqdPq6q0oDCPr7IaO6uAHL3b4DxMoiO2LhE,20246 +git/objects/submodule/util.py,sha256=sQqAYaiSJdFkZa9NlAuK_wTsMNiS-kkQnQjvIoJtc_o,3509 +git/objects/tag.py,sha256=gAx8i-DEwy_Z3R2zLkvetYRV8A56BCcTr3iLuTUTfEM,4467 +git/objects/tree.py,sha256=jJH888SHiP4dGzE-ra1yenQOyya_0C_MkHr06c1gHpM,13849 +git/objects/util.py,sha256=Ml2eqZPKO4y9Hc2vWbXJgpsK3nkN3KGMzbn8AlzLyYQ,23834 +git/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +git/refs/__init__.py,sha256=DWlJNnsx-4jM_E-VycbP-FZUdn6iWhjnH_uZ_pZXBro,509 +git/refs/__pycache__/__init__.cpython-311.pyc,, +git/refs/__pycache__/head.cpython-311.pyc,, +git/refs/__pycache__/log.cpython-311.pyc,, +git/refs/__pycache__/reference.cpython-311.pyc,, +git/refs/__pycache__/remote.cpython-311.pyc,, +git/refs/__pycache__/symbolic.cpython-311.pyc,, +git/refs/__pycache__/tag.cpython-311.pyc,, +git/refs/head.py,sha256=GAZpD5EfqSciDXPtgjHY8ZbBixKExJRhojUB-HrrJPg,10491 +git/refs/log.py,sha256=kXiuAgTo1DIuM_BfbDUk9gQ0YO-mutIMVdHv1_ES90o,12493 +git/refs/reference.py,sha256=l6mhF4YLSEwtjz6b9PpOQH-fkng7EYWMaJhkjn-2jXA,5630 +git/refs/remote.py,sha256=WwqV9T7BbYf3F_WZNUQivu9xktIIKGklCjDpwQrhD-A,2806 +git/refs/symbolic.py,sha256=c8zOwaqzcg-J-rGrpuWdvh8zwMvSUqAHghd4vJoYG_s,34552 +git/refs/tag.py,sha256=kgzV2vhpL4FD2TqHb0BJuMRAHgAvJF-TcoyWlaB-djQ,5010 +git/remote.py,sha256=IHQ3BvXgoIN1EvHlyH3vrSaQoDkLOE6nooSC0w183sU,46561 +git/repo/__init__.py,sha256=CILSVH36fX_WxVFSjD9o1WF5LgsNedPiJvSngKZqfVU,210 +git/repo/__pycache__/__init__.cpython-311.pyc,, +git/repo/__pycache__/base.cpython-311.pyc,, +git/repo/__pycache__/fun.cpython-311.pyc,, +git/repo/base.py,sha256=mitfJ8u99CsMpDd7_VRyx-SF8omu2tpf3lqzSaQkKoQ,59353 +git/repo/fun.py,sha256=tEsClpmbOrKMSNIdncOB_6JdikrL1-AfkOFd7xMpD8k,13582 +git/types.py,sha256=xCwpp2Y01lhS0MapHhj04m0P_x34kwSD1Gsou_ZPWj8,10251 +git/util.py,sha256=1E883mnPAFLyFk7ivwnEremsp-uJOTc3ks_QypyLung,43651 diff --git a/llm/Lib/site-packages/GitPython-3.1.43.dist-info/WHEEL b/llm/Lib/site-packages/GitPython-3.1.43.dist-info/WHEEL new file mode 100644 index 0000000000000000000000000000000000000000..bab98d675883cc7567a79df485cd7b4f015e376f --- /dev/null +++ b/llm/Lib/site-packages/GitPython-3.1.43.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.43.0) +Root-Is-Purelib: true +Tag: py3-none-any + diff --git a/llm/Lib/site-packages/GitPython-3.1.43.dist-info/top_level.txt b/llm/Lib/site-packages/GitPython-3.1.43.dist-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..5664e303b5dc2e9ef8e14a0845d9486ec1920afd --- /dev/null +++ b/llm/Lib/site-packages/GitPython-3.1.43.dist-info/top_level.txt @@ -0,0 +1 @@ +git diff --git a/llm/Lib/site-packages/Jinja2-3.1.3.dist-info/INSTALLER b/llm/Lib/site-packages/Jinja2-3.1.3.dist-info/INSTALLER new file mode 100644 index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68 --- /dev/null +++ b/llm/Lib/site-packages/Jinja2-3.1.3.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/llm/Lib/site-packages/Jinja2-3.1.3.dist-info/LICENSE.rst b/llm/Lib/site-packages/Jinja2-3.1.3.dist-info/LICENSE.rst new file mode 100644 index 0000000000000000000000000000000000000000..c37cae49ec77ad6ebb25568c1605f1fee5313cfb --- /dev/null +++ b/llm/Lib/site-packages/Jinja2-3.1.3.dist-info/LICENSE.rst @@ -0,0 +1,28 @@ +Copyright 2007 Pallets + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + +1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in the + documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A +PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED +TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/llm/Lib/site-packages/Jinja2-3.1.3.dist-info/METADATA b/llm/Lib/site-packages/Jinja2-3.1.3.dist-info/METADATA new file mode 100644 index 0000000000000000000000000000000000000000..56e942902a96e7f012479a582c5cf89511219f9a --- /dev/null +++ b/llm/Lib/site-packages/Jinja2-3.1.3.dist-info/METADATA @@ -0,0 +1,105 @@ +Metadata-Version: 2.1 +Name: Jinja2 +Version: 3.1.3 +Summary: A very fast and expressive template engine. +Home-page: https://palletsprojects.com/p/jinja/ +Maintainer: Pallets +Maintainer-email: contact@palletsprojects.com +License: BSD-3-Clause +Project-URL: Donate, https://palletsprojects.com/donate +Project-URL: Documentation, https://jinja.palletsprojects.com/ +Project-URL: Changes, https://jinja.palletsprojects.com/changes/ +Project-URL: Source Code, https://github.com/pallets/jinja/ +Project-URL: Issue Tracker, https://github.com/pallets/jinja/issues/ +Project-URL: Chat, https://discord.gg/pallets +Classifier: Development Status :: 5 - Production/Stable +Classifier: Environment :: Web Environment +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: BSD License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python +Classifier: Topic :: Internet :: WWW/HTTP :: Dynamic Content +Classifier: Topic :: Text Processing :: Markup :: HTML +Requires-Python: >=3.7 +Description-Content-Type: text/x-rst +License-File: LICENSE.rst +Requires-Dist: MarkupSafe >=2.0 +Provides-Extra: i18n +Requires-Dist: Babel >=2.7 ; extra == 'i18n' + +Jinja +===== + +Jinja is a fast, expressive, extensible templating engine. Special +placeholders in the template allow writing code similar to Python +syntax. Then the template is passed data to render the final document. + +It includes: + +- Template inheritance and inclusion. +- Define and import macros within templates. +- HTML templates can use autoescaping to prevent XSS from untrusted + user input. +- A sandboxed environment can safely render untrusted templates. +- AsyncIO support for generating templates and calling async + functions. +- I18N support with Babel. +- Templates are compiled to optimized Python code just-in-time and + cached, or can be compiled ahead-of-time. +- Exceptions point to the correct line in templates to make debugging + easier. +- Extensible filters, tests, functions, and even syntax. + +Jinja's philosophy is that while application logic belongs in Python if +possible, it shouldn't make the template designer's job difficult by +restricting functionality too much. + + +Installing +---------- + +Install and update using `pip`_: + +.. code-block:: text + + $ pip install -U Jinja2 + +.. _pip: https://pip.pypa.io/en/stable/getting-started/ + + +In A Nutshell +------------- + +.. code-block:: jinja + + {% extends "base.html" %} + {% block title %}Members{% endblock %} + {% block content %} + + {% endblock %} + + +Donate +------ + +The Pallets organization develops and supports Jinja and other popular +packages. In order to grow the community of contributors and users, and +allow the maintainers to devote more time to the projects, `please +donate today`_. + +.. _please donate today: https://palletsprojects.com/donate + + +Links +----- + +- Documentation: https://jinja.palletsprojects.com/ +- Changes: https://jinja.palletsprojects.com/changes/ +- PyPI Releases: https://pypi.org/project/Jinja2/ +- Source Code: https://github.com/pallets/jinja/ +- Issue Tracker: https://github.com/pallets/jinja/issues/ +- Chat: https://discord.gg/pallets diff --git a/llm/Lib/site-packages/Jinja2-3.1.3.dist-info/RECORD b/llm/Lib/site-packages/Jinja2-3.1.3.dist-info/RECORD new file mode 100644 index 0000000000000000000000000000000000000000..c2a7387852f4e22b43477a32fd7884c8f655a392 --- /dev/null +++ b/llm/Lib/site-packages/Jinja2-3.1.3.dist-info/RECORD @@ -0,0 +1,58 @@ +Jinja2-3.1.3.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +Jinja2-3.1.3.dist-info/LICENSE.rst,sha256=O0nc7kEF6ze6wQ-vG-JgQI_oXSUrjp3y4JefweCUQ3s,1475 +Jinja2-3.1.3.dist-info/METADATA,sha256=0cLNbRCI91jytc7Bzv3XAQfZzFDF2gxkJuH46eF5vew,3301 +Jinja2-3.1.3.dist-info/RECORD,, +Jinja2-3.1.3.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92 +Jinja2-3.1.3.dist-info/entry_points.txt,sha256=zRd62fbqIyfUpsRtU7EVIFyiu1tPwfgO7EvPErnxgTE,59 +Jinja2-3.1.3.dist-info/top_level.txt,sha256=PkeVWtLb3-CqjWi1fO29OCbj55EhX_chhKrCdrVe_zs,7 +jinja2/__init__.py,sha256=NTBwMwsECrdHmxeXF7seusHLzrh6Ldn1A9qhS5cDuf0,1927 +jinja2/__pycache__/__init__.cpython-311.pyc,, +jinja2/__pycache__/_identifier.cpython-311.pyc,, +jinja2/__pycache__/async_utils.cpython-311.pyc,, +jinja2/__pycache__/bccache.cpython-311.pyc,, +jinja2/__pycache__/compiler.cpython-311.pyc,, +jinja2/__pycache__/constants.cpython-311.pyc,, +jinja2/__pycache__/debug.cpython-311.pyc,, +jinja2/__pycache__/defaults.cpython-311.pyc,, +jinja2/__pycache__/environment.cpython-311.pyc,, +jinja2/__pycache__/exceptions.cpython-311.pyc,, +jinja2/__pycache__/ext.cpython-311.pyc,, +jinja2/__pycache__/filters.cpython-311.pyc,, +jinja2/__pycache__/idtracking.cpython-311.pyc,, +jinja2/__pycache__/lexer.cpython-311.pyc,, +jinja2/__pycache__/loaders.cpython-311.pyc,, +jinja2/__pycache__/meta.cpython-311.pyc,, +jinja2/__pycache__/nativetypes.cpython-311.pyc,, +jinja2/__pycache__/nodes.cpython-311.pyc,, +jinja2/__pycache__/optimizer.cpython-311.pyc,, +jinja2/__pycache__/parser.cpython-311.pyc,, +jinja2/__pycache__/runtime.cpython-311.pyc,, +jinja2/__pycache__/sandbox.cpython-311.pyc,, +jinja2/__pycache__/tests.cpython-311.pyc,, +jinja2/__pycache__/utils.cpython-311.pyc,, +jinja2/__pycache__/visitor.cpython-311.pyc,, +jinja2/_identifier.py,sha256=_zYctNKzRqlk_murTNlzrju1FFJL7Va_Ijqqd7ii2lU,1958 +jinja2/async_utils.py,sha256=dFcmh6lMNfbh7eLKrBio8JqAKLHdZbpCuurFN4OERtY,2447 +jinja2/bccache.py,sha256=mhz5xtLxCcHRAa56azOhphIAe19u1we0ojifNMClDio,14061 +jinja2/compiler.py,sha256=PJzYdRLStlEOqmnQs1YxlizPrJoj3jTZuUleREn6AIQ,72199 +jinja2/constants.py,sha256=GMoFydBF_kdpaRKPoM5cl5MviquVRLVyZtfp5-16jg0,1433 +jinja2/debug.py,sha256=iWJ432RadxJNnaMOPrjIDInz50UEgni3_HKuFXi2vuQ,6299 +jinja2/defaults.py,sha256=boBcSw78h-lp20YbaXSJsqkAI2uN_mD_TtCydpeq5wU,1267 +jinja2/environment.py,sha256=0qldX3VQKZcm6lgn7zHz94oRFow7YPYERiqkquomNjU,61253 +jinja2/exceptions.py,sha256=ioHeHrWwCWNaXX1inHmHVblvc4haO7AXsjCp3GfWvx0,5071 +jinja2/ext.py,sha256=5fnMpllaXkfm2P_93RIvi-OnK7Tk8mCW8Du-GcD12Hc,31844 +jinja2/filters.py,sha256=vYjKb2zaPShvYtn_LpSmqfS8SScbrA_KOanNibsMDIE,53862 +jinja2/idtracking.py,sha256=GfNmadir4oDALVxzn3DL9YInhJDr69ebXeA2ygfuCGA,10704 +jinja2/lexer.py,sha256=DW2nX9zk-6MWp65YR2bqqj0xqCvLtD-u9NWT8AnFRxQ,29726 +jinja2/loaders.py,sha256=ayAwxfrA1SAffQta0nwSDm3TDT4KYiIGN_D9Z45B310,23085 +jinja2/meta.py,sha256=GNPEvifmSaU3CMxlbheBOZjeZ277HThOPUTf1RkppKQ,4396 +jinja2/nativetypes.py,sha256=7GIGALVJgdyL80oZJdQUaUfwSt5q2lSSZbXt0dNf_M4,4210 +jinja2/nodes.py,sha256=i34GPRAZexXMT6bwuf5SEyvdmS-bRCy9KMjwN5O6pjk,34550 +jinja2/optimizer.py,sha256=tHkMwXxfZkbfA1KmLcqmBMSaz7RLIvvItrJcPoXTyD8,1650 +jinja2/parser.py,sha256=Y199wPL-G67gJoi5G_5sHuu9uEP1PJkjjLEW_xTH8-k,39736 +jinja2/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +jinja2/runtime.py,sha256=_6LkKIWFJjQdqlrgA3K39zBFQ-7Orm3wGDm96RwxQoE,33406 +jinja2/sandbox.py,sha256=Y0xZeXQnH6EX5VjaV2YixESxoepnRbW_3UeQosaBU3M,14584 +jinja2/tests.py,sha256=Am5Z6Lmfr2XaH_npIfJJ8MdXtWsbLjMULZJulTAj30E,5905 +jinja2/utils.py,sha256=IMwRIcN1SsTw2-jdQtlH2KzNABsXZBW_-tnFXafQBvY,23933 +jinja2/visitor.py,sha256=MH14C6yq24G_KVtWzjwaI7Wg14PCJIYlWW1kpkxYak0,3568 diff --git a/llm/Lib/site-packages/Jinja2-3.1.3.dist-info/WHEEL b/llm/Lib/site-packages/Jinja2-3.1.3.dist-info/WHEEL new file mode 100644 index 0000000000000000000000000000000000000000..98c0d20b7a64f4f998d7913e1d38a05dba20916c --- /dev/null +++ b/llm/Lib/site-packages/Jinja2-3.1.3.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.42.0) +Root-Is-Purelib: true +Tag: py3-none-any + diff --git a/llm/Lib/site-packages/Jinja2-3.1.3.dist-info/entry_points.txt b/llm/Lib/site-packages/Jinja2-3.1.3.dist-info/entry_points.txt new file mode 100644 index 0000000000000000000000000000000000000000..7b9666c8ea311ea0f0cfe7bed861aaa5469f92bb --- /dev/null +++ b/llm/Lib/site-packages/Jinja2-3.1.3.dist-info/entry_points.txt @@ -0,0 +1,2 @@ +[babel.extractors] +jinja2 = jinja2.ext:babel_extract[i18n] diff --git a/llm/Lib/site-packages/Jinja2-3.1.3.dist-info/top_level.txt b/llm/Lib/site-packages/Jinja2-3.1.3.dist-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..7f7afbf3bf54b346092be6a72070fcbd305ead1e --- /dev/null +++ b/llm/Lib/site-packages/Jinja2-3.1.3.dist-info/top_level.txt @@ -0,0 +1 @@ +jinja2 diff --git a/llm/Lib/site-packages/accelerate-0.29.3.dist-info/INSTALLER b/llm/Lib/site-packages/accelerate-0.29.3.dist-info/INSTALLER new file mode 100644 index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68 --- /dev/null +++ b/llm/Lib/site-packages/accelerate-0.29.3.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/llm/Lib/site-packages/accelerate-0.29.3.dist-info/LICENSE b/llm/Lib/site-packages/accelerate-0.29.3.dist-info/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/llm/Lib/site-packages/accelerate-0.29.3.dist-info/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/llm/Lib/site-packages/accelerate-0.29.3.dist-info/METADATA b/llm/Lib/site-packages/accelerate-0.29.3.dist-info/METADATA new file mode 100644 index 0000000000000000000000000000000000000000..3fc723bb7bb61d73562595210ddb57c130cefb53 --- /dev/null +++ b/llm/Lib/site-packages/accelerate-0.29.3.dist-info/METADATA @@ -0,0 +1,378 @@ +Metadata-Version: 2.1 +Name: accelerate +Version: 0.29.3 +Summary: Accelerate +Home-page: https://github.com/huggingface/accelerate +Author: The HuggingFace team +Author-email: zach.mueller@huggingface.co +License: Apache +Keywords: deep learning +Classifier: Development Status :: 5 - Production/Stable +Classifier: Intended Audience :: Developers +Classifier: Intended Audience :: Education +Classifier: Intended Audience :: Science/Research +Classifier: License :: OSI Approved :: Apache Software License +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence +Requires-Python: >=3.8.0 +Description-Content-Type: text/markdown +License-File: LICENSE +Requires-Dist: numpy (>=1.17) +Requires-Dist: packaging (>=20.0) +Requires-Dist: psutil +Requires-Dist: pyyaml +Requires-Dist: torch (>=1.10.0) +Requires-Dist: huggingface-hub +Requires-Dist: safetensors (>=0.3.1) +Provides-Extra: dev +Requires-Dist: black (~=23.1) ; extra == 'dev' +Requires-Dist: hf-doc-builder (>=0.3.0) ; extra == 'dev' +Requires-Dist: ruff (~=0.2.1) ; extra == 'dev' +Requires-Dist: pytest (<=8.0.0,>=7.2.0) ; extra == 'dev' +Requires-Dist: pytest-xdist ; extra == 'dev' +Requires-Dist: pytest-subtests ; extra == 'dev' +Requires-Dist: parameterized ; extra == 'dev' +Requires-Dist: datasets ; extra == 'dev' +Requires-Dist: evaluate ; extra == 'dev' +Requires-Dist: torchpippy (>=0.2.0) ; extra == 'dev' +Requires-Dist: transformers ; extra == 'dev' +Requires-Dist: scipy ; extra == 'dev' +Requires-Dist: scikit-learn ; extra == 'dev' +Requires-Dist: deepspeed ; extra == 'dev' +Requires-Dist: tqdm ; extra == 'dev' +Requires-Dist: bitsandbytes ; extra == 'dev' +Requires-Dist: timm ; extra == 'dev' +Requires-Dist: rich ; extra == 'dev' +Provides-Extra: docs +Provides-Extra: quality +Requires-Dist: black (~=23.1) ; extra == 'quality' +Requires-Dist: hf-doc-builder (>=0.3.0) ; extra == 'quality' +Requires-Dist: ruff (~=0.2.1) ; extra == 'quality' +Provides-Extra: rich +Requires-Dist: rich ; extra == 'rich' +Provides-Extra: sagemaker +Requires-Dist: sagemaker ; extra == 'sagemaker' +Provides-Extra: test_dev +Requires-Dist: datasets ; extra == 'test_dev' +Requires-Dist: evaluate ; extra == 'test_dev' +Requires-Dist: torchpippy (>=0.2.0) ; extra == 'test_dev' +Requires-Dist: transformers ; extra == 'test_dev' +Requires-Dist: scipy ; extra == 'test_dev' +Requires-Dist: scikit-learn ; extra == 'test_dev' +Requires-Dist: deepspeed ; extra == 'test_dev' +Requires-Dist: tqdm ; extra == 'test_dev' +Requires-Dist: bitsandbytes ; extra == 'test_dev' +Requires-Dist: timm ; extra == 'test_dev' +Provides-Extra: test_prod +Requires-Dist: pytest (<=8.0.0,>=7.2.0) ; extra == 'test_prod' +Requires-Dist: pytest-xdist ; extra == 'test_prod' +Requires-Dist: pytest-subtests ; extra == 'test_prod' +Requires-Dist: parameterized ; extra == 'test_prod' +Provides-Extra: test_trackers +Requires-Dist: wandb ; extra == 'test_trackers' +Requires-Dist: comet-ml ; extra == 'test_trackers' +Requires-Dist: tensorboard ; extra == 'test_trackers' +Requires-Dist: dvclive ; extra == 'test_trackers' +Provides-Extra: testing +Requires-Dist: pytest (<=8.0.0,>=7.2.0) ; extra == 'testing' +Requires-Dist: pytest-xdist ; extra == 'testing' +Requires-Dist: pytest-subtests ; extra == 'testing' +Requires-Dist: parameterized ; extra == 'testing' +Requires-Dist: datasets ; extra == 'testing' +Requires-Dist: evaluate ; extra == 'testing' +Requires-Dist: torchpippy (>=0.2.0) ; extra == 'testing' +Requires-Dist: transformers ; extra == 'testing' +Requires-Dist: scipy ; extra == 'testing' +Requires-Dist: scikit-learn ; extra == 'testing' +Requires-Dist: deepspeed ; extra == 'testing' +Requires-Dist: tqdm ; extra == 'testing' +Requires-Dist: bitsandbytes ; extra == 'testing' +Requires-Dist: timm ; extra == 'testing' + + + +

+
+ +
+

+ +

+ + + License + + + Documentation + + + GitHub release + + + Contributor Covenant + +

+ +

+

Run your *raw* PyTorch training script on any kind of device +

+ +

+ +

+ +## Easy to integrate + +🤗 Accelerate was created for PyTorch users who like to write the training loop of PyTorch models but are reluctant to write and maintain the boilerplate code needed to use multi-GPUs/TPU/fp16. + +🤗 Accelerate abstracts exactly and only the boilerplate code related to multi-GPUs/TPU/fp16 and leaves the rest of your code unchanged. + +Here is an example: + +```diff + import torch + import torch.nn.functional as F + from datasets import load_dataset ++ from accelerate import Accelerator + ++ accelerator = Accelerator() +- device = 'cpu' ++ device = accelerator.device + + model = torch.nn.Transformer().to(device) + optimizer = torch.optim.Adam(model.parameters()) + + dataset = load_dataset('my_dataset') + data = torch.utils.data.DataLoader(dataset, shuffle=True) + ++ model, optimizer, data = accelerator.prepare(model, optimizer, data) + + model.train() + for epoch in range(10): + for source, targets in data: + source = source.to(device) + targets = targets.to(device) + + optimizer.zero_grad() + + output = model(source) + loss = F.cross_entropy(output, targets) + +- loss.backward() ++ accelerator.backward(loss) + + optimizer.step() +``` + +As you can see in this example, by adding 5-lines to any standard PyTorch training script you can now run on any kind of single or distributed node setting (single CPU, single GPU, multi-GPUs and TPUs) as well as with or without mixed precision (fp8, fp16, bf16). + +In particular, the same code can then be run without modification on your local machine for debugging or your training environment. + +🤗 Accelerate even handles the device placement for you (which requires a few more changes to your code, but is safer in general), so you can even simplify your training loop further: + +```diff + import torch + import torch.nn.functional as F + from datasets import load_dataset ++ from accelerate import Accelerator + +- device = 'cpu' ++ accelerator = Accelerator() + +- model = torch.nn.Transformer().to(device) ++ model = torch.nn.Transformer() + optimizer = torch.optim.Adam(model.parameters()) + + dataset = load_dataset('my_dataset') + data = torch.utils.data.DataLoader(dataset, shuffle=True) + ++ model, optimizer, data = accelerator.prepare(model, optimizer, data) + + model.train() + for epoch in range(10): + for source, targets in data: +- source = source.to(device) +- targets = targets.to(device) + + optimizer.zero_grad() + + output = model(source) + loss = F.cross_entropy(output, targets) + +- loss.backward() ++ accelerator.backward(loss) + + optimizer.step() +``` + +Want to learn more? Check out the [documentation](https://huggingface.co/docs/accelerate) or have a look at our [examples](https://github.com/huggingface/accelerate/tree/main/examples). + +## Launching script + +🤗 Accelerate also provides an optional CLI tool that allows you to quickly configure and test your training environment before launching the scripts. No need to remember how to use `torch.distributed.run` or to write a specific launcher for TPU training! +On your machine(s) just run: + +```bash +accelerate config +``` + +and answer the questions asked. This will generate a config file that will be used automatically to properly set the default options when doing + +```bash +accelerate launch my_script.py --args_to_my_script +``` + +For instance, here is how you would run the GLUE example on the MRPC task (from the root of the repo): + +```bash +accelerate launch examples/nlp_example.py +``` + +This CLI tool is **optional**, and you can still use `python my_script.py` or `python -m torchrun my_script.py` at your convenience. + +You can also directly pass in the arguments you would to `torchrun` as arguments to `accelerate launch` if you wish to not run` accelerate config`. + +For example, here is how to launch on two GPUs: + +```bash +accelerate launch --multi_gpu --num_processes 2 examples/nlp_example.py +``` + +To learn more, check the CLI documentation available [here](https://huggingface.co/docs/accelerate/package_reference/cli). + +## Launching multi-CPU run using MPI + +🤗 Here is another way to launch multi-CPU run using MPI. You can learn how to install Open MPI on [this page](https://www.open-mpi.org/faq/?category=building#easy-build). You can use Intel MPI or MVAPICH as well. +Once you have MPI setup on your cluster, just run: +```bash +accelerate config +``` +Answer the questions that are asked, selecting to run using multi-CPU, and answer "yes" when asked if you want accelerate to launch mpirun. +Then, use `accelerate launch` with your script like: +```bash +accelerate launch examples/nlp_example.py +``` +Alternatively, you can use mpirun directly, without using the CLI like: +```bash +mpirun -np 2 python examples/nlp_example.py +``` + +## Launching training using DeepSpeed + +🤗 Accelerate supports training on single/multiple GPUs using DeepSpeed. To use it, you don't need to change anything in your training code; you can set everything using just `accelerate config`. However, if you desire to tweak your DeepSpeed related args from your Python script, we provide you the `DeepSpeedPlugin`. + +```python +from accelerate import Accelerator, DeepSpeedPlugin + +# deepspeed needs to know your gradient accumulation steps beforehand, so don't forget to pass it +# Remember you still need to do gradient accumulation by yourself, just like you would have done without deepspeed +deepspeed_plugin = DeepSpeedPlugin(zero_stage=2, gradient_accumulation_steps=2) +accelerator = Accelerator(mixed_precision='fp16', deepspeed_plugin=deepspeed_plugin) + +# How to save your 🤗 Transformer? +accelerator.wait_for_everyone() +unwrapped_model = accelerator.unwrap_model(model) +unwrapped_model.save_pretrained(save_dir, save_function=accelerator.save, state_dict=accelerator.get_state_dict(model)) +``` + +Note: DeepSpeed support is experimental for now. In case you get into some problem, please open an issue. + +## Launching your training from a notebook + +🤗 Accelerate also provides a `notebook_launcher` function you can use in a notebook to launch a distributed training. This is especially useful for Colab or Kaggle notebooks with a TPU backend. Just define your training loop in a `training_function` then in your last cell, add: + +```python +from accelerate import notebook_launcher + +notebook_launcher(training_function) +``` + +An example can be found in [this notebook](https://github.com/huggingface/notebooks/blob/main/examples/accelerate_examples/simple_nlp_example.ipynb). [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/accelerate_examples/simple_nlp_example.ipynb) + +## Why should I use 🤗 Accelerate? + +You should use 🤗 Accelerate when you want to easily run your training scripts in a distributed environment without having to renounce full control over your training loop. This is not a high-level framework above PyTorch, just a thin wrapper so you don't have to learn a new library. In fact, the whole API of 🤗 Accelerate is in one class, the `Accelerator` object. + +## Why shouldn't I use 🤗 Accelerate? + +You shouldn't use 🤗 Accelerate if you don't want to write a training loop yourself. There are plenty of high-level libraries above PyTorch that will offer you that, 🤗 Accelerate is not one of them. + +## Frameworks using 🤗 Accelerate + +If you like the simplicity of 🤗 Accelerate but would prefer a higher-level abstraction around its capabilities, some frameworks and libraries that are built on top of 🤗 Accelerate are listed below: + +* [Amphion](https://github.com/open-mmlab/Amphion) is a toolkit for Audio, Music, and Speech Generation. Its purpose is to support reproducible research and help junior researchers and engineers get started in the field of audio, music, and speech generation research and development. +* [Animus](https://github.com/Scitator/animus) is a minimalistic framework to run machine learning experiments. Animus highlights common "breakpoints" in ML experiments and provides a unified interface for them within [IExperiment](https://github.com/Scitator/animus/blob/main/animus/core.py#L76). +* [Catalyst](https://github.com/catalyst-team/catalyst#getting-started) is a PyTorch framework for Deep Learning Research and Development. It focuses on reproducibility, rapid experimentation, and codebase reuse so you can create something new rather than write yet another train loop. Catalyst provides a [Runner](https://catalyst-team.github.io/catalyst/api/core.html#runner) to connect all parts of the experiment: hardware backend, data transformations, model training, and inference logic. +* [fastai](https://github.com/fastai/fastai#installing) is a PyTorch framework for Deep Learning that simplifies training fast and accurate neural nets using modern best practices. fastai provides a [Learner](https://docs.fast.ai/learner.html#Learner) to handle the training, fine-tuning, and inference of deep learning algorithms. +* [Finetuner](https://github.com/jina-ai/finetuner) is a service that enables models to create higher-quality embeddings for semantic search, visual similarity search, cross-modal text<->image search, recommendation systems, clustering, duplication detection, anomaly detection, or other uses. +* [InvokeAI](https://github.com/invoke-ai/InvokeAI) is a creative engine for Stable Diffusion models, offering industry-leading WebUI, terminal usage support, and serves as the foundation for many commercial products. +* [Kornia](https://kornia.readthedocs.io/en/latest/get-started/introduction.html) is a differentiable library that allows classical computer vision to be integrated into deep learning models. Kornia provides a [Trainer](https://kornia.readthedocs.io/en/latest/x.html#kornia.x.Trainer) with the specific purpose to train and fine-tune the supported deep learning algorithms within the library. +* [Open Assistant](https://projects.laion.ai/Open-Assistant/) is a chat-based assistant that understands tasks, can interact with their party systems, and retrieve information dynamically to do so. +* [pytorch-accelerated](https://github.com/Chris-hughes10/pytorch-accelerated) is a lightweight training library, with a streamlined feature set centered around a general-purpose [Trainer](https://pytorch-accelerated.readthedocs.io/en/latest/trainer.html), that places a huge emphasis on simplicity and transparency; enabling users to understand exactly what is going on under the hood, but without having to write and maintain the boilerplate themselves! +* [Stable Diffusion web UI](https://github.com/AUTOMATIC1111/stable-diffusion-webui) is an open-source browser-based easy-to-use interface based on the Gradio library for Stable Diffusion. +* [torchkeras](https://github.com/lyhue1991/torchkeras) is a simple tool for training pytorch model just in a keras style, a dynamic and beautiful plot is provided in notebook to monitor your loss or metric. +* [transformers](https://github.com/huggingface/transformers) as a tool for helping train state-of-the-art machine learning models in PyTorch, Tensorflow, and JAX. (Accelerate is the backend for the PyTorch side). + + +## Installation + +This repository is tested on Python 3.8+ and PyTorch 1.10.0+ + +You should install 🤗 Accelerate in a [virtual environment](https://docs.python.org/3/library/venv.html). If you're unfamiliar with Python virtual environments, check out the [user guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/). + +First, create a virtual environment with the version of Python you're going to use and activate it. + +Then, you will need to install PyTorch: refer to the [official installation page](https://pytorch.org/get-started/locally/#start-locally) regarding the specific install command for your platform. Then 🤗 Accelerate can be installed using pip as follows: + +```bash +pip install accelerate +``` + +## Supported integrations + +- CPU only +- multi-CPU on one node (machine) +- multi-CPU on several nodes (machines) +- single GPU +- multi-GPU on one node (machine) +- multi-GPU on several nodes (machines) +- TPU +- FP16/BFloat16 mixed precision +- FP8 mixed precision with [Transformer Engine](https://github.com/NVIDIA/TransformerEngine) +- DeepSpeed support (Experimental) +- PyTorch Fully Sharded Data Parallel (FSDP) support (Experimental) +- Megatron-LM support (Experimental) + +## Citing 🤗 Accelerate + +If you use 🤗 Accelerate in your publication, please cite it by using the following BibTeX entry. + +```bibtex +@Misc{accelerate, + title = {Accelerate: Training and inference at scale made simple, efficient and adaptable.}, + author = {Sylvain Gugger and Lysandre Debut and Thomas Wolf and Philipp Schmid and Zachary Mueller and Sourab Mangrulkar and Marc Sun and Benjamin Bossan}, + howpublished = {\url{https://github.com/huggingface/accelerate}}, + year = {2022} +} +``` diff --git a/llm/Lib/site-packages/accelerate-0.29.3.dist-info/RECORD b/llm/Lib/site-packages/accelerate-0.29.3.dist-info/RECORD new file mode 100644 index 0000000000000000000000000000000000000000..a4e7cd37a814c31d600a1dddc487dbd22b97d6f9 --- /dev/null +++ b/llm/Lib/site-packages/accelerate-0.29.3.dist-info/RECORD @@ -0,0 +1,164 @@ +../../Scripts/accelerate-config.exe,sha256=Vc33m1EHdtjr3Go97mmzD40pHC0udwU7e2YlKu-NM6Q,108393 +../../Scripts/accelerate-estimate-memory.exe,sha256=X-r55bc3vnkw-IVoXywYuIOAVJgqbUiCHJcbDW2hPXA,108395 +../../Scripts/accelerate-launch.exe,sha256=t4JMeZ4RvYQdn-ldK88qgDsWQx1l12rMbomdlG5IFHU,108393 +../../Scripts/accelerate.exe,sha256=pReYe3Amm_rR1mdNNM1ghovR9bfE0U1BvtFmpT8cEHM,108401 +accelerate-0.29.3.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +accelerate-0.29.3.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357 +accelerate-0.29.3.dist-info/METADATA,sha256=DNiQffLlP8RQMZZvHtnR0loDVK60yC3FhB8UQKkthgo,18942 +accelerate-0.29.3.dist-info/RECORD,, +accelerate-0.29.3.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +accelerate-0.29.3.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92 +accelerate-0.29.3.dist-info/entry_points.txt,sha256=Z_KV59tIt4oZtUDEQ0w8JThJ6_1dd8vR8heH24DeAXI,238 +accelerate-0.29.3.dist-info/top_level.txt,sha256=esVfdxTidsjQ90zsN_rPpjLFJ4ijRlx4mnLrG09hlt4,11 +accelerate/__init__.py,sha256=UUqSsQQDFMm6aAZGCgNyrbTFPtwkguZA2KnoPb0XbWo,1456 +accelerate/__pycache__/__init__.cpython-311.pyc,, +accelerate/__pycache__/accelerator.cpython-311.pyc,, +accelerate/__pycache__/big_modeling.cpython-311.pyc,, +accelerate/__pycache__/checkpointing.cpython-311.pyc,, +accelerate/__pycache__/data_loader.cpython-311.pyc,, +accelerate/__pycache__/hooks.cpython-311.pyc,, +accelerate/__pycache__/inference.cpython-311.pyc,, +accelerate/__pycache__/launchers.cpython-311.pyc,, +accelerate/__pycache__/local_sgd.cpython-311.pyc,, +accelerate/__pycache__/logging.cpython-311.pyc,, +accelerate/__pycache__/memory_utils.cpython-311.pyc,, +accelerate/__pycache__/optimizer.cpython-311.pyc,, +accelerate/__pycache__/scheduler.cpython-311.pyc,, +accelerate/__pycache__/state.cpython-311.pyc,, +accelerate/__pycache__/tracking.cpython-311.pyc,, +accelerate/accelerator.py,sha256=rh4-KBMCkCGLldjKo1CRtBIbsXG76fJqYWdgOugaw7w,143024 +accelerate/big_modeling.py,sha256=pmtLTKTf8mJK1E2o51E3H5TBAuw_zLX_7pWtogtbP1w,29278 +accelerate/checkpointing.py,sha256=vFyLNg9-8qsPBYhAkcm-WwKEeK5Lrq9qLrQWNGFKoPk,11378 +accelerate/commands/__init__.py,sha256=m1PPTDT4ziIAvM0-FDSgIMIZ69Konn126s6LwuzH6v8,606 +accelerate/commands/__pycache__/__init__.cpython-311.pyc,, +accelerate/commands/__pycache__/accelerate_cli.cpython-311.pyc,, +accelerate/commands/__pycache__/env.cpython-311.pyc,, +accelerate/commands/__pycache__/estimate.cpython-311.pyc,, +accelerate/commands/__pycache__/launch.cpython-311.pyc,, +accelerate/commands/__pycache__/test.cpython-311.pyc,, +accelerate/commands/__pycache__/tpu.cpython-311.pyc,, +accelerate/commands/__pycache__/utils.cpython-311.pyc,, +accelerate/commands/accelerate_cli.py,sha256=i3nge5Wj8i4zkV0CVIk9P8veleRZbTZY0AU4fJOrKF8,1749 +accelerate/commands/config/__init__.py,sha256=iJK8dgj3pc5Vdr1E7UuGoFu-BlybyXLxYDoTg9gXngE,1645 +accelerate/commands/config/__pycache__/__init__.cpython-311.pyc,, +accelerate/commands/config/__pycache__/cluster.cpython-311.pyc,, +accelerate/commands/config/__pycache__/config.cpython-311.pyc,, +accelerate/commands/config/__pycache__/config_args.cpython-311.pyc,, +accelerate/commands/config/__pycache__/config_utils.cpython-311.pyc,, +accelerate/commands/config/__pycache__/default.cpython-311.pyc,, +accelerate/commands/config/__pycache__/sagemaker.cpython-311.pyc,, +accelerate/commands/config/__pycache__/update.cpython-311.pyc,, +accelerate/commands/config/cluster.py,sha256=lA55beGeo0fAowfffKhf8nGcy6lBjaOxTtV-Yg_Rz6s,29926 +accelerate/commands/config/config.py,sha256=FuRlQvOjgATEtyqOSsGD-KEtOCvACOHjs2C-krrtldk,3035 +accelerate/commands/config/config_args.py,sha256=hE42coVnn0UU-ysqp2ZH-jlqaXoPaHt5E_3qxT42GIM,10024 +accelerate/commands/config/config_utils.py,sha256=DcjIV1mDInFmct2_XQ-9KYAkREINs6YuHRbZe5HFjT8,2926 +accelerate/commands/config/default.py,sha256=3-SdEhl_zXM9S3f-FxkSVtiBQ5VY-QNsC4O26u60bss,5350 +accelerate/commands/config/sagemaker.py,sha256=GjHE2-h4tRr1P_PFtMF3miiAtJlzkbHbMb6kFXqn8eo,10341 +accelerate/commands/config/update.py,sha256=NXW1J7GkUHpg71QlIXsmMB_0z8S8IZo2FWax5POwrhc,2395 +accelerate/commands/env.py,sha256=HXXUozMFlxs0b-bU2a3nEcXwYz-5EBkfCvE9svqeN2U,3595 +accelerate/commands/estimate.py,sha256=shEn2nXyHmz94zpAzV2R8__lcNYW9f9djl7bOHoo04k,12398 +accelerate/commands/launch.py,sha256=rYmkdc0Kbcux4TOqBG_sJN-NNc4nmV90vuwHqhGNfWw,41439 +accelerate/commands/menu/__init__.py,sha256=uqSlBM0TFHBwzdv3p3SXfpAk1lZFp4h1a7mbBdscPHs,645 +accelerate/commands/menu/__pycache__/__init__.cpython-311.pyc,, +accelerate/commands/menu/__pycache__/cursor.cpython-311.pyc,, +accelerate/commands/menu/__pycache__/helpers.cpython-311.pyc,, +accelerate/commands/menu/__pycache__/input.cpython-311.pyc,, +accelerate/commands/menu/__pycache__/keymap.cpython-311.pyc,, +accelerate/commands/menu/__pycache__/selection_menu.cpython-311.pyc,, +accelerate/commands/menu/cursor.py,sha256=-lmpJVAzvNc0c3EOtSuLoKB59zqylVCbYyWLPnrOmvQ,2028 +accelerate/commands/menu/helpers.py,sha256=KrSB5fJjH4MUEUAQJ6bYaN16AYcnl9UalDrPD3DYeeg,1483 +accelerate/commands/menu/input.py,sha256=Uj9eDp8-Mb0Fe49nuogqo9W_RCfYd6udfjiPKx7Wjmg,2537 +accelerate/commands/menu/keymap.py,sha256=eXj-suyYs1m5dEHoUKN4mKAMLc8DWHnwhP6G6JSU0jQ,4086 +accelerate/commands/menu/selection_menu.py,sha256=bxy-DHaKKC6SCToOlMBv5_z0MdUzylEg6Sio9OuV3GM,4921 +accelerate/commands/test.py,sha256=YrPYEaAACOGZ6btn2MV6NbMSEdBUcMWADLbQWaZSHtk,2149 +accelerate/commands/tpu.py,sha256=KyxDP7IuveidZrbW4rx2s8Ku3o_ptI6tzwr_R7ck0os,5548 +accelerate/commands/utils.py,sha256=ilcfE32oHh28EToM00nc_SR6upfZiuxUI0AjjZu8KYY,3995 +accelerate/data_loader.py,sha256=qQojnHAW0cjTL7jLQN_g-oHlRZBkKzti3ifk84Izuw4,48307 +accelerate/hooks.py,sha256=x0FBwwoy6PKSwulavYTpc4gERIoB7RHGPF0Qe6qjXNA,31244 +accelerate/inference.py,sha256=Ci7kkw2cocNpuvmbo1ytW2QgcI_HKWoXkIdonFOr0tg,7977 +accelerate/launchers.py,sha256=iFDZ7seDdRwHAHy1BbVPmPccAONiPdV2aBOHNuT2ZD8,11375 +accelerate/local_sgd.py,sha256=v0-AxldUSCYCI-rqjLiEHsVtSqyEIWTC5ppn7CW7qfY,4002 +accelerate/logging.py,sha256=kvUvk33r_7T2BNzIwqRZBOhuC-50Ju4rm4HbsM6h2G8,4897 +accelerate/memory_utils.py,sha256=3R5LoeHl6GgTZ-IMPrDZMdaEehWarGdPqODushb-6pg,862 +accelerate/optimizer.py,sha256=H7e1XwEysZ_GFR8V_3bHjFAY7zzrzO8samCyW_r7dZo,7453 +accelerate/scheduler.py,sha256=des_4M_Tt1W8gCYZZbLla0GHBEgJY3Wx2EGBQPTzeiY,4238 +accelerate/state.py,sha256=yOpKq0xf-yY7qPeQMKWqG05PiU_uUsIkyGqyAlOIJNQ,50409 +accelerate/test_utils/__init__.py,sha256=amEDYw-ztgIvHkYT3mv3ixk1QJirUnf6jfPJzqUUYkQ,1459 +accelerate/test_utils/__pycache__/__init__.cpython-311.pyc,, +accelerate/test_utils/__pycache__/examples.cpython-311.pyc,, +accelerate/test_utils/__pycache__/testing.cpython-311.pyc,, +accelerate/test_utils/__pycache__/training.cpython-311.pyc,, +accelerate/test_utils/examples.py,sha256=jRm1S9TkmeoLaqprBvtVFN4LesiaDZtKMNIoLNY2euw,7281 +accelerate/test_utils/scripts/__init__.py,sha256=m1PPTDT4ziIAvM0-FDSgIMIZ69Konn126s6LwuzH6v8,606 +accelerate/test_utils/scripts/__pycache__/__init__.cpython-311.pyc,, +accelerate/test_utils/scripts/__pycache__/test_cli.cpython-311.pyc,, +accelerate/test_utils/scripts/__pycache__/test_distributed_data_loop.cpython-311.pyc,, +accelerate/test_utils/scripts/__pycache__/test_notebook.cpython-311.pyc,, +accelerate/test_utils/scripts/__pycache__/test_ops.cpython-311.pyc,, +accelerate/test_utils/scripts/__pycache__/test_script.cpython-311.pyc,, +accelerate/test_utils/scripts/__pycache__/test_sync.cpython-311.pyc,, +accelerate/test_utils/scripts/external_deps/__init__.py,sha256=m1PPTDT4ziIAvM0-FDSgIMIZ69Konn126s6LwuzH6v8,606 +accelerate/test_utils/scripts/external_deps/__pycache__/__init__.cpython-311.pyc,, +accelerate/test_utils/scripts/external_deps/__pycache__/test_checkpointing.cpython-311.pyc,, +accelerate/test_utils/scripts/external_deps/__pycache__/test_metrics.cpython-311.pyc,, +accelerate/test_utils/scripts/external_deps/__pycache__/test_peak_memory_usage.cpython-311.pyc,, +accelerate/test_utils/scripts/external_deps/__pycache__/test_performance.cpython-311.pyc,, +accelerate/test_utils/scripts/external_deps/__pycache__/test_pippy.cpython-311.pyc,, +accelerate/test_utils/scripts/external_deps/__pycache__/test_zero3_integration.cpython-311.pyc,, +accelerate/test_utils/scripts/external_deps/test_checkpointing.py,sha256=zILzHevzqxB1NPPDrJ1furaitI8MTvhBeG9QzzL0bmE,10668 +accelerate/test_utils/scripts/external_deps/test_metrics.py,sha256=67-S1qeCpCL9ceaH22RsIsBJscMS7VQWaO4Krcszzbw,12133 +accelerate/test_utils/scripts/external_deps/test_peak_memory_usage.py,sha256=D0YnKCxkI4ZwDOmZ5Ev6hL9jPyP7SU4WffpVFiK14bs,11072 +accelerate/test_utils/scripts/external_deps/test_performance.py,sha256=8fV3wCM1H9HVRRyC5C4EGWt-9aHILX_y3-E7LfSiv7M,9803 +accelerate/test_utils/scripts/external_deps/test_pippy.py,sha256=RdMoD1rlLKMyjyl0soSqR3iDbGidS6-z5GHo3bJUOw8,4647 +accelerate/test_utils/scripts/external_deps/test_zero3_integration.py,sha256=bJ0Jio-6OCyS2FIgFmZi3duqG1gbkOoTEcHsrORYIL4,1503 +accelerate/test_utils/scripts/test_cli.py,sha256=qfk1aYFtdvYFCYPkl05602SNGvk08QTv0xZVVcFVtzM,833 +accelerate/test_utils/scripts/test_distributed_data_loop.py,sha256=VqFPKNRu8yx2MoZ4nHy5wRocEthSymcIA2mg1knqDq8,8315 +accelerate/test_utils/scripts/test_notebook.py,sha256=Q4OOWHa_GMmzwfiq71BTpKYmhCHLC02J42OO94ut9xk,1629 +accelerate/test_utils/scripts/test_ops.py,sha256=BcGn3xJT2wUJ0Yk_6VLNkneSv9z24JeAoQjsgdIIRr4,6170 +accelerate/test_utils/scripts/test_script.py,sha256=QyHRWvHQm1XWkAH7YilQ0gZe3zwvEkyqD6JXmneWqak,32059 +accelerate/test_utils/scripts/test_sync.py,sha256=3kltq-GuUjOVuo6_FOuWiPyc5f3pGiqiwEAbex5x_-o,18263 +accelerate/test_utils/testing.py,sha256=HIp7n6qPMh8KPbwEzNWu5mzfxnQRcU15EQ1AQKehpo0,20571 +accelerate/test_utils/training.py,sha256=8k_YAQ21MzUdb2aFWq1t2fihW1b-iBGh1OJSL3whY68,4019 +accelerate/tracking.py,sha256=WLY-H1DTsxrz4BVzle7QZMp0Irg84yFMbA1e6JaY3pM,39789 +accelerate/utils/__init__.py,sha256=SEP34Od2TbTZt7AbhPJoWWDxFoNMeNEyAuVfaPgVu7k,6065 +accelerate/utils/__pycache__/__init__.cpython-311.pyc,, +accelerate/utils/__pycache__/bnb.cpython-311.pyc,, +accelerate/utils/__pycache__/constants.cpython-311.pyc,, +accelerate/utils/__pycache__/dataclasses.cpython-311.pyc,, +accelerate/utils/__pycache__/deepspeed.cpython-311.pyc,, +accelerate/utils/__pycache__/environment.cpython-311.pyc,, +accelerate/utils/__pycache__/fsdp_utils.cpython-311.pyc,, +accelerate/utils/__pycache__/imports.cpython-311.pyc,, +accelerate/utils/__pycache__/launch.cpython-311.pyc,, +accelerate/utils/__pycache__/megatron_lm.cpython-311.pyc,, +accelerate/utils/__pycache__/memory.cpython-311.pyc,, +accelerate/utils/__pycache__/modeling.cpython-311.pyc,, +accelerate/utils/__pycache__/offload.cpython-311.pyc,, +accelerate/utils/__pycache__/operations.cpython-311.pyc,, +accelerate/utils/__pycache__/other.cpython-311.pyc,, +accelerate/utils/__pycache__/random.cpython-311.pyc,, +accelerate/utils/__pycache__/rich.cpython-311.pyc,, +accelerate/utils/__pycache__/torch_xla.cpython-311.pyc,, +accelerate/utils/__pycache__/tqdm.cpython-311.pyc,, +accelerate/utils/__pycache__/transformer_engine.cpython-311.pyc,, +accelerate/utils/__pycache__/versions.cpython-311.pyc,, +accelerate/utils/bnb.py,sha256=3i59dy8EcBYJEnT2alJ5_M-zeIpFsrceQ4bImiJJKOk,20570 +accelerate/utils/constants.py,sha256=e6Bpf7gSZLFkvfr-1B1841b6lVoKJ5uyyf5kefe0aT4,2566 +accelerate/utils/dataclasses.py,sha256=QSP-gYjXz68s0PAseKwLHRBQUnzcBQwPk80otV4X20k,74253 +accelerate/utils/deepspeed.py,sha256=1JFnz-dY6xP9yHywnX8bzZNq-d-8Cpg5CvVNLZ74b_0,10276 +accelerate/utils/environment.py,sha256=8eVGMCu7xT1y0Hxochnxz_RghDePtWo2TghDlOm5Gf0,10409 +accelerate/utils/fsdp_utils.py,sha256=QURWBtK8D00zppqJko0yeznEovXvnkRLI0NpPPkog1Q,10667 +accelerate/utils/imports.py,sha256=gYj_W3E5V83dYlSqqYE89OAK6JonzwhlcEjsJcOpB3E,12232 +accelerate/utils/launch.py,sha256=hHpcnR0NrSmqaT7AIaeIeXOAJVIhWnWdq3kA1XSnOYs,27459 +accelerate/utils/megatron_lm.py,sha256=IfHrtMiPSwuzh5ri96rTTIcEluuMNuIj3O8Y4jW6Fzk,57124 +accelerate/utils/memory.py,sha256=VxJCU-tMX8uE34GbJnxtDXYPHh4D9p2Y-d6rkGxqSa0,5200 +accelerate/utils/modeling.py,sha256=OfTHPg7oM9-jzYotLZjJKj6TrhCTFV3qOtQAOhKXmzQ,80246 +accelerate/utils/offload.py,sha256=qjaVai81wbkA0YH2WkmOXvZT0BRphygfRV_4Ua4j4U4,7837 +accelerate/utils/operations.py,sha256=zsmRx8mP2eoImPc42pOmBIqaHX7RDugw8AZ_HF3onpg,30610 +accelerate/utils/other.py,sha256=kgON65EhzQN3oQZqzgAOmmNC2vsQkeO77qEuzN7Zv7c,12283 +accelerate/utils/random.py,sha256=t-HsLQRm8etSiLSyONCU9wNhj-0VjDUyDme9p6RxDNU,4881 +accelerate/utils/rich.py,sha256=8JZX_uGMQX-BufdXxJpdne7BWd1KyLHSgbiGxrDMYr8,847 +accelerate/utils/torch_xla.py,sha256=Pq1tuqN0X_pWDVza6YgjfO45uoJdoRVRForLeLQzFus,1908 +accelerate/utils/tqdm.py,sha256=9Ovx4GL8AvjSaBd_OysoUGPW9ZJ3ZBOde6776HMEMOA,1344 +accelerate/utils/transformer_engine.py,sha256=gNPkOv_D1SDLm6nVZtxWIjyA6snxWtAQeBWUZLIErJE,3582 +accelerate/utils/versions.py,sha256=UgmcbjBm--6CIx1ZamSAMjAK_B_2l48LbeaNygqej8M,2149 diff --git a/llm/Lib/site-packages/accelerate-0.29.3.dist-info/REQUESTED b/llm/Lib/site-packages/accelerate-0.29.3.dist-info/REQUESTED new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/llm/Lib/site-packages/accelerate-0.29.3.dist-info/WHEEL b/llm/Lib/site-packages/accelerate-0.29.3.dist-info/WHEEL new file mode 100644 index 0000000000000000000000000000000000000000..becc9a66ea739ba941d48a749e248761cc6e658a --- /dev/null +++ b/llm/Lib/site-packages/accelerate-0.29.3.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.37.1) +Root-Is-Purelib: true +Tag: py3-none-any + diff --git a/llm/Lib/site-packages/accelerate-0.29.3.dist-info/entry_points.txt b/llm/Lib/site-packages/accelerate-0.29.3.dist-info/entry_points.txt new file mode 100644 index 0000000000000000000000000000000000000000..db491c83cf468945b19a098b171fcbfa7d7db1d1 --- /dev/null +++ b/llm/Lib/site-packages/accelerate-0.29.3.dist-info/entry_points.txt @@ -0,0 +1,5 @@ +[console_scripts] +accelerate = accelerate.commands.accelerate_cli:main +accelerate-config = accelerate.commands.config:main +accelerate-estimate-memory = accelerate.commands.estimate:main +accelerate-launch = accelerate.commands.launch:main diff --git a/llm/Lib/site-packages/accelerate-0.29.3.dist-info/top_level.txt b/llm/Lib/site-packages/accelerate-0.29.3.dist-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..a9368375be0e0e13fdad0eea4b92541bd9e1f594 --- /dev/null +++ b/llm/Lib/site-packages/accelerate-0.29.3.dist-info/top_level.txt @@ -0,0 +1 @@ +accelerate diff --git a/llm/Lib/site-packages/accelerate/__init__.py b/llm/Lib/site-packages/accelerate/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d596762c554a4f07598401d2b1b22c46250dcb55 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/__init__.py @@ -0,0 +1,48 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +__version__ = "0.29.3" + +from .accelerator import Accelerator +from .big_modeling import ( + cpu_offload, + cpu_offload_with_hook, + disk_offload, + dispatch_model, + init_empty_weights, + init_on_device, + load_checkpoint_and_dispatch, +) +from .data_loader import skip_first_batches +from .inference import prepare_pippy +from .launchers import debug_launcher, notebook_launcher +from .state import PartialState +from .utils import ( + AutocastKwargs, + DataLoaderConfiguration, + DeepSpeedPlugin, + DistributedDataParallelKwargs, + DistributedType, + FullyShardedDataParallelPlugin, + GradScalerKwargs, + InitProcessGroupKwargs, + find_executable_batch_size, + infer_auto_device_map, + is_rich_available, + load_checkpoint_in_model, + synchronize_rng_states, +) + + +if is_rich_available(): + from .utils import rich diff --git a/llm/Lib/site-packages/accelerate/__pycache__/__init__.cpython-311.pyc b/llm/Lib/site-packages/accelerate/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6d6edd9914b7e1f9cba4d32fbb5423e52fe77f2 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/__pycache__/__init__.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/__pycache__/accelerator.cpython-311.pyc b/llm/Lib/site-packages/accelerate/__pycache__/accelerator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..406505080b964ee4ababb9d223cf495a34740506 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/__pycache__/accelerator.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/__pycache__/big_modeling.cpython-311.pyc b/llm/Lib/site-packages/accelerate/__pycache__/big_modeling.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f08849909f0b722898b7db12bf9b2abd75594b71 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/__pycache__/big_modeling.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/__pycache__/checkpointing.cpython-311.pyc b/llm/Lib/site-packages/accelerate/__pycache__/checkpointing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8919c57d3a4fc44b3ec89656a5fec504654360b5 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/__pycache__/checkpointing.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/__pycache__/data_loader.cpython-311.pyc b/llm/Lib/site-packages/accelerate/__pycache__/data_loader.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8fb3eddc41c540a1fb6449b0e37e10e457a8bba Binary files /dev/null and b/llm/Lib/site-packages/accelerate/__pycache__/data_loader.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/__pycache__/hooks.cpython-311.pyc b/llm/Lib/site-packages/accelerate/__pycache__/hooks.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81a0009a08f8a8dad26d0ae8f9ea9947f2013f56 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/__pycache__/hooks.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/__pycache__/inference.cpython-311.pyc b/llm/Lib/site-packages/accelerate/__pycache__/inference.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..713c9ae11f342a4adcc38fe4b6e6df0390eb1771 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/__pycache__/inference.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/__pycache__/launchers.cpython-311.pyc b/llm/Lib/site-packages/accelerate/__pycache__/launchers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c18ea73257b782b47a5ed1d831787418134ed43 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/__pycache__/launchers.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/__pycache__/local_sgd.cpython-311.pyc b/llm/Lib/site-packages/accelerate/__pycache__/local_sgd.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93d48a3d1c232e3d8a84b726577139d251715054 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/__pycache__/local_sgd.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/__pycache__/logging.cpython-311.pyc b/llm/Lib/site-packages/accelerate/__pycache__/logging.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d728be1122a2090219cb78c3916030ab790bba6 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/__pycache__/logging.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/__pycache__/memory_utils.cpython-311.pyc b/llm/Lib/site-packages/accelerate/__pycache__/memory_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d96a3f67f69930a8d49908f1bc305c2291711c94 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/__pycache__/memory_utils.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/__pycache__/optimizer.cpython-311.pyc b/llm/Lib/site-packages/accelerate/__pycache__/optimizer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78daf90c49b5045cfc152916c4a848e07e9acd4a Binary files /dev/null and b/llm/Lib/site-packages/accelerate/__pycache__/optimizer.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/__pycache__/scheduler.cpython-311.pyc b/llm/Lib/site-packages/accelerate/__pycache__/scheduler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..398bfc6fac8a572b092a5f8f3b0e12bcffa7ba51 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/__pycache__/scheduler.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/__pycache__/state.cpython-311.pyc b/llm/Lib/site-packages/accelerate/__pycache__/state.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d8c67d25c6aaa9b8cd8cdec64b73b40d2cb24b6 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/__pycache__/state.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/__pycache__/tracking.cpython-311.pyc b/llm/Lib/site-packages/accelerate/__pycache__/tracking.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d39b552c2ce546e85b4907a1ac648414e8b522a5 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/__pycache__/tracking.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/accelerator.py b/llm/Lib/site-packages/accelerate/accelerator.py new file mode 100644 index 0000000000000000000000000000000000000000..4786946c6da68e12a29b5f3cd799d40e6af49c4a --- /dev/null +++ b/llm/Lib/site-packages/accelerate/accelerator.py @@ -0,0 +1,3259 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import contextlib +import functools +import json +import math +import os +import re +import shutil +import sys +import warnings +from collections import OrderedDict +from contextlib import contextmanager +from functools import partial +from types import MethodType +from typing import Any, Callable, Union + +import torch +import torch.utils.hooks as hooks + +from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state +from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches +from .hooks import AlignDevicesHook +from .logging import get_logger +from .optimizer import AcceleratedOptimizer +from .scheduler import AcceleratedScheduler +from .state import AcceleratorState, GradientState, PartialState +from .tracking import LOGGER_TYPE_TO_CLASS, GeneralTracker, filter_trackers +from .utils import ( + MODEL_NAME, + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, + AutocastKwargs, + DataLoaderConfiguration, + DeepSpeedPlugin, + DistributedDataParallelKwargs, + DistributedType, + DynamoBackend, + FP8RecipeKwargs, + FullyShardedDataParallelPlugin, + GradientAccumulationPlugin, + GradScalerKwargs, + InitProcessGroupKwargs, + KwargsHandler, + LoggerType, + MegatronLMPlugin, + PrecisionType, + ProjectConfiguration, + RNGType, + TorchDynamoPlugin, + check_os_kernel, + clean_state_dict_for_safetensors, + compare_versions, + convert_model, + convert_outputs_to_fp32, + extract_model_from_parallel, + gather, + gather_object, + get_mixed_precision_context_manager, + get_pretty_name, + has_transformer_engine_layers, + is_bf16_available, + is_deepspeed_available, + is_fp8_available, + is_ipex_available, + is_megatron_lm_available, + is_mlu_available, + is_msamp_available, + is_npu_available, + is_torch_version, + is_torch_xla_available, + is_xpu_available, + load_fsdp_model, + load_fsdp_optimizer, + pad_across_processes, + parse_choice_from_env, + recursively_apply, + reduce, + release_memory, + save, + save_fsdp_model, + save_fsdp_optimizer, + shard_checkpoint, + wait_for_everyone, +) +from .utils.constants import FSDP_PYTORCH_VERSION +from .utils.modeling import get_state_dict_offloaded_model +from .utils.other import is_compiled_module + + +if is_deepspeed_available(): + from .utils import ( + DeepSpeedEngineWrapper, + DeepSpeedOptimizerWrapper, + DeepSpeedSchedulerWrapper, + DummyOptim, + DummyScheduler, + ) + +if is_fp8_available(): + import transformer_engine.common.recipe as te_recipe + from transformer_engine.pytorch import fp8_autocast + + +if is_megatron_lm_available(): + from .utils import ( + MegatronEngine, + MegatronLMDummyDataLoader, + MegatronLMDummyScheduler, + MegatronLMOptimizerWrapper, + MegatronLMSchedulerWrapper, + megatron_lm_initialize, + megatron_lm_prepare_data_loader, + megatron_lm_prepare_model, + megatron_lm_prepare_optimizer, + megatron_lm_prepare_scheduler, + ) + +from torch.distributed.algorithms.join import Join + + +if is_torch_xla_available(): + import torch_xla.amp as xamp + import torch_xla.core.xla_model as xm + import torch_xla.distributed.xla_multiprocessing as xmp + + +if is_npu_available(check_device=False): + import torch_npu # noqa: F401 + + +try: + from torch.optim.lr_scheduler import LRScheduler +except ImportError: + from torch.optim.lr_scheduler import _LRScheduler as LRScheduler + +logger = get_logger(__name__) + +# Sentinel values for defaults +_split_batches = object() +_dispatch_batches = object() +_even_batches = object() +_use_seedable_sampler = object() + + +class Accelerator: + """ + Creates an instance of an accelerator for distributed training (on multi-GPU, TPU) or mixed precision training. + + Args: + device_placement (`bool`, *optional*, defaults to `True`): + Whether or not the accelerator should put objects on device (tensors yielded by the dataloader, model, + etc...). + mixed_precision (`str`, *optional*): + Whether or not to use mixed precision training. Choose from 'no','fp16','bf16 or 'fp8'. Will default to the + value in the environment variable `ACCELERATE_MIXED_PRECISION`, which will use the default value in the + accelerate config of the current system or the flag passed with the `accelerate.launch` command. 'fp8' + requires the installation of transformers-engine. + gradient_accumulation_steps (`int`, *optional*, default to 1): + The number of steps that should pass before gradients are accumulated. A number > 1 should be combined with + `Accelerator.accumulate`. If not passed, will default to the value in the environment variable + `ACCELERATE_GRADIENT_ACCUMULATION_STEPS`. Can also be configured through a `GradientAccumulationPlugin`. + cpu (`bool`, *optional*): + Whether or not to force the script to execute on CPU. Will ignore GPU available if set to `True` and force + the execution on one process only. + dataloader_config (`DataLoaderConfiguration`, *optional*): + A configuration for how the dataloaders should be handled in distributed scenarios. + deepspeed_plugin ([`~utils.DeepSpeedPlugin`], *optional*): + Tweak your DeepSpeed related args using this argument. This argument is optional and can be configured + directly using *accelerate config* + fsdp_plugin ([`~utils.FullyShardedDataParallelPlugin`], *optional*): + Tweak your FSDP related args using this argument. This argument is optional and can be configured directly + using *accelerate config* + megatron_lm_plugin ([`~utils.MegatronLMPlugin`], *optional*): + Tweak your MegatronLM related args using this argument. This argument is optional and can be configured + directly using *accelerate config* + rng_types (list of `str` or [`~utils.RNGType`]): + The list of random number generators to synchronize at the beginning of each iteration in your prepared + dataloaders. Should be one or several of: + + - `"torch"`: the base torch random number generator + - `"cuda"`: the CUDA random number generator (GPU only) + - `"xla"`: the XLA random number generator (TPU only) + - `"generator"`: the `torch.Generator` of the sampler (or batch sampler if there is no sampler in your + dataloader) or of the iterable dataset (if it exists) if the underlying dataset is of that type. + + Will default to `["torch"]` for PyTorch versions <=1.5.1 and `["generator"]` for PyTorch versions >= 1.6. + log_with (list of `str`, [`~utils.LoggerType`] or [`~tracking.GeneralTracker`], *optional*): + A list of loggers to be setup for experiment tracking. Should be one or several of: + + - `"all"` + - `"tensorboard"` + - `"wandb"` + - `"comet_ml"` + If `"all"` is selected, will pick up all available trackers in the environment and initialize them. Can + also accept implementations of `GeneralTracker` for custom trackers, and can be combined with `"all"`. + project_config ([`~utils.ProjectConfiguration`], *optional*): + A configuration for how saving the state can be handled. + project_dir (`str`, `os.PathLike`, *optional*): + A path to a directory for storing data such as logs of locally-compatible loggers and potentially saved + checkpoints. + step_scheduler_with_optimizer (`bool`, *optional`, defaults to `True`): + Set `True` if the learning rate scheduler is stepped at the same time as the optimizer, `False` if only + done under certain circumstances (at the end of each epoch, for instance). + kwargs_handlers (list of [`~utils.KwargsHandler`], *optional*) + A list of [`~utils.KwargsHandler`] to customize how the objects related to distributed training or mixed + precision are created. See [kwargs](kwargs) for more information. + dynamo_backend (`str` or [`~utils.DynamoBackend`], *optional*, defaults to `"no"`): + Set to one of the possible dynamo backends to optimize your training with torch dynamo. + gradient_accumulation_plugin ([`~utils.GradientAccumulationPlugin`], *optional*): + A configuration for how gradient accumulation should be handled, if more tweaking than just the + `gradient_accumulation_steps` is needed. + + **Available attributes:** + + - **device** (`torch.device`) -- The device to use. + - **distributed_type** ([`~utils.DistributedType`]) -- The distributed training configuration. + - **local_process_index** (`int`) -- The process index on the current machine. + - **mixed_precision** (`str`) -- The configured mixed precision mode. + - **num_processes** (`int`) -- The total number of processes used for training. + - **optimizer_step_was_skipped** (`bool`) -- Whether or not the optimizer update was skipped (because of + gradient overflow in mixed precision), in which + case the learning rate should not be changed. + - **process_index** (`int`) -- The overall index of the current process among all processes. + - **state** ([`~state.AcceleratorState`]) -- The distributed setup state. + - **sync_gradients** (`bool`) -- Whether the gradients are currently being synced across all processes. + - **use_distributed** (`bool`) -- Whether the current configuration is for distributed training. + """ + + def __init__( + self, + device_placement: bool = True, + split_batches: bool = _split_batches, + mixed_precision: PrecisionType | str | None = None, + gradient_accumulation_steps: int = 1, + cpu: bool = False, + dataloader_config: DataLoaderConfiguration | None = None, + deepspeed_plugin: DeepSpeedPlugin | None = None, + fsdp_plugin: FullyShardedDataParallelPlugin | None = None, + megatron_lm_plugin: MegatronLMPlugin | None = None, + rng_types: list[str | RNGType] | None = None, + log_with: str | LoggerType | GeneralTracker | list[str | LoggerType | GeneralTracker] | None = None, + project_dir: str | os.PathLike | None = None, + project_config: ProjectConfiguration | None = None, + gradient_accumulation_plugin: GradientAccumulationPlugin | None = None, + dispatch_batches: bool | None = _dispatch_batches, + even_batches: bool = _even_batches, + use_seedable_sampler: bool = _use_seedable_sampler, + step_scheduler_with_optimizer: bool = True, + kwargs_handlers: list[KwargsHandler] | None = None, + dynamo_backend: DynamoBackend | str | None = None, + ): + self.trackers = [] + if project_config is not None: + self.project_configuration = project_config + else: + self.project_configuration = ProjectConfiguration(project_dir=project_dir) + if project_dir is not None and self.project_dir is None: + self.project_configuration.set_directories(project_dir) + if mixed_precision is not None: + mixed_precision = str(mixed_precision) + if mixed_precision not in PrecisionType: + raise ValueError( + f"Unknown mixed_precision mode: {mixed_precision}. Choose between {PrecisionType.list()}" + ) + + dynamo_plugin = TorchDynamoPlugin() if dynamo_backend is None else TorchDynamoPlugin(backend=dynamo_backend) + + if deepspeed_plugin is None: # init from env variables + deepspeed_plugin = ( + DeepSpeedPlugin() if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true" else None + ) + else: + assert isinstance( + deepspeed_plugin, DeepSpeedPlugin + ), "`deepspeed_plugin` must be an `accelerate.utils.DeepSpeedPlugin` object." + os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" # use DeepSpeed if plugin is provided + if deepspeed_plugin: + if not is_deepspeed_available(): + raise ImportError("DeepSpeed is not installed => run `pip install deepspeed` or build it from source.") + if is_mlu_available(): + if compare_versions("deepspeed-mlu", "<", "0.10.1"): + raise ImportError("DeepSpeed MLU version must be >= 0.10.1. Please update DeepSpeed MLU.") + elif compare_versions("deepspeed", "<", "0.9.3"): + raise ImportError("DeepSpeed version must be >= 0.9.3. Please update DeepSpeed.") + + mixed_precision = ( + os.environ.get("ACCELERATE_MIXED_PRECISION", "no") if mixed_precision is None else mixed_precision + ) + deepspeed_plugin.set_mixed_precision(mixed_precision) + deepspeed_plugin.set_deepspeed_weakref() + + if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" or isinstance( + fsdp_plugin, FullyShardedDataParallelPlugin + ): + if is_torch_version("<", FSDP_PYTORCH_VERSION): + raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}") + + if fsdp_plugin is None: # init from env variables + fsdp_plugin = ( + FullyShardedDataParallelPlugin() if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true" else None + ) + else: + if not isinstance(fsdp_plugin, FullyShardedDataParallelPlugin): + raise TypeError("`fsdp_plugin` must be a FullyShardedDataParallelPlugin object.") + os.environ["ACCELERATE_USE_FSDP"] = "true" # use FSDP if plugin is provided + + if megatron_lm_plugin is None: # init from env variables + megatron_lm_plugin = ( + MegatronLMPlugin() if os.environ.get("ACCELERATE_USE_MEGATRON_LM", "false") == "true" else None + ) + else: + if not isinstance(megatron_lm_plugin, MegatronLMPlugin): + raise TypeError("`megatron_lm_plugin` must be a MegatronLMPlugin object.") + os.environ["ACCELERATE_USE_MEGATRON_LM"] = "true" # use MegatronLM if plugin is provided + + if megatron_lm_plugin: + if not is_megatron_lm_available(): + raise ImportError("Megatron is not installed. please build it from source.") + + # Kwargs handlers + self.ddp_handler = None + self.scaler_handler = None + self.init_handler = None + self.fp8_recipe_handler = None + self.autocast_handler = None + if kwargs_handlers is not None: + for handler in kwargs_handlers: + assert isinstance( + handler, KwargsHandler + ), f"Unsupported kwargs handler passed: {handler}, must be one that inherits `accelerate.utils.KwargsHandler`." + if isinstance(handler, DistributedDataParallelKwargs): + if self.ddp_handler is not None: + raise ValueError("You can only pass one `DistributedDataParallelKwargs` in `kwargs_handler`.") + else: + self.ddp_handler = handler + elif isinstance(handler, GradScalerKwargs): + if self.scaler_handler is not None: + raise ValueError("You can only pass one `GradScalerKwargs` in `kwargs_handler`.") + else: + self.scaler_handler = handler + elif isinstance(handler, InitProcessGroupKwargs): + if self.init_handler is not None: + raise ValueError("You can only pass one `InitProcessGroupKwargs` in `kwargs_handler`.") + else: + self.init_handler = handler + elif isinstance(handler, FP8RecipeKwargs): + if self.fp8_recipe_handler is not None: + raise ValueError("You can only pass one `FP8RecipeKwargs` in `kwargs_handler`.") + else: + self.fp8_recipe_handler = handler + elif isinstance(handler, AutocastKwargs): + if self.autocast_handler is not None: + raise ValueError("You can only pass one `AutocastKwargs` in `kwargs_handler`.") + else: + self.autocast_handler = handler + + kwargs = self.init_handler.to_kwargs() if self.init_handler is not None else {} + self.state = AcceleratorState( + mixed_precision=mixed_precision, + cpu=cpu, + dynamo_plugin=dynamo_plugin, + deepspeed_plugin=deepspeed_plugin, + fsdp_plugin=fsdp_plugin, + megatron_lm_plugin=megatron_lm_plugin, + _from_accelerator=True, + **kwargs, + ) + + if self.fp8_recipe_handler is None and self.state.mixed_precision == "fp8": + self.fp8_recipe_handler = FP8RecipeKwargs(backend="MSAMP" if is_msamp_available() else "TE") + + trackers = filter_trackers(log_with, self.logging_dir) + if len(trackers) < 1 and log_with is not None: + warnings.warn(f"`log_with={log_with}` was passed but no supported trackers are currently installed.") + self.log_with = trackers + + if ( + (mixed_precision != "bf16") + and getattr(self.state, "downcast_bfloat", False) + and (self.state.distributedType != DistributedType.XLA) + ): + raise ValueError("Can only use `downcast_bf16` when using `mixed_precision='bf16'` and on a TPU") + + if gradient_accumulation_plugin is not None: + if gradient_accumulation_steps != 1: + raise ValueError( + "You can only pass one of `gradient_accumulation_steps` and `gradient_accumulation_plugin`. Please only pass in the created `GradientAccumulationPlugin` object." + ) + else: + gradient_accumulation_steps = int( + parse_choice_from_env("ACCELERATE_GRADIENT_ACCUMULATION_STEPS", gradient_accumulation_steps) + ) + gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=gradient_accumulation_steps) + self.gradient_state = GradientState( + gradient_accumulation_plugin=gradient_accumulation_plugin, + ) + + self.device_placement = device_placement + if dataloader_config is None: + dataloader_config = DataLoaderConfiguration() + self.dataloader_config = dataloader_config + # Deal with deprecated args + # TODO: Remove in v1.0.0 + deprecated_dl_args = {} + if dispatch_batches is not _dispatch_batches: + deprecated_dl_args["dispatch_batches"] = dispatch_batches + self.dataloader_config.dispatch_batches = dispatch_batches + if split_batches is not _split_batches: + deprecated_dl_args["split_batches"] = split_batches + self.dataloader_config.split_batches = split_batches + if even_batches is not _even_batches: + deprecated_dl_args["even_batches"] = even_batches + self.dataloader_config.even_batches = even_batches + if use_seedable_sampler is not _use_seedable_sampler: + deprecated_dl_args["use_seedable_sampler"] = use_seedable_sampler + self.dataloader_config.use_seedable_sampler = use_seedable_sampler + if len(deprecated_dl_args) > 0: + values = ", ".join([f"{k}={v}" for k, v in deprecated_dl_args.items()]) + warnings.warn( + f"Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: {deprecated_dl_args.keys()}. " + "Please pass an `accelerate.DataLoaderConfiguration` instead: \n" + f"dataloader_config = DataLoaderConfiguration({values})", + FutureWarning, + ) + self.step_scheduler_with_optimizer = step_scheduler_with_optimizer + + # Mixed precision attributes + self.scaler = None + self.native_amp = False + if ( + self.state.mixed_precision == "fp16" + and self.device.type != "cpu" + and self.distributed_type not in (DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM) + ): + self.native_amp = True + if self.device.type not in ("xpu", "cuda", "mps", "npu", "xla", "mlu") or is_torch_xla_available( + check_is_tpu=True + ): + raise ValueError(f"fp16 mixed precision requires a GPU (not {self.device.type!r}).") + kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {} + if self.distributed_type == DistributedType.FSDP: + from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler + + self.scaler = ShardedGradScaler(**kwargs) + elif is_torch_xla_available(check_is_gpu=True): + self.scaler = xamp.GradScaler(**kwargs) + elif is_mlu_available(): + self.scaler = torch.mlu.amp.GradScaler(**kwargs) + elif is_npu_available(): + self.scaler = torch.npu.amp.GradScaler(**kwargs) + else: + self.scaler = torch.cuda.amp.GradScaler(**kwargs) + + elif self.state.mixed_precision == "bf16" and self.distributed_type not in ( + DistributedType.DEEPSPEED, + DistributedType.MEGATRON_LM, + ): + if self.device.type in ["cpu", "xpu"]: + self.native_amp = True + else: + self.native_amp = is_bf16_available(True) + if mixed_precision == "bf16" and not self.native_amp and not is_torch_xla_available(): + raise ValueError("bf16 mixed precision requires PyTorch >= 1.10 and a supported device.") + + # Start of internal step tracking + self.step = 0 + + # Internal references to the training objects + self._optimizers = [] + self._models = [] + self._schedulers = [] + self._dataloaders = [] + self._custom_objects = [] + + # Hooks + self._load_model_state_pre_hook = OrderedDict() + self._save_model_state_pre_hook = OrderedDict() + + # RNG Types + self.rng_types = rng_types + if self.rng_types is None: + self.rng_types = ["generator"] + + # Set a flag tensor for early stopping and other breakpoints + self.flag_tensor = None + + check_os_kernel() + + @property + def use_distributed(self): + """ + Whether the Accelerator is configured for distributed training + """ + return self.state.use_distributed + + @property + def distributed_type(self): + return self.state.distributed_type + + @property + def num_processes(self): + return self.state.num_processes + + @property + def process_index(self): + return self.state.process_index + + @property + def local_process_index(self): + return self.state.local_process_index + + @property + def device(self): + return self.state.device + + @property + def split_batches(self): + return self.dataloader_config.split_batches + + @property + def dispatch_batches(self): + return self.dataloader_config.dispatch_batches + + @property + def even_batches(self): + return self.dataloader_config.even_batches + + @even_batches.setter + def even_batches(self, value: bool): + self.dataloader_config.even_batches = value + + @property + def use_seedable_sampler(self): + return self.dataloader_config.use_seedable_sampler + + @property + def project_dir(self): + return self.project_configuration.project_dir + + @property + def logging_dir(self): + return self.project_configuration.logging_dir + + @property + def save_iteration(self): + return self.project_configuration.iteration + + @property + def is_main_process(self): + """True for one process only.""" + return self.state.is_main_process + + @property + def is_local_main_process(self): + """True for one process per server.""" + return self.state.is_local_main_process + + @property + def use_fp16(self): + warnings.warn( + "The `use_fp16` property is deprecated and will be removed in version 1.0 of Accelerate use " + "`Accelerator.mixed_precision == 'fp16'` instead.", + FutureWarning, + ) + return self.mixed_precision != "no" + + @property + def is_last_process(self): + return self.process_index == self.num_processes - 1 + + @property + def mixed_precision(self): + return self.state.mixed_precision + + @contextmanager + def split_between_processes(self, inputs: list | tuple | dict | torch.Tensor, apply_padding: bool = False): + """ + Splits `input` between `self.num_processes` quickly and can be then used on that process. Useful when doing + distributed inference, such as with different prompts. + + Note that when using a `dict`, all keys need to have the same number of elements. + + Args: + inputs (`list`, `tuple`, `torch.Tensor`, or `dict` of `list`/`tuple`/`torch.Tensor`): + The input to split between processes. + apply_padding (`bool`, `optional`, defaults to `False`): + Whether to apply padding by repeating the last element of the input so that all processes have the same + number of elements. Useful when trying to perform actions such as `Accelerator.gather()` on the outputs + or passing in less inputs than there are processes. If so, just remember to drop the padded elements + afterwards. + + Example: + + ```python + # Assume there are two processes + from accelerate import Accelerator + + accelerator = Accelerator() + with accelerator.split_between_processes(["A", "B", "C"]) as inputs: + print(inputs) + # Process 0 + ["A", "B"] + # Process 1 + ["C"] + + with accelerator.split_between_processes(["A", "B", "C"], apply_padding=True) as inputs: + print(inputs) + # Process 0 + ["A", "B"] + # Process 1 + ["C", "C"] + ``` + """ + with PartialState().split_between_processes(inputs, apply_padding=apply_padding) as inputs: + yield inputs + + def on_main_process(self, function: Callable[..., Any] = None): + """ + A decorator that will run the decorated function on the main process only. Can also be called using the + `PartialState` class. + + Args: + function (`Callable`): The function to decorate. + + Example: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + + + >>> @accelerator.on_main_process + ... def print_something(): + ... print("This will be printed by process 0 only.") + + + >>> print_something() + "This will be printed by process 0 only" + ``` + """ + # For times when the `Accelerator` object itself utilizes this decorator. + if function is None: + if "Accelerator." in self.__qualname__: + function = self + else: + raise ValueError( + "The `on_main_process` decorator must be called with a function on an instantiated `Accelerator` object." + ) + + def _inner(*args, **kwargs): + return PartialState().on_main_process(function)(*args, **kwargs) + + return _inner + + def on_local_main_process(self, function: Callable[..., Any] = None): + """ + A decorator that will run the decorated function on the local main process only. Can also be called using the + `PartialState` class. + + Args: + function (`Callable`): The function to decorate. + + Example: + ```python + # Assume we have 2 servers with 4 processes each. + from accelerate import Accelerator + + accelerator = Accelerator() + + + @accelerator.on_local_main_process + def print_something(): + print("This will be printed by process 0 only on each server.") + + + print_something() + # On server 1: + "This will be printed by process 0 only" + # On server 2: + "This will be printed by process 0 only" + ``` + """ + # For times when the `Accelerator` object itself utilizes this decorator. + if function is None: + if "Accelerator." in self.__qualname__: + function = self + else: + raise ValueError( + "The `on_local_main_process` decorator must be called with a function on an instantiated `Accelerator` object." + ) + + def _inner(*args, **kwargs): + return PartialState().on_local_main_process(function)(*args, **kwargs) + + return _inner + + def on_last_process(self, function: Callable[..., Any]): + """ + A decorator that will run the decorated function on the last process only. Can also be called using the + `PartialState` class. + + Args: + function (`Callable`): The function to decorate. + + Example: + ```python + # Assume we have 4 processes. + from accelerate import Accelerator + + accelerator = Accelerator() + + + @accelerator.on_last_process + def print_something(): + print(f"Printed on process {accelerator.process_index}") + + + print_something() + "Printed on process 3" + ``` + """ + # For times when the `Accelerator` object itself utilizes this decorator. + if function is None: + if "Accelerator." in self.__qualname__: + function = self + else: + raise ValueError( + "The `on_last_process` decorator must be called with a function on an instantiated `Accelerator` object." + ) + + def _inner(*args, **kwargs): + return PartialState().on_last_process(function)(*args, **kwargs) + + return _inner + + def on_process(self, function: Callable[..., Any] = None, process_index: int = None): + """ + A decorator that will run the decorated function on a given process index only. Can also be called using the + `PartialState` class. + + Args: + function (`Callable`, `optional`): + The function to decorate. + process_index (`int`, `optional`): + The index of the process on which to run the function. + + Example: + ```python + # Assume we have 4 processes. + from accelerate import Accelerator + + accelerator = Accelerator() + + + @accelerator.on_process(process_index=2) + def print_something(): + print(f"Printed on process {accelerator.process_index}") + + + print_something() + "Printed on process 2" + ``` + """ + # Initial construction of the decorator. + if (self is not None) and (process_index is not None) and (function is None): + return partial(self.on_process, process_index=process_index) + # For times when the `Accelerator` object itself utilizes this decorator. + if function is None: + if "Accelerator." in self.__qualname__: + function = self + else: + raise ValueError( + "The `on_main_process` decorator must be called with a function on an instantiated `Accelerator` object." + ) + + def _inner(*args, **kwargs): + return PartialState().on_process(function, process_index)(*args, **kwargs) + + return _inner + + def on_local_process(self, function: Callable[..., Any] = None, local_process_index: int = None): + """ + A decorator that will run the decorated function on a given local process index only. Can also be called using + the `PartialState` class. + + Args: + function (`Callable`, *optional*): + The function to decorate. + local_process_index (`int`, *optional*): + The index of the local process on which to run the function. + + Example: + ```python + # Assume we have 2 servers with 4 processes each. + from accelerate import Accelerator + + accelerator = Accelerator() + + + @accelerator.on_local_process(local_process_index=2) + def print_something(): + print(f"Printed on process {accelerator.local_process_index}") + + + print_something() + # On server 1: + "Printed on process 2" + # On server 2: + "Printed on process 2" + ``` + """ + # Initial construction of the decorator. + if (self is not None) and (local_process_index is not None) and (function is None): + return partial(self.on_local_process, local_process_index=local_process_index) + # For times when the `Accelerator` object itself utilizes this decorator. + if function is None: + if "Accelerator." in self.__qualname__: + function = self + else: + raise ValueError( + "The `on_main_process` decorator must be called with a function on an instantiated `Accelerator` object." + ) + + def _inner(*args, **kwargs): + return PartialState().on_local_process(function, local_process_index)(*args, **kwargs) + + return _inner + + @contextmanager + def main_process_first(self): + """ + Lets the main process go first inside a with block. + + The other processes will enter the with block after the main process exits. + + Example: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> with accelerator.main_process_first(): + ... # This will be printed first by process 0 then in a seemingly + ... # random order by the other processes. + ... print(f"This will be printed by process {accelerator.process_index}") + ``` + """ + with self.state.main_process_first(): + yield + + @contextmanager + def local_main_process_first(self): + """ + Lets the local main process go inside a with block. + + The other processes will enter the with block after the main process exits. + + Example: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> with accelerator.local_main_process_first(): + ... # This will be printed first by local process 0 then in a seemingly + ... # random order by the other processes. + ... print(f"This will be printed by process {accelerator.local_process_index}") + ``` + """ + with self.state.local_main_process_first(): + yield + + @contextmanager + def no_sync(self, model): + """ + A context manager to disable gradient synchronizations across DDP processes by calling + `torch.nn.parallel.DistributedDataParallel.no_sync`. + + If `model` is not in DDP, this context manager does nothing + + Args: + model (`torch.nn.Module`): + PyTorch Module that was prepared with `Accelerator.prepare` + + Example: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> dataloader, model, optimizer = accelerator.prepare(dataloader, model, optimizer) + >>> input_a = next(iter(dataloader)) + >>> input_b = next(iter(dataloader)) + + >>> with accelerator.no_sync(): + ... outputs = model(input_a) + ... loss = loss_func(outputs) + ... accelerator.backward(loss) + ... # No synchronization across processes, only accumulate gradients + >>> outputs = model(input_b) + >>> accelerator.backward(loss) + >>> # Synchronization across all processes + >>> optimizer.step() + >>> optimizer.zero_grad() + ``` + """ + context = contextlib.nullcontext + if self.use_distributed: + context = getattr(model, "no_sync", context) + + with context(): + yield + + @staticmethod + @contextmanager + def trigger_sync_in_backward(model): + """Trigger the sync of the gradients in the next backward pass of the model after multiple forward passes under + `Accelerator.no_sync` (only applicable in multi-GPU scenarios). + + If the script is not launched in distributed mode, this context manager does nothing. + + Args: + model (`torch.nn.Module`): + The model for which to trigger the gradient synchronization. + + Example: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> dataloader, model, optimizer = accelerator.prepare(dataloader, model, optimizer) + + >>> with accelerator.no_sync(): + ... loss_a = loss_func(model(input_a)) # first forward pass + ... loss_b = loss_func(model(input_b)) # second forward pass + >>> accelerator.backward(loss_a) # No synchronization across processes, only accumulate gradients + >>> with accelerator.trigger_sync_in_backward(model): + ... accelerator.backward(loss_b) # Synchronization across all processes + >>> optimizer.step() + >>> optimizer.zero_grad() + ``` + """ + if not isinstance(model, torch.nn.parallel.DistributedDataParallel): + yield + return + + old_require_backward_grad_sync = model.require_backward_grad_sync + old_require_forward_param_sync = model.require_forward_param_sync + + # EXPERIMENTAL: This will force grad sync during `backward()`, but it is unknown if it breaks other DDP features. + # https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/torch/nn/parallel/distributed.py#L1453-L1466 + model.require_backward_grad_sync = True + model.require_forward_param_sync = True + # https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/torch/csrc/distributed/c10d/reducer.cpp#L1371-L1402 + model.reducer.prepare_for_backward([]) + try: + yield + finally: + model.require_backward_grad_sync = old_require_backward_grad_sync + model.require_forward_param_sync = old_require_forward_param_sync + + def _do_sync(self, force: bool = False): + "Sets the right `sync_gradients` context and either resets or increases `self.step`" + if self.gradient_state.sync_with_dataloader and self.gradient_state.end_of_dataloader: + self.step = 0 + self.gradient_state._set_sync_gradients(True) + else: + self.step += 1 + self.gradient_state._set_sync_gradients(force or ((self.step % self.gradient_state.num_steps) == 0)) + + @property + def sync_gradients(self): + return self.gradient_state.sync_gradients + + @sync_gradients.setter + def sync_gradients(self, sync_gradients): + self.gradient_state.sync_gradients = sync_gradients + + @property + def gradient_accumulation_steps(self): + return self.gradient_state.num_steps + + @gradient_accumulation_steps.setter + def gradient_accumulation_steps(self, gradient_accumulation_steps): + self.gradient_state.plugin_kwargs.update({"num_steps": gradient_accumulation_steps}) + + @contextmanager + def accumulate(self, *models): + """ + A context manager that will lightly wrap around and perform gradient accumulation automatically + + Args: + *models (list of `torch.nn.Module`): + PyTorch Modules that were prepared with `Accelerator.prepare`. Models passed to `accumulate()` will + skip gradient syncing during backward pass in distributed training + + Example: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator(gradient_accumulation_steps=1) + >>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler) + + >>> for input, output in dataloader: + ... with accelerator.accumulate(model): + ... outputs = model(input) + ... loss = loss_func(outputs) + ... loss.backward() + ... optimizer.step() + ... scheduler.step() + ... optimizer.zero_grad() + ``` + """ + # sync_each_batch=True will guarantee below that self.sync_gradients=True, therefore + # resulting in the nullcontext always being selected. + self._do_sync(force=self.gradient_state.plugin_kwargs.get("sync_each_batch", False)) + with contextlib.ExitStack() as cm_stack: + for m in models: + cm_stack.enter_context(contextlib.nullcontext() if self.sync_gradients else self.no_sync(m)) + yield + + @contextmanager + def join_uneven_inputs(self, joinables, even_batches=None): + """ + A context manager that facilitates distributed training or evaluation on uneven inputs, which acts as a wrapper + around `torch.distributed.algorithms.join`. This is useful when the total batch size does not evenly divide the + length of the dataset. + + Args: + joinables (`list[torch.distributed.algorithms.Joinable]`): + A list of models or optimizers that subclass `torch.distributed.algorithms.Joinable`. Most commonly, a + PyTorch Module that was prepared with `Accelerator.prepare` for DistributedDataParallel training. + even_batches (`bool`, *optional*) + If set, this will override the value of `even_batches` set in the `Accelerator`. If it is not provided, + the default `Accelerator` value wil be used. + + + + `join_uneven_inputs` is only supported for Distributed Data Parallel training on multiple GPUs. For any other + configuration, this method will have no effect. + + + + + + Overidding `even_batches` will not affect iterable-style data loaders. + + + + Example: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator(even_batches=True) + >>> ddp_model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader) + + >>> with accelerator.join_uneven_inputs([ddp_model], even_batches=False): + ... for input, output in dataloader: + ... outputs = model(input) + ... loss = loss_func(outputs) + ... loss.backward() + ... optimizer.step() + ... optimizer.zero_grad() + ``` + """ + if self.distributed_type in ( + DistributedType.MULTI_GPU, + DistributedType.MULTI_NPU, + DistributedType.MULTI_MLU, + DistributedType.MULTI_XPU, + ): + dl_even_batches_values = [] + + if even_batches is not None: + iterable_dl_seen = False + # override value in batch sampler for map-style datasets + for dl_idx, dl in enumerate(self._dataloaders): + if isinstance(dl, DataLoaderDispatcher): + iterable_dl_seen = True + continue + dl_even_batches_values.append((dl_idx, dl.batch_sampler.even_batches)) + dl.batch_sampler.even_batches = even_batches + + if iterable_dl_seen: + warnings.warn( + "Overridding even_batches is only supported for map-style datasets, yet some dataloaders given were iterable" + ) + else: + even_batches = self.even_batches + + enable_join = False if even_batches else True + try: + with Join(joinables, enable=enable_join, throw_on_early_termination=False): + yield + finally: + # reset any batch samplers that have been modified + for dl_idx, even_batches_value in dl_even_batches_values: + self._dataloaders[dl_idx].batch_sampler.even_batches = even_batches_value + else: + # Even when disabled, Join expects models to subclass Joinable, so skip entirely for single process runs + if self.distributed_type != DistributedType.NO: + warnings.warn( + "Joining uneven inputs is only supported for multi-GPU training, as a result `join_uneven_inputs` will have no effect." + ) + + with contextlib.nullcontext(joinables): + yield + + def print(self, *args, **kwargs): + """ + Drop in replacement of `print()` to only print once per server. + + Example: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> accelerator.print("Hello world!") + ``` + """ + self.state.print(*args, **kwargs) + + def _prepare_one(self, obj, first_pass=False, device_placement=None): + # First pass of preparation: DataLoader, model, optimizer + if first_pass: + if isinstance(obj, torch.utils.data.DataLoader): + return self.prepare_data_loader(obj, device_placement=device_placement) + elif isinstance(obj, torch.nn.Module): + return self.prepare_model(obj, device_placement=device_placement) + elif isinstance(obj, torch.optim.Optimizer): + optimizer = self.prepare_optimizer(obj, device_placement=device_placement) + return optimizer + # Second pass of preparation: LR scheduler (which need the full list of optimizers) + elif isinstance(obj, LRScheduler): + scheduler = self.prepare_scheduler(obj) + return scheduler + # Return the unprocessed object if previous criteria was not met + return obj + + def prepare(self, *args, device_placement=None): + """ + Prepare all objects passed in `args` for distributed training and mixed precision, then return them in the same + order. + + Args: + *args (list of objects): + Any of the following type of objects: + + - `torch.utils.data.DataLoader`: PyTorch Dataloader + - `torch.nn.Module`: PyTorch Module + - `torch.optim.Optimizer`: PyTorch Optimizer + - `torch.optim.lr_scheduler.LRScheduler`: PyTorch LR Scheduler + + device_placement (`list[bool]`, *optional*): + Used to customize whether automatic device placement should be performed for each object passed. Needs + to be a list of the same length as `args`. Not compatible with DeepSpeed or FSDP. + + + + You don't need to prepare a model if you only use it for inference without any kind of mixed precision + + + + Examples: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> # Assume a model, optimizer, data_loader and scheduler are defined + >>> model, optimizer, data_loader, scheduler = accelerator.prepare(model, optimizer, data_loader, scheduler) + ``` + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> # Assume a model, optimizer, data_loader and scheduler are defined + >>> device_placement = [True, True, False, False] + >>> # Will place the first to items passed in automatically to the right device but not the last two. + >>> model, optimizer, data_loader, scheduler = accelerator.prepare( + ... model, optimizer, data_loader, scheduler, device_placement=device_placement + ... ) + ``` + """ + if device_placement is None: + device_placement = [None for _ in args] + elif self.distributed_type in (DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM): + raise ValueError("You can't customize device placements with DeepSpeed or Megatron-LM.") + elif len(device_placement) != len(args): + raise ValueError( + f"`device_placement` should be a list with {len(args)} elements (the number of objects passed)." + ) + + for obj in args: + # TODO: Look at enabling native TP training directly with a proper config + if ( + isinstance(obj, torch.nn.Module) + and self.verify_device_map(obj) + and self.distributed_type != DistributedType.NO + and os.environ.get("ACCELERATE_BYPASS_DEVICE_MAP", "false") != "true" + ): + raise ValueError( + "You can't train a model that has been loaded with `device_map='auto'` in any distributed mode." + " Please rerun your script specifying `--num_processes=1` or by launching with `python {{myscript.py}}`." + ) + + if self.distributed_type == DistributedType.DEEPSPEED: + model_count = 0 + for obj in args: + if isinstance(obj, torch.nn.Module): + model_count += 1 + if model_count > 1: + raise AssertionError( + "You can't use same `Accelerator()` instance with multiple models when using DeepSpeed" + ) + + # On TPUs, putting the model on the XLA device will create new parameters, so the corresponding optimizer will + # have parameters disconnected from the model (so no training :-( ). + # If the model and optimizer have parameters on different devices we raise an error. + if self.distributed_type == DistributedType.XLA: + model_device, optimizer_device = self._get_devices() + if model_device is not None and optimizer_device is not None and model_device != optimizer_device: + raise ValueError( + "The model and the optimizer parameters are not on the same device, which probably means you " + "created an optimizer around your model **before** putting on the device. Make sure the line " + "model.to(device) is before the optimizer creation in your script or remove it entirely and use " + "the flag default value for `device_placement` in your `Accelerator` to let it handle that " + "part for you." + ) + + # If we're dealing with device placement, this deals with that by... + tpu_should_fix_optimizer = self.device_placement and self.distributed_type == DistributedType.XLA + if tpu_should_fix_optimizer or (self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "TE"): + # 1. grabbing old model parameters + old_named_params = self._get_named_parameters(*args) + + if self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]: + if self.device.type == "cpu" and self.state.use_ipex: + args = self._prepare_ipex(*args) + elif self.device.type == "xpu" and is_xpu_available(): + args = self._prepare_ipex(*args) + if self.distributed_type == DistributedType.DEEPSPEED: + result = self._prepare_deepspeed(*args) + elif self.distributed_type == DistributedType.MEGATRON_LM: + result = self._prepare_megatron_lm(*args) + else: + if self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "MSAMP": + args = self._prepare_msamp(*args) + # MS-AMP will handle the device placement + device_placement = [False for _ in args] + result = tuple( + self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement) + ) + result = tuple(self._prepare_one(obj, device_placement=d) for obj, d in zip(result, device_placement)) + + if tpu_should_fix_optimizer or (self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "TE"): + # 2. grabbing new model parameters + new_named_params = self._get_named_parameters(*result) + # 3. building a map from the first to the second + mapping = {p: new_named_params[n] for n, p in old_named_params.items()} + # 4. using that map to update the parameters of the optimizer + for obj in result: + if isinstance(obj, torch.optim.Optimizer): + obj._switch_parameters(mapping) + + for item in result: + if any( + item in container + for container in (self._dataloaders, self._models, self._optimizers, self._schedulers) + ): + item._is_accelerate_prepared = True + + return result if len(result) > 1 else result[0] + + def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, evaluation_mode: bool = False): + """ + Prepares a PyTorch model for training in any distributed setup. It is recommended to use + [`Accelerator.prepare`] instead. + + Args: + model (`torch.nn.Module`): + A PyTorch model to prepare. You don't need to prepare a model if it is used only for inference without + any kind of mixed precision + device_placement (`bool`, *optional*): + Whether or not to place the model on the proper device. Will default to `self.device_placement`. + evaluation_mode (`bool`, *optional*, defaults to `False`): + Whether or not to set the model for evaluation only, by just applying mixed precision and + `torch.compile` (if configured in the `Accelerator` object). + + Example: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> # Assume a model is defined + >>> model = accelerator.prepare_model(model) + ``` + """ + if device_placement is None: + device_placement = self.device_placement and self.distributed_type != DistributedType.FSDP + self._models.append(model) + + # TODO: Look at enabling native TP training directly with a proper config + if ( + self.verify_device_map(model) + and self.distributed_type != DistributedType.NO + and os.environ.get("ACCELERATE_BYPASS_DEVICE_MAP", "false") != "true" + ): + raise ValueError( + "You can't train a model that has been loaded with `device_map='auto'` in any distributed mode." + " Please rerun your script specifying `--num_processes=1` or by launching with `python {{myscript.py}}`." + ) + + if self.native_amp: + model._original_forward = model.forward + model_forward_func = model.forward.__func__ if hasattr(model.forward, "__func__") else model.forward + autocast_context = get_mixed_precision_context_manager(self.native_amp, self.autocast_handler) + new_forward = autocast_context(model_forward_func) + if hasattr(model.forward, "__func__"): + model.forward = MethodType(new_forward, model) + model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model) + else: + model.forward = convert_outputs_to_fp32(new_forward) + elif self.mixed_precision == "fp8" and self.fp8_recipe_handler.backend == "TE": + if not has_transformer_engine_layers(model): + with torch.no_grad(): + convert_model(model) + model._converted_to_transformer_engine = True + model._original_forward = model.forward + + kwargs = self.fp8_recipe_handler.to_kwargs() if self.fp8_recipe_handler is not None else {} + if "fp8_format" in kwargs: + kwargs["fp8_format"] = getattr(te_recipe.Format, kwargs["fp8_format"]) + fp8_recipe = te_recipe.DelayedScaling(**kwargs) + model.forward = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe)(model.forward) + + if (getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)) and getattr( + model, "hf_device_map", False + ): + model_devices = set(model.hf_device_map.values()) + if len(model_devices) > 1 and self.distributed_type != DistributedType.NO: + raise ValueError( + "You can't train a model that has been loaded in 8-bit precision on multiple devices in any distributed mode." + " In order to use 8-bit models that have been loaded across multiple GPUs the solution is to use Naive Pipeline Parallelism." + " Therefore you should not specify that you are under any distributed regime in your accelerate config." + ) + current_device = list(model_devices)[0] + current_device_index = current_device.index if isinstance(current_device, torch.device) else current_device + + if torch.device(current_device_index) != self.device: + # if on the first device (GPU 0) we don't care + if (self.device.index is not None) or (current_device_index != 0): + raise ValueError( + "You can't train a model that has been loaded in 8-bit precision on a different device than the one " + "you're training on. Make sure you loaded the model on the correct device using for example `device_map={'':torch.cuda.current_device() or device_map={'':torch.xpu.current_device()}" + ) + + if "cpu" in model_devices or "disk" in model_devices: + raise ValueError( + "You can't train a model that has been loaded in 8-bit precision with CPU or disk offload." + ) + elif device_placement and not self.verify_device_map(model): + model = model.to(self.device) + if not evaluation_mode: + if self.distributed_type in ( + DistributedType.MULTI_GPU, + DistributedType.MULTI_MLU, + DistributedType.MULTI_NPU, + DistributedType.MULTI_XPU, + ): + if any(p.requires_grad for p in model.parameters()): + kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {} + # TODO: Look at enabling native TP training directly with a proper config + if os.environ.get("ACCELERATE_BYPASS_DEVICE_MAP", "false") != "true": + device_ids, output_device = [self.local_process_index], self.local_process_index + else: + device_ids, output_device = None, None + + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=device_ids, output_device=output_device, **kwargs + ) + elif self.distributed_type == DistributedType.FSDP: + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP + + # Check if the model is already a FSDP model due to `Manual Wrapping` and if so, + # don't wrap it again + # In case the model is already compiled using PyTorch 2.0 and the wrapped model in it + # is a FSDP model, don't wrap it again + is_type_fsdp = isinstance(model, FSDP) or ( + is_compiled_module(model) and isinstance(model._orig_mod, FSDP) + ) + + if not is_type_fsdp: + self.state.fsdp_plugin.set_auto_wrap_policy(model) + fsdp_plugin = self.state.fsdp_plugin + kwargs = { + "sharding_strategy": fsdp_plugin.sharding_strategy, + "cpu_offload": fsdp_plugin.cpu_offload, + "auto_wrap_policy": fsdp_plugin.auto_wrap_policy, + "mixed_precision": fsdp_plugin.mixed_precision_policy, + "sync_module_states": fsdp_plugin.sync_module_states, + "backward_prefetch": fsdp_plugin.backward_prefetch, + "forward_prefetch": fsdp_plugin.forward_prefetch, + "use_orig_params": fsdp_plugin.use_orig_params, + "param_init_fn": fsdp_plugin.param_init_fn, + "ignored_modules": fsdp_plugin.ignored_modules, + "limit_all_gathers": fsdp_plugin.limit_all_gathers, + "device_id": self.device, + } + model = FSDP(model, **kwargs) + if fsdp_plugin.activation_checkpointing: + from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointImpl, + apply_activation_checkpointing, + checkpoint_wrapper, + ) + + apply_activation_checkpointing( + model, + checkpoint_wrapper_fn=functools.partial( + checkpoint_wrapper, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, + ), + auto_wrap_policy=fsdp_plugin.auto_wrap_policy, + ) + # if the previous and current models are same, delete the previous one + if len(self._models) > 1 and (self._models[-2] is self._models[-1]): + del self._models[-2] + self._models[-1] = model + elif self.distributed_type == DistributedType.MULTI_CPU: + kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {} + model = torch.nn.parallel.DistributedDataParallel(model, **kwargs) + elif self.distributed_type == DistributedType.XLA and self.state.fork_launched: + model = xmp.MpModelWrapper(model).to(self.device) + # torch.compile should be called last and only if the model isn't already compiled. + if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model): + if not is_torch_version(">=", "2.0"): + raise ValueError("Using `torch.compile` requires PyTorch 2.0 or higher.") + model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs()) + return model + + def _prepare_deepspeed(self, *args): + import deepspeed + + deepspeed_plugin = self.state.deepspeed_plugin + + is_dataloader_present = any(isinstance(obj, torch.utils.data.DataLoader) for obj in args) + result = [ + self._prepare_one(obj, first_pass=True) if isinstance(obj, torch.utils.data.DataLoader) else obj + for obj in args + ] + + if deepspeed_plugin.is_auto("train_micro_batch_size_per_gpu"): + if is_dataloader_present: + batch_sizes = [obj.batch_size for obj in args if hasattr(obj, "batch_size")] + if any(bs is None for bs in batch_sizes): + raise ValueError( + "At least one of the dataloaders passed to `accelerate.prepare()` has `None` as batch size. " + "Please set an integer value in `train_micro_batch_size_per_gpu` in the deepspeed config file " + "or assign integer value to `AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu']`." + ) + if self.split_batches: + batch_sizes = [batch_size // self.num_processes for batch_size in batch_sizes] + + batch_size_per_device = min(batch_sizes) if deepspeed_plugin.is_train_batch_min else max(batch_sizes) + if len(batch_sizes) > 1: + logger.info( + "Since you passed both train and evaluation dataloader, `is_train_batch_min` (here " + f"{deepspeed_plugin.is_train_batch_min} will decide the `train_batch_size` ({batch_size_per_device})." + ) + else: + raise ValueError( + "When using DeepSpeed, `accelerate.prepare()` requires you to pass at least one of training or evaluation dataloaders " + "with `batch_size` attribute returning an integer value " + "or alternatively set an integer value in `train_micro_batch_size_per_gpu` in the deepspeed config file " + "or assign integer value to `AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu']`." + ) + else: + batch_size_per_device = deepspeed_plugin.get_value("train_micro_batch_size_per_gpu") + + # handle `gradient_accumulation_steps` when the value is `auto` + deepspeed_plugin.fill_match( + "gradient_accumulation_steps", + must_match=False, + gradient_accumulation_steps=self.gradient_accumulation_steps, + ) + + config_kwargs = { + "train_micro_batch_size_per_gpu": batch_size_per_device, + "train_batch_size": batch_size_per_device + * deepspeed_plugin.get_value("gradient_accumulation_steps") + * self.num_processes, + "gradient_clipping": 1.0, + "zero_optimization.stage3_gather_16bit_weights_on_model_save": False, + } + + model = None + optimizer = None + scheduler = None + for obj in result: + if isinstance(obj, torch.nn.Module): + model = obj + elif isinstance(obj, (torch.optim.Optimizer, DummyOptim)): + optimizer = obj + elif (isinstance(obj, (LRScheduler, DummyScheduler))) or ( + type(obj).__name__ in deepspeed.runtime.lr_schedules.VALID_LR_SCHEDULES + ): + scheduler = obj + + if optimizer is not None: + if "optimizer" in deepspeed_plugin.deepspeed_config and not isinstance(optimizer, (DummyOptim)): + raise ValueError( + "You cannot specify an optimizer in the config file and in the code at the same time. " + "Please remove the optimizer from the config file or " + "create `accelerate.utils.DummyOptim` in the code." + ) + elif "optimizer" not in deepspeed_plugin.deepspeed_config and isinstance(optimizer, (DummyOptim)): + raise ValueError( + "You cannot create a `DummyOptim` without specifying an optimizer in the config file." + ) + + if isinstance(optimizer, (torch.optim.Optimizer)): + deepspeed_plugin.deepspeed_config["zero_allow_untested_optimizer"] = True + + if scheduler is not None: + if "scheduler" in deepspeed_plugin.deepspeed_config and not isinstance(scheduler, (DummyScheduler)): + raise ValueError( + "You cannot specify a scheduler in the config file and in the code at the same time. " + "Please remove the scheduler from the config file or " + "create `accelerate.utils.DummyScheduler` in the code." + ) + elif ( + "scheduler" not in deepspeed_plugin.deepspeed_config + and isinstance(scheduler, (DummyScheduler)) + and scheduler.lr_scheduler_callable is None + ): + raise ValueError( + "Either specify a scheduler in the config file or " + "pass in the `lr_scheduler_callable` parameter when using `accelerate.utils.DummyScheduler`." + ) + + if optimizer is not None and scheduler is not None: + if isinstance(optimizer, (DummyOptim)) and not isinstance(scheduler, (DummyScheduler)): + raise ValueError( + "You can only specify `accelerate.utils.DummyScheduler` in the code when using " + "`accelerate.utils.DummyOptim`." + ) + + if model is not None: + # deal with config keys that use `auto` value and rely on model's hidden_size + hidden_size_based_keys = [ + "zero_optimization.reduce_bucket_size", + "zero_optimization.stage3_prefetch_bucket_size", + "zero_optimization.stage3_param_persistence_threshold", + ] + hidden_size_auto_keys = [x for x in hidden_size_based_keys if deepspeed_plugin.is_auto(x)] + if len(hidden_size_auto_keys) > 0: + reasoning = ( + "therefore it's not possible to automatically fill out the following `auto` entries " + + f"in the DeepSpeed config file: {hidden_size_auto_keys}. You can fix that by replacing " + + "`auto` values for these keys with an integer value of your choice." + ) + if not hasattr(model, "config"): + raise ValueError("Can't find `model.config` entry, " + reasoning) + + if hasattr(model.config, "hidden_size"): + hidden_size = model.config.hidden_size + elif hasattr(model.config, "hidden_sizes"): + # if there are many hidden sizes pick the largest one + hidden_size = max(model.config.hidden_sizes) + else: + raise ValueError( + "Can find neither `model.config.hidden_size` nor `model.config.hidden_sizes`, " + reasoning + ) + + config_kwargs.update( + { + "zero_optimization.reduce_bucket_size": hidden_size * hidden_size, + "zero_optimization.stage3_prefetch_bucket_size": 0.9 * hidden_size * hidden_size, + "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size, + } + ) + + if isinstance(optimizer, (DummyOptim)): + config_kwargs.update( + {"optimizer.params.lr": optimizer.lr, "optimizer.params.weight_decay": optimizer.weight_decay} + ) + if isinstance(scheduler, (DummyScheduler)) and scheduler.lr_scheduler_callable is None: + max_lr = ( + getattr(scheduler.optimizer, "lr", None) + if getattr(scheduler.optimizer, "defaults", None) is None + else scheduler.optimizer.defaults["lr"] + ) + config_kwargs.update( + { + "scheduler.params.warmup_min_lr": 0, + "scheduler.params.warmup_max_lr": max_lr, + "scheduler.params.warmup_num_steps": scheduler.warmup_num_steps, + } + ) + if scheduler.total_num_steps is not None: + config_kwargs["scheduler.params.total_num_steps"] = ( + math.ceil(scheduler.total_num_steps / self.num_processes) + if not self.split_batches + else scheduler.total_num_steps + ) + deepspeed_plugin.deepspeed_config_process(must_match=False, **config_kwargs) + self.deepspeed_config = deepspeed_plugin.deepspeed_config + kwargs = dict(model=model, config_params=self.deepspeed_config) + if optimizer is not None: + if isinstance(optimizer, (DummyOptim)): + kwargs["model_parameters"] = optimizer.params + if isinstance(scheduler, (DummyScheduler)) and scheduler.lr_scheduler_callable is not None: + kwargs["lr_scheduler"] = scheduler.lr_scheduler_callable + else: + if self.deepspeed_config["zero_optimization"].get("offload_optimizer", {}).get( + "device", "none" + ) != "none" and self.deepspeed_config.get("zero_force_ds_cpu_optimizer", True): + from deepspeed.ops.adam import DeepSpeedCPUAdam + + defaults = {k: v for k, v in optimizer.defaults.items() if k in ["lr", "weight_decay"]} + optimizer = DeepSpeedCPUAdam(optimizer.param_groups, **defaults) + kwargs["optimizer"] = optimizer + if scheduler is not None: + if type(scheduler).__name__ in deepspeed.runtime.lr_schedules.VALID_LR_SCHEDULES: + kwargs["lr_scheduler"] = scheduler + + engine, optimizer, _, lr_scheduler = deepspeed.initialize(**kwargs) + if optimizer is not None: + optimizer = DeepSpeedOptimizerWrapper(optimizer) + if scheduler is not None: + if lr_scheduler is None: + scheduler = AcceleratedScheduler( + scheduler, + optimizer, + step_with_optimizer=self.step_scheduler_with_optimizer, + split_batches=self.split_batches, + ) + else: + scheduler = DeepSpeedSchedulerWrapper(lr_scheduler, optimizer) + + for i in range(len(result)): + if isinstance(result[i], torch.nn.Module): + result[i] = engine + elif isinstance(result[i], (torch.optim.Optimizer, DummyOptim)): + result[i] = optimizer + elif (isinstance(result[i], (LRScheduler, DummyScheduler))) or ( + type(result[i]).__name__ in deepspeed.runtime.lr_schedules.VALID_LR_SCHEDULES + ): + result[i] = scheduler + # pointing for deepspeed_engine_wrapped.backward() + self.deepspeed_engine_wrapped = DeepSpeedEngineWrapper(engine) + self._models.append(engine) + if optimizer is not None: + self._optimizers.append(optimizer) + if scheduler is not None: + self._schedulers.append(scheduler) + if len(self._models) > 1: + raise AssertionError( + "You can't use same `Accelerator()` instance with multiple models when using DeepSpeed" + ) + return tuple(result) + + def _prepare_megatron_lm(self, *args): + megatron_lm_plugin = self.state.megatron_lm_plugin + if not megatron_lm_plugin.megatron_dataset_flag: + batch_sizes = [obj.batch_size for obj in args if hasattr(obj, "batch_size")] + if len(batch_sizes) == 0: + raise ValueError( + "You must specify a training or evaluation dataloader in `accelerate.prepare()` when using Megatron-LM." + ) + + micro_batch_size = min(batch_sizes) if megatron_lm_plugin.is_train_batch_min else max(batch_sizes) + if len(batch_sizes) > 1: + logger.info( + "Since you passed both train and evaluation dataloader, `is_train_batch_min` (here " + f"{megatron_lm_plugin.is_train_batch_min} will decide the `train_batch_size` ({micro_batch_size})." + ) + else: + for obj in args: + if isinstance(obj, MegatronLMDummyDataLoader): + micro_batch_size = obj.dataset_args["micro_batch_size"] + break + + dp_degree = self.num_processes // (megatron_lm_plugin.tp_degree * megatron_lm_plugin.pp_degree) + megatron_lm_plugin.set_training_args(micro_batch_size, dp_degree) + + model = None + optimizer = None + scheduler = None + is_dummy_scheduler = False + batch_data = None + for obj in args: + if isinstance(obj, torch.utils.data.DataLoader) and batch_data is None: + batch_data = next(iter(obj)) + if isinstance(obj, torch.nn.Module): + model = obj + elif isinstance(obj, (torch.optim.Optimizer)): + optimizer = obj + elif isinstance(obj, (LRScheduler, MegatronLMDummyScheduler)): + scheduler = obj + + if model is not None: + megatron_lm_plugin.set_network_size_args(model, batch_data) + if optimizer is not None: + megatron_lm_plugin.set_optimizer_type(optimizer) + if scheduler is not None: + is_dummy_scheduler = isinstance(scheduler, MegatronLMDummyScheduler) + if not is_dummy_scheduler: + raise ValueError( + "You can't use a custom scheduler with Megatron-LM. Please use the `accelerate.utils.MegatronLMDummyScheduler` instead." + ) + megatron_lm_plugin.set_scheduler_args(scheduler) + + # initialize megatron-lm + megatron_lm_initialize(self, args_defaults=megatron_lm_plugin.megatron_lm_default_args) + counter = 0 + result = [] + for obj in args: + if isinstance(obj, torch.utils.data.DataLoader): + result.append(megatron_lm_prepare_data_loader(self, obj)) + counter += 1 + elif isinstance(obj, MegatronLMDummyDataLoader): + if counter == 0: + obj.set_megatron_data_args() + dataloaders = megatron_lm_prepare_data_loader(self, obj) + result.append(dataloaders[counter]) + counter += 1 + else: + result.append(obj) + + if model is not None: + model = megatron_lm_prepare_model(self) + if optimizer is not None: + optimizer = megatron_lm_prepare_optimizer(self, model) + if scheduler is not None: + scheduler = megatron_lm_prepare_scheduler(self, optimizer, scheduler) + + if model is not None: + model = MegatronEngine(self, model, optimizer, scheduler) + if optimizer is not None: + optimizer = MegatronLMOptimizerWrapper(optimizer) + if scheduler is not None: + scheduler = MegatronLMSchedulerWrapper(scheduler, optimizer) + + for i in range(len(result)): + if isinstance(result[i], torch.nn.Module): + result[i] = model + elif isinstance(result[i], torch.optim.Optimizer): + result[i] = optimizer + elif isinstance(result[i], MegatronLMDummyScheduler): + result[i] = scheduler + if model is not None: + self._models.append(model) + if optimizer is not None: + self._optimizers.append(optimizer) + if scheduler is not None: + self._schedulers.append(scheduler) + if len(self._models) > 1: + raise AssertionError( + "You can't use same `Accelerator()` instance with multiple models when using Megatron-LM" + ) + return tuple(result) + + def _prepare_ipex(self, *args): + if not is_ipex_available(): + raise ImportError( + "IPEX is not installed or IPEX's version does not match current PyTorch version. Please refer" + " to https://github.com/intel/intel-extension-for-pytorch." + ) + else: + import intel_extension_for_pytorch as ipex + + model = None + optimizer = None + result = [obj for obj in args] + for obj in result: + if isinstance(obj, torch.nn.Module): + model = obj + model.train() + elif isinstance(obj, (torch.optim.Optimizer)): + optimizer = obj + if optimizer is not None and model is not None: + dtype = torch.bfloat16 if self.state.mixed_precision == "bf16" else None + if self.device.type == "xpu" and is_xpu_available(): + model = model.to(self.device) + model, optimizer = torch.xpu.optimize( + model, optimizer=optimizer, dtype=dtype, inplace=True, level="O1" + ) + else: + model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=dtype, inplace=True, level="O1") + for i in range(len(result)): + if isinstance(result[i], torch.nn.Module): + result[i] = model + elif isinstance(result[i], (torch.optim.Optimizer)): + result[i] = optimizer + return tuple(result) + + def _prepare_msamp(self, *args): + if not is_msamp_available(): + raise ImportError( + "MS-AMP was not found on your system. Please ensure that MS-AMP is available " + " or choose `'te'` as the backend for FP8 mixed precision training." + ) + else: + import msamp + + model, optimizer = None, None + num_models, num_optimizers = 0, 0 + result = [obj for obj in args] + for obj in result: + if isinstance(obj, torch.nn.Module): + model = obj + num_models += 1 + elif isinstance(obj, (torch.optim.Optimizer)): + optimizer = obj + num_optimizers += 1 + if optimizer is None or model is None: + raise ValueError( + "You must pass a model and an optimizer together to `accelerate.prepare()` when using MS-AMP." + ) + elif num_models > 1 or num_optimizers > 1: + raise ValueError( + f"You can't use multiple models ({num_models}) or optimizers {num_optimizers} with MS-AMP." + ) + else: + model, optimizer = msamp.initialize(model, optimizer, opt_level=self.fp8_recipe_handler.opt_level) + for i in range(len(result)): + if isinstance(result[i], torch.nn.Module): + result[i] = model + elif isinstance(result[i], (torch.optim.Optimizer)): + result[i] = optimizer + return tuple(result) + + def prepare_data_loader( + self, data_loader: torch.utils.data.DataLoader, device_placement=None, slice_fn_for_dispatch=None + ): + """ + Prepares a PyTorch DataLoader for training in any distributed setup. It is recommended to use + [`Accelerator.prepare`] instead. + + Args: + data_loader (`torch.utils.data.DataLoader`): + A vanilla PyTorch DataLoader to prepare + device_placement (`bool`, *optional*): + Whether or not to place the batches on the proper device in the prepared dataloader. Will default to + `self.device_placement`. + slice_fn_for_dispatch (`Callable`, *optional*`): + If passed, this function will be used to slice tensors across `num_processes`. Will default to + [`~utils.slice_tensors`]. This argument is used only when `dispatch_batches` is set to `True` and will + be ignored otherwise. + + Example: + + ```python + >>> import torch + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> data_loader = torch.utils.data.DataLoader(...) + >>> data_loader = accelerator.prepare_data_loader(data_loader, device_placement=True) + ``` + """ + # Ensure we can't double wrap a DataLoader due to `find_batch_size` + if getattr(data_loader, "_is_accelerate_prepared", False): + if data_loader not in self._dataloaders: + self._dataloaders.append(data_loader) + return data_loader + if device_placement is None: + device_placement = self.device_placement if self.distributed_type != DistributedType.XLA else False + prepared_data_loader = prepare_data_loader( + data_loader, + self.device, + num_processes=self.num_processes, + process_index=self.process_index, + split_batches=self.split_batches, + put_on_device=device_placement, + rng_types=self.rng_types.copy(), + dispatch_batches=self.dispatch_batches, + even_batches=self.even_batches, + slice_fn_for_dispatch=slice_fn_for_dispatch, + use_seedable_sampler=self.use_seedable_sampler, + ) + self._dataloaders.append(prepared_data_loader) + return prepared_data_loader + + def prepare_optimizer(self, optimizer: torch.optim.Optimizer, device_placement=None): + """ + Prepares a PyTorch Optimizer for training in any distributed setup. It is recommended to use + [`Accelerator.prepare`] instead. + + Args: + optimizer (`torch.optim.Optimizer`): + A vanilla PyTorch optimizer to prepare + device_placement (`bool`, *optional*): + Whether or not to place the optimizer on the proper device. Will default to `self.device_placement`. + + Example: + + ```python + >>> import torch + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> optimizer = torch.optim.Adam(...) + >>> optimizer = accelerator.prepare_optimizer(optimizer, device_placement=True) + ``` + """ + # Ensure we can't double wrap an optimizer due to `find_batch_size` + if getattr(optimizer, "_is_accelerate_prepared", False): + if optimizer not in self._optimizers: + self._optimizers.append(optimizer) + return optimizer + if device_placement is None: + device_placement = self.device_placement + optimizer = AcceleratedOptimizer(optimizer, device_placement=device_placement, scaler=self.scaler) + self._optimizers.append(optimizer) + return optimizer + + def prepare_scheduler(self, scheduler: LRScheduler): + """ + Prepares a PyTorch Scheduler for training in any distributed setup. It is recommended to use + [`Accelerator.prepare`] instead. + + Args: + scheduler (`torch.optim.lr_scheduler.LRScheduler`): + A vanilla PyTorch scheduler to prepare + + Example: + + ```python + >>> import torch + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> optimizer = torch.optim.Adam(...) + >>> scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, ...) + >>> scheduler = accelerator.prepare_scheduler(scheduler) + ``` + """ + # Ensure we can't double wrap a scheduler due to `find_batch_size` + if getattr(scheduler, "_is_accelerate_prepared", False): + if scheduler not in self._schedulers: + self._schedulers.append(scheduler) + return scheduler + # We try to find the optimizer associated with `scheduler`, the default is the full list. + optimizer = self._optimizers + for opt in self._optimizers: + if getattr(scheduler, "optimizer", None) == opt.optimizer: + optimizer = opt + break + scheduler = AcceleratedScheduler( + scheduler, + optimizer, + step_with_optimizer=self.step_scheduler_with_optimizer, + split_batches=self.split_batches, + ) + self._schedulers.append(scheduler) + return scheduler + + def backward(self, loss, **kwargs): + """ + Scales the gradients in accordance to the `GradientAccumulationPlugin` and calls the correct `backward()` based + on the configuration. + + Should be used in lieu of `loss.backward()`. + + Example: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator(gradient_accumulation_steps=2) + >>> outputs = model(inputs) + >>> loss = loss_fn(outputs, labels) + >>> accelerator.backward(loss) + ``` + """ + if self.distributed_type != DistributedType.DEEPSPEED: + # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward` + loss = loss / self.gradient_accumulation_steps + if self.distributed_type == DistributedType.DEEPSPEED: + self.deepspeed_engine_wrapped.backward(loss, **kwargs) + elif self.distributed_type == DistributedType.MEGATRON_LM: + return + elif self.scaler is not None: + self.scaler.scale(loss).backward(**kwargs) + else: + loss.backward(**kwargs) + + def set_trigger(self): + """ + Sets the internal trigger tensor to 1 on the current process. A latter check should follow using this which + will check across all processes. + + Note: + Does not require `wait_for_everyone()` + + Example: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> # Assume later in the training script + >>> # `should_do_breakpoint` is a custom function to monitor when to break, + >>> # e.g. when the loss is NaN + >>> if should_do_breakpoint(loss): + ... accelerator.set_trigger() + >>> # Assume later in the training script + >>> if accelerator.check_breakpoint(): + ... break + ``` + """ + self.flag_tensor = torch.tensor(1, device=self.device) + + def check_trigger(self): + """ + Checks if the internal trigger tensor has been set to 1 in any of the processes. If so, will return `True` and + reset the trigger tensor to 0. + + Note: + Does not require `wait_for_everyone()` + + Example: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> # Assume later in the training script + >>> # `should_do_breakpoint` is a custom function to monitor when to break, + >>> # e.g. when the loss is NaN + >>> if should_do_breakpoint(loss): + ... accelerator.set_trigger() + >>> # Assume later in the training script + >>> if accelerator.check_trigger(): + ... break + ``` + """ + # Now that we are outside `__init__`, we can initialize it if it is `None` on device + if self.flag_tensor is None: + self.flag_tensor = torch.tensor(0, device=self.device) + flag_tensor = self.reduce(self.flag_tensor) + if flag_tensor.item() >= 1: + self.flag_tensor = torch.tensor(0, device=self.device) + return True + return False + + def unscale_gradients(self, optimizer=None): + """ + Unscale the gradients in mixed precision training with AMP. This is a noop in all other settings. + + Likely should be called through [`Accelerator.clip_grad_norm_`] or [`Accelerator.clip_grad_value_`] + + Args: + optimizer (`torch.optim.Optimizer` or `list[torch.optim.Optimizer]`, *optional*): + The optimizer(s) for which to unscale gradients. If not set, will unscale gradients on all optimizers + that were passed to [`~Accelerator.prepare`]. + + Example: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> model, optimizer = accelerator.prepare(model, optimizer) + >>> outputs = model(inputs) + >>> loss = loss_fn(outputs, labels) + >>> accelerator.backward(loss) + >>> accelerator.unscale_gradients(optimizer=optimizer) + ``` + """ + if self.native_amp and self.mixed_precision == "fp16": + if optimizer is None: + # TODO: this unscales all optimizers where we should only unscale the one where parameters are. + optimizer = self._optimizers + elif not isinstance(optimizer, (tuple, list)): + optimizer = [optimizer] + for opt in optimizer: + while isinstance(opt, AcceleratedOptimizer): + opt = opt.optimizer + self.scaler.unscale_(opt) + + def clip_grad_norm_(self, parameters, max_norm, norm_type=2): + """ + Should be used in place of `torch.nn.utils.clip_grad_norm_`. + + Returns: + `torch.Tensor`: Total norm of the parameter gradients (viewed as a single vector). + + Example: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator(gradient_accumulation_steps=2) + >>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler) + + >>> for input, target in dataloader: + ... optimizer.zero_grad() + ... output = model(input) + ... loss = loss_func(output, target) + ... accelerator.backward(loss) + ... if accelerator.sync_gradients: + ... accelerator.clip_grad_norm_(model.parameters(), max_grad_norm) + ... optimizer.step() + ``` + """ + if self.distributed_type == DistributedType.FSDP: + self.unscale_gradients() + parameters = [p for p in parameters] + for model in self._models: + if parameters == [p for p in model.parameters()]: + return model.clip_grad_norm_(max_norm, norm_type) + elif self.distributed_type == DistributedType.DEEPSPEED: + # `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed + # We cannot return the gradient norm because DeepSpeed does it. + return None + elif self.distributed_type == DistributedType.XLA: + # Reduce gradients first for XLA + for acc_opt in self._optimizers: + if not acc_opt.gradient_state.is_xla_gradients_synced: + opt = acc_opt + while isinstance(opt, AcceleratedOptimizer): + opt = opt.optimizer + gradients = xm._fetch_gradients(opt) + # Use xm.all_reduce to perform an in-place all-reduce. Recusrsive all-reduce each tensor + # one by one in self.reduce is non-inplace. + xm.all_reduce("sum", gradients, scale=1.0 / self.num_processes) + # Set is_xla_gradients_synced to True to avoid all-reduce twice in the AcceleratedOptimizer step. + acc_opt.gradient_state.is_xla_gradients_synced = True + self.unscale_gradients() + return torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=norm_type) + + def clip_grad_value_(self, parameters, clip_value): + """ + Should be used in place of `torch.nn.utils.clip_grad_value_`. + + Example: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator(gradient_accumulation_steps=2) + >>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler) + + >>> for input, target in dataloader: + ... optimizer.zero_grad() + ... output = model(input) + ... loss = loss_func(output, target) + ... accelerator.backward(loss) + ... if accelerator.sync_gradients: + ... accelerator.clip_grad_value_(model.parameters(), clip_value) + ... optimizer.step() + ``` + """ + if self.distributed_type in [DistributedType.DEEPSPEED, DistributedType.FSDP]: + raise Exception("DeepSpeed and FSDP do not support `clip_grad_value_`. Use `clip_grad_norm_` instead.") + self.unscale_gradients() + torch.nn.utils.clip_grad_value_(parameters, clip_value) + + def gather(self, tensor): + """ + Gather the values in *tensor* across all processes and concatenate them on the first dimension. Useful to + regroup the predictions from all processes when doing evaluation. + + Note: + This gather happens in all processes. + + Args: + tensor (`torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`): + The tensors to gather across all processes. + + Returns: + `torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`: The gathered tensor(s). Note that the + first dimension of the result is *num_processes* multiplied by the first dimension of the input tensors. + + Example: + + ```python + >>> # Assuming four processes + >>> import torch + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> process_tensor = torch.tensor([accelerator.process_index]) + >>> gathered_tensor = accelerator.gather(process_tensor) + >>> gathered_tensor + tensor([0, 1, 2, 3]) + ``` + """ + return gather(tensor) + + def gather_for_metrics(self, input_data): + """ + Gathers `input_data` and potentially drops duplicates in the last batch if on a distributed system. Should be + used for gathering the inputs and targets for metric calculation. + + Args: + input (`torch.Tensor`, `object`, a nested tuple/list/dictionary of `torch.Tensor`, or a nested tuple/list/dictionary of `object`): + The tensors or objects for calculating metrics across all processes + + Example: + + ```python + >>> # Assuming two processes, with a batch size of 5 on a dataset with 9 samples + >>> import torch + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> dataloader = torch.utils.data.DataLoader(range(9), batch_size=5) + >>> dataloader = accelerator.prepare(dataloader) + >>> batch = next(iter(dataloader)) + >>> gathered_items = accelerator.gather_for_metrics(batch) + >>> len(gathered_items) + 9 + ``` + """ + + try: + recursively_apply(lambda x: x, input_data, error_on_other_type=True) + all_tensors = True + except TypeError: + all_tensors = False + + if not all_tensors: + data = gather_object(input_data) + else: + data = self.gather(input_data) + + try: + if self.gradient_state.end_of_dataloader: + # at the end of a dataloader, `gather_for_metrics` regresses to + # `gather` unless the dataset has a remainder so log. + if self.gradient_state.remainder == -1: + logger.info( + "The used dataset had no length, returning gathered tensors. You should drop the remainder yourself." + ) + return data + elif self.gradient_state.remainder > 0: + # Last batch needs to be truncated on distributed systems as it contains additional samples + def _adjust_samples(tensor): + return tensor[: self.gradient_state.remainder] + + return recursively_apply(_adjust_samples, data) + else: # remainder is 0 + # no remainder even though at end of dataloader, so nothing to do. + return data + else: + # Not at the end of the dataloader, no need to adjust the tensors + return data + except Exception: + # Dataset had no length or raised an error + return data + + def reduce(self, tensor, reduction="sum", scale=1.0): + """ + Reduce the values in *tensor* across all processes based on *reduction*. + + Note: + All processes get the reduced value. + + Args: + tensor (`torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`): + The tensors to reduce across all processes. + reduction (`str`, *optional*, defaults to "sum"): + A reduction type, can be one of 'sum', 'mean', or 'none'. If 'none', will not perform any operation. + scale (`float`, *optional*, defaults to 1.0): + A default scaling value to be applied after the reduce, only valied on XLA. + + Returns: + `torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`: + The reduced tensor(s). + + Example: + + ```python + >>> # Assuming two processes + >>> import torch + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> process_tensor = torch.arange(accelerator.num_processes) + 1 + (2 * accelerator.process_index) + >>> process_tensor = process_tensor.to(accelerator.device) + >>> reduced_tensor = accelerator.reduce(process_tensor, reduction="sum") + >>> reduced_tensor + tensor([4, 6]) + ``` + """ + return reduce(tensor, reduction, scale) + + def pad_across_processes(self, tensor, dim=0, pad_index=0, pad_first=False): + """ + Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so + they can safely be gathered. + + Args: + tensor (nested list/tuple/dictionary of `torch.Tensor`): + The data to gather. + dim (`int`, *optional*, defaults to 0): + The dimension on which to pad. + pad_index (`int`, *optional*, defaults to 0): + The value with which to pad. + pad_first (`bool`, *optional*, defaults to `False`): + Whether to pad at the beginning or the end. + + Returns: + `torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`: + The padded tensor(s). + + Example: + + ```python + >>> # Assuming two processes, with the first processes having a tensor of size 1 and the second of size 2 + >>> import torch + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> process_tensor = torch.arange(accelerator.process_index + 1).to(accelerator.device) + >>> padded_tensor = accelerator.pad_across_processes(process_tensor) + >>> padded_tensor.shape + torch.Size([2]) + ``` + """ + return pad_across_processes(tensor, dim=dim, pad_index=pad_index, pad_first=pad_first) + + def unwrap_model(self, model, keep_fp32_wrapper: bool = True): + """ + Unwraps the `model` from the additional layer possible added by [`~Accelerator.prepare`]. Useful before saving + the model. + + Args: + model (`torch.nn.Module`): + The model to unwrap. + keep_fp32_wrapper (`bool`, *optional*, defaults to `True`): + Whether to not remove the mixed precision hook if it was added. + + Returns: + `torch.nn.Module`: The unwrapped model. + + Example: + + ```python + >>> # Assuming two GPU processes + >>> from torch.nn.parallel import DistributedDataParallel + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> model = accelerator.prepare(MyModel()) + >>> print(model.__class__.__name__) + DistributedDataParallel + + >>> model = accelerator.unwrap_model(model) + >>> print(model.__class__.__name__) + MyModel + ``` + """ + return extract_model_from_parallel(model, keep_fp32_wrapper) + + def wait_for_everyone(self): + """ + Will stop the execution of the current process until every other process has reached that point (so this does + nothing when the script is only run in one process). Useful to do before saving a model. + + Example: + + ```python + >>> # Assuming two GPU processes + >>> import time + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> if accelerator.is_main_process: + ... time.sleep(2) + >>> else: + ... print("I'm waiting for the main process to finish its sleep...") + >>> accelerator.wait_for_everyone() + >>> # Should print on every process at the same time + >>> print("Everyone is here") + ``` + """ + wait_for_everyone() + + @on_main_process + def init_trackers(self, project_name: str, config: dict | None = None, init_kwargs: dict | None = {}): + """ + Initializes a run for all trackers stored in `self.log_with`, potentially with starting configurations + + Args: + project_name (`str`): + The name of the project. All trackers will save their data based on this + config (`dict`, *optional*): + Optional starting configuration to be logged. + init_kwargs (`dict`, *optional*): + A nested dictionary of kwargs to be passed to a specific tracker's `__init__` function. Should be + formatted like so: + ```python + {"wandb": {"tags": ["tag_a", "tag_b"]}} + ``` + + Example: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator(log_with="tensorboard") + >>> accelerator.init_trackers( + ... project_name="my_project", + ... config={"learning_rate": 0.001, "batch_size": 32}, + ... init_kwargs={"tensorboard": {"flush_secs": 60}}, + ... ) + ``` + """ + for tracker in self.log_with: + if issubclass(type(tracker), GeneralTracker): + # Custom trackers are already initialized + self.trackers.append(tracker) + else: + tracker_init = LOGGER_TYPE_TO_CLASS[str(tracker)] + if tracker_init.requires_logging_directory: + # We can skip this check since it was done in `__init__` + self.trackers.append( + tracker_init(project_name, self.logging_dir, **init_kwargs.get(str(tracker), {})) + ) + else: + self.trackers.append(tracker_init(project_name, **init_kwargs.get(str(tracker), {}))) + if config is not None: + for tracker in self.trackers: + tracker.store_init_configuration(config) + + def get_tracker(self, name: str, unwrap: bool = False): + """ + Returns a `tracker` from `self.trackers` based on `name` on the main process only. + + Args: + name (`str`): + The name of a tracker, corresponding to the `.name` property. + unwrap (`bool`): + Whether to return the internal tracking mechanism or to return the wrapped tracker instead + (recommended). + + Returns: + `GeneralTracker`: The tracker corresponding to `name` if it exists. + + Example: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator(log_with="tensorboard") + >>> accelerator.init_trackers("my_project") + >>> tensorboard_tracker = accelerator.get_tracker("tensorboard") + ``` + """ + if len(self.trackers) > 0: + for tracker in self.trackers: + if tracker.name == name: + return tracker.tracker if unwrap else tracker + raise ValueError(f"{name} is not an available tracker stored inside the `Accelerator`.") + # Handle tracker only made on main process + return GeneralTracker(_blank=True) + + @on_main_process + def log(self, values: dict, step: int | None = None, log_kwargs: dict | None = {}): + """ + Logs `values` to all stored trackers in `self.trackers` on the main process only. + + Args: + values (`dict`): + Values should be a dictionary-like object containing only types `int`, `float`, or `str`. + step (`int`, *optional*): + The run step. If included, the log will be affiliated with this step. + log_kwargs (`dict`, *optional*): + A nested dictionary of kwargs to be passed to a specific tracker's `log` function. Should be formatted + like so: + ```python + {"wandb": {"tags": ["tag_a", "tag_b"]}} + ``` + + Example: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator(log_with="tensorboard") + >>> accelerator.init_trackers("my_project") + >>> accelerator.log({"loss": 0.5, "accuracy": 0.9}) + ``` + """ + for tracker in self.trackers: + tracker.log(values, step=step, **log_kwargs.get(tracker.name, {})) + + @on_main_process + def end_training(self): + """ + Runs any special end training behaviors, such as stopping trackers on the main process only. Should always be + called at the end of your script if using experiment tracking. + + Example: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator(log_with="tensorboard") + >>> accelerator.init_trackers("my_project") + >>> # Do training + >>> accelerator.end_training() + ``` + """ + for tracker in self.trackers: + tracker.finish() + + def save(self, obj, f, safe_serialization=False): + """ + Save the object passed to disk once per machine. Use in place of `torch.save`. + + Args: + obj (`object`): The object to save. + f (`str` or `os.PathLike`): Where to save the content of `obj`. + safe_serialization (`bool`, *optional*, defaults to `False`): Whether to save `obj` using `safetensors` + + Note: + If `save_on_each_node` was passed in as a `ProjectConfiguration`, will save the object once per node, + rather than only once on the main node. + + Example: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> arr = [0, 1, 2, 3] + >>> accelerator.save(arr, "array.pkl") + ``` + """ + save( + obj, + f, + save_on_each_node=self.project_configuration.save_on_each_node, + safe_serialization=safe_serialization, + ) + + def save_model( + self, + model: torch.nn.Module, + save_directory: Union[str, os.PathLike], + max_shard_size: Union[int, str] = "10GB", + safe_serialization: bool = True, + ): + """ + Save a model so that it can be re-loaded using load_checkpoint_in_model + + Arguments: + model: (`torch.nn.Module`): + Model to be saved. The model can be wrapped or unwraped. + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): + The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size + lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`). + + + + If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard + which will be bigger than `max_shard_size`. + + + + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). + + Example: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> model = ... + >>> accelerator.save_model(model, save_directory) + ``` + """ + + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + + # get the state_dict of the model + if any( + [ + module._hf_hook.offload + for module in model.modules() + if hasattr(module, "_hf_hook") and isinstance(module._hf_hook, AlignDevicesHook) + ] + ): + state_dict = get_state_dict_offloaded_model(model) + else: + if any(param.device == torch.device("meta") for param in model.parameters()): + raise RuntimeError("You can't save the model since some parameters are on the meta device.") + state_dict = self.get_state_dict(model) + + if safe_serialization: + state_dict = clean_state_dict_for_safetensors(state_dict) + weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME + + # Shard the model if it is too big. + shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name) + + # Clean the folder from a previous save + for filename in os.listdir(save_directory): + full_filename = os.path.join(save_directory, filename) + # If we have a shard file that is not going to be replaced, we delete it, but only from the main process + # in distributed settings to avoid race conditions. + weights_no_suffix = weights_name.replace(".bin", "") + + # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005 + filename_no_suffix = filename.replace(".bin", "") + reg = re.compile(r"(.*?)-\d{5}-of-\d{5}") + + if ( + filename.startswith(weights_no_suffix) + and os.path.isfile(full_filename) + and filename not in shards.keys() + and reg.fullmatch(filename_no_suffix) is not None + and PartialState().is_main_process + ): + os.remove(full_filename) + + # Save the model + for shard_file, shard in shards.items(): + self.save(shard, os.path.join(save_directory, shard_file), safe_serialization=safe_serialization) + + if index is None: + path_to_weights = os.path.join(save_directory, WEIGHTS_NAME) + logger.info(f"Model weights saved in {path_to_weights}") + else: + save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME + save_index_file = os.path.join(save_directory, save_index_file) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + logger.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " + f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + def register_save_state_pre_hook(self, hook: Callable[..., None]) -> hooks.RemovableHandle: + """ + Registers a pre hook to be run before `save_checkpoint` is called in [`Accelerator.save_state`]. + + Args: + hook (`Callable`): + A function to be called in [`Accelerator.save_state`] before `save_checkpoint`. + + The hook should have the following signature: + + `hook(models: list[torch.nn.Module], weights: list[dict[str, torch.Tensor]], input_dir: str) -> None` + + The `models` argument are the models as saved in the accelerator state under `accelerator._models`, `weigths` + argument are the state dicts of the `models`, and the `input_dir` argument is the `input_dir` argument passed + to [`Accelerator.load_state`]. + + + + Should only be used in conjunction with [`Accelerator.register_load_state_pre_hook`]. Can be useful to save + configurations in addition to model weights. Can also be used to overwrite model saving with a customized + method. In this case, make sure to remove already loaded weights from the weights list. + + + + Returns: + `torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling + `handle.remove()` + """ + handle = hooks.RemovableHandle(self._save_model_state_pre_hook) + self._save_model_state_pre_hook[handle.id] = hook + return handle + + def save_state(self, output_dir: str = None, safe_serialization: bool = True, **save_model_func_kwargs): + """ + Saves the current states of the model, optimizer, scaler, RNG generators, and registered objects to a folder. + + If a `ProjectConfiguration` was passed to the `Accelerator` object with `automatic_checkpoint_naming` enabled + then checkpoints will be saved to `self.project_dir/checkpoints`. If the number of current saves is greater + than `total_limit` then the oldest save is deleted. Each checkpoint is saved in seperate folders named + `checkpoint_`. + + Otherwise they are just saved to `output_dir`. + + + + Should only be used when wanting to save a checkpoint during training and restoring the state in the same + environment. + + + + Args: + output_dir (`str` or `os.PathLike`): + The name of the folder to save all relevant weights and states. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). + save_model_func_kwargs (`dict`, *optional*): + Additional keyword arguments for saving model which can be passed to the underlying save function, such + as optional arguments for DeepSpeed's `save_checkpoint` function. + + Example: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> model, optimizer, lr_scheduler = ... + >>> model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + >>> accelerator.save_state(output_dir="my_checkpoint") + ``` + """ + if self.project_configuration.automatic_checkpoint_naming: + output_dir = os.path.join(self.project_dir, "checkpoints") + os.makedirs(output_dir, exist_ok=True) + if self.project_configuration.automatic_checkpoint_naming: + folders = [os.path.join(output_dir, folder) for folder in os.listdir(output_dir)] + if ( + self.project_configuration.total_limit is not None + and (len(folders) + 1 > self.project_configuration.total_limit) + and self.is_main_process + ): + + def _inner(folder): + return list(map(int, re.findall(r"[\/]?([0-9]+)(?=[^\/]*$)", folder)))[0] + + folders.sort(key=_inner) + logger.warning( + f"Deleting {len(folders) + 1 - self.project_configuration.total_limit} checkpoints to make room for new checkpoint." + ) + for folder in folders[: len(folders) + 1 - self.project_configuration.total_limit]: + shutil.rmtree(folder) + output_dir = os.path.join(output_dir, f"checkpoint_{self.save_iteration}") + if os.path.exists(output_dir): + raise ValueError( + f"Checkpoint directory {output_dir} ({self.save_iteration}) already exists. Please manually override `self.save_iteration` with what iteration to start with." + ) + self.wait_for_everyone() + os.makedirs(output_dir, exist_ok=True) + logger.info(f"Saving current state to {output_dir}") + + if self.distributed_type == DistributedType.XLA: + # Finish running the previous step before checkpointing + xm.mark_step() + + # Save the models taking care of FSDP and DeepSpeed nuances + weights = [] + for i, model in enumerate(self._models): + if self.distributed_type == DistributedType.FSDP: + logger.info("Saving FSDP model") + save_fsdp_model(self.state.fsdp_plugin, self, model, output_dir, i) + logger.info(f"FSDP Model saved to output dir {output_dir}") + elif self.distributed_type == DistributedType.DEEPSPEED: + logger.info("Saving DeepSpeed Model and Optimizer") + ckpt_id = f"{MODEL_NAME}" if i == 0 else f"{MODEL_NAME}_{i}" + model.save_checkpoint(output_dir, ckpt_id, **save_model_func_kwargs) + logger.info(f"DeepSpeed Model and Optimizer saved to output dir {os.path.join(output_dir, ckpt_id)}") + elif self.distributed_type == DistributedType.MEGATRON_LM: + logger.info("Saving Megatron-LM Model, Optimizer and Scheduler") + model.save_checkpoint(output_dir) + logger.info(f"Megatron-LM Model , Optimizer and Scheduler saved to output dir {output_dir}") + else: + weights.append(self.get_state_dict(model, unwrap=False)) + + # Save the optimizers taking care of FSDP and DeepSpeed nuances + optimizers = [] + if self.distributed_type == DistributedType.FSDP: + for i, opt in enumerate(self._optimizers): + logger.info("Saving FSDP Optimizer") + save_fsdp_optimizer(self.state.fsdp_plugin, self, opt, self._models[i], output_dir, i) + logger.info(f"FSDP Optimizer saved to output dir {output_dir}") + elif self.distributed_type not in [DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]: + optimizers = self._optimizers + + # Save the lr schedulers taking care of DeepSpeed nuances + schedulers = [] + if self.distributed_type == DistributedType.DEEPSPEED: + for i, scheduler in enumerate(self._schedulers): + if isinstance(scheduler, DeepSpeedSchedulerWrapper): + continue + schedulers.append(scheduler) + elif self.distributed_type not in [DistributedType.MEGATRON_LM]: + schedulers = self._schedulers + + # Save the samplers of the dataloaders + dataloaders = self._dataloaders + + # Call model loading hooks that might have been registered with + # accelerator.register_model_state_hook + for hook in self._save_model_state_pre_hook.values(): + hook(self._models, weights, output_dir) + + save_location = save_accelerator_state( + output_dir, + weights, + optimizers, + schedulers, + dataloaders, + self.state.process_index, + self.scaler, + save_on_each_node=self.project_configuration.save_on_each_node, + safe_serialization=safe_serialization, + ) + for i, obj in enumerate(self._custom_objects): + save_custom_state(obj, output_dir, i, save_on_each_node=self.project_configuration.save_on_each_node) + self.project_configuration.iteration += 1 + return save_location + + def register_load_state_pre_hook(self, hook: Callable[..., None]) -> hooks.RemovableHandle: + """ + Registers a pre hook to be run before [`load_checkpoint`] is called in [`Accelerator.load_state`]. + + Args: + hook (`Callable`): + A function to be called in [`Accelerator.load_state`] before `load_checkpoint`. + + The hook should have the following signature: + + `hook(models: list[torch.nn.Module], input_dir: str) -> None` + + The `models` argument are the models as saved in the accelerator state under `accelerator._models`, and the + `input_dir` argument is the `input_dir` argument passed to [`Accelerator.load_state`]. + + + + Should only be used in conjunction with [`Accelerator.register_save_state_pre_hook`]. Can be useful to load + configurations in addition to model weights. Can also be used to overwrite model loading with a customized + method. In this case, make sure to remove already loaded models from the models list. + + + + Returns: + `torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling + `handle.remove()` + """ + handle = hooks.RemovableHandle(self._load_model_state_pre_hook) + self._load_model_state_pre_hook[handle.id] = hook + return handle + + def load_state(self, input_dir: str = None, **load_model_func_kwargs): + """ + Loads the current states of the model, optimizer, scaler, RNG generators, and registered objects. + + + + Should only be used in conjunction with [`Accelerator.save_state`]. If a file is not registered for + checkpointing, it will not be loaded if stored in the directory. + + + + Args: + input_dir (`str` or `os.PathLike`): + The name of the folder all relevant weights and states were saved in. Can be `None` if + `automatic_checkpoint_naming` is used, and will pick up from the latest checkpoint. + load_model_func_kwargs (`dict`, *optional*): + Additional keyword arguments for loading model which can be passed to the underlying load function, + such as optional arguments for DeepSpeed's `load_checkpoint` function or a `map_location` to load the + model and optimizer on. + + Example: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> model, optimizer, lr_scheduler = ... + >>> model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler) + >>> accelerator.load_state("my_checkpoint") + ``` + """ + if input_dir is not None: + # Check if folder exists + input_dir = os.path.expanduser(input_dir) + if not os.path.isdir(input_dir): + raise ValueError(f"Tried to find {input_dir} but folder does not exist") + elif self.project_configuration.automatic_checkpoint_naming: + # Pick up from automatic checkpoint naming + input_dir = os.path.join(self.project_dir, "checkpoints") + folders = [os.path.join(input_dir, folder) for folder in os.listdir(input_dir)] + + def _inner(folder): + return list(map(int, re.findall(r"[\/]?([0-9]+)(?=[^\/]*$)", folder)))[0] + + folders.sort(key=_inner) + input_dir = folders[-1] + else: + raise ValueError("No input_dir provided and automatic checkpoint naming is disabled.") + logger.info(f"Loading states from {input_dir}") + + # Load the models taking care of FSDP and DeepSpeed nuances + models = [] + for i, model in enumerate(self._models): + if self.distributed_type == DistributedType.FSDP: + logger.info("Loading FSDP model") + load_fsdp_model(self.state.fsdp_plugin, self, model, input_dir, i) + logger.info(f"FSDP Model loaded from input dir {input_dir}") + elif self.distributed_type == DistributedType.DEEPSPEED: + logger.info("Loading DeepSpeed Model and Optimizer") + ckpt_id = f"{MODEL_NAME}" if i == 0 else f"{MODEL_NAME}_{i}" + model.load_checkpoint(input_dir, ckpt_id, **load_model_func_kwargs) + logger.info(f"DeepSpeed Model and Optimizer loaded from input dir {os.path.join(input_dir, ckpt_id)}") + elif self.distributed_type == DistributedType.MEGATRON_LM: + logger.info("Loading Megatron-LM Model, Optimizer and Scheduler") + model.load_checkpoint(input_dir) + logger.info(f"Megatron-LM Model , Optimizer and Scheduler loaded from input dir {input_dir}") + else: + models.append(model) + + # Load the optimizers taking care of FSDP and DeepSpeed nuances + optimizers = [] + if self.distributed_type == DistributedType.FSDP: + for i, opt in enumerate(self._optimizers): + logger.info("Loading FSDP Optimizer") + load_fsdp_optimizer(self.state.fsdp_plugin, self, opt, self._models[i], input_dir, i) + logger.info(f"FSDP Optimizer loaded from input dir {input_dir}") + elif self.distributed_type not in [DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]: + optimizers = self._optimizers + + # Load the lr schedulers taking care of DeepSpeed nuances + schedulers = [] + if self.distributed_type == DistributedType.DEEPSPEED: + for i, scheduler in enumerate(self._schedulers): + if isinstance(scheduler, DeepSpeedSchedulerWrapper): + continue + schedulers.append(scheduler) + elif self.distributed_type not in [DistributedType.MEGATRON_LM]: + schedulers = self._schedulers + + dataloaders = self._dataloaders + + # Call model loading hooks that might have been registered with + # accelerator.register_model_state_hook + for hook in self._load_model_state_pre_hook.values(): + hook(models, input_dir) + + map_location = load_model_func_kwargs.pop("map_location", None) + if map_location is None: + if self.num_processes > 1 and self.distributed_type in ( + DistributedType.MULTI_GPU, + DistributedType.MULTI_MLU, + DistributedType.MULTI_NPU, + ): + map_location = "on_device" + else: + map_location = "cpu" + + load_accelerator_state( + input_dir, + models, + optimizers, + schedulers, + dataloaders, + self.state.process_index, + self.scaler, + map_location, + **load_model_func_kwargs, + ) + custom_checkpoints = [ + f for f in os.listdir(input_dir) if re.search(r"^custom_checkpoint_\d+\.pkl$", f) is not None + ] + if len(custom_checkpoints) != len(self._custom_objects): + err = "Number of custom checkpoints in folder {input_dir} does not match the number of registered objects:" + err += f"\n\tFound checkpoints: {len(custom_checkpoints)}" + err += f"\n\tRegistered objects: {len(self._custom_objects)}\n" + err += "Please make sure to only load checkpoints from folders that were created with the same set of registered objects," + err += "or avoid using `custom_checkpoint` in the filename for files in that same directory and load them in manually." + raise RuntimeError(err) + else: + logger.info(f"Loading in {len(custom_checkpoints)} custom states") + for index, obj in enumerate(self._custom_objects): + load_custom_state(obj, input_dir, index) + + def free_memory(self): + """ + Will release all references to the internal objects stored and call the garbage collector. You should call this + method between two trainings with different models/optimizers. Also will reset `Accelerator.step` to 0. + + Example: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> model, optimizer, scheduler = ... + >>> model, optimizer, scheduler = accelerator.prepare(model, optimizer, scheduler) + >>> accelerator.free_memory() + >>> del model, optimizer, scheduler + ``` + """ + self._schedulers = [] + self._optimizers = [] + self._models = [] + self._dataloaders = [] + self.deepspeed_engine_wrapped = None + self.step = 0 + release_memory() + + def clear(self): + """ + Alias for [`Accelerate.free_memory`], releases all references to the internal objects stored and call the + garbage collector. You should call this method between two trainings with different models/optimizers. + + Example: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> model, optimizer, scheduler = ... + >>> model, optimizer, scheduler = accelerator.prepare(model, optimizer, scheduler) + >>> accelerator.free_memory() + >>> del model, optimizer, scheduler + ``` + """ + self.free_memory() + + def _get_named_parameters(self, *args): + named_parameters = {} + for obj in args: + if isinstance(obj, torch.nn.Module): + obj = extract_model_from_parallel(obj) + named_parameters.update({n: p for n, p in obj.named_parameters()}) + return named_parameters + + def _get_devices(self, *args): + model_device = None + optimizer_device = None + for obj in args: + # Loop through model parameters and stop at the first once we have its device. + if isinstance(obj, torch.nn.Module): + for param in obj.parameters(): + model_device = param.device + break + # Loop through optimizer parameters groups and stop at the first once we have its device. + if isinstance(obj, torch.optim.Optimizer): + for param_group in obj.param_groups: + if len(param_group["params"]) > 0: + optimizer_device = param_group["params"][0].device + break + return (model_device, optimizer_device) + + def get_state_dict(self, model, unwrap=True): + """ + Returns the state dictionary of a model sent through [`Accelerator.prepare`] potentially without full + precision. + + Args: + model (`torch.nn.Module`): + A PyTorch model sent through [`Accelerator.prepare`] + unwrap (`bool`, *optional*, defaults to `True`): + Whether to return the original underlying state_dict of `model` or to return the wrapped state_dict + + Returns: + `dict`: The state dictionary of the model potentially without full precision. + + Example: + + ```python + >>> import torch + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> net = torch.nn.Linear(2, 2) + >>> net = accelerator.prepare(net) + >>> state_dict = accelerator.get_state_dict(net) + ``` + """ + + if self.distributed_type == DistributedType.DEEPSPEED: + if self.deepspeed_config["zero_optimization"]["stage"] == 3: + if model.zero_gather_16bit_weights_on_model_save(): + state_dict = model._zero3_consolidated_16bit_state_dict() + else: + raise ValueError( + "Cannot get 16bit model weights because `stage3_gather_16bit_weights_on_model_save` in DeepSpeed config is False. " + "To save the model weights in 16bit, set `stage3_gather_16bit_weights_on_model_save` to True in DeepSpeed config file or " + "set `zero3_save_16bit_model` to True when using `accelerate config`. " + "To save the full checkpoint, run `model.save_checkpoint(save_dir)` and use `zero_to_fp32.py` to recover weights." + ) + else: + from deepspeed.checkpoint.utils import clone_tensors_for_torch_save + + state_dict = clone_tensors_for_torch_save(self.unwrap_model(model).state_dict()) + elif self.distributed_type == DistributedType.FSDP: + from torch.distributed.fsdp import FullStateDictConfig, StateDictType + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + + full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, full_state_dict_config): + state_dict = model.state_dict() + else: + if unwrap: + model = self.unwrap_model(model) + state_dict = model.state_dict() + + return state_dict + + def register_for_checkpointing(self, *objects): + """ + Makes note of `objects` and will save or load them in during `save_state` or `load_state`. + + These should be utilized when the state is being loaded or saved in the same script. It is not designed to be + used in different scripts. + + + + Every `object` must have a `load_state_dict` and `state_dict` function to be stored. + + + + Example: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> # Assume `CustomObject` has a `state_dict` and `load_state_dict` function. + >>> obj = CustomObject() + >>> accelerator.register_for_checkpointing(obj) + >>> accelerator.save_state("checkpoint.pt") + ``` + """ + invalid_objects = [] + for obj in objects: + if not hasattr(obj, "state_dict") or not hasattr(obj, "load_state_dict"): + invalid_objects.append(obj) + if len(invalid_objects) > 0: + err = "All `objects` must include a `state_dict` and `load_state_dict` function to be stored. The following inputs are invalid:" + for index, obj in enumerate(invalid_objects): + err += f"\n\t- Item at index {index}, `{get_pretty_name(obj)}`" + raise ValueError(err) + self._custom_objects.extend(objects) + + @contextmanager + def autocast(self, cache_enabled: bool = False, autocast_handler: AutocastKwargs = None): + """ + Will apply automatic mixed-precision inside the block inside this context manager, if it is enabled. Nothing + different will happen otherwise. + + A different `autocast_handler` can be passed in to override the one set in the `Accelerator` object. This is + useful in blocks under `autocast` where you want to revert to fp32. + + Example: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator(mixed_precision="fp16") + >>> with accelerator.autocast(): + ... train() + ``` + """ + if cache_enabled: + warnings.warn( + "Passing `cache_enabled=True` to `accelerator.autocast` is deprecated and will be removed in v0.23.0. " + "Please use the `AutocastKwargs` class instead and pass it to the `Accelerator` as a `kwarg_handler`.", + FutureWarning, + ) + if self.autocast_handler is not None: + self.autocast_handler.cache_enabled = True + else: + self.autocast_handler = AutocastKwargs(cache_enabled=True) + if autocast_handler is None: + autocast_handler = self.autocast_handler + autocast_context = get_mixed_precision_context_manager(self.native_amp, autocast_handler) + autocast_context.__enter__() + # TODO: should the `yield` be in a try/finally block? + yield + autocast_context.__exit__(*sys.exc_info()) + + @property + def optimizer_step_was_skipped(self): + """ + Whether or not the optimizer update was skipped (because of gradient overflow in mixed precision), in which + case the learning rate should not be changed. + """ + for optimizer in self._optimizers: + if optimizer.step_was_skipped: + return True + return False + + def skip_first_batches(self, dataloader, num_batches: int = 0): + """ + Creates a new `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. + + Args: + dataloader (`torch.utils.data.DataLoader`): The data loader in which to skip batches. + num_batches (`int`, *optional*, defaults to 0): The number of batches to skip + + Example: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler) + >>> skipped_dataloader = accelerator.skip_first_batches(dataloader, num_batches=2) + >>> # for the first epoch only + >>> for input, target in skipped_dataloader: + ... optimizer.zero_grad() + ... output = model(input) + ... loss = loss_func(output, target) + ... accelerator.backward(loss) + ... optimizer.step() + + >>> # subsequent epochs + >>> for input, target in dataloader: + ... optimizer.zero_grad() + ... ... + ``` + """ + return skip_first_batches(dataloader, num_batches=num_batches) + + def __deepcopy__(self, memo): + logger.info("Deep copying the `Accelerator` object, note that this will point to the same original object.") + return self + + def verify_device_map(self, model: torch.nn.Module) -> bool: + """ + Verifies that `model` has not been prepared with big model inference with a device-map resembling `auto`. + """ + # Checks if any of the child modules has the attribute `hf_device_map` and this map has more than one entry. + for m in model.modules(): + if hasattr(m, "hf_device_map") and len(m.hf_device_map) > 1: + return True + + return False diff --git a/llm/Lib/site-packages/accelerate/big_modeling.py b/llm/Lib/site-packages/accelerate/big_modeling.py new file mode 100644 index 0000000000000000000000000000000000000000..94febb5d3dde35689d99ebbf26c1c78346f0ab17 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/big_modeling.py @@ -0,0 +1,627 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from contextlib import contextmanager +from functools import wraps +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn + +from .hooks import ( + AlignDevicesHook, + CpuOffload, + UserCpuOffloadHook, + add_hook_to_module, + attach_align_device_hook, + attach_align_device_hook_on_blocks, +) +from .utils import ( + OffloadedWeightsLoader, + check_cuda_p2p_ib_support, + check_device_map, + extract_submodules_state_dict, + find_tied_parameters, + get_balanced_memory, + infer_auto_device_map, + is_mlu_available, + is_npu_available, + is_torch_version, + is_xpu_available, + load_checkpoint_in_model, + offload_state_dict, + parse_flag_from_env, + retie_parameters, +) +from .utils.other import recursive_getattr + + +logger = logging.getLogger(__name__) + + +@contextmanager +def init_empty_weights(include_buffers: bool = None): + """ + A context manager under which models are initialized with all parameters on the meta device, therefore creating an + empty model. Useful when just initializing the model would blow the available RAM. + + Args: + include_buffers (`bool`, *optional*): + Whether or not to also put all buffers on the meta device while initializing. + + Example: + + ```python + import torch.nn as nn + from accelerate import init_empty_weights + + # Initialize a model with 100 billions parameters in no time and without using any RAM. + with init_empty_weights(): + tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)]) + ``` + + + + Any model created under this context manager has no weights. As such you can't do something like + `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`]. + Make sure to overwrite the default device_map param for [`load_checkpoint_and_dispatch`], otherwise dispatch is not + called. + + + """ + if include_buffers is None: + include_buffers = parse_flag_from_env("ACCELERATE_INIT_INCLUDE_BUFFERS", False) + with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f: + yield f + + +@contextmanager +def init_on_device(device: torch.device, include_buffers: bool = None): + """ + A context manager under which models are initialized with all parameters on the specified device. + + Args: + device (`torch.device`): + Device to initialize all parameters on. + include_buffers (`bool`, *optional*): + Whether or not to also put all buffers on the meta device while initializing. + + Example: + + ```python + import torch.nn as nn + from accelerate import init_on_device + + with init_on_device(device=torch.device("cuda")): + tst = nn.Liner(100, 100) # on `cuda` device + ``` + """ + if include_buffers is None: + include_buffers = parse_flag_from_env("ACCELERATE_INIT_INCLUDE_BUFFERS", False) + + # TODO(shingjan): remove the torch version check once older versions are deprecated + if is_torch_version(">=", "2.0") and include_buffers: + with device: + yield + return + + old_register_parameter = nn.Module.register_parameter + if include_buffers: + old_register_buffer = nn.Module.register_buffer + + def register_empty_parameter(module, name, param): + old_register_parameter(module, name, param) + if param is not None: + param_cls = type(module._parameters[name]) + kwargs = module._parameters[name].__dict__ + kwargs["requires_grad"] = param.requires_grad + module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) + + def register_empty_buffer(module, name, buffer, persistent=True): + old_register_buffer(module, name, buffer, persistent=persistent) + if buffer is not None: + module._buffers[name] = module._buffers[name].to(device) + + # Patch tensor creation + if include_buffers: + tensor_constructors_to_patch = { + torch_function_name: getattr(torch, torch_function_name) + for torch_function_name in ["empty", "zeros", "ones", "full"] + } + else: + tensor_constructors_to_patch = {} + + def patch_tensor_constructor(fn): + def wrapper(*args, **kwargs): + kwargs["device"] = device + return fn(*args, **kwargs) + + return wrapper + + try: + nn.Module.register_parameter = register_empty_parameter + if include_buffers: + nn.Module.register_buffer = register_empty_buffer + for torch_function_name in tensor_constructors_to_patch.keys(): + setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) + yield + finally: + nn.Module.register_parameter = old_register_parameter + if include_buffers: + nn.Module.register_buffer = old_register_buffer + for torch_function_name, old_torch_function in tensor_constructors_to_patch.items(): + setattr(torch, torch_function_name, old_torch_function) + + +def cpu_offload( + model: nn.Module, + execution_device: Optional[torch.device] = None, + offload_buffers: bool = False, + state_dict: Optional[Dict[str, torch.Tensor]] = None, + preload_module_classes: Optional[List[str]] = None, +): + """ + Activates full CPU offload for a model. As a result, all parameters of the model will be offloaded and only one + copy of the state dict of the model will be kept. During the forward pass, parameters will be extracted from that + state dict and put on the execution device passed as they are needed, then offloaded again. + + Args: + model (`torch.nn.Module`): + The model to offload. + execution_device (`torch.device`, *optional*): + The device on which the forward pass of the model will be executed (should be a GPU). Will default to the + model first parameter device. + offload_buffers (`bool`, *optional*, defaults to `False`): + Whether or not to offload the buffers with the model parameters. + state_dict (`Dict[str, torch.Tensor]`, *optional*): + The state dict of the model that will be kept on CPU. + preload_module_classes (`List[str]`, *optional*): + A list of classes whose instances should load all their weights (even in the submodules) at the beginning + of the forward. This should only be used for classes that have submodules which are registered but not + called directly during the forward, for instance if a `dense` linear layer is registered, but at forward, + `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly. + """ + if execution_device is None: + execution_device = next(iter(model.parameters())).device + if state_dict is None: + state_dict = {n: p.to("cpu") for n, p in model.state_dict().items()} + + add_hook_to_module(model, AlignDevicesHook(io_same_device=True), append=True) + attach_align_device_hook( + model, + execution_device=execution_device, + offload=True, + offload_buffers=offload_buffers, + weights_map=state_dict, + preload_module_classes=preload_module_classes, + ) + + return model + + +def cpu_offload_with_hook( + model: torch.nn.Module, + execution_device: Optional[Union[int, str, torch.device]] = None, + prev_module_hook: Optional[UserCpuOffloadHook] = None, +): + """ + Offloads a model on the CPU and puts it back to an execution device when executed. The difference with + [`cpu_offload`] is that the model stays on the execution device after the forward and is only offloaded again when + the `offload` method of the returned `hook` is called. Useful for pipelines running a model in a loop. + + Args: + model (`torch.nn.Module`): + The model to offload. + execution_device(`str`, `int` or `torch.device`, *optional*): + The device on which the model should be executed. Will default to the MPS device if it's available, then + GPU 0 if there is a GPU, and finally to the CPU. + prev_module_hook (`UserCpuOffloadHook`, *optional*): + The hook sent back by this function for a previous model in the pipeline you are running. If passed, its + offload method will be called just before the forward of the model to which this hook is attached. + + Example: + + ```py + model_1, hook_1 = cpu_offload_with_hook(model_1, cuda_device) + model_2, hook_2 = cpu_offload_with_hook(model_2, cuda_device, prev_module_hook=hook_1) + model_3, hook_3 = cpu_offload_with_hook(model_3, cuda_device, prev_module_hook=hook_2) + + hid_1 = model_1(input) + for i in range(50): + # model1 is offloaded on the CPU at the first iteration, model 2 stays on the GPU for this whole loop. + hid_2 = model_2(hid_1) + # model2 is offloaded to the CPU just before this forward. + hid_3 = model_3(hid_3) + + # For model3, you need to manually call the hook offload method. + hook_3.offload() + ``` + """ + hook = CpuOffload(execution_device=execution_device, prev_module_hook=prev_module_hook) + add_hook_to_module(model, hook, append=True) + user_hook = UserCpuOffloadHook(model, hook) + return model, user_hook + + +def disk_offload( + model: nn.Module, + offload_dir: Union[str, os.PathLike], + execution_device: Optional[torch.device] = None, + offload_buffers: bool = False, + preload_module_classes: Optional[List[str]] = None, +): + """ + Activates full disk offload for a model. As a result, all parameters of the model will be offloaded as + memory-mapped array in a given folder. During the forward pass, parameters will be accessed from that folder and + put on the execution device passed as they are needed, then offloaded again. + + Args: + model (`torch.nn.Module`): The model to offload. + offload_dir (`str` or `os.PathLike`): + The folder in which to offload the model weights (or where the model weights are already offloaded). + execution_device (`torch.device`, *optional*): + The device on which the forward pass of the model will be executed (should be a GPU). Will default to the + model's first parameter device. + offload_buffers (`bool`, *optional*, defaults to `False`): + Whether or not to offload the buffers with the model parameters. + preload_module_classes (`List[str]`, *optional*): + A list of classes whose instances should load all their weights (even in the submodules) at the beginning + of the forward. This should only be used for classes that have submodules which are registered but not + called directly during the forward, for instance if a `dense` linear layer is registered, but at forward, + `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly. + """ + if not os.path.isdir(offload_dir) or not os.path.isfile(os.path.join(offload_dir, "index.json")): + offload_state_dict(offload_dir, model.state_dict()) + if execution_device is None: + execution_device = next(iter(model.parameters())).device + weights_map = OffloadedWeightsLoader(save_folder=offload_dir) + + add_hook_to_module(model, AlignDevicesHook(io_same_device=True), append=True) + attach_align_device_hook( + model, + execution_device=execution_device, + offload=True, + offload_buffers=offload_buffers, + weights_map=weights_map, + preload_module_classes=preload_module_classes, + ) + + return model + + +def dispatch_model( + model: nn.Module, + device_map: Dict[str, Union[str, int, torch.device]], + main_device: Optional[torch.device] = None, + state_dict: Optional[Dict[str, torch.Tensor]] = None, + offload_dir: Optional[Union[str, os.PathLike]] = None, + offload_index: Optional[Dict[str, str]] = None, + offload_buffers: bool = False, + skip_keys: Optional[Union[str, List[str]]] = None, + preload_module_classes: Optional[List[str]] = None, + force_hooks: bool = False, +): + """ + Dispatches a model according to a given device map. Layers of the model might be spread across GPUs, offloaded on + the CPU or even the disk. + + Args: + model (`torch.nn.Module`): + The model to dispatch. + device_map (`Dict[str, Union[str, int, torch.device]]`): + A dictionary mapping module names in the models `state_dict` to the device they should go to. Note that + `"disk"` is accepted even if it's not a proper value for `torch.device`. + main_device (`str`, `int` or `torch.device`, *optional*): + The main execution device. Will default to the first device in the `device_map` different from `"cpu"` or + `"disk"`. + state_dict (`Dict[str, torch.Tensor]`, *optional*): + The state dict of the part of the model that will be kept on CPU. + offload_dir (`str` or `os.PathLike`): + The folder in which to offload the model weights (or where the model weights are already offloaded). + offload_index (`Dict`, *optional*): + A dictionary from weight name to their information (`dtype`/ `shape` or safetensors filename). Will default + to the index saved in `save_folder`. + offload_buffers (`bool`, *optional*, defaults to `False`): + Whether or not to offload the buffers with the model parameters. + skip_keys (`str` or `List[str]`, *optional*): + A list of keys to ignore when moving inputs or outputs between devices. + preload_module_classes (`List[str]`, *optional*): + A list of classes whose instances should load all their weights (even in the submodules) at the beginning + of the forward. This should only be used for classes that have submodules which are registered but not + called directly during the forward, for instance if a `dense` linear layer is registered, but at forward, + `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly. + force_hooks (`bool`, *optional*, defaults to `False`): + Whether or not to force device hooks to be attached to the model even if all layers are dispatched to a + single device. + """ + # Error early if the device map is incomplete. + check_device_map(model, device_map) + + # for backward compatibility + is_bnb_quantized = ( + getattr(model, "is_quantized", False) or getattr(model, "is_loaded_in_8bit", False) + ) and getattr(model, "quantization_method", "bitsandbytes") == "bitsandbytes" + + # We attach hooks if the device_map has at least 2 different devices or if + # force_hooks is set to `True`. Otherwise, the model in already loaded + # in the unique device and the user can decide where to dispatch the model. + # If the model is quantized, we always force-dispatch the model + if (len(set(device_map.values())) > 1) or is_bnb_quantized or force_hooks: + if main_device is None: + if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}: + main_device = "cpu" + else: + main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0] + + if main_device != "cpu": + cpu_modules = [name for name, device in device_map.items() if device == "cpu"] + if state_dict is None and len(cpu_modules) > 0: + state_dict = extract_submodules_state_dict(model.state_dict(), cpu_modules) + + disk_modules = [name for name, device in device_map.items() if device == "disk"] + if offload_dir is None and offload_index is None and len(disk_modules) > 0: + raise ValueError( + "We need an `offload_dir` to dispatch this model according to this `device_map`, the following submodules " + f"need to be offloaded: {', '.join(disk_modules)}." + ) + if ( + len(disk_modules) > 0 + and offload_index is None + and (not os.path.isdir(offload_dir) or not os.path.isfile(os.path.join(offload_dir, "index.json"))) + ): + disk_state_dict = extract_submodules_state_dict(model.state_dict(), disk_modules) + offload_state_dict(offload_dir, disk_state_dict) + + execution_device = { + name: main_device if device in ["cpu", "disk"] else device for name, device in device_map.items() + } + execution_device[""] = main_device + offloaded_devices = ["disk"] if main_device == "cpu" or main_device == "mps" else ["cpu", "disk"] + offload = {name: device in offloaded_devices for name, device in device_map.items()} + save_folder = offload_dir if len(disk_modules) > 0 else None + if state_dict is not None or save_folder is not None or offload_index is not None: + device = main_device if offload_index is not None else None + weights_map = OffloadedWeightsLoader( + state_dict=state_dict, save_folder=save_folder, index=offload_index, device=device + ) + else: + weights_map = None + + # When dispatching the model's parameters to the devices specified in device_map, we want to avoid allocating memory several times for the + # tied parameters. The dictionary tied_params_map keeps track of the already allocated data for a given tied parameter (represented by its + # original pointer) on each devices. + tied_params = find_tied_parameters(model) + + tied_params_map = {} + for group in tied_params: + for param_name in group: + # data_ptr() is enough here, as `find_tied_parameters` finds tied params simply by comparing `param1 is param2`, so we don't need + # to care about views of tensors through storage_offset. + data_ptr = recursive_getattr(model, param_name).data_ptr() + tied_params_map[data_ptr] = {} + + # Note: To handle the disk offloading case, we can not simply use weights_map[param_name].data_ptr() as the reference pointer, + # as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer. + + attach_align_device_hook_on_blocks( + model, + execution_device=execution_device, + offload=offload, + offload_buffers=offload_buffers, + weights_map=weights_map, + skip_keys=skip_keys, + preload_module_classes=preload_module_classes, + tied_params_map=tied_params_map, + ) + + # warn if there is any params on the meta device + offloaded_devices_str = " and ".join( + [device for device in set(device_map.values()) if device in ("cpu", "disk")] + ) + if len(offloaded_devices_str) > 0: + logging.warning( + f"Some parameters are on the meta device device because they were offloaded to the {offloaded_devices_str}." + ) + + # Attaching the hook may break tied weights, so we retie them + retie_parameters(model, tied_params) + + # add warning to cuda and to method + def add_warning(fn, model): + @wraps(fn) + def wrapper(*args, **kwargs): + warning_msg = "You shouldn't move a model that is dispatched using accelerate hooks." + if str(fn.__name__) == "to": + to_device = torch._C._nn._parse_to(*args, **kwargs)[0] + if to_device is not None: + logger.warning(warning_msg) + else: + logger.warning(warning_msg) + for param in model.parameters(): + if param.device == torch.device("meta"): + raise RuntimeError("You can't move a model that has some modules offloaded to cpu or disk.") + return fn(*args, **kwargs) + + return wrapper + + model.to = add_warning(model.to, model) + if is_npu_available(): + model.npu = add_warning(model.npu, model) + elif is_mlu_available(): + model.mlu = add_warning(model.mlu, model) + elif is_xpu_available(): + model.xpu = add_warning(model.xpu, model) + else: + model.cuda = add_warning(model.cuda, model) + + # Check if we are using multi-gpus with RTX 4000 series + use_multi_gpu = len([device for device in set(device_map.values()) if device not in ("cpu", "disk")]) > 1 + if use_multi_gpu and not check_cuda_p2p_ib_support(): + logger.warning( + "We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. " + "This can affect the multi-gpu inference when using accelerate device_map." + "Please make sure to update your driver to the latest version which resolves this." + ) + else: + device = list(device_map.values())[0] + # `torch.Tensor.to()` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)). + if is_npu_available() and isinstance(device, int): + device = f"npu:{device}" + elif is_mlu_available() and isinstance(device, int): + device = f"mlu:{device}" + elif is_xpu_available() and isinstance(device, int): + device = f"xpu:{device}" + if device != "disk": + model.to(device) + else: + raise ValueError( + "You are trying to offload the whole model to the disk. Please use the `disk_offload` function instead." + ) + # Convert OrderedDict back to dict for easier usage + model.hf_device_map = dict(device_map) + return model + + +def load_checkpoint_and_dispatch( + model: nn.Module, + checkpoint: Union[str, os.PathLike], + device_map: Optional[Union[str, Dict[str, Union[int, str, torch.device]]]] = None, + max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None, + no_split_module_classes: Optional[List[str]] = None, + offload_folder: Optional[Union[str, os.PathLike]] = None, + offload_buffers: bool = False, + dtype: Optional[Union[str, torch.dtype]] = None, + offload_state_dict: Optional[bool] = None, + skip_keys: Optional[Union[str, List[str]]] = None, + preload_module_classes: Optional[List[str]] = None, + force_hooks: bool = False, + strict: bool = False, +): + """ + Loads a (potentially sharded) checkpoint inside a model, potentially sending weights to a given device as they are + loaded and adds the various hooks that will make this model run properly (even if split across devices). + + Args: + model (`torch.nn.Module`): The model in which we want to load a checkpoint. + checkpoint (`str` or `os.PathLike`): + The folder checkpoint to load. It can be: + - a path to a file containing a whole model state dict + - a path to a `.json` file containing the index to a sharded checkpoint + - a path to a folder containing a unique `.index.json` file and the shards of a checkpoint. + device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer + name, once a given module name is inside, every submodule of it will be sent to the same device. + + To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For more + information about each option see [here](../concept_guides/big_model_inference#designing-a-device-map). + Defaults to None, which means [`dispatch_model`] will not be called. + max_memory (`Dict`, *optional*): + A dictionary device identifier to maximum memory. Will default to the maximum memory available for each GPU + and the available CPU RAM if unset. + no_split_module_classes (`List[str]`, *optional*): + A list of layer class names that should never be split across device (for instance any layer that has a + residual connection). + offload_folder (`str` or `os.PathLike`, *optional*): + If the `device_map` contains any value `"disk"`, the folder where we will offload weights. + offload_buffers (`bool`, *optional*, defaults to `False`): + In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as + well as the parameters. + dtype (`str` or `torch.dtype`, *optional*): + If provided, the weights will be converted to that type when loaded. + offload_state_dict (`bool`, *optional*): + If `True`, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if + the weight of the CPU state dict + the biggest shard does not fit. Will default to `True` if the device map + picked contains `"disk"` values. + skip_keys (`str` or `List[str]`, *optional*): + A list of keys to ignore when moving inputs or outputs between devices. + preload_module_classes (`List[str]`, *optional*): + A list of classes whose instances should load all their weights (even in the submodules) at the beginning + of the forward. This should only be used for classes that have submodules which are registered but not + called directly during the forward, for instance if a `dense` linear layer is registered, but at forward, + `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly. + force_hooks (`bool`, *optional*, defaults to `False`): + Whether or not to force device hooks to be attached to the model even if all layers are dispatched to a + single device. + strict (`bool`, *optional*, defaults to `False`): + Whether to strictly enforce that the keys in the checkpoint state_dict match the keys of the model's + state_dict. + + Example: + + ```python + >>> from accelerate import init_empty_weights, load_checkpoint_and_dispatch + >>> from huggingface_hub import hf_hub_download + >>> from transformers import AutoConfig, AutoModelForCausalLM + + >>> # Download the Weights + >>> checkpoint = "EleutherAI/gpt-j-6B" + >>> weights_location = hf_hub_download(checkpoint, "pytorch_model.bin") + + >>> # Create a model and initialize it with empty weights + >>> config = AutoConfig.from_pretrained(checkpoint) + >>> with init_empty_weights(): + ... model = AutoModelForCausalLM.from_config(config) + + >>> # Load the checkpoint and dispatch it to the right devices + >>> model = load_checkpoint_and_dispatch( + ... model, weights_location, device_map="auto", no_split_module_classes=["GPTJBlock"] + ... ) + ``` + """ + if isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]: + raise ValueError( + "If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or " + "'sequential'." + ) + if isinstance(device_map, str): + if device_map != "sequential": + max_memory = get_balanced_memory( + model, + max_memory=max_memory, + no_split_module_classes=no_split_module_classes, + dtype=dtype, + low_zero=(device_map == "balanced_low_0"), + ) + device_map = infer_auto_device_map( + model, + max_memory=max_memory, + no_split_module_classes=no_split_module_classes, + dtype=dtype, + offload_buffers=offload_buffers, + ) + if offload_state_dict is None and device_map is not None and "disk" in device_map.values(): + offload_state_dict = True + load_checkpoint_in_model( + model, + checkpoint, + device_map=device_map, + offload_folder=offload_folder, + dtype=dtype, + offload_state_dict=offload_state_dict, + offload_buffers=offload_buffers, + strict=strict, + ) + if device_map is None: + return model + return dispatch_model( + model, + device_map=device_map, + offload_dir=offload_folder, + offload_buffers=offload_buffers, + skip_keys=skip_keys, + preload_module_classes=preload_module_classes, + force_hooks=force_hooks, + ) diff --git a/llm/Lib/site-packages/accelerate/checkpointing.py b/llm/Lib/site-packages/accelerate/checkpointing.py new file mode 100644 index 0000000000000000000000000000000000000000..307eca49d7c93dcbf450c3de5f5f356e11db1d51 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/checkpointing.py @@ -0,0 +1,275 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +from pathlib import Path +from typing import List + +import numpy as np +import torch +from safetensors.torch import load_file +from torch.cuda.amp import GradScaler + +from .utils import ( + MODEL_NAME, + OPTIMIZER_NAME, + RNG_STATE_NAME, + SAFE_MODEL_NAME, + SAFE_WEIGHTS_NAME, + SAMPLER_NAME, + SCALER_NAME, + SCHEDULER_NAME, + WEIGHTS_NAME, + get_pretty_name, + is_torch_xla_available, + is_xpu_available, + save, +) + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + +from .logging import get_logger +from .state import PartialState + + +logger = get_logger(__name__) + + +def save_accelerator_state( + output_dir: str, + model_states: List[dict], + optimizers: list, + schedulers: list, + dataloaders: list, + process_index: int, + scaler: GradScaler = None, + save_on_each_node: bool = False, + safe_serialization: bool = True, +): + """ + Saves the current states of the models, optimizers, scaler, and RNG generators to a given directory. + + + + If `safe_serialization` is `True`, models will be saved with `safetensors` while the rest are saved using native + `pickle`. + + + + Args: + output_dir (`str` or `os.PathLike`): + The name of the folder to save all relevant weights and states. + model_states (`List[torch.nn.Module]`): + A list of model states + optimizers (`List[torch.optim.Optimizer]`): + A list of optimizer instances + schedulers (`List[torch.optim.lr_scheduler._LRScheduler]`): + A list of learning rate schedulers + dataloaders (`List[torch.utils.data.DataLoader]`): + A list of dataloader instances to save their sampler states + process_index (`int`): + The current process index in the Accelerator state + scaler (`torch.cuda.amp.GradScaler`, *optional*): + An optional gradient scaler instance to save + save_on_each_node (`bool`, *optional*): + Whether to save on every node, or only the main node. + safe_serialization (`bool`, *optional*, defaults to `True`): + Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). + """ + output_dir = Path(output_dir) + # Model states + for i, state in enumerate(model_states): + weights_name = WEIGHTS_NAME if not safe_serialization else SAFE_WEIGHTS_NAME + if i > 0: + weights_name = weights_name.replace(".", f"_{i}.") + output_model_file = output_dir.joinpath(weights_name) + save(state, output_model_file, save_on_each_node=save_on_each_node, safe_serialization=safe_serialization) + logger.info(f"Model weights saved in {output_model_file}") + # Optimizer states + for i, opt in enumerate(optimizers): + state = opt.state_dict() + optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin" + output_optimizer_file = output_dir.joinpath(optimizer_name) + save(state, output_optimizer_file, save_on_each_node=save_on_each_node, safe_serialization=False) + logger.info(f"Optimizer state saved in {output_optimizer_file}") + # Scheduler states + for i, scheduler in enumerate(schedulers): + state = scheduler.state_dict() + scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin" + output_scheduler_file = output_dir.joinpath(scheduler_name) + save(state, output_scheduler_file, save_on_each_node=save_on_each_node, safe_serialization=False) + logger.info(f"Scheduler state saved in {output_scheduler_file}") + # DataLoader states + for i, dataloader in enumerate(dataloaders): + sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin" + output_sampler_file = output_dir.joinpath(sampler_name) + # Only save if we have our custom sampler + from .data_loader import IterableDatasetShard, SeedableRandomSampler + + if isinstance(dataloader.dataset, IterableDatasetShard): + sampler = dataloader.sampler.sampler + + if isinstance(sampler, SeedableRandomSampler): + save(sampler, output_sampler_file, save_on_each_node=save_on_each_node, safe_serialization=False) + logger.info(f"Sampler state for dataloader {i} saved in {output_sampler_file}") + + # GradScaler state + if scaler is not None: + state = scaler.state_dict() + output_scaler_file = output_dir.joinpath(SCALER_NAME) + torch.save(state, output_scaler_file) + logger.info(f"Gradient scaler state saved in {output_scaler_file}") + # Random number generator states + states = {} + states_name = f"{RNG_STATE_NAME}_{process_index}.pkl" + states["random_state"] = random.getstate() + states["numpy_random_seed"] = np.random.get_state() + states["torch_manual_seed"] = torch.get_rng_state() + if is_xpu_available(): + states["torch_xpu_manual_seed"] = torch.xpu.get_rng_state_all() + else: + states["torch_cuda_manual_seed"] = torch.cuda.get_rng_state_all() + if is_torch_xla_available(): + states["xm_seed"] = xm.get_rng_state() + output_states_file = output_dir.joinpath(states_name) + torch.save(states, output_states_file) + logger.info(f"Random states saved in {output_states_file}") + return output_dir + + +def load_accelerator_state( + input_dir, + models, + optimizers, + schedulers, + dataloaders, + process_index, + scaler=None, + map_location=None, + **load_model_func_kwargs, +): + """ + Loads states of the models, optimizers, scaler, and RNG generators from a given directory. + + Args: + input_dir (`str` or `os.PathLike`): + The name of the folder to load all relevant weights and states. + models (`List[torch.nn.Module]`): + A list of model instances + optimizers (`List[torch.optim.Optimizer]`): + A list of optimizer instances + schedulers (`List[torch.optim.lr_scheduler._LRScheduler]`): + A list of learning rate schedulers + process_index (`int`): + The current process index in the Accelerator state + scaler (`torch.cuda.amp.GradScaler`, *optional*): + An optional *GradScaler* instance to load + map_location (`str`, *optional*): + What device to load the optimizer state onto. Should be one of either "cpu" or "on_device". + load_model_func_kwargs (`dict`, *optional*): + Additional arguments that can be passed to the model's `load_state_dict` method. + """ + if map_location not in [None, "cpu", "on_device"]: + raise TypeError( + "Unsupported optimizer map location passed, please choose one of `None`, `'cpu'`, or `'on_device'`" + ) + if map_location is None: + map_location = "cpu" + elif map_location == "on_device": + map_location = PartialState().device + + input_dir = Path(input_dir) + # Model states + for i, model in enumerate(models): + ending = f"_{i}" if i > 0 else "" + input_model_file = input_dir.joinpath(f"{SAFE_MODEL_NAME}{ending}.safetensors") + if input_model_file.exists(): + state_dict = load_file(input_model_file, device=str(map_location)) + else: + # Load with torch + input_model_file = input_dir.joinpath(f"{MODEL_NAME}{ending}.bin") + state_dict = torch.load(input_model_file, map_location=map_location) + models[i].load_state_dict(state_dict, **load_model_func_kwargs) + logger.info("All model weights loaded successfully") + + # Optimizer states + for i, opt in enumerate(optimizers): + optimizer_name = f"{OPTIMIZER_NAME}.bin" if i == 0 else f"{OPTIMIZER_NAME}_{i}.bin" + input_optimizer_file = input_dir.joinpath(optimizer_name) + optimizer_state = torch.load(input_optimizer_file, map_location=map_location) + optimizers[i].load_state_dict(optimizer_state) + logger.info("All optimizer states loaded successfully") + + # Scheduler states + for i, scheduler in enumerate(schedulers): + scheduler_name = f"{SCHEDULER_NAME}.bin" if i == 0 else f"{SCHEDULER_NAME}_{i}.bin" + input_scheduler_file = input_dir.joinpath(scheduler_name) + scheduler.load_state_dict(torch.load(input_scheduler_file)) + logger.info("All scheduler states loaded successfully") + + for i, dataloader in enumerate(dataloaders): + sampler_name = f"{SAMPLER_NAME}.bin" if i == 0 else f"{SAMPLER_NAME}_{i}.bin" + input_sampler_file = input_dir.joinpath(sampler_name) + # Only load if we have our custom sampler + from .data_loader import IterableDatasetShard, SeedableRandomSampler + + if isinstance(dataloader.dataset, IterableDatasetShard): + sampler = dataloader.sampler.sampler + + if isinstance(sampler, SeedableRandomSampler): + dataloader.sampler.sampler = torch.load(input_sampler_file) + logger.info("All dataloader sampler states loaded successfully") + + # GradScaler state + if scaler is not None: + input_scaler_file = input_dir.joinpath(SCALER_NAME) + scaler.load_state_dict(torch.load(input_scaler_file)) + logger.info("GradScaler state loaded successfully") + + # Random states + try: + states = torch.load(input_dir.joinpath(f"{RNG_STATE_NAME}_{process_index}.pkl")) + random.setstate(states["random_state"]) + np.random.set_state(states["numpy_random_seed"]) + torch.set_rng_state(states["torch_manual_seed"]) + if is_xpu_available(): + torch.xpu.set_rng_state_all(states["torch_xpu_manual_seed"]) + else: + torch.cuda.set_rng_state_all(states["torch_cuda_manual_seed"]) + if is_torch_xla_available(): + xm.set_rng_state(states["xm_seed"]) + logger.info("All random states loaded successfully") + except Exception: + logger.info("Could not load random states") + + +def save_custom_state(obj, path, index: int = 0, save_on_each_node: bool = False): + """ + Saves the state of `obj` to `{path}/custom_checkpoint_{index}.pkl` + """ + # Should this be the right way to get a qual_name type value from `obj`? + save_location = Path(path) / f"custom_checkpoint_{index}.pkl" + logger.info(f"Saving the state of {get_pretty_name(obj)} to {save_location}") + save(obj.state_dict(), save_location, save_on_each_node=save_on_each_node) + + +def load_custom_state(obj, path, index: int = 0): + """ + Loads the state of `obj` at `{path}/custom_checkpoint_{index}.pkl` + """ + load_location = f"{path}/custom_checkpoint_{index}.pkl" + logger.info(f"Loading the state of {get_pretty_name(obj)} from {load_location}") + obj.load_state_dict(torch.load(load_location, map_location="cpu")) diff --git a/llm/Lib/site-packages/accelerate/commands/__init__.py b/llm/Lib/site-packages/accelerate/commands/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c9cbe26c257b515f657c05e1996d517e69613972 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/commands/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/llm/Lib/site-packages/accelerate/commands/__pycache__/__init__.cpython-311.pyc b/llm/Lib/site-packages/accelerate/commands/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16f1dd37ea9ae45a807168ac50f9b5a99628e94d Binary files /dev/null and b/llm/Lib/site-packages/accelerate/commands/__pycache__/__init__.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/commands/__pycache__/accelerate_cli.cpython-311.pyc b/llm/Lib/site-packages/accelerate/commands/__pycache__/accelerate_cli.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20e2f0736df3a79789a6de262213ceb64389273a Binary files /dev/null and b/llm/Lib/site-packages/accelerate/commands/__pycache__/accelerate_cli.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/commands/__pycache__/env.cpython-311.pyc b/llm/Lib/site-packages/accelerate/commands/__pycache__/env.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a17b281bee53bb5a87195cb2c6b81a1e11d6b86 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/commands/__pycache__/env.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/commands/__pycache__/estimate.cpython-311.pyc b/llm/Lib/site-packages/accelerate/commands/__pycache__/estimate.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63bb96c4f3df3012ca168b6f2450df84cd6bcd35 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/commands/__pycache__/estimate.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/commands/__pycache__/launch.cpython-311.pyc b/llm/Lib/site-packages/accelerate/commands/__pycache__/launch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad3e6d6061a8d0ced0cef3629fe33a44b4dcf76e Binary files /dev/null and b/llm/Lib/site-packages/accelerate/commands/__pycache__/launch.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/commands/__pycache__/test.cpython-311.pyc b/llm/Lib/site-packages/accelerate/commands/__pycache__/test.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..186ea10d6a3fe569bb52d21fbfa66ae628c57cc8 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/commands/__pycache__/test.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/commands/__pycache__/tpu.cpython-311.pyc b/llm/Lib/site-packages/accelerate/commands/__pycache__/tpu.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cddd39cb881c428353b457d4861b78905b8d0817 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/commands/__pycache__/tpu.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/commands/__pycache__/utils.cpython-311.pyc b/llm/Lib/site-packages/accelerate/commands/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..369d72714263c32eca965e9c0b28a1bfcaa57777 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/commands/__pycache__/utils.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/commands/accelerate_cli.py b/llm/Lib/site-packages/accelerate/commands/accelerate_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..ea5a08abf51a83ca048524ea0b8758f9d52b7edc --- /dev/null +++ b/llm/Lib/site-packages/accelerate/commands/accelerate_cli.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from accelerate.commands.config import get_config_parser +from accelerate.commands.env import env_command_parser +from accelerate.commands.estimate import estimate_command_parser +from accelerate.commands.launch import launch_command_parser +from accelerate.commands.test import test_command_parser +from accelerate.commands.tpu import tpu_command_parser +from accelerate.commands.utils import CustomArgumentParser + + +def main(): + parser = CustomArgumentParser("Accelerate CLI tool", usage="accelerate []", allow_abbrev=False) + subparsers = parser.add_subparsers(help="accelerate command helpers") + + # Register commands + get_config_parser(subparsers=subparsers) + estimate_command_parser(subparsers=subparsers) + env_command_parser(subparsers=subparsers) + launch_command_parser(subparsers=subparsers) + tpu_command_parser(subparsers=subparsers) + test_command_parser(subparsers=subparsers) + + # Let's go + args = parser.parse_args() + + if not hasattr(args, "func"): + parser.print_help() + exit(1) + + # Run + args.func(args) + + +if __name__ == "__main__": + main() diff --git a/llm/Lib/site-packages/accelerate/commands/config/__init__.py b/llm/Lib/site-packages/accelerate/commands/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..649a15888cccd070b3d4ca9a600457c6ad59d4d3 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/commands/config/__init__.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from .config import config_command_parser +from .config_args import default_config_file, load_config_from_file # noqa: F401 +from .default import default_command_parser +from .update import update_command_parser + + +def get_config_parser(subparsers=None): + parent_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) + # The main config parser + config_parser = config_command_parser(subparsers) + # The subparser to add commands to + subcommands = config_parser.add_subparsers(title="subcommands", dest="subcommand") + + # Then add other parsers with the parent parser + default_command_parser(subcommands, parents=[parent_parser]) + update_command_parser(subcommands, parents=[parent_parser]) + + return config_parser + + +def main(): + config_parser = get_config_parser() + args = config_parser.parse_args() + + if not hasattr(args, "func"): + config_parser.print_help() + exit(1) + + # Run + args.func(args) + + +if __name__ == "__main__": + main() diff --git a/llm/Lib/site-packages/accelerate/commands/config/__pycache__/__init__.cpython-311.pyc b/llm/Lib/site-packages/accelerate/commands/config/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c6fabb342858461f732a5716456b16dec63fc2f Binary files /dev/null and b/llm/Lib/site-packages/accelerate/commands/config/__pycache__/__init__.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/commands/config/__pycache__/cluster.cpython-311.pyc b/llm/Lib/site-packages/accelerate/commands/config/__pycache__/cluster.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2472fe6a8424da9421a523772889851c677723e0 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/commands/config/__pycache__/cluster.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/commands/config/__pycache__/config.cpython-311.pyc b/llm/Lib/site-packages/accelerate/commands/config/__pycache__/config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..950488de2ef705da91a1f6cd562338dc3ff66c6b Binary files /dev/null and b/llm/Lib/site-packages/accelerate/commands/config/__pycache__/config.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/commands/config/__pycache__/config_args.cpython-311.pyc b/llm/Lib/site-packages/accelerate/commands/config/__pycache__/config_args.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..688a26c2f6b1e5ed494de2f03f67a0cc2acdb4a7 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/commands/config/__pycache__/config_args.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/commands/config/__pycache__/config_utils.cpython-311.pyc b/llm/Lib/site-packages/accelerate/commands/config/__pycache__/config_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1972ef2332460ee2ee3197fcb6941c5cc54c6f3 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/commands/config/__pycache__/config_utils.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/commands/config/__pycache__/default.cpython-311.pyc b/llm/Lib/site-packages/accelerate/commands/config/__pycache__/default.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da3ab55a3e2e9a5f4b866194ae171c7b6d81d22e Binary files /dev/null and b/llm/Lib/site-packages/accelerate/commands/config/__pycache__/default.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/commands/config/__pycache__/sagemaker.cpython-311.pyc b/llm/Lib/site-packages/accelerate/commands/config/__pycache__/sagemaker.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5c166519103ce1b7c178152205eec4363da8025 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/commands/config/__pycache__/sagemaker.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/commands/config/__pycache__/update.cpython-311.pyc b/llm/Lib/site-packages/accelerate/commands/config/__pycache__/update.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6ff6a2d7287a92c9a7b0eb014f2fc4606339c94 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/commands/config/__pycache__/update.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/commands/config/cluster.py b/llm/Lib/site-packages/accelerate/commands/config/cluster.py new file mode 100644 index 0000000000000000000000000000000000000000..4d8821b0660932a946dc238151b2c6599de625d1 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/commands/config/cluster.py @@ -0,0 +1,705 @@ +#!/usr/bin/env python + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from ...utils import ( + ComputeEnvironment, + DistributedType, + is_deepspeed_available, + is_mlu_available, + is_mps_available, + is_npu_available, + is_transformers_available, + is_xpu_available, +) +from ...utils.constants import ( + DEEPSPEED_MULTINODE_LAUNCHERS, + FSDP_AUTO_WRAP_POLICY, + FSDP_BACKWARD_PREFETCH, + FSDP_SHARDING_STRATEGY, + FSDP_STATE_DICT_TYPE, + TORCH_DYNAMO_MODES, +) +from .config_args import ClusterConfig +from .config_utils import ( + DYNAMO_BACKENDS, + _ask_field, + _ask_options, + _convert_distributed_mode, + _convert_dynamo_backend, + _convert_mixed_precision, + _convert_yes_no_to_bool, +) + + +def get_cluster_input(): + distributed_type = _ask_options( + "Which type of machine are you using?", + ["No distributed training", "multi-CPU", "multi-XPU", "multi-GPU", "multi-NPU", "multi-MLU", "TPU"], + _convert_distributed_mode, + ) + + machine_rank = 0 + num_machines = 1 + num_processes = 1 + gpu_ids = None + main_process_ip = None + main_process_port = None + rdzv_backend = "static" + same_network = True + debug = False + + if distributed_type in [ + DistributedType.MULTI_GPU, + DistributedType.MULTI_MLU, + DistributedType.MULTI_NPU, + DistributedType.MULTI_XPU, + DistributedType.MULTI_CPU, + ]: + num_machines = _ask_field( + "How many different machines will you use (use more than 1 for multi-node training)? [1]: ", + int, + default=1, + ) + if num_machines > 1: + machine_rank = _ask_options( + "What is the rank of this machine?", + list(range(num_machines)), + int, + ) + main_process_ip = _ask_field( + "What is the IP address of the machine that will host the main process? ", + ) + main_process_port = _ask_field( + "What is the port you will use to communicate with the main process? ", + int, + ) + same_network = _ask_field( + "Are all the machines on the same local network? Answer `no` if nodes are on the cloud and/or on different network hosts [YES/no]: ", + _convert_yes_no_to_bool, + default=True, + error_message="Please enter yes or no.", + ) + if not same_network: + rdzv_backend = _ask_field( + "What rendezvous backend will you use? ('static', 'c10d', ...): ", default="static" + ) + debug = _ask_field( + "Should distributed operations be checked while running for errors? This can avoid timeout issues but will be slower. [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + + if distributed_type == DistributedType.NO: + use_cpu = _ask_field( + "Do you want to run your training on CPU only (even if a GPU / Apple Silicon / Ascend NPU device is available)? [yes/NO]:", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + elif distributed_type == DistributedType.MULTI_CPU: + use_cpu = True + else: + use_cpu = False + + ipex_config = {} + mpirun_config = {} + if use_cpu: + ipex_config["ipex"] = _ask_field( + "Do you want to use Intel PyTorch Extension (IPEX) to speed up training on CPU? [yes/NO]:", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if distributed_type == DistributedType.MULTI_CPU: + use_mpirun = _ask_field( + "Do you want accelerate to launch mpirun? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if use_mpirun: + mpirun_hostfile = _ask_field( + "Please enter the path to the hostfile to use with mpirun [~/hostfile]: ", + str, + default="~/hostfile", + ) + mpirun_config["mpirun_hostfile"] = os.path.expanduser(mpirun_hostfile.strip()) + mpirun_config["mpirun_ccl"] = _ask_field("Enter the number of oneCCL worker threads [1]: ", default=1) + if ( + not use_cpu + and is_xpu_available() + and distributed_type + not in [DistributedType.MULTI_GPU, DistributedType.MULTI_NPU, DistributedType.MULTI_MLU, DistributedType.XLA] + ): + ipex_config["use_xpu"] = _ask_field( + "Do you want to use XPU plugin to speed up training on XPU? [yes/NO]:", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + + dynamo_config = {} + use_dynamo = _ask_field( + "Do you wish to optimize your script with torch dynamo?[yes/NO]:", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if use_dynamo: + prefix = "dynamo_" + dynamo_config[prefix + "backend"] = _ask_options( + "Which dynamo backend would you like to use?", + [x.lower() for x in DYNAMO_BACKENDS], + _convert_dynamo_backend, + default=2, + ) + use_custom_options = _ask_field( + "Do you want to customize the defaults sent to torch.compile? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + + if use_custom_options: + dynamo_config[prefix + "mode"] = _ask_options( + "Which mode do you want to use?", + TORCH_DYNAMO_MODES, + lambda x: TORCH_DYNAMO_MODES[int(x)], + default=0, + ) + dynamo_config[prefix + "use_fullgraph"] = _ask_field( + "Do you want the fullgraph mode or it is ok to break model into several subgraphs? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + dynamo_config[prefix + "use_dynamic"] = _ask_field( + "Do you want to enable dynamic shape tracing? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + + use_mps = not use_cpu and is_mps_available() + deepspeed_config = {} + if ( + distributed_type + in [ + DistributedType.MULTI_GPU, + DistributedType.MULTI_XPU, + DistributedType.MULTI_NPU, + DistributedType.MULTI_MLU, + DistributedType.NO, + ] + and not use_mps + ): + use_deepspeed = _ask_field( + "Do you want to use DeepSpeed? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if use_deepspeed: + distributed_type = DistributedType.DEEPSPEED + assert ( + is_deepspeed_available() + ), "DeepSpeed is not installed => run `pip3 install deepspeed` or build it from source" + + if distributed_type == DistributedType.DEEPSPEED: + use_deepspeed_config = _ask_field( + "Do you want to specify a json file to a DeepSpeed config? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if use_deepspeed_config: + deepspeed_config["deepspeed_config_file"] = _ask_field( + "Please enter the path to the json DeepSpeed config file: ", + str, + default="none", + ) + else: + deepspeed_config["zero_stage"] = _ask_options( + "What should be your DeepSpeed's ZeRO optimization stage?", + [0, 1, 2, 3], + int, + default=2, + ) + + deepspeed_devices = ["none", "cpu", "nvme"] + if deepspeed_config["zero_stage"] >= 2: + deepspeed_config["offload_optimizer_device"] = _ask_options( + "Where to offload optimizer states?", deepspeed_devices, lambda x: deepspeed_devices[int(x)] + ) + deepspeed_config["offload_param_device"] = _ask_options( + "Where to offload parameters?", deepspeed_devices, lambda x: deepspeed_devices[int(x)] + ) + if deepspeed_config["offload_param_device"] == "nvme": + deepspeed_config["offload_param_nvme_path"] = _ask_field( + "Nvme Path to offload parameters?", + str, + default="/nvme", + ) + if deepspeed_config["offload_optimizer_device"] == "nvme": + deepspeed_config["offload_optimizer_nvme_path"] = _ask_field( + "Nvme Path to offload optimizer states?", + str, + default="/nvme", + ) + deepspeed_config["gradient_accumulation_steps"] = _ask_field( + "How many gradient accumulation steps you're passing in your script? [1]: ", + int, + default=1, + ) + use_gradient_clipping = _ask_field( + "Do you want to use gradient clipping? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if use_gradient_clipping: + deepspeed_config["gradient_clipping"] = _ask_field( + "What is the gradient clipping value? [1.0]: ", + float, + default=1.0, + ) + if deepspeed_config["zero_stage"] == 3: + deepspeed_config["zero3_save_16bit_model"] = _ask_field( + "Do you want to save 16-bit model weights when using ZeRO Stage-3? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + deepspeed_config["zero3_init_flag"] = _ask_field( + "Do you want to enable `deepspeed.zero.Init` when using ZeRO Stage-3 for constructing massive models? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if deepspeed_config["zero3_init_flag"]: + if not is_transformers_available(): + raise Exception( + "When `zero3_init_flag` is set, it requires Transformers to be installed. " + "Please run `pip3 install transformers`." + ) + + if num_machines > 1: + launcher_query = "Which Type of launcher do you want to use?" + deepspeed_config["deepspeed_multinode_launcher"] = _ask_options( + launcher_query, + DEEPSPEED_MULTINODE_LAUNCHERS, + lambda x: DEEPSPEED_MULTINODE_LAUNCHERS[int(x)], + ) + + if deepspeed_config["deepspeed_multinode_launcher"] != DEEPSPEED_MULTINODE_LAUNCHERS[1]: + deepspeed_config["deepspeed_hostfile"] = _ask_field( + "DeepSpeed configures multi-node compute resources with hostfile. " + "Each row is of the format `hostname slots=[num_gpus]`, e.g., `localhost slots=2`; " + "for more information please refer official [documentation]" + "(https://www.deepspeed.ai/getting-started/#resource-configuration-multi-node). " + "Please specify the location of hostfile: ", + str, + ) + + is_exclusion_filter = _ask_field( + "Do you want to specify exclusion filter string? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if is_exclusion_filter: + deepspeed_config["deepspeed_exclusion_filter"] = _ask_field( + "DeepSpeed exclusion filter string: ", + str, + ) + + is_inclusion_filter = _ask_field( + "Do you want to specify inclusion filter string? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if is_inclusion_filter: + deepspeed_config["deepspeed_inclusion_filter"] = _ask_field( + "DeepSpeed inclusion filter string: ", + str, + ) + + fsdp_config = {} + if distributed_type in [ + DistributedType.MULTI_GPU, + DistributedType.MULTI_NPU, + DistributedType.MULTI_MLU, + DistributedType.MULTI_XPU, + ]: + use_fsdp = _ask_field( + "Do you want to use FullyShardedDataParallel? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if use_fsdp: + distributed_type = DistributedType.FSDP + if distributed_type == DistributedType.FSDP: + sharding_strategy_query = "What should be your sharding strategy?" + fsdp_config["fsdp_sharding_strategy"] = _ask_options( + sharding_strategy_query, + FSDP_SHARDING_STRATEGY, + lambda x: FSDP_SHARDING_STRATEGY[int(x)], + ) + fsdp_config["fsdp_offload_params"] = _ask_field( + "Do you want to offload parameters and gradients to CPU? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + fsdp_wrap_query = "What should be your auto wrap policy?" + fsdp_config["fsdp_auto_wrap_policy"] = _ask_options( + fsdp_wrap_query, + FSDP_AUTO_WRAP_POLICY, + lambda x: FSDP_AUTO_WRAP_POLICY[int(x)], + ) + if fsdp_config["fsdp_auto_wrap_policy"] == FSDP_AUTO_WRAP_POLICY[0]: + use_no_split_modules = _ask_field( + "Do you want to use the model's `_no_split_modules` to wrap. Only applicable for 🤗 Transformers [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if not use_no_split_modules: + fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = _ask_field( + "Specify the comma-separated list of transformer layer class names (case-sensitive) to wrap ,e.g, :" + "`BertLayer`, `GPTJBlock`, `T5Block`, `BertLayer,BertEmbeddings,BertSelfOutput` ...? : ", + str, + ) + elif fsdp_config["fsdp_auto_wrap_policy"] == FSDP_AUTO_WRAP_POLICY[1]: + fsdp_config["fsdp_min_num_params"] = _ask_field( + "What should be your FSDP's minimum number of parameters for Default Auto Wrapping Policy? [1e8]: ", + int, + default=100000000, + ) + fsdp_backward_prefetch_query = "What should be your FSDP's backward prefetch policy?" + fsdp_config["fsdp_backward_prefetch"] = _ask_options( + fsdp_backward_prefetch_query, + FSDP_BACKWARD_PREFETCH, + lambda x: FSDP_BACKWARD_PREFETCH[int(x)], + ) + fsdp_state_dict_type_query = "What should be your FSDP's state dict type?" + fsdp_config["fsdp_state_dict_type"] = _ask_options( + fsdp_state_dict_type_query, + FSDP_STATE_DICT_TYPE, + lambda x: FSDP_STATE_DICT_TYPE[int(x)], + default=2, + ) + fsdp_config["fsdp_forward_prefetch"] = _ask_field( + "Do you want to enable FSDP's forward prefetch policy? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + fsdp_config["fsdp_use_orig_params"] = _ask_field( + "Do you want to enable FSDP's `use_orig_params` feature? [YES/no]: ", + _convert_yes_no_to_bool, + default=True, + error_message="Please enter yes or no.", + ) + fsdp_config["fsdp_cpu_ram_efficient_loading"] = _ask_field( + "Do you want to enable CPU RAM efficient model loading? Only applicable for 🤗 Transformers models. [YES/no]: ", + _convert_yes_no_to_bool, + default=True, + error_message="Please enter yes or no.", + ) + if fsdp_config["fsdp_cpu_ram_efficient_loading"]: + fsdp_config["fsdp_sync_module_states"] = True + else: + fsdp_config["fsdp_sync_module_states"] = _ask_field( + "Do you want each individually wrapped FSDP unit to broadcast module parameters from rank 0 at the start? [YES/no]: ", + _convert_yes_no_to_bool, + default=True, + error_message="Please enter yes or no.", + ) + + megatron_lm_config = {} + if distributed_type in [DistributedType.MULTI_GPU]: + use_megatron_lm = _ask_field( + "Do you want to use Megatron-LM ? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if use_megatron_lm: + distributed_type = DistributedType.MEGATRON_LM + if distributed_type == DistributedType.MEGATRON_LM: + prefix = "megatron_lm_" + megatron_lm_config[prefix + "tp_degree"] = _ask_field( + "What is the Tensor Parallelism degree/size? [1]:", + int, + default=1, + error_message="Please enter an integer.", + ) + if megatron_lm_config[prefix + "tp_degree"] > 1: + megatron_lm_config[prefix + "sequence_parallelism"] = _ask_field( + "Do you want to enable Sequence Parallelism? [YES/no]: ", + _convert_yes_no_to_bool, + default=True, + error_message="Please enter yes or no.", + ) + + megatron_lm_config[prefix + "pp_degree"] = _ask_field( + "What is the Pipeline Parallelism degree/size? [1]:", + int, + default=1, + error_message="Please enter an integer.", + ) + if megatron_lm_config[prefix + "pp_degree"] > 1: + megatron_lm_config[prefix + "num_micro_batches"] = _ask_field( + "What is the number of micro-batches? [1]:", + int, + default=1, + error_message="Please enter an integer.", + ) + + megatron_lm_config[prefix + "recompute_activations"] = _ask_field( + "Do you want to enable selective activation recomputation? [YES/no]: ", + _convert_yes_no_to_bool, + default=True, + error_message="Please enter yes or no.", + ) + + megatron_lm_config[prefix + "use_distributed_optimizer"] = _ask_field( + "Do you want to use distributed optimizer " + "which shards optimizer state and gradients across data parallel ranks? [YES/no]: ", + _convert_yes_no_to_bool, + default=True, + error_message="Please enter yes or no.", + ) + + megatron_lm_config[prefix + "gradient_clipping"] = _ask_field( + "What is the gradient clipping value based on global L2 Norm (0 to disable)? [1.0]: ", + float, + default=1.0, + ) + # TPU specific defaults + tpu_commands = None + tpu_command_file = None + tpu_downcast_bf16 = "no" + tpu_env = [] + tpu_name = None + tpu_vm = None + tpu_zone = None + tpu_use_sudo = False + tpu_use_cluster = False + + if distributed_type in [ + DistributedType.MULTI_CPU, + DistributedType.MULTI_XPU, + DistributedType.MULTI_GPU, + DistributedType.MULTI_MLU, + DistributedType.MULTI_NPU, + DistributedType.XLA, + ]: + machine_type = str(distributed_type).split(".")[1].replace("MULTI_", "") + if machine_type == "TPU": + machine_type += " cores" + elif machine_type == "CPU": + machine_type = "processes" + else: + machine_type += "(s)" + num_processes = _ask_field( + f"How many {machine_type} should be used for distributed training? [1]:", + int, + default=1, + error_message="Please enter an integer.", + ) + elif distributed_type in [DistributedType.FSDP, DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]: + num_processes = _ask_field( + "How many GPU(s) should be used for distributed training? [1]:", + int, + default=1, + error_message="Please enter an integer.", + ) + else: + num_processes = 1 + + if (distributed_type == DistributedType.MULTI_GPU) and (num_machines == 1) and (num_processes == 1): + raise ValueError( + f"Specified distributed type {distributed_type} but only using 1 GPU on a single machine. Please select `No distributed training` for the type of machine you are using." + ) + + if ( + distributed_type + in [ + DistributedType.MULTI_GPU, + DistributedType.MULTI_MLU, + DistributedType.MULTI_NPU, + DistributedType.MULTI_XPU, + DistributedType.NO, + ] + and not use_cpu + and not use_mps + ): + if is_npu_available(): + machine_type = "NPU(s)" + elif is_mlu_available(): + machine_type = "MLU(s)" + else: + machine_type = "GPU(s)" + gpu_ids = _ask_field( + f"What {machine_type} (by id) should be used for training on this machine as a comma-seperated list? [all]:", + default="all", + ) + + # CPU affinity is only supported on NVIDIA hardware for now + enable_cpu_affinity = False + if distributed_type == (DistributedType.NO, DistributedType.MULTI_GPU) and not use_cpu and not use_mps: + enable_cpu_affinity = _ask_field( + "Would you like to enable numa efficiency? (Currently only supported on NVIDIA hardware). [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + + if distributed_type == DistributedType.XLA: + mixed_precision = "no" + main_training_function = _ask_field( + "What is the name of the function in your script that should be launched in all parallel scripts? [main]: ", + default="main", + ) + tpu_use_cluster = _ask_field( + "Are you using a TPU cluster? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if tpu_use_cluster: + tpu_name = _ask_field( + "What is the name of your TPU cluster? ", + default=None, + error_message="Please enter the name of your TPU cluster.", + ) + tpu_zone = _ask_field( + "What is the zone of your TPU cluster? ", + default=None, + error_message="Please enter the zone of your TPU cluster.", + ) + tpu_use_sudo = _ask_field( + "To run a python script in a TPU pod, should `sudo` be used? [yes/NO]: ", + default=False, + error_message="Please enter yes or no.", + ) + run_commands = _ask_field( + "Do you have code you wish to run on startup in each pod? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if run_commands: + use_command_file = _ask_field( + "Is this code located in a bash script? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if use_command_file: + tpu_command_file = _ask_field( + "What is the path to your bash script? ", + default=None, + error_message="Please enter the path to your bash script.", + ) + tpu_command_file = os.path.abspath(tpu_command_file) + else: + print("Please enter each command seperately you wish to run on startup in each pod.") + tpu_commands = [] + another_command = True + while another_command: + tpu_commands.append( + _ask_field( + "Please enter a single command to be ran ", + default=None, + error_message="Please enter the commands you wish to run on startup in each pod as a single string.", + ) + ) + another_command = _ask_field( + "Do you wish to add another command? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + tpu_vm = _ask_field( + "If not using an instance group, what are the names of the Compute VM instances to be used, seperated by a comma: ", + default="", + ).split(",") + tpu_env = _ask_field( + "What environment variables do you wish to set in each pod, seperated by a comma: ", + default="", + ).split(",") + + else: + main_training_function = "main" + if distributed_type == DistributedType.DEEPSPEED and use_deepspeed_config: + mixed_precision = None + else: + mixed_precision = _ask_options( + "Do you wish to use FP16 or BF16 (mixed precision)?", + ["no", "fp16", "bf16", "fp8"], + _convert_mixed_precision, + ) + + if use_dynamo and mixed_precision == "no" and not use_cpu: + print( + "Torch dynamo used without mixed precision requires TF32 to be efficient. Accelerate will enable it by default when launching your scripts." + ) + + if distributed_type == DistributedType.XLA and mixed_precision == "bf16": + tpu_downcast_bf16 = _ask_field( + "Should `torch.float` be cast as `bfloat16` and `torch.double` remain `float32` on TPUs?", default="no" + ) + + return ClusterConfig( + compute_environment=ComputeEnvironment.LOCAL_MACHINE, + distributed_type=distributed_type, + num_processes=num_processes, + gpu_ids=gpu_ids, + mixed_precision=mixed_precision, + downcast_bf16=tpu_downcast_bf16, + machine_rank=machine_rank, + num_machines=num_machines, + main_process_ip=main_process_ip, + main_process_port=main_process_port, + main_training_function=main_training_function, + deepspeed_config=deepspeed_config, + fsdp_config=fsdp_config, + megatron_lm_config=megatron_lm_config, + ipex_config=ipex_config, + mpirun_config=mpirun_config, + use_cpu=use_cpu, + rdzv_backend=rdzv_backend, + same_network=same_network, + commands=tpu_commands, + command_file=tpu_command_file, + tpu_env=tpu_env, + tpu_name=tpu_name, + tpu_vm=tpu_vm, + tpu_zone=tpu_zone, + tpu_use_sudo=tpu_use_sudo, + tpu_use_cluster=tpu_use_cluster, + dynamo_config=dynamo_config, + debug=debug, + enable_cpu_affinity=enable_cpu_affinity, + ) diff --git a/llm/Lib/site-packages/accelerate/commands/config/config.py b/llm/Lib/site-packages/accelerate/commands/config/config.py new file mode 100644 index 0000000000000000000000000000000000000000..72414f2abe62d76bd5133f4b0ed99bf34133f6f6 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/commands/config/config.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os + +from accelerate.utils import ComputeEnvironment + +from .cluster import get_cluster_input +from .config_args import cache_dir, default_config_file, default_yaml_config_file, load_config_from_file # noqa: F401 +from .config_utils import _ask_field, _ask_options, _convert_compute_environment # noqa: F401 +from .sagemaker import get_sagemaker_input + + +description = "Launches a series of prompts to create and save a `default_config.yaml` configuration file for your training system. Should always be ran first on your machine" + + +def get_user_input(): + compute_environment = _ask_options( + "In which compute environment are you running?", + ["This machine", "AWS (Amazon SageMaker)"], + _convert_compute_environment, + ) + if compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER: + config = get_sagemaker_input() + else: + config = get_cluster_input() + return config + + +def config_command_parser(subparsers=None): + if subparsers is not None: + parser = subparsers.add_parser("config", description=description) + else: + parser = argparse.ArgumentParser("Accelerate config command", description=description) + + parser.add_argument( + "--config_file", + default=None, + help=( + "The path to use to store the config file. Will default to a file named default_config.yaml in the cache " + "location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have " + "such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed " + "with 'huggingface'." + ), + ) + + if subparsers is not None: + parser.set_defaults(func=config_command) + return parser + + +def config_command(args): + config = get_user_input() + if args.config_file is not None: + config_file = args.config_file + else: + if not os.path.isdir(cache_dir): + os.makedirs(cache_dir) + config_file = default_yaml_config_file + + if config_file.endswith(".json"): + config.to_json_file(config_file) + else: + config.to_yaml_file(config_file) + print(f"accelerate configuration saved at {config_file}") + + +def main(): + parser = config_command_parser() + args = parser.parse_args() + config_command(args) + + +if __name__ == "__main__": + main() diff --git a/llm/Lib/site-packages/accelerate/commands/config/config_args.py b/llm/Lib/site-packages/accelerate/commands/config/config_args.py new file mode 100644 index 0000000000000000000000000000000000000000..c50f1c34a42d354903a80b506290958807a7b7c0 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/commands/config/config_args.py @@ -0,0 +1,243 @@ +#!/usr/bin/env python + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from dataclasses import dataclass +from enum import Enum +from typing import List, Optional, Union + +import yaml + +from ...utils import ComputeEnvironment, DistributedType, SageMakerDistributedType +from ...utils.constants import SAGEMAKER_PYTHON_VERSION, SAGEMAKER_PYTORCH_VERSION, SAGEMAKER_TRANSFORMERS_VERSION + + +hf_cache_home = os.path.expanduser( + os.environ.get("HF_HOME", os.path.join(os.environ.get("XDG_CACHE_HOME", "~/.cache"), "huggingface")) +) +cache_dir = os.path.join(hf_cache_home, "accelerate") +default_json_config_file = os.path.join(cache_dir, "default_config.yaml") +default_yaml_config_file = os.path.join(cache_dir, "default_config.yaml") + +# For backward compatibility: the default config is the json one if it's the only existing file. +if os.path.isfile(default_yaml_config_file) or not os.path.isfile(default_json_config_file): + default_config_file = default_yaml_config_file +else: + default_config_file = default_json_config_file + + +def load_config_from_file(config_file): + if config_file is not None: + if not os.path.isfile(config_file): + raise FileNotFoundError( + f"The passed configuration file `{config_file}` does not exist. " + "Please pass an existing file to `accelerate launch`, or use the default one " + "created through `accelerate config` and run `accelerate launch` " + "without the `--config_file` argument." + ) + else: + config_file = default_config_file + with open(config_file, encoding="utf-8") as f: + if config_file.endswith(".json"): + if ( + json.load(f).get("compute_environment", ComputeEnvironment.LOCAL_MACHINE) + == ComputeEnvironment.LOCAL_MACHINE + ): + config_class = ClusterConfig + else: + config_class = SageMakerConfig + return config_class.from_json_file(json_file=config_file) + else: + if ( + yaml.safe_load(f).get("compute_environment", ComputeEnvironment.LOCAL_MACHINE) + == ComputeEnvironment.LOCAL_MACHINE + ): + config_class = ClusterConfig + else: + config_class = SageMakerConfig + return config_class.from_yaml_file(yaml_file=config_file) + + +@dataclass +class BaseConfig: + compute_environment: ComputeEnvironment + distributed_type: Union[DistributedType, SageMakerDistributedType] + mixed_precision: str + use_cpu: bool + debug: bool + + def to_dict(self): + result = self.__dict__ + # For serialization, it's best to convert Enums to strings (or their underlying value type). + for key, value in result.items(): + if isinstance(value, Enum): + result[key] = value.value + if isinstance(value, dict) and not bool(value): + result[key] = None + result = {k: v for k, v in result.items() if v is not None} + return result + + @classmethod + def from_json_file(cls, json_file=None): + json_file = default_json_config_file if json_file is None else json_file + with open(json_file, encoding="utf-8") as f: + config_dict = json.load(f) + if "compute_environment" not in config_dict: + config_dict["compute_environment"] = ComputeEnvironment.LOCAL_MACHINE + if "mixed_precision" not in config_dict: + config_dict["mixed_precision"] = "fp16" if ("fp16" in config_dict and config_dict["fp16"]) else None + if "fp16" in config_dict: # Convert the config to the new format. + del config_dict["fp16"] + if "dynamo_backend" in config_dict: # Convert the config to the new format. + dynamo_backend = config_dict.pop("dynamo_backend") + config_dict["dynamo_config"] = {} if dynamo_backend == "NO" else {"dynamo_backend": dynamo_backend} + if "use_cpu" not in config_dict: + config_dict["use_cpu"] = False + if "debug" not in config_dict: + config_dict["debug"] = False + if "enable_cpu_affinity" not in config_dict: + config_dict["enable_cpu_affinity"] = False + extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys())) + if len(extra_keys) > 0: + raise ValueError( + f"The config file at {json_file} had unknown keys ({extra_keys}), please try upgrading your `accelerate`" + " version or fix (and potentially remove) these keys from your config file." + ) + + return cls(**config_dict) + + def to_json_file(self, json_file): + with open(json_file, "w", encoding="utf-8") as f: + content = json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" + f.write(content) + + @classmethod + def from_yaml_file(cls, yaml_file=None): + yaml_file = default_yaml_config_file if yaml_file is None else yaml_file + with open(yaml_file, encoding="utf-8") as f: + config_dict = yaml.safe_load(f) + if "compute_environment" not in config_dict: + config_dict["compute_environment"] = ComputeEnvironment.LOCAL_MACHINE + if "mixed_precision" not in config_dict: + config_dict["mixed_precision"] = "fp16" if ("fp16" in config_dict and config_dict["fp16"]) else None + if isinstance(config_dict["mixed_precision"], bool) and not config_dict["mixed_precision"]: + config_dict["mixed_precision"] = "no" + if "fp16" in config_dict: # Convert the config to the new format. + del config_dict["fp16"] + if "dynamo_backend" in config_dict: # Convert the config to the new format. + dynamo_backend = config_dict.pop("dynamo_backend") + config_dict["dynamo_config"] = {} if dynamo_backend == "NO" else {"dynamo_backend": dynamo_backend} + if "use_cpu" not in config_dict: + config_dict["use_cpu"] = False + if "debug" not in config_dict: + config_dict["debug"] = False + if "enable_cpu_affinity" not in config_dict: + config_dict["enable_cpu_affinity"] = False + extra_keys = sorted(set(config_dict.keys()) - set(cls.__dataclass_fields__.keys())) + if len(extra_keys) > 0: + raise ValueError( + f"The config file at {yaml_file} had unknown keys ({extra_keys}), please try upgrading your `accelerate`" + " version or fix (and potentially remove) these keys from your config file." + ) + return cls(**config_dict) + + def to_yaml_file(self, yaml_file): + with open(yaml_file, "w", encoding="utf-8") as f: + yaml.safe_dump(self.to_dict(), f) + + def __post_init__(self): + if isinstance(self.compute_environment, str): + self.compute_environment = ComputeEnvironment(self.compute_environment) + if isinstance(self.distributed_type, str): + if self.compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER: + self.distributed_type = SageMakerDistributedType(self.distributed_type) + else: + self.distributed_type = DistributedType(self.distributed_type) + if getattr(self, "dynamo_config", None) is None: + self.dynamo_config = {} + + +@dataclass +class ClusterConfig(BaseConfig): + num_processes: int + machine_rank: int = 0 + num_machines: int = 1 + gpu_ids: Optional[str] = None + main_process_ip: Optional[str] = None + main_process_port: Optional[int] = None + rdzv_backend: Optional[str] = "static" + same_network: Optional[bool] = False + main_training_function: str = "main" + enable_cpu_affinity: bool = False + + # args for deepspeed_plugin + deepspeed_config: dict = None + # args for fsdp + fsdp_config: dict = None + # args for megatron_lm + megatron_lm_config: dict = None + # args for ipex + ipex_config: dict = None + # args for mpirun + mpirun_config: dict = None + # args for TPU + downcast_bf16: bool = False + + # args for TPU pods + tpu_name: str = None + tpu_zone: str = None + tpu_use_cluster: bool = False + tpu_use_sudo: bool = False + command_file: str = None + commands: List[str] = None + tpu_vm: List[str] = None + tpu_env: List[str] = None + + # args for dynamo + dynamo_config: dict = None + + def __post_init__(self): + if self.deepspeed_config is None: + self.deepspeed_config = {} + if self.fsdp_config is None: + self.fsdp_config = {} + if self.megatron_lm_config is None: + self.megatron_lm_config = {} + if self.ipex_config is None: + self.ipex_config = {} + if self.mpirun_config is None: + self.mpirun_config = {} + return super().__post_init__() + + +@dataclass +class SageMakerConfig(BaseConfig): + ec2_instance_type: str + iam_role_name: str + image_uri: Optional[str] = None + profile: Optional[str] = None + region: str = "us-east-1" + num_machines: int = 1 + gpu_ids: str = "all" + base_job_name: str = f"accelerate-sagemaker-{num_machines}" + pytorch_version: str = SAGEMAKER_PYTORCH_VERSION + transformers_version: str = SAGEMAKER_TRANSFORMERS_VERSION + py_version: str = SAGEMAKER_PYTHON_VERSION + sagemaker_inputs_file: str = None + sagemaker_metrics_file: str = None + additional_args: dict = None + dynamo_config: dict = None diff --git a/llm/Lib/site-packages/accelerate/commands/config/config_utils.py b/llm/Lib/site-packages/accelerate/commands/config/config_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..091da5b03c771bccefc2cd21b047536fbc07bcbf --- /dev/null +++ b/llm/Lib/site-packages/accelerate/commands/config/config_utils.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from ...utils.dataclasses import ( + ComputeEnvironment, + DistributedType, + DynamoBackend, + PrecisionType, + SageMakerDistributedType, +) +from ..menu import BulletMenu + + +DYNAMO_BACKENDS = [ + "EAGER", + "AOT_EAGER", + "INDUCTOR", + "AOT_TS_NVFUSER", + "NVPRIMS_NVFUSER", + "CUDAGRAPHS", + "OFI", + "FX2TRT", + "ONNXRT", + "TENSORRT", + "IPEX", + "TVM", +] + + +def _ask_field(input_text, convert_value=None, default=None, error_message=None): + ask_again = True + while ask_again: + result = input(input_text) + try: + if default is not None and len(result) == 0: + return default + return convert_value(result) if convert_value is not None else result + except Exception: + if error_message is not None: + print(error_message) + + +def _ask_options(input_text, options=[], convert_value=None, default=0): + menu = BulletMenu(input_text, options) + result = menu.run(default_choice=default) + return convert_value(result) if convert_value is not None else result + + +def _convert_compute_environment(value): + value = int(value) + return ComputeEnvironment(["LOCAL_MACHINE", "AMAZON_SAGEMAKER"][value]) + + +def _convert_distributed_mode(value): + value = int(value) + return DistributedType(["NO", "MULTI_CPU", "MULTI_XPU", "MULTI_GPU", "MULTI_NPU", "MULTI_MLU", "XLA"][value]) + + +def _convert_dynamo_backend(value): + value = int(value) + return DynamoBackend(DYNAMO_BACKENDS[value]).value + + +def _convert_mixed_precision(value): + value = int(value) + return PrecisionType(["no", "fp16", "bf16", "fp8"][value]) + + +def _convert_sagemaker_distributed_mode(value): + value = int(value) + return SageMakerDistributedType(["NO", "DATA_PARALLEL", "MODEL_PARALLEL"][value]) + + +def _convert_yes_no_to_bool(value): + return {"yes": True, "no": False}[value.lower()] + + +class SubcommandHelpFormatter(argparse.RawDescriptionHelpFormatter): + """ + A custom formatter that will remove the usage line from the help message for subcommands. + """ + + def _format_usage(self, usage, actions, groups, prefix): + usage = super()._format_usage(usage, actions, groups, prefix) + usage = usage.replace(" [] ", "") + return usage diff --git a/llm/Lib/site-packages/accelerate/commands/config/default.py b/llm/Lib/site-packages/accelerate/commands/config/default.py new file mode 100644 index 0000000000000000000000000000000000000000..e33331b98e6c8eacbaf8e9710b40e2ca6fc88b3d --- /dev/null +++ b/llm/Lib/site-packages/accelerate/commands/config/default.py @@ -0,0 +1,133 @@ +#!/usr/bin/env python + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import torch + +from ...utils import is_mlu_available, is_npu_available, is_xpu_available +from .config_args import ClusterConfig, default_json_config_file +from .config_utils import SubcommandHelpFormatter + + +description = "Create a default config file for Accelerate with only a few flags set." + + +def write_basic_config(mixed_precision="no", save_location: str = default_json_config_file, use_xpu: bool = False): + """ + Creates and saves a basic cluster config to be used on a local machine with potentially multiple GPUs. Will also + set CPU if it is a CPU-only machine. + + Args: + mixed_precision (`str`, *optional*, defaults to "no"): + Mixed Precision to use. Should be one of "no", "fp16", or "bf16" + save_location (`str`, *optional*, defaults to `default_json_config_file`): + Optional custom save location. Should be passed to `--config_file` when using `accelerate launch`. Default + location is inside the huggingface cache folder (`~/.cache/huggingface`) but can be overriden by setting + the `HF_HOME` environmental variable, followed by `accelerate/default_config.yaml`. + use_xpu (`bool`, *optional*, defaults to `False`): + Whether to use XPU if available. + """ + path = Path(save_location) + path.parent.mkdir(parents=True, exist_ok=True) + if path.exists(): + print( + f"Configuration already exists at {save_location}, will not override. Run `accelerate config` manually or pass a different `save_location`." + ) + return False + mixed_precision = mixed_precision.lower() + if mixed_precision not in ["no", "fp16", "bf16", "fp8"]: + raise ValueError( + f"`mixed_precision` should be one of 'no', 'fp16', 'bf16', or 'fp8'. Received {mixed_precision}" + ) + config = { + "compute_environment": "LOCAL_MACHINE", + "mixed_precision": mixed_precision, + } + if is_mlu_available(): + num_mlus = torch.mlu.device_count() + config["num_processes"] = num_mlus + config["use_cpu"] = False + if num_mlus > 1: + config["distributed_type"] = "MULTI_MLU" + else: + config["distributed_type"] = "NO" + elif torch.cuda.is_available(): + num_gpus = torch.cuda.device_count() + config["num_processes"] = num_gpus + config["use_cpu"] = False + if num_gpus > 1: + config["distributed_type"] = "MULTI_GPU" + else: + config["distributed_type"] = "NO" + elif is_xpu_available() and use_xpu: + num_xpus = torch.xpu.device_count() + config["num_processes"] = num_xpus + config["use_cpu"] = False + if num_xpus > 1: + config["distributed_type"] = "MULTI_XPU" + else: + config["distributed_type"] = "NO" + elif is_npu_available(): + num_npus = torch.npu.device_count() + config["num_processes"] = num_npus + config["use_cpu"] = False + if num_npus > 1: + config["distributed_type"] = "MULTI_NPU" + else: + config["distributed_type"] = "NO" + else: + num_xpus = 0 + config["use_cpu"] = True + config["num_processes"] = 1 + config["distributed_type"] = "NO" + config["debug"] = False + config = ClusterConfig(**config) + config.to_json_file(path) + return path + + +def default_command_parser(parser, parents): + parser = parser.add_parser("default", parents=parents, help=description, formatter_class=SubcommandHelpFormatter) + parser.add_argument( + "--config_file", + default=default_json_config_file, + help=( + "The path to use to store the config file. Will default to a file named default_config.yaml in the cache " + "location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have " + "such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed " + "with 'huggingface'." + ), + dest="save_location", + ) + + parser.add_argument( + "--mixed_precision", + choices=["no", "fp16", "bf16"], + type=str, + help="Whether or not to use mixed precision training. " + "Choose between FP16 and BF16 (bfloat16) training. " + "BF16 training is only supported on Nvidia Ampere GPUs and PyTorch 1.10 or later.", + default="no", + ) + parser.set_defaults(func=default_config_command) + return parser + + +def default_config_command(args): + config_file = write_basic_config(args.mixed_precision, args.save_location) + if config_file: + print(f"accelerate configuration saved at {config_file}") diff --git a/llm/Lib/site-packages/accelerate/commands/config/sagemaker.py b/llm/Lib/site-packages/accelerate/commands/config/sagemaker.py new file mode 100644 index 0000000000000000000000000000000000000000..1e3491fee0ad28df82683a89d128bbc097053c2f --- /dev/null +++ b/llm/Lib/site-packages/accelerate/commands/config/sagemaker.py @@ -0,0 +1,267 @@ +#!/usr/bin/env python + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +import os + +from ...utils.constants import SAGEMAKER_PARALLEL_EC2_INSTANCES, TORCH_DYNAMO_MODES +from ...utils.dataclasses import ComputeEnvironment, SageMakerDistributedType +from ...utils.imports import is_boto3_available +from .config_args import SageMakerConfig +from .config_utils import ( + DYNAMO_BACKENDS, + _ask_field, + _ask_options, + _convert_dynamo_backend, + _convert_mixed_precision, + _convert_sagemaker_distributed_mode, + _convert_yes_no_to_bool, +) + + +if is_boto3_available(): + import boto3 # noqa: F401 + + +def _create_iam_role_for_sagemaker(role_name): + iam_client = boto3.client("iam") + + sagemaker_trust_policy = { + "Version": "2012-10-17", + "Statement": [ + {"Effect": "Allow", "Principal": {"Service": "sagemaker.amazonaws.com"}, "Action": "sts:AssumeRole"} + ], + } + try: + # create the role, associated with the chosen trust policy + iam_client.create_role( + RoleName=role_name, AssumeRolePolicyDocument=json.dumps(sagemaker_trust_policy, indent=2) + ) + policy_document = { + "Version": "2012-10-17", + "Statement": [ + { + "Effect": "Allow", + "Action": [ + "sagemaker:*", + "ecr:GetDownloadUrlForLayer", + "ecr:BatchGetImage", + "ecr:BatchCheckLayerAvailability", + "ecr:GetAuthorizationToken", + "cloudwatch:PutMetricData", + "cloudwatch:GetMetricData", + "cloudwatch:GetMetricStatistics", + "cloudwatch:ListMetrics", + "logs:CreateLogGroup", + "logs:CreateLogStream", + "logs:DescribeLogStreams", + "logs:PutLogEvents", + "logs:GetLogEvents", + "s3:CreateBucket", + "s3:ListBucket", + "s3:GetBucketLocation", + "s3:GetObject", + "s3:PutObject", + ], + "Resource": "*", + } + ], + } + # attach policy to role + iam_client.put_role_policy( + RoleName=role_name, + PolicyName=f"{role_name}_policy_permission", + PolicyDocument=json.dumps(policy_document, indent=2), + ) + except iam_client.exceptions.EntityAlreadyExistsException: + print(f"role {role_name} already exists. Using existing one") + + +def _get_iam_role_arn(role_name): + iam_client = boto3.client("iam") + return iam_client.get_role(RoleName=role_name)["Role"]["Arn"] + + +def get_sagemaker_input(): + credentials_configuration = _ask_options( + "How do you want to authorize?", + ["AWS Profile", "Credentials (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY) "], + int, + ) + aws_profile = None + if credentials_configuration == 0: + aws_profile = _ask_field("Enter your AWS Profile name: [default] ", default="default") + os.environ["AWS_PROFILE"] = aws_profile + else: + print( + "Note you will need to provide AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY when you launch you training script with," + "`accelerate launch --aws_access_key_id XXX --aws_secret_access_key YYY`" + ) + aws_access_key_id = _ask_field("AWS Access Key ID: ") + os.environ["AWS_ACCESS_KEY_ID"] = aws_access_key_id + + aws_secret_access_key = _ask_field("AWS Secret Access Key: ") + os.environ["AWS_SECRET_ACCESS_KEY"] = aws_secret_access_key + + aws_region = _ask_field("Enter your AWS Region: [us-east-1]", default="us-east-1") + os.environ["AWS_DEFAULT_REGION"] = aws_region + + role_management = _ask_options( + "Do you already have an IAM Role for executing Amazon SageMaker Training Jobs?", + ["Provide IAM Role name", "Create new IAM role using credentials"], + int, + ) + if role_management == 0: + iam_role_name = _ask_field("Enter your IAM role name: ") + else: + iam_role_name = "accelerate_sagemaker_execution_role" + print(f'Accelerate will create an iam role "{iam_role_name}" using the provided credentials') + _create_iam_role_for_sagemaker(iam_role_name) + + is_custom_docker_image = _ask_field( + "Do you want to use custom Docker image? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + docker_image = None + if is_custom_docker_image: + docker_image = _ask_field("Enter your Docker image: ", lambda x: str(x).lower()) + + is_sagemaker_inputs_enabled = _ask_field( + "Do you want to provide SageMaker input channels with data locations? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + sagemaker_inputs_file = None + if is_sagemaker_inputs_enabled: + sagemaker_inputs_file = _ask_field( + "Enter the path to the SageMaker inputs TSV file with columns (channel_name, data_location): ", + lambda x: str(x).lower(), + ) + + is_sagemaker_metrics_enabled = _ask_field( + "Do you want to enable SageMaker metrics? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + sagemaker_metrics_file = None + if is_sagemaker_metrics_enabled: + sagemaker_metrics_file = _ask_field( + "Enter the path to the SageMaker metrics TSV file with columns (metric_name, metric_regex): ", + lambda x: str(x).lower(), + ) + + distributed_type = _ask_options( + "What is the distributed mode?", + ["No distributed training", "Data parallelism"], + _convert_sagemaker_distributed_mode, + ) + dynamo_config = {} + use_dynamo = _ask_field( + "Do you wish to optimize your script with torch dynamo?[yes/NO]:", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + if use_dynamo: + prefix = "dynamo_" + dynamo_config[prefix + "backend"] = _ask_options( + "Which dynamo backend would you like to use?", + [x.lower() for x in DYNAMO_BACKENDS], + _convert_dynamo_backend, + default=2, + ) + use_custom_options = _ask_field( + "Do you want to customize the defaults sent to torch.compile? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + + if use_custom_options: + dynamo_config[prefix + "mode"] = _ask_options( + "Which mode do you want to use?", + TORCH_DYNAMO_MODES, + lambda x: TORCH_DYNAMO_MODES[int(x)], + default="default", + ) + dynamo_config[prefix + "use_fullgraph"] = _ask_field( + "Do you want the fullgraph mode or it is ok to break model into several subgraphs? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + dynamo_config[prefix + "use_dynamic"] = _ask_field( + "Do you want to enable dynamic shape tracing? [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + ec2_instance_query = "Which EC2 instance type you want to use for your training?" + if distributed_type != SageMakerDistributedType.NO: + ec2_instance_type = _ask_options( + ec2_instance_query, SAGEMAKER_PARALLEL_EC2_INSTANCES, lambda x: SAGEMAKER_PARALLEL_EC2_INSTANCES[int(x)] + ) + else: + ec2_instance_query += "? [ml.p3.2xlarge]:" + ec2_instance_type = _ask_field(ec2_instance_query, lambda x: str(x).lower(), default="ml.p3.2xlarge") + + debug = False + if distributed_type != SageMakerDistributedType.NO: + debug = _ask_field( + "Should distributed operations be checked while running for errors? This can avoid timeout issues but will be slower. [yes/NO]: ", + _convert_yes_no_to_bool, + default=False, + error_message="Please enter yes or no.", + ) + + num_machines = 1 + if distributed_type in (SageMakerDistributedType.DATA_PARALLEL, SageMakerDistributedType.MODEL_PARALLEL): + num_machines = _ask_field( + "How many machines do you want use? [1]: ", + int, + default=1, + ) + + mixed_precision = _ask_options( + "Do you wish to use FP16 or BF16 (mixed precision)?", + ["no", "fp16", "bf16", "fp8"], + _convert_mixed_precision, + ) + + if use_dynamo and mixed_precision == "no": + print( + "Torch dynamo used without mixed precision requires TF32 to be efficient. Accelerate will enable it by default when launching your scripts." + ) + + return SageMakerConfig( + image_uri=docker_image, + compute_environment=ComputeEnvironment.AMAZON_SAGEMAKER, + distributed_type=distributed_type, + use_cpu=False, + dynamo_config=dynamo_config, + ec2_instance_type=ec2_instance_type, + profile=aws_profile, + region=aws_region, + iam_role_name=iam_role_name, + mixed_precision=mixed_precision, + num_machines=num_machines, + sagemaker_inputs_file=sagemaker_inputs_file, + sagemaker_metrics_file=sagemaker_metrics_file, + debug=debug, + ) diff --git a/llm/Lib/site-packages/accelerate/commands/config/update.py b/llm/Lib/site-packages/accelerate/commands/config/update.py new file mode 100644 index 0000000000000000000000000000000000000000..5f025594b04ada3e3a78687befc5c1bc1d236adf --- /dev/null +++ b/llm/Lib/site-packages/accelerate/commands/config/update.py @@ -0,0 +1,63 @@ +#!/usr/bin/env python + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +from .config_args import default_config_file, load_config_from_file +from .config_utils import SubcommandHelpFormatter + + +description = "Update an existing config file with the latest defaults while maintaining the old configuration." + + +def update_config(args): + """ + Update an existing config file with the latest defaults while maintaining the old configuration. + """ + config_file = args.config_file + if config_file is None and Path(default_config_file).exists(): + config_file = default_config_file + elif not Path(config_file).exists(): + raise ValueError(f"The passed config file located at {config_file} doesn't exist.") + config = load_config_from_file(config_file) + + if config_file.endswith(".json"): + config.to_json_file(config_file) + else: + config.to_yaml_file(config_file) + return config_file + + +def update_command_parser(parser, parents): + parser = parser.add_parser("update", parents=parents, help=description, formatter_class=SubcommandHelpFormatter) + parser.add_argument( + "--config_file", + default=None, + help=( + "The path to the config file to update. Will default to a file named default_config.yaml in the cache " + "location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have " + "such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed " + "with 'huggingface'." + ), + ) + + parser.set_defaults(func=update_config_command) + return parser + + +def update_config_command(args): + config_file = update_config(args) + print(f"Sucessfully updated the configuration file at {config_file}.") diff --git a/llm/Lib/site-packages/accelerate/commands/env.py b/llm/Lib/site-packages/accelerate/commands/env.py new file mode 100644 index 0000000000000000000000000000000000000000..fb2f60f787a9eba3f75b6ac9171aefd0ffc61647 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/commands/env.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import platform +import subprocess + +import numpy as np +import psutil +import torch + +from accelerate import __version__ as version +from accelerate.commands.config import default_config_file, load_config_from_file + +from ..utils import is_mlu_available, is_npu_available, is_xpu_available + + +def env_command_parser(subparsers=None): + if subparsers is not None: + parser = subparsers.add_parser("env") + else: + parser = argparse.ArgumentParser("Accelerate env command") + + parser.add_argument( + "--config_file", default=None, help="The config file to use for the default values in the launching script." + ) + + if subparsers is not None: + parser.set_defaults(func=env_command) + return parser + + +def env_command(args): + pt_version = torch.__version__ + pt_cuda_available = torch.cuda.is_available() + pt_xpu_available = is_xpu_available() + pt_mlu_available = is_mlu_available() + pt_npu_available = is_npu_available() + + accelerate_config = "Not found" + # Get the default from the config file. + if args.config_file is not None or os.path.isfile(default_config_file): + accelerate_config = load_config_from_file(args.config_file).to_dict() + + # if we can run which, get it + command = None + bash_location = "Not found" + if os.name == "nt": + command = ["where", "accelerate"] + elif os.name == "posix": + command = ["which", "accelerate"] + if command is not None: + bash_location = subprocess.check_output(command, text=True, stderr=subprocess.STDOUT).strip() + info = { + "`Accelerate` version": version, + "Platform": platform.platform(), + "`accelerate` bash location": bash_location, + "Python version": platform.python_version(), + "Numpy version": np.__version__, + "PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})", + "PyTorch XPU available": str(pt_xpu_available), + "PyTorch NPU available": str(pt_npu_available), + "PyTorch MLU available": str(pt_mlu_available), + "System RAM": f"{psutil.virtual_memory().total / 1024 ** 3:.2f} GB", + } + if pt_cuda_available: + info["GPU type"] = torch.cuda.get_device_name() + + print("\nCopy-and-paste the text below in your GitHub issue\n") + print("\n".join([f"- {prop}: {val}" for prop, val in info.items()])) + + print("- `Accelerate` default config:" if args.config_file is None else "- `Accelerate` config passed:") + accelerate_config_str = ( + "\n".join([f"\t- {prop}: {val}" for prop, val in accelerate_config.items()]) + if isinstance(accelerate_config, dict) + else f"\t{accelerate_config}" + ) + print(accelerate_config_str) + + info["`Accelerate` configs"] = accelerate_config + + return info + + +def main() -> int: + parser = env_command_parser() + args = parser.parse_args() + env_command(args) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/llm/Lib/site-packages/accelerate/commands/estimate.py b/llm/Lib/site-packages/accelerate/commands/estimate.py new file mode 100644 index 0000000000000000000000000000000000000000..56da3c5ad9e953687fab71dfc1fb0a878309d1d6 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/commands/estimate.py @@ -0,0 +1,309 @@ +#!/usr/bin/env python + +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from huggingface_hub import model_info +from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError + +from accelerate import init_empty_weights +from accelerate.commands.utils import CustomArgumentParser +from accelerate.utils import ( + calculate_maximum_sizes, + convert_bytes, + is_timm_available, + is_transformers_available, +) + + +if is_transformers_available(): + import transformers + from transformers import AutoConfig, AutoModel + +if is_timm_available(): + import timm + + +def verify_on_hub(repo: str, token: str = None): + "Verifies that the model is on the hub and returns the model info." + try: + return model_info(repo, token=token) + except GatedRepoError: + return "gated" + except RepositoryNotFoundError: + return "repo" + + +def check_has_model(error): + """ + Checks what library spawned `error` when a model is not found + """ + if is_timm_available() and isinstance(error, RuntimeError) and "Unknown model" in error.args[0]: + return "timm" + elif ( + is_transformers_available() + and isinstance(error, OSError) + and "does not appear to have a file named" in error.args[0] + ): + return "transformers" + else: + return "unknown" + + +def create_empty_model(model_name: str, library_name: str, trust_remote_code: bool = False, access_token: str = None): + """ + Creates an empty model from its parent library on the `Hub` to calculate the overall memory consumption. + + Args: + model_name (`str`): + The model name on the Hub + library_name (`str`): + The library the model has an integration with, such as `transformers`. Will be used if `model_name` has no + metadata on the Hub to determine the library. + trust_remote_code (`bool`, `optional`, defaults to `False`): + Whether or not to allow for custom models defined on the Hub in their own modeling files. This option + should only be set to `True` for repositories you trust and in which you have read the code, as it will + execute code present on the Hub on your local machine. + access_token (`str`, `optional`, defaults to `None`): + The access token to use to access private or gated models on the Hub. (for use on the Gradio app) + + Returns: + `torch.nn.Module`: The torch model that has been initialized on the `meta` device. + + """ + model_info = verify_on_hub(model_name, access_token) + # Simplified errors + if model_info == "gated": + raise GatedRepoError( + f"Repo for model `{model_name}` is gated. You must be authenticated to access it. Please run `huggingface-cli login`." + ) + elif model_info == "repo": + raise RepositoryNotFoundError( + f"Repo for model `{model_name}` does not exist on the Hub. If you are trying to access a private repo," + " make sure you are authenticated via `huggingface-cli login` and have access." + ) + if library_name is None: + library_name = getattr(model_info, "library_name", False) + if not library_name: + raise ValueError( + f"Model `{model_name}` does not have any library metadata on the Hub, please manually pass in a `--library_name` to use (such as `transformers`)" + ) + if library_name == "transformers": + if not is_transformers_available(): + raise ImportError( + f"To check `{model_name}`, `transformers` must be installed. Please install it via `pip install transformers`" + ) + print(f"Loading pretrained config for `{model_name}` from `transformers`...") + if model_info.config is None: + raise RuntimeError(f"Tried to load `{model_name}` with `transformers` but it does not have any metadata.") + + auto_map = model_info.config.get("auto_map", False) + config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code, token=access_token) + with init_empty_weights(): + # remote code could specify a specific `AutoModel` class in the `auto_map` + constructor = AutoModel + if isinstance(auto_map, dict): + value = None + for key in auto_map.keys(): + if key.startswith("AutoModelFor"): + value = key + break + if value is not None: + constructor = getattr(transformers, value) + model = constructor.from_config(config, trust_remote_code=trust_remote_code) + elif library_name == "timm": + if not is_timm_available(): + raise ImportError( + f"To check `{model_name}`, `timm` must be installed. Please install it via `pip install timm`" + ) + print(f"Loading pretrained config for `{model_name}` from `timm`...") + with init_empty_weights(): + model = timm.create_model(model_name, pretrained=False) + else: + raise ValueError( + f"Library `{library_name}` is not supported yet, please open an issue on GitHub for us to add support." + ) + return model + + +def create_ascii_table(headers: list, rows: list, title: str): + "Creates a pretty table from a list of rows, minimal version of `tabulate`." + sep_char, in_between = "│", "─" + column_widths = [] + for i in range(len(headers)): + column_values = [row[i] for row in rows] + [headers[i]] + max_column_width = max(len(value) for value in column_values) + column_widths.append(max_column_width) + + formats = [f"%{column_widths[i]}s" for i in range(len(rows[0]))] + + pattern = f"{sep_char}{sep_char.join(formats)}{sep_char}" + diff = 0 + + def make_row(left_char, middle_char, right_char): + return f"{left_char}{middle_char.join([in_between * n for n in column_widths])}{in_between * diff}{right_char}" + + separator = make_row("├", "┼", "┤") + if len(title) > sum(column_widths): + diff = abs(len(title) - len(separator)) + column_widths[-1] += diff + + # Update with diff + separator = make_row("├", "┼", "┤") + initial_rows = [ + make_row("┌", in_between, "┐"), + f"{sep_char}{title.center(len(separator) - 2)}{sep_char}", + make_row("├", "┬", "┤"), + ] + table = "\n".join(initial_rows) + "\n" + column_widths[-1] += diff + centered_line = [text.center(column_widths[i]) for i, text in enumerate(headers)] + table += f"{pattern % tuple(centered_line)}\n{separator}\n" + for i, line in enumerate(rows): + centered_line = [t.center(column_widths[i]) for i, t in enumerate(line)] + table += f"{pattern % tuple(centered_line)}\n" + table += f'└{"┴".join([in_between * n for n in column_widths])}┘' + + return table + + +def estimate_command_parser(subparsers=None): + if subparsers is not None: + parser = subparsers.add_parser("estimate-memory") + else: + parser = CustomArgumentParser(description="Model size estimator for fitting a model onto CUDA memory.") + + parser.add_argument("model_name", type=str, help="The model name on the Hugging Face Hub.") + parser.add_argument( + "--library_name", + type=str, + help="The library the model has an integration with, such as `transformers`, needed only if this information is not stored on the Hub.", + choices=["timm", "transformers"], + ) + parser.add_argument( + "--dtypes", + type=str, + nargs="+", + default=["float32", "float16", "int8", "int4"], + help="The dtypes to use for the model, must be one (or many) of `float32`, `float16`, `int8`, and `int4`", + choices=["float32", "float16", "int8", "int4"], + ) + parser.add_argument( + "--trust_remote_code", + action="store_true", + help="""Whether or not to allow for custom models defined on the Hub in their own modeling files. This flag + should only be used for repositories you trust and in which you have read the code, as it will execute + code present on the Hub on your local machine.""", + default=False, + ) + + if subparsers is not None: + parser.set_defaults(func=estimate_command) + return parser + + +def estimate_training_usage(bytes: int, mixed_precision: str, msamp_config: str = None) -> dict: + """ + Given an amount of `bytes` and `mixed_precision`, calculates how much training memory is needed for a batch size of + 1. + + Args: + bytes (`int`): + The size of the model being trained. + mixed_precision (`str`): + The mixed precision that would be ran. + msamp_config (`str`): + The msamp config to estimate the training memory for if `mixed_precision` is set to `"fp8"`. + """ + memory_sizes = {"model": -1, "optimizer": -1, "gradients": -1, "step": -1} + fp32_size = bytes + fp16_size = bytes // 2 + + if mixed_precision == "float32": + memory_sizes["model"] = fp32_size + memory_sizes["gradients"] = fp32_size + memory_sizes["optimizer"] = fp32_size * 2 + memory_sizes["step"] = fp32_size * 4 + elif mixed_precision in ("float16", "bfloat16") or (mixed_precision == "fp8" and msamp_config is None): + # With native `TransformersEngine`, there is no memory savings with FP8 + # With mixed precision training, the model has weights stored + # in FP16 and FP32 + memory_sizes["model"] = fp32_size + # 1.5 from weight gradient + computation (GEMM) + memory_sizes["gradients"] = fp32_size + fp16_size + # 2x from optimizer states + memory_sizes["optimizer"] = fp32_size * 2 # Optimizer states + memory_sizes["step"] = memory_sizes["optimizer"] + return memory_sizes + + +def gather_data(args): + "Creates an empty model and gathers the data for the sizes" + try: + model = create_empty_model( + args.model_name, library_name=args.library_name, trust_remote_code=args.trust_remote_code + ) + except (RuntimeError, OSError) as e: + library = check_has_model(e) + if library != "unknown": + raise RuntimeError( + f"Tried to load `{args.model_name}` with `{library}` but a possible model to load was not found inside the repo." + ) + raise e + + total_size, largest_layer = calculate_maximum_sizes(model) + + data = [] + + for dtype in args.dtypes: + dtype_total_size = total_size + dtype_largest_layer = largest_layer[0] + dtype_training_size = estimate_training_usage(dtype_total_size, dtype) + if dtype == "float16": + dtype_total_size /= 2 + dtype_largest_layer /= 2 + elif dtype == "int8": + dtype_total_size /= 4 + dtype_largest_layer /= 4 + elif dtype == "int4": + dtype_total_size /= 8 + dtype_largest_layer /= 8 + data.append([dtype, dtype_largest_layer, dtype_total_size, dtype_training_size]) + return data + + +def estimate_command(args): + data = gather_data(args) + for row in data: + for i, item in enumerate(row): + if isinstance(item, (int, float)): + row[i] = convert_bytes(item) + elif isinstance(item, dict): + training_usage = max(item.values()) + row[i] = convert_bytes(training_usage) if training_usage != -1 else "N/A" + + headers = ["dtype", "Largest Layer", "Total Size", "Training using Adam"] + + title = f"Memory Usage for loading `{args.model_name}`" + table = create_ascii_table(headers, data, title) + print(table) + + +def main(): + parser = estimate_command_parser() + args = parser.parse_args() + estimate_command(args) + + +if __name__ == "__main__": + main() diff --git a/llm/Lib/site-packages/accelerate/commands/launch.py b/llm/Lib/site-packages/accelerate/commands/launch.py new file mode 100644 index 0000000000000000000000000000000000000000..f6b91654bc63c3cd0db9cca5f72be511458a20fb --- /dev/null +++ b/llm/Lib/site-packages/accelerate/commands/launch.py @@ -0,0 +1,1085 @@ +#!/usr/bin/env python + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import importlib +import logging +import os +import subprocess +import sys +from pathlib import Path + +import psutil +import torch + +from accelerate.commands.config import default_config_file, load_config_from_file +from accelerate.commands.config.config_args import SageMakerConfig +from accelerate.commands.config.config_utils import DYNAMO_BACKENDS +from accelerate.commands.utils import CustomArgumentParser +from accelerate.state import get_int_from_env +from accelerate.utils import ( + ComputeEnvironment, + DistributedType, + PrepareForLaunch, + _filter_args, + check_cuda_p2p_ib_support, + convert_dict_to_env_variables, + is_bf16_available, + is_deepspeed_available, + is_mlu_available, + is_npu_available, + is_rich_available, + is_sagemaker_available, + is_torch_version, + is_torch_xla_available, + is_xpu_available, + patch_environment, + prepare_deepspeed_cmd_env, + prepare_multi_gpu_env, + prepare_sagemager_args_inputs, + prepare_simple_launcher_cmd_env, + prepare_tpu, +) +from accelerate.utils.constants import DEEPSPEED_MULTINODE_LAUNCHERS, TORCH_DYNAMO_MODES + + +if is_rich_available(): + from rich import get_console + from rich.logging import RichHandler + + FORMAT = "%(message)s" + logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()]) + + +logger = logging.getLogger(__name__) + + +options_to_group = { + "multi_gpu": "Distributed GPUs", + "tpu": "TPU", + "use_deepspeed": "DeepSpeed Arguments", + "use_fsdp": "FSDP Arguments", + "use_megatron_lm": "Megatron-LM Arguments", +} + + +def clean_option(option): + "Finds all cases of - after the first two characters and changes them to _" + if option.startswith("--"): + return option[2:].replace("-", "_") + + +class CustomHelpFormatter(argparse.HelpFormatter): + """ + This is a custom help formatter that will hide all arguments that are not used in the command line when the help is + called. This is useful for the case where the user is using a specific platform and only wants to see the arguments + for that platform. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.titles = [ + "Hardware Selection Arguments", + "Resource Selection Arguments", + "Training Paradigm Arguments", + "positional arguments", + "optional arguments", + ] + + def add_argument(self, action: argparse.Action): + if "accelerate" in sys.argv[0] and "launch" in sys.argv[1:]: + args = sys.argv[2:] + else: + args = sys.argv[1:] + + if len(args) > 1: + args = list(map(clean_option, args)) + used_platforms = [arg for arg in args if arg in options_to_group.keys()] + used_titles = [options_to_group[o] for o in used_platforms] + if action.container.title not in self.titles + used_titles: + action.help = argparse.SUPPRESS + elif action.container.title == "Hardware Selection Arguments": + if set(action.option_strings).isdisjoint(set(args)): + action.help = argparse.SUPPRESS + else: + action.help = action.help + " (currently selected)" + elif action.container.title == "Training Paradigm Arguments": + if set(action.option_strings).isdisjoint(set(args)): + action.help = argparse.SUPPRESS + else: + action.help = action.help + " (currently selected)" + + action.option_strings = [s for s in action.option_strings if "-" not in s[2:]] + super().add_argument(action) + + def end_section(self): + if len(self._current_section.items) < 2: + self._current_section.items = [] + self._current_section.heading = "" + super().end_section() + + +def launch_command_parser(subparsers=None): + description = "Launch a python script in a distributed scenario. Arguments can be passed in with either hyphens (`--num-processes=2`) or underscores (`--num_processes=2`)" + if subparsers is not None: + parser = subparsers.add_parser( + "launch", description=description, add_help=False, allow_abbrev=False, formatter_class=CustomHelpFormatter + ) + else: + parser = CustomArgumentParser( + "Accelerate launch command", + description=description, + add_help=False, + allow_abbrev=False, + formatter_class=CustomHelpFormatter, + ) + + parser.add_argument("-h", "--help", action="help", help="Show this help message and exit.") + + parser.add_argument( + "--config_file", + default=None, + help="The config file to use for the default values in the launching script.", + ) + parser.add_argument( + "--quiet", + "-q", + action="store_true", + help="Silence subprocess errors from the launch stack trace and only show the relevant tracebacks. (Only applicable to DeepSpeed and single-process configurations)", + ) + # Hardware selection arguments + hardware_args = parser.add_argument_group( + "Hardware Selection Arguments", "Arguments for selecting the hardware to be used." + ) + hardware_args.add_argument( + "--cpu", default=False, action="store_true", help="Whether or not to force the training on the CPU." + ) + hardware_args.add_argument( + "--multi_gpu", + default=False, + action="store_true", + help="Whether or not this should launch a distributed GPU training.", + ) + hardware_args.add_argument( + "--tpu", default=False, action="store_true", help="Whether or not this should launch a TPU training." + ) + hardware_args.add_argument( + "--ipex", + default=False, + action="store_true", + help="Whether or not this should launch a Intel PyTorch Extension (IPEX) training.", + ) + + # Resource selection arguments + resource_args = parser.add_argument_group( + "Resource Selection Arguments", "Arguments for fine-tuning how available hardware should be used." + ) + resource_args.add_argument( + "--mixed_precision", + type=str, + choices=["no", "fp16", "bf16", "fp8"], + help="Whether or not to use mixed precision training. " + "Choose between FP16 and BF16 (bfloat16) training. " + "BF16 training is only supported on Nvidia Ampere GPUs and PyTorch 1.10 or later.", + ) + resource_args.add_argument( + "--num_processes", type=int, default=None, help="The total number of processes to be launched in parallel." + ) + resource_args.add_argument( + "--num_machines", type=int, default=None, help="The total number of machines used in this training." + ) + resource_args.add_argument( + "--num_cpu_threads_per_process", + type=int, + default=None, + help="The number of CPU threads per process. Can be tuned for optimal performance.", + ) + resource_args.add_argument( + "--enable_cpu_affinity", + default=False, + action="store_true", + help="Whether or not CPU affinity and balancing should be enabled. Currently only supported on NVIDIA hardware.", + ) + + # Dynamo arguments + resource_args.add_argument( + "--dynamo_backend", + type=str, + choices=["no"] + [b.lower() for b in DYNAMO_BACKENDS], + help="Choose a backend to optimize your training with dynamo, see more at " + "https://github.com/pytorch/torchdynamo.", + ) + resource_args.add_argument( + "--dynamo_mode", + type=str, + default="default", + choices=TORCH_DYNAMO_MODES, + help="Choose a mode to optimize your training with dynamo.", + ) + resource_args.add_argument( + "--dynamo_use_fullgraph", + default=False, + action="store_true", + help="Whether to use full graph mode for dynamo or it is ok to break model into several subgraphs", + ) + resource_args.add_argument( + "--dynamo_use_dynamic", + default=False, + action="store_true", + help="Whether to enable dynamic shape tracing.", + ) + + # Training Paradigm arguments + paradigm_args = parser.add_argument_group( + "Training Paradigm Arguments", "Arguments for selecting which training paradigm to be used." + ) + paradigm_args.add_argument( + "--use_deepspeed", + default=False, + action="store_true", + help="Whether to use deepspeed.", + ) + paradigm_args.add_argument( + "--use_fsdp", + default=False, + action="store_true", + help="Whether to use fsdp.", + ) + paradigm_args.add_argument( + "--use_megatron_lm", + default=False, + action="store_true", + help="Whether to use Megatron-LM.", + ) + paradigm_args.add_argument( + "--use_xpu", + default=False, + action="store_true", + help="Whether to use IPEX plugin to speed up training on XPU specifically.", + ) + + # distributed GPU training arguments + distributed_args = parser.add_argument_group("Distributed GPUs", "Arguments related to distributed GPU training.") + distributed_args.add_argument( + "--gpu_ids", + default=None, + help="What GPUs (by id) should be used for training on this machine as a comma-seperated list", + ) + distributed_args.add_argument( + "--same_network", + default=False, + action="store_true", + help="Whether all machines used for multinode training exist on the same local network.", + ) + distributed_args.add_argument( + "--machine_rank", type=int, default=None, help="The rank of the machine on which this script is launched." + ) + distributed_args.add_argument( + "--main_process_ip", type=str, default=None, help="The IP address of the machine of rank 0." + ) + distributed_args.add_argument( + "--main_process_port", + type=int, + default=None, + help="The port to use to communicate with the machine of rank 0.", + ) + distributed_args.add_argument( + "-t", + "--tee", + default="0", + type=str, + help="Tee std streams into a log file and also to console.", + ) + distributed_args.add_argument( + "--role", + type=str, + default="default", + help="User-defined role for the workers.", + ) + # Rendezvous related arguments + distributed_args.add_argument( + "--rdzv_backend", + type=str, + default="static", + help="The rendezvous method to use, such as 'static' (the default) or 'c10d'", + ) + distributed_args.add_argument( + "--rdzv_conf", + type=str, + default="", + help="Additional rendezvous configuration (=,=,...).", + ) + distributed_args.add_argument( + "--max_restarts", + type=int, + default=0, + help="Maximum number of worker group restarts before failing.", + ) + distributed_args.add_argument( + "--monitor_interval", + type=float, + default=5, + help="Interval, in seconds, to monitor the state of workers.", + ) + parser.add_argument( + "-m", + "--module", + action="store_true", + help="Change each process to interpret the launch script as a Python module, executing with the same behavior as 'python -m'.", + ) + parser.add_argument( + "--no_python", + action="store_true", + help="Skip prepending the training script with 'python' - just execute it directly. Useful when the script is not a Python script.", + ) + + # TPU arguments + tpu_args = parser.add_argument_group("TPU", "Arguments related to TPU.") + tpu_args.add_argument( + "--tpu_cluster", + action="store_true", + dest="tpu_use_cluster", + help="Whether to use a GCP TPU pod for training.", + ) + tpu_args.add_argument( + "--no_tpu_cluster", + action="store_false", + dest="tpu_use_cluster", + help="Should not be passed explicitly, this is for internal use only.", + ) + tpu_args.add_argument( + "--tpu_use_sudo", + action="store_true", + help="Whether to use `sudo` when running the TPU training script in each pod.", + ) + tpu_args.add_argument( + "--vm", + type=str, + action="append", + help=( + "List of single Compute VM instance names. " + "If not provided we assume usage of instance groups. For TPU pods." + ), + ) + tpu_args.add_argument( + "--env", + type=str, + action="append", + help="List of environment variables to set on the Compute VM instances. For TPU pods.", + ) + tpu_args.add_argument( + "--main_training_function", + type=str, + default=None, + help="The name of the main function to be executed in your script (only for TPU training).", + ) + tpu_args.add_argument( + "--downcast_bf16", + action="store_true", + help="Whether when using bf16 precision on TPUs if both float and double tensors are cast to bfloat16 or if double tensors remain as float32.", + ) + + # DeepSpeed arguments + deepspeed_args = parser.add_argument_group("DeepSpeed Arguments", "Arguments related to DeepSpeed.") + deepspeed_args.add_argument( + "--deepspeed_config_file", + default=None, + type=str, + help="DeepSpeed config file.", + ) + deepspeed_args.add_argument( + "--zero_stage", + default=None, + type=int, + help="DeepSpeed's ZeRO optimization stage (useful only when `use_deepspeed` flag is passed). " + "If unspecified, will default to `2`.", + ) + deepspeed_args.add_argument( + "--offload_optimizer_device", + default=None, + type=str, + help="Decides where (none|cpu|nvme) to offload optimizer states (useful only when `use_deepspeed` flag is passed). " + "If unspecified, will default to 'none'.", + ) + deepspeed_args.add_argument( + "--offload_param_device", + default=None, + type=str, + help="Decides where (none|cpu|nvme) to offload parameters (useful only when `use_deepspeed` flag is passed). " + "If unspecified, will default to 'none'.", + ) + deepspeed_args.add_argument( + "--offload_optimizer_nvme_path", + default=None, + type=str, + help="Decides Nvme Path to offload optimizer states (useful only when `use_deepspeed` flag is passed). " + "If unspecified, will default to 'none'.", + ) + deepspeed_args.add_argument( + "--offload_param_nvme_path", + default=None, + type=str, + help="Decides Nvme Path to offload parameters (useful only when `use_deepspeed` flag is passed). " + "If unspecified, will default to 'none'.", + ) + deepspeed_args.add_argument( + "--gradient_accumulation_steps", + default=None, + type=int, + help="No of gradient_accumulation_steps used in your training script (useful only when `use_deepspeed` flag is passed). " + "If unspecified, will default to `1`.", + ) + deepspeed_args.add_argument( + "--gradient_clipping", + default=None, + type=float, + help="gradient clipping value used in your training script (useful only when `use_deepspeed` flag is passed). " + "If unspecified, will default to `1.0`.", + ) + deepspeed_args.add_argument( + "--zero3_init_flag", + default=None, + type=str, + help="Decides Whether (true|false) to enable `deepspeed.zero.Init` for constructing massive models. " + "Only applicable with DeepSpeed ZeRO Stage-3. If unspecified, will default to `true`.", + ) + deepspeed_args.add_argument( + "--zero3_save_16bit_model", + default=None, + type=str, + help="Decides Whether (true|false) to save 16-bit model weights when using ZeRO Stage-3. " + "Only applicable with DeepSpeed ZeRO Stage-3. If unspecified, will default to `false`.", + ) + deepspeed_args.add_argument( + "--deepspeed_hostfile", + default=None, + type=str, + help="DeepSpeed hostfile for configuring multi-node compute resources.", + ) + deepspeed_args.add_argument( + "--deepspeed_exclusion_filter", + default=None, + type=str, + help="DeepSpeed exclusion filter string when using mutli-node setup.", + ) + deepspeed_args.add_argument( + "--deepspeed_inclusion_filter", + default=None, + type=str, + help="DeepSpeed inclusion filter string when using mutli-node setup.", + ) + deepspeed_args.add_argument( + "--deepspeed_multinode_launcher", + default=None, + type=str, + help="DeepSpeed multi-node launcher to use. If unspecified, will default to `pdsh`.", + ) + + # fsdp arguments + fsdp_args = parser.add_argument_group("FSDP Arguments", "Arguments related to Fully Shared Data Parallelism.") + fsdp_args.add_argument( + "--fsdp_offload_params", + default="false", + type=str, + help="Decides Whether (true|false) to offload parameters and gradients to CPU. (useful only when `use_fsdp` flag is passed).", + ) + fsdp_args.add_argument( + "--fsdp_min_num_params", + type=int, + default=1e8, + help="FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `use_fsdp` flag is passed).", + ) + fsdp_args.add_argument( + "--fsdp_sharding_strategy", + type=str, + default="FULL_SHARD", + help="FSDP's Sharding Strategy. (useful only when `use_fsdp` flag is passed).", + ) + fsdp_args.add_argument( + "--fsdp_auto_wrap_policy", + type=str, + default=None, + help="FSDP's auto wrap policy. (useful only when `use_fsdp` flag is passed).", + ) + fsdp_args.add_argument( + "--fsdp_transformer_layer_cls_to_wrap", + default=None, + type=str, + help="Transformer layer class name (case-sensitive) to wrap ,e.g, `BertLayer`, `GPTJBlock`, `T5Block` .... " + "(useful only when `use_fsdp` flag is passed).", + ) + fsdp_args.add_argument( + "--fsdp_backward_prefetch_policy", + default=None, + type=str, + help="This argument is deprecated and will be removed in version 0.27.0 of 🤗 Accelerate. Use `fsdp_backward_prefetch` instead.", + ) + fsdp_args.add_argument( + "--fsdp_backward_prefetch", + default=None, + type=str, + help="FSDP's backward prefetch policy. (useful only when `use_fsdp` flag is passed).", + ) + fsdp_args.add_argument( + "--fsdp_state_dict_type", + default=None, + type=str, + help="FSDP's state dict type. (useful only when `use_fsdp` flag is passed).", + ) + fsdp_args.add_argument( + "--fsdp_forward_prefetch", + default="false", + type=str, + help="If True, then FSDP explicitly prefetches the next upcoming " + "all-gather while executing in the forward pass (useful only when `use_fsdp` flag is passed).", + ) + fsdp_args.add_argument( + "--fsdp_use_orig_params", + default="true", + type=str, + help="If True, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable paramteres." + " (useful only when `use_fsdp` flag is passed).", + ) + fsdp_args.add_argument( + "--fsdp_cpu_ram_efficient_loading", + default="true", + type=str, + help="If True, only the first process loads the pretrained model checkoint while all other processes have empty weights. " + "Only applicable for 🤗 Transformers. When using this, `--fsdp_sync_module_states` needs to True. " + "(useful only when `use_fsdp` flag is passed).", + ) + fsdp_args.add_argument( + "--fsdp_sync_module_states", + default="true", + type=str, + help="If True, each individually wrapped FSDP unit will broadcast module parameters from rank 0." + " (useful only when `use_fsdp` flag is passed).", + ) + + # megatron_lm args + megatron_lm_args = parser.add_argument_group("Megatron-LM Arguments", "Arguments related to Megatron-LM.") + megatron_lm_args.add_argument( + "--megatron_lm_tp_degree", + type=int, + default=1, + help="Megatron-LM's Tensor Parallelism (TP) degree. (useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_pp_degree", + type=int, + default=1, + help="Megatron-LM's Pipeline Parallelism (PP) degree. (useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_num_micro_batches", + type=int, + default=None, + help="Megatron-LM's number of micro batches when PP degree > 1. (useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_sequence_parallelism", + default=None, + type=str, + help="Decides Whether (true|false) to enable Sequence Parallelism when TP degree > 1. " + "(useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_recompute_activations", + default=None, + type=str, + help="Decides Whether (true|false) to enable Selective Activation Recomputation. " + "(useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_use_distributed_optimizer", + default=None, + type=str, + help="Decides Whether (true|false) to use distributed optimizer " + "which shards optimizer state and gradients across Data Pralellel (DP) ranks. " + "(useful only when `use_megatron_lm` flag is passed).", + ) + megatron_lm_args.add_argument( + "--megatron_lm_gradient_clipping", + default=1.0, + type=float, + help="Megatron-LM's gradient clipping value based on global L2 Norm (0 to disable). " + "(useful only when `use_megatron_lm` flag is passed).", + ) + + # AWS arguments + aws_args = parser.add_argument_group("AWS Arguments", "Arguments related to AWS.") + aws_args.add_argument( + "--aws_access_key_id", + type=str, + default=None, + help="The AWS_ACCESS_KEY_ID used to launch the Amazon SageMaker training job", + ) + aws_args.add_argument( + "--aws_secret_access_key", + type=str, + default=None, + help="The AWS_SECRET_ACCESS_KEY used to launch the Amazon SageMaker training job.", + ) + parser.add_argument( + "--debug", + action="store_true", + help="Whether to print out the torch.distributed stack trace when something fails.", + ) + parser.add_argument( + "training_script", + type=str, + help=( + "The full path to the script to be launched in parallel, followed by all the arguments for the training " + "script." + ), + ) + + # MPI arguments + mpirun_args = parser.add_argument_group("MPI Arguments", "Arguments related to mpirun for Multi-CPU") + mpirun_args.add_argument( + "--mpirun_hostfile", + type=str, + default=None, + help="Location for a hostfile for using Accelerate to launch a multi-CPU training job with mpirun. This will " + "get passed to the MPI --hostfile or -f parameter, depending on which MPI program is installed.", + ) + mpirun_args.add_argument( + "--mpirun_ccl", + type=int, + default=1, + help="The number of oneCCL worker threads when using Accelerate to launch multi-CPU training with mpirun.", + ) + + # Other arguments of the training scripts + parser.add_argument("training_script_args", nargs=argparse.REMAINDER, help="Arguments of the training script.") + + if subparsers is not None: + parser.set_defaults(func=launch_command) + return parser + + +def simple_launcher(args): + cmd, current_env = prepare_simple_launcher_cmd_env(args) + + process = subprocess.Popen(cmd, env=current_env) + process.wait() + if process.returncode != 0: + if not args.quiet: + raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd) + else: + sys.exit(1) + + +def multi_gpu_launcher(args): + import torch.distributed.run as distrib_run + + current_env = prepare_multi_gpu_env(args) + if not check_cuda_p2p_ib_support(): + message = "Using RTX 4000 series which doesn't support faster communication speedups. Ensuring P2P and IB communications are disabled." + warn = False + if "NCCL_P2P_DISABLE" not in current_env: + current_env["NCCL_P2P_DISABLE"] = "1" + warn = True + if "NCCL_IB_DISABLE" not in current_env: + current_env["NCCL_IB_DISABLE"] = "1" + warn = True + if warn: + logger.warning(message) + + debug = getattr(args, "debug", False) + args = _filter_args( + args, + distrib_run.get_args_parser(), + ["--training_script", args.training_script, "--training_script_args", args.training_script_args], + ) + + with patch_environment(**current_env): + try: + distrib_run.run(args) + except Exception: + if is_rich_available() and debug: + console = get_console() + console.print("\n[bold red]Using --debug, `torch.distributed` Stack Trace:[/bold red]") + console.print_exception(suppress=[__file__], show_locals=False) + else: + raise + + +def deepspeed_launcher(args): + import torch.distributed.run as distrib_run + + if not is_deepspeed_available(): + raise ImportError("DeepSpeed is not installed => run `pip3 install deepspeed` or build it from source.") + else: + from deepspeed.launcher.runner import DEEPSPEED_ENVIRONMENT_NAME + + cmd, current_env = prepare_deepspeed_cmd_env(args) + if not check_cuda_p2p_ib_support(): + message = "Using RTX 4000 series which doesn't support faster communication speedups. Ensuring P2P and IB communications are disabled." + warn = False + if "NCCL_P2P_DISABLE" not in current_env: + current_env["NCCL_P2P_DISABLE"] = "1" + warn = True + if "NCCL_IB_DISABLE" not in current_env: + current_env["NCCL_IB_DISABLE"] = "1" + warn = True + if warn: + logger.warning(message) + + if args.num_machines > 1 and args.deepspeed_multinode_launcher != DEEPSPEED_MULTINODE_LAUNCHERS[1]: + with open(DEEPSPEED_ENVIRONMENT_NAME, "a") as f: + valid_env_items = convert_dict_to_env_variables(current_env) + if len(valid_env_items) > 1: + f.writelines(valid_env_items) + + process = subprocess.Popen(cmd, env=current_env) + process.wait() + if process.returncode != 0: + if not args.quiet: + raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd) + else: + sys.exit(1) + else: + debug = getattr(args, "debug", False) + args = _filter_args( + args, + distrib_run.get_args_parser(), + ["--training_script", args.training_script, "--training_script_args", args.training_script_args], + ) + with patch_environment(**current_env): + try: + distrib_run.run(args) + except Exception: + if is_rich_available() and debug: + console = get_console() + console.print("\n[bold red]Using --debug, `torch.distributed` Stack Trace:[/bold red]") + console.print_exception(suppress=[__file__], show_locals=False) + else: + raise + + +def tpu_launcher(args): + import torch_xla.distributed.xla_multiprocessing as xmp + + if args.no_python: + raise ValueError("--no_python cannot be used with TPU launcher") + + args, current_env = prepare_tpu(args, {}) + + if args.module: + mod_name = args.training_script + else: + # Import training_script as a module + script_path = Path(args.training_script) + sys.path.append(str(script_path.parent.resolve())) + mod_name = script_path.stem + + mod = importlib.import_module(mod_name) + if not hasattr(mod, args.main_training_function): + raise ValueError( + f"Your training script should have a function named {args.main_training_function}, or you should pass a " + "different value to `--main_training_function`." + ) + + # Patch sys.argv + sys.argv = [mod.__file__] + args.training_script_args + + main_function = getattr(mod, args.main_training_function) + with patch_environment(**current_env): + xmp.spawn(PrepareForLaunch(main_function), args=(), nprocs=args.num_processes) + + +def tpu_pod_launcher(args): + from torch_xla.distributed import xla_dist + + current_env = {} + args, current_env = prepare_tpu(args, current_env, True) + debug = getattr(args, "debug", False) + + training_script = args.training_script + training_script_args = args.training_script_args + new_args = _filter_args( + args, xla_dist.get_args_parser(), ["--tpu", args.tpu_name, "--positional", "", "--restart-tpuvm-pod-server"] + ) + + if args.tpu_use_sudo: + new_cmd = ["sudo"] + else: + new_cmd = [] + + new_cmd += [ + "accelerate-launch", + "--tpu", + "--no_tpu_cluster", + "--num_machines", + "1", + "--mixed_precision", + "no", + "--dynamo_backend", + "no", + "--num_processes", + str(args.num_processes), + "--main_training_function", + str(args.main_training_function), + training_script, + ] + training_script_args + + new_args.positional = new_cmd + bad_flags = "" + for arg in vars(new_args): + if arg.startswith("docker_"): + value = getattr(new_args, arg) + if value != "" and value is not None: + bad_flags += f'{arg}="{value}"\n' + if bad_flags != "": + raise ValueError( + f"Docker containers are not supported for TPU pod launcher currently, please remove the following flags:\n{bad_flags}" + ) + new_args.env = [f"{k}={v}" for k, v in current_env.items()] + new_args.env.append("ACCELERATE_IN_TPU_POD=1") + try: + xla_dist.resolve_and_execute(new_args) + except Exception: + if is_rich_available() and debug: + console = get_console() + console.print("\n[bold red]Using --debug, `torch_xla.xla_dist` Stack Trace:[/bold red]") + console.print_exception(suppress=[__file__], show_locals=False) + else: + raise + + +def sagemaker_launcher(sagemaker_config: SageMakerConfig, args): + if not is_sagemaker_available(): + raise ImportError( + "Please install sagemaker to be able to launch training on Amazon SageMaker with `pip install accelerate[sagemaker]`" + ) + if args.module or args.no_python: + raise ValueError( + "SageMaker requires a python training script file and cannot be used with --module or --no_python" + ) + + from sagemaker.huggingface import HuggingFace + + args, sagemaker_inputs = prepare_sagemager_args_inputs(sagemaker_config, args) + + huggingface_estimator = HuggingFace(**args) + + huggingface_estimator.fit(inputs=sagemaker_inputs) + print(f"You can find your model data at: {huggingface_estimator.model_data}") + + +def _validate_launch_command(args): + # Sanity checks + if sum([args.multi_gpu, args.cpu, args.tpu, args.use_deepspeed, args.use_fsdp]) > 1: + raise ValueError( + "You can only use one of `--cpu`, `--multi_gpu`, `--tpu`, `--use_deepspeed`, `--use_fsdp` at a time." + ) + if args.multi_gpu and (args.num_processes is not None) and (args.num_processes < 2): + raise ValueError("You need to use at least 2 processes to use `--multi_gpu`.") + + defaults = None + warned = [] + mp_from_config_flag = False + # Get the default from the config file. + if args.config_file is not None or os.path.isfile(default_config_file) and not args.cpu: + defaults = load_config_from_file(args.config_file) + if ( + not args.multi_gpu + and not args.tpu + and not args.tpu_use_cluster + and not args.use_deepspeed + and not args.use_fsdp + and not args.use_megatron_lm + ): + args.use_deepspeed = defaults.distributed_type == DistributedType.DEEPSPEED + args.multi_gpu = ( + True + if defaults.distributed_type + in ( + DistributedType.MULTI_GPU, + DistributedType.MULTI_NPU, + DistributedType.MULTI_MLU, + DistributedType.MULTI_XPU, + ) + else False + ) + args.tpu = defaults.distributed_type == DistributedType.XLA + args.use_fsdp = defaults.distributed_type == DistributedType.FSDP + args.use_megatron_lm = defaults.distributed_type == DistributedType.MEGATRON_LM + args.tpu_use_cluster = defaults.tpu_use_cluster if args.tpu else False + if args.gpu_ids is None: + if defaults.gpu_ids is not None: + args.gpu_ids = defaults.gpu_ids + else: + args.gpu_ids = "all" + + if args.multi_gpu and args.num_machines is None: + args.num_machines = defaults.num_machines + + if len(args.gpu_ids.split(",")) < 2 and (args.gpu_ids != "all") and args.multi_gpu and args.num_machines <= 1: + raise ValueError( + "Less than two GPU ids were configured and tried to run on on multiple GPUs. " + "Please ensure at least two are specified for `--gpu_ids`, or use `--gpu_ids='all'`." + ) + if defaults.compute_environment == ComputeEnvironment.LOCAL_MACHINE: + # Update args with the defaults + for name, attr in defaults.__dict__.items(): + if isinstance(attr, dict): + for k in defaults.deepspeed_config: + setattr(args, k, defaults.deepspeed_config[k]) + for k in defaults.fsdp_config: + arg_to_set = k + if "fsdp" not in arg_to_set: + arg_to_set = "fsdp_" + arg_to_set + setattr(args, arg_to_set, defaults.fsdp_config[k]) + for k in defaults.megatron_lm_config: + setattr(args, k, defaults.megatron_lm_config[k]) + for k in defaults.dynamo_config: + setattr(args, k, defaults.dynamo_config[k]) + for k in defaults.ipex_config: + setattr(args, k, defaults.ipex_config[k]) + for k in defaults.mpirun_config: + setattr(args, k, defaults.mpirun_config[k]) + continue + + # Those args are handled separately + if ( + name not in ["compute_environment", "mixed_precision", "distributed_type"] + and getattr(args, name, None) is None + ): + setattr(args, name, attr) + if not args.debug: + args.debug = defaults.debug + + if not args.mixed_precision: + if defaults.mixed_precision is None: + args.mixed_precision = "no" + else: + args.mixed_precision = defaults.mixed_precision + mp_from_config_flag = True + else: + if args.use_cpu or (args.use_xpu and torch.xpu.is_available()): + native_amp = is_torch_version(">=", "1.10") + else: + native_amp = is_bf16_available(True) + if ( + args.mixed_precision == "bf16" + and not native_amp + and not (args.tpu and is_torch_xla_available(check_is_tpu=True)) + ): + raise ValueError("bf16 mixed precision requires PyTorch >= 1.10 and a supported device.") + + # Silently set the default here + if args.dynamo_backend is None: + args.dynamo_backend = "no" + else: + if args.num_processes is None: + if args.use_xpu and is_xpu_available(): + args.num_processes = torch.xpu.device_count() + elif is_mlu_available(): + args.num_processes = torch.mlu.device_count() + elif is_npu_available(): + args.num_processes = torch.npu.device_count() + else: + args.num_processes = torch.cuda.device_count() + warned.append(f"\t`--num_processes` was set to a value of `{args.num_processes}`") + if args.debug is None: + args.debug = False + if not args.multi_gpu and ( + (args.use_xpu and is_xpu_available() and torch.xpu.device_count() > 1) + or (is_mlu_available() and torch.mlu.device_count() > 1) + or (is_npu_available() and torch.npu.device_count() > 1) + or (torch.cuda.device_count() > 1) + ): + warned.append( + "\t\tMore than one GPU was found, enabling multi-GPU training.\n" + "\t\tIf this was unintended please pass in `--num_processes=1`." + ) + args.multi_gpu = True + if args.num_machines is None: + warned.append("\t`--num_machines` was set to a value of `1`") + args.num_machines = 1 + if args.mixed_precision is None: + warned.append("\t`--mixed_precision` was set to a value of `'no'`") + args.mixed_precision = "no" + if not hasattr(args, "use_cpu"): + args.use_cpu = args.cpu + if args.dynamo_backend is None: + warned.append("\t`--dynamo_backend` was set to a value of `'no'`") + args.dynamo_backend = "no" + if args.debug: + logger.debug("Running script in debug mode, expect distributed operations to be slightly slower.") + + is_aws_env_disabled = defaults is None or ( + defaults is not None and defaults.compute_environment != ComputeEnvironment.AMAZON_SAGEMAKER + ) + if is_aws_env_disabled and args.num_cpu_threads_per_process is None: + args.num_cpu_threads_per_process = 1 + if args.use_cpu and args.num_processes >= 1: + local_size = get_int_from_env( + ["MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"], 1 + ) + threads_per_process = int(psutil.cpu_count(logical=False) / local_size) + if threads_per_process > 1: + args.num_cpu_threads_per_process = threads_per_process + warned.append( + f"\t`--num_cpu_threads_per_process` was set to `{args.num_cpu_threads_per_process}` to improve out-of-box performance when training on CPUs" + ) + + if any(warned): + message = "The following values were not passed to `accelerate launch` and had defaults used instead:\n" + message += "\n".join(warned) + message += ( + "\nTo avoid this warning pass in values for each of the problematic parameters or run `accelerate config`." + ) + logger.warning(message) + return args, defaults, mp_from_config_flag + + +def launch_command(args): + args, defaults, mp_from_config_flag = _validate_launch_command(args) + # Use the proper launcher + if args.use_deepspeed and not args.cpu: + args.deepspeed_fields_from_accelerate_config = list(defaults.deepspeed_config.keys()) if defaults else [] + if mp_from_config_flag: + args.deepspeed_fields_from_accelerate_config.append("mixed_precision") + args.deepspeed_fields_from_accelerate_config = ",".join(args.deepspeed_fields_from_accelerate_config) + deepspeed_launcher(args) + elif args.use_fsdp and not args.cpu: + multi_gpu_launcher(args) + elif args.use_megatron_lm and not args.cpu: + multi_gpu_launcher(args) + elif args.multi_gpu and not args.cpu: + multi_gpu_launcher(args) + elif args.tpu and not args.cpu: + if args.tpu_use_cluster: + tpu_pod_launcher(args) + else: + tpu_launcher(args) + elif defaults is not None and defaults.compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER: + sagemaker_launcher(defaults, args) + else: + simple_launcher(args) + + +def main(): + parser = launch_command_parser() + args = parser.parse_args() + launch_command(args) + + +if __name__ == "__main__": + main() diff --git a/llm/Lib/site-packages/accelerate/commands/menu/__init__.py b/llm/Lib/site-packages/accelerate/commands/menu/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c2c851cc0b192ab8207d3fa68d7409868c84354c --- /dev/null +++ b/llm/Lib/site-packages/accelerate/commands/menu/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .selection_menu import BulletMenu diff --git a/llm/Lib/site-packages/accelerate/commands/menu/__pycache__/__init__.cpython-311.pyc b/llm/Lib/site-packages/accelerate/commands/menu/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..273ad63861ed1efbb4a998150c260c56d38f9b4f Binary files /dev/null and b/llm/Lib/site-packages/accelerate/commands/menu/__pycache__/__init__.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/commands/menu/__pycache__/cursor.cpython-311.pyc b/llm/Lib/site-packages/accelerate/commands/menu/__pycache__/cursor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a21d7849431c39b763e20907f6db08582fc69ade Binary files /dev/null and b/llm/Lib/site-packages/accelerate/commands/menu/__pycache__/cursor.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/commands/menu/__pycache__/helpers.cpython-311.pyc b/llm/Lib/site-packages/accelerate/commands/menu/__pycache__/helpers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e7e60e2c4ba4919ac90cbc167300d5557e96872 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/commands/menu/__pycache__/helpers.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/commands/menu/__pycache__/input.cpython-311.pyc b/llm/Lib/site-packages/accelerate/commands/menu/__pycache__/input.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5ae84461c2e47070b2d361d2e61dd5b0d35cd06 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/commands/menu/__pycache__/input.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/commands/menu/__pycache__/keymap.cpython-311.pyc b/llm/Lib/site-packages/accelerate/commands/menu/__pycache__/keymap.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79f312ba62324160b3ae54085d7c1ababf70c5f7 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/commands/menu/__pycache__/keymap.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/commands/menu/__pycache__/selection_menu.cpython-311.pyc b/llm/Lib/site-packages/accelerate/commands/menu/__pycache__/selection_menu.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a9715d0ede7bdb017d381c2d9ea2b45a0f47a29 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/commands/menu/__pycache__/selection_menu.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/commands/menu/cursor.py b/llm/Lib/site-packages/accelerate/commands/menu/cursor.py new file mode 100644 index 0000000000000000000000000000000000000000..c1f0bb7b68025ae4fe0c2c76c095eb36b4e64f2c --- /dev/null +++ b/llm/Lib/site-packages/accelerate/commands/menu/cursor.py @@ -0,0 +1,65 @@ +# Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A utility for showing and hiding the terminal cursor on Windows and Linux, based on https://github.com/bchao1/bullet +""" + +import os +import sys +from contextlib import contextmanager + + +# Windows only +if os.name == "nt": + import ctypes + import msvcrt # noqa + + class CursorInfo(ctypes.Structure): + # _fields is a specific attr expected by ctypes + _fields_ = [("size", ctypes.c_int), ("visible", ctypes.c_byte)] + + +def hide_cursor(): + if os.name == "nt": + ci = CursorInfo() + handle = ctypes.windll.kernel32.GetStdHandle(-11) + ctypes.windll.kernel32.GetConsoleCursorInfo(handle, ctypes.byref(ci)) + ci.visible = False + ctypes.windll.kernel32.SetConsoleCursorInfo(handle, ctypes.byref(ci)) + elif os.name == "posix": + sys.stdout.write("\033[?25l") + sys.stdout.flush() + + +def show_cursor(): + if os.name == "nt": + ci = CursorInfo() + handle = ctypes.windll.kernel32.GetStdHandle(-11) + ctypes.windll.kernel32.GetConsoleCursorInfo(handle, ctypes.byref(ci)) + ci.visible = True + ctypes.windll.kernel32.SetConsoleCursorInfo(handle, ctypes.byref(ci)) + elif os.name == "posix": + sys.stdout.write("\033[?25h") + sys.stdout.flush() + + +@contextmanager +def hide(): + "Context manager to hide the terminal cursor" + try: + hide_cursor() + yield + finally: + show_cursor() diff --git a/llm/Lib/site-packages/accelerate/commands/menu/helpers.py b/llm/Lib/site-packages/accelerate/commands/menu/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..de46f37ddcf4591167e3e01791391e4b1729034f --- /dev/null +++ b/llm/Lib/site-packages/accelerate/commands/menu/helpers.py @@ -0,0 +1,59 @@ +# Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A variety of helper functions and constants when dealing with terminal menu choices, based on +https://github.com/bchao1/bullet +""" + +import enum +import shutil +import sys + + +TERMINAL_WIDTH, _ = shutil.get_terminal_size() + +CURSOR_TO_CHAR = {"UP": "A", "DOWN": "B", "RIGHT": "C", "LEFT": "D"} + + +class Direction(enum.Enum): + UP = 0 + DOWN = 1 + + +def forceWrite(content, end=""): + sys.stdout.write(str(content) + end) + sys.stdout.flush() + + +def writeColor(content, color, end=""): + forceWrite(f"\u001b[{color}m{content}\u001b[0m", end) + + +def reset_cursor(): + forceWrite("\r") + + +def move_cursor(num_lines: int, direction: str): + forceWrite(f"\033[{num_lines}{CURSOR_TO_CHAR[direction.upper()]}") + + +def clear_line(): + forceWrite(" " * TERMINAL_WIDTH) + reset_cursor() + + +def linebreak(): + reset_cursor() + forceWrite("-" * TERMINAL_WIDTH) diff --git a/llm/Lib/site-packages/accelerate/commands/menu/input.py b/llm/Lib/site-packages/accelerate/commands/menu/input.py new file mode 100644 index 0000000000000000000000000000000000000000..2690f86aa61f7ac648f4a9c2040a34ee35147201 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/commands/menu/input.py @@ -0,0 +1,86 @@ +# Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file contains utilities for handling input from the user and registering specific keys to specific functions, +based on https://github.com/bchao1/bullet +""" + +from typing import List + +from .keymap import KEYMAP, get_character + + +def mark(key: str): + """ + Mark the function with the key code so it can be handled in the register + """ + + def decorator(func): + handle = getattr(func, "handle_key", []) + handle += [key] + func.handle_key = handle + return func + + return decorator + + +def mark_multiple(*keys: List[str]): + """ + Mark the function with the key codes so it can be handled in the register + """ + + def decorator(func): + handle = getattr(func, "handle_key", []) + handle += keys + func.handle_key = handle + return func + + return decorator + + +class KeyHandler(type): + """ + Metaclass that adds the key handlers to the class + """ + + def __new__(cls, name, bases, attrs): + new_cls = super().__new__(cls, name, bases, attrs) + if not hasattr(new_cls, "key_handler"): + new_cls.key_handler = {} + new_cls.handle_input = KeyHandler.handle_input + + for value in attrs.values(): + handled_keys = getattr(value, "handle_key", []) + for key in handled_keys: + new_cls.key_handler[key] = value + return new_cls + + @staticmethod + def handle_input(cls): + "Finds and returns the selected character if it exists in the handler" + char = get_character() + if char != KEYMAP["undefined"]: + char = ord(char) + handler = cls.key_handler.get(char) + if handler: + cls.current_selection = char + return handler(cls) + else: + return None + + +def register(cls): + """Adds KeyHandler metaclass to the class""" + return KeyHandler(cls.__name__, cls.__bases__, cls.__dict__.copy()) diff --git a/llm/Lib/site-packages/accelerate/commands/menu/keymap.py b/llm/Lib/site-packages/accelerate/commands/menu/keymap.py new file mode 100644 index 0000000000000000000000000000000000000000..787db12860fe21c6786dda69c34fcccab114f2f8 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/commands/menu/keymap.py @@ -0,0 +1,133 @@ +# Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utilities relating to parsing raw characters from the keyboard, based on https://github.com/bchao1/bullet +""" + +import os +import string +import sys + + +ARROW_KEY_FLAG = 1 << 8 + +KEYMAP = { + "tab": ord("\t"), + "newline": ord("\r"), + "esc": 27, + "up": 65 + ARROW_KEY_FLAG, + "down": 66 + ARROW_KEY_FLAG, + "right": 67 + ARROW_KEY_FLAG, + "left": 68 + ARROW_KEY_FLAG, + "mod_int": 91, + "undefined": sys.maxsize, + "interrupt": 3, + "insert": 50, + "delete": 51, + "pg_up": 53, + "pg_down": 54, +} + +KEYMAP["arrow_begin"] = KEYMAP["up"] +KEYMAP["arrow_end"] = KEYMAP["left"] + +if sys.platform == "win32": + WIN_CH_BUFFER = [] + WIN_KEYMAP = { + b"\xe0H": KEYMAP["up"] - ARROW_KEY_FLAG, + b"\x00H": KEYMAP["up"] - ARROW_KEY_FLAG, + b"\xe0P": KEYMAP["down"] - ARROW_KEY_FLAG, + b"\x00P": KEYMAP["down"] - ARROW_KEY_FLAG, + b"\xe0M": KEYMAP["right"] - ARROW_KEY_FLAG, + b"\x00M": KEYMAP["right"] - ARROW_KEY_FLAG, + b"\xe0K": KEYMAP["left"] - ARROW_KEY_FLAG, + b"\x00K": KEYMAP["left"] - ARROW_KEY_FLAG, + } + +for i in range(10): + KEYMAP[str(i)] = ord(str(i)) + + +def get_raw_chars(): + "Gets raw characters from inputs" + if os.name == "nt": + import msvcrt + + encoding = "mbcs" + # Flush the keyboard buffer + while msvcrt.kbhit(): + msvcrt.getch() + if len(WIN_CH_BUFFER) == 0: + # Read the keystroke + ch = msvcrt.getch() + + # If it is a prefix char, get second part + if ch in (b"\x00", b"\xe0"): + ch2 = ch + msvcrt.getch() + # Translate actual Win chars to bullet char types + try: + chx = chr(WIN_KEYMAP[ch2]) + WIN_CH_BUFFER.append(chr(KEYMAP["mod_int"])) + WIN_CH_BUFFER.append(chx) + if ord(chx) in ( + KEYMAP["insert"] - 1 << 9, + KEYMAP["delete"] - 1 << 9, + KEYMAP["pg_up"] - 1 << 9, + KEYMAP["pg_down"] - 1 << 9, + ): + WIN_CH_BUFFER.append(chr(126)) + ch = chr(KEYMAP["esc"]) + except KeyError: + ch = ch2[1] + else: + ch = ch.decode(encoding) + else: + ch = WIN_CH_BUFFER.pop(0) + elif os.name == "posix": + import termios + import tty + + fd = sys.stdin.fileno() + old_settings = termios.tcgetattr(fd) + try: + tty.setraw(fd) + ch = sys.stdin.read(1) + finally: + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + return ch + + +def get_character(): + "Gets a character from the keyboard and returns the key code" + char = get_raw_chars() + if ord(char) in [KEYMAP["interrupt"], KEYMAP["newline"]]: + return char + + elif ord(char) == KEYMAP["esc"]: + combo = get_raw_chars() + if ord(combo) == KEYMAP["mod_int"]: + key = get_raw_chars() + if ord(key) >= KEYMAP["arrow_begin"] - ARROW_KEY_FLAG and ord(key) <= KEYMAP["arrow_end"] - ARROW_KEY_FLAG: + return chr(ord(key) + ARROW_KEY_FLAG) + else: + return KEYMAP["undefined"] + else: + return get_raw_chars() + + else: + if char in string.printable: + return char + else: + return KEYMAP["undefined"] diff --git a/llm/Lib/site-packages/accelerate/commands/menu/selection_menu.py b/llm/Lib/site-packages/accelerate/commands/menu/selection_menu.py new file mode 100644 index 0000000000000000000000000000000000000000..ee9a771a54ef666ee46b67ae6c75fb957d49efdd --- /dev/null +++ b/llm/Lib/site-packages/accelerate/commands/menu/selection_menu.py @@ -0,0 +1,144 @@ +# Copyright 2022 The HuggingFace Team and Brian Chao. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Main driver for the selection menu, based on https://github.com/bchao1/bullet +""" + +import builtins +import sys + +from ...utils.imports import _is_package_available +from . import cursor, input +from .helpers import Direction, clear_line, forceWrite, linebreak, move_cursor, reset_cursor, writeColor +from .keymap import KEYMAP + + +in_colab = False +try: + in_colab = _is_package_available("google.colab") +except ModuleNotFoundError: + pass + + +@input.register +class BulletMenu: + """ + A CLI menu to select a choice from a list of choices using the keyboard. + """ + + def __init__(self, prompt: str = None, choices: list = []): + self.position = 0 + self.choices = choices + self.prompt = prompt + if sys.platform == "win32": + self.arrow_char = "*" + else: + self.arrow_char = "➔ " + + def write_choice(self, index, end: str = ""): + if sys.platform != "win32": + writeColor(self.choices[index], 32, end) + else: + forceWrite(self.choices[index], end) + + def print_choice(self, index: int): + "Prints the choice at the given index" + if index == self.position: + forceWrite(f" {self.arrow_char} ") + self.write_choice(index) + else: + forceWrite(f" {self.choices[index]}") + reset_cursor() + + def move_direction(self, direction: Direction, num_spaces: int = 1): + "Should not be directly called, used to move a direction of either up or down" + old_position = self.position + if direction == Direction.DOWN: + if self.position + 1 >= len(self.choices): + return + self.position += num_spaces + else: + if self.position - 1 < 0: + return + self.position -= num_spaces + clear_line() + self.print_choice(old_position) + move_cursor(num_spaces, direction.name) + self.print_choice(self.position) + + @input.mark(KEYMAP["up"]) + def move_up(self): + self.move_direction(Direction.UP) + + @input.mark(KEYMAP["down"]) + def move_down(self): + self.move_direction(Direction.DOWN) + + @input.mark(KEYMAP["newline"]) + def select(self): + move_cursor(len(self.choices) - self.position, "DOWN") + return self.position + + @input.mark(KEYMAP["interrupt"]) + def interrupt(self): + move_cursor(len(self.choices) - self.position, "DOWN") + raise KeyboardInterrupt + + @input.mark_multiple(*[KEYMAP[str(number)] for number in range(10)]) + def select_row(self): + index = int(chr(self.current_selection)) + movement = index - self.position + if index == self.position: + return + if index < len(self.choices): + if self.position > index: + self.move_direction(Direction.UP, -movement) + elif self.position < index: + self.move_direction(Direction.DOWN, movement) + else: + return + else: + return + + def run(self, default_choice: int = 0): + "Start the menu and return the selected choice" + if self.prompt: + linebreak() + forceWrite(self.prompt, "\n") + if in_colab: + forceWrite("Please input a choice index (starting from 0), and press enter", "\n") + else: + forceWrite("Please select a choice using the arrow or number keys, and selecting with enter", "\n") + self.position = default_choice + for i in range(len(self.choices)): + self.print_choice(i) + forceWrite("\n") + move_cursor(len(self.choices) - self.position, "UP") + with cursor.hide(): + while True: + if in_colab: + try: + choice = int(builtins.input()) + except ValueError: + choice = default_choice + else: + choice = self.handle_input() + if choice is not None: + reset_cursor() + for _ in range(len(self.choices) + 1): + move_cursor(1, "UP") + clear_line() + self.write_choice(choice, "\n") + return choice diff --git a/llm/Lib/site-packages/accelerate/commands/test.py b/llm/Lib/site-packages/accelerate/commands/test.py new file mode 100644 index 0000000000000000000000000000000000000000..a0d2f7bcf14727aa13e3438f4cd6e6f140f5bb2f --- /dev/null +++ b/llm/Lib/site-packages/accelerate/commands/test.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from accelerate.test_utils import execute_subprocess_async, path_in_accelerate_package + + +def test_command_parser(subparsers=None): + if subparsers is not None: + parser = subparsers.add_parser("test") + else: + parser = argparse.ArgumentParser("Accelerate test command") + + parser.add_argument( + "--config_file", + default=None, + help=( + "The path to use to store the config file. Will default to a file named default_config.yaml in the cache " + "location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have " + "such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed " + "with 'huggingface'." + ), + ) + + if subparsers is not None: + parser.set_defaults(func=test_command) + return parser + + +def test_command(args): + script_name = path_in_accelerate_package("test_utils", "scripts", "test_script.py") + + if args.config_file is None: + test_args = [script_name] + else: + test_args = f"--config_file={args.config_file} {script_name}".split() + + cmd = ["accelerate-launch"] + test_args + result = execute_subprocess_async(cmd) + if result.returncode == 0: + print("Test is a success! You are ready for your distributed training!") + + +def main(): + parser = test_command_parser() + args = parser.parse_args() + test_command(args) + + +if __name__ == "__main__": + main() diff --git a/llm/Lib/site-packages/accelerate/commands/tpu.py b/llm/Lib/site-packages/accelerate/commands/tpu.py new file mode 100644 index 0000000000000000000000000000000000000000..fc0f07bf8697bfdb6484d3bf817f2e18b1313b00 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/commands/tpu.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import subprocess + +from packaging.version import Version, parse + +from accelerate.commands.config.config_args import default_config_file, load_config_from_file + + +_description = "Run commands across TPU VMs for initial setup before running `accelerate launch`." + + +def tpu_command_parser(subparsers=None): + if subparsers is not None: + parser = subparsers.add_parser("tpu-config", description=_description) + else: + parser = argparse.ArgumentParser("Accelerate tpu-config command", description=_description) + # Core arguments + config_args = parser.add_argument_group( + "Config Arguments", "Arguments that can be configured through `accelerate config`." + ) + config_args.add_argument( + "--config_file", + type=str, + default=None, + help="Path to the config file to use for accelerate.", + ) + config_args.add_argument( + "--tpu_name", + default=None, + help="The name of the TPU to use. If not specified, will use the TPU specified in the config file.", + ) + config_args.add_argument( + "--tpu_zone", + default=None, + help="The zone of the TPU to use. If not specified, will use the zone specified in the config file.", + ) + pod_args = parser.add_argument_group("TPU Arguments", "Arguments for options ran inside the TPU.") + pod_args.add_argument( + "--use_alpha", + action="store_true", + help="Whether to use `gcloud alpha` when running the TPU training script instead of `gcloud`.", + ) + pod_args.add_argument( + "--command_file", + default=None, + help="The path to the file containing the commands to run on the pod on startup.", + ) + pod_args.add_argument( + "--command", + action="append", + nargs="+", + help="A command to run on the pod. Can be passed multiple times.", + ) + pod_args.add_argument( + "--install_accelerate", + action="store_true", + help="Whether to install accelerate on the pod. Defaults to False.", + ) + pod_args.add_argument( + "--accelerate_version", + default="latest", + help="The version of accelerate to install on the pod. If not specified, will use the latest pypi version. Specify 'dev' to install from GitHub.", + ) + pod_args.add_argument( + "--debug", action="store_true", help="If set, will print the command that would be run instead of running it." + ) + + if subparsers is not None: + parser.set_defaults(func=tpu_command_launcher) + return parser + + +def tpu_command_launcher(args): + defaults = None + + # Get the default from the config file if it exists. + if args.config_file is not None or os.path.isfile(default_config_file): + defaults = load_config_from_file(args.config_file) + if not args.command_file and defaults.command_file is not None and not args.command: + args.command_file = defaults.command_file + if not args.command and defaults.commands is not None: + args.command = defaults.commands + if not args.tpu_name: + args.tpu_name = defaults.tpu_name + if not args.tpu_zone: + args.tpu_zone = defaults.tpu_zone + if args.accelerate_version == "dev": + args.accelerate_version = "git+https://github.com/huggingface/accelerate.git" + elif args.accelerate_version == "latest": + args.accelerate_version = "accelerate -U" + elif isinstance(parse(args.accelerate_version), Version): + args.accelerate_version = f"accelerate=={args.accelerate_version}" + + if not args.command_file and not args.command: + raise ValueError("You must specify either a command file or a command to run on the pod.") + + if args.command_file: + with open(args.command_file) as f: + args.command = [f.read().splitlines()] + + # To turn list of lists into list of strings + if isinstance(args.command[0], list): + args.command = [line for cmd in args.command for line in cmd] + # Default to the shared folder and install accelerate + new_cmd = ["cd /usr/share"] + if args.install_accelerate: + new_cmd += [f"pip install {args.accelerate_version}"] + new_cmd += args.command + args.command = "; ".join(new_cmd) + + # Then send it to gcloud + # Eventually try to use google-api-core to do this instead of subprocess + cmd = ["gcloud"] + if args.use_alpha: + cmd += ["alpha"] + cmd += [ + "compute", + "tpus", + "tpu-vm", + "ssh", + args.tpu_name, + "--zone", + args.tpu_zone, + "--command", + args.command, + "--worker", + "all", + ] + if args.debug: + print(f"Running {' '.join(cmd)}") + return + subprocess.run(cmd) + print("Successfully setup pod.") + + +def main(): + parser = tpu_command_parser() + args = parser.parse_args() + + tpu_command_launcher(args) diff --git a/llm/Lib/site-packages/accelerate/commands/utils.py b/llm/Lib/site-packages/accelerate/commands/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b65215fac7666b475af98b17e264ef6701239bc1 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/commands/utils.py @@ -0,0 +1,120 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + + +class _StoreAction(argparse.Action): + """ + Custom action that allows for `-` or `_` to be passed in for an argument. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + new_option_strings = [] + for option_string in self.option_strings: + new_option_strings.append(option_string) + if "_" in option_string[2:]: + # Add `-` version to the option string + new_option_strings.append(option_string.replace("_", "-")) + self.option_strings = new_option_strings + + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, values) + + +class _StoreConstAction(_StoreAction): + """ + Same as `argparse._StoreConstAction` but uses the custom `_StoreAction`. + """ + + def __init__(self, option_strings, dest, const, default=None, required=False, help=None): + super().__init__( + option_strings=option_strings, + dest=dest, + nargs=0, + const=const, + default=default, + required=required, + help=help, + ) + + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, self.const) + + +class _StoreTrueAction(_StoreConstAction): + """ + Same as `argparse._StoreTrueAction` but uses the custom `_StoreConstAction`. + """ + + def __init__( + self, + option_strings, + dest, + default=None, + required=False, + help=None, + ): + super().__init__( + option_strings=option_strings, dest=dest, const=True, default=default, required=required, help=help + ) + + +class CustomArgumentGroup(argparse._ArgumentGroup): + """ + Custom argument group that allows for the use of `-` or `_` in arguments passed and overrides the help for each + when applicable. + """ + + def _add_action(self, action): + args = vars(action) + if isinstance(action, argparse._StoreTrueAction): + action = _StoreTrueAction( + args["option_strings"], args["dest"], args["default"], args["required"], args["help"] + ) + elif isinstance(action, argparse._StoreConstAction): + action = _StoreConstAction( + args["option_strings"], + args["dest"], + args["const"], + args["default"], + args["required"], + args["help"], + ) + elif isinstance(action, argparse._StoreAction): + action = _StoreAction(**args) + action = super()._add_action(action) + return action + + +class CustomArgumentParser(argparse.ArgumentParser): + """ + Custom argument parser that allows for the use of `-` or `_` in arguments passed and overrides the help for each + when applicable. + """ + + def add_argument(self, *args, **kwargs): + if "action" in kwargs: + # Translate action -> class + if kwargs["action"] == "store_true": + kwargs["action"] = _StoreTrueAction + else: + kwargs["action"] = _StoreAction + super().add_argument(*args, **kwargs) + + def add_argument_group(self, *args, **kwargs): + group = CustomArgumentGroup(self, *args, **kwargs) + self._action_groups.append(group) + return group diff --git a/llm/Lib/site-packages/accelerate/data_loader.py b/llm/Lib/site-packages/accelerate/data_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..0764e0971a3845d04dc1c7fc500d0c06f67d2c0e --- /dev/null +++ b/llm/Lib/site-packages/accelerate/data_loader.py @@ -0,0 +1,1093 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from contextlib import suppress +from typing import Callable, List, Optional, Union + +import torch +from torch.utils.data import BatchSampler, DataLoader, IterableDataset, RandomSampler + +from .logging import get_logger +from .state import AcceleratorState, DistributedType, GradientState, is_torch_xla_available +from .utils import ( + RNGType, + broadcast, + broadcast_object_list, + concatenate, + find_batch_size, + get_data_structure, + initialize_tensors, + is_torch_version, + send_to_device, + slice_tensors, + synchronize_rng_states, +) + + +logger = get_logger(__name__) + +# kwargs of the DataLoader in min version 1.4.0. +_PYTORCH_DATALOADER_KWARGS = { + "batch_size": 1, + "shuffle": False, + "sampler": None, + "batch_sampler": None, + "num_workers": 0, + "collate_fn": None, + "pin_memory": False, + "drop_last": False, + "timeout": 0, + "worker_init_fn": None, + "multiprocessing_context": None, + "generator": None, + "prefetch_factor": 2, + "persistent_workers": False, +} + +# kwargs added after by version +_PYTORCH_DATALOADER_ADDITIONAL_KWARGS = {} + +for v, additional_kwargs in _PYTORCH_DATALOADER_ADDITIONAL_KWARGS.items(): + if is_torch_version(">=", v): + _PYTORCH_DATALOADER_KWARGS.update(additional_kwargs) + + +class SeedableRandomSampler(RandomSampler): + """ + Same as a random sampler, except that in `__iter__` a seed can be used. + + Needed specifically in distributed cases, when the random generator for each GPU needs to start from the same seed + and be fully reproducable on multiple iterations. + + If a custom `generator` is passed, it will rely on its initial seed as well as the current iteration it is on + (stored in `self.epoch`). + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.epoch = 0 + self.initial_seed = torch.random.initial_seed() + + def __iter__(self): + if self.generator is None: + self.generator = torch.Generator() + self.generator.manual_seed(self.initial_seed) + + # Allow `self.epoch` to modify the seed of the generator + seed = self.epoch + self.initial_seed + # print("Setting seed at epoch", self.epoch, seed) + self.generator.manual_seed(seed) + yield from super().__iter__() + self.set_epoch(self.epoch + 1) + + def set_epoch(self, epoch: int): + "Sets the current iteration of the sampler." + self.epoch = epoch + + +class BatchSamplerShard(BatchSampler): + """ + Wraps a PyTorch `BatchSampler` to generate batches for one of the processes only. Instances of this class will + always yield a number of batches that is a round multiple of `num_processes` and that all have the same size. + Depending on the value of the `drop_last` attribute of the batch sampler passed, it will either stop the iteration + at the first batch that would be too small / not present on all processes or loop with indices from the beginning. + + Args: + batch_sampler (`torch.utils.data.sampler.BatchSampler`): + The batch sampler to split in several shards. + num_processes (`int`, *optional*, defaults to 1): + The number of processes running concurrently. + process_index (`int`, *optional*, defaults to 0): + The index of the current process. + split_batches (`bool`, *optional*, defaults to `False`): + Whether the shards should be created by splitting a batch to give a piece of it on each process, or by + yielding different full batches on each process. + + On two processes with a sampler of `[[0, 1, 2, 3], [4, 5, 6, 7]]`, this will result in: + + - the sampler on process 0 to yield `[0, 1, 2, 3]` and the sampler on process 1 to yield `[4, 5, 6, 7]` if + this argument is set to `False`. + - the sampler on process 0 to yield `[0, 1]` then `[4, 5]` and the sampler on process 1 to yield `[2, 3]` + then `[6, 7]` if this argument is set to `True`. + even_batches (`bool`, *optional*, defaults to `True`): + Whether or not to loop back at the beginning of the sampler when the number of samples is not a round + multiple of (original batch size / number of processes). + + + + `BatchSampler`s with varying batch sizes are not enabled by default. To enable this behaviour, set `even_batches` + equal to `False` + + """ + + def __init__( + self, + batch_sampler: BatchSampler, + num_processes: int = 1, + process_index: int = 0, + split_batches: bool = False, + even_batches: bool = True, + ): + if split_batches and batch_sampler.batch_size % num_processes != 0: + raise ValueError( + f"To use `BatchSamplerShard` in `split_batches` mode, the batch size ({batch_sampler.batch_size}) " + f"needs to be a round multiple of the number of processes ({num_processes})." + ) + self.batch_sampler = batch_sampler + self.num_processes = num_processes + self.process_index = process_index + self.split_batches = split_batches + self.even_batches = even_batches + self.batch_size = getattr(batch_sampler, "batch_size", None) + self.drop_last = getattr(batch_sampler, "drop_last", False) + if self.batch_size is None and self.even_batches: + raise ValueError( + "You need to use `even_batches=False` when the batch sampler has no batch size. If you " + "are not calling this method directly, set `accelerator.even_batches=False` instead." + ) + + @property + def total_length(self): + return len(self.batch_sampler) + + def __len__(self): + if self.split_batches: + # Split batches does not change the length of the batch sampler + return len(self.batch_sampler) + if len(self.batch_sampler) % self.num_processes == 0: + # If the length is a round multiple of the number of processes, it's easy. + return len(self.batch_sampler) // self.num_processes + length = len(self.batch_sampler) // self.num_processes + if self.drop_last: + # Same if we drop the remainder. + return length + elif self.even_batches: + # When we even batches we always get +1 + return length + 1 + else: + # Otherwise it depends on the process index. + return length + 1 if self.process_index < len(self.batch_sampler) % self.num_processes else length + + def __iter__(self): + return self._iter_with_split() if self.split_batches else self._iter_with_no_split() + + def _iter_with_split(self): + initial_data = [] + batch_length = self.batch_sampler.batch_size // self.num_processes + for idx, batch in enumerate(self.batch_sampler): + if idx == 0: + initial_data = batch + if len(batch) == self.batch_size: + # If the batch is full, we yield the part of it this process is responsible of. + yield batch[batch_length * self.process_index : batch_length * (self.process_index + 1)] + + # If drop_last is True of the last batch was full, iteration is over, otherwise... + if not self.drop_last and len(initial_data) > 0 and len(batch) < self.batch_size: + if not self.even_batches: + if len(batch) > batch_length * self.process_index: + yield batch[batch_length * self.process_index : batch_length * (self.process_index + 1)] + else: + # For degenerate cases where the dataset has less than num_process * batch_size samples + while len(initial_data) < self.batch_size: + initial_data += initial_data + batch = batch + initial_data + yield batch[batch_length * self.process_index : batch_length * (self.process_index + 1)] + + def _iter_with_no_split(self): + initial_data = [] + batch_to_yield = [] + for idx, batch in enumerate(self.batch_sampler): + # We gather the initial indices in case we need to circle back at the end. + if not self.drop_last and idx < self.num_processes: + initial_data += batch + # We identify the batch to yield but wait until we ar sure every process gets a full batch before actually + # yielding it. + if idx % self.num_processes == self.process_index: + batch_to_yield = batch + if idx % self.num_processes == self.num_processes - 1 and ( + self.batch_size is None or len(batch) == self.batch_size + ): + yield batch_to_yield + batch_to_yield = [] + + # If drop_last is True, iteration is over, otherwise... + if not self.drop_last and len(initial_data) > 0: + if not self.even_batches: + if len(batch_to_yield) > 0: + yield batch_to_yield + else: + # ... we yield the complete batch we had saved before if it has the proper length + if len(batch_to_yield) == self.batch_size: + yield batch_to_yield + + # For degenerate cases where the dataset has less than num_process * batch_size samples + while len(initial_data) < self.num_processes * self.batch_size: + initial_data += initial_data + + # If the last batch seen was of the proper size, it has been yielded by its process so we move to the next + if len(batch) == self.batch_size: + batch = [] + idx += 1 + + # Make sure we yield a multiple of self.num_processes batches + cycle_index = 0 + while idx % self.num_processes != 0 or len(batch) > 0: + end_index = cycle_index + self.batch_size - len(batch) + batch += initial_data[cycle_index:end_index] + if idx % self.num_processes == self.process_index: + yield batch + cycle_index = end_index + batch = [] + idx += 1 + + +class IterableDatasetShard(IterableDataset): + """ + Wraps a PyTorch `IterableDataset` to generate samples for one of the processes only. Instances of this class will + always yield a number of samples that is a round multiple of the actual batch size (depending of the value of + `split_batches`, this is either `batch_size` or `batch_size x num_processes`). Depending on the value of the + `drop_last` attribute of the batch sampler passed, it will either stop the iteration at the first batch that would + be too small or loop with indices from the beginning. + + Args: + dataset (`torch.utils.data.dataset.IterableDataset`): + The batch sampler to split in several shards. + batch_size (`int`, *optional*, defaults to 1): + The size of the batches per shard (if `split_batches=False`) or the size of the batches (if + `split_batches=True`). + drop_last (`bool`, *optional*, defaults to `False`): + Whether or not to drop the last incomplete batch or complete the last batches by using the samples from the + beginning. + num_processes (`int`, *optional*, defaults to 1): + The number of processes running concurrently. + process_index (`int`, *optional*, defaults to 0): + The index of the current process. + split_batches (`bool`, *optional*, defaults to `False`): + Whether the shards should be created by splitting a batch to give a piece of it on each process, or by + yielding different full batches on each process. + + On two processes with an iterable dataset yielding of `[0, 1, 2, 3, 4, 5, 6, 7]`, this will result in: + + - the shard on process 0 to yield `[0, 1, 2, 3]` and the shard on process 1 to yield `[4, 5, 6, 7]` if this + argument is set to `False`. + - the shard on process 0 to yield `[0, 1, 4, 5]` and the sampler on process 1 to yield `[2, 3, 6, 7]` if + this argument is set to `True`. + """ + + def __init__( + self, + dataset: IterableDataset, + batch_size: int = 1, + drop_last: bool = False, + num_processes: int = 1, + process_index: int = 0, + split_batches: bool = False, + ): + if split_batches and batch_size > 1 and batch_size % num_processes != 0: + raise ValueError( + f"To use `IterableDatasetShard` in `split_batches` mode, the batch size ({batch_size}) " + f"needs to be a round multiple of the number of processes ({num_processes})." + ) + self.dataset = dataset + self.batch_size = batch_size + self.drop_last = drop_last + self.num_processes = num_processes + self.process_index = process_index + self.split_batches = split_batches + + def set_epoch(self, epoch): + self.epoch = epoch + if hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(epoch) + + def __len__(self): + # We will just raise the downstream error if the underlying dataset is not sized + if self.drop_last: + return (len(self.dataset) // (self.batch_size * self.num_processes)) * self.batch_size + else: + return math.ceil(len(self.dataset) / (self.batch_size * self.num_processes)) * self.batch_size + + def __iter__(self): + if ( + not hasattr(self.dataset, "set_epoch") + and hasattr(self.dataset, "generator") + and isinstance(self.dataset.generator, torch.Generator) + ): + self.dataset.generator.manual_seed(self.epoch) + real_batch_size = self.batch_size if self.split_batches else (self.batch_size * self.num_processes) + process_batch_size = (self.batch_size // self.num_processes) if self.split_batches else self.batch_size + process_slice = range(self.process_index * process_batch_size, (self.process_index + 1) * process_batch_size) + + first_batch = None + current_batch = [] + for element in self.dataset: + current_batch.append(element) + # Wait to have a full batch before yielding elements. + if len(current_batch) == real_batch_size: + for i in process_slice: + yield current_batch[i] + if first_batch is None: + first_batch = current_batch.copy() + current_batch = [] + + # Finished if drop_last is True, otherwise complete the last batch with elements from the beginning. + if not self.drop_last and len(current_batch) > 0: + if first_batch is None: + first_batch = current_batch.copy() + while len(current_batch) < real_batch_size: + current_batch += first_batch + for i in process_slice: + yield current_batch[i] + + +class DataLoaderStateMixin: + """ + Mixin class that adds a state to a `DataLoader` to keep track of the status inside the dataloader such as at the + end of the iteration, the number of items in the dataset in the last batch relative to the batch size, and other + useful information that might be needed. + + **Available attributes:** + + - **end_of_dataloader** (`bool`) -- Whether at the last iteration or batch + - **remainder** (`int`) -- The number of items that are remaining in the last batch, relative to the total + batch size + + """ + + def __init_subclass__(cls, **kwargs): + cls.end_of_dataloader = False + cls.remainder = -1 + + def reset(self): + self.end_of_dataloader = False + self.remainder = -1 + + def begin(self): + "Prepares the gradient state for the current dataloader" + self.reset() + with suppress(Exception): + if not self._drop_last: + length = getattr(self.dataset, "total_dataset_length", len(self.dataset)) + self.remainder = length % self.total_batch_size + self.gradient_state._add_dataloader(self) + + def end(self): + "Cleans up the gradient state after exiting the dataloader" + self.gradient_state._remove_dataloader(self) + + +class DataLoaderShard(DataLoader, DataLoaderStateMixin): + """ + Subclass of a PyTorch `DataLoader` that will deal with device placement and current distributed setup. + + Args: + dataset (`torch.utils.data.dataset.Dataset`): + The dataset to use to build this datalaoder. + device (`torch.device`, *optional*): + If passed, the device to put all batches on. + rng_types (list of `str` or [`~utils.RNGType`]): + The list of random number generators to synchronize at the beginning of each iteration. Should be one or + several of: + + - `"torch"`: the base torch random number generator + - `"cuda"`: the CUDA random number generator (GPU only) + - `"xla"`: the XLA random number generator (TPU only) + - `"generator"`: an optional `torch.Generator` + synchronized_generator (`torch.Generator`, *optional*): + A random number generator to keep synchronized across processes. + skip_batches (`int`, *optional*, defaults to 0): + The number of batches to skip at the beginning. + **kwargs (additional keyword arguments, *optional*): + All other keyword arguments to pass to the regular `DataLoader` initialization. + + **Available attributes:** + + - **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes. + Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total + number of processes + + - **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes. + """ + + def __init__( + self, + dataset, + device=None, + rng_types=None, + synchronized_generator=None, + skip_batches=0, + _drop_last: bool = False, + **kwargs, + ): + super().__init__(dataset, **kwargs) + self.device = device + self.rng_types = rng_types + self.synchronized_generator = synchronized_generator + self.skip_batches = skip_batches + self.gradient_state = GradientState() + self._drop_last = _drop_last + self.iteration = 0 + + def __iter__(self): + if self.rng_types is not None: + synchronize_rng_states(self.rng_types, self.synchronized_generator) + self.begin() + + self.set_epoch(self.iteration) + dataloader_iter = super().__iter__() + # We iterate one batch ahead to check when we are at the end + try: + current_batch = next(dataloader_iter) + except StopIteration: + yield + + batch_index = 0 + while True: + try: + # But we still move it to the device so it is done before `StopIteration` is reached + if self.device is not None: + current_batch = send_to_device(current_batch, self.device) + next_batch = next(dataloader_iter) + if batch_index >= self.skip_batches: + yield current_batch + batch_index += 1 + current_batch = next_batch + except StopIteration: + self.end_of_dataloader = True + if batch_index >= self.skip_batches: + yield current_batch + break + + self.iteration += 1 + self.end() + + def set_epoch(self, epoch: int): + # In case it is manually passed in, the user can set it to what they like + if self.iteration != epoch: + self.iteration = epoch + if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"): + self.batch_sampler.sampler.set_epoch(epoch) + # We support if a custom `Dataset` implementation has `set_epoch` + # or in general HF datasets `Datasets` + elif hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(epoch) + + @property + def total_batch_size(self): + batch_sampler = self.sampler if isinstance(self.sampler, BatchSampler) else self.batch_sampler + return ( + batch_sampler.batch_size + if getattr(batch_sampler, "split_batches", False) + else (batch_sampler.batch_size * getattr(batch_sampler, "num_processes", 1)) + ) + + @property + def total_dataset_length(self): + if hasattr(self.dataset, "total_length"): + return self.dataset.total_length + else: + return len(self.dataset) + + +if is_torch_xla_available(): + import torch_xla.distributed.parallel_loader as xpl + + class MpDeviceLoaderWrapper(xpl.MpDeviceLoader): + """ + Wrapper for the xpl.MpDeviceLoader class that knows the total batch size. + + XLA preloading threads will all call DataLoaderShard's __iter__(). Remove rng_types from DataLoaderShard to + prevent it from using the XLA device in the preloading threads, and synchronize the RNG once from the main + thread only. + + **Available attributes:** + + - **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes. + Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total + number of processes + + - **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes. + """ + + def __init__(self, dataloader: DataLoaderShard, device: torch.device): + super().__init__(dataloader, device) + self._rng_types = self._loader.rng_types + self._loader.rng_types = None + + def __iter__(self): + if self._rng_types is not None: + synchronize_rng_states(self._rng_types, self._loader.synchronized_generator) + + return super().__iter__() + + @property + def total_batch_size(self): + return self._loader.total_batch_size + + @property + def total_dataset_length(self): + return self._loader.total_dataset_length + + @property + def batch_sampler(self): + return self._loader.batch_sampler + + +class DataLoaderDispatcher(DataLoader, DataLoaderStateMixin): + """ + Subclass of a PyTorch `DataLoader` that will iterate and preprocess on process 0 only, then dispatch on each + process their part of the batch. + + Args: + split_batches (`bool`, *optional*, defaults to `False`): + Whether the resulting `DataLoader` should split the batches of the original data loader across devices or + yield full batches (in which case it will yield batches starting at the `process_index`-th and advancing of + `num_processes` batches at each iteration). Another way to see this is that the observed batch size will be + the same as the initial `dataloader` if this option is set to `True`, the batch size of the initial + `dataloader` multiplied by `num_processes` otherwise. Setting this option to `True` requires that the batch + size of the `dataloader` is a round multiple of `batch_size`. + skip_batches (`int`, *optional*, defaults to 0): + The number of batches to skip at the beginning of an iteration. + + **Available attributes:** + + - **total_batch_size** (`int`) -- Total batch size of the dataloader across all processes. + Equal to the original batch size when `split_batches=True`; otherwise the original batch size * the total + number of processes + + - **total_dataset_length** (`int`) -- Total length of the inner dataset across all processes. + """ + + def __init__( + self, dataset, split_batches: bool = False, skip_batches=0, _drop_last: bool = False, slice_fn=None, **kwargs + ): + shuffle = False + if is_torch_version(">=", "1.11.0"): + from torch.utils.data.datapipes.iter.combinatorics import ShufflerIterDataPipe + + # We need to save the shuffling state of the DataPipe + if isinstance(dataset, ShufflerIterDataPipe): + shuffle = dataset._shuffle_enabled + super().__init__(dataset, **kwargs) + self.split_batches = split_batches + if shuffle: + torch.utils.data.graph_settings.apply_shuffle_settings(dataset, shuffle=shuffle) + + self.gradient_state = GradientState() + self.state = AcceleratorState() + self._drop_last = _drop_last + self.skip_batches = skip_batches + + self.slice_fn = slice_tensors if slice_fn is None else slice_fn + self.iteration = 0 + + def _fetch_batches(self, iterator): + batches, batch = None, None + # On process 0, we gather the batch to dispatch. + if self.state.process_index == 0: + try: + if self.split_batches: + # One batch of the main iterator is dispatched and split. + batch = next(iterator) + else: + # num_processes batches of the main iterator are concatenated then dispatched and split. + # We add the batches one by one so we have the remainder available when drop_last=False. + batches = [] + for _ in range(self.state.num_processes): + batches.append(next(iterator)) + try: + batch = concatenate(batches, dim=0) + except RuntimeError as e: + raise RuntimeError( + "You can't use batches of different size with `dispatch_batches=True` or when using an `IterableDataset`." + "either pass `dispatch_batches=False` and have each process fetch its own batch " + " or pass `split_batches=True`. By doing so, the main process will fetch a full batch and " + "slice it into `num_processes` batches for each process." + ) from e + # In both cases, we need to get the structure of the batch that we will broadcast on other + # processes to initialize the tensors with the right shape. + # data_structure, stop_iteration + batch_info = [get_data_structure(batch), False] + except StopIteration: + batch_info = [None, True] + else: + batch_info = [None, self._stop_iteration] + # This is inplace, so after this instruction, every process has the same `batch_info` as process 0. + broadcast_object_list(batch_info) + self._stop_iteration = batch_info[1] + if self._stop_iteration: + # If drop_last is False and split_batches is False, we may have a remainder to take care of. + if not self.split_batches and not self._drop_last: + if self.state.process_index == 0 and len(batches) > 0: + batch = concatenate(batches, dim=0) + batch_info = [get_data_structure(batch), False] + else: + batch_info = [None, True] + broadcast_object_list(batch_info) + return batch, batch_info + + def __iter__(self): + self.begin() + self.set_epoch(self.iteration) + main_iterator = None + if is_torch_version(">=", "2.0.1"): + # NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts + # shared seed to all dist processes. Thus, we need to create iterator for all dist processes. + # But, we only iterate through the DataLoader on process 0. + main_iterator = super().__iter__() + elif self.state.process_index == 0: + main_iterator = super().__iter__() + stop_iteration = False + self._stop_iteration = False + first_batch = None + next_batch, next_batch_info = self._fetch_batches(main_iterator) + batch_index = 0 + while not stop_iteration: + batch, batch_info = next_batch, next_batch_info + + if self.state.process_index != 0: + # Initialize tensors on other processes than process 0. + batch = initialize_tensors(batch_info[0]) + batch = send_to_device(batch, self.state.device) + # Broadcast the batch before splitting it. + batch = broadcast(batch, from_process=0) + + if not self._drop_last and first_batch is None: + # We keep at least num processes elements of the first batch to be able to complete the last batch + first_batch = self.slice_fn( + batch, + slice(0, self.state.num_processes), + process_index=self.state.process_index, + num_processes=self.state.num_processes, + ) + + if batch is None: + raise ValueError( + f"Batch does not contain any data (`{batch}`). At the end of all iterable data available before expected stop iteration." + ) + + observed_batch_size = find_batch_size(batch) + batch_size = observed_batch_size // self.state.num_processes + + stop_iteration = self._stop_iteration + if not stop_iteration: + # We may still be at the end of the dataloader without knowing it yet: if there is nothing left in + # the dataloader since the number of batches is a round multiple of the number of processes. + next_batch, next_batch_info = self._fetch_batches(main_iterator) + # next_batch_info[0] is None when there are no more batches, otherwise we still need to process them. + if self._stop_iteration and next_batch_info[0] is None: + stop_iteration = True + + if not self._drop_last and stop_iteration and observed_batch_size % self.state.num_processes != 0: + # If the last batch is not complete, let's add the first batch to it. + batch = concatenate([batch, first_batch], dim=0) + # Batch size computation above is wrong, it's off by 1 so we fix it. + batch_size += 1 + + data_slice = slice(self.state.process_index * batch_size, (self.state.process_index + 1) * batch_size) + batch = self.slice_fn( + batch, + data_slice, + process_index=self.state.process_index, + num_processes=self.state.num_processes, + ) + + if stop_iteration: + self.end_of_dataloader = True + self.remainder = observed_batch_size + if batch_index >= self.skip_batches: + yield batch + batch_index += 1 + self.iteration += 1 + self.end() + + def set_epoch(self, epoch: int): + # In case it is manually passed in, the user can set it to what they like + if self.iteration != epoch: + self.iteration = epoch + if hasattr(self.batch_sampler.sampler, "set_epoch"): + self.batch_sampler.sampler.set_epoch(epoch) + elif hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(epoch) + + def __len__(self): + whole_length = super().__len__() + if self.split_batches: + return whole_length + elif self._drop_last: + return whole_length // self.state.num_processes + else: + return math.ceil(whole_length / self.state.num_processes) + + @property + def total_batch_size(self): + return ( + self.dataset.batch_size if self.split_batches else (self.dataset.batch_size * self.dataset.num_processes) + ) + + @property + def total_dataset_length(self): + return len(self.dataset) + + +def prepare_data_loader( + dataloader: DataLoader, + device: Optional[torch.device] = None, + num_processes: Optional[int] = None, + process_index: Optional[int] = None, + split_batches: bool = False, + put_on_device: bool = False, + rng_types: Optional[List[Union[str, RNGType]]] = None, + dispatch_batches: Optional[bool] = None, + even_batches: bool = True, + slice_fn_for_dispatch: Optional[Callable] = None, + use_seedable_sampler: bool = False, +) -> DataLoader: + """ + Wraps a PyTorch `DataLoader` to generate batches for one of the processes only. + + Depending on the value of the `drop_last` attribute of the `dataloader` passed, it will either stop the iteration + at the first batch that would be too small / not present on all processes or loop with indices from the beginning. + + Args: + dataloader (`torch.utils.data.dataloader.DataLoader`): + The data loader to split across several devices. + device (`torch.device`): + The target device for the returned `DataLoader`. + num_processes (`int`, *optional*): + The number of processes running concurrently. Will default to the value given by + [`~state.AcceleratorState`]. + process_index (`int`, *optional*): + The index of the current process. Will default to the value given by [`~state.AcceleratorState`]. + split_batches (`bool`, *optional*, defaults to `False`): + Whether the resulting `DataLoader` should split the batches of the original data loader across devices or + yield full batches (in which case it will yield batches starting at the `process_index`-th and advancing of + `num_processes` batches at each iteration). + + Another way to see this is that the observed batch size will be the same as the initial `dataloader` if + this option is set to `True`, the batch size of the initial `dataloader` multiplied by `num_processes` + otherwise. + + Setting this option to `True` requires that the batch size of the `dataloader` is a round multiple of + `batch_size`. + put_on_device (`bool`, *optional*, defaults to `False`): + Whether or not to put the batches on `device` (only works if the batches are nested list, tuples or + dictionaries of tensors). + rng_types (list of `str` or [`~utils.RNGType`]): + The list of random number generators to synchronize at the beginning of each iteration. Should be one or + several of: + + - `"torch"`: the base torch random number generator + - `"cuda"`: the CUDA random number generator (GPU only) + - `"xla"`: the XLA random number generator (TPU only) + - `"generator"`: the `torch.Generator` of the sampler (or batch sampler if there is no sampler in your + dataloader) or of the iterable dataset (if it exists) if the underlying dataset is of that type. + + dispatch_batches (`bool`, *optional*): + If set to `True`, the datalaoder prepared is only iterated through on the main process and then the batches + are split and broadcast to each process. Will default to `True` when the underlying dataset is an + `IterableDataset`, `False` otherwise. + even_batches (`bool`, *optional*, defaults to `True`): + If set to `True`, in cases where the total batch size across all processes does not exactly divide the + dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among + all workers. + slice_fn_for_dispatch (`Callable`, *optional*`): + If passed, this function will be used to slice tensors across `num_processes`. Will default to + [`~utils.slice_tensors`]. This argument is used only when `dispatch_batches` is set to `True` and will be + ignored otherwise. + use_seedable_sampler (`bool`, *optional*, defaults to `False`): + Whether to use the [`~data_loader.SeedableRandomSampler`] instead of a `RandomSampler` for better + reproducability. Comes at a cost of potentially different performances due to different shuffling + algorithms but ensures results will be the *exact* same. Should be paired with `set_seed()` at every + `self.set_epoch` + + Returns: + `torch.utils.data.dataloader.DataLoader`: A new data loader that will yield the portion of the batches + + + + `BatchSampler`s with varying batch sizes are not enabled by default. To enable this behaviour, set `even_batches` + equal to `False` + + + """ + if dispatch_batches is None: + if not put_on_device: + dispatch_batches = False + else: + dispatch_batches = isinstance(dataloader.dataset, IterableDataset) + + if dispatch_batches and not put_on_device: + raise ValueError("Using `dispatch_batches=True` requires `put_on_device=True`.") + # Grab defaults from AcceleratorState + state = AcceleratorState() + if num_processes is None: + num_processes = state.num_processes + if process_index is None: + process_index = state.process_index + + # Sanity check + if split_batches: + if dataloader.batch_size is not None: + batch_size_for_check = dataloader.batch_size + else: + # For custom batch_sampler + if hasattr(dataloader.batch_sampler, "batch_size"): + batch_size_for_check = dataloader.batch_sampler.batch_size + else: + raise ValueError( + "In order to use `split_batches==True` you must have a `batch_size` attribute either in the passed " + "`dataloader` or `dataloader.batch_sampler` objects, and it has to return a natural number. " + "Your `dataloader.batch_size` is None and `dataloader.batch_sampler` " + f"(`{type(dataloader.batch_sampler)}`) does not have the `batch_size` attribute set." + ) + + if batch_size_for_check > 1 and batch_size_for_check % num_processes != 0: + raise ValueError( + f"To use a `DataLoader` in `split_batches` mode, the batch size ({dataloader.batch_size}) " + f"needs to be a round multiple of the number of processes ({num_processes})." + ) + + new_dataset = dataloader.dataset + # Iterable dataset doesn't like batch_sampler, but data_loader creates a default one for it + new_batch_sampler = dataloader.batch_sampler if not isinstance(new_dataset, IterableDataset) else None + sampler_is_batch_sampler = False + synchronized_generator = None + sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler) + if sampler_is_batch_sampler: + sampler = getattr(dataloader.sampler, "sampler", None) + else: + sampler = getattr(dataloader.batch_sampler, "sampler", None) + if isinstance(sampler, RandomSampler) and use_seedable_sampler: + # When iterating through the dataloader during distributed processes + # we want to ensure that on each process we are iterating through the same + # samples in the same order if a seed is set. This requires a tweak + # to the `torch.utils.data.RandomSampler` class (if used). + sampler = SeedableRandomSampler( + data_source=sampler.data_source, + replacement=sampler.replacement, + num_samples=sampler._num_samples, + generator=getattr(sampler, "generator", torch.Generator()), + ) + + if isinstance(dataloader.sampler, RandomSampler) and state.distributed_type == DistributedType.XLA: + # isinstance(dataloader.sampler, RandomSampler) indicates the original dataloader has `shuffle` enabled. + generator = torch.Generator().manual_seed(42) + dataloader.generator = generator + dataloader.sampler.generator = generator + # No change if no multiprocess + if (num_processes != 1 or state.distributed_type == DistributedType.MEGATRON_LM) and not dispatch_batches: + if isinstance(new_dataset, IterableDataset): + if getattr(dataloader.dataset, "generator", None) is not None: + synchronized_generator = dataloader.dataset.generator + new_dataset = IterableDatasetShard( + new_dataset, + batch_size=dataloader.batch_size, + drop_last=dataloader.drop_last, + num_processes=num_processes, + process_index=process_index, + split_batches=split_batches, + ) + else: + batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler + new_batch_sampler = BatchSamplerShard( + batch_sampler, + num_processes=num_processes, + process_index=process_index, + split_batches=split_batches, + even_batches=even_batches, + ) + + # We ignore all of those since they are all dealt with by our new_batch_sampler + ignore_kwargs = [ + "batch_size", + "shuffle", + "sampler", + "batch_sampler", + "drop_last", + ] + + if rng_types is not None and synchronized_generator is None and "generator" in rng_types: + rng_types.remove("generator") + + kwargs = { + k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k]) + for k in _PYTORCH_DATALOADER_KWARGS + if k not in ignore_kwargs + } + + # Need to provide batch_size as batch_sampler is None for Iterable dataset + if new_batch_sampler is None: + kwargs["drop_last"] = dataloader.drop_last + kwargs["batch_size"] = ( + dataloader.batch_size // num_processes if split_batches and not dispatch_batches else dataloader.batch_size + ) + if dispatch_batches: + kwargs.pop("generator") + dataloader = DataLoaderDispatcher( + new_dataset, + split_batches=split_batches, + batch_sampler=new_batch_sampler, + _drop_last=dataloader.drop_last, + slice_fn=slice_fn_for_dispatch, + **kwargs, + ) + elif sampler_is_batch_sampler: + dataloader = DataLoaderShard( + new_dataset, + device=device if put_on_device and state.distributed_type != DistributedType.XLA else None, + sampler=new_batch_sampler, + batch_size=dataloader.batch_size, + rng_types=rng_types, + _drop_last=dataloader.drop_last, + synchronized_generator=synchronized_generator, + **kwargs, + ) + else: + dataloader = DataLoaderShard( + new_dataset, + device=device if put_on_device and state.distributed_type != DistributedType.XLA else None, + batch_sampler=new_batch_sampler, + rng_types=rng_types, + synchronized_generator=synchronized_generator, + _drop_last=dataloader.drop_last, + **kwargs, + ) + + if isinstance(sampler, SeedableRandomSampler) and use_seedable_sampler: + if sampler_is_batch_sampler: + dataloader.sampler.sampler = sampler + else: + dataloader.batch_sampler.sampler = sampler + if hasattr(dataloader.batch_sampler, "batch_sampler"): + dataloader.batch_sampler.batch_sampler.sampler = sampler + if state.distributed_type == DistributedType.XLA: + return MpDeviceLoaderWrapper(dataloader, device) + return dataloader + + +class SkipBatchSampler(BatchSampler): + """ + A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`. + """ + + def __init__(self, batch_sampler, skip_batches=0): + self.batch_sampler = batch_sampler + self.skip_batches = skip_batches + + def __iter__(self): + for index, samples in enumerate(self.batch_sampler): + if index >= self.skip_batches: + yield samples + + @property + def total_length(self): + return len(self.batch_sampler) + + def __len__(self): + return len(self.batch_sampler) - self.skip_batches + + +class SkipDataLoader(DataLoader): + """ + Subclass of a PyTorch `DataLoader` that will skip the first batches. + + Args: + dataset (`torch.utils.data.dataset.Dataset`): + The dataset to use to build this datalaoder. + skip_batches (`int`, *optional*, defaults to 0): + The number of batches to skip at the beginning. + kwargs: + All other keyword arguments to pass to the regular `DataLoader` initialization. + """ + + def __init__(self, dataset, skip_batches=0, **kwargs): + super().__init__(dataset, **kwargs) + self.skip_batches = skip_batches + + def __iter__(self): + for index, batch in enumerate(super().__iter__()): + if index >= self.skip_batches: + yield batch + + +def skip_first_batches(dataloader, num_batches=0): + """ + Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. + """ + dataset = dataloader.dataset + sampler_is_batch_sampler = False + if isinstance(dataset, IterableDataset): + new_batch_sampler = None + else: + sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler) + batch_sampler = dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler + new_batch_sampler = SkipBatchSampler(batch_sampler, skip_batches=num_batches) + + # We ignore all of those since they are all dealt with by our new_batch_sampler + ignore_kwargs = [ + "batch_size", + "shuffle", + "sampler", + "batch_sampler", + "drop_last", + ] + + kwargs = { + k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k]) + for k in _PYTORCH_DATALOADER_KWARGS + if k not in ignore_kwargs + } + + # Need to provide batch_size as batch_sampler is None for Iterable dataset + if new_batch_sampler is None: + kwargs["drop_last"] = dataloader.drop_last + kwargs["batch_size"] = dataloader.batch_size + + if isinstance(dataloader, DataLoaderDispatcher): + if new_batch_sampler is None: + # Need to manually skip batches in the dataloader + kwargs["skip_batches"] = num_batches + dataloader = DataLoaderDispatcher( + dataset, + split_batches=dataloader.split_batches, + batch_sampler=new_batch_sampler, + _drop_last=dataloader._drop_last, + **kwargs, + ) + elif isinstance(dataloader, DataLoaderShard): + if new_batch_sampler is None: + # Need to manually skip batches in the dataloader + kwargs["skip_batches"] = num_batches + elif sampler_is_batch_sampler: + kwargs["sampler"] = new_batch_sampler + kwargs["batch_size"] = dataloader.batch_size + else: + kwargs["batch_sampler"] = new_batch_sampler + dataloader = DataLoaderShard( + dataset, + device=dataloader.device, + rng_types=dataloader.rng_types, + synchronized_generator=dataloader.synchronized_generator, + **kwargs, + ) + else: + if new_batch_sampler is None: + # Need to manually skip batches in the dataloader + dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs) + else: + dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs) + + return dataloader diff --git a/llm/Lib/site-packages/accelerate/hooks.py b/llm/Lib/site-packages/accelerate/hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..e9a4b384f3cac39e7bedabb1f5e7c0320aae6a7f --- /dev/null +++ b/llm/Lib/site-packages/accelerate/hooks.py @@ -0,0 +1,709 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from typing import Dict, List, Mapping, Optional, Union + +import torch +import torch.nn as nn + +from .state import PartialState +from .utils import ( + PrefixedDataset, + find_device, + named_module_tensors, + send_to_device, + set_module_tensor_to_device, +) +from .utils.modeling import get_non_persistent_buffers +from .utils.other import recursive_getattr + + +class ModelHook: + """ + A hook that contains callbacks to be executed just before and after the forward method of a model. The difference + with PyTorch existing hooks is that they get passed along the kwargs. + + Class attribute: + - **no_grad** (`bool`, *optional*, defaults to `False`) -- Whether or not to execute the actual forward pass under + the `torch.no_grad()` context manager. + """ + + no_grad = False + + def init_hook(self, module): + """ + To be executed when the hook is attached to the module. + + Args: + module (`torch.nn.Module`): The module attached to this hook. + """ + return module + + def pre_forward(self, module, *args, **kwargs): + """ + To be executed just before the forward method of the model. + + Args: + module (`torch.nn.Module`): The module whose forward pass will be executed just after this event. + args (`Tuple[Any]`): The positional arguments passed to the module. + kwargs (`Dict[Str, Any]`): The keyword arguments passed to the module. + + Returns: + `Tuple[Tuple[Any], Dict[Str, Any]]`: A tuple with the treated `args` and `kwargs`. + """ + return args, kwargs + + def post_forward(self, module, output): + """ + To be executed just after the forward method of the model. + + Args: + module (`torch.nn.Module`): The module whose forward pass been executed just before this event. + output (`Any`): The output of the module. + + Returns: + `Any`: The processed `output`. + """ + return output + + def detach_hook(self, module): + """ + To be executed when the hook is detached from a module. + + Args: + module (`torch.nn.Module`): The module detached from this hook. + """ + return module + + +class SequentialHook(ModelHook): + """ + A hook that can contain several hooks and iterates through them at each event. + """ + + def __init__(self, *hooks): + self.hooks = hooks + + def init_hook(self, module): + for hook in self.hooks: + module = hook.init_hook(module) + return module + + def pre_forward(self, module, *args, **kwargs): + for hook in self.hooks: + args, kwargs = hook.pre_forward(module, *args, **kwargs) + return args, kwargs + + def post_forward(self, module, output): + for hook in self.hooks: + output = hook.post_forward(module, output) + return output + + def detach_hook(self, module): + for hook in self.hooks: + module = hook.detach_hook(module) + return module + + +def add_hook_to_module(module: nn.Module, hook: ModelHook, append: bool = False): + """ + Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove + this behavior and restore the original `forward` method, use `remove_hook_from_module`. + + + + If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks + together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class. + + + + Args: + module (`torch.nn.Module`): + The module to attach a hook to. + hook (`ModelHook`): + The hook to attach. + append (`bool`, *optional*, defaults to `False`): + Whether the hook should be chained with an existing one (if module already contains a hook) or not. + + Returns: + `torch.nn.Module`: The same module, with the hook attached (the module is modified in place, so the result can + be discarded). + """ + + if append and (getattr(module, "_hf_hook", None) is not None): + old_hook = module._hf_hook + remove_hook_from_module(module) + hook = SequentialHook(old_hook, hook) + + if hasattr(module, "_hf_hook") and hasattr(module, "_old_forward"): + # If we already put some hook on this module, we replace it with the new one. + old_forward = module._old_forward + else: + old_forward = module.forward + module._old_forward = old_forward + + module = hook.init_hook(module) + module._hf_hook = hook + + def new_forward(module, *args, **kwargs): + args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs) + if module._hf_hook.no_grad: + with torch.no_grad(): + output = module._old_forward(*args, **kwargs) + else: + output = module._old_forward(*args, **kwargs) + return module._hf_hook.post_forward(module, output) + + # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail. + # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409 + if "GraphModuleImpl" in str(type(module)): + module.__class__.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward) + else: + module.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward) + + return module + + +def remove_hook_from_module(module: nn.Module, recurse=False): + """ + Removes any hook attached to a module via `add_hook_to_module`. + + Args: + module (`torch.nn.Module`): The module to attach a hook to. + recurse (`bool`, **optional**): Whether to remove the hooks recursively + + Returns: + `torch.nn.Module`: The same module, with the hook detached (the module is modified in place, so the result can + be discarded). + """ + + if hasattr(module, "_hf_hook"): + module._hf_hook.detach_hook(module) + delattr(module, "_hf_hook") + + if hasattr(module, "_old_forward"): + # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail. + # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409 + if "GraphModuleImpl" in str(type(module)): + module.__class__.forward = module._old_forward + else: + module.forward = module._old_forward + delattr(module, "_old_forward") + + if recurse: + for child in module.children(): + remove_hook_from_module(child, recurse) + + return module + + +class AlignDevicesHook(ModelHook): + """ + A generic `ModelHook` that ensures inputs and model weights are on the same device for the forward pass of the + associated module, potentially offloading the weights after the forward pass. + + Args: + execution_device (`torch.device`, *optional*): + The device on which inputs and model weights should be placed before the forward pass. + offload (`bool`, *optional*, defaults to `False`): + Whether or not the weights should be offloaded after the forward pass. + io_same_device (`bool`, *optional*, defaults to `False`): + Whether or not the output should be placed on the same device as the input was. + weights_map (`Mapping[str, torch.Tensor]`, *optional*): + When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values. + offload_buffers (`bool`, *optional*, defaults to `False`): + Whether or not to include the associated module's buffers when offloading. + place_submodules (`bool`, *optional*, defaults to `False`): + Whether to place the submodules on `execution_device` during the `init_hook` event. + """ + + def __init__( + self, + execution_device: Optional[Union[int, str, torch.device]] = None, + offload: bool = False, + io_same_device: bool = False, + weights_map: Optional[Mapping] = None, + offload_buffers: bool = False, + place_submodules: bool = False, + skip_keys: Optional[Union[str, List[str]]] = None, + tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None, + ): + self.execution_device = execution_device + self.offload = offload + self.io_same_device = io_same_device + self.weights_map = weights_map + self.offload_buffers = offload_buffers + self.place_submodules = place_submodules + self.skip_keys = skip_keys + + # Will contain the input device when `io_same_device=True`. + self.input_device = None + self.param_original_devices = {} + self.buffer_original_devices = {} + self.tied_params_names = set() + + # The hook pre_forward/post_forward need to have knowledge of this dictionary, as with offloading we want to avoid duplicating memory + # for tied weights already loaded on the target execution device. + self.tied_params_map = tied_params_map + + def __repr__(self): + return ( + f"AlignDevicesHook(execution_device={self.execution_device}, offload={self.offload}, " + f"io_same_device={self.io_same_device}, offload_buffers={self.offload_buffers}, " + f"place_submodules={self.place_submodules}, skip_keys={repr(self.skip_keys)})" + ) + + def init_hook(self, module): + # In case the AlignDevicesHook is on meta device, ignore tied weights as data_ptr() is then always zero. + if self.execution_device == "meta" or self.execution_device == torch.device("meta"): + self.tied_params_map = None + + if not self.offload and self.execution_device is not None: + for name, _ in named_module_tensors(module, recurse=self.place_submodules): + set_module_tensor_to_device(module, name, self.execution_device, tied_params_map=self.tied_params_map) + elif self.offload: + self.original_devices = { + name: param.device for name, param in named_module_tensors(module, recurse=self.place_submodules) + } + if self.weights_map is None: + self.weights_map = { + name: param.to("cpu") + for name, param in named_module_tensors( + module, include_buffers=self.offload_buffers, recurse=self.place_submodules + ) + } + for name, _ in named_module_tensors( + module, include_buffers=self.offload_buffers, recurse=self.place_submodules, remove_non_persistent=True + ): + # When using disk offloading, we can not rely on `weights_map[name].data_ptr()` as the reference pointer, + # as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer. + # As we have no reliable way to track the shared data pointer of tied weights in this case, we use tied_params_names: List[str] + # to add on the fly pointers to `tied_params_map` in the pre_forward call. + if ( + self.tied_params_map is not None + and recursive_getattr(module, name).data_ptr() in self.tied_params_map + ): + self.tied_params_names.add(name) + + set_module_tensor_to_device(module, name, "meta") + + if not self.offload_buffers and self.execution_device is not None: + for name, _ in module.named_buffers(recurse=self.place_submodules): + set_module_tensor_to_device( + module, name, self.execution_device, tied_params_map=self.tied_params_map + ) + elif self.offload_buffers and self.execution_device is not None: + for name in get_non_persistent_buffers(module, recurse=self.place_submodules): + set_module_tensor_to_device( + module, name, self.execution_device, tied_params_map=self.tied_params_map + ) + + return module + + def pre_forward(self, module, *args, **kwargs): + if self.io_same_device: + self.input_device = find_device([args, kwargs]) + if self.offload: + self.tied_pointers_to_remove = set() + + for name, _ in named_module_tensors( + module, + include_buffers=self.offload_buffers, + recurse=self.place_submodules, + remove_non_persistent=True, + ): + fp16_statistics = None + value = self.weights_map[name] + if "weight" in name and name.replace("weight", "SCB") in self.weights_map.keys(): + if value.dtype == torch.int8: + fp16_statistics = self.weights_map[name.replace("weight", "SCB")] + + # In case we are using offloading with tied weights, we need to keep track of the offloaded weights + # that are loaded on device at this point, as we will need to remove them as well from the dictionary + # self.tied_params_map in order to allow to free memory. + if name in self.tied_params_names and value.data_ptr() not in self.tied_params_map: + self.tied_params_map[value.data_ptr()] = {} + + if ( + value is not None + and self.tied_params_map is not None + and value.data_ptr() in self.tied_params_map + and self.execution_device not in self.tied_params_map[value.data_ptr()] + ): + self.tied_pointers_to_remove.add((value.data_ptr(), self.execution_device)) + + set_module_tensor_to_device( + module, + name, + self.execution_device, + value=value, + fp16_statistics=fp16_statistics, + tied_params_map=self.tied_params_map, + ) + + return send_to_device(args, self.execution_device), send_to_device( + kwargs, self.execution_device, skip_keys=self.skip_keys + ) + + def post_forward(self, module, output): + if self.offload: + for name, _ in named_module_tensors( + module, + include_buffers=self.offload_buffers, + recurse=self.place_submodules, + remove_non_persistent=True, + ): + set_module_tensor_to_device(module, name, "meta") + if type(module).__name__ == "Linear8bitLt": + module.state.SCB = None + module.state.CxB = None + + # We may have loaded tied weights into self.tied_params_map (avoiding to load them several times in e.g. submodules): remove them from + # this dictionary to allow the garbage collector to do its job. + for value_pointer, device in self.tied_pointers_to_remove: + del self.tied_params_map[value_pointer][device] + self.tied_pointers_to_remove = set() + + if self.io_same_device and self.input_device is not None: + output = send_to_device(output, self.input_device, skip_keys=self.skip_keys) + + return output + + def detach_hook(self, module): + if self.offload: + for name, device in self.original_devices.items(): + if device != torch.device("meta"): + set_module_tensor_to_device(module, name, device, value=self.weights_map.get(name, None)) + return module + + +def attach_execution_device_hook( + module: torch.nn.Module, + execution_device: Union[int, str, torch.device], + skip_keys: Optional[Union[str, List[str]]] = None, + preload_module_classes: Optional[List[str]] = None, + tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None, +): + """ + Recursively attaches `AlignDevicesHook` to all submodules of a given model to make sure they have the right + execution device + + Args: + module (`torch.nn.Module`): + The module where we want to attach the hooks. + execution_device (`int`, `str` or `torch.device`): + The device on which inputs and model weights should be placed before the forward pass. + skip_keys (`str` or `List[str]`, *optional*): + A list of keys to ignore when moving inputs or outputs between devices. + preload_module_classes (`List[str]`, *optional*): + A list of classes whose instances should load all their weights (even in the submodules) at the beginning + of the forward. This should only be used for classes that have submodules which are registered but not + called directly during the forward, for instance if a `dense` linear layer is registered, but at forward, + `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly. + tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`): + A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution + device, this parameter is useful to reuse the first available pointer of a shared weight for all others, + instead of duplicating memory. + """ + if not hasattr(module, "_hf_hook") and len(module.state_dict()) > 0: + add_hook_to_module( + module, + AlignDevicesHook(execution_device, skip_keys=skip_keys, tied_params_map=tied_params_map), + ) + + # Break the recursion if we get to a preload module. + if preload_module_classes is not None and module.__class__.__name__ in preload_module_classes: + return + + for child in module.children(): + attach_execution_device_hook(child, execution_device, tied_params_map=tied_params_map) + + +def attach_align_device_hook( + module: torch.nn.Module, + execution_device: Optional[torch.device] = None, + offload: bool = False, + weights_map: Optional[Mapping] = None, + offload_buffers: bool = False, + module_name: str = "", + skip_keys: Optional[Union[str, List[str]]] = None, + preload_module_classes: Optional[List[str]] = None, + tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None, +): + """ + Recursively attaches `AlignDevicesHook` to all submodules of a given model that have direct parameters and/or + buffers. + + Args: + module (`torch.nn.Module`): + The module where we want to attach the hooks. + execution_device (`torch.device`, *optional*): + The device on which inputs and model weights should be placed before the forward pass. + offload (`bool`, *optional*, defaults to `False`): + Whether or not the weights should be offloaded after the forward pass. + weights_map (`Mapping[str, torch.Tensor]`, *optional*): + When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values. + offload_buffers (`bool`, *optional*, defaults to `False`): + Whether or not to include the associated module's buffers when offloading. + module_name (`str`, *optional*, defaults to `""`): + The name of the module. + skip_keys (`str` or `List[str]`, *optional*): + A list of keys to ignore when moving inputs or outputs between devices. + preload_module_classes (`List[str]`, *optional*): + A list of classes whose instances should load all their weights (even in the submodules) at the beginning + of the forward. This should only be used for classes that have submodules which are registered but not + called directly during the forward, for instance if a `dense` linear layer is registered, but at forward, + `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly. + tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`): + A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution + device, this parameter is useful to reuse the first available pointer of a shared weight for all others, + instead of duplicating memory. + """ + # Attach the hook on this module if it has any direct tensor. + directs = named_module_tensors(module) + full_offload = ( + offload and preload_module_classes is not None and module.__class__.__name__ in preload_module_classes + ) + + if len(list(directs)) > 0 or full_offload: + if weights_map is not None: + prefix = f"{module_name}." if len(module_name) > 0 else "" + prefixed_weights_map = PrefixedDataset(weights_map, prefix) + else: + prefixed_weights_map = None + hook = AlignDevicesHook( + execution_device=execution_device, + offload=offload, + weights_map=prefixed_weights_map, + offload_buffers=offload_buffers, + place_submodules=full_offload, + skip_keys=skip_keys, + tied_params_map=tied_params_map, + ) + add_hook_to_module(module, hook, append=True) + + # We stop the recursion in case we hit the full offload. + if full_offload: + return + + # Recurse on all children of the module. + for child_name, child in module.named_children(): + child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name + attach_align_device_hook( + child, + execution_device=execution_device, + offload=offload, + weights_map=weights_map, + offload_buffers=offload_buffers, + module_name=child_name, + preload_module_classes=preload_module_classes, + skip_keys=skip_keys, + tied_params_map=tied_params_map, + ) + + +def remove_hook_from_submodules(module: nn.Module): + """ + Recursively removes all hooks attached on the submodules of a given model. + + Args: + module (`torch.nn.Module`): The module on which to remove all hooks. + """ + remove_hook_from_module(module) + for child in module.children(): + remove_hook_from_submodules(child) + + +def attach_align_device_hook_on_blocks( + module: nn.Module, + execution_device: Optional[Union[torch.device, Dict[str, torch.device]]] = None, + offload: Union[bool, Dict[str, bool]] = False, + weights_map: Mapping = None, + offload_buffers: bool = False, + module_name: str = "", + skip_keys: Optional[Union[str, List[str]]] = None, + preload_module_classes: Optional[List[str]] = None, + tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None, +): + """ + Attaches `AlignDevicesHook` to all blocks of a given model as needed. + + Args: + module (`torch.nn.Module`): + The module where we want to attach the hooks. + execution_device (`torch.device` or `Dict[str, torch.device]`, *optional*): + The device on which inputs and model weights should be placed before the forward pass. It can be one device + for the whole module, or a dictionary mapping module name to device. + offload (`bool`, *optional*, defaults to `False`): + Whether or not the weights should be offloaded after the forward pass. It can be one boolean for the whole + module, or a dictionary mapping module name to boolean. + weights_map (`Mapping[str, torch.Tensor]`, *optional*): + When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values. + offload_buffers (`bool`, *optional*, defaults to `False`): + Whether or not to include the associated module's buffers when offloading. + module_name (`str`, *optional*, defaults to `""`): + The name of the module. + skip_keys (`str` or `List[str]`, *optional*): + A list of keys to ignore when moving inputs or outputs between devices. + preload_module_classes (`List[str]`, *optional*): + A list of classes whose instances should load all their weights (even in the submodules) at the beginning + of the forward. This should only be used for classes that have submodules which are registered but not + called directly during the forward, for instance if a `dense` linear layer is registered, but at forward, + `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly. + tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`): + A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution + device, this parameter is useful to reuse the first available pointer of a shared weight for all others, + instead of duplicating memory. + """ + # If one device and one offload, we've got one hook. + if not isinstance(execution_device, Mapping) and not isinstance(offload, dict): + if not offload: + hook = AlignDevicesHook( + execution_device=execution_device, + io_same_device=True, + skip_keys=skip_keys, + place_submodules=True, + tied_params_map=tied_params_map, + ) + add_hook_to_module(module, hook) + else: + attach_align_device_hook( + module, + execution_device=execution_device, + offload=True, + weights_map=weights_map, + offload_buffers=offload_buffers, + module_name=module_name, + skip_keys=skip_keys, + tied_params_map=tied_params_map, + ) + return + + if not isinstance(execution_device, Mapping): + execution_device = {key: execution_device for key in offload.keys()} + if not isinstance(offload, Mapping): + offload = {key: offload for key in execution_device.keys()} + + if module_name in execution_device and module_name in offload and not offload[module_name]: + hook = AlignDevicesHook( + execution_device=execution_device[module_name], + offload_buffers=offload_buffers, + io_same_device=(module_name == ""), + place_submodules=True, + skip_keys=skip_keys, + tied_params_map=tied_params_map, + ) + add_hook_to_module(module, hook) + attach_execution_device_hook(module, execution_device[module_name], tied_params_map=tied_params_map) + elif module_name in execution_device and module_name in offload: + attach_align_device_hook( + module, + execution_device=execution_device[module_name], + offload=True, + weights_map=weights_map, + offload_buffers=offload_buffers, + module_name=module_name, + skip_keys=skip_keys, + preload_module_classes=preload_module_classes, + tied_params_map=tied_params_map, + ) + if not hasattr(module, "_hf_hook"): + hook = AlignDevicesHook( + execution_device=execution_device[module_name], + io_same_device=(module_name == ""), + skip_keys=skip_keys, + tied_params_map=tied_params_map, + ) + add_hook_to_module(module, hook) + attach_execution_device_hook( + module, + execution_device[module_name], + preload_module_classes=preload_module_classes, + skip_keys=skip_keys, + tied_params_map=tied_params_map, + ) + elif module_name == "": + hook = AlignDevicesHook( + execution_device=execution_device.get(""), + io_same_device=True, + skip_keys=skip_keys, + tied_params_map=tied_params_map, + ) + add_hook_to_module(module, hook) + + for child_name, child in module.named_children(): + child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name + attach_align_device_hook_on_blocks( + child, + execution_device=execution_device, + offload=offload, + weights_map=weights_map, + offload_buffers=offload_buffers, + module_name=child_name, + preload_module_classes=preload_module_classes, + skip_keys=skip_keys, + tied_params_map=tied_params_map, + ) + + +class CpuOffload(ModelHook): + """ + Offloads a model on the CPU until its forward pass is called. The model will not be offloaded back to the CPU after + the forward, the user needs to call the `init_hook` method again for this. + + Args: + execution_device(`str`, `int` or `torch.device`, *optional*): + The device on which the model should be executed. Will default to the MPS device if it's available, then + GPU 0 if there is a GPU, and finally to the CPU. + prev_module_hook (`UserCpuOffloadHook`, *optional*): + The hook sent back by [`cpu_offload_with_hook`] for a previous model in the pipeline you are running. If + passed, its offload method will be called just before the forward of the model to which this hook is + attached. + """ + + def __init__( + self, + execution_device: Optional[Union[str, int, torch.device]] = None, + prev_module_hook: Optional["UserCpuOffloadHook"] = None, + ): + self.prev_module_hook = prev_module_hook + + self.execution_device = execution_device if execution_device is not None else PartialState().default_device + + def init_hook(self, module): + return module.to("cpu") + + def pre_forward(self, module, *args, **kwargs): + if self.prev_module_hook is not None: + self.prev_module_hook.offload() + module.to(self.execution_device) + return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device) + + +class UserCpuOffloadHook: + """ + A simple hook grouping a model and a `ModelHook`, which provides easy APIs for to call the init method of the hook + or remove it entirely. + """ + + def __init__(self, model, hook): + self.model = model + self.hook = hook + + def offload(self): + self.hook.init_hook(self.model) + + def remove(self): + remove_hook_from_module(self.model) diff --git a/llm/Lib/site-packages/accelerate/inference.py b/llm/Lib/site-packages/accelerate/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..cf4cf15017938e34867d4eeaad120745051ab385 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/inference.py @@ -0,0 +1,188 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from types import MethodType +from typing import Any, Dict, List, Optional, Tuple, Union + +from .state import PartialState +from .utils import ( + calculate_maximum_sizes, + convert_bytes, + copy_tensor_to_devices, + ignorant_find_batch_size, + infer_auto_device_map, + is_pippy_available, + pad_input_tensors, + send_to_device, +) + + +if is_pippy_available(): + from pippy.IR import Pipe, PipeSplitWrapper, annotate_split_points + from pippy.PipelineStage import PipelineStage + + +def generate_device_map(model, num_processes: int = 1, no_split_module_classes=None, max_memory: dict = None): + """ + Calculates the device map for `model` with an offset for PiPPy + """ + if num_processes == 1: + return infer_auto_device_map(model, no_split_module_classes=no_split_module_classes, clean_result=False) + if max_memory is None: + model_size, shared = calculate_maximum_sizes(model) + + # Split into `n` chunks for each GPU + memory = (model_size + shared[0]) / num_processes + memory = convert_bytes(memory) + value, ending = memory.split(" ") + + # Add a chunk to deal with potential extra shared memory instances + memory = math.ceil(float(value)) * 1.1 + memory = f"{memory} {ending}" + max_memory = {i: memory for i in range(num_processes)} + device_map = infer_auto_device_map( + model, + max_memory=max_memory, + no_split_module_classes=no_split_module_classes, + clean_result=False, + ) + return device_map + + +def find_pippy_batch_size(args, kwargs): + found_batch_size = None + if args is not None: + for arg in args: + found_batch_size = ignorant_find_batch_size(arg) + if found_batch_size is not None: + break + if kwargs is not None and found_batch_size is None: + for kwarg in kwargs.values(): + found_batch_size = ignorant_find_batch_size(kwarg) + if found_batch_size is not None: + break + return found_batch_size + + +def build_pipeline(model, split_points, args, kwargs, num_chunks): + """ + Attaches the split points to the model based on `self.device_map` and generates a `PipelineStage`. Requires passing + in needed `args` and `kwargs` as the model needs on the CPU. + + Users can pass in custom `num_chunks` as an optional hyper-parameter. By default will use + `AcceleratorState.num_processes` + """ + # We need to annotate the split points in the model for PiPPy + state = PartialState() + annotate_split_points(model, {split_point: PipeSplitWrapper.SplitPoint.BEGINNING for split_point in split_points}) + found_batch_size = find_pippy_batch_size(args, kwargs) + if found_batch_size != num_chunks: + if args is not None: + args = pad_input_tensors(args, found_batch_size, num_chunks) + if kwargs is not None: + kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks) + pipe = Pipe.from_tracing(model, num_chunks=num_chunks, example_args=args, example_kwargs=kwargs) + stage = PipelineStage(pipe, state.local_process_index, device=state.device) + + return stage + + +def pippy_forward(forward, num_chunks, gather_output, *args, **kwargs): + state = PartialState() + output = None + + if state.num_processes == 1: + output = forward(*args, **kwargs) + elif state.is_local_main_process: + found_batch_size = find_pippy_batch_size(args, kwargs) + if found_batch_size is None: + raise ValueError("Could not find batch size from args or kwargs") + else: + if found_batch_size != num_chunks: + args = pad_input_tensors(args, found_batch_size, num_chunks) + kwargs = pad_input_tensors(kwargs, found_batch_size, num_chunks) + forward(*args, **kwargs) + elif state.is_last_process: + output = forward() + else: + forward() + if gather_output: + # Each node will get a copy of the full output which is only on the last GPU + output = copy_tensor_to_devices(output) + return output + + +def prepare_pippy( + model, + split_points: Optional[Union[str, List[str]]] = "auto", + no_split_module_classes: Optional[List[str]] = None, + example_args: Optional[Tuple[Any]] = (), + example_kwargs: Optional[Dict[str, Any]] = None, + num_chunks: Optional[int] = None, + gather_output: Optional[bool] = False, +): + """ + Wraps `model` for pipeline parallel inference. + + Args: + model (`torch.nn.Module`): + A model we want to split for pipeline-parallel inference + split_points (`str` or `List[str]`, defaults to 'auto'): + How to generate the split points and chunk the model across each GPU. 'auto' will find the best balanced + split given any model. Should be a list of layer names in the model to split by otherwise. + no_split_module_classes (`List[str]`): + A list of class names for layers we don't want to be split. + example_args (tuple of model inputs): + The expected inputs for the model that uses order-based inputs. Recommended to use this method if possible. + example_kwargs (dict of model inputs) + The expected inputs for the model that uses dictionary-based inputs. This is a *highly* limiting structure + that requires the same keys be present at *all* inference calls. Not recommended unless the prior condition + is true for all cases. + num_chunks (`int`, defaults to the number of available GPUs): + The number of different stages the Pipeline will have. By default it will assign one chunk per GPU, but + this can be tuned and played with. In general one should have num_chunks >= num_gpus. + gather_output (`bool`, defaults to `False`): + If `True`, the output from the last GPU (which holds the true outputs) is sent across to all GPUs. + """ + if not is_pippy_available(): + raise ImportError( + "`pippy` was not found to be installed on your system. Please " + "install using `pip install torchpippy` or ensure you have at least version 0.2.0" + ) + state = PartialState() + example_args = send_to_device(example_args, "cpu") + example_kwargs = send_to_device(example_kwargs, "cpu") + if num_chunks is None: + num_chunks = state.num_processes + if split_points == "auto": + device_map = generate_device_map(model, num_chunks, no_split_module_classes=no_split_module_classes) + split_points = [] + for i in range(1, num_chunks): + split_points.append(next(k for k, v in device_map.items() if v == i)) + model.hf_split_points = split_points + stage = build_pipeline(model, split_points, example_args, example_kwargs, num_chunks) + model._original_forward = model.forward + model._original_call = model.__call__ + model.pippy_stage = stage + model.hf_split_points = split_points + + def forward(*args, **kwargs): + return pippy_forward(stage.forward, num_chunks, gather_output, *args, **kwargs) + + # To act like a decorator so that it can be popped when doing `extract_model_from_parallel` + # Note: creates an infinite recursion loop with `generate` + model_forward = MethodType(forward, model) + forward.__wrapped__ = model_forward + model.forward = forward + return model diff --git a/llm/Lib/site-packages/accelerate/launchers.py b/llm/Lib/site-packages/accelerate/launchers.py new file mode 100644 index 0000000000000000000000000000000000000000..0265b25187f813356cfb49768097d6cf2599b0d3 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/launchers.py @@ -0,0 +1,258 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import tempfile + +import torch + +from .state import AcceleratorState, PartialState +from .utils import ( + PrecisionType, + PrepareForLaunch, + are_libraries_initialized, + check_cuda_p2p_ib_support, + get_gpu_info, + is_mps_available, + patch_environment, +) + + +def test_launch(): + "Verify a `PartialState` can be initialized." + _ = PartialState() + + +def notebook_launcher( + function, + args=(), + num_processes=None, + mixed_precision="no", + use_port="29500", + master_addr="127.0.0.1", + node_rank=0, + num_nodes=1, +): + """ + Launches a training function, using several processes or multiple nodes if it's possible in the current environment + (TPU with multiple cores for instance). + + + + To use this function absolutely zero calls to a CUDA device must be made in the notebook session before calling. If + any have been made, you will need to restart the notebook and make sure no cells use any CUDA capability. + + Setting `ACCELERATE_DEBUG_MODE="1"` in your environment will run a test before truly launching to ensure that none + of those calls have been made. + + + + Args: + function (`Callable`): + The training function to execute. If it accepts arguments, the first argument should be the index of the + process run. + args (`Tuple`): + Tuple of arguments to pass to the function (it will receive `*args`). + num_processes (`int`, *optional*): + The number of processes to use for training. Will default to 8 in Colab/Kaggle if a TPU is available, to + the number of GPUs available otherwise. + mixed_precision (`str`, *optional*, defaults to `"no"`): + If `fp16` or `bf16`, will use mixed precision training on multi-GPU. + use_port (`str`, *optional*, defaults to `"29500"`): + The port to use to communicate between processes when launching a multi-GPU training. + master_addr (`str`, *optional*, defaults to `"127.0.0.1"`): + The address to use for communication between processes. + node_rank (`int`, *optional*, defaults to 0): + The rank of the current node. + num_nodes (`int`, *optional*, defaults to 1): + The number of nodes to use for training. + + Example: + + ```python + # Assume this is defined in a Jupyter Notebook on an instance with two GPUs + from accelerate import notebook_launcher + + + def train(*args): + # Your training function here + ... + + + notebook_launcher(train, args=(arg1, arg2), num_processes=2, mixed_precision="fp16") + ``` + """ + # Are we in a google colab or a Kaggle Kernel? + in_colab = False + in_kaggle = False + if any(key.startswith("KAGGLE") for key in os.environ.keys()): + in_kaggle = True + elif "IPython" in sys.modules: + in_colab = "google.colab" in str(sys.modules["IPython"].get_ipython()) + + try: + mixed_precision = PrecisionType(mixed_precision.lower()) + except ValueError: + raise ValueError( + f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}." + ) + + if (in_colab or in_kaggle) and (os.environ.get("TPU_NAME", None) is not None): + # TPU launch + import torch_xla.distributed.xla_multiprocessing as xmp + + if len(AcceleratorState._shared_state) > 0: + raise ValueError( + "To train on TPU in Colab or Kaggle Kernel, the `Accelerator` should only be initialized inside " + "your training function. Restart your notebook and make sure no cells initializes an " + "`Accelerator`." + ) + if num_processes is None: + num_processes = 8 + + launcher = PrepareForLaunch(function, distributed_type="TPU") + print(f"Launching a training on {num_processes} TPU cores.") + xmp.spawn(launcher, args=args, nprocs=num_processes, start_method="fork") + elif in_colab and get_gpu_info()[1] < 2: + # No need for a distributed launch otherwise as it's either CPU or one GPU. + if torch.cuda.is_available(): + print("Launching training on one GPU.") + else: + print("Launching training on one CPU.") + function(*args) + else: + if num_processes is None: + raise ValueError( + "You have to specify the number of GPUs you would like to use, add `num_processes=...` to your call." + ) + if node_rank >= num_nodes: + raise ValueError("The node_rank must be less than the number of nodes.") + if num_processes > 1: + # Multi-GPU launch + from torch.multiprocessing import start_processes + from torch.multiprocessing.spawn import ProcessRaisedException + + if len(AcceleratorState._shared_state) > 0: + raise ValueError( + "To launch a multi-GPU training from your notebook, the `Accelerator` should only be initialized " + "inside your training function. Restart your notebook and make sure no cells initializes an " + "`Accelerator`." + ) + # Check for specific libraries known to initialize CUDA that users constantly use + problematic_imports = are_libraries_initialized("bitsandbytes") + if len(problematic_imports) > 0: + err = ( + "Could not start distributed process. Libraries known to initialize CUDA upon import have been " + "imported already. Please keep these imports inside your training function to try and help with this:" + ) + for lib_name in problematic_imports: + err += f"\n\t* `{lib_name}`" + raise RuntimeError(err) + + patched_env = dict( + nproc=num_processes, + node_rank=node_rank, + world_size=num_nodes * num_processes, + master_addr=master_addr, + master_port=use_port, + mixed_precision=mixed_precision, + ) + + # Check for CUDA P2P and IB issues + if not check_cuda_p2p_ib_support(): + patched_env["nccl_p2p_disable"] = "1" + patched_env["nccl_ib_disable"] = "1" + + # torch.distributed will expect a few environment variable to be here. We set the ones common to each + # process here (the other ones will be set be the launcher). + with patch_environment(**patched_env): + # First dummy launch + if os.environ.get("ACCELERATE_DEBUG_MODE", "false").lower() == "true": + launcher = PrepareForLaunch(test_launch, distributed_type="MULTI_GPU") + try: + start_processes(launcher, args=(), nprocs=num_processes, start_method="fork") + except ProcessRaisedException as e: + err = "An issue was found when verifying a stable environment for the notebook launcher." + if "Cannot re-initialize CUDA in forked subprocess" in e.args[0]: + raise RuntimeError( + f"{err}" + "This likely stems from an outside import causing issues once the `notebook_launcher()` is called. " + "Please review your imports and test them when running the `notebook_launcher()` to identify " + "which one is problematic and causing CUDA to be initialized." + ) from e + else: + raise RuntimeError(f"{err} The following error was raised: {e}") from e + # Now the actual launch + launcher = PrepareForLaunch(function, distributed_type="MULTI_GPU") + print(f"Launching training on {num_processes} GPUs.") + try: + start_processes(launcher, args=args, nprocs=num_processes, start_method="fork") + except ProcessRaisedException as e: + if "Cannot re-initialize CUDA in forked subprocess" in e.args[0]: + raise RuntimeError( + "CUDA has been initialized before the `notebook_launcher` could create a forked subprocess. " + "This likely stems from an outside import causing issues once the `notebook_launcher()` is called. " + "Please review your imports and test them when running the `notebook_launcher()` to identify " + "which one is problematic and causing CUDA to be initialized." + ) from e + else: + raise RuntimeError(f"An issue was found when launching the training: {e}") from e + + else: + # No need for a distributed launch otherwise as it's either CPU, GPU or MPS. + if is_mps_available(): + os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + print("Launching training on MPS.") + elif torch.cuda.is_available(): + print("Launching training on one GPU.") + else: + print("Launching training on CPU.") + function(*args) + + +def debug_launcher(function, args=(), num_processes=2): + """ + Launches a training function using several processes on CPU for debugging purposes. + + + + This function is provided for internal testing and debugging, but it's not intended for real trainings. It will + only use the CPU. + + + + Args: + function (`Callable`): + The training function to execute. + args (`Tuple`): + Tuple of arguments to pass to the function (it will receive `*args`). + num_processes (`int`, *optional*, defaults to 2): + The number of processes to use for training. + """ + from torch.multiprocessing import start_processes + + with tempfile.NamedTemporaryFile() as tmp_file: + # torch.distributed will expect a few environment variable to be here. We set the ones common to each + # process here (the other ones will be set be the launcher). + with patch_environment( + world_size=num_processes, + master_addr="127.0.0.1", + master_port="29500", + accelerate_mixed_precision="no", + accelerate_debug_rdv_file=tmp_file.name, + accelerate_use_cpu="yes", + ): + launcher = PrepareForLaunch(function, debug=True) + start_processes(launcher, args=args, nprocs=num_processes, start_method="fork") diff --git a/llm/Lib/site-packages/accelerate/local_sgd.py b/llm/Lib/site-packages/accelerate/local_sgd.py new file mode 100644 index 0000000000000000000000000000000000000000..7f2657fcc8b057b4396cf299e6cf681fa7b83aa8 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/local_sgd.py @@ -0,0 +1,102 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +from accelerate import Accelerator, DistributedType + + +class LocalSGD: + """ + A helper class to support local SGD on top of Accelerator. It simply runs a given number of updates independently + on each device, and averages model weights every K synchronization step. + + It should be used only in the multi-GPU (or multi-CPU) setup without extensions such as DeepSpeed. In particular, + this is a simple implementation that cannot support scenarios such as model parallelism. + + + Although we are not aware of the true origins of this simple approach, the idea of local SGD is quite old and goes + back to at least: + + Zhang, J., De Sa, C., Mitliagkas, I., & Ré, C. (2016). [Parallel SGD: When does averaging help?. arXiv preprint + arXiv:1606.07365.](https://arxiv.org/abs/1606.07365) + + We credit the term Local SGD to the following paper (but there might be earlier references we are not aware of). + + Stich, Sebastian Urban. ["Local SGD Converges Fast and Communicates Little." ICLR 2019-International Conference on + Learning Representations. No. CONF. 2019.](https://arxiv.org/abs/1805.09767) + + """ + + def __enter__(self): + if self.enabled: + self.model_sync_obj = self.model.no_sync() + self.model_sync_obj.__enter__() + + return self + + def __exit__(self, type, value, tb): + if self.enabled: + # Average all models on exit + self._sync_and_avg_model_params() + self.model_sync_obj.__exit__(type, value, tb) + + def __init__(self, accelerator: Accelerator, model: torch.nn.Module, local_sgd_steps: int, enabled: bool = True): + """ + Constructor. + + Args: + model (`torch.nn.Module): + The model whose parameters we need to average. + accelerator (`Accelerator`): + Accelerator object. + local_sgd_steps (`int`): + A number of local SGD steps (before model parameters are synchronized). + enabled (`bool): + Local SGD is disabled if this parameter set to `False`. + """ + if accelerator.distributed_type not in [ + DistributedType.NO, + DistributedType.MULTI_CPU, + DistributedType.MULTI_GPU, + DistributedType.MULTI_MLU, + DistributedType.MULTI_NPU, + ]: + raise NotImplementedError("LocalSGD is supported only for CPUs and GPUs (no DeepSpeed or MegatronLM)") + self.enabled = enabled and accelerator.distributed_type != DistributedType.NO + self.num_steps = 0 + if self.enabled: + self.accelerator = accelerator + self.model = model + self.local_sgd_steps = local_sgd_steps + + def step(self): + """ + This function makes a "step" and synchronizes model parameters if necessary. + """ + self.num_steps += 1 + if not self.enabled: + return + + if self.num_steps % self.local_sgd_steps == 0: + self._sync_and_avg_model_params() + + def _sync_and_avg_model_params(self): + """ + Synchronize + Average model parameters across all GPUs + """ + + self.accelerator.wait_for_everyone() + with self.accelerator.autocast(): + for param in self.model.parameters(): + param.data = self.accelerator.reduce(param.data, reduction="mean") diff --git a/llm/Lib/site-packages/accelerate/logging.py b/llm/Lib/site-packages/accelerate/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..ebb8c1eb830e54e3f2870cb3a84afd33b7631ea6 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/logging.py @@ -0,0 +1,123 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +import logging +import os + +from .state import PartialState + + +class MultiProcessAdapter(logging.LoggerAdapter): + """ + An adapter to assist with logging in multiprocess. + + `log` takes in an additional `main_process_only` kwarg, which dictates whether it should be called on all processes + or only the main executed one. Default is `main_process_only=True`. + + Does not require an `Accelerator` object to be created first. + """ + + @staticmethod + def _should_log(main_process_only): + "Check if log should be performed" + state = PartialState() + return not main_process_only or (main_process_only and state.is_main_process) + + def log(self, level, msg, *args, **kwargs): + """ + Delegates logger call after checking if we should log. + + Accepts a new kwarg of `main_process_only`, which will dictate whether it will be logged across all processes + or only the main executed one. Default is `True` if not passed + + Also accepts "in_order", which if `True` makes the processes log one by one, in order. This is much easier to + read, but comes at the cost of sometimes needing to wait for the other processes. Default is `False` to not + break with the previous behavior. + + `in_order` is ignored if `main_process_only` is passed. + """ + if PartialState._shared_state == {}: + raise RuntimeError( + "You must initialize the accelerate state by calling either `PartialState()` or `Accelerator()` before using the logging utility." + ) + main_process_only = kwargs.pop("main_process_only", True) + in_order = kwargs.pop("in_order", False) + + if self.isEnabledFor(level): + if self._should_log(main_process_only): + msg, kwargs = self.process(msg, kwargs) + self.logger.log(level, msg, *args, **kwargs) + + elif in_order: + state = PartialState() + for i in range(state.num_processes): + if i == state.process_index: + msg, kwargs = self.process(msg, kwargs) + self.logger.log(level, msg, *args, **kwargs) + state.wait_for_everyone() + + @functools.lru_cache(None) + def warning_once(self, *args, **kwargs): + """ + This method is identical to `logger.warning()`, but will emit the warning with the same message only once + + Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the + cache. The assumption here is that all warning messages are unique across the code. If they aren't then need to + switch to another type of cache that includes the caller frame information in the hashing function. + """ + self.warning(*args, **kwargs) + + +def get_logger(name: str, log_level: str = None): + """ + Returns a `logging.Logger` for `name` that can handle multiprocessing. + + If a log should be called on all processes, pass `main_process_only=False` If a log should be called on all + processes and in order, also pass `in_order=True` + + Args: + name (`str`): + The name for the logger, such as `__file__` + log_level (`str`, *optional*): + The log level to use. If not passed, will default to the `LOG_LEVEL` environment variable, or `INFO` if not + + Example: + + ```python + >>> from accelerate.logging import get_logger + >>> from accelerate import Accelerator + + >>> logger = get_logger(__name__) + + >>> accelerator = Accelerator() + >>> logger.info("My log", main_process_only=False) + >>> logger.debug("My log", main_process_only=True) + + >>> logger = get_logger(__name__, log_level="DEBUG") + >>> logger.info("My log") + >>> logger.debug("My second log") + + >>> array = ["a", "b", "c", "d"] + >>> letter_at_rank = array[accelerator.process_index] + >>> logger.info(letter_at_rank, in_order=True) + ``` + """ + if log_level is None: + log_level = os.environ.get("ACCELERATE_LOG_LEVEL", None) + logger = logging.getLogger(name) + if log_level is not None: + logger.setLevel(log_level.upper()) + logger.root.setLevel(log_level.upper()) + return MultiProcessAdapter(logger, {}) diff --git a/llm/Lib/site-packages/accelerate/memory_utils.py b/llm/Lib/site-packages/accelerate/memory_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fa2e2c8b9d7d0064c3e5e282737a7ad6919bde29 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/memory_utils.py @@ -0,0 +1,22 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + + +warnings.warn( + "memory_utils has been reorganized to utils.memory. Import `find_executable_batchsize` from the main `__init__`: " + "`from accelerate import find_executable_batch_size` to avoid this warning.", + FutureWarning, +) diff --git a/llm/Lib/site-packages/accelerate/optimizer.py b/llm/Lib/site-packages/accelerate/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..c2fc3e9f1b7592b29ed18ce1ce78a0859286f438 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/optimizer.py @@ -0,0 +1,193 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import warnings + +import torch + +from .state import AcceleratorState, GradientState +from .utils import DistributedType, honor_type, is_torch_xla_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + +def move_to_device(state, device): + if isinstance(state, (list, tuple)): + return honor_type(state, (move_to_device(t, device) for t in state)) + elif isinstance(state, dict): + return type(state)({k: move_to_device(v, device) for k, v in state.items()}) + elif isinstance(state, torch.Tensor): + return state.to(device) + return state + + +class AcceleratedOptimizer(torch.optim.Optimizer): + """ + Internal wrapper around a torch optimizer. + + Conditionally will perform `step` and `zero_grad` if gradients should be synchronized when performing gradient + accumulation. + + Args: + optimizer (`torch.optim.optimizer.Optimizer`): + The optimizer to wrap. + device_placement (`bool`, *optional*, defaults to `True`): + Whether or not the optimizer should handle device placement. If so, it will place the state dictionary of + `optimizer` on the right device. + scaler (`torch.cuda.amp.grad_scaler.GradScaler`, *optional*): + The scaler to use in the step function if training with mixed precision. + """ + + def __init__(self, optimizer, device_placement=True, scaler=None): + self.optimizer = optimizer + self.scaler = scaler + self.accelerator_state = AcceleratorState() + self.gradient_state = GradientState() + self.device_placement = device_placement + self._is_overflow = False + + if self.scaler is not None: + self._accelerate_step_called = False + self._optimizer_original_step_method = self.optimizer.step + self._optimizer_patched_step_method = patch_optimizer_step(self, self.optimizer.step) + + # Handle device placement + if device_placement: + state_dict = self.optimizer.state_dict() + if self.accelerator_state.distributed_type == DistributedType.XLA: + xm.send_cpu_data_to_device(state_dict, self.accelerator_state.device) + else: + state_dict = move_to_device(state_dict, self.accelerator_state.device) + self.optimizer.load_state_dict(state_dict) + + @property + def state(self): + return self.optimizer.state + + @state.setter + def state(self, state): + self.optimizer.state = state + + @property + def param_groups(self): + return self.optimizer.param_groups + + @param_groups.setter + def param_groups(self, param_groups): + self.optimizer.param_groups = param_groups + + @property + def defaults(self): + return self.optimizer.defaults + + @defaults.setter + def defaults(self, defaults): + self.optimizer.defaults = defaults + + def add_param_group(self, param_group): + self.optimizer.add_param_group(param_group) + + def load_state_dict(self, state_dict): + if self.accelerator_state.distributed_type == DistributedType.XLA and self.device_placement: + xm.send_cpu_data_to_device(state_dict, self.accelerator_state.device) + self.optimizer.load_state_dict(state_dict) + + def state_dict(self): + return self.optimizer.state_dict() + + def zero_grad(self, set_to_none=None): + if self.gradient_state.sync_gradients: + accept_arg = "set_to_none" in inspect.signature(self.optimizer.zero_grad).parameters + if accept_arg: + if set_to_none is None: + set_to_none = True + self.optimizer.zero_grad(set_to_none=set_to_none) + else: + if set_to_none is not None: + raise ValueError("`set_to_none` for Optimizer.zero_grad` is not supported by this optimizer.") + self.optimizer.zero_grad() + + def step(self, closure=None): + if ( + not self.gradient_state.is_xla_gradients_synced + and self.accelerator_state.distributed_type == DistributedType.XLA + ): + gradients = xm._fetch_gradients(self.optimizer) + xm.all_reduce("sum", gradients, scale=1.0 / xm.xrt_world_size()) + self.gradient_state.is_xla_gradients_synced = True + if self.gradient_state.sync_gradients: + if self.scaler is not None: + self.optimizer.step = self._optimizer_patched_step_method + + self.scaler.step(self.optimizer, closure) + self.scaler.update() + + if not self._accelerate_step_called: + # If the optimizer step was skipped, gradient overflow was detected. + self._is_overflow = True + else: + self._is_overflow = False + # Reset the step method to the original one + self.optimizer.step = self._optimizer_original_step_method + # Reset the indicator + self._accelerate_step_called = False + else: + self.optimizer.step(closure) + if self.accelerator_state.distributed_type == DistributedType.XLA: + self.gradient_state.is_xla_gradients_synced = False + + def _switch_parameters(self, parameters_map): + for param_group in self.optimizer.param_groups: + param_group["params"] = [parameters_map.get(p, p) for p in param_group["params"]] + + @property + def is_overflow(self): + """Whether or not the optimizer step was done, or skipped because of gradient overflow.""" + warnings.warn( + "The `is_overflow` property is deprecated and will be removed in version 1.0 of Accelerate use " + "`optimizer.step_was_skipped` instead.", + FutureWarning, + ) + return self._is_overflow + + @property + def step_was_skipped(self): + """Whether or not the optimizer step was skipped.""" + return self._is_overflow + + def __getstate__(self): + _ignored_keys = [ + "_accelerate_step_called", + "_optimizer_original_step_method", + "_optimizer_patched_step_method", + ] + return {k: v for k, v in self.__dict__.items() if k not in _ignored_keys} + + def __setstate__(self, state): + self.__dict__.update(state) + if self.scaler is not None: + self._accelerate_step_called = False + self._optimizer_original_step_method = self.optimizer.step + self._optimizer_patched_step_method = patch_optimizer_step(self, self.optimizer.step) + + +def patch_optimizer_step(accelerated_optimizer: AcceleratedOptimizer, method): + def patched_step(*args, **kwargs): + accelerated_optimizer._accelerate_step_called = True + return method(*args, **kwargs) + + return patched_step diff --git a/llm/Lib/site-packages/accelerate/scheduler.py b/llm/Lib/site-packages/accelerate/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..1fa8a13f238afd7b908ee8e8cb8e0620f48d4ff8 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/scheduler.py @@ -0,0 +1,98 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# We ignore warnings about stepping the scheduler since we step it ourselves during gradient accumulation + +import warnings + +from .state import AcceleratorState, GradientState + + +warnings.filterwarnings("ignore", category=UserWarning, module="torch.optim.lr_scheduler") + + +class AcceleratedScheduler: + """ + A wrapper around a learning rate scheduler that will only step when the optimizer(s) have a training step. Useful + to avoid making a scheduler step too fast when gradients went overflow and there was no training step (in mixed + precision training) + + When performing gradient accumulation scheduler lengths should not be changed accordingly, Accelerate will always + step the scheduler to account for it. + + Args: + scheduler (`torch.optim.lr_scheduler._LRScheduler`): + The scheduler to wrap. + optimizers (one or a list of `torch.optim.Optimizer`): + The optimizers used. + step_with_optimizer (`bool`, *optional*, defaults to `True`): + Whether or not the scheduler should be stepped at each optimizer step. + split_batches (`bool`, *optional*, defaults to `False`): + Whether or not the dataloaders split one batch across the different processes (so batch size is the same + regardless of the number of processes) or create batches on each process (so batch size is the original + batch size multiplied by the number of processes). + """ + + def __init__(self, scheduler, optimizers, step_with_optimizer: bool = True, split_batches: bool = False): + self.scheduler = scheduler + self.optimizers = optimizers if isinstance(optimizers, (list, tuple)) else [optimizers] + self.split_batches = split_batches + self.step_with_optimizer = step_with_optimizer + self.gradient_state = GradientState() + + def step(self, *args, **kwargs): + if not self.step_with_optimizer: + # No link between scheduler and optimizer -> just step + self.scheduler.step(*args, **kwargs) + return + + # Otherwise, first make sure the optimizer was stepped. + if not self.gradient_state.sync_gradients: + if self.gradient_state.adjust_scheduler: + self.scheduler._step_count += 1 + return + + for opt in self.optimizers: + if opt.step_was_skipped: + return + if self.split_batches: + # Split batches -> the training dataloader batch size is not changed so one step per training step + self.scheduler.step(*args, **kwargs) + else: + # Otherwise the training dataloader batch size was multiplied by `num_processes`, so we need to do + # num_processes steps per training step + num_processes = AcceleratorState().num_processes + for _ in range(num_processes): + # Special case when using OneCycle and `drop_last` was not used + if hasattr(self.scheduler, "total_steps"): + if self.scheduler._step_count <= self.scheduler.total_steps: + self.scheduler.step(*args, **kwargs) + else: + self.scheduler.step(*args, **kwargs) + + # Passthroughs + def get_last_lr(self): + return self.scheduler.get_last_lr() + + def state_dict(self): + return self.scheduler.state_dict() + + def load_state_dict(self, state_dict): + self.scheduler.load_state_dict(state_dict) + + def get_lr(self): + return self.scheduler.get_lr() + + def print_lr(self, *args, **kwargs): + return self.scheduler.print_lr(*args, **kwargs) diff --git a/llm/Lib/site-packages/accelerate/state.py b/llm/Lib/site-packages/accelerate/state.py new file mode 100644 index 0000000000000000000000000000000000000000..1d65c5a1314bd9cdb8013808f1540184ba08b12b --- /dev/null +++ b/llm/Lib/site-packages/accelerate/state.py @@ -0,0 +1,1209 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import logging +import math +import os +import threading +import warnings +from contextlib import contextmanager +from functools import partial +from typing import Any, Callable, Optional + +import torch + +from .utils import ( + DistributedType, + DynamoBackend, + GradientAccumulationPlugin, + check_cuda_p2p_ib_support, + check_fp8_capability, + get_ccl_version, + get_cpu_distributed_information, + get_int_from_env, + is_ccl_available, + is_datasets_available, + is_deepspeed_available, + is_fp8_available, + is_ipex_available, + is_mlu_available, + is_mps_available, + is_npu_available, + is_torch_xla_available, + is_xpu_available, + parse_choice_from_env, + parse_flag_from_env, + set_numa_affinity, +) +from .utils.dataclasses import SageMakerDistributedType + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + +if is_mlu_available(check_device=False): + import torch_mlu # noqa: F401 + +if is_npu_available(check_device=False): + import torch_npu # noqa: F401 + +logger = logging.getLogger(__name__) + + +def is_initialized() -> bool: + """ + Checks if the `AcceleratorState` has been initialized from `Accelerator`. Same as `AcceleratorState.initialized`, + but works as a module method. + """ + return AcceleratorState._shared_state != {} + + +# Lambda function that does nothing +def do_nothing(*args, **kwargs): + return None + + +class ThreadLocalSharedDict(threading.local): + """ + Descriptor that holds a dict shared between instances of a class in the same thread. + + Note: Descriptors have slightly different semantics than just a dict field on its own. + `PartialState(...)._shared_state` and `PartialState._shared_state` (instance vs class) give the same value: the + underlying _storage dict. Likewise, `PartialState(...)._shared_state = {...}` overrides the _storage dict inside + the descriptor as you would expect. However, `PartialState._shared_state = {}` actually replaces the descriptor + object with a dict instead Thus, you should modify the _storage dict in-place (e.g. `_shared_state.clear()`). + + See Python documentation for an explanation of descriptors: https://docs.python.org/3/howto/descriptor.html + + This is required for using PyTorch/XLA with PJRT in multithreaded mode (required for TPU v2 and v3). + + See https://github.com/pytorch/xla/blob/r2.0/docs/pjrt.md#multithreading-on-tpu-v2v3 + """ + + def __init__(self, thread_local: bool = False): + self._storage = {} + + def __get__(self, obj, objtype=None): + return self._storage + + def __set__(self, obj, value): + self._storage = value + + +# Prefer global shared dictionary, except when using TPU. +SharedDict = dict if not is_torch_xla_available() else ThreadLocalSharedDict + + +# Inspired by Alex Martelli's 'Borg'. +class PartialState: + """ + Singleton class that has information about the current training environment and functions to help with process + control. Designed to be used when only process control and device execution states are needed. Does *not* need to + be initialized from `Accelerator`. + + Args: + cpu (`bool`, *optional*): + Whether or not to force the script to execute on CPU. Will ignore any accelerators available if set to + `True` and force the execution on the CPU. + kwargs (additional keyword arguments, *optional*): + Additional keyword arguments to pass to the relevent `init_process_group` function. Valid `kwargs` can be + found in [`utils.InitProcessGroupKwargs`]. See the example section for detailed usage. + + **Available attributes:** + + - **device** (`torch.device`) -- The device to use. + - **distributed_type** ([`~accelerate.state.DistributedType`]) -- The type of distributed environment currently + in use. + - **local_process_index** (`int`) -- The index of the current process on the current server. + - **mixed_precision** (`str`) -- Whether or not the current script will use mixed precision, and if so the type + of mixed precision being performed. (Choose from 'no','fp16','bf16 or 'fp8'). + - **num_processes** (`int`) -- The number of processes currently launched in parallel. + - **process_index** (`int`) -- The index of the current process. + - **is_last_process** (`bool`) -- Whether or not the current process is the last one. + - **is_main_process** (`bool`) -- Whether or not the current process is the main one. + - **is_local_main_process** (`bool`) -- Whether or not the current process is the main one on the local node. + - **debug** (`bool`) -- Whether or not the current script is being run in debug mode. + + Example: + ```python + from accelerate.utils import InitProcessGroupKwargs + + # To include `InitProcessGroupKwargs`, init then call `.to_kwargs()` + kwargs = InitProcessGroupKwargs(...).to_kwargs() + state = PartialState(**kwargs) + ``` + """ + + _shared_state = SharedDict() + _known_attrs = [ + "_cpu", + "_mixed_precision", + "_shared_state", + "backend", + "debug", + "device", + "distributed_type", + "fork_launched", + "local_process_index", + "num_processes", + "process_index", + ] + + def __init__(self, cpu: bool = False, **kwargs): + self.__dict__ = self._shared_state + if not self.initialized: + self._cpu = cpu + self.backend = None + env_device = os.environ.get("ACCELERATE_TORCH_DEVICE", None) + self.device = torch.device(env_device) if env_device is not None else None + self.debug = parse_flag_from_env("ACCELERATE_DEBUG_MODE") + use_sagemaker_dp = kwargs.pop("_use_sagemaker_dp", None) + dist_information = None + if use_sagemaker_dp is None: + use_sagemaker_dp = ( + os.environ.get("ACCELERATE_USE_SAGEMAKER", "false") == "true" + and os.environ.get("ACCELERATE_SAGEMAKER_DISTRIBUTED_TYPE") != SageMakerDistributedType.NO + ) + + # Sets up self.backend + imports + original_backend = kwargs.pop("backend", None) + backend, distributed_type = self._prepare_backend(cpu, use_sagemaker_dp, original_backend) + if original_backend is not None and backend != original_backend: + raise ValueError("Your assigned backend {original_backend} is not avaliable, please use {backend}") + self.backend = backend + self.distributed_type = distributed_type + use_deepspeed = False + if not cpu and self.backend != "xla": + if int(os.environ.get("LOCAL_RANK", -1)) != -1: + # Deal with spawning deepspeed + if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true": + if not is_deepspeed_available(): + raise ImportError( + "DeepSpeed is not available => install it using `pip3 install deepspeed` or build it from source" + ) + from deepspeed import comm as dist + + if is_xpu_available() and is_ccl_available(): + os.environ["CCL_PROCESS_LAUNCHER"] = "none" + os.environ["CCL_LOCAL_SIZE"] = os.environ.get("LOCAL_WORLD_SIZE", "1") + os.environ["CCL_LOCAL_RANK"] = os.environ.get("LOCAL_RANK", "0") + + if not dist.is_initialized(): + dist.init_distributed(dist_backend=self.backend, auto_mpi_discovery=False, **kwargs) + # We need to flag to `use_deepspeed` to be True to override `distributed_type` later + use_deepspeed = True + # Deal with all other backends but XPU and CPU, that gets handled special later + elif ( + self.distributed_type not in (DistributedType.MULTI_XPU, DistributedType.MULTI_CPU) + and not torch.distributed.is_initialized() + ): + torch.distributed.init_process_group(backend=self.backend, **kwargs) + # XPU and CPU require special env configs to be set + if self.distributed_type in (DistributedType.MULTI_XPU, DistributedType.MULTI_CPU): + dist_information = get_cpu_distributed_information() + os.environ["RANK"] = str(dist_information.rank) + os.environ["WORLD_SIZE"] = str(dist_information.world_size) + os.environ["LOCAL_RANK"] = str(dist_information.local_rank) + os.environ["LOCAL_WORLD_SIZE"] = str(dist_information.local_world_size) + if self.backend == "ccl" and self.distributed_type == DistributedType.MULTI_XPU: + os.environ["CCL_PROCESS_LAUNCHER"] = "none" + os.environ["CCL_LOCAL_SIZE"] = os.environ["LOCAL_WORLD_SIZE"] + os.environ["CCL_LOCAL_RANK"] = os.environ["LOCAL_RANK"] + if not os.environ.get("MASTER_PORT", None): + os.environ["MASTER_PORT"] = "29500" + if ( + not os.environ.get("MASTER_ADDR", None) + and dist_information.local_world_size != dist_information.world_size + and self.backend != "mpi" + ): + raise ValueError( + "Tried to launch on distributed with multinode, but `MASTER_ADDR` env was not set, " + "please try exporting rank 0's hostname as `MASTER_ADDR`" + ) + kwargs["rank"] = dist_information.rank + kwargs["world_size"] = dist_information.world_size + + if ( + self.distributed_type == DistributedType.MULTI_CPU + and get_int_from_env(["OMP_NUM_THREADS", "OMP_NUM_THREADS"], 0) > 0 + ): + import psutil + + num_cpu_threads_per_process = int( + psutil.cpu_count(logical=False) / dist_information.local_world_size + ) + if num_cpu_threads_per_process == 0: + num_cpu_threads_per_process = 1 + torch.set_num_threads(num_cpu_threads_per_process) + warnings.warn( + f"OMP_NUM_THREADS/MKL_NUM_THREADS unset, we set it at {num_cpu_threads_per_process} to improve oob" + " performance." + ) + + if not torch.distributed.is_initialized(): + torch.distributed.init_process_group(backend=self.backend, **kwargs) + + # No backend == no distributed training + if self.backend is None: + self.distributed_type = DistributedType.NO + self.num_processes = 1 + self.process_index = 0 + self.local_process_index = 0 + elif self.backend == "xla": + # XLA needs device setting first for `set_replication` + self.set_device() + xm.set_replication(self.device, xm.get_xla_supported_devices()) + self.num_processes = xm.xrt_world_size() + self.process_index = xm.get_ordinal() + if is_torch_xla_available(check_is_tpu=True): + self.local_process_index = xm.get_local_ordinal() + else: + self.local_process_index = int(os.environ.get("LOCAL_RANK", -1)) + else: + self.num_processes = torch.distributed.get_world_size() + self.process_index = torch.distributed.get_rank() + self.local_process_index = ( + int(os.environ.get("LOCAL_RANK", -1)) if dist_information is None else dist_information.local_rank + ) + self.set_device() + # Now we can change to deepseed + if use_deepspeed: + self.distributed_type = DistributedType.DEEPSPEED + + # Set CPU affinity if enabled + if parse_flag_from_env("ACCELERATE_CPU_AFFINITY", False): + set_numa_affinity(self.local_process_index) + + # Check for old RTX 4000's that can't use P2P or IB and are on old drivers + if self.device.type == "cuda" and not check_cuda_p2p_ib_support(): + if "NCCL_P2P_DISABLE" not in os.environ or "NCCL_IB_DISABLE" not in os.environ: + raise NotImplementedError( + "Using RTX 4000 series doesn't support faster communication broadband via P2P or IB. " + 'Please set `NCCL_P2P_DISABLE="1"` and `NCCL_IB_DISABLE="1" or use `accelerate launch` which ' + "will do this automatically." + ) + # Important: This should be the *only* code outside of `self.initialized!` + self.fork_launched = parse_flag_from_env("FORK_LAUNCHED", 0) + + def __repr__(self) -> str: + return ( + f"Distributed environment: {self.distributed_type}{(' Backend: ' + self.backend) if self.backend else ''}\n" + f"Num processes: {self.num_processes}\n" + f"Process index: {self.process_index}\n" + f"Local process index: {self.local_process_index}\n" + f"Device: {self.device}\n" + ) + + @staticmethod + def _reset_state(): + "Resets `_shared_state`, is used internally and should not be called" + PartialState._shared_state.clear() + + @property + def initialized(self) -> bool: + "Returns whether the `PartialState` has been initialized" + return self._shared_state != {} + + @property + def use_distributed(self): + """ + Whether the Accelerator is configured for distributed training + """ + return self.distributed_type != DistributedType.NO and self.num_processes > 1 + + @property + def is_last_process(self) -> bool: + "Returns whether the current process is the last one" + return self.process_index == self.num_processes - 1 + + @property + def is_main_process(self) -> bool: + "Returns whether the current process is the main process" + return ( + self.process_index == 0 if self.distributed_type != DistributedType.MEGATRON_LM else self.is_last_process + ) + + @property + def is_local_main_process(self) -> bool: + "Returns whether the current process is the main process on the local node" + return ( + self.local_process_index == 0 + if self.distributed_type != DistributedType.MEGATRON_LM + else self.is_last_process + ) + + def wait_for_everyone(self): + """ + Will stop the execution of the current process until every other process has reached that point (so this does + nothing when the script is only run in one process). Useful to do before saving a model. + + Example: + + ```python + >>> # Assuming two GPU processes + >>> import time + >>> from accelerate.state import PartialState + + >>> state = PartialState() + >>> if state.is_main_process: + ... time.sleep(2) + >>> else: + ... print("I'm waiting for the main process to finish its sleep...") + >>> state.wait_for_everyone() + >>> # Should print on every process at the same time + >>> print("Everyone is here") + ``` + """ + if self.distributed_type in ( + DistributedType.MULTI_GPU, + DistributedType.MULTI_MLU, + DistributedType.MULTI_NPU, + DistributedType.MULTI_XPU, + DistributedType.MULTI_CPU, + DistributedType.DEEPSPEED, + DistributedType.FSDP, + ): + torch.distributed.barrier() + elif self.distributed_type == DistributedType.XLA: + xm.rendezvous("accelerate.utils.wait_for_everyone") + + def _goes_first(self, is_main: bool): + if not is_main: + self.wait_for_everyone() + + yield + + if is_main: + self.wait_for_everyone() + + @contextmanager + def split_between_processes(self, inputs: list | tuple | dict | torch.Tensor, apply_padding: bool = False): + """ + Splits `input` between `self.num_processes` quickly and can be then used on that process. Useful when doing + distributed inference, such as with different prompts. + + Note that when using a `dict`, all keys need to have the same number of elements. + + Args: + inputs (`list`, `tuple`, `torch.Tensor`, `dict` of `list`/`tuple`/`torch.Tensor`, or `datasets.Dataset`): + The input to split between processes. + apply_padding (`bool`, `optional`, defaults to `False`): + Whether to apply padding by repeating the last element of the input so that all processes have the same + number of elements. Useful when trying to perform actions such as `gather()` on the outputs or passing + in less inputs than there are processes. If so, just remember to drop the padded elements afterwards. + + + Example: + + ```python + # Assume there are two processes + from accelerate import PartialState + + state = PartialState() + with state.split_between_processes(["A", "B", "C"]) as inputs: + print(inputs) + # Process 0 + ["A", "B"] + # Process 1 + ["C"] + + with state.split_between_processes(["A", "B", "C"], apply_padding=True) as inputs: + print(inputs) + # Process 0 + ["A", "B"] + # Process 1 + ["C", "C"] + ``` + """ + if self.num_processes == 1: + yield inputs + return + length = len(inputs) + # Nested dictionary of any types + if isinstance(inputs, dict): + length = len(inputs[list(inputs.keys())[0]]) + if not all(len(v) == length for v in inputs.values()): + raise ValueError("All values in the dictionary must have the same length") + num_samples_per_process = math.ceil(length / self.num_processes) + start_index = self.process_index * num_samples_per_process + end_index = start_index + num_samples_per_process + if (len(inputs) % self.num_processes != 0) and (self.process_index == self.num_processes - 1): + end_index = length + + def _split_values(inputs, start_index, end_index): + if isinstance(inputs, (list, tuple, torch.Tensor)): + if start_index >= len(inputs): + result = inputs[-1:] + else: + result = inputs[start_index:end_index] + if apply_padding: + if isinstance(result, torch.Tensor): + from accelerate.utils import pad_across_processes, send_to_device + + # The tensor needs to be on the device before we can pad it + tensorized_result = send_to_device(result, self.device) + result = pad_across_processes(tensorized_result, pad_index=inputs[-1]) + else: + result += [result[-1]] * (num_samples_per_process - len(result)) + return result + elif isinstance(inputs, dict): + for key in inputs.keys(): + inputs[key] = _split_values(inputs[key], start_index, end_index) + return inputs + else: + if is_datasets_available(): + from datasets import Dataset + + if isinstance(inputs, Dataset): + if start_index >= len(inputs): + start_index = len(inputs) - 1 + if end_index > len(inputs): + end_index = len(inputs) + result_idcs = list(range(start_index, end_index)) + if apply_padding: + result_idcs += [end_index - 1] * (num_samples_per_process - len(result_idcs)) + return inputs.select(result_idcs) + return inputs + + yield _split_values(inputs, start_index, end_index) + + @contextmanager + def main_process_first(self): + """ + Lets the main process go first inside a with block. + + The other processes will enter the with block after the main process exits. + + Example: + + ```python + >>> from accelerate import Accelerator + + >>> accelerator = Accelerator() + >>> with accelerator.main_process_first(): + ... # This will be printed first by process 0 then in a seemingly + ... # random order by the other processes. + ... print(f"This will be printed by process {accelerator.process_index}") + ``` + """ + yield from self._goes_first(self.is_main_process) + + @contextmanager + def local_main_process_first(self): + """ + Lets the local main process go inside a with block. + + The other processes will enter the with block after the main process exits. + + Example: + + ```python + >>> from accelerate.state import PartialState + + >>> state = PartialState() + >>> with state.local_main_process_first(): + ... # This will be printed first by local process 0 then in a seemingly + ... # random order by the other processes. + ... print(f"This will be printed by process {state.local_process_index}") + ``` + """ + yield from self._goes_first(self.is_local_main_process) + + def on_main_process(self, function: Callable[..., Any] = None): + """ + Decorator that only runs the decorated function on the main process. + + Args: + function (`Callable`): The function to decorate. + + Example: + + ```python + >>> from accelerate.state import PartialState + + >>> state = PartialState() + + + >>> @state.on_main_process + ... def print_something(): + ... print("This will be printed by process 0 only.") + + + >>> print_something() + "This will be printed by process 0 only" + ``` + """ + if not self.initialized: + raise ValueError("The `PartialState` or `Accelerator` must be initialized before calling this function.") + if self.is_main_process or not self.use_distributed: + return function + return do_nothing + + def on_local_main_process(self, function: Callable[..., Any] = None): + """ + Decorator that only runs the decorated function on the local main process. + + Args: + function (`Callable`): The function to decorate. + + Example: + ```python + # Assume we have 2 servers with 4 processes each. + from accelerate.state import PartialState + + state = PartialState() + + + @state.on_local_main_process + def print_something(): + print("This will be printed by process 0 only on each server.") + + + print_something() + # On server 1: + "This will be printed by process 0 only" + # On server 2: + "This will be printed by process 0 only" + ``` + """ + if self.is_local_main_process or not self.use_distributed: + return function + return do_nothing + + def on_last_process(self, function: Callable[..., Any]): + """ + Decorator that only runs the decorated function on the last process. + + Args: + function (`Callable`): The function to decorate. + + Example: + ```python + # Assume we have 4 processes. + from accelerate.state import PartialState + + state = PartialState() + + + @state.on_last_process + def print_something(): + print(f"Printed on process {state.process_index}") + + + print_something() + "Printed on process 3" + ``` + """ + if self.is_last_process or not self.use_distributed: + return function + return do_nothing + + def on_process(self, function: Callable[..., Any] = None, process_index: int = None): + """ + Decorator that only runs the decorated function on the process with the given index. + + Args: + function (`Callable`, `optional`): + The function to decorate. + process_index (`int`, `optional`): + The index of the process on which to run the function. + + Example: + ```python + # Assume we have 4 processes. + from accelerate.state import PartialState + + state = PartialState() + + + @state.on_process(process_index=2) + def print_something(): + print(f"Printed on process {state.process_index}") + + + print_something() + "Printed on process 2" + ``` + """ + if function is None: + return partial(self.on_process, process_index=process_index) + if (self.process_index == process_index) or (not self.use_distributed): + return function + return do_nothing + + def on_local_process(self, function: Callable[..., Any] = None, local_process_index: int = None): + """ + Decorator that only runs the decorated function on the process with the given index on the current node. + + Args: + function (`Callable`, *optional*): + The function to decorate. + local_process_index (`int`, *optional*): + The index of the local process on which to run the function. + + Example: + ```python + # Assume we have 2 servers with 4 processes each. + from accelerate import Accelerator + + accelerator = Accelerator() + + + @accelerator.on_local_process(local_process_index=2) + def print_something(): + print(f"Printed on process {accelerator.local_process_index}") + + + print_something() + # On server 1: + "Printed on process 2" + # On server 2: + "Printed on process 2" + ``` + """ + if function is None: + return partial(self.on_local_process, local_process_index=local_process_index) + if (self.local_process_index == local_process_index) or (not self.use_distributed): + return function + return do_nothing + + def print(self, *args, **kwargs): + if self.is_local_main_process: + print(*args, **kwargs) + + @property + def default_device(self) -> torch.device: + """ + Returns the default device which is: + - MPS if `torch.backends.mps.is_available()` and `torch.backends.mps.is_built()` both return True. + - CUDA if `torch.cuda.is_available()` + - MLU if `is_mlu_available()` + - NPU if `is_npu_available()` + - CPU otherwise + """ + if is_mps_available(): + os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + return torch.device("mps") + elif is_mlu_available(): + return torch.device("mlu") + elif torch.cuda.is_available(): + return torch.device("cuda") + elif is_xpu_available(): + return torch.device("xpu:0") + elif is_npu_available(): + return torch.device("npu") + else: + return torch.device("cpu") + + def _prepare_backend( + self, cpu: bool = False, sagemaker_dp=False, backend: str = None + ) -> tuple[str, DistributedType]: + "Prepares any imports needed before initializing the distributed backend and sets `self.backend` properly" + distributed_type = None + if sagemaker_dp: + import smdistributed.dataparallel.torch.torch_smddp # noqa + + backend = "smddp" + distributed_type = DistributedType.MULTI_GPU + elif is_torch_xla_available(): + backend = "xla" + distributed_type = DistributedType.XLA + elif int(os.environ.get("LOCAL_RANK", -1)) != -1 and not cpu: + if is_mlu_available(): + backend = "cncl" + distributed_type = DistributedType.MULTI_MLU + elif torch.cuda.is_available(): + if backend is None: + backend = "nccl" + distributed_type = DistributedType.MULTI_GPU + elif is_npu_available(): + backend = "hccl" + distributed_type = DistributedType.MULTI_NPU + + if distributed_type is None and ( + int(os.environ.get("LOCAL_RANK", -1)) != -1 + or get_int_from_env(["PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE", "WORLD_SIZE"], 1) > 1 + ): + if not cpu and is_xpu_available(): + distributed_type = DistributedType.MULTI_XPU + else: + distributed_type = DistributedType.MULTI_CPU + + if ( + backend in (None, "ccl") + and is_ccl_available() + and (get_int_from_env(["CCL_WORKER_COUNT"], 0) > 0 or distributed_type == DistributedType.MULTI_XPU) + ): + if get_ccl_version() >= "1.12": + import oneccl_bindings_for_pytorch # noqa: F401 + else: + import torch_ccl # noqa: F401 + + backend = "ccl" + elif backend in (None, "mpi") and torch.distributed.is_mpi_available(): + backend = "mpi" + else: + backend = "gloo" + if distributed_type is None: + distributed_type = DistributedType.NO + + return backend, distributed_type + + def set_device(self): + """ + Sets the device in `self.device` to the current distributed environment. + """ + if self.device is not None: + return + if self.distributed_type == DistributedType.NO: + self.device = torch.device("cpu") if self._cpu else self.default_device + return + device = str(self.distributed_type).split(".")[-1].replace("MULTI_", "").lower() + if device not in ("cpu", "gpu", "mlu", "npu", "xpu", "xla"): + raise ValueError( + f"Can't set device for {self.distributed_type} ({device}), verify we should be calling `_set_device()` for it!" + ) + if device == "xla": + self.device = xm.xla_device() + else: + if device == "gpu": + device = "cuda" + self.device = torch.device(device, self.local_process_index) + if self.device is not None: + if device == "xpu": + torch.xpu.set_device(self.device) + elif device == "mlu": + torch.mlu.set_device(self.device) + elif device == "npu": + torch.npu.set_device(self.device) + elif device == "cuda": + torch.cuda.set_device(self.device) + + def __getattr__(self, name: str): + # By this point we know that no attributes of `self` contain `name`, + # so we just modify the error message + if name in self._known_attrs: + raise AttributeError( + f"`PartialState` object has no attribute `{name}`. " + "This happens if `PartialState._reset_state()` was called and " + "an `Accelerator` or `PartialState` was not reinitialized." + ) + # Raise a typical AttributeError + raise AttributeError(f"'PartialState' object has no attribute '{name}'") + + +class AcceleratorState: + """ + Singleton class that has information about the current training environment. + + **Available attributes:** + + - **device** (`torch.device`) -- The device to use. + - **distributed_type** ([`~accelerate.state.DistributedType`]) -- The type of distributed environment currently + in use. + - **initialized** (`bool`) -- Whether or not the `AcceleratorState` has been initialized from `Accelerator`. + - **local_process_index** (`int`) -- The index of the current process on the current server. + - **mixed_precision** (`str`) -- Whether or not the current script will use mixed precision, and if so the type + of mixed precision being performed. (Choose from 'no','fp16','bf16 or 'fp8'). + - **num_processes** (`int`) -- The number of processes currently launched in parallel. + - **process_index** (`int`) -- The index of the current process. + - **is_last_process** (`bool`) -- Whether or not the current process is the last one. + - **is_main_process** (`bool`) -- Whether or not the current process is the main one. + - **is_local_main_process** (`bool`) -- Whether or not the current process is the main one on the local node. + - **debug** (`bool`) -- Whether or not the current script is being run in debug mode. + """ + + _shared_state = SharedDict() + _known_attrs = PartialState._known_attrs + [ + "deepspeed_plugin", + "use_ipex", + "fsdp_plugin", + "megatron_lm_plugin", + "dynamo_plugin", + ] + + def __init__( + self, + mixed_precision: str = None, + cpu: bool = False, + dynamo_plugin=None, + deepspeed_plugin=None, + fsdp_plugin=None, + megatron_lm_plugin=None, + _from_accelerator: bool = False, + **kwargs, + ): + self.__dict__ = self._shared_state + if parse_flag_from_env("ACCELERATE_USE_CPU"): + cpu = True + if PartialState._shared_state == {}: + PartialState(cpu, **kwargs) + self.__dict__.update(PartialState._shared_state) + self._check_initialized(mixed_precision, cpu) + if not self.initialized: + self.deepspeed_plugin = None + self.use_ipex = None + mixed_precision = ( + parse_choice_from_env("ACCELERATE_MIXED_PRECISION", "no") + if mixed_precision is None + else mixed_precision.lower() + ) + if mixed_precision == "fp8": + if not is_fp8_available(): + raise ValueError( + "Using `fp8` precision requires `transformer_engine` or `MS-AMP` to be installed." + ) + elif not check_fp8_capability(): + logger.warning( + f"The current device has compute capability of {torch.cuda.get_device_capability()} which is " + "insufficient for FP8 mixed precision training (requires a GPU Hopper/Ada Lovelace " + "or higher, compute capability of 8.9 or higher). Will use FP16 instead." + ) + mixed_precision = "fp16" + + self.dynamo_plugin = dynamo_plugin + if not _from_accelerator: + raise ValueError( + "Please make sure to properly initialize your accelerator via `accelerator = Accelerator()` " + "before using any functionality from the `accelerate` library." + ) + # deepspeed handles mixed_precision using deepspeed_config + self._mixed_precision = "no" if self.distributed_type == DistributedType.DEEPSPEED else mixed_precision + if self.distributed_type == DistributedType.XLA and is_torch_xla_available(check_is_tpu=True): + if mixed_precision == "bf16": + if os.environ.get("ACCELERATE_DOWNCAST_BF16"): + os.environ["XLA_USE_BF16"] = str(0) + os.environ["XLA_DOWNCAST_BF16"] = str(1) + self.downcast_bfloat = True + else: + os.environ["XLA_USE_BF16"] = str(1) + os.environ["XLA_DOWNCAST_BF16"] = str(0) + self.downcast_bfloat = False + elif os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true" and not cpu: + self.deepspeed_plugin = deepspeed_plugin + elif self.distributed_type in [ + DistributedType.MULTI_GPU, + DistributedType.MULTI_MLU, + DistributedType.MULTI_NPU, + DistributedType.MULTI_XPU, + ]: + if os.environ.get("ACCELERATE_USE_FSDP", "false") == "true": + self.distributed_type = DistributedType.FSDP + if self._mixed_precision != "no": + fsdp_plugin.set_mixed_precision(self._mixed_precision) + self.fsdp_plugin = fsdp_plugin + if os.environ.get("ACCELERATE_USE_MEGATRON_LM", "false") == "true" and self.distributed_type not in [ + DistributedType.MULTI_NPU, + DistributedType.MULTI_XPU, + ]: + self.distributed_type = DistributedType.MEGATRON_LM + megatron_lm_plugin.set_mixed_precision(self._mixed_precision) + self.megatron_lm_plugin = megatron_lm_plugin + elif self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]: + if is_ipex_available(): + # check if user disables it explicitly + self.use_ipex = parse_flag_from_env("ACCELERATE_USE_IPEX", default=True) + else: + self.use_ipex = False + if ( + self.dynamo_plugin.backend != DynamoBackend.NO + and self._mixed_precision == "no" + and self.device.type == "cuda" + ): + torch.backends.cuda.matmul.allow_tf32 = True + PartialState._shared_state["distributed_type"] = self.distributed_type + + @property + def initialized(self) -> bool: + return self._shared_state != PartialState._shared_state + + def __repr__(self): + repr = PartialState().__repr__() + f"\nMixed precision type: {self.mixed_precision}\n" + if self.distributed_type == DistributedType.DEEPSPEED: + repr += f"ds_config: {self.deepspeed_plugin.deepspeed_config}\n" + return repr + + def _check_initialized(self, mixed_precision=None, cpu=None): + "Checks if a modification is trying to be made and the `AcceleratorState` has already been initialized" + if self.initialized: + err = "AcceleratorState has already been initialized and cannot be changed, restart your runtime completely and pass `{flag}` to `Accelerator()`." + if cpu and self.device.type != "cpu": + raise ValueError(err.format(flag="cpu=True")) + if ( + mixed_precision is not None + and mixed_precision != self._mixed_precision + and self.distributed_type != DistributedType.DEEPSPEED + ): + raise ValueError(err.format(flag=f"mixed_precision='{mixed_precision}'")) + + # For backward compatibility + @property + def use_fp16(self): + warnings.warn( + "The `use_fp16` property is deprecated and will be removed in version 1.0 of Accelerate use " + "`AcceleratorState.mixed_precision == 'fp16'` instead.", + FutureWarning, + ) + return self._mixed_precision != "no" + + @property + def mixed_precision(self): + if self.distributed_type == DistributedType.DEEPSPEED: + config = self.deepspeed_plugin.deepspeed_config + if config.get("fp16", {}).get("enabled", False): + mixed_precision = "fp16" + elif config.get("bf16", {}).get("enabled", False): + mixed_precision = "bf16" + else: + mixed_precision = "no" + else: + mixed_precision = self._mixed_precision + return mixed_precision + + @staticmethod + def _reset_state(reset_partial_state: bool = False): + "Resets `_shared_state`, is used internally and should not be called" + AcceleratorState._shared_state.clear() + if reset_partial_state: + PartialState._reset_state() + + @property + def use_distributed(self): + """ + Whether the Accelerator is configured for distributed training + """ + return PartialState().use_distributed + + @property + def is_last_process(self) -> bool: + "Returns whether the current process is the last one" + return PartialState().is_last_process + + @property + def is_main_process(self) -> bool: + "Returns whether the current process is the main process" + return PartialState().is_main_process + + @property + def is_local_main_process(self) -> bool: + "Returns whether the current process is the main process on the local node" + return PartialState().is_local_main_process + + def wait_for_everyone(self): + PartialState().wait_for_everyone() + + @contextmanager + def split_between_processes(self, inputs: list | tuple | dict | torch.Tensor, apply_padding: bool = False): + """ + Splits `input` between `self.num_processes` quickly and can be then used on that process. Useful when doing + distributed inference, such as with different prompts. + + Note that when using a `dict`, all keys need to have the same number of elements. + + Args: + inputs (`list`, `tuple`, `torch.Tensor`, or `dict` of `list`/`tuple`/`torch.Tensor`): + The input to split between processes. + apply_padding (`bool`, `optional`, defaults to `False`): + Whether to apply padding by repeating the last element of the input so that all processes have the same + number of elements. Useful when trying to perform actions such as `gather()` on the outputs or passing + in less inputs than there are processes. If so, just remember to drop the padded elements afterwards. + + + Example: + + ```python + # Assume there are two processes + from accelerate.state import AcceleratorState + + state = AcceleratorState() + with state.split_between_processes(["A", "B", "C"]) as inputs: + print(inputs) + # Process 0 + ["A", "B"] + # Process 1 + ["C"] + + with state.split_between_processes(["A", "B", "C"], apply_padding=True) as inputs: + print(inputs) + # Process 0 + ["A", "B"] + # Process 1 + ["C", "C"] + ``` + """ + with PartialState().split_between_processes(inputs, apply_padding=apply_padding) as inputs: + yield inputs + + @contextmanager + def main_process_first(self): + """ + Lets the main process go first inside a with block. + + The other processes will enter the with block after the main process exits. + """ + with PartialState().main_process_first(): + yield + + @contextmanager + def local_main_process_first(self): + """ + Lets the local main process go inside a with block. + + The other processes will enter the with block after the main process exits. + """ + with PartialState().local_main_process_first(): + yield + + def print(self, *args, **kwargs): + PartialState().print(*args, **kwargs) + + def __getattr__(self, name: str): + # By this point we know that no attributes of `self` contain `name`, + # so we just modify the error message + if name in self._known_attrs: + raise AttributeError( + f"`AcceleratorState` object has no attribute `{name}`. " + "This happens if `AcceleratorState._reset_state()` was called and " + "an `Accelerator` or `PartialState` was not reinitialized." + ) + # Raise a typical AttributeError + raise AttributeError(f"'AcceleratorState' object has no attribute '{name}'") + + +class GradientState: + """ + Singleton class that has information related to gradient synchronization for gradient accumulation + + **Available attributes:** + + - **end_of_dataloader** (`bool`) -- Whether we have reached the end the current dataloader + - **remainder** (`int`) -- The number of extra samples that were added from padding the dataloader + - **sync_gradients** (`bool`) -- Whether the gradients should be synced across all devices + - **active_dataloader** (`Optional[DataLoader]`) -- The dataloader that is currently being iterated over + - **dataloader_references** (`List[Optional[DataLoader]]`) -- A list of references to the dataloaders that are + being iterated over + - **num_steps** (`int`) -- The number of steps to accumulate over + - **adjust_scheduler** (`bool`) -- Whether the scheduler should be adjusted to account for the gradient + accumulation + - **sync_with_dataloader** (`bool`) -- Whether the gradients should be synced at the end of the dataloader + iteration and the number of total steps reset + - **is_xla_gradients_synced** (`bool`) -- Whether the XLA gradients have been synchronized. It is initialized + as false. Once gradients have been reduced before the optimizer step, this flag is set to true. Subsequently, + after each step, the flag is reset to false. FSDP will always synchronize the gradients, hence + is_xla_gradients_synced is always true. + """ + + _shared_state = SharedDict() + + def __init__(self, gradient_accumulation_plugin: Optional[GradientAccumulationPlugin] = None): + self.__dict__ = self._shared_state + if not self.initialized: + self.sync_gradients = True + self.active_dataloader = None + self.dataloader_references = [None] + self.plugin_kwargs = ( + gradient_accumulation_plugin.to_kwargs() if gradient_accumulation_plugin is not None else {} + ) + self._is_xla_gradients_synced = False + + # Plugin args are different and can be updated + if gradient_accumulation_plugin is not None and self.plugin_kwargs != gradient_accumulation_plugin.to_kwargs(): + self.plugin_kwargs = gradient_accumulation_plugin.to_kwargs() + + @property + def num_steps(self) -> int: + "Returns the number of steps to accumulate over" + return self.plugin_kwargs.get("num_steps", 1) + + @property + def adjust_scheduler(self) -> bool: + "Returns whether the scheduler should be adjusted" + return self.plugin_kwargs.get("adjust_scheduler", False) + + @property + def sync_with_dataloader(self) -> bool: + "Returns whether the gradients should be synced at the end of the dataloader iteration and the number of total steps reset" + return self.plugin_kwargs.get("sync_with_dataloader", True) + + @property + def initialized(self) -> bool: + "Returns whether the `GradientState` has been initialized" + return GradientState._shared_state != {} + + @property + def end_of_dataloader(self) -> bool: + "Returns whether we have reached the end of the current dataloader" + if not self.in_dataloader: + return False + return self.active_dataloader.end_of_dataloader + + @property + def remainder(self) -> int: + "Returns the number of extra samples that were added from padding the dataloader" + if not self.in_dataloader: + return -1 + return self.active_dataloader.remainder + + def __repr__(self): + return ( + f"Sync Gradients: {self.sync_gradients}\n" + f"At end of current dataloader: {self.end_of_dataloader}\n" + f"Extra samples added: {self.remainder}\n" + f"Gradient accumulation plugin: {self.plugin_kwargs}\n" + ) + + @property + def is_xla_gradients_synced(self): + "Returns the value of is_xla_gradients_synced. FSDP will always synchronize the gradients, hence is_xla_gradients_synced is always true." + if parse_flag_from_env("ACCELERATE_USE_FSDP", default=False): + return True + return self._is_xla_gradients_synced + + @is_xla_gradients_synced.setter + def is_xla_gradients_synced(self, is_synced): + "Set the _is_xla_gradients_synced attribute." + self._is_xla_gradients_synced = is_synced + + def _set_sync_gradients(self, sync_gradients): + "Private function that sets whether gradients should be synchronized. Users should not have to call this." + self.sync_gradients = sync_gradients + # Allow grad-sync to automatically work on TPUs + if ( + self.sync_gradients + and is_torch_xla_available(check_is_tpu=True) + and PartialState().distributed_type == DistributedType.XLA + ): + xm.mark_step() + + def _add_dataloader(self, dataloader): + "Private function that adds a dataloader to `self.dataloader_references` and sets `in_dataloader` to `True`. Users should not have to call this." + self.active_dataloader = dataloader + self.dataloader_references.append(self.active_dataloader) + + def _remove_dataloader(self, dataloader): + "Private function that removes a dataloader from `self.dataloader_references` and sets `in_dataloader` to `False` if there are no more dataloaders. Users should not have to call this." + self.dataloader_references.remove(dataloader) + self.active_dataloader = self.dataloader_references[-1] + + @property + def in_dataloader(self) -> bool: + "Returns whether the current process is in a dataloader" + return self.active_dataloader is not None + + @staticmethod + def _reset_state(): + "Resets `_shared_state`, is used internally and should not be called" + GradientState._shared_state.clear() diff --git a/llm/Lib/site-packages/accelerate/test_utils/__init__.py b/llm/Lib/site-packages/accelerate/test_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bd5f6f15da98737459dcf7ab05b68e87be9a384d --- /dev/null +++ b/llm/Lib/site-packages/accelerate/test_utils/__init__.py @@ -0,0 +1,50 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .testing import ( + DEFAULT_LAUNCH_COMMAND, + are_the_same_tensors, + assert_exception, + device_count, + execute_subprocess_async, + get_launch_command, + memory_allocated_func, + path_in_accelerate_package, + require_bnb, + require_cpu, + require_cuda, + require_huggingface_suite, + require_mlu, + require_mps, + require_multi_device, + require_multi_gpu, + require_multi_xpu, + require_non_cpu, + require_non_torch_xla, + require_non_xpu, + require_npu, + require_pippy, + require_single_device, + require_single_gpu, + require_single_xpu, + require_torch_min_version, + require_tpu, + require_xpu, + skip, + slow, + torch_device, +) +from .training import RegressionDataset, RegressionModel, RegressionModel4XPU + + +from .scripts import test_script, test_sync, test_ops # isort: skip diff --git a/llm/Lib/site-packages/accelerate/test_utils/__pycache__/__init__.cpython-311.pyc b/llm/Lib/site-packages/accelerate/test_utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0e1e14943b988a4819407f6ef4947768741f941 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/test_utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/test_utils/__pycache__/examples.cpython-311.pyc b/llm/Lib/site-packages/accelerate/test_utils/__pycache__/examples.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9983a43a90a39d61df78f8dd7a22bea880a3c95e Binary files /dev/null and b/llm/Lib/site-packages/accelerate/test_utils/__pycache__/examples.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/test_utils/__pycache__/testing.cpython-311.pyc b/llm/Lib/site-packages/accelerate/test_utils/__pycache__/testing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84480f78f294000c7132cee947d6b8dd895b138e Binary files /dev/null and b/llm/Lib/site-packages/accelerate/test_utils/__pycache__/testing.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/test_utils/__pycache__/training.cpython-311.pyc b/llm/Lib/site-packages/accelerate/test_utils/__pycache__/training.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3137e8ac599bc210c1c19a32d7de1c6460c9904 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/test_utils/__pycache__/training.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/test_utils/examples.py b/llm/Lib/site-packages/accelerate/test_utils/examples.py new file mode 100644 index 0000000000000000000000000000000000000000..ed41d38c9092385ba9730472aa10b5208f48c67b --- /dev/null +++ b/llm/Lib/site-packages/accelerate/test_utils/examples.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python + +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A collection of utilities for comparing `examples/complete_*_example.py` scripts with the capabilities inside of each +`examples/by_feature` example. `compare_against_test` is the main function that should be used when testing, while the +others are used to either get the code that matters, or to preprocess them (such as stripping comments) +""" + +import os +from typing import List + + +def get_function_contents_by_name(lines: List[str], name: str): + """ + Extracts a function from `lines` of segmented source code with the name `name`. + + Args: + lines (`List[str]`): + Source code of a script seperated by line. + name (`str`): + The name of the function to extract. Should be either `training_function` or `main` + """ + if name != "training_function" and name != "main": + raise ValueError(f"Incorrect function name passed: {name}, choose either 'main' or 'training_function'") + good_lines, found_start = [], False + for line in lines: + if not found_start and f"def {name}" in line: + found_start = True + good_lines.append(line) + continue + if found_start: + if name == "training_function" and "def main" in line: + return good_lines + if name == "main" and "if __name__" in line: + return good_lines + good_lines.append(line) + + +def clean_lines(lines: List[str]): + """ + Filters `lines` and removes any entries that start with a comment ('#') or is just a newline ('\n') + + Args: + lines (`List[str]`): + Source code of a script seperated by line. + """ + return [line for line in lines if not line.lstrip().startswith("#") and line != "\n"] + + +def compare_against_test(base_filename: str, feature_filename: str, parser_only: bool, secondary_filename: str = None): + """ + Tests whether the additional code inside of `feature_filename` was implemented in `base_filename`. This should be + used when testing to see if `complete_*_.py` examples have all of the implementations from each of the + `examples/by_feature/*` scripts. + + It utilizes `nlp_example.py` to extract out all of the repeated training code, so that only the new additional code + is examined and checked. If something *other* than `nlp_example.py` should be used, such as `cv_example.py` for the + `complete_cv_example.py` script, it should be passed in for the `secondary_filename` parameter. + + Args: + base_filename (`str` or `os.PathLike`): + The filepath of a single "complete" example script to test, such as `examples/complete_cv_example.py` + feature_filename (`str` or `os.PathLike`): + The filepath of a single feature example script. The contents of this script are checked to see if they + exist in `base_filename` + parser_only (`bool`): + Whether to compare only the `main()` sections in both files, or to compare the contents of + `training_loop()` + secondary_filename (`str`, *optional*): + A potential secondary filepath that should be included in the check. This function extracts the base + functionalities off of "examples/nlp_example.py", so if `base_filename` is a script other than + `complete_nlp_example.py`, the template script should be included here. Such as `examples/cv_example.py` + """ + with open(base_filename) as f: + base_file_contents = f.readlines() + with open(os.path.abspath(os.path.join("examples", "nlp_example.py"))) as f: + full_file_contents = f.readlines() + with open(feature_filename) as f: + feature_file_contents = f.readlines() + if secondary_filename is not None: + with open(secondary_filename) as f: + secondary_file_contents = f.readlines() + + # This is our base, we remove all the code from here in our `full_filename` and `feature_filename` to find the new content + if parser_only: + base_file_func = clean_lines(get_function_contents_by_name(base_file_contents, "main")) + full_file_func = clean_lines(get_function_contents_by_name(full_file_contents, "main")) + feature_file_func = clean_lines(get_function_contents_by_name(feature_file_contents, "main")) + if secondary_filename is not None: + secondary_file_func = clean_lines(get_function_contents_by_name(secondary_file_contents, "main")) + else: + base_file_func = clean_lines(get_function_contents_by_name(base_file_contents, "training_function")) + full_file_func = clean_lines(get_function_contents_by_name(full_file_contents, "training_function")) + feature_file_func = clean_lines(get_function_contents_by_name(feature_file_contents, "training_function")) + if secondary_filename is not None: + secondary_file_func = clean_lines( + get_function_contents_by_name(secondary_file_contents, "training_function") + ) + + _dl_line = "train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size)\n" + + # Specific code in our script that differs from the full version, aka what is new + new_feature_code = [] + passed_idxs = [] # We keep track of the idxs just in case it's a repeated statement + it = iter(feature_file_func) + for i in range(len(feature_file_func) - 1): + if i not in passed_idxs: + line = next(it) + if (line not in full_file_func) and (line.lstrip() != _dl_line): + if "TESTING_MOCKED_DATALOADERS" not in line: + new_feature_code.append(line) + passed_idxs.append(i) + else: + # Skip over the `config['num_epochs'] = 2` statement + _ = next(it) + + # Extract out just the new parts from the full_file_training_func + new_full_example_parts = [] + passed_idxs = [] # We keep track of the idxs just in case it's a repeated statement + for i, line in enumerate(base_file_func): + if i not in passed_idxs: + if (line not in full_file_func) and (line.lstrip() != _dl_line): + if "TESTING_MOCKED_DATALOADERS" not in line: + new_full_example_parts.append(line) + passed_idxs.append(i) + + # Finally, get the overall diff + diff_from_example = [line for line in new_feature_code if line not in new_full_example_parts] + if secondary_filename is not None: + diff_from_two = [line for line in full_file_contents if line not in secondary_file_func] + diff_from_example = [line for line in diff_from_example if line not in diff_from_two] + + return diff_from_example diff --git a/llm/Lib/site-packages/accelerate/test_utils/scripts/__init__.py b/llm/Lib/site-packages/accelerate/test_utils/scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c9cbe26c257b515f657c05e1996d517e69613972 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/test_utils/scripts/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/llm/Lib/site-packages/accelerate/test_utils/scripts/__pycache__/__init__.cpython-311.pyc b/llm/Lib/site-packages/accelerate/test_utils/scripts/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b61705e7d38d33b6cc3fee11e3c9fe14e3c8ff89 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/test_utils/scripts/__pycache__/__init__.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/test_utils/scripts/__pycache__/test_cli.cpython-311.pyc b/llm/Lib/site-packages/accelerate/test_utils/scripts/__pycache__/test_cli.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a02ca8f675bd710aada57f4adab157b92228be30 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/test_utils/scripts/__pycache__/test_cli.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/test_utils/scripts/__pycache__/test_distributed_data_loop.cpython-311.pyc b/llm/Lib/site-packages/accelerate/test_utils/scripts/__pycache__/test_distributed_data_loop.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88d885f1635f58316d1f74dad4e9abd2b524e8f1 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/test_utils/scripts/__pycache__/test_distributed_data_loop.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/test_utils/scripts/__pycache__/test_notebook.cpython-311.pyc b/llm/Lib/site-packages/accelerate/test_utils/scripts/__pycache__/test_notebook.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33c817fb715be1f5f730467ac829b191ae6eebf3 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/test_utils/scripts/__pycache__/test_notebook.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/test_utils/scripts/__pycache__/test_ops.cpython-311.pyc b/llm/Lib/site-packages/accelerate/test_utils/scripts/__pycache__/test_ops.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..990ebb9bad41480954d302c4aab98fd6f6354497 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/test_utils/scripts/__pycache__/test_ops.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/test_utils/scripts/__pycache__/test_script.cpython-311.pyc b/llm/Lib/site-packages/accelerate/test_utils/scripts/__pycache__/test_script.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c63382264579021430bb6cba4b5326b6cc8ec403 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/test_utils/scripts/__pycache__/test_script.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/test_utils/scripts/__pycache__/test_sync.cpython-311.pyc b/llm/Lib/site-packages/accelerate/test_utils/scripts/__pycache__/test_sync.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09e2e317c500aea3f4fc0d288eecbd3a062ff53d Binary files /dev/null and b/llm/Lib/site-packages/accelerate/test_utils/scripts/__pycache__/test_sync.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/__init__.py b/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c9cbe26c257b515f657c05e1996d517e69613972 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/__pycache__/__init__.cpython-311.pyc b/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b21d5e7635e07ce514e940382363fc86b7b8e94 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/__pycache__/__init__.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/__pycache__/test_checkpointing.cpython-311.pyc b/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/__pycache__/test_checkpointing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d924cc4dd933a4cb069ae49c535b3fb7d832b5d7 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/__pycache__/test_checkpointing.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/__pycache__/test_metrics.cpython-311.pyc b/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/__pycache__/test_metrics.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d860e38e7eef6552506acf57f0228cb2fd9e062 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/__pycache__/test_metrics.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/__pycache__/test_peak_memory_usage.cpython-311.pyc b/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/__pycache__/test_peak_memory_usage.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28e83f19b1cca2db813acf9b30d7acaa4bab4c4a Binary files /dev/null and b/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/__pycache__/test_peak_memory_usage.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/__pycache__/test_performance.cpython-311.pyc b/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/__pycache__/test_performance.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e7c701c0fb318f23db970c4c45df4fecaea53e3 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/__pycache__/test_performance.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/__pycache__/test_pippy.cpython-311.pyc b/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/__pycache__/test_pippy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e97efad21eb8e892553f32b182df6f5ec00f708 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/__pycache__/test_pippy.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/__pycache__/test_zero3_integration.cpython-311.pyc b/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/__pycache__/test_zero3_integration.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6c1eb163c1f5a615169eac5285f809a2e2f31a7 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/__pycache__/test_zero3_integration.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/test_checkpointing.py b/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/test_checkpointing.py new file mode 100644 index 0000000000000000000000000000000000000000..41c77c7ec5e6e2475a795efdb54702600eac0282 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/test_checkpointing.py @@ -0,0 +1,268 @@ +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import os + +import evaluate +import torch +from datasets import load_dataset +from torch.optim import AdamW +from torch.utils.data import DataLoader +from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed + +from accelerate import Accelerator, DistributedType +from accelerate.utils.deepspeed import DummyOptim, DummyScheduler + + +MAX_GPU_BATCH_SIZE = 16 +EVAL_BATCH_SIZE = 32 + + +def get_dataloaders(accelerator: Accelerator, batch_size: int = 16, model_name: str = "bert-base-cased"): + """ + Creates a set of `DataLoader`s for the `glue` dataset. + + Args: + accelerator (`Accelerator`): + An `Accelerator` object + batch_size (`int`, *optional*): + The batch size for the train and validation DataLoaders. + model_name (`str`, *optional*): + """ + tokenizer = AutoTokenizer.from_pretrained(model_name) + datasets = load_dataset("glue", "mrpc") + + def tokenize_function(examples): + # max_length=None => use the model max length (it's actually the default) + outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None) + return outputs + + # Apply the method we just defined to all the examples in all the splits of the dataset + tokenized_datasets = datasets.map( + tokenize_function, batched=True, remove_columns=["idx", "sentence1", "sentence2"], load_from_cache_file=False + ) + + # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the + # transformers library + tokenized_datasets = tokenized_datasets.rename_column("label", "labels") + + def collate_fn(examples): + # On TPU it's best to pad everything to the same length or training will be very slow. + if accelerator.distributed_type == DistributedType.XLA: + return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt") + return tokenizer.pad(examples, padding="longest", return_tensors="pt") + + # Instantiate dataloaders. + train_dataloader = DataLoader( + tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size + ) + eval_dataloader = DataLoader( + tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE + ) + + return train_dataloader, eval_dataloader + + +def evaluation_loop(accelerator, model, eval_dataloader, metric): + model.eval() + samples_seen = 0 + for step, batch in enumerate(eval_dataloader): + # We could avoid this line since we set the accelerator with `device_placement=True`. + batch.to(accelerator.device) + with torch.no_grad(): + outputs = model(**batch) + predictions = outputs.logits.argmax(dim=-1) + # It is slightly faster to call this once, than multiple times + predictions, references = accelerator.gather( + (predictions, batch["labels"]) + ) # If we are in a multiprocess environment, the last batch has duplicates + if accelerator.use_distributed: + if step == len(eval_dataloader) - 1: + predictions = predictions[: len(eval_dataloader.dataset) - samples_seen] + references = references[: len(eval_dataloader.dataset) - samples_seen] + else: + samples_seen += references.shape[0] + metric.add_batch( + predictions=predictions, + references=references, + ) + + eval_metric = metric.compute() + return eval_metric["accuracy"] + + +def training_function(config, args): + # Initialize accelerator + accelerator = Accelerator() + + # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs + lr = config["lr"] + num_epochs = int(config["num_epochs"]) + seed = int(config["seed"]) + batch_size = int(config["batch_size"]) + model_name = args.model_name_or_path + + set_seed(seed) + train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size, model_name) + + # Instantiate the model (we build the model here so that the seed also control new weights initialization) + model = AutoModelForSequenceClassification.from_pretrained(model_name, return_dict=True) + + # Instantiate optimizer + optimizer_cls = ( + AdamW + if accelerator.state.deepspeed_plugin is None + or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config + else DummyOptim + ) + optimizer = optimizer_cls(params=model.parameters(), lr=lr) + + if accelerator.state.deepspeed_plugin is not None: + gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[ + "gradient_accumulation_steps" + ] + else: + gradient_accumulation_steps = 1 + max_training_steps = (len(train_dataloader) * num_epochs) // gradient_accumulation_steps + + # Instantiate scheduler + if ( + accelerator.state.deepspeed_plugin is None + or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config + ): + lr_scheduler = get_linear_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=0, + num_training_steps=max_training_steps, + ) + else: + lr_scheduler = DummyScheduler(optimizer, total_num_steps=max_training_steps, warmup_num_steps=0) + + # Prepare everything + # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the + # prepare method. + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + ) + + # We need to keep track of how many total steps we have iterated over + overall_step = 0 + # We also need to keep track of the stating epoch so files are named properly + starting_epoch = 0 + metric = evaluate.load("glue", "mrpc") + ending_epoch = num_epochs + + if args.partial_train_epoch is not None: + ending_epoch = args.partial_train_epoch + + if args.resume_from_checkpoint: + accelerator.load_state(args.resume_from_checkpoint) + epoch_string = args.resume_from_checkpoint.split("epoch_")[1] + state_epoch_num = "" + for char in epoch_string: + if char.isdigit(): + state_epoch_num += char + else: + break + starting_epoch = int(state_epoch_num) + 1 + accuracy = evaluation_loop(accelerator, model, eval_dataloader, metric) + accelerator.print("resumed checkpoint performance:", accuracy) + accelerator.print("resumed checkpoint's scheduler's lr:", lr_scheduler.get_lr()[0]) + accelerator.print("resumed optimizers's lr:", optimizer.param_groups[0]["lr"]) + with open(os.path.join(args.output_dir, f"state_{starting_epoch - 1}.json")) as f: + resumed_state = json.load(f) + assert resumed_state["accuracy"] == accuracy, "Accuracy mismatch, loading from checkpoint failed" + assert ( + resumed_state["lr"] == lr_scheduler.get_lr()[0] + ), "Scheduler learning rate mismatch, loading from checkpoint failed" + assert ( + resumed_state["optimizer_lr"] == optimizer.param_groups[0]["lr"] + ), "Optimizer learning rate mismatch, loading from checkpoint failed" + assert resumed_state["epoch"] == starting_epoch - 1, "Epoch mismatch, loading from checkpoint failed" + return + + # Now we train the model + state = {} + for epoch in range(starting_epoch, ending_epoch): + model.train() + for step, batch in enumerate(train_dataloader): + outputs = model(**batch) + loss = outputs.loss + loss = loss / gradient_accumulation_steps + accelerator.backward(loss) + if step % gradient_accumulation_steps == 0: + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + overall_step += 1 + output_dir = f"epoch_{epoch}" + output_dir = os.path.join(args.output_dir, output_dir) + accelerator.save_state(output_dir) + accuracy = evaluation_loop(accelerator, model, eval_dataloader, metric) + state["accuracy"] = accuracy + state["lr"] = lr_scheduler.get_lr()[0] + state["optimizer_lr"] = optimizer.param_groups[0]["lr"] + state["epoch"] = epoch + state["step"] = overall_step + accelerator.print(f"epoch {epoch}:", state) + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + with open(os.path.join(args.output_dir, f"state_{epoch}.json"), "w") as f: + json.dump(state, f) + + +def main(): + parser = argparse.ArgumentParser(description="Simple example of training script tracking peak GPU memory usage.") + parser.add_argument( + "--model_name_or_path", + type=str, + default="bert-base-cased", + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=False, + ) + parser.add_argument( + "--output_dir", + type=str, + default=".", + help="Optional save directory where all checkpoint folders will be stored. Default is the current working directory.", + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help="If the training should continue from a checkpoint folder.", + ) + parser.add_argument( + "--partial_train_epoch", + type=int, + default=None, + help="If passed, the training will stop after this number of epochs.", + ) + parser.add_argument( + "--num_epochs", + type=int, + default=2, + help="Number of train epochs.", + ) + args = parser.parse_args() + config = {"lr": 2e-5, "num_epochs": args.num_epochs, "seed": 42, "batch_size": 16} + + training_function(config, args) + + +if __name__ == "__main__": + main() diff --git a/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/test_metrics.py b/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/test_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..9ac13aba6266f31f1e0f9eb41b961fc2933d00ab --- /dev/null +++ b/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/test_metrics.py @@ -0,0 +1,306 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import math +import os +from copy import deepcopy + +import datasets +import evaluate +import torch +import transformers +from datasets import load_dataset +from torch.utils.data import DataLoader, IterableDataset +from transformers import AutoModelForSequenceClassification, AutoTokenizer + +from accelerate import Accelerator, DataLoaderConfiguration, DistributedType +from accelerate.data_loader import DataLoaderDispatcher +from accelerate.test_utils import RegressionDataset, RegressionModel, torch_device +from accelerate.utils import is_torch_xla_available, set_seed + + +os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "true" + + +class ListHandler(logging.Handler): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.logs = [] + + def emit(self, record): + self.logs.append(record) + + +def get_basic_setup(accelerator, num_samples=82, batch_size=16): + "Returns everything needed to perform basic training" + set_seed(42) + model = RegressionModel() + ddp_model = deepcopy(model) + dset = RegressionDataset(length=num_samples) + dataloader = DataLoader(dset, batch_size=batch_size) + model.to(accelerator.device) + ddp_model, dataloader = accelerator.prepare(ddp_model, dataloader) + return model, ddp_model, dataloader + + +def get_dataloader(accelerator: Accelerator, use_longest=False): + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/mrpc-bert-base-cased") + dataset = load_dataset("glue", "mrpc", split="validation") + + def tokenize_function(examples): + outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None) + return outputs + + with accelerator.main_process_first(): + tokenized_datasets = dataset.map( + tokenize_function, + batched=True, + remove_columns=["idx", "sentence1", "sentence2"], + ) + + tokenized_datasets = tokenized_datasets.rename_column("label", "labels") + + def collate_fn(examples): + if use_longest: + return tokenizer.pad(examples, padding="longest", return_tensors="pt") + return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt") + + return DataLoader(tokenized_datasets, shuffle=False, collate_fn=collate_fn, batch_size=16) + + +def get_mrpc_setup(dispatch_batches, split_batches): + dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, split_batches=split_batches) + accelerator = Accelerator(dataloader_config=dataloader_config) + dataloader = get_dataloader(accelerator, not dispatch_batches) + model = AutoModelForSequenceClassification.from_pretrained( + "hf-internal-testing/mrpc-bert-base-cased", return_dict=True + ) + ddp_model, ddp_dataloader = accelerator.prepare(model, dataloader) + return { + "ddp": [ddp_model, ddp_dataloader, torch_device], + "no": [model, dataloader, accelerator.device], + }, accelerator + + +def generate_predictions(model, dataloader, accelerator): + logits_and_targets = [] + for batch in dataloader: + input, target = batch.values() + with torch.no_grad(): + logit = model(input) + logit, target = accelerator.gather_for_metrics((logit, target)) + logits_and_targets.append((logit, target)) + logits, targs = [], [] + for logit, targ in logits_and_targets: + logits.append(logit) + targs.append(targ) + logits, targs = torch.cat(logits), torch.cat(targs) + return logits, targs + + +def test_torch_metrics( + accelerator: Accelerator, num_samples=82, dispatch_batches=False, split_batches=False, batch_size=16 +): + _, ddp_model, dataloader = get_basic_setup(accelerator, num_samples, batch_size) + logits, _ = generate_predictions(ddp_model, dataloader, accelerator) + assert ( + len(logits) == num_samples + ), f"Unexpected number of inputs:\n Expected: {num_samples}\n Actual: {len(logits)}" + + +def test_mrpc(dispatch_batches: bool = False, split_batches: bool = False): + metric = evaluate.load("glue", "mrpc") + setup, accelerator = get_mrpc_setup(dispatch_batches, split_batches) + # First do baseline + model, dataloader, device = setup["no"] + model.to(device) + model.eval() + for batch in dataloader: + batch.to(device) + with torch.inference_mode(): + outputs = model(**batch) + preds = outputs.logits.argmax(dim=-1) + metric.add_batch(predictions=preds, references=batch["labels"]) + baseline = metric.compute() + + # Then do distributed + model, dataloader, device = setup["ddp"] + model.eval() + for batch in dataloader: + with torch.inference_mode(): + outputs = model(**batch) + preds = outputs.logits.argmax(dim=-1) + references = batch["labels"] + preds, references = accelerator.gather_for_metrics((preds, references)) + metric.add_batch(predictions=preds, references=references) + distributed = metric.compute() + + for key in "accuracy f1".split(): + assert math.isclose( + baseline[key], distributed[key] + ), f"Baseline and Distributed are not the same for key {key}:\n\tBaseline: {baseline[key]}\n\tDistributed: {distributed[key]}\n" + + +def test_gather_for_metrics_with_non_tensor_objects_iterable_dataset(): + class DummyIterableDataset(IterableDataset): + def __init__(self, data): + self.data = data + + def __len__(self): + return len(self.data) + + def __iter__(self): + yield from self.data + + iterable_dataset = DummyIterableDataset([n for n in range(30)]) + dataloader = DataLoader(iterable_dataset, batch_size=4) + accelerator = Accelerator() + prepared_dataloader = accelerator.prepare(dataloader) + + if accelerator.is_main_process: + logger = logging.root.manager.loggerDict["accelerate.accelerator"] + list_handler = ListHandler() + logger.addHandler(list_handler) + + batches_for_metrics = [] + for batch in prepared_dataloader: + batches_for_metrics.append(accelerator.gather_for_metrics(batch)) + + assert torch.cat(batches_for_metrics).size(0) == 30 + + if accelerator.is_main_process: + assert len(list_handler.logs) == 0 + logger.removeHandler(list_handler) + + +def test_gather_for_metrics_with_iterable_dataset(): + class DummyIterableDataset(IterableDataset): + def __init__(self, data): + self.data = data + + def __len__(self): + return len(self.data) + + def __iter__(self): + yield from self.data + + iterable_dataset = DummyIterableDataset(torch.as_tensor(range(30))) + dataloader = DataLoader(iterable_dataset, batch_size=4) + + accelerator = Accelerator() + prepared_dataloader = accelerator.prepare(dataloader) + + assert isinstance(prepared_dataloader, DataLoaderDispatcher) + + if accelerator.is_main_process: + logger = logging.root.manager.loggerDict["accelerate.accelerator"] + list_handler = ListHandler() + logger.addHandler(list_handler) + + batches_for_metrics = [] + for batch in prepared_dataloader: + batches_for_metrics.append(accelerator.gather_for_metrics(batch)) + + assert torch.cat(batches_for_metrics).size(0) == 30 + + if accelerator.is_main_process: + assert len(list_handler.logs) == 0 + + logger.removeHandler(list_handler) + + +def test_gather_for_metrics_drop_last(): + accelerator = Accelerator() + per_device_batch_size = 5 + num_items = (10 * accelerator.num_processes) + 1 + dataloader = DataLoader(range(num_items), batch_size=per_device_batch_size, drop_last=True) + dataloader = accelerator.prepare(dataloader) + + iterator = iter(dataloader) + next(iterator) # Skip first batch tensor([0, 1, 2, 3, 4], device='cuda:0') + batch = next(iterator) + gathered_items = accelerator.gather_for_metrics(batch) + + # Should return a full set of complete batches from each GPU + num_expected_items = per_device_batch_size * accelerator.num_processes + assert gathered_items.size(0) == ( + num_expected_items + ), f"Expected number of items: {num_expected_items}, Actual: {gathered_items.size(0)}" + + +def main(): + dataloader_config = DataLoaderConfiguration(split_batches=False, dispatch_batches=False) + accelerator = Accelerator(dataloader_config=dataloader_config) + if accelerator.is_local_main_process: + datasets.utils.logging.set_verbosity_warning() + transformers.utils.logging.set_verbosity_warning() + else: + datasets.utils.logging.set_verbosity_error() + transformers.utils.logging.set_verbosity_error() + # TorchXLA does not support batch dispatching. 'put_on_device' is always False for + # TorchXLA, which can cause a value error in 'prepare_data_loader' function. + dispatch_batches_options = [False] if accelerator.state.distributed_type == DistributedType.XLA else [True, False] + + # Temporarily close this test for TorchXLA due to the 'Cannot set version_counter for + # inference tensor' error in inference mode. Reopen it after TorchXLA fixes this bug. + # These are a bit slower so they should only be ran on the GPU or TPU + if accelerator.device.type != "cpu" and not is_torch_xla_available(): + if accelerator.is_local_main_process: + print("**Testing gather_for_metrics**") + for split_batches in [True, False]: + for dispatch_batches in dispatch_batches_options: + if accelerator.is_local_main_process: + print(f"With: `split_batches={split_batches}`, `dispatch_batches={dispatch_batches}`") + test_mrpc(dispatch_batches, split_batches) + accelerator.state._reset_state() + print("test_gather_for_metrics_with_iterable_dataset") + test_gather_for_metrics_with_iterable_dataset() + print("test gather_for_metrics_with_non_tensor_objects_iterable_dataset") + test_gather_for_metrics_with_non_tensor_objects_iterable_dataset() + + # MpDeviceLoader in TorchXLA is an asynchronous loader that preloads several batches into cache. + # This can cause the 'end_of_dataloader' of DataLoaderStateMixin to be set earlier than intended. + # Skip this test when TorchXLA is enabled. + if accelerator.state.distributed_type != DistributedType.XLA: + if accelerator.is_local_main_process: + print("**Test torch metrics**") + for split_batches in [True, False]: + for dispatch_batches in dispatch_batches_options: + dataloader_config = DataLoaderConfiguration( + split_batches=split_batches, dispatch_batches=dispatch_batches + ) + accelerator = Accelerator(dataloader_config=dataloader_config) + if accelerator.is_local_main_process: + print(f"With: `split_batches={split_batches}`, `dispatch_batches={dispatch_batches}`, length=99") + test_torch_metrics(accelerator, 99) + accelerator.state._reset_state() + if accelerator.is_local_main_process: + print("**Test last batch is not dropped when perfectly divisible**") + accelerator = Accelerator() + test_torch_metrics(accelerator, 512) + accelerator.state._reset_state() + if accelerator.is_local_main_process: + print("**Test that `drop_last` is taken into account**") + test_gather_for_metrics_drop_last() + accelerator.state._reset_state() + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/test_peak_memory_usage.py b/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/test_peak_memory_usage.py new file mode 100644 index 0000000000000000000000000000000000000000..aeb55f6c87d7831ed5d7f370a4b9d7810777bd3e --- /dev/null +++ b/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/test_peak_memory_usage.py @@ -0,0 +1,282 @@ +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import gc +import json +import os + +import torch +from datasets import load_dataset +from torch.optim import AdamW +from torch.utils.data import DataLoader +from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed + +from accelerate import Accelerator, DistributedType +from accelerate.utils import is_mlu_available, is_npu_available, is_xpu_available +from accelerate.utils.deepspeed import DummyOptim, DummyScheduler + + +MAX_GPU_BATCH_SIZE = 16 +EVAL_BATCH_SIZE = 32 + + +# Converting Bytes to Megabytes +def b2mb(x): + return int(x / 2**20) + + +# This context manager is used to track the peak memory usage of the process +class TorchTracemalloc: + def __enter__(self): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero + self.begin = torch.cuda.memory_allocated() + elif is_mlu_available(): + torch.mlu.empty_cache() + torch.mlu.reset_max_memory_allocated() # reset the peak gauge to zero + self.begin = torch.mlu.memory_allocated() + elif is_npu_available(): + torch.npu.empty_cache() + torch.npu.reset_max_memory_allocated() # reset the peak gauge to zero + self.begin = torch.npu.memory_allocated() + elif is_xpu_available(): + torch.xpu.empty_cache() + torch.xpu.reset_max_memory_allocated() # reset the peak gauge to zero + self.begin = torch.xpu.memory_allocated() + return self + + def __exit__(self, *exc): + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + self.end = torch.cuda.memory_allocated() + self.peak = torch.cuda.max_memory_allocated() + elif is_mlu_available(): + torch.mlu.empty_cache() + torch.mlu.memory_allocated() # reset the peak gauge to zero + self.begin = torch.mlu.max_memory_allocated() + elif is_npu_available(): + torch.npu.empty_cache() + self.end = torch.npu.memory_allocated() + self.peak = torch.npu.max_memory_allocated() + elif is_xpu_available(): + torch.xpu.empty_cache() + self.end = torch.xpu.memory_allocated() + self.peak = torch.xpu.max_memory_allocated() + self.used = b2mb(self.end - self.begin) + self.peaked = b2mb(self.peak - self.begin) + # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}") + + +def get_dataloaders( + accelerator: Accelerator, + batch_size: int = 16, + model_name: str = "bert-base-cased", + n_train: int = 320, + n_val: int = 160, +): + """ + Creates a set of `DataLoader`s for the `glue` dataset. + + Args: + accelerator (`Accelerator`): + An `Accelerator` object + batch_size (`int`, *optional*): + The batch size for the train and validation DataLoaders. + model_name (`str`, *optional*): + The name of the model to use. + n_train (`int`, *optional*): + The number of training examples to use. + n_val (`int`, *optional*): + The number of validation examples to use. + """ + tokenizer = AutoTokenizer.from_pretrained(model_name) + datasets = load_dataset( + "glue", "mrpc", split={"train": f"train[:{n_train}]", "validation": f"validation[:{n_val}]"} + ) + + def tokenize_function(examples): + # max_length=None => use the model max length (it's actually the default) + outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None) + return outputs + + # Apply the method we just defined to all the examples in all the splits of the dataset + tokenized_datasets = datasets.map( + tokenize_function, batched=True, remove_columns=["idx", "sentence1", "sentence2"], load_from_cache_file=False + ) + + # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the + # transformers library + tokenized_datasets = tokenized_datasets.rename_column("label", "labels") + + def collate_fn(examples): + # On TPU it's best to pad everything to the same length or training will be very slow. + if accelerator.distributed_type == DistributedType.XLA: + return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt") + return tokenizer.pad(examples, padding="longest", return_tensors="pt") + + # Instantiate dataloaders. + train_dataloader = DataLoader( + tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size + ) + eval_dataloader = DataLoader( + tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE + ) + + return train_dataloader, eval_dataloader + + +def training_function(config, args): + # Initialize accelerator + accelerator = Accelerator() + + # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs + lr = config["lr"] + num_epochs = int(config["num_epochs"]) + seed = int(config["seed"]) + batch_size = int(config["batch_size"]) + model_name = args.model_name_or_path + + set_seed(seed) + train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size, model_name, args.n_train, args.n_val) + + # Instantiate the model (we build the model here so that the seed also control new weights initialization) + model = AutoModelForSequenceClassification.from_pretrained(model_name, return_dict=True) + + # Instantiate optimizer + optimizer_cls = ( + AdamW + if accelerator.state.deepspeed_plugin is None + or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config + else DummyOptim + ) + optimizer = optimizer_cls(params=model.parameters(), lr=lr) + + if accelerator.state.deepspeed_plugin is not None: + gradient_accumulation_steps = accelerator.state.deepspeed_plugin.deepspeed_config[ + "gradient_accumulation_steps" + ] + else: + gradient_accumulation_steps = 1 + max_training_steps = (len(train_dataloader) * num_epochs) // gradient_accumulation_steps + + # Instantiate scheduler + if ( + accelerator.state.deepspeed_plugin is None + or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config + ): + lr_scheduler = get_linear_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=0, + num_training_steps=max_training_steps, + ) + else: + lr_scheduler = DummyScheduler(optimizer, total_num_steps=max_training_steps, warmup_num_steps=0) + + # Prepare everything + # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the + # prepare method. + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + ) + + # We need to keep track of how many total steps we have iterated over + overall_step = 0 + # We also need to keep track of the stating epoch so files are named properly + starting_epoch = 0 + + # Now we train the model + train_total_peak_memory = {} + for epoch in range(starting_epoch, num_epochs): + with TorchTracemalloc() as tracemalloc: + model.train() + for step, batch in enumerate(train_dataloader): + outputs = model(**batch) + loss = outputs.loss + loss = loss / gradient_accumulation_steps + accelerator.backward(loss) + if step % gradient_accumulation_steps == 0: + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + overall_step += 1 + + # Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage + accelerator.print(f"Memory before entering the train : {b2mb(tracemalloc.begin)}") + accelerator.print(f"Memory consumed at the end of the train (end-begin): {tracemalloc.used}") + accelerator.print(f"Peak Memory consumed during the train (max-begin): {tracemalloc.peaked}") + accelerator.print( + f"Total Peak Memory consumed during the train (max): {tracemalloc.peaked + b2mb(tracemalloc.begin)}" + ) + train_total_peak_memory[f"epoch-{epoch}"] = tracemalloc.peaked + b2mb(tracemalloc.begin) + if args.peak_memory_upper_bound is not None: + assert ( + train_total_peak_memory[f"epoch-{epoch}"] <= args.peak_memory_upper_bound + ), "Peak memory usage exceeded the upper bound" + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + with open(os.path.join(args.output_dir, "peak_memory_utilization.json"), "w") as f: + json.dump(train_total_peak_memory, f) + + +def main(): + parser = argparse.ArgumentParser(description="Simple example of training script tracking peak GPU memory usage.") + parser.add_argument( + "--model_name_or_path", + type=str, + default="bert-base-cased", + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=False, + ) + parser.add_argument( + "--output_dir", + type=str, + default=".", + help="Optional save directory where all checkpoint folders will be stored. Default is the current working directory.", + ) + parser.add_argument( + "--peak_memory_upper_bound", + type=float, + default=None, + help="The upper bound of peak memory usage in MB. If set, the training will throw an error if the peak memory usage exceeds this value.", + ) + parser.add_argument( + "--n_train", + type=int, + default=320, + help="Number of training examples to use.", + ) + parser.add_argument( + "--n_val", + type=int, + default=160, + help="Number of validation examples to use.", + ) + parser.add_argument( + "--num_epochs", + type=int, + default=1, + help="Number of train epochs.", + ) + args = parser.parse_args() + config = {"lr": 2e-5, "num_epochs": args.num_epochs, "seed": 42, "batch_size": 16} + training_function(config, args) + + +if __name__ == "__main__": + main() diff --git a/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/test_performance.py b/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/test_performance.py new file mode 100644 index 0000000000000000000000000000000000000000..7051859aa74bbac5b15e4465395b8177e3dd1d27 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/test_performance.py @@ -0,0 +1,243 @@ +# Copyright 2022 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import os + +import evaluate +import torch +from datasets import load_dataset +from torch.optim import AdamW +from torch.utils.data import DataLoader +from transformers import AutoModelForSequenceClassification, AutoTokenizer, get_linear_schedule_with_warmup, set_seed + +from accelerate import Accelerator, DistributedType +from accelerate.utils.deepspeed import DummyOptim, DummyScheduler + + +MAX_GPU_BATCH_SIZE = 16 +EVAL_BATCH_SIZE = 32 + + +def get_dataloaders(accelerator: Accelerator, batch_size: int = 16, model_name: str = "bert-base-cased"): + """ + Creates a set of `DataLoader`s for the `glue` dataset. + + Args: + accelerator (`Accelerator`): + An `Accelerator` object + batch_size (`int`, *optional*): + The batch size for the train and validation DataLoaders. + model_name (`str`, *optional*): + """ + tokenizer = AutoTokenizer.from_pretrained(model_name) + datasets = load_dataset("glue", "mrpc") + + def tokenize_function(examples): + # max_length=None => use the model max length (it's actually the default) + outputs = tokenizer(examples["sentence1"], examples["sentence2"], truncation=True, max_length=None) + return outputs + + # Apply the method we just defined to all the examples in all the splits of the dataset + tokenized_datasets = datasets.map( + tokenize_function, batched=True, remove_columns=["idx", "sentence1", "sentence2"], load_from_cache_file=False + ) + + # We also rename the 'label' column to 'labels' which is the expected name for labels by the models of the + # transformers library + tokenized_datasets = tokenized_datasets.rename_column("label", "labels") + + def collate_fn(examples): + # On TPU it's best to pad everything to the same length or training will be very slow. + if accelerator.distributed_type == DistributedType.XLA: + return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt") + return tokenizer.pad(examples, padding="longest", return_tensors="pt") + + # Instantiate dataloaders. + train_dataloader = DataLoader( + tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=batch_size + ) + eval_dataloader = DataLoader( + tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE + ) + + return train_dataloader, eval_dataloader + + +def training_function(config, args): + # Initialize accelerator + accelerator = Accelerator() + + # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs + lr = config["lr"] + num_epochs = int(config["num_epochs"]) + seed = int(config["seed"]) + batch_size = int(config["batch_size"]) + model_name = args.model_name_or_path + + set_seed(seed) + train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size, model_name) + + # Instantiate the model (we build the model here so that the seed also control new weights initialization) + model = AutoModelForSequenceClassification.from_pretrained(model_name, return_dict=True) + + # Instantiate optimizer + optimizer_cls = ( + AdamW + if accelerator.state.deepspeed_plugin is None + or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config + else DummyOptim + ) + optimizer = optimizer_cls(params=model.parameters(), lr=lr) + + max_training_steps = len(train_dataloader) * num_epochs + + # Instantiate scheduler + linear_decay_scheduler = False + if ( + accelerator.state.deepspeed_plugin is None + or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config + ): + lr_scheduler = get_linear_schedule_with_warmup( + optimizer=optimizer, + num_warmup_steps=0, + num_training_steps=max_training_steps, + ) + linear_decay_scheduler = True + else: + lr_scheduler = DummyScheduler(optimizer, total_num_steps=max_training_steps, warmup_num_steps=0) + + # Prepare everything + # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the + # prepare method. + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare( + model, optimizer, train_dataloader, eval_dataloader, lr_scheduler + ) + + # We also need to keep track of the stating epoch so files are named properly + starting_epoch = 0 + + # Now we train the model + metric = evaluate.load("glue", "mrpc") + best_performance = 0 + performance_metric = {} + expected_lr_after_first_optim_step = lr * ( + 1 - 1 / (max_training_steps / accelerator.num_processes / accelerator.gradient_accumulation_steps) + ) + lr_scheduler_check_completed = False + for epoch in range(starting_epoch, num_epochs): + model.train() + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(model): + outputs = model(**batch) + loss = outputs.loss + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # assert the learning rate after first optimizer step + if ( + accelerator.sync_gradients + and not lr_scheduler_check_completed + and linear_decay_scheduler + and accelerator.state.mixed_precision == "no" + ): + assert ( + lr_scheduler.get_last_lr()[0] == expected_lr_after_first_optim_step + ), f"Wrong lr found at second step, expected {expected_lr_after_first_optim_step}, got {lr_scheduler.get_last_lr()[0]}" + lr_scheduler_check_completed = True + + model.eval() + samples_seen = 0 + for step, batch in enumerate(eval_dataloader): + # We could avoid this line since we set the accelerator with `device_placement=True`. + batch.to(accelerator.device) + with torch.no_grad(): + outputs = model(**batch) + predictions = outputs.logits.argmax(dim=-1) + # It is slightly faster to call this once, than multiple times + predictions, references = accelerator.gather( + (predictions, batch["labels"]) + ) # If we are in a multiprocess environment, the last batch has duplicates + if accelerator.use_distributed: + if step == len(eval_dataloader) - 1: + predictions = predictions[: len(eval_dataloader.dataset) - samples_seen] + references = references[: len(eval_dataloader.dataset) - samples_seen] + else: + samples_seen += references.shape[0] + metric.add_batch( + predictions=predictions, + references=references, + ) + + eval_metric = metric.compute() + # Use accelerator.print to print only on the main process. + accelerator.print(f"epoch {epoch}:", eval_metric) + performance_metric[f"epoch-{epoch}"] = eval_metric["accuracy"] + + if best_performance < eval_metric["accuracy"]: + best_performance = eval_metric["accuracy"] + + # check that the LR is 0 + if linear_decay_scheduler and accelerator.state.mixed_precision == "no": + assert ( + lr_scheduler.get_last_lr()[0] == 0 + ), f"Wrong lr found at last step, expected 0, got {lr_scheduler.get_last_lr()[0]}" + + if args.performance_lower_bound is not None: + assert ( + args.performance_lower_bound <= best_performance + ), f"Best performance metric {best_performance} is lower than the lower bound {args.performance_lower_bound}" + + accelerator.wait_for_everyone() + if accelerator.is_main_process: + with open(os.path.join(args.output_dir, "all_results.json"), "w") as f: + json.dump(performance_metric, f) + + +def main(): + parser = argparse.ArgumentParser(description="Simple example of training script tracking peak GPU memory usage.") + parser.add_argument( + "--model_name_or_path", + type=str, + default="bert-base-cased", + help="Path to pretrained model or model identifier from huggingface.co/models.", + required=False, + ) + parser.add_argument( + "--output_dir", + type=str, + default=".", + help="Optional save directory where all checkpoint folders will be stored. Default is the current working directory.", + ) + parser.add_argument( + "--performance_lower_bound", + type=float, + default=None, + help="Optional lower bound for the performance metric. If set, the training will throw error when the performance metric drops below this value.", + ) + parser.add_argument( + "--num_epochs", + type=int, + default=3, + help="Number of train epochs.", + ) + args = parser.parse_args() + config = {"lr": 2e-5, "num_epochs": args.num_epochs, "seed": 42, "batch_size": 16} + training_function(config, args) + + +if __name__ == "__main__": + main() diff --git a/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/test_pippy.py b/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/test_pippy.py new file mode 100644 index 0000000000000000000000000000000000000000..f589365649d56fd690b4f4104a8838f885183527 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/test_pippy.py @@ -0,0 +1,129 @@ +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from torchvision.models import resnet34 +from transformers import ( + BertConfig, + BertForMaskedLM, + GPT2Config, + GPT2ForSequenceClassification, + T5Config, + T5ForConditionalGeneration, +) + +from accelerate import PartialState +from accelerate.inference import prepare_pippy +from accelerate.utils import DistributedType, send_to_device, set_seed + + +model_to_config = { + "t5": (T5ForConditionalGeneration, T5Config, 1024), + "bert": (BertForMaskedLM, BertConfig, 512), + "gpt2": (GPT2ForSequenceClassification, GPT2Config, 1024), +} + + +def get_model_and_data_for_text(model_name, device, num_processes: int = 2): + initializer, config, seq_len = model_to_config[model_name] + config_args = {} + # Eventually needed for batch inference tests on gpt-2 when bs != 1 + # if model_name == "gpt2": + # config_args["pad_token_id"] = 0 + model_config = config(**config_args) + model = initializer(model_config) + return model, torch.randint( + low=0, + high=model_config.vocab_size, + size=(num_processes, seq_len), + device=device, + dtype=torch.int64, + requires_grad=False, + ) + + +def test_gpt2(batch_size: int = 2): + set_seed(42) + state = PartialState() + model, inputs = get_model_and_data_for_text("gpt2", "cpu", batch_size) + model = prepare_pippy(model, example_args=(inputs,), no_split_module_classes=model._no_split_modules) + # For inference args need to be a tuple + inputs = inputs.to("cuda") + with torch.no_grad(): + output = model(inputs) + # Zach: Check that we just grab the real outputs we need at the end + if not state.is_last_process: + assert output is None, "Output was not generated on just the last process!" + else: + assert output is not None, "Output was not generated in the last process!" + + +def test_t5(batch_size: int = 2): + set_seed(42) + state = PartialState() + model, inputs = get_model_and_data_for_text("t5", "cpu", batch_size) + example_inputs = {"input_ids": inputs, "decoder_input_ids": inputs} + model = prepare_pippy( + model, + no_split_module_classes=model._no_split_modules, + example_kwargs=example_inputs, + ) + # For inference args need to be a tuple + inputs = send_to_device(example_inputs, "cuda:0") + with torch.no_grad(): + output = model(*inputs.values()) + # Zach: Check that we just grab the real outputs we need at the end + if not state.is_last_process: + assert output is None, "Output was not generated on just the last process!" + else: + assert output is not None, "Output was not generated in the last process!" + + +def test_resnet(batch_size: int = 2): + set_seed(42) + state = PartialState() + model = resnet34() + input_tensor = torch.rand(batch_size, 3, 224, 224) + model = prepare_pippy( + model, + example_args=(input_tensor,), + ) + inputs = send_to_device(input_tensor, "cuda:0") + with torch.no_grad(): + output = model(inputs) + # Zach: Check that we just grab the real outputs we need at the end + if not state.is_last_process: + assert output is None, "Output was not generated on just the last process!" + else: + assert output is not None, "Output was not generated in the last process!" + + +if __name__ == "__main__": + state = PartialState() + state.print("Testing pippy integration...") + if state.distributed_type == DistributedType.MULTI_GPU: + state.print("Testing GPT2...") + test_gpt2() + # Issue: When modifying the tokenizer for batch GPT2 inference, there's an issue + # due to references + # NameError: cannot access free variable 'chunk_args_list' where it is not associated with a value in enclosing scope + # test_gpt2(3) + state.print("Testing T5...") + test_t5() + test_t5(1) + test_t5(3) + state.print("Testing CV model...") + test_resnet() + test_resnet(3) + else: + print("Less than two GPUs found, not running tests!") diff --git a/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/test_zero3_integration.py b/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/test_zero3_integration.py new file mode 100644 index 0000000000000000000000000000000000000000..67e78a7d37c0b82113e1cdbb3e76987b24c8494f --- /dev/null +++ b/llm/Lib/site-packages/accelerate/test_utils/scripts/external_deps/test_zero3_integration.py @@ -0,0 +1,52 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.distributed + +from accelerate.test_utils import require_huggingface_suite +from accelerate.utils import is_transformers_available + + +if is_transformers_available(): + from transformers import AutoModel, TrainingArguments + + +GPT2_TINY = "sshleifer/tiny-gpt2" + + +@require_huggingface_suite +def init_torch_dist_then_launch_deepspeed(): + torch.distributed.init_process_group(backend="nccl") + deepspeed_config = { + "zero_optimization": { + "stage": 3, + }, + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + } + train_args = TrainingArguments( + output_dir="./", + deepspeed=deepspeed_config, + ) + model = AutoModel.from_pretrained(GPT2_TINY) + assert train_args is not None + assert model is not None + + +def main(): + init_torch_dist_then_launch_deepspeed() + + +if __name__ == "__main__": + main() diff --git a/llm/Lib/site-packages/accelerate/test_utils/scripts/test_cli.py b/llm/Lib/site-packages/accelerate/test_utils/scripts/test_cli.py new file mode 100644 index 0000000000000000000000000000000000000000..c85828cd49624372ae1866082e5580c60f8c9293 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/test_utils/scripts/test_cli.py @@ -0,0 +1,26 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + + +def main(): + if torch.cuda.is_available(): + num_gpus = torch.cuda.device_count() + else: + num_gpus = 0 + print(f"Successfully ran on {num_gpus} GPUs") + + +if __name__ == "__main__": + main() diff --git a/llm/Lib/site-packages/accelerate/test_utils/scripts/test_distributed_data_loop.py b/llm/Lib/site-packages/accelerate/test_utils/scripts/test_distributed_data_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..17d577c58ac2e0bea6e63c54b464eef483de12a8 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/test_utils/scripts/test_distributed_data_loop.py @@ -0,0 +1,238 @@ +#!/usr/bin/env python + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import warnings +from typing import List +from unittest.mock import Mock + +import torch +from torch.utils.data import DataLoader, IterableDataset, TensorDataset + +from accelerate.accelerator import Accelerator, DataLoaderConfiguration +from accelerate.utils.dataclasses import DistributedType + + +class DummyIterableDataset(IterableDataset): + def __init__(self, data): + self.data = data + + def __iter__(self): + yield from self.data + + +def create_accelerator(even_batches=True): + dataloader_config = DataLoaderConfiguration(even_batches=even_batches) + accelerator = Accelerator(dataloader_config=dataloader_config) + assert accelerator.num_processes == 2, "this script expects that two GPUs are available" + return accelerator + + +def create_dataloader(accelerator: Accelerator, dataset_size: int, batch_size: int, iterable: bool = False): + """ + Create a simple DataLoader to use during the test cases + """ + if iterable: + dataset = DummyIterableDataset(torch.as_tensor(range(dataset_size))) + else: + dataset = TensorDataset(torch.as_tensor(range(dataset_size))) + + dl = DataLoader(dataset, batch_size=batch_size) + dl = accelerator.prepare(dl) + + return dl + + +def verify_dataloader_batch_sizes( + accelerator: Accelerator, + dataset_size: int, + batch_size: int, + process_0_expected_batch_sizes: List[int], + process_1_expected_batch_sizes: List[int], +): + """ + A helper function for verifying the batch sizes coming from a prepared dataloader in each process + """ + dl = create_dataloader(accelerator=accelerator, dataset_size=dataset_size, batch_size=batch_size) + + batch_sizes = [len(batch[0]) for batch in dl] + + if accelerator.process_index == 0: + assert batch_sizes == process_0_expected_batch_sizes + elif accelerator.process_index == 1: + assert batch_sizes == process_1_expected_batch_sizes + + +def test_default_ensures_even_batch_sizes(): + accelerator = create_accelerator() + + # without padding, we would expect a different number of batches + verify_dataloader_batch_sizes( + accelerator, + dataset_size=3, + batch_size=1, + process_0_expected_batch_sizes=[1, 1], + process_1_expected_batch_sizes=[1, 1], + ) + + # without padding, we would expect the same number of batches, but different sizes + verify_dataloader_batch_sizes( + accelerator, + dataset_size=7, + batch_size=2, + process_0_expected_batch_sizes=[2, 2], + process_1_expected_batch_sizes=[2, 2], + ) + + +def test_can_disable_even_batches(): + accelerator = create_accelerator(even_batches=False) + + verify_dataloader_batch_sizes( + accelerator, + dataset_size=3, + batch_size=1, + process_0_expected_batch_sizes=[1, 1], + process_1_expected_batch_sizes=[1], + ) + + verify_dataloader_batch_sizes( + accelerator, + dataset_size=7, + batch_size=2, + process_0_expected_batch_sizes=[2, 2], + process_1_expected_batch_sizes=[2, 1], + ) + + +def test_can_join_uneven_inputs(): + accelerator = create_accelerator(even_batches=False) + + model = torch.nn.Linear(1, 1) + ddp_model = accelerator.prepare(model) + + dl = create_dataloader(accelerator, dataset_size=3, batch_size=1) + + batch_idxs = [] + with accelerator.join_uneven_inputs([ddp_model]): + for batch_idx, batch in enumerate(dl): + output = ddp_model(batch[0].float()) + loss = output.sum() + loss.backward() + batch_idxs.append(batch_idx) + + accelerator.wait_for_everyone() + + if accelerator.process_index == 0: + assert batch_idxs == [0, 1] + elif accelerator.process_index == 1: + assert batch_idxs == [0] + + +def test_join_raises_warning_for_non_ddp_distributed(accelerator): + with warnings.catch_warnings(record=True) as w: + with accelerator.join_uneven_inputs([Mock()]): + pass + + assert issubclass(w[-1].category, UserWarning) + assert "only supported for multi-GPU" in str(w[-1].message) + + +def test_join_can_override_even_batches(): + default_even_batches = True + overridden_even_batches = False + accelerator = create_accelerator(even_batches=default_even_batches) + model = torch.nn.Linear(1, 1) + ddp_model = accelerator.prepare(model) + train_dl = create_dataloader(accelerator, dataset_size=3, batch_size=1) + valid_dl = create_dataloader(accelerator, dataset_size=3, batch_size=1) + + with accelerator.join_uneven_inputs([ddp_model], even_batches=overridden_even_batches): + train_dl_overridden_value = train_dl.batch_sampler.even_batches + valid_dl_overridden_value = valid_dl.batch_sampler.even_batches + + assert train_dl_overridden_value == overridden_even_batches + assert valid_dl_overridden_value == overridden_even_batches + assert train_dl.batch_sampler.even_batches == default_even_batches + assert valid_dl.batch_sampler.even_batches == default_even_batches + + +def test_join_can_override_for_mixed_type_dataloaders(): + default_even_batches = True + overridden_even_batches = False + accelerator = create_accelerator(even_batches=default_even_batches) + model = torch.nn.Linear(1, 1) + ddp_model = accelerator.prepare(model) + create_dataloader(accelerator, dataset_size=3, batch_size=1, iterable=True) + batch_dl = create_dataloader(accelerator, dataset_size=3, batch_size=1) + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + try: + with accelerator.join_uneven_inputs([ddp_model], even_batches=overridden_even_batches): + batch_dl_overridden_value = batch_dl.batch_sampler.even_batches + except AttributeError: + # ensure attribute error is not raised when processing iterable dl + raise AssertionError + + assert batch_dl_overridden_value == overridden_even_batches + assert batch_dl.batch_sampler.even_batches == default_even_batches + + +def test_join_raises_warning_for_iterable_when_overriding_even_batches(): + accelerator = create_accelerator() + model = torch.nn.Linear(1, 1) + ddp_model = accelerator.prepare(model) + create_dataloader(accelerator, dataset_size=3, batch_size=1, iterable=True) + + with warnings.catch_warnings(record=True) as w: + with accelerator.join_uneven_inputs([ddp_model], even_batches=False): + pass + + assert issubclass(w[-1].category, UserWarning) + assert "only supported for map-style datasets" in str(w[-1].message) + + +def main(): + accelerator = create_accelerator() + + accelerator.print("Test that even_batches variable ensures uniform batches across processes") + test_default_ensures_even_batch_sizes() + + accelerator.print("Run tests with even_batches disabled") + test_can_disable_even_batches() + + accelerator.print("Test joining uneven inputs") + test_can_join_uneven_inputs() + + accelerator.print("Test overriding even_batches when joining uneven inputs") + test_join_can_override_even_batches() + + accelerator.print("Test overriding even_batches for mixed dataloader types") + test_join_can_override_for_mixed_type_dataloaders() + + accelerator.print("Test overriding even_batches raises a warning for iterable dataloaders") + test_join_raises_warning_for_iterable_when_overriding_even_batches() + + accelerator.print("Test join with non DDP distributed raises warning") + original_state = accelerator.state.distributed_type + accelerator.state.distributed_type = DistributedType.FSDP + test_join_raises_warning_for_non_ddp_distributed(accelerator) + accelerator.state.distributed_type = original_state + + +if __name__ == "__main__": + main() diff --git a/llm/Lib/site-packages/accelerate/test_utils/scripts/test_notebook.py b/llm/Lib/site-packages/accelerate/test_utils/scripts/test_notebook.py new file mode 100644 index 0000000000000000000000000000000000000000..f0c073ac3eae5a86c351a5f8232b84bcdfb920a8 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/test_utils/scripts/test_notebook.py @@ -0,0 +1,56 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Test file to ensure that in general certain situational setups for notebooks work. +""" + +import os + +from pytest import raises + +from accelerate import PartialState, notebook_launcher +from accelerate.test_utils import require_bnb +from accelerate.utils import is_bnb_available + + +def basic_function(): + # Just prints the PartialState + print(f"PartialState:\n{PartialState()}") + + +NUM_PROCESSES = int(os.environ.get("ACCELERATE_NUM_PROCESSES", 1)) + + +def test_can_initialize(): + notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES) + + +@require_bnb +def test_problematic_imports(): + with raises(RuntimeError, match="Please keep these imports"): + import bitsandbytes as bnb # noqa: F401 + + notebook_launcher(basic_function, (), num_processes=NUM_PROCESSES) + + +def main(): + print("Test basic notebook can be ran") + test_can_initialize() + if is_bnb_available(): + print("Test problematic imports (bnb)") + test_problematic_imports() + + +if __name__ == "__main__": + main() diff --git a/llm/Lib/site-packages/accelerate/test_utils/scripts/test_ops.py b/llm/Lib/site-packages/accelerate/test_utils/scripts/test_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..1b18780fa70fdc2b8f579f07910c8682437459d5 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/test_utils/scripts/test_ops.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python + +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from accelerate import PartialState +from accelerate.test_utils.testing import assert_exception +from accelerate.utils.dataclasses import DistributedType +from accelerate.utils.operations import ( + DistributedOperationException, + broadcast, + copy_tensor_to_devices, + gather, + gather_object, + pad_across_processes, + reduce, +) + + +def create_tensor(state): + return (torch.arange(state.num_processes) + 1.0 + (state.num_processes * state.process_index)).to(state.device) + + +def test_gather(state): + tensor = create_tensor(state) + gathered_tensor = gather(tensor) + assert gathered_tensor.tolist() == list(range(1, state.num_processes**2 + 1)) + + +def test_gather_object(state): + # Gather objects in TorchXLA is not supported. + if state.distributed_type == DistributedType.XLA: + return + obj = [state.process_index] + gathered_obj = gather_object(obj) + assert len(gathered_obj) == state.num_processes, f"{gathered_obj}, {len(gathered_obj)} != {state.num_processes}" + assert gathered_obj == list(range(state.num_processes)), f"{gathered_obj} != {list(range(state.num_processes))}" + + +def test_gather_non_contigous(state): + # Skip this test because the 'is_contiguous' function of XLA tensor always returns True. + if state.distributed_type == DistributedType.XLA: + return + # Create a non-contiguous tensor + tensor = torch.arange(12).view(4, 3).t().to(state.device) + assert not tensor.is_contiguous() + # Shouldn't error out + _ = gather(tensor) + + +def test_broadcast(state): + tensor = create_tensor(state) + broadcasted_tensor = broadcast(tensor) + assert broadcasted_tensor.shape == torch.Size([state.num_processes]) + assert broadcasted_tensor.tolist() == list(range(1, state.num_processes + 1)) + + +def test_pad_across_processes(state): + # We need to pad the tensor with one more element if we are the main process + # to ensure that we can pad + if state.is_main_process: + tensor = torch.arange(state.num_processes + 1).to(state.device) + else: + tensor = torch.arange(state.num_processes).to(state.device) + padded_tensor = pad_across_processes(tensor) + assert padded_tensor.shape == torch.Size([state.num_processes + 1]) + if not state.is_main_process: + assert padded_tensor.tolist() == list(range(0, state.num_processes)) + [0] + + +def test_reduce_sum(state): + # For now runs on only two processes + if state.num_processes != 2: + return + tensor = create_tensor(state) + reduced_tensor = reduce(tensor, "sum") + truth_tensor = torch.tensor([4.0, 6]).to(state.device) + assert torch.allclose(reduced_tensor, truth_tensor), f"{reduced_tensor} != {truth_tensor}" + + +def test_reduce_mean(state): + # For now runs on only two processes + if state.num_processes != 2: + return + tensor = create_tensor(state) + reduced_tensor = reduce(tensor, "mean") + truth_tensor = torch.tensor([2.0, 3]).to(state.device) + assert torch.allclose(reduced_tensor, truth_tensor), f"{reduced_tensor} != {truth_tensor}" + + +def test_op_checker(state): + # Must be in a distributed state, and gathering is currently not supported in TorchXLA. + if state.distributed_type in [DistributedType.NO, DistributedType.XLA]: + return + state.debug = True + # `pad_across_processes` + if state.process_index == 0: + data = {"tensor": torch.tensor([[0.0, 1, 2, 3, 4]]).to(state.device)} + else: + data = {"tensor": torch.tensor([[[0.0, 1, 2, 3, 4, 5]]]).to(state.device)} + + with assert_exception(DistributedOperationException): + pad_across_processes(data, dim=0) + + # `reduce` + if state.process_index == 0: + data = {"tensor": torch.tensor([[0.0, 1, 2, 3, 4]]).to(state.device)} + else: + data = {"tensor": torch.tensor([[[0.0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]]).to(state.device)} + + with assert_exception(DistributedOperationException): + reduce(data) + + # `broadcast` + if state.process_index == 0: + data = {"tensor": torch.tensor([[0.0, 1, 2, 3, 4]]).to(state.device)} + else: + data = {"tensor": torch.tensor([[[0.0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]]).to(state.device)} + + with assert_exception(DistributedOperationException): + broadcast(data) + + state.debug = False + + +def test_copy_tensor_to_devices(state): + if state.distributed_type not in [DistributedType.MULTI_GPU, DistributedType.XLA]: + return + if state.is_main_process: + tensor = torch.tensor([1, 2, 3], dtype=torch.int).to(state.device) + else: + tensor = None + tensor = copy_tensor_to_devices(tensor) + assert torch.allclose(tensor, torch.tensor([1, 2, 3], dtype=torch.int, device=state.device)) + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +def main(): + state = PartialState() + state.print(f"State: {state}") + state.print("testing gather") + test_gather(state) + state.print("testing gather_object") + test_gather_object(state) + state.print("testing gather non-contigous") + test_gather_non_contigous(state) + state.print("testing broadcast") + test_broadcast(state) + state.print("testing pad_across_processes") + test_pad_across_processes(state) + state.print("testing reduce_sum") + test_reduce_sum(state) + state.print("testing reduce_mean") + test_reduce_mean(state) + state.print("testing op_checker") + test_op_checker(state) + state.print("testing sending tensors across devices") + test_copy_tensor_to_devices(state) + + +if __name__ == "__main__": + main() diff --git a/llm/Lib/site-packages/accelerate/test_utils/scripts/test_script.py b/llm/Lib/site-packages/accelerate/test_utils/scripts/test_script.py new file mode 100644 index 0000000000000000000000000000000000000000..a982612e4463eb807f272b7a65093ac23008861a --- /dev/null +++ b/llm/Lib/site-packages/accelerate/test_utils/scripts/test_script.py @@ -0,0 +1,802 @@ +#!/usr/bin/env python + +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import io +import math +import time +from copy import deepcopy +from pathlib import Path + +import numpy as np +import torch +from torch.utils.data import DataLoader, Dataset + +from accelerate import Accelerator +from accelerate.data_loader import SeedableRandomSampler, prepare_data_loader +from accelerate.state import AcceleratorState +from accelerate.test_utils import RegressionDataset, are_the_same_tensors +from accelerate.utils import ( + DataLoaderConfiguration, + DistributedType, + gather, + is_bf16_available, + is_datasets_available, + is_ipex_available, + is_mlu_available, + is_npu_available, + is_xpu_available, + set_seed, + synchronize_rng_states, +) + + +# TODO: remove RegressionModel4XPU once ccl support empty buffer in broadcasting. +if is_xpu_available(): + from accelerate.test_utils import RegressionModel4XPU as RegressionModel +else: + from accelerate.test_utils import RegressionModel + + +def generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler=False): + "Creates a dataloader that can also use the `SeedableRandomSampler`" + if use_seedable_sampler: + # The SeedableRandomSampler is needed during distributed setups + # for full reproducability across processes with the `DataLoader` + sampler = SeedableRandomSampler( + generator=generator, + data_source=train_set, + num_samples=len(train_set), + ) + return DataLoader(train_set, batch_size=batch_size, sampler=sampler) + else: + return DataLoader(train_set, batch_size=batch_size, shuffle=True, generator=generator) + + +def print_main(state): + print(f"Printing from the main process {state.process_index}") + + +def print_local_main(state): + print(f"Printing from the local main process {state.local_process_index}") + + +def print_last(state): + print(f"Printing from the last process {state.process_index}") + + +def print_on(state, process_idx): + print(f"Printing from process {process_idx}: {state.process_index}") + + +def process_execution_check(): + accelerator = Accelerator() + num_processes = accelerator.num_processes + # Test main_process_first context manager + path = Path("check_main_process_first.txt") + with accelerator.main_process_first(): + if accelerator.is_main_process: + time.sleep(0.1) # ensure main process takes longest + with open(path, "a+") as f: + f.write("Currently in the main process\n") + else: + with open(path, "a+") as f: + f.write("Now on another process\n") + accelerator.wait_for_everyone() + + if accelerator.is_main_process: + with open(path) as f: + text = "".join(f.readlines()) + try: + assert text.startswith("Currently in the main process\n"), "Main process was not first" + if num_processes > 1: + assert text.endswith("Now on another process\n"), "Main process was not first" + assert ( + text.count("Now on another process\n") == accelerator.num_processes - 1 + ), f"Only wrote to file {text.count('Now on another process') + 1} times, not {accelerator.num_processes}" + except AssertionError: + path.unlink() + raise + + if accelerator.is_main_process and path.exists(): + path.unlink() + accelerator.wait_for_everyone() + # Test the decorators + f = io.StringIO() + with contextlib.redirect_stdout(f): + accelerator.on_main_process(print_main)(accelerator.state) + result = f.getvalue().rstrip() + if accelerator.is_main_process: + assert result == "Printing from the main process 0", f"{result} != Printing from the main process 0" + else: + assert f.getvalue().rstrip() == "", f'{result} != ""' + f.truncate(0) + f.seek(0) + + with contextlib.redirect_stdout(f): + accelerator.on_local_main_process(print_local_main)(accelerator.state) + if accelerator.is_local_main_process: + assert f.getvalue().rstrip() == "Printing from the local main process 0" + else: + assert f.getvalue().rstrip() == "" + f.truncate(0) + f.seek(0) + + with contextlib.redirect_stdout(f): + accelerator.on_last_process(print_last)(accelerator.state) + if accelerator.is_last_process: + assert f.getvalue().rstrip() == f"Printing from the last process {accelerator.state.num_processes - 1}" + else: + assert f.getvalue().rstrip() == "" + f.truncate(0) + f.seek(0) + + for process_idx in range(num_processes): + with contextlib.redirect_stdout(f): + accelerator.on_process(print_on, process_index=process_idx)(accelerator.state, process_idx) + if accelerator.process_index == process_idx: + assert f.getvalue().rstrip() == f"Printing from process {process_idx}: {accelerator.process_index}" + else: + assert f.getvalue().rstrip() == "" + f.truncate(0) + f.seek(0) + + +def init_state_check(): + # Test we can instantiate this twice in a row. + state = AcceleratorState() + if state.local_process_index == 0: + print("Testing, testing. 1, 2, 3.") + print(state) + + +def rng_sync_check(): + state = AcceleratorState() + synchronize_rng_states(["torch"]) + assert are_the_same_tensors(torch.get_rng_state()), "RNG states improperly synchronized on CPU." + if state.distributed_type == DistributedType.MULTI_GPU: + synchronize_rng_states(["cuda"]) + assert are_the_same_tensors(torch.cuda.get_rng_state()), "RNG states improperly synchronized on GPU." + elif state.distributed_type == DistributedType.MULTI_XPU: + synchronize_rng_states(["xpu"]) + assert are_the_same_tensors(torch.xpu.get_rng_state()), "RNG states improperly synchronized on XPU." + generator = torch.Generator() + synchronize_rng_states(["generator"], generator=generator) + assert are_the_same_tensors(generator.get_state()), "RNG states improperly synchronized in generator." + + if state.local_process_index == 0: + print("All rng are properly synched.") + + +def dl_preparation_check(): + state = AcceleratorState() + length = 32 * state.num_processes + + dl = DataLoader(range(length), batch_size=8) + dl = prepare_data_loader(dl, state.device, state.num_processes, state.process_index, put_on_device=True) + result = [] + for batch in dl: + result.append(gather(batch)) + result = torch.cat(result) + + print(state.process_index, result, type(dl)) + assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result." + + dl = DataLoader(range(length), batch_size=8) + dl = prepare_data_loader( + dl, + state.device, + state.num_processes, + state.process_index, + put_on_device=True, + split_batches=True, + ) + result = [] + for batch in dl: + result.append(gather(batch)) + result = torch.cat(result) + assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result." + + if state.process_index == 0: + print("Non-shuffled dataloader passing.") + + dl = DataLoader(range(length), batch_size=8, shuffle=True) + dl = prepare_data_loader(dl, state.device, state.num_processes, state.process_index, put_on_device=True) + result = [] + for batch in dl: + result.append(gather(batch)) + result = torch.cat(result).tolist() + result.sort() + assert result == list(range(length)), "Wrong shuffled dataloader result." + + dl = DataLoader(range(length), batch_size=8, shuffle=True) + dl = prepare_data_loader( + dl, + state.device, + state.num_processes, + state.process_index, + put_on_device=True, + split_batches=True, + ) + result = [] + for batch in dl: + result.append(gather(batch)) + result = torch.cat(result).tolist() + result.sort() + assert result == list(range(length)), "Wrong shuffled dataloader result." + + if state.local_process_index == 0: + print("Shuffled dataloader passing.") + + +def central_dl_preparation_check(): + state = AcceleratorState() + length = 32 * state.num_processes + + dl = DataLoader(range(length), batch_size=8) + dl = prepare_data_loader( + dl, state.device, state.num_processes, state.process_index, put_on_device=True, dispatch_batches=True + ) + result = [] + for batch in dl: + result.append(gather(batch)) + result = torch.cat(result) + assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result." + + dl = DataLoader(range(length), batch_size=8) + dl = prepare_data_loader( + dl, + state.device, + state.num_processes, + state.process_index, + put_on_device=True, + split_batches=True, + dispatch_batches=True, + ) + result = [] + for batch in dl: + result.append(gather(batch)) + result = torch.cat(result) + assert torch.equal(result.cpu(), torch.arange(0, length).long()), "Wrong non-shuffled dataloader result." + + if state.process_index == 0: + print("Non-shuffled central dataloader passing.") + + dl = DataLoader(range(length), batch_size=8, shuffle=True) + dl = prepare_data_loader( + dl, state.device, state.num_processes, state.process_index, put_on_device=True, dispatch_batches=True + ) + result = [] + for batch in dl: + result.append(gather(batch)) + result = torch.cat(result).tolist() + result.sort() + assert result == list(range(length)), "Wrong shuffled dataloader result." + + dl = DataLoader(range(length), batch_size=8, shuffle=True) + dl = prepare_data_loader( + dl, + state.device, + state.num_processes, + state.process_index, + put_on_device=True, + split_batches=True, + dispatch_batches=True, + ) + result = [] + for batch in dl: + result.append(gather(batch)) + result = torch.cat(result).tolist() + result.sort() + assert result == list(range(length)), "Wrong shuffled dataloader result." + + if state.local_process_index == 0: + print("Shuffled central dataloader passing.") + + +def custom_sampler_check(): + state = AcceleratorState() + + class CustomDataset(Dataset): + def __init__(self, data): + self.data = data + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + return self.data[index] + + class CustomBatchSampler: + def __init__(self, dataset_length: int, batch_size: int, shuffle: bool = True): + self.batch_size = batch_size + self.data_index = np.arange(dataset_length) + self.shuffle = shuffle + + def __iter__(self): + num_batches = len(self) + if self.shuffle: + index = np.random.permutation(self.data_index) + else: + index = self.data_index + output = np.array_split(index, num_batches) + yield from output + + def __len__(self): + return math.ceil(len(self.data_index) / self.batch_size) + + dataset = CustomDataset(range(32 * state.num_processes)) + sampler = CustomBatchSampler(len(dataset), batch_size=8) + dl = DataLoader(dataset, batch_sampler=sampler) + dl = prepare_data_loader(dl, state.device, state.num_processes, state.process_index) + # We need just ensure that `dl.batch_sampler` (or `dl.batch_sampler.batch_sampler` is indeed the old batch sampler + if hasattr(dl.batch_sampler, "batch_sampler"): + assert isinstance( + dl.batch_sampler.batch_sampler, CustomBatchSampler + ), "Custom sampler was changed after calling `prepare_data_loader`" + else: + assert isinstance( + dl.batch_sampler, CustomBatchSampler + ), "Custom sampler was changed after calling `prepare_data_loader`" + + +def check_seedable_sampler(): + # Set seed + set_seed(42) + train_set = RegressionDataset(length=10, seed=42) + train_dl = DataLoader(train_set, batch_size=2, shuffle=True) + + config = DataLoaderConfiguration(use_seedable_sampler=True) + accelerator = Accelerator(dataloader_config=config) + train_dl = accelerator.prepare(train_dl) + original_items = [] + for _ in range(3): + for batch in train_dl: + original_items.append(batch["x"]) + original_items = torch.cat(original_items) + + # Set seed again and the epoch + set_seed(42) + train_dl.set_epoch(0) + new_items = [] + for _ in range(3): + for batch in train_dl: + new_items.append(batch["x"]) + new_items = torch.cat(new_items) + assert torch.allclose(original_items, new_items), "Did not obtain the same items with the same seed and epoch." + + +def check_seedable_sampler_in_batch_sampler_shard(): + set_seed(42) + + config = DataLoaderConfiguration(use_seedable_sampler=True) + accelerator = Accelerator(dataloader_config=config) + assert accelerator.num_processes > 1, "This test requires more than one process." + + dataloader = DataLoader(list(range(10)), batch_size=1, shuffle=True) + prepared_data_loader = prepare_data_loader( + dataloader=dataloader, + use_seedable_sampler=True, + ) + + target_sampler = prepared_data_loader.batch_sampler.batch_sampler.sampler + assert isinstance( + target_sampler, SeedableRandomSampler + ), "Sampler in BatchSamplerShard is not SeedableRandomSampler." + + +def mock_training(length, batch_size, generator, use_seedable_sampler=False): + set_seed(42) + generator.manual_seed(42) + train_set = RegressionDataset(length=length, seed=42) + + train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler) + model = RegressionModel() + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + for epoch in range(3): + for batch in train_dl: + model.zero_grad() + output = model(batch["x"]) + loss = torch.nn.functional.mse_loss(output, batch["y"]) + loss.backward() + optimizer.step() + return train_set, model + + +def training_check(use_seedable_sampler=False): + state = AcceleratorState() + generator = torch.Generator() + batch_size = 8 + length = batch_size * 4 * state.num_processes + + train_set, old_model = mock_training(length, batch_size * state.num_processes, generator, use_seedable_sampler) + assert are_the_same_tensors(old_model.a), "Did not obtain the same model on both processes." + assert are_the_same_tensors(old_model.b), "Did not obtain the same model on both processes." + + accelerator = Accelerator() + train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler) + model = RegressionModel() + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + + train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer) + set_seed(42) + generator.manual_seed(42) + for _ in range(3): + for batch in train_dl: + model.zero_grad() + output = model(batch["x"]) + loss = torch.nn.functional.mse_loss(output, batch["y"]) + accelerator.backward(loss) + optimizer.step() + + model = accelerator.unwrap_model(model).cpu() + assert torch.allclose(old_model.a, model.a), "Did not obtain the same model on CPU or distributed training." + assert torch.allclose(old_model.b, model.b), "Did not obtain the same model on CPU or distributed training." + + accelerator.print("Training yielded the same results on one CPU or distributed setup with no batch split.") + + dataloader_config = DataLoaderConfiguration(split_batches=True, use_seedable_sampler=use_seedable_sampler) + accelerator = Accelerator(dataloader_config=dataloader_config) + train_dl = generate_baseline_dataloader( + train_set, generator, batch_size * state.num_processes, use_seedable_sampler + ) + model = RegressionModel() + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + + train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer) + set_seed(42) + generator.manual_seed(42) + for _ in range(3): + for batch in train_dl: + model.zero_grad() + output = model(batch["x"]) + loss = torch.nn.functional.mse_loss(output, batch["y"]) + accelerator.backward(loss) + optimizer.step() + + model = accelerator.unwrap_model(model).cpu() + assert torch.allclose(old_model.a, model.a), "Did not obtain the same model on CPU or distributed training." + assert torch.allclose(old_model.b, model.b), "Did not obtain the same model on CPU or distributed training." + + accelerator.print("Training yielded the same results on one CPU or distributes setup with batch split.") + + if torch.cuda.is_available() or is_npu_available() or is_mlu_available(): + # Mostly a test that FP16 doesn't crash as the operation inside the model is not converted to FP16 + print("FP16 training check.") + AcceleratorState._reset_state() + dataloader_config = DataLoaderConfiguration(use_seedable_sampler=use_seedable_sampler) + accelerator = Accelerator(mixed_precision="fp16", dataloader_config=dataloader_config) + train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler) + model = RegressionModel() + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + + train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer) + set_seed(42) + generator.manual_seed(42) + for _ in range(3): + for batch in train_dl: + model.zero_grad() + output = model(batch["x"]) + loss = torch.nn.functional.mse_loss(output, batch["y"]) + accelerator.backward(loss) + optimizer.step() + + model = accelerator.unwrap_model(model).cpu() + assert torch.allclose(old_model.a, model.a), "Did not obtain the same model on CPU or distributed training." + assert torch.allclose(old_model.b, model.b), "Did not obtain the same model on CPU or distributed training." + + if torch.cuda.is_available(): + # Mostly a test that model.forward will have autocast when running unwrap_model(model, keep_fp32_wrapper=True) + print("Keep fp32 wrapper check.") + AcceleratorState._reset_state() + accelerator = Accelerator(mixed_precision="fp16") + + model = torch.nn.Linear(2, 4) + model = accelerator.prepare(model) + model_with_fp32_wrapper = accelerator.unwrap_model(model, keep_fp32_wrapper=True) + + # Run forward with fp16 as input. + # When the model is with mixed precision wrapper, no error will be raised. + input_tensor = torch.Tensor([1, 2]).to(dtype=torch.float16, device=accelerator.device) + output = model_with_fp32_wrapper(input_tensor) + + # BF16 support is only for CPU + TPU, and some GPU + if is_bf16_available(): + # Mostly a test that BF16 doesn't crash as the operation inside the model is not converted to BF16 + print("BF16 training check.") + AcceleratorState._reset_state() + dataloader_config = DataLoaderConfiguration(use_seedable_sampler=use_seedable_sampler) + accelerator = Accelerator(mixed_precision="bf16", dataloader_config=dataloader_config) + train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler) + model = RegressionModel() + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + + train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer) + set_seed(42) + generator.manual_seed(42) + for _ in range(3): + for batch in train_dl: + model.zero_grad() + output = model(batch["x"]) + loss = torch.nn.functional.mse_loss(output, batch["y"]) + accelerator.backward(loss) + optimizer.step() + + model = accelerator.unwrap_model(model).cpu() + assert torch.allclose(old_model.a, model.a), "Did not obtain the same model on CPU or distributed training." + assert torch.allclose(old_model.b, model.b), "Did not obtain the same model on CPU or distributed training." + + # IPEX support is only for CPU + if is_ipex_available(): + print("ipex BF16 training check.") + AcceleratorState._reset_state() + dataloader_config = DataLoaderConfiguration(use_seedable_sampler=use_seedable_sampler) + accelerator = Accelerator(mixed_precision="bf16", cpu=True, dataloader_config=dataloader_config) + train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler) + model = RegressionModel() + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + + train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer) + set_seed(42) + generator.manual_seed(42) + for _ in range(3): + for batch in train_dl: + model.zero_grad() + output = model(batch["x"]) + loss = torch.nn.functional.mse_loss(output, batch["y"]) + accelerator.backward(loss) + optimizer.step() + + model = accelerator.unwrap_model(model).cpu() + assert torch.allclose(old_model.a, model.a), "Did not obtain the same model on CPU or distributed training." + assert torch.allclose(old_model.b, model.b), "Did not obtain the same model on CPU or distributed training." + + # XPU support is only for XPU + if is_xpu_available(): + print("xpu BF16 training check.") + AcceleratorState._reset_state() + dataloader_config = DataLoaderConfiguration(use_seedable_sampler=use_seedable_sampler) + accelerator = Accelerator(mixed_precision="bf16", cpu=False, dataloader_config=dataloader_config) + train_dl = generate_baseline_dataloader(train_set, generator, batch_size, use_seedable_sampler) + model = RegressionModel() + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + + train_dl, model, optimizer = accelerator.prepare(train_dl, model, optimizer) + set_seed(42) + generator.manual_seed(42) + for _ in range(3): + for batch in train_dl: + model.zero_grad() + output = model(batch["x"]) + loss = torch.nn.functional.mse_loss(output, batch["y"]) + accelerator.backward(loss) + optimizer.step() + + model = accelerator.unwrap_model(model).cpu() + assert torch.allclose(old_model.a, model.a), "Did not obtain the same model on XPU or distributed training." + assert torch.allclose(old_model.b, model.b), "Did not obtain the same model on XPU or distributed training." + + +def test_split_between_processes_dataset(datasets_Dataset): + state = AcceleratorState() + data = datasets_Dataset.from_list([dict(k=v) for v in range(2 * state.num_processes)]) + with state.split_between_processes(data, apply_padding=False) as results: + assert ( + len(results) == 2 + ), f"Each process did not have two items. Process index: {state.process_index}; Length: {len(results)}" + + data = datasets_Dataset.from_list([dict(k=v) for v in range(2 * state.num_processes - 1)]) + with state.split_between_processes(data, apply_padding=False) as results: + if state.is_last_process: + assert ( + len(results) == 1 + ), f"Last process did not receive a single item. Process index: {state.process_index}; Length: {len(results)}" + else: + assert ( + len(results) == 2 + ), f"One of the intermediate processes did not receive two items. Process index: {state.process_index}; Length: {len(results)}" + + data = datasets_Dataset.from_list([dict(k=v) for v in range(2 * state.num_processes - 1)]) + with state.split_between_processes(data, apply_padding=True) as results: + if state.num_processes == 1: + assert ( + len(results) == 1 + ), f"Single process did not receive a single item. Process index: {state.process_index}; Length: {len(results)}" + else: + assert ( + len(results) == 2 + ), f"Each process did not have two items. Process index: {state.process_index}; Length: {len(results)}" + + state.wait_for_everyone() + + +def test_split_between_processes_list(): + state = AcceleratorState() + data = list(range(0, 2 * state.num_processes)) + with state.split_between_processes(data) as results: + assert ( + len(results) == 2 + ), f"Each process did not have two items. Process index: {state.process_index}; Length: {len(results)}" + + data = list(range(0, (3 * state.num_processes) - 1)) + with state.split_between_processes(data, apply_padding=True) as results: + if state.is_last_process: + # Test that the last process gets the extra item(s) + num_samples_per_device = math.ceil(len(data) / state.num_processes) + assert ( + len(results) == num_samples_per_device + ), f"Last process did not get the extra item(s). Process index: {state.process_index}; Length: {len(results)}" + state.wait_for_everyone() + + +def test_split_between_processes_nested_dict(): + state = AcceleratorState() + a = [1, 2, 3, 4, 5, 6, 7, 8] + b = ["a", "b", "c", "d", "e", "f", "g", "h"] + c = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) + if state.num_processes in (1, 2, 4): + data = {"a": a, "b": b, "c": c} + data_copy = deepcopy(data) + with state.split_between_processes(data) as results: + if state.process_index == 0: + assert results["a"] == data_copy["a"][: 8 // state.num_processes] + elif state.num_processes == 2: + assert results["a"] == data_copy["a"][4:] + elif state.process_index == 3: + # We return a list each time + assert results["a"] == data_copy["a"][-2:], f'Expected: {data_copy["a"][-2]}, Actual: {results["a"]}' + if state.process_index == 0: + assert results["b"] == data_copy["b"][: 8 // state.num_processes] + elif state.num_processes == 2: + assert results["b"] == data_copy["b"][4:] + elif state.process_index == 3: + assert results["b"] == data_copy["b"][-2:] + if state.process_index == 0: + assert torch.allclose( + results["c"], data_copy["c"][: 8 // state.num_processes] + ), f"Did not obtain expected values on process 0, expected `{data['c'][:8 // state.num_processes]}`, received: {results['c']}" + elif state.num_processes == 2: + assert torch.allclose( + results["c"], data_copy["c"][4:] + ), f"Did not obtain expected values on process 2, expected `{data['c'][4:]}`, received: {results['c']}" + elif state.process_index == 3: + assert torch.allclose( + results["c"], data_copy["c"][-2:] + ), f"Did not obtain expected values on process 4, expected `{data['c'][-2:]}`, received: {results['c']}" + + state.wait_for_everyone() + + +def test_split_between_processes_tensor(): + state = AcceleratorState() + if state.num_processes > 1: + data = torch.tensor([[0, 1, 2, 3], [4, 5, 6, 7]]).to(state.device) + with state.split_between_processes(data) as results: + if state.process_index == 0: + assert torch.allclose(results, torch.tensor([0, 1, 2, 3]).to(state.device)) + else: + assert torch.allclose(results, torch.tensor([4, 5, 6, 7]).to(state.device)) + state.wait_for_everyone() + + +def test_trigger(): + accelerator = Accelerator() + # should start with being false + assert accelerator.check_trigger() is False + + # set a breakpoint on the main process + if accelerator.is_main_process: + accelerator.set_trigger() + + # check it's been activated across all processes + # calls `all_reduce` and triggers a sync + assert accelerator.check_trigger() is True + + # check it's been reset after the sync + assert accelerator.check_trigger() is False + + +def test_reinstantiated_state(): + import pytest + + AcceleratorState._reset_state() + simple_model = torch.nn.Linear(1, 1) + # First define an accelerator + accelerator = Accelerator() + # Then call `reset_state`, breaking the state existing in the accelerator + AcceleratorState._reset_state() + # Now try and prepare a simple model, should raise the custom error early + with pytest.raises(AttributeError) as cm: + accelerator.prepare(simple_model) + assert "`AcceleratorState` object has no attribute" in str(cm.value.args[0]) + assert "This happens if `AcceleratorState._reset_state()`" in str(cm.value.args[0]) + + +def main(): + accelerator = Accelerator() + state = accelerator.state + if state.local_process_index == 0: + print("**Initialization**") + init_state_check() + state.wait_for_everyone() + + if state.distributed_type == DistributedType.MULTI_GPU: + num_processes_per_node = torch.cuda.device_count() + else: + num_processes_per_node = state.num_processes + + # We only run this test on non-multinode + if num_processes_per_node == state.num_processes: + if state.process_index == 0: + print("\n**Test process execution**") + process_execution_check() + + if state.process_index == 0: + print("\n**Test split between processes as a list**") + test_split_between_processes_list() + + if state.process_index == 0: + print("\n**Test split between processes as a dict**") + test_split_between_processes_nested_dict() + + if state.process_index == 0: + print("\n**Test split between processes as a tensor**") + test_split_between_processes_tensor() + + if state.process_index == 0: + print("\n**Test split between processes as a datasets.Dataset**") + if is_datasets_available(): + from datasets import Dataset as datasets_Dataset + + test_split_between_processes_dataset(datasets_Dataset) + else: + print("Skipped because Hugging Face datasets is not available") + + if state.local_process_index == 0: + print("\n**Test random number generator synchronization**") + rng_sync_check() + + if state.local_process_index == 0: + print("\n**DataLoader integration test**") + dl_preparation_check() + if state.distributed_type != DistributedType.XLA: + central_dl_preparation_check() + custom_sampler_check() + check_seedable_sampler() + + if state.num_processes > 1: + check_seedable_sampler_in_batch_sampler_shard() + + # Trainings are not exactly the same in DeepSpeed and CPU mode + if state.distributed_type == DistributedType.DEEPSPEED: + return + + if state.local_process_index == 0: + print("\n**Training integration test**") + training_check(use_seedable_sampler=False) + training_check(use_seedable_sampler=True) + + if state.local_process_index == 0: + print("\n**Breakpoint trigger test**") + test_trigger() + + if state.local_process_index == 0: + print("\n**Test reinstantiated state**") + test_reinstantiated_state() + + +if __name__ == "__main__": + main() diff --git a/llm/Lib/site-packages/accelerate/test_utils/scripts/test_sync.py b/llm/Lib/site-packages/accelerate/test_utils/scripts/test_sync.py new file mode 100644 index 0000000000000000000000000000000000000000..bd458bcab8aaa42409a7c1234a4afffb087e8a7c --- /dev/null +++ b/llm/Lib/site-packages/accelerate/test_utils/scripts/test_sync.py @@ -0,0 +1,392 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy + +import torch +import torch.nn.functional as F +from torch.optim import AdamW +from torch.optim.lr_scheduler import LambdaLR +from torch.utils.data import DataLoader + +from accelerate.accelerator import Accelerator, GradientAccumulationPlugin +from accelerate.state import GradientState +from accelerate.test_utils import RegressionDataset, RegressionModel +from accelerate.utils import DistributedType, set_seed + + +def check_model_parameters(model_a, model_b, did_step, iteration, **kwargs): + for param, grad_param in zip(model_a.parameters(), model_b.parameters()): + if not param.requires_grad: + continue + if not did_step: + # Grads should not be in sync + assert ( + torch.allclose(param.grad, grad_param.grad, **kwargs) is False + ), f"Gradients in sync when they should not be at iteration {iteration}:\nmodel_a grad ({param.grad}) == model_b grad ({grad_param.grad})" + else: + # Grads should be in sync + assert ( + torch.allclose(param.grad, grad_param.grad, **kwargs) is True + ), f"Gradients not in sync when they should be at iteration {iteration}:\nmodel_a grad ({param.grad}) != model_b grad ({grad_param.grad})" + + +def step_model(model, input, target, accelerator, do_backward=True): + model.train() + output = model(input) + loss = F.mse_loss(output, target.to(output.device)) + if not do_backward: + loss /= accelerator.gradient_accumulation_steps + loss.backward() + else: + accelerator.backward(loss) + + +def get_training_setup(accelerator, sched=False): + "Returns everything needed to perform basic training" + set_seed(42) + model = RegressionModel() + ddp_model = deepcopy(model) + dset = RegressionDataset(length=80) + dataloader = DataLoader(dset, batch_size=16) + model.to(accelerator.device) + if sched: + opt = AdamW(params=model.parameters(), lr=1e-3) + ddp_opt = AdamW(params=ddp_model.parameters(), lr=1e-3) + sched = LambdaLR(opt, lr_lambda=lambda epoch: epoch**0.65) + ddp_sched = LambdaLR(ddp_opt, lr_lambda=lambda epoch: epoch**0.65) + # Make a copy of `model` + if sched: + ddp_model, ddp_opt, ddp_sched, dataloader = accelerator.prepare(ddp_model, ddp_opt, ddp_sched, dataloader) + else: + ddp_model, dataloader = accelerator.prepare(ddp_model, dataloader) + if sched: + return (model, opt, sched, dataloader, ddp_model, ddp_opt, ddp_sched) + return model, ddp_model, dataloader + + +def test_noop_sync(accelerator): + # Test when on a single CPU or GPU that the context manager does nothing + model, ddp_model, dataloader = get_training_setup(accelerator) + # Use a single batch + ddp_input, ddp_target = next(iter(dataloader)).values() + for iteration in range(3): + # Gather the distributed inputs and targs for the base model + input, target = accelerator.gather((ddp_input, ddp_target)) + input, target = input.to(accelerator.device), target.to(accelerator.device) + # Perform our initial ground truth step in non "DDP" + step_model(model, input, target, accelerator) + # Do "gradient accumulation" (noop) + if iteration % 2 == 0: + # Accumulate grads locally + with accelerator.no_sync(ddp_model): + step_model(ddp_model, ddp_input, ddp_target, accelerator) + else: + # Sync grads + step_model(ddp_model, ddp_input, ddp_target, accelerator) + + # Since `no_sync` is a noop, `ddp_model` and `model` grads should always be in sync + check_model_parameters(model, ddp_model, True, iteration) + for param, ddp_param in zip(model.parameters(), ddp_model.parameters()): + if not param.requires_grad: + continue + assert torch.allclose( + param.grad, ddp_param.grad + ), f"Gradients not in sync when they should be:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})" + + # Shuffle ddp_input on each iteration + torch.manual_seed(1337 + iteration) + ddp_input = ddp_input[torch.randperm(len(ddp_input))] + + +def test_distributed_sync(accelerator): + # Test on distributed setup that context manager behaves properly + model, ddp_model, dataloader = get_training_setup(accelerator) + # Use a single batch + ddp_input, ddp_target = next(iter(dataloader)).values() + for iteration in range(3): + # Gather the distributed inputs and targs for the base model + input, target = accelerator.gather((ddp_input, ddp_target)) + input, target = input.to(accelerator.device), target.to(accelerator.device) + # Perform our initial ground truth step in non "DDP" + step_model(model, input, target, accelerator) + # Do "gradient accumulation" (noop) + if iteration % 2 == 0: + # Accumulate grads locally + with accelerator.no_sync(ddp_model): + step_model(ddp_model, ddp_input, ddp_target, accelerator) + else: + # Sync grads + step_model(ddp_model, ddp_input, ddp_target, accelerator) + + # DDP model and model should only be in sync when not (iteration % 2 == 0) + for param, ddp_param in zip(model.parameters(), ddp_model.parameters()): + if not param.requires_grad: + continue + if iteration % 2 == 0: + # Grads should not be in sync + assert ( + torch.allclose(param.grad, ddp_param.grad) is False + ), f"Gradients in sync when they should not be:\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})" + else: + # Grads should be in sync + assert ( + torch.allclose(param.grad, ddp_param.grad) is True + ), f"Gradients not in sync when they should be:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})" + + # Shuffle ddp_input on each iteration + torch.manual_seed(1337 + iteration) + ddp_input = ddp_input[torch.randperm(len(ddp_input))] + + +def test_distributed_sync_multiple_fwd(accelerator): + # Test on distributed setup that context manager behaves properly when used with multiple forwards followed by multiple backwards + model, ddp_model, dataloader = get_training_setup(accelerator) + # Do multiple forwards + losses = [] + num_iterations = 3 + for iteration in range(num_iterations): + ddp_input, ddp_target = next(iter(dataloader)).values() + + # Gather the distributed inputs and targs for the base model + input, target = accelerator.gather((ddp_input, ddp_target)) + input, target = input.to(accelerator.device), target.to(accelerator.device) + + # Perform our initial ground truth step in non "DDP" + step_model(model, input, target, accelerator) + + # Accumulate grads locally + with accelerator.no_sync(ddp_model): + ddp_output = ddp_model(ddp_input) + loss = F.mse_loss(ddp_output, ddp_target.to(ddp_output.device)) + losses.append(loss) + + # Do multiple backwards and sync only at the last backward + for iteration in range(num_iterations): + loss = losses[iteration] + + if iteration < num_iterations - 1: + # Accumulate grads locally + accelerator.backward(loss) + + # DDP model and model should only be in sync after last backward + for param, ddp_param in zip(model.parameters(), ddp_model.parameters()): + if not param.requires_grad: + continue + # Grads should not be in sync + assert ( + torch.allclose(param.grad, ddp_param.grad) is False + ), f"Gradients in sync when they should not be:\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})" + + else: + # Sync grads if last backward + with accelerator.trigger_sync_in_backward(ddp_model): + accelerator.backward(loss) + + # DDP model and model should only be in sync after last backward + for param, ddp_param in zip(model.parameters(), ddp_model.parameters()): + if not param.requires_grad: + continue + # Grads should be in sync + assert ( + torch.allclose(param.grad, ddp_param.grad) is True + ), f"Gradients not in sync when they should be:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})" + + +def test_gradient_accumulation(split_batches=False, dispatch_batches=False, sync_each_batch=False): + gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=2, sync_each_batch=sync_each_batch) + accelerator = Accelerator( + split_batches=split_batches, + dispatch_batches=dispatch_batches, + gradient_accumulation_plugin=gradient_accumulation_plugin, + ) + # Test that context manager behaves properly + model, ddp_model, dataloader = get_training_setup(accelerator) + for iteration, batch in enumerate(dataloader): + ddp_input, ddp_target = batch.values() + # Gather the distributed inputs and targs for the base model + input, target = accelerator.gather((ddp_input, ddp_target)) + input, target = input.to(accelerator.device), target.to(accelerator.device) + # Perform our initial ground truth step in non "DDP" + step_model(model, input, target, accelerator, False) + # Do "gradient accumulation" (noop) + with accelerator.accumulate(ddp_model): + step_model(ddp_model, ddp_input, ddp_target, accelerator) + + # DDP model and model should only be in sync when not (iteration % 2 == 0) + for param, ddp_param in zip(model.parameters(), ddp_model.parameters()): + if not param.requires_grad: + continue + if ((iteration + 1) % 2 == 0) or (iteration == len(dataloader) - 1) or sync_each_batch: + # Grads should be in sync + assert ( + torch.allclose(param.grad, ddp_param.grad) is True + ), f"Gradients not in sync when they should be at iteration {iteration}:\nModel grad ({param.grad}) != DDP grad ({ddp_param.grad})" + else: + # Grads should not be in sync + assert ( + torch.allclose(param.grad, ddp_param.grad) is False + ), f"Gradients in sync when they should not be at iteration {iteration}:\nModel grad ({param.grad}) == DDP grad ({ddp_param.grad})" + + # Shuffle ddp_input on each iteration + torch.manual_seed(1337 + iteration) + ddp_input = ddp_input[torch.randperm(len(ddp_input))] + GradientState._reset_state() + + +def test_gradient_accumulation_with_opt_and_scheduler( + split_batches=False, dispatch_batches=False, sync_each_batch=False +): + gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=2, sync_each_batch=sync_each_batch) + accelerator = Accelerator( + split_batches=split_batches, + dispatch_batches=dispatch_batches, + gradient_accumulation_plugin=gradient_accumulation_plugin, + ) + # Test that context manager behaves properly + model, opt, sched, dataloader, ddp_model, ddp_opt, ddp_sched = get_training_setup(accelerator, True) + for iteration, batch in enumerate(dataloader): + ddp_input, ddp_target = batch.values() + # Gather the distributed inputs and targs for the base model + input, target = accelerator.gather((ddp_input, ddp_target)) + input, target = input.to(accelerator.device), target.to(accelerator.device) + # Perform our initial ground truth step in non "DDP" + model.train() + ddp_model.train() + step_model(model, input, target, accelerator, False) + opt.step() + + if ((iteration + 1) % 2 == 0) or ((iteration + 1) == len(dataloader)) or sync_each_batch: + if split_batches: + sched.step() + else: + for _ in range(accelerator.num_processes): + sched.step() + + # Perform gradient accumulation under wrapper + with accelerator.accumulate(ddp_model): + step_model(ddp_model, ddp_input, ddp_target, accelerator) + ddp_opt.step() + ddp_sched.step() + + # Learning rates should be the same + assert ( + opt.param_groups[0]["lr"] == ddp_opt.param_groups[0]["lr"] + ), f'Learning rates found in each optimizer did not align\nopt: {opt.param_groups[0]["lr"]}\nDDP opt: {ddp_opt.param_groups[0]["lr"]}\n' + did_step = (((iteration + 1) % 2) == 0) or ((iteration + 1) == len(dataloader)) or sync_each_batch + if accelerator.num_processes > 1: + check_model_parameters( + model, + ddp_model, + did_step, + iteration, + rtol=1e-3, # somehow needs a relative tolerance + ) + + if ((iteration + 1) % 2 == 0) or ((iteration + 1) == len(dataloader)) or sync_each_batch: + opt.zero_grad() # needs to be guarded by logic as to when we should zero grads + ddp_opt.zero_grad() + + # Shuffle ddp_input on each iteration + torch.manual_seed(1337 + iteration) + GradientState._reset_state() + + +def test_dataloader_break(): + accelerator = Accelerator() + + first_dset = RegressionDataset(length=80) + first_dataloader = DataLoader(first_dset, batch_size=16) + second_dset = RegressionDataset(length=96) + second_dataloader = DataLoader(second_dset, batch_size=16) + first_dataloader, second_dataloader = accelerator.prepare(first_dataloader, second_dataloader) + assert accelerator.gradient_state.active_dataloader is None + for iteration, _ in enumerate(first_dataloader): + assert id(accelerator.gradient_state.active_dataloader) == id(first_dataloader) + if iteration < len(first_dataloader) - 1: + assert not accelerator.gradient_state.end_of_dataloader + if iteration == 1: + for batch_num, _ in enumerate(second_dataloader): + assert id(accelerator.gradient_state.active_dataloader) == id(second_dataloader) + if batch_num < len(second_dataloader) - 1: + assert not accelerator.gradient_state.end_of_dataloader + else: + assert accelerator.gradient_state.end_of_dataloader + else: + assert accelerator.gradient_state.end_of_dataloader + assert accelerator.gradient_state.active_dataloader is None + + +def main(): + accelerator = Accelerator() + state = accelerator.state + if state.local_process_index == 0: + print("**Test `accumulate` gradient accumulation with dataloader break**") + if state.distributed_type != DistributedType.XLA: + test_dataloader_break() + if state.distributed_type == DistributedType.NO: + if state.local_process_index == 0: + print("**Test NOOP `no_sync` context manager**") + test_noop_sync(accelerator) + if state.distributed_type in ( + DistributedType.MULTI_GPU, + DistributedType.MULTI_NPU, + DistributedType.MULTI_MLU, + DistributedType.MULTI_CPU, + ): + if state.local_process_index == 0: + print("**Test Distributed `no_sync` context manager**") + test_distributed_sync(accelerator) + if state.local_process_index == 0: + print("**Test Distributed `no_sync` context manager with multiple forwards**") + test_distributed_sync_multiple_fwd(accelerator) + if state.distributed_type in (DistributedType.MULTI_GPU, DistributedType.MULTI_NPU, DistributedType.MULTI_MLU): + for split_batch in [True, False]: + for dispatch_batches in [True, False]: + for sync_each_batch in [True, False]: + if state.local_process_index == 0: + print( + "**Test `accumulate` gradient accumulation, ", + f"`split_batches={split_batch}` and `dispatch_batches={dispatch_batches}` and `sync_each_batch={sync_each_batch}`**", + ) + test_gradient_accumulation(split_batch, dispatch_batches, sync_each_batch) + + # Currently will break on torch 2.0 +, need to investigate why + if state.local_process_index == 0: + print( + "**Test `accumulate` gradient accumulation with optimizer and scheduler, ", + "`split_batches=False`, `dispatch_batches=False`, `sync_each_batch=False`**", + ) + test_gradient_accumulation_with_opt_and_scheduler() + if state.distributed_type in (DistributedType.MULTI_GPU, DistributedType.MULTI_NPU, DistributedType.MULTI_MLU): + for split_batch in [True, False]: + for dispatch_batches in [True, False]: + for sync_each_batch in [True, False]: + if not split_batch and not dispatch_batches and not sync_each_batch: + continue + if state.local_process_index == 0: + print( + "**Test `accumulate` gradient accumulation with optimizer and scheduler, ", + f"`split_batches={split_batch}` and `dispatch_batches={dispatch_batches}` and `sync_each_batch={sync_each_batch}`**", + ) + test_gradient_accumulation_with_opt_and_scheduler(split_batch, dispatch_batches, sync_each_batch) + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/llm/Lib/site-packages/accelerate/test_utils/testing.py b/llm/Lib/site-packages/accelerate/test_utils/testing.py new file mode 100644 index 0000000000000000000000000000000000000000..179fff4808a1c36fcdc8026a617a4cad549bd796 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/test_utils/testing.py @@ -0,0 +1,605 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import inspect +import os +import shutil +import subprocess +import sys +import tempfile +import unittest +from contextlib import contextmanager +from functools import partial +from pathlib import Path +from typing import List, Union +from unittest import mock + +import torch + +import accelerate + +from ..state import AcceleratorState, PartialState +from ..utils import ( + gather, + is_bnb_available, + is_clearml_available, + is_comet_ml_available, + is_cuda_available, + is_datasets_available, + is_deepspeed_available, + is_dvclive_available, + is_mlu_available, + is_mps_available, + is_npu_available, + is_pandas_available, + is_pippy_available, + is_tensorboard_available, + is_timm_available, + is_torch_version, + is_torch_xla_available, + is_transformers_available, + is_wandb_available, + is_xpu_available, + str_to_bool, +) + + +def get_backend(): + if is_torch_xla_available(): + return "xla", torch.cuda.device_count(), torch.cuda.memory_allocated + elif is_cuda_available(): + return "cuda", torch.cuda.device_count(), torch.cuda.memory_allocated + elif is_mps_available(): + return "mps", 1, torch.mps.current_allocated_memory() + elif is_mlu_available(): + return "mlu", torch.mlu.device_count(), torch.mlu.memory_allocated + elif is_npu_available(): + return "npu", torch.npu.device_count(), torch.npu.memory_allocated + elif is_xpu_available(): + return "xpu", torch.xpu.device_count(), torch.xpu.memory_allocated + else: + return "cpu", 1, 0 + + +torch_device, device_count, memory_allocated_func = get_backend() + + +def get_launch_command(**kwargs) -> list: + """ + Wraps around `kwargs` to help simplify launching from `subprocess`. + + Example: + ```python + # returns ['accelerate', 'launch', '--num_processes=2', '--device_count=2'] + get_launch_command(num_processes=2, device_count=2) + ``` + """ + command = ["accelerate", "launch"] + for k, v in kwargs.items(): + if isinstance(v, bool) and v: + command.append(f"--{k}") + elif v is not None: + command.append(f"--{k}={v}") + return command + + +DEFAULT_LAUNCH_COMMAND = get_launch_command(num_processes=device_count) + + +def parse_flag_from_env(key, default=False): + try: + value = os.environ[key] + except KeyError: + # KEY isn't set, default to `default`. + _value = default + else: + # KEY is set, convert it to True or False. + try: + _value = str_to_bool(value) + except ValueError: + # More values are supported, but let's keep the message simple. + raise ValueError(f"If set, {key} must be yes or no.") + return _value + + +_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) + + +def skip(test_case): + "Decorator that skips a test unconditionally" + return unittest.skip("Test was skipped")(test_case) + + +def slow(test_case): + """ + Decorator marking a test as slow. Slow tests are skipped by default. Set the RUN_SLOW environment variable to a + truthy value to run them. + """ + return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) + + +def require_cpu(test_case): + """ + Decorator marking a test that must be only ran on the CPU. These tests are skipped when a GPU is available. + """ + return unittest.skipUnless(torch_device == "cpu", "test requires only a CPU")(test_case) + + +def require_non_cpu(test_case): + """ + Decorator marking a test that requires a hardware accelerator backend. These tests are skipped when there are no + hardware accelerator available. + """ + return unittest.skipUnless(torch_device != "cpu", "test requires a GPU")(test_case) + + +def require_cuda(test_case): + """ + Decorator marking a test that requires CUDA. These tests are skipped when there are no GPU available or when + TorchXLA is available. + """ + return unittest.skipUnless(is_cuda_available() and not is_torch_xla_available(), "test requires a GPU")(test_case) + + +def require_xpu(test_case): + """ + Decorator marking a test that requires XPU. These tests are skipped when there are no XPU available. + """ + return unittest.skipUnless(is_xpu_available(), "test requires a XPU")(test_case) + + +def require_non_xpu(test_case): + """ + Decorator marking a test that should be skipped for XPU. + """ + return unittest.skipUnless(torch_device != "xpu", "test requires a non-XPU")(test_case) + + +def require_mlu(test_case): + """ + Decorator marking a test that requires MLU. These tests are skipped when there are no MLU available. + """ + return unittest.skipUnless(is_mlu_available(), "test require a MLU")(test_case) + + +def require_npu(test_case): + """ + Decorator marking a test that requires NPU. These tests are skipped when there are no NPU available. + """ + return unittest.skipUnless(is_npu_available(), "test require a NPU")(test_case) + + +def require_mps(test_case): + """ + Decorator marking a test that requires MPS backend. These tests are skipped when torch doesn't support `mps` + backend. + """ + return unittest.skipUnless(is_mps_available(), "test requires a `mps` backend support in `torch`")(test_case) + + +def require_huggingface_suite(test_case): + """ + Decorator marking a test that requires transformers and datasets. These tests are skipped when they are not. + """ + return unittest.skipUnless( + is_transformers_available() and is_datasets_available(), + "test requires the Hugging Face suite", + )(test_case) + + +def require_transformers(test_case): + """ + Decorator marking a test that requires transformers. These tests are skipped when they are not. + """ + return unittest.skipUnless(is_transformers_available(), "test requires the transformers library")(test_case) + + +def require_timm(test_case): + """ + Decorator marking a test that requires transformers. These tests are skipped when they are not. + """ + return unittest.skipUnless(is_timm_available(), "test requires the timm library")(test_case) + + +def require_bnb(test_case): + """ + Decorator marking a test that requires bitsandbytes. These tests are skipped when they are not. + """ + return unittest.skipUnless(is_bnb_available(), "test requires the bitsandbytes library")(test_case) + + +def require_tpu(test_case): + """ + Decorator marking a test that requires TPUs. These tests are skipped when there are no TPUs available. + """ + return unittest.skipUnless(is_torch_xla_available(check_is_tpu=True), "test requires TPU")(test_case) + + +def require_non_torch_xla(test_case): + """ + Decorator marking a test as requiring an environment without TorchXLA. These tests are skipped when TorchXLA is + available. + """ + return unittest.skipUnless(not is_torch_xla_available(), "test requires an env without TorchXLA")(test_case) + + +def require_single_device(test_case): + """ + Decorator marking a test that requires a single device. These tests are skipped when there is no hardware + accelerator available or number of devices is more than one. + """ + return unittest.skipUnless(torch_device != "cpu" and device_count == 1, "test requires a hardware accelerator")( + test_case + ) + + +def require_single_gpu(test_case): + """ + Decorator marking a test that requires CUDA on a single GPU. These tests are skipped when there are no GPU + available or number of GPUs is more than one. + """ + return unittest.skipUnless(torch.cuda.device_count() == 1, "test requires a GPU")(test_case) + + +def require_single_xpu(test_case): + """ + Decorator marking a test that requires CUDA on a single XPU. These tests are skipped when there are no XPU + available or number of xPUs is more than one. + """ + return unittest.skipUnless(torch.xpu.device_count() == 1, "test requires a XPU")(test_case) + + +def require_multi_device(test_case): + """ + Decorator marking a test that requires a multi-device setup. These tests are skipped on a machine without multiple + devices. + """ + return unittest.skipUnless(device_count > 1, "test requires multiple hardware accelerators")(test_case) + + +def require_multi_gpu(test_case): + """ + Decorator marking a test that requires a multi-GPU setup. These tests are skipped on a machine without multiple + GPUs. + """ + return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case) + + +def require_multi_xpu(test_case): + """ + Decorator marking a test that requires a multi-XPU setup. These tests are skipped on a machine without multiple + XPUs. + """ + return unittest.skipUnless(torch.xpu.device_count() > 1, "test requires multiple XPUs")(test_case) + + +def require_deepspeed(test_case): + """ + Decorator marking a test that requires DeepSpeed installed. These tests are skipped when DeepSpeed isn't installed + """ + return unittest.skipUnless(is_deepspeed_available(), "test requires DeepSpeed")(test_case) + + +def require_fsdp(test_case): + """ + Decorator marking a test that requires FSDP installed. These tests are skipped when FSDP isn't installed + """ + return unittest.skipUnless(is_torch_version(">=", "1.12.0"), "test requires torch version >= 1.12.0")(test_case) + + +def require_torch_min_version(test_case=None, version=None): + """ + Decorator marking that a test requires a particular torch version to be tested. These tests are skipped when an + installed torch version is less than the required one. + """ + if test_case is None: + return partial(require_torch_min_version, version=version) + return unittest.skipUnless(is_torch_version(">=", version), f"test requires torch version >= {version}")(test_case) + + +def require_tensorboard(test_case): + """ + Decorator marking a test that requires tensorboard installed. These tests are skipped when tensorboard isn't + installed + """ + return unittest.skipUnless(is_tensorboard_available(), "test requires Tensorboard")(test_case) + + +def require_wandb(test_case): + """ + Decorator marking a test that requires wandb installed. These tests are skipped when wandb isn't installed + """ + return unittest.skipUnless(is_wandb_available(), "test requires wandb")(test_case) + + +def require_comet_ml(test_case): + """ + Decorator marking a test that requires comet_ml installed. These tests are skipped when comet_ml isn't installed + """ + return unittest.skipUnless(is_comet_ml_available(), "test requires comet_ml")(test_case) + + +def require_clearml(test_case): + """ + Decorator marking a test that requires clearml installed. These tests are skipped when clearml isn't installed + """ + return unittest.skipUnless(is_clearml_available(), "test requires clearml")(test_case) + + +def require_dvclive(test_case): + """ + Decorator marking a test that requires dvclive installed. These tests are skipped when dvclive isn't installed + """ + return unittest.skipUnless(is_dvclive_available(), "test requires dvclive")(test_case) + + +def require_pandas(test_case): + """ + Decorator marking a test that requires pandas installed. These tests are skipped when pandas isn't installed + """ + return unittest.skipUnless(is_pandas_available(), "test requires pandas")(test_case) + + +def require_pippy(test_case): + """ + Decorator marking a test that requires pippy installed. These tests are skipped when pippy isn't installed + """ + return unittest.skipUnless(is_pippy_available(), "test requires pippy")(test_case) + + +_atleast_one_tracker_available = ( + any([is_wandb_available(), is_tensorboard_available()]) and not is_comet_ml_available() +) + + +def require_trackers(test_case): + """ + Decorator marking that a test requires at least one tracking library installed. These tests are skipped when none + are installed + """ + return unittest.skipUnless( + _atleast_one_tracker_available, + "test requires at least one tracker to be available and for `comet_ml` to not be installed", + )(test_case) + + +class TempDirTestCase(unittest.TestCase): + """ + A TestCase class that keeps a single `tempfile.TemporaryDirectory` open for the duration of the class, wipes its + data at the start of a test, and then destroyes it at the end of the TestCase. + + Useful for when a class or API requires a single constant folder throughout it's use, such as Weights and Biases + + The temporary directory location will be stored in `self.tmpdir` + """ + + clear_on_setup = True + + @classmethod + def setUpClass(cls): + "Creates a `tempfile.TemporaryDirectory` and stores it in `cls.tmpdir`" + cls.tmpdir = Path(tempfile.mkdtemp()) + + @classmethod + def tearDownClass(cls): + "Remove `cls.tmpdir` after test suite has finished" + if os.path.exists(cls.tmpdir): + shutil.rmtree(cls.tmpdir) + + def setUp(self): + "Destroy all contents in `self.tmpdir`, but not `self.tmpdir`" + if self.clear_on_setup: + for path in self.tmpdir.glob("**/*"): + if path.is_file(): + path.unlink() + elif path.is_dir(): + shutil.rmtree(path) + + +class AccelerateTestCase(unittest.TestCase): + """ + A TestCase class that will reset the accelerator state at the end of every test. Every test that checks or utilizes + the `AcceleratorState` class should inherit from this to avoid silent failures due to state being shared between + tests. + """ + + def tearDown(self): + super().tearDown() + # Reset the state of the AcceleratorState singleton. + AcceleratorState._reset_state() + PartialState._reset_state() + + +class MockingTestCase(unittest.TestCase): + """ + A TestCase class designed to dynamically add various mockers that should be used in every test, mimicking the + behavior of a class-wide mock when defining one normally will not do. + + Useful when a mock requires specific information available only initialized after `TestCase.setUpClass`, such as + setting an environment variable with that information. + + The `add_mocks` function should be ran at the end of a `TestCase`'s `setUp` function, after a call to + `super().setUp()` such as: + ```python + def setUp(self): + super().setUp() + mocks = mock.patch.dict(os.environ, {"SOME_ENV_VAR", "SOME_VALUE"}) + self.add_mocks(mocks) + ``` + """ + + def add_mocks(self, mocks: Union[mock.Mock, List[mock.Mock]]): + """ + Add custom mocks for tests that should be repeated on each test. Should be called during + `MockingTestCase.setUp`, after `super().setUp()`. + + Args: + mocks (`mock.Mock` or list of `mock.Mock`): + Mocks that should be added to the `TestCase` after `TestCase.setUpClass` has been run + """ + self.mocks = mocks if isinstance(mocks, (tuple, list)) else [mocks] + for m in self.mocks: + m.start() + self.addCleanup(m.stop) + + +def are_the_same_tensors(tensor): + state = AcceleratorState() + tensor = tensor[None].clone().to(state.device) + tensors = gather(tensor).cpu() + tensor = tensor[0].cpu() + for i in range(tensors.shape[0]): + if not torch.equal(tensors[i], tensor): + return False + return True + + +class _RunOutput: + def __init__(self, returncode, stdout, stderr): + self.returncode = returncode + self.stdout = stdout + self.stderr = stderr + + +async def _read_stream(stream, callback): + while True: + line = await stream.readline() + if line: + callback(line) + else: + break + + +async def _stream_subprocess(cmd, env=None, stdin=None, timeout=None, quiet=False, echo=False) -> _RunOutput: + if echo: + print("\nRunning: ", " ".join(cmd)) + + p = await asyncio.create_subprocess_exec( + cmd[0], + *cmd[1:], + stdin=stdin, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + env=env, + ) + + # note: there is a warning for a possible deadlock when using `wait` with huge amounts of data in the pipe + # https://docs.python.org/3/library/asyncio-subprocess.html#asyncio.asyncio.subprocess.Process.wait + # + # If it starts hanging, will need to switch to the following code. The problem is that no data + # will be seen until it's done and if it hangs for example there will be no debug info. + # out, err = await p.communicate() + # return _RunOutput(p.returncode, out, err) + + out = [] + err = [] + + def tee(line, sink, pipe, label=""): + line = line.decode("utf-8").rstrip() + sink.append(line) + if not quiet: + print(label, line, file=pipe) + + # XXX: the timeout doesn't seem to make any difference here + await asyncio.wait( + [ + asyncio.create_task(_read_stream(p.stdout, lambda l: tee(l, out, sys.stdout, label="stdout:"))), + asyncio.create_task(_read_stream(p.stderr, lambda l: tee(l, err, sys.stderr, label="stderr:"))), + ], + timeout=timeout, + ) + return _RunOutput(await p.wait(), out, err) + + +def execute_subprocess_async(cmd: list, env=None, stdin=None, timeout=180, quiet=False, echo=True) -> _RunOutput: + # Cast every path in `cmd` to a string + for i, c in enumerate(cmd): + if isinstance(c, Path): + cmd[i] = str(c) + loop = asyncio.get_event_loop() + result = loop.run_until_complete( + _stream_subprocess(cmd, env=env, stdin=stdin, timeout=timeout, quiet=quiet, echo=echo) + ) + + cmd_str = " ".join(cmd) + if result.returncode > 0: + stderr = "\n".join(result.stderr) + raise RuntimeError( + f"'{cmd_str}' failed with returncode {result.returncode}\n\n" + f"The combined stderr from workers follows:\n{stderr}" + ) + + return result + + +class SubprocessCallException(Exception): + pass + + +def run_command(command: List[str], return_stdout=False, env=None): + """ + Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture + if an error occured while running `command` + """ + # Cast every path in `command` to a string + for i, c in enumerate(command): + if isinstance(c, Path): + command[i] = str(c) + if env is None: + env = os.environ.copy() + try: + output = subprocess.check_output(command, stderr=subprocess.STDOUT, env=env) + if return_stdout: + if hasattr(output, "decode"): + output = output.decode("utf-8") + return output + except subprocess.CalledProcessError as e: + raise SubprocessCallException( + f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}" + ) from e + + +def path_in_accelerate_package(*components: str) -> Path: + """ + Get a path within the `accelerate` package's directory. + + Args: + *components: Components of the path to join after the package directory. + + Returns: + `Path`: The path to the requested file or directory. + """ + + accelerate_package_dir = Path(inspect.getfile(accelerate)).parent + return accelerate_package_dir.joinpath(*components) + + +@contextmanager +def assert_exception(exception_class: Exception, msg: str = None) -> bool: + """ + Context manager to assert that the right `Exception` class was raised. + + If `msg` is provided, will check that the message is contained in the raised exception. + """ + was_ran = False + try: + yield + was_ran = True + except Exception as e: + assert isinstance(e, exception_class), f"Expected exception of type {exception_class} but got {type(e)}" + if msg is not None: + assert msg in str(e), f"Expected message '{msg}' to be in exception but got '{str(e)}'" + if was_ran: + raise AssertionError(f"Expected exception of type {exception_class} but ran without issue.") diff --git a/llm/Lib/site-packages/accelerate/test_utils/training.py b/llm/Lib/site-packages/accelerate/test_utils/training.py new file mode 100644 index 0000000000000000000000000000000000000000..d89cfd3c71546871d00cb9c2a5cd07494c46cbfe --- /dev/null +++ b/llm/Lib/site-packages/accelerate/test_utils/training.py @@ -0,0 +1,101 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +from torch.utils.data import DataLoader + +from accelerate.utils.dataclasses import DistributedType + + +class RegressionDataset: + def __init__(self, a=2, b=3, length=64, seed=None): + rng = np.random.default_rng(seed) + self.length = length + self.x = rng.normal(size=(length,)).astype(np.float32) + self.y = a * self.x + b + rng.normal(scale=0.1, size=(length,)).astype(np.float32) + + def __len__(self): + return self.length + + def __getitem__(self, i): + return {"x": self.x[i], "y": self.y[i]} + + +class RegressionModel4XPU(torch.nn.Module): + def __init__(self, a=0, b=0, double_output=False): + super().__init__() + self.a = torch.nn.Parameter(torch.tensor([2, 3]).float()) + self.b = torch.nn.Parameter(torch.tensor([2, 3]).float()) + self.first_batch = True + + def forward(self, x=None): + if self.first_batch: + print(f"Model dtype: {self.a.dtype}, {self.b.dtype}. Input dtype: {x.dtype}") + self.first_batch = False + return x * self.a[0] + self.b[0] + + +class RegressionModel(torch.nn.Module): + def __init__(self, a=0, b=0, double_output=False): + super().__init__() + self.a = torch.nn.Parameter(torch.tensor(a).float()) + self.b = torch.nn.Parameter(torch.tensor(b).float()) + self.first_batch = True + + def forward(self, x=None): + if self.first_batch: + print(f"Model dtype: {self.a.dtype}, {self.b.dtype}. Input dtype: {x.dtype}") + self.first_batch = False + return x * self.a + self.b + + +def mocked_dataloaders(accelerator, batch_size: int = 16): + from datasets import load_dataset + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained("bert-base-cased") + data_files = {"train": "tests/test_samples/MRPC/train.csv", "validation": "tests/test_samples/MRPC/dev.csv"} + datasets = load_dataset("csv", data_files=data_files) + label_list = datasets["train"].unique("label") + + label_to_id = {v: i for i, v in enumerate(label_list)} + + def tokenize_function(examples): + # max_length=None => use the model max length (it's actually the default) + outputs = tokenizer( + examples["sentence1"], examples["sentence2"], truncation=True, max_length=None, padding="max_length" + ) + if "label" in examples: + outputs["labels"] = [label_to_id[l] for l in examples["label"]] + return outputs + + # Apply the method we just defined to all the examples in all the splits of the dataset + tokenized_datasets = datasets.map( + tokenize_function, + batched=True, + remove_columns=["sentence1", "sentence2", "label"], + ) + + def collate_fn(examples): + # On TPU it's best to pad everything to the same length or training will be very slow. + if accelerator.distributed_type == DistributedType.XLA: + return tokenizer.pad(examples, padding="max_length", max_length=128, return_tensors="pt") + return tokenizer.pad(examples, padding="longest", return_tensors="pt") + + # Instantiate dataloaders. + train_dataloader = DataLoader(tokenized_datasets["train"], shuffle=True, collate_fn=collate_fn, batch_size=2) + eval_dataloader = DataLoader(tokenized_datasets["validation"], shuffle=False, collate_fn=collate_fn, batch_size=1) + + return train_dataloader, eval_dataloader diff --git a/llm/Lib/site-packages/accelerate/tracking.py b/llm/Lib/site-packages/accelerate/tracking.py new file mode 100644 index 0000000000000000000000000000000000000000..5efba19bc6769d9c70ea8b17b8da784b908f529f --- /dev/null +++ b/llm/Lib/site-packages/accelerate/tracking.py @@ -0,0 +1,1023 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Expectation: +# Provide a project dir name, then each type of logger gets stored in project/{`logging_dir`} + +import json +import os +import time +from functools import wraps +from typing import Any, Dict, List, Optional, Union + +import yaml + +from .logging import get_logger +from .state import PartialState +from .utils import ( + LoggerType, + is_aim_available, + is_clearml_available, + is_comet_ml_available, + is_dvclive_available, + is_mlflow_available, + is_tensorboard_available, + is_wandb_available, + listify, +) + + +_available_trackers = [] + +if is_tensorboard_available(): + _available_trackers.append(LoggerType.TENSORBOARD) + +if is_wandb_available(): + _available_trackers.append(LoggerType.WANDB) + +if is_comet_ml_available(): + _available_trackers.append(LoggerType.COMETML) + +if is_aim_available(): + _available_trackers.append(LoggerType.AIM) + +if is_mlflow_available(): + _available_trackers.append(LoggerType.MLFLOW) + +if is_clearml_available(): + _available_trackers.append(LoggerType.CLEARML) + +if is_dvclive_available(): + _available_trackers.append(LoggerType.DVCLIVE) + +logger = get_logger(__name__) + + +def on_main_process(function): + """ + Decorator to selectively run the decorated function on the main process only based on the `main_process_only` + attribute in a class. + + Checks at function execution rather than initialization time, not triggering the initialization of the + `PartialState`. + """ + + @wraps(function) + def execute_on_main_process(self, *args, **kwargs): + if getattr(self, "main_process_only", False): + return PartialState().on_main_process(function)(self, *args, **kwargs) + else: + return function(self, *args, **kwargs) + + return execute_on_main_process + + +def get_available_trackers(): + "Returns a list of all supported available trackers in the system" + return _available_trackers + + +class GeneralTracker: + """ + A base Tracker class to be used for all logging integration implementations. + + Each function should take in `**kwargs` that will automatically be passed in from a base dictionary provided to + [`Accelerator`]. + + Should implement `name`, `requires_logging_directory`, and `tracker` properties such that: + + `name` (`str`): String representation of the tracker class name, such as "TensorBoard" `requires_logging_directory` + (`bool`): Whether the logger requires a directory to store their logs. `tracker` (`object`): Should return internal + tracking mechanism used by a tracker class (such as the `run` for wandb) + + Implementations can also include a `main_process_only` (`bool`) attribute to toggle if relevent logging, init, and + other functions should occur on the main process or across all processes (by default will use `True`) + """ + + main_process_only = True + + def __init__(self, _blank=False): + if not _blank: + err = "" + if not hasattr(self, "name"): + err += "`name`" + if not hasattr(self, "requires_logging_directory"): + if len(err) > 0: + err += ", " + err += "`requires_logging_directory`" + + # as tracker is a @property that relies on post-init + if "tracker" not in dir(self): + if len(err) > 0: + err += ", " + err += "`tracker`" + if len(err) > 0: + raise NotImplementedError( + f"The implementation for this tracker class is missing the following " + f"required attributes. Please define them in the class definition: " + f"{err}" + ) + + def store_init_configuration(self, values: dict): + """ + Logs `values` as hyperparameters for the run. Implementations should use the experiment configuration + functionality of a tracking API. + + Args: + values (Dictionary `str` to `bool`, `str`, `float` or `int`): + Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`, + `str`, `float`, `int`, or `None`. + """ + pass + + def log(self, values: dict, step: Optional[int], **kwargs): + """ + Logs `values` to the current run. Base `log` implementations of a tracking API should go in here, along with + special behavior for the `step parameter. + + Args: + values (Dictionary `str` to `str`, `float`, or `int`): + Values to be logged as key-value pairs. The values need to have type `str`, `float`, or `int`. + step (`int`, *optional*): + The run step. If included, the log will be affiliated with this step. + """ + pass + + def finish(self): + """ + Should run any finalizing functions within the tracking API. If the API should not have one, just don't + overwrite that method. + """ + pass + + +class TensorBoardTracker(GeneralTracker): + """ + A `Tracker` class that supports `tensorboard`. Should be initialized at the start of your script. + + Args: + run_name (`str`): + The name of the experiment run + logging_dir (`str`, `os.PathLike`): + Location for TensorBoard logs to be stored. + **kwargs (additional keyword arguments, *optional*): + Additional key word arguments passed along to the `tensorboard.SummaryWriter.__init__` method. + """ + + name = "tensorboard" + requires_logging_directory = True + + @on_main_process + def __init__(self, run_name: str, logging_dir: Union[str, os.PathLike], **kwargs): + try: + from torch.utils import tensorboard + except ModuleNotFoundError: + import tensorboardX as tensorboard + super().__init__() + self.run_name = run_name + self.logging_dir = os.path.join(logging_dir, run_name) + self.writer = tensorboard.SummaryWriter(self.logging_dir, **kwargs) + logger.debug(f"Initialized TensorBoard project {self.run_name} logging to {self.logging_dir}") + logger.debug( + "Make sure to log any initial configurations with `self.store_init_configuration` before training!" + ) + + @property + def tracker(self): + return self.writer + + @on_main_process + def store_init_configuration(self, values: dict): + """ + Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment. Stores the + hyperparameters in a yaml file for future use. + + Args: + values (Dictionary `str` to `bool`, `str`, `float` or `int`): + Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`, + `str`, `float`, `int`, or `None`. + """ + self.writer.add_hparams(values, metric_dict={}) + self.writer.flush() + project_run_name = time.time() + dir_name = os.path.join(self.logging_dir, str(project_run_name)) + os.makedirs(dir_name, exist_ok=True) + with open(os.path.join(dir_name, "hparams.yml"), "w") as outfile: + try: + yaml.dump(values, outfile) + except yaml.representer.RepresenterError: + logger.error("Serialization to store hyperparameters failed") + raise + logger.debug("Stored initial configuration hyperparameters to TensorBoard and hparams yaml file") + + @on_main_process + def log(self, values: dict, step: Optional[int] = None, **kwargs): + """ + Logs `values` to the current run. + + Args: + values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`): + Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of + `str` to `float`/`int`. + step (`int`, *optional*): + The run step. If included, the log will be affiliated with this step. + kwargs: + Additional key word arguments passed along to either `SummaryWriter.add_scaler`, + `SummaryWriter.add_text`, or `SummaryWriter.add_scalers` method based on the contents of `values`. + """ + values = listify(values) + for k, v in values.items(): + if isinstance(v, (int, float)): + self.writer.add_scalar(k, v, global_step=step, **kwargs) + elif isinstance(v, str): + self.writer.add_text(k, v, global_step=step, **kwargs) + elif isinstance(v, dict): + self.writer.add_scalars(k, v, global_step=step, **kwargs) + self.writer.flush() + logger.debug("Successfully logged to TensorBoard") + + @on_main_process + def log_images(self, values: dict, step: Optional[int], **kwargs): + """ + Logs `images` to the current run. + + Args: + values (Dictionary `str` to `List` of `np.ndarray` or `PIL.Image`): + Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or + step (`int`, *optional*): + The run step. If included, the log will be affiliated with this step. + kwargs: + Additional key word arguments passed along to the `SummaryWriter.add_image` method. + """ + for k, v in values.items(): + self.writer.add_images(k, v, global_step=step, **kwargs) + logger.debug("Successfully logged images to TensorBoard") + + @on_main_process + def finish(self): + """ + Closes `TensorBoard` writer + """ + self.writer.close() + logger.debug("TensorBoard writer closed") + + +class WandBTracker(GeneralTracker): + """ + A `Tracker` class that supports `wandb`. Should be initialized at the start of your script. + + Args: + run_name (`str`): + The name of the experiment run. + **kwargs (additional keyword arguments, *optional*): + Additional key word arguments passed along to the `wandb.init` method. + """ + + name = "wandb" + requires_logging_directory = False + main_process_only = False + + @on_main_process + def __init__(self, run_name: str, **kwargs): + super().__init__() + self.run_name = run_name + + import wandb + + self.run = wandb.init(project=self.run_name, **kwargs) + logger.debug(f"Initialized WandB project {self.run_name}") + logger.debug( + "Make sure to log any initial configurations with `self.store_init_configuration` before training!" + ) + + @property + def tracker(self): + return self.run + + @on_main_process + def store_init_configuration(self, values: dict): + """ + Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment. + + Args: + values (Dictionary `str` to `bool`, `str`, `float` or `int`): + Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`, + `str`, `float`, `int`, or `None`. + """ + import wandb + + wandb.config.update(values, allow_val_change=True) + logger.debug("Stored initial configuration hyperparameters to WandB") + + @on_main_process + def log(self, values: dict, step: Optional[int] = None, **kwargs): + """ + Logs `values` to the current run. + + Args: + values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`): + Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of + `str` to `float`/`int`. + step (`int`, *optional*): + The run step. If included, the log will be affiliated with this step. + kwargs: + Additional key word arguments passed along to the `wandb.log` method. + """ + self.run.log(values, step=step, **kwargs) + logger.debug("Successfully logged to WandB") + + @on_main_process + def log_images(self, values: dict, step: Optional[int] = None, **kwargs): + """ + Logs `images` to the current run. + + Args: + values (Dictionary `str` to `List` of `np.ndarray` or `PIL.Image`): + Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or + step (`int`, *optional*): + The run step. If included, the log will be affiliated with this step. + kwargs: + Additional key word arguments passed along to the `wandb.log` method. + """ + import wandb + + for k, v in values.items(): + self.log({k: [wandb.Image(image) for image in v]}, step=step, **kwargs) + logger.debug("Successfully logged images to WandB") + + @on_main_process + def log_table( + self, + table_name: str, + columns: List[str] = None, + data: List[List[Any]] = None, + dataframe: Any = None, + step: Optional[int] = None, + **kwargs, + ): + """ + Log a Table containing any object type (text, image, audio, video, molecule, html, etc). Can be defined either + with `columns` and `data` or with `dataframe`. + + Args: + table_name (`str`): + The name to give to the logged table on the wandb workspace + columns (list of `str`, *optional*): + The name of the columns on the table + data (List of List of Any data type, *optional*): + The data to be logged in the table + dataframe (Any data type, *optional*): + The data to be logged in the table + step (`int`, *optional*): + The run step. If included, the log will be affiliated with this step. + """ + import wandb + + values = {table_name: wandb.Table(columns=columns, data=data, dataframe=dataframe)} + self.log(values, step=step, **kwargs) + + @on_main_process + def finish(self): + """ + Closes `wandb` writer + """ + self.run.finish() + logger.debug("WandB run closed") + + +class CometMLTracker(GeneralTracker): + """ + A `Tracker` class that supports `comet_ml`. Should be initialized at the start of your script. + + API keys must be stored in a Comet config file. + + Args: + run_name (`str`): + The name of the experiment run. + **kwargs (additional keyword arguments, *optional*): + Additional key word arguments passed along to the `Experiment.__init__` method. + """ + + name = "comet_ml" + requires_logging_directory = False + + @on_main_process + def __init__(self, run_name: str, **kwargs): + super().__init__() + self.run_name = run_name + + from comet_ml import Experiment + + self.writer = Experiment(project_name=run_name, **kwargs) + logger.debug(f"Initialized CometML project {self.run_name}") + logger.debug( + "Make sure to log any initial configurations with `self.store_init_configuration` before training!" + ) + + @property + def tracker(self): + return self.writer + + @on_main_process + def store_init_configuration(self, values: dict): + """ + Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment. + + Args: + values (Dictionary `str` to `bool`, `str`, `float` or `int`): + Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`, + `str`, `float`, `int`, or `None`. + """ + self.writer.log_parameters(values) + logger.debug("Stored initial configuration hyperparameters to CometML") + + @on_main_process + def log(self, values: dict, step: Optional[int] = None, **kwargs): + """ + Logs `values` to the current run. + + Args: + values (Dictionary `str` to `str`, `float`, `int` or `dict` of `str` to `float`/`int`): + Values to be logged as key-value pairs. The values need to have type `str`, `float`, `int` or `dict` of + `str` to `float`/`int`. + step (`int`, *optional*): + The run step. If included, the log will be affiliated with this step. + kwargs: + Additional key word arguments passed along to either `Experiment.log_metric`, `Experiment.log_other`, + or `Experiment.log_metrics` method based on the contents of `values`. + """ + if step is not None: + self.writer.set_step(step) + for k, v in values.items(): + if isinstance(v, (int, float)): + self.writer.log_metric(k, v, step=step, **kwargs) + elif isinstance(v, str): + self.writer.log_other(k, v, **kwargs) + elif isinstance(v, dict): + self.writer.log_metrics(v, step=step, **kwargs) + logger.debug("Successfully logged to CometML") + + @on_main_process + def finish(self): + """ + Closes `comet-ml` writer + """ + self.writer.end() + logger.debug("CometML run closed") + + +class AimTracker(GeneralTracker): + """ + A `Tracker` class that supports `aim`. Should be initialized at the start of your script. + + Args: + run_name (`str`): + The name of the experiment run. + **kwargs (additional keyword arguments, *optional*): + Additional key word arguments passed along to the `Run.__init__` method. + """ + + name = "aim" + requires_logging_directory = True + + @on_main_process + def __init__(self, run_name: str, logging_dir: Optional[Union[str, os.PathLike]] = ".", **kwargs): + self.run_name = run_name + + from aim import Run + + self.writer = Run(repo=logging_dir, **kwargs) + self.writer.name = self.run_name + logger.debug(f"Initialized Aim project {self.run_name}") + logger.debug( + "Make sure to log any initial configurations with `self.store_init_configuration` before training!" + ) + + @property + def tracker(self): + return self.writer + + @on_main_process + def store_init_configuration(self, values: dict): + """ + Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment. + + Args: + values (`dict`): + Values to be stored as initial hyperparameters as key-value pairs. + """ + self.writer["hparams"] = values + + @on_main_process + def log(self, values: dict, step: Optional[int], **kwargs): + """ + Logs `values` to the current run. + + Args: + values (`dict`): + Values to be logged as key-value pairs. + step (`int`, *optional*): + The run step. If included, the log will be affiliated with this step. + kwargs: + Additional key word arguments passed along to the `Run.track` method. + """ + # Note: replace this with the dictionary support when merged + for key, value in values.items(): + self.writer.track(value, name=key, step=step, **kwargs) + + @on_main_process + def log_images(self, values: dict, step: Optional[int] = None, kwargs: Optional[Dict[str, dict]] = None): + """ + Logs `images` to the current run. + + Args: + values (`Dict[str, Union[np.ndarray, PIL.Image, Tuple[np.ndarray, str], Tuple[PIL.Image, str]]]`): + Values to be logged as key-value pairs. The values need to have type `np.ndarray` or PIL.Image. If a + tuple is provided, the first element should be the image and the second element should be the caption. + step (`int`, *optional*): + The run step. If included, the log will be affiliated with this step. + kwargs (`Dict[str, dict]`): + Additional key word arguments passed along to the `Run.Image` and `Run.track` method specified by the + keys `aim_image` and `track`, respectively. + """ + import aim + + aim_image_kw = {} + track_kw = {} + + if kwargs is not None: + aim_image_kw = kwargs.get("aim_image", {}) + track_kw = kwargs.get("track", {}) + + for key, value in values.items(): + if isinstance(value, tuple): + img, caption = value + else: + img, caption = value, "" + aim_image = aim.Image(img, caption=caption, **aim_image_kw) + self.writer.track(aim_image, name=key, step=step, **track_kw) + + @on_main_process + def finish(self): + """ + Closes `aim` writer + """ + self.writer.close() + + +class MLflowTracker(GeneralTracker): + """ + A `Tracker` class that supports `mlflow`. Should be initialized at the start of your script. + + Args: + experiment_name (`str`, *optional*): + Name of the experiment. Environment variable MLFLOW_EXPERIMENT_NAME has priority over this argument. + logging_dir (`str` or `os.PathLike`, defaults to `"."`): + Location for mlflow logs to be stored. + run_id (`str`, *optional*): + If specified, get the run with the specified UUID and log parameters and metrics under that run. The run’s + end time is unset and its status is set to running, but the run’s other attributes (source_version, + source_type, etc.) are not changed. Environment variable MLFLOW_RUN_ID has priority over this argument. + tags (`Dict[str, str]`, *optional*): + An optional `dict` of `str` keys and values, or a `str` dump from a `dict`, to set as tags on the run. If a + run is being resumed, these tags are set on the resumed run. If a new run is being created, these tags are + set on the new run. Environment variable MLFLOW_TAGS has priority over this argument. + nested_run (`bool`, *optional*, defaults to `False`): + Controls whether run is nested in parent run. True creates a nested run. Environment variable + MLFLOW_NESTED_RUN has priority over this argument. + run_name (`str`, *optional*): + Name of new run (stored as a mlflow.runName tag). Used only when `run_id` is unspecified. + description (`str`, *optional*): + An optional string that populates the description box of the run. If a run is being resumed, the + description is set on the resumed run. If a new run is being created, the description is set on the new + run. + """ + + name = "mlflow" + requires_logging_directory = False + + @on_main_process + def __init__( + self, + experiment_name: str = None, + logging_dir: Optional[Union[str, os.PathLike]] = None, + run_id: Optional[str] = None, + tags: Optional[Union[Dict[str, Any], str]] = None, + nested_run: Optional[bool] = False, + run_name: Optional[str] = None, + description: Optional[str] = None, + ): + experiment_name = os.environ.get("MLFLOW_EXPERIMENT_NAME", experiment_name) + run_id = os.environ.get("MLFLOW_RUN_ID", run_id) + tags = os.environ.get("MLFLOW_TAGS", tags) + if isinstance(tags, str): + tags = json.loads(tags) + + nested_run = os.environ.get("MLFLOW_NESTED_RUN", nested_run) + + import mlflow + + exps = mlflow.search_experiments(filter_string=f"name = '{experiment_name}'") + if len(exps) > 0: + if len(exps) > 1: + logger.warning("Multiple experiments with the same name found. Using first one.") + experiment_id = exps[0].experiment_id + else: + experiment_id = mlflow.create_experiment( + name=experiment_name, + artifact_location=logging_dir, + tags=tags, + ) + + self.active_run = mlflow.start_run( + run_id=run_id, + experiment_id=experiment_id, + run_name=run_name, + nested=nested_run, + tags=tags, + description=description, + ) + + logger.debug(f"Initialized mlflow experiment {experiment_name}") + logger.debug( + "Make sure to log any initial configurations with `self.store_init_configuration` before training!" + ) + + @property + def tracker(self): + return self.active_run + + @on_main_process + def store_init_configuration(self, values: dict): + """ + Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment. + + Args: + values (`dict`): + Values to be stored as initial hyperparameters as key-value pairs. + """ + import mlflow + + for name, value in list(values.items()): + # internally, all values are converted to str in MLflow + if len(str(value)) > mlflow.utils.validation.MAX_PARAM_VAL_LENGTH: + logger.warning_once( + f'Accelerate is attempting to log a value of "{value}" for key "{name}" as a parameter. MLflow\'s' + f" log_param() only accepts values no longer than {mlflow.utils.validation.MAX_PARAM_VAL_LENGTH} characters so we dropped this attribute." + ) + del values[name] + + values_list = list(values.items()) + + # MLflow cannot log more than 100 values in one go, so we have to split it + for i in range(0, len(values_list), mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH): + mlflow.log_params(dict(values_list[i : i + mlflow.utils.validation.MAX_PARAMS_TAGS_PER_BATCH])) + + logger.debug("Stored initial configuration hyperparameters to MLflow") + + @on_main_process + def log(self, values: dict, step: Optional[int]): + """ + Logs `values` to the current run. + + Args: + values (`dict`): + Values to be logged as key-value pairs. + step (`int`, *optional*): + The run step. If included, the log will be affiliated with this step. + """ + metrics = {} + for k, v in values.items(): + if isinstance(v, (int, float)): + metrics[k] = v + else: + logger.warning_once( + f'MLflowTracker is attempting to log a value of "{v}" of type {type(v)} for key "{k}" as a metric. ' + "MLflow's log_metric() only accepts float and int types so we dropped this attribute." + ) + import mlflow + + mlflow.log_metrics(metrics, step=step) + logger.debug("Successfully logged to mlflow") + + @on_main_process + def finish(self): + """ + End the active MLflow run. + """ + import mlflow + + mlflow.end_run() + + +class ClearMLTracker(GeneralTracker): + """ + A `Tracker` class that supports `clearml`. Should be initialized at the start of your script. + + Args: + run_name (`str`, *optional*): + Name of the experiment. Environment variables `CLEARML_PROJECT` and `CLEARML_TASK` have priority over this + argument. + **kwargs (additional keyword arguments, *optional*): + Kwargs passed along to the `Task.__init__` method. + """ + + name = "clearml" + requires_logging_directory = False + + @on_main_process + def __init__(self, run_name: str = None, **kwargs): + from clearml import Task + + current_task = Task.current_task() + self._initialized_externally = False + if current_task: + self._initialized_externally = True + self.task = current_task + return + + kwargs.setdefault("project_name", os.environ.get("CLEARML_PROJECT", run_name)) + kwargs.setdefault("task_name", os.environ.get("CLEARML_TASK", run_name)) + self.task = Task.init(**kwargs) + + @property + def tracker(self): + return self.task + + @on_main_process + def store_init_configuration(self, values: dict): + """ + Connect configuration dictionary to the Task object. Should be run at the beginning of your experiment. + + Args: + values (`dict`): + Values to be stored as initial hyperparameters as key-value pairs. + """ + return self.task.connect_configuration(values) + + @on_main_process + def log(self, values: Dict[str, Union[int, float]], step: Optional[int] = None, **kwargs): + """ + Logs `values` dictionary to the current run. The dictionary keys must be strings. The dictionary values must be + ints or floats + + Args: + values (`Dict[str, Union[int, float]]`): + Values to be logged as key-value pairs. If the key starts with 'eval_'/'test_'/'train_', the value will + be reported under the 'eval'/'test'/'train' series and the respective prefix will be removed. + Otherwise, the value will be reported under the 'train' series, and no prefix will be removed. + step (`int`, *optional*): + If specified, the values will be reported as scalars, with the iteration number equal to `step`. + Otherwise they will be reported as single values. + kwargs: + Additional key word arguments passed along to the `clearml.Logger.report_single_value` or + `clearml.Logger.report_scalar` methods. + """ + clearml_logger = self.task.get_logger() + for k, v in values.items(): + if not isinstance(v, (int, float)): + logger.warning_once( + "Accelerator is attempting to log a value of " + f'"{v}" of type {type(v)} for key "{k}" as a scalar. ' + "This invocation of ClearML logger's report_scalar() " + "is incorrect so we dropped this attribute." + ) + continue + if step is None: + clearml_logger.report_single_value(name=k, value=v, **kwargs) + continue + title, series = ClearMLTracker._get_title_series(k) + clearml_logger.report_scalar(title=title, series=series, value=v, iteration=step, **kwargs) + + @on_main_process + def log_images(self, values: dict, step: Optional[int] = None, **kwargs): + """ + Logs `images` to the current run. + + Args: + values (`Dict[str, List[Union[np.ndarray, PIL.Image]]`): + Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or + step (`int`, *optional*): + The run step. If included, the log will be affiliated with this step. + kwargs: + Additional key word arguments passed along to the `clearml.Logger.report_image` method. + """ + clearml_logger = self.task.get_logger() + for k, v in values.items(): + title, series = ClearMLTracker._get_title_series(k) + clearml_logger.report_image(title=title, series=series, iteration=step, image=v, **kwargs) + + @on_main_process + def log_table( + self, + table_name: str, + columns: List[str] = None, + data: List[List[Any]] = None, + dataframe: Any = None, + step: Optional[int] = None, + **kwargs, + ): + """ + Log a Table to the task. Can be defined eitherwith `columns` and `data` or with `dataframe`. + + Args: + table_name (`str`): + The name of the table + columns (list of `str`, *optional*): + The name of the columns on the table + data (List of List of Any data type, *optional*): + The data to be logged in the table. If `columns` is not specified, then the first entry in data will be + the name of the columns of the table + dataframe (Any data type, *optional*): + The data to be logged in the table + step (`int`, *optional*): + The run step. If included, the log will be affiliated with this step. + kwargs: + Additional key word arguments passed along to the `clearml.Logger.report_table` method. + """ + to_report = dataframe + if dataframe is None: + if data is None: + raise ValueError( + "`ClearMLTracker.log_table` requires that `data` to be supplied if `dataframe` is `None`" + ) + to_report = [columns] + data if columns else data + title, series = ClearMLTracker._get_title_series(table_name) + self.task.get_logger().report_table(title=title, series=series, table_plot=to_report, iteration=step, **kwargs) + + @on_main_process + def finish(self): + """ + Close the ClearML task. If the task was initialized externally (e.g. by manually calling `Task.init`), this + function is a noop + """ + if self.task and not self._initialized_externally: + self.task.close() + + @staticmethod + def _get_title_series(name): + for prefix in ["eval", "test", "train"]: + if name.startswith(prefix + "_"): + return name[len(prefix) + 1 :], prefix + return name, "train" + + +class DVCLiveTracker(GeneralTracker): + """ + A `Tracker` class that supports `dvclive`. Should be initialized at the start of your script. + + Args: + run_name (`str`, *optional*): + Ignored for dvclive. See `kwargs` instead. + kwargs: + Additional key word arguments passed along to [`dvclive.Live()`](https://dvc.org/doc/dvclive/live). + + Example: + + ```py + from accelerate import Accelerator + + accelerator = Accelerator(log_with="dvclive") + accelerator.init_trackers(project_name="my_project", init_kwargs={"dvclive": {"dir": "my_directory"}}) + ``` + """ + + name = "dvclive" + requires_logging_directory = False + + @on_main_process + def __init__(self, run_name: Optional[str] = None, live: Optional[Any] = None, **kwargs): + from dvclive import Live + + super().__init__() + self.live = live if live is not None else Live(**kwargs) + + @property + def tracker(self): + return self.live + + @on_main_process + def store_init_configuration(self, values: dict): + """ + Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment. Stores the + hyperparameters in a yaml file for future use. + + Args: + values (Dictionary `str` to `bool`, `str`, `float`, `int`, or a List or Dict of those types): + Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`, + `str`, `float`, or `int`. + """ + self.live.log_params(values) + + @on_main_process + def log(self, values: dict, step: Optional[int] = None, **kwargs): + """ + Logs `values` to the current run. + + Args: + values (Dictionary `str` to `str`, `float`, or `int`): + Values to be logged as key-value pairs. The values need to have type `str`, `float`, or `int`. + step (`int`, *optional*): + The run step. If included, the log will be affiliated with this step. + kwargs: + Additional key word arguments passed along to `dvclive.Live.log_metric()`. + """ + from dvclive.plots import Metric + + if step is not None: + self.live.step = step + for k, v in values.items(): + if Metric.could_log(v): + self.live.log_metric(k, v, **kwargs) + else: + logger.warning_once( + "Accelerator attempted to log a value of " + f'"{v}" of type {type(v)} for key "{k}" as a scalar. ' + "This invocation of DVCLive's Live.log_metric() " + "is incorrect so we dropped this attribute." + ) + self.live.next_step() + + @on_main_process + def finish(self): + """ + Closes `dvclive.Live()`. + """ + self.live.end() + + +LOGGER_TYPE_TO_CLASS = { + "aim": AimTracker, + "comet_ml": CometMLTracker, + "mlflow": MLflowTracker, + "tensorboard": TensorBoardTracker, + "wandb": WandBTracker, + "clearml": ClearMLTracker, + "dvclive": DVCLiveTracker, +} + + +def filter_trackers( + log_with: List[Union[str, LoggerType, GeneralTracker]], + logging_dir: Union[str, os.PathLike] = None, +): + """ + Takes in a list of potential tracker types and checks that: + - The tracker wanted is available in that environment + - Filters out repeats of tracker types + - If `all` is in `log_with`, will return all trackers in the environment + - If a tracker requires a `logging_dir`, ensures that `logging_dir` is not `None` + + Args: + log_with (list of `str`, [`~utils.LoggerType`] or [`~tracking.GeneralTracker`], *optional*): + A list of loggers to be setup for experiment tracking. Should be one or several of: + + - `"all"` + - `"tensorboard"` + - `"wandb"` + - `"comet_ml"` + - `"mlflow"` + - `"dvclive"` + If `"all"` is selected, will pick up all available trackers in the environment and initialize them. Can + also accept implementations of `GeneralTracker` for custom trackers, and can be combined with `"all"`. + logging_dir (`str`, `os.PathLike`, *optional*): + A path to a directory for storing logs of locally-compatible loggers. + """ + loggers = [] + if log_with is not None: + if not isinstance(log_with, (list, tuple)): + log_with = [log_with] + if "all" in log_with or LoggerType.ALL in log_with: + loggers = [o for o in log_with if issubclass(type(o), GeneralTracker)] + get_available_trackers() + else: + for log_type in log_with: + if log_type not in LoggerType and not issubclass(type(log_type), GeneralTracker): + raise ValueError(f"Unsupported logging capability: {log_type}. Choose between {LoggerType.list()}") + if issubclass(type(log_type), GeneralTracker): + loggers.append(log_type) + else: + log_type = LoggerType(log_type) + if log_type not in loggers: + if log_type in get_available_trackers(): + tracker_init = LOGGER_TYPE_TO_CLASS[str(log_type)] + if tracker_init.requires_logging_directory: + if logging_dir is None: + raise ValueError( + f"Logging with `{log_type}` requires a `logging_dir` to be passed in." + ) + loggers.append(log_type) + else: + logger.debug(f"Tried adding logger {log_type}, but package is unavailable in the system.") + + return loggers diff --git a/llm/Lib/site-packages/accelerate/utils/__init__.py b/llm/Lib/site-packages/accelerate/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..50baa32fdbca940857055445c1988c16b9f01b6a --- /dev/null +++ b/llm/Lib/site-packages/accelerate/utils/__init__.py @@ -0,0 +1,225 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .constants import ( + MODEL_NAME, + OPTIMIZER_NAME, + RNG_STATE_NAME, + SAFE_MODEL_NAME, + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + SAMPLER_NAME, + SCALER_NAME, + SCHEDULER_NAME, + TORCH_DISTRIBUTED_OPERATION_TYPES, + TORCH_LAUNCH_PARAMS, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, +) +from .dataclasses import ( + AutocastKwargs, + BnbQuantizationConfig, + ComputeEnvironment, + CustomDtype, + DataLoaderConfiguration, + DeepSpeedPlugin, + DistributedDataParallelKwargs, + DistributedType, + DynamoBackend, + FP8RecipeKwargs, + FullyShardedDataParallelPlugin, + GradientAccumulationPlugin, + GradScalerKwargs, + InitProcessGroupKwargs, + KwargsHandler, + LoggerType, + MegatronLMPlugin, + PrecisionType, + ProjectConfiguration, + RNGType, + SageMakerDistributedType, + TensorInformation, + TorchDynamoPlugin, +) +from .environment import ( + are_libraries_initialized, + check_cuda_p2p_ib_support, + check_fp8_capability, + convert_dict_to_env_variables, + get_cpu_distributed_information, + get_gpu_info, + get_int_from_env, + parse_choice_from_env, + parse_flag_from_env, + set_numa_affinity, + str_to_bool, +) +from .imports import ( + get_ccl_version, + is_4bit_bnb_available, + is_8bit_bnb_available, + is_aim_available, + is_bf16_available, + is_bnb_available, + is_boto3_available, + is_ccl_available, + is_clearml_available, + is_comet_ml_available, + is_cuda_available, + is_datasets_available, + is_deepspeed_available, + is_dvclive_available, + is_fp8_available, + is_ipex_available, + is_megatron_lm_available, + is_mlflow_available, + is_mlu_available, + is_mps_available, + is_msamp_available, + is_npu_available, + is_pandas_available, + is_peft_available, + is_pippy_available, + is_pynvml_available, + is_rich_available, + is_sagemaker_available, + is_tensorboard_available, + is_timm_available, + is_torch_xla_available, + is_transformer_engine_available, + is_transformers_available, + is_wandb_available, + is_xpu_available, +) +from .modeling import ( + calculate_maximum_sizes, + check_device_map, + check_tied_parameters_in_config, + check_tied_parameters_on_same_device, + compute_module_sizes, + convert_file_size_to_int, + dtype_byte_size, + find_tied_parameters, + get_balanced_memory, + get_max_layer_size, + get_max_memory, + get_mixed_precision_context_manager, + id_tensor_storage, + infer_auto_device_map, + is_peft_model, + load_checkpoint_in_model, + load_offloaded_weights, + load_state_dict, + named_module_tensors, + retie_parameters, + set_module_tensor_to_device, + shard_checkpoint, +) +from .offload import ( + OffloadedWeightsLoader, + PrefixedDataset, + extract_submodules_state_dict, + load_offloaded_weight, + offload_state_dict, + offload_weight, + save_offload_index, +) +from .operations import ( + CannotPadNestedTensorWarning, + broadcast, + broadcast_object_list, + concatenate, + convert_outputs_to_fp32, + convert_to_fp32, + copy_tensor_to_devices, + find_batch_size, + find_device, + gather, + gather_object, + get_data_structure, + honor_type, + ignorant_find_batch_size, + initialize_tensors, + is_namedtuple, + is_tensor_information, + is_torch_tensor, + listify, + pad_across_processes, + pad_input_tensors, + recursively_apply, + reduce, + send_to_device, + slice_tensors, +) +from .versions import compare_versions, is_torch_version + + +if is_deepspeed_available(): + from .deepspeed import ( + DeepSpeedEngineWrapper, + DeepSpeedOptimizerWrapper, + DeepSpeedSchedulerWrapper, + DummyOptim, + DummyScheduler, + HfDeepSpeedConfig, + ) + +from .bnb import has_4bit_bnb_layers, load_and_quantize_model +from .fsdp_utils import load_fsdp_model, load_fsdp_optimizer, save_fsdp_model, save_fsdp_optimizer +from .launch import ( + PrepareForLaunch, + _filter_args, + prepare_deepspeed_cmd_env, + prepare_multi_gpu_env, + prepare_sagemager_args_inputs, + prepare_simple_launcher_cmd_env, + prepare_tpu, +) +from .megatron_lm import ( + AbstractTrainStep, + BertTrainStep, + GPTTrainStep, + MegatronEngine, + MegatronLMDummyDataLoader, + MegatronLMDummyScheduler, + MegatronLMOptimizerWrapper, + MegatronLMSchedulerWrapper, + T5TrainStep, + avg_losses_across_data_parallel_group, + gather_across_data_parallel_groups, +) +from .megatron_lm import initialize as megatron_lm_initialize +from .megatron_lm import prepare_data_loader as megatron_lm_prepare_data_loader +from .megatron_lm import prepare_model as megatron_lm_prepare_model +from .megatron_lm import prepare_optimizer as megatron_lm_prepare_optimizer +from .megatron_lm import prepare_scheduler as megatron_lm_prepare_scheduler +from .memory import find_executable_batch_size, release_memory +from .other import ( + check_os_kernel, + clean_state_dict_for_safetensors, + clear_environment, + convert_bytes, + extract_model_from_parallel, + get_pretty_name, + is_port_in_use, + merge_dicts, + patch_environment, + recursive_getattr, + save, + wait_for_everyone, + write_basic_config, +) +from .random import set_seed, synchronize_rng_state, synchronize_rng_states +from .torch_xla import install_xla +from .tqdm import tqdm +from .transformer_engine import convert_model, has_transformer_engine_layers diff --git a/llm/Lib/site-packages/accelerate/utils/__pycache__/__init__.cpython-311.pyc b/llm/Lib/site-packages/accelerate/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe45b51e203b4b1bac0913b35f661576b69545e5 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/utils/__pycache__/bnb.cpython-311.pyc b/llm/Lib/site-packages/accelerate/utils/__pycache__/bnb.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b79e913fa6e1ead4eeb7811c7666b73785e0b567 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/utils/__pycache__/bnb.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/utils/__pycache__/constants.cpython-311.pyc b/llm/Lib/site-packages/accelerate/utils/__pycache__/constants.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdd1bad85f63527a2d7d1bd21afa765e190faf55 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/utils/__pycache__/constants.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/utils/__pycache__/dataclasses.cpython-311.pyc b/llm/Lib/site-packages/accelerate/utils/__pycache__/dataclasses.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..633a2b1fba13ca0fb73bbfb07bd0d08fbed333ae Binary files /dev/null and b/llm/Lib/site-packages/accelerate/utils/__pycache__/dataclasses.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/utils/__pycache__/deepspeed.cpython-311.pyc b/llm/Lib/site-packages/accelerate/utils/__pycache__/deepspeed.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6eb41526c539cfd38b4f23388e6072b0bdf95868 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/utils/__pycache__/deepspeed.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/utils/__pycache__/environment.cpython-311.pyc b/llm/Lib/site-packages/accelerate/utils/__pycache__/environment.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d996518830608453e8c558b10277784a0293d393 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/utils/__pycache__/environment.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/utils/__pycache__/fsdp_utils.cpython-311.pyc b/llm/Lib/site-packages/accelerate/utils/__pycache__/fsdp_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5703998184e8074285478d2614b6202480f094a9 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/utils/__pycache__/fsdp_utils.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/utils/__pycache__/imports.cpython-311.pyc b/llm/Lib/site-packages/accelerate/utils/__pycache__/imports.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8589dc317bcc8cfd10403b91b07b6c0a37a0ff20 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/utils/__pycache__/imports.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/utils/__pycache__/launch.cpython-311.pyc b/llm/Lib/site-packages/accelerate/utils/__pycache__/launch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..917c79a2bbb9887b6a40ff2b07a8dc6f3f3ec4ce Binary files /dev/null and b/llm/Lib/site-packages/accelerate/utils/__pycache__/launch.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/utils/__pycache__/megatron_lm.cpython-311.pyc b/llm/Lib/site-packages/accelerate/utils/__pycache__/megatron_lm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15fa4a429a80b9b9080ccce5bf831ce73e2a6525 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/utils/__pycache__/megatron_lm.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/utils/__pycache__/memory.cpython-311.pyc b/llm/Lib/site-packages/accelerate/utils/__pycache__/memory.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..217822a49ee9428187abb708a4d6d671d6bdff95 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/utils/__pycache__/memory.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/utils/__pycache__/modeling.cpython-311.pyc b/llm/Lib/site-packages/accelerate/utils/__pycache__/modeling.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08bec29668e14109684c76f01556ef84e0514791 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/utils/__pycache__/modeling.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/utils/__pycache__/offload.cpython-311.pyc b/llm/Lib/site-packages/accelerate/utils/__pycache__/offload.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d218d5f2b55cd260217f1883debaecd9d2c3ab08 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/utils/__pycache__/offload.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/utils/__pycache__/operations.cpython-311.pyc b/llm/Lib/site-packages/accelerate/utils/__pycache__/operations.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91ad6ad87ddf4fe74b8d13fa882b211393b6ee31 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/utils/__pycache__/operations.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/utils/__pycache__/other.cpython-311.pyc b/llm/Lib/site-packages/accelerate/utils/__pycache__/other.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..913dd64c06aafc3e7f60f14ba43540a61bebd12c Binary files /dev/null and b/llm/Lib/site-packages/accelerate/utils/__pycache__/other.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/utils/__pycache__/random.cpython-311.pyc b/llm/Lib/site-packages/accelerate/utils/__pycache__/random.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9e8fbdf8940fef5371ba818e51dd057ec29a259 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/utils/__pycache__/random.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/utils/__pycache__/rich.cpython-311.pyc b/llm/Lib/site-packages/accelerate/utils/__pycache__/rich.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28aa3f953df9f1286a6ad16fad3c001b213f58ae Binary files /dev/null and b/llm/Lib/site-packages/accelerate/utils/__pycache__/rich.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/utils/__pycache__/torch_xla.cpython-311.pyc b/llm/Lib/site-packages/accelerate/utils/__pycache__/torch_xla.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a041c5ae536102e5d61760a2a51ab7ab922ba846 Binary files /dev/null and b/llm/Lib/site-packages/accelerate/utils/__pycache__/torch_xla.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/utils/__pycache__/tqdm.cpython-311.pyc b/llm/Lib/site-packages/accelerate/utils/__pycache__/tqdm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b81150ae0c6de3dfbc2df9846230f59bbc30336d Binary files /dev/null and b/llm/Lib/site-packages/accelerate/utils/__pycache__/tqdm.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/utils/__pycache__/transformer_engine.cpython-311.pyc b/llm/Lib/site-packages/accelerate/utils/__pycache__/transformer_engine.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7df0bd0cf5884f9adc5fe8be73958f4942e3357f Binary files /dev/null and b/llm/Lib/site-packages/accelerate/utils/__pycache__/transformer_engine.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/utils/__pycache__/versions.cpython-311.pyc b/llm/Lib/site-packages/accelerate/utils/__pycache__/versions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c54a09c1362fc794b82f1c56e9f963e39d87f41e Binary files /dev/null and b/llm/Lib/site-packages/accelerate/utils/__pycache__/versions.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/accelerate/utils/bnb.py b/llm/Lib/site-packages/accelerate/utils/bnb.py new file mode 100644 index 0000000000000000000000000000000000000000..284ee5df6e89171948745255dd33a3b2b91123a2 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/utils/bnb.py @@ -0,0 +1,467 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import os +from copy import deepcopy +from typing import Dict, List, Optional, Union + +import torch +import torch.nn as nn + +from accelerate.utils.imports import ( + is_4bit_bnb_available, + is_8bit_bnb_available, +) + +from ..big_modeling import dispatch_model, init_empty_weights +from .dataclasses import BnbQuantizationConfig +from .modeling import ( + find_tied_parameters, + get_balanced_memory, + infer_auto_device_map, + load_checkpoint_in_model, + offload_weight, + set_module_tensor_to_device, +) + + +logger = logging.getLogger(__name__) + + +def load_and_quantize_model( + model: torch.nn.Module, + bnb_quantization_config: BnbQuantizationConfig, + weights_location: Union[str, os.PathLike] = None, + device_map: Optional[Dict[str, Union[int, str, torch.device]]] = None, + no_split_module_classes: Optional[List[str]] = None, + max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None, + offload_folder: Optional[Union[str, os.PathLike]] = None, + offload_state_dict: bool = False, +): + """ + This function will quantize the input model with the associated config passed in `bnb_quantization_config`. If the + model is in the meta device, we will load and dispatch the weights according to the `device_map` passed. If the + model is already loaded, we will quantize the model and put the model on the GPU, + + Args: + model (`torch.nn.Module`): + Input model. The model can be already loaded or on the meta device + bnb_quantization_config (`BnbQuantizationConfig`): + The bitsandbytes quantization parameters + weights_location (`str` or `os.PathLike`): + The folder weights_location to load. It can be: + - a path to a file containing a whole model state dict + - a path to a `.json` file containing the index to a sharded checkpoint + - a path to a folder containing a unique `.index.json` file and the shards of a checkpoint. + - a path to a folder containing a unique pytorch_model.bin file. + device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer + name, once a given module name is inside, every submodule of it will be sent to the same device. + no_split_module_classes (`List[str]`, *optional*): + A list of layer class names that should never be split across device (for instance any layer that has a + residual connection). + max_memory (`Dict`, *optional*): + A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset. + offload_folder (`str` or `os.PathLike`, *optional*): + If the `device_map` contains any value `"disk"`, the folder where we will offload weights. + offload_state_dict (`bool`, *optional*, defaults to `False`): + If `True`, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if + the weight of the CPU state dict + the biggest shard does not fit. + + Returns: + `torch.nn.Module`: The quantized model + """ + + load_in_4bit = bnb_quantization_config.load_in_4bit + load_in_8bit = bnb_quantization_config.load_in_8bit + + if load_in_8bit and not is_8bit_bnb_available(): + raise ImportError( + "You have a version of `bitsandbytes` that is not compatible with 8bit quantization," + " make sure you have the latest version of `bitsandbytes` installed." + ) + if load_in_4bit and not is_4bit_bnb_available(): + raise ValueError( + "You have a version of `bitsandbytes` that is not compatible with 4bit quantization," + "make sure you have the latest version of `bitsandbytes` installed." + ) + + modules_on_cpu = [] + # custom device map + if isinstance(device_map, dict) and len(device_map.keys()) > 1: + modules_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]] + + # We keep some modules such as the lm_head in their original dtype for numerical stability reasons + if bnb_quantization_config.skip_modules is None: + bnb_quantization_config.skip_modules = get_keys_to_not_convert(model) + + # add cpu modules to skip modules only for 4-bit modules + if load_in_4bit: + bnb_quantization_config.skip_modules.extend(modules_on_cpu) + modules_to_not_convert = bnb_quantization_config.skip_modules + + # We add the modules we want to keep in full precision + if bnb_quantization_config.keep_in_fp32_modules is None: + bnb_quantization_config.keep_in_fp32_modules = [] + keep_in_fp32_modules = bnb_quantization_config.keep_in_fp32_modules + modules_to_not_convert.extend(keep_in_fp32_modules) + + # compatibility with peft + model.is_loaded_in_4bit = load_in_4bit + model.is_loaded_in_8bit = load_in_8bit + + model_device = get_parameter_device(model) + if model_device.type != "meta": + # quantization of an already loaded model + logger.warning( + "It is not recommended to quantize a loaded model. " + "The model should be instantiated under the `init_empty_weights` context manager." + ) + model = replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert) + # convert param to the right dtype + dtype = bnb_quantization_config.torch_dtype + for name, param in model.state_dict().items(): + if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules): + param.to(torch.float32) + if param.dtype != torch.float32: + name = name.replace(".weight", "").replace(".bias", "") + param = getattr(model, name, None) + if param is not None: + param.to(torch.float32) + elif torch.is_floating_point(param): + param.to(dtype) + if model_device.type == "cuda": + # move everything to cpu in the first place because we can't do quantization if the weights are already on cuda + model.cuda(torch.cuda.current_device()) + torch.cuda.empty_cache() + elif torch.cuda.is_available(): + model.to(torch.cuda.current_device()) + else: + raise RuntimeError("No GPU found. A GPU is needed for quantization.") + logger.info( + f"The model device type is {model_device.type}. However, cuda is needed for quantization." + "We move the model to cuda." + ) + return model + + elif weights_location is None: + raise RuntimeError( + f"`weights_location` needs to be the folder path containing the weights of the model, but we found {weights_location} " + ) + + else: + with init_empty_weights(): + model = replace_with_bnb_layers( + model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert + ) + + device_map = get_quantized_model_device_map( + model, + bnb_quantization_config, + device_map, + max_memory=max_memory, + no_split_module_classes=no_split_module_classes, + ) + if offload_state_dict is None and device_map is not None and "disk" in device_map.values(): + offload_state_dict = True + + offload = any(x in list(device_map.values()) for x in ["cpu", "disk"]) + + load_checkpoint_in_model( + model, + weights_location, + device_map, + dtype=bnb_quantization_config.torch_dtype, + offload_folder=offload_folder, + offload_state_dict=offload_state_dict, + keep_in_fp32_modules=bnb_quantization_config.keep_in_fp32_modules, + offload_8bit_bnb=load_in_8bit and offload, + ) + return dispatch_model(model, device_map=device_map, offload_dir=offload_folder) + + +def get_quantized_model_device_map( + model, bnb_quantization_config, device_map=None, max_memory=None, no_split_module_classes=None +): + if device_map is None: + if torch.cuda.is_available(): + device_map = {"": torch.cuda.current_device()} + else: + raise RuntimeError("No GPU found. A GPU is needed for quantization.") + logger.info("The device_map was not initialized." "Setting device_map to `{'':torch.cuda.current_device()}`.") + + if isinstance(device_map, str): + if device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]: + raise ValueError( + "If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or " + "'sequential'." + ) + + special_dtypes = {} + special_dtypes.update( + { + name: bnb_quantization_config.torch_dtype + for name, _ in model.named_parameters() + if any(m in name for m in bnb_quantization_config.skip_modules) + } + ) + special_dtypes.update( + { + name: torch.float32 + for name, _ in model.named_parameters() + if any(m in name for m in bnb_quantization_config.keep_in_fp32_modules) + } + ) + + kwargs = {} + kwargs["special_dtypes"] = special_dtypes + kwargs["no_split_module_classes"] = no_split_module_classes + kwargs["dtype"] = bnb_quantization_config.target_dtype + + # get max_memory for each device. + if device_map != "sequential": + max_memory = get_balanced_memory( + model, + low_zero=(device_map == "balanced_low_0"), + max_memory=max_memory, + **kwargs, + ) + + kwargs["max_memory"] = max_memory + device_map = infer_auto_device_map(model, **kwargs) + + if isinstance(device_map, dict): + # check if don't have any quantized module on the cpu + modules_not_to_convert = bnb_quantization_config.skip_modules + bnb_quantization_config.keep_in_fp32_modules + + device_map_without_some_modules = { + key: device_map[key] for key in device_map.keys() if key not in modules_not_to_convert + } + for device in ["cpu", "disk"]: + if device in device_map_without_some_modules.values(): + if bnb_quantization_config.load_in_4bit: + raise ValueError( + """ + Some modules are dispatched on the CPU or the disk. Make sure you have enough GPU RAM to fit + the quantized model. If you want to dispatch the model on the CPU or the disk while keeping + these modules in `torch_dtype`, you need to pass a custom `device_map` to + `load_and_quantize_model`. Check + https://huggingface.co/docs/accelerate/main/en/usage_guides/quantization#offload-modules-to-cpu-and-disk + for more details. + """ + ) + else: + logger.info( + "Some modules are are offloaded to the CPU or the disk. Note that these modules will be converted to 8-bit" + ) + del device_map_without_some_modules + return device_map + + +def replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=None, current_key_name=None): + """ + A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules or by `bnb.nn.Linear4bit` + modules from the `bitsandbytes`library. The function will be run recursively and replace `torch.nn.Linear` modules. + + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + modules_to_not_convert (`List[str]`): + Names of the modules to not quantize convert. In practice we keep the `lm_head` in full precision for + numerical stability reasons. + current_key_name (`List[str]`, *optional*): + An array to track the current key of the recursion. This is used to check whether the current key (part of + it) is not in the list of modules to not convert. + """ + + if modules_to_not_convert is None: + modules_to_not_convert = [] + + model, has_been_replaced = _replace_with_bnb_layers( + model, bnb_quantization_config, modules_to_not_convert, current_key_name + ) + if not has_been_replaced: + logger.warning( + "You are loading your model in 8bit or 4bit but no linear modules were found in your model." + " this can happen for some architectures such as gpt2 that uses Conv1D instead of Linear layers." + " Please double check your model architecture, or submit an issue on github if you think this is" + " a bug." + ) + return model + + +def _replace_with_bnb_layers( + model, + bnb_quantization_config, + modules_to_not_convert=None, + current_key_name=None, +): + """ + Private method that wraps the recursion for module replacement. + + Returns the converted model and a boolean that indicates if the conversion has been successfull or not. + """ + # bitsandbytes will initialize CUDA on import, so it needs to be imported lazily + import bitsandbytes as bnb + + has_been_replaced = False + for name, module in model.named_children(): + if current_key_name is None: + current_key_name = [] + current_key_name.append(name) + if isinstance(module, nn.Linear) and name not in modules_to_not_convert: + # Check if the current key is not in the `modules_to_not_convert` + current_key_name_str = ".".join(current_key_name) + proceed = True + for key in modules_to_not_convert: + if ( + (key in current_key_name_str) and (key + "." in current_key_name_str) + ) or key == current_key_name_str: + proceed = False + break + if proceed: + # Load bnb module with empty weight and replace ``nn.Linear` module + if bnb_quantization_config.load_in_8bit: + bnb_module = bnb.nn.Linear8bitLt( + module.in_features, + module.out_features, + module.bias is not None, + has_fp16_weights=False, + threshold=bnb_quantization_config.llm_int8_threshold, + ) + elif bnb_quantization_config.load_in_4bit: + bnb_module = bnb.nn.Linear4bit( + module.in_features, + module.out_features, + module.bias is not None, + bnb_quantization_config.bnb_4bit_compute_dtype, + compress_statistics=bnb_quantization_config.bnb_4bit_use_double_quant, + quant_type=bnb_quantization_config.bnb_4bit_quant_type, + ) + else: + raise ValueError("load_in_8bit and load_in_4bit can't be both False") + bnb_module.weight.data = module.weight.data + if module.bias is not None: + bnb_module.bias.data = module.bias.data + bnb_module.requires_grad_(False) + setattr(model, name, bnb_module) + has_been_replaced = True + if len(list(module.children())) > 0: + _, _has_been_replaced = _replace_with_bnb_layers( + module, bnb_quantization_config, modules_to_not_convert, current_key_name + ) + has_been_replaced = has_been_replaced | _has_been_replaced + # Remove the last key for recursion + current_key_name.pop(-1) + return model, has_been_replaced + + +def get_keys_to_not_convert(model): + r""" + An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules + we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want + to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in + int8. + + Parameters: + model (`torch.nn.Module`): + Input model + """ + # Create a copy of the model + with init_empty_weights(): + tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager` + + tied_params = find_tied_parameters(tied_model) + # For compatibility with Accelerate < 0.18 + if isinstance(tied_params, dict): + tied_keys = sum(list(tied_params.values()), []) + list(tied_params.keys()) + else: + tied_keys = sum(tied_params, []) + has_tied_params = len(tied_keys) > 0 + + # Check if it is a base model + is_base_model = False + if hasattr(model, "base_model_prefix"): + is_base_model = not hasattr(model, model.base_model_prefix) + + # Ignore this for base models (BertModel, GPT2Model, etc.) + if (not has_tied_params) and is_base_model: + return [] + + # otherwise they have an attached head + list_modules = list(model.named_children()) + list_last_module = [list_modules[-1][0]] + + # add last module together with tied weights + intersection = set(list_last_module) - set(tied_keys) + list_untouched = list(set(tied_keys)) + list(intersection) + + # remove ".weight" from the keys + names_to_remove = [".weight", ".bias"] + filtered_module_names = [] + for name in list_untouched: + for name_to_remove in names_to_remove: + if name_to_remove in name: + name = name.replace(name_to_remove, "") + filtered_module_names.append(name) + + return filtered_module_names + + +def has_4bit_bnb_layers(model): + """Check if we have `bnb.nn.Linear4bit` or `bnb.nn.Linear8bitLt` layers inside our model""" + # bitsandbytes will initialize CUDA on import, so it needs to be imported lazily + import bitsandbytes as bnb + + for m in model.modules(): + if isinstance(m, bnb.nn.Linear4bit): + return True + return False + + +def get_parameter_device(parameter: nn.Module): + return next(parameter.parameters()).device + + +def quantize_and_offload_8bit(model, param, param_name, new_dtype, offload_folder, offload_index, fp16_statistics): + # if it is not quantized, we quantize and offload the quantized weights and the SCB stats + if fp16_statistics is None: + set_module_tensor_to_device(model, param_name, 0, dtype=new_dtype, value=param) + tensor_name = param_name + module = model + if "." in tensor_name: + splits = tensor_name.split(".") + for split in splits[:-1]: + new_module = getattr(module, split) + if new_module is None: + raise ValueError(f"{module} has no attribute {split}.") + module = new_module + tensor_name = splits[-1] + # offload weights + module._parameters[tensor_name].requires_grad = False + offload_weight(module._parameters[tensor_name], param_name, offload_folder, index=offload_index) + if hasattr(module._parameters[tensor_name], "SCB"): + offload_weight( + module._parameters[tensor_name].SCB, + param_name.replace("weight", "SCB"), + offload_folder, + index=offload_index, + ) + else: + offload_weight(param, param_name, offload_folder, index=offload_index) + offload_weight(fp16_statistics, param_name.replace("weight", "SCB"), offload_folder, index=offload_index) + + set_module_tensor_to_device(model, param_name, "meta", dtype=new_dtype, value=torch.empty(*param.size())) diff --git a/llm/Lib/site-packages/accelerate/utils/constants.py b/llm/Lib/site-packages/accelerate/utils/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..8c299570757cb6a5df93f4794e403d1581dd7c2e --- /dev/null +++ b/llm/Lib/site-packages/accelerate/utils/constants.py @@ -0,0 +1,72 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import operator as op + + +SCALER_NAME = "scaler.pt" +MODEL_NAME = "pytorch_model" +SAFE_MODEL_NAME = "model" +RNG_STATE_NAME = "random_states" +OPTIMIZER_NAME = "optimizer" +SCHEDULER_NAME = "scheduler" +SAMPLER_NAME = "sampler" +WEIGHTS_NAME = f"{MODEL_NAME}.bin" +WEIGHTS_INDEX_NAME = f"{WEIGHTS_NAME}.index.json" +SAFE_WEIGHTS_NAME = f"{SAFE_MODEL_NAME}.safetensors" +SAFE_WEIGHTS_INDEX_NAME = f"{SAFE_WEIGHTS_NAME}.index.json" +SAGEMAKER_PYTORCH_VERSION = "1.10.2" +SAGEMAKER_PYTHON_VERSION = "py38" +SAGEMAKER_TRANSFORMERS_VERSION = "4.17.0" +SAGEMAKER_PARALLEL_EC2_INSTANCES = ["ml.p3.16xlarge", "ml.p3dn.24xlarge", "ml.p4dn.24xlarge"] +FSDP_SHARDING_STRATEGY = ["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD", "HYBRID_SHARD_ZERO2"] +FSDP_AUTO_WRAP_POLICY = ["TRANSFORMER_BASED_WRAP", "SIZE_BASED_WRAP", "NO_WRAP"] +FSDP_BACKWARD_PREFETCH = ["BACKWARD_PRE", "BACKWARD_POST", "NO_PREFETCH"] +FSDP_STATE_DICT_TYPE = ["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"] +FSDP_PYTORCH_VERSION = "2.1.0" +FSDP_MODEL_NAME = "pytorch_model_fsdp" +DEEPSPEED_MULTINODE_LAUNCHERS = ["pdsh", "standard", "openmpi", "mvapich", "mpich"] +TORCH_DYNAMO_MODES = ["default", "reduce-overhead", "max-autotune"] + +STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt} + +# These are the args for `torch.distributed.launch` for pytorch < 1.9 +TORCH_LAUNCH_PARAMS = [ + "nnodes", + "nproc_per_node", + "rdzv_backend", + "rdzv_endpoint", + "rdzv_id", + "rdzv_conf", + "standalone", + "max_restarts", + "monitor_interval", + "start_method", + "role", + "module", + "m", + "no_python", + "run_path", + "log_dir", + "r", + "redirects", + "t", + "tee", + "node_rank", + "master_addr", + "master_port", +] + +CUDA_DISTRIBUTED_TYPES = ["DEEPSPEED", "MULTI_GPU", "FSDP", "MEGATRON_LM"] +TORCH_DISTRIBUTED_OPERATION_TYPES = CUDA_DISTRIBUTED_TYPES + ["MULTI_NPU", "MULTI_MLU", "MULTI_XPU", "MULTI_CPU"] diff --git a/llm/Lib/site-packages/accelerate/utils/dataclasses.py b/llm/Lib/site-packages/accelerate/utils/dataclasses.py new file mode 100644 index 0000000000000000000000000000000000000000..63e8a3a49f7c765214d53c30a279097671c8838e --- /dev/null +++ b/llm/Lib/site-packages/accelerate/utils/dataclasses.py @@ -0,0 +1,1717 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +General namespace and dataclass related classes +""" + +import argparse +import copy +import enum +import functools +import os +import typing +import warnings +from contextlib import contextmanager +from dataclasses import dataclass, field +from datetime import timedelta +from typing import Any, Callable, Dict, Iterable, List, Literal, Optional, Tuple, get_args + +import torch + +from .constants import FSDP_AUTO_WRAP_POLICY, FSDP_BACKWARD_PREFETCH, FSDP_SHARDING_STRATEGY, FSDP_STATE_DICT_TYPE +from .environment import str_to_bool +from .imports import is_cuda_available, is_npu_available, is_xpu_available +from .versions import compare_versions + + +class KwargsHandler: + """ + Internal mixin that implements a `to_kwargs()` method for a dataclass. + """ + + def to_dict(self): + return copy.deepcopy(self.__dict__) + + def to_kwargs(self): + """ + Returns a dictionary containing the attributes with values different from the default of this class. + """ + # import clear_environment here to avoid circular import problem + from .other import clear_environment + + with clear_environment(): + default_dict = self.__class__().to_dict() + this_dict = self.to_dict() + return {k: v for k, v in this_dict.items() if default_dict[k] != v} + + +@dataclass +class AutocastKwargs(KwargsHandler): + """ + Use this object in your [`Accelerator`] to customize how `torch.autocast` behaves. Please refer to the + documentation of this [context manager](https://pytorch.org/docs/stable/amp.html#torch.autocast) for more + information on each argument. + + Example: + + ```python + from accelerate import Accelerator + from accelerate.utils import AutocastKwargs + + kwargs = AutocastKwargs(cache_enabled=True) + accelerator = Accelerator(kwargs_handlers=[kwargs]) + ``` + """ + + enabled: bool = True + cache_enabled: bool = None + + +@dataclass +class DistributedDataParallelKwargs(KwargsHandler): + """ + Use this object in your [`Accelerator`] to customize how your model is wrapped in a + `torch.nn.parallel.DistributedDataParallel`. Please refer to the documentation of this + [wrapper](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) for more + information on each argument. + + + + `gradient_as_bucket_view` is only available in PyTorch 1.7.0 and later versions. + + `static_graph` is only available in PyTorch 1.11.0 and later versions. + + + + Example: + + ```python + from accelerate import Accelerator + from accelerate.utils import DistributedDataParallelKwargs + + kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) + accelerator = Accelerator(kwargs_handlers=[kwargs]) + ``` + """ + + dim: int = 0 + broadcast_buffers: bool = True + bucket_cap_mb: int = 25 + find_unused_parameters: bool = False + check_reduction: bool = False + gradient_as_bucket_view: bool = False + static_graph: bool = False + + +@dataclass +class GradScalerKwargs(KwargsHandler): + """ + Use this object in your [`Accelerator`] to customize the behavior of mixed precision, specifically how the + `torch.cuda.amp.GradScaler` used is created. Please refer to the documentation of this + [scaler](https://pytorch.org/docs/stable/amp.html?highlight=gradscaler) for more information on each argument. + + + + `GradScaler` is only available in PyTorch 1.5.0 and later versions. + + + + Example: + + ```python + from accelerate import Accelerator + from accelerate.utils import GradScalerKwargs + + kwargs = GradScalerKwargs(backoff_filter=0.25) + accelerator = Accelerator(kwargs_handlers=[kwargs]) + ``` + """ + + init_scale: float = 65536.0 + growth_factor: float = 2.0 + backoff_factor: float = 0.5 + growth_interval: int = 2000 + enabled: bool = True + + +@dataclass +class InitProcessGroupKwargs(KwargsHandler): + """ + Use this object in your [`Accelerator`] to customize the initialization of the distributed processes. Please refer + to the documentation of this + [method](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for more + information on each argument. + + ```python + from datetime import timedelta + from accelerate import Accelerator + from accelerate.utils import InitProcessGroupKwargs + + kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=800)) + accelerator = Accelerator(kwargs_handlers=[kwargs]) + ``` + """ + + backend: Optional[str] = "nccl" + init_method: Optional[str] = None + timeout: timedelta = timedelta(seconds=1800) + + +# Literals +Backend = Literal["MSAMP", "TE"] +OptLevel = Literal["O1", "O2"] +FP8Format = Literal["E4M3", "HYBRID"] +AmaxComputeAlgorithm = Literal["max", "most_recent"] + + +@dataclass +class FP8RecipeKwargs(KwargsHandler): + """ + Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision + training with `transformer-engine` or `ms-amp`. + + + + For more information on `transformer-engine` args, please refer to the API + [documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html). + + For more information on the `ms-amp` args, please refer to the Optimization Level + [documentation](https://azure.github.io/MS-AMP/docs/user-tutorial/optimization-level). + + + + ```python + from accelerate import Accelerator + from accelerate.utils import FP8RecipeKwargs + + kwargs = FP8RecipeKwargs(backend="te", fp8_format="HYBRID") + accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=[kwargs]) + ``` + + To use MS-AMP as an engine, pass `backend="msamp"` and the `optimization_level`: + + ```python + kwargs = FP8RecipeKwargs(backend="msamp", optimization_level="02") + ``` + + Args: + backend (`str`, *optional*, defaults to "msamp"): + Which FP8 engine to use. Must be one of `"msamp"` (MS-AMP) or `"te"` (TransformerEngine). + margin (`int`, *optional*, default to 0): + The margin to use for the gradient scaling. + interval (`int`, *optional*, default to 1): + The interval to use for how often the scaling factor is recomputed. + fp8_format (`str`, *optional*, default to "E4M3"): + The format to use for the FP8 recipe. Must be one of `E4M3` or `HYBRID`. + amax_history_len (`int`, *optional*, default to 1024): + The length of the history to use for the scaling factor computation + amax_compute_algo (`str`, *optional*, default to "most_recent"): + The algorithm to use for the scaling factor computation. Must be one of `max` or `most_recent`. + override_linear_precision (`tuple` of three `bool`, *optional*, default to `(False, False, False)`): + Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision. + optimization_level (`str`), one of `O1`, `O2`. (default is `O2`): + What level of 8-bit collective communication should be used with MS-AMP. In general: + * O1: Weight gradients and `all_reduce` communications are done in fp8, reducing GPU + memory usage and communication bandwidth + * O2: First-order optimizer states are in 8-bit, and second order states are in FP16. + Only available when using Adam or AdamW. This maintains accuracy and can potentially save the + highest memory. + * 03: Specifically for DeepSpeed, implements capabilities so weights and master weights of models + are stored in FP8. If `fp8` is selected and deepspeed is enabled, will be used by default. (Not + available currently). + """ + + backend: Backend = "MSAMP" + opt_level: OptLevel = "O2" + margin: int = 0 + interval: int = 1 + fp8_format: FP8Format = "E4M3" + amax_history_len: int = 1 + amax_compute_algo: AmaxComputeAlgorithm = "most_recent" + override_linear_precision: Tuple[bool, bool, bool] = (False, False, False) + + def __post_init__(self): + if self.backend.upper() not in get_args(Backend): + raise ValueError("`backend` must be 'MSAMP' or 'TE' (TransformerEngine).") + + self.backend = self.backend.upper() + # Check TE args + if self.backend == "TE": + self.fp8_format = self.fp8_format.upper() + if self.fp8_format not in get_args(FP8Format): + raise ValueError(f"`fp8_format` must be one of {' or '.join(get_args(FP8Format))}.") + if self.amax_compute_algo not in get_args(AmaxComputeAlgorithm): + raise ValueError(f"`amax_compute_algo` must be one of {' or '.join(get_args(AmaxComputeAlgorithm))}") + elif self.backend == "MSAMP": + if self.opt_level not in get_args(OptLevel): + raise ValueError(f"`optimization_level` must be one of {' or '.join(get_args(OptLevel))}") + + +class EnumWithContains(enum.EnumMeta): + "A metaclass that adds the ability to check if `self` contains an item with the `in` operator" + + def __contains__(cls, item): + try: + cls(item) + except ValueError: + return False + return True + + +class BaseEnum(enum.Enum, metaclass=EnumWithContains): + "An enum class that can get the value of an item with `str(Enum.key)`" + + def __str__(self): + return self.value + + @classmethod + def list(cls): + "Method to list all the possible items in `cls`" + return list(map(str, cls)) + + +class DeprecatedFieldDescriptor: + """ + Descriptor for deprecated fields in an enum class. + + Args: + field_name (`str`): + The name of the deprecated field. + replaced_with (`str`): + The name of the field that replaces the deprecated one. + """ + + def __init__(self, field_name, replaced_with): + self.field_name = field_name + self.replaced_with = replaced_with + + def __get__(self, instance, owner): + warnings.warn( + f"The `{self.field_name}` of `{owner}` is deprecated and will be removed in v1.0.0. " + f"Please use the `{self.replaced_with}` instead.", + FutureWarning, + ) + return getattr(owner, self.replaced_with) + + +class DistributedType(str, enum.Enum): + """ + Represents a type of distributed environment. + + Values: + + - **NO** -- Not a distributed environment, just a single process. + - **MULTI_CPU** -- Distributed on multiple CPU nodes. + - **MULTI_GPU** -- Distributed on multiple GPUs. + - **MULTI_MLU** -- Distributed on multiple MLUs. + - **MULTI_NPU** -- Distributed on multiple NPUs. + - **MULTI_XPU** -- Distributed on multiple XPUs. + - **DEEPSPEED** -- Using DeepSpeed. + - **XLA** -- Using TorchXLA. + - **TPU** -- This field will be deprecated in v0.27.0. Use XLA instead. + """ + + # Subclassing str as well as Enum allows the `DistributedType` to be JSON-serializable out of the box. + NO = "NO" + MULTI_CPU = "MULTI_CPU" + MULTI_GPU = "MULTI_GPU" + MULTI_NPU = "MULTI_NPU" + MULTI_MLU = "MULTI_MLU" + MULTI_XPU = "MULTI_XPU" + DEEPSPEED = "DEEPSPEED" + FSDP = "FSDP" + XLA = "XLA" + MEGATRON_LM = "MEGATRON_LM" + TPU = DeprecatedFieldDescriptor("TPU", "XLA") + + +class SageMakerDistributedType(str, enum.Enum): + """ + Represents a type of distributed environment. + + Values: + + - **NO** -- Not a distributed environment, just a single process. + - **DATA_PARALLEL** -- using sagemaker distributed data parallelism. + - **MODEL_PARALLEL** -- using sagemaker distributed model parallelism. + """ + + # Subclassing str as well as Enum allows the `SageMakerDistributedType` to be JSON-serializable out of the box. + NO = "NO" + DATA_PARALLEL = "DATA_PARALLEL" + MODEL_PARALLEL = "MODEL_PARALLEL" + + +class ComputeEnvironment(str, enum.Enum): + """ + Represents a type of the compute environment. + + Values: + + - **LOCAL_MACHINE** -- private/custom cluster hardware. + - **AMAZON_SAGEMAKER** -- Amazon SageMaker as compute environment. + """ + + # Subclassing str as well as Enum allows the `ComputeEnvironment` to be JSON-serializable out of the box. + LOCAL_MACHINE = "LOCAL_MACHINE" + AMAZON_SAGEMAKER = "AMAZON_SAGEMAKER" + + +class DynamoBackend(str, BaseEnum): + """ + Represents a dynamo backend (see https://pytorch.org/docs/stable/torch.compiler.html). + + Values: + + - **NO** -- Do not use torch dynamo. + - **EAGER** -- Uses PyTorch to run the extracted GraphModule. This is quite useful in debugging TorchDynamo + issues. + - **AOT_EAGER** -- Uses AotAutograd with no compiler, i.e, just using PyTorch eager for the AotAutograd's + extracted forward and backward graphs. This is useful for debugging, and unlikely to give speedups. + - **INDUCTOR** -- Uses TorchInductor backend with AotAutograd and cudagraphs by leveraging codegened Triton + kernels. [Read + more](https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747) + - **AOT_TS_NVFUSER** -- nvFuser with AotAutograd/TorchScript. [Read + more](https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593) + - **NVPRIMS_NVFUSER** -- nvFuser with PrimTorch. [Read + more](https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593) + - **CUDAGRAPHS** -- cudagraphs with AotAutograd. [Read more](https://github.com/pytorch/torchdynamo/pull/757) + - **OFI** -- Uses Torchscript optimize_for_inference. Inference only. [Read + more](https://pytorch.org/docs/stable/generated/torch.jit.optimize_for_inference.html) + - **FX2TRT** -- Uses Nvidia TensorRT for inference optimizations. Inference only. [Read + more](https://github.com/pytorch/TensorRT/blob/master/docsrc/tutorials/getting_started_with_fx_path.rst) + - **ONNXRT** -- Uses ONNXRT for inference on CPU/GPU. Inference only. [Read more](https://onnxruntime.ai/) + - **TENSORRT** -- Uses ONNXRT to run TensorRT for inference optimizations. [Read + more](https://github.com/onnx/onnx-tensorrt) + - **IPEX** -- Uses IPEX for inference on CPU. Inference only. [Read + more](https://github.com/intel/intel-extension-for-pytorch). + - **TVM** -- Uses Apach TVM for inference optimizations. [Read more](https://tvm.apache.org/) + + """ + + # Subclassing str as well as Enum allows the `SageMakerDistributedType` to be JSON-serializable out of the box. + NO = "NO" + EAGER = "EAGER" + AOT_EAGER = "AOT_EAGER" + INDUCTOR = "INDUCTOR" + AOT_TS_NVFUSER = "AOT_TS_NVFUSER" + NVPRIMS_NVFUSER = "NVPRIMS_NVFUSER" + CUDAGRAPHS = "CUDAGRAPHS" + OFI = "OFI" + FX2TRT = "FX2TRT" + ONNXRT = "ONNXRT" + TENSORRT = "TENSORRT" + IPEX = "IPEX" + TVM = "TVM" + + +class LoggerType(BaseEnum): + """Represents a type of supported experiment tracker + + Values: + + - **ALL** -- all available trackers in the environment that are supported + - **TENSORBOARD** -- TensorBoard as an experiment tracker + - **WANDB** -- wandb as an experiment tracker + - **COMETML** -- comet_ml as an experiment tracker + - **DVCLIVE** -- dvclive as an experiment tracker + """ + + ALL = "all" + AIM = "aim" + TENSORBOARD = "tensorboard" + WANDB = "wandb" + COMETML = "comet_ml" + MLFLOW = "mlflow" + CLEARML = "clearml" + DVCLIVE = "dvclive" + + +class PrecisionType(BaseEnum): + """Represents a type of precision used on floating point values + + Values: + + - **NO** -- using full precision (FP32) + - **FP16** -- using half precision + - **BF16** -- using brain floating point precision + """ + + NO = "no" + FP8 = "fp8" + FP16 = "fp16" + BF16 = "bf16" + + +class RNGType(BaseEnum): + TORCH = "torch" + CUDA = "cuda" + MLU = "mlu" + NPU = "npu" + XLA = "xla" + XPU = "xpu" + GENERATOR = "generator" + + +class CustomDtype(enum.Enum): + r""" + An enum that contains multiple custom dtypes that can be used for `infer_auto_device_map`. + """ + + FP8 = "fp8" + INT4 = "int4" + INT2 = "int2" + + +# data classes + + +@dataclass +class TensorInformation: + shape: torch.Size + dtype: torch.dtype + + +@dataclass +class DataLoaderConfiguration: + """ + Configuration for dataloader-related items when calling `accelerator.prepare`. + """ + + split_batches: bool = field( + default=False, + metadata={ + "help": "Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If" + " `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a" + " round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set" + " in your script multiplied by the number of processes." + }, + ) + dispatch_batches: bool = field( + default=None, + metadata={ + "help": "If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process" + " and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose" + " underlying dataset is an `IterableDataslet`, `False` otherwise." + }, + ) + even_batches: bool = field( + default=True, + metadata={ + "help": "If set to `True`, in cases where the total batch size across all processes does not exactly divide the" + " dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among" + " all workers." + }, + ) + use_seedable_sampler: bool = field( + default=False, + metadata={ + "help": "Whether or not use a fully seedable random sampler ([`data_loader.SeedableRandomSampler`])." + "Ensures training results are fully reproducable using a different sampling technique. " + "While seed-to-seed results may differ, on average the differences are neglible when using" + "multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results." + }, + ) + + +@dataclass +class ProjectConfiguration: + """ + Configuration for the Accelerator object based on inner-project needs. + """ + + project_dir: str = field(default=None, metadata={"help": "A path to a directory for storing data."}) + logging_dir: str = field( + default=None, + metadata={ + "help": "A path to a directory for storing logs of locally-compatible loggers. If None, defaults to `project_dir`." + }, + ) + automatic_checkpoint_naming: bool = field( + default=False, + metadata={"help": "Whether saved states should be automatically iteratively named."}, + ) + + total_limit: int = field( + default=None, + metadata={"help": "The maximum number of total saved states to keep."}, + ) + + iteration: int = field( + default=0, + metadata={"help": "The current save iteration."}, + ) + + save_on_each_node: bool = field( + default=False, + metadata={ + "help": ( + "When doing multi-node distributed training, whether to save models and checkpoints on each node, or" + " only on the main one" + ) + }, + ) + + def set_directories(self, project_dir: str = None): + "Sets `self.project_dir` and `self.logging_dir` to the appropriate values." + self.project_dir = project_dir + if self.logging_dir is None: + self.logging_dir = project_dir + + def __post_init__(self): + self.set_directories(self.project_dir) + + +@dataclass +class GradientAccumulationPlugin(KwargsHandler): + """ + A plugin to configure gradient accumulation behavior. You can only pass one of `gradient_accumulation_plugin` or + `gradient_accumulation_steps` to [`Accelerator`]. Passing both raises an error. + + Parameters: + num_steps (`int`): + The number of steps to accumulate gradients for. + adjust_scheduler (`bool`, *optional*, defaults to `True`): + Whether to adjust the scheduler steps to account for the number of steps being accumulated. Should be + `True` if the used scheduler was not adjusted for gradient accumulation. + sync_with_dataloader (`bool`, *optional*, defaults to `True`): + Whether to synchronize setting the gradients when at the end of the dataloader. + sync_each_batch (`bool`, *optional*): + Whether to synchronize setting the gradients at each data batch. Seting to `True` may reduce memory + requirements when using gradient accumulation with distributed training, at expense of speed. + + Example: + + ```python + from accelerate.utils import GradientAccumulationPlugin + + gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=2) + accelerator = Accelerator(gradient_accumulation_plugin=gradient_accumulation_plugin) + ``` + """ + + num_steps: int = field(default=None, metadata={"help": "The number of steps to accumulate gradients for."}) + adjust_scheduler: bool = field( + default=True, + metadata={ + "help": "Whether to adjust the scheduler steps to account for the number of steps being accumulated. Should be `True` if the used scheduler was not adjusted for gradient accumulation." + }, + ) + sync_with_dataloader: bool = field( + default=True, + metadata={ + "help": "Whether to synchronize setting the gradients when at the end of the dataloader. Should only be set to `False` if you know what you're doing." + }, + ) + sync_each_batch: bool = field( + default=False, + metadata={ + "help": "Whether to synchronize setting the gradients at each data batch. Setting to `True` may reduce memory requirements when using gradient accumulation with distributed training, at expense of speed." + }, + ) + + +@dataclass +class TorchDynamoPlugin(KwargsHandler): + """ + This plugin is used to compile a model with PyTorch 2.0 + """ + + backend: DynamoBackend = field( + default=None, + metadata={"help": f"Possible options are {[b.value.lower() for b in DynamoBackend]}"}, + ) + mode: str = field( + default=None, metadata={"help": "Possible options are 'default', 'reduce-overhead' or 'max-autotune'"} + ) + fullgraph: bool = field(default=None, metadata={"help": "Whether it is ok to break model into several subgraphs"}) + dynamic: bool = field(default=None, metadata={"help": "Whether to use dynamic shape for tracing"}) + options: Any = field(default=None, metadata={"help": "A dictionary of options to pass to the backend."}) + disable: bool = field(default=False, metadata={"help": "Turn torch.compile() into a no-op for testing"}) + + def __post_init__(self): + prefix = "ACCELERATE_DYNAMO_" + if self.backend is None: + self.backend = os.environ.get(prefix + "BACKEND", "no") + self.backend = DynamoBackend(self.backend.upper()) + if self.mode is None: + self.mode = os.environ.get(prefix + "MODE", "default") + if self.fullgraph is None: + self.fullgraph = str_to_bool(os.environ.get(prefix + "USE_FULLGRAPH", "False")) == 1 + if self.dynamic is None: + self.dynamic = str_to_bool(os.environ.get(prefix + "USE_DYNAMIC", "False")) == 1 + + def to_dict(self): + dynamo_config = copy.deepcopy(self.__dict__) + dynamo_config["backend"] = dynamo_config["backend"].value.lower() + return dynamo_config + + +@dataclass +class DeepSpeedPlugin: + """ + This plugin is used to integrate DeepSpeed. + """ + + hf_ds_config: Any = field( + default=None, + metadata={ + "help": "path to DeepSpeed config file or dict or an object of class `accelerate.utils.deepspeed.HfDeepSpeedConfig`." + }, + ) + gradient_accumulation_steps: int = field( + default=None, + metadata={ + "help": "Number of steps to accumulate gradients before updating optimizer states. If not set, will use the value from the `Accelerator` directly." + }, + ) + gradient_clipping: float = field(default=None, metadata={"help": "Enable gradient clipping with value"}) + zero_stage: int = field( + default=None, + metadata={"help": "Possible options are 0,1,2,3; Default will be taken from environment variable"}, + ) + is_train_batch_min: str = field( + default=True, + metadata={"help": "If both train & eval dataloaders are specified, this will decide the train_batch_size"}, + ) + offload_optimizer_device: bool = field( + default=None, + metadata={"help": "Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3."}, + ) + offload_param_device: bool = field( + default=None, + metadata={"help": "Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3."}, + ) + offload_optimizer_nvme_path: str = field( + default=None, + metadata={"help": "Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3."}, + ) + offload_param_nvme_path: str = field( + default=None, + metadata={"help": "Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3."}, + ) + zero3_init_flag: bool = field( + default=None, + metadata={ + "help": "Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models." + "Only applicable with ZeRO Stage-3." + }, + ) + zero3_save_16bit_model: bool = field( + default=None, + metadata={"help": "Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3."}, + ) + + def __post_init__(self): + from .deepspeed import HfDeepSpeedConfig + + if self.gradient_accumulation_steps is None: + gas = os.environ.get("ACCELERATE_GRADIENT_ACCUMULATION_STEPS", "auto") + self.gradient_accumulation_steps = int(gas) if gas.isdigit() else gas + + if self.gradient_clipping is None: + gradient_clipping = os.environ.get("ACCELERATE_GRADIENT_CLIPPING", "none") + if gradient_clipping != "none": + self.gradient_clipping = float(gradient_clipping) + + if self.zero_stage is None: + self.zero_stage = int(os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE", 2)) + + if self.offload_optimizer_device is None: + self.offload_optimizer_device = os.environ.get("ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_DEVICE", "none") + + if self.offload_param_device is None: + self.offload_param_device = os.environ.get("ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_DEVICE", "none") + + if self.offload_optimizer_nvme_path is None: + self.offload_optimizer_nvme_path = os.environ.get( + "ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_NVME_PATH", "none" + ) + + if self.offload_param_nvme_path is None: + self.offload_param_nvme_path = os.environ.get("ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_NVME_PATH", "none") + + if self.zero3_save_16bit_model is None: + self.zero3_save_16bit_model = ( + os.environ.get("ACCELERATE_DEEPSPEED_ZERO3_SAVE_16BIT_MODEL", "false") == "true" + ) + + if self.hf_ds_config is None: + self.hf_ds_config = os.environ.get("ACCELERATE_DEEPSPEED_CONFIG_FILE", "none") + if ( + isinstance(self.hf_ds_config, dict) + or (isinstance(self.hf_ds_config, str) and self.hf_ds_config != "none") + or isinstance(self.hf_ds_config, HfDeepSpeedConfig) + ): + if not isinstance(self.hf_ds_config, HfDeepSpeedConfig): + self.hf_ds_config = HfDeepSpeedConfig(self.hf_ds_config) + if "gradient_accumulation_steps" not in self.hf_ds_config.config: + self.hf_ds_config.config["gradient_accumulation_steps"] = 1 + if "zero_optimization" not in self.hf_ds_config.config: + raise ValueError("Please specify the ZeRO optimization config in the DeepSpeed config.") + + self._deepspeed_config_checks() + plugin_to_config_mapping = { + "gradient_accumulation_steps": "gradient_accumulation_steps", + "gradient_clipping": "gradient_clipping", + "zero_stage": "zero_optimization.stage", + "offload_optimizer_device": "zero_optimization.offload_optimizer.device", + "offload_param_device": "zero_optimization.offload_param.device", + "offload_param_nvme_path": "zero_optimization.offload_param.nvme_path", + "offload_optimizer_nvme_path": "zero_optimization.offload_optimizer.nvme_path", + "zero3_save_16bit_model": "zero_optimization.stage3_gather_16bit_weights_on_model_save", + } + kwargs = {v: getattr(self, k) for k, v in plugin_to_config_mapping.items() if getattr(self, k) is not None} + for key in kwargs.keys(): + self.fill_match(key, **kwargs, must_match=False) + self.hf_ds_config.set_stage_and_offload() + + # filling the missing values in the class attributes from the DeepSpeed config + # when using the DeepSpeed config file. + for key, value in plugin_to_config_mapping.items(): + config_value = self.hf_ds_config.get_value(value) + if config_value is not None and config_value != "auto": + setattr(self, key, config_value) + else: + config = { + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "gradient_accumulation_steps": self.gradient_accumulation_steps, + "zero_optimization": { + "stage": self.zero_stage, + "offload_optimizer": { + "device": self.offload_optimizer_device, + "nvme_path": self.offload_optimizer_nvme_path + if self.offload_optimizer_device == "nvme" + else None, + }, + "offload_param": { + "device": self.offload_param_device, + "nvme_path": self.offload_param_nvme_path if self.offload_param_device == "nvme" else None, + }, + "stage3_gather_16bit_weights_on_model_save": self.zero3_save_16bit_model, + }, + } + if self.gradient_clipping: + config["gradient_clipping"] = self.gradient_clipping + self.hf_ds_config = HfDeepSpeedConfig(config) + + self.deepspeed_config = self.hf_ds_config.config + self.deepspeed_config["steps_per_print"] = float("inf") # this will stop deepspeed from logging @ stdout + if self.zero3_init_flag is None: + self.zero3_init_flag = ( + str_to_bool(os.environ.get("ACCELERATE_DEEPSPEED_ZERO3_INIT", str(self.hf_ds_config.is_zero3()))) == 1 + ) + if self.zero3_init_flag and not self.hf_ds_config.is_zero3(): + warnings.warn("DeepSpeed Zero3 Init flag is only applicable for ZeRO Stage 3. Setting it to False.") + self.zero3_init_flag = False + + def fill_match(self, ds_key_long, mismatches=None, must_match=True, **kwargs): + mismatches = [] if mismatches is None else mismatches + config, ds_key = self.hf_ds_config.find_config_node(ds_key_long) + if config is None: + return + + if config.get(ds_key) == "auto": + if ds_key_long in kwargs: + config[ds_key] = kwargs[ds_key_long] + return + else: + raise ValueError( + f"`{ds_key_long}` not found in kwargs. " + f"Please specify `{ds_key_long}` without `auto` (set to correct value) in the DeepSpeed config file or " + "pass it in kwargs." + ) + + if not must_match: + return + + ds_val = config.get(ds_key) + if ds_val is not None and ds_key_long in kwargs: + if ds_val != kwargs[ds_key_long]: + mismatches.append(f"- ds {ds_key_long}={ds_val} vs arg {ds_key_long}={kwargs[ds_key_long]}") + + def is_auto(self, ds_key_long): + val = self.hf_ds_config.get_value(ds_key_long) + if val is None: + return False + else: + return val == "auto" + + def get_value(self, ds_key_long, default=None): + return self.hf_ds_config.get_value(ds_key_long, default) + + def deepspeed_config_process(self, prefix="", mismatches=None, config=None, must_match=True, **kwargs): + """Process the DeepSpeed config with the values from the kwargs.""" + mismatches = [] if mismatches is None else mismatches + if config is None: + config = self.deepspeed_config + for key, value in config.items(): + if isinstance(value, dict): + self.deepspeed_config_process( + prefix=prefix + key + ".", mismatches=mismatches, config=value, must_match=must_match, **kwargs + ) + else: + self.fill_match(prefix + key, mismatches, must_match=must_match, **kwargs) + if len(mismatches) > 0 and prefix == "": + mismatches_msg = "\n".join(mismatches) + raise ValueError( + "Please correct the following DeepSpeed config values that mismatch kwargs " + f" values:\n{mismatches_msg}\nThe easiest method is to set these DeepSpeed config values to 'auto'." + ) + + def set_mixed_precision(self, mixed_precision): + ds_config = self.deepspeed_config + kwargs = { + "fp16.enabled": mixed_precision == "fp16", + "bf16.enabled": mixed_precision == "bf16", + } + if mixed_precision == "fp16": + if "fp16" not in ds_config: + ds_config["fp16"] = {"enabled": True, "auto_cast": True} + elif mixed_precision == "bf16": + if "bf16" not in ds_config: + ds_config["bf16"] = {"enabled": True} + + if mixed_precision != "no": + diff_dtype = "bf16" if mixed_precision == "fp16" else "fp16" + if str(ds_config.get(diff_dtype, {}).get("enabled", "False")).lower() == "true": + raise ValueError( + f"`--mixed_precision` arg cannot be set to `{mixed_precision}` when `{diff_dtype}` is set in the DeepSpeed config file." + ) + for dtype in ["fp16", "bf16"]: + if dtype not in ds_config: + ds_config[dtype] = {"enabled": False} + self.fill_match("fp16.enabled", must_match=False, **kwargs) + self.fill_match("bf16.enabled", must_match=False, **kwargs) + + def set_deepspeed_weakref(self): + from .imports import is_transformers_available + + if self.zero3_init_flag: + if not is_transformers_available(): + raise Exception( + "When `zero3_init_flag` is set, it requires Transformers to be installed. " + "Please run `pip install transformers`." + ) + ds_config = copy.deepcopy(self.deepspeed_config) + if "gradient_accumulation_steps" not in ds_config or ds_config["gradient_accumulation_steps"] == "auto": + ds_config["gradient_accumulation_steps"] = 1 + if ( + "train_micro_batch_size_per_gpu" not in ds_config + or ds_config["train_micro_batch_size_per_gpu"] == "auto" + ): + ds_config["train_micro_batch_size_per_gpu"] = 1 + if ds_config.get("train_batch_size", None) == "auto": + del ds_config["train_batch_size"] + + if compare_versions("transformers", "<", "4.33"): + from transformers.deepspeed import HfDeepSpeedConfig + else: + from transformers.integrations import HfDeepSpeedConfig + + self.dschf = HfDeepSpeedConfig(ds_config) # keep this object alive # noqa + + def is_zero3_init_enabled(self): + return self.zero3_init_flag + + @contextmanager + def zero3_init_context_manager(self, enable=False): + old = self.zero3_init_flag + if old == enable: + yield + else: + self.zero3_init_flag = enable + self.dschf = None + self.set_deepspeed_weakref() + yield + self.zero3_init_flag = old + self.dschf = None + self.set_deepspeed_weakref() + + def _deepspeed_config_checks(self): + env_variable_names_to_ignore = [ + "ACCELERATE_GRADIENT_ACCUMULATION_STEPS", + "ACCELERATE_GRADIENT_CLIPPING", + "ACCELERATE_DEEPSPEED_ZERO_STAGE", + "ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_DEVICE", + "ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_DEVICE", + "ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_NVME_PATH", + "ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_NVME_PATH", + "ACCELERATE_DEEPSPEED_ZERO3_SAVE_16BIT_MODEL", + "ACCELERATE_MIXED_PRECISION", + ] + env_variable_names_to_ignore = [ + name.replace("ACCELERATE_", "").replace("DEEPSPEED_", "").lower() for name in env_variable_names_to_ignore + ] + + deepspeed_fields_from_accelerate_config = os.environ.get("ACCELERATE_CONFIG_DS_FIELDS", "").split(",") + + if any(name in env_variable_names_to_ignore for name in deepspeed_fields_from_accelerate_config): + raise ValueError( + f"When using `deepspeed_config_file`, the following accelerate config variables will be ignored: {env_variable_names_to_ignore}.\n" + "Please specify them appropriately in the DeepSpeed config file.\n" + "If you are using an accelerate config file, remove others config variables mentioned in the above specified list.\n" + "The easiest method is to create a new config following the questionnaire via `accelerate config`.\n" + "It will only ask for the necessary config variables when using `deepspeed_config_file`." + ) + + +@dataclass +class FullyShardedDataParallelPlugin: + """ + This plugin is used to enable fully sharded data parallelism. + """ + + sharding_strategy: "typing.Any" = field( + default=None, + metadata={ + "help": "FSDP Sharding Strategy of type `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`" + }, + ) + backward_prefetch: "typing.Any" = field( + default=None, + metadata={ + "help": "FSDP Backward Prefetch of type `torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch`" + }, + ) + mixed_precision_policy: "typing.Any" = field( + default=None, + metadata={ + "help": "A config to enable mixed precision training with FullyShardedDataParallel. " + "The 3 flags that are set are `param_dtype`, `reduce_dtype`, `buffer_dtype`. " + "Each flag expects `torch.dtype` as the value. " + "It is of type `torch.distributed.fsdp.fully_sharded_data_parallel.MixedPrecision`." + }, + ) + auto_wrap_policy: Optional[Callable] = field( + default=None, + metadata={"help": "A callable specifying a policy to recursively wrap layers with FSDP"}, + ) + cpu_offload: "typing.Any" = field( + default=None, + metadata={ + "help": "Decides Whether to offload parameters and gradients to CPU. " + "It is of type `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload`." + }, + ) + ignored_modules: Optional[Iterable[torch.nn.Module]] = field( + default=None, + metadata={"help": "A list of modules to ignore for FSDP."}, + ) + state_dict_type: "typing.Any" = field( + default=None, + metadata={ + "help": "FSDP State Dict Type of type `torch.distributed.fsdp.fully_sharded_data_parallel.StateDictType`" + }, + ) + state_dict_config: "typing.Any" = field( + default=None, + metadata={ + "help": "FSDP State Dict Config of type `torch.distributed.fsdp.fully_sharded_data_parallel.StateDictConfig`" + }, + ) + optim_state_dict_config: "typing.Any" = field( + default=None, + metadata={ + "help": "FSDP Optimizer State Dict Config of type `torch.distributed.fsdp.fully_sharded_data_parallel.OptimStateDictConfig`" + }, + ) + limit_all_gathers: bool = field( + default=True, + metadata={ + "help": "If False, then FSDP allows the CPU thread to schedule all-gathers " + "without any extra synchronization. If True, then FSDP explicitly synchronizes the CPU thread to prevent " + "too many in-flight all-gathers. This bool only affects the sharded strategies that schedule all-gathers. " + "Enabling this can help lower the number of CUDA malloc retries." + }, + ) + use_orig_params: bool = field( + default=True, + metadata={ + "help": "If `True`, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable parameters. " + "Useful in cases such as parameter-efficient fine-tuning. " + "Please refer this [blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019). " + "This also enables multiple optimizer param groups. This should be `True` when creating an optimizer object before preparing/wrapping the model with FSDP." + }, + ) + param_init_fn: Optional[Callable[[torch.nn.Module], None]] = field( + default=None, + metadata={ + "help": "A Callable[torch.nn.Module] -> None that specifies how modules " + "that are currently on the meta device should be initialized onto an actual device." + }, + ) + sync_module_states: bool = field( + default=True, + metadata={ + "help": "If True, each individually wrapped FSDP unit will broadcast module parameters from rank 0 " + "to ensure they are the same across all ranks after initialization" + }, + ) + forward_prefetch: bool = field( + default=False, + metadata={ + "help": "If True, then FSDP explicitly prefetches the next upcoming " + "all-gather while executing in the forward pass. only use with Static graphs." + }, + ) + activation_checkpointing: bool = field( + default=False, + metadata={ + "help": "If True, activation checkpointing is a technique to reduce memory usage by clearing activations of " + "certain layers and recomputing them during a backward pass. Effectively, this trades extra computation time " + "for reduced memory usage." + }, + ) + + def __post_init__(self): + from torch.distributed.fsdp.fully_sharded_data_parallel import BackwardPrefetch, CPUOffload, ShardingStrategy + + prefix = "FSDP_" + if self.sharding_strategy is None: + sharding_strategy = os.environ.get(prefix + "SHARDING_STRATEGY", "FULL_SHARD") + sharding_strategy = ( + FSDP_SHARDING_STRATEGY.index(sharding_strategy) + 1 + if not sharding_strategy.isdigit() + else int(sharding_strategy) + ) + self.sharding_strategy = ShardingStrategy(sharding_strategy) + + if self.cpu_offload is None: + if str_to_bool(os.environ.get(prefix + "OFFLOAD_PARAMS", "False")) == 1: + self.cpu_offload = CPUOffload(offload_params=True) + else: + self.cpu_offload = CPUOffload(offload_params=False) + + if self.backward_prefetch is None: + prefetch_policy = os.environ.get(prefix + "BACKWARD_PREFETCH", "NO_PREFETCH") + if prefetch_policy != FSDP_BACKWARD_PREFETCH[-1]: + self.backward_prefetch = BackwardPrefetch(FSDP_BACKWARD_PREFETCH.index(prefetch_policy) + 1) + + if self.state_dict_type is None: + state_dict_type_policy = os.environ.get(prefix + "STATE_DICT_TYPE", "FULL_STATE_DICT") + self.set_state_dict_type(state_dict_type_policy) + self.use_orig_params = str_to_bool(os.environ.get(prefix + "USE_ORIG_PARAMS", "False")) == 1 + self.sync_module_states = str_to_bool(os.environ.get(prefix + "SYNC_MODULE_STATES", "True")) == 1 + self.forward_prefetch = str_to_bool(os.environ.get(prefix + "FORWARD_PREFETCH", "False")) == 1 + self.activation_checkpointing = str_to_bool(os.environ.get(prefix + "ACTIVATION_CHECKPOINTING", "False")) == 1 + + if self.sync_module_states: + if is_npu_available(): + device = torch.npu.current_device() + elif is_cuda_available(): + device = torch.cuda.current_device() + elif is_xpu_available(): + device = torch.xpu.current_device() + else: + raise RuntimeError( + "There are currently no available devices found, must be one of 'XPU', 'CUDA', or 'NPU'." + ) + self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False) + + @staticmethod + def get_module_class_from_name(module, name): + """ + Gets a class from a module by its name. + + Args: + module (`torch.nn.Module`): The module to get the class from. + name (`str`): The name of the class. + """ + modules_children = list(module.children()) + if module.__class__.__name__ == name: + return module.__class__ + elif len(modules_children) == 0: + return + else: + for child_module in modules_children: + module_class = FullyShardedDataParallelPlugin.get_module_class_from_name(child_module, name) + if module_class is not None: + return module_class + + def set_auto_wrap_policy(self, model): + from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy + + default_transformer_cls_names_to_wrap = ( + ",".join(model._no_split_modules) if getattr(model, "_no_split_modules", None) is not None else "" + ) + if self.auto_wrap_policy is None: + auto_wrap_policy = os.environ.get("FSDP_AUTO_WRAP_POLICY", "NO_WRAP") + if auto_wrap_policy == FSDP_AUTO_WRAP_POLICY[0]: + transformer_cls_names_to_wrap = os.environ.get( + "FSDP_TRANSFORMER_CLS_TO_WRAP", default_transformer_cls_names_to_wrap + ).split(",") + transformer_cls_to_wrap = set() + for layer_class in transformer_cls_names_to_wrap: + transformer_cls = FullyShardedDataParallelPlugin.get_module_class_from_name(model, layer_class) + if transformer_cls is None: + raise Exception("Could not find the transformer layer class to wrap in the model.") + else: + transformer_cls_to_wrap.add(transformer_cls) + + self.auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + # Transformer layer class to wrap + transformer_layer_cls=transformer_cls_to_wrap, + ) + elif auto_wrap_policy == FSDP_AUTO_WRAP_POLICY[1]: + min_num_params = int(os.environ.get("FSDP_MIN_NUM_PARAMS", 0)) + if min_num_params > 0: + self.auto_wrap_policy = functools.partial( + size_based_auto_wrap_policy, min_num_params=min_num_params + ) + + def set_mixed_precision(self, mixed_precision, buffer_autocast=False, override=False): + if isinstance(mixed_precision, str): + if mixed_precision == "fp16": + dtype = torch.float16 + elif mixed_precision == "bf16": + dtype = torch.bfloat16 + elif mixed_precision == "fp32": + dtype = torch.float32 + else: + raise ValueError(f"Unknown mixed precision value: {mixed_precision}") + else: + dtype = mixed_precision + + buffer_dtype = torch.float32 if buffer_autocast else dtype + from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision + + if self.mixed_precision_policy is None or override: + self.mixed_precision_policy = MixedPrecision( + param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=buffer_dtype + ) + + def set_state_dict_type(self, state_dict_type_policy): + from torch.distributed.fsdp.fully_sharded_data_parallel import ( + FullOptimStateDictConfig, + FullStateDictConfig, + StateDictType, + ) + + self.state_dict_type = StateDictType(FSDP_STATE_DICT_TYPE.index(state_dict_type_policy) + 1) + + if self.state_dict_type == StateDictType.FULL_STATE_DICT: + if self.state_dict_config is None: + self.state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + if self.optim_state_dict_config is None: + self.optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True) + + +@dataclass +class MegatronLMPlugin: + """ + Plugin for Megatron-LM to enable tensor, pipeline, sequence and data parallelism. Also to enable selective + activation recomputation and optimized fused kernels. + """ + + tp_degree: int = field(default=None, metadata={"help": "tensor parallelism degree."}) + pp_degree: int = field(default=None, metadata={"help": "pipeline parallelism degree."}) + num_micro_batches: int = field(default=None, metadata={"help": "number of micro-batches."}) + gradient_clipping: float = field( + default=None, metadata={"help": "gradient clipping value based on global L2 Norm (0 to disable)"} + ) + sequence_parallelism: bool = field( + default=None, + metadata={"help": "enable sequence parallelism"}, + ) + recompute_activations: bool = field( + default=None, + metadata={"help": "enable selective activation recomputation"}, + ) + use_distributed_optimizer: bool = field( + default=None, + metadata={"help": "enable distributed optimizer"}, + ) + pipeline_model_parallel_split_rank: int = field( + default=None, metadata={"help": "Rank where encoder and decoder should be split."} + ) + num_layers_per_virtual_pipeline_stage: int = field( + default=None, metadata={"help": "Number of layers per virtual pipeline stage."} + ) + is_train_batch_min: str = field( + default=True, + metadata={"help": "If both train & eval dataloaders are specified, this will decide the micro_batch_size"}, + ) + train_iters: int = field( + default=None, + metadata={ + "help": "Total number of iterations to train over all training runs. " + "Note that either train-iters or train-samples should be provided when using `MegatronLMDummyScheduler`" + }, + ) + train_samples: int = field( + default=None, + metadata={ + "help": "Total number of samples to train over all training runs. " + "Note that either train-iters or train-samples should be provided when using `MegatronLMDummyScheduler`" + }, + ) + weight_decay_incr_style: str = field( + default="constant", + metadata={"help": 'Weight decay increment function. choices=["constant", "linear", "cosine"]. '}, + ) + start_weight_decay: float = field( + default=None, + metadata={"help": "Initial weight decay coefficient for L2 regularization."}, + ) + end_weight_decay: float = field( + default=None, + metadata={"help": "End of run weight decay coefficient for L2 regularization."}, + ) + lr_decay_style: str = field( + default="linear", + metadata={"help": "Learning rate decay function. choices=['constant', 'linear', 'cosine']."}, + ) + lr_decay_iters: int = field( + default=None, + metadata={"help": "Number of iterations for learning rate decay. If None defaults to `train_iters`."}, + ) + lr_decay_samples: int = field( + default=None, + metadata={"help": "Number of samples for learning rate decay. If None defaults to `train_samples`."}, + ) + lr_warmup_iters: int = field( + default=None, + metadata={"help": "number of iterations to linearly warmup learning rate over."}, + ) + lr_warmup_samples: int = field( + default=None, + metadata={"help": "number of samples to linearly warmup learning rate over."}, + ) + lr_warmup_fraction: float = field( + default=None, + metadata={"help": "fraction of lr-warmup-(iters/samples) to linearly warmup learning rate over."}, + ) + min_lr: float = field( + default=0, + metadata={"help": "Minumum value for learning rate. The scheduler clip values below this threshold."}, + ) + consumed_samples: List[int] = field( + default=None, + metadata={ + "help": "Number of samples consumed in the same order as the dataloaders to `accelerator.prepare` call." + }, + ) + no_wd_decay_cond: Optional[Callable] = field(default=None, metadata={"help": "Condition to disable weight decay."}) + scale_lr_cond: Optional[Callable] = field(default=None, metadata={"help": "Condition to scale learning rate."}) + lr_mult: float = field(default=1.0, metadata={"help": "Learning rate multiplier."}) + megatron_dataset_flag: bool = field( + default=False, + metadata={"help": "Whether the format of dataset follows Megatron-LM Indexed/Cached/MemoryMapped format."}, + ) + seq_length: int = field( + default=None, + metadata={"help": "Maximum sequence length to process."}, + ) + encoder_seq_length: int = field( + default=None, + metadata={"help": "Maximum sequence length to process for the encoder."}, + ) + decoder_seq_length: int = field( + default=None, + metadata={"help": "Maximum sequence length to process for the decoder."}, + ) + tensorboard_dir: str = field( + default=None, + metadata={"help": "Path to save tensorboard logs."}, + ) + set_all_logging_options: bool = field( + default=False, + metadata={"help": "Whether to set all logging options."}, + ) + eval_iters: int = field( + default=100, metadata={"help": "Number of iterations to run for evaluation validation/test for."} + ) + eval_interval: int = field( + default=1000, metadata={"help": "Interval between running evaluation on validation set."} + ) + return_logits: bool = field( + default=False, + metadata={"help": "Whether to return logits from the model."}, + ) + + # custom train step args + custom_train_step_class: Optional[Any] = field( + default=None, + metadata={"help": "Custom train step class."}, + ) + custom_train_step_kwargs: Optional[Dict[str, Any]] = field( + default=None, + metadata={"help": "Custom train step kwargs."}, + ) + + # custom model args + custom_model_provider_function: Optional[Callable] = field( + default=None, + metadata={"help": "Custom model provider function."}, + ) + custom_prepare_model_function: Optional[Callable] = field( + default=None, + metadata={"help": "Custom prepare model function."}, + ) + + # remaining args such as enabling Alibi/ROPE positional embeddings, + # wandb logging, Multi-Query Attention, etc. + other_megatron_args: Optional[Dict[str, Any]] = field( + default=None, + metadata={"help": "Other Megatron-LM arguments. Please refer Megatron-LM"}, + ) + + def __post_init__(self): + prefix = "MEGATRON_LM_" + if self.tp_degree is None: + self.tp_degree = int(os.environ.get(prefix + "TP_DEGREE", 1)) + if self.pp_degree is None: + self.pp_degree = int(os.environ.get(prefix + "PP_DEGREE", 1)) + if self.num_micro_batches is None: + self.num_micro_batches = int(os.environ.get(prefix + "NUM_MICRO_BATCHES", 1)) + if self.gradient_clipping is None: + self.gradient_clipping = float(os.environ.get(prefix + "GRADIENT_CLIPPING", 1.0)) + if self.recompute_activations is None: + self.recompute_activations = str_to_bool(os.environ.get(prefix + "RECOMPUTE_ACTIVATIONS", "False")) == 1 + if self.use_distributed_optimizer is None: + self.use_distributed_optimizer = ( + str_to_bool(os.environ.get(prefix + "USE_DISTRIBUTED_OPTIMIZER", "False")) == 1 + ) + if self.sequence_parallelism is None: + self.sequence_parallelism = str_to_bool(os.environ.get(prefix + "SEQUENCE_PARALLELISM", "False")) == 1 + + if self.pp_degree > 1 or self.use_distributed_optimizer: + self.DDP_impl = "local" + else: + self.DDP_impl = "torch" + + if self.consumed_samples is not None: + if len(self.consumed_samples) == 1: + self.consumed_samples.extend([0, 0]) + elif len(self.consumed_samples) == 2: + self.consumed_samples.append(0) + + self.megatron_lm_default_args = { + "tensor_model_parallel_size": self.tp_degree, + "pipeline_model_parallel_size": self.pp_degree, + "pipeline_model_parallel_split_rank": self.pipeline_model_parallel_split_rank, + "num_layers_per_virtual_pipeline_stage": self.num_layers_per_virtual_pipeline_stage, + "DDP_impl": self.DDP_impl, + "use_distributed_optimizer": self.use_distributed_optimizer, + "sequence_parallel": self.sequence_parallelism, + "clip_grad": self.gradient_clipping, + "num_micro_batches": self.num_micro_batches, + "consumed_samples": self.consumed_samples, + "no_wd_decay_cond": self.no_wd_decay_cond, + "scale_lr_cond": self.scale_lr_cond, + "lr_mult": self.lr_mult, + "megatron_dataset_flag": self.megatron_dataset_flag, + "eval_iters": self.eval_iters, + "eval_interval": self.eval_interval, + } + if self.recompute_activations: + self.megatron_lm_default_args["recompute_granularity"] = "selective" + if self.tensorboard_dir is not None: + self.megatron_lm_default_args["tensorboard_dir"] = self.tensorboard_dir + if self.set_all_logging_options: + self.set_tensorboard_logging_options() + if self.other_megatron_args is not None: + self.megatron_lm_default_args.update(self.other_megatron_args) + + def set_network_size_args(self, model, batch_data=None): + # Check if the model is either BERT, GPT or T5 else raise error + # set 'num_layers', 'hidden_size', 'num_attention_heads', 'max_position_embeddings' + if "megatron-bert" in model.config.model_type.lower(): + model_type_name = "bert" + num_layers = model.config.num_hidden_layers + hidden_size = model.config.hidden_size + num_attention_heads = model.config.num_attention_heads + max_position_embeddings = model.config.max_position_embeddings + num_labels = model.config.num_labels + orig_vocab_size = model.config.vocab_size + if "maskedlm" in model.__class__.__name__.lower(): + pretraining_flag = True + if self.seq_length is not None: + if self.encoder_seq_length is not None: + warnings.warn("Both `seq_length` and `encoder_seq_length` are set. Using `encoder_seq_length`.") + self.seq_length = self.encoder_seq_length + elif self.encoder_seq_length is not None: + self.seq_length = self.encoder_seq_length + elif batch_data is not None: + self.seq_length = batch_data["input_ids"].shape[1] + else: + self.seq_length = max_position_embeddings + self.megatron_lm_default_args["seq_length"] = self.seq_length + elif "gpt2" in model.config.model_type.lower(): + model_type_name = "gpt" + num_layers = model.config.n_layer + hidden_size = model.config.n_embd + num_attention_heads = model.config.n_head + max_position_embeddings = model.config.n_positions + orig_vocab_size = model.config.vocab_size + pretraining_flag = True + if self.seq_length is not None: + if self.decoder_seq_length is not None: + warnings.warn("Both `seq_length` and `decoder_seq_length` are set. Using `decoder_seq_length`.") + self.seq_length = self.decoder_seq_length + elif self.decoder_seq_length is not None: + self.seq_length = self.decoder_seq_length + elif batch_data is not None: + self.seq_length = batch_data["input_ids"].shape[1] + else: + self.seq_length = max_position_embeddings + self.megatron_lm_default_args["seq_length"] = self.seq_length + self.megatron_lm_default_args["return_logits"] = self.return_logits + self.megatron_lm_default_args["tokenizer_type"] = "GPT2BPETokenizer" + elif "t5" in model.config.model_type.lower(): + model_type_name = "t5" + num_layers = model.config.num_layers + hidden_size = model.config.d_model + num_attention_heads = model.config.num_heads + max_position_embeddings = model.config.n_positions if hasattr(model.config, "n_positions") else 1024 + orig_vocab_size = model.config.vocab_size + pretraining_flag = True + if self.encoder_seq_length is None: + if batch_data is not None: + self.encoder_seq_length = batch_data["input_ids"].shape[1] + else: + self.encoder_seq_length = max_position_embeddings + if self.decoder_seq_length is None: + if batch_data is not None: + self.decoder_seq_length = batch_data["labels"].shape[1] + else: + self.decoder_seq_length = max_position_embeddings + + self.megatron_lm_default_args["encoder_seq_length"] = self.encoder_seq_length + self.megatron_lm_default_args["decoder_seq_length"] = self.decoder_seq_length + else: + raise ValueError( + "🤗 Accelerate Megatron-LM integration supports only BERT, GPT and T5 model. " + "Please check the model you are using is one of those." + ) + + self.megatron_lm_default_args["model_type_name"] = model_type_name + self.megatron_lm_default_args["num_layers"] = num_layers + self.megatron_lm_default_args["hidden_size"] = hidden_size + self.megatron_lm_default_args["num_attention_heads"] = num_attention_heads + self.megatron_lm_default_args["max_position_embeddings"] = max_position_embeddings + self.megatron_lm_default_args["pretraining_flag"] = pretraining_flag + self.megatron_lm_default_args["orig_vocab_size"] = orig_vocab_size + self.megatron_lm_default_args["model_return_dict"] = model.config.return_dict + if model_type_name == "bert": + self.megatron_lm_default_args["num_labels"] = num_labels + + def set_mixed_precision(self, mixed_precision): + if mixed_precision == "fp16": + self.megatron_lm_default_args["fp16"] = True + elif mixed_precision == "bf16": + self.megatron_lm_default_args["bf16"] = True + self.DDP_impl = "local" + self.megatron_lm_default_args["DDP_impl"] = self.DDP_impl + + def set_training_args(self, micro_batch_size, dp_degree): + self.data_parallel_size = dp_degree + self.micro_batch_size = micro_batch_size + self.global_batch_size = dp_degree * micro_batch_size * self.num_micro_batches + self.megatron_lm_default_args["data_parallel_size"] = self.data_parallel_size + self.megatron_lm_default_args["micro_batch_size"] = self.micro_batch_size + self.megatron_lm_default_args["global_batch_size"] = self.global_batch_size + + def set_optimizer_type(self, optimizer): + optimizer_name = optimizer.__class__.__name__.lower() + if "adam" in optimizer_name: + self.megatron_lm_default_args["optimizer"] = "adam" + self.megatron_lm_default_args["adam_beta1"] = optimizer.defaults["betas"][0] + self.megatron_lm_default_args["adam_beta2"] = optimizer.defaults["betas"][1] + self.megatron_lm_default_args["adam_eps"] = optimizer.defaults["eps"] + elif "sgd" in optimizer_name: + self.megatron_lm_default_args["optimizer"] = "sgd" + self.megatron_lm_default_args["sgd_momentum"] = optimizer.defaults["momentum"] + else: + raise ValueError(f"Optimizer {optimizer_name} is not supported by Megatron-LM") + + self.megatron_lm_default_args["lr"] = optimizer.defaults["lr"] + self.megatron_lm_default_args["weight_decay"] = optimizer.defaults["weight_decay"] + + def set_scheduler_args(self, scheduler): + if self.train_iters is None: + self.train_iters = scheduler.total_num_steps // self.megatron_lm_default_args["data_parallel_size"] + if self.train_samples is not None: + self.train_samples = None + warnings.warn( + "Ignoring `train_samples` as `train_iters` based on scheduler is being used for training." + ) + if self.lr_warmup_iters is None: + self.lr_warmup_iters = scheduler.warmup_num_steps // self.megatron_lm_default_args["data_parallel_size"] + if self.lr_warmup_samples is not None: + warnings.warn( + "Ignoring `lr_warmup_samples` as `lr_warmup_iters` based on scheduler is being used for training." + ) + self.lr_warmup_samples = 0 + + self.megatron_lm_default_args["train_iters"] = self.train_iters + self.megatron_lm_default_args["lr_warmup_iters"] = self.lr_warmup_iters + self.megatron_lm_default_args["train_samples"] = self.train_samples + self.megatron_lm_default_args["lr_warmup_samples"] = self.lr_warmup_samples + self.megatron_lm_default_args["lr_decay_iters"] = self.lr_decay_iters + self.megatron_lm_default_args["lr_decay_samples"] = self.lr_decay_samples + self.megatron_lm_default_args["lr_warmup_fraction"] = self.lr_warmup_fraction + self.megatron_lm_default_args["lr_decay_style"] = self.lr_decay_style + self.megatron_lm_default_args["weight_decay_incr_style"] = self.weight_decay_incr_style + self.megatron_lm_default_args["start_weight_decay"] = self.start_weight_decay + self.megatron_lm_default_args["end_weight_decay"] = self.end_weight_decay + self.megatron_lm_default_args["min_lr"] = self.min_lr + + def set_tensorboard_logging_options(self): + from megatron.arguments import _add_logging_args + + parser = argparse.ArgumentParser() + parser = _add_logging_args(parser) + logging_args = parser.parse_known_args() + self.dataset_args = vars(logging_args[0]) + for key, value in self.dataset_args.items(): + if key.startswith("log_"): + self.megatron_lm_default_args[key] = True + elif key.startswith("no_log_"): + self.megatron_lm_default_args[key.replace("no_", "")] = True + + +@dataclass +class BnbQuantizationConfig: + """ + A plugin to enable BitsAndBytes 4bit and 8bit quantization + """ + + load_in_8bit: bool = field(default=False, metadata={"help": "enable 8bit quantization."}) + + llm_int8_threshold: float = field( + default=6.0, metadata={"help": "value of the outliner threshold. only relevant when load_in_8bit=True"} + ) + + load_in_4bit: bool = field(default=False, metadata={"help": "enable 4bit quantization."}) + + bnb_4bit_quant_type: str = field( + default="fp4", + metadata={ + "help": "set the quantization data type in the `bnb.nn.Linear4Bit` layers. Options are {'fp4','np4'}." + }, + ) + + bnb_4bit_use_double_quant: bool = field( + default=False, + metadata={ + "help": "enable nested quantization where the quantization constants from the first quantization are quantized again." + }, + ) + + bnb_4bit_compute_dtype: bool = field( + default="fp16", + metadata={ + "help": "This sets the computational type which might be different than the input time. For example, inputs might be " + "fp32, but computation can be set to bf16 for speedups. Options are {'fp32','fp16','bf16'}." + }, + ) + + torch_dtype: torch.dtype = field( + default=None, + metadata={ + "help": "this sets the dtype of the remaining non quantized layers. `bitsandbytes` library suggests to set the value" + "to `torch.float16` for 8 bit model and use the same dtype as the compute dtype for 4 bit model " + }, + ) + + skip_modules: List[str] = field( + default=None, + metadata={ + "help": "an explicit list of the modules that we don't quantize. The dtype of these modules will be `torch_dtype`." + }, + ) + + keep_in_fp32_modules: List[str] = field( + default=None, + metadata={"help": "an explicit list of the modules that we don't quantize. We keep them in `torch.float32`."}, + ) + + def __post_init__(self): + """ + Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. + """ + if not isinstance(self.load_in_8bit, bool): + raise ValueError("load_in_8bit must be a boolean") + + if not isinstance(self.load_in_4bit, bool): + raise ValueError("load_in_4bit must be a boolean") + + if self.load_in_4bit and self.load_in_8bit: + raise ValueError("load_in_4bit and load_in_8 can't be both True") + + if not self.load_in_4bit and not self.load_in_8bit: + raise ValueError("load_in_4bit and load_in_8 can't be both False") + + if not isinstance(self.llm_int8_threshold, (int, float)): + raise ValueError("llm_int8_threshold must be a float or an int") + + if not isinstance(self.bnb_4bit_quant_type, str): + raise ValueError("bnb_4bit_quant_type must be a string") + elif self.bnb_4bit_quant_type not in ["fp4", "nf4"]: + raise ValueError(f"bnb_4bit_quant_type must be in ['fp4','nf4'] but found {self.bnb_4bit_quant_type}") + + if not isinstance(self.bnb_4bit_use_double_quant, bool): + raise ValueError("bnb_4bit_use_double_quant must be a boolean") + + if isinstance(self.bnb_4bit_compute_dtype, str): + if self.bnb_4bit_compute_dtype == "fp32": + self.bnb_4bit_compute_dtype = torch.float32 + elif self.bnb_4bit_compute_dtype == "fp16": + self.bnb_4bit_compute_dtype = torch.float16 + elif self.bnb_4bit_compute_dtype == "bf16": + self.bnb_4bit_compute_dtype = torch.bfloat16 + else: + raise ValueError( + f"bnb_4bit_compute_dtype must be in ['fp32','fp16','bf16'] but found {self.bnb_4bit_compute_dtype}" + ) + elif not isinstance(self.bnb_4bit_compute_dtype, torch.dtype): + raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype") + + if self.skip_modules is not None and not isinstance(self.skip_modules, list): + raise ValueError("skip_modules must be a list of strings") + + if self.keep_in_fp32_modules is not None and not isinstance(self.keep_in_fp32_modules, list): + raise ValueError("keep_in_fp_32_modules must be a list of strings") + + if self.load_in_4bit: + self.target_dtype = CustomDtype.INT4 + + if self.load_in_8bit: + self.target_dtype = torch.int8 + + if self.load_in_4bit and self.llm_int8_threshold != 6.0: + warnings.warn("llm_int8_threshold can only be used for model loaded in 8bit") + + if isinstance(self.torch_dtype, str): + if self.torch_dtype == "fp32": + self.torch_dtype = torch.float32 + elif self.torch_dtype == "fp16": + self.torch_dtype = torch.float16 + elif self.torch_dtype == "bf16": + self.torch_dtype = torch.bfloat16 + else: + raise ValueError(f"torch_dtype must be in ['fp32','fp16','bf16'] but found {self.torch_dtype}") + if self.load_in_8bit and self.torch_dtype is None: + self.torch_dtype = torch.float16 + + if self.load_in_4bit and self.torch_dtype is None: + self.torch_dtype = self.bnb_4bit_compute_dtype + + if not isinstance(self.torch_dtype, torch.dtype): + raise ValueError("torch_dtype must be a torch.dtype") diff --git a/llm/Lib/site-packages/accelerate/utils/deepspeed.py b/llm/Lib/site-packages/accelerate/utils/deepspeed.py new file mode 100644 index 0000000000000000000000000000000000000000..fe5a63fc7314d42f68baae41cf56f9abc94237a0 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/utils/deepspeed.py @@ -0,0 +1,271 @@ +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import json +import os +from copy import deepcopy + +from ..optimizer import AcceleratedOptimizer +from ..scheduler import AcceleratedScheduler + + +class HfDeepSpeedConfig: + """ + This object contains a DeepSpeed configuration dictionary and can be quickly queried for things like zero stage. + + A `weakref` of this object is stored in the module's globals to be able to access the config from areas where + things like the Trainer object is not available (e.g. `from_pretrained` and `_get_resized_embeddings`). Therefore + it's important that this object remains alive while the program is still running. + + [`Trainer`] uses the `HfTrainerDeepSpeedConfig` subclass instead. That subclass has logic to sync the configuration + with values of [`TrainingArguments`] by replacing special placeholder values: `"auto"`. Without this special logic + the DeepSpeed configuration is not modified in any way. + + Args: + config_file_or_dict (`Union[str, Dict]`): path to DeepSpeed config file or dict. + + """ + + def __init__(self, config_file_or_dict): + if isinstance(config_file_or_dict, dict): + # Don't modify user's data should they want to reuse it (e.g. in tests), because once we + # modified it, it will not be accepted here again, since `auto` values would have been overridden + config = deepcopy(config_file_or_dict) + elif os.path.exists(config_file_or_dict): + with open(config_file_or_dict, encoding="utf-8") as f: + config = json.load(f) + else: + try: + config_decoded = base64.urlsafe_b64decode(config_file_or_dict).decode("utf-8") + config = json.loads(config_decoded) + except (UnicodeDecodeError, AttributeError, ValueError): + raise ValueError( + f"Expected a string path to an existing deepspeed config, or a dictionary, or a base64 encoded string. Received: {config_file_or_dict}" + ) + + self.config = config + + self.set_stage_and_offload() + + def set_stage_and_offload(self): + # zero stage - this is done as early as possible, before model is created, to allow + # ``is_deepspeed_zero3_enabled`` query and getting to the early deepspeed config object + # during ``zero.Init()`` which needs to know the dtype, and some other hparams. + self._stage = self.get_value("zero_optimization.stage", -1) + + # offload + self._offload = False + if self.is_zero2() or self.is_zero3(): + offload_devices_valid = set(["cpu", "nvme"]) + offload_devices = set( + [ + self.get_value("zero_optimization.offload_optimizer.device"), + self.get_value("zero_optimization.offload_param.device"), + ] + ) + if len(offload_devices & offload_devices_valid) > 0: + self._offload = True + + def find_config_node(self, ds_key_long): + config = self.config + + # find the config node of interest if it exists + nodes = ds_key_long.split(".") + ds_key = nodes.pop() + for node in nodes: + config = config.get(node) + if config is None: + return None, ds_key + + return config, ds_key + + def get_value(self, ds_key_long, default=None): + """ + Returns the set value or `default` if no value is set + """ + config, ds_key = self.find_config_node(ds_key_long) + if config is None: + return default + return config.get(ds_key, default) + + def del_config_sub_tree(self, ds_key_long, must_exist=False): + """ + Deletes a sub-section of the config file if it's found. + + Unless `must_exist` is `True` the section doesn't have to exist. + """ + config = self.config + + # find the config node of interest if it exists + nodes = ds_key_long.split(".") + for node in nodes: + parent_config = config + config = config.get(node) + if config is None: + if must_exist: + raise ValueError(f"Can't find {ds_key_long} entry in the config: {self.config}") + else: + return + + # if found remove it + if parent_config is not None: + parent_config.pop(node) + + def is_true(self, ds_key_long): + """ + Returns `True`/``False` only if the value is set, always `False` otherwise. So use this method to ask the very + specific question of whether the value is set to `True` (and it's not set to `False`` or isn't set). + + """ + value = self.get_value(ds_key_long) + return False if value is None else bool(value) + + def is_false(self, ds_key_long): + """ + Returns `True`/``False` only if the value is set, always `False` otherwise. So use this method to ask the very + specific question of whether the value is set to `False` (and it's not set to `True`` or isn't set). + """ + value = self.get_value(ds_key_long) + return False if value is None else not bool(value) + + def is_zero2(self): + return self._stage == 2 + + def is_zero3(self): + return self._stage == 3 + + def is_offload(self): + return self._offload + + +class DeepSpeedEngineWrapper: + """ + Internal wrapper for deepspeed.runtime.engine.DeepSpeedEngine. This is used to follow conventional training loop. + + Args: + engine (deepspeed.runtime.engine.DeepSpeedEngine): deepspeed engine to wrap + """ + + def __init__(self, engine): + self.engine = engine + + def backward(self, loss, **kwargs): + # runs backpropagation and handles mixed precision + self.engine.backward(loss, **kwargs) + + # Deepspeed's `engine.step` performs the following operations: + # - gradient accumulation check + # - gradient clipping + # - optimizer step + # - zero grad + # - checking overflow + # - lr_scheduler step (only if engine.lr_scheduler is not None) + self.engine.step() + # and this plugin overrides the above calls with no-ops when Accelerate runs under + # Deepspeed, but allows normal functionality for non-Deepspeed cases thus enabling a simple + # training loop that works transparently under many training regimes. + + +class DeepSpeedOptimizerWrapper(AcceleratedOptimizer): + """ + Internal wrapper around a deepspeed optimizer. + + Args: + optimizer (`torch.optim.optimizer.Optimizer`): + The optimizer to wrap. + """ + + def __init__(self, optimizer): + super().__init__(optimizer, device_placement=False, scaler=None) + self.__has_overflow__ = hasattr(self.optimizer, "overflow") + + def zero_grad(self, set_to_none=None): + pass # `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed + + def step(self): + pass # `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed + + @property + def step_was_skipped(self): + """Whether or not the optimizer step was done, or skipped because of gradient overflow.""" + if self.__has_overflow__: + return self.optimizer.overflow + return False + + +class DeepSpeedSchedulerWrapper(AcceleratedScheduler): + """ + Internal wrapper around a deepspeed scheduler. + + Args: + scheduler (`torch.optim.lr_scheduler.LambdaLR`): + The scheduler to wrap. + optimizers (one or a list of `torch.optim.Optimizer`): + """ + + def __init__(self, scheduler, optimizers): + super().__init__(scheduler, optimizers) + + def step(self): + pass # `accelerator.backward(loss)` is doing that automatically. Therefore, its implementation is not needed + + +class DummyOptim: + """ + Dummy optimizer presents model parameters or param groups, this is primarily used to follow conventional training + loop when optimizer config is specified in the deepspeed config file. + + Args: + lr (float): + Learning rate. + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + weight_decay (float): + Weight decay. + **kwargs (additional keyword arguments, *optional*): + Other arguments. + """ + + def __init__(self, params, lr=0.001, weight_decay=0, **kwargs): + self.params = params + self.lr = lr + self.weight_decay = weight_decay + self.kwargs = kwargs + + +class DummyScheduler: + """ + Dummy scheduler presents model parameters or param groups, this is primarily used to follow conventional training + loop when scheduler config is specified in the deepspeed config file. + + Args: + optimizer (`torch.optim.optimizer.Optimizer`): + The optimizer to wrap. + total_num_steps (int, *optional*): + Total number of steps. + warmup_num_steps (int, *optional*): + Number of steps for warmup. + lr_scheduler_callable (callable, *optional*): + A callable function that creates an LR Scheduler. It accepts only one argument `optimizer`. + **kwargs (additional keyword arguments, *optional*): + Other arguments. + """ + + def __init__(self, optimizer, total_num_steps=None, warmup_num_steps=0, lr_scheduler_callable=None, **kwargs): + self.optimizer = optimizer + self.total_num_steps = total_num_steps + self.warmup_num_steps = warmup_num_steps + self.lr_scheduler_callable = lr_scheduler_callable + self.kwargs = kwargs diff --git a/llm/Lib/site-packages/accelerate/utils/environment.py b/llm/Lib/site-packages/accelerate/utils/environment.py new file mode 100644 index 0000000000000000000000000000000000000000..9fdbd323632769146188cac1e91d08ab1e2ba617 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/utils/environment.py @@ -0,0 +1,274 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import math +import os +import platform +import subprocess +import sys +from dataclasses import dataclass, field +from functools import lru_cache +from shutil import which +from typing import List, Optional + +import torch +from packaging.version import parse + + +logger = logging.getLogger(__name__) + + +def convert_dict_to_env_variables(current_env: dict): + """ + Verifies that all keys and values in `current_env` do not contain illegal keys or values, and returns a list of + strings as the result. + + Example: + ```python + >>> from accelerate.utils.environment import verify_env + + >>> env = {"ACCELERATE_DEBUG_MODE": "1", "BAD_ENV_NAME": ">> valid_env_items = verify_env(env) + >>> print(valid_env_items) + ["ACCELERATE_DEBUG_MODE=1\n", "OTHER_ENV=2\n"] + ``` + """ + forbidden_chars = [";", "\n", "<", ">", " "] + valid_env_items = [] + for key, value in current_env.items(): + if all(char not in (key + value) for char in forbidden_chars) and len(key) >= 1 and len(value) >= 1: + valid_env_items.append(f"{key}={value}\n") + else: + logger.warning(f"WARNING: Skipping {key}={value} as it contains forbidden characters or missing values.") + return valid_env_items + + +def str_to_bool(value) -> int: + """ + Converts a string representation of truth to `True` (1) or `False` (0). + + True values are `y`, `yes`, `t`, `true`, `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`; + """ + value = value.lower() + if value in ("y", "yes", "t", "true", "on", "1"): + return 1 + elif value in ("n", "no", "f", "false", "off", "0"): + return 0 + else: + raise ValueError(f"invalid truth value {value}") + + +def get_int_from_env(env_keys, default): + """Returns the first positive env value found in the `env_keys` list or the default.""" + for e in env_keys: + val = int(os.environ.get(e, -1)) + if val >= 0: + return val + return default + + +def parse_flag_from_env(key, default=False): + """Returns truthy value for `key` from the env if available else the default.""" + value = os.environ.get(key, str(default)) + return str_to_bool(value) == 1 # As its name indicates `str_to_bool` actually returns an int... + + +def parse_choice_from_env(key, default="no"): + value = os.environ.get(key, str(default)) + return value + + +def are_libraries_initialized(*library_names: str) -> List[str]: + """ + Checks if any of `library_names` are imported in the environment. Will return any names that are. + """ + return [lib_name for lib_name in library_names if lib_name in sys.modules.keys()] + + +def _nvidia_smi(): + """ + Returns the right nvidia-smi command based on the system. + """ + if platform.system() == "Windows": + # If platform is Windows and nvidia-smi can't be found in path + # try from systemd drive with default installation path + command = which("nvidia-smi") + if command is None: + command = "%s\\Program Files\\NVIDIA Corporation\\NVSMI\\nvidia-smi.exe" % os.environ["systemdrive"] + else: + command = "nvidia-smi" + return command + + +def get_gpu_info(): + """ + Gets GPU count and names using `nvidia-smi` instead of torch to not initialize CUDA. + + Largely based on the `gputil` library. + """ + # Returns as list of `n` GPUs and their names + output = subprocess.check_output( + [_nvidia_smi(), "--query-gpu=count,name", "--format=csv,noheader"], universal_newlines=True + ) + output = output.strip() + gpus = output.split(os.linesep) + # Get names from output + gpu_count = len(gpus) + gpu_names = [gpu.split(",")[1].strip() for gpu in gpus] + return gpu_names, gpu_count + + +def get_driver_version(): + """ + Returns the driver version + + In the case of multiple GPUs, will return the first. + """ + output = subprocess.check_output( + [_nvidia_smi(), "--query-gpu=driver_version", "--format=csv,noheader"], universal_newlines=True + ) + output = output.strip() + return output.split(os.linesep)[0] + + +def check_cuda_p2p_ib_support(): + """ + Checks if the devices being used have issues with P2P and IB communications, namely any consumer GPU hardware after + the 3090. + + Noteably uses `nvidia-smi` instead of torch to not initialize CUDA. + """ + try: + device_names, device_count = get_gpu_info() + # As new consumer GPUs get released, add them to `unsupported_devices`` + unsupported_devices = {"RTX 40"} + if device_count > 1: + if any( + unsupported_device in device_name + for device_name in device_names + for unsupported_device in unsupported_devices + ): + # Check if they have the right driver version + acceptable_driver_version = "550.40.07" + current_driver_version = get_driver_version() + if parse(current_driver_version) < parse(acceptable_driver_version): + return False + return True + except Exception: + pass + return True + + +def check_fp8_capability(): + """ + Checks if all the current GPUs available support FP8. + + Notably must initialize `torch.cuda` to check. + """ + cuda_device_capacity = torch.cuda.get_device_capability() + return cuda_device_capacity >= (8, 9) + + +@dataclass +class CPUInformation: + """ + Stores information about the CPU in a distributed environment. It contains the following attributes: + - rank: The rank of the current process. + - world_size: The total number of processes in the world. + - local_rank: The rank of the current process on the local node. + - local_world_size: The total number of processes on the local node. + """ + + rank: int = field(default=0, metadata={"help": "The rank of the current process."}) + world_size: int = field(default=1, metadata={"help": "The total number of processes in the world."}) + local_rank: int = field(default=0, metadata={"help": "The rank of the current process on the local node."}) + local_world_size: int = field(default=1, metadata={"help": "The total number of processes on the local node."}) + + +def get_cpu_distributed_information() -> CPUInformation: + """ + Returns various information about the environment in relation to CPU distributed training as a `CPUInformation` + dataclass. + """ + information = {} + information["rank"] = get_int_from_env(["RANK", "PMI_RANK", "OMPI_COMM_WORLD_RANK", "MV2_COMM_WORLD_RANK"], 0) + information["world_size"] = get_int_from_env( + ["WORLD_SIZE", "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE"], 1 + ) + information["local_rank"] = get_int_from_env( + ["LOCAL_RANK", "MPI_LOCALRANKID", "OMPI_COMM_WORLD_LOCAL_RANK", "MV2_COMM_WORLD_LOCAL_RANK"], 0 + ) + information["local_world_size"] = get_int_from_env( + ["LOCAL_WORLD_SIZE", "MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"], + 1, + ) + return CPUInformation(**information) + + +def override_numa_affinity(local_process_index: int, verbose: Optional[bool] = None) -> None: + """ + Overrides whatever NUMA affinity is set for the current process. This is very taxing and requires recalculating the + affinity to set, ideally you should use `utils.environment.set_numa_affinity` instead. + + Args: + local_process_index (int): + The index of the current process on the current server. + verbose (bool, *optional*): + Whether to log out the assignment of each CPU. If `ACCELERATE_DEBUG_MODE` is enabled, will default to True. + """ + if verbose is None: + verbose = parse_flag_from_env("ACCELERATE_DEBUG_MODE", False) + if torch.cuda.is_available(): + from accelerate.utils import is_pynvml_available + + if not is_pynvml_available(): + raise ImportError( + "To set CPU affinity on CUDA GPUs the `pynvml` package must be available. (`pip install pynvml`)" + ) + import pynvml as nvml + + # The below code is based on https://github.com/NVIDIA/DeepLearningExamples/blob/master/TensorFlow2/LanguageModeling/BERT/gpu_affinity.py + nvml.nvmlInit() + num_elements = math.ceil(os.cpu_count() / 64) + handle = nvml.nvmlDeviceGetHandleByIndex(local_process_index) + affinity_string = "" + for j in nvml.nvmlDeviceGetCpuAffinity(handle, num_elements): + # assume nvml returns list of 64 bit ints + affinity_string = f"{j:064b}{affinity_string}" + affinity_list = [int(x) for x in affinity_string] + affinity_list.reverse() # so core 0 is the 0th element + affinity_to_set = [i for i, e in enumerate(affinity_list) if e != 0] + os.sched_setaffinity(0, affinity_to_set) + if verbose: + cpu_cores = os.sched_getaffinity(0) + logger.info(f"Assigning {len(cpu_cores)} cpu cores to process {local_process_index}: {cpu_cores}") + + +@lru_cache +def set_numa_affinity(local_process_index: int, verbose: Optional[bool] = None) -> None: + """ + Assigns the current process to a specific NUMA node. Ideally most efficient when having at least 2 cpus per node. + + This result is cached between calls. If you want to override it, please use + `accelerate.utils.environment.override_numa_afifnity`. + + Args: + local_process_index (int): + The index of the current process on the current server. + verbose (bool, *optional*): + Whether to print the new cpu cores assignment for each process. If `ACCELERATE_DEBUG_MODE` is enabled, will + default to True. + """ + override_numa_affinity(local_process_index=local_process_index, verbose=verbose) diff --git a/llm/Lib/site-packages/accelerate/utils/fsdp_utils.py b/llm/Lib/site-packages/accelerate/utils/fsdp_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..01bb54b262b7f00b4bfb0933fc5fe94b24146097 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/utils/fsdp_utils.py @@ -0,0 +1,209 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +import torch + +from ..logging import get_logger +from .constants import FSDP_MODEL_NAME, FSDP_PYTORCH_VERSION, OPTIMIZER_NAME +from .imports import is_torch_distributed_available +from .modeling import is_peft_model +from .versions import is_torch_version + + +if is_torch_version(">=", FSDP_PYTORCH_VERSION) and is_torch_distributed_available(): + import torch.distributed.checkpoint as dist_cp + from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner, DefaultSavePlanner + from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType + + +logger = get_logger(__name__) + + +def _get_model_state_dict(model, adapter_only=False): + if adapter_only and is_peft_model(model): + from peft import get_peft_model_state_dict + + return get_peft_model_state_dict(model, adapter_name=model.active_adapter) + else: + return model.state_dict() + + +def _set_model_state_dict(model, state_dict, adapter_only=False): + if adapter_only and is_peft_model(model): + from peft import set_peft_model_state_dict + + return set_peft_model_state_dict(model, state_dict, adapter_name=model.active_adapter) + else: + return model.load_state_dict(state_dict) + + +def save_fsdp_model(fsdp_plugin, accelerator, model, output_dir, model_index=0, adapter_only=False): + os.makedirs(output_dir, exist_ok=True) + + if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT: + # FSDP raises error when single GPU is used with `offload_to_cpu=True` for FULL_STATE_DICT + # so, only enable it when num_processes>1 + is_multi_process = accelerator.num_processes > 1 + fsdp_plugin.state_dict_config.offload_to_cpu = is_multi_process + fsdp_plugin.state_dict_config.rank0_only = is_multi_process + + with FSDP.state_dict_type( + model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config + ): + state_dict = _get_model_state_dict(model, adapter_only=adapter_only) + if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT: + weights_name = f"{FSDP_MODEL_NAME}.bin" if model_index == 0 else f"{FSDP_MODEL_NAME}_{model_index}.bin" + output_model_file = os.path.join(output_dir, weights_name) + if accelerator.process_index == 0: + logger.info(f"Saving model to {output_model_file}") + torch.save(state_dict, output_model_file) + logger.info(f"Model saved to {output_model_file}") + elif fsdp_plugin.state_dict_type == StateDictType.LOCAL_STATE_DICT: + weights_name = ( + f"{FSDP_MODEL_NAME}_rank{accelerator.process_index}.bin" + if model_index == 0 + else f"{FSDP_MODEL_NAME}_{model_index}_rank{accelerator.process_index}.bin" + ) + output_model_file = os.path.join(output_dir, weights_name) + logger.info(f"Saving model to {output_model_file}") + torch.save(state_dict, output_model_file) + logger.info(f"Model saved to {output_model_file}") + elif fsdp_plugin.state_dict_type == StateDictType.SHARDED_STATE_DICT: + ckpt_dir = os.path.join(output_dir, f"{FSDP_MODEL_NAME}_{model_index}") + os.makedirs(ckpt_dir, exist_ok=True) + logger.info(f"Saving model to {ckpt_dir}") + state_dict = {"model": state_dict} + + dist_cp.save_state_dict( + state_dict=state_dict, + storage_writer=dist_cp.FileSystemWriter(ckpt_dir), + planner=DefaultSavePlanner(), + ) + logger.info(f"Model saved to {ckpt_dir}") + + +def load_fsdp_model(fsdp_plugin, accelerator, model, input_dir, model_index=0, adapter_only=False): + accelerator.wait_for_everyone() + if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT: + # FSDP raises error when single GPU is used with `offload_to_cpu=True` for FULL_STATE_DICT + # so, only enable it when num_processes>1 + is_multi_process = accelerator.num_processes > 1 + fsdp_plugin.state_dict_config.offload_to_cpu = is_multi_process + fsdp_plugin.state_dict_config.rank0_only = is_multi_process + with FSDP.state_dict_type( + model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config + ): + if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT: + if type(model) != FSDP and accelerator.process_index != 0: + if not fsdp_plugin.sync_module_states: + raise ValueError( + "Set the `sync_module_states` flag to `True` so that model states are synced across processes when " + "initializing FSDP object" + ) + return + weights_name = f"{FSDP_MODEL_NAME}.bin" if model_index == 0 else f"{FSDP_MODEL_NAME}_{model_index}.bin" + input_model_file = os.path.join(input_dir, weights_name) + logger.info(f"Loading model from {input_model_file}") + state_dict = torch.load(input_model_file) + logger.info(f"Model loaded from {input_model_file}") + elif fsdp_plugin.state_dict_type == StateDictType.LOCAL_STATE_DICT: + weights_name = ( + f"{FSDP_MODEL_NAME}_rank{accelerator.process_index}.bin" + if model_index == 0 + else f"{FSDP_MODEL_NAME}_{model_index}_rank{accelerator.process_index}.bin" + ) + input_model_file = os.path.join(input_dir, weights_name) + logger.info(f"Loading model from {input_model_file}") + state_dict = torch.load(input_model_file) + logger.info(f"Model loaded from {input_model_file}") + elif fsdp_plugin.state_dict_type == StateDictType.SHARDED_STATE_DICT: + ckpt_dir = ( + os.path.join(input_dir, f"{FSDP_MODEL_NAME}_{model_index}") + if f"{FSDP_MODEL_NAME}" not in input_dir + else input_dir + ) + logger.info(f"Loading model from {ckpt_dir}") + state_dict = {"model": _get_model_state_dict(model, adapter_only=adapter_only)} + dist_cp.load_state_dict( + state_dict=state_dict, + storage_reader=dist_cp.FileSystemReader(ckpt_dir), + planner=DefaultLoadPlanner(), + ) + state_dict = state_dict["model"] + logger.info(f"Model loaded from {ckpt_dir}") + load_result = _set_model_state_dict(model, state_dict, adapter_only=adapter_only) + return load_result + + +def save_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, output_dir, optimizer_index=0): + os.makedirs(output_dir, exist_ok=True) + with FSDP.state_dict_type( + model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config + ): + optim_state = FSDP.optim_state_dict(model, optimizer) + if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT: + if accelerator.process_index == 0: + optim_state_name = ( + f"{OPTIMIZER_NAME}.bin" if optimizer_index == 0 else f"{OPTIMIZER_NAME}_{optimizer_index}.bin" + ) + output_optimizer_file = os.path.join(output_dir, optim_state_name) + logger.info(f"Saving Optimizer state to {output_optimizer_file}") + torch.save(optim_state, output_optimizer_file) + logger.info(f"Optimizer state saved in {output_optimizer_file}") + else: + ckpt_dir = os.path.join(output_dir, f"{OPTIMIZER_NAME}_{optimizer_index}") + os.makedirs(ckpt_dir, exist_ok=True) + logger.info(f"Saving Optimizer state to {ckpt_dir}") + dist_cp.save_state_dict( + state_dict={"optimizer": optim_state}, + storage_writer=dist_cp.FileSystemWriter(ckpt_dir), + planner=DefaultSavePlanner(), + ) + logger.info(f"Optimizer state saved in {ckpt_dir}") + + +def load_fsdp_optimizer(fsdp_plugin, accelerator, optimizer, model, input_dir, optimizer_index=0, adapter_only=False): + accelerator.wait_for_everyone() + with FSDP.state_dict_type( + model, fsdp_plugin.state_dict_type, fsdp_plugin.state_dict_config, fsdp_plugin.optim_state_dict_config + ): + if fsdp_plugin.state_dict_type == StateDictType.FULL_STATE_DICT: + optim_state = None + if accelerator.process_index == 0 or not fsdp_plugin.optim_state_dict_config.rank0_only: + optimizer_name = ( + f"{OPTIMIZER_NAME}.bin" if optimizer_index == 0 else f"{OPTIMIZER_NAME}_{optimizer_index}.bin" + ) + input_optimizer_file = os.path.join(input_dir, optimizer_name) + logger.info(f"Loading Optimizer state from {input_optimizer_file}") + optim_state = torch.load(input_optimizer_file) + logger.info(f"Optimizer state loaded from {input_optimizer_file}") + else: + ckpt_dir = ( + os.path.join(input_dir, f"{OPTIMIZER_NAME}_{optimizer_index}") + if f"{OPTIMIZER_NAME}" not in input_dir + else input_dir + ) + logger.info(f"Loading Optimizer from {ckpt_dir}") + optim_state = load_sharded_optimizer_state_dict( + model_state_dict=_get_model_state_dict(model, adapter_only=adapter_only), + optimizer_key="optimizer", + storage_reader=dist_cp.FileSystemReader(ckpt_dir), + ) + optim_state = optim_state["optimizer"] + logger.info(f"Optimizer loaded from {ckpt_dir}") + flattened_osd = FSDP.optim_state_dict_to_load(model=model, optim=optimizer, optim_state_dict=optim_state) + optimizer.load_state_dict(flattened_osd) diff --git a/llm/Lib/site-packages/accelerate/utils/imports.py b/llm/Lib/site-packages/accelerate/utils/imports.py new file mode 100644 index 0000000000000000000000000000000000000000..1ef57c05d46cd2db4a21854f477c47048f71c2e0 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/utils/imports.py @@ -0,0 +1,385 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import importlib.metadata +import os +import warnings +from functools import lru_cache + +import torch +from packaging import version +from packaging.version import parse + +from .environment import parse_flag_from_env, str_to_bool +from .versions import compare_versions, is_torch_version + + +# Try to run Torch native job in an environment with TorchXLA installed by setting this value to 0. +USE_TORCH_XLA = parse_flag_from_env("USE_TORCH_XLA", default=True) + +_torch_xla_available = False +if USE_TORCH_XLA: + try: + import torch_xla.core.xla_model as xm # noqa: F401 + import torch_xla.runtime + + _torch_xla_available = True + except ImportError: + pass + +# Keep it for is_tpu_available. It will be removed along with is_tpu_available. +_tpu_available = _torch_xla_available + +# Cache this result has it's a C FFI call which can be pretty time-consuming +_torch_distributed_available = torch.distributed.is_available() + + +def _is_package_available(pkg_name, metadata_name=None): + # Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version + package_exists = importlib.util.find_spec(pkg_name) is not None + if package_exists: + try: + # Some libraries have different names in the metadata + _ = importlib.metadata.metadata(pkg_name if metadata_name is None else metadata_name) + return True + except importlib.metadata.PackageNotFoundError: + return False + + +def is_torch_distributed_available() -> bool: + return _torch_distributed_available + + +def is_ccl_available(): + try: + pass + except ImportError: + print( + "Intel(R) oneCCL Bindings for PyTorch* is required to run DDP on Intel(R) GPUs, but it is not" + " detected. If you see \"ValueError: Invalid backend: 'ccl'\" error, please install Intel(R) oneCCL" + " Bindings for PyTorch*." + ) + return ( + importlib.util.find_spec("torch_ccl") is not None + or importlib.util.find_spec("oneccl_bindings_for_pytorch") is not None + ) + + +def get_ccl_version(): + return importlib.metadata.version("oneccl_bind_pt") + + +def is_pynvml_available(): + return _is_package_available("pynvml") + + +def is_msamp_available(): + return _is_package_available("msamp", "ms-amp") + + +def is_transformer_engine_available(): + return _is_package_available("transformer_engine") + + +def is_fp8_available(): + return is_msamp_available() or is_transformer_engine_available() + + +def is_cuda_available(): + """ + Checks if `cuda` is available via an `nvml-based` check which won't trigger the drivers and leave cuda + uninitialized. + """ + pytorch_nvml_based_cuda_check_previous_value = os.environ.get("PYTORCH_NVML_BASED_CUDA_CHECK") + try: + os.environ["PYTORCH_NVML_BASED_CUDA_CHECK"] = str(1) + available = torch.cuda.is_available() + finally: + if pytorch_nvml_based_cuda_check_previous_value: + os.environ["PYTORCH_NVML_BASED_CUDA_CHECK"] = pytorch_nvml_based_cuda_check_previous_value + else: + os.environ.pop("PYTORCH_NVML_BASED_CUDA_CHECK", None) + + return available + + +@lru_cache +def is_tpu_available(check_device=True): + "Checks if `torch_xla` is installed and potentially if a TPU is in the environment" + warnings.warn( + "`is_tpu_available` is deprecated and will be removed in v0.27.0. " + "Please use the `is_torch_xla_available` instead.", + FutureWarning, + ) + # Due to bugs on the amp series GPUs, we disable torch-xla on them + if is_cuda_available(): + return False + if check_device: + if _tpu_available: + try: + # Will raise a RuntimeError if no XLA configuration is found + _ = xm.xla_device() + return True + except RuntimeError: + return False + return _tpu_available + + +@lru_cache +def is_torch_xla_available(check_is_tpu=False, check_is_gpu=False): + """ + Check if `torch_xla` is available. To train a native pytorch job in an environment with torch xla installed, set + the USE_TORCH_XLA to false. + """ + assert not (check_is_tpu and check_is_gpu), "The check_is_tpu and check_is_gpu cannot both be true." + + if not _torch_xla_available: + return False + elif check_is_gpu: + return torch_xla.runtime.device_type() in ["GPU", "CUDA"] + elif check_is_tpu: + return torch_xla.runtime.device_type() == "TPU" + + return True + + +def is_deepspeed_available(): + if is_mlu_available(): + return _is_package_available("deepspeed", metadata_name="deepspeed-mlu") + return _is_package_available("deepspeed") + + +def is_pippy_available(): + package_exists = _is_package_available("pippy", "torchpippy") + if package_exists: + pippy_version = version.parse(importlib.metadata.version("torchpippy")) + return compare_versions(pippy_version, ">", "0.1.1") + return False + + +def is_bf16_available(ignore_tpu=False): + "Checks if bf16 is supported, optionally ignoring the TPU" + if is_torch_xla_available(check_is_tpu=True): + return not ignore_tpu + if is_cuda_available(): + return torch.cuda.is_bf16_supported() + return True + + +def is_4bit_bnb_available(): + package_exists = _is_package_available("bitsandbytes") + if package_exists: + bnb_version = version.parse(importlib.metadata.version("bitsandbytes")) + return compare_versions(bnb_version, ">=", "0.39.0") + return False + + +def is_8bit_bnb_available(): + package_exists = _is_package_available("bitsandbytes") + if package_exists: + bnb_version = version.parse(importlib.metadata.version("bitsandbytes")) + return compare_versions(bnb_version, ">=", "0.37.2") + return False + + +def is_bnb_available(): + return _is_package_available("bitsandbytes") + + +def is_megatron_lm_available(): + if str_to_bool(os.environ.get("ACCELERATE_USE_MEGATRON_LM", "False")) == 1: + package_exists = importlib.util.find_spec("megatron") is not None + if package_exists: + try: + megatron_version = parse(importlib.metadata.version("megatron-lm")) + return compare_versions(megatron_version, ">=", "2.2.0") + except Exception as e: + warnings.warn(f"Parse Megatron version failed. Exception:{e}") + return False + + +def is_transformers_available(): + return _is_package_available("transformers") + + +def is_datasets_available(): + return _is_package_available("datasets") + + +def is_peft_available(): + return _is_package_available("peft") + + +def is_timm_available(): + return _is_package_available("timm") + + +def is_aim_available(): + package_exists = _is_package_available("aim") + if package_exists: + aim_version = version.parse(importlib.metadata.version("aim")) + return compare_versions(aim_version, "<", "4.0.0") + return False + + +def is_tensorboard_available(): + return _is_package_available("tensorboard") or _is_package_available("tensorboardX") + + +def is_wandb_available(): + return _is_package_available("wandb") + + +def is_comet_ml_available(): + return _is_package_available("comet_ml") + + +def is_boto3_available(): + return _is_package_available("boto3") + + +def is_rich_available(): + if _is_package_available("rich"): + if "ACCELERATE_DISABLE_RICH" in os.environ: + warnings.warn( + "`ACCELERATE_DISABLE_RICH` is deprecated and will be removed in v0.22.0 and deactivated by default. Please use `ACCELERATE_ENABLE_RICH` if you wish to use `rich`." + ) + return not parse_flag_from_env("ACCELERATE_DISABLE_RICH", False) + return parse_flag_from_env("ACCELERATE_ENABLE_RICH", False) + return False + + +def is_sagemaker_available(): + return _is_package_available("sagemaker") + + +def is_tqdm_available(): + return _is_package_available("tqdm") + + +def is_clearml_available(): + return _is_package_available("clearml") + + +def is_pandas_available(): + return _is_package_available("pandas") + + +def is_mlflow_available(): + if _is_package_available("mlflow"): + return True + + if importlib.util.find_spec("mlflow") is not None: + try: + _ = importlib.metadata.metadata("mlflow-skinny") + return True + except importlib.metadata.PackageNotFoundError: + return False + return False + + +def is_mps_available(): + return is_torch_version(">=", "1.12") and torch.backends.mps.is_available() and torch.backends.mps.is_built() + + +def is_ipex_available(): + def get_major_and_minor_from_version(full_version): + return str(version.parse(full_version).major) + "." + str(version.parse(full_version).minor) + + _torch_version = importlib.metadata.version("torch") + if importlib.util.find_spec("intel_extension_for_pytorch") is None: + return False + _ipex_version = "N/A" + try: + _ipex_version = importlib.metadata.version("intel_extension_for_pytorch") + except importlib.metadata.PackageNotFoundError: + return False + torch_major_and_minor = get_major_and_minor_from_version(_torch_version) + ipex_major_and_minor = get_major_and_minor_from_version(_ipex_version) + if torch_major_and_minor != ipex_major_and_minor: + warnings.warn( + f"Intel Extension for PyTorch {ipex_major_and_minor} needs to work with PyTorch {ipex_major_and_minor}.*," + f" but PyTorch {_torch_version} is found. Please switch to the matching version and run again." + ) + return False + return True + + +@lru_cache +def is_mlu_available(check_device=False): + "Checks if `torch_mlu` is installed and potentially if a MLU is in the environment" + if importlib.util.find_spec("torch_mlu") is None: + return False + + import torch + import torch_mlu # noqa: F401 + + if check_device: + try: + # Will raise a RuntimeError if no MLU is found + _ = torch.mlu.device_count() + return torch.mlu.is_available() + except RuntimeError: + return False + return hasattr(torch, "mlu") and torch.mlu.is_available() + + +@lru_cache +def is_npu_available(check_device=False): + "Checks if `torch_npu` is installed and potentially if a NPU is in the environment" + if importlib.util.find_spec("torch") is None or importlib.util.find_spec("torch_npu") is None: + return False + + import torch + import torch_npu # noqa: F401 + + if check_device: + try: + # Will raise a RuntimeError if no NPU is found + _ = torch.npu.device_count() + return torch.npu.is_available() + except RuntimeError: + return False + return hasattr(torch, "npu") and torch.npu.is_available() + + +@lru_cache +def is_xpu_available(check_device=False): + "check if user disables it explicitly" + if not parse_flag_from_env("ACCELERATE_USE_XPU", default=True): + return False + "Checks if `intel_extension_for_pytorch` is installed and potentially if a XPU is in the environment" + if is_ipex_available(): + import torch + + if is_torch_version("<=", "1.12"): + return False + else: + return False + + import intel_extension_for_pytorch # noqa: F401 + + if check_device: + try: + # Will raise a RuntimeError if no XPU is found + _ = torch.xpu.device_count() + return torch.xpu.is_available() + except RuntimeError: + return False + return hasattr(torch, "xpu") and torch.xpu.is_available() + + +def is_dvclive_available(): + return _is_package_available("dvclive") diff --git a/llm/Lib/site-packages/accelerate/utils/launch.py b/llm/Lib/site-packages/accelerate/utils/launch.py new file mode 100644 index 0000000000000000000000000000000000000000..dc074270bf8f8f66f9a963c6990de0dd05e766ce --- /dev/null +++ b/llm/Lib/site-packages/accelerate/utils/launch.py @@ -0,0 +1,624 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import subprocess +import sys +import warnings +from ast import literal_eval +from shutil import which +from typing import Any, Dict, List, Tuple + +import torch + +from ..commands.config.config_args import SageMakerConfig +from ..utils import ( + DynamoBackend, + PrecisionType, + is_ipex_available, + is_mlu_available, + is_npu_available, + is_torch_xla_available, + is_xpu_available, +) +from ..utils.constants import DEEPSPEED_MULTINODE_LAUNCHERS +from ..utils.other import is_port_in_use, merge_dicts +from .dataclasses import DistributedType, SageMakerDistributedType + + +def _filter_args(args, parser, default_args=[]): + """ + Filters out all `accelerate` specific args + """ + new_args, _ = parser.parse_known_args(default_args) + for key, value in vars(args).items(): + if key in vars(new_args).keys(): + setattr(new_args, key, value) + return new_args + + +def _get_mpirun_args(): + """ + Determines the executable and argument names for mpirun, based on the type of install. The supported MPI programs + are: OpenMPI, Intel MPI, or MVAPICH. + + Returns: Program name and arg names for hostfile, num processes, and processes per node + """ + # Find the MPI program name + mpi_apps = [x for x in ["mpirun", "mpiexec"] if which(x)] + + if len(mpi_apps) == 0: + raise OSError("mpirun or mpiexec were not found. Ensure that Intel MPI, Open MPI, or MVAPICH are installed.") + + # Call the app with the --version flag to determine which MPI app is installed + mpi_app = mpi_apps[0] + mpirun_version = subprocess.check_output([mpi_app, "--version"]) + + if b"Open MPI" in mpirun_version: + return mpi_app, "--hostfile", "-n", "--npernode" + else: + # Intel MPI and MVAPICH both use the same arg names + return mpi_app, "-f", "-n", "-ppn" + + +def prepare_simple_launcher_cmd_env(args: argparse.Namespace) -> Tuple[List[str], Dict[str, str]]: + """ + Prepares and returns the command list and an environment with the correct simple launcher environment variables. + """ + cmd = [] + if args.no_python and args.module: + raise ValueError("--module and --no_python cannot be used together") + + if args.mpirun_hostfile is not None: + mpi_app_name, hostfile_arg, num_proc_arg, proc_per_node_arg = _get_mpirun_args() + mpirun_ccl = getattr(args, "mpirun_ccl", None) + num_machines = args.num_machines + num_processes = getattr(args, "num_processes", None) + nproc_per_node = str(num_processes // num_machines) if num_processes and num_machines else "1" + cmd += [mpi_app_name, hostfile_arg, args.mpirun_hostfile, proc_per_node_arg, nproc_per_node] + if num_processes: + cmd += [num_proc_arg, str(num_processes)] + if not args.no_python: + cmd.append(sys.executable) + if args.module: + cmd.append("-m") + cmd.append(args.training_script) + cmd.extend(args.training_script_args) + + current_env = os.environ.copy() + current_env["ACCELERATE_USE_CPU"] = str(args.cpu or args.use_cpu) + if args.debug: + current_env["ACCELERATE_DEBUG_MODE"] = "true" + if args.gpu_ids != "all" and args.gpu_ids is not None: + if is_xpu_available(): + current_env["ZE_AFFINITY_MASK"] = args.gpu_ids + elif is_mlu_available(): + current_env["MLU_VISIBLE_DEVICES"] = args.gpu_ids + elif is_npu_available(): + current_env["ASCEND_RT_VISIBLE_DEVICES"] = args.gpu_ids + else: + current_env["CUDA_VISIBLE_DEVICES"] = args.gpu_ids + if args.num_machines > 1: + current_env["MASTER_ADDR"] = args.main_process_ip + current_env["MASTER_PORT"] = str(args.main_process_port) + + if args.mpirun_hostfile is not None: + current_env["CCL_WORKER_COUNT"] = mpirun_ccl + elif args.num_processes > 1: + current_env["MASTER_ADDR"] = args.main_process_ip if args.main_process_ip is not None else "127.0.0.1" + current_env["MASTER_PORT"] = str(args.main_process_port) if args.main_process_port is not None else "29500" + + try: + mixed_precision = PrecisionType(args.mixed_precision.lower()) + except ValueError: + raise ValueError( + f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}." + ) + + current_env["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision) + + try: + dynamo_backend = DynamoBackend(args.dynamo_backend.upper()) + except ValueError: + raise ValueError( + f"Unknown dynamo backend: {args.dynamo_backend.upper()}. Choose between {DynamoBackend.list()}." + ) + current_env["ACCELERATE_DYNAMO_BACKEND"] = dynamo_backend.value + current_env["ACCELERATE_DYNAMO_MODE"] = args.dynamo_mode + current_env["ACCELERATE_DYNAMO_USE_FULLGRAPH"] = str(args.dynamo_use_fullgraph) + current_env["ACCELERATE_DYNAMO_USE_DYNAMIC"] = str(args.dynamo_use_dynamic) + + current_env["OMP_NUM_THREADS"] = str(args.num_cpu_threads_per_process) + if is_ipex_available(): + current_env["ACCELERATE_USE_IPEX"] = str(args.ipex).lower() + current_env["ACCELERATE_USE_XPU"] = str(args.use_xpu).lower() + if args.enable_cpu_affinity: + current_env["ACCELERATE_CPU_AFFINITY"] = "1" + return cmd, current_env + + +def prepare_multi_gpu_env(args: argparse.Namespace) -> Dict[str, str]: + """ + Prepares and returns an environment with the correct multi-GPU environment variables. + """ + num_processes = args.num_processes + num_machines = args.num_machines + main_process_ip = args.main_process_ip + main_process_port = args.main_process_port + if num_machines > 1: + args.nproc_per_node = str(num_processes // num_machines) + args.nnodes = str(num_machines) + args.node_rank = int(args.machine_rank) + if getattr(args, "same_network", False): + args.master_addr = str(main_process_ip) + args.master_port = str(main_process_port) + else: + args.rdzv_endpoint = f"{main_process_ip}:{main_process_port}" + else: + args.nproc_per_node = str(num_processes) + if main_process_port is not None: + args.master_port = str(main_process_port) + + if main_process_port is None: + main_process_port = 29500 + + # only need to check port availability in main process, in case we have to start multiple launchers on the same machine + # for some reasons like splitting log files. + need_port_check = num_machines <= 1 or int(args.machine_rank) == 0 + if need_port_check and is_port_in_use(main_process_port): + raise ConnectionError( + f"Tried to launch distributed communication on port `{main_process_port}`, but another process is utilizing it. " + "Please specify a different port (such as using the `--main_process_port` flag or specifying a different `main_process_port` in your config file)" + " and rerun your script. To automatically use the next open port (on a single node), you can set this to `0`." + ) + + if args.module and args.no_python: + raise ValueError("--module and --no_python cannot be used together") + elif args.module: + args.module = True + elif args.no_python: + args.no_python = True + + current_env = os.environ.copy() + if args.debug: + current_env["ACCELERATE_DEBUG_MODE"] = "true" + gpu_ids = getattr(args, "gpu_ids", "all") + if gpu_ids != "all" and args.gpu_ids is not None: + if is_xpu_available(): + current_env["ZE_AFFINITY_MASK"] = gpu_ids + elif is_mlu_available(): + current_env["MLU_VISIBLE_DEVICES"] = gpu_ids + elif is_npu_available(): + current_env["ASCEND_RT_VISIBLE_DEVICES"] = gpu_ids + else: + current_env["CUDA_VISIBLE_DEVICES"] = gpu_ids + mixed_precision = args.mixed_precision.lower() + try: + mixed_precision = PrecisionType(mixed_precision) + except ValueError: + raise ValueError(f"Unknown mixed_precision mode: {mixed_precision}. Choose between {PrecisionType.list()}.") + + current_env["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision) + + try: + dynamo_backend = DynamoBackend(args.dynamo_backend.upper()) + except ValueError: + raise ValueError( + f"Unknown dynamo backend: {args.dynamo_backend.upper()}. Choose between {DynamoBackend.list()}." + ) + current_env["ACCELERATE_DYNAMO_BACKEND"] = dynamo_backend.value + current_env["ACCELERATE_DYNAMO_MODE"] = args.dynamo_mode + current_env["ACCELERATE_DYNAMO_USE_FULLGRAPH"] = str(args.dynamo_use_fullgraph) + current_env["ACCELERATE_DYNAMO_USE_DYNAMIC"] = str(args.dynamo_use_dynamic) + + if args.use_fsdp: + current_env["ACCELERATE_USE_FSDP"] = "true" + if args.fsdp_cpu_ram_efficient_loading and not args.fsdp_sync_module_states: + raise ValueError("When using `--fsdp_cpu_ram_efficient_loading` set `--fsdp_sync_module_states` to `True`") + + current_env["FSDP_SHARDING_STRATEGY"] = str(args.fsdp_sharding_strategy) + current_env["FSDP_OFFLOAD_PARAMS"] = str(args.fsdp_offload_params).lower() + current_env["FSDP_MIN_NUM_PARAMS"] = str(args.fsdp_min_num_params) + if args.fsdp_auto_wrap_policy is not None: + current_env["FSDP_AUTO_WRAP_POLICY"] = str(args.fsdp_auto_wrap_policy) + if args.fsdp_transformer_layer_cls_to_wrap is not None: + current_env["FSDP_TRANSFORMER_CLS_TO_WRAP"] = str(args.fsdp_transformer_layer_cls_to_wrap) + if args.fsdp_backward_prefetch_policy is not None: + warnings.warn( + "`fsdp_backward_prefetch_policy` is deprecated and will be removed in version 0.27.0 of 🤗 Accelerate. Use" + " `fsdp_backward_prefetch` instead", + FutureWarning, + ) + args.fsdp_backward_prefetch = args.fsdp_backward_prefetch_policy + if args.fsdp_backward_prefetch is not None: + current_env["FSDP_BACKWARD_PREFETCH"] = str(args.fsdp_backward_prefetch) + if args.fsdp_state_dict_type is not None: + current_env["FSDP_STATE_DICT_TYPE"] = str(args.fsdp_state_dict_type) + current_env["FSDP_FORWARD_PREFETCH"] = str(args.fsdp_forward_prefetch).lower() + current_env["FSDP_USE_ORIG_PARAMS"] = str(args.fsdp_use_orig_params).lower() + current_env["FSDP_CPU_RAM_EFFICIENT_LOADING"] = str(args.fsdp_cpu_ram_efficient_loading).lower() + current_env["FSDP_SYNC_MODULE_STATES"] = str(args.fsdp_sync_module_states).lower() + + if args.use_megatron_lm: + prefix = "MEGATRON_LM_" + current_env["ACCELERATE_USE_MEGATRON_LM"] = "true" + current_env[prefix + "TP_DEGREE"] = str(args.megatron_lm_tp_degree) + current_env[prefix + "PP_DEGREE"] = str(args.megatron_lm_pp_degree) + current_env[prefix + "GRADIENT_CLIPPING"] = str(args.megatron_lm_gradient_clipping) + if args.megatron_lm_num_micro_batches is not None: + current_env[prefix + "NUM_MICRO_BATCHES"] = str(args.megatron_lm_num_micro_batches) + if args.megatron_lm_sequence_parallelism is not None: + current_env[prefix + "SEQUENCE_PARALLELISM"] = str(args.megatron_lm_sequence_parallelism) + if args.megatron_lm_recompute_activations is not None: + current_env[prefix + "RECOMPUTE_ACTIVATIONS"] = str(args.megatron_lm_recompute_activations) + if args.megatron_lm_use_distributed_optimizer is not None: + current_env[prefix + "USE_DISTRIBUTED_OPTIMIZER"] = str(args.megatron_lm_use_distributed_optimizer) + + current_env["OMP_NUM_THREADS"] = str(args.num_cpu_threads_per_process) + if args.enable_cpu_affinity: + current_env["ACCELERATE_CPU_AFFINITY"] = "1" + return current_env + + +def prepare_deepspeed_cmd_env(args: argparse.Namespace) -> Tuple[List[str], Dict[str, str]]: + """ + Prepares and returns the command list and an environment with the correct DeepSpeed environment variables. + """ + num_processes = args.num_processes + num_machines = args.num_machines + main_process_ip = args.main_process_ip + main_process_port = args.main_process_port + cmd = None + + # make sure launcher is not None + if args.deepspeed_multinode_launcher is None: + # set to default pdsh + args.deepspeed_multinode_launcher = DEEPSPEED_MULTINODE_LAUNCHERS[0] + + if num_machines > 1 and args.deepspeed_multinode_launcher != DEEPSPEED_MULTINODE_LAUNCHERS[1]: + cmd = ["deepspeed", "--no_local_rank"] + cmd.extend(["--hostfile", str(args.deepspeed_hostfile), "--launcher", str(args.deepspeed_multinode_launcher)]) + if args.deepspeed_exclusion_filter is not None: + cmd.extend( + [ + "--exclude", + str(args.deepspeed_exclusion_filter), + ] + ) + elif args.deepspeed_inclusion_filter is not None: + cmd.extend( + [ + "--include", + str(args.deepspeed_inclusion_filter), + ] + ) + else: + cmd.extend(["--num_gpus", str(args.num_processes // args.num_machines)]) + if main_process_ip: + cmd.extend(["--master_addr", str(main_process_ip)]) + cmd.extend(["--master_port", str(main_process_port)]) + if args.module and args.no_python: + raise ValueError("--module and --no_python cannot be used together") + elif args.module: + cmd.append("--module") + elif args.no_python: + cmd.append("--no_python") + cmd.append(args.training_script) + cmd.extend(args.training_script_args) + elif num_machines > 1 and args.deepspeed_multinode_launcher == DEEPSPEED_MULTINODE_LAUNCHERS[1]: + args.nproc_per_node = str(num_processes // num_machines) + args.nnodes = str(num_machines) + args.node_rank = int(args.machine_rank) + if getattr(args, "same_network", False): + args.master_addr = str(main_process_ip) + args.master_port = str(main_process_port) + else: + args.rdzv_endpoint = f"{main_process_ip}:{main_process_port}" + else: + args.nproc_per_node = str(num_processes) + if main_process_port is not None: + args.master_port = str(main_process_port) + + if main_process_port is None: + main_process_port = 29500 + + # only need to check port availability in main process, in case we have to start multiple launchers on the same machine + # for some reasons like splitting log files. + need_port_check = num_machines <= 1 or int(args.machine_rank) == 0 + if need_port_check and is_port_in_use(main_process_port): + raise ConnectionError( + f"Tried to launch distributed communication on port `{main_process_port}`, but another process is utilizing it. " + "Please specify a different port (such as using the `--main_process_port` flag or specifying a different `main_process_port` in your config file)" + " and rerun your script. To automatically use the next open port (on a single node), you can set this to `0`." + ) + + if args.module and args.no_python: + raise ValueError("--module and --no_python cannot be used together") + elif args.module: + args.module = True + elif args.no_python: + args.no_python = True + + current_env = os.environ.copy() + if args.debug: + current_env["ACCELERATE_DEBUG_MODE"] = "true" + gpu_ids = getattr(args, "gpu_ids", "all") + if gpu_ids != "all" and args.gpu_ids is not None: + if is_xpu_available(): + current_env["ZE_AFFINITY_MASK"] = gpu_ids + elif is_mlu_available(): + current_env["MLU_VISIBLE_DEVICES"] = gpu_ids + elif is_npu_available(): + current_env["ASCEND_RT_VISIBLE_DEVICES"] = gpu_ids + else: + current_env["CUDA_VISIBLE_DEVICES"] = gpu_ids + try: + mixed_precision = PrecisionType(args.mixed_precision.lower()) + except ValueError: + raise ValueError( + f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}." + ) + + current_env["PYTHONPATH"] = env_var_path_add("PYTHONPATH", os.path.abspath(".")) + current_env["ACCELERATE_MIXED_PRECISION"] = str(mixed_precision) + current_env["ACCELERATE_CONFIG_DS_FIELDS"] = str(args.deepspeed_fields_from_accelerate_config).lower() + current_env["ACCELERATE_USE_DEEPSPEED"] = "true" + if args.zero_stage is not None: + current_env["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(args.zero_stage) + if args.gradient_accumulation_steps is not None: + current_env["ACCELERATE_GRADIENT_ACCUMULATION_STEPS"] = str(args.gradient_accumulation_steps) + if args.gradient_clipping is not None: + current_env["ACCELERATE_GRADIENT_CLIPPING"] = str(args.gradient_clipping).lower() + if args.offload_optimizer_device is not None: + current_env["ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_DEVICE"] = str(args.offload_optimizer_device).lower() + if args.offload_param_device is not None: + current_env["ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_DEVICE"] = str(args.offload_param_device).lower() + if args.zero3_init_flag is not None: + current_env["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = str(args.zero3_init_flag).lower() + if args.zero3_save_16bit_model is not None: + current_env["ACCELERATE_DEEPSPEED_ZERO3_SAVE_16BIT_MODEL"] = str(args.zero3_save_16bit_model).lower() + if args.deepspeed_config_file is not None: + current_env["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = str(args.deepspeed_config_file) + if args.enable_cpu_affinity: + current_env["ACCELERATE_CPU_AFFINITY"] = "1" + return cmd, current_env + + +def prepare_tpu( + args: argparse.Namespace, current_env: Dict[str, str], pod: bool = False +) -> Tuple[argparse.Namespace, Dict[str, str]]: + """ + Prepares and returns an environment with the correct TPU environment variables. + """ + if args.mixed_precision == "bf16" and is_torch_xla_available(check_is_tpu=True): + if args.downcast_bf16: + current_env["XLA_DOWNCAST_BF16"] = "1" + else: + current_env["XLA_USE_BF16"] = "1" + if args.debug: + current_env["ACCELERATE_DEBUG_MODE"] = "true" + if pod: + # Take explicit args and set them up for XLA + args.vm = args.tpu_vm + args.tpu = args.tpu_name + return args, current_env + + +def _convert_nargs_to_dict(nargs: List[str]) -> Dict[str, str]: + if len(nargs) < 0: + return {} + # helper function to infer type for argsparser + + def _infer_type(s): + try: + s = float(s) + + if s // 1 == s: + return int(s) + return s + except ValueError: + return s + + parser = argparse.ArgumentParser() + _, unknown = parser.parse_known_args(nargs) + for index, argument in enumerate(unknown): + if argument.startswith(("-", "--")): + action = None + if index + 1 < len(unknown): # checks if next index would be in list + if unknown[index + 1].startswith(("-", "--")): # checks if next element is an key + # raise an error if element is store_true or store_false + raise ValueError( + "SageMaker doesn’t support argparse actions for `store_true` or `store_false`. Please define explicit types" + ) + else: # raise an error if last element is store_true or store_false + raise ValueError( + "SageMaker doesn’t support argparse actions for `store_true` or `store_false`. Please define explicit types" + ) + # adds argument to parser based on action_store true + if action is None: + parser.add_argument(argument, type=_infer_type) + else: + parser.add_argument(argument, action=action) + + return { + key: (literal_eval(value) if value in ("True", "False") else value) + for key, value in parser.parse_args(nargs).__dict__.items() + } + + +def prepare_sagemager_args_inputs( + sagemaker_config: SageMakerConfig, args: argparse.Namespace +) -> Tuple[argparse.Namespace, Dict[str, Any]]: + # configure environment + print("Configuring Amazon SageMaker environment") + os.environ["AWS_DEFAULT_REGION"] = sagemaker_config.region + + # configure credentials + if sagemaker_config.profile is not None: + os.environ["AWS_PROFILE"] = sagemaker_config.profile + elif args.aws_access_key_id is not None and args.aws_secret_access_key is not None: + os.environ["AWS_ACCESS_KEY_ID"] = args.aws_access_key_id + os.environ["AWS_SECRET_ACCESS_KEY"] = args.aws_secret_access_key + else: + raise OSError("You need to provide an aws_access_key_id and aws_secret_access_key when not using aws_profile") + + # extract needed arguments + source_dir = os.path.dirname(args.training_script) + if not source_dir: # checks if string is empty + source_dir = "." + entry_point = os.path.basename(args.training_script) + if not entry_point.endswith(".py"): + raise ValueError(f'Your training script should be a python script and not "{entry_point}"') + + print("Converting Arguments to Hyperparameters") + hyperparameters = _convert_nargs_to_dict(args.training_script_args) + + try: + mixed_precision = PrecisionType(args.mixed_precision.lower()) + except ValueError: + raise ValueError( + f"Unknown mixed_precision mode: {args.mixed_precision.lower()}. Choose between {PrecisionType.list()}." + ) + + try: + dynamo_backend = DynamoBackend(args.dynamo_backend.upper()) + except ValueError: + raise ValueError( + f"Unknown dynamo backend: {args.dynamo_backend.upper()}. Choose between {DynamoBackend.list()}." + ) + + # Environment variables to be set for use during training job + environment = { + "ACCELERATE_USE_SAGEMAKER": "true", + "ACCELERATE_MIXED_PRECISION": str(mixed_precision), + "ACCELERATE_DYNAMO_BACKEND": dynamo_backend.value, + "ACCELERATE_DYNAMO_MODE": args.dynamo_mode, + "ACCELERATE_DYNAMO_USE_FULLGRAPH": str(args.dynamo_use_fullgraph), + "ACCELERATE_DYNAMO_USE_DYNAMIC": str(args.dynamo_use_dynamic), + "ACCELERATE_SAGEMAKER_DISTRIBUTED_TYPE": sagemaker_config.distributed_type.value, + } + # configure distribution set up + distribution = None + if sagemaker_config.distributed_type == SageMakerDistributedType.DATA_PARALLEL: + distribution = {"smdistributed": {"dataparallel": {"enabled": True}}} + + # configure sagemaker inputs + sagemaker_inputs = None + if sagemaker_config.sagemaker_inputs_file is not None: + print(f"Loading SageMaker Inputs from {sagemaker_config.sagemaker_inputs_file} file") + sagemaker_inputs = {} + with open(sagemaker_config.sagemaker_inputs_file) as file: + for i, line in enumerate(file): + if i == 0: + continue + l = line.split("\t") + sagemaker_inputs[l[0]] = l[1].strip() + print(f"Loaded SageMaker Inputs: {sagemaker_inputs}") + + # configure sagemaker metrics + sagemaker_metrics = None + if sagemaker_config.sagemaker_metrics_file is not None: + print(f"Loading SageMaker Metrics from {sagemaker_config.sagemaker_metrics_file} file") + sagemaker_metrics = [] + with open(sagemaker_config.sagemaker_metrics_file) as file: + for i, line in enumerate(file): + if i == 0: + continue + l = line.split("\t") + metric_dict = { + "Name": l[0], + "Regex": l[1].strip(), + } + sagemaker_metrics.append(metric_dict) + print(f"Loaded SageMaker Metrics: {sagemaker_metrics}") + + # configure session + print("Creating Estimator") + args = { + "image_uri": sagemaker_config.image_uri, + "entry_point": entry_point, + "source_dir": source_dir, + "role": sagemaker_config.iam_role_name, + "transformers_version": sagemaker_config.transformers_version, + "pytorch_version": sagemaker_config.pytorch_version, + "py_version": sagemaker_config.py_version, + "base_job_name": sagemaker_config.base_job_name, + "instance_count": sagemaker_config.num_machines, + "instance_type": sagemaker_config.ec2_instance_type, + "debugger_hook_config": False, + "distribution": distribution, + "hyperparameters": hyperparameters, + "environment": environment, + "metric_definitions": sagemaker_metrics, + } + + if sagemaker_config.additional_args is not None: + args = merge_dicts(sagemaker_config.additional_args, args) + return args, sagemaker_inputs + + +def env_var_path_add(env_var_name, path_to_add): + """ + Extends a path-based environment variable's value with a new path and returns the updated value. It's up to the + caller to set it in os.environ. + """ + paths = [p for p in os.environ.get(env_var_name, "").split(":") if len(p) > 0] + paths.append(str(path_to_add)) + return ":".join(paths) + + +class PrepareForLaunch: + """ + Prepare a function that will launched in a distributed setup. + + Args: + launcher (`Callable`): + The function to launch. + distributed_type ([`~state.DistributedType`]): + The distributed type to prepare for. + debug (`bool`, *optional*, defaults to `False`): + Whether or not this is a debug launch. + """ + + def __init__(self, launcher, distributed_type="NO", debug=False): + self.launcher = launcher + self.distributed_type = DistributedType(distributed_type) + self.debug = debug + + def __call__(self, index, *args): + if self.debug: + world_size = int(os.environ.get("WORLD_SIZE")) + rdv_file = os.environ.get("ACCELERATE_DEBUG_RDV_FILE") + torch.distributed.init_process_group( + "gloo", + rank=index, + store=torch.distributed.FileStore(rdv_file, world_size), + world_size=world_size, + ) + elif self.distributed_type in ( + DistributedType.MULTI_GPU, + DistributedType.MULTI_MLU, + DistributedType.MULTI_NPU, + DistributedType.MULTI_XPU, + DistributedType.MULTI_CPU, + ): + # Prepare the environment for torch.distributed + os.environ["LOCAL_RANK"] = str(index) + nproc = int(os.environ.get("NPROC", 1)) + node_rank = int(os.environ.get("NODE_RANK", 0)) + os.environ["RANK"] = str(nproc * node_rank + index) + + os.environ["FORK_LAUNCHED"] = str(1) + self.launcher(*args) diff --git a/llm/Lib/site-packages/accelerate/utils/megatron_lm.py b/llm/Lib/site-packages/accelerate/utils/megatron_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..ff1eb199307fd44cb815644e8bb34d08e2d6adb6 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/utils/megatron_lm.py @@ -0,0 +1,1435 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import math +from abc import ABC +from functools import partial + +import torch +import torch.nn.functional as F +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP + +from ..optimizer import AcceleratedOptimizer +from ..scheduler import AcceleratedScheduler +from .imports import is_megatron_lm_available, is_transformers_available +from .operations import recursively_apply, send_to_device + + +if is_transformers_available(): + from transformers.modeling_outputs import ( + CausalLMOutputWithCrossAttentions, + Seq2SeqLMOutput, + SequenceClassifierOutput, + ) + + +if is_megatron_lm_available(): + from megatron import ( + get_args, + get_num_microbatches, + get_tensorboard_writer, + get_timers, + get_tokenizer, + mpu, + print_rank_0, + print_rank_last, + ) + from megatron.arguments import _add_data_args, _add_validation_args, parse_args, validate_args + from megatron.checkpointing import load_args_from_checkpoint, load_checkpoint, save_checkpoint + from megatron.data.data_samplers import MegatronPretrainingRandomSampler, MegatronPretrainingSampler + from megatron.global_vars import set_global_variables + from megatron.initialize import ( + _compile_dependencies, + _init_autoresume, + _set_random_seed, + set_jit_fusion_options, + write_args_to_tensorboard, + ) + from megatron.model import BertModel, Float16Module, GPTModel, ModelType, T5Model + from megatron.model import DistributedDataParallel as LocalDDP + from megatron.model.classification import Classification + from megatron.optimizer import get_megatron_optimizer + from megatron.schedules import get_forward_backward_func + from megatron.text_generation.communication import broadcast_int_list, broadcast_tensor + from megatron.text_generation.generation import ( + beam_search_and_return_on_first_stage, + generate_tokens_probs_and_return_on_first_stage, + ) + from megatron.tokenizer.tokenizer import _vocab_size_with_padding + from megatron.training import get_model, get_optimizer_param_scheduler, training_log + from megatron.utils import ( + average_losses_across_data_parallel_group, + calc_params_l2_norm, + get_ltor_masks_and_position_ids, + unwrap_model, + ) + + +# model utilities +def model_provider_func(pre_process=True, post_process=True, add_encoder=True, add_decoder=True): + """Build the model.""" + args = get_args() + mode = "pre-training" if args.pretraining_flag else "fine-tuning" + if args.rank == 0: + print(f"Building {args.model_type_name} model in the {mode} mode.") + print( + "The Megatron LM model weights are initialized at random in `accelerator.prepare`. " + "Please use `accelerator.load_checkpoint` to load a pre-trained checkpoint matching the distributed setup." + ) + if args.model_type_name == "bert": + if args.pretraining_flag: + num_tokentypes = 2 if args.bert_binary_head else 0 + model = BertModel( + num_tokentypes=num_tokentypes, + add_binary_head=args.bert_binary_head, + parallel_output=True, + pre_process=pre_process, + post_process=post_process, + ) + else: + model = Classification( + num_classes=args.num_labels, num_tokentypes=2, pre_process=pre_process, post_process=post_process + ) + elif args.model_type_name == "gpt": + model = GPTModel(num_tokentypes=0, parallel_output=True, pre_process=pre_process, post_process=post_process) + elif args.model_type_name == "t5": + model = T5Model( + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process, + add_encoder=add_encoder, + add_decoder=add_decoder, + ) + else: + raise ValueError(f"Unsupported model type: {args.model_type_name}") + return model + + +def prepare_model(accelerator): + accelerator.print("Preparing model") + args = get_args() + if accelerator.state.megatron_lm_plugin.custom_prepare_model_function is not None: + if accelerator.state.megatron_lm_plugin.custom_model_provider_function is None: + raise ValueError( + "You must provide a `custom_model_provider_function` when using a `custom_prepare_model_function`." + ) + custom_model_provider_func = accelerator.state.megatron_lm_plugin.custom_model_provider_function + model = accelerator.state.megatron_lm_plugin.custom_prepare_model_function(custom_model_provider_func) + else: + if args.model_type_name in ("bert", "gpt"): + model_type = ModelType.encoder_or_decoder + elif args.model_type_name == "t5": + model_type = ModelType.encoder_and_decoder + if args.pipeline_model_parallel_split_rank is None and args.pipeline_model_parallel_size > 1: + args.pipeline_model_parallel_split_rank = args.pipeline_model_parallel_size // 2 + model = get_model(model_provider_func, model_type) + return model + + +# dataloader utilities +class MegatronLMDummyDataLoader: + """ + Dummy dataloader presents model parameters or param groups, this is primarily used to follow conventional training + + Args: + **dataset_kwargs: Megatron data arguments. + """ + + def __init__(self, **dataset_kwargs): + parser = argparse.ArgumentParser() + parser = _add_data_args(parser) + parser = _add_validation_args(parser) + data_args = parser.parse_known_args() + self.dataset_args = vars(data_args[0]) + self.dataset_args.update(dataset_kwargs) + self.dataset_args["megatron_dataset_flag"] = True + + def set_megatron_data_args(self): + args = get_args() + for key, value in self.dataset_args.items(): + setattr(args, key, value) + + def get_train_valid_test_datasets_provider(self): + def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build train, valid, and test datasets.""" + args = get_args() + dataset_args = { + "data_prefix": args.data_path, + "data_impl": args.data_impl, + "splits_string": args.split, + "train_valid_test_num_samples": train_val_test_num_samples, + "skip_warmup": (not args.mmap_warmup), + "seed": args.seed, + } + if args.model_type_name == "bert": + dataset_args.update( + { + "max_seq_length": args.seq_length, + "masked_lm_prob": args.mask_prob, + "short_seq_prob": args.short_seq_prob, + "binary_head": args.bert_binary_head, + } + ) + elif args.model_type_name == "gpt": + dataset_args.update( + { + "seq_length": args.seq_length, + } + ) + elif args.model_type_name == "t5": + dataset_args.update( + { + "max_seq_length": args.encoder_seq_length, + "max_seq_length_dec": args.decoder_seq_length, + "masked_lm_prob": args.mask_prob, + "short_seq_prob": args.short_seq_prob, + "dataset_type": "t5", + } + ) + else: + raise ValueError(f"Unsupported model type: {args.model_type_name}") + if args.model_type_name == "gpt": + from megatron.data.gpt_dataset import build_train_valid_test_datasets + else: + from megatron.data.dataset_utils import build_train_valid_test_datasets + train_ds, valid_ds, test_ds = build_train_valid_test_datasets(**dataset_args) + return train_ds, valid_ds, test_ds + + return train_valid_test_datasets_provider + + def build_pretraining_data_loader(self, dataset, consumed_samples): + if dataset is None: + return None + args = get_args() + micro_batch_size = args.micro_batch_size * args.num_micro_batches + + # Megatron sampler + if args.dataloader_type == "single": + batch_sampler = MegatronPretrainingSampler( + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=micro_batch_size, + data_parallel_rank=mpu.get_data_parallel_rank(), + data_parallel_size=mpu.get_data_parallel_world_size(), + ) + elif args.dataloader_type == "cyclic": + batch_sampler = MegatronPretrainingRandomSampler( + dataset, + total_samples=len(dataset), + consumed_samples=consumed_samples, + micro_batch_size=micro_batch_size, + data_parallel_rank=mpu.get_data_parallel_rank(), + data_parallel_size=mpu.get_data_parallel_world_size(), + data_sharding=args.data_sharding, + ) + else: + raise Exception(f"{args.dataloader_type} dataloader type is not supported.") + + # Torch dataloader. + return torch.utils.data.DataLoader( + dataset, batch_sampler=batch_sampler, num_workers=args.num_workers, pin_memory=True + ) + + def build_train_valid_test_data_iterators(self): + def cyclic_iter(iter): + while True: + yield from iter + + args = get_args() + + (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None) + + print_rank_0("> building train, validation, and test datasets ...") + + # Backward compatibility, assume fixed batch size. + if args.iteration > 0 and args.consumed_train_samples == 0: + assert args.train_samples is None, "only backward compatiblity support for iteration-based training" + args.consumed_train_samples = args.iteration * args.global_batch_size + if args.iteration > 0 and args.consumed_valid_samples == 0: + if args.train_samples is None: + args.consumed_valid_samples = ( + (args.iteration // args.eval_interval) * args.eval_iters * args.global_batch_size + ) + + # Data loader only on rank 0 of each model parallel group. + if mpu.get_tensor_model_parallel_rank() == 0: + # Number of train/valid/test samples. + if args.train_samples: + train_samples = args.train_samples + else: + train_samples = args.train_iters * args.global_batch_size + eval_iters = (args.train_iters // args.eval_interval + 1) * args.eval_iters + test_iters = args.eval_iters + train_val_test_num_samples = [ + train_samples, + eval_iters * args.global_batch_size, + test_iters * args.global_batch_size, + ] + print_rank_0(" > datasets target sizes (minimum size):") + print_rank_0(f" train: {train_val_test_num_samples[0]}") + print_rank_0(f" validation: {train_val_test_num_samples[1]}") + print_rank_0(f" test: {train_val_test_num_samples[2]}") + + # Build the datasets. + train_valid_test_datasets_provider = self.get_train_valid_test_datasets_provider() + train_ds, valid_ds, test_ds = train_valid_test_datasets_provider(train_val_test_num_samples) + + # Build dataloders. + train_dataloader = self.build_pretraining_data_loader(train_ds, args.consumed_train_samples) + valid_dataloader = self.build_pretraining_data_loader(valid_ds, args.consumed_valid_samples) + test_dataloader = self.build_pretraining_data_loader(test_ds, 0) + + # Flags to know if we need to do training/validation/testing. + do_train = train_dataloader is not None and args.train_iters > 0 + do_valid = valid_dataloader is not None and args.eval_iters > 0 + do_test = test_dataloader is not None and args.eval_iters > 0 + # Need to broadcast num_tokens and num_type_tokens. + flags = torch.cuda.LongTensor([int(do_train), int(do_valid), int(do_test)]) + else: + flags = torch.cuda.LongTensor([0, 0, 0]) + + # Broadcast num tokens. + torch.distributed.broadcast( + flags, mpu.get_tensor_model_parallel_src_rank(), group=mpu.get_tensor_model_parallel_group() + ) + args.do_train = flags[0].item() + args.do_valid = flags[1].item() + args.do_test = flags[2].item() + + # Build iterators. + dl_type = args.dataloader_type + assert dl_type in ["single", "cyclic"] + + if train_dataloader is not None: + train_data_iterator = ( + iter(train_dataloader) if dl_type == "single" else iter(cyclic_iter(train_dataloader)) + ) + else: + train_data_iterator = None + + if valid_dataloader is not None: + valid_data_iterator = ( + iter(valid_dataloader) if dl_type == "single" else iter(cyclic_iter(valid_dataloader)) + ) + else: + valid_data_iterator = None + + if test_dataloader is not None: + test_data_iterator = iter(test_dataloader) if dl_type == "single" else iter(cyclic_iter(test_dataloader)) + else: + test_data_iterator = None + + return train_data_iterator, valid_data_iterator, test_data_iterator + + +def prepare_data_loader(accelerator, dataloader): + accelerator.print("Preparing dataloader") + args = get_args() + if not args.megatron_dataset_flag: + from ..data_loader import _PYTORCH_DATALOADER_KWARGS, prepare_data_loader + + args = get_args() + micro_batch_size = args.micro_batch_size * args.num_micro_batches + kwargs = {k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k]) for k in _PYTORCH_DATALOADER_KWARGS} + if kwargs["batch_size"] is None: + if isinstance(kwargs["sampler"], torch.utils.data.BatchSampler): + kwargs["sampler"].batch_size = micro_batch_size + else: + del kwargs["sampler"] + del kwargs["shuffle"] + del kwargs["batch_size"] + kwargs["batch_sampler"].batch_size = micro_batch_size + else: + del kwargs["batch_sampler"] + kwargs["batch_size"] = micro_batch_size + + dataloader = torch.utils.data.DataLoader(dataloader.dataset, **kwargs) + return prepare_data_loader( + dataloader, + accelerator.device, + num_processes=mpu.get_data_parallel_world_size(), + process_index=mpu.get_data_parallel_rank(), + split_batches=accelerator.split_batches, + put_on_device=True, + rng_types=accelerator.rng_types.copy(), + dispatch_batches=accelerator.dispatch_batches, + ) + else: + if args.consumed_samples is not None: + ( + args.consumed_train_samples, + args.consumed_valid_samples, + args.consumed_test_samples, + ) = args.consumed_samples + else: + args.consumed_train_samples, args.consumed_valid_samples, args.consumed_test_samples = 0, 0, 0 + ( + train_data_iterator, + valid_data_iterator, + test_data_iterator, + ) = dataloader.build_train_valid_test_data_iterators() + return train_data_iterator, valid_data_iterator, test_data_iterator + + +# optimizer utilities +class MegatronLMOptimizerWrapper(AcceleratedOptimizer): + def __init__(self, optimizer): + super().__init__(optimizer, device_placement=False, scaler=None) + + def zero_grad(self, set_to_none=None): + pass # `model(**batch)` is doing that automatically. Therefore, it's implementation is not needed + + def step(self): + pass # `model(**batch)` is doing that automatically. Therefore, it's implementation is not needed + + @property + def step_was_skipped(self): + """Whether or not the optimizer step was done, or skipped because of gradient overflow.""" + return self.optimizer.skipped_iter + + +def prepare_optimizer(accelerator, model): + accelerator.print("Preparing optimizer") + args = get_args() + optimizer = get_megatron_optimizer(model, args.no_wd_decay_cond, args.scale_lr_cond, args.lr_mult) + return optimizer + + +# scheduler utilities +class MegatronLMDummyScheduler: + """ + Dummy scheduler presents model parameters or param groups, this is primarily used to follow conventional training + loop when scheduler config is specified in the deepspeed config file. + + Args: + optimizer (`torch.optim.optimizer.Optimizer`): + The optimizer to wrap. + total_num_steps (int): + Total number of steps. + warmup_num_steps (int): + Number of steps for warmup. + **kwargs (additional keyword arguments, *optional*): + Other arguments. + """ + + def __init__(self, optimizer, total_num_steps=None, warmup_num_steps=0, **kwargs): + self.optimizer = optimizer + self.total_num_steps = total_num_steps + self.warmup_num_steps = warmup_num_steps + self.kwargs = kwargs + + +class MegatronLMSchedulerWrapper(AcceleratedScheduler): + def __init__(self, scheduler, optimizers): + super().__init__(scheduler, optimizers) + + def step(self, *args, **kwargs): + return # `model(**batch)` is doing that automatically. Therefore, it's implementation is not needed + + +def prepare_scheduler(accelerator, optimizer, scheduler): + accelerator.print("Preparing scheduler") + scheduler = get_optimizer_param_scheduler(optimizer) + return scheduler + + +class AbstractTrainStep(ABC): + """Abstract class for batching, forward pass and loss handler.""" + + def __init__(self, name): + super().__init__() + self.name = name + + def get_batch_func(self): + pass + + def get_forward_step_func(self): + pass + + def get_loss_func(self): + pass + + +class BertTrainStep(AbstractTrainStep): + """ + Bert train step class. + + Args: + args (`argparse.Namespace`): Megatron-LM arguments. + """ + + def __init__(self, args): + super().__init__("BertTrainStep") + self.get_batch = self.get_batch_func(args.megatron_dataset_flag) + self.loss_func = self.get_loss_func(args.pretraining_flag, args.num_labels) + self.forward_step = self.get_forward_step_func(args.pretraining_flag, args.bert_binary_head) + if not args.model_return_dict: + self.model_output_class = None + else: + self.model_output_class = SequenceClassifierOutput + + def get_batch_func(self, megatron_dataset_flag): + def get_batch_megatron(data_iterator): + """Build the batch.""" + + # Items and their type. + keys = ["text", "types", "labels", "is_random", "loss_mask", "padding_mask"] + datatype = torch.int64 + + # Broadcast data. + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + data_b = mpu.broadcast_data(keys, data, datatype) + + # Unpack. + tokens = data_b["text"].long() + types = data_b["types"].long() + sentence_order = data_b["is_random"].long() + loss_mask = data_b["loss_mask"].float() + lm_labels = data_b["labels"].long() + padding_mask = data_b["padding_mask"].long() + + return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask + + def get_batch_transformer(data_iterator): + """Build the batch.""" + data = next(data_iterator) + data = send_to_device(data, torch.cuda.current_device()) + + # Unpack. + tokens = data["input_ids"].long() + padding_mask = data["attention_mask"].long() + if "token_type_ids" in data: + types = data["token_type_ids"].long() + else: + types = None + if "labels" in data: + lm_labels = data["labels"].long() + loss_mask = (data["labels"] != -100).to(torch.float) + else: + lm_labels = None + loss_mask = None + if "next_sentence_label" in data: + sentence_order = data["next_sentence_label"].long() + else: + sentence_order = None + + return tokens, types, sentence_order, loss_mask, lm_labels, padding_mask + + if megatron_dataset_flag: + return get_batch_megatron + else: + return get_batch_transformer + + def get_loss_func(self, pretraining_flag, num_labels): + def loss_func_pretrain(loss_mask, sentence_order, output_tensor): + lm_loss_, sop_logits = output_tensor + + lm_loss_ = lm_loss_.float() + loss_mask = loss_mask.float() + lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() + + if sop_logits is not None: + sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), sentence_order.view(-1), ignore_index=-1) + sop_loss = sop_loss.float() + loss = lm_loss + sop_loss + averaged_losses = average_losses_across_data_parallel_group([lm_loss, sop_loss]) + return loss, {"lm loss": averaged_losses[0], "sop loss": averaged_losses[1]} + + else: + loss = lm_loss + averaged_losses = average_losses_across_data_parallel_group([lm_loss]) + return loss, {"lm loss": averaged_losses[0]} + + def loss_func_finetune(labels, logits): + if num_labels == 1: + # We are doing regression + loss_fct = MSELoss() + loss = loss_fct(logits.view(-1), labels.view(-1)) + elif self.num_labels > 1 and (labels.dtype in (torch.long, torch.int)): + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, num_labels), labels.view(-1)) + else: + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + averaged_losses = average_losses_across_data_parallel_group([loss]) + return loss, {"loss": averaged_losses[0]} + + if pretraining_flag: + return loss_func_pretrain + else: + return loss_func_finetune + + def get_forward_step_func(self, pretraining_flag, bert_binary_head): + def forward_step(data_iterator, model): + """Forward step.""" + tokens, types, sentence_order, loss_mask, labels, padding_mask = self.get_batch(data_iterator) + if not bert_binary_head: + types = None + # Forward pass through the model. + if pretraining_flag: + output_tensor = model(tokens, padding_mask, tokentype_ids=types, lm_labels=labels) + return output_tensor, partial(self.loss_func, loss_mask, sentence_order) + else: + logits = model(tokens, padding_mask, tokentype_ids=types) + return logits, partial(self.loss_func, labels) + + return forward_step + + +class GPTTrainStep(AbstractTrainStep): + """ + GPT train step class. + + Args: + args (`argparse.Namespace`): Megatron-LM arguments. + """ + + def __init__(self, args): + super().__init__("GPTTrainStep") + self.get_batch = self.get_batch_func(args.megatron_dataset_flag) + self.loss_func = self.get_loss_func() + self.forward_step = self.get_forward_step_func() + self.eod_token = args.padded_vocab_size - 1 + if args.vocab_file is not None: + tokenizer = get_tokenizer() + self.eod_token = tokenizer.eod + self.reset_position_ids = args.reset_position_ids + self.reset_attention_mask = args.reset_attention_mask + self.eod_mask_loss = args.eod_mask_loss + if not args.model_return_dict: + self.model_output_class = None + else: + self.model_output_class = CausalLMOutputWithCrossAttentions + + def get_batch_func(self, megatron_dataset_flag): + def get_batch_megatron(data_iterator): + """Generate a batch""" + # Items and their type. + keys = ["text"] + datatype = torch.int64 + + # Broadcast data. + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + data_b = mpu.broadcast_data(keys, data, datatype) + + # Unpack. + tokens_ = data_b["text"].long() + labels = tokens_[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + + # Get the masks and postition ids. + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + tokens, self.eod_token, self.reset_position_ids, self.reset_attention_mask, self.eod_mask_loss + ) + + return tokens, labels, loss_mask, attention_mask, position_ids + + def get_batch_transformer(data_iterator): + data = next(data_iterator) + data = {"input_ids": data["input_ids"]} + data = send_to_device(data, torch.cuda.current_device()) + + tokens_ = data["input_ids"].long() + padding = torch.zeros((tokens_.shape[0], 1), dtype=tokens_.dtype, device=tokens_.device) + self.eod_token + tokens_ = torch.concat([tokens_, padding], dim=1) + labels = tokens_[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + # Get the masks and postition ids. + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + tokens, self.eod_token, self.reset_position_ids, self.reset_attention_mask, True + ) + return tokens, labels, loss_mask, attention_mask, position_ids + + if megatron_dataset_flag: + return get_batch_megatron + else: + return get_batch_transformer + + def get_loss_func(self): + args = get_args() + + def loss_func(loss_mask, output_tensor): + if args.return_logits: + losses, logits = output_tensor + else: + losses = output_tensor + losses = losses.float() + loss_mask = loss_mask.view(-1).float() + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([loss]) + + output_dict = {"lm loss": averaged_loss[0]} + if args.return_logits: + output_dict.update({"logits": logits}) + return loss, output_dict + + return loss_func + + def get_forward_step_func(self): + def forward_step(data_iterator, model): + """Forward step.""" + # Get the batch. + tokens, labels, loss_mask, attention_mask, position_ids = self.get_batch(data_iterator) + output_tensor = model(tokens, position_ids, attention_mask, labels=labels) + + return output_tensor, partial(self.loss_func, loss_mask) + + return forward_step + + +class T5TrainStep(AbstractTrainStep): + """ + T5 train step class. + + Args: + args (`argparse.Namespace`): Megatron-LM arguments. + """ + + def __init__(self, args): + super().__init__("T5TrainStep") + self.get_batch = self.get_batch_func(args.megatron_dataset_flag) + self.loss_func = self.get_loss_func() + self.forward_step = self.get_forward_step_func() + if not args.model_return_dict: + self.model_output_class = None + else: + self.model_output_class = Seq2SeqLMOutput + + @staticmethod + def attn_mask_postprocess(attention_mask): + # We create a 3D attention mask from a 2D tensor mask. + # [b, 1, s] + attention_mask_b1s = attention_mask.unsqueeze(1) + # [b, s, 1] + attention_mask_bs1 = attention_mask.unsqueeze(2) + # [b, s, s] + attention_mask_bss = attention_mask_b1s * attention_mask_bs1 + # Convert attention mask to binary: + extended_attention_mask = attention_mask_bss < 0.5 + return extended_attention_mask + + @staticmethod + def get_decoder_mask(seq_length, device): + attention_mask = torch.tril(torch.ones((1, seq_length, seq_length), device=device)) + attention_mask = attention_mask < 0.5 + return attention_mask + + @staticmethod + def get_enc_dec_mask(attention_mask, dec_seq_length, device): + batch_size, _ = attention_mask.shape + # We create a 3D attention mask from a 2D tensor mask. + # [b, 1, s] + attention_mask_b1s = attention_mask.unsqueeze(1) + # [b, s, 1] + attention_mask_bs1 = torch.ones((batch_size, dec_seq_length, 1), device=device) + attention_mask_bss = attention_mask_bs1 * attention_mask_b1s + extended_attention_mask = attention_mask_bss < 0.5 + return extended_attention_mask + + def get_batch_func(self, megatron_dataset_flag): + def get_batch_megatron(data_iterator): + """Build the batch.""" + + keys = ["text_enc", "text_dec", "labels", "loss_mask", "enc_mask", "dec_mask", "enc_dec_mask"] + datatype = torch.int64 + + # Broadcast data. + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + data_b = mpu.broadcast_data(keys, data, datatype) + + # Unpack. + tokens_enc = data_b["text_enc"].long() + tokens_dec = data_b["text_dec"].long() + labels = data_b["labels"].long() + loss_mask = data_b["loss_mask"].float() + + enc_mask = data_b["enc_mask"] < 0.5 + dec_mask = data_b["dec_mask"] < 0.5 + enc_dec_mask = data_b["enc_dec_mask"] < 0.5 + + return tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask, enc_dec_mask + + def get_batch_transformer(data_iterator): + """Build the batch.""" + data = next(data_iterator) + data = send_to_device(data, torch.cuda.current_device()) + + tokens_enc = data["input_ids"].long() + labels = data["labels"].long() + loss_mask = (labels != -100).to(torch.float) + if "decoder_input_ids" in data: + tokens_dec = data["decoder_input_ids"].long() + else: + tokens_dec = labels.new_zeros(labels.shape, device=labels.device, dtype=torch.long) + tokens_dec[..., 1:] = labels[..., :-1].clone() + tokens_dec[..., 0] = 0 + tokens_dec.masked_fill_(tokens_dec == -100, 0) + enc_mask = T5TrainStep.attn_mask_postprocess(data["attention_mask"].long()) + dec_mask = T5TrainStep.get_decoder_mask(tokens_dec.shape[1], tokens_dec.device) + enc_dec_mask = T5TrainStep.get_enc_dec_mask( + data["attention_mask"].long(), tokens_dec.shape[1], tokens_dec.device + ) + + return tokens_enc, tokens_dec, loss_mask, labels, enc_mask, dec_mask, enc_dec_mask + + if megatron_dataset_flag: + return get_batch_megatron + else: + return get_batch_transformer + + def get_loss_func(self): + def loss_func(loss_mask, output_tensor): + lm_loss_ = output_tensor.float() + lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() + + loss = lm_loss + averaged_losses = average_losses_across_data_parallel_group([lm_loss]) + + return loss, {"lm loss": averaged_losses[0]} + + return loss_func + + def get_forward_step_func(self): + def forward_step(data_iterator, model): + """Forward step.""" + # Get the batch. + tokens_enc, tokens_dec, loss_mask, lm_labels, enc_mask, dec_mask, enc_dec_mask = self.get_batch( + data_iterator + ) + # Forward model lm_labels + output_tensor = model( + tokens_enc, tokens_dec, enc_mask, dec_mask, enc_dec_mask, tokentype_ids=None, lm_labels=lm_labels + ) + + return output_tensor, partial(self.loss_func, loss_mask) + + return forward_step + + +# intialize megatron setup +def initialize(accelerator, extra_args_provider=None, args_defaults={}): + accelerator.print("Initializing Megatron-LM") + assert torch.cuda.is_available(), "Megatron requires CUDA." + + # Parse arguments + args = parse_args(extra_args_provider, ignore_unknown_args=True) + + # Set defaults + for key, value in args_defaults.items(): + if getattr(args, key, None) is not None: + if args.rank == 0: + print( + f"WARNING: overriding default arguments for " f"{key}:{getattr(args, key)} with {key}:{value}", + flush=True, + ) + setattr(args, key, value) + + if args.use_checkpoint_args or args_defaults.get("use_checkpoint_args", False): + assert args.load is not None, "--use-checkpoints-args requires --load argument" + load_args_from_checkpoint(args) + + validate_args(args) + + # set global args, build tokenizer, and set adlr-autoresume, + # tensorboard-writer, and timers. + set_global_variables(args) + + # torch.distributed initialization + def finish_mpu_init(): + args = get_args() + # Pytorch distributed. + device_count = torch.cuda.device_count() + args.rank = torch.distributed.get_rank() + args.world_size = torch.distributed.get_world_size() + if device_count > 0: + device = args.rank % device_count + if args.local_rank is not None: + assert args.local_rank == device, "expected local-rank to be the same as rank % device-count." + else: + args.local_rank = device + + # Set the tensor model-parallel, pipeline model-parallel, and + # data-parallel communicators. + if mpu.model_parallel_is_initialized(): + print("model parallel is already initialized") + else: + mpu.initialize_model_parallel( + args.tensor_model_parallel_size, + args.pipeline_model_parallel_size, + args.virtual_pipeline_model_parallel_size, + args.pipeline_model_parallel_split_rank, + ) + + # Random seeds for reproducibility. + if args.rank == 0: + print(f"> setting random seeds to {args.seed} ...") + _set_random_seed(args.seed, args.data_parallel_random_init) + + args = get_args() + + # Megatron's MPU is the master. Complete initialization right away. + finish_mpu_init() + + # Autoresume. + _init_autoresume() + + # Compile dependencies. + _compile_dependencies() + + # Set pytorch JIT layer fusion options and warmup JIT functions. + set_jit_fusion_options() + args = get_args() + args.padded_vocab_size = _vocab_size_with_padding(args.orig_vocab_size, args) + if args.model_type_name == "bert" and args.pretraining_flag and args.num_labels == 2: + args.bert_binary_head = True + else: + args.bert_binary_head = False + args.iteration = 0 + + +class MegatronEngine(torch.nn.Module): + """ + Megatron-LM model wrapper + + Args: + accelerator (:class:`~accelerate.Accelerator`): The accelerator object to use. + model: Megatron-LM model + optimizer: Megatron-LM optimizer + lr_scheduler: Megatron-LM lr scheduler + """ + + def __init__(self, accelerator, model, optimizer, scheduler): + super().__init__() + self.module = model + self.base_model = model[0] + self.optimizer = optimizer + self.scheduler = scheduler + args = get_args() + if accelerator.state.megatron_lm_plugin.custom_train_step_class is not None: + self.train_step_handler = accelerator.state.megatron_lm_plugin.custom_train_step_class( + args, **accelerator.state.megatron_lm_plugin.custom_train_step_kwargs + ) + elif args.model_type_name == "bert": + self.train_step_handler = BertTrainStep(args) + elif args.model_type_name == "gpt": + self.train_step_handler = GPTTrainStep(args) + elif args.model_type_name == "t5": + self.train_step_handler = T5TrainStep(args) + else: + raise ValueError(f"Unsupported model type: {args.model_type_name}") + self.optimizer.skipped_iter = False + + # Tracking loss. + self.total_loss_dict = {} + self.eval_total_loss_dict = {} + self.iteration = 0 + self.report_memory_flag = True + if args.tensorboard_dir is not None: + write_args_to_tensorboard() + + def train(self): + for model_module in self.module: + model_module.train() + self.log_eval_results() + + def eval(self): + for model_module in self.module: + model_module.eval() + + def train_step(self, **batch_data): + """ + Training step for Megatron-LM + + Args: + batch_data (:obj:`dict`): The batch data to train on. + """ + + args = get_args() + timers = get_timers() + + if len(batch_data) > 0: + data_chunks = [] + if args.num_micro_batches > 1: + for i in range(0, args.num_micro_batches): + data_chunks.append( + { + k: v[i * args.micro_batch_size : (i + 1) * args.micro_batch_size] + for k, v in batch_data.items() + } + ) + else: + data_chunks = [batch_data] + + if len(self.module) > 1: + batch_data_iterator = ( + [iter(data_chunks) for _ in range(len(self.module))] + if len(batch_data) > 0 + else [None] * len(self.module) + ) + else: + batch_data_iterator = iter(data_chunks) if len(batch_data) > 0 else None + + # Set grad to zero. + if args.DDP_impl == "local" and args.use_contiguous_buffers_in_local_ddp: + for partition in self.module: + partition.zero_grad_buffer() + self.optimizer.zero_grad() + + # Forward pass. + forward_backward_func = get_forward_backward_func() + losses_reduced = forward_backward_func( + self.train_step_handler.forward_step, + batch_data_iterator, + self.module, + self.optimizer, + None, + forward_only=False, + ) + + # Empty unused memory. + if args.empty_unused_memory_level >= 1: + torch.cuda.empty_cache() + + # Reduce gradients. + timers("backward-reduce-model-grads").start() + self.optimizer.reduce_model_grads(args, timers) + timers("backward-reduce-model-grads").stop() + + # Update parameters. + timers("optimizer").start() + update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step(args, timers) + timers("optimizer").stop() + + # Gather params. + if update_successful: + timers("backward-gather-model-params").start() + self.optimizer.gather_model_params(args, timers) + timers("backward-gather-model-params").stop() + + # Update learning rate. + if update_successful: + if self.scheduler is not None: + increment = get_num_microbatches() * args.micro_batch_size * args.data_parallel_size + self.scheduler.step(increment=increment) + skipped_iter = 0 + else: + skipped_iter = 1 + + self.optimizer.skipped_iter = not update_successful + + # Empty unused memory. + if args.empty_unused_memory_level >= 2: + torch.cuda.empty_cache() + + args.consumed_train_samples += ( + mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches() + ) + + if mpu.is_pipeline_last_stage(ignore_virtual=True): + # Average loss across microbatches. + loss_reduced = {} + for key in losses_reduced[0]: + losses_reduced_for_key = [x[key] for x in losses_reduced] + if len(losses_reduced_for_key[0].shape) == 0: + loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key) + else: + loss_reduced[key] = torch.concat(losses_reduced_for_key) + return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad + return {}, skipped_iter, grad_norm, num_zeros_in_grad + + def eval_step(self, **batch_data): + """ + Evaluation step for Megatron-LM + + Args: + batch_data (:obj:`dict`): The batch data to evaluate on. + """ + + args = get_args() + data_chunks = [] + if args.num_micro_batches > 1: + for i in range(0, args.num_micro_batches): + data_chunks.append( + {k: v[i * args.micro_batch_size : (i + 1) * args.micro_batch_size] for k, v in batch_data.items()} + ) + else: + data_chunks = [batch_data] + + if len(self.module) > 1: + batch_data_iterator = [iter(data_chunks) for _ in range(len(self.module))] + else: + batch_data_iterator = iter(data_chunks) + forward_backward_func = get_forward_backward_func() + loss_dicts = forward_backward_func( + self.train_step_handler.forward_step, + batch_data_iterator, + self.module, + optimizer=None, + timers=None, + forward_only=True, + ) + # Empty unused memory + if args.empty_unused_memory_level >= 1: + torch.cuda.empty_cache() + + args.consumed_valid_samples += ( + mpu.get_data_parallel_world_size() * args.micro_batch_size * get_num_microbatches() + ) + + if mpu.is_pipeline_last_stage(ignore_virtual=True): + # Average loss across microbatches. + loss_reduced = {} + for key in loss_dicts[0]: + losses_reduced_for_key = [x[key] for x in loss_dicts] + if len(losses_reduced_for_key[0].shape) == 0: + loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key) + else: + loss_reduced[key] = torch.concat(losses_reduced_for_key) + return loss_reduced + else: + return {} + + def forward(self, **batch_data): + # During training, we use train_step() + # model(**batch_data) performs following operations by delegating it to `self.train_step`: + # 1. Prepare **batch_data for Tendor, Pipeline and Model Parallelism + # 2. Set grad to zero. + # 3. forward pass and backward pass using Pipeline Parallelism + # 4. Empty unused memory. + # 5. Reduce gradients. + # 6. Update parameters. + # 7. Gather params when using Distributed Optimizer (Data Parallelism). + # 8. Update learning rate if scheduler is specified. + # 9. Empty unused memory. + # 10. Average loss across microbatches and across DP ranks. + # + # During evaluation, we use eval_step() + args = get_args() + if self.module[0].training: + loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = self.train_step(**batch_data) + self.iteration += 1 + if args.tensorboard_dir is not None: + # Logging. + loss_scale = self.optimizer.get_loss_scale().item() + params_norm = None + if args.log_params_norm: + params_norm = calc_params_l2_norm(self.model) + self.report_memory_flag = training_log( + loss_dict, + self.total_loss_dict, + self.optimizer.param_groups[0]["lr"], + self.iteration, + loss_scale, + self.report_memory_flag, + skipped_iter, + grad_norm, + params_norm, + num_zeros_in_grad, + ) + else: + loss_dict = self.eval_step(**batch_data) + if args.tensorboard_dir is not None: + for key in loss_dict: + self.eval_total_loss_dict[key] = ( + self.eval_total_loss_dict.get(key, torch.cuda.FloatTensor([0.0])) + loss_dict[key] + ) + self.eval_total_loss_dict[key + "_num_iters"] = self.eval_total_loss_dict.get( + key + "_num_iters", torch.cuda.FloatTensor([0.0]) + ) + torch.cuda.FloatTensor([1.0]) + + loss = torch.tensor(0.0, device=args.local_rank) + for key in loss_dict: + if len(loss_dict[key].shape) == 0: + loss += loss_dict[key] + + logits = None + if "logits" in loss_dict: + logits = loss_dict["logits"] + # loss = reduce(loss) + if self.train_step_handler.model_output_class is not None: + return self.train_step_handler.model_output_class(loss=loss, logits=logits) + return loss + + def log_eval_results(self): + args = get_args() + if args.tensorboard_dir is None or self.iteration == 0: + return + args = get_args() + writer = get_tensorboard_writer() + string = f"validation loss at iteration {self.iteration} | " + for key in self.eval_total_loss_dict: + if key.endswith("_num_iters"): + continue + value = self.eval_total_loss_dict[key] / self.eval_total_loss_dict[key + "_num_iters"] + string += f"{key} value: {value} | " + ppl = math.exp(min(20, value.item())) + if args.pretraining_flag: + string += f"{key} PPL: {ppl} | " + if writer: + writer.add_scalar(f"{key} validation", value.item(), self.iteration) + if args.pretraining_flag: + writer.add_scalar(f"{key} validation ppl", ppl, self.iteration) + + length = len(string) + 1 + print_rank_last("-" * length) + print_rank_last(string) + print_rank_last("-" * length) + self.eval_total_loss_dict = {} + + def save_checkpoint(self, output_dir): + self.log_eval_results() + args = get_args() + args.save = output_dir + torch.distributed.barrier() + save_checkpoint(self.iteration, self.module, self.optimizer, self.scheduler) + torch.distributed.barrier() + + def load_checkpoint(self, input_dir): + args = get_args() + args.load = input_dir + args.consumed_train_samples = 0 + args.consumed_valid_samples = 0 + torch.distributed.barrier() + iteration = load_checkpoint(self.module, self.optimizer, self.scheduler) + torch.distributed.barrier() + self.iteration = iteration + if args.fp16 and self.iteration == 0: + self.optimizer.reload_model_params() + + def megatron_generate( + self, + inputs, + attention_mask=None, + max_length=None, + max_new_tokens=None, + num_beams=None, + temperature=None, + top_k=None, + top_p=None, + length_penalty=None, + **kwargs, + ): + """ + Generate method for GPT2 model. This method is used for inference. Supports both greedy and beam search along + with sampling. Refer the Megatron-LM repo for more details + + Args: + inputs (torch.Tensor): input ids + attention_mask (torch.Tensor, optional): attention mask. Defaults to None. + max_length (int, optional): max length of the generated sequence. Defaults to None. + Either this or max_new_tokens should be provided. + max_new_tokens (int, optional): max number of tokens to be generated. Defaults to None. + Either this or max_length should be provided. + num_beams (int, optional): number of beams to use for beam search. Defaults to None. + temperature (float, optional): temperature for sampling. Defaults to 1.0. + top_k (int, optional): top k tokens to consider for sampling. Defaults to 0.0. + top_p (float, optional): tokens in top p probability are considered for sampling. Defaults to 0.0. + length_penalty (float, optional): length penalty for beam search. Defaults to None. + kwargs: additional key-value arguments + """ + + # checking if required arguments are passed + args = get_args() + if args.model_type_name != "gpt": + raise NotImplementedError("Generate method is not implemented for this model") + + if args.data_parallel_size > 1: + raise ValueError("Generate method requires data parallelism to be 1") + + if args.sequence_parallel: + raise ValueError("Generate method requires sequence parallelism to be False") + + if args.recompute_granularity is not None: + raise ValueError("Checkpoint activations cannot be set for inference") + + if args.vocab_file is None: + raise ValueError("Vocab file is required for inference") + + # Prepare inputs + if max_length is None and max_new_tokens is None: + raise ValueError("`max_length` or `max_new_tokens` are required for inference") + + if temperature is None: + temperature = 1.0 + elif not (0.0 < temperature <= 100.0): + raise ValueError("temperature must be a positive number less than or equal to 100.0") + + if top_k is None: + top_k = 0 + elif not (0 <= top_k <= 1000): + raise ValueError("top_k must be a positive number less than or equal to 1000") + + if top_p is None: + top_p = 0.0 + elif top_p > 0.0 and top_k > 0.0: + raise ValueError("top_p and top_k sampling cannot be set together") + else: + if not (0.0 <= top_p <= 1.0): + raise ValueError("top_p must be less than or equal to 1.0") + + top_p_decay = kwargs.get("top_p_decay", 0.0) + if not (0.0 <= top_p_decay <= 1.0): + raise ValueError("top_p_decay must be less than or equal to 1.0") + + top_p_bound = kwargs.get("top_p_bound", 0.0) + if not (0.0 <= top_p_bound <= 1.0): + raise ValueError("top_p_bound must be less than or equal to 1.0") + + add_BOS = kwargs.get("add_BOS", False) + if not (isinstance(add_BOS, bool)): + raise ValueError("add_BOS must be a boolean") + + beam_width = num_beams + if beam_width is not None: + if not isinstance(beam_width, int): + raise ValueError("beam_width must be an integer") + if beam_width < 1: + raise ValueError("beam_width must be greater than 0") + if inputs.shape[0] > 1: + return "When doing beam_search, batch size must be 1" + + tokenizer = get_tokenizer() + + stop_token = kwargs.get("stop_token", tokenizer.eod) + if stop_token is not None: + if not isinstance(stop_token, int): + raise ValueError("stop_token must be an integer") + + if length_penalty is None: + length_penalty = 1.0 + + sizes_list = None + prompts_tokens_tensor = None + prompts_length_tensor = None + if torch.distributed.get_rank() == 0: + # Get the prompts length. + if attention_mask is None: + prompts_length_tensor = torch.cuda.LongTensor([inputs.shape[1]] * inputs.shape[0]) + else: + prompts_length_tensor = attention_mask.sum(axis=-1).cuda() + + if max_new_tokens is None: + max_new_tokens = max_length - inputs.shape[1] + if max_new_tokens <= 0: + raise ValueError("max_new_tokens must be greater than 0") + + if add_BOS: + max_length = max_new_tokens + inputs.shape[1] + 1 + # making sure that `max_length` is a multiple of 4 to leverage fused kernels + max_length = 4 * math.ceil(max_length / 4) + max_new_tokens = max_length - (inputs.shape[1] + 1) + padding = torch.cuda.LongTensor([[tokenizer.eod] * max_new_tokens] * inputs.shape[0]) + prompts_tokens_tensor = torch.concat( + [torch.unsqueeze(padding[:, 0], axis=-1), inputs.cuda(), padding], axis=-1 + ) + else: + # making sure that `max_length` is a multiple of 4 to leverage fused kernels + max_length = max_new_tokens + inputs.shape[1] + max_length = 4 * math.ceil(max_length / 4) + max_new_tokens = max_length - inputs.shape[1] + padding = torch.cuda.LongTensor([[tokenizer.eod] * max_new_tokens] * inputs.shape[0]) + prompts_tokens_tensor = torch.concat([inputs.cuda(), padding], axis=-1) + + # We need the sizes of these tensors for the boradcast + sizes_list = [ + prompts_tokens_tensor.size(0), # Batch size + prompts_tokens_tensor.size(1), + ] # Sequence lenght + + # First, broadcast the sizes. + sizes_tensor = broadcast_int_list(2, int_list=sizes_list, rank=0) + + # Now that we have the sizes, we can boradcast the tokens + # and length tensors. + sizes = sizes_tensor.tolist() + context_tokens_tensor = broadcast_tensor(sizes, torch.int64, tensor=prompts_tokens_tensor, rank=0) + context_length_tensor = broadcast_tensor(sizes[0], torch.int64, tensor=prompts_length_tensor, rank=0) + + # Run the inference + random_seed = kwargs.get("random_seed", 0) + torch.random.manual_seed(random_seed) + unwrapped_model = unwrap_model(self.base_model, (torchDDP, LocalDDP, Float16Module)) + if beam_width is not None: + tokens, _ = beam_search_and_return_on_first_stage( + unwrapped_model, + context_tokens_tensor, + context_length_tensor, + beam_width, + stop_token=stop_token, + num_return_gen=1, + length_penalty=length_penalty, + ) + else: + tokens, _, _ = generate_tokens_probs_and_return_on_first_stage( + unwrapped_model, + context_tokens_tensor, + context_length_tensor, + return_output_log_probs=False, + top_k=top_k, + top_p=top_p, + top_p_decay=top_p_decay, + top_p_bound=top_p_bound, + temperature=temperature, + use_eod_token_for_early_termination=True, + ) + return tokens + + +# other utilities +def avg_losses_across_data_parallel_group(losses): + """ + Average losses across data parallel group. + + Args: + losses (List[Tensor]): List of losses to average across data parallel group. + """ + + return average_losses_across_data_parallel_group(losses) + + +def gather_across_data_parallel_groups(tensor): + """ + Recursively gather tensor in a nested list/tuple/dictionary of tensors from data parallel ranks. + + Args: + tensor (nested list/tuple/dictionary of `torch.Tensor`): + The data to gather across data parallel ranks. + + """ + + def _gpu_gather_one(tensor): + if tensor.ndim == 0: + tensor = tensor.clone()[None] + output_tensors = [ + torch.empty_like(tensor) + for _ in range(torch.distributed.get_world_size(group=mpu.get_data_parallel_group())) + ] + torch.distributed.all_gather(output_tensors, tensor, group=mpu.get_data_parallel_group()) + return torch.cat(output_tensors, dim=0) + + return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True) diff --git a/llm/Lib/site-packages/accelerate/utils/memory.py b/llm/Lib/site-packages/accelerate/utils/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..0141bf5f60430fa521de6cf196ac511a50790bb3 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/utils/memory.py @@ -0,0 +1,158 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A collection of utilities for ensuring that training can always occur. Heavily influenced by the +[toma](https://github.com/BlackHC/toma) library. +""" + +import functools +import gc +import inspect + +import torch + +from .imports import is_mlu_available, is_mps_available, is_npu_available, is_xpu_available + + +def release_memory(*objects): + """ + Releases memory from `objects` by setting them to `None` and calls `gc.collect()` and `torch.cuda.empty_cache()`. + Returned objects should be reassigned to the same variables. + + Args: + objects (`Iterable`): + An iterable of objects + Returns: + A list of `None` objects to replace `objects` + + Example: + + ```python + >>> import torch + >>> from accelerate.utils import release_memory + + >>> a = torch.ones(1000, 1000).cuda() + >>> b = torch.ones(1000, 1000).cuda() + >>> a, b = release_memory(a, b) + ``` + """ + if not isinstance(objects, list): + objects = list(objects) + for i in range(len(objects)): + objects[i] = None + gc.collect() + if is_xpu_available(): + torch.xpu.empty_cache() + elif is_mlu_available(): + torch.mlu.empty_cache() + elif is_npu_available(): + torch.npu.empty_cache() + elif is_mps_available(): + torch.mps.empty_cache() + else: + torch.cuda.empty_cache() + return objects + + +def should_reduce_batch_size(exception: Exception) -> bool: + """ + Checks if `exception` relates to CUDA out-of-memory, CUDNN not supported, or CPU out-of-memory + + Args: + exception (`Exception`): + An exception + """ + _statements = [ + "CUDA out of memory.", # CUDA OOM + "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED.", # CUDNN SNAFU + "DefaultCPUAllocator: can't allocate memory", # CPU OOM + ] + if isinstance(exception, RuntimeError) and len(exception.args) == 1: + return any(err in exception.args[0] for err in _statements) + return False + + +def find_executable_batch_size(function: callable = None, starting_batch_size: int = 128): + """ + A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or + CUDNN, the batch size is cut in half and passed to `function` + + `function` must take in a `batch_size` parameter as its first argument. + + Args: + function (`callable`, *optional*): + A function to wrap + starting_batch_size (`int`, *optional*): + The batch size to try and fit into memory + + Example: + + ```python + >>> from accelerate.utils import find_executable_batch_size + + + >>> @find_executable_batch_size(starting_batch_size=128) + ... def train(batch_size, model, optimizer): + ... ... + + + >>> train(model, optimizer) + ``` + """ + if function is None: + return functools.partial(find_executable_batch_size, starting_batch_size=starting_batch_size) + + batch_size = starting_batch_size + + def decorator(*args, **kwargs): + nonlocal batch_size + gc.collect() + if is_xpu_available(): + torch.xpu.empty_cache() + elif is_mlu_available(): + torch.mlu.empty_cache() + elif is_npu_available(): + torch.npu.empty_cache() + else: + torch.cuda.empty_cache() + params = list(inspect.signature(function).parameters.keys()) + # Guard against user error + if len(params) < (len(args) + 1): + arg_str = ", ".join([f"{arg}={value}" for arg, value in zip(params[1:], args[1:])]) + raise TypeError( + f"Batch size was passed into `{function.__name__}` as the first argument when called." + f"Remove this as the decorator already does so: `{function.__name__}({arg_str})`" + ) + while True: + if batch_size == 0: + raise RuntimeError("No executable batch size found, reached zero.") + try: + return function(batch_size, *args, **kwargs) + except Exception as e: + if should_reduce_batch_size(e): + gc.collect() + if is_xpu_available(): + torch.xpu.empty_cache() + elif is_mlu_available(): + torch.mlu.empty_cache() + elif is_npu_available(): + torch.npu.empty_cache() + else: + torch.cuda.empty_cache() + batch_size //= 2 + else: + raise + + return decorator diff --git a/llm/Lib/site-packages/accelerate/utils/modeling.py b/llm/Lib/site-packages/accelerate/utils/modeling.py new file mode 100644 index 0000000000000000000000000000000000000000..d1d7f2bdf9984fd4fcfc098eddf1efe265d05464 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/utils/modeling.py @@ -0,0 +1,1800 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import gc +import importlib +import inspect +import json +import logging +import os +import re +import shutil +import tempfile +import warnings +from collections import OrderedDict, defaultdict +from typing import Dict, List, Optional, Tuple, Union + +import packaging +import torch +import torch.nn as nn + +from ..state import AcceleratorState +from .constants import SAFE_WEIGHTS_NAME, WEIGHTS_NAME +from .dataclasses import AutocastKwargs, CustomDtype, DistributedType +from .imports import ( + is_mlu_available, + is_mps_available, + is_npu_available, + is_peft_available, + is_torch_xla_available, + is_xpu_available, +) +from .offload import load_offloaded_weight, offload_weight, save_offload_index +from .tqdm import is_tqdm_available, tqdm +from .versions import compare_versions + + +if is_npu_available(check_device=False): + import torch_npu # noqa: F401 + +if is_mlu_available(check_device=False): + import torch_mlu # noqa: F401 + +from safetensors import safe_open +from safetensors.torch import load_file as safe_load_file + + +WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json" + +logger = logging.getLogger(__name__) + + +def is_peft_model(model): + from .other import extract_model_from_parallel + + if is_peft_available(): + from peft import PeftModel + + return is_peft_available() and isinstance(extract_model_from_parallel(model), PeftModel) + + +def check_device_same(first_device, second_device): + """ + Utility method to check if two `torch` devices are similar. When dealing with CUDA devices, torch throws `False` + for `torch.device("cuda") == torch.device("cuda:0")` whereas they should be the same + + Args: + first_device (`torch.device`): + First device to check + second_device (`torch.device`): + Second device to check + """ + if first_device.type != second_device.type: + return False + + if first_device.type == "cuda" and first_device.index is None: + # In case the first_device is a cuda device and have + # the index attribute set to `None`, default it to `0` + first_device = torch.device("cuda", index=0) + + if second_device.type == "cuda" and second_device.index is None: + # In case the second_device is a cuda device and have + # the index attribute set to `None`, default it to `0` + second_device = torch.device("cuda", index=0) + + return first_device == second_device + + +def convert_file_size_to_int(size: Union[int, str]): + """ + Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes). + + Args: + size (`int` or `str`): The size to convert. Will be directly returned if an `int`. + + Example: + + ```py + >>> convert_file_size_to_int("1MiB") + 1048576 + ``` + """ + mem_size = -1 + err_msg = ( + f"`size` {size} is not in a valid format. Use an integer for bytes, or a string with an unit (like '5.0GB')." + ) + try: + if isinstance(size, int): + mem_size = size + elif size.upper().endswith("GIB"): + mem_size = int(float(size[:-3]) * (2**30)) + elif size.upper().endswith("MIB"): + mem_size = int(float(size[:-3]) * (2**20)) + elif size.upper().endswith("KIB"): + mem_size = int(float(size[:-3]) * (2**10)) + elif size.upper().endswith("GB"): + int_size = int(float(size[:-2]) * (10**9)) + mem_size = int_size // 8 if size.endswith("b") else int_size + elif size.upper().endswith("MB"): + int_size = int(float(size[:-2]) * (10**6)) + mem_size = int_size // 8 if size.endswith("b") else int_size + elif size.upper().endswith("KB"): + int_size = int(float(size[:-2]) * (10**3)) + mem_size = int_size // 8 if size.endswith("b") else int_size + except ValueError: + raise ValueError(err_msg) + + if mem_size < 0: + raise ValueError(err_msg) + return mem_size + + +def dtype_byte_size(dtype: torch.dtype): + """ + Returns the size (in bytes) occupied by one parameter of type `dtype`. + + Example: + + ```py + >>> dtype_byte_size(torch.float32) + 4 + ``` + """ + if dtype == torch.bool: + return 1 / 8 + elif dtype == CustomDtype.INT2: + return 1 / 4 + elif dtype == CustomDtype.INT4: + return 1 / 2 + elif dtype == CustomDtype.FP8: + return 1 + bit_search = re.search(r"[^\d](\d+)$", str(dtype)) + if bit_search is None: + raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") + bit_size = int(bit_search.groups()[0]) + return bit_size // 8 + + +def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]: + """ + Unique identifier to a tensor storage. Multiple different tensors can share the same underlying storage. For + example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is + guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with + non-overlapping lifetimes may have the same id. + """ + _SIZE = { + torch.int64: 8, + torch.float32: 4, + torch.int32: 4, + torch.bfloat16: 2, + torch.float16: 2, + torch.int16: 2, + torch.uint8: 1, + torch.int8: 1, + torch.bool: 1, + torch.float64: 8, + } + try: + storage_ptr = tensor.untyped_storage().data_ptr() + storage_size = tensor.untyped_storage().nbytes() + except Exception: + # Fallback for torch==1.10 + try: + storage_ptr = tensor.storage().data_ptr() + storage_size = tensor.storage().size() * _SIZE[tensor.dtype] + except NotImplementedError: + # Fallback for meta storage + storage_ptr = 0 + # On torch >=2.0 this is the tensor size + storage_size = tensor.nelement() * _SIZE[tensor.dtype] + + return tensor.device, storage_ptr, storage_size + + +def shard_checkpoint( + state_dict: Dict[str, torch.Tensor], max_shard_size: Union[int, str] = "10GB", weights_name: str = WEIGHTS_NAME +): + """ + Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a + given size. + + The sub-checkpoints are determined by iterating through the `state_dict` in the order of its keys, so there is no + optimization made to make each sub-checkpoint as close as possible to the maximum size passed. For example, if the + limit is 10GB and we have weights of sizes [6GB, 6GB, 2GB, 6GB, 2GB, 2GB] they will get sharded as [6GB], [6+2GB], + [6+2+2GB] and not [6+2+2GB], [6+2GB], [6GB]. + + + + If one of the model's weight is bigger that `max_sahrd_size`, it will end up in its own sub-checkpoint which will + have a size greater than `max_shard_size`. + + + + Args: + state_dict (`Dict[str, torch.Tensor]`): The state dictionary of a model to save. + max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`): + The maximum size of each sub-checkpoint. If expressed as a string, needs to be digits followed by a unit + (like `"5MB"`). + weights_name (`str`, *optional*, defaults to `"pytorch_model.bin"`): + The name of the model save file. + """ + max_shard_size = convert_file_size_to_int(max_shard_size) + + sharded_state_dicts = [{}] + last_block_size = 0 + total_size = 0 + storage_id_to_block = {} + + for key, weight in state_dict.items(): + # when bnb serialization is used the weights in the state dict can be strings + # check: https://github.com/huggingface/transformers/pull/24416 for more details + if isinstance(weight, str): + continue + else: + storage_id = id_tensor_storage(weight) + + # If a `weight` shares the same underlying storage as another tensor, we put `weight` in the same `block` + if storage_id in storage_id_to_block: + block_id = storage_id_to_block[storage_id] + sharded_state_dicts[block_id][key] = weight + continue + + weight_size = weight.numel() * dtype_byte_size(weight.dtype) + + # If this weight is going to tip up over the maximal size, we split. + if last_block_size + weight_size > max_shard_size: + sharded_state_dicts.append({}) + last_block_size = 0 + + sharded_state_dicts[-1][key] = weight + last_block_size += weight_size + total_size += weight_size + storage_id_to_block[storage_id] = len(sharded_state_dicts) - 1 + + # If we only have one shard, we return it + if len(sharded_state_dicts) == 1: + return {weights_name: sharded_state_dicts[0]}, None + + # Otherwise, let's build the index + weight_map = {} + shards = {} + for idx, shard in enumerate(sharded_state_dicts): + shard_file = weights_name.replace(".bin", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.bin") + shard_file = shard_file.replace( + ".safetensors", f"-{idx + 1:05d}-of-{len(sharded_state_dicts):05d}.safetensors" + ) + shards[shard_file] = shard + for key in shard.keys(): + weight_map[key] = shard_file + + # Add the metadata + metadata = {"total_size": total_size} + index = {"metadata": metadata, "weight_map": weight_map} + return shards, index + + +def set_module_tensor_to_device( + module: nn.Module, + tensor_name: str, + device: Union[int, str, torch.device], + value: Optional[torch.Tensor] = None, + dtype: Optional[Union[str, torch.dtype]] = None, + fp16_statistics: Optional[torch.HalfTensor] = None, + tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None, +): + """ + A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing + `param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). + + Args: + module (`torch.nn.Module`): + The module in which the tensor we want to move lives. + tensor_name (`str`): + The full name of the parameter/buffer. + device (`int`, `str` or `torch.device`): + The device on which to set the tensor. + value (`torch.Tensor`, *optional*): + The value of the tensor (useful when going from the meta device to any other device). + dtype (`torch.dtype`, *optional*): + If passed along the value of the parameter will be cast to this `dtype`. Otherwise, `value` will be cast to + the dtype of the existing parameter in the model. + fp16_statistics (`torch.HalfTensor`, *optional*): + The list of fp16 statistics to set on the module, used for 8 bit model serialization. + tied_params_map (Dict[int, Dict[torch.device, torch.Tensor]], *optional*, defaults to `None`): + A map of current data pointers to dictionaries of devices to already dispatched tied weights. For a given + execution device, this parameter is useful to reuse the first available pointer of a shared weight on the + device for all others, instead of duplicating memory. + """ + # Recurse if needed + if "." in tensor_name: + splits = tensor_name.split(".") + for split in splits[:-1]: + new_module = getattr(module, split) + if new_module is None: + raise ValueError(f"{module} has no attribute {split}.") + module = new_module + tensor_name = splits[-1] + + if tensor_name not in module._parameters and tensor_name not in module._buffers: + raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") + is_buffer = tensor_name in module._buffers + old_value = getattr(module, tensor_name) + + # Treat the case where old_value (or a custom `value`, typically offloaded to RAM/disk) belongs to a tied group, and one of the weight + # in the tied group has already been dispatched to the device, by avoiding reallocating memory on the device and just copying the pointer. + if ( + value is not None + and tied_params_map is not None + and value.data_ptr() in tied_params_map + and device in tied_params_map[value.data_ptr()] + ): + module._parameters[tensor_name] = tied_params_map[value.data_ptr()][device] + return + elif ( + tied_params_map is not None + and old_value.data_ptr() in tied_params_map + and device in tied_params_map[old_value.data_ptr()] + ): + module._parameters[tensor_name] = tied_params_map[old_value.data_ptr()][device] + return + + if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None: + raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.") + + if value is not None: + if old_value.shape != value.shape: + raise ValueError( + f'Trying to set a tensor of shape {value.shape} in "{tensor_name}" (which has shape {old_value.shape}), this look incorrect.' + ) + + if dtype is None: + # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model + value = value.to(old_value.dtype) + elif not str(value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): + value = value.to(dtype) + + param = module._parameters[tensor_name] if tensor_name in module._parameters else None + param_cls = type(param) + + device_quantization = None + with torch.no_grad(): + # leave it on cpu first before moving them to cuda + # # fix the case where the device is meta, we don't want to put it on cpu because there is no data =0 + if ( + param is not None + and param.device.type != "cuda" + and torch.device(device).type == "cuda" + and param_cls.__name__ in ["Int8Params", "FP4Params", "Params4bit"] + ): + device_quantization = device + device = "cpu" + # `torch.Tensor.to()` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)). + if is_npu_available() and isinstance(device, int): + device = f"npu:{device}" + elif is_mlu_available() and isinstance(device, int): + device = f"mlu:{device}" + if is_xpu_available() and isinstance(device, int): + device = f"xpu:{device}" + if value is None: + new_value = old_value.to(device) + if dtype is not None and device in ["meta", torch.device("meta")]: + if not str(old_value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): + new_value = new_value.to(dtype) + + if not is_buffer: + module._parameters[tensor_name] = param_cls(new_value, requires_grad=old_value.requires_grad) + elif isinstance(value, torch.Tensor): + new_value = value.to(device) + else: + new_value = torch.tensor(value, device=device) + if device_quantization is not None: + device = device_quantization + if is_buffer: + module._buffers[tensor_name] = new_value + elif value is not None or not check_device_same(torch.device(device), module._parameters[tensor_name].device): + param_cls = type(module._parameters[tensor_name]) + kwargs = module._parameters[tensor_name].__dict__ + if param_cls.__name__ in ["Int8Params", "FP4Params"]: + if param_cls.__name__ == "Int8Params" and new_value.dtype == torch.float32: + # downcast to fp16 if any - needed for 8bit serialization + new_value = new_value.to(torch.float16) + # quantize module that are going to stay on the cpu so that we offload quantized weights + if device == "cpu" and param_cls.__name__ == "Int8Params": + new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(0).to("cpu") + new_value.CB = new_value.CB.to("cpu") + new_value.SCB = new_value.SCB.to("cpu") + else: + new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(device) + elif param_cls.__name__ in ["QTensor", "QBitsTensor"]: + new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad).to(device) + else: + new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(device) + + module._parameters[tensor_name] = new_value + if fp16_statistics is not None: + module._parameters[tensor_name].SCB = fp16_statistics.to(device) + del fp16_statistics + # as we put the weight to meta, it doesn't have SCB attr anymore. make sure that it is not a meta weight + if ( + module.__class__.__name__ == "Linear8bitLt" + and getattr(module.weight, "SCB", None) is None + and str(module.weight.device) != "meta" + ): + # quantize only if necessary + device_index = torch.device(device).index if torch.device(device).type == "cuda" else None + if not getattr(module.weight, "SCB", None) and device_index is not None: + if module.bias is not None and module.bias.device.type != "meta": + # if a bias exists, we need to wait until the bias is set on the correct device + module = module.cuda(device_index) + elif module.bias is None: + # if no bias exists, we can quantize right away + module = module.cuda(device_index) + elif module.__class__.__name__ == "Linear4bit" and getattr(module.weight, "quant_state", None) is None: + # quantize only if necessary + device_index = torch.device(device).index if torch.device(device).type == "cuda" else None + if not getattr(module.weight, "quant_state", None) and device_index is not None: + module.weight = module.weight.cuda(device_index) + # clean pre and post foward hook + if is_npu_available(): + torch.npu.empty_cache() + elif is_mlu_available(): + torch.mlu.empty_cache() + elif is_xpu_available(): + torch.xpu.empty_cache() + else: + torch.cuda.empty_cache() + + # When handling tied weights, we update tied_params_map to keep track of the tied weights that have already been allocated on the device in + # order to avoid duplicating memory, see above. + if ( + tied_params_map is not None + and old_value.data_ptr() in tied_params_map + and device not in tied_params_map[old_value.data_ptr()] + ): + tied_params_map[old_value.data_ptr()][device] = new_value + elif ( + value is not None + and tied_params_map is not None + and value.data_ptr() in tied_params_map + and device not in tied_params_map[value.data_ptr()] + ): + tied_params_map[value.data_ptr()][device] = new_value + + +def named_module_tensors( + module: nn.Module, include_buffers: bool = True, recurse: bool = False, remove_non_persistent: bool = False +): + """ + A helper function that gathers all the tensors (parameters + buffers) of a given module. If `include_buffers=True` + it's the same as doing `module.named_parameters(recurse=recurse) + module.named_buffers(recurse=recurse)`. + + Args: + module (`torch.nn.Module`): + The module we want the tensors on. + include_buffer (`bool`, *optional*, defaults to `True`): + Whether or not to include the buffers in the result. + recurse (`bool`, *optional`, defaults to `False`): + Whether or not to go look in every submodule or just return the direct parameters and buffers. + remove_non_persistent (`bool`, *optional*, defaults to `False`): + Whether or not to remove the non persistent buffer from the buffers. Useful only when include_buffers = + True + """ + yield from module.named_parameters(recurse=recurse) + + if include_buffers: + non_persistent_buffers = set() + if remove_non_persistent: + non_persistent_buffers = get_non_persistent_buffers(module, recurse=recurse) + for named_buffer in module.named_buffers(recurse=recurse): + name, _ = named_buffer + if name not in non_persistent_buffers: + yield named_buffer + + +def get_non_persistent_buffers(module: nn.Module, recurse: bool = False): + """ + Gather all non persistent buffers of a given modules into a set + + Args: + module (`nn.Module`): + The module we want the non persistent buffers on. + recurse (`bool`, *optional*, defaults to `False`): + Whether or not to go look in every submodule or just return the direct non persistent buffers. + """ + + non_persistent_buffers_set = module._non_persistent_buffers_set + if recurse: + for _, m in module.named_modules(): + non_persistent_buffers_set |= m._non_persistent_buffers_set + + return non_persistent_buffers_set + + +class FindTiedParametersResult(list): + """ + This is a subclass of a list to handle backward compatibility for Transformers. Do not rely on the fact this is not + a list or on the `values` method as in the future this will be removed. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def values(self): + # TODO: at the next Transformers release (4.28.0) issue a deprecation warning here. + return sum([x[1:] for x in self], []) + + +def check_tied_parameters_in_config(model: nn.Module): + """ + Check if there is any indication in the given model that some weights should be tied. + + Args: + model (`torch.nn.Module`): The model to inspect + + Returns: + bool: True if the model needs to have tied weights + """ + + # based on model.tie_weights() method + has_tied_word_embedding = False + has_tied_encoder_decoder = False + has_tied_module = False + + if "PreTrainedModel" in [c.__name__ for c in inspect.getmro(model.__class__)]: + has_tied_word_embedding = ( + hasattr(model, "config") + and getattr(model.config, "tie_word_embeddings", False) + and model.get_output_embeddings() + ) + has_tied_encoder_decoder = ( + hasattr(model, "config") + and getattr(model.config, "is_encoder_decoder", False) + and getattr(model.config, "tie_encoder_decoder", False) + ) + has_tied_module = any(hasattr(module, "_tie_weights") for module in model.modules()) + + return any([has_tied_word_embedding, has_tied_encoder_decoder, has_tied_module]) + + +def _get_param_device(param, device_map): + if param in device_map: + return device_map[param] + parent_param = ".".join(param.split(".")[:-1]) + if parent_param == param: + raise ValueError(f"The `device_map` does not contain the module {param}.") + else: + return _get_param_device(parent_param, device_map) + + +def check_tied_parameters_on_same_device(tied_params, device_map): + """ + Check if tied parameters are on the same device + + Args: + tied_params (`List[List[str]]`): + A list of lists of parameter names being all tied together. + + device_map (`Dict[str, Union[int, str, torch.device]]`): + A map that specifies where each submodule should go. + + """ + for tie_param in tied_params: + tie_param_devices = {} + for param in tie_param: + tie_param_devices[param] = _get_param_device(param, device_map) + if len(set(tie_param_devices.values())) > 1: + logger.warn( + f"Tied parameters are on different devices: {tie_param_devices}. " + "Please modify your custom device map or set `device_map='auto'`. " + ) + + +def find_tied_parameters(model: nn.Module, **kwargs): + """ + Find the tied parameters in a given model. + + + + The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore + them. + + + + Args: + model (`torch.nn.Module`): The model to inspect. + + Returns: + List[List[str]]: A list of lists of parameter names being all tied together. + + Example: + + ```py + >>> from collections import OrderedDict + >>> import torch.nn as nn + + >>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))])) + >>> model.linear2.weight = model.linear1.weight + >>> find_tied_parameters(model) + [['linear1.weight', 'linear2.weight']] + ``` + """ + # Initialize result and named_parameters before recursing. + named_parameters = kwargs.get("named_parameters", None) + prefix = kwargs.get("prefix", "") + result = kwargs.get("result", {}) + + if named_parameters is None: + named_parameters = {n: p for n, p in model.named_parameters()} + else: + # A tied parameter will not be in the full `named_parameters` seen above but will be in the `named_parameters` + # of the submodule it belongs to. So while recursing we track the names that are not in the initial + # `named_parameters`. + for name, parameter in model.named_parameters(): + full_name = name if prefix == "" else f"{prefix}.{name}" + if full_name not in named_parameters: + # When we find one, it has to be one of the existing parameters. + for new_name, new_param in named_parameters.items(): + if new_param is parameter: + if new_name not in result: + result[new_name] = [] + result[new_name].append(full_name) + + # Once we have treated direct parameters, we move to the child modules. + for name, child in model.named_children(): + child_name = name if prefix == "" else f"{prefix}.{name}" + find_tied_parameters(child, named_parameters=named_parameters, prefix=child_name, result=result) + + return FindTiedParametersResult([sorted([weight] + list(set(tied))) for weight, tied in result.items()]) + + +def retie_parameters(model, tied_params): + """ + Reties tied parameters in a given model if the link was broken (for instance when adding hooks). + + Args: + model (`torch.nn.Module`): + The model in which to retie parameters. + tied_params (`List[List[str]]`): + A mapping parameter name to tied parameter name as obtained by `find_tied_parameters`. + """ + for tied_group in tied_params: + param_to_tie = None + # two loops : the first one to set param_to_tie , the second one to change the values of tied_group + for param_name in tied_group: + module = model + splits = param_name.split(".") + for split in splits[:-1]: + module = getattr(module, split) + param = getattr(module, splits[-1]) + if param_to_tie is None and param.device != torch.device("meta"): + param_to_tie = param + break + if param_to_tie is not None: + for param_name in tied_group: + module = model + splits = param_name.split(".") + for split in splits[:-1]: + module = getattr(module, split) + setattr(module, splits[-1], param_to_tie) + + +def _get_proper_dtype(dtype: Union[str, torch.device]) -> torch.dtype: + """ + Just does torch.dtype(dtype) if necessary. + """ + if isinstance(dtype, str): + # We accept "torch.float16" or just "float16" + dtype = dtype.replace("torch.", "") + dtype = getattr(torch, dtype) + return dtype + + +def compute_module_sizes( + model: nn.Module, + dtype: Optional[Union[str, torch.device]] = None, + special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None, + buffers_only: bool = False, +): + """ + Compute the size of each submodule of a given model. + """ + if dtype is not None: + dtype = _get_proper_dtype(dtype) + dtype_size = dtype_byte_size(dtype) + if special_dtypes is not None: + special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()} + special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()} + module_sizes = defaultdict(int) + + module_list = [] + + if not buffers_only: + module_list = named_module_tensors(model, recurse=True) + else: + module_list = model.named_buffers(recurse=True) + + for name, tensor in module_list: + if special_dtypes is not None and name in special_dtypes: + size = tensor.numel() * special_dtypes_size[name] + elif dtype is None: + size = tensor.numel() * dtype_byte_size(tensor.dtype) + elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): + # According to the code in set_module_tensor_to_device, these types won't be converted + # so use their original size here + size = tensor.numel() * dtype_byte_size(tensor.dtype) + else: + size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype)) + name_parts = name.split(".") + for idx in range(len(name_parts) + 1): + module_sizes[".".join(name_parts[:idx])] += size + + return module_sizes + + +def compute_module_total_buffer_size( + model: nn.Module, + dtype: Optional[Union[str, torch.device]] = None, + special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None, +): + """ + Compute the total size of buffers in each submodule of a given model. + """ + module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes, buffers_only=True) + return module_sizes.get("", 0) + + +def get_max_layer_size( + modules: List[Tuple[str, torch.nn.Module]], module_sizes: Dict[str, int], no_split_module_classes: List[str] +): + """ + Utility function that will scan a list of named modules and return the maximum size used by one full layer. The + definition of a layer being: + - a module with no direct children (just parameters and buffers) + - a module whose class name is in the list `no_split_module_classes` + + Args: + modules (`List[Tuple[str, torch.nn.Module]]`): + The list of named modules where we want to determine the maximum layer size. + module_sizes (`Dict[str, int]`): + A dictionary mapping each layer name to its size (as generated by `compute_module_sizes`). + no_split_module_classes (`List[str]`): + A list of class names for layers we don't want to be split. + + Returns: + `Tuple[int, List[str]]`: The maximum size of a layer with the list of layer names realizing that maximum size. + """ + max_size = 0 + layer_names = [] + modules_to_treat = modules.copy() + while len(modules_to_treat) > 0: + module_name, module = modules_to_treat.pop(0) + modules_children = list(module.named_children()) if isinstance(module, torch.nn.Module) else [] + if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes: + # No splitting this one so we compare to the max_size + size = module_sizes[module_name] + if size > max_size: + max_size = size + layer_names = [module_name] + elif size == max_size: + layer_names.append(module_name) + else: + modules_to_treat = [(f"{module_name}.{n}", v) for n, v in modules_children] + modules_to_treat + return max_size, layer_names + + +def get_max_memory(max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None): + """ + Get the maximum memory available if nothing is passed, converts string to int otherwise. + """ + import psutil + + if max_memory is None: + if not (torch.cuda.is_available() or is_npu_available() or is_mlu_available() or is_xpu_available()): + max_memory = {} + + else: + # Make sure CUDA is initialized on each GPU to have the right memory info. + if is_npu_available(): + for i in range(torch.npu.device_count()): + _ = torch.tensor(0, device=torch.device("npu", i)) + max_memory = {i: torch.npu.mem_get_info(i)[0] for i in range(torch.npu.device_count())} + elif is_mlu_available(): + for i in range(torch.mlu.device_count()): + _ = torch.tensor(0, device=torch.device("mlu", i)) + max_memory = {i: torch.mlu.mem_get_info(i)[0] for i in range(torch.mlu.device_count())} + elif is_xpu_available(): + for i in range(torch.xpu.device_count()): + _ = torch.tensor(0, device=torch.device("xpu", i)) + max_memory = {i: torch.xpu.max_memory_allocated(i) for i in range(torch.xpu.device_count())} + else: + for i in range(torch.cuda.device_count()): + _ = torch.tensor([0], device=i) + max_memory = {i: torch.cuda.mem_get_info(i)[0] for i in range(torch.cuda.device_count())} + # allocate everything in the mps device as the RAM is shared + if is_mps_available(): + max_memory["mps"] = psutil.virtual_memory().available + else: + max_memory["cpu"] = psutil.virtual_memory().available + return max_memory + + for key in max_memory: + if isinstance(max_memory[key], str): + max_memory[key] = convert_file_size_to_int(max_memory[key]) + + # Need to sort the device by type to make sure that we allocate the gpu first. + # As gpu/npu/xpu are represented by int, we need to sort them first. + gpu_devices = [k for k in max_memory.keys() if isinstance(k, int)] + gpu_devices.sort() + # check if gpu/npu/xpu devices are available and if not, throw a warning + if is_npu_available(): + num_devices = torch.npu.device_count() + elif is_mlu_available(): + num_devices = torch.mlu.device_count() + elif is_xpu_available(): + num_devices = torch.xpu.device_count() + else: + num_devices = torch.cuda.device_count() + for device in gpu_devices: + if device >= num_devices or device < 0: + logger.warning(f"Device {device} is not available, available devices are {list(range(num_devices))}") + # Add the other devices in the preset order if they are available + all_devices = gpu_devices + [k for k in ["mps", "cpu", "disk"] if k in max_memory.keys()] + # Raise an error if a device is not recognized + for k in max_memory.keys(): + if k not in all_devices: + raise ValueError( + f"Device {k} is not recognized, available devices are integers(for GPU/XPU), 'mps', 'cpu' and 'disk'" + ) + max_memory = {k: max_memory[k] for k in all_devices} + + return max_memory + + +def clean_device_map(device_map: Dict[str, Union[int, str, torch.device]], module_name: str = ""): + """ + Cleans a device_map by grouping all submodules that go on the same device together. + """ + # Get the value of the current module and if there is only one split across several keys, regroup it. + prefix = "" if module_name == "" else f"{module_name}." + values = [v for k, v in device_map.items() if k.startswith(prefix)] + if len(set(values)) == 1 and len(values) > 1: + for k in [k for k in device_map if k.startswith(prefix)]: + del device_map[k] + device_map[module_name] = values[0] + + # Recurse over the children + children_modules = [k for k in device_map.keys() if k.startswith(prefix) and len(k) > len(module_name)] + idx = len(module_name.split(".")) + 1 if len(module_name) > 0 else 1 + children_modules = set(".".join(k.split(".")[:idx]) for k in children_modules) + for child in children_modules: + clean_device_map(device_map, module_name=child) + + return device_map + + +def load_offloaded_weights(model, index, offload_folder): + """ + Loads the weights from the offload folder into the model. + + Args: + model (`torch.nn.Module`): + The model to load the weights into. + index (`dict`): + A dictionary containing the parameter name and its metadata for each parameter that was offloaded from the + model. + offload_folder (`str`): + The folder where the offloaded weights are stored. + """ + if index is None or len(index) == 0: + # Nothing to do + return + for param_name, metadata in index.items(): + if "SCB" in param_name: + continue + fp16_statistics = None + if "weight" in param_name and param_name.replace("weight", "SCB") in index.keys(): + weight_name = param_name.replace("weight", "SCB") + fp16_statistics = load_offloaded_weight( + os.path.join(offload_folder, f"{weight_name}.dat"), index[weight_name] + ) + tensor_file = os.path.join(offload_folder, f"{param_name}.dat") + weight = load_offloaded_weight(tensor_file, metadata) + set_module_tensor_to_device(model, param_name, "cpu", value=weight, fp16_statistics=fp16_statistics) + + +def get_balanced_memory( + model: nn.Module, + max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None, + no_split_module_classes: Optional[List[str]] = None, + dtype: Optional[Union[str, torch.dtype]] = None, + special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None, + low_zero: bool = False, +): + """ + Compute a `max_memory` dictionary for [`infer_auto_device_map`] that will balance the use of each available GPU. + + + + All computation is done analyzing sizes and dtypes of the model parameters. As a result, the model can be on the + meta device (as it would if initialized within the `init_empty_weights` context manager). + + + + Args: + model (`torch.nn.Module`): + The model to analyze. + max_memory (`Dict`, *optional*): + A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset. + Example: `max_memory={0: "1GB"}`. + no_split_module_classes (`List[str]`, *optional*): + A list of layer class names that should never be split across device (for instance any layer that has a + residual connection). + dtype (`str` or `torch.dtype`, *optional*): + If provided, the weights will be converted to that type when loaded. + special_dtypes (`Dict[str, Union[str, torch.device]]`, *optional*): + If provided, special dtypes to consider for some specific weights (will override dtype used as default for + all weights). + low_zero (`bool`, *optional*): + Minimizes the number of weights on GPU 0, which is convenient when it's used for other operations (like the + Transformers generate function). + """ + # Get default / clean up max_memory + user_not_set_max_memory = max_memory is None + max_memory = get_max_memory(max_memory) + + if is_npu_available(): + num_devices = len([d for d in max_memory if torch.device(d).type == "npu" and max_memory[d] > 0]) + elif is_mlu_available(): + num_devices = len([d for d in max_memory if torch.device(d).type == "mlu" and max_memory[d] > 0]) + elif is_xpu_available(): + num_devices = len( + [ + d + for d in max_memory + if ( + d != "cpu" + and (torch.device(d).type == "xpu" or torch.xpu.get_device_properties(d).dev_type == "gpu") + ) + and max_memory[d] > 0 + ] + ) + else: + num_devices = len([d for d in max_memory if torch.device(d).type == "cuda" and max_memory[d] > 0]) + + if num_devices == 0: + return max_memory + + if num_devices == 1: + # We cannot do low_zero on just one GPU, but we will still reserve some memory for the buffer + low_zero = False + # If user just asked us to handle memory usage, we should avoid OOM + if user_not_set_max_memory: + for key in max_memory.keys(): + if isinstance(key, int): + max_memory[key] *= 0.9 # 90% is a good compromise + logger.info( + f"We will use 90% of the memory on device {key} for storing the model, and 10% for the buffer to avoid OOM. " + "You can set `max_memory` in to a higher value to use more memory (at your own risk)." + ) + break # only one device + + module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes) + per_gpu = module_sizes[""] // (num_devices - 1 if low_zero else num_devices) + + # We can't just set the memory to model_size // num_devices as it will end being too small: each GPU will get + # slightly less layers and some layers will end up offload at the end. So this function computes a buffer size to + # add which is the biggest of: + # - the size of no split block (if applicable) + # - the mean of the layer sizes + if no_split_module_classes is None: + no_split_module_classes = [] + elif not isinstance(no_split_module_classes, (list, tuple)): + no_split_module_classes = [no_split_module_classes] + + # Identify the size of the no_split_block modules + if len(no_split_module_classes) > 0: + no_split_children = {} + for name, size in module_sizes.items(): + if name == "": + continue + submodule = model + for submodule_name in name.split("."): + submodule = getattr(submodule, submodule_name) + class_name = submodule.__class__.__name__ + if class_name in no_split_module_classes and class_name not in no_split_children: + no_split_children[class_name] = size + + if set(no_split_children.keys()) == set(no_split_module_classes): + break + buffer = max(no_split_children.values()) if len(no_split_children) > 0 else 0 + else: + buffer = 0 + + # Compute mean of final modules. In the first dict of module sizes, leaves are the parameters + leaves = [n for n in module_sizes if len([p for p in module_sizes if n == "" or p.startswith(n + ".")]) == 0] + module_sizes = {n: v for n, v in module_sizes.items() if n not in leaves} + # Once removed, leaves are the final modules. + leaves = [n for n in module_sizes if len([p for p in module_sizes if n == "" or p.startswith(n + ".")]) == 0] + mean_leaves = int(sum([module_sizes[n] for n in leaves]) / max(len(leaves), 1)) + buffer = int(1.25 * max(buffer, mean_leaves)) + per_gpu += buffer + + # Sorted list of GPUs id (we may have some gpu ids not included in the our max_memory list - let's ignore them) + gpus_idx_list = list( + sorted( + device_id for device_id, device_mem in max_memory.items() if isinstance(device_id, int) and device_mem > 0 + ) + ) + # The last device is left with max_memory just in case the buffer is not enough. + for idx in gpus_idx_list[:-1]: + max_memory[idx] = min(max_memory[0] if low_zero and idx == 0 else per_gpu, max_memory[idx]) + + if low_zero: + min_zero = max(0, module_sizes[""] - sum([max_memory[i] for i in range(1, num_devices)])) + max_memory[0] = min(min_zero, max_memory[0]) + + return max_memory + + +def calculate_maximum_sizes(model: torch.nn.Module): + "Computes the total size of the model and its largest layer" + sizes = compute_module_sizes(model) + # `transformers` models store this information for us + no_split_modules = getattr(model, "_no_split_modules", None) + if no_split_modules is None: + no_split_modules = [] + + modules_to_treat = ( + list(model.named_parameters(recurse=False)) + + list(model.named_children()) + + list(model.named_buffers(recurse=False)) + ) + largest_layer = get_max_layer_size(modules_to_treat, sizes, no_split_modules) + total_size = sizes[""] + return total_size, largest_layer + + +def infer_auto_device_map( + model: nn.Module, + max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None, + no_split_module_classes: Optional[List[str]] = None, + dtype: Optional[Union[str, torch.dtype]] = None, + special_dtypes: Optional[Dict[str, Union[str, torch.dtype]]] = None, + verbose: bool = False, + clean_result: bool = True, + offload_buffers: bool = False, +): + """ + Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk, + such that: + - we don't exceed the memory available of any of the GPU. + - if offload to the CPU is needed, there is always room left on GPU 0 to put back the layer offloaded on CPU that + has the largest size. + - if offload to the CPU is needed,we don't exceed the RAM available on the CPU. + - if offload to the disk is needed, there is always room left on the CPU to put back the layer offloaded on disk + that has the largest size. + + + + All computation is done analyzing sizes and dtypes of the model parameters. As a result, the model can be on the + meta device (as it would if initialized within the `init_empty_weights` context manager). + + + + Args: + model (`torch.nn.Module`): + The model to analyze. + max_memory (`Dict`, *optional*): + A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset. + Example: `max_memory={0: "1GB"}`. + no_split_module_classes (`List[str]`, *optional*): + A list of layer class names that should never be split across device (for instance any layer that has a + residual connection). + dtype (`str` or `torch.dtype`, *optional*): + If provided, the weights will be converted to that type when loaded. + special_dtypes (`Dict[str, Union[str, torch.device]]`, *optional*): + If provided, special dtypes to consider for some specific weights (will override dtype used as default for + all weights). + verbose (`bool`, *optional*, defaults to `False`): + Whether or not to provide debugging statements as the function builds the device_map. + clean_result (`bool`, *optional*, defaults to `True`): + Clean the resulting device_map by grouping all submodules that go on the same device together. + offload_buffers (`bool`, *optional*, defaults to `False`): + In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as + well as the parameters. + """ + # Get default / clean up max_memory + max_memory = get_max_memory(max_memory) + if no_split_module_classes is None: + no_split_module_classes = [] + elif not isinstance(no_split_module_classes, (list, tuple)): + no_split_module_classes = [no_split_module_classes] + + devices = list(max_memory.keys()) + if "disk" not in devices: + devices.append("disk") + gpus = [device for device in devices if device not in ["cpu", "disk"]] + + # Devices that need to keep space for a potential offloaded layer. + if "mps" in gpus: + main_devices = ["mps"] + elif len(gpus) > 0: + main_devices = [gpus[0], "cpu"] + else: + main_devices = ["cpu"] + + module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes) + tied_parameters = find_tied_parameters(model) + + if check_tied_parameters_in_config(model) and len(tied_parameters) == 0: + logger.warn( + "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function." + ) + + device_map = OrderedDict() + current_device = 0 + current_memory_used = 0 + device_memory_used = {} + device_buffer_sizes = {} + + # Direct submodules and parameters + modules_to_treat = ( + list(model.named_parameters(recurse=False)) + + list(model.named_children()) + + list(model.named_buffers(recurse=False)) + ) + # Initialize maximum largest layer, to know which space to keep in memory + max_layer_size, max_layer_names = get_max_layer_size(modules_to_treat, module_sizes, no_split_module_classes) + + # Ready ? This is going to be a bit messy. + while len(modules_to_treat) > 0: + name, module = modules_to_treat.pop(0) + if verbose: + print(f"\nTreating module {name}.") + # Max size in the remaining layers may have changed since we took one, so we maybe update it. + max_layer_names = [n for n in max_layer_names if n != name and not n.startswith(name + ".")] + if len(max_layer_names) == 0: + max_layer_size, max_layer_names = get_max_layer_size( + [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)], + module_sizes, + no_split_module_classes, + ) + # Assess size needed + module_size = module_sizes[name] + + # We keep relevant tied parameters only: one of the tied parameters in the group is inside the current module + # and the other is not. + # Note: If we are currently processing the name `compute.weight`, an other parameter named e.g. `compute.weight_submodule.parameter` + # needs to be considered outside the current module, hence the check with additional dots. + tied_param_goups = [ + tied_group + for tied_group in tied_parameters + if any(name + "." in k + "." for k in tied_group) and not all(name + "." in k + "." for k in tied_group) + ] + + if verbose and len(tied_param_goups) > 0: + print(f" Found the relevant tied param groups {tied_param_goups}") + + # Then we keep track of all the parameters that are tied to the current module, but not in the current module + tied_params = sum( + [[p for p in tied_group if name + "." not in p + "."] for tied_group in tied_param_goups], [] + ) + + if verbose and len(tied_params) > 0: + print(f" So those parameters need to be taken into account {tied_params}") + + device = devices[current_device] + current_max_size = max_memory[device] if device != "disk" else None + current_memory_reserved = 0 + # Reduce max size available by the largest layer. + if devices[current_device] in main_devices: + current_max_size = current_max_size - max_layer_size + current_memory_reserved = max_layer_size + # Case 1 -> We're too big! + if current_max_size is not None and current_memory_used + module_size > current_max_size: + # Split or not split? + modules_children = ( + [] + if isinstance(module, nn.Parameter) or isinstance(module, torch.Tensor) + else list(module.named_children()) + ) + if verbose: + print( + f"Not enough space on {devices[current_device]} to put {name} (space available " + f"{current_max_size - current_memory_used}, module size {module_size})." + ) + if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes: + # -> no split, we go to the next device + if verbose: + print("This module cannot be split, going to the next device.") + + device_memory_used[device] = current_memory_used + current_memory_reserved + current_device += 1 + modules_to_treat = [(name, module)] + modules_to_treat + current_memory_used = 0 + else: + # -> split, we replace the module studied by its children + parameters + if verbose: + print(f"Splitting {name}.") + modules_children = list(module.named_parameters(recurse=False)) + modules_children + modules_to_treat = [(f"{name}.{n}", v) for n, v in modules_children] + modules_to_treat + # Update the max layer size. + max_layer_size, max_layer_names = get_max_layer_size( + [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)], + module_sizes, + no_split_module_classes, + ) + + # Case 2, it fits! We're not entirely out of the wood though, because we may have some tied parameters. + elif len(tied_params) > 0: + # First locate all tied modules + tied_module_names = [] + tied_modules = [] + for tied_param in tied_params: + tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n in tied_param][0] + tied_module_names.append(modules_to_treat[tied_module_index][0]) + tied_modules.append(modules_to_treat[tied_module_index][1]) + if verbose: + print( + f" It looks like {name} is going to fit on {devices[current_device]} but we have tied " + f"parameters to account for.\n - Names {tied_params}\n - Module names {tied_module_names}" + ) + + # Let's see if it all fits first + module_size_with_ties = module_size + for tied_param, tied_module_name in zip(tied_params, tied_module_names): + module_size_with_ties += module_sizes[tied_module_name] - module_sizes[tied_param] + + if current_max_size is None or current_memory_used + module_size_with_ties <= current_max_size: + # We really really fit! + if verbose: + print(f"Putting {name} and {tied_module_names} on {devices[current_device]}.") + current_memory_used += module_size_with_ties + device_map[name] = devices[current_device] + for tied_module_name in tied_module_names: + if tied_module_name in [m[0] for m in modules_to_treat]: + # The module may have been removed by a previous iteration of this loop. + tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][ + 0 + ] + modules_to_treat.pop(tied_module_index) + device_map[tied_module_name] = devices[current_device] + + if not offload_buffers and isinstance(module, nn.Module): + current_buffer_size = compute_module_total_buffer_size( + module, dtype=dtype, special_dtypes=special_dtypes + ) + device_buffer_sizes[device] = device_buffer_sizes.get(device, 0) + current_buffer_size + + else: + # We don't fit with the tied modules. Next question is: can we split one of the tied modules to make it + # smaller or do we need to go on the next device? + if verbose: + print( + f"Not enough space on {devices[current_device]} to put {name} and {tied_module_names} (space " + f"available {current_max_size - current_memory_used}, needed size {module_size_with_ties})." + ) + split_happened = False + for tied_module_name, tied_module in zip(tied_module_names, tied_modules): + tied_module_children = list(tied_module.named_children()) + if len(tied_module_children) == 0 or tied_module.__class__.__name__ in no_split_module_classes: + # can't break this one. + continue + + if verbose: + print(f"Splitting {tied_module_name}.") + tied_module_children = list(tied_module.named_parameters(recurse=False)) + tied_module_children + tied_module_children = [(f"{tied_module_name}.{n}", v) for n, v in tied_module_children] + tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][0] + + modules_to_treat = ( + [(name, module)] + + modules_to_treat[:tied_module_index] + + tied_module_children + + modules_to_treat[tied_module_index + 1 :] + ) + # Update the max layer size. + max_layer_size, max_layer_names = get_max_layer_size( + [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)], + module_sizes, + no_split_module_classes, + ) + split_happened = True + break + + if not split_happened: + # If the tied module is not split, we go to the next device + if verbose: + print("None of the tied module can be split, going to the next device.") + + device_memory_used[device] = current_memory_used + current_memory_reserved + current_device += 1 + modules_to_treat = [(name, module)] + modules_to_treat + current_memory_used = 0 + + else: + if verbose: + if current_max_size is None: + print(f"Putting {name} (size={module_size}) on {devices[current_device]}.") + else: + print( + f"Putting {name} (size={module_size}) on {devices[current_device]} " + f"(available={current_max_size - current_memory_used})." + ) + current_memory_used += module_size + device_memory_used[device] = current_memory_used + current_memory_reserved + device_map[name] = devices[current_device] + + if not offload_buffers and isinstance(module, nn.Module): + current_buffer_size = compute_module_total_buffer_size( + module, dtype=dtype, special_dtypes=special_dtypes + ) + device_buffer_sizes[device] = device_buffer_sizes.get(device, 0) + current_buffer_size + + if clean_result: + device_map = clean_device_map(device_map) + + non_gpu_buffer_size = device_buffer_sizes.get("cpu", 0) + device_buffer_sizes.get("disk", 0) + if non_gpu_buffer_size > 0 and not offload_buffers: + is_buffer_fit_any_gpu = False + for gpu_device, gpu_max_memory in max_memory.items(): + if gpu_device == "cpu" or gpu_device == "disk": + continue + + if not is_buffer_fit_any_gpu: + gpu_memory_used = device_memory_used.get(gpu_device, 0) + + if gpu_max_memory >= non_gpu_buffer_size + gpu_memory_used: + is_buffer_fit_any_gpu = True + + if len(gpus) > 0 and not is_buffer_fit_any_gpu: + warnings.warn( + f"Current model requires {non_gpu_buffer_size} bytes of buffer for offloaded layers, which seems does " + f"not fit any GPU's remaining memory. If you are experiencing a OOM later, please consider using " + f"offload_buffers=True." + ) + + return device_map + + +def check_device_map(model: nn.Module, device_map: Dict[str, Union[int, str, torch.device]]): + """ + Checks a device map covers everything in a given model. + + Args: + model (`torch.nn.Module`): The model to check the device map against. + device_map (`Dict[str, Union[int, str, torch.device]]`): The device map to check. + """ + all_model_tensors = [name for name, _ in model.state_dict().items()] + for module_name in device_map.keys(): + if module_name == "": + all_model_tensors.clear() + break + else: + all_model_tensors = [ + name + for name in all_model_tensors + if not name == module_name and not name.startswith(module_name + ".") + ] + if len(all_model_tensors) > 0: + non_covered_params = ", ".join(all_model_tensors) + raise ValueError( + f"The device_map provided does not give any device for the following parameters: {non_covered_params}" + ) + + +def load_state_dict(checkpoint_file, device_map=None): + """ + Load a checkpoint from a given file. If the checkpoint is in the safetensors format and a device map is passed, the + weights can be fast-loaded directly on the GPU. + + Args: + checkpoint_file (`str`): The path to the checkpoint to load. + device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer + name, once a given module name is inside, every submodule of it will be sent to the same device. + """ + if checkpoint_file.endswith(".safetensors"): + with safe_open(checkpoint_file, framework="pt") as f: + metadata = f.metadata() + weight_names = f.keys() + + if metadata is None: + logger.warn( + f"The safetensors archive passed at {checkpoint_file} does not contain metadata. " + "Make sure to save your model with the `save_pretrained` method. Defaulting to 'pt' metadata." + ) + metadata = {"format": "pt"} + + if metadata.get("format") not in ["pt", "tf", "flax"]: + raise OSError( + f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " + "you save your model with the `save_pretrained` method." + ) + elif metadata["format"] != "pt": + raise ValueError(f"The checkpoint passed was saved with {metadata['format']}, we need a the pt format.") + if device_map is None: + return safe_load_file(checkpoint_file) + else: + # if we only have one device we can load everything directly + if len(set(device_map.values())) == 1: + return safe_load_file(checkpoint_file, device=list(device_map.values())[0]) + + devices = list(set(device_map.values()) - {"disk"}) + # cpu device should always exist as fallback option + if "cpu" not in devices: + devices.append("cpu") + + # For each device, get the weights that go there + device_weights = {device: [] for device in devices} + for module_name, device in device_map.items(): + if device in devices: + device_weights[device].extend( + [k for k in weight_names if k == module_name or k.startswith(module_name + ".")] + ) + + # all weights that haven't defined a device should be loaded on CPU + device_weights["cpu"].extend([k for k in weight_names if k not in sum(device_weights.values(), [])]) + tensors = {} + if is_tqdm_available(): + progress_bar = tqdm( + main_process_only=False, + total=sum([len(device_weights[device]) for device in devices]), + unit="w", + smoothing=0, + leave=False, + ) + else: + progress_bar = None + for device in devices: + target_device = device + + if is_xpu_available(): + current_safetensors_version = packaging.version.parse(importlib.metadata.version("safetensors")) + + if compare_versions(current_safetensors_version, "<", "0.4.2"): + raise ModuleNotFoundError( + f"You need at least safetensors 0.4.2 for Intel GPU, while you have {current_safetensors_version}" + ) + + if isinstance(device, int): + target_device = f"xpu:{device}" + + with safe_open(checkpoint_file, framework="pt", device=target_device) as f: + for key in device_weights[device]: + if progress_bar is not None: + progress_bar.set_postfix(dev=device, refresh=False) + progress_bar.set_description(key) + tensors[key] = f.get_tensor(key) + if progress_bar is not None: + progress_bar.update() + if progress_bar is not None: + progress_bar.close() + + return tensors + else: + return torch.load(checkpoint_file, map_location=torch.device("cpu")) + + +def get_state_dict_offloaded_model(model: nn.Module): + """ + Returns the state dictionary for an offloaded model via iterative onloading + + Args: + model (`torch.nn.Module`): + The offloaded model we want to save + """ + from ..hooks import AlignDevicesHook + + state_dict = {} + placeholders = set() + for name, module in model.named_modules(): + if name == "": + continue + if hasattr(module, "_hf_hook") and isinstance(module._hf_hook, AlignDevicesHook) and module._hf_hook.offload: + original_device = module._hf_hook.execution_device + # assign hook execution device to cpu + module._hf_hook.execution_device = "cpu" + # onload meta tensors to execution device + try: + module._hf_hook.pre_forward(module) + except MemoryError: + raise MemoryError("Offloaded module must fit in CPU memory to call save_model!") from None + module_state_dict = module.state_dict() + # offload meta tensors from cpu + module._hf_hook.post_forward(module, torch.tensor([])) + # re-assign hook to original execution device + module._hf_hook.execution_device = original_device + else: + module_state_dict = module.state_dict() + + for key in module_state_dict: + # ignore placeholder parameters that are still on the meta device + if module_state_dict[key].device == torch.device("meta"): + placeholders.add(name + f".{key}") + continue + params = module_state_dict[key] + state_dict[name + f".{key}"] = params + for key in placeholders.copy(): + if key in state_dict: + placeholders.remove(key) + if placeholders: + logger.warning(f"The following tensors were not saved because they were still on meta device: {placeholders}") + + return state_dict + + +def load_checkpoint_in_model( + model: nn.Module, + checkpoint: Union[str, os.PathLike], + device_map: Optional[Dict[str, Union[int, str, torch.device]]] = None, + offload_folder: Optional[Union[str, os.PathLike]] = None, + dtype: Optional[Union[str, torch.dtype]] = None, + offload_state_dict: bool = False, + offload_buffers: bool = False, + keep_in_fp32_modules: List[str] = None, + offload_8bit_bnb: bool = False, + strict: bool = False, +): + """ + Loads a (potentially sharded) checkpoint inside a model, potentially sending weights to a given device as they are + loaded. + + + + Once loaded across devices, you still need to call [`dispatch_model`] on your model to make it able to run. To + group the checkpoint loading and dispatch in one single call, use [`load_checkpoint_and_dispatch`]. + + + + Args: + model (`torch.nn.Module`): + The model in which we want to load a checkpoint. + checkpoint (`str` or `os.PathLike`): + The folder checkpoint to load. It can be: + - a path to a file containing a whole model state dict + - a path to a `.json` file containing the index to a sharded checkpoint + - a path to a folder containing a unique `.index.json` file and the shards of a checkpoint. + - a path to a folder containing a unique pytorch_model.bin or a model.safetensors file. + device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer + name, once a given module name is inside, every submodule of it will be sent to the same device. + offload_folder (`str` or `os.PathLike`, *optional*): + If the `device_map` contains any value `"disk"`, the folder where we will offload weights. + dtype (`str` or `torch.dtype`, *optional*): + If provided, the weights will be converted to that type when loaded. + offload_state_dict (`bool`, *optional*, defaults to `False`): + If `True`, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if + the weight of the CPU state dict + the biggest shard does not fit. + offload_buffers (`bool`, *optional*, defaults to `False`): + Whether or not to include the buffers in the weights offloaded to disk. + keep_in_fp32_modules(`List[str]`, *optional*): + A list of the modules that we keep in `torch.float32` dtype. + offload_8bit_bnb (`bool`, *optional*): + Whether or not to enable offload of 8-bit modules on cpu/disk. + strict (`bool`, *optional*, defaults to `False`): + Whether to strictly enforce that the keys in the checkpoint state_dict match the keys of the model's + state_dict. + + """ + if offload_8bit_bnb: + from .bnb import quantize_and_offload_8bit + + tied_params = find_tied_parameters(model) + + if check_tied_parameters_in_config(model) and len(tied_params) == 0: + logger.warn( + "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function." + ) + if device_map is not None: + check_tied_parameters_on_same_device(tied_params, device_map) + + if offload_folder is None and device_map is not None and "disk" in device_map.values(): + raise ValueError( + "At least one of the model submodule will be offloaded to disk, please pass along an `offload_folder`." + ) + elif offload_folder is not None and device_map is not None and "disk" in device_map.values(): + os.makedirs(offload_folder, exist_ok=True) + + if isinstance(dtype, str): + # We accept "torch.float16" or just "float16" + dtype = dtype.replace("torch.", "") + dtype = getattr(torch, dtype) + + checkpoint_files = None + index_filename = None + if os.path.isfile(checkpoint): + if str(checkpoint).endswith(".json"): + index_filename = checkpoint + else: + checkpoint_files = [checkpoint] + elif os.path.isdir(checkpoint): + # check if the whole state dict is present + potential_state_bin = [f for f in os.listdir(checkpoint) if f == WEIGHTS_NAME] + potential_state_safetensor = [f for f in os.listdir(checkpoint) if f == SAFE_WEIGHTS_NAME] + if len(potential_state_bin) == 1: + checkpoint_files = [os.path.join(checkpoint, potential_state_bin[0])] + elif len(potential_state_safetensor) == 1: + checkpoint_files = [os.path.join(checkpoint, potential_state_safetensor[0])] + else: + # otherwise check for sharded checkpoints + potential_index = [f for f in os.listdir(checkpoint) if f.endswith(".index.json")] + if len(potential_index) == 0: + raise ValueError( + f"{checkpoint} is not a folder containing a `.index.json` file or a {WEIGHTS_NAME} or a {SAFE_WEIGHTS_NAME} file" + ) + elif len(potential_index) == 1: + index_filename = os.path.join(checkpoint, potential_index[0]) + else: + raise ValueError( + f"{checkpoint} containing more than one `.index.json` file, delete the irrelevant ones." + ) + else: + raise ValueError( + "`checkpoint` should be the path to a file containing a whole state dict, or the index of a sharded " + f"checkpoint, or a folder containing a sharded checkpoint or the whole state dict, but got {checkpoint}." + ) + + if index_filename is not None: + checkpoint_folder = os.path.split(index_filename)[0] + with open(index_filename) as f: + index = json.loads(f.read()) + + if "weight_map" in index: + index = index["weight_map"] + checkpoint_files = sorted(list(set(index.values()))) + checkpoint_files = [os.path.join(checkpoint_folder, f) for f in checkpoint_files] + + # Logic for missing/unexepected keys goes here. + + offload_index = {} + if offload_state_dict: + state_dict_folder = tempfile.mkdtemp() + state_dict_index = {} + + unexpected_keys = set() + model_keys = set(model.state_dict().keys()) + buffer_names = [name for name, _ in model.named_buffers()] + for checkpoint_file in checkpoint_files: + loaded_checkpoint = load_state_dict(checkpoint_file, device_map=device_map) + if device_map is None: + model.load_state_dict(loaded_checkpoint, strict=strict) + unexpected_keys.update(set(loaded_checkpoint.keys()) - model_keys) + else: + for param_name, param in loaded_checkpoint.items(): + # skip SCB parameter (for 8-bit serialization) + if "SCB" in param_name: + continue + + if param_name not in model_keys: + unexpected_keys.add(param_name) + if not strict: + continue # Skip loading this parameter. + + module_name = param_name + + while len(module_name) > 0 and module_name not in device_map: + module_name = ".".join(module_name.split(".")[:-1]) + if module_name == "" and "" not in device_map: + # TODO: group all errors and raise at the end. + raise ValueError(f"{param_name} doesn't have any device set.") + param_device = device_map[module_name] + new_dtype = dtype + if dtype is not None and torch.is_floating_point(param): + if keep_in_fp32_modules is not None and dtype == torch.float16: + proceed = False + for key in keep_in_fp32_modules: + if ((key in param_name) and (key + "." in param_name)) or key == param_name: + proceed = True + break + if proceed: + new_dtype = torch.float32 + + if "weight" in param_name and param_name.replace("weight", "SCB") in loaded_checkpoint.keys(): + if param.dtype == torch.int8: + fp16_statistics = loaded_checkpoint[param_name.replace("weight", "SCB")] + else: + fp16_statistics = None + + if param_device == "disk": + if offload_buffers or param_name not in buffer_names: + if new_dtype is None: + new_dtype = param.dtype + if offload_8bit_bnb: + quantize_and_offload_8bit( + model, param, param_name, new_dtype, offload_folder, offload_index, fp16_statistics + ) + continue + else: + set_module_tensor_to_device(model, param_name, "meta", dtype=new_dtype) + offload_weight(param, param_name, offload_folder, index=offload_index) + elif param_device == "cpu" and offload_state_dict: + if new_dtype is None: + new_dtype = param.dtype + if offload_8bit_bnb: + quantize_and_offload_8bit( + model, param, param_name, new_dtype, state_dict_folder, state_dict_index, fp16_statistics + ) + else: + set_module_tensor_to_device(model, param_name, "meta", dtype=new_dtype) + offload_weight(param, param_name, state_dict_folder, index=state_dict_index) + else: + set_module_tensor_to_device( + model, + param_name, + param_device, + value=param, + dtype=new_dtype, + fp16_statistics=fp16_statistics, + ) + + # Force Python to clean up. + del loaded_checkpoint + gc.collect() + + if not strict and len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {checkpoint} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}. This may or may not be an issue - make sure that the checkpoint does not have unnecessary parameters, or that the model definition correctly corresponds to the checkpoint." + ) + + save_offload_index(offload_index, offload_folder) + + # Load back offloaded state dict on CPU + if offload_state_dict: + load_offloaded_weights(model, state_dict_index, state_dict_folder) + shutil.rmtree(state_dict_folder) + + retie_parameters(model, tied_params) + + +def get_mixed_precision_context_manager(native_amp: bool = False, autocast_kwargs: AutocastKwargs = None): + """ + Return a context manager for autocasting mixed precision + + Args: + native_amp (`bool`, *optional*, defaults to False): + Whether mixed precision is actually enabled. + cache_enabled (`bool`, *optional*, defaults to True): + Whether the weight cache inside autocast should be enabled. + """ + state = AcceleratorState() + if autocast_kwargs is None: + autocast_kwargs = {} + else: + autocast_kwargs = autocast_kwargs.to_kwargs() + if native_amp: + device_type = ( + "cuda" + if (state.distributed_type == DistributedType.XLA and is_torch_xla_available(check_is_gpu=True)) + else state.device.type + ) + if state.mixed_precision == "fp16": + return torch.autocast(device_type=device_type, dtype=torch.float16, **autocast_kwargs) + elif state.mixed_precision == "bf16" and state.distributed_type in [ + DistributedType.NO, + DistributedType.MULTI_CPU, + DistributedType.MULTI_GPU, + DistributedType.MULTI_MLU, + DistributedType.MULTI_NPU, + DistributedType.MULTI_XPU, + DistributedType.FSDP, + DistributedType.XLA, + ]: + return torch.autocast(device_type=device_type, dtype=torch.bfloat16, **autocast_kwargs) + else: + return torch.autocast(device_type=device_type, **autocast_kwargs) + else: + return contextlib.nullcontext() diff --git a/llm/Lib/site-packages/accelerate/utils/offload.py b/llm/Lib/site-packages/accelerate/utils/offload.py new file mode 100644 index 0000000000000000000000000000000000000000..d064847ca21bde644b443de315b239414aa2fd51 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/utils/offload.py @@ -0,0 +1,213 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from collections.abc import Mapping +from typing import Dict, List, Optional, Union + +import numpy as np +import torch +from safetensors import safe_open + + +def offload_weight(weight, weight_name, offload_folder, index=None): + dtype = None + # Check the string instead of the dtype to be compatible with versions of PyTorch that don't have bfloat16. + if str(weight.dtype) == "torch.bfloat16": + # Need to reinterpret the underlined data as int16 since NumPy does not handle bfloat16s. + weight = weight.view(torch.int16) + dtype = "bfloat16" + array = weight.cpu().numpy() + tensor_file = os.path.join(offload_folder, f"{weight_name}.dat") + if index is not None: + if dtype is None: + dtype = str(array.dtype) + index[weight_name] = {"dtype": dtype, "shape": list(array.shape)} + if array.ndim == 0: + array = array[None] + file_array = np.memmap(tensor_file, dtype=array.dtype, mode="w+", shape=array.shape) + file_array[:] = array[:] + file_array.flush() + return index + + +def load_offloaded_weight(weight_file, weight_info): + shape = tuple(weight_info["shape"]) + if shape == (): + # NumPy memory-mapped arrays can't have 0 dims so it was saved as 1d tensor + shape = (1,) + + dtype = weight_info["dtype"] + if dtype == "bfloat16": + # NumPy does not support bfloat16 so this was saved as a int16 + dtype = "int16" + + weight = np.memmap(weight_file, dtype=dtype, shape=shape, mode="r") + + if len(weight_info["shape"]) == 0: + weight = weight[0] + weight = torch.tensor(weight) + if weight_info["dtype"] == "bfloat16": + weight = weight.view(torch.bfloat16) + + return weight + + +def save_offload_index(index, offload_folder): + if index is None or len(index) == 0: + # Nothing to save + return + + offload_index_file = os.path.join(offload_folder, "index.json") + if os.path.isfile(offload_index_file): + with open(offload_index_file, encoding="utf-8") as f: + current_index = json.load(f) + else: + current_index = {} + current_index.update(index) + + with open(offload_index_file, "w", encoding="utf-8") as f: + json.dump(current_index, f, indent=2) + + +def offload_state_dict(save_dir: Union[str, os.PathLike], state_dict: Dict[str, torch.Tensor]): + """ + Offload a state dict in a given folder. + + Args: + save_dir (`str` or `os.PathLike`): + The directory in which to offload the state dict. + state_dict (`Dict[str, torch.Tensor]`): + The dictionary of tensors to offload. + """ + os.makedirs(save_dir, exist_ok=True) + index = {} + for name, parameter in state_dict.items(): + index = offload_weight(parameter, name, save_dir, index=index) + + # Update index + save_offload_index(index, save_dir) + + +class PrefixedDataset(Mapping): + """ + Will access keys in a given dataset by adding a prefix. + + Args: + dataset (`Mapping`): Any map with string keys. + prefix (`str`): A prefix to add when trying to access any element in the underlying dataset. + """ + + def __init__(self, dataset: Mapping, prefix: str): + self.dataset = dataset + self.prefix = prefix + + def __getitem__(self, key): + return self.dataset[f"{self.prefix}{key}"] + + def __iter__(self): + return iter([key for key in self.dataset if key.startswith(self.prefix)]) + + def __len__(self): + return len(self.dataset) + + +class OffloadedWeightsLoader(Mapping): + """ + A collection that loads weights stored in a given state dict or memory-mapped on disk. + + Args: + state_dict (`Dict[str, torch.Tensor]`, *optional*): + A dictionary parameter name to tensor. + save_folder (`str` or `os.PathLike`, *optional*): + The directory in which the weights are stored (by `offload_state_dict` for instance). + index (`Dict`, *optional*): + A dictionary from weight name to their information (`dtype`/ `shape` or safetensors filename). Will default + to the index saved in `save_folder`. + """ + + def __init__( + self, + state_dict: Dict[str, torch.Tensor] = None, + save_folder: Optional[Union[str, os.PathLike]] = None, + index: Mapping = None, + device=None, + ): + if state_dict is None and save_folder is None and index is None: + raise ValueError("Need either a `state_dict`, a `save_folder` or an `index` containing offloaded weights.") + + self.state_dict = {} if state_dict is None else state_dict + self.save_folder = save_folder + if index is None and save_folder is not None: + with open(os.path.join(save_folder, "index.json")) as f: + index = json.load(f) + self.index = {} if index is None else index + self.all_keys = list(self.state_dict.keys()) + self.all_keys.extend([key for key in self.index if key not in self.all_keys]) + self.device = device + + def __getitem__(self, key: str): + # State dict gets priority + if key in self.state_dict: + return self.state_dict[key] + weight_info = self.index[key] + if weight_info.get("safetensors_file") is not None: + device = "cpu" if self.device is None else self.device + tensor = None + try: + with safe_open(weight_info["safetensors_file"], framework="pt", device=device) as f: + tensor = f.get_tensor(weight_info.get("weight_name", key)) + except TypeError: + # if failed to get_tensor on the device, such as bf16 on mps, try to load it on CPU first + with safe_open(weight_info["safetensors_file"], framework="pt", device="cpu") as f: + tensor = f.get_tensor(weight_info.get("weight_name", key)) + + if "dtype" in weight_info: + tensor = tensor.to(getattr(torch, weight_info["dtype"])) + + if tensor.device != torch.device(device): + tensor = tensor.to(device) + return tensor + + weight_file = os.path.join(self.save_folder, f"{key}.dat") + return load_offloaded_weight(weight_file, weight_info) + + def __iter__(self): + return iter(self.all_keys) + + def __len__(self): + return len(self.all_keys) + + +def extract_submodules_state_dict(state_dict: Dict[str, torch.Tensor], submodule_names: List[str]): + """ + Extract the sub state-dict corresponding to a list of given submodules. + + Args: + state_dict (`Dict[str, torch.Tensor]`): The state dict to extract from. + submodule_names (`List[str]`): The list of submodule names we want to extract. + """ + result = {} + for module_name in submodule_names: + # We want to catch module_name parameter (module_name.xxx) or potentially module_name, but not any of the + # submodules that could being like module_name (transformers.h.1 and transformers.h.10 for instance) + result.update( + { + key: param + for key, param in state_dict.items() + if key == module_name or key.startswith(module_name + ".") + } + ) + return result diff --git a/llm/Lib/site-packages/accelerate/utils/operations.py b/llm/Lib/site-packages/accelerate/utils/operations.py new file mode 100644 index 0000000000000000000000000000000000000000..e2456a10710c9c0707245a00cf3630297d6d579f --- /dev/null +++ b/llm/Lib/site-packages/accelerate/utils/operations.py @@ -0,0 +1,851 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +A set of basic tensor ops compatible with tpu, gpu, and multigpu +""" + +import pickle +import warnings +from functools import update_wrapper, wraps +from typing import Any, Mapping + +import torch + +from ..state import PartialState +from .constants import TORCH_DISTRIBUTED_OPERATION_TYPES +from .dataclasses import DistributedType, TensorInformation +from .imports import ( + is_npu_available, + is_torch_distributed_available, + is_torch_version, + is_torch_xla_available, + is_xpu_available, +) + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + +if is_torch_distributed_available(): + from torch.distributed import ReduceOp + + +def is_torch_tensor(tensor): + return isinstance(tensor, torch.Tensor) + + +def is_torch_xpu_tensor(tensor): + return isinstance( + tensor, + torch.xpu.FloatTensor, + torch.xpu.ByteTensor, + torch.xpu.IntTensor, + torch.xpu.LongTensor, + torch.xpu.HalfTensor, + torch.xpu.DoubleTensor, + torch.xpu.BFloat16Tensor, + ) + + +def is_tensor_information(tensor_info): + return isinstance(tensor_info, TensorInformation) + + +def is_namedtuple(data): + """ + Checks if `data` is a `namedtuple` or not. Can have false positives, but only if a user is trying to mimic a + `namedtuple` perfectly. + """ + return isinstance(data, tuple) and hasattr(data, "_asdict") and hasattr(data, "_fields") + + +def honor_type(obj, generator): + """ + Cast a generator to the same type as obj (list, tuple, or namedtuple) + """ + # Some objects may not be able to instantiate from a generator directly + if is_namedtuple(obj): + return type(obj)(*list(generator)) + else: + return type(obj)(generator) + + +def recursively_apply(func, data, *args, test_type=is_torch_tensor, error_on_other_type=False, **kwargs): + """ + Recursively apply a function on a data structure that is a nested list/tuple/dictionary of a given base type. + + Args: + func (`callable`): + The function to recursively apply. + data (nested list/tuple/dictionary of `main_type`): + The data on which to apply `func` + *args: + Positional arguments that will be passed to `func` when applied on the unpacked data. + main_type (`type`, *optional*, defaults to `torch.Tensor`): + The base type of the objects to which apply `func`. + error_on_other_type (`bool`, *optional*, defaults to `False`): + Whether to return an error or not if after unpacking `data`, we get on an object that is not of type + `main_type`. If `False`, the function will leave objects of types different than `main_type` unchanged. + **kwargs (additional keyword arguments, *optional*): + Keyword arguments that will be passed to `func` when applied on the unpacked data. + + Returns: + The same data structure as `data` with `func` applied to every object of type `main_type`. + """ + if isinstance(data, (tuple, list)): + return honor_type( + data, + ( + recursively_apply( + func, o, *args, test_type=test_type, error_on_other_type=error_on_other_type, **kwargs + ) + for o in data + ), + ) + elif isinstance(data, Mapping): + return type(data)( + { + k: recursively_apply( + func, v, *args, test_type=test_type, error_on_other_type=error_on_other_type, **kwargs + ) + for k, v in data.items() + } + ) + elif test_type(data): + return func(data, *args, **kwargs) + elif error_on_other_type: + raise TypeError( + f"Unsupported types ({type(data)}) passed to `{func.__name__}`. Only nested list/tuple/dicts of " + f"objects that are valid for `{test_type.__name__}` should be passed." + ) + return data + + +def send_to_device(tensor, device, non_blocking=False, skip_keys=None): + """ + Recursively sends the elements in a nested list/tuple/dictionary of tensors to a given device. + + Args: + tensor (nested list/tuple/dictionary of `torch.Tensor`): + The data to send to a given device. + device (`torch.device`): + The device to send the data to. + + Returns: + The same data structure as `tensor` with all tensors sent to the proper device. + """ + if is_torch_tensor(tensor) or hasattr(tensor, "to"): + # `torch.Tensor.to("npu")` could not find context when called for the first time (see this [issue](https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue)). + if device == "npu": + device = "npu:0" + if device == "xpu": + device = "xpu:0" + # TODO: torch_mlu LongTensor.to() has bugs, we will fix this later. + if is_torch_tensor(tensor) and tensor.device.type in ["mlu"] and tensor.dtype in [torch.int64]: + tensor = tensor.cpu() + try: + return tensor.to(device, non_blocking=non_blocking) + except TypeError: # .to() doesn't accept non_blocking as kwarg + return tensor.to(device) + except AssertionError as error: + # `torch.Tensor.to()` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)). + # This call is inside the try-block since is_npu_available is not supported by torch.compile. + if is_npu_available(): + if isinstance(device, int): + device = f"npu:{device}" + else: + raise error + except Exception as error: + if is_xpu_available(): + if isinstance(device, int): + device = f"xpu:{device}" + else: + raise error + try: + return tensor.to(device, non_blocking=non_blocking) + except TypeError: # .to() doesn't accept non_blocking as kwarg + return tensor.to(device) + elif isinstance(tensor, (tuple, list)): + return honor_type( + tensor, (send_to_device(t, device, non_blocking=non_blocking, skip_keys=skip_keys) for t in tensor) + ) + elif isinstance(tensor, Mapping): + if isinstance(skip_keys, str): + skip_keys = [skip_keys] + elif skip_keys is None: + skip_keys = [] + return type(tensor)( + { + k: t if k in skip_keys else send_to_device(t, device, non_blocking=non_blocking, skip_keys=skip_keys) + for k, t in tensor.items() + } + ) + else: + return tensor + + +def get_data_structure(data): + """ + Recursively gathers the information needed to rebuild a nested list/tuple/dictionary of tensors. + + Args: + data (nested list/tuple/dictionary of `torch.Tensor`): + The data to send to analyze. + + Returns: + The same data structure as `data` with [`~utils.TensorInformation`] instead of tensors. + """ + + def _get_data_structure(tensor): + return TensorInformation(shape=tensor.shape, dtype=tensor.dtype) + + return recursively_apply(_get_data_structure, data) + + +def get_shape(data): + """ + Recursively gathers the shape of a nested list/tuple/dictionary of tensors as a list. + + Args: + data (nested list/tuple/dictionary of `torch.Tensor`): + The data to send to analyze. + + Returns: + The same data structure as `data` with lists of tensor shapes instead of tensors. + """ + + def _get_shape(tensor): + return list(tensor.shape) + + return recursively_apply(_get_shape, data) + + +def initialize_tensors(data_structure): + """ + Recursively initializes tensors from a nested list/tuple/dictionary of [`~utils.TensorInformation`]. + + Returns: + The same data structure as `data` with tensors instead of [`~utils.TensorInformation`]. + """ + + def _initialize_tensor(tensor_info): + return torch.empty(*tensor_info.shape, dtype=tensor_info.dtype) + + return recursively_apply(_initialize_tensor, data_structure, test_type=is_tensor_information) + + +def find_batch_size(data): + """ + Recursively finds the batch size in a nested list/tuple/dictionary of lists of tensors. + + Args: + data (nested list/tuple/dictionary of `torch.Tensor`): The data from which to find the batch size. + + Returns: + `int`: The batch size. + """ + if isinstance(data, (tuple, list, Mapping)) and (len(data) == 0): + raise ValueError(f"Cannot find the batch size from empty {type(data)}.") + + if isinstance(data, (tuple, list)): + return find_batch_size(data[0]) + elif isinstance(data, Mapping): + for k in data.keys(): + return find_batch_size(data[k]) + elif not isinstance(data, torch.Tensor): + raise TypeError(f"Can only find the batch size of tensors but got {type(data)}.") + return data.shape[0] + + +def ignorant_find_batch_size(data): + """ + Same as [`utils.operations.find_batch_size`] except will ignore if `ValueError` and `TypeErrors` are raised + + Args: + data (nested list/tuple/dictionary of `torch.Tensor`): The data from which to find the batch size. + + Returns: + `int`: The batch size. + """ + try: + return find_batch_size(data) + except (ValueError, TypeError): + pass + return None + + +def listify(data): + """ + Recursively finds tensors in a nested list/tuple/dictionary and converts them to a list of numbers. + + Args: + data (nested list/tuple/dictionary of `torch.Tensor`): The data from which to convert to regular numbers. + + Returns: + The same data structure as `data` with lists of numbers instead of `torch.Tensor`. + """ + + def _convert_to_list(tensor): + tensor = tensor.detach().cpu() + if tensor.dtype == torch.bfloat16: + # As of Numpy 1.21.4, NumPy does not support bfloat16 (see + # https://github.com/numpy/numpy/blob/a47ecdea856986cd60eabbd53265c2ca5916ad5d/doc/source/user/basics.types.rst ). + # Until Numpy adds bfloat16, we must convert float32. + tensor = tensor.to(torch.float32) + return tensor.tolist() + + return recursively_apply(_convert_to_list, data) + + +def _tpu_gather(tensor): + def _tpu_gather_one(tensor): + if tensor.ndim == 0: + tensor = tensor.clone()[None] + + # Can only gather contiguous tensors + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + return xm.all_gather(tensor) + + res = recursively_apply(_tpu_gather_one, tensor, error_on_other_type=True) + xm.mark_step() + return res + + +def _gpu_gather(tensor): + state = PartialState() + if is_torch_version(">=", "1.13"): + gather_op = torch.distributed.all_gather_into_tensor + else: + gather_op = torch.distributed._all_gather_base + + def _gpu_gather_one(tensor): + if tensor.ndim == 0: + tensor = tensor.clone()[None] + + # Can only gather contiguous tensors + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + + if state.backend is not None and state.backend != "gloo": + # We use `empty` as `all_gather_into_tensor` slightly + # differs from `all_gather` for better efficiency, + # and we rely on the number of items in the tensor + # rather than its direct shape + output_tensors = torch.empty( + state.num_processes * tensor.numel(), + dtype=tensor.dtype, + device=state.device, + ) + gather_op(output_tensors, tensor) + return output_tensors.view(-1, *tensor.size()[1:]) + else: + # a backend of `None` is always CPU + # also gloo does not support `all_gather_into_tensor`, + # which will result in a larger memory overhead for the op + output_tensors = [torch.empty_like(tensor) for _ in range(state.num_processes)] + torch.distributed.all_gather(output_tensors, tensor) + return torch.cat(output_tensors, dim=0) + + return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True) + + +class DistributedOperationException(Exception): + """ + An exception class for distributed operations. Raised if the operation cannot be performed due to the shape of the + tensors. + """ + + pass + + +def verify_operation(function): + """ + Verifies that `tensor` is the same shape across all processes. Only ran if `PartialState().debug` is `True`. + """ + + @wraps(function) + def wrapper(*args, **kwargs): + if PartialState().distributed_type == DistributedType.NO or not PartialState().debug: + return function(*args, **kwargs) + operation = f"{function.__module__}.{function.__name__}" + if "tensor" in kwargs: + tensor = kwargs["tensor"] + else: + tensor = args[0] + if PartialState().device.type != find_device(tensor).type: + raise DistributedOperationException( + f"One or more of the tensors passed to {operation} were not on the {tensor.device.type} while the `Accelerator` is configured for {PartialState().device.type}. " + f"Please move it to the {PartialState().device.type} before calling {operation}." + ) + shapes = get_shape(tensor) + output = gather_object([shapes]) + if output[0] is not None: + are_same = output.count(output[0]) == len(output) + if not are_same: + process_shape_str = "\n - ".join([f"Process {i}: {shape}" for i, shape in enumerate(output)]) + raise DistributedOperationException( + f"Cannot apply desired operation due to shape mismatches. " + "All shapes across devices must be valid." + f"\n\nOperation: `{operation}`\nInput shapes:\n - {process_shape_str}" + ) + return function(*args, **kwargs) + + return wrapper + + +def chained_operation(function): + """ + Checks that `verify_operation` failed and if so reports a more helpful error chaining the existing + `DistributedOperationException`. + """ + + @wraps(function) + def wrapper(*args, **kwargs): + try: + return function(*args, **kwargs) + except DistributedOperationException as e: + operation = f"{function.__module__}.{function.__name__}" + raise DistributedOperationException( + f"Error found while calling `{operation}`. Please see the earlier error for more details." + ) from e + + return wrapper + + +@verify_operation +def gather(tensor): + """ + Recursively gather tensor in a nested list/tuple/dictionary of tensors from all devices. + + Args: + tensor (nested list/tuple/dictionary of `torch.Tensor`): + The data to gather. + + Returns: + The same data structure as `tensor` with all tensors sent to the proper device. + """ + if PartialState().distributed_type == DistributedType.XLA: + return _tpu_gather(tensor) + elif PartialState().distributed_type in TORCH_DISTRIBUTED_OPERATION_TYPES: + return _gpu_gather(tensor) + else: + return tensor + + +def _gpu_gather_object(object: Any): + output_objects = [None for _ in range(PartialState().num_processes)] + torch.distributed.all_gather_object(output_objects, object) + # all_gather_object returns a list of lists, so we need to flatten it + return [x for y in output_objects for x in y] + + +def gather_object(object: Any): + """ + Recursively gather object in a nested list/tuple/dictionary of objects from all devices. + + Args: + object (nested list/tuple/dictionary of picklable object): + The data to gather. + + Returns: + The same data structure as `object` with all the objects sent to every device. + """ + if PartialState().distributed_type == DistributedType.XLA: + raise NotImplementedError("gather objects in TPU is not supported") + elif PartialState().distributed_type in TORCH_DISTRIBUTED_OPERATION_TYPES: + return _gpu_gather_object(object) + else: + return object + + +def _gpu_broadcast(data, src=0): + def _gpu_broadcast_one(tensor, src=0): + torch.distributed.broadcast(tensor, src=src) + return tensor + + return recursively_apply(_gpu_broadcast_one, data, error_on_other_type=True, src=src) + + +def _tpu_broadcast(tensor, src=0, name="broadcast tensor"): + if isinstance(tensor, (list, tuple)): + return honor_type(tensor, (_tpu_broadcast(t, name=f"{name}_{i}") for i, t in enumerate(tensor))) + elif isinstance(tensor, Mapping): + return type(tensor)({k: _tpu_broadcast(v, name=f"{name}_{k}") for k, v in tensor.items()}) + return xm.mesh_reduce(name, tensor, lambda x: x[src]) + + +TENSOR_TYPE_TO_INT = { + torch.float: 1, + torch.double: 2, + torch.half: 3, + torch.bfloat16: 4, + torch.uint8: 5, + torch.int8: 6, + torch.int16: 7, + torch.int32: 8, + torch.int64: 9, + torch.bool: 10, +} + +TENSOR_INT_TO_DTYPE = {v: k for k, v in TENSOR_TYPE_TO_INT.items()} + + +def gather_tensor_shape(tensor): + """ + Grabs the shape of `tensor` only available on one process and returns a tensor of its shape + """ + # Allocate 80 bytes to store the shape + max_tensor_dimension = 2**20 + state = PartialState() + base_tensor = torch.empty(max_tensor_dimension, dtype=torch.int, device=state.device) + + # Since PyTorch can't just send a tensor to another GPU without + # knowing its size, we store the size of the tensor with data + # in an allocation + if tensor is not None: + shape = tensor.shape + tensor_dtype = TENSOR_TYPE_TO_INT[tensor.dtype] + base_tensor[: len(shape) + 1] = torch.tensor(list(shape) + [tensor_dtype], dtype=int) + # Perform a reduction to copy the size data onto all GPUs + base_tensor = reduce(base_tensor, reduction="sum") + base_tensor = base_tensor[base_tensor.nonzero()] + # The last non-zero data contains the coded dtype the source tensor is + dtype = int(base_tensor[-1:][0]) + base_tensor = base_tensor[:-1] + return base_tensor, dtype + + +def copy_tensor_to_devices(tensor=None) -> torch.Tensor: + """ + Copys a tensor that only exists on a single device and broadcasts it to other devices. Differs from `broadcast` as + each worker doesn't need to know its shape when used (and tensor can be `None`) + + Args: + tensor (`torch.tensor`): + The tensor that should be sent to all devices. Must only have it be defined on a single device, the rest + should be `None`. + """ + state = PartialState() + shape, dtype = gather_tensor_shape(tensor) + if tensor is None: + tensor = torch.zeros(shape, dtype=TENSOR_INT_TO_DTYPE[dtype]).to(state.device) + return reduce(tensor, reduction="sum") + + +@verify_operation +def broadcast(tensor, from_process: int = 0): + """ + Recursively broadcast tensor in a nested list/tuple/dictionary of tensors to all devices. + + Args: + tensor (nested list/tuple/dictionary of `torch.Tensor`): + The data to gather. + from_process (`int`, *optional*, defaults to 0): + The process from which to send the data + + Returns: + The same data structure as `tensor` with all tensors broadcasted to the proper device. + """ + if PartialState().distributed_type == DistributedType.XLA: + return _tpu_broadcast(tensor, src=from_process, name="accelerate.utils.broadcast") + elif PartialState().distributed_type in TORCH_DISTRIBUTED_OPERATION_TYPES: + return _gpu_broadcast(tensor, src=from_process) + else: + return tensor + + +def broadcast_object_list(object_list, from_process: int = 0): + """ + Broadcast a list of picklable objects form one process to the others. + + Args: + object_list (list of picklable objects): + The list of objects to broadcast. This list will be modified inplace. + from_process (`int`, *optional*, defaults to 0): + The process from which to send the data. + + Returns: + The same list containing the objects from process 0. + """ + if PartialState().distributed_type == DistributedType.XLA: + for i, obj in enumerate(object_list): + object_list[i] = xm.mesh_reduce("accelerate.utils.broadcast_object_list", obj, lambda x: x[from_process]) + elif PartialState().distributed_type in TORCH_DISTRIBUTED_OPERATION_TYPES: + torch.distributed.broadcast_object_list(object_list, src=from_process) + return object_list + + +def slice_tensors(data, tensor_slice, process_index=None, num_processes=None): + """ + Recursively takes a slice in a nested list/tuple/dictionary of tensors. + + Args: + data (nested list/tuple/dictionary of `torch.Tensor`): + The data to slice. + tensor_slice (`slice`): + The slice to take. + + Returns: + The same data structure as `data` with all the tensors slices. + """ + + def _slice_tensor(tensor, tensor_slice): + return tensor[tensor_slice] + + return recursively_apply(_slice_tensor, data, tensor_slice) + + +def concatenate(data, dim=0): + """ + Recursively concatenate the tensors in a nested list/tuple/dictionary of lists of tensors with the same shape. + + Args: + data (nested list/tuple/dictionary of lists of tensors `torch.Tensor`): + The data to concatenate. + dim (`int`, *optional*, defaults to 0): + The dimension on which to concatenate. + + Returns: + The same data structure as `data` with all the tensors concatenated. + """ + if isinstance(data[0], (tuple, list)): + return honor_type(data[0], (concatenate([d[i] for d in data], dim=dim) for i in range(len(data[0])))) + elif isinstance(data[0], Mapping): + return type(data[0])({k: concatenate([d[k] for d in data], dim=dim) for k in data[0].keys()}) + elif not isinstance(data[0], torch.Tensor): + raise TypeError(f"Can only concatenate tensors but got {type(data[0])}") + return torch.cat(data, dim=dim) + + +class CannotPadNestedTensorWarning(UserWarning): + pass + + +@chained_operation +def pad_across_processes(tensor, dim=0, pad_index=0, pad_first=False): + """ + Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so they + can safely be gathered. + + Args: + tensor (nested list/tuple/dictionary of `torch.Tensor`): + The data to gather. + dim (`int`, *optional*, defaults to 0): + The dimension on which to pad. + pad_index (`int`, *optional*, defaults to 0): + The value with which to pad. + pad_first (`bool`, *optional*, defaults to `False`): + Whether to pad at the beginning or the end. + """ + + def _pad_across_processes(tensor, dim=0, pad_index=0, pad_first=False): + if getattr(tensor, "is_nested", False): + warnings.warn( + "Cannot pad nested tensors without more information. Leaving unprocessed.", + CannotPadNestedTensorWarning, + ) + return tensor + if dim >= len(tensor.shape): + return tensor + + # Gather all sizes + size = torch.tensor(tensor.shape, device=tensor.device)[None] + sizes = gather(size).cpu() + # Then pad to the maximum size + max_size = max(s[dim] for s in sizes) + if max_size == tensor.shape[dim]: + return tensor + + old_size = tensor.shape + new_size = list(old_size) + new_size[dim] = max_size + new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index + if pad_first: + indices = tuple( + slice(max_size - old_size[dim], max_size) if i == dim else slice(None) for i in range(len(new_size)) + ) + else: + indices = tuple(slice(0, old_size[dim]) if i == dim else slice(None) for i in range(len(new_size))) + new_tensor[indices] = tensor + return new_tensor + + return recursively_apply( + _pad_across_processes, tensor, error_on_other_type=True, dim=dim, pad_index=pad_index, pad_first=pad_first + ) + + +def pad_input_tensors(tensor, batch_size, num_processes, dim=0): + """ + Takes a `tensor` of arbitrary size and pads it so that it can work given `num_processes` needed dimensions. + + New tensors are just the last input repeated. + + E.g.: + Tensor: ([3,4,4]) Num processes: 4 Expected result shape: ([4,4,4]) + + """ + + def _pad_input_tensors(tensor, batch_size, num_processes, dim=0): + remainder = batch_size // num_processes + last_inputs = batch_size - (remainder * num_processes) + if batch_size // num_processes == 0: + to_pad = num_processes - batch_size + else: + to_pad = num_processes - (batch_size // num_processes) + # In the rare case that `to_pad` is negative, + # we need to pad the last inputs - the found `to_pad` + if last_inputs > to_pad & to_pad < 1: + to_pad = last_inputs - to_pad + old_size = tensor.shape + new_size = list(old_size) + new_size[0] = batch_size + to_pad + new_tensor = tensor.new_zeros(tuple(new_size)) + indices = tuple(slice(0, old_size[dim]) if i == dim else slice(None) for i in range(len(new_size))) + new_tensor[indices] = tensor + return new_tensor + + return recursively_apply( + _pad_input_tensors, + tensor, + error_on_other_type=True, + batch_size=batch_size, + num_processes=num_processes, + dim=dim, + ) + + +@verify_operation +def reduce(tensor, reduction="mean", scale=1.0): + """ + Recursively reduce the tensors in a nested list/tuple/dictionary of lists of tensors across all processes by the + mean of a given operation. + + Args: + tensor (nested list/tuple/dictionary of `torch.Tensor`): + The data to reduce. + reduction (`str`, *optional*, defaults to `"mean"`): + A reduction method. Can be of "mean", "sum", or "none" + scale (`float`, *optional*): + A default scaling value to be applied after the reduce, only valied on XLA. + + Returns: + The same data structure as `data` with all the tensors reduced. + """ + + def _reduce_across_processes(tensor, reduction="mean", scale=1.0): + state = PartialState() + cloned_tensor = tensor.clone() + if state.distributed_type == DistributedType.NO: + return cloned_tensor + if state.distributed_type == DistributedType.XLA: + # Some processes may have different HLO graphs than other + # processes, for example in the breakpoint API + # accelerator.set_trigger(). Use mark_step to make HLOs + # the same on all processes. + xm.mark_step() + xm.all_reduce(xm.REDUCE_SUM, [cloned_tensor], scale) + xm.mark_step() + elif state.distributed_type.value in TORCH_DISTRIBUTED_OPERATION_TYPES: + torch.distributed.all_reduce(cloned_tensor, ReduceOp.SUM) + if reduction == "mean": + cloned_tensor /= state.num_processes + return cloned_tensor + + return recursively_apply( + _reduce_across_processes, tensor, error_on_other_type=True, reduction=reduction, scale=scale + ) + + +def convert_to_fp32(tensor): + """ + Recursively converts the elements nested list/tuple/dictionary of tensors in FP16/BF16 precision to FP32. + + Args: + tensor (nested list/tuple/dictionary of `torch.Tensor`): + The data to convert from FP16/BF16 to FP32. + + Returns: + The same data structure as `tensor` with all tensors that were in FP16/BF16 precision converted to FP32. + """ + + def _convert_to_fp32(tensor): + return tensor.float() + + def _is_fp16_bf16_tensor(tensor): + return (is_torch_tensor(tensor) or hasattr(tensor, "dtype")) and tensor.dtype in ( + torch.float16, + torch.bfloat16, + ) + + return recursively_apply(_convert_to_fp32, tensor, test_type=_is_fp16_bf16_tensor) + + +class ConvertOutputsToFp32: + """ + Decorator to apply to a function outputing tensors (like a model forward pass) that ensures the outputs in FP16 + precision will be convert back to FP32. + + Args: + model_forward (`Callable`): + The function which outputs we want to treat. + + Returns: + The same function as `model_forward` but with converted outputs. + """ + + def __init__(self, model_forward): + self.model_forward = model_forward + update_wrapper(self, model_forward) + + def __call__(self, *args, **kwargs): + return convert_to_fp32(self.model_forward(*args, **kwargs)) + + def __getstate__(self): + raise pickle.PicklingError( + "Cannot pickle a prepared model with automatic mixed precision, please unwrap the model with `Accelerator.unwrap_model(model)` before pickling it." + ) + + +def convert_outputs_to_fp32(model_forward): + model_forward = ConvertOutputsToFp32(model_forward) + + def forward(*args, **kwargs): + return model_forward(*args, **kwargs) + + # To act like a decorator so that it can be popped when doing `extract_model_from_parallel` + forward.__wrapped__ = model_forward + + return forward + + +def find_device(data): + """ + Finds the device on which a nested dict/list/tuple of tensors lies (assuming they are all on the same device). + + Args: + (nested list/tuple/dictionary of `torch.Tensor`): The data we want to know the device of. + """ + if isinstance(data, Mapping): + for obj in data.values(): + device = find_device(obj) + if device is not None: + return device + elif isinstance(data, (tuple, list)): + for obj in data: + device = find_device(obj) + if device is not None: + return device + elif isinstance(data, torch.Tensor): + return data.device diff --git a/llm/Lib/site-packages/accelerate/utils/other.py b/llm/Lib/site-packages/accelerate/utils/other.py new file mode 100644 index 0000000000000000000000000000000000000000..a313d08685be25707109c4973b346cdb0a4af90b --- /dev/null +++ b/llm/Lib/site-packages/accelerate/utils/other.py @@ -0,0 +1,366 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections +import os +import platform +import re +import socket +from contextlib import contextmanager +from functools import partial, reduce +from types import MethodType +from typing import OrderedDict + +import torch +from packaging.version import Version +from safetensors.torch import save_file as safe_save_file + +from ..commands.config.default import write_basic_config # noqa: F401 +from ..logging import get_logger +from ..state import PartialState +from .constants import FSDP_PYTORCH_VERSION +from .dataclasses import DistributedType +from .imports import is_deepspeed_available, is_torch_distributed_available, is_torch_xla_available +from .modeling import id_tensor_storage +from .transformer_engine import convert_model +from .versions import is_torch_version + + +logger = get_logger(__name__) + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + +def is_compiled_module(module): + """ + Check whether the module was compiled with torch.compile() + """ + if is_torch_version("<", "2.0.0") or not hasattr(torch, "_dynamo"): + return False + return isinstance(module, torch._dynamo.eval_frame.OptimizedModule) + + +def extract_model_from_parallel(model, keep_fp32_wrapper: bool = True, recursive: bool = False): + """ + Extract a model from its distributed containers. + + Args: + model (`torch.nn.Module`): + The model to extract. + keep_fp32_wrapper (`bool`, *optional*): + Whether to remove mixed precision hooks from the model. + recursive (`bool`, *optional*, defaults to `False`): + Whether to recursively extract all cases of `module.module` from `model` as well as unwrap child sublayers + recursively, not just the top-level distributed containers. + + Returns: + `torch.nn.Module`: The extracted model. + """ + options = (torch.nn.parallel.DistributedDataParallel, torch.nn.DataParallel) + + is_compiled = is_compiled_module(model) + if is_compiled: + compiled_model = model + model = model._orig_mod + + if is_deepspeed_available(): + from deepspeed import DeepSpeedEngine + + options += (DeepSpeedEngine,) + + if is_torch_version(">=", FSDP_PYTORCH_VERSION) and is_torch_distributed_available(): + from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP + + options += (FSDP,) + + while isinstance(model, options): + model = model.module + + if recursive: + # This is needed in cases such as using FSDPv2 on XLA + def _recursive_unwrap(module): + # Wrapped modules are standardly wrapped as `module`, similar to the cases earlier + # with DDP, DataParallel, DeepSpeed, and FSDP + if hasattr(module, "module"): + unwrapped_module = _recursive_unwrap(module.module) + else: + unwrapped_module = module + # Next unwrap child sublayers recursively + for name, child in unwrapped_module.named_children(): + setattr(unwrapped_module, name, _recursive_unwrap(child)) + return unwrapped_module + + # Start with top-level + model = _recursive_unwrap(model) + + if not keep_fp32_wrapper: + forward = model.forward + original_forward = model.__dict__.pop("_original_forward", None) + if original_forward is not None: + while hasattr(forward, "__wrapped__"): + forward = forward.__wrapped__ + if forward == original_forward: + break + model.forward = MethodType(forward, model) + if getattr(model, "_converted_to_transformer_engine", False): + convert_model(model, to_transformer_engine=False) + + if is_compiled: + compiled_model._orig_mod = model + model = compiled_model + + return model + + +def wait_for_everyone(): + """ + Introduces a blocking point in the script, making sure all processes have reached this point before continuing. + + + + Make sure all processes will reach this instruction otherwise one of your processes will hang forever. + + + """ + PartialState().wait_for_everyone() + + +def clean_state_dict_for_safetensors(state_dict: dict): + """ + Cleans the state dictionary from a model and removes tensor aliasing if present. + + Args: + state_dict (`dict`): + The state dictionary from a model + """ + ptrs = collections.defaultdict(list) + # When bnb serialization is used, weights in state dict can be strings + for name, tensor in state_dict.items(): + if not isinstance(tensor, str): + ptrs[id_tensor_storage(tensor)].append(name) + + # These are all pointers of tensors with shared memory + shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} + warn_names = set() + for names in shared_ptrs.values(): + # When not all duplicates have been cleaned, we still remove those keys but put a clear warning. + # If the link between tensors was done at runtime then `from_pretrained` will not get + # the key back leading to random tensor. A proper warning will be shown + # during reload (if applicable), but since the file is not necessarily compatible with + # the config, better show a proper warning. + found_names = [name for name in names if name in state_dict] + warn_names.update(found_names[1:]) + for name in found_names[1:]: + del state_dict[name] + if len(warn_names) > 0: + logger.warning( + f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading", + ) + state_dict = {k: v.contiguous() if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()} + return state_dict + + +def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = False): + """ + Save the data to disk. Use in place of `torch.save()`. + + Args: + obj: + The data to save + f: + The file (or file-like object) to use to save the data + save_on_each_node (`bool`, *optional*, defaults to `False`): + Whether to only save on the global main process + safe_serialization (`bool`, *optional*, defaults to `False`): + Whether to save `obj` using `safetensors` or the traditional PyTorch way (that uses `pickle`). + """ + # When TorchXLA is enabled, it's necessary to transfer all data to the CPU before saving. + # Another issue arises with `id_tensor_storage`, which treats all XLA tensors as identical. + # If tensors remain on XLA, calling `clean_state_dict_for_safetensors` will result in only + # one XLA tensor remaining. + if PartialState().distributed_type == DistributedType.XLA: + obj = xm._maybe_convert_to_cpu(obj) + # Check if it's a model and remove duplicates + if safe_serialization: + save_func = partial(safe_save_file, metadata={"format": "pt"}) + if isinstance(obj, OrderedDict): + obj = clean_state_dict_for_safetensors(obj) + else: + save_func = torch.save + + if PartialState().is_main_process and not save_on_each_node: + save_func(obj, f) + elif PartialState().is_local_main_process and save_on_each_node: + save_func(obj, f) + + +@contextmanager +def clear_environment(): + """ + A context manager that will temporarily clear environment variables. + + When this context exits, the previous environment variables will be back. + + Example: + + ```python + >>> import os + >>> from accelerate.utils import clear_environment + + >>> os.environ["FOO"] = "bar" + >>> with clear_environment(): + ... print(os.environ) + ... os.environ["FOO"] = "new_bar" + ... print(os.environ["FOO"]) + {} + new_bar + + >>> print(os.environ["FOO"]) + bar + ``` + """ + _old_os_environ = os.environ.copy() + os.environ.clear() + + try: + yield + finally: + os.environ.clear() # clear any added keys, + os.environ.update(_old_os_environ) # then restore previous environment + + +@contextmanager +def patch_environment(**kwargs): + """ + A context manager that will add each keyword argument passed to `os.environ` and remove them when exiting. + + Will convert the values in `kwargs` to strings and upper-case all the keys. + + Example: + + ```python + >>> import os + >>> from accelerate.utils import patch_environment + + >>> with patch_environment(FOO="bar"): + ... print(os.environ["FOO"]) # prints "bar" + >>> print(os.environ["FOO"]) # raises KeyError + ``` + """ + existing_vars = {} + for key, value in kwargs.items(): + key = key.upper() + if key in os.environ: + existing_vars[key] = os.environ[key] + os.environ[key] = str(value) + + try: + yield + finally: + for key in kwargs: + key = key.upper() + if key in existing_vars: + # restore previous value + os.environ[key] = existing_vars[key] + else: + os.environ.pop(key, None) + + +def get_pretty_name(obj): + """ + Gets a pretty name from `obj`. + """ + if not hasattr(obj, "__qualname__") and not hasattr(obj, "__name__"): + obj = getattr(obj, "__class__", obj) + if hasattr(obj, "__qualname__"): + return obj.__qualname__ + if hasattr(obj, "__name__"): + return obj.__name__ + return str(obj) + + +def merge_dicts(source, destination): + """ + Recursively merges two dictionaries. + + Args: + source (`dict`): The dictionary to merge into `destination`. + destination (`dict`): The dictionary to merge `source` into. + """ + for key, value in source.items(): + if isinstance(value, dict): + node = destination.setdefault(key, {}) + merge_dicts(value, node) + else: + destination[key] = value + + return destination + + +def is_port_in_use(port: int = None) -> bool: + """ + Checks if a port is in use on `localhost`. Useful for checking if multiple `accelerate launch` commands have been + run and need to see if the port is already in use. + """ + if port is None: + port = 29500 + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex(("localhost", port)) == 0 + + +def convert_bytes(size): + "Converts `size` from bytes to the largest possible unit" + for x in ["bytes", "KB", "MB", "GB", "TB"]: + if size < 1024.0: + return f"{round(size, 2)} {x}" + size /= 1024.0 + + return f"{round(size, 2)} PB" + + +def check_os_kernel(): + """Warns if the kernel version is below the recommended minimum on Linux.""" + # see issue #1929 + info = platform.uname() + system = info.system + if system != "Linux": + return + + _, version, *_ = re.split(r"(\d+\.\d+\.\d+)", info.release) + min_version = "5.5.0" + if Version(version) < Version(min_version): + msg = ( + f"Detected kernel version {version}, which is below the recommended minimum of {min_version}; this can " + "cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher." + ) + logger.warning(msg, main_process_only=True) + + +def recursive_getattr(obj, attr: str): + """ + Recursive `getattr`. + + Args: + obj: + A class instance holding the attribute. + attr (`str`): + The attribute that is to be retrieved, e.g. 'attribute1.attribute2'. + """ + + def _getattr(obj, attr): + return getattr(obj, attr) + + return reduce(_getattr, [obj] + attr.split(".")) diff --git a/llm/Lib/site-packages/accelerate/utils/random.py b/llm/Lib/site-packages/accelerate/utils/random.py new file mode 100644 index 0000000000000000000000000000000000000000..f21312289a77bce9143c985292a14185f35f5938 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/utils/random.py @@ -0,0 +1,122 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +from typing import List, Optional, Union + +import numpy as np +import torch + +from ..state import AcceleratorState +from .constants import CUDA_DISTRIBUTED_TYPES +from .dataclasses import DistributedType, RNGType +from .imports import is_mlu_available, is_npu_available, is_torch_xla_available, is_xpu_available + + +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + +def set_seed(seed: int, device_specific: bool = False, deterministic: bool = False): + """ + Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`. + + Args: + seed (`int`): + The seed to set. + device_specific (`bool`, *optional*, defaults to `False`): + Whether to differ the seed on each device slightly with `self.process_index`. + deterministic (`bool`, *optional*, defaults to `False`): + Whether to use deterministic algorithms where available. Can slow down training. + """ + if device_specific: + seed += AcceleratorState().process_index + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if is_xpu_available(): + torch.xpu.manual_seed_all(seed) + elif is_npu_available(): + torch.npu.manual_seed_all(seed) + elif is_mlu_available(): + torch.mlu.manual_seed_all(seed) + else: + torch.cuda.manual_seed_all(seed) + # ^^ safe to call this function even if cuda is not available + if is_torch_xla_available(): + xm.set_rng_state(seed) + + if deterministic: + torch.use_deterministic_algorithms(True) + + +def synchronize_rng_state(rng_type: Optional[RNGType] = None, generator: Optional[torch.Generator] = None): + # Get the proper rng state + if rng_type == RNGType.TORCH: + rng_state = torch.get_rng_state() + elif rng_type == RNGType.CUDA: + rng_state = torch.cuda.get_rng_state() + elif rng_type == RNGType.XLA: + assert is_torch_xla_available(), "Can't synchronize XLA seeds as torch_xla is unavailable." + rng_state = torch.tensor(xm.get_rng_state()) + elif rng_type == RNGType.NPU: + assert is_npu_available(), "Can't synchronize NPU seeds on an environment without NPUs." + rng_state = torch.npu.get_rng_state() + elif rng_type == RNGType.MLU: + assert is_mlu_available(), "Can't synchronize MLU seeds on an environment without MLUs." + rng_state = torch.mlu.get_rng_state() + elif rng_type == RNGType.XPU: + assert is_xpu_available(), "Can't synchronize XPU seeds on an environment without XPUs." + rng_state = torch.xpu.get_rng_state() + elif rng_type == RNGType.GENERATOR: + assert generator is not None, "Need a generator to synchronize its seed." + rng_state = generator.get_state() + + # Broadcast the rng state from device 0 to other devices + state = AcceleratorState() + if state.distributed_type == DistributedType.XLA: + rng_state = rng_state.to(xm.xla_device()) + xm.collective_broadcast([rng_state]) + xm.mark_step() + rng_state = rng_state.cpu() + elif ( + state.distributed_type in CUDA_DISTRIBUTED_TYPES + or state.distributed_type == DistributedType.MULTI_MLU + or state.distributed_type == DistributedType.MULTI_NPU + or state.distributed_type == DistributedType.MULTI_XPU + ): + rng_state = rng_state.to(state.device) + torch.distributed.broadcast(rng_state, 0) + rng_state = rng_state.cpu() + elif state.distributed_type == DistributedType.MULTI_CPU: + torch.distributed.broadcast(rng_state, 0) + + # Set the broadcast rng state + if rng_type == RNGType.TORCH: + torch.set_rng_state(rng_state) + elif rng_type == RNGType.CUDA: + torch.cuda.set_rng_state(rng_state) + elif rng_type == RNGType.NPU: + torch.npu.set_rng_state(rng_state) + elif rng_type == RNGType.XPU: + torch.xpu.set_rng_state(rng_state) + elif rng_type == RNGType.XLA: + xm.set_rng_state(rng_state.item()) + elif rng_type == RNGType.GENERATOR: + generator.set_state(rng_state) + + +def synchronize_rng_states(rng_types: List[Union[str, RNGType]], generator: Optional[torch.Generator] = None): + for rng_type in rng_types: + synchronize_rng_state(RNGType(rng_type), generator=generator) diff --git a/llm/Lib/site-packages/accelerate/utils/rich.py b/llm/Lib/site-packages/accelerate/utils/rich.py new file mode 100644 index 0000000000000000000000000000000000000000..2d48661b7fcef92ef1168b74cc275c6d3ccc67a1 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/utils/rich.py @@ -0,0 +1,24 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .imports import is_rich_available + + +if is_rich_available(): + from rich.traceback import install + + install(show_locals=False) + +else: + raise ModuleNotFoundError("To use the rich extension, install rich with `pip install rich`") diff --git a/llm/Lib/site-packages/accelerate/utils/torch_xla.py b/llm/Lib/site-packages/accelerate/utils/torch_xla.py new file mode 100644 index 0000000000000000000000000000000000000000..140133926c2f88d39c70f5a9f46a08f88bed36da --- /dev/null +++ b/llm/Lib/site-packages/accelerate/utils/torch_xla.py @@ -0,0 +1,51 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.metadata +import subprocess +import sys + + +def install_xla(upgrade: bool = False): + """ + Helper function to install appropriate xla wheels based on the `torch` version in Google Colaboratory. + + Args: + upgrade (`bool`, *optional*, defaults to `False`): + Whether to upgrade `torch` and install the latest `torch_xla` wheels. + + Example: + + ```python + >>> from accelerate.utils import install_xla + + >>> install_xla(upgrade=True) + ``` + """ + in_colab = False + if "IPython" in sys.modules: + in_colab = "google.colab" in str(sys.modules["IPython"].get_ipython()) + + if in_colab: + if upgrade: + torch_install_cmd = ["pip", "install", "-U", "torch"] + subprocess.run(torch_install_cmd, check=True) + # get the current version of torch + torch_version = importlib.metadata.version("torch") + torch_version_trunc = torch_version[: torch_version.rindex(".")] + xla_wheel = f"https://storage.googleapis.com/tpu-pytorch/wheels/colab/torch_xla-{torch_version_trunc}-cp37-cp37m-linux_x86_64.whl" + xla_install_cmd = ["pip", "install", xla_wheel] + subprocess.run(xla_install_cmd, check=True) + else: + raise RuntimeError("`install_xla` utility works only on google colab.") diff --git a/llm/Lib/site-packages/accelerate/utils/tqdm.py b/llm/Lib/site-packages/accelerate/utils/tqdm.py new file mode 100644 index 0000000000000000000000000000000000000000..940a8bb04aced0c898ba1926bacc4b60b72d6f54 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/utils/tqdm.py @@ -0,0 +1,37 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .imports import is_tqdm_available + + +if is_tqdm_available(): + from tqdm.auto import tqdm as _tqdm + +from ..state import PartialState + + +def tqdm(main_process_only: bool = True, *args, **kwargs): + """ + Wrapper around `tqdm.tqdm` that optionally displays only on the main process. + + Args: + main_process_only (`bool`, *optional*): + Whether to display the progress bar only on the main process + """ + if not is_tqdm_available(): + raise ImportError("Accelerate's `tqdm` module requires `tqdm` to be installed. Please run `pip install tqdm`.") + disable = False + if main_process_only: + disable = PartialState().local_process_index != 0 + return _tqdm(*args, **kwargs, disable=disable) diff --git a/llm/Lib/site-packages/accelerate/utils/transformer_engine.py b/llm/Lib/site-packages/accelerate/utils/transformer_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..211a9f553ca22ac4938969416d07a9b139918b60 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/utils/transformer_engine.py @@ -0,0 +1,84 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.nn as nn + +from .imports import is_fp8_available + + +if is_fp8_available(): + import transformer_engine.pytorch as te + + +def convert_model(model, to_transformer_engine=True, _convert_linear=True, _convert_ln=True): + """ + Recursively converts the linear and layernorm layers of a model to their `transformers_engine` counterpart. + """ + if not is_fp8_available(): + raise ImportError("Using `convert_model` requires transformer_engine to be installed.") + for name, module in model.named_children(): + if isinstance(module, nn.Linear) and to_transformer_engine and _convert_linear: + # Return early if the linear layer weights are not multiples of 16 + if any(p % 16 != 0 for p in module.weight.shape): + return + has_bias = module.bias is not None + te_module = te.Linear( + module.in_features, module.out_features, bias=has_bias, params_dtype=module.weight.dtype + ) + te_module.weight.copy_(module.weight) + if has_bias: + te_module.bias.copy_(module.bias) + + setattr(model, name, te_module) + elif isinstance(module, nn.LayerNorm) and to_transformer_engine and _convert_ln: + te_module = te.LayerNorm(module.normalized_shape[0], eps=module.eps, params_dtype=module.weight.dtype) + te_module.weight.copy_(module.weight) + te_module.bias.copy_(module.bias) + + setattr(model, name, te_module) + elif isinstance(module, te.Linear) and not to_transformer_engine and _convert_linear: + has_bias = module.bias is not None + new_module = nn.Linear( + module.in_features, module.out_features, bias=has_bias, params_dtype=module.weight.dtype + ) + new_module.weight.copy_(module.weight) + if has_bias: + new_module.bias.copy_(module.bias) + + setattr(model, name, new_module) + elif isinstance(module, te.LayerNorm) and not to_transformer_engine and _convert_ln: + new_module = nn.LayerNorm(module.normalized_shape[0], eps=module.eps, params_dtype=module.weight.dtype) + new_module.weight.copy_(module.weight) + new_module.bias.copy_(module.bias) + + setattr(model, name, new_module) + else: + convert_model( + module, + to_transformer_engine=to_transformer_engine, + _convert_linear=_convert_linear, + _convert_ln=_convert_ln, + ) + + +def has_transformer_engine_layers(model): + """ + Returns whether a given model has some `transformer_engine` layer or not. + """ + if not is_fp8_available(): + raise ImportError("Using `has_transformer_engine_layers` requires transformer_engine to be installed.") + for m in model.modules(): + if isinstance(m, (te.LayerNorm, te.Linear, te.TransformerLayer)): + return True + return False diff --git a/llm/Lib/site-packages/accelerate/utils/versions.py b/llm/Lib/site-packages/accelerate/utils/versions.py new file mode 100644 index 0000000000000000000000000000000000000000..985c918f0e057bacc70c372f6906071bb73db577 --- /dev/null +++ b/llm/Lib/site-packages/accelerate/utils/versions.py @@ -0,0 +1,56 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.metadata +from typing import Union + +from packaging.version import Version, parse + +from .constants import STR_OPERATION_TO_FUNC + + +torch_version = parse(importlib.metadata.version("torch")) + + +def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str): + """ + Compares a library version to some requirement using a given operation. + + Args: + library_or_version (`str` or `packaging.version.Version`): + A library name or a version to check. + operation (`str`): + A string representation of an operator, such as `">"` or `"<="`. + requirement_version (`str`): + The version to compare the library version against + """ + if operation not in STR_OPERATION_TO_FUNC.keys(): + raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}") + operation = STR_OPERATION_TO_FUNC[operation] + if isinstance(library_or_version, str): + library_or_version = parse(importlib.metadata.version(library_or_version)) + return operation(library_or_version, parse(requirement_version)) + + +def is_torch_version(operation: str, version: str): + """ + Compares the current PyTorch version to a given reference with an operation. + + Args: + operation (`str`): + A string representation of an operator, such as `">"` or `"<="` + version (`str`): + A string version of PyTorch + """ + return compare_versions(torch_version, operation, version) diff --git a/llm/Lib/site-packages/aiohttp-3.9.5.dist-info/INSTALLER b/llm/Lib/site-packages/aiohttp-3.9.5.dist-info/INSTALLER new file mode 100644 index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp-3.9.5.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/llm/Lib/site-packages/aiohttp-3.9.5.dist-info/LICENSE.txt b/llm/Lib/site-packages/aiohttp-3.9.5.dist-info/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..aa2bc24043c1dddac90e74eba3ff92b848515898 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp-3.9.5.dist-info/LICENSE.txt @@ -0,0 +1,13 @@ + Copyright aio-libs contributors. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/llm/Lib/site-packages/aiohttp-3.9.5.dist-info/METADATA b/llm/Lib/site-packages/aiohttp-3.9.5.dist-info/METADATA new file mode 100644 index 0000000000000000000000000000000000000000..b120f8f97d986a42c0ff7422c57b4fdd47451573 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp-3.9.5.dist-info/METADATA @@ -0,0 +1,245 @@ +Metadata-Version: 2.1 +Name: aiohttp +Version: 3.9.5 +Summary: Async http client/server framework (asyncio) +Home-page: https://github.com/aio-libs/aiohttp +Maintainer: aiohttp team +Maintainer-email: team@aiohttp.org +License: Apache 2 +Project-URL: Chat: Matrix, https://matrix.to/#/#aio-libs:matrix.org +Project-URL: Chat: Matrix Space, https://matrix.to/#/#aio-libs-space:matrix.org +Project-URL: CI: GitHub Actions, https://github.com/aio-libs/aiohttp/actions?query=workflow%3ACI +Project-URL: Coverage: codecov, https://codecov.io/github/aio-libs/aiohttp +Project-URL: Docs: Changelog, https://docs.aiohttp.org/en/stable/changes.html +Project-URL: Docs: RTD, https://docs.aiohttp.org +Project-URL: GitHub: issues, https://github.com/aio-libs/aiohttp/issues +Project-URL: GitHub: repo, https://github.com/aio-libs/aiohttp +Classifier: Development Status :: 5 - Production/Stable +Classifier: Framework :: AsyncIO +Classifier: Intended Audience :: Developers +Classifier: License :: OSI Approved :: Apache Software License +Classifier: Operating System :: POSIX +Classifier: Operating System :: MacOS :: MacOS X +Classifier: Operating System :: Microsoft :: Windows +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: Topic :: Internet :: WWW/HTTP +Requires-Python: >=3.8 +Description-Content-Type: text/x-rst +License-File: LICENSE.txt +Requires-Dist: aiosignal >=1.1.2 +Requires-Dist: attrs >=17.3.0 +Requires-Dist: frozenlist >=1.1.1 +Requires-Dist: multidict <7.0,>=4.5 +Requires-Dist: yarl <2.0,>=1.0 +Requires-Dist: async-timeout <5.0,>=4.0 ; python_version < "3.11" +Provides-Extra: speedups +Requires-Dist: brotlicffi ; (platform_python_implementation != "CPython") and extra == 'speedups' +Requires-Dist: Brotli ; (platform_python_implementation == "CPython") and extra == 'speedups' +Requires-Dist: aiodns ; (sys_platform == "linux" or sys_platform == "darwin") and extra == 'speedups' + +================================== +Async http client/server framework +================================== + +.. image:: https://raw.githubusercontent.com/aio-libs/aiohttp/master/docs/aiohttp-plain.svg + :height: 64px + :width: 64px + :alt: aiohttp logo + +| + +.. image:: https://github.com/aio-libs/aiohttp/workflows/CI/badge.svg + :target: https://github.com/aio-libs/aiohttp/actions?query=workflow%3ACI + :alt: GitHub Actions status for master branch + +.. image:: https://codecov.io/gh/aio-libs/aiohttp/branch/master/graph/badge.svg + :target: https://codecov.io/gh/aio-libs/aiohttp + :alt: codecov.io status for master branch + +.. image:: https://badge.fury.io/py/aiohttp.svg + :target: https://pypi.org/project/aiohttp + :alt: Latest PyPI package version + +.. image:: https://readthedocs.org/projects/aiohttp/badge/?version=latest + :target: https://docs.aiohttp.org/ + :alt: Latest Read The Docs + +.. image:: https://img.shields.io/matrix/aio-libs:matrix.org?label=Discuss%20on%20Matrix%20at%20%23aio-libs%3Amatrix.org&logo=matrix&server_fqdn=matrix.org&style=flat + :target: https://matrix.to/#/%23aio-libs:matrix.org + :alt: Matrix Room — #aio-libs:matrix.org + +.. image:: https://img.shields.io/matrix/aio-libs-space:matrix.org?label=Discuss%20on%20Matrix%20at%20%23aio-libs-space%3Amatrix.org&logo=matrix&server_fqdn=matrix.org&style=flat + :target: https://matrix.to/#/%23aio-libs-space:matrix.org + :alt: Matrix Space — #aio-libs-space:matrix.org + + +Key Features +============ + +- Supports both client and server side of HTTP protocol. +- Supports both client and server Web-Sockets out-of-the-box and avoids + Callback Hell. +- Provides Web-server with middleware and pluggable routing. + + +Getting started +=============== + +Client +------ + +To get something from the web: + +.. code-block:: python + + import aiohttp + import asyncio + + async def main(): + + async with aiohttp.ClientSession() as session: + async with session.get('http://python.org') as response: + + print("Status:", response.status) + print("Content-type:", response.headers['content-type']) + + html = await response.text() + print("Body:", html[:15], "...") + + asyncio.run(main()) + +This prints: + +.. code-block:: + + Status: 200 + Content-type: text/html; charset=utf-8 + Body: ... + +Coming from `requests `_ ? Read `why we need so many lines `_. + +Server +------ + +An example using a simple server: + +.. code-block:: python + + # examples/server_simple.py + from aiohttp import web + + async def handle(request): + name = request.match_info.get('name', "Anonymous") + text = "Hello, " + name + return web.Response(text=text) + + async def wshandle(request): + ws = web.WebSocketResponse() + await ws.prepare(request) + + async for msg in ws: + if msg.type == web.WSMsgType.text: + await ws.send_str("Hello, {}".format(msg.data)) + elif msg.type == web.WSMsgType.binary: + await ws.send_bytes(msg.data) + elif msg.type == web.WSMsgType.close: + break + + return ws + + + app = web.Application() + app.add_routes([web.get('/', handle), + web.get('/echo', wshandle), + web.get('/{name}', handle)]) + + if __name__ == '__main__': + web.run_app(app) + + +Documentation +============= + +https://aiohttp.readthedocs.io/ + + +Demos +===== + +https://github.com/aio-libs/aiohttp-demos + + +External links +============== + +* `Third party libraries + `_ +* `Built with aiohttp + `_ +* `Powered by aiohttp + `_ + +Feel free to make a Pull Request for adding your link to these pages! + + +Communication channels +====================== + +*aio-libs Discussions*: https://github.com/aio-libs/aiohttp/discussions + +*gitter chat* https://gitter.im/aio-libs/Lobby + +We support `Stack Overflow +`_. +Please add *aiohttp* tag to your question there. + +Requirements +============ + +- async-timeout_ +- attrs_ +- multidict_ +- yarl_ +- frozenlist_ + +Optionally you may install the aiodns_ library (highly recommended for sake of speed). + +.. _aiodns: https://pypi.python.org/pypi/aiodns +.. _attrs: https://github.com/python-attrs/attrs +.. _multidict: https://pypi.python.org/pypi/multidict +.. _frozenlist: https://pypi.org/project/frozenlist/ +.. _yarl: https://pypi.python.org/pypi/yarl +.. _async-timeout: https://pypi.python.org/pypi/async_timeout + +License +======= + +``aiohttp`` is offered under the Apache 2 license. + + +Keepsafe +======== + +The aiohttp community would like to thank Keepsafe +(https://www.getkeepsafe.com) for its support in the early days of +the project. + + +Source code +=========== + +The latest developer version is available in a GitHub repository: +https://github.com/aio-libs/aiohttp + +Benchmarks +========== + +If you are interested in efficiency, the AsyncIO community maintains a +list of benchmarks on the official wiki: +https://github.com/python/asyncio/wiki/Benchmarks diff --git a/llm/Lib/site-packages/aiohttp-3.9.5.dist-info/RECORD b/llm/Lib/site-packages/aiohttp-3.9.5.dist-info/RECORD new file mode 100644 index 0000000000000000000000000000000000000000..54c1f40b4208f0f7a6ddd80da22ebabe8e1fe075 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp-3.9.5.dist-info/RECORD @@ -0,0 +1,119 @@ +aiohttp-3.9.5.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +aiohttp-3.9.5.dist-info/LICENSE.txt,sha256=wUk-nxDVnR-6n53ygAjhVX4zz5-6yM4SY6ozk5goA94,601 +aiohttp-3.9.5.dist-info/METADATA,sha256=kKiQL2qSh9OPRTR0p18PDA_eIr3RVDY5GCUc6EwN9Ig,7704 +aiohttp-3.9.5.dist-info/RECORD,, +aiohttp-3.9.5.dist-info/WHEEL,sha256=nSybvzWlmdJnHiUQSY-d7V1ycwEVUTqXiTvr2eshg44,102 +aiohttp-3.9.5.dist-info/top_level.txt,sha256=iv-JIaacmTl-hSho3QmphcKnbRRYx1st47yjz_178Ro,8 +aiohttp/.hash/_cparser.pxd.hash,sha256=dVGMrCmyJM_owqoRLPezK095md0X5R319koTuhUN6DQ,64 +aiohttp/.hash/_find_header.pxd.hash,sha256=W5qRPWDc55gArGZkriI5tztmQHkrdwR6NdQfRQfTxIg,64 +aiohttp/.hash/_helpers.pyi.hash,sha256=bAsxbXsjcZ5gbj1c561GYcRtQ5REXxrCihR-HN0XKPk,64 +aiohttp/.hash/_helpers.pyx.hash,sha256=-DfrN0XUqBhyb8bp2fJQVb1Lo9S1S-psob-7MJBM18c,64 +aiohttp/.hash/_http_parser.pyx.hash,sha256=oAd3pEXdaK_NpQ_aHW-X0zqg5DXUxWP0cWVBGdFu814,64 +aiohttp/.hash/_http_writer.pyx.hash,sha256=z39c0hUcdud-ZCon2d9bWpxrFMVdW1dvjtCgxW4RDnI,64 +aiohttp/.hash/_websocket.pyx.hash,sha256=90x5ulhWiFtw2wAri2_82Zas5i3iEkJ-flYJK9Xx-SY,64 +aiohttp/.hash/hdrs.py.hash,sha256=QBHPUkJcp8iPZv3ENUbevgpJzljxoP2qwkBeX3nQ82o,64 +aiohttp/__init__.py,sha256=iyyddK6RtQIJpHQL71YaQ9-ti3S5HUeeaTpIaLlfKVs,8002 +aiohttp/__pycache__/__init__.cpython-311.pyc,, +aiohttp/__pycache__/abc.cpython-311.pyc,, +aiohttp/__pycache__/base_protocol.cpython-311.pyc,, +aiohttp/__pycache__/client.cpython-311.pyc,, +aiohttp/__pycache__/client_exceptions.cpython-311.pyc,, +aiohttp/__pycache__/client_proto.cpython-311.pyc,, +aiohttp/__pycache__/client_reqrep.cpython-311.pyc,, +aiohttp/__pycache__/client_ws.cpython-311.pyc,, +aiohttp/__pycache__/compression_utils.cpython-311.pyc,, +aiohttp/__pycache__/connector.cpython-311.pyc,, +aiohttp/__pycache__/cookiejar.cpython-311.pyc,, +aiohttp/__pycache__/formdata.cpython-311.pyc,, +aiohttp/__pycache__/hdrs.cpython-311.pyc,, +aiohttp/__pycache__/helpers.cpython-311.pyc,, +aiohttp/__pycache__/http.cpython-311.pyc,, +aiohttp/__pycache__/http_exceptions.cpython-311.pyc,, +aiohttp/__pycache__/http_parser.cpython-311.pyc,, +aiohttp/__pycache__/http_websocket.cpython-311.pyc,, +aiohttp/__pycache__/http_writer.cpython-311.pyc,, +aiohttp/__pycache__/locks.cpython-311.pyc,, +aiohttp/__pycache__/log.cpython-311.pyc,, +aiohttp/__pycache__/multipart.cpython-311.pyc,, +aiohttp/__pycache__/payload.cpython-311.pyc,, +aiohttp/__pycache__/payload_streamer.cpython-311.pyc,, +aiohttp/__pycache__/pytest_plugin.cpython-311.pyc,, +aiohttp/__pycache__/resolver.cpython-311.pyc,, +aiohttp/__pycache__/streams.cpython-311.pyc,, +aiohttp/__pycache__/tcp_helpers.cpython-311.pyc,, +aiohttp/__pycache__/test_utils.cpython-311.pyc,, +aiohttp/__pycache__/tracing.cpython-311.pyc,, +aiohttp/__pycache__/typedefs.cpython-311.pyc,, +aiohttp/__pycache__/web.cpython-311.pyc,, +aiohttp/__pycache__/web_app.cpython-311.pyc,, +aiohttp/__pycache__/web_exceptions.cpython-311.pyc,, +aiohttp/__pycache__/web_fileresponse.cpython-311.pyc,, +aiohttp/__pycache__/web_log.cpython-311.pyc,, +aiohttp/__pycache__/web_middlewares.cpython-311.pyc,, +aiohttp/__pycache__/web_protocol.cpython-311.pyc,, +aiohttp/__pycache__/web_request.cpython-311.pyc,, +aiohttp/__pycache__/web_response.cpython-311.pyc,, +aiohttp/__pycache__/web_routedef.cpython-311.pyc,, +aiohttp/__pycache__/web_runner.cpython-311.pyc,, +aiohttp/__pycache__/web_server.cpython-311.pyc,, +aiohttp/__pycache__/web_urldispatcher.cpython-311.pyc,, +aiohttp/__pycache__/web_ws.cpython-311.pyc,, +aiohttp/__pycache__/worker.cpython-311.pyc,, +aiohttp/_cparser.pxd,sha256=W6-cu0SyHhOEPeb475NvxagQ1Jz9pWqyZJvwEqTLNs0,4476 +aiohttp/_find_header.pxd,sha256=BFUSmxhemBtblqxzjzH3x03FfxaWlTyuAIOz8YZ5_nM,70 +aiohttp/_headers.pxi,sha256=1MhCe6Un_KI1tpO85HnDfzVO94BhcirLanAOys5FIHA,2090 +aiohttp/_helpers.cp311-win_amd64.pyd,sha256=3R9fX40_j4Qp6P0DGVp3708xDQp6TnupZVP1NO8d-30,54272 +aiohttp/_helpers.pyi,sha256=2Hd5IC0Zf4YTEJ412suyyhsh1kVyVDv5g4stgyo2Ksc,208 +aiohttp/_helpers.pyx,sha256=tgl7fZh0QMT6cjf4jSJ8iaO6DdQD3GON2-SH4N5_ETg,1084 +aiohttp/_http_parser.cp311-win_amd64.pyd,sha256=CkSfUFHRcy_rS4wjSOdQR7vDjOmeb1sacMsk_M5Q7Uc,263168 +aiohttp/_http_parser.pyx,sha256=SAenRA9RzY5AaVWWns_s9DMbZlVHgKlEyLgPQvUcTb0,28963 +aiohttp/_http_writer.cp311-win_amd64.pyd,sha256=FkrGrf2qzPJRUm3Ir2ra6_zwR0bJxSRjTlmv71Oh-Cs,49152 +aiohttp/_http_writer.pyx,sha256=8CBLytO2rx1kdpWe9HYSznhLXdeZWyE-3xI7jaGasag,4738 +aiohttp/_websocket.cp311-win_amd64.pyd,sha256=uF98m_IwYr6Ne553zVRBb9t2jOr7EUwc2xn4s0mpN3w,36352 +aiohttp/_websocket.pyx,sha256=o9J7yi9c2-jTBjE3dUkXxhDWKvRWJz5GZfyLsgJQa38,1617 +aiohttp/abc.py,sha256=wpbkcMsLWB_r_sD35PesqbWFY3tZOvilOjqSqHDUaMQ,5709 +aiohttp/base_protocol.py,sha256=N4IzuMQLGHG7AJji9um-2gqQxKGo41P3aCZwrvegumc,2971 +aiohttp/client.py,sha256=VSQeFcR4eLEgXNRug-bWozUDC6bJNvTbz6DR39o0tQc,48865 +aiohttp/client_exceptions.py,sha256=BiReSs5jdjdmhB99vYNBcypsuRfyQU2UQXG0cVRAD2A,9757 +aiohttp/client_proto.py,sha256=9tmgb3DbBUQC8D9Vv-uElsKH28XsE_C8C6SD30rGBb0,10206 +aiohttp/client_reqrep.py,sha256=DMwNd2p84R9kHadxiEppU9PvefSSqhu9eEOScRPjlL4,41282 +aiohttp/client_ws.py,sha256=mi8iVYQR25Hi20AQing6T1BZBcO24NsQIKubhsR8izM,11325 +aiohttp/compression_utils.py,sha256=OIQOhFq_YssPx-SNCn9UkQTqntHhZKZakhwUQfaJSyA,5172 +aiohttp/connector.py,sha256=bwu10B9Bdgm59-Aghyjy4LRC4aHzL03F1K6dZy-BLuU,55307 +aiohttp/cookiejar.py,sha256=vgjrRISdZ5jwGfJC6T9QLuW-f56KfpGrtnrcbsUHkU0,14434 +aiohttp/formdata.py,sha256=XRu5kZi8MqEMoct2KGBrVetaHVthIRI__0o3UQfctzc,6703 +aiohttp/hdrs.py,sha256=_JN4MBE-UoBXGWGoSCKhIviTRc2IXS4fyk5nnuox0Ak,4721 +aiohttp/helpers.py,sha256=Zx9sbxHNJGVOyTzjBEcZ6kOZwn_6g8hD5cej4Jb9Yx0,32017 +aiohttp/http.py,sha256=DGKcwDbgIMpasv7s2jeKCRuixyj7W-RIrihRFjj0xcY,1914 +aiohttp/http_exceptions.py,sha256=GJYn38j4sI4KdUh993VnZlbgVHOUNI_Z9-ASDTjl5aU,2822 +aiohttp/http_parser.py,sha256=V-CWgfz6PI1QRM_7TjT8O3SNHTBi3lGzDCiQHLl7n74,37548 +aiohttp/http_websocket.py,sha256=eM8wkn12fCF2vV-Zm7ITxQsG7hOdRHeD8PF6NPmJpsQ,27472 +aiohttp/http_writer.py,sha256=p8H39HhtilQEE90njvtJHc94Am95zjHNoS8T1JcNXJc,6131 +aiohttp/locks.py,sha256=vp1Z4zx0SvooSffw88dkZ-7qpk2CqRf5vWh2dpKagTA,1177 +aiohttp/log.py,sha256=zYUTvXsMQ9Sz1yNN8kXwd5Qxu49a1FzjZ_wQqriEc8M,333 +aiohttp/multipart.py,sha256=N3WiFX-46QvCijk3KtpHCJnuJRGWiHYnLJDiQ6Mf4QA,35952 +aiohttp/payload.py,sha256=Ap0E4_p1d9E2UTTIYmBe5HXz05Tt7TN5nZ2zfYDtkr0,14005 +aiohttp/payload_streamer.py,sha256=rBb3jAFcwAK1QOgbhya2y4zGjhT11oQrepdcffA1_jM,2162 +aiohttp/py.typed,sha256=3VVwXUAWVEVX7sDwyYDnW5ZdBC9_Z9AJAFfLCleUW0k,8 +aiohttp/pytest_plugin.py,sha256=fJxoTu3NI1wDjFIV1FDmx0oGdRU7r6baovTwJneR6J4,11986 +aiohttp/resolver.py,sha256=k5cVNWiiCHqKDGko7UZNu2y-j6SrU7vQBx-omwHPhso,5230 +aiohttp/streams.py,sha256=v9GkCFn9mynR1yGB1Xl6mwxEJXsOC4T-DzoL4IC0XOU,21812 +aiohttp/tcp_helpers.py,sha256=K-hhGh3jd6qCEnHJo8LvFyfJwBjh99UKI7A0aSRVhj4,998 +aiohttp/test_utils.py,sha256=q3NaqC9-lC-9smmXloLwfpB5tC_e20thC3LarO5lddk,21157 +aiohttp/tracing.py,sha256=0EccU7PYykvNwd75SNFcuQ1I57RdSu2jDYIvNK2BW6c,15603 +aiohttp/typedefs.py,sha256=x7HBHDU2IlRZZb7ketOdc2Js0MLx53agxk0UnLxNFw4,1525 +aiohttp/web.py,sha256=8rghTkpERz14vRQA4oyXfGZLWMIBMrrHXxcvtZZ_fAU,19879 +aiohttp/web_app.py,sha256=0bajIxV0xb5AB9TChRsoVP8ytTLiY_TU7zrQ66GHNus,18907 +aiohttp/web_exceptions.py,sha256=itNRhCMDJFhnMWftr5SyTsoqh-i0n9rzTj0sjcAEUjo,10812 +aiohttp/web_fileresponse.py,sha256=4ddU695Y4w1reMFqqpuxbJQ9vF0XEG7HsCDl930qsRs,11874 +aiohttp/web_log.py,sha256=w81HIudhfSxfodo2Fjkok7jWT56XXIrVMJN6ihYnLo0,8014 +aiohttp/web_middlewares.py,sha256=rYWtxDZ2AM3C2FvNuNyffpfmMfcHrxkSZMaTcsG1T_Q,4148 +aiohttp/web_protocol.py,sha256=6YU4IsJldwoi8SZQ3dp-fHi3TdljwW2uRRj4BgTdfSU,23758 +aiohttp/web_request.py,sha256=b67jMnKemnPjCX2MloT_U4IP5ILM6nWABItbK5gWR30,29887 +aiohttp/web_response.py,sha256=HNzl7eOOI1TR6Kfd30sQIinJXWWhVHM19YUdZvEbmFs,28677 +aiohttp/web_routedef.py,sha256=7ZribqwusXb1s0T2vLj1roFne1fdz_ZsudBwRyiwQxM,6348 +aiohttp/web_runner.py,sha256=Oled_t6ma3qQ6f9xOWv_8jgt0tl-dvhvEG09ms83rQk,12360 +aiohttp/web_server.py,sha256=kOlImrScEbvkGHG7i-N-7eqf55f2zC_J2BZcJanGGmU,2664 +aiohttp/web_urldispatcher.py,sha256=4URDODVtRTsEbjjJb-wX0FhoJ7T5nhrJFz8NpEP9gb4,41366 +aiohttp/web_ws.py,sha256=77sl9-DZGXZ01BaU0nJfTn9gQ9B8Ps0XY7kK8n3u_SU,19499 +aiohttp/worker.py,sha256=vDMxlk-Mo3rzN4yubw2-c8T6yg7PRY8Mv0NLuRm8lWw,8212 diff --git a/llm/Lib/site-packages/aiohttp-3.9.5.dist-info/WHEEL b/llm/Lib/site-packages/aiohttp-3.9.5.dist-info/WHEEL new file mode 100644 index 0000000000000000000000000000000000000000..d111766f66e4cd0e45f15c92ed470f34c8f3434c --- /dev/null +++ b/llm/Lib/site-packages/aiohttp-3.9.5.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.43.0) +Root-Is-Purelib: false +Tag: cp311-cp311-win_amd64 + diff --git a/llm/Lib/site-packages/aiohttp-3.9.5.dist-info/top_level.txt b/llm/Lib/site-packages/aiohttp-3.9.5.dist-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..ee4ba4f3d739e094878215c84eb41ba85c80e4a8 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp-3.9.5.dist-info/top_level.txt @@ -0,0 +1 @@ +aiohttp diff --git a/llm/Lib/site-packages/aiohttp/.hash/_cparser.pxd.hash b/llm/Lib/site-packages/aiohttp/.hash/_cparser.pxd.hash new file mode 100644 index 0000000000000000000000000000000000000000..1dc9b9a49e79ff6252bfcef9d1a923925533f5b6 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/.hash/_cparser.pxd.hash @@ -0,0 +1 @@ +5baf9cbb44b21e13843de6f8ef936fc5a810d49cfda56ab2649bf012a4cb36cd \ No newline at end of file diff --git a/llm/Lib/site-packages/aiohttp/.hash/_find_header.pxd.hash b/llm/Lib/site-packages/aiohttp/.hash/_find_header.pxd.hash new file mode 100644 index 0000000000000000000000000000000000000000..ab9d476577b82ff8744b6acaf7b9b315d9057296 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/.hash/_find_header.pxd.hash @@ -0,0 +1 @@ +0455129b185e981b5b96ac738f31f7c74dc57f1696953cae0083b3f18679fe73 \ No newline at end of file diff --git a/llm/Lib/site-packages/aiohttp/.hash/_helpers.pyi.hash b/llm/Lib/site-packages/aiohttp/.hash/_helpers.pyi.hash new file mode 100644 index 0000000000000000000000000000000000000000..c304d37499171f02314aa886217d9973902af22f --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/.hash/_helpers.pyi.hash @@ -0,0 +1 @@ +d87779202d197f8613109e35dacbb2ca1b21d64572543bf9838b2d832a362ac7 \ No newline at end of file diff --git a/llm/Lib/site-packages/aiohttp/.hash/_helpers.pyx.hash b/llm/Lib/site-packages/aiohttp/.hash/_helpers.pyx.hash new file mode 100644 index 0000000000000000000000000000000000000000..8164dbb55e2cc59eb36e46334419e7501be1a10b --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/.hash/_helpers.pyx.hash @@ -0,0 +1 @@ +b6097b7d987440c4fa7237f88d227c89a3ba0dd403dc638ddbe487e0de7f1138 \ No newline at end of file diff --git a/llm/Lib/site-packages/aiohttp/.hash/_http_parser.pyx.hash b/llm/Lib/site-packages/aiohttp/.hash/_http_parser.pyx.hash new file mode 100644 index 0000000000000000000000000000000000000000..169d66d01b52acd26f1285dc0bb97f7d581946b7 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/.hash/_http_parser.pyx.hash @@ -0,0 +1 @@ +4807a7440f51cd8e406955969ecfecf4331b66554780a944c8b80f42f51c4dbd \ No newline at end of file diff --git a/llm/Lib/site-packages/aiohttp/.hash/_http_writer.pyx.hash b/llm/Lib/site-packages/aiohttp/.hash/_http_writer.pyx.hash new file mode 100644 index 0000000000000000000000000000000000000000..10d83475cc21f7ad1fe1340720add4ac47e00f3e --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/.hash/_http_writer.pyx.hash @@ -0,0 +1 @@ +f0204bcad3b6af1d6476959ef47612ce784b5dd7995b213edf123b8da19ab1a8 \ No newline at end of file diff --git a/llm/Lib/site-packages/aiohttp/.hash/_websocket.pyx.hash b/llm/Lib/site-packages/aiohttp/.hash/_websocket.pyx.hash new file mode 100644 index 0000000000000000000000000000000000000000..511f26f901f0e33c38202e3d5d11666d2e09b5cd --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/.hash/_websocket.pyx.hash @@ -0,0 +1 @@ +a3d27bca2f5cdbe8d3063137754917c610d62af456273e4665fc8bb202506b7f \ No newline at end of file diff --git a/llm/Lib/site-packages/aiohttp/.hash/hdrs.py.hash b/llm/Lib/site-packages/aiohttp/.hash/hdrs.py.hash new file mode 100644 index 0000000000000000000000000000000000000000..b032e6e9a8e9a22c93a3ea8cf81526c811c2391a --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/.hash/hdrs.py.hash @@ -0,0 +1 @@ +fc937830113e5280571961a84822a122f89345cd885d2e1fca4e679eea31d009 \ No newline at end of file diff --git a/llm/Lib/site-packages/aiohttp/__init__.py b/llm/Lib/site-packages/aiohttp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..07a0b25a8d8b602d6d2ad646e6149510f1ba0082 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/__init__.py @@ -0,0 +1,240 @@ +__version__ = "3.9.5" + +from typing import TYPE_CHECKING, Tuple + +from . import hdrs as hdrs +from .client import ( + BaseConnector as BaseConnector, + ClientConnectionError as ClientConnectionError, + ClientConnectorCertificateError as ClientConnectorCertificateError, + ClientConnectorError as ClientConnectorError, + ClientConnectorSSLError as ClientConnectorSSLError, + ClientError as ClientError, + ClientHttpProxyError as ClientHttpProxyError, + ClientOSError as ClientOSError, + ClientPayloadError as ClientPayloadError, + ClientProxyConnectionError as ClientProxyConnectionError, + ClientRequest as ClientRequest, + ClientResponse as ClientResponse, + ClientResponseError as ClientResponseError, + ClientSession as ClientSession, + ClientSSLError as ClientSSLError, + ClientTimeout as ClientTimeout, + ClientWebSocketResponse as ClientWebSocketResponse, + ContentTypeError as ContentTypeError, + Fingerprint as Fingerprint, + InvalidURL as InvalidURL, + NamedPipeConnector as NamedPipeConnector, + RequestInfo as RequestInfo, + ServerConnectionError as ServerConnectionError, + ServerDisconnectedError as ServerDisconnectedError, + ServerFingerprintMismatch as ServerFingerprintMismatch, + ServerTimeoutError as ServerTimeoutError, + TCPConnector as TCPConnector, + TooManyRedirects as TooManyRedirects, + UnixConnector as UnixConnector, + WSServerHandshakeError as WSServerHandshakeError, + request as request, +) +from .cookiejar import CookieJar as CookieJar, DummyCookieJar as DummyCookieJar +from .formdata import FormData as FormData +from .helpers import BasicAuth, ChainMapProxy, ETag +from .http import ( + HttpVersion as HttpVersion, + HttpVersion10 as HttpVersion10, + HttpVersion11 as HttpVersion11, + WebSocketError as WebSocketError, + WSCloseCode as WSCloseCode, + WSMessage as WSMessage, + WSMsgType as WSMsgType, +) +from .multipart import ( + BadContentDispositionHeader as BadContentDispositionHeader, + BadContentDispositionParam as BadContentDispositionParam, + BodyPartReader as BodyPartReader, + MultipartReader as MultipartReader, + MultipartWriter as MultipartWriter, + content_disposition_filename as content_disposition_filename, + parse_content_disposition as parse_content_disposition, +) +from .payload import ( + PAYLOAD_REGISTRY as PAYLOAD_REGISTRY, + AsyncIterablePayload as AsyncIterablePayload, + BufferedReaderPayload as BufferedReaderPayload, + BytesIOPayload as BytesIOPayload, + BytesPayload as BytesPayload, + IOBasePayload as IOBasePayload, + JsonPayload as JsonPayload, + Payload as Payload, + StringIOPayload as StringIOPayload, + StringPayload as StringPayload, + TextIOPayload as TextIOPayload, + get_payload as get_payload, + payload_type as payload_type, +) +from .payload_streamer import streamer as streamer +from .resolver import ( + AsyncResolver as AsyncResolver, + DefaultResolver as DefaultResolver, + ThreadedResolver as ThreadedResolver, +) +from .streams import ( + EMPTY_PAYLOAD as EMPTY_PAYLOAD, + DataQueue as DataQueue, + EofStream as EofStream, + FlowControlDataQueue as FlowControlDataQueue, + StreamReader as StreamReader, +) +from .tracing import ( + TraceConfig as TraceConfig, + TraceConnectionCreateEndParams as TraceConnectionCreateEndParams, + TraceConnectionCreateStartParams as TraceConnectionCreateStartParams, + TraceConnectionQueuedEndParams as TraceConnectionQueuedEndParams, + TraceConnectionQueuedStartParams as TraceConnectionQueuedStartParams, + TraceConnectionReuseconnParams as TraceConnectionReuseconnParams, + TraceDnsCacheHitParams as TraceDnsCacheHitParams, + TraceDnsCacheMissParams as TraceDnsCacheMissParams, + TraceDnsResolveHostEndParams as TraceDnsResolveHostEndParams, + TraceDnsResolveHostStartParams as TraceDnsResolveHostStartParams, + TraceRequestChunkSentParams as TraceRequestChunkSentParams, + TraceRequestEndParams as TraceRequestEndParams, + TraceRequestExceptionParams as TraceRequestExceptionParams, + TraceRequestRedirectParams as TraceRequestRedirectParams, + TraceRequestStartParams as TraceRequestStartParams, + TraceResponseChunkReceivedParams as TraceResponseChunkReceivedParams, +) + +if TYPE_CHECKING: + # At runtime these are lazy-loaded at the bottom of the file. + from .worker import ( + GunicornUVLoopWebWorker as GunicornUVLoopWebWorker, + GunicornWebWorker as GunicornWebWorker, + ) + +__all__: Tuple[str, ...] = ( + "hdrs", + # client + "BaseConnector", + "ClientConnectionError", + "ClientConnectorCertificateError", + "ClientConnectorError", + "ClientConnectorSSLError", + "ClientError", + "ClientHttpProxyError", + "ClientOSError", + "ClientPayloadError", + "ClientProxyConnectionError", + "ClientResponse", + "ClientRequest", + "ClientResponseError", + "ClientSSLError", + "ClientSession", + "ClientTimeout", + "ClientWebSocketResponse", + "ContentTypeError", + "Fingerprint", + "InvalidURL", + "RequestInfo", + "ServerConnectionError", + "ServerDisconnectedError", + "ServerFingerprintMismatch", + "ServerTimeoutError", + "TCPConnector", + "TooManyRedirects", + "UnixConnector", + "NamedPipeConnector", + "WSServerHandshakeError", + "request", + # cookiejar + "CookieJar", + "DummyCookieJar", + # formdata + "FormData", + # helpers + "BasicAuth", + "ChainMapProxy", + "ETag", + # http + "HttpVersion", + "HttpVersion10", + "HttpVersion11", + "WSMsgType", + "WSCloseCode", + "WSMessage", + "WebSocketError", + # multipart + "BadContentDispositionHeader", + "BadContentDispositionParam", + "BodyPartReader", + "MultipartReader", + "MultipartWriter", + "content_disposition_filename", + "parse_content_disposition", + # payload + "AsyncIterablePayload", + "BufferedReaderPayload", + "BytesIOPayload", + "BytesPayload", + "IOBasePayload", + "JsonPayload", + "PAYLOAD_REGISTRY", + "Payload", + "StringIOPayload", + "StringPayload", + "TextIOPayload", + "get_payload", + "payload_type", + # payload_streamer + "streamer", + # resolver + "AsyncResolver", + "DefaultResolver", + "ThreadedResolver", + # streams + "DataQueue", + "EMPTY_PAYLOAD", + "EofStream", + "FlowControlDataQueue", + "StreamReader", + # tracing + "TraceConfig", + "TraceConnectionCreateEndParams", + "TraceConnectionCreateStartParams", + "TraceConnectionQueuedEndParams", + "TraceConnectionQueuedStartParams", + "TraceConnectionReuseconnParams", + "TraceDnsCacheHitParams", + "TraceDnsCacheMissParams", + "TraceDnsResolveHostEndParams", + "TraceDnsResolveHostStartParams", + "TraceRequestChunkSentParams", + "TraceRequestEndParams", + "TraceRequestExceptionParams", + "TraceRequestRedirectParams", + "TraceRequestStartParams", + "TraceResponseChunkReceivedParams", + # workers (imported lazily with __getattr__) + "GunicornUVLoopWebWorker", + "GunicornWebWorker", +) + + +def __dir__() -> Tuple[str, ...]: + return __all__ + ("__author__", "__doc__") + + +def __getattr__(name: str) -> object: + global GunicornUVLoopWebWorker, GunicornWebWorker + + # Importing gunicorn takes a long time (>100ms), so only import if actually needed. + if name in ("GunicornUVLoopWebWorker", "GunicornWebWorker"): + try: + from .worker import GunicornUVLoopWebWorker as guv, GunicornWebWorker as gw + except ImportError: + return None + + GunicornUVLoopWebWorker = guv # type: ignore[misc] + GunicornWebWorker = gw # type: ignore[misc] + return guv if name == "GunicornUVLoopWebWorker" else gw + + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/__init__.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3dbf5fc2608102ed065ae01f466f5d664a593262 Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/__init__.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/abc.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/abc.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f074f8f8db437bbf93058d287c60ebe89f423332 Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/abc.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/base_protocol.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/base_protocol.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9199dd055b29ed96b1763b96da179e37824906e3 Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/base_protocol.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/client.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/client.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec1aec15e293783438554af7ccecddc0e45e657b Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/client.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/client_exceptions.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/client_exceptions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce8003bb1d152afe802f61298b1435b50a67ffac Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/client_exceptions.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/client_proto.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/client_proto.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20fb5337d315ac89197d6db52171bed891b02a24 Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/client_proto.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/client_reqrep.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/client_reqrep.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a39e7496d8b89bf72d21c7dff02e66e32e5133a Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/client_reqrep.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/client_ws.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/client_ws.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..687dd4f0910d19d21ab5e5a0cbcec695489203ab Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/client_ws.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/compression_utils.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/compression_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..191384c946a2f0b28518d65772619fe97d2a117f Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/compression_utils.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/connector.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/connector.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03028c1cd173ffba56aef37cededd87a11fcc501 Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/connector.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/cookiejar.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/cookiejar.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70eaca988f63bee5828b94efbf147a01b01cc23f Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/cookiejar.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/formdata.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/formdata.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c34141e8f70ece297cbd6970129f68a3dacd5d6f Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/formdata.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/hdrs.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/hdrs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a935795177de1ba395a5d11f23ba3de65042054a Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/hdrs.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/helpers.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/helpers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a16a5f4dae942cab772d5efb3e9c49758e074f8 Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/helpers.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/http.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/http.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab6fe41a81b632a303a6e1633328b4ea770d0646 Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/http.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/http_exceptions.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/http_exceptions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..169367442a78aba631d6fba6d285dde2ac49d940 Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/http_exceptions.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/http_parser.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/http_parser.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96b515f1670973212a1541e53b6a8d3c6dd94bf4 Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/http_parser.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/http_websocket.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/http_websocket.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0b020eb60de82e5c186e2ca69ac5865eab225fd Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/http_websocket.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/http_writer.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/http_writer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0cb29ef827b51c54b12072c05984a3ef3c763c8f Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/http_writer.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/locks.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/locks.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f98679d0fa245784b1ad73ebfa4c9e17a0b65cf Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/locks.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/log.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/log.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d8703f5f46d5d41cff4129a3808f9c36da63db9 Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/log.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/multipart.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/multipart.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14c953c790081b953dcedb02db324647e8669f33 Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/multipart.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/payload.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/payload.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af9373ff543f3a08c3f1e6ab207f3f644771d63a Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/payload.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/payload_streamer.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/payload_streamer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c56594407980bff6ddaa69ba112a5639cc8428b Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/payload_streamer.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/pytest_plugin.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/pytest_plugin.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..331143df54a2eeb79f5cb4ccb4ccb465e2fbfdb2 Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/pytest_plugin.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/resolver.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/resolver.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ee50ca157811e7de8d9ca80e755c2c4f2ec2dd2 Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/resolver.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/streams.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/streams.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d26edb932839889f0cf3c5c3e0fed5b1e0cc4b1 Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/streams.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/tcp_helpers.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/tcp_helpers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6dbdb81d7683c2219562bb2e35debe554dc6cc64 Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/tcp_helpers.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/test_utils.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/test_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..831a900818c2b5472e0269e1a7af2059124de6b5 Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/test_utils.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/tracing.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/tracing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3dd1a8015be4534a16a982a5c41bdea59374c2aa Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/tracing.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/typedefs.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/typedefs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6feaa9400d884792dd67d54eb5ad508e22ff9c6c Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/typedefs.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/web.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/web.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..988ff714747358a9b17fb256a8741bb9ba37672a Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/web.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/web_app.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/web_app.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..071ea467073a999b78c9c7faa7e28e456f941aa0 Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/web_app.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/web_exceptions.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/web_exceptions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1287d812d166801d9d3f49a80945f4905643a76b Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/web_exceptions.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/web_fileresponse.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/web_fileresponse.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd431017cfbcb05963e499e2369d38de5465c303 Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/web_fileresponse.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/web_log.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/web_log.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3dc224770dc11d76be32b20f988d4cba818e188 Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/web_log.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/web_middlewares.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/web_middlewares.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb7a1b4aeece52354926fe5847653e82655d80c7 Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/web_middlewares.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/web_protocol.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/web_protocol.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f029074b79eb19fbb8c9050a60e402ec6ab4d760 Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/web_protocol.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/web_request.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/web_request.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9d7e144b9df6cabf64baf2a47b64416bef95968 Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/web_request.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/web_response.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/web_response.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6dbc0d0b88fc8eb3b8c1d34bd87580e32b0d6082 Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/web_response.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/web_routedef.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/web_routedef.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f695aa7bb7c32942d190be4375ff871f2bd3519b Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/web_routedef.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/web_runner.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/web_runner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bdb0811ff7b8c13345f4fdc9662b333c101f26b Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/web_runner.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/web_server.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/web_server.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0509266ab91a14ee7d515447f6db96d6398c5bb3 Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/web_server.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/web_urldispatcher.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/web_urldispatcher.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0347ae582b952e3556fcc19bbca7bc70542bd07f Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/web_urldispatcher.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/web_ws.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/web_ws.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..970329a54477986f24f59f27506b96e437a0a6fc Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/web_ws.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/__pycache__/worker.cpython-311.pyc b/llm/Lib/site-packages/aiohttp/__pycache__/worker.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb64d7b48301dbdba4a3aa2e9ff52e7ee4588a17 Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/__pycache__/worker.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiohttp/_cparser.pxd b/llm/Lib/site-packages/aiohttp/_cparser.pxd new file mode 100644 index 0000000000000000000000000000000000000000..a4783837fc157e784d41e0c9093bfb40df488aa7 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/_cparser.pxd @@ -0,0 +1,158 @@ +from libc.stdint cimport int32_t, uint8_t, uint16_t, uint64_t + + +cdef extern from "../vendor/llhttp/build/llhttp.h": + + struct llhttp__internal_s: + int32_t _index + void* _span_pos0 + void* _span_cb0 + int32_t error + const char* reason + const char* error_pos + void* data + void* _current + uint64_t content_length + uint8_t type + uint8_t method + uint8_t http_major + uint8_t http_minor + uint8_t header_state + uint8_t lenient_flags + uint8_t upgrade + uint8_t finish + uint16_t flags + uint16_t status_code + void* settings + + ctypedef llhttp__internal_s llhttp__internal_t + ctypedef llhttp__internal_t llhttp_t + + ctypedef int (*llhttp_data_cb)(llhttp_t*, const char *at, size_t length) except -1 + ctypedef int (*llhttp_cb)(llhttp_t*) except -1 + + struct llhttp_settings_s: + llhttp_cb on_message_begin + llhttp_data_cb on_url + llhttp_data_cb on_status + llhttp_data_cb on_header_field + llhttp_data_cb on_header_value + llhttp_cb on_headers_complete + llhttp_data_cb on_body + llhttp_cb on_message_complete + llhttp_cb on_chunk_header + llhttp_cb on_chunk_complete + + llhttp_cb on_url_complete + llhttp_cb on_status_complete + llhttp_cb on_header_field_complete + llhttp_cb on_header_value_complete + + ctypedef llhttp_settings_s llhttp_settings_t + + enum llhttp_errno: + HPE_OK, + HPE_INTERNAL, + HPE_STRICT, + HPE_LF_EXPECTED, + HPE_UNEXPECTED_CONTENT_LENGTH, + HPE_CLOSED_CONNECTION, + HPE_INVALID_METHOD, + HPE_INVALID_URL, + HPE_INVALID_CONSTANT, + HPE_INVALID_VERSION, + HPE_INVALID_HEADER_TOKEN, + HPE_INVALID_CONTENT_LENGTH, + HPE_INVALID_CHUNK_SIZE, + HPE_INVALID_STATUS, + HPE_INVALID_EOF_STATE, + HPE_INVALID_TRANSFER_ENCODING, + HPE_CB_MESSAGE_BEGIN, + HPE_CB_HEADERS_COMPLETE, + HPE_CB_MESSAGE_COMPLETE, + HPE_CB_CHUNK_HEADER, + HPE_CB_CHUNK_COMPLETE, + HPE_PAUSED, + HPE_PAUSED_UPGRADE, + HPE_USER + + ctypedef llhttp_errno llhttp_errno_t + + enum llhttp_flags: + F_CHUNKED, + F_CONTENT_LENGTH + + enum llhttp_type: + HTTP_REQUEST, + HTTP_RESPONSE, + HTTP_BOTH + + enum llhttp_method: + HTTP_DELETE, + HTTP_GET, + HTTP_HEAD, + HTTP_POST, + HTTP_PUT, + HTTP_CONNECT, + HTTP_OPTIONS, + HTTP_TRACE, + HTTP_COPY, + HTTP_LOCK, + HTTP_MKCOL, + HTTP_MOVE, + HTTP_PROPFIND, + HTTP_PROPPATCH, + HTTP_SEARCH, + HTTP_UNLOCK, + HTTP_BIND, + HTTP_REBIND, + HTTP_UNBIND, + HTTP_ACL, + HTTP_REPORT, + HTTP_MKACTIVITY, + HTTP_CHECKOUT, + HTTP_MERGE, + HTTP_MSEARCH, + HTTP_NOTIFY, + HTTP_SUBSCRIBE, + HTTP_UNSUBSCRIBE, + HTTP_PATCH, + HTTP_PURGE, + HTTP_MKCALENDAR, + HTTP_LINK, + HTTP_UNLINK, + HTTP_SOURCE, + HTTP_PRI, + HTTP_DESCRIBE, + HTTP_ANNOUNCE, + HTTP_SETUP, + HTTP_PLAY, + HTTP_PAUSE, + HTTP_TEARDOWN, + HTTP_GET_PARAMETER, + HTTP_SET_PARAMETER, + HTTP_REDIRECT, + HTTP_RECORD, + HTTP_FLUSH + + ctypedef llhttp_method llhttp_method_t; + + void llhttp_settings_init(llhttp_settings_t* settings) + void llhttp_init(llhttp_t* parser, llhttp_type type, + const llhttp_settings_t* settings) + + llhttp_errno_t llhttp_execute(llhttp_t* parser, const char* data, size_t len) + + int llhttp_should_keep_alive(const llhttp_t* parser) + + void llhttp_resume_after_upgrade(llhttp_t* parser) + + llhttp_errno_t llhttp_get_errno(const llhttp_t* parser) + const char* llhttp_get_error_reason(const llhttp_t* parser) + const char* llhttp_get_error_pos(const llhttp_t* parser) + + const char* llhttp_method_name(llhttp_method_t method) + + void llhttp_set_lenient_headers(llhttp_t* parser, int enabled) + void llhttp_set_lenient_optional_cr_before_lf(llhttp_t* parser, int enabled) + void llhttp_set_lenient_spaces_after_chunk_size(llhttp_t* parser, int enabled) diff --git a/llm/Lib/site-packages/aiohttp/_find_header.pxd b/llm/Lib/site-packages/aiohttp/_find_header.pxd new file mode 100644 index 0000000000000000000000000000000000000000..9305cb61da27a827c280b3c7ea3ca65e96c8793b --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/_find_header.pxd @@ -0,0 +1,2 @@ +cdef extern from "_find_header.h": + int find_header(char *, int) diff --git a/llm/Lib/site-packages/aiohttp/_headers.pxi b/llm/Lib/site-packages/aiohttp/_headers.pxi new file mode 100644 index 0000000000000000000000000000000000000000..45c50f15e8c5e77996bf29315c13027b7b8f66ef --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/_headers.pxi @@ -0,0 +1,83 @@ +# The file is autogenerated from aiohttp/hdrs.py +# Run ./tools/gen.py to update it after the origin changing. + +from . import hdrs +cdef tuple headers = ( + hdrs.ACCEPT, + hdrs.ACCEPT_CHARSET, + hdrs.ACCEPT_ENCODING, + hdrs.ACCEPT_LANGUAGE, + hdrs.ACCEPT_RANGES, + hdrs.ACCESS_CONTROL_ALLOW_CREDENTIALS, + hdrs.ACCESS_CONTROL_ALLOW_HEADERS, + hdrs.ACCESS_CONTROL_ALLOW_METHODS, + hdrs.ACCESS_CONTROL_ALLOW_ORIGIN, + hdrs.ACCESS_CONTROL_EXPOSE_HEADERS, + hdrs.ACCESS_CONTROL_MAX_AGE, + hdrs.ACCESS_CONTROL_REQUEST_HEADERS, + hdrs.ACCESS_CONTROL_REQUEST_METHOD, + hdrs.AGE, + hdrs.ALLOW, + hdrs.AUTHORIZATION, + hdrs.CACHE_CONTROL, + hdrs.CONNECTION, + hdrs.CONTENT_DISPOSITION, + hdrs.CONTENT_ENCODING, + hdrs.CONTENT_LANGUAGE, + hdrs.CONTENT_LENGTH, + hdrs.CONTENT_LOCATION, + hdrs.CONTENT_MD5, + hdrs.CONTENT_RANGE, + hdrs.CONTENT_TRANSFER_ENCODING, + hdrs.CONTENT_TYPE, + hdrs.COOKIE, + hdrs.DATE, + hdrs.DESTINATION, + hdrs.DIGEST, + hdrs.ETAG, + hdrs.EXPECT, + hdrs.EXPIRES, + hdrs.FORWARDED, + hdrs.FROM, + hdrs.HOST, + hdrs.IF_MATCH, + hdrs.IF_MODIFIED_SINCE, + hdrs.IF_NONE_MATCH, + hdrs.IF_RANGE, + hdrs.IF_UNMODIFIED_SINCE, + hdrs.KEEP_ALIVE, + hdrs.LAST_EVENT_ID, + hdrs.LAST_MODIFIED, + hdrs.LINK, + hdrs.LOCATION, + hdrs.MAX_FORWARDS, + hdrs.ORIGIN, + hdrs.PRAGMA, + hdrs.PROXY_AUTHENTICATE, + hdrs.PROXY_AUTHORIZATION, + hdrs.RANGE, + hdrs.REFERER, + hdrs.RETRY_AFTER, + hdrs.SEC_WEBSOCKET_ACCEPT, + hdrs.SEC_WEBSOCKET_EXTENSIONS, + hdrs.SEC_WEBSOCKET_KEY, + hdrs.SEC_WEBSOCKET_KEY1, + hdrs.SEC_WEBSOCKET_PROTOCOL, + hdrs.SEC_WEBSOCKET_VERSION, + hdrs.SERVER, + hdrs.SET_COOKIE, + hdrs.TE, + hdrs.TRAILER, + hdrs.TRANSFER_ENCODING, + hdrs.URI, + hdrs.UPGRADE, + hdrs.USER_AGENT, + hdrs.VARY, + hdrs.VIA, + hdrs.WWW_AUTHENTICATE, + hdrs.WANT_DIGEST, + hdrs.WARNING, + hdrs.X_FORWARDED_FOR, + hdrs.X_FORWARDED_HOST, + hdrs.X_FORWARDED_PROTO, +) diff --git a/llm/Lib/site-packages/aiohttp/_helpers.cp311-win_amd64.pyd b/llm/Lib/site-packages/aiohttp/_helpers.cp311-win_amd64.pyd new file mode 100644 index 0000000000000000000000000000000000000000..983e43c027c946ca7f9990aa0ea5e9977cf77d29 Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/_helpers.cp311-win_amd64.pyd differ diff --git a/llm/Lib/site-packages/aiohttp/_helpers.pyi b/llm/Lib/site-packages/aiohttp/_helpers.pyi new file mode 100644 index 0000000000000000000000000000000000000000..6dda40d032ab4c1f6b85c9d6d52b79c020875e67 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/_helpers.pyi @@ -0,0 +1,6 @@ +from typing import Any + +class reify: + def __init__(self, wrapped: Any) -> None: ... + def __get__(self, inst: Any, owner: Any) -> Any: ... + def __set__(self, inst: Any, value: Any) -> None: ... diff --git a/llm/Lib/site-packages/aiohttp/_helpers.pyx b/llm/Lib/site-packages/aiohttp/_helpers.pyx new file mode 100644 index 0000000000000000000000000000000000000000..cccb332cea6ed9eeab585a0c632e7a67f700973a --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/_helpers.pyx @@ -0,0 +1,35 @@ +cdef class reify: + """Use as a class method decorator. It operates almost exactly like + the Python `@property` decorator, but it puts the result of the + method it decorates into the instance dict after the first call, + effectively replacing the function it decorates with an instance + variable. It is, in Python parlance, a data descriptor. + + """ + + cdef object wrapped + cdef object name + + def __init__(self, wrapped): + self.wrapped = wrapped + self.name = wrapped.__name__ + + @property + def __doc__(self): + return self.wrapped.__doc__ + + def __get__(self, inst, owner): + try: + try: + return inst._cache[self.name] + except KeyError: + val = self.wrapped(inst) + inst._cache[self.name] = val + return val + except AttributeError: + if inst is None: + return self + raise + + def __set__(self, inst, value): + raise AttributeError("reified property is read-only") diff --git a/llm/Lib/site-packages/aiohttp/_http_parser.cp311-win_amd64.pyd b/llm/Lib/site-packages/aiohttp/_http_parser.cp311-win_amd64.pyd new file mode 100644 index 0000000000000000000000000000000000000000..f30bacadde3c4d4ab5314329ff375a9b73cd8a74 Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/_http_parser.cp311-win_amd64.pyd differ diff --git a/llm/Lib/site-packages/aiohttp/_http_parser.pyx b/llm/Lib/site-packages/aiohttp/_http_parser.pyx new file mode 100644 index 0000000000000000000000000000000000000000..17c6fcdcf92269a02d600b61042fff905c935d68 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/_http_parser.pyx @@ -0,0 +1,838 @@ +#cython: language_level=3 +# +# Based on https://github.com/MagicStack/httptools +# + +from cpython cimport ( + Py_buffer, + PyBUF_SIMPLE, + PyBuffer_Release, + PyBytes_AsString, + PyBytes_AsStringAndSize, + PyObject_GetBuffer, +) +from cpython.mem cimport PyMem_Free, PyMem_Malloc +from libc.limits cimport ULLONG_MAX +from libc.string cimport memcpy + +from multidict import CIMultiDict as _CIMultiDict, CIMultiDictProxy as _CIMultiDictProxy +from yarl import URL as _URL + +from aiohttp import hdrs +from aiohttp.helpers import DEBUG, set_exception + +from .http_exceptions import ( + BadHttpMessage, + BadStatusLine, + ContentLengthError, + InvalidHeader, + InvalidURLError, + LineTooLong, + PayloadEncodingError, + TransferEncodingError, +) +from .http_parser import DeflateBuffer as _DeflateBuffer +from .http_writer import ( + HttpVersion as _HttpVersion, + HttpVersion10 as _HttpVersion10, + HttpVersion11 as _HttpVersion11, +) +from .streams import EMPTY_PAYLOAD as _EMPTY_PAYLOAD, StreamReader as _StreamReader + +cimport cython + +from aiohttp cimport _cparser as cparser + +include "_headers.pxi" + +from aiohttp cimport _find_header + +DEF DEFAULT_FREELIST_SIZE = 250 + +cdef extern from "Python.h": + int PyByteArray_Resize(object, Py_ssize_t) except -1 + Py_ssize_t PyByteArray_Size(object) except -1 + char* PyByteArray_AsString(object) + +__all__ = ('HttpRequestParser', 'HttpResponseParser', + 'RawRequestMessage', 'RawResponseMessage') + +cdef object URL = _URL +cdef object URL_build = URL.build +cdef object CIMultiDict = _CIMultiDict +cdef object CIMultiDictProxy = _CIMultiDictProxy +cdef object HttpVersion = _HttpVersion +cdef object HttpVersion10 = _HttpVersion10 +cdef object HttpVersion11 = _HttpVersion11 +cdef object SEC_WEBSOCKET_KEY1 = hdrs.SEC_WEBSOCKET_KEY1 +cdef object CONTENT_ENCODING = hdrs.CONTENT_ENCODING +cdef object EMPTY_PAYLOAD = _EMPTY_PAYLOAD +cdef object StreamReader = _StreamReader +cdef object DeflateBuffer = _DeflateBuffer + + +cdef inline object extend(object buf, const char* at, size_t length): + cdef Py_ssize_t s + cdef char* ptr + s = PyByteArray_Size(buf) + PyByteArray_Resize(buf, s + length) + ptr = PyByteArray_AsString(buf) + memcpy(ptr + s, at, length) + + +DEF METHODS_COUNT = 46; + +cdef list _http_method = [] + +for i in range(METHODS_COUNT): + _http_method.append( + cparser.llhttp_method_name( i).decode('ascii')) + + +cdef inline str http_method_str(int i): + if i < METHODS_COUNT: + return _http_method[i] + else: + return "" + +cdef inline object find_header(bytes raw_header): + cdef Py_ssize_t size + cdef char *buf + cdef int idx + PyBytes_AsStringAndSize(raw_header, &buf, &size) + idx = _find_header.find_header(buf, size) + if idx == -1: + return raw_header.decode('utf-8', 'surrogateescape') + return headers[idx] + + +@cython.freelist(DEFAULT_FREELIST_SIZE) +cdef class RawRequestMessage: + cdef readonly str method + cdef readonly str path + cdef readonly object version # HttpVersion + cdef readonly object headers # CIMultiDict + cdef readonly object raw_headers # tuple + cdef readonly object should_close + cdef readonly object compression + cdef readonly object upgrade + cdef readonly object chunked + cdef readonly object url # yarl.URL + + def __init__(self, method, path, version, headers, raw_headers, + should_close, compression, upgrade, chunked, url): + self.method = method + self.path = path + self.version = version + self.headers = headers + self.raw_headers = raw_headers + self.should_close = should_close + self.compression = compression + self.upgrade = upgrade + self.chunked = chunked + self.url = url + + def __repr__(self): + info = [] + info.append(("method", self.method)) + info.append(("path", self.path)) + info.append(("version", self.version)) + info.append(("headers", self.headers)) + info.append(("raw_headers", self.raw_headers)) + info.append(("should_close", self.should_close)) + info.append(("compression", self.compression)) + info.append(("upgrade", self.upgrade)) + info.append(("chunked", self.chunked)) + info.append(("url", self.url)) + sinfo = ', '.join(name + '=' + repr(val) for name, val in info) + return '' + + def _replace(self, **dct): + cdef RawRequestMessage ret + ret = _new_request_message(self.method, + self.path, + self.version, + self.headers, + self.raw_headers, + self.should_close, + self.compression, + self.upgrade, + self.chunked, + self.url) + if "method" in dct: + ret.method = dct["method"] + if "path" in dct: + ret.path = dct["path"] + if "version" in dct: + ret.version = dct["version"] + if "headers" in dct: + ret.headers = dct["headers"] + if "raw_headers" in dct: + ret.raw_headers = dct["raw_headers"] + if "should_close" in dct: + ret.should_close = dct["should_close"] + if "compression" in dct: + ret.compression = dct["compression"] + if "upgrade" in dct: + ret.upgrade = dct["upgrade"] + if "chunked" in dct: + ret.chunked = dct["chunked"] + if "url" in dct: + ret.url = dct["url"] + return ret + +cdef _new_request_message(str method, + str path, + object version, + object headers, + object raw_headers, + bint should_close, + object compression, + bint upgrade, + bint chunked, + object url): + cdef RawRequestMessage ret + ret = RawRequestMessage.__new__(RawRequestMessage) + ret.method = method + ret.path = path + ret.version = version + ret.headers = headers + ret.raw_headers = raw_headers + ret.should_close = should_close + ret.compression = compression + ret.upgrade = upgrade + ret.chunked = chunked + ret.url = url + return ret + + +@cython.freelist(DEFAULT_FREELIST_SIZE) +cdef class RawResponseMessage: + cdef readonly object version # HttpVersion + cdef readonly int code + cdef readonly str reason + cdef readonly object headers # CIMultiDict + cdef readonly object raw_headers # tuple + cdef readonly object should_close + cdef readonly object compression + cdef readonly object upgrade + cdef readonly object chunked + + def __init__(self, version, code, reason, headers, raw_headers, + should_close, compression, upgrade, chunked): + self.version = version + self.code = code + self.reason = reason + self.headers = headers + self.raw_headers = raw_headers + self.should_close = should_close + self.compression = compression + self.upgrade = upgrade + self.chunked = chunked + + def __repr__(self): + info = [] + info.append(("version", self.version)) + info.append(("code", self.code)) + info.append(("reason", self.reason)) + info.append(("headers", self.headers)) + info.append(("raw_headers", self.raw_headers)) + info.append(("should_close", self.should_close)) + info.append(("compression", self.compression)) + info.append(("upgrade", self.upgrade)) + info.append(("chunked", self.chunked)) + sinfo = ', '.join(name + '=' + repr(val) for name, val in info) + return '' + + +cdef _new_response_message(object version, + int code, + str reason, + object headers, + object raw_headers, + bint should_close, + object compression, + bint upgrade, + bint chunked): + cdef RawResponseMessage ret + ret = RawResponseMessage.__new__(RawResponseMessage) + ret.version = version + ret.code = code + ret.reason = reason + ret.headers = headers + ret.raw_headers = raw_headers + ret.should_close = should_close + ret.compression = compression + ret.upgrade = upgrade + ret.chunked = chunked + return ret + + +@cython.internal +cdef class HttpParser: + + cdef: + cparser.llhttp_t* _cparser + cparser.llhttp_settings_t* _csettings + + bytearray _raw_name + bytearray _raw_value + bint _has_value + + object _protocol + object _loop + object _timer + + size_t _max_line_size + size_t _max_field_size + size_t _max_headers + bint _response_with_body + bint _read_until_eof + + bint _started + object _url + bytearray _buf + str _path + str _reason + object _headers + list _raw_headers + bint _upgraded + list _messages + object _payload + bint _payload_error + object _payload_exception + object _last_error + bint _auto_decompress + int _limit + + str _content_encoding + + Py_buffer py_buf + + def __cinit__(self): + self._cparser = \ + PyMem_Malloc(sizeof(cparser.llhttp_t)) + if self._cparser is NULL: + raise MemoryError() + + self._csettings = \ + PyMem_Malloc(sizeof(cparser.llhttp_settings_t)) + if self._csettings is NULL: + raise MemoryError() + + def __dealloc__(self): + PyMem_Free(self._cparser) + PyMem_Free(self._csettings) + + cdef _init( + self, cparser.llhttp_type mode, + object protocol, object loop, int limit, + object timer=None, + size_t max_line_size=8190, size_t max_headers=32768, + size_t max_field_size=8190, payload_exception=None, + bint response_with_body=True, bint read_until_eof=False, + bint auto_decompress=True, + ): + cparser.llhttp_settings_init(self._csettings) + cparser.llhttp_init(self._cparser, mode, self._csettings) + self._cparser.data = self + self._cparser.content_length = 0 + + self._protocol = protocol + self._loop = loop + self._timer = timer + + self._buf = bytearray() + self._payload = None + self._payload_error = 0 + self._payload_exception = payload_exception + self._messages = [] + + self._raw_name = bytearray() + self._raw_value = bytearray() + self._has_value = False + + self._max_line_size = max_line_size + self._max_headers = max_headers + self._max_field_size = max_field_size + self._response_with_body = response_with_body + self._read_until_eof = read_until_eof + self._upgraded = False + self._auto_decompress = auto_decompress + self._content_encoding = None + + self._csettings.on_url = cb_on_url + self._csettings.on_status = cb_on_status + self._csettings.on_header_field = cb_on_header_field + self._csettings.on_header_value = cb_on_header_value + self._csettings.on_headers_complete = cb_on_headers_complete + self._csettings.on_body = cb_on_body + self._csettings.on_message_begin = cb_on_message_begin + self._csettings.on_message_complete = cb_on_message_complete + self._csettings.on_chunk_header = cb_on_chunk_header + self._csettings.on_chunk_complete = cb_on_chunk_complete + + self._last_error = None + self._limit = limit + + cdef _process_header(self): + if self._raw_name: + raw_name = bytes(self._raw_name) + raw_value = bytes(self._raw_value) + + name = find_header(raw_name) + value = raw_value.decode('utf-8', 'surrogateescape') + + self._headers.add(name, value) + + if name is CONTENT_ENCODING: + self._content_encoding = value + + PyByteArray_Resize(self._raw_name, 0) + PyByteArray_Resize(self._raw_value, 0) + self._has_value = False + self._raw_headers.append((raw_name, raw_value)) + + cdef _on_header_field(self, char* at, size_t length): + cdef Py_ssize_t size + cdef char *buf + if self._has_value: + self._process_header() + + size = PyByteArray_Size(self._raw_name) + PyByteArray_Resize(self._raw_name, size + length) + buf = PyByteArray_AsString(self._raw_name) + memcpy(buf + size, at, length) + + cdef _on_header_value(self, char* at, size_t length): + cdef Py_ssize_t size + cdef char *buf + + size = PyByteArray_Size(self._raw_value) + PyByteArray_Resize(self._raw_value, size + length) + buf = PyByteArray_AsString(self._raw_value) + memcpy(buf + size, at, length) + self._has_value = True + + cdef _on_headers_complete(self): + self._process_header() + + method = http_method_str(self._cparser.method) + should_close = not cparser.llhttp_should_keep_alive(self._cparser) + upgrade = self._cparser.upgrade + chunked = self._cparser.flags & cparser.F_CHUNKED + + raw_headers = tuple(self._raw_headers) + headers = CIMultiDictProxy(self._headers) + + if upgrade or self._cparser.method == cparser.HTTP_CONNECT: + self._upgraded = True + + # do not support old websocket spec + if SEC_WEBSOCKET_KEY1 in headers: + raise InvalidHeader(SEC_WEBSOCKET_KEY1) + + encoding = None + enc = self._content_encoding + if enc is not None: + self._content_encoding = None + enc = enc.lower() + if enc in ('gzip', 'deflate', 'br'): + encoding = enc + + if self._cparser.type == cparser.HTTP_REQUEST: + msg = _new_request_message( + method, self._path, + self.http_version(), headers, raw_headers, + should_close, encoding, upgrade, chunked, self._url) + else: + msg = _new_response_message( + self.http_version(), self._cparser.status_code, self._reason, + headers, raw_headers, should_close, encoding, + upgrade, chunked) + + if ( + ULLONG_MAX > self._cparser.content_length > 0 or chunked or + self._cparser.method == cparser.HTTP_CONNECT or + (self._cparser.status_code >= 199 and + self._cparser.content_length == 0 and + self._read_until_eof) + ): + payload = StreamReader( + self._protocol, timer=self._timer, loop=self._loop, + limit=self._limit) + else: + payload = EMPTY_PAYLOAD + + self._payload = payload + if encoding is not None and self._auto_decompress: + self._payload = DeflateBuffer(payload, encoding) + + if not self._response_with_body: + payload = EMPTY_PAYLOAD + + self._messages.append((msg, payload)) + + cdef _on_message_complete(self): + self._payload.feed_eof() + self._payload = None + + cdef _on_chunk_header(self): + self._payload.begin_http_chunk_receiving() + + cdef _on_chunk_complete(self): + self._payload.end_http_chunk_receiving() + + cdef object _on_status_complete(self): + pass + + cdef inline http_version(self): + cdef cparser.llhttp_t* parser = self._cparser + + if parser.http_major == 1: + if parser.http_minor == 0: + return HttpVersion10 + elif parser.http_minor == 1: + return HttpVersion11 + + return HttpVersion(parser.http_major, parser.http_minor) + + ### Public API ### + + def feed_eof(self): + cdef bytes desc + + if self._payload is not None: + if self._cparser.flags & cparser.F_CHUNKED: + raise TransferEncodingError( + "Not enough data for satisfy transfer length header.") + elif self._cparser.flags & cparser.F_CONTENT_LENGTH: + raise ContentLengthError( + "Not enough data for satisfy content length header.") + elif cparser.llhttp_get_errno(self._cparser) != cparser.HPE_OK: + desc = cparser.llhttp_get_error_reason(self._cparser) + raise PayloadEncodingError(desc.decode('latin-1')) + else: + self._payload.feed_eof() + elif self._started: + self._on_headers_complete() + if self._messages: + return self._messages[-1][0] + + def feed_data(self, data): + cdef: + size_t data_len + size_t nb + cdef cparser.llhttp_errno_t errno + + PyObject_GetBuffer(data, &self.py_buf, PyBUF_SIMPLE) + data_len = self.py_buf.len + + errno = cparser.llhttp_execute( + self._cparser, + self.py_buf.buf, + data_len) + + if errno is cparser.HPE_PAUSED_UPGRADE: + cparser.llhttp_resume_after_upgrade(self._cparser) + + nb = cparser.llhttp_get_error_pos(self._cparser) - self.py_buf.buf + + PyBuffer_Release(&self.py_buf) + + if errno not in (cparser.HPE_OK, cparser.HPE_PAUSED_UPGRADE): + if self._payload_error == 0: + if self._last_error is not None: + ex = self._last_error + self._last_error = None + else: + after = cparser.llhttp_get_error_pos(self._cparser) + before = data[:after - self.py_buf.buf] + after_b = after.split(b"\r\n", 1)[0] + before = before.rsplit(b"\r\n", 1)[-1] + data = before + after_b + pointer = " " * (len(repr(before))-1) + "^" + ex = parser_error_from_errno(self._cparser, data, pointer) + self._payload = None + raise ex + + if self._messages: + messages = self._messages + self._messages = [] + else: + messages = () + + if self._upgraded: + return messages, True, data[nb:] + else: + return messages, False, b'' + + def set_upgraded(self, val): + self._upgraded = val + + +cdef class HttpRequestParser(HttpParser): + + def __init__( + self, protocol, loop, int limit, timer=None, + size_t max_line_size=8190, size_t max_headers=32768, + size_t max_field_size=8190, payload_exception=None, + bint response_with_body=True, bint read_until_eof=False, + bint auto_decompress=True, + ): + self._init(cparser.HTTP_REQUEST, protocol, loop, limit, timer, + max_line_size, max_headers, max_field_size, + payload_exception, response_with_body, read_until_eof, + auto_decompress) + + cdef object _on_status_complete(self): + cdef int idx1, idx2 + if not self._buf: + return + self._path = self._buf.decode('utf-8', 'surrogateescape') + try: + idx3 = len(self._path) + if self._cparser.method == cparser.HTTP_CONNECT: + # authority-form, + # https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.3 + self._url = URL.build(authority=self._path, encoded=True) + elif idx3 > 1 and self._path[0] == '/': + # origin-form, + # https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.1 + idx1 = self._path.find("?") + if idx1 == -1: + query = "" + idx2 = self._path.find("#") + if idx2 == -1: + path = self._path + fragment = "" + else: + path = self._path[0: idx2] + fragment = self._path[idx2+1:] + + else: + path = self._path[0:idx1] + idx1 += 1 + idx2 = self._path.find("#", idx1+1) + if idx2 == -1: + query = self._path[idx1:] + fragment = "" + else: + query = self._path[idx1: idx2] + fragment = self._path[idx2+1:] + + self._url = URL.build( + path=path, + query_string=query, + fragment=fragment, + encoded=True, + ) + else: + # absolute-form for proxy maybe, + # https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.2 + self._url = URL(self._path, encoded=True) + finally: + PyByteArray_Resize(self._buf, 0) + + +cdef class HttpResponseParser(HttpParser): + + def __init__( + self, protocol, loop, int limit, timer=None, + size_t max_line_size=8190, size_t max_headers=32768, + size_t max_field_size=8190, payload_exception=None, + bint response_with_body=True, bint read_until_eof=False, + bint auto_decompress=True + ): + self._init(cparser.HTTP_RESPONSE, protocol, loop, limit, timer, + max_line_size, max_headers, max_field_size, + payload_exception, response_with_body, read_until_eof, + auto_decompress) + # Use strict parsing on dev mode, so users are warned about broken servers. + if not DEBUG: + cparser.llhttp_set_lenient_headers(self._cparser, 1) + cparser.llhttp_set_lenient_optional_cr_before_lf(self._cparser, 1) + cparser.llhttp_set_lenient_spaces_after_chunk_size(self._cparser, 1) + + cdef object _on_status_complete(self): + if self._buf: + self._reason = self._buf.decode('utf-8', 'surrogateescape') + PyByteArray_Resize(self._buf, 0) + else: + self._reason = self._reason or '' + +cdef int cb_on_message_begin(cparser.llhttp_t* parser) except -1: + cdef HttpParser pyparser = parser.data + + pyparser._started = True + pyparser._headers = CIMultiDict() + pyparser._raw_headers = [] + PyByteArray_Resize(pyparser._buf, 0) + pyparser._path = None + pyparser._reason = None + return 0 + + +cdef int cb_on_url(cparser.llhttp_t* parser, + const char *at, size_t length) except -1: + cdef HttpParser pyparser = parser.data + try: + if length > pyparser._max_line_size: + raise LineTooLong( + 'Status line is too long', pyparser._max_line_size, length) + extend(pyparser._buf, at, length) + except BaseException as ex: + pyparser._last_error = ex + return -1 + else: + return 0 + + +cdef int cb_on_status(cparser.llhttp_t* parser, + const char *at, size_t length) except -1: + cdef HttpParser pyparser = parser.data + cdef str reason + try: + if length > pyparser._max_line_size: + raise LineTooLong( + 'Status line is too long', pyparser._max_line_size, length) + extend(pyparser._buf, at, length) + except BaseException as ex: + pyparser._last_error = ex + return -1 + else: + return 0 + + +cdef int cb_on_header_field(cparser.llhttp_t* parser, + const char *at, size_t length) except -1: + cdef HttpParser pyparser = parser.data + cdef Py_ssize_t size + try: + pyparser._on_status_complete() + size = len(pyparser._raw_name) + length + if size > pyparser._max_field_size: + raise LineTooLong( + 'Header name is too long', pyparser._max_field_size, size) + pyparser._on_header_field(at, length) + except BaseException as ex: + pyparser._last_error = ex + return -1 + else: + return 0 + + +cdef int cb_on_header_value(cparser.llhttp_t* parser, + const char *at, size_t length) except -1: + cdef HttpParser pyparser = parser.data + cdef Py_ssize_t size + try: + size = len(pyparser._raw_value) + length + if size > pyparser._max_field_size: + raise LineTooLong( + 'Header value is too long', pyparser._max_field_size, size) + pyparser._on_header_value(at, length) + except BaseException as ex: + pyparser._last_error = ex + return -1 + else: + return 0 + + +cdef int cb_on_headers_complete(cparser.llhttp_t* parser) except -1: + cdef HttpParser pyparser = parser.data + try: + pyparser._on_status_complete() + pyparser._on_headers_complete() + except BaseException as exc: + pyparser._last_error = exc + return -1 + else: + if ( + pyparser._cparser.upgrade or + pyparser._cparser.method == cparser.HTTP_CONNECT + ): + return 2 + else: + return 0 + + +cdef int cb_on_body(cparser.llhttp_t* parser, + const char *at, size_t length) except -1: + cdef HttpParser pyparser = parser.data + cdef bytes body = at[:length] + try: + pyparser._payload.feed_data(body, length) + except BaseException as underlying_exc: + reraised_exc = underlying_exc + if pyparser._payload_exception is not None: + reraised_exc = pyparser._payload_exception(str(underlying_exc)) + + set_exception(pyparser._payload, reraised_exc, underlying_exc) + + pyparser._payload_error = 1 + return -1 + else: + return 0 + + +cdef int cb_on_message_complete(cparser.llhttp_t* parser) except -1: + cdef HttpParser pyparser = parser.data + try: + pyparser._started = False + pyparser._on_message_complete() + except BaseException as exc: + pyparser._last_error = exc + return -1 + else: + return 0 + + +cdef int cb_on_chunk_header(cparser.llhttp_t* parser) except -1: + cdef HttpParser pyparser = parser.data + try: + pyparser._on_chunk_header() + except BaseException as exc: + pyparser._last_error = exc + return -1 + else: + return 0 + + +cdef int cb_on_chunk_complete(cparser.llhttp_t* parser) except -1: + cdef HttpParser pyparser = parser.data + try: + pyparser._on_chunk_complete() + except BaseException as exc: + pyparser._last_error = exc + return -1 + else: + return 0 + + +cdef parser_error_from_errno(cparser.llhttp_t* parser, data, pointer): + cdef cparser.llhttp_errno_t errno = cparser.llhttp_get_errno(parser) + cdef bytes desc = cparser.llhttp_get_error_reason(parser) + + err_msg = "{}:\n\n {!r}\n {}".format(desc.decode("latin-1"), data, pointer) + + if errno in {cparser.HPE_CB_MESSAGE_BEGIN, + cparser.HPE_CB_HEADERS_COMPLETE, + cparser.HPE_CB_MESSAGE_COMPLETE, + cparser.HPE_CB_CHUNK_HEADER, + cparser.HPE_CB_CHUNK_COMPLETE, + cparser.HPE_INVALID_CONSTANT, + cparser.HPE_INVALID_HEADER_TOKEN, + cparser.HPE_INVALID_CONTENT_LENGTH, + cparser.HPE_INVALID_CHUNK_SIZE, + cparser.HPE_INVALID_EOF_STATE, + cparser.HPE_INVALID_TRANSFER_ENCODING}: + return BadHttpMessage(err_msg) + elif errno in {cparser.HPE_INVALID_STATUS, + cparser.HPE_INVALID_METHOD, + cparser.HPE_INVALID_VERSION}: + return BadStatusLine(error=err_msg) + elif errno == cparser.HPE_INVALID_URL: + return InvalidURLError(err_msg) + + return BadHttpMessage(err_msg) diff --git a/llm/Lib/site-packages/aiohttp/_http_writer.cp311-win_amd64.pyd b/llm/Lib/site-packages/aiohttp/_http_writer.cp311-win_amd64.pyd new file mode 100644 index 0000000000000000000000000000000000000000..d6281942b1580120cddceddc0a01a422f7f60bcd Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/_http_writer.cp311-win_amd64.pyd differ diff --git a/llm/Lib/site-packages/aiohttp/_http_writer.pyx b/llm/Lib/site-packages/aiohttp/_http_writer.pyx new file mode 100644 index 0000000000000000000000000000000000000000..2f9e9f91a5179188467bc3d1c28f6e43e78060a5 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/_http_writer.pyx @@ -0,0 +1,163 @@ +from cpython.bytes cimport PyBytes_FromStringAndSize +from cpython.exc cimport PyErr_NoMemory +from cpython.mem cimport PyMem_Free, PyMem_Malloc, PyMem_Realloc +from cpython.object cimport PyObject_Str +from libc.stdint cimport uint8_t, uint64_t +from libc.string cimport memcpy + +from multidict import istr + +DEF BUF_SIZE = 16 * 1024 # 16KiB +cdef char BUFFER[BUF_SIZE] + +cdef object _istr = istr + + +# ----------------- writer --------------------------- + +cdef struct Writer: + char *buf + Py_ssize_t size + Py_ssize_t pos + + +cdef inline void _init_writer(Writer* writer): + writer.buf = &BUFFER[0] + writer.size = BUF_SIZE + writer.pos = 0 + + +cdef inline void _release_writer(Writer* writer): + if writer.buf != BUFFER: + PyMem_Free(writer.buf) + + +cdef inline int _write_byte(Writer* writer, uint8_t ch): + cdef char * buf + cdef Py_ssize_t size + + if writer.pos == writer.size: + # reallocate + size = writer.size + BUF_SIZE + if writer.buf == BUFFER: + buf = PyMem_Malloc(size) + if buf == NULL: + PyErr_NoMemory() + return -1 + memcpy(buf, writer.buf, writer.size) + else: + buf = PyMem_Realloc(writer.buf, size) + if buf == NULL: + PyErr_NoMemory() + return -1 + writer.buf = buf + writer.size = size + writer.buf[writer.pos] = ch + writer.pos += 1 + return 0 + + +cdef inline int _write_utf8(Writer* writer, Py_UCS4 symbol): + cdef uint64_t utf = symbol + + if utf < 0x80: + return _write_byte(writer, utf) + elif utf < 0x800: + if _write_byte(writer, (0xc0 | (utf >> 6))) < 0: + return -1 + return _write_byte(writer, (0x80 | (utf & 0x3f))) + elif 0xD800 <= utf <= 0xDFFF: + # surogate pair, ignored + return 0 + elif utf < 0x10000: + if _write_byte(writer, (0xe0 | (utf >> 12))) < 0: + return -1 + if _write_byte(writer, (0x80 | ((utf >> 6) & 0x3f))) < 0: + return -1 + return _write_byte(writer, (0x80 | (utf & 0x3f))) + elif utf > 0x10FFFF: + # symbol is too large + return 0 + else: + if _write_byte(writer, (0xf0 | (utf >> 18))) < 0: + return -1 + if _write_byte(writer, + (0x80 | ((utf >> 12) & 0x3f))) < 0: + return -1 + if _write_byte(writer, + (0x80 | ((utf >> 6) & 0x3f))) < 0: + return -1 + return _write_byte(writer, (0x80 | (utf & 0x3f))) + + +cdef inline int _write_str(Writer* writer, str s): + cdef Py_UCS4 ch + for ch in s: + if _write_utf8(writer, ch) < 0: + return -1 + + +# --------------- _serialize_headers ---------------------- + +cdef str to_str(object s): + typ = type(s) + if typ is str: + return s + elif typ is _istr: + return PyObject_Str(s) + elif not isinstance(s, str): + raise TypeError("Cannot serialize non-str key {!r}".format(s)) + else: + return str(s) + + +cdef void _safe_header(str string) except *: + if "\r" in string or "\n" in string: + raise ValueError( + "Newline or carriage return character detected in HTTP status message or " + "header. This is a potential security issue." + ) + + +def _serialize_headers(str status_line, headers): + cdef Writer writer + cdef object key + cdef object val + cdef bytes ret + + _init_writer(&writer) + + for key, val in headers.items(): + _safe_header(to_str(key)) + _safe_header(to_str(val)) + + try: + if _write_str(&writer, status_line) < 0: + raise + if _write_byte(&writer, b'\r') < 0: + raise + if _write_byte(&writer, b'\n') < 0: + raise + + for key, val in headers.items(): + if _write_str(&writer, to_str(key)) < 0: + raise + if _write_byte(&writer, b':') < 0: + raise + if _write_byte(&writer, b' ') < 0: + raise + if _write_str(&writer, to_str(val)) < 0: + raise + if _write_byte(&writer, b'\r') < 0: + raise + if _write_byte(&writer, b'\n') < 0: + raise + + if _write_byte(&writer, b'\r') < 0: + raise + if _write_byte(&writer, b'\n') < 0: + raise + + return PyBytes_FromStringAndSize(writer.buf, writer.pos) + finally: + _release_writer(&writer) diff --git a/llm/Lib/site-packages/aiohttp/_websocket.cp311-win_amd64.pyd b/llm/Lib/site-packages/aiohttp/_websocket.cp311-win_amd64.pyd new file mode 100644 index 0000000000000000000000000000000000000000..2f3afa485af44292e1966a566effc43dd84e2a0c Binary files /dev/null and b/llm/Lib/site-packages/aiohttp/_websocket.cp311-win_amd64.pyd differ diff --git a/llm/Lib/site-packages/aiohttp/_websocket.pyx b/llm/Lib/site-packages/aiohttp/_websocket.pyx new file mode 100644 index 0000000000000000000000000000000000000000..6f036b24fc171e771d084aa2049f1b73964e3e9f --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/_websocket.pyx @@ -0,0 +1,56 @@ +from cpython cimport PyBytes_AsString + + +#from cpython cimport PyByteArray_AsString # cython still not exports that +cdef extern from "Python.h": + char* PyByteArray_AsString(bytearray ba) except NULL + +from libc.stdint cimport uint32_t, uint64_t, uintmax_t + + +def _websocket_mask_cython(object mask, object data): + """Note, this function mutates its `data` argument + """ + cdef: + Py_ssize_t data_len, i + # bit operations on signed integers are implementation-specific + unsigned char * in_buf + const unsigned char * mask_buf + uint32_t uint32_msk + uint64_t uint64_msk + + assert len(mask) == 4 + + if not isinstance(mask, bytes): + mask = bytes(mask) + + if isinstance(data, bytearray): + data = data + else: + data = bytearray(data) + + data_len = len(data) + in_buf = PyByteArray_AsString(data) + mask_buf = PyBytes_AsString(mask) + uint32_msk = (mask_buf)[0] + + # TODO: align in_data ptr to achieve even faster speeds + # does it need in python ?! malloc() always aligns to sizeof(long) bytes + + if sizeof(size_t) >= 8: + uint64_msk = uint32_msk + uint64_msk = (uint64_msk << 32) | uint32_msk + + while data_len >= 8: + (in_buf)[0] ^= uint64_msk + in_buf += 8 + data_len -= 8 + + + while data_len >= 4: + (in_buf)[0] ^= uint32_msk + in_buf += 4 + data_len -= 4 + + for i in range(0, data_len): + in_buf[i] ^= mask_buf[i] diff --git a/llm/Lib/site-packages/aiohttp/abc.py b/llm/Lib/site-packages/aiohttp/abc.py new file mode 100644 index 0000000000000000000000000000000000000000..e539bc3bbc9d670fa9e546339783d8fa46d7c55f --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/abc.py @@ -0,0 +1,209 @@ +import asyncio +import logging +from abc import ABC, abstractmethod +from collections.abc import Sized +from http.cookies import BaseCookie, Morsel +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + Generator, + Iterable, + List, + Optional, + Tuple, +) + +from multidict import CIMultiDict +from yarl import URL + +from .helpers import get_running_loop +from .typedefs import LooseCookies + +if TYPE_CHECKING: + from .web_app import Application + from .web_exceptions import HTTPException + from .web_request import BaseRequest, Request + from .web_response import StreamResponse +else: + BaseRequest = Request = Application = StreamResponse = None + HTTPException = None + + +class AbstractRouter(ABC): + def __init__(self) -> None: + self._frozen = False + + def post_init(self, app: Application) -> None: + """Post init stage. + + Not an abstract method for sake of backward compatibility, + but if the router wants to be aware of the application + it can override this. + """ + + @property + def frozen(self) -> bool: + return self._frozen + + def freeze(self) -> None: + """Freeze router.""" + self._frozen = True + + @abstractmethod + async def resolve(self, request: Request) -> "AbstractMatchInfo": + """Return MATCH_INFO for given request""" + + +class AbstractMatchInfo(ABC): + @property # pragma: no branch + @abstractmethod + def handler(self) -> Callable[[Request], Awaitable[StreamResponse]]: + """Execute matched request handler""" + + @property + @abstractmethod + def expect_handler( + self, + ) -> Callable[[Request], Awaitable[Optional[StreamResponse]]]: + """Expect handler for 100-continue processing""" + + @property # pragma: no branch + @abstractmethod + def http_exception(self) -> Optional[HTTPException]: + """HTTPException instance raised on router's resolving, or None""" + + @abstractmethod # pragma: no branch + def get_info(self) -> Dict[str, Any]: + """Return a dict with additional info useful for introspection""" + + @property # pragma: no branch + @abstractmethod + def apps(self) -> Tuple[Application, ...]: + """Stack of nested applications. + + Top level application is left-most element. + + """ + + @abstractmethod + def add_app(self, app: Application) -> None: + """Add application to the nested apps stack.""" + + @abstractmethod + def freeze(self) -> None: + """Freeze the match info. + + The method is called after route resolution. + + After the call .add_app() is forbidden. + + """ + + +class AbstractView(ABC): + """Abstract class based view.""" + + def __init__(self, request: Request) -> None: + self._request = request + + @property + def request(self) -> Request: + """Request instance.""" + return self._request + + @abstractmethod + def __await__(self) -> Generator[Any, None, StreamResponse]: + """Execute the view handler.""" + + +class AbstractResolver(ABC): + """Abstract DNS resolver.""" + + @abstractmethod + async def resolve(self, host: str, port: int, family: int) -> List[Dict[str, Any]]: + """Return IP address for given hostname""" + + @abstractmethod + async def close(self) -> None: + """Release resolver""" + + +if TYPE_CHECKING: + IterableBase = Iterable[Morsel[str]] +else: + IterableBase = Iterable + + +ClearCookiePredicate = Callable[["Morsel[str]"], bool] + + +class AbstractCookieJar(Sized, IterableBase): + """Abstract Cookie Jar.""" + + def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: + self._loop = get_running_loop(loop) + + @abstractmethod + def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None: + """Clear all cookies if no predicate is passed.""" + + @abstractmethod + def clear_domain(self, domain: str) -> None: + """Clear all cookies for domain and all subdomains.""" + + @abstractmethod + def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None: + """Update cookies.""" + + @abstractmethod + def filter_cookies(self, request_url: URL) -> "BaseCookie[str]": + """Return the jar's cookies filtered by their attributes.""" + + +class AbstractStreamWriter(ABC): + """Abstract stream writer.""" + + buffer_size = 0 + output_size = 0 + length: Optional[int] = 0 + + @abstractmethod + async def write(self, chunk: bytes) -> None: + """Write chunk into stream.""" + + @abstractmethod + async def write_eof(self, chunk: bytes = b"") -> None: + """Write last chunk.""" + + @abstractmethod + async def drain(self) -> None: + """Flush the write buffer.""" + + @abstractmethod + def enable_compression(self, encoding: str = "deflate") -> None: + """Enable HTTP body compression""" + + @abstractmethod + def enable_chunking(self) -> None: + """Enable HTTP chunked mode""" + + @abstractmethod + async def write_headers( + self, status_line: str, headers: "CIMultiDict[str]" + ) -> None: + """Write HTTP headers""" + + +class AbstractAccessLogger(ABC): + """Abstract writer to access log.""" + + def __init__(self, logger: logging.Logger, log_format: str) -> None: + self.logger = logger + self.log_format = log_format + + @abstractmethod + def log(self, request: BaseRequest, response: StreamResponse, time: float) -> None: + """Emit log to logger.""" diff --git a/llm/Lib/site-packages/aiohttp/base_protocol.py b/llm/Lib/site-packages/aiohttp/base_protocol.py new file mode 100644 index 0000000000000000000000000000000000000000..889203322c88ac14af8e58bc39de7199dd190525 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/base_protocol.py @@ -0,0 +1,95 @@ +import asyncio +from typing import Optional, cast + +from .helpers import set_exception +from .tcp_helpers import tcp_nodelay + + +class BaseProtocol(asyncio.Protocol): + __slots__ = ( + "_loop", + "_paused", + "_drain_waiter", + "_connection_lost", + "_reading_paused", + "transport", + ) + + def __init__(self, loop: asyncio.AbstractEventLoop) -> None: + self._loop: asyncio.AbstractEventLoop = loop + self._paused = False + self._drain_waiter: Optional[asyncio.Future[None]] = None + self._reading_paused = False + + self.transport: Optional[asyncio.Transport] = None + + @property + def connected(self) -> bool: + """Return True if the connection is open.""" + return self.transport is not None + + def pause_writing(self) -> None: + assert not self._paused + self._paused = True + + def resume_writing(self) -> None: + assert self._paused + self._paused = False + + waiter = self._drain_waiter + if waiter is not None: + self._drain_waiter = None + if not waiter.done(): + waiter.set_result(None) + + def pause_reading(self) -> None: + if not self._reading_paused and self.transport is not None: + try: + self.transport.pause_reading() + except (AttributeError, NotImplementedError, RuntimeError): + pass + self._reading_paused = True + + def resume_reading(self) -> None: + if self._reading_paused and self.transport is not None: + try: + self.transport.resume_reading() + except (AttributeError, NotImplementedError, RuntimeError): + pass + self._reading_paused = False + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + tr = cast(asyncio.Transport, transport) + tcp_nodelay(tr, True) + self.transport = tr + + def connection_lost(self, exc: Optional[BaseException]) -> None: + # Wake up the writer if currently paused. + self.transport = None + if not self._paused: + return + waiter = self._drain_waiter + if waiter is None: + return + self._drain_waiter = None + if waiter.done(): + return + if exc is None: + waiter.set_result(None) + else: + set_exception( + waiter, + ConnectionError("Connection lost"), + exc, + ) + + async def _drain_helper(self) -> None: + if not self.connected: + raise ConnectionResetError("Connection lost") + if not self._paused: + return + waiter = self._drain_waiter + if waiter is None: + waiter = self._loop.create_future() + self._drain_waiter = waiter + await asyncio.shield(waiter) diff --git a/llm/Lib/site-packages/aiohttp/client.py b/llm/Lib/site-packages/aiohttp/client.py new file mode 100644 index 0000000000000000000000000000000000000000..c4923c8d78af9cccca333cb2ddde9408a5a1c2b5 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/client.py @@ -0,0 +1,1366 @@ +"""HTTP Client for asyncio.""" + +import asyncio +import base64 +import hashlib +import json +import os +import sys +import traceback +import warnings +from contextlib import suppress +from types import SimpleNamespace, TracebackType +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Coroutine, + Final, + FrozenSet, + Generator, + Generic, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, +) + +import attr +from multidict import CIMultiDict, MultiDict, MultiDictProxy, istr +from yarl import URL + +from . import hdrs, http, payload +from .abc import AbstractCookieJar +from .client_exceptions import ( + ClientConnectionError as ClientConnectionError, + ClientConnectorCertificateError as ClientConnectorCertificateError, + ClientConnectorError as ClientConnectorError, + ClientConnectorSSLError as ClientConnectorSSLError, + ClientError as ClientError, + ClientHttpProxyError as ClientHttpProxyError, + ClientOSError as ClientOSError, + ClientPayloadError as ClientPayloadError, + ClientProxyConnectionError as ClientProxyConnectionError, + ClientResponseError as ClientResponseError, + ClientSSLError as ClientSSLError, + ContentTypeError as ContentTypeError, + InvalidURL as InvalidURL, + ServerConnectionError as ServerConnectionError, + ServerDisconnectedError as ServerDisconnectedError, + ServerFingerprintMismatch as ServerFingerprintMismatch, + ServerTimeoutError as ServerTimeoutError, + TooManyRedirects as TooManyRedirects, + WSServerHandshakeError as WSServerHandshakeError, +) +from .client_reqrep import ( + ClientRequest as ClientRequest, + ClientResponse as ClientResponse, + Fingerprint as Fingerprint, + RequestInfo as RequestInfo, + _merge_ssl_params, +) +from .client_ws import ClientWebSocketResponse as ClientWebSocketResponse +from .connector import ( + BaseConnector as BaseConnector, + NamedPipeConnector as NamedPipeConnector, + TCPConnector as TCPConnector, + UnixConnector as UnixConnector, +) +from .cookiejar import CookieJar +from .helpers import ( + _SENTINEL, + DEBUG, + BasicAuth, + TimeoutHandle, + ceil_timeout, + get_env_proxy_for_url, + get_running_loop, + method_must_be_empty_body, + sentinel, + strip_auth_from_url, +) +from .http import WS_KEY, HttpVersion, WebSocketReader, WebSocketWriter +from .http_websocket import WSHandshakeError, WSMessage, ws_ext_gen, ws_ext_parse +from .streams import FlowControlDataQueue +from .tracing import Trace, TraceConfig +from .typedefs import JSONEncoder, LooseCookies, LooseHeaders, StrOrURL + +__all__ = ( + # client_exceptions + "ClientConnectionError", + "ClientConnectorCertificateError", + "ClientConnectorError", + "ClientConnectorSSLError", + "ClientError", + "ClientHttpProxyError", + "ClientOSError", + "ClientPayloadError", + "ClientProxyConnectionError", + "ClientResponseError", + "ClientSSLError", + "ContentTypeError", + "InvalidURL", + "ServerConnectionError", + "ServerDisconnectedError", + "ServerFingerprintMismatch", + "ServerTimeoutError", + "TooManyRedirects", + "WSServerHandshakeError", + # client_reqrep + "ClientRequest", + "ClientResponse", + "Fingerprint", + "RequestInfo", + # connector + "BaseConnector", + "TCPConnector", + "UnixConnector", + "NamedPipeConnector", + # client_ws + "ClientWebSocketResponse", + # client + "ClientSession", + "ClientTimeout", + "request", +) + + +if TYPE_CHECKING: + from ssl import SSLContext +else: + SSLContext = None + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class ClientTimeout: + total: Optional[float] = None + connect: Optional[float] = None + sock_read: Optional[float] = None + sock_connect: Optional[float] = None + ceil_threshold: float = 5 + + # pool_queue_timeout: Optional[float] = None + # dns_resolution_timeout: Optional[float] = None + # socket_connect_timeout: Optional[float] = None + # connection_acquiring_timeout: Optional[float] = None + # new_connection_timeout: Optional[float] = None + # http_header_timeout: Optional[float] = None + # response_body_timeout: Optional[float] = None + + # to create a timeout specific for a single request, either + # - create a completely new one to overwrite the default + # - or use http://www.attrs.org/en/stable/api.html#attr.evolve + # to overwrite the defaults + + +# 5 Minute default read timeout +DEFAULT_TIMEOUT: Final[ClientTimeout] = ClientTimeout(total=5 * 60) + +_RetType = TypeVar("_RetType") +_CharsetResolver = Callable[[ClientResponse, bytes], str] + + +class ClientSession: + """First-class interface for making HTTP requests.""" + + ATTRS = frozenset( + [ + "_base_url", + "_source_traceback", + "_connector", + "requote_redirect_url", + "_loop", + "_cookie_jar", + "_connector_owner", + "_default_auth", + "_version", + "_json_serialize", + "_requote_redirect_url", + "_timeout", + "_raise_for_status", + "_auto_decompress", + "_trust_env", + "_default_headers", + "_skip_auto_headers", + "_request_class", + "_response_class", + "_ws_response_class", + "_trace_configs", + "_read_bufsize", + "_max_line_size", + "_max_field_size", + "_resolve_charset", + ] + ) + + _source_traceback: Optional[traceback.StackSummary] = None + _connector: Optional[BaseConnector] = None + + def __init__( + self, + base_url: Optional[StrOrURL] = None, + *, + connector: Optional[BaseConnector] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, + cookies: Optional[LooseCookies] = None, + headers: Optional[LooseHeaders] = None, + skip_auto_headers: Optional[Iterable[str]] = None, + auth: Optional[BasicAuth] = None, + json_serialize: JSONEncoder = json.dumps, + request_class: Type[ClientRequest] = ClientRequest, + response_class: Type[ClientResponse] = ClientResponse, + ws_response_class: Type[ClientWebSocketResponse] = ClientWebSocketResponse, + version: HttpVersion = http.HttpVersion11, + cookie_jar: Optional[AbstractCookieJar] = None, + connector_owner: bool = True, + raise_for_status: Union[ + bool, Callable[[ClientResponse], Awaitable[None]] + ] = False, + read_timeout: Union[float, _SENTINEL] = sentinel, + conn_timeout: Optional[float] = None, + timeout: Union[object, ClientTimeout] = sentinel, + auto_decompress: bool = True, + trust_env: bool = False, + requote_redirect_url: bool = True, + trace_configs: Optional[List[TraceConfig]] = None, + read_bufsize: int = 2**16, + max_line_size: int = 8190, + max_field_size: int = 8190, + fallback_charset_resolver: _CharsetResolver = lambda r, b: "utf-8", + ) -> None: + # We initialise _connector to None immediately, as it's referenced in __del__() + # and could cause issues if an exception occurs during initialisation. + self._connector: Optional[BaseConnector] = None + if timeout is sentinel or timeout is None: + self._timeout = DEFAULT_TIMEOUT + if read_timeout is not sentinel: + warnings.warn( + "read_timeout is deprecated, " "use timeout argument instead", + DeprecationWarning, + stacklevel=2, + ) + self._timeout = attr.evolve(self._timeout, total=read_timeout) + if conn_timeout is not None: + self._timeout = attr.evolve(self._timeout, connect=conn_timeout) + warnings.warn( + "conn_timeout is deprecated, " "use timeout argument instead", + DeprecationWarning, + stacklevel=2, + ) + else: + if not isinstance(timeout, ClientTimeout): + raise ValueError( + f"timeout parameter cannot be of {type(timeout)} type, " + "please use 'timeout=ClientTimeout(...)'", + ) + self._timeout = timeout + if read_timeout is not sentinel: + raise ValueError( + "read_timeout and timeout parameters " + "conflict, please setup " + "timeout.read" + ) + if conn_timeout is not None: + raise ValueError( + "conn_timeout and timeout parameters " + "conflict, please setup " + "timeout.connect" + ) + if loop is None: + if connector is not None: + loop = connector._loop + + loop = get_running_loop(loop) + + if base_url is None or isinstance(base_url, URL): + self._base_url: Optional[URL] = base_url + else: + self._base_url = URL(base_url) + assert ( + self._base_url.origin() == self._base_url + ), "Only absolute URLs without path part are supported" + + if connector is None: + connector = TCPConnector(loop=loop) + + if connector._loop is not loop: + raise RuntimeError("Session and connector has to use same event loop") + + self._loop = loop + + if loop.get_debug(): + self._source_traceback = traceback.extract_stack(sys._getframe(1)) + + if cookie_jar is None: + cookie_jar = CookieJar(loop=loop) + self._cookie_jar = cookie_jar + + if cookies is not None: + self._cookie_jar.update_cookies(cookies) + + self._connector = connector + self._connector_owner = connector_owner + self._default_auth = auth + self._version = version + self._json_serialize = json_serialize + self._raise_for_status = raise_for_status + self._auto_decompress = auto_decompress + self._trust_env = trust_env + self._requote_redirect_url = requote_redirect_url + self._read_bufsize = read_bufsize + self._max_line_size = max_line_size + self._max_field_size = max_field_size + + # Convert to list of tuples + if headers: + real_headers: CIMultiDict[str] = CIMultiDict(headers) + else: + real_headers = CIMultiDict() + self._default_headers: CIMultiDict[str] = real_headers + if skip_auto_headers is not None: + self._skip_auto_headers = frozenset(istr(i) for i in skip_auto_headers) + else: + self._skip_auto_headers = frozenset() + + self._request_class = request_class + self._response_class = response_class + self._ws_response_class = ws_response_class + + self._trace_configs = trace_configs or [] + for trace_config in self._trace_configs: + trace_config.freeze() + + self._resolve_charset = fallback_charset_resolver + + def __init_subclass__(cls: Type["ClientSession"]) -> None: + warnings.warn( + "Inheritance class {} from ClientSession " + "is discouraged".format(cls.__name__), + DeprecationWarning, + stacklevel=2, + ) + + if DEBUG: + + def __setattr__(self, name: str, val: Any) -> None: + if name not in self.ATTRS: + warnings.warn( + "Setting custom ClientSession.{} attribute " + "is discouraged".format(name), + DeprecationWarning, + stacklevel=2, + ) + super().__setattr__(name, val) + + def __del__(self, _warnings: Any = warnings) -> None: + if not self.closed: + kwargs = {"source": self} + _warnings.warn( + f"Unclosed client session {self!r}", ResourceWarning, **kwargs + ) + context = {"client_session": self, "message": "Unclosed client session"} + if self._source_traceback is not None: + context["source_traceback"] = self._source_traceback + self._loop.call_exception_handler(context) + + def request( + self, method: str, url: StrOrURL, **kwargs: Any + ) -> "_RequestContextManager": + """Perform HTTP request.""" + return _RequestContextManager(self._request(method, url, **kwargs)) + + def _build_url(self, str_or_url: StrOrURL) -> URL: + url = URL(str_or_url) + if self._base_url is None: + return url + else: + assert not url.is_absolute() and url.path.startswith("/") + return self._base_url.join(url) + + async def _request( + self, + method: str, + str_or_url: StrOrURL, + *, + params: Optional[Mapping[str, str]] = None, + data: Any = None, + json: Any = None, + cookies: Optional[LooseCookies] = None, + headers: Optional[LooseHeaders] = None, + skip_auto_headers: Optional[Iterable[str]] = None, + auth: Optional[BasicAuth] = None, + allow_redirects: bool = True, + max_redirects: int = 10, + compress: Optional[str] = None, + chunked: Optional[bool] = None, + expect100: bool = False, + raise_for_status: Union[ + None, bool, Callable[[ClientResponse], Awaitable[None]] + ] = None, + read_until_eof: bool = True, + proxy: Optional[StrOrURL] = None, + proxy_auth: Optional[BasicAuth] = None, + timeout: Union[ClientTimeout, _SENTINEL] = sentinel, + verify_ssl: Optional[bool] = None, + fingerprint: Optional[bytes] = None, + ssl_context: Optional[SSLContext] = None, + ssl: Union[SSLContext, bool, Fingerprint] = True, + server_hostname: Optional[str] = None, + proxy_headers: Optional[LooseHeaders] = None, + trace_request_ctx: Optional[SimpleNamespace] = None, + read_bufsize: Optional[int] = None, + auto_decompress: Optional[bool] = None, + max_line_size: Optional[int] = None, + max_field_size: Optional[int] = None, + ) -> ClientResponse: + + # NOTE: timeout clamps existing connect and read timeouts. We cannot + # set the default to None because we need to detect if the user wants + # to use the existing timeouts by setting timeout to None. + + if self.closed: + raise RuntimeError("Session is closed") + + ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint) + + if data is not None and json is not None: + raise ValueError( + "data and json parameters can not be used at the same time" + ) + elif json is not None: + data = payload.JsonPayload(json, dumps=self._json_serialize) + + if not isinstance(chunked, bool) and chunked is not None: + warnings.warn("Chunk size is deprecated #1615", DeprecationWarning) + + redirects = 0 + history = [] + version = self._version + params = params or {} + + # Merge with default headers and transform to CIMultiDict + headers = self._prepare_headers(headers) + proxy_headers = self._prepare_headers(proxy_headers) + + try: + url = self._build_url(str_or_url) + except ValueError as e: + raise InvalidURL(str_or_url) from e + + skip_headers = set(self._skip_auto_headers) + if skip_auto_headers is not None: + for i in skip_auto_headers: + skip_headers.add(istr(i)) + + if proxy is not None: + try: + proxy = URL(proxy) + except ValueError as e: + raise InvalidURL(proxy) from e + + if timeout is sentinel: + real_timeout: ClientTimeout = self._timeout + else: + if not isinstance(timeout, ClientTimeout): + real_timeout = ClientTimeout(total=timeout) + else: + real_timeout = timeout + # timeout is cumulative for all request operations + # (request, redirects, responses, data consuming) + tm = TimeoutHandle( + self._loop, real_timeout.total, ceil_threshold=real_timeout.ceil_threshold + ) + handle = tm.start() + + if read_bufsize is None: + read_bufsize = self._read_bufsize + + if auto_decompress is None: + auto_decompress = self._auto_decompress + + if max_line_size is None: + max_line_size = self._max_line_size + + if max_field_size is None: + max_field_size = self._max_field_size + + traces = [ + Trace( + self, + trace_config, + trace_config.trace_config_ctx(trace_request_ctx=trace_request_ctx), + ) + for trace_config in self._trace_configs + ] + + for trace in traces: + await trace.send_request_start(method, url.update_query(params), headers) + + timer = tm.timer() + try: + with timer: + while True: + url, auth_from_url = strip_auth_from_url(url) + if auth and auth_from_url: + raise ValueError( + "Cannot combine AUTH argument with " + "credentials encoded in URL" + ) + + if auth is None: + auth = auth_from_url + if auth is None: + auth = self._default_auth + # It would be confusing if we support explicit + # Authorization header with auth argument + if ( + headers is not None + and auth is not None + and hdrs.AUTHORIZATION in headers + ): + raise ValueError( + "Cannot combine AUTHORIZATION header " + "with AUTH argument or credentials " + "encoded in URL" + ) + + all_cookies = self._cookie_jar.filter_cookies(url) + + if cookies is not None: + tmp_cookie_jar = CookieJar() + tmp_cookie_jar.update_cookies(cookies) + req_cookies = tmp_cookie_jar.filter_cookies(url) + if req_cookies: + all_cookies.load(req_cookies) + + if proxy is not None: + proxy = URL(proxy) + elif self._trust_env: + with suppress(LookupError): + proxy, proxy_auth = get_env_proxy_for_url(url) + + req = self._request_class( + method, + url, + params=params, + headers=headers, + skip_auto_headers=skip_headers, + data=data, + cookies=all_cookies, + auth=auth, + version=version, + compress=compress, + chunked=chunked, + expect100=expect100, + loop=self._loop, + response_class=self._response_class, + proxy=proxy, + proxy_auth=proxy_auth, + timer=timer, + session=self, + ssl=ssl if ssl is not None else True, + server_hostname=server_hostname, + proxy_headers=proxy_headers, + traces=traces, + trust_env=self.trust_env, + ) + + # connection timeout + try: + async with ceil_timeout( + real_timeout.connect, + ceil_threshold=real_timeout.ceil_threshold, + ): + assert self._connector is not None + conn = await self._connector.connect( + req, traces=traces, timeout=real_timeout + ) + except asyncio.TimeoutError as exc: + raise ServerTimeoutError( + "Connection timeout " "to host {}".format(url) + ) from exc + + assert conn.transport is not None + + assert conn.protocol is not None + conn.protocol.set_response_params( + timer=timer, + skip_payload=method_must_be_empty_body(method), + read_until_eof=read_until_eof, + auto_decompress=auto_decompress, + read_timeout=real_timeout.sock_read, + read_bufsize=read_bufsize, + timeout_ceil_threshold=self._connector._timeout_ceil_threshold, + max_line_size=max_line_size, + max_field_size=max_field_size, + ) + + try: + try: + resp = await req.send(conn) + try: + await resp.start(conn) + except BaseException: + resp.close() + raise + except BaseException: + conn.close() + raise + except ClientError: + raise + except OSError as exc: + if exc.errno is None and isinstance(exc, asyncio.TimeoutError): + raise + raise ClientOSError(*exc.args) from exc + + self._cookie_jar.update_cookies(resp.cookies, resp.url) + + # redirects + if resp.status in (301, 302, 303, 307, 308) and allow_redirects: + + for trace in traces: + await trace.send_request_redirect( + method, url.update_query(params), headers, resp + ) + + redirects += 1 + history.append(resp) + if max_redirects and redirects >= max_redirects: + resp.close() + raise TooManyRedirects( + history[0].request_info, tuple(history) + ) + + # For 301 and 302, mimic IE, now changed in RFC + # https://github.com/kennethreitz/requests/pull/269 + if (resp.status == 303 and resp.method != hdrs.METH_HEAD) or ( + resp.status in (301, 302) and resp.method == hdrs.METH_POST + ): + method = hdrs.METH_GET + data = None + if headers.get(hdrs.CONTENT_LENGTH): + headers.pop(hdrs.CONTENT_LENGTH) + + r_url = resp.headers.get(hdrs.LOCATION) or resp.headers.get( + hdrs.URI + ) + if r_url is None: + # see github.com/aio-libs/aiohttp/issues/2022 + break + else: + # reading from correct redirection + # response is forbidden + resp.release() + + try: + parsed_url = URL( + r_url, encoded=not self._requote_redirect_url + ) + + except ValueError as e: + raise InvalidURL(r_url) from e + + scheme = parsed_url.scheme + if scheme not in ("http", "https", ""): + resp.close() + raise ValueError("Can redirect only to http or https") + elif not scheme: + parsed_url = url.join(parsed_url) + + if url.origin() != parsed_url.origin(): + auth = None + headers.pop(hdrs.AUTHORIZATION, None) + + url = parsed_url + params = {} + resp.release() + continue + + break + + # check response status + if raise_for_status is None: + raise_for_status = self._raise_for_status + + if raise_for_status is None: + pass + elif callable(raise_for_status): + await raise_for_status(resp) + elif raise_for_status: + resp.raise_for_status() + + # register connection + if handle is not None: + if resp.connection is not None: + resp.connection.add_callback(handle.cancel) + else: + handle.cancel() + + resp._history = tuple(history) + + for trace in traces: + await trace.send_request_end( + method, url.update_query(params), headers, resp + ) + return resp + + except BaseException as e: + # cleanup timer + tm.close() + if handle: + handle.cancel() + handle = None + + for trace in traces: + await trace.send_request_exception( + method, url.update_query(params), headers, e + ) + raise + + def ws_connect( + self, + url: StrOrURL, + *, + method: str = hdrs.METH_GET, + protocols: Iterable[str] = (), + timeout: float = 10.0, + receive_timeout: Optional[float] = None, + autoclose: bool = True, + autoping: bool = True, + heartbeat: Optional[float] = None, + auth: Optional[BasicAuth] = None, + origin: Optional[str] = None, + params: Optional[Mapping[str, str]] = None, + headers: Optional[LooseHeaders] = None, + proxy: Optional[StrOrURL] = None, + proxy_auth: Optional[BasicAuth] = None, + ssl: Union[SSLContext, bool, None, Fingerprint] = True, + verify_ssl: Optional[bool] = None, + fingerprint: Optional[bytes] = None, + ssl_context: Optional[SSLContext] = None, + proxy_headers: Optional[LooseHeaders] = None, + compress: int = 0, + max_msg_size: int = 4 * 1024 * 1024, + ) -> "_WSRequestContextManager": + """Initiate websocket connection.""" + return _WSRequestContextManager( + self._ws_connect( + url, + method=method, + protocols=protocols, + timeout=timeout, + receive_timeout=receive_timeout, + autoclose=autoclose, + autoping=autoping, + heartbeat=heartbeat, + auth=auth, + origin=origin, + params=params, + headers=headers, + proxy=proxy, + proxy_auth=proxy_auth, + ssl=ssl, + verify_ssl=verify_ssl, + fingerprint=fingerprint, + ssl_context=ssl_context, + proxy_headers=proxy_headers, + compress=compress, + max_msg_size=max_msg_size, + ) + ) + + async def _ws_connect( + self, + url: StrOrURL, + *, + method: str = hdrs.METH_GET, + protocols: Iterable[str] = (), + timeout: float = 10.0, + receive_timeout: Optional[float] = None, + autoclose: bool = True, + autoping: bool = True, + heartbeat: Optional[float] = None, + auth: Optional[BasicAuth] = None, + origin: Optional[str] = None, + params: Optional[Mapping[str, str]] = None, + headers: Optional[LooseHeaders] = None, + proxy: Optional[StrOrURL] = None, + proxy_auth: Optional[BasicAuth] = None, + ssl: Optional[Union[SSLContext, bool, Fingerprint]] = True, + verify_ssl: Optional[bool] = None, + fingerprint: Optional[bytes] = None, + ssl_context: Optional[SSLContext] = None, + proxy_headers: Optional[LooseHeaders] = None, + compress: int = 0, + max_msg_size: int = 4 * 1024 * 1024, + ) -> ClientWebSocketResponse: + + if headers is None: + real_headers: CIMultiDict[str] = CIMultiDict() + else: + real_headers = CIMultiDict(headers) + + default_headers = { + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "Upgrade", + hdrs.SEC_WEBSOCKET_VERSION: "13", + } + + for key, value in default_headers.items(): + real_headers.setdefault(key, value) + + sec_key = base64.b64encode(os.urandom(16)) + real_headers[hdrs.SEC_WEBSOCKET_KEY] = sec_key.decode() + + if protocols: + real_headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = ",".join(protocols) + if origin is not None: + real_headers[hdrs.ORIGIN] = origin + if compress: + extstr = ws_ext_gen(compress=compress) + real_headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = extstr + + # For the sake of backward compatibility, if user passes in None, convert it to True + if ssl is None: + ssl = True + ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint) + + # send request + resp = await self.request( + method, + url, + params=params, + headers=real_headers, + read_until_eof=False, + auth=auth, + proxy=proxy, + proxy_auth=proxy_auth, + ssl=ssl, + proxy_headers=proxy_headers, + ) + + try: + # check handshake + if resp.status != 101: + raise WSServerHandshakeError( + resp.request_info, + resp.history, + message="Invalid response status", + status=resp.status, + headers=resp.headers, + ) + + if resp.headers.get(hdrs.UPGRADE, "").lower() != "websocket": + raise WSServerHandshakeError( + resp.request_info, + resp.history, + message="Invalid upgrade header", + status=resp.status, + headers=resp.headers, + ) + + if resp.headers.get(hdrs.CONNECTION, "").lower() != "upgrade": + raise WSServerHandshakeError( + resp.request_info, + resp.history, + message="Invalid connection header", + status=resp.status, + headers=resp.headers, + ) + + # key calculation + r_key = resp.headers.get(hdrs.SEC_WEBSOCKET_ACCEPT, "") + match = base64.b64encode(hashlib.sha1(sec_key + WS_KEY).digest()).decode() + if r_key != match: + raise WSServerHandshakeError( + resp.request_info, + resp.history, + message="Invalid challenge response", + status=resp.status, + headers=resp.headers, + ) + + # websocket protocol + protocol = None + if protocols and hdrs.SEC_WEBSOCKET_PROTOCOL in resp.headers: + resp_protocols = [ + proto.strip() + for proto in resp.headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(",") + ] + + for proto in resp_protocols: + if proto in protocols: + protocol = proto + break + + # websocket compress + notakeover = False + if compress: + compress_hdrs = resp.headers.get(hdrs.SEC_WEBSOCKET_EXTENSIONS) + if compress_hdrs: + try: + compress, notakeover = ws_ext_parse(compress_hdrs) + except WSHandshakeError as exc: + raise WSServerHandshakeError( + resp.request_info, + resp.history, + message=exc.args[0], + status=resp.status, + headers=resp.headers, + ) from exc + else: + compress = 0 + notakeover = False + + conn = resp.connection + assert conn is not None + conn_proto = conn.protocol + assert conn_proto is not None + transport = conn.transport + assert transport is not None + reader: FlowControlDataQueue[WSMessage] = FlowControlDataQueue( + conn_proto, 2**16, loop=self._loop + ) + conn_proto.set_parser(WebSocketReader(reader, max_msg_size), reader) + writer = WebSocketWriter( + conn_proto, + transport, + use_mask=True, + compress=compress, + notakeover=notakeover, + ) + except BaseException: + resp.close() + raise + else: + return self._ws_response_class( + reader, + writer, + protocol, + resp, + timeout, + autoclose, + autoping, + self._loop, + receive_timeout=receive_timeout, + heartbeat=heartbeat, + compress=compress, + client_notakeover=notakeover, + ) + + def _prepare_headers(self, headers: Optional[LooseHeaders]) -> "CIMultiDict[str]": + """Add default headers and transform it to CIMultiDict""" + # Convert headers to MultiDict + result = CIMultiDict(self._default_headers) + if headers: + if not isinstance(headers, (MultiDictProxy, MultiDict)): + headers = CIMultiDict(headers) + added_names: Set[str] = set() + for key, value in headers.items(): + if key in added_names: + result.add(key, value) + else: + result[key] = value + added_names.add(key) + return result + + def get( + self, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any + ) -> "_RequestContextManager": + """Perform HTTP GET request.""" + return _RequestContextManager( + self._request(hdrs.METH_GET, url, allow_redirects=allow_redirects, **kwargs) + ) + + def options( + self, url: StrOrURL, *, allow_redirects: bool = True, **kwargs: Any + ) -> "_RequestContextManager": + """Perform HTTP OPTIONS request.""" + return _RequestContextManager( + self._request( + hdrs.METH_OPTIONS, url, allow_redirects=allow_redirects, **kwargs + ) + ) + + def head( + self, url: StrOrURL, *, allow_redirects: bool = False, **kwargs: Any + ) -> "_RequestContextManager": + """Perform HTTP HEAD request.""" + return _RequestContextManager( + self._request( + hdrs.METH_HEAD, url, allow_redirects=allow_redirects, **kwargs + ) + ) + + def post( + self, url: StrOrURL, *, data: Any = None, **kwargs: Any + ) -> "_RequestContextManager": + """Perform HTTP POST request.""" + return _RequestContextManager( + self._request(hdrs.METH_POST, url, data=data, **kwargs) + ) + + def put( + self, url: StrOrURL, *, data: Any = None, **kwargs: Any + ) -> "_RequestContextManager": + """Perform HTTP PUT request.""" + return _RequestContextManager( + self._request(hdrs.METH_PUT, url, data=data, **kwargs) + ) + + def patch( + self, url: StrOrURL, *, data: Any = None, **kwargs: Any + ) -> "_RequestContextManager": + """Perform HTTP PATCH request.""" + return _RequestContextManager( + self._request(hdrs.METH_PATCH, url, data=data, **kwargs) + ) + + def delete(self, url: StrOrURL, **kwargs: Any) -> "_RequestContextManager": + """Perform HTTP DELETE request.""" + return _RequestContextManager(self._request(hdrs.METH_DELETE, url, **kwargs)) + + async def close(self) -> None: + """Close underlying connector. + + Release all acquired resources. + """ + if not self.closed: + if self._connector is not None and self._connector_owner: + await self._connector.close() + self._connector = None + + @property + def closed(self) -> bool: + """Is client session closed. + + A readonly property. + """ + return self._connector is None or self._connector.closed + + @property + def connector(self) -> Optional[BaseConnector]: + """Connector instance used for the session.""" + return self._connector + + @property + def cookie_jar(self) -> AbstractCookieJar: + """The session cookies.""" + return self._cookie_jar + + @property + def version(self) -> Tuple[int, int]: + """The session HTTP protocol version.""" + return self._version + + @property + def requote_redirect_url(self) -> bool: + """Do URL requoting on redirection handling.""" + return self._requote_redirect_url + + @requote_redirect_url.setter + def requote_redirect_url(self, val: bool) -> None: + """Do URL requoting on redirection handling.""" + warnings.warn( + "session.requote_redirect_url modification " "is deprecated #2778", + DeprecationWarning, + stacklevel=2, + ) + self._requote_redirect_url = val + + @property + def loop(self) -> asyncio.AbstractEventLoop: + """Session's loop.""" + warnings.warn( + "client.loop property is deprecated", DeprecationWarning, stacklevel=2 + ) + return self._loop + + @property + def timeout(self) -> ClientTimeout: + """Timeout for the session.""" + return self._timeout + + @property + def headers(self) -> "CIMultiDict[str]": + """The default headers of the client session.""" + return self._default_headers + + @property + def skip_auto_headers(self) -> FrozenSet[istr]: + """Headers for which autogeneration should be skipped""" + return self._skip_auto_headers + + @property + def auth(self) -> Optional[BasicAuth]: + """An object that represents HTTP Basic Authorization""" + return self._default_auth + + @property + def json_serialize(self) -> JSONEncoder: + """Json serializer callable""" + return self._json_serialize + + @property + def connector_owner(self) -> bool: + """Should connector be closed on session closing""" + return self._connector_owner + + @property + def raise_for_status( + self, + ) -> Union[bool, Callable[[ClientResponse], Awaitable[None]]]: + """Should `ClientResponse.raise_for_status()` be called for each response.""" + return self._raise_for_status + + @property + def auto_decompress(self) -> bool: + """Should the body response be automatically decompressed.""" + return self._auto_decompress + + @property + def trust_env(self) -> bool: + """ + Should proxies information from environment or netrc be trusted. + + Information is from HTTP_PROXY / HTTPS_PROXY environment variables + or ~/.netrc file if present. + """ + return self._trust_env + + @property + def trace_configs(self) -> List[TraceConfig]: + """A list of TraceConfig instances used for client tracing""" + return self._trace_configs + + def detach(self) -> None: + """Detach connector from session without closing the former. + + Session is switched to closed state anyway. + """ + self._connector = None + + def __enter__(self) -> None: + raise TypeError("Use async with instead") + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + # __exit__ should exist in pair with __enter__ but never executed + pass # pragma: no cover + + async def __aenter__(self) -> "ClientSession": + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + await self.close() + + +class _BaseRequestContextManager(Coroutine[Any, Any, _RetType], Generic[_RetType]): + + __slots__ = ("_coro", "_resp") + + def __init__(self, coro: Coroutine["asyncio.Future[Any]", None, _RetType]) -> None: + self._coro = coro + + def send(self, arg: None) -> "asyncio.Future[Any]": + return self._coro.send(arg) + + def throw(self, *args: Any, **kwargs: Any) -> "asyncio.Future[Any]": + return self._coro.throw(*args, **kwargs) + + def close(self) -> None: + return self._coro.close() + + def __await__(self) -> Generator[Any, None, _RetType]: + ret = self._coro.__await__() + return ret + + def __iter__(self) -> Generator[Any, None, _RetType]: + return self.__await__() + + async def __aenter__(self) -> _RetType: + self._resp = await self._coro + return self._resp + + +class _RequestContextManager(_BaseRequestContextManager[ClientResponse]): + __slots__ = () + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: + # We're basing behavior on the exception as it can be caused by + # user code unrelated to the status of the connection. If you + # would like to close a connection you must do that + # explicitly. Otherwise connection error handling should kick in + # and close/recycle the connection as required. + self._resp.release() + await self._resp.wait_for_close() + + +class _WSRequestContextManager(_BaseRequestContextManager[ClientWebSocketResponse]): + __slots__ = () + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: + await self._resp.close() + + +class _SessionRequestContextManager: + + __slots__ = ("_coro", "_resp", "_session") + + def __init__( + self, + coro: Coroutine["asyncio.Future[Any]", None, ClientResponse], + session: ClientSession, + ) -> None: + self._coro = coro + self._resp: Optional[ClientResponse] = None + self._session = session + + async def __aenter__(self) -> ClientResponse: + try: + self._resp = await self._coro + except BaseException: + await self._session.close() + raise + else: + return self._resp + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: + assert self._resp is not None + self._resp.close() + await self._session.close() + + +def request( + method: str, + url: StrOrURL, + *, + params: Optional[Mapping[str, str]] = None, + data: Any = None, + json: Any = None, + headers: Optional[LooseHeaders] = None, + skip_auto_headers: Optional[Iterable[str]] = None, + auth: Optional[BasicAuth] = None, + allow_redirects: bool = True, + max_redirects: int = 10, + compress: Optional[str] = None, + chunked: Optional[bool] = None, + expect100: bool = False, + raise_for_status: Optional[bool] = None, + read_until_eof: bool = True, + proxy: Optional[StrOrURL] = None, + proxy_auth: Optional[BasicAuth] = None, + timeout: Union[ClientTimeout, object] = sentinel, + cookies: Optional[LooseCookies] = None, + version: HttpVersion = http.HttpVersion11, + connector: Optional[BaseConnector] = None, + read_bufsize: Optional[int] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, + max_line_size: int = 8190, + max_field_size: int = 8190, +) -> _SessionRequestContextManager: + """Constructs and sends a request. + + Returns response object. + method - HTTP method + url - request url + params - (optional) Dictionary or bytes to be sent in the query + string of the new request + data - (optional) Dictionary, bytes, or file-like object to + send in the body of the request + json - (optional) Any json compatible python object + headers - (optional) Dictionary of HTTP Headers to send with + the request + cookies - (optional) Dict object to send with the request + auth - (optional) BasicAuth named tuple represent HTTP Basic Auth + auth - aiohttp.helpers.BasicAuth + allow_redirects - (optional) If set to False, do not follow + redirects + version - Request HTTP version. + compress - Set to True if request has to be compressed + with deflate encoding. + chunked - Set to chunk size for chunked transfer encoding. + expect100 - Expect 100-continue response from server. + connector - BaseConnector sub-class instance to support + connection pooling. + read_until_eof - Read response until eof if response + does not have Content-Length header. + loop - Optional event loop. + timeout - Optional ClientTimeout settings structure, 5min + total timeout by default. + Usage:: + >>> import aiohttp + >>> resp = await aiohttp.request('GET', 'http://python.org/') + >>> resp + + >>> data = await resp.read() + """ + connector_owner = False + if connector is None: + connector_owner = True + connector = TCPConnector(loop=loop, force_close=True) + + session = ClientSession( + loop=loop, + cookies=cookies, + version=version, + timeout=timeout, + connector=connector, + connector_owner=connector_owner, + ) + + return _SessionRequestContextManager( + session._request( + method, + url, + params=params, + data=data, + json=json, + headers=headers, + skip_auto_headers=skip_auto_headers, + auth=auth, + allow_redirects=allow_redirects, + max_redirects=max_redirects, + compress=compress, + chunked=chunked, + expect100=expect100, + raise_for_status=raise_for_status, + read_until_eof=read_until_eof, + proxy=proxy, + proxy_auth=proxy_auth, + read_bufsize=read_bufsize, + max_line_size=max_line_size, + max_field_size=max_field_size, + ), + session, + ) diff --git a/llm/Lib/site-packages/aiohttp/client_exceptions.py b/llm/Lib/site-packages/aiohttp/client_exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..8af265b59a63e7c7ac03849ceabdc04b1e2b03cb --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/client_exceptions.py @@ -0,0 +1,346 @@ +"""HTTP related errors.""" + +import asyncio +import warnings +from typing import TYPE_CHECKING, Any, Optional, Tuple, Union + +from .http_parser import RawResponseMessage +from .typedefs import LooseHeaders + +try: + import ssl + + SSLContext = ssl.SSLContext +except ImportError: # pragma: no cover + ssl = SSLContext = None # type: ignore[assignment] + + +if TYPE_CHECKING: + from .client_reqrep import ClientResponse, ConnectionKey, Fingerprint, RequestInfo +else: + RequestInfo = ClientResponse = ConnectionKey = None + +__all__ = ( + "ClientError", + "ClientConnectionError", + "ClientOSError", + "ClientConnectorError", + "ClientProxyConnectionError", + "ClientSSLError", + "ClientConnectorSSLError", + "ClientConnectorCertificateError", + "ServerConnectionError", + "ServerTimeoutError", + "ServerDisconnectedError", + "ServerFingerprintMismatch", + "ClientResponseError", + "ClientHttpProxyError", + "WSServerHandshakeError", + "ContentTypeError", + "ClientPayloadError", + "InvalidURL", +) + + +class ClientError(Exception): + """Base class for client connection errors.""" + + +class ClientResponseError(ClientError): + """Base class for exceptions that occur after getting a response. + + request_info: An instance of RequestInfo. + history: A sequence of responses, if redirects occurred. + status: HTTP status code. + message: Error message. + headers: Response headers. + """ + + def __init__( + self, + request_info: RequestInfo, + history: Tuple[ClientResponse, ...], + *, + code: Optional[int] = None, + status: Optional[int] = None, + message: str = "", + headers: Optional[LooseHeaders] = None, + ) -> None: + self.request_info = request_info + if code is not None: + if status is not None: + raise ValueError( + "Both code and status arguments are provided; " + "code is deprecated, use status instead" + ) + warnings.warn( + "code argument is deprecated, use status instead", + DeprecationWarning, + stacklevel=2, + ) + if status is not None: + self.status = status + elif code is not None: + self.status = code + else: + self.status = 0 + self.message = message + self.headers = headers + self.history = history + self.args = (request_info, history) + + def __str__(self) -> str: + return "{}, message={!r}, url={!r}".format( + self.status, + self.message, + self.request_info.real_url, + ) + + def __repr__(self) -> str: + args = f"{self.request_info!r}, {self.history!r}" + if self.status != 0: + args += f", status={self.status!r}" + if self.message != "": + args += f", message={self.message!r}" + if self.headers is not None: + args += f", headers={self.headers!r}" + return f"{type(self).__name__}({args})" + + @property + def code(self) -> int: + warnings.warn( + "code property is deprecated, use status instead", + DeprecationWarning, + stacklevel=2, + ) + return self.status + + @code.setter + def code(self, value: int) -> None: + warnings.warn( + "code property is deprecated, use status instead", + DeprecationWarning, + stacklevel=2, + ) + self.status = value + + +class ContentTypeError(ClientResponseError): + """ContentType found is not valid.""" + + +class WSServerHandshakeError(ClientResponseError): + """websocket server handshake error.""" + + +class ClientHttpProxyError(ClientResponseError): + """HTTP proxy error. + + Raised in :class:`aiohttp.connector.TCPConnector` if + proxy responds with status other than ``200 OK`` + on ``CONNECT`` request. + """ + + +class TooManyRedirects(ClientResponseError): + """Client was redirected too many times.""" + + +class ClientConnectionError(ClientError): + """Base class for client socket errors.""" + + +class ClientOSError(ClientConnectionError, OSError): + """OSError error.""" + + +class ClientConnectorError(ClientOSError): + """Client connector error. + + Raised in :class:`aiohttp.connector.TCPConnector` if + a connection can not be established. + """ + + def __init__(self, connection_key: ConnectionKey, os_error: OSError) -> None: + self._conn_key = connection_key + self._os_error = os_error + super().__init__(os_error.errno, os_error.strerror) + self.args = (connection_key, os_error) + + @property + def os_error(self) -> OSError: + return self._os_error + + @property + def host(self) -> str: + return self._conn_key.host + + @property + def port(self) -> Optional[int]: + return self._conn_key.port + + @property + def ssl(self) -> Union[SSLContext, bool, "Fingerprint"]: + return self._conn_key.ssl + + def __str__(self) -> str: + return "Cannot connect to host {0.host}:{0.port} ssl:{1} [{2}]".format( + self, "default" if self.ssl is True else self.ssl, self.strerror + ) + + # OSError.__reduce__ does too much black magick + __reduce__ = BaseException.__reduce__ + + +class ClientProxyConnectionError(ClientConnectorError): + """Proxy connection error. + + Raised in :class:`aiohttp.connector.TCPConnector` if + connection to proxy can not be established. + """ + + +class UnixClientConnectorError(ClientConnectorError): + """Unix connector error. + + Raised in :py:class:`aiohttp.connector.UnixConnector` + if connection to unix socket can not be established. + """ + + def __init__( + self, path: str, connection_key: ConnectionKey, os_error: OSError + ) -> None: + self._path = path + super().__init__(connection_key, os_error) + + @property + def path(self) -> str: + return self._path + + def __str__(self) -> str: + return "Cannot connect to unix socket {0.path} ssl:{1} [{2}]".format( + self, "default" if self.ssl is True else self.ssl, self.strerror + ) + + +class ServerConnectionError(ClientConnectionError): + """Server connection errors.""" + + +class ServerDisconnectedError(ServerConnectionError): + """Server disconnected.""" + + def __init__(self, message: Union[RawResponseMessage, str, None] = None) -> None: + if message is None: + message = "Server disconnected" + + self.args = (message,) + self.message = message + + +class ServerTimeoutError(ServerConnectionError, asyncio.TimeoutError): + """Server timeout error.""" + + +class ServerFingerprintMismatch(ServerConnectionError): + """SSL certificate does not match expected fingerprint.""" + + def __init__(self, expected: bytes, got: bytes, host: str, port: int) -> None: + self.expected = expected + self.got = got + self.host = host + self.port = port + self.args = (expected, got, host, port) + + def __repr__(self) -> str: + return "<{} expected={!r} got={!r} host={!r} port={!r}>".format( + self.__class__.__name__, self.expected, self.got, self.host, self.port + ) + + +class ClientPayloadError(ClientError): + """Response payload error.""" + + +class InvalidURL(ClientError, ValueError): + """Invalid URL. + + URL used for fetching is malformed, e.g. it doesn't contains host + part. + """ + + # Derive from ValueError for backward compatibility + + def __init__(self, url: Any) -> None: + # The type of url is not yarl.URL because the exception can be raised + # on URL(url) call + super().__init__(url) + + @property + def url(self) -> Any: + return self.args[0] + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} {self.url}>" + + +class ClientSSLError(ClientConnectorError): + """Base error for ssl.*Errors.""" + + +if ssl is not None: + cert_errors = (ssl.CertificateError,) + cert_errors_bases = ( + ClientSSLError, + ssl.CertificateError, + ) + + ssl_errors = (ssl.SSLError,) + ssl_error_bases = (ClientSSLError, ssl.SSLError) +else: # pragma: no cover + cert_errors = tuple() + cert_errors_bases = ( + ClientSSLError, + ValueError, + ) + + ssl_errors = tuple() + ssl_error_bases = (ClientSSLError,) + + +class ClientConnectorSSLError(*ssl_error_bases): # type: ignore[misc] + """Response ssl error.""" + + +class ClientConnectorCertificateError(*cert_errors_bases): # type: ignore[misc] + """Response certificate error.""" + + def __init__( + self, connection_key: ConnectionKey, certificate_error: Exception + ) -> None: + self._conn_key = connection_key + self._certificate_error = certificate_error + self.args = (connection_key, certificate_error) + + @property + def certificate_error(self) -> Exception: + return self._certificate_error + + @property + def host(self) -> str: + return self._conn_key.host + + @property + def port(self) -> Optional[int]: + return self._conn_key.port + + @property + def ssl(self) -> bool: + return self._conn_key.is_ssl + + def __str__(self) -> str: + return ( + "Cannot connect to host {0.host}:{0.port} ssl:{0.ssl} " + "[{0.certificate_error.__class__.__name__}: " + "{0.certificate_error.args}]".format(self) + ) diff --git a/llm/Lib/site-packages/aiohttp/client_proto.py b/llm/Lib/site-packages/aiohttp/client_proto.py new file mode 100644 index 0000000000000000000000000000000000000000..59cf155ccba4b65dcaeff3eacc47b1888b835ce2 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/client_proto.py @@ -0,0 +1,296 @@ +import asyncio +from contextlib import suppress +from typing import Any, Optional, Tuple + +from .base_protocol import BaseProtocol +from .client_exceptions import ( + ClientOSError, + ClientPayloadError, + ServerDisconnectedError, + ServerTimeoutError, +) +from .helpers import ( + _EXC_SENTINEL, + BaseTimerContext, + set_exception, + status_code_must_be_empty_body, +) +from .http import HttpResponseParser, RawResponseMessage +from .http_exceptions import HttpProcessingError +from .streams import EMPTY_PAYLOAD, DataQueue, StreamReader + + +class ResponseHandler(BaseProtocol, DataQueue[Tuple[RawResponseMessage, StreamReader]]): + """Helper class to adapt between Protocol and StreamReader.""" + + def __init__(self, loop: asyncio.AbstractEventLoop) -> None: + BaseProtocol.__init__(self, loop=loop) + DataQueue.__init__(self, loop) + + self._should_close = False + + self._payload: Optional[StreamReader] = None + self._skip_payload = False + self._payload_parser = None + + self._timer = None + + self._tail = b"" + self._upgraded = False + self._parser: Optional[HttpResponseParser] = None + + self._read_timeout: Optional[float] = None + self._read_timeout_handle: Optional[asyncio.TimerHandle] = None + + self._timeout_ceil_threshold: Optional[float] = 5 + + @property + def upgraded(self) -> bool: + return self._upgraded + + @property + def should_close(self) -> bool: + if self._payload is not None and not self._payload.is_eof() or self._upgraded: + return True + + return ( + self._should_close + or self._upgraded + or self.exception() is not None + or self._payload_parser is not None + or len(self) > 0 + or bool(self._tail) + ) + + def force_close(self) -> None: + self._should_close = True + + def close(self) -> None: + transport = self.transport + if transport is not None: + transport.close() + self.transport = None + self._payload = None + self._drop_timeout() + + def is_connected(self) -> bool: + return self.transport is not None and not self.transport.is_closing() + + def connection_lost(self, exc: Optional[BaseException]) -> None: + self._drop_timeout() + + original_connection_error = exc + reraised_exc = original_connection_error + + connection_closed_cleanly = original_connection_error is None + + if self._payload_parser is not None: + with suppress(Exception): # FIXME: log this somehow? + self._payload_parser.feed_eof() + + uncompleted = None + if self._parser is not None: + try: + uncompleted = self._parser.feed_eof() + except Exception as underlying_exc: + if self._payload is not None: + client_payload_exc_msg = ( + f"Response payload is not completed: {underlying_exc !r}" + ) + if not connection_closed_cleanly: + client_payload_exc_msg = ( + f"{client_payload_exc_msg !s}. " + f"{original_connection_error !r}" + ) + set_exception( + self._payload, + ClientPayloadError(client_payload_exc_msg), + underlying_exc, + ) + + if not self.is_eof(): + if isinstance(original_connection_error, OSError): + reraised_exc = ClientOSError(*original_connection_error.args) + if connection_closed_cleanly: + reraised_exc = ServerDisconnectedError(uncompleted) + # assigns self._should_close to True as side effect, + # we do it anyway below + underlying_non_eof_exc = ( + _EXC_SENTINEL + if connection_closed_cleanly + else original_connection_error + ) + assert underlying_non_eof_exc is not None + assert reraised_exc is not None + self.set_exception(reraised_exc, underlying_non_eof_exc) + + self._should_close = True + self._parser = None + self._payload = None + self._payload_parser = None + self._reading_paused = False + + super().connection_lost(reraised_exc) + + def eof_received(self) -> None: + # should call parser.feed_eof() most likely + self._drop_timeout() + + def pause_reading(self) -> None: + super().pause_reading() + self._drop_timeout() + + def resume_reading(self) -> None: + super().resume_reading() + self._reschedule_timeout() + + def set_exception( + self, + exc: BaseException, + exc_cause: BaseException = _EXC_SENTINEL, + ) -> None: + self._should_close = True + self._drop_timeout() + super().set_exception(exc, exc_cause) + + def set_parser(self, parser: Any, payload: Any) -> None: + # TODO: actual types are: + # parser: WebSocketReader + # payload: FlowControlDataQueue + # but they are not generi enough + # Need an ABC for both types + self._payload = payload + self._payload_parser = parser + + self._drop_timeout() + + if self._tail: + data, self._tail = self._tail, b"" + self.data_received(data) + + def set_response_params( + self, + *, + timer: Optional[BaseTimerContext] = None, + skip_payload: bool = False, + read_until_eof: bool = False, + auto_decompress: bool = True, + read_timeout: Optional[float] = None, + read_bufsize: int = 2**16, + timeout_ceil_threshold: float = 5, + max_line_size: int = 8190, + max_field_size: int = 8190, + ) -> None: + self._skip_payload = skip_payload + + self._read_timeout = read_timeout + + self._timeout_ceil_threshold = timeout_ceil_threshold + + self._parser = HttpResponseParser( + self, + self._loop, + read_bufsize, + timer=timer, + payload_exception=ClientPayloadError, + response_with_body=not skip_payload, + read_until_eof=read_until_eof, + auto_decompress=auto_decompress, + max_line_size=max_line_size, + max_field_size=max_field_size, + ) + + if self._tail: + data, self._tail = self._tail, b"" + self.data_received(data) + + def _drop_timeout(self) -> None: + if self._read_timeout_handle is not None: + self._read_timeout_handle.cancel() + self._read_timeout_handle = None + + def _reschedule_timeout(self) -> None: + timeout = self._read_timeout + if self._read_timeout_handle is not None: + self._read_timeout_handle.cancel() + + if timeout: + self._read_timeout_handle = self._loop.call_later( + timeout, self._on_read_timeout + ) + else: + self._read_timeout_handle = None + + def start_timeout(self) -> None: + self._reschedule_timeout() + + def _on_read_timeout(self) -> None: + exc = ServerTimeoutError("Timeout on reading data from socket") + self.set_exception(exc) + if self._payload is not None: + set_exception(self._payload, exc) + + def data_received(self, data: bytes) -> None: + self._reschedule_timeout() + + if not data: + return + + # custom payload parser + if self._payload_parser is not None: + eof, tail = self._payload_parser.feed_data(data) + if eof: + self._payload = None + self._payload_parser = None + + if tail: + self.data_received(tail) + return + else: + if self._upgraded or self._parser is None: + # i.e. websocket connection, websocket parser is not set yet + self._tail += data + else: + # parse http messages + try: + messages, upgraded, tail = self._parser.feed_data(data) + except BaseException as underlying_exc: + if self.transport is not None: + # connection.release() could be called BEFORE + # data_received(), the transport is already + # closed in this case + self.transport.close() + # should_close is True after the call + self.set_exception(HttpProcessingError(), underlying_exc) + return + + self._upgraded = upgraded + + payload: Optional[StreamReader] = None + for message, payload in messages: + if message.should_close: + self._should_close = True + + self._payload = payload + + if self._skip_payload or status_code_must_be_empty_body( + message.code + ): + self.feed_data((message, EMPTY_PAYLOAD), 0) + else: + self.feed_data((message, payload), 0) + if payload is not None: + # new message(s) was processed + # register timeout handler unsubscribing + # either on end-of-stream or immediately for + # EMPTY_PAYLOAD + if payload is not EMPTY_PAYLOAD: + payload.on_eof(self._drop_timeout) + else: + self._drop_timeout() + + if tail: + if upgraded: + self.data_received(tail) + else: + self._tail = tail diff --git a/llm/Lib/site-packages/aiohttp/client_reqrep.py b/llm/Lib/site-packages/aiohttp/client_reqrep.py new file mode 100644 index 0000000000000000000000000000000000000000..9bb7e51b7a886c5435e79ed8cd5ad777b91aebad --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/client_reqrep.py @@ -0,0 +1,1207 @@ +import asyncio +import codecs +import contextlib +import functools +import io +import re +import sys +import traceback +import warnings +from hashlib import md5, sha1, sha256 +from http.cookies import CookieError, Morsel, SimpleCookie +from types import MappingProxyType, TracebackType +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Tuple, + Type, + Union, + cast, +) + +import attr +from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy +from yarl import URL + +from . import hdrs, helpers, http, multipart, payload +from .abc import AbstractStreamWriter +from .client_exceptions import ( + ClientConnectionError, + ClientOSError, + ClientResponseError, + ContentTypeError, + InvalidURL, + ServerFingerprintMismatch, +) +from .compression_utils import HAS_BROTLI +from .formdata import FormData +from .helpers import ( + BaseTimerContext, + BasicAuth, + HeadersMixin, + TimerNoop, + basicauth_from_netrc, + netrc_from_env, + noop, + reify, + set_exception, + set_result, +) +from .http import ( + SERVER_SOFTWARE, + HttpVersion, + HttpVersion10, + HttpVersion11, + StreamWriter, +) +from .log import client_logger +from .streams import StreamReader +from .typedefs import ( + DEFAULT_JSON_DECODER, + JSONDecoder, + LooseCookies, + LooseHeaders, + RawHeaders, +) + +try: + import ssl + from ssl import SSLContext +except ImportError: # pragma: no cover + ssl = None # type: ignore[assignment] + SSLContext = object # type: ignore[misc,assignment] + + +__all__ = ("ClientRequest", "ClientResponse", "RequestInfo", "Fingerprint") + + +if TYPE_CHECKING: + from .client import ClientSession + from .connector import Connection + from .tracing import Trace + + +_CONTAINS_CONTROL_CHAR_RE = re.compile(r"[^-!#$%&'*+.^_`|~0-9a-zA-Z]") +json_re = re.compile(r"^application/(?:[\w.+-]+?\+)?json") + + +def _gen_default_accept_encoding() -> str: + return "gzip, deflate, br" if HAS_BROTLI else "gzip, deflate" + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class ContentDisposition: + type: Optional[str] + parameters: "MappingProxyType[str, str]" + filename: Optional[str] + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class RequestInfo: + url: URL + method: str + headers: "CIMultiDictProxy[str]" + real_url: URL = attr.ib() + + @real_url.default + def real_url_default(self) -> URL: + return self.url + + +class Fingerprint: + HASHFUNC_BY_DIGESTLEN = { + 16: md5, + 20: sha1, + 32: sha256, + } + + def __init__(self, fingerprint: bytes) -> None: + digestlen = len(fingerprint) + hashfunc = self.HASHFUNC_BY_DIGESTLEN.get(digestlen) + if not hashfunc: + raise ValueError("fingerprint has invalid length") + elif hashfunc is md5 or hashfunc is sha1: + raise ValueError( + "md5 and sha1 are insecure and " "not supported. Use sha256." + ) + self._hashfunc = hashfunc + self._fingerprint = fingerprint + + @property + def fingerprint(self) -> bytes: + return self._fingerprint + + def check(self, transport: asyncio.Transport) -> None: + if not transport.get_extra_info("sslcontext"): + return + sslobj = transport.get_extra_info("ssl_object") + cert = sslobj.getpeercert(binary_form=True) + got = self._hashfunc(cert).digest() + if got != self._fingerprint: + host, port, *_ = transport.get_extra_info("peername") + raise ServerFingerprintMismatch(self._fingerprint, got, host, port) + + +if ssl is not None: + SSL_ALLOWED_TYPES = (ssl.SSLContext, bool, Fingerprint, type(None)) +else: # pragma: no cover + SSL_ALLOWED_TYPES = (bool, type(None)) + + +def _merge_ssl_params( + ssl: Union["SSLContext", bool, Fingerprint], + verify_ssl: Optional[bool], + ssl_context: Optional["SSLContext"], + fingerprint: Optional[bytes], +) -> Union["SSLContext", bool, Fingerprint]: + if ssl is None: + ssl = True # Double check for backwards compatibility + if verify_ssl is not None and not verify_ssl: + warnings.warn( + "verify_ssl is deprecated, use ssl=False instead", + DeprecationWarning, + stacklevel=3, + ) + if ssl is not True: + raise ValueError( + "verify_ssl, ssl_context, fingerprint and ssl " + "parameters are mutually exclusive" + ) + else: + ssl = False + if ssl_context is not None: + warnings.warn( + "ssl_context is deprecated, use ssl=context instead", + DeprecationWarning, + stacklevel=3, + ) + if ssl is not True: + raise ValueError( + "verify_ssl, ssl_context, fingerprint and ssl " + "parameters are mutually exclusive" + ) + else: + ssl = ssl_context + if fingerprint is not None: + warnings.warn( + "fingerprint is deprecated, " "use ssl=Fingerprint(fingerprint) instead", + DeprecationWarning, + stacklevel=3, + ) + if ssl is not True: + raise ValueError( + "verify_ssl, ssl_context, fingerprint and ssl " + "parameters are mutually exclusive" + ) + else: + ssl = Fingerprint(fingerprint) + if not isinstance(ssl, SSL_ALLOWED_TYPES): + raise TypeError( + "ssl should be SSLContext, bool, Fingerprint or None, " + "got {!r} instead.".format(ssl) + ) + return ssl + + +@attr.s(auto_attribs=True, slots=True, frozen=True) +class ConnectionKey: + # the key should contain an information about used proxy / TLS + # to prevent reusing wrong connections from a pool + host: str + port: Optional[int] + is_ssl: bool + ssl: Union[SSLContext, bool, Fingerprint] + proxy: Optional[URL] + proxy_auth: Optional[BasicAuth] + proxy_headers_hash: Optional[int] # hash(CIMultiDict) + + +def _is_expected_content_type( + response_content_type: str, expected_content_type: str +) -> bool: + if expected_content_type == "application/json": + return json_re.match(response_content_type) is not None + return expected_content_type in response_content_type + + +class ClientRequest: + GET_METHODS = { + hdrs.METH_GET, + hdrs.METH_HEAD, + hdrs.METH_OPTIONS, + hdrs.METH_TRACE, + } + POST_METHODS = {hdrs.METH_PATCH, hdrs.METH_POST, hdrs.METH_PUT} + ALL_METHODS = GET_METHODS.union(POST_METHODS).union({hdrs.METH_DELETE}) + + DEFAULT_HEADERS = { + hdrs.ACCEPT: "*/*", + hdrs.ACCEPT_ENCODING: _gen_default_accept_encoding(), + } + + body = b"" + auth = None + response = None + + __writer = None # async task for streaming data + _continue = None # waiter future for '100 Continue' response + + # N.B. + # Adding __del__ method with self._writer closing doesn't make sense + # because _writer is instance method, thus it keeps a reference to self. + # Until writer has finished finalizer will not be called. + + def __init__( + self, + method: str, + url: URL, + *, + params: Optional[Mapping[str, str]] = None, + headers: Optional[LooseHeaders] = None, + skip_auto_headers: Iterable[str] = frozenset(), + data: Any = None, + cookies: Optional[LooseCookies] = None, + auth: Optional[BasicAuth] = None, + version: http.HttpVersion = http.HttpVersion11, + compress: Optional[str] = None, + chunked: Optional[bool] = None, + expect100: bool = False, + loop: Optional[asyncio.AbstractEventLoop] = None, + response_class: Optional[Type["ClientResponse"]] = None, + proxy: Optional[URL] = None, + proxy_auth: Optional[BasicAuth] = None, + timer: Optional[BaseTimerContext] = None, + session: Optional["ClientSession"] = None, + ssl: Union[SSLContext, bool, Fingerprint] = True, + proxy_headers: Optional[LooseHeaders] = None, + traces: Optional[List["Trace"]] = None, + trust_env: bool = False, + server_hostname: Optional[str] = None, + ): + if loop is None: + loop = asyncio.get_event_loop() + + match = _CONTAINS_CONTROL_CHAR_RE.search(method) + if match: + raise ValueError( + f"Method cannot contain non-token characters {method!r} " + "(found at least {match.group()!r})" + ) + + assert isinstance(url, URL), url + assert isinstance(proxy, (URL, type(None))), proxy + # FIXME: session is None in tests only, need to fix tests + # assert session is not None + self._session = cast("ClientSession", session) + if params: + q = MultiDict(url.query) + url2 = url.with_query(params) + q.extend(url2.query) + url = url.with_query(q) + self.original_url = url + self.url = url.with_fragment(None) + self.method = method.upper() + self.chunked = chunked + self.compress = compress + self.loop = loop + self.length = None + if response_class is None: + real_response_class = ClientResponse + else: + real_response_class = response_class + self.response_class: Type[ClientResponse] = real_response_class + self._timer = timer if timer is not None else TimerNoop() + self._ssl = ssl if ssl is not None else True + self.server_hostname = server_hostname + + if loop.get_debug(): + self._source_traceback = traceback.extract_stack(sys._getframe(1)) + + self.update_version(version) + self.update_host(url) + self.update_headers(headers) + self.update_auto_headers(skip_auto_headers) + self.update_cookies(cookies) + self.update_content_encoding(data) + self.update_auth(auth, trust_env) + self.update_proxy(proxy, proxy_auth, proxy_headers) + + self.update_body_from_data(data) + if data is not None or self.method not in self.GET_METHODS: + self.update_transfer_encoding() + self.update_expect_continue(expect100) + if traces is None: + traces = [] + self._traces = traces + + def __reset_writer(self, _: object = None) -> None: + self.__writer = None + + @property + def _writer(self) -> Optional["asyncio.Task[None]"]: + return self.__writer + + @_writer.setter + def _writer(self, writer: Optional["asyncio.Task[None]"]) -> None: + if self.__writer is not None: + self.__writer.remove_done_callback(self.__reset_writer) + self.__writer = writer + if writer is not None: + writer.add_done_callback(self.__reset_writer) + + def is_ssl(self) -> bool: + return self.url.scheme in ("https", "wss") + + @property + def ssl(self) -> Union["SSLContext", bool, Fingerprint]: + return self._ssl + + @property + def connection_key(self) -> ConnectionKey: + proxy_headers = self.proxy_headers + if proxy_headers: + h: Optional[int] = hash(tuple((k, v) for k, v in proxy_headers.items())) + else: + h = None + return ConnectionKey( + self.host, + self.port, + self.is_ssl(), + self.ssl, + self.proxy, + self.proxy_auth, + h, + ) + + @property + def host(self) -> str: + ret = self.url.raw_host + assert ret is not None + return ret + + @property + def port(self) -> Optional[int]: + return self.url.port + + @property + def request_info(self) -> RequestInfo: + headers: CIMultiDictProxy[str] = CIMultiDictProxy(self.headers) + return RequestInfo(self.url, self.method, headers, self.original_url) + + def update_host(self, url: URL) -> None: + """Update destination host, port and connection type (ssl).""" + # get host/port + if not url.raw_host: + raise InvalidURL(url) + + # basic auth info + username, password = url.user, url.password + if username: + self.auth = helpers.BasicAuth(username, password or "") + + def update_version(self, version: Union[http.HttpVersion, str]) -> None: + """Convert request version to two elements tuple. + + parser HTTP version '1.1' => (1, 1) + """ + if isinstance(version, str): + v = [part.strip() for part in version.split(".", 1)] + try: + version = http.HttpVersion(int(v[0]), int(v[1])) + except ValueError: + raise ValueError( + f"Can not parse http version number: {version}" + ) from None + self.version = version + + def update_headers(self, headers: Optional[LooseHeaders]) -> None: + """Update request headers.""" + self.headers: CIMultiDict[str] = CIMultiDict() + + # add host + netloc = cast(str, self.url.raw_host) + if helpers.is_ipv6_address(netloc): + netloc = f"[{netloc}]" + # See https://github.com/aio-libs/aiohttp/issues/3636. + netloc = netloc.rstrip(".") + if self.url.port is not None and not self.url.is_default_port(): + netloc += ":" + str(self.url.port) + self.headers[hdrs.HOST] = netloc + + if headers: + if isinstance(headers, (dict, MultiDictProxy, MultiDict)): + headers = headers.items() # type: ignore[assignment] + + for key, value in headers: # type: ignore[misc] + # A special case for Host header + if key.lower() == "host": + self.headers[key] = value + else: + self.headers.add(key, value) + + def update_auto_headers(self, skip_auto_headers: Iterable[str]) -> None: + self.skip_auto_headers = CIMultiDict( + (hdr, None) for hdr in sorted(skip_auto_headers) + ) + used_headers = self.headers.copy() + used_headers.extend(self.skip_auto_headers) # type: ignore[arg-type] + + for hdr, val in self.DEFAULT_HEADERS.items(): + if hdr not in used_headers: + self.headers.add(hdr, val) + + if hdrs.USER_AGENT not in used_headers: + self.headers[hdrs.USER_AGENT] = SERVER_SOFTWARE + + def update_cookies(self, cookies: Optional[LooseCookies]) -> None: + """Update request cookies header.""" + if not cookies: + return + + c = SimpleCookie() + if hdrs.COOKIE in self.headers: + c.load(self.headers.get(hdrs.COOKIE, "")) + del self.headers[hdrs.COOKIE] + + if isinstance(cookies, Mapping): + iter_cookies = cookies.items() + else: + iter_cookies = cookies # type: ignore[assignment] + for name, value in iter_cookies: + if isinstance(value, Morsel): + # Preserve coded_value + mrsl_val = value.get(value.key, Morsel()) + mrsl_val.set(value.key, value.value, value.coded_value) + c[name] = mrsl_val + else: + c[name] = value # type: ignore[assignment] + + self.headers[hdrs.COOKIE] = c.output(header="", sep=";").strip() + + def update_content_encoding(self, data: Any) -> None: + """Set request content encoding.""" + if data is None: + return + + enc = self.headers.get(hdrs.CONTENT_ENCODING, "").lower() + if enc: + if self.compress: + raise ValueError( + "compress can not be set " "if Content-Encoding header is set" + ) + elif self.compress: + if not isinstance(self.compress, str): + self.compress = "deflate" + self.headers[hdrs.CONTENT_ENCODING] = self.compress + self.chunked = True # enable chunked, no need to deal with length + + def update_transfer_encoding(self) -> None: + """Analyze transfer-encoding header.""" + te = self.headers.get(hdrs.TRANSFER_ENCODING, "").lower() + + if "chunked" in te: + if self.chunked: + raise ValueError( + "chunked can not be set " + 'if "Transfer-Encoding: chunked" header is set' + ) + + elif self.chunked: + if hdrs.CONTENT_LENGTH in self.headers: + raise ValueError( + "chunked can not be set " "if Content-Length header is set" + ) + + self.headers[hdrs.TRANSFER_ENCODING] = "chunked" + else: + if hdrs.CONTENT_LENGTH not in self.headers: + self.headers[hdrs.CONTENT_LENGTH] = str(len(self.body)) + + def update_auth(self, auth: Optional[BasicAuth], trust_env: bool = False) -> None: + """Set basic auth.""" + if auth is None: + auth = self.auth + if auth is None and trust_env and self.url.host is not None: + netrc_obj = netrc_from_env() + with contextlib.suppress(LookupError): + auth = basicauth_from_netrc(netrc_obj, self.url.host) + if auth is None: + return + + if not isinstance(auth, helpers.BasicAuth): + raise TypeError("BasicAuth() tuple is required instead") + + self.headers[hdrs.AUTHORIZATION] = auth.encode() + + def update_body_from_data(self, body: Any) -> None: + if body is None: + return + + # FormData + if isinstance(body, FormData): + body = body() + + try: + body = payload.PAYLOAD_REGISTRY.get(body, disposition=None) + except payload.LookupError: + body = FormData(body)() + + self.body = body + + # enable chunked encoding if needed + if not self.chunked: + if hdrs.CONTENT_LENGTH not in self.headers: + size = body.size + if size is None: + self.chunked = True + else: + if hdrs.CONTENT_LENGTH not in self.headers: + self.headers[hdrs.CONTENT_LENGTH] = str(size) + + # copy payload headers + assert body.headers + for (key, value) in body.headers.items(): + if key in self.headers: + continue + if key in self.skip_auto_headers: + continue + self.headers[key] = value + + def update_expect_continue(self, expect: bool = False) -> None: + if expect: + self.headers[hdrs.EXPECT] = "100-continue" + elif self.headers.get(hdrs.EXPECT, "").lower() == "100-continue": + expect = True + + if expect: + self._continue = self.loop.create_future() + + def update_proxy( + self, + proxy: Optional[URL], + proxy_auth: Optional[BasicAuth], + proxy_headers: Optional[LooseHeaders], + ) -> None: + if proxy_auth and not isinstance(proxy_auth, helpers.BasicAuth): + raise ValueError("proxy_auth must be None or BasicAuth() tuple") + self.proxy = proxy + self.proxy_auth = proxy_auth + self.proxy_headers = proxy_headers + + def keep_alive(self) -> bool: + if self.version < HttpVersion10: + # keep alive not supported at all + return False + if self.version == HttpVersion10: + if self.headers.get(hdrs.CONNECTION) == "keep-alive": + return True + else: # no headers means we close for Http 1.0 + return False + elif self.headers.get(hdrs.CONNECTION) == "close": + return False + + return True + + async def write_bytes( + self, writer: AbstractStreamWriter, conn: "Connection" + ) -> None: + """Support coroutines that yields bytes objects.""" + # 100 response + if self._continue is not None: + try: + await writer.drain() + await self._continue + except asyncio.CancelledError: + return + + protocol = conn.protocol + assert protocol is not None + try: + if isinstance(self.body, payload.Payload): + await self.body.write(writer) + else: + if isinstance(self.body, (bytes, bytearray)): + self.body = (self.body,) # type: ignore[assignment] + + for chunk in self.body: + await writer.write(chunk) # type: ignore[arg-type] + except OSError as underlying_exc: + reraised_exc = underlying_exc + + exc_is_not_timeout = underlying_exc.errno is not None or not isinstance( + underlying_exc, asyncio.TimeoutError + ) + if exc_is_not_timeout: + reraised_exc = ClientOSError( + underlying_exc.errno, + f"Can not write request body for {self.url !s}", + ) + + set_exception(protocol, reraised_exc, underlying_exc) + except asyncio.CancelledError: + await writer.write_eof() + except Exception as underlying_exc: + set_exception( + protocol, + ClientConnectionError( + f"Failed to send bytes into the underlying connection {conn !s}", + ), + underlying_exc, + ) + else: + await writer.write_eof() + protocol.start_timeout() + + async def send(self, conn: "Connection") -> "ClientResponse": + # Specify request target: + # - CONNECT request must send authority form URI + # - not CONNECT proxy must send absolute form URI + # - most common is origin form URI + if self.method == hdrs.METH_CONNECT: + connect_host = self.url.raw_host + assert connect_host is not None + if helpers.is_ipv6_address(connect_host): + connect_host = f"[{connect_host}]" + path = f"{connect_host}:{self.url.port}" + elif self.proxy and not self.is_ssl(): + path = str(self.url) + else: + path = self.url.raw_path + if self.url.raw_query_string: + path += "?" + self.url.raw_query_string + + protocol = conn.protocol + assert protocol is not None + writer = StreamWriter( + protocol, + self.loop, + on_chunk_sent=functools.partial( + self._on_chunk_request_sent, self.method, self.url + ), + on_headers_sent=functools.partial( + self._on_headers_request_sent, self.method, self.url + ), + ) + + if self.compress: + writer.enable_compression(self.compress) + + if self.chunked is not None: + writer.enable_chunking() + + # set default content-type + if ( + self.method in self.POST_METHODS + and hdrs.CONTENT_TYPE not in self.skip_auto_headers + and hdrs.CONTENT_TYPE not in self.headers + ): + self.headers[hdrs.CONTENT_TYPE] = "application/octet-stream" + + # set the connection header + connection = self.headers.get(hdrs.CONNECTION) + if not connection: + if self.keep_alive(): + if self.version == HttpVersion10: + connection = "keep-alive" + else: + if self.version == HttpVersion11: + connection = "close" + + if connection is not None: + self.headers[hdrs.CONNECTION] = connection + + # status + headers + status_line = "{0} {1} HTTP/{v.major}.{v.minor}".format( + self.method, path, v=self.version + ) + await writer.write_headers(status_line, self.headers) + + self._writer = self.loop.create_task(self.write_bytes(writer, conn)) + + response_class = self.response_class + assert response_class is not None + self.response = response_class( + self.method, + self.original_url, + writer=self._writer, + continue100=self._continue, + timer=self._timer, + request_info=self.request_info, + traces=self._traces, + loop=self.loop, + session=self._session, + ) + return self.response + + async def close(self) -> None: + if self._writer is not None: + with contextlib.suppress(asyncio.CancelledError): + await self._writer + + def terminate(self) -> None: + if self._writer is not None: + if not self.loop.is_closed(): + self._writer.cancel() + self._writer.remove_done_callback(self.__reset_writer) + self._writer = None + + async def _on_chunk_request_sent(self, method: str, url: URL, chunk: bytes) -> None: + for trace in self._traces: + await trace.send_request_chunk_sent(method, url, chunk) + + async def _on_headers_request_sent( + self, method: str, url: URL, headers: "CIMultiDict[str]" + ) -> None: + for trace in self._traces: + await trace.send_request_headers(method, url, headers) + + +class ClientResponse(HeadersMixin): + + # Some of these attributes are None when created, + # but will be set by the start() method. + # As the end user will likely never see the None values, we cheat the types below. + # from the Status-Line of the response + version: Optional[HttpVersion] = None # HTTP-Version + status: int = None # type: ignore[assignment] # Status-Code + reason: Optional[str] = None # Reason-Phrase + + content: StreamReader = None # type: ignore[assignment] # Payload stream + _headers: CIMultiDictProxy[str] = None # type: ignore[assignment] + _raw_headers: RawHeaders = None # type: ignore[assignment] + + _connection = None # current connection + _source_traceback: Optional[traceback.StackSummary] = None + # set up by ClientRequest after ClientResponse object creation + # post-init stage allows to not change ctor signature + _closed = True # to allow __del__ for non-initialized properly response + _released = False + __writer = None + + def __init__( + self, + method: str, + url: URL, + *, + writer: "asyncio.Task[None]", + continue100: Optional["asyncio.Future[bool]"], + timer: BaseTimerContext, + request_info: RequestInfo, + traces: List["Trace"], + loop: asyncio.AbstractEventLoop, + session: "ClientSession", + ) -> None: + assert isinstance(url, URL) + + self.method = method + self.cookies = SimpleCookie() + + self._real_url = url + self._url = url.with_fragment(None) + self._body: Any = None + self._writer: Optional[asyncio.Task[None]] = writer + self._continue = continue100 # None by default + self._closed = True + self._history: Tuple[ClientResponse, ...] = () + self._request_info = request_info + self._timer = timer if timer is not None else TimerNoop() + self._cache: Dict[str, Any] = {} + self._traces = traces + self._loop = loop + # store a reference to session #1985 + self._session: Optional[ClientSession] = session + # Save reference to _resolve_charset, so that get_encoding() will still + # work after the response has finished reading the body. + if session is None: + # TODO: Fix session=None in tests (see ClientRequest.__init__). + self._resolve_charset: Callable[ + ["ClientResponse", bytes], str + ] = lambda *_: "utf-8" + else: + self._resolve_charset = session._resolve_charset + if loop.get_debug(): + self._source_traceback = traceback.extract_stack(sys._getframe(1)) + + def __reset_writer(self, _: object = None) -> None: + self.__writer = None + + @property + def _writer(self) -> Optional["asyncio.Task[None]"]: + return self.__writer + + @_writer.setter + def _writer(self, writer: Optional["asyncio.Task[None]"]) -> None: + if self.__writer is not None: + self.__writer.remove_done_callback(self.__reset_writer) + self.__writer = writer + if writer is not None: + writer.add_done_callback(self.__reset_writer) + + @reify + def url(self) -> URL: + return self._url + + @reify + def url_obj(self) -> URL: + warnings.warn("Deprecated, use .url #1654", DeprecationWarning, stacklevel=2) + return self._url + + @reify + def real_url(self) -> URL: + return self._real_url + + @reify + def host(self) -> str: + assert self._url.host is not None + return self._url.host + + @reify + def headers(self) -> "CIMultiDictProxy[str]": + return self._headers + + @reify + def raw_headers(self) -> RawHeaders: + return self._raw_headers + + @reify + def request_info(self) -> RequestInfo: + return self._request_info + + @reify + def content_disposition(self) -> Optional[ContentDisposition]: + raw = self._headers.get(hdrs.CONTENT_DISPOSITION) + if raw is None: + return None + disposition_type, params_dct = multipart.parse_content_disposition(raw) + params = MappingProxyType(params_dct) + filename = multipart.content_disposition_filename(params) + return ContentDisposition(disposition_type, params, filename) + + def __del__(self, _warnings: Any = warnings) -> None: + if self._closed: + return + + if self._connection is not None: + self._connection.release() + self._cleanup_writer() + + if self._loop.get_debug(): + kwargs = {"source": self} + _warnings.warn(f"Unclosed response {self!r}", ResourceWarning, **kwargs) + context = {"client_response": self, "message": "Unclosed response"} + if self._source_traceback: + context["source_traceback"] = self._source_traceback + self._loop.call_exception_handler(context) + + def __repr__(self) -> str: + out = io.StringIO() + ascii_encodable_url = str(self.url) + if self.reason: + ascii_encodable_reason = self.reason.encode( + "ascii", "backslashreplace" + ).decode("ascii") + else: + ascii_encodable_reason = "None" + print( + "".format( + ascii_encodable_url, self.status, ascii_encodable_reason + ), + file=out, + ) + print(self.headers, file=out) + return out.getvalue() + + @property + def connection(self) -> Optional["Connection"]: + return self._connection + + @reify + def history(self) -> Tuple["ClientResponse", ...]: + """A sequence of of responses, if redirects occurred.""" + return self._history + + @reify + def links(self) -> "MultiDictProxy[MultiDictProxy[Union[str, URL]]]": + links_str = ", ".join(self.headers.getall("link", [])) + + if not links_str: + return MultiDictProxy(MultiDict()) + + links: MultiDict[MultiDictProxy[Union[str, URL]]] = MultiDict() + + for val in re.split(r",(?=\s*<)", links_str): + match = re.match(r"\s*<(.*)>(.*)", val) + if match is None: # pragma: no cover + # the check exists to suppress mypy error + continue + url, params_str = match.groups() + params = params_str.split(";")[1:] + + link: MultiDict[Union[str, URL]] = MultiDict() + + for param in params: + match = re.match(r"^\s*(\S*)\s*=\s*(['\"]?)(.*?)(\2)\s*$", param, re.M) + if match is None: # pragma: no cover + # the check exists to suppress mypy error + continue + key, _, value, _ = match.groups() + + link.add(key, value) + + key = link.get("rel", url) + + link.add("url", self.url.join(URL(url))) + + links.add(str(key), MultiDictProxy(link)) + + return MultiDictProxy(links) + + async def start(self, connection: "Connection") -> "ClientResponse": + """Start response processing.""" + self._closed = False + self._protocol = connection.protocol + self._connection = connection + + with self._timer: + while True: + # read response + try: + protocol = self._protocol + message, payload = await protocol.read() # type: ignore[union-attr] + except http.HttpProcessingError as exc: + raise ClientResponseError( + self.request_info, + self.history, + status=exc.code, + message=exc.message, + headers=exc.headers, + ) from exc + + if message.code < 100 or message.code > 199 or message.code == 101: + break + + if self._continue is not None: + set_result(self._continue, True) + self._continue = None + + # payload eof handler + payload.on_eof(self._response_eof) + + # response status + self.version = message.version + self.status = message.code + self.reason = message.reason + + # headers + self._headers = message.headers # type is CIMultiDictProxy + self._raw_headers = message.raw_headers # type is Tuple[bytes, bytes] + + # payload + self.content = payload + + # cookies + for hdr in self.headers.getall(hdrs.SET_COOKIE, ()): + try: + self.cookies.load(hdr) + except CookieError as exc: + client_logger.warning("Can not load response cookies: %s", exc) + return self + + def _response_eof(self) -> None: + if self._closed: + return + + # protocol could be None because connection could be detached + protocol = self._connection and self._connection.protocol + if protocol is not None and protocol.upgraded: + return + + self._closed = True + self._cleanup_writer() + self._release_connection() + + @property + def closed(self) -> bool: + return self._closed + + def close(self) -> None: + if not self._released: + self._notify_content() + + self._closed = True + if self._loop is None or self._loop.is_closed(): + return + + self._cleanup_writer() + if self._connection is not None: + self._connection.close() + self._connection = None + + def release(self) -> Any: + if not self._released: + self._notify_content() + + self._closed = True + + self._cleanup_writer() + self._release_connection() + return noop() + + @property + def ok(self) -> bool: + """Returns ``True`` if ``status`` is less than ``400``, ``False`` if not. + + This is **not** a check for ``200 OK`` but a check that the response + status is under 400. + """ + return 400 > self.status + + def raise_for_status(self) -> None: + if not self.ok: + # reason should always be not None for a started response + assert self.reason is not None + self.release() + raise ClientResponseError( + self.request_info, + self.history, + status=self.status, + message=self.reason, + headers=self.headers, + ) + + def _release_connection(self) -> None: + if self._connection is not None: + if self._writer is None: + self._connection.release() + self._connection = None + else: + self._writer.add_done_callback(lambda f: self._release_connection()) + + async def _wait_released(self) -> None: + if self._writer is not None: + await self._writer + self._release_connection() + + def _cleanup_writer(self) -> None: + if self._writer is not None: + self._writer.cancel() + self._session = None + + def _notify_content(self) -> None: + content = self.content + if content and content.exception() is None: + set_exception(content, ClientConnectionError("Connection closed")) + self._released = True + + async def wait_for_close(self) -> None: + if self._writer is not None: + await self._writer + self.release() + + async def read(self) -> bytes: + """Read response payload.""" + if self._body is None: + try: + self._body = await self.content.read() + for trace in self._traces: + await trace.send_response_chunk_received( + self.method, self.url, self._body + ) + except BaseException: + self.close() + raise + elif self._released: # Response explicitly released + raise ClientConnectionError("Connection closed") + + protocol = self._connection and self._connection.protocol + if protocol is None or not protocol.upgraded: + await self._wait_released() # Underlying connection released + return self._body # type: ignore[no-any-return] + + def get_encoding(self) -> str: + ctype = self.headers.get(hdrs.CONTENT_TYPE, "").lower() + mimetype = helpers.parse_mimetype(ctype) + + encoding = mimetype.parameters.get("charset") + if encoding: + with contextlib.suppress(LookupError): + return codecs.lookup(encoding).name + + if mimetype.type == "application" and ( + mimetype.subtype == "json" or mimetype.subtype == "rdap" + ): + # RFC 7159 states that the default encoding is UTF-8. + # RFC 7483 defines application/rdap+json + return "utf-8" + + if self._body is None: + raise RuntimeError( + "Cannot compute fallback encoding of a not yet read body" + ) + + return self._resolve_charset(self, self._body) + + async def text(self, encoding: Optional[str] = None, errors: str = "strict") -> str: + """Read response payload and decode.""" + if self._body is None: + await self.read() + + if encoding is None: + encoding = self.get_encoding() + + return self._body.decode( # type: ignore[no-any-return,union-attr] + encoding, errors=errors + ) + + async def json( + self, + *, + encoding: Optional[str] = None, + loads: JSONDecoder = DEFAULT_JSON_DECODER, + content_type: Optional[str] = "application/json", + ) -> Any: + """Read and decodes JSON response.""" + if self._body is None: + await self.read() + + if content_type: + ctype = self.headers.get(hdrs.CONTENT_TYPE, "").lower() + if not _is_expected_content_type(ctype, content_type): + raise ContentTypeError( + self.request_info, + self.history, + message=( + "Attempt to decode JSON with " "unexpected mimetype: %s" % ctype + ), + headers=self.headers, + ) + + stripped = self._body.strip() # type: ignore[union-attr] + if not stripped: + return None + + if encoding is None: + encoding = self.get_encoding() + + return loads(stripped.decode(encoding)) + + async def __aenter__(self) -> "ClientResponse": + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + # similar to _RequestContextManager, we do not need to check + # for exceptions, response object can close connection + # if state is broken + self.release() + await self.wait_for_close() diff --git a/llm/Lib/site-packages/aiohttp/client_ws.py b/llm/Lib/site-packages/aiohttp/client_ws.py new file mode 100644 index 0000000000000000000000000000000000000000..d213aa37519372f110ff3784feac2af526ff3b1f --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/client_ws.py @@ -0,0 +1,315 @@ +"""WebSocket client for asyncio.""" + +import asyncio +import sys +from typing import Any, Optional, cast + +from .client_exceptions import ClientError +from .client_reqrep import ClientResponse +from .helpers import call_later, set_result +from .http import ( + WS_CLOSED_MESSAGE, + WS_CLOSING_MESSAGE, + WebSocketError, + WSCloseCode, + WSMessage, + WSMsgType, +) +from .http_websocket import WebSocketWriter # WSMessage +from .streams import EofStream, FlowControlDataQueue +from .typedefs import ( + DEFAULT_JSON_DECODER, + DEFAULT_JSON_ENCODER, + JSONDecoder, + JSONEncoder, +) + +if sys.version_info >= (3, 11): + import asyncio as async_timeout +else: + import async_timeout + + +class ClientWebSocketResponse: + def __init__( + self, + reader: "FlowControlDataQueue[WSMessage]", + writer: WebSocketWriter, + protocol: Optional[str], + response: ClientResponse, + timeout: float, + autoclose: bool, + autoping: bool, + loop: asyncio.AbstractEventLoop, + *, + receive_timeout: Optional[float] = None, + heartbeat: Optional[float] = None, + compress: int = 0, + client_notakeover: bool = False, + ) -> None: + self._response = response + self._conn = response.connection + + self._writer = writer + self._reader = reader + self._protocol = protocol + self._closed = False + self._closing = False + self._close_code: Optional[int] = None + self._timeout = timeout + self._receive_timeout = receive_timeout + self._autoclose = autoclose + self._autoping = autoping + self._heartbeat = heartbeat + self._heartbeat_cb: Optional[asyncio.TimerHandle] = None + if heartbeat is not None: + self._pong_heartbeat = heartbeat / 2.0 + self._pong_response_cb: Optional[asyncio.TimerHandle] = None + self._loop = loop + self._waiting: Optional[asyncio.Future[bool]] = None + self._exception: Optional[BaseException] = None + self._compress = compress + self._client_notakeover = client_notakeover + + self._reset_heartbeat() + + def _cancel_heartbeat(self) -> None: + if self._pong_response_cb is not None: + self._pong_response_cb.cancel() + self._pong_response_cb = None + + if self._heartbeat_cb is not None: + self._heartbeat_cb.cancel() + self._heartbeat_cb = None + + def _reset_heartbeat(self) -> None: + self._cancel_heartbeat() + + if self._heartbeat is not None: + self._heartbeat_cb = call_later( + self._send_heartbeat, + self._heartbeat, + self._loop, + timeout_ceil_threshold=self._conn._connector._timeout_ceil_threshold + if self._conn is not None + else 5, + ) + + def _send_heartbeat(self) -> None: + if self._heartbeat is not None and not self._closed: + # fire-and-forget a task is not perfect but maybe ok for + # sending ping. Otherwise we need a long-living heartbeat + # task in the class. + self._loop.create_task(self._writer.ping()) + + if self._pong_response_cb is not None: + self._pong_response_cb.cancel() + self._pong_response_cb = call_later( + self._pong_not_received, + self._pong_heartbeat, + self._loop, + timeout_ceil_threshold=self._conn._connector._timeout_ceil_threshold + if self._conn is not None + else 5, + ) + + def _pong_not_received(self) -> None: + if not self._closed: + self._closed = True + self._close_code = WSCloseCode.ABNORMAL_CLOSURE + self._exception = asyncio.TimeoutError() + self._response.close() + + @property + def closed(self) -> bool: + return self._closed + + @property + def close_code(self) -> Optional[int]: + return self._close_code + + @property + def protocol(self) -> Optional[str]: + return self._protocol + + @property + def compress(self) -> int: + return self._compress + + @property + def client_notakeover(self) -> bool: + return self._client_notakeover + + def get_extra_info(self, name: str, default: Any = None) -> Any: + """extra info from connection transport""" + conn = self._response.connection + if conn is None: + return default + transport = conn.transport + if transport is None: + return default + return transport.get_extra_info(name, default) + + def exception(self) -> Optional[BaseException]: + return self._exception + + async def ping(self, message: bytes = b"") -> None: + await self._writer.ping(message) + + async def pong(self, message: bytes = b"") -> None: + await self._writer.pong(message) + + async def send_str(self, data: str, compress: Optional[int] = None) -> None: + if not isinstance(data, str): + raise TypeError("data argument must be str (%r)" % type(data)) + await self._writer.send(data, binary=False, compress=compress) + + async def send_bytes(self, data: bytes, compress: Optional[int] = None) -> None: + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError("data argument must be byte-ish (%r)" % type(data)) + await self._writer.send(data, binary=True, compress=compress) + + async def send_json( + self, + data: Any, + compress: Optional[int] = None, + *, + dumps: JSONEncoder = DEFAULT_JSON_ENCODER, + ) -> None: + await self.send_str(dumps(data), compress=compress) + + async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool: + # we need to break `receive()` cycle first, + # `close()` may be called from different task + if self._waiting is not None and not self._closing: + self._closing = True + self._reader.feed_data(WS_CLOSING_MESSAGE, 0) + await self._waiting + + if not self._closed: + self._cancel_heartbeat() + self._closed = True + try: + await self._writer.close(code, message) + except asyncio.CancelledError: + self._close_code = WSCloseCode.ABNORMAL_CLOSURE + self._response.close() + raise + except Exception as exc: + self._close_code = WSCloseCode.ABNORMAL_CLOSURE + self._exception = exc + self._response.close() + return True + + if self._close_code: + self._response.close() + return True + + while True: + try: + async with async_timeout.timeout(self._timeout): + msg = await self._reader.read() + except asyncio.CancelledError: + self._close_code = WSCloseCode.ABNORMAL_CLOSURE + self._response.close() + raise + except Exception as exc: + self._close_code = WSCloseCode.ABNORMAL_CLOSURE + self._exception = exc + self._response.close() + return True + + if msg.type == WSMsgType.CLOSE: + self._close_code = msg.data + self._response.close() + return True + else: + return False + + async def receive(self, timeout: Optional[float] = None) -> WSMessage: + while True: + if self._waiting is not None: + raise RuntimeError("Concurrent call to receive() is not allowed") + + if self._closed: + return WS_CLOSED_MESSAGE + elif self._closing: + await self.close() + return WS_CLOSED_MESSAGE + + try: + self._waiting = self._loop.create_future() + try: + async with async_timeout.timeout(timeout or self._receive_timeout): + msg = await self._reader.read() + self._reset_heartbeat() + finally: + waiter = self._waiting + self._waiting = None + set_result(waiter, True) + except (asyncio.CancelledError, asyncio.TimeoutError): + self._close_code = WSCloseCode.ABNORMAL_CLOSURE + raise + except EofStream: + self._close_code = WSCloseCode.OK + await self.close() + return WSMessage(WSMsgType.CLOSED, None, None) + except ClientError: + self._closed = True + self._close_code = WSCloseCode.ABNORMAL_CLOSURE + return WS_CLOSED_MESSAGE + except WebSocketError as exc: + self._close_code = exc.code + await self.close(code=exc.code) + return WSMessage(WSMsgType.ERROR, exc, None) + except Exception as exc: + self._exception = exc + self._closing = True + self._close_code = WSCloseCode.ABNORMAL_CLOSURE + await self.close() + return WSMessage(WSMsgType.ERROR, exc, None) + + if msg.type == WSMsgType.CLOSE: + self._closing = True + self._close_code = msg.data + if not self._closed and self._autoclose: + await self.close() + elif msg.type == WSMsgType.CLOSING: + self._closing = True + elif msg.type == WSMsgType.PING and self._autoping: + await self.pong(msg.data) + continue + elif msg.type == WSMsgType.PONG and self._autoping: + continue + + return msg + + async def receive_str(self, *, timeout: Optional[float] = None) -> str: + msg = await self.receive(timeout) + if msg.type != WSMsgType.TEXT: + raise TypeError(f"Received message {msg.type}:{msg.data!r} is not str") + return cast(str, msg.data) + + async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes: + msg = await self.receive(timeout) + if msg.type != WSMsgType.BINARY: + raise TypeError(f"Received message {msg.type}:{msg.data!r} is not bytes") + return cast(bytes, msg.data) + + async def receive_json( + self, + *, + loads: JSONDecoder = DEFAULT_JSON_DECODER, + timeout: Optional[float] = None, + ) -> Any: + data = await self.receive_str(timeout=timeout) + return loads(data) + + def __aiter__(self) -> "ClientWebSocketResponse": + return self + + async def __anext__(self) -> WSMessage: + msg = await self.receive() + if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): + raise StopAsyncIteration + return msg diff --git a/llm/Lib/site-packages/aiohttp/compression_utils.py b/llm/Lib/site-packages/aiohttp/compression_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cc2a13e3aae8df7619e84efc1ec2044255121fb3 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/compression_utils.py @@ -0,0 +1,157 @@ +import asyncio +import zlib +from concurrent.futures import Executor +from typing import Optional, cast + +try: + try: + import brotlicffi as brotli + except ImportError: + import brotli + + HAS_BROTLI = True +except ImportError: # pragma: no cover + HAS_BROTLI = False + +MAX_SYNC_CHUNK_SIZE = 1024 + + +def encoding_to_mode( + encoding: Optional[str] = None, + suppress_deflate_header: bool = False, +) -> int: + if encoding == "gzip": + return 16 + zlib.MAX_WBITS + + return -zlib.MAX_WBITS if suppress_deflate_header else zlib.MAX_WBITS + + +class ZlibBaseHandler: + def __init__( + self, + mode: int, + executor: Optional[Executor] = None, + max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE, + ): + self._mode = mode + self._executor = executor + self._max_sync_chunk_size = max_sync_chunk_size + + +class ZLibCompressor(ZlibBaseHandler): + def __init__( + self, + encoding: Optional[str] = None, + suppress_deflate_header: bool = False, + level: Optional[int] = None, + wbits: Optional[int] = None, + strategy: int = zlib.Z_DEFAULT_STRATEGY, + executor: Optional[Executor] = None, + max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE, + ): + super().__init__( + mode=encoding_to_mode(encoding, suppress_deflate_header) + if wbits is None + else wbits, + executor=executor, + max_sync_chunk_size=max_sync_chunk_size, + ) + if level is None: + self._compressor = zlib.compressobj(wbits=self._mode, strategy=strategy) + else: + self._compressor = zlib.compressobj( + wbits=self._mode, strategy=strategy, level=level + ) + self._compress_lock = asyncio.Lock() + + def compress_sync(self, data: bytes) -> bytes: + return self._compressor.compress(data) + + async def compress(self, data: bytes) -> bytes: + async with self._compress_lock: + # To ensure the stream is consistent in the event + # there are multiple writers, we need to lock + # the compressor so that only one writer can + # compress at a time. + if ( + self._max_sync_chunk_size is not None + and len(data) > self._max_sync_chunk_size + ): + return await asyncio.get_event_loop().run_in_executor( + self._executor, self.compress_sync, data + ) + return self.compress_sync(data) + + def flush(self, mode: int = zlib.Z_FINISH) -> bytes: + return self._compressor.flush(mode) + + +class ZLibDecompressor(ZlibBaseHandler): + def __init__( + self, + encoding: Optional[str] = None, + suppress_deflate_header: bool = False, + executor: Optional[Executor] = None, + max_sync_chunk_size: Optional[int] = MAX_SYNC_CHUNK_SIZE, + ): + super().__init__( + mode=encoding_to_mode(encoding, suppress_deflate_header), + executor=executor, + max_sync_chunk_size=max_sync_chunk_size, + ) + self._decompressor = zlib.decompressobj(wbits=self._mode) + + def decompress_sync(self, data: bytes, max_length: int = 0) -> bytes: + return self._decompressor.decompress(data, max_length) + + async def decompress(self, data: bytes, max_length: int = 0) -> bytes: + if ( + self._max_sync_chunk_size is not None + and len(data) > self._max_sync_chunk_size + ): + return await asyncio.get_event_loop().run_in_executor( + self._executor, self.decompress_sync, data, max_length + ) + return self.decompress_sync(data, max_length) + + def flush(self, length: int = 0) -> bytes: + return ( + self._decompressor.flush(length) + if length > 0 + else self._decompressor.flush() + ) + + @property + def eof(self) -> bool: + return self._decompressor.eof + + @property + def unconsumed_tail(self) -> bytes: + return self._decompressor.unconsumed_tail + + @property + def unused_data(self) -> bytes: + return self._decompressor.unused_data + + +class BrotliDecompressor: + # Supports both 'brotlipy' and 'Brotli' packages + # since they share an import name. The top branches + # are for 'brotlipy' and bottom branches for 'Brotli' + def __init__(self) -> None: + if not HAS_BROTLI: + raise RuntimeError( + "The brotli decompression is not available. " + "Please install `Brotli` module" + ) + self._obj = brotli.Decompressor() + + def decompress_sync(self, data: bytes) -> bytes: + if hasattr(self._obj, "decompress"): + return cast(bytes, self._obj.decompress(data)) + return cast(bytes, self._obj.process(data)) + + def flush(self) -> bytes: + if hasattr(self._obj, "flush"): + return cast(bytes, self._obj.flush()) + return b"" diff --git a/llm/Lib/site-packages/aiohttp/connector.py b/llm/Lib/site-packages/aiohttp/connector.py new file mode 100644 index 0000000000000000000000000000000000000000..b2ce0ef5b4d14642d9f9c08783d35ddbbccb556a --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/connector.py @@ -0,0 +1,1511 @@ +import asyncio +import functools +import random +import sys +import traceback +import warnings +from collections import defaultdict, deque +from contextlib import suppress +from http import HTTPStatus +from http.cookies import SimpleCookie +from itertools import cycle, islice +from time import monotonic +from types import TracebackType +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + DefaultDict, + Dict, + Iterator, + List, + Literal, + Optional, + Set, + Tuple, + Type, + Union, + cast, +) + +import attr + +from . import hdrs, helpers +from .abc import AbstractResolver +from .client_exceptions import ( + ClientConnectionError, + ClientConnectorCertificateError, + ClientConnectorError, + ClientConnectorSSLError, + ClientHttpProxyError, + ClientProxyConnectionError, + ServerFingerprintMismatch, + UnixClientConnectorError, + cert_errors, + ssl_errors, +) +from .client_proto import ResponseHandler +from .client_reqrep import ClientRequest, Fingerprint, _merge_ssl_params +from .helpers import ceil_timeout, get_running_loop, is_ip_address, noop, sentinel +from .locks import EventResultOrError +from .resolver import DefaultResolver + +try: + import ssl + + SSLContext = ssl.SSLContext +except ImportError: # pragma: no cover + ssl = None # type: ignore[assignment] + SSLContext = object # type: ignore[misc,assignment] + + +__all__ = ("BaseConnector", "TCPConnector", "UnixConnector", "NamedPipeConnector") + + +if TYPE_CHECKING: + from .client import ClientTimeout + from .client_reqrep import ConnectionKey + from .tracing import Trace + + +class _DeprecationWaiter: + __slots__ = ("_awaitable", "_awaited") + + def __init__(self, awaitable: Awaitable[Any]) -> None: + self._awaitable = awaitable + self._awaited = False + + def __await__(self) -> Any: + self._awaited = True + return self._awaitable.__await__() + + def __del__(self) -> None: + if not self._awaited: + warnings.warn( + "Connector.close() is a coroutine, " + "please use await connector.close()", + DeprecationWarning, + ) + + +class Connection: + + _source_traceback = None + _transport = None + + def __init__( + self, + connector: "BaseConnector", + key: "ConnectionKey", + protocol: ResponseHandler, + loop: asyncio.AbstractEventLoop, + ) -> None: + self._key = key + self._connector = connector + self._loop = loop + self._protocol: Optional[ResponseHandler] = protocol + self._callbacks: List[Callable[[], None]] = [] + + if loop.get_debug(): + self._source_traceback = traceback.extract_stack(sys._getframe(1)) + + def __repr__(self) -> str: + return f"Connection<{self._key}>" + + def __del__(self, _warnings: Any = warnings) -> None: + if self._protocol is not None: + kwargs = {"source": self} + _warnings.warn(f"Unclosed connection {self!r}", ResourceWarning, **kwargs) + if self._loop.is_closed(): + return + + self._connector._release(self._key, self._protocol, should_close=True) + + context = {"client_connection": self, "message": "Unclosed connection"} + if self._source_traceback is not None: + context["source_traceback"] = self._source_traceback + self._loop.call_exception_handler(context) + + def __bool__(self) -> Literal[True]: + """Force subclasses to not be falsy, to make checks simpler.""" + return True + + @property + def loop(self) -> asyncio.AbstractEventLoop: + warnings.warn( + "connector.loop property is deprecated", DeprecationWarning, stacklevel=2 + ) + return self._loop + + @property + def transport(self) -> Optional[asyncio.Transport]: + if self._protocol is None: + return None + return self._protocol.transport + + @property + def protocol(self) -> Optional[ResponseHandler]: + return self._protocol + + def add_callback(self, callback: Callable[[], None]) -> None: + if callback is not None: + self._callbacks.append(callback) + + def _notify_release(self) -> None: + callbacks, self._callbacks = self._callbacks[:], [] + + for cb in callbacks: + with suppress(Exception): + cb() + + def close(self) -> None: + self._notify_release() + + if self._protocol is not None: + self._connector._release(self._key, self._protocol, should_close=True) + self._protocol = None + + def release(self) -> None: + self._notify_release() + + if self._protocol is not None: + self._connector._release( + self._key, self._protocol, should_close=self._protocol.should_close + ) + self._protocol = None + + @property + def closed(self) -> bool: + return self._protocol is None or not self._protocol.is_connected() + + +class _TransportPlaceholder: + """placeholder for BaseConnector.connect function""" + + def close(self) -> None: + pass + + +class BaseConnector: + """Base connector class. + + keepalive_timeout - (optional) Keep-alive timeout. + force_close - Set to True to force close and do reconnect + after each request (and between redirects). + limit - The total number of simultaneous connections. + limit_per_host - Number of simultaneous connections to one host. + enable_cleanup_closed - Enables clean-up closed ssl transports. + Disabled by default. + timeout_ceil_threshold - Trigger ceiling of timeout values when + it's above timeout_ceil_threshold. + loop - Optional event loop. + """ + + _closed = True # prevent AttributeError in __del__ if ctor was failed + _source_traceback = None + + # abort transport after 2 seconds (cleanup broken connections) + _cleanup_closed_period = 2.0 + + def __init__( + self, + *, + keepalive_timeout: Union[object, None, float] = sentinel, + force_close: bool = False, + limit: int = 100, + limit_per_host: int = 0, + enable_cleanup_closed: bool = False, + loop: Optional[asyncio.AbstractEventLoop] = None, + timeout_ceil_threshold: float = 5, + ) -> None: + + if force_close: + if keepalive_timeout is not None and keepalive_timeout is not sentinel: + raise ValueError( + "keepalive_timeout cannot " "be set if force_close is True" + ) + else: + if keepalive_timeout is sentinel: + keepalive_timeout = 15.0 + + loop = get_running_loop(loop) + self._timeout_ceil_threshold = timeout_ceil_threshold + + self._closed = False + if loop.get_debug(): + self._source_traceback = traceback.extract_stack(sys._getframe(1)) + + self._conns: Dict[ConnectionKey, List[Tuple[ResponseHandler, float]]] = {} + self._limit = limit + self._limit_per_host = limit_per_host + self._acquired: Set[ResponseHandler] = set() + self._acquired_per_host: DefaultDict[ + ConnectionKey, Set[ResponseHandler] + ] = defaultdict(set) + self._keepalive_timeout = cast(float, keepalive_timeout) + self._force_close = force_close + + # {host_key: FIFO list of waiters} + self._waiters = defaultdict(deque) # type: ignore[var-annotated] + + self._loop = loop + self._factory = functools.partial(ResponseHandler, loop=loop) + + self.cookies = SimpleCookie() + + # start keep-alive connection cleanup task + self._cleanup_handle: Optional[asyncio.TimerHandle] = None + + # start cleanup closed transports task + self._cleanup_closed_handle: Optional[asyncio.TimerHandle] = None + self._cleanup_closed_disabled = not enable_cleanup_closed + self._cleanup_closed_transports: List[Optional[asyncio.Transport]] = [] + self._cleanup_closed() + + def __del__(self, _warnings: Any = warnings) -> None: + if self._closed: + return + if not self._conns: + return + + conns = [repr(c) for c in self._conns.values()] + + self._close() + + kwargs = {"source": self} + _warnings.warn(f"Unclosed connector {self!r}", ResourceWarning, **kwargs) + context = { + "connector": self, + "connections": conns, + "message": "Unclosed connector", + } + if self._source_traceback is not None: + context["source_traceback"] = self._source_traceback + self._loop.call_exception_handler(context) + + def __enter__(self) -> "BaseConnector": + warnings.warn( + '"with Connector():" is deprecated, ' + 'use "async with Connector():" instead', + DeprecationWarning, + ) + return self + + def __exit__(self, *exc: Any) -> None: + self._close() + + async def __aenter__(self) -> "BaseConnector": + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]] = None, + exc_value: Optional[BaseException] = None, + exc_traceback: Optional[TracebackType] = None, + ) -> None: + await self.close() + + @property + def force_close(self) -> bool: + """Ultimately close connection on releasing if True.""" + return self._force_close + + @property + def limit(self) -> int: + """The total number for simultaneous connections. + + If limit is 0 the connector has no limit. + The default limit size is 100. + """ + return self._limit + + @property + def limit_per_host(self) -> int: + """The limit for simultaneous connections to the same endpoint. + + Endpoints are the same if they are have equal + (host, port, is_ssl) triple. + """ + return self._limit_per_host + + def _cleanup(self) -> None: + """Cleanup unused transports.""" + if self._cleanup_handle: + self._cleanup_handle.cancel() + # _cleanup_handle should be unset, otherwise _release() will not + # recreate it ever! + self._cleanup_handle = None + + now = self._loop.time() + timeout = self._keepalive_timeout + + if self._conns: + connections = {} + deadline = now - timeout + for key, conns in self._conns.items(): + alive = [] + for proto, use_time in conns: + if proto.is_connected(): + if use_time - deadline < 0: + transport = proto.transport + proto.close() + if key.is_ssl and not self._cleanup_closed_disabled: + self._cleanup_closed_transports.append(transport) + else: + alive.append((proto, use_time)) + else: + transport = proto.transport + proto.close() + if key.is_ssl and not self._cleanup_closed_disabled: + self._cleanup_closed_transports.append(transport) + + if alive: + connections[key] = alive + + self._conns = connections + + if self._conns: + self._cleanup_handle = helpers.weakref_handle( + self, + "_cleanup", + timeout, + self._loop, + timeout_ceil_threshold=self._timeout_ceil_threshold, + ) + + def _drop_acquired_per_host( + self, key: "ConnectionKey", val: ResponseHandler + ) -> None: + acquired_per_host = self._acquired_per_host + if key not in acquired_per_host: + return + conns = acquired_per_host[key] + conns.remove(val) + if not conns: + del self._acquired_per_host[key] + + def _cleanup_closed(self) -> None: + """Double confirmation for transport close. + + Some broken ssl servers may leave socket open without proper close. + """ + if self._cleanup_closed_handle: + self._cleanup_closed_handle.cancel() + + for transport in self._cleanup_closed_transports: + if transport is not None: + transport.abort() + + self._cleanup_closed_transports = [] + + if not self._cleanup_closed_disabled: + self._cleanup_closed_handle = helpers.weakref_handle( + self, + "_cleanup_closed", + self._cleanup_closed_period, + self._loop, + timeout_ceil_threshold=self._timeout_ceil_threshold, + ) + + def close(self) -> Awaitable[None]: + """Close all opened transports.""" + self._close() + return _DeprecationWaiter(noop()) + + def _close(self) -> None: + if self._closed: + return + + self._closed = True + + try: + if self._loop.is_closed(): + return + + # cancel cleanup task + if self._cleanup_handle: + self._cleanup_handle.cancel() + + # cancel cleanup close task + if self._cleanup_closed_handle: + self._cleanup_closed_handle.cancel() + + for data in self._conns.values(): + for proto, t0 in data: + proto.close() + + for proto in self._acquired: + proto.close() + + for transport in self._cleanup_closed_transports: + if transport is not None: + transport.abort() + + finally: + self._conns.clear() + self._acquired.clear() + self._waiters.clear() + self._cleanup_handle = None + self._cleanup_closed_transports.clear() + self._cleanup_closed_handle = None + + @property + def closed(self) -> bool: + """Is connector closed. + + A readonly property. + """ + return self._closed + + def _available_connections(self, key: "ConnectionKey") -> int: + """ + Return number of available connections. + + The limit, limit_per_host and the connection key are taken into account. + + If it returns less than 1 means that there are no connections + available. + """ + if self._limit: + # total calc available connections + available = self._limit - len(self._acquired) + + # check limit per host + if ( + self._limit_per_host + and available > 0 + and key in self._acquired_per_host + ): + acquired = self._acquired_per_host.get(key) + assert acquired is not None + available = self._limit_per_host - len(acquired) + + elif self._limit_per_host and key in self._acquired_per_host: + # check limit per host + acquired = self._acquired_per_host.get(key) + assert acquired is not None + available = self._limit_per_host - len(acquired) + else: + available = 1 + + return available + + async def connect( + self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" + ) -> Connection: + """Get from pool or create new connection.""" + key = req.connection_key + available = self._available_connections(key) + + # Wait if there are no available connections or if there are/were + # waiters (i.e. don't steal connection from a waiter about to wake up) + if available <= 0 or key in self._waiters: + fut = self._loop.create_future() + + # This connection will now count towards the limit. + self._waiters[key].append(fut) + + if traces: + for trace in traces: + await trace.send_connection_queued_start() + + try: + await fut + except BaseException as e: + if key in self._waiters: + # remove a waiter even if it was cancelled, normally it's + # removed when it's notified + try: + self._waiters[key].remove(fut) + except ValueError: # fut may no longer be in list + pass + + raise e + finally: + if key in self._waiters and not self._waiters[key]: + del self._waiters[key] + + if traces: + for trace in traces: + await trace.send_connection_queued_end() + + proto = self._get(key) + if proto is None: + placeholder = cast(ResponseHandler, _TransportPlaceholder()) + self._acquired.add(placeholder) + self._acquired_per_host[key].add(placeholder) + + if traces: + for trace in traces: + await trace.send_connection_create_start() + + try: + proto = await self._create_connection(req, traces, timeout) + if self._closed: + proto.close() + raise ClientConnectionError("Connector is closed.") + except BaseException: + if not self._closed: + self._acquired.remove(placeholder) + self._drop_acquired_per_host(key, placeholder) + self._release_waiter() + raise + else: + if not self._closed: + self._acquired.remove(placeholder) + self._drop_acquired_per_host(key, placeholder) + + if traces: + for trace in traces: + await trace.send_connection_create_end() + else: + if traces: + # Acquire the connection to prevent race conditions with limits + placeholder = cast(ResponseHandler, _TransportPlaceholder()) + self._acquired.add(placeholder) + self._acquired_per_host[key].add(placeholder) + for trace in traces: + await trace.send_connection_reuseconn() + self._acquired.remove(placeholder) + self._drop_acquired_per_host(key, placeholder) + + self._acquired.add(proto) + self._acquired_per_host[key].add(proto) + return Connection(self, key, proto, self._loop) + + def _get(self, key: "ConnectionKey") -> Optional[ResponseHandler]: + try: + conns = self._conns[key] + except KeyError: + return None + + t1 = self._loop.time() + while conns: + proto, t0 = conns.pop() + if proto.is_connected(): + if t1 - t0 > self._keepalive_timeout: + transport = proto.transport + proto.close() + # only for SSL transports + if key.is_ssl and not self._cleanup_closed_disabled: + self._cleanup_closed_transports.append(transport) + else: + if not conns: + # The very last connection was reclaimed: drop the key + del self._conns[key] + return proto + else: + transport = proto.transport + proto.close() + if key.is_ssl and not self._cleanup_closed_disabled: + self._cleanup_closed_transports.append(transport) + + # No more connections: drop the key + del self._conns[key] + return None + + def _release_waiter(self) -> None: + """ + Iterates over all waiters until one to be released is found. + + The one to be released is not finished and + belongs to a host that has available connections. + """ + if not self._waiters: + return + + # Having the dict keys ordered this avoids to iterate + # at the same order at each call. + queues = list(self._waiters.keys()) + random.shuffle(queues) + + for key in queues: + if self._available_connections(key) < 1: + continue + + waiters = self._waiters[key] + while waiters: + waiter = waiters.popleft() + if not waiter.done(): + waiter.set_result(None) + return + + def _release_acquired(self, key: "ConnectionKey", proto: ResponseHandler) -> None: + if self._closed: + # acquired connection is already released on connector closing + return + + try: + self._acquired.remove(proto) + self._drop_acquired_per_host(key, proto) + except KeyError: # pragma: no cover + # this may be result of undetermenistic order of objects + # finalization due garbage collection. + pass + else: + self._release_waiter() + + def _release( + self, + key: "ConnectionKey", + protocol: ResponseHandler, + *, + should_close: bool = False, + ) -> None: + if self._closed: + # acquired connection is already released on connector closing + return + + self._release_acquired(key, protocol) + + if self._force_close: + should_close = True + + if should_close or protocol.should_close: + transport = protocol.transport + protocol.close() + + if key.is_ssl and not self._cleanup_closed_disabled: + self._cleanup_closed_transports.append(transport) + else: + conns = self._conns.get(key) + if conns is None: + conns = self._conns[key] = [] + conns.append((protocol, self._loop.time())) + + if self._cleanup_handle is None: + self._cleanup_handle = helpers.weakref_handle( + self, + "_cleanup", + self._keepalive_timeout, + self._loop, + timeout_ceil_threshold=self._timeout_ceil_threshold, + ) + + async def _create_connection( + self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" + ) -> ResponseHandler: + raise NotImplementedError() + + +class _DNSCacheTable: + def __init__(self, ttl: Optional[float] = None) -> None: + self._addrs_rr: Dict[Tuple[str, int], Tuple[Iterator[Dict[str, Any]], int]] = {} + self._timestamps: Dict[Tuple[str, int], float] = {} + self._ttl = ttl + + def __contains__(self, host: object) -> bool: + return host in self._addrs_rr + + def add(self, key: Tuple[str, int], addrs: List[Dict[str, Any]]) -> None: + self._addrs_rr[key] = (cycle(addrs), len(addrs)) + + if self._ttl is not None: + self._timestamps[key] = monotonic() + + def remove(self, key: Tuple[str, int]) -> None: + self._addrs_rr.pop(key, None) + + if self._ttl is not None: + self._timestamps.pop(key, None) + + def clear(self) -> None: + self._addrs_rr.clear() + self._timestamps.clear() + + def next_addrs(self, key: Tuple[str, int]) -> List[Dict[str, Any]]: + loop, length = self._addrs_rr[key] + addrs = list(islice(loop, length)) + # Consume one more element to shift internal state of `cycle` + next(loop) + return addrs + + def expired(self, key: Tuple[str, int]) -> bool: + if self._ttl is None: + return False + + return self._timestamps[key] + self._ttl < monotonic() + + +class TCPConnector(BaseConnector): + """TCP connector. + + verify_ssl - Set to True to check ssl certifications. + fingerprint - Pass the binary sha256 + digest of the expected certificate in DER format to verify + that the certificate the server presents matches. See also + https://en.wikipedia.org/wiki/Transport_Layer_Security#Certificate_pinning + resolver - Enable DNS lookups and use this + resolver + use_dns_cache - Use memory cache for DNS lookups. + ttl_dns_cache - Max seconds having cached a DNS entry, None forever. + family - socket address family + local_addr - local tuple of (host, port) to bind socket to + + keepalive_timeout - (optional) Keep-alive timeout. + force_close - Set to True to force close and do reconnect + after each request (and between redirects). + limit - The total number of simultaneous connections. + limit_per_host - Number of simultaneous connections to one host. + enable_cleanup_closed - Enables clean-up closed ssl transports. + Disabled by default. + loop - Optional event loop. + """ + + def __init__( + self, + *, + verify_ssl: bool = True, + fingerprint: Optional[bytes] = None, + use_dns_cache: bool = True, + ttl_dns_cache: Optional[int] = 10, + family: int = 0, + ssl_context: Optional[SSLContext] = None, + ssl: Union[bool, Fingerprint, SSLContext] = True, + local_addr: Optional[Tuple[str, int]] = None, + resolver: Optional[AbstractResolver] = None, + keepalive_timeout: Union[None, float, object] = sentinel, + force_close: bool = False, + limit: int = 100, + limit_per_host: int = 0, + enable_cleanup_closed: bool = False, + loop: Optional[asyncio.AbstractEventLoop] = None, + timeout_ceil_threshold: float = 5, + ): + super().__init__( + keepalive_timeout=keepalive_timeout, + force_close=force_close, + limit=limit, + limit_per_host=limit_per_host, + enable_cleanup_closed=enable_cleanup_closed, + loop=loop, + timeout_ceil_threshold=timeout_ceil_threshold, + ) + + self._ssl = _merge_ssl_params(ssl, verify_ssl, ssl_context, fingerprint) + if resolver is None: + resolver = DefaultResolver(loop=self._loop) + self._resolver = resolver + + self._use_dns_cache = use_dns_cache + self._cached_hosts = _DNSCacheTable(ttl=ttl_dns_cache) + self._throttle_dns_events: Dict[Tuple[str, int], EventResultOrError] = {} + self._family = family + self._local_addr = local_addr + + def close(self) -> Awaitable[None]: + """Close all ongoing DNS calls.""" + for ev in self._throttle_dns_events.values(): + ev.cancel() + + return super().close() + + @property + def family(self) -> int: + """Socket family like AF_INET.""" + return self._family + + @property + def use_dns_cache(self) -> bool: + """True if local DNS caching is enabled.""" + return self._use_dns_cache + + def clear_dns_cache( + self, host: Optional[str] = None, port: Optional[int] = None + ) -> None: + """Remove specified host/port or clear all dns local cache.""" + if host is not None and port is not None: + self._cached_hosts.remove((host, port)) + elif host is not None or port is not None: + raise ValueError("either both host and port " "or none of them are allowed") + else: + self._cached_hosts.clear() + + async def _resolve_host( + self, host: str, port: int, traces: Optional[List["Trace"]] = None + ) -> List[Dict[str, Any]]: + """Resolve host and return list of addresses.""" + if is_ip_address(host): + return [ + { + "hostname": host, + "host": host, + "port": port, + "family": self._family, + "proto": 0, + "flags": 0, + } + ] + + if not self._use_dns_cache: + + if traces: + for trace in traces: + await trace.send_dns_resolvehost_start(host) + + res = await self._resolver.resolve(host, port, family=self._family) + + if traces: + for trace in traces: + await trace.send_dns_resolvehost_end(host) + + return res + + key = (host, port) + if key in self._cached_hosts and not self._cached_hosts.expired(key): + # get result early, before any await (#4014) + result = self._cached_hosts.next_addrs(key) + + if traces: + for trace in traces: + await trace.send_dns_cache_hit(host) + return result + + # + # If multiple connectors are resolving the same host, we wait + # for the first one to resolve and then use the result for all of them. + # We use a throttle event to ensure that we only resolve the host once + # and then use the result for all the waiters. + # + # In this case we need to create a task to ensure that we can shield + # the task from cancellation as cancelling this lookup should not cancel + # the underlying lookup or else the cancel event will get broadcast to + # all the waiters across all connections. + # + resolved_host_task = asyncio.create_task( + self._resolve_host_with_throttle(key, host, port, traces) + ) + try: + return await asyncio.shield(resolved_host_task) + except asyncio.CancelledError: + + def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None: + with suppress(Exception, asyncio.CancelledError): + fut.result() + + resolved_host_task.add_done_callback(drop_exception) + raise + + async def _resolve_host_with_throttle( + self, + key: Tuple[str, int], + host: str, + port: int, + traces: Optional[List["Trace"]], + ) -> List[Dict[str, Any]]: + """Resolve host with a dns events throttle.""" + if key in self._throttle_dns_events: + # get event early, before any await (#4014) + event = self._throttle_dns_events[key] + if traces: + for trace in traces: + await trace.send_dns_cache_hit(host) + await event.wait() + else: + # update dict early, before any await (#4014) + self._throttle_dns_events[key] = EventResultOrError(self._loop) + if traces: + for trace in traces: + await trace.send_dns_cache_miss(host) + try: + + if traces: + for trace in traces: + await trace.send_dns_resolvehost_start(host) + + addrs = await self._resolver.resolve(host, port, family=self._family) + if traces: + for trace in traces: + await trace.send_dns_resolvehost_end(host) + + self._cached_hosts.add(key, addrs) + self._throttle_dns_events[key].set() + except BaseException as e: + # any DNS exception, independently of the implementation + # is set for the waiters to raise the same exception. + self._throttle_dns_events[key].set(exc=e) + raise + finally: + self._throttle_dns_events.pop(key) + + return self._cached_hosts.next_addrs(key) + + async def _create_connection( + self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" + ) -> ResponseHandler: + """Create connection. + + Has same keyword arguments as BaseEventLoop.create_connection. + """ + if req.proxy: + _, proto = await self._create_proxy_connection(req, traces, timeout) + else: + _, proto = await self._create_direct_connection(req, traces, timeout) + + return proto + + @staticmethod + @functools.lru_cache(None) + def _make_ssl_context(verified: bool) -> SSLContext: + if verified: + return ssl.create_default_context() + else: + sslcontext = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + sslcontext.options |= ssl.OP_NO_SSLv2 + sslcontext.options |= ssl.OP_NO_SSLv3 + sslcontext.check_hostname = False + sslcontext.verify_mode = ssl.CERT_NONE + try: + sslcontext.options |= ssl.OP_NO_COMPRESSION + except AttributeError as attr_err: + warnings.warn( + "{!s}: The Python interpreter is compiled " + "against OpenSSL < 1.0.0. Ref: " + "https://docs.python.org/3/library/ssl.html" + "#ssl.OP_NO_COMPRESSION".format(attr_err), + ) + sslcontext.set_default_verify_paths() + return sslcontext + + def _get_ssl_context(self, req: ClientRequest) -> Optional[SSLContext]: + """Logic to get the correct SSL context + + 0. if req.ssl is false, return None + + 1. if ssl_context is specified in req, use it + 2. if _ssl_context is specified in self, use it + 3. otherwise: + 1. if verify_ssl is not specified in req, use self.ssl_context + (will generate a default context according to self.verify_ssl) + 2. if verify_ssl is True in req, generate a default SSL context + 3. if verify_ssl is False in req, generate a SSL context that + won't verify + """ + if req.is_ssl(): + if ssl is None: # pragma: no cover + raise RuntimeError("SSL is not supported.") + sslcontext = req.ssl + if isinstance(sslcontext, ssl.SSLContext): + return sslcontext + if sslcontext is not True: + # not verified or fingerprinted + return self._make_ssl_context(False) + sslcontext = self._ssl + if isinstance(sslcontext, ssl.SSLContext): + return sslcontext + if sslcontext is not True: + # not verified or fingerprinted + return self._make_ssl_context(False) + return self._make_ssl_context(True) + else: + return None + + def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]: + ret = req.ssl + if isinstance(ret, Fingerprint): + return ret + ret = self._ssl + if isinstance(ret, Fingerprint): + return ret + return None + + async def _wrap_create_connection( + self, + *args: Any, + req: ClientRequest, + timeout: "ClientTimeout", + client_error: Type[Exception] = ClientConnectorError, + **kwargs: Any, + ) -> Tuple[asyncio.Transport, ResponseHandler]: + try: + async with ceil_timeout( + timeout.sock_connect, ceil_threshold=timeout.ceil_threshold + ): + return await self._loop.create_connection(*args, **kwargs) + except cert_errors as exc: + raise ClientConnectorCertificateError(req.connection_key, exc) from exc + except ssl_errors as exc: + raise ClientConnectorSSLError(req.connection_key, exc) from exc + except OSError as exc: + if exc.errno is None and isinstance(exc, asyncio.TimeoutError): + raise + raise client_error(req.connection_key, exc) from exc + + def _fail_on_no_start_tls(self, req: "ClientRequest") -> None: + """Raise a :py:exc:`RuntimeError` on missing ``start_tls()``. + + It is necessary for TLS-in-TLS so that it is possible to + send HTTPS queries through HTTPS proxies. + + This doesn't affect regular HTTP requests, though. + """ + if not req.is_ssl(): + return + + proxy_url = req.proxy + assert proxy_url is not None + if proxy_url.scheme != "https": + return + + self._check_loop_for_start_tls() + + def _check_loop_for_start_tls(self) -> None: + try: + self._loop.start_tls + except AttributeError as attr_exc: + raise RuntimeError( + "An HTTPS request is being sent through an HTTPS proxy. " + "This needs support for TLS in TLS but it is not implemented " + "in your runtime for the stdlib asyncio.\n\n" + "Please upgrade to Python 3.11 or higher. For more details, " + "please see:\n" + "* https://bugs.python.org/issue37179\n" + "* https://github.com/python/cpython/pull/28073\n" + "* https://docs.aiohttp.org/en/stable/" + "client_advanced.html#proxy-support\n" + "* https://github.com/aio-libs/aiohttp/discussions/6044\n", + ) from attr_exc + + def _loop_supports_start_tls(self) -> bool: + try: + self._check_loop_for_start_tls() + except RuntimeError: + return False + else: + return True + + def _warn_about_tls_in_tls( + self, + underlying_transport: asyncio.Transport, + req: ClientRequest, + ) -> None: + """Issue a warning if the requested URL has HTTPS scheme.""" + if req.request_info.url.scheme != "https": + return + + asyncio_supports_tls_in_tls = getattr( + underlying_transport, + "_start_tls_compatible", + False, + ) + + if asyncio_supports_tls_in_tls: + return + + warnings.warn( + "An HTTPS request is being sent through an HTTPS proxy. " + "This support for TLS in TLS is known to be disabled " + "in the stdlib asyncio (Python <3.11). This is why you'll probably see " + "an error in the log below.\n\n" + "It is possible to enable it via monkeypatching. " + "For more details, see:\n" + "* https://bugs.python.org/issue37179\n" + "* https://github.com/python/cpython/pull/28073\n\n" + "You can temporarily patch this as follows:\n" + "* https://docs.aiohttp.org/en/stable/client_advanced.html#proxy-support\n" + "* https://github.com/aio-libs/aiohttp/discussions/6044\n", + RuntimeWarning, + source=self, + # Why `4`? At least 3 of the calls in the stack originate + # from the methods in this class. + stacklevel=3, + ) + + async def _start_tls_connection( + self, + underlying_transport: asyncio.Transport, + req: ClientRequest, + timeout: "ClientTimeout", + client_error: Type[Exception] = ClientConnectorError, + ) -> Tuple[asyncio.BaseTransport, ResponseHandler]: + """Wrap the raw TCP transport with TLS.""" + tls_proto = self._factory() # Create a brand new proto for TLS + + # Safety of the `cast()` call here is based on the fact that + # internally `_get_ssl_context()` only returns `None` when + # `req.is_ssl()` evaluates to `False` which is never gonna happen + # in this code path. Of course, it's rather fragile + # maintainability-wise but this is to be solved separately. + sslcontext = cast(ssl.SSLContext, self._get_ssl_context(req)) + + try: + async with ceil_timeout( + timeout.sock_connect, ceil_threshold=timeout.ceil_threshold + ): + try: + tls_transport = await self._loop.start_tls( + underlying_transport, + tls_proto, + sslcontext, + server_hostname=req.server_hostname or req.host, + ssl_handshake_timeout=timeout.total, + ) + except BaseException: + # We need to close the underlying transport since + # `start_tls()` probably failed before it had a + # chance to do this: + underlying_transport.close() + raise + except cert_errors as exc: + raise ClientConnectorCertificateError(req.connection_key, exc) from exc + except ssl_errors as exc: + raise ClientConnectorSSLError(req.connection_key, exc) from exc + except OSError as exc: + if exc.errno is None and isinstance(exc, asyncio.TimeoutError): + raise + raise client_error(req.connection_key, exc) from exc + except TypeError as type_err: + # Example cause looks like this: + # TypeError: transport is not supported by start_tls() + + raise ClientConnectionError( + "Cannot initialize a TLS-in-TLS connection to host " + f"{req.host!s}:{req.port:d} through an underlying connection " + f"to an HTTPS proxy {req.proxy!s} ssl:{req.ssl or 'default'} " + f"[{type_err!s}]" + ) from type_err + else: + if tls_transport is None: + msg = "Failed to start TLS (possibly caused by closing transport)" + raise client_error(req.connection_key, OSError(msg)) + tls_proto.connection_made( + tls_transport + ) # Kick the state machine of the new TLS protocol + + return tls_transport, tls_proto + + async def _create_direct_connection( + self, + req: ClientRequest, + traces: List["Trace"], + timeout: "ClientTimeout", + *, + client_error: Type[Exception] = ClientConnectorError, + ) -> Tuple[asyncio.Transport, ResponseHandler]: + sslcontext = self._get_ssl_context(req) + fingerprint = self._get_fingerprint(req) + + host = req.url.raw_host + assert host is not None + # Replace multiple trailing dots with a single one. + # A trailing dot is only present for fully-qualified domain names. + # See https://github.com/aio-libs/aiohttp/pull/7364. + if host.endswith(".."): + host = host.rstrip(".") + "." + port = req.port + assert port is not None + try: + # Cancelling this lookup should not cancel the underlying lookup + # or else the cancel event will get broadcast to all the waiters + # across all connections. + hosts = await self._resolve_host(host, port, traces=traces) + except OSError as exc: + if exc.errno is None and isinstance(exc, asyncio.TimeoutError): + raise + # in case of proxy it is not ClientProxyConnectionError + # it is problem of resolving proxy ip itself + raise ClientConnectorError(req.connection_key, exc) from exc + + last_exc: Optional[Exception] = None + + for hinfo in hosts: + host = hinfo["host"] + port = hinfo["port"] + + # Strip trailing dots, certificates contain FQDN without dots. + # See https://github.com/aio-libs/aiohttp/issues/3636 + server_hostname = ( + (req.server_hostname or hinfo["hostname"]).rstrip(".") + if sslcontext + else None + ) + + try: + transp, proto = await self._wrap_create_connection( + self._factory, + host, + port, + timeout=timeout, + ssl=sslcontext, + family=hinfo["family"], + proto=hinfo["proto"], + flags=hinfo["flags"], + server_hostname=server_hostname, + local_addr=self._local_addr, + req=req, + client_error=client_error, + ) + except ClientConnectorError as exc: + last_exc = exc + continue + + if req.is_ssl() and fingerprint: + try: + fingerprint.check(transp) + except ServerFingerprintMismatch as exc: + transp.close() + if not self._cleanup_closed_disabled: + self._cleanup_closed_transports.append(transp) + last_exc = exc + continue + + return transp, proto + else: + assert last_exc is not None + raise last_exc + + async def _create_proxy_connection( + self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" + ) -> Tuple[asyncio.BaseTransport, ResponseHandler]: + self._fail_on_no_start_tls(req) + runtime_has_start_tls = self._loop_supports_start_tls() + + headers: Dict[str, str] = {} + if req.proxy_headers is not None: + headers = req.proxy_headers # type: ignore[assignment] + headers[hdrs.HOST] = req.headers[hdrs.HOST] + + url = req.proxy + assert url is not None + proxy_req = ClientRequest( + hdrs.METH_GET, + url, + headers=headers, + auth=req.proxy_auth, + loop=self._loop, + ssl=req.ssl, + ) + + # create connection to proxy server + transport, proto = await self._create_direct_connection( + proxy_req, [], timeout, client_error=ClientProxyConnectionError + ) + + # Many HTTP proxies has buggy keepalive support. Let's not + # reuse connection but close it after processing every + # response. + proto.force_close() + + auth = proxy_req.headers.pop(hdrs.AUTHORIZATION, None) + if auth is not None: + if not req.is_ssl(): + req.headers[hdrs.PROXY_AUTHORIZATION] = auth + else: + proxy_req.headers[hdrs.PROXY_AUTHORIZATION] = auth + + if req.is_ssl(): + if runtime_has_start_tls: + self._warn_about_tls_in_tls(transport, req) + + # For HTTPS requests over HTTP proxy + # we must notify proxy to tunnel connection + # so we send CONNECT command: + # CONNECT www.python.org:443 HTTP/1.1 + # Host: www.python.org + # + # next we must do TLS handshake and so on + # to do this we must wrap raw socket into secure one + # asyncio handles this perfectly + proxy_req.method = hdrs.METH_CONNECT + proxy_req.url = req.url + key = attr.evolve( + req.connection_key, proxy=None, proxy_auth=None, proxy_headers_hash=None + ) + conn = Connection(self, key, proto, self._loop) + proxy_resp = await proxy_req.send(conn) + try: + protocol = conn._protocol + assert protocol is not None + + # read_until_eof=True will ensure the connection isn't closed + # once the response is received and processed allowing + # START_TLS to work on the connection below. + protocol.set_response_params( + read_until_eof=runtime_has_start_tls, + timeout_ceil_threshold=self._timeout_ceil_threshold, + ) + resp = await proxy_resp.start(conn) + except BaseException: + proxy_resp.close() + conn.close() + raise + else: + conn._protocol = None + conn._transport = None + try: + if resp.status != 200: + message = resp.reason + if message is None: + message = HTTPStatus(resp.status).phrase + raise ClientHttpProxyError( + proxy_resp.request_info, + resp.history, + status=resp.status, + message=message, + headers=resp.headers, + ) + if not runtime_has_start_tls: + rawsock = transport.get_extra_info("socket", default=None) + if rawsock is None: + raise RuntimeError( + "Transport does not expose socket instance" + ) + # Duplicate the socket, so now we can close proxy transport + rawsock = rawsock.dup() + except BaseException: + # It shouldn't be closed in `finally` because it's fed to + # `loop.start_tls()` and the docs say not to touch it after + # passing there. + transport.close() + raise + finally: + if not runtime_has_start_tls: + transport.close() + + if not runtime_has_start_tls: + # HTTP proxy with support for upgrade to HTTPS + sslcontext = self._get_ssl_context(req) + return await self._wrap_create_connection( + self._factory, + timeout=timeout, + ssl=sslcontext, + sock=rawsock, + server_hostname=req.host, + req=req, + ) + + return await self._start_tls_connection( + # Access the old transport for the last time before it's + # closed and forgotten forever: + transport, + req=req, + timeout=timeout, + ) + finally: + proxy_resp.close() + + return transport, proto + + +class UnixConnector(BaseConnector): + """Unix socket connector. + + path - Unix socket path. + keepalive_timeout - (optional) Keep-alive timeout. + force_close - Set to True to force close and do reconnect + after each request (and between redirects). + limit - The total number of simultaneous connections. + limit_per_host - Number of simultaneous connections to one host. + loop - Optional event loop. + """ + + def __init__( + self, + path: str, + force_close: bool = False, + keepalive_timeout: Union[object, float, None] = sentinel, + limit: int = 100, + limit_per_host: int = 0, + loop: Optional[asyncio.AbstractEventLoop] = None, + ) -> None: + super().__init__( + force_close=force_close, + keepalive_timeout=keepalive_timeout, + limit=limit, + limit_per_host=limit_per_host, + loop=loop, + ) + self._path = path + + @property + def path(self) -> str: + """Path to unix socket.""" + return self._path + + async def _create_connection( + self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" + ) -> ResponseHandler: + try: + async with ceil_timeout( + timeout.sock_connect, ceil_threshold=timeout.ceil_threshold + ): + _, proto = await self._loop.create_unix_connection( + self._factory, self._path + ) + except OSError as exc: + if exc.errno is None and isinstance(exc, asyncio.TimeoutError): + raise + raise UnixClientConnectorError(self.path, req.connection_key, exc) from exc + + return proto + + +class NamedPipeConnector(BaseConnector): + """Named pipe connector. + + Only supported by the proactor event loop. + See also: https://docs.python.org/3/library/asyncio-eventloop.html + + path - Windows named pipe path. + keepalive_timeout - (optional) Keep-alive timeout. + force_close - Set to True to force close and do reconnect + after each request (and between redirects). + limit - The total number of simultaneous connections. + limit_per_host - Number of simultaneous connections to one host. + loop - Optional event loop. + """ + + def __init__( + self, + path: str, + force_close: bool = False, + keepalive_timeout: Union[object, float, None] = sentinel, + limit: int = 100, + limit_per_host: int = 0, + loop: Optional[asyncio.AbstractEventLoop] = None, + ) -> None: + super().__init__( + force_close=force_close, + keepalive_timeout=keepalive_timeout, + limit=limit, + limit_per_host=limit_per_host, + loop=loop, + ) + if not isinstance( + self._loop, asyncio.ProactorEventLoop # type: ignore[attr-defined] + ): + raise RuntimeError( + "Named Pipes only available in proactor " "loop under windows" + ) + self._path = path + + @property + def path(self) -> str: + """Path to the named pipe.""" + return self._path + + async def _create_connection( + self, req: ClientRequest, traces: List["Trace"], timeout: "ClientTimeout" + ) -> ResponseHandler: + try: + async with ceil_timeout( + timeout.sock_connect, ceil_threshold=timeout.ceil_threshold + ): + _, proto = await self._loop.create_pipe_connection( # type: ignore[attr-defined] + self._factory, self._path + ) + # the drain is required so that the connection_made is called + # and transport is set otherwise it is not set before the + # `assert conn.transport is not None` + # in client.py's _request method + await asyncio.sleep(0) + # other option is to manually set transport like + # `proto.transport = trans` + except OSError as exc: + if exc.errno is None and isinstance(exc, asyncio.TimeoutError): + raise + raise ClientConnectorError(req.connection_key, exc) from exc + + return cast(ResponseHandler, proto) diff --git a/llm/Lib/site-packages/aiohttp/cookiejar.py b/llm/Lib/site-packages/aiohttp/cookiejar.py new file mode 100644 index 0000000000000000000000000000000000000000..5d59717335e9760b252df2b5551d532723d8dcc3 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/cookiejar.py @@ -0,0 +1,419 @@ +import asyncio +import calendar +import contextlib +import datetime +import os # noqa +import pathlib +import pickle +import re +import time +from collections import defaultdict +from http.cookies import BaseCookie, Morsel, SimpleCookie +from math import ceil +from typing import ( # noqa + DefaultDict, + Dict, + Iterable, + Iterator, + List, + Mapping, + Optional, + Set, + Tuple, + Union, + cast, +) + +from yarl import URL + +from .abc import AbstractCookieJar, ClearCookiePredicate +from .helpers import is_ip_address +from .typedefs import LooseCookies, PathLike, StrOrURL + +__all__ = ("CookieJar", "DummyCookieJar") + + +CookieItem = Union[str, "Morsel[str]"] + + +class CookieJar(AbstractCookieJar): + """Implements cookie storage adhering to RFC 6265.""" + + DATE_TOKENS_RE = re.compile( + r"[\x09\x20-\x2F\x3B-\x40\x5B-\x60\x7B-\x7E]*" + r"(?P[\x00-\x08\x0A-\x1F\d:a-zA-Z\x7F-\xFF]+)" + ) + + DATE_HMS_TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})") + + DATE_DAY_OF_MONTH_RE = re.compile(r"(\d{1,2})") + + DATE_MONTH_RE = re.compile( + "(jan)|(feb)|(mar)|(apr)|(may)|(jun)|(jul)|" "(aug)|(sep)|(oct)|(nov)|(dec)", + re.I, + ) + + DATE_YEAR_RE = re.compile(r"(\d{2,4})") + + # calendar.timegm() fails for timestamps after datetime.datetime.max + # Minus one as a loss of precision occurs when timestamp() is called. + MAX_TIME = ( + int(datetime.datetime.max.replace(tzinfo=datetime.timezone.utc).timestamp()) - 1 + ) + try: + calendar.timegm(time.gmtime(MAX_TIME)) + except (OSError, ValueError): + # Hit the maximum representable time on Windows + # https://learn.microsoft.com/en-us/cpp/c-runtime-library/reference/localtime-localtime32-localtime64 + # Throws ValueError on PyPy 3.8 and 3.9, OSError elsewhere + MAX_TIME = calendar.timegm((3000, 12, 31, 23, 59, 59, -1, -1, -1)) + except OverflowError: + # #4515: datetime.max may not be representable on 32-bit platforms + MAX_TIME = 2**31 - 1 + # Avoid minuses in the future, 3x faster + SUB_MAX_TIME = MAX_TIME - 1 + + def __init__( + self, + *, + unsafe: bool = False, + quote_cookie: bool = True, + treat_as_secure_origin: Union[StrOrURL, List[StrOrURL], None] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, + ) -> None: + super().__init__(loop=loop) + self._cookies: DefaultDict[Tuple[str, str], SimpleCookie] = defaultdict( + SimpleCookie + ) + self._host_only_cookies: Set[Tuple[str, str]] = set() + self._unsafe = unsafe + self._quote_cookie = quote_cookie + if treat_as_secure_origin is None: + treat_as_secure_origin = [] + elif isinstance(treat_as_secure_origin, URL): + treat_as_secure_origin = [treat_as_secure_origin.origin()] + elif isinstance(treat_as_secure_origin, str): + treat_as_secure_origin = [URL(treat_as_secure_origin).origin()] + else: + treat_as_secure_origin = [ + URL(url).origin() if isinstance(url, str) else url.origin() + for url in treat_as_secure_origin + ] + self._treat_as_secure_origin = treat_as_secure_origin + self._next_expiration: float = ceil(time.time()) + self._expirations: Dict[Tuple[str, str, str], float] = {} + + def save(self, file_path: PathLike) -> None: + file_path = pathlib.Path(file_path) + with file_path.open(mode="wb") as f: + pickle.dump(self._cookies, f, pickle.HIGHEST_PROTOCOL) + + def load(self, file_path: PathLike) -> None: + file_path = pathlib.Path(file_path) + with file_path.open(mode="rb") as f: + self._cookies = pickle.load(f) + + def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None: + if predicate is None: + self._next_expiration = ceil(time.time()) + self._cookies.clear() + self._host_only_cookies.clear() + self._expirations.clear() + return + + to_del = [] + now = time.time() + for (domain, path), cookie in self._cookies.items(): + for name, morsel in cookie.items(): + key = (domain, path, name) + if ( + key in self._expirations and self._expirations[key] <= now + ) or predicate(morsel): + to_del.append(key) + + for domain, path, name in to_del: + self._host_only_cookies.discard((domain, name)) + key = (domain, path, name) + if key in self._expirations: + del self._expirations[(domain, path, name)] + self._cookies[(domain, path)].pop(name, None) + + self._next_expiration = ( + min(*self._expirations.values(), self.SUB_MAX_TIME) + 1 + if self._expirations + else self.MAX_TIME + ) + + def clear_domain(self, domain: str) -> None: + self.clear(lambda x: self._is_domain_match(domain, x["domain"])) + + def __iter__(self) -> "Iterator[Morsel[str]]": + self._do_expiration() + for val in self._cookies.values(): + yield from val.values() + + def __len__(self) -> int: + return sum(1 for i in self) + + def _do_expiration(self) -> None: + self.clear(lambda x: False) + + def _expire_cookie(self, when: float, domain: str, path: str, name: str) -> None: + self._next_expiration = min(self._next_expiration, when) + self._expirations[(domain, path, name)] = when + + def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None: + """Update cookies.""" + hostname = response_url.raw_host + + if not self._unsafe and is_ip_address(hostname): + # Don't accept cookies from IPs + return + + if isinstance(cookies, Mapping): + cookies = cookies.items() + + for name, cookie in cookies: + if not isinstance(cookie, Morsel): + tmp = SimpleCookie() + tmp[name] = cookie # type: ignore[assignment] + cookie = tmp[name] + + domain = cookie["domain"] + + # ignore domains with trailing dots + if domain.endswith("."): + domain = "" + del cookie["domain"] + + if not domain and hostname is not None: + # Set the cookie's domain to the response hostname + # and set its host-only-flag + self._host_only_cookies.add((hostname, name)) + domain = cookie["domain"] = hostname + + if domain.startswith("."): + # Remove leading dot + domain = domain[1:] + cookie["domain"] = domain + + if hostname and not self._is_domain_match(domain, hostname): + # Setting cookies for different domains is not allowed + continue + + path = cookie["path"] + if not path or not path.startswith("/"): + # Set the cookie's path to the response path + path = response_url.path + if not path.startswith("/"): + path = "/" + else: + # Cut everything from the last slash to the end + path = "/" + path[1 : path.rfind("/")] + cookie["path"] = path + + max_age = cookie["max-age"] + if max_age: + try: + delta_seconds = int(max_age) + max_age_expiration = min(time.time() + delta_seconds, self.MAX_TIME) + self._expire_cookie(max_age_expiration, domain, path, name) + except ValueError: + cookie["max-age"] = "" + + else: + expires = cookie["expires"] + if expires: + expire_time = self._parse_date(expires) + if expire_time: + self._expire_cookie(expire_time, domain, path, name) + else: + cookie["expires"] = "" + + self._cookies[(domain, path)][name] = cookie + + self._do_expiration() + + def filter_cookies(self, request_url: URL = URL()) -> "BaseCookie[str]": + """Returns this jar's cookies filtered by their attributes.""" + filtered: Union[SimpleCookie, "BaseCookie[str]"] = ( + SimpleCookie() if self._quote_cookie else BaseCookie() + ) + if not self._cookies: + # Skip do_expiration() if there are no cookies. + return filtered + self._do_expiration() + if not self._cookies: + # Skip rest of function if no non-expired cookies. + return filtered + request_url = URL(request_url) + hostname = request_url.raw_host or "" + + is_not_secure = request_url.scheme not in ("https", "wss") + if is_not_secure and self._treat_as_secure_origin: + request_origin = URL() + with contextlib.suppress(ValueError): + request_origin = request_url.origin() + is_not_secure = request_origin not in self._treat_as_secure_origin + + # Point 2: https://www.rfc-editor.org/rfc/rfc6265.html#section-5.4 + for cookie in sorted(self, key=lambda c: len(c["path"])): + name = cookie.key + domain = cookie["domain"] + + # Send shared cookies + if not domain: + filtered[name] = cookie.value + continue + + if not self._unsafe and is_ip_address(hostname): + continue + + if (domain, name) in self._host_only_cookies: + if domain != hostname: + continue + elif not self._is_domain_match(domain, hostname): + continue + + if not self._is_path_match(request_url.path, cookie["path"]): + continue + + if is_not_secure and cookie["secure"]: + continue + + # It's critical we use the Morsel so the coded_value + # (based on cookie version) is preserved + mrsl_val = cast("Morsel[str]", cookie.get(cookie.key, Morsel())) + mrsl_val.set(cookie.key, cookie.value, cookie.coded_value) + filtered[name] = mrsl_val + + return filtered + + @staticmethod + def _is_domain_match(domain: str, hostname: str) -> bool: + """Implements domain matching adhering to RFC 6265.""" + if hostname == domain: + return True + + if not hostname.endswith(domain): + return False + + non_matching = hostname[: -len(domain)] + + if not non_matching.endswith("."): + return False + + return not is_ip_address(hostname) + + @staticmethod + def _is_path_match(req_path: str, cookie_path: str) -> bool: + """Implements path matching adhering to RFC 6265.""" + if not req_path.startswith("/"): + req_path = "/" + + if req_path == cookie_path: + return True + + if not req_path.startswith(cookie_path): + return False + + if cookie_path.endswith("/"): + return True + + non_matching = req_path[len(cookie_path) :] + + return non_matching.startswith("/") + + @classmethod + def _parse_date(cls, date_str: str) -> Optional[int]: + """Implements date string parsing adhering to RFC 6265.""" + if not date_str: + return None + + found_time = False + found_day = False + found_month = False + found_year = False + + hour = minute = second = 0 + day = 0 + month = 0 + year = 0 + + for token_match in cls.DATE_TOKENS_RE.finditer(date_str): + + token = token_match.group("token") + + if not found_time: + time_match = cls.DATE_HMS_TIME_RE.match(token) + if time_match: + found_time = True + hour, minute, second = (int(s) for s in time_match.groups()) + continue + + if not found_day: + day_match = cls.DATE_DAY_OF_MONTH_RE.match(token) + if day_match: + found_day = True + day = int(day_match.group()) + continue + + if not found_month: + month_match = cls.DATE_MONTH_RE.match(token) + if month_match: + found_month = True + assert month_match.lastindex is not None + month = month_match.lastindex + continue + + if not found_year: + year_match = cls.DATE_YEAR_RE.match(token) + if year_match: + found_year = True + year = int(year_match.group()) + + if 70 <= year <= 99: + year += 1900 + elif 0 <= year <= 69: + year += 2000 + + if False in (found_day, found_month, found_year, found_time): + return None + + if not 1 <= day <= 31: + return None + + if year < 1601 or hour > 23 or minute > 59 or second > 59: + return None + + return calendar.timegm((year, month, day, hour, minute, second, -1, -1, -1)) + + +class DummyCookieJar(AbstractCookieJar): + """Implements a dummy cookie storage. + + It can be used with the ClientSession when no cookie processing is needed. + + """ + + def __init__(self, *, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: + super().__init__(loop=loop) + + def __iter__(self) -> "Iterator[Morsel[str]]": + while False: + yield None + + def __len__(self) -> int: + return 0 + + def clear(self, predicate: Optional[ClearCookiePredicate] = None) -> None: + pass + + def clear_domain(self, domain: str) -> None: + pass + + def update_cookies(self, cookies: LooseCookies, response_url: URL = URL()) -> None: + pass + + def filter_cookies(self, request_url: URL) -> "BaseCookie[str]": + return SimpleCookie() diff --git a/llm/Lib/site-packages/aiohttp/formdata.py b/llm/Lib/site-packages/aiohttp/formdata.py new file mode 100644 index 0000000000000000000000000000000000000000..bbf78a82d124548d26915969f5b9c75398cf29a9 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/formdata.py @@ -0,0 +1,182 @@ +import io +import warnings +from typing import Any, Iterable, List, Optional +from urllib.parse import urlencode + +from multidict import MultiDict, MultiDictProxy + +from . import hdrs, multipart, payload +from .helpers import guess_filename +from .payload import Payload + +__all__ = ("FormData",) + + +class FormData: + """Helper class for form body generation. + + Supports multipart/form-data and application/x-www-form-urlencoded. + """ + + def __init__( + self, + fields: Iterable[Any] = (), + quote_fields: bool = True, + charset: Optional[str] = None, + ) -> None: + self._writer = multipart.MultipartWriter("form-data") + self._fields: List[Any] = [] + self._is_multipart = False + self._is_processed = False + self._quote_fields = quote_fields + self._charset = charset + + if isinstance(fields, dict): + fields = list(fields.items()) + elif not isinstance(fields, (list, tuple)): + fields = (fields,) + self.add_fields(*fields) + + @property + def is_multipart(self) -> bool: + return self._is_multipart + + def add_field( + self, + name: str, + value: Any, + *, + content_type: Optional[str] = None, + filename: Optional[str] = None, + content_transfer_encoding: Optional[str] = None, + ) -> None: + + if isinstance(value, io.IOBase): + self._is_multipart = True + elif isinstance(value, (bytes, bytearray, memoryview)): + msg = ( + "In v4, passing bytes will no longer create a file field. " + "Please explicitly use the filename parameter or pass a BytesIO object." + ) + if filename is None and content_transfer_encoding is None: + warnings.warn(msg, DeprecationWarning) + filename = name + + type_options: MultiDict[str] = MultiDict({"name": name}) + if filename is not None and not isinstance(filename, str): + raise TypeError( + "filename must be an instance of str. " "Got: %s" % filename + ) + if filename is None and isinstance(value, io.IOBase): + filename = guess_filename(value, name) + if filename is not None: + type_options["filename"] = filename + self._is_multipart = True + + headers = {} + if content_type is not None: + if not isinstance(content_type, str): + raise TypeError( + "content_type must be an instance of str. " "Got: %s" % content_type + ) + headers[hdrs.CONTENT_TYPE] = content_type + self._is_multipart = True + if content_transfer_encoding is not None: + if not isinstance(content_transfer_encoding, str): + raise TypeError( + "content_transfer_encoding must be an instance" + " of str. Got: %s" % content_transfer_encoding + ) + msg = ( + "content_transfer_encoding is deprecated. " + "To maintain compatibility with v4 please pass a BytesPayload." + ) + warnings.warn(msg, DeprecationWarning) + self._is_multipart = True + + self._fields.append((type_options, headers, value)) + + def add_fields(self, *fields: Any) -> None: + to_add = list(fields) + + while to_add: + rec = to_add.pop(0) + + if isinstance(rec, io.IOBase): + k = guess_filename(rec, "unknown") + self.add_field(k, rec) # type: ignore[arg-type] + + elif isinstance(rec, (MultiDictProxy, MultiDict)): + to_add.extend(rec.items()) + + elif isinstance(rec, (list, tuple)) and len(rec) == 2: + k, fp = rec + self.add_field(k, fp) # type: ignore[arg-type] + + else: + raise TypeError( + "Only io.IOBase, multidict and (name, file) " + "pairs allowed, use .add_field() for passing " + "more complex parameters, got {!r}".format(rec) + ) + + def _gen_form_urlencoded(self) -> payload.BytesPayload: + # form data (x-www-form-urlencoded) + data = [] + for type_options, _, value in self._fields: + data.append((type_options["name"], value)) + + charset = self._charset if self._charset is not None else "utf-8" + + if charset == "utf-8": + content_type = "application/x-www-form-urlencoded" + else: + content_type = "application/x-www-form-urlencoded; " "charset=%s" % charset + + return payload.BytesPayload( + urlencode(data, doseq=True, encoding=charset).encode(), + content_type=content_type, + ) + + def _gen_form_data(self) -> multipart.MultipartWriter: + """Encode a list of fields using the multipart/form-data MIME format""" + if self._is_processed: + raise RuntimeError("Form data has been processed already") + for dispparams, headers, value in self._fields: + try: + if hdrs.CONTENT_TYPE in headers: + part = payload.get_payload( + value, + content_type=headers[hdrs.CONTENT_TYPE], + headers=headers, + encoding=self._charset, + ) + else: + part = payload.get_payload( + value, headers=headers, encoding=self._charset + ) + except Exception as exc: + raise TypeError( + "Can not serialize value type: %r\n " + "headers: %r\n value: %r" % (type(value), headers, value) + ) from exc + + if dispparams: + part.set_content_disposition( + "form-data", quote_fields=self._quote_fields, **dispparams + ) + # FIXME cgi.FieldStorage doesn't likes body parts with + # Content-Length which were sent via chunked transfer encoding + assert part.headers is not None + part.headers.popall(hdrs.CONTENT_LENGTH, None) + + self._writer.append_payload(part) + + self._is_processed = True + return self._writer + + def __call__(self) -> Payload: + if self._is_multipart: + return self._gen_form_data() + else: + return self._gen_form_urlencoded() diff --git a/llm/Lib/site-packages/aiohttp/hdrs.py b/llm/Lib/site-packages/aiohttp/hdrs.py new file mode 100644 index 0000000000000000000000000000000000000000..078abfff86be97186724b5bdf0e03613c34d9bbc --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/hdrs.py @@ -0,0 +1,108 @@ +"""HTTP Headers constants.""" + +# After changing the file content call ./tools/gen.py +# to regenerate the headers parser +from typing import Final, Set + +from multidict import istr + +METH_ANY: Final[str] = "*" +METH_CONNECT: Final[str] = "CONNECT" +METH_HEAD: Final[str] = "HEAD" +METH_GET: Final[str] = "GET" +METH_DELETE: Final[str] = "DELETE" +METH_OPTIONS: Final[str] = "OPTIONS" +METH_PATCH: Final[str] = "PATCH" +METH_POST: Final[str] = "POST" +METH_PUT: Final[str] = "PUT" +METH_TRACE: Final[str] = "TRACE" + +METH_ALL: Final[Set[str]] = { + METH_CONNECT, + METH_HEAD, + METH_GET, + METH_DELETE, + METH_OPTIONS, + METH_PATCH, + METH_POST, + METH_PUT, + METH_TRACE, +} + +ACCEPT: Final[istr] = istr("Accept") +ACCEPT_CHARSET: Final[istr] = istr("Accept-Charset") +ACCEPT_ENCODING: Final[istr] = istr("Accept-Encoding") +ACCEPT_LANGUAGE: Final[istr] = istr("Accept-Language") +ACCEPT_RANGES: Final[istr] = istr("Accept-Ranges") +ACCESS_CONTROL_MAX_AGE: Final[istr] = istr("Access-Control-Max-Age") +ACCESS_CONTROL_ALLOW_CREDENTIALS: Final[istr] = istr("Access-Control-Allow-Credentials") +ACCESS_CONTROL_ALLOW_HEADERS: Final[istr] = istr("Access-Control-Allow-Headers") +ACCESS_CONTROL_ALLOW_METHODS: Final[istr] = istr("Access-Control-Allow-Methods") +ACCESS_CONTROL_ALLOW_ORIGIN: Final[istr] = istr("Access-Control-Allow-Origin") +ACCESS_CONTROL_EXPOSE_HEADERS: Final[istr] = istr("Access-Control-Expose-Headers") +ACCESS_CONTROL_REQUEST_HEADERS: Final[istr] = istr("Access-Control-Request-Headers") +ACCESS_CONTROL_REQUEST_METHOD: Final[istr] = istr("Access-Control-Request-Method") +AGE: Final[istr] = istr("Age") +ALLOW: Final[istr] = istr("Allow") +AUTHORIZATION: Final[istr] = istr("Authorization") +CACHE_CONTROL: Final[istr] = istr("Cache-Control") +CONNECTION: Final[istr] = istr("Connection") +CONTENT_DISPOSITION: Final[istr] = istr("Content-Disposition") +CONTENT_ENCODING: Final[istr] = istr("Content-Encoding") +CONTENT_LANGUAGE: Final[istr] = istr("Content-Language") +CONTENT_LENGTH: Final[istr] = istr("Content-Length") +CONTENT_LOCATION: Final[istr] = istr("Content-Location") +CONTENT_MD5: Final[istr] = istr("Content-MD5") +CONTENT_RANGE: Final[istr] = istr("Content-Range") +CONTENT_TRANSFER_ENCODING: Final[istr] = istr("Content-Transfer-Encoding") +CONTENT_TYPE: Final[istr] = istr("Content-Type") +COOKIE: Final[istr] = istr("Cookie") +DATE: Final[istr] = istr("Date") +DESTINATION: Final[istr] = istr("Destination") +DIGEST: Final[istr] = istr("Digest") +ETAG: Final[istr] = istr("Etag") +EXPECT: Final[istr] = istr("Expect") +EXPIRES: Final[istr] = istr("Expires") +FORWARDED: Final[istr] = istr("Forwarded") +FROM: Final[istr] = istr("From") +HOST: Final[istr] = istr("Host") +IF_MATCH: Final[istr] = istr("If-Match") +IF_MODIFIED_SINCE: Final[istr] = istr("If-Modified-Since") +IF_NONE_MATCH: Final[istr] = istr("If-None-Match") +IF_RANGE: Final[istr] = istr("If-Range") +IF_UNMODIFIED_SINCE: Final[istr] = istr("If-Unmodified-Since") +KEEP_ALIVE: Final[istr] = istr("Keep-Alive") +LAST_EVENT_ID: Final[istr] = istr("Last-Event-ID") +LAST_MODIFIED: Final[istr] = istr("Last-Modified") +LINK: Final[istr] = istr("Link") +LOCATION: Final[istr] = istr("Location") +MAX_FORWARDS: Final[istr] = istr("Max-Forwards") +ORIGIN: Final[istr] = istr("Origin") +PRAGMA: Final[istr] = istr("Pragma") +PROXY_AUTHENTICATE: Final[istr] = istr("Proxy-Authenticate") +PROXY_AUTHORIZATION: Final[istr] = istr("Proxy-Authorization") +RANGE: Final[istr] = istr("Range") +REFERER: Final[istr] = istr("Referer") +RETRY_AFTER: Final[istr] = istr("Retry-After") +SEC_WEBSOCKET_ACCEPT: Final[istr] = istr("Sec-WebSocket-Accept") +SEC_WEBSOCKET_VERSION: Final[istr] = istr("Sec-WebSocket-Version") +SEC_WEBSOCKET_PROTOCOL: Final[istr] = istr("Sec-WebSocket-Protocol") +SEC_WEBSOCKET_EXTENSIONS: Final[istr] = istr("Sec-WebSocket-Extensions") +SEC_WEBSOCKET_KEY: Final[istr] = istr("Sec-WebSocket-Key") +SEC_WEBSOCKET_KEY1: Final[istr] = istr("Sec-WebSocket-Key1") +SERVER: Final[istr] = istr("Server") +SET_COOKIE: Final[istr] = istr("Set-Cookie") +TE: Final[istr] = istr("TE") +TRAILER: Final[istr] = istr("Trailer") +TRANSFER_ENCODING: Final[istr] = istr("Transfer-Encoding") +UPGRADE: Final[istr] = istr("Upgrade") +URI: Final[istr] = istr("URI") +USER_AGENT: Final[istr] = istr("User-Agent") +VARY: Final[istr] = istr("Vary") +VIA: Final[istr] = istr("Via") +WANT_DIGEST: Final[istr] = istr("Want-Digest") +WARNING: Final[istr] = istr("Warning") +WWW_AUTHENTICATE: Final[istr] = istr("WWW-Authenticate") +X_FORWARDED_FOR: Final[istr] = istr("X-Forwarded-For") +X_FORWARDED_HOST: Final[istr] = istr("X-Forwarded-Host") +X_FORWARDED_PROTO: Final[istr] = istr("X-Forwarded-Proto") diff --git a/llm/Lib/site-packages/aiohttp/helpers.py b/llm/Lib/site-packages/aiohttp/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..f2d729f1bcecaeec317b715141303f99e53046cc --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/helpers.py @@ -0,0 +1,1029 @@ +"""Various helper functions""" + +import asyncio +import base64 +import binascii +import contextlib +import datetime +import enum +import functools +import inspect +import netrc +import os +import platform +import re +import sys +import time +import warnings +import weakref +from collections import namedtuple +from contextlib import suppress +from email.parser import HeaderParser +from email.utils import parsedate +from math import ceil +from pathlib import Path +from types import TracebackType +from typing import ( + Any, + Callable, + ContextManager, + Dict, + Generator, + Generic, + Iterable, + Iterator, + List, + Mapping, + Optional, + Pattern, + Protocol, + Tuple, + Type, + TypeVar, + Union, + get_args, + overload, +) +from urllib.parse import quote +from urllib.request import getproxies, proxy_bypass + +import attr +from multidict import MultiDict, MultiDictProxy, MultiMapping +from yarl import URL + +from . import hdrs +from .log import client_logger, internal_logger + +if sys.version_info >= (3, 11): + import asyncio as async_timeout +else: + import async_timeout + +__all__ = ("BasicAuth", "ChainMapProxy", "ETag") + +IS_MACOS = platform.system() == "Darwin" +IS_WINDOWS = platform.system() == "Windows" + +PY_310 = sys.version_info >= (3, 10) +PY_311 = sys.version_info >= (3, 11) + + +_T = TypeVar("_T") +_S = TypeVar("_S") + +_SENTINEL = enum.Enum("_SENTINEL", "sentinel") +sentinel = _SENTINEL.sentinel + +NO_EXTENSIONS = bool(os.environ.get("AIOHTTP_NO_EXTENSIONS")) + +DEBUG = sys.flags.dev_mode or ( + not sys.flags.ignore_environment and bool(os.environ.get("PYTHONASYNCIODEBUG")) +) + + +CHAR = {chr(i) for i in range(0, 128)} +CTL = {chr(i) for i in range(0, 32)} | { + chr(127), +} +SEPARATORS = { + "(", + ")", + "<", + ">", + "@", + ",", + ";", + ":", + "\\", + '"', + "/", + "[", + "]", + "?", + "=", + "{", + "}", + " ", + chr(9), +} +TOKEN = CHAR ^ CTL ^ SEPARATORS + + +class noop: + def __await__(self) -> Generator[None, None, None]: + yield + + +class BasicAuth(namedtuple("BasicAuth", ["login", "password", "encoding"])): + """Http basic authentication helper.""" + + def __new__( + cls, login: str, password: str = "", encoding: str = "latin1" + ) -> "BasicAuth": + if login is None: + raise ValueError("None is not allowed as login value") + + if password is None: + raise ValueError("None is not allowed as password value") + + if ":" in login: + raise ValueError('A ":" is not allowed in login (RFC 1945#section-11.1)') + + return super().__new__(cls, login, password, encoding) + + @classmethod + def decode(cls, auth_header: str, encoding: str = "latin1") -> "BasicAuth": + """Create a BasicAuth object from an Authorization HTTP header.""" + try: + auth_type, encoded_credentials = auth_header.split(" ", 1) + except ValueError: + raise ValueError("Could not parse authorization header.") + + if auth_type.lower() != "basic": + raise ValueError("Unknown authorization method %s" % auth_type) + + try: + decoded = base64.b64decode( + encoded_credentials.encode("ascii"), validate=True + ).decode(encoding) + except binascii.Error: + raise ValueError("Invalid base64 encoding.") + + try: + # RFC 2617 HTTP Authentication + # https://www.ietf.org/rfc/rfc2617.txt + # the colon must be present, but the username and password may be + # otherwise blank. + username, password = decoded.split(":", 1) + except ValueError: + raise ValueError("Invalid credentials.") + + return cls(username, password, encoding=encoding) + + @classmethod + def from_url(cls, url: URL, *, encoding: str = "latin1") -> Optional["BasicAuth"]: + """Create BasicAuth from url.""" + if not isinstance(url, URL): + raise TypeError("url should be yarl.URL instance") + if url.user is None: + return None + return cls(url.user, url.password or "", encoding=encoding) + + def encode(self) -> str: + """Encode credentials.""" + creds = (f"{self.login}:{self.password}").encode(self.encoding) + return "Basic %s" % base64.b64encode(creds).decode(self.encoding) + + +def strip_auth_from_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]: + auth = BasicAuth.from_url(url) + if auth is None: + return url, None + else: + return url.with_user(None), auth + + +def netrc_from_env() -> Optional[netrc.netrc]: + """Load netrc from file. + + Attempt to load it from the path specified by the env-var + NETRC or in the default location in the user's home directory. + + Returns None if it couldn't be found or fails to parse. + """ + netrc_env = os.environ.get("NETRC") + + if netrc_env is not None: + netrc_path = Path(netrc_env) + else: + try: + home_dir = Path.home() + except RuntimeError as e: # pragma: no cover + # if pathlib can't resolve home, it may raise a RuntimeError + client_logger.debug( + "Could not resolve home directory when " + "trying to look for .netrc file: %s", + e, + ) + return None + + netrc_path = home_dir / ("_netrc" if IS_WINDOWS else ".netrc") + + try: + return netrc.netrc(str(netrc_path)) + except netrc.NetrcParseError as e: + client_logger.warning("Could not parse .netrc file: %s", e) + except OSError as e: + netrc_exists = False + with contextlib.suppress(OSError): + netrc_exists = netrc_path.is_file() + # we couldn't read the file (doesn't exist, permissions, etc.) + if netrc_env or netrc_exists: + # only warn if the environment wanted us to load it, + # or it appears like the default file does actually exist + client_logger.warning("Could not read .netrc file: %s", e) + + return None + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class ProxyInfo: + proxy: URL + proxy_auth: Optional[BasicAuth] + + +def basicauth_from_netrc(netrc_obj: Optional[netrc.netrc], host: str) -> BasicAuth: + """ + Return :py:class:`~aiohttp.BasicAuth` credentials for ``host`` from ``netrc_obj``. + + :raises LookupError: if ``netrc_obj`` is :py:data:`None` or if no + entry is found for the ``host``. + """ + if netrc_obj is None: + raise LookupError("No .netrc file found") + auth_from_netrc = netrc_obj.authenticators(host) + + if auth_from_netrc is None: + raise LookupError(f"No entry for {host!s} found in the `.netrc` file.") + login, account, password = auth_from_netrc + + # TODO(PY311): username = login or account + # Up to python 3.10, account could be None if not specified, + # and login will be empty string if not specified. From 3.11, + # login and account will be empty string if not specified. + username = login if (login or account is None) else account + + # TODO(PY311): Remove this, as password will be empty string + # if not specified + if password is None: + password = "" + + return BasicAuth(username, password) + + +def proxies_from_env() -> Dict[str, ProxyInfo]: + proxy_urls = { + k: URL(v) + for k, v in getproxies().items() + if k in ("http", "https", "ws", "wss") + } + netrc_obj = netrc_from_env() + stripped = {k: strip_auth_from_url(v) for k, v in proxy_urls.items()} + ret = {} + for proto, val in stripped.items(): + proxy, auth = val + if proxy.scheme in ("https", "wss"): + client_logger.warning( + "%s proxies %s are not supported, ignoring", proxy.scheme.upper(), proxy + ) + continue + if netrc_obj and auth is None: + if proxy.host is not None: + try: + auth = basicauth_from_netrc(netrc_obj, proxy.host) + except LookupError: + auth = None + ret[proto] = ProxyInfo(proxy, auth) + return ret + + +def current_task( + loop: Optional[asyncio.AbstractEventLoop] = None, +) -> "Optional[asyncio.Task[Any]]": + return asyncio.current_task(loop=loop) + + +def get_running_loop( + loop: Optional[asyncio.AbstractEventLoop] = None, +) -> asyncio.AbstractEventLoop: + if loop is None: + loop = asyncio.get_event_loop() + if not loop.is_running(): + warnings.warn( + "The object should be created within an async function", + DeprecationWarning, + stacklevel=3, + ) + if loop.get_debug(): + internal_logger.warning( + "The object should be created within an async function", stack_info=True + ) + return loop + + +def isasyncgenfunction(obj: Any) -> bool: + func = getattr(inspect, "isasyncgenfunction", None) + if func is not None: + return func(obj) # type: ignore[no-any-return] + else: + return False + + +def get_env_proxy_for_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]: + """Get a permitted proxy for the given URL from the env.""" + if url.host is not None and proxy_bypass(url.host): + raise LookupError(f"Proxying is disallowed for `{url.host!r}`") + + proxies_in_env = proxies_from_env() + try: + proxy_info = proxies_in_env[url.scheme] + except KeyError: + raise LookupError(f"No proxies found for `{url!s}` in the env") + else: + return proxy_info.proxy, proxy_info.proxy_auth + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class MimeType: + type: str + subtype: str + suffix: str + parameters: "MultiDictProxy[str]" + + +@functools.lru_cache(maxsize=56) +def parse_mimetype(mimetype: str) -> MimeType: + """Parses a MIME type into its components. + + mimetype is a MIME type string. + + Returns a MimeType object. + + Example: + + >>> parse_mimetype('text/html; charset=utf-8') + MimeType(type='text', subtype='html', suffix='', + parameters={'charset': 'utf-8'}) + + """ + if not mimetype: + return MimeType( + type="", subtype="", suffix="", parameters=MultiDictProxy(MultiDict()) + ) + + parts = mimetype.split(";") + params: MultiDict[str] = MultiDict() + for item in parts[1:]: + if not item: + continue + key, _, value = item.partition("=") + params.add(key.lower().strip(), value.strip(' "')) + + fulltype = parts[0].strip().lower() + if fulltype == "*": + fulltype = "*/*" + + mtype, _, stype = fulltype.partition("/") + stype, _, suffix = stype.partition("+") + + return MimeType( + type=mtype, subtype=stype, suffix=suffix, parameters=MultiDictProxy(params) + ) + + +def guess_filename(obj: Any, default: Optional[str] = None) -> Optional[str]: + name = getattr(obj, "name", None) + if name and isinstance(name, str) and name[0] != "<" and name[-1] != ">": + return Path(name).name + return default + + +not_qtext_re = re.compile(r"[^\041\043-\133\135-\176]") +QCONTENT = {chr(i) for i in range(0x20, 0x7F)} | {"\t"} + + +def quoted_string(content: str) -> str: + """Return 7-bit content as quoted-string. + + Format content into a quoted-string as defined in RFC5322 for + Internet Message Format. Notice that this is not the 8-bit HTTP + format, but the 7-bit email format. Content must be in usascii or + a ValueError is raised. + """ + if not (QCONTENT > set(content)): + raise ValueError(f"bad content for quoted-string {content!r}") + return not_qtext_re.sub(lambda x: "\\" + x.group(0), content) + + +def content_disposition_header( + disptype: str, quote_fields: bool = True, _charset: str = "utf-8", **params: str +) -> str: + """Sets ``Content-Disposition`` header for MIME. + + This is the MIME payload Content-Disposition header from RFC 2183 + and RFC 7579 section 4.2, not the HTTP Content-Disposition from + RFC 6266. + + disptype is a disposition type: inline, attachment, form-data. + Should be valid extension token (see RFC 2183) + + quote_fields performs value quoting to 7-bit MIME headers + according to RFC 7578. Set to quote_fields to False if recipient + can take 8-bit file names and field values. + + _charset specifies the charset to use when quote_fields is True. + + params is a dict with disposition params. + """ + if not disptype or not (TOKEN > set(disptype)): + raise ValueError("bad content disposition type {!r}" "".format(disptype)) + + value = disptype + if params: + lparams = [] + for key, val in params.items(): + if not key or not (TOKEN > set(key)): + raise ValueError( + "bad content disposition parameter" " {!r}={!r}".format(key, val) + ) + if quote_fields: + if key.lower() == "filename": + qval = quote(val, "", encoding=_charset) + lparams.append((key, '"%s"' % qval)) + else: + try: + qval = quoted_string(val) + except ValueError: + qval = "".join( + (_charset, "''", quote(val, "", encoding=_charset)) + ) + lparams.append((key + "*", qval)) + else: + lparams.append((key, '"%s"' % qval)) + else: + qval = val.replace("\\", "\\\\").replace('"', '\\"') + lparams.append((key, '"%s"' % qval)) + sparams = "; ".join("=".join(pair) for pair in lparams) + value = "; ".join((value, sparams)) + return value + + +class _TSelf(Protocol, Generic[_T]): + _cache: Dict[str, _T] + + +class reify(Generic[_T]): + """Use as a class method decorator. + + It operates almost exactly like + the Python `@property` decorator, but it puts the result of the + method it decorates into the instance dict after the first call, + effectively replacing the function it decorates with an instance + variable. It is, in Python parlance, a data descriptor. + """ + + def __init__(self, wrapped: Callable[..., _T]) -> None: + self.wrapped = wrapped + self.__doc__ = wrapped.__doc__ + self.name = wrapped.__name__ + + def __get__(self, inst: _TSelf[_T], owner: Optional[Type[Any]] = None) -> _T: + try: + try: + return inst._cache[self.name] + except KeyError: + val = self.wrapped(inst) + inst._cache[self.name] = val + return val + except AttributeError: + if inst is None: + return self + raise + + def __set__(self, inst: _TSelf[_T], value: _T) -> None: + raise AttributeError("reified property is read-only") + + +reify_py = reify + +try: + from ._helpers import reify as reify_c + + if not NO_EXTENSIONS: + reify = reify_c # type: ignore[misc,assignment] +except ImportError: + pass + +_ipv4_pattern = ( + r"^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}" + r"(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$" +) +_ipv6_pattern = ( + r"^(?:(?:(?:[A-F0-9]{1,4}:){6}|(?=(?:[A-F0-9]{0,4}:){0,6}" + r"(?:[0-9]{1,3}\.){3}[0-9]{1,3}$)(([0-9A-F]{1,4}:){0,5}|:)" + r"((:[0-9A-F]{1,4}){1,5}:|:)|::(?:[A-F0-9]{1,4}:){5})" + r"(?:(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])\.){3}" + r"(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])|(?:[A-F0-9]{1,4}:){7}" + r"[A-F0-9]{1,4}|(?=(?:[A-F0-9]{0,4}:){0,7}[A-F0-9]{0,4}$)" + r"(([0-9A-F]{1,4}:){1,7}|:)((:[0-9A-F]{1,4}){1,7}|:)|(?:[A-F0-9]{1,4}:){7}" + r":|:(:[A-F0-9]{1,4}){7})$" +) +_ipv4_regex = re.compile(_ipv4_pattern) +_ipv6_regex = re.compile(_ipv6_pattern, flags=re.IGNORECASE) +_ipv4_regexb = re.compile(_ipv4_pattern.encode("ascii")) +_ipv6_regexb = re.compile(_ipv6_pattern.encode("ascii"), flags=re.IGNORECASE) + + +def _is_ip_address( + regex: Pattern[str], regexb: Pattern[bytes], host: Optional[Union[str, bytes]] +) -> bool: + if host is None: + return False + if isinstance(host, str): + return bool(regex.match(host)) + elif isinstance(host, (bytes, bytearray, memoryview)): + return bool(regexb.match(host)) + else: + raise TypeError(f"{host} [{type(host)}] is not a str or bytes") + + +is_ipv4_address = functools.partial(_is_ip_address, _ipv4_regex, _ipv4_regexb) +is_ipv6_address = functools.partial(_is_ip_address, _ipv6_regex, _ipv6_regexb) + + +def is_ip_address(host: Optional[Union[str, bytes, bytearray, memoryview]]) -> bool: + return is_ipv4_address(host) or is_ipv6_address(host) + + +_cached_current_datetime: Optional[int] = None +_cached_formatted_datetime = "" + + +def rfc822_formatted_time() -> str: + global _cached_current_datetime + global _cached_formatted_datetime + + now = int(time.time()) + if now != _cached_current_datetime: + # Weekday and month names for HTTP date/time formatting; + # always English! + # Tuples are constants stored in codeobject! + _weekdayname = ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun") + _monthname = ( + "", # Dummy so we can use 1-based month numbers + "Jan", + "Feb", + "Mar", + "Apr", + "May", + "Jun", + "Jul", + "Aug", + "Sep", + "Oct", + "Nov", + "Dec", + ) + + year, month, day, hh, mm, ss, wd, *tail = time.gmtime(now) + _cached_formatted_datetime = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( + _weekdayname[wd], + day, + _monthname[month], + year, + hh, + mm, + ss, + ) + _cached_current_datetime = now + return _cached_formatted_datetime + + +def _weakref_handle(info: "Tuple[weakref.ref[object], str]") -> None: + ref, name = info + ob = ref() + if ob is not None: + with suppress(Exception): + getattr(ob, name)() + + +def weakref_handle( + ob: object, + name: str, + timeout: float, + loop: asyncio.AbstractEventLoop, + timeout_ceil_threshold: float = 5, +) -> Optional[asyncio.TimerHandle]: + if timeout is not None and timeout > 0: + when = loop.time() + timeout + if timeout >= timeout_ceil_threshold: + when = ceil(when) + + return loop.call_at(when, _weakref_handle, (weakref.ref(ob), name)) + return None + + +def call_later( + cb: Callable[[], Any], + timeout: float, + loop: asyncio.AbstractEventLoop, + timeout_ceil_threshold: float = 5, +) -> Optional[asyncio.TimerHandle]: + if timeout is not None and timeout > 0: + when = loop.time() + timeout + if timeout > timeout_ceil_threshold: + when = ceil(when) + return loop.call_at(when, cb) + return None + + +class TimeoutHandle: + """Timeout handle""" + + def __init__( + self, + loop: asyncio.AbstractEventLoop, + timeout: Optional[float], + ceil_threshold: float = 5, + ) -> None: + self._timeout = timeout + self._loop = loop + self._ceil_threshold = ceil_threshold + self._callbacks: List[ + Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]] + ] = [] + + def register( + self, callback: Callable[..., None], *args: Any, **kwargs: Any + ) -> None: + self._callbacks.append((callback, args, kwargs)) + + def close(self) -> None: + self._callbacks.clear() + + def start(self) -> Optional[asyncio.Handle]: + timeout = self._timeout + if timeout is not None and timeout > 0: + when = self._loop.time() + timeout + if timeout >= self._ceil_threshold: + when = ceil(when) + return self._loop.call_at(when, self.__call__) + else: + return None + + def timer(self) -> "BaseTimerContext": + if self._timeout is not None and self._timeout > 0: + timer = TimerContext(self._loop) + self.register(timer.timeout) + return timer + else: + return TimerNoop() + + def __call__(self) -> None: + for cb, args, kwargs in self._callbacks: + with suppress(Exception): + cb(*args, **kwargs) + + self._callbacks.clear() + + +class BaseTimerContext(ContextManager["BaseTimerContext"]): + def assert_timeout(self) -> None: + """Raise TimeoutError if timeout has been exceeded.""" + + +class TimerNoop(BaseTimerContext): + def __enter__(self) -> BaseTimerContext: + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + return + + +class TimerContext(BaseTimerContext): + """Low resolution timeout context manager""" + + def __init__(self, loop: asyncio.AbstractEventLoop) -> None: + self._loop = loop + self._tasks: List[asyncio.Task[Any]] = [] + self._cancelled = False + + def assert_timeout(self) -> None: + """Raise TimeoutError if timer has already been cancelled.""" + if self._cancelled: + raise asyncio.TimeoutError from None + + def __enter__(self) -> BaseTimerContext: + task = current_task(loop=self._loop) + + if task is None: + raise RuntimeError( + "Timeout context manager should be used " "inside a task" + ) + + if self._cancelled: + raise asyncio.TimeoutError from None + + self._tasks.append(task) + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> Optional[bool]: + if self._tasks: + self._tasks.pop() + + if exc_type is asyncio.CancelledError and self._cancelled: + raise asyncio.TimeoutError from None + return None + + def timeout(self) -> None: + if not self._cancelled: + for task in set(self._tasks): + task.cancel() + + self._cancelled = True + + +def ceil_timeout( + delay: Optional[float], ceil_threshold: float = 5 +) -> async_timeout.Timeout: + if delay is None or delay <= 0: + return async_timeout.timeout(None) + + loop = get_running_loop() + now = loop.time() + when = now + delay + if delay > ceil_threshold: + when = ceil(when) + return async_timeout.timeout_at(when) + + +class HeadersMixin: + ATTRS = frozenset(["_content_type", "_content_dict", "_stored_content_type"]) + + _headers: MultiMapping[str] + + _content_type: Optional[str] = None + _content_dict: Optional[Dict[str, str]] = None + _stored_content_type: Union[str, None, _SENTINEL] = sentinel + + def _parse_content_type(self, raw: Optional[str]) -> None: + self._stored_content_type = raw + if raw is None: + # default value according to RFC 2616 + self._content_type = "application/octet-stream" + self._content_dict = {} + else: + msg = HeaderParser().parsestr("Content-Type: " + raw) + self._content_type = msg.get_content_type() + params = msg.get_params(()) + self._content_dict = dict(params[1:]) # First element is content type again + + @property + def content_type(self) -> str: + """The value of content part for Content-Type HTTP header.""" + raw = self._headers.get(hdrs.CONTENT_TYPE) + if self._stored_content_type != raw: + self._parse_content_type(raw) + return self._content_type # type: ignore[return-value] + + @property + def charset(self) -> Optional[str]: + """The value of charset part for Content-Type HTTP header.""" + raw = self._headers.get(hdrs.CONTENT_TYPE) + if self._stored_content_type != raw: + self._parse_content_type(raw) + return self._content_dict.get("charset") # type: ignore[union-attr] + + @property + def content_length(self) -> Optional[int]: + """The value of Content-Length HTTP header.""" + content_length = self._headers.get(hdrs.CONTENT_LENGTH) + + if content_length is not None: + return int(content_length) + else: + return None + + +def set_result(fut: "asyncio.Future[_T]", result: _T) -> None: + if not fut.done(): + fut.set_result(result) + + +_EXC_SENTINEL = BaseException() + + +class ErrorableProtocol(Protocol): + def set_exception( + self, + exc: BaseException, + exc_cause: BaseException = ..., + ) -> None: + ... # pragma: no cover + + +def set_exception( + fut: "asyncio.Future[_T] | ErrorableProtocol", + exc: BaseException, + exc_cause: BaseException = _EXC_SENTINEL, +) -> None: + """Set future exception. + + If the future is marked as complete, this function is a no-op. + + :param exc_cause: An exception that is a direct cause of ``exc``. + Only set if provided. + """ + if asyncio.isfuture(fut) and fut.done(): + return + + exc_is_sentinel = exc_cause is _EXC_SENTINEL + exc_causes_itself = exc is exc_cause + if not exc_is_sentinel and not exc_causes_itself: + exc.__cause__ = exc_cause + + fut.set_exception(exc) + + +@functools.total_ordering +class AppKey(Generic[_T]): + """Keys for static typing support in Application.""" + + __slots__ = ("_name", "_t", "__orig_class__") + + # This may be set by Python when instantiating with a generic type. We need to + # support this, in order to support types that are not concrete classes, + # like Iterable, which can't be passed as the second parameter to __init__. + __orig_class__: Type[object] + + def __init__(self, name: str, t: Optional[Type[_T]] = None): + # Prefix with module name to help deduplicate key names. + frame = inspect.currentframe() + while frame: + if frame.f_code.co_name == "": + module: str = frame.f_globals["__name__"] + break + frame = frame.f_back + + self._name = module + "." + name + self._t = t + + def __lt__(self, other: object) -> bool: + if isinstance(other, AppKey): + return self._name < other._name + return True # Order AppKey above other types. + + def __repr__(self) -> str: + t = self._t + if t is None: + with suppress(AttributeError): + # Set to type arg. + t = get_args(self.__orig_class__)[0] + + if t is None: + t_repr = "<>" + elif isinstance(t, type): + if t.__module__ == "builtins": + t_repr = t.__qualname__ + else: + t_repr = f"{t.__module__}.{t.__qualname__}" + else: + t_repr = repr(t) + return f"" + + +class ChainMapProxy(Mapping[Union[str, AppKey[Any]], Any]): + __slots__ = ("_maps",) + + def __init__(self, maps: Iterable[Mapping[Union[str, AppKey[Any]], Any]]) -> None: + self._maps = tuple(maps) + + def __init_subclass__(cls) -> None: + raise TypeError( + "Inheritance class {} from ChainMapProxy " + "is forbidden".format(cls.__name__) + ) + + @overload # type: ignore[override] + def __getitem__(self, key: AppKey[_T]) -> _T: + ... + + @overload + def __getitem__(self, key: str) -> Any: + ... + + def __getitem__(self, key: Union[str, AppKey[_T]]) -> Any: + for mapping in self._maps: + try: + return mapping[key] + except KeyError: + pass + raise KeyError(key) + + @overload # type: ignore[override] + def get(self, key: AppKey[_T], default: _S) -> Union[_T, _S]: + ... + + @overload + def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]: + ... + + @overload + def get(self, key: str, default: Any = ...) -> Any: + ... + + def get(self, key: Union[str, AppKey[_T]], default: Any = None) -> Any: + try: + return self[key] + except KeyError: + return default + + def __len__(self) -> int: + # reuses stored hash values if possible + return len(set().union(*self._maps)) + + def __iter__(self) -> Iterator[Union[str, AppKey[Any]]]: + d: Dict[Union[str, AppKey[Any]], Any] = {} + for mapping in reversed(self._maps): + # reuses stored hash values if possible + d.update(mapping) + return iter(d) + + def __contains__(self, key: object) -> bool: + return any(key in m for m in self._maps) + + def __bool__(self) -> bool: + return any(self._maps) + + def __repr__(self) -> str: + content = ", ".join(map(repr, self._maps)) + return f"ChainMapProxy({content})" + + +# https://tools.ietf.org/html/rfc7232#section-2.3 +_ETAGC = r"[!\x23-\x7E\x80-\xff]+" +_ETAGC_RE = re.compile(_ETAGC) +_QUOTED_ETAG = rf'(W/)?"({_ETAGC})"' +QUOTED_ETAG_RE = re.compile(_QUOTED_ETAG) +LIST_QUOTED_ETAG_RE = re.compile(rf"({_QUOTED_ETAG})(?:\s*,\s*|$)|(.)") + +ETAG_ANY = "*" + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class ETag: + value: str + is_weak: bool = False + + +def validate_etag_value(value: str) -> None: + if value != ETAG_ANY and not _ETAGC_RE.fullmatch(value): + raise ValueError( + f"Value {value!r} is not a valid etag. Maybe it contains '\"'?" + ) + + +def parse_http_date(date_str: Optional[str]) -> Optional[datetime.datetime]: + """Process a date string, return a datetime object""" + if date_str is not None: + timetuple = parsedate(date_str) + if timetuple is not None: + with suppress(ValueError): + return datetime.datetime(*timetuple[:6], tzinfo=datetime.timezone.utc) + return None + + +def must_be_empty_body(method: str, code: int) -> bool: + """Check if a request must return an empty body.""" + return ( + status_code_must_be_empty_body(code) + or method_must_be_empty_body(method) + or (200 <= code < 300 and method.upper() == hdrs.METH_CONNECT) + ) + + +def method_must_be_empty_body(method: str) -> bool: + """Check if a method must return an empty body.""" + # https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1 + # https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.2 + return method.upper() == hdrs.METH_HEAD + + +def status_code_must_be_empty_body(code: int) -> bool: + """Check if a status code must return an empty body.""" + # https://datatracker.ietf.org/doc/html/rfc9112#section-6.3-2.1 + return code in {204, 304} or 100 <= code < 200 + + +def should_remove_content_length(method: str, code: int) -> bool: + """Check if a Content-Length header should be removed. + + This should always be a subset of must_be_empty_body + """ + # https://www.rfc-editor.org/rfc/rfc9110.html#section-8.6-8 + # https://www.rfc-editor.org/rfc/rfc9110.html#section-15.4.5-4 + return ( + code in {204, 304} + or 100 <= code < 200 + or (200 <= code < 300 and method.upper() == hdrs.METH_CONNECT) + ) diff --git a/llm/Lib/site-packages/aiohttp/http.py b/llm/Lib/site-packages/aiohttp/http.py new file mode 100644 index 0000000000000000000000000000000000000000..ccc95f568c4ae4f56ffc965f36aa625a98f7ef9c --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/http.py @@ -0,0 +1,72 @@ +import sys +from http import HTTPStatus +from typing import Mapping, Tuple + +from . import __version__ +from .http_exceptions import HttpProcessingError as HttpProcessingError +from .http_parser import ( + HeadersParser as HeadersParser, + HttpParser as HttpParser, + HttpRequestParser as HttpRequestParser, + HttpResponseParser as HttpResponseParser, + RawRequestMessage as RawRequestMessage, + RawResponseMessage as RawResponseMessage, +) +from .http_websocket import ( + WS_CLOSED_MESSAGE as WS_CLOSED_MESSAGE, + WS_CLOSING_MESSAGE as WS_CLOSING_MESSAGE, + WS_KEY as WS_KEY, + WebSocketError as WebSocketError, + WebSocketReader as WebSocketReader, + WebSocketWriter as WebSocketWriter, + WSCloseCode as WSCloseCode, + WSMessage as WSMessage, + WSMsgType as WSMsgType, + ws_ext_gen as ws_ext_gen, + ws_ext_parse as ws_ext_parse, +) +from .http_writer import ( + HttpVersion as HttpVersion, + HttpVersion10 as HttpVersion10, + HttpVersion11 as HttpVersion11, + StreamWriter as StreamWriter, +) + +__all__ = ( + "HttpProcessingError", + "RESPONSES", + "SERVER_SOFTWARE", + # .http_writer + "StreamWriter", + "HttpVersion", + "HttpVersion10", + "HttpVersion11", + # .http_parser + "HeadersParser", + "HttpParser", + "HttpRequestParser", + "HttpResponseParser", + "RawRequestMessage", + "RawResponseMessage", + # .http_websocket + "WS_CLOSED_MESSAGE", + "WS_CLOSING_MESSAGE", + "WS_KEY", + "WebSocketReader", + "WebSocketWriter", + "ws_ext_gen", + "ws_ext_parse", + "WSMessage", + "WebSocketError", + "WSMsgType", + "WSCloseCode", +) + + +SERVER_SOFTWARE: str = "Python/{0[0]}.{0[1]} aiohttp/{1}".format( + sys.version_info, __version__ +) + +RESPONSES: Mapping[int, Tuple[str, str]] = { + v: (v.phrase, v.description) for v in HTTPStatus.__members__.values() +} diff --git a/llm/Lib/site-packages/aiohttp/http_exceptions.py b/llm/Lib/site-packages/aiohttp/http_exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..e8abca29fc22220e30eba52900a44fee9b220b6e --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/http_exceptions.py @@ -0,0 +1,106 @@ +"""Low-level http related exceptions.""" + + +from textwrap import indent +from typing import Optional, Union + +from .typedefs import _CIMultiDict + +__all__ = ("HttpProcessingError",) + + +class HttpProcessingError(Exception): + """HTTP error. + + Shortcut for raising HTTP errors with custom code, message and headers. + + code: HTTP Error code. + message: (optional) Error message. + headers: (optional) Headers to be sent in response, a list of pairs + """ + + code = 0 + message = "" + headers = None + + def __init__( + self, + *, + code: Optional[int] = None, + message: str = "", + headers: Optional[_CIMultiDict] = None, + ) -> None: + if code is not None: + self.code = code + self.headers = headers + self.message = message + + def __str__(self) -> str: + msg = indent(self.message, " ") + return f"{self.code}, message:\n{msg}" + + def __repr__(self) -> str: + return f"<{self.__class__.__name__}: {self.code}, message={self.message!r}>" + + +class BadHttpMessage(HttpProcessingError): + + code = 400 + message = "Bad Request" + + def __init__(self, message: str, *, headers: Optional[_CIMultiDict] = None) -> None: + super().__init__(message=message, headers=headers) + self.args = (message,) + + +class HttpBadRequest(BadHttpMessage): + + code = 400 + message = "Bad Request" + + +class PayloadEncodingError(BadHttpMessage): + """Base class for payload errors""" + + +class ContentEncodingError(PayloadEncodingError): + """Content encoding error.""" + + +class TransferEncodingError(PayloadEncodingError): + """transfer encoding error.""" + + +class ContentLengthError(PayloadEncodingError): + """Not enough data for satisfy content length header.""" + + +class LineTooLong(BadHttpMessage): + def __init__( + self, line: str, limit: str = "Unknown", actual_size: str = "Unknown" + ) -> None: + super().__init__( + f"Got more than {limit} bytes ({actual_size}) when reading {line}." + ) + self.args = (line, limit, actual_size) + + +class InvalidHeader(BadHttpMessage): + def __init__(self, hdr: Union[bytes, str]) -> None: + hdr_s = hdr.decode(errors="backslashreplace") if isinstance(hdr, bytes) else hdr + super().__init__(f"Invalid HTTP header: {hdr!r}") + self.hdr = hdr_s + self.args = (hdr,) + + +class BadStatusLine(BadHttpMessage): + def __init__(self, line: str = "", error: Optional[str] = None) -> None: + if not isinstance(line, str): + line = repr(line) + super().__init__(error or f"Bad status line {line!r}") + self.args = (line,) + self.line = line + + +class InvalidURLError(BadHttpMessage): + pass diff --git a/llm/Lib/site-packages/aiohttp/http_parser.py b/llm/Lib/site-packages/aiohttp/http_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..7ea7fc3e5447936fb88a202e709cb3c0640ed300 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/http_parser.py @@ -0,0 +1,1041 @@ +import abc +import asyncio +import re +import string +from contextlib import suppress +from enum import IntEnum +from typing import ( + Any, + ClassVar, + Final, + Generic, + List, + Literal, + NamedTuple, + Optional, + Pattern, + Set, + Tuple, + Type, + TypeVar, + Union, +) + +from multidict import CIMultiDict, CIMultiDictProxy, istr +from yarl import URL + +from . import hdrs +from .base_protocol import BaseProtocol +from .compression_utils import HAS_BROTLI, BrotliDecompressor, ZLibDecompressor +from .helpers import ( + _EXC_SENTINEL, + DEBUG, + NO_EXTENSIONS, + BaseTimerContext, + method_must_be_empty_body, + set_exception, + status_code_must_be_empty_body, +) +from .http_exceptions import ( + BadHttpMessage, + BadStatusLine, + ContentEncodingError, + ContentLengthError, + InvalidHeader, + InvalidURLError, + LineTooLong, + TransferEncodingError, +) +from .http_writer import HttpVersion, HttpVersion10 +from .log import internal_logger +from .streams import EMPTY_PAYLOAD, StreamReader +from .typedefs import RawHeaders + +__all__ = ( + "HeadersParser", + "HttpParser", + "HttpRequestParser", + "HttpResponseParser", + "RawRequestMessage", + "RawResponseMessage", +) + +_SEP = Literal[b"\r\n", b"\n"] + +ASCIISET: Final[Set[str]] = set(string.printable) + +# See https://www.rfc-editor.org/rfc/rfc9110.html#name-overview +# and https://www.rfc-editor.org/rfc/rfc9110.html#name-tokens +# +# method = token +# tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." / +# "^" / "_" / "`" / "|" / "~" / DIGIT / ALPHA +# token = 1*tchar +_TCHAR_SPECIALS: Final[str] = re.escape("!#$%&'*+-.^_`|~") +TOKENRE: Final[Pattern[str]] = re.compile(f"[0-9A-Za-z{_TCHAR_SPECIALS}]+") +VERSRE: Final[Pattern[str]] = re.compile(r"HTTP/(\d)\.(\d)", re.ASCII) +DIGITS: Final[Pattern[str]] = re.compile(r"\d+", re.ASCII) +HEXDIGITS: Final[Pattern[bytes]] = re.compile(rb"[0-9a-fA-F]+") + + +class RawRequestMessage(NamedTuple): + method: str + path: str + version: HttpVersion + headers: "CIMultiDictProxy[str]" + raw_headers: RawHeaders + should_close: bool + compression: Optional[str] + upgrade: bool + chunked: bool + url: URL + + +class RawResponseMessage(NamedTuple): + version: HttpVersion + code: int + reason: str + headers: CIMultiDictProxy[str] + raw_headers: RawHeaders + should_close: bool + compression: Optional[str] + upgrade: bool + chunked: bool + + +_MsgT = TypeVar("_MsgT", RawRequestMessage, RawResponseMessage) + + +class ParseState(IntEnum): + + PARSE_NONE = 0 + PARSE_LENGTH = 1 + PARSE_CHUNKED = 2 + PARSE_UNTIL_EOF = 3 + + +class ChunkState(IntEnum): + PARSE_CHUNKED_SIZE = 0 + PARSE_CHUNKED_CHUNK = 1 + PARSE_CHUNKED_CHUNK_EOF = 2 + PARSE_MAYBE_TRAILERS = 3 + PARSE_TRAILERS = 4 + + +class HeadersParser: + def __init__( + self, + max_line_size: int = 8190, + max_headers: int = 32768, + max_field_size: int = 8190, + lax: bool = False, + ) -> None: + self.max_line_size = max_line_size + self.max_headers = max_headers + self.max_field_size = max_field_size + self._lax = lax + + def parse_headers( + self, lines: List[bytes] + ) -> Tuple["CIMultiDictProxy[str]", RawHeaders]: + headers: CIMultiDict[str] = CIMultiDict() + # note: "raw" does not mean inclusion of OWS before/after the field value + raw_headers = [] + + lines_idx = 1 + line = lines[1] + line_count = len(lines) + + while line: + # Parse initial header name : value pair. + try: + bname, bvalue = line.split(b":", 1) + except ValueError: + raise InvalidHeader(line) from None + + if len(bname) == 0: + raise InvalidHeader(bname) + + # https://www.rfc-editor.org/rfc/rfc9112.html#section-5.1-2 + if {bname[0], bname[-1]} & {32, 9}: # {" ", "\t"} + raise InvalidHeader(line) + + bvalue = bvalue.lstrip(b" \t") + if len(bname) > self.max_field_size: + raise LineTooLong( + "request header name {}".format( + bname.decode("utf8", "backslashreplace") + ), + str(self.max_field_size), + str(len(bname)), + ) + name = bname.decode("utf-8", "surrogateescape") + if not TOKENRE.fullmatch(name): + raise InvalidHeader(bname) + + header_length = len(bvalue) + + # next line + lines_idx += 1 + line = lines[lines_idx] + + # consume continuation lines + continuation = self._lax and line and line[0] in (32, 9) # (' ', '\t') + + # Deprecated: https://www.rfc-editor.org/rfc/rfc9112.html#name-obsolete-line-folding + if continuation: + bvalue_lst = [bvalue] + while continuation: + header_length += len(line) + if header_length > self.max_field_size: + raise LineTooLong( + "request header field {}".format( + bname.decode("utf8", "backslashreplace") + ), + str(self.max_field_size), + str(header_length), + ) + bvalue_lst.append(line) + + # next line + lines_idx += 1 + if lines_idx < line_count: + line = lines[lines_idx] + if line: + continuation = line[0] in (32, 9) # (' ', '\t') + else: + line = b"" + break + bvalue = b"".join(bvalue_lst) + else: + if header_length > self.max_field_size: + raise LineTooLong( + "request header field {}".format( + bname.decode("utf8", "backslashreplace") + ), + str(self.max_field_size), + str(header_length), + ) + + bvalue = bvalue.strip(b" \t") + value = bvalue.decode("utf-8", "surrogateescape") + + # https://www.rfc-editor.org/rfc/rfc9110.html#section-5.5-5 + if "\n" in value or "\r" in value or "\x00" in value: + raise InvalidHeader(bvalue) + + headers.add(name, value) + raw_headers.append((bname, bvalue)) + + return (CIMultiDictProxy(headers), tuple(raw_headers)) + + +def _is_supported_upgrade(headers: CIMultiDictProxy[str]) -> bool: + """Check if the upgrade header is supported.""" + return headers.get(hdrs.UPGRADE, "").lower() in {"tcp", "websocket"} + + +class HttpParser(abc.ABC, Generic[_MsgT]): + lax: ClassVar[bool] = False + + def __init__( + self, + protocol: Optional[BaseProtocol] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, + limit: int = 2**16, + max_line_size: int = 8190, + max_headers: int = 32768, + max_field_size: int = 8190, + timer: Optional[BaseTimerContext] = None, + code: Optional[int] = None, + method: Optional[str] = None, + readall: bool = False, + payload_exception: Optional[Type[BaseException]] = None, + response_with_body: bool = True, + read_until_eof: bool = False, + auto_decompress: bool = True, + ) -> None: + self.protocol = protocol + self.loop = loop + self.max_line_size = max_line_size + self.max_headers = max_headers + self.max_field_size = max_field_size + self.timer = timer + self.code = code + self.method = method + self.readall = readall + self.payload_exception = payload_exception + self.response_with_body = response_with_body + self.read_until_eof = read_until_eof + + self._lines: List[bytes] = [] + self._tail = b"" + self._upgraded = False + self._payload = None + self._payload_parser: Optional[HttpPayloadParser] = None + self._auto_decompress = auto_decompress + self._limit = limit + self._headers_parser = HeadersParser( + max_line_size, max_headers, max_field_size, self.lax + ) + + @abc.abstractmethod + def parse_message(self, lines: List[bytes]) -> _MsgT: + pass + + def feed_eof(self) -> Optional[_MsgT]: + if self._payload_parser is not None: + self._payload_parser.feed_eof() + self._payload_parser = None + else: + # try to extract partial message + if self._tail: + self._lines.append(self._tail) + + if self._lines: + if self._lines[-1] != "\r\n": + self._lines.append(b"") + with suppress(Exception): + return self.parse_message(self._lines) + return None + + def feed_data( + self, + data: bytes, + SEP: _SEP = b"\r\n", + EMPTY: bytes = b"", + CONTENT_LENGTH: istr = hdrs.CONTENT_LENGTH, + METH_CONNECT: str = hdrs.METH_CONNECT, + SEC_WEBSOCKET_KEY1: istr = hdrs.SEC_WEBSOCKET_KEY1, + ) -> Tuple[List[Tuple[_MsgT, StreamReader]], bool, bytes]: + + messages = [] + + if self._tail: + data, self._tail = self._tail + data, b"" + + data_len = len(data) + start_pos = 0 + loop = self.loop + + while start_pos < data_len: + + # read HTTP message (request/response line + headers), \r\n\r\n + # and split by lines + if self._payload_parser is None and not self._upgraded: + pos = data.find(SEP, start_pos) + # consume \r\n + if pos == start_pos and not self._lines: + start_pos = pos + len(SEP) + continue + + if pos >= start_pos: + # line found + line = data[start_pos:pos] + if SEP == b"\n": # For lax response parsing + line = line.rstrip(b"\r") + self._lines.append(line) + start_pos = pos + len(SEP) + + # \r\n\r\n found + if self._lines[-1] == EMPTY: + try: + msg: _MsgT = self.parse_message(self._lines) + finally: + self._lines.clear() + + def get_content_length() -> Optional[int]: + # payload length + length_hdr = msg.headers.get(CONTENT_LENGTH) + if length_hdr is None: + return None + + # Shouldn't allow +/- or other number formats. + # https://www.rfc-editor.org/rfc/rfc9110#section-8.6-2 + # msg.headers is already stripped of leading/trailing wsp + if not DIGITS.fullmatch(length_hdr): + raise InvalidHeader(CONTENT_LENGTH) + + return int(length_hdr) + + length = get_content_length() + # do not support old websocket spec + if SEC_WEBSOCKET_KEY1 in msg.headers: + raise InvalidHeader(SEC_WEBSOCKET_KEY1) + + self._upgraded = msg.upgrade and _is_supported_upgrade( + msg.headers + ) + + method = getattr(msg, "method", self.method) + # code is only present on responses + code = getattr(msg, "code", 0) + + assert self.protocol is not None + # calculate payload + empty_body = status_code_must_be_empty_body(code) or bool( + method and method_must_be_empty_body(method) + ) + if not empty_body and ( + ((length is not None and length > 0) or msg.chunked) + and not self._upgraded + ): + payload = StreamReader( + self.protocol, + timer=self.timer, + loop=loop, + limit=self._limit, + ) + payload_parser = HttpPayloadParser( + payload, + length=length, + chunked=msg.chunked, + method=method, + compression=msg.compression, + code=self.code, + readall=self.readall, + response_with_body=self.response_with_body, + auto_decompress=self._auto_decompress, + lax=self.lax, + ) + if not payload_parser.done: + self._payload_parser = payload_parser + elif method == METH_CONNECT: + assert isinstance(msg, RawRequestMessage) + payload = StreamReader( + self.protocol, + timer=self.timer, + loop=loop, + limit=self._limit, + ) + self._upgraded = True + self._payload_parser = HttpPayloadParser( + payload, + method=msg.method, + compression=msg.compression, + readall=True, + auto_decompress=self._auto_decompress, + lax=self.lax, + ) + elif not empty_body and length is None and self.read_until_eof: + payload = StreamReader( + self.protocol, + timer=self.timer, + loop=loop, + limit=self._limit, + ) + payload_parser = HttpPayloadParser( + payload, + length=length, + chunked=msg.chunked, + method=method, + compression=msg.compression, + code=self.code, + readall=True, + response_with_body=self.response_with_body, + auto_decompress=self._auto_decompress, + lax=self.lax, + ) + if not payload_parser.done: + self._payload_parser = payload_parser + else: + payload = EMPTY_PAYLOAD + + messages.append((msg, payload)) + else: + self._tail = data[start_pos:] + data = EMPTY + break + + # no parser, just store + elif self._payload_parser is None and self._upgraded: + assert not self._lines + break + + # feed payload + elif data and start_pos < data_len: + assert not self._lines + assert self._payload_parser is not None + try: + eof, data = self._payload_parser.feed_data(data[start_pos:], SEP) + except BaseException as underlying_exc: + reraised_exc = underlying_exc + if self.payload_exception is not None: + reraised_exc = self.payload_exception(str(underlying_exc)) + + set_exception( + self._payload_parser.payload, + reraised_exc, + underlying_exc, + ) + + eof = True + data = b"" + + if eof: + start_pos = 0 + data_len = len(data) + self._payload_parser = None + continue + else: + break + + if data and start_pos < data_len: + data = data[start_pos:] + else: + data = EMPTY + + return messages, self._upgraded, data + + def parse_headers( + self, lines: List[bytes] + ) -> Tuple[ + "CIMultiDictProxy[str]", RawHeaders, Optional[bool], Optional[str], bool, bool + ]: + """Parses RFC 5322 headers from a stream. + + Line continuations are supported. Returns list of header name + and value pairs. Header name is in upper case. + """ + headers, raw_headers = self._headers_parser.parse_headers(lines) + close_conn = None + encoding = None + upgrade = False + chunked = False + + # https://www.rfc-editor.org/rfc/rfc9110.html#section-5.5-6 + # https://www.rfc-editor.org/rfc/rfc9110.html#name-collected-abnf + singletons = ( + hdrs.CONTENT_LENGTH, + hdrs.CONTENT_LOCATION, + hdrs.CONTENT_RANGE, + hdrs.CONTENT_TYPE, + hdrs.ETAG, + hdrs.HOST, + hdrs.MAX_FORWARDS, + hdrs.SERVER, + hdrs.TRANSFER_ENCODING, + hdrs.USER_AGENT, + ) + bad_hdr = next((h for h in singletons if len(headers.getall(h, ())) > 1), None) + if bad_hdr is not None: + raise BadHttpMessage(f"Duplicate '{bad_hdr}' header found.") + + # keep-alive + conn = headers.get(hdrs.CONNECTION) + if conn: + v = conn.lower() + if v == "close": + close_conn = True + elif v == "keep-alive": + close_conn = False + # https://www.rfc-editor.org/rfc/rfc9110.html#name-101-switching-protocols + elif v == "upgrade" and headers.get(hdrs.UPGRADE): + upgrade = True + + # encoding + enc = headers.get(hdrs.CONTENT_ENCODING) + if enc: + enc = enc.lower() + if enc in ("gzip", "deflate", "br"): + encoding = enc + + # chunking + te = headers.get(hdrs.TRANSFER_ENCODING) + if te is not None: + if "chunked" == te.lower(): + chunked = True + else: + raise BadHttpMessage("Request has invalid `Transfer-Encoding`") + + if hdrs.CONTENT_LENGTH in headers: + raise BadHttpMessage( + "Transfer-Encoding can't be present with Content-Length", + ) + + return (headers, raw_headers, close_conn, encoding, upgrade, chunked) + + def set_upgraded(self, val: bool) -> None: + """Set connection upgraded (to websocket) mode. + + :param bool val: new state. + """ + self._upgraded = val + + +class HttpRequestParser(HttpParser[RawRequestMessage]): + """Read request status line. + + Exception .http_exceptions.BadStatusLine + could be raised in case of any errors in status line. + Returns RawRequestMessage. + """ + + def parse_message(self, lines: List[bytes]) -> RawRequestMessage: + # request line + line = lines[0].decode("utf-8", "surrogateescape") + try: + method, path, version = line.split(" ", maxsplit=2) + except ValueError: + raise BadStatusLine(line) from None + + if len(path) > self.max_line_size: + raise LineTooLong( + "Status line is too long", str(self.max_line_size), str(len(path)) + ) + + # method + if not TOKENRE.fullmatch(method): + raise BadStatusLine(method) + + # version + match = VERSRE.fullmatch(version) + if match is None: + raise BadStatusLine(line) + version_o = HttpVersion(int(match.group(1)), int(match.group(2))) + + if method == "CONNECT": + # authority-form, + # https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.3 + url = URL.build(authority=path, encoded=True) + elif path.startswith("/"): + # origin-form, + # https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.1 + path_part, _hash_separator, url_fragment = path.partition("#") + path_part, _question_mark_separator, qs_part = path_part.partition("?") + + # NOTE: `yarl.URL.build()` is used to mimic what the Cython-based + # NOTE: parser does, otherwise it results into the same + # NOTE: HTTP Request-Line input producing different + # NOTE: `yarl.URL()` objects + url = URL.build( + path=path_part, + query_string=qs_part, + fragment=url_fragment, + encoded=True, + ) + elif path == "*" and method == "OPTIONS": + # asterisk-form, + url = URL(path, encoded=True) + else: + # absolute-form for proxy maybe, + # https://datatracker.ietf.org/doc/html/rfc7230#section-5.3.2 + url = URL(path, encoded=True) + if url.scheme == "": + # not absolute-form + raise InvalidURLError( + path.encode(errors="surrogateescape").decode("latin1") + ) + + # read headers + ( + headers, + raw_headers, + close, + compression, + upgrade, + chunked, + ) = self.parse_headers(lines) + + if close is None: # then the headers weren't set in the request + if version_o <= HttpVersion10: # HTTP 1.0 must asks to not close + close = True + else: # HTTP 1.1 must ask to close. + close = False + + return RawRequestMessage( + method, + path, + version_o, + headers, + raw_headers, + close, + compression, + upgrade, + chunked, + url, + ) + + +class HttpResponseParser(HttpParser[RawResponseMessage]): + """Read response status line and headers. + + BadStatusLine could be raised in case of any errors in status line. + Returns RawResponseMessage. + """ + + # Lax mode should only be enabled on response parser. + lax = not DEBUG + + def feed_data( + self, + data: bytes, + SEP: Optional[_SEP] = None, + *args: Any, + **kwargs: Any, + ) -> Tuple[List[Tuple[RawResponseMessage, StreamReader]], bool, bytes]: + if SEP is None: + SEP = b"\r\n" if DEBUG else b"\n" + return super().feed_data(data, SEP, *args, **kwargs) + + def parse_message(self, lines: List[bytes]) -> RawResponseMessage: + line = lines[0].decode("utf-8", "surrogateescape") + try: + version, status = line.split(maxsplit=1) + except ValueError: + raise BadStatusLine(line) from None + + try: + status, reason = status.split(maxsplit=1) + except ValueError: + status = status.strip() + reason = "" + + if len(reason) > self.max_line_size: + raise LineTooLong( + "Status line is too long", str(self.max_line_size), str(len(reason)) + ) + + # version + match = VERSRE.fullmatch(version) + if match is None: + raise BadStatusLine(line) + version_o = HttpVersion(int(match.group(1)), int(match.group(2))) + + # The status code is a three-digit ASCII number, no padding + if len(status) != 3 or not DIGITS.fullmatch(status): + raise BadStatusLine(line) + status_i = int(status) + + # read headers + ( + headers, + raw_headers, + close, + compression, + upgrade, + chunked, + ) = self.parse_headers(lines) + + if close is None: + if version_o <= HttpVersion10: + close = True + # https://www.rfc-editor.org/rfc/rfc9112.html#name-message-body-length + elif 100 <= status_i < 200 or status_i in {204, 304}: + close = False + elif hdrs.CONTENT_LENGTH in headers or hdrs.TRANSFER_ENCODING in headers: + close = False + else: + # https://www.rfc-editor.org/rfc/rfc9112.html#section-6.3-2.8 + close = True + + return RawResponseMessage( + version_o, + status_i, + reason.strip(), + headers, + raw_headers, + close, + compression, + upgrade, + chunked, + ) + + +class HttpPayloadParser: + def __init__( + self, + payload: StreamReader, + length: Optional[int] = None, + chunked: bool = False, + compression: Optional[str] = None, + code: Optional[int] = None, + method: Optional[str] = None, + readall: bool = False, + response_with_body: bool = True, + auto_decompress: bool = True, + lax: bool = False, + ) -> None: + self._length = 0 + self._type = ParseState.PARSE_NONE + self._chunk = ChunkState.PARSE_CHUNKED_SIZE + self._chunk_size = 0 + self._chunk_tail = b"" + self._auto_decompress = auto_decompress + self._lax = lax + self.done = False + + # payload decompression wrapper + if response_with_body and compression and self._auto_decompress: + real_payload: Union[StreamReader, DeflateBuffer] = DeflateBuffer( + payload, compression + ) + else: + real_payload = payload + + # payload parser + if not response_with_body: + # don't parse payload if it's not expected to be received + self._type = ParseState.PARSE_NONE + real_payload.feed_eof() + self.done = True + + elif chunked: + self._type = ParseState.PARSE_CHUNKED + elif length is not None: + self._type = ParseState.PARSE_LENGTH + self._length = length + if self._length == 0: + real_payload.feed_eof() + self.done = True + else: + if readall and code != 204: + self._type = ParseState.PARSE_UNTIL_EOF + elif method in ("PUT", "POST"): + internal_logger.warning( # pragma: no cover + "Content-Length or Transfer-Encoding header is required" + ) + self._type = ParseState.PARSE_NONE + real_payload.feed_eof() + self.done = True + + self.payload = real_payload + + def feed_eof(self) -> None: + if self._type == ParseState.PARSE_UNTIL_EOF: + self.payload.feed_eof() + elif self._type == ParseState.PARSE_LENGTH: + raise ContentLengthError( + "Not enough data for satisfy content length header." + ) + elif self._type == ParseState.PARSE_CHUNKED: + raise TransferEncodingError( + "Not enough data for satisfy transfer length header." + ) + + def feed_data( + self, chunk: bytes, SEP: _SEP = b"\r\n", CHUNK_EXT: bytes = b";" + ) -> Tuple[bool, bytes]: + # Read specified amount of bytes + if self._type == ParseState.PARSE_LENGTH: + required = self._length + chunk_len = len(chunk) + + if required >= chunk_len: + self._length = required - chunk_len + self.payload.feed_data(chunk, chunk_len) + if self._length == 0: + self.payload.feed_eof() + return True, b"" + else: + self._length = 0 + self.payload.feed_data(chunk[:required], required) + self.payload.feed_eof() + return True, chunk[required:] + + # Chunked transfer encoding parser + elif self._type == ParseState.PARSE_CHUNKED: + if self._chunk_tail: + chunk = self._chunk_tail + chunk + self._chunk_tail = b"" + + while chunk: + + # read next chunk size + if self._chunk == ChunkState.PARSE_CHUNKED_SIZE: + pos = chunk.find(SEP) + if pos >= 0: + i = chunk.find(CHUNK_EXT, 0, pos) + if i >= 0: + size_b = chunk[:i] # strip chunk-extensions + else: + size_b = chunk[:pos] + + if self._lax: # Allow whitespace in lax mode. + size_b = size_b.strip() + + if not re.fullmatch(HEXDIGITS, size_b): + exc = TransferEncodingError( + chunk[:pos].decode("ascii", "surrogateescape") + ) + set_exception(self.payload, exc) + raise exc + size = int(bytes(size_b), 16) + + chunk = chunk[pos + len(SEP) :] + if size == 0: # eof marker + self._chunk = ChunkState.PARSE_MAYBE_TRAILERS + if self._lax and chunk.startswith(b"\r"): + chunk = chunk[1:] + else: + self._chunk = ChunkState.PARSE_CHUNKED_CHUNK + self._chunk_size = size + self.payload.begin_http_chunk_receiving() + else: + self._chunk_tail = chunk + return False, b"" + + # read chunk and feed buffer + if self._chunk == ChunkState.PARSE_CHUNKED_CHUNK: + required = self._chunk_size + chunk_len = len(chunk) + + if required > chunk_len: + self._chunk_size = required - chunk_len + self.payload.feed_data(chunk, chunk_len) + return False, b"" + else: + self._chunk_size = 0 + self.payload.feed_data(chunk[:required], required) + chunk = chunk[required:] + if self._lax and chunk.startswith(b"\r"): + chunk = chunk[1:] + self._chunk = ChunkState.PARSE_CHUNKED_CHUNK_EOF + self.payload.end_http_chunk_receiving() + + # toss the CRLF at the end of the chunk + if self._chunk == ChunkState.PARSE_CHUNKED_CHUNK_EOF: + if chunk[: len(SEP)] == SEP: + chunk = chunk[len(SEP) :] + self._chunk = ChunkState.PARSE_CHUNKED_SIZE + else: + self._chunk_tail = chunk + return False, b"" + + # if stream does not contain trailer, after 0\r\n + # we should get another \r\n otherwise + # trailers needs to be skipped until \r\n\r\n + if self._chunk == ChunkState.PARSE_MAYBE_TRAILERS: + head = chunk[: len(SEP)] + if head == SEP: + # end of stream + self.payload.feed_eof() + return True, chunk[len(SEP) :] + # Both CR and LF, or only LF may not be received yet. It is + # expected that CRLF or LF will be shown at the very first + # byte next time, otherwise trailers should come. The last + # CRLF which marks the end of response might not be + # contained in the same TCP segment which delivered the + # size indicator. + if not head: + return False, b"" + if head == SEP[:1]: + self._chunk_tail = head + return False, b"" + self._chunk = ChunkState.PARSE_TRAILERS + + # read and discard trailer up to the CRLF terminator + if self._chunk == ChunkState.PARSE_TRAILERS: + pos = chunk.find(SEP) + if pos >= 0: + chunk = chunk[pos + len(SEP) :] + self._chunk = ChunkState.PARSE_MAYBE_TRAILERS + else: + self._chunk_tail = chunk + return False, b"" + + # Read all bytes until eof + elif self._type == ParseState.PARSE_UNTIL_EOF: + self.payload.feed_data(chunk, len(chunk)) + + return False, b"" + + +class DeflateBuffer: + """DeflateStream decompress stream and feed data into specified stream.""" + + decompressor: Any + + def __init__(self, out: StreamReader, encoding: Optional[str]) -> None: + self.out = out + self.size = 0 + self.encoding = encoding + self._started_decoding = False + + self.decompressor: Union[BrotliDecompressor, ZLibDecompressor] + if encoding == "br": + if not HAS_BROTLI: # pragma: no cover + raise ContentEncodingError( + "Can not decode content-encoding: brotli (br). " + "Please install `Brotli`" + ) + self.decompressor = BrotliDecompressor() + else: + self.decompressor = ZLibDecompressor(encoding=encoding) + + def set_exception( + self, + exc: BaseException, + exc_cause: BaseException = _EXC_SENTINEL, + ) -> None: + set_exception(self.out, exc, exc_cause) + + def feed_data(self, chunk: bytes, size: int) -> None: + if not size: + return + + self.size += size + + # RFC1950 + # bits 0..3 = CM = 0b1000 = 8 = "deflate" + # bits 4..7 = CINFO = 1..7 = windows size. + if ( + not self._started_decoding + and self.encoding == "deflate" + and chunk[0] & 0xF != 8 + ): + # Change the decoder to decompress incorrectly compressed data + # Actually we should issue a warning about non-RFC-compliant data. + self.decompressor = ZLibDecompressor( + encoding=self.encoding, suppress_deflate_header=True + ) + + try: + chunk = self.decompressor.decompress_sync(chunk) + except Exception: + raise ContentEncodingError( + "Can not decode content-encoding: %s" % self.encoding + ) + + self._started_decoding = True + + if chunk: + self.out.feed_data(chunk, len(chunk)) + + def feed_eof(self) -> None: + chunk = self.decompressor.flush() + + if chunk or self.size > 0: + self.out.feed_data(chunk, len(chunk)) + if self.encoding == "deflate" and not self.decompressor.eof: + raise ContentEncodingError("deflate") + + self.out.feed_eof() + + def begin_http_chunk_receiving(self) -> None: + self.out.begin_http_chunk_receiving() + + def end_http_chunk_receiving(self) -> None: + self.out.end_http_chunk_receiving() + + +HttpRequestParserPy = HttpRequestParser +HttpResponseParserPy = HttpResponseParser +RawRequestMessagePy = RawRequestMessage +RawResponseMessagePy = RawResponseMessage + +try: + if not NO_EXTENSIONS: + from ._http_parser import ( # type: ignore[import-not-found,no-redef] + HttpRequestParser, + HttpResponseParser, + RawRequestMessage, + RawResponseMessage, + ) + + HttpRequestParserC = HttpRequestParser + HttpResponseParserC = HttpResponseParser + RawRequestMessageC = RawRequestMessage + RawResponseMessageC = RawResponseMessage +except ImportError: # pragma: no cover + pass diff --git a/llm/Lib/site-packages/aiohttp/http_websocket.py b/llm/Lib/site-packages/aiohttp/http_websocket.py new file mode 100644 index 0000000000000000000000000000000000000000..b4524c0f1fef8ff62e43051e85e6060237b57030 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/http_websocket.py @@ -0,0 +1,740 @@ +"""WebSocket protocol versions 13 and 8.""" + +import asyncio +import functools +import json +import random +import re +import sys +import zlib +from enum import IntEnum +from struct import Struct +from typing import ( + Any, + Callable, + Final, + List, + NamedTuple, + Optional, + Pattern, + Set, + Tuple, + Union, + cast, +) + +from .base_protocol import BaseProtocol +from .compression_utils import ZLibCompressor, ZLibDecompressor +from .helpers import NO_EXTENSIONS, set_exception +from .streams import DataQueue + +__all__ = ( + "WS_CLOSED_MESSAGE", + "WS_CLOSING_MESSAGE", + "WS_KEY", + "WebSocketReader", + "WebSocketWriter", + "WSMessage", + "WebSocketError", + "WSMsgType", + "WSCloseCode", +) + + +class WSCloseCode(IntEnum): + OK = 1000 + GOING_AWAY = 1001 + PROTOCOL_ERROR = 1002 + UNSUPPORTED_DATA = 1003 + ABNORMAL_CLOSURE = 1006 + INVALID_TEXT = 1007 + POLICY_VIOLATION = 1008 + MESSAGE_TOO_BIG = 1009 + MANDATORY_EXTENSION = 1010 + INTERNAL_ERROR = 1011 + SERVICE_RESTART = 1012 + TRY_AGAIN_LATER = 1013 + BAD_GATEWAY = 1014 + + +ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode} + +# For websockets, keeping latency low is extremely important as implementations +# generally expect to be able to send and receive messages quickly. We use a +# larger chunk size than the default to reduce the number of executor calls +# since the executor is a significant source of latency and overhead when +# the chunks are small. A size of 5KiB was chosen because it is also the +# same value python-zlib-ng choose to use as the threshold to release the GIL. + +WEBSOCKET_MAX_SYNC_CHUNK_SIZE = 5 * 1024 + + +class WSMsgType(IntEnum): + # websocket spec types + CONTINUATION = 0x0 + TEXT = 0x1 + BINARY = 0x2 + PING = 0x9 + PONG = 0xA + CLOSE = 0x8 + + # aiohttp specific types + CLOSING = 0x100 + CLOSED = 0x101 + ERROR = 0x102 + + text = TEXT + binary = BINARY + ping = PING + pong = PONG + close = CLOSE + closing = CLOSING + closed = CLOSED + error = ERROR + + +WS_KEY: Final[bytes] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + + +UNPACK_LEN2 = Struct("!H").unpack_from +UNPACK_LEN3 = Struct("!Q").unpack_from +UNPACK_CLOSE_CODE = Struct("!H").unpack +PACK_LEN1 = Struct("!BB").pack +PACK_LEN2 = Struct("!BBH").pack +PACK_LEN3 = Struct("!BBQ").pack +PACK_CLOSE_CODE = Struct("!H").pack +MSG_SIZE: Final[int] = 2**14 +DEFAULT_LIMIT: Final[int] = 2**16 + + +class WSMessage(NamedTuple): + type: WSMsgType + # To type correctly, this would need some kind of tagged union for each type. + data: Any + extra: Optional[str] + + def json(self, *, loads: Callable[[Any], Any] = json.loads) -> Any: + """Return parsed JSON data. + + .. versionadded:: 0.22 + """ + return loads(self.data) + + +WS_CLOSED_MESSAGE = WSMessage(WSMsgType.CLOSED, None, None) +WS_CLOSING_MESSAGE = WSMessage(WSMsgType.CLOSING, None, None) + + +class WebSocketError(Exception): + """WebSocket protocol parser error.""" + + def __init__(self, code: int, message: str) -> None: + self.code = code + super().__init__(code, message) + + def __str__(self) -> str: + return cast(str, self.args[1]) + + +class WSHandshakeError(Exception): + """WebSocket protocol handshake error.""" + + +native_byteorder: Final[str] = sys.byteorder + + +# Used by _websocket_mask_python +@functools.lru_cache +def _xor_table() -> List[bytes]: + return [bytes(a ^ b for a in range(256)) for b in range(256)] + + +def _websocket_mask_python(mask: bytes, data: bytearray) -> None: + """Websocket masking function. + + `mask` is a `bytes` object of length 4; `data` is a `bytearray` + object of any length. The contents of `data` are masked with `mask`, + as specified in section 5.3 of RFC 6455. + + Note that this function mutates the `data` argument. + + This pure-python implementation may be replaced by an optimized + version when available. + + """ + assert isinstance(data, bytearray), data + assert len(mask) == 4, mask + + if data: + _XOR_TABLE = _xor_table() + a, b, c, d = (_XOR_TABLE[n] for n in mask) + data[::4] = data[::4].translate(a) + data[1::4] = data[1::4].translate(b) + data[2::4] = data[2::4].translate(c) + data[3::4] = data[3::4].translate(d) + + +if NO_EXTENSIONS: # pragma: no cover + _websocket_mask = _websocket_mask_python +else: + try: + from ._websocket import _websocket_mask_cython # type: ignore[import-not-found] + + _websocket_mask = _websocket_mask_cython + except ImportError: # pragma: no cover + _websocket_mask = _websocket_mask_python + +_WS_DEFLATE_TRAILING: Final[bytes] = bytes([0x00, 0x00, 0xFF, 0xFF]) + + +_WS_EXT_RE: Final[Pattern[str]] = re.compile( + r"^(?:;\s*(?:" + r"(server_no_context_takeover)|" + r"(client_no_context_takeover)|" + r"(server_max_window_bits(?:=(\d+))?)|" + r"(client_max_window_bits(?:=(\d+))?)))*$" +) + +_WS_EXT_RE_SPLIT: Final[Pattern[str]] = re.compile(r"permessage-deflate([^,]+)?") + + +def ws_ext_parse(extstr: Optional[str], isserver: bool = False) -> Tuple[int, bool]: + if not extstr: + return 0, False + + compress = 0 + notakeover = False + for ext in _WS_EXT_RE_SPLIT.finditer(extstr): + defext = ext.group(1) + # Return compress = 15 when get `permessage-deflate` + if not defext: + compress = 15 + break + match = _WS_EXT_RE.match(defext) + if match: + compress = 15 + if isserver: + # Server never fail to detect compress handshake. + # Server does not need to send max wbit to client + if match.group(4): + compress = int(match.group(4)) + # Group3 must match if group4 matches + # Compress wbit 8 does not support in zlib + # If compress level not support, + # CONTINUE to next extension + if compress > 15 or compress < 9: + compress = 0 + continue + if match.group(1): + notakeover = True + # Ignore regex group 5 & 6 for client_max_window_bits + break + else: + if match.group(6): + compress = int(match.group(6)) + # Group5 must match if group6 matches + # Compress wbit 8 does not support in zlib + # If compress level not support, + # FAIL the parse progress + if compress > 15 or compress < 9: + raise WSHandshakeError("Invalid window size") + if match.group(2): + notakeover = True + # Ignore regex group 5 & 6 for client_max_window_bits + break + # Return Fail if client side and not match + elif not isserver: + raise WSHandshakeError("Extension for deflate not supported" + ext.group(1)) + + return compress, notakeover + + +def ws_ext_gen( + compress: int = 15, isserver: bool = False, server_notakeover: bool = False +) -> str: + # client_notakeover=False not used for server + # compress wbit 8 does not support in zlib + if compress < 9 or compress > 15: + raise ValueError( + "Compress wbits must between 9 and 15, " "zlib does not support wbits=8" + ) + enabledext = ["permessage-deflate"] + if not isserver: + enabledext.append("client_max_window_bits") + + if compress < 15: + enabledext.append("server_max_window_bits=" + str(compress)) + if server_notakeover: + enabledext.append("server_no_context_takeover") + # if client_notakeover: + # enabledext.append('client_no_context_takeover') + return "; ".join(enabledext) + + +class WSParserState(IntEnum): + READ_HEADER = 1 + READ_PAYLOAD_LENGTH = 2 + READ_PAYLOAD_MASK = 3 + READ_PAYLOAD = 4 + + +class WebSocketReader: + def __init__( + self, queue: DataQueue[WSMessage], max_msg_size: int, compress: bool = True + ) -> None: + self.queue = queue + self._max_msg_size = max_msg_size + + self._exc: Optional[BaseException] = None + self._partial = bytearray() + self._state = WSParserState.READ_HEADER + + self._opcode: Optional[int] = None + self._frame_fin = False + self._frame_opcode: Optional[int] = None + self._frame_payload = bytearray() + + self._tail = b"" + self._has_mask = False + self._frame_mask: Optional[bytes] = None + self._payload_length = 0 + self._payload_length_flag = 0 + self._compressed: Optional[bool] = None + self._decompressobj: Optional[ZLibDecompressor] = None + self._compress = compress + + def feed_eof(self) -> None: + self.queue.feed_eof() + + def feed_data(self, data: bytes) -> Tuple[bool, bytes]: + if self._exc: + return True, data + + try: + return self._feed_data(data) + except Exception as exc: + self._exc = exc + set_exception(self.queue, exc) + return True, b"" + + def _feed_data(self, data: bytes) -> Tuple[bool, bytes]: + for fin, opcode, payload, compressed in self.parse_frame(data): + if compressed and not self._decompressobj: + self._decompressobj = ZLibDecompressor(suppress_deflate_header=True) + if opcode == WSMsgType.CLOSE: + if len(payload) >= 2: + close_code = UNPACK_CLOSE_CODE(payload[:2])[0] + if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + f"Invalid close code: {close_code}", + ) + try: + close_message = payload[2:].decode("utf-8") + except UnicodeDecodeError as exc: + raise WebSocketError( + WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" + ) from exc + msg = WSMessage(WSMsgType.CLOSE, close_code, close_message) + elif payload: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + f"Invalid close frame: {fin} {opcode} {payload!r}", + ) + else: + msg = WSMessage(WSMsgType.CLOSE, 0, "") + + self.queue.feed_data(msg, 0) + + elif opcode == WSMsgType.PING: + self.queue.feed_data( + WSMessage(WSMsgType.PING, payload, ""), len(payload) + ) + + elif opcode == WSMsgType.PONG: + self.queue.feed_data( + WSMessage(WSMsgType.PONG, payload, ""), len(payload) + ) + + elif ( + opcode not in (WSMsgType.TEXT, WSMsgType.BINARY) + and self._opcode is None + ): + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}" + ) + else: + # load text/binary + if not fin: + # got partial frame payload + if opcode != WSMsgType.CONTINUATION: + self._opcode = opcode + self._partial.extend(payload) + if self._max_msg_size and len(self._partial) >= self._max_msg_size: + raise WebSocketError( + WSCloseCode.MESSAGE_TOO_BIG, + "Message size {} exceeds limit {}".format( + len(self._partial), self._max_msg_size + ), + ) + else: + # previous frame was non finished + # we should get continuation opcode + if self._partial: + if opcode != WSMsgType.CONTINUATION: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "The opcode in non-fin frame is expected " + "to be zero, got {!r}".format(opcode), + ) + + if opcode == WSMsgType.CONTINUATION: + assert self._opcode is not None + opcode = self._opcode + self._opcode = None + + self._partial.extend(payload) + if self._max_msg_size and len(self._partial) >= self._max_msg_size: + raise WebSocketError( + WSCloseCode.MESSAGE_TOO_BIG, + "Message size {} exceeds limit {}".format( + len(self._partial), self._max_msg_size + ), + ) + + # Decompress process must to be done after all packets + # received. + if compressed: + assert self._decompressobj is not None + self._partial.extend(_WS_DEFLATE_TRAILING) + payload_merged = self._decompressobj.decompress_sync( + self._partial, self._max_msg_size + ) + if self._decompressobj.unconsumed_tail: + left = len(self._decompressobj.unconsumed_tail) + raise WebSocketError( + WSCloseCode.MESSAGE_TOO_BIG, + "Decompressed message size {} exceeds limit {}".format( + self._max_msg_size + left, self._max_msg_size + ), + ) + else: + payload_merged = bytes(self._partial) + + self._partial.clear() + + if opcode == WSMsgType.TEXT: + try: + text = payload_merged.decode("utf-8") + self.queue.feed_data( + WSMessage(WSMsgType.TEXT, text, ""), len(text) + ) + except UnicodeDecodeError as exc: + raise WebSocketError( + WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message" + ) from exc + else: + self.queue.feed_data( + WSMessage(WSMsgType.BINARY, payload_merged, ""), + len(payload_merged), + ) + + return False, b"" + + def parse_frame( + self, buf: bytes + ) -> List[Tuple[bool, Optional[int], bytearray, Optional[bool]]]: + """Return the next frame from the socket.""" + frames = [] + if self._tail: + buf, self._tail = self._tail + buf, b"" + + start_pos = 0 + buf_length = len(buf) + + while True: + # read header + if self._state == WSParserState.READ_HEADER: + if buf_length - start_pos >= 2: + data = buf[start_pos : start_pos + 2] + start_pos += 2 + first_byte, second_byte = data + + fin = (first_byte >> 7) & 1 + rsv1 = (first_byte >> 6) & 1 + rsv2 = (first_byte >> 5) & 1 + rsv3 = (first_byte >> 4) & 1 + opcode = first_byte & 0xF + + # frame-fin = %x0 ; more frames of this message follow + # / %x1 ; final frame of this message + # frame-rsv1 = %x0 ; + # 1 bit, MUST be 0 unless negotiated otherwise + # frame-rsv2 = %x0 ; + # 1 bit, MUST be 0 unless negotiated otherwise + # frame-rsv3 = %x0 ; + # 1 bit, MUST be 0 unless negotiated otherwise + # + # Remove rsv1 from this test for deflate development + if rsv2 or rsv3 or (rsv1 and not self._compress): + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "Received frame with non-zero reserved bits", + ) + + if opcode > 0x7 and fin == 0: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "Received fragmented control frame", + ) + + has_mask = (second_byte >> 7) & 1 + length = second_byte & 0x7F + + # Control frames MUST have a payload + # length of 125 bytes or less + if opcode > 0x7 and length > 125: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "Control frame payload cannot be " "larger than 125 bytes", + ) + + # Set compress status if last package is FIN + # OR set compress status if this is first fragment + # Raise error if not first fragment with rsv1 = 0x1 + if self._frame_fin or self._compressed is None: + self._compressed = True if rsv1 else False + elif rsv1: + raise WebSocketError( + WSCloseCode.PROTOCOL_ERROR, + "Received frame with non-zero reserved bits", + ) + + self._frame_fin = bool(fin) + self._frame_opcode = opcode + self._has_mask = bool(has_mask) + self._payload_length_flag = length + self._state = WSParserState.READ_PAYLOAD_LENGTH + else: + break + + # read payload length + if self._state == WSParserState.READ_PAYLOAD_LENGTH: + length = self._payload_length_flag + if length == 126: + if buf_length - start_pos >= 2: + data = buf[start_pos : start_pos + 2] + start_pos += 2 + length = UNPACK_LEN2(data)[0] + self._payload_length = length + self._state = ( + WSParserState.READ_PAYLOAD_MASK + if self._has_mask + else WSParserState.READ_PAYLOAD + ) + else: + break + elif length > 126: + if buf_length - start_pos >= 8: + data = buf[start_pos : start_pos + 8] + start_pos += 8 + length = UNPACK_LEN3(data)[0] + self._payload_length = length + self._state = ( + WSParserState.READ_PAYLOAD_MASK + if self._has_mask + else WSParserState.READ_PAYLOAD + ) + else: + break + else: + self._payload_length = length + self._state = ( + WSParserState.READ_PAYLOAD_MASK + if self._has_mask + else WSParserState.READ_PAYLOAD + ) + + # read payload mask + if self._state == WSParserState.READ_PAYLOAD_MASK: + if buf_length - start_pos >= 4: + self._frame_mask = buf[start_pos : start_pos + 4] + start_pos += 4 + self._state = WSParserState.READ_PAYLOAD + else: + break + + if self._state == WSParserState.READ_PAYLOAD: + length = self._payload_length + payload = self._frame_payload + + chunk_len = buf_length - start_pos + if length >= chunk_len: + self._payload_length = length - chunk_len + payload.extend(buf[start_pos:]) + start_pos = buf_length + else: + self._payload_length = 0 + payload.extend(buf[start_pos : start_pos + length]) + start_pos = start_pos + length + + if self._payload_length == 0: + if self._has_mask: + assert self._frame_mask is not None + _websocket_mask(self._frame_mask, payload) + + frames.append( + (self._frame_fin, self._frame_opcode, payload, self._compressed) + ) + + self._frame_payload = bytearray() + self._state = WSParserState.READ_HEADER + else: + break + + self._tail = buf[start_pos:] + + return frames + + +class WebSocketWriter: + def __init__( + self, + protocol: BaseProtocol, + transport: asyncio.Transport, + *, + use_mask: bool = False, + limit: int = DEFAULT_LIMIT, + random: random.Random = random.Random(), + compress: int = 0, + notakeover: bool = False, + ) -> None: + self.protocol = protocol + self.transport = transport + self.use_mask = use_mask + self.randrange = random.randrange + self.compress = compress + self.notakeover = notakeover + self._closing = False + self._limit = limit + self._output_size = 0 + self._compressobj: Any = None # actually compressobj + + async def _send_frame( + self, message: bytes, opcode: int, compress: Optional[int] = None + ) -> None: + """Send a frame over the websocket with message as its payload.""" + if self._closing and not (opcode & WSMsgType.CLOSE): + raise ConnectionResetError("Cannot write to closing transport") + + rsv = 0 + + # Only compress larger packets (disabled) + # Does small packet needs to be compressed? + # if self.compress and opcode < 8 and len(message) > 124: + if (compress or self.compress) and opcode < 8: + if compress: + # Do not set self._compress if compressing is for this frame + compressobj = self._make_compress_obj(compress) + else: # self.compress + if not self._compressobj: + self._compressobj = self._make_compress_obj(self.compress) + compressobj = self._compressobj + + message = await compressobj.compress(message) + # Its critical that we do not return control to the event + # loop until we have finished sending all the compressed + # data. Otherwise we could end up mixing compressed frames + # if there are multiple coroutines compressing data. + message += compressobj.flush( + zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH + ) + if message.endswith(_WS_DEFLATE_TRAILING): + message = message[:-4] + rsv = rsv | 0x40 + + msg_length = len(message) + + use_mask = self.use_mask + if use_mask: + mask_bit = 0x80 + else: + mask_bit = 0 + + if msg_length < 126: + header = PACK_LEN1(0x80 | rsv | opcode, msg_length | mask_bit) + elif msg_length < (1 << 16): + header = PACK_LEN2(0x80 | rsv | opcode, 126 | mask_bit, msg_length) + else: + header = PACK_LEN3(0x80 | rsv | opcode, 127 | mask_bit, msg_length) + if use_mask: + mask_int = self.randrange(0, 0xFFFFFFFF) + mask = mask_int.to_bytes(4, "big") + message = bytearray(message) + _websocket_mask(mask, message) + self._write(header + mask + message) + self._output_size += len(header) + len(mask) + msg_length + else: + if msg_length > MSG_SIZE: + self._write(header) + self._write(message) + else: + self._write(header + message) + + self._output_size += len(header) + msg_length + + # It is safe to return control to the event loop when using compression + # after this point as we have already sent or buffered all the data. + + if self._output_size > self._limit: + self._output_size = 0 + await self.protocol._drain_helper() + + def _make_compress_obj(self, compress: int) -> ZLibCompressor: + return ZLibCompressor( + level=zlib.Z_BEST_SPEED, + wbits=-compress, + max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE, + ) + + def _write(self, data: bytes) -> None: + if self.transport is None or self.transport.is_closing(): + raise ConnectionResetError("Cannot write to closing transport") + self.transport.write(data) + + async def pong(self, message: Union[bytes, str] = b"") -> None: + """Send pong message.""" + if isinstance(message, str): + message = message.encode("utf-8") + await self._send_frame(message, WSMsgType.PONG) + + async def ping(self, message: Union[bytes, str] = b"") -> None: + """Send ping message.""" + if isinstance(message, str): + message = message.encode("utf-8") + await self._send_frame(message, WSMsgType.PING) + + async def send( + self, + message: Union[str, bytes], + binary: bool = False, + compress: Optional[int] = None, + ) -> None: + """Send a frame over the websocket with message as its payload.""" + if isinstance(message, str): + message = message.encode("utf-8") + if binary: + await self._send_frame(message, WSMsgType.BINARY, compress) + else: + await self._send_frame(message, WSMsgType.TEXT, compress) + + async def close(self, code: int = 1000, message: Union[bytes, str] = b"") -> None: + """Close the websocket, sending the specified code and message.""" + if isinstance(message, str): + message = message.encode("utf-8") + try: + await self._send_frame( + PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE + ) + finally: + self._closing = True diff --git a/llm/Lib/site-packages/aiohttp/http_writer.py b/llm/Lib/site-packages/aiohttp/http_writer.py new file mode 100644 index 0000000000000000000000000000000000000000..84e9b49589517dd8468e375e2cb1c7fd575a75b9 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/http_writer.py @@ -0,0 +1,198 @@ +"""Http related parsers and protocol.""" + +import asyncio +import zlib +from typing import Any, Awaitable, Callable, NamedTuple, Optional, Union # noqa + +from multidict import CIMultiDict + +from .abc import AbstractStreamWriter +from .base_protocol import BaseProtocol +from .compression_utils import ZLibCompressor +from .helpers import NO_EXTENSIONS + +__all__ = ("StreamWriter", "HttpVersion", "HttpVersion10", "HttpVersion11") + + +class HttpVersion(NamedTuple): + major: int + minor: int + + +HttpVersion10 = HttpVersion(1, 0) +HttpVersion11 = HttpVersion(1, 1) + + +_T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]] +_T_OnHeadersSent = Optional[Callable[["CIMultiDict[str]"], Awaitable[None]]] + + +class StreamWriter(AbstractStreamWriter): + def __init__( + self, + protocol: BaseProtocol, + loop: asyncio.AbstractEventLoop, + on_chunk_sent: _T_OnChunkSent = None, + on_headers_sent: _T_OnHeadersSent = None, + ) -> None: + self._protocol = protocol + + self.loop = loop + self.length = None + self.chunked = False + self.buffer_size = 0 + self.output_size = 0 + + self._eof = False + self._compress: Optional[ZLibCompressor] = None + self._drain_waiter = None + + self._on_chunk_sent: _T_OnChunkSent = on_chunk_sent + self._on_headers_sent: _T_OnHeadersSent = on_headers_sent + + @property + def transport(self) -> Optional[asyncio.Transport]: + return self._protocol.transport + + @property + def protocol(self) -> BaseProtocol: + return self._protocol + + def enable_chunking(self) -> None: + self.chunked = True + + def enable_compression( + self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY + ) -> None: + self._compress = ZLibCompressor(encoding=encoding, strategy=strategy) + + def _write(self, chunk: bytes) -> None: + size = len(chunk) + self.buffer_size += size + self.output_size += size + transport = self.transport + if not self._protocol.connected or transport is None or transport.is_closing(): + raise ConnectionResetError("Cannot write to closing transport") + transport.write(chunk) + + async def write( + self, chunk: bytes, *, drain: bool = True, LIMIT: int = 0x10000 + ) -> None: + """Writes chunk of data to a stream. + + write_eof() indicates end of stream. + writer can't be used after write_eof() method being called. + write() return drain future. + """ + if self._on_chunk_sent is not None: + await self._on_chunk_sent(chunk) + + if isinstance(chunk, memoryview): + if chunk.nbytes != len(chunk): + # just reshape it + chunk = chunk.cast("c") + + if self._compress is not None: + chunk = await self._compress.compress(chunk) + if not chunk: + return + + if self.length is not None: + chunk_len = len(chunk) + if self.length >= chunk_len: + self.length = self.length - chunk_len + else: + chunk = chunk[: self.length] + self.length = 0 + if not chunk: + return + + if chunk: + if self.chunked: + chunk_len_pre = ("%x\r\n" % len(chunk)).encode("ascii") + chunk = chunk_len_pre + chunk + b"\r\n" + + self._write(chunk) + + if self.buffer_size > LIMIT and drain: + self.buffer_size = 0 + await self.drain() + + async def write_headers( + self, status_line: str, headers: "CIMultiDict[str]" + ) -> None: + """Write request/response status and headers.""" + if self._on_headers_sent is not None: + await self._on_headers_sent(headers) + + # status + headers + buf = _serialize_headers(status_line, headers) + self._write(buf) + + async def write_eof(self, chunk: bytes = b"") -> None: + if self._eof: + return + + if chunk and self._on_chunk_sent is not None: + await self._on_chunk_sent(chunk) + + if self._compress: + if chunk: + chunk = await self._compress.compress(chunk) + + chunk += self._compress.flush() + if chunk and self.chunked: + chunk_len = ("%x\r\n" % len(chunk)).encode("ascii") + chunk = chunk_len + chunk + b"\r\n0\r\n\r\n" + else: + if self.chunked: + if chunk: + chunk_len = ("%x\r\n" % len(chunk)).encode("ascii") + chunk = chunk_len + chunk + b"\r\n0\r\n\r\n" + else: + chunk = b"0\r\n\r\n" + + if chunk: + self._write(chunk) + + await self.drain() + + self._eof = True + + async def drain(self) -> None: + """Flush the write buffer. + + The intended use is to write + + await w.write(data) + await w.drain() + """ + if self._protocol.transport is not None: + await self._protocol._drain_helper() + + +def _safe_header(string: str) -> str: + if "\r" in string or "\n" in string: + raise ValueError( + "Newline or carriage return detected in headers. " + "Potential header injection attack." + ) + return string + + +def _py_serialize_headers(status_line: str, headers: "CIMultiDict[str]") -> bytes: + headers_gen = (_safe_header(k) + ": " + _safe_header(v) for k, v in headers.items()) + line = status_line + "\r\n" + "\r\n".join(headers_gen) + "\r\n\r\n" + return line.encode("utf-8") + + +_serialize_headers = _py_serialize_headers + +try: + import aiohttp._http_writer as _http_writer # type: ignore[import-not-found] + + _c_serialize_headers = _http_writer._serialize_headers + if not NO_EXTENSIONS: + _serialize_headers = _c_serialize_headers +except ImportError: + pass diff --git a/llm/Lib/site-packages/aiohttp/locks.py b/llm/Lib/site-packages/aiohttp/locks.py new file mode 100644 index 0000000000000000000000000000000000000000..fbfff2fa42cc29d377e29552092744899c10d640 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/locks.py @@ -0,0 +1,41 @@ +import asyncio +import collections +from typing import Any, Deque, Optional + + +class EventResultOrError: + """Event asyncio lock helper class. + + Wraps the Event asyncio lock allowing either to awake the + locked Tasks without any error or raising an exception. + + thanks to @vorpalsmith for the simple design. + """ + + def __init__(self, loop: asyncio.AbstractEventLoop) -> None: + self._loop = loop + self._exc: Optional[BaseException] = None + self._event = asyncio.Event() + self._waiters: Deque[asyncio.Future[Any]] = collections.deque() + + def set(self, exc: Optional[BaseException] = None) -> None: + self._exc = exc + self._event.set() + + async def wait(self) -> Any: + waiter = self._loop.create_task(self._event.wait()) + self._waiters.append(waiter) + try: + val = await waiter + finally: + self._waiters.remove(waiter) + + if self._exc is not None: + raise self._exc + + return val + + def cancel(self) -> None: + """Cancel all waiters""" + for waiter in self._waiters: + waiter.cancel() diff --git a/llm/Lib/site-packages/aiohttp/log.py b/llm/Lib/site-packages/aiohttp/log.py new file mode 100644 index 0000000000000000000000000000000000000000..d1314145b262658992db308631a13bac6948f7f8 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/log.py @@ -0,0 +1,8 @@ +import logging + +access_logger = logging.getLogger("aiohttp.access") +client_logger = logging.getLogger("aiohttp.client") +internal_logger = logging.getLogger("aiohttp.internal") +server_logger = logging.getLogger("aiohttp.server") +web_logger = logging.getLogger("aiohttp.web") +ws_logger = logging.getLogger("aiohttp.websocket") diff --git a/llm/Lib/site-packages/aiohttp/multipart.py b/llm/Lib/site-packages/aiohttp/multipart.py new file mode 100644 index 0000000000000000000000000000000000000000..660cdf136b941d981a76a61d0174e3c2939b1569 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/multipart.py @@ -0,0 +1,1015 @@ +import base64 +import binascii +import json +import re +import uuid +import warnings +import zlib +from collections import deque +from types import TracebackType +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Deque, + Dict, + Iterator, + List, + Mapping, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) +from urllib.parse import parse_qsl, unquote, urlencode + +from multidict import CIMultiDict, CIMultiDictProxy + +from .compression_utils import ZLibCompressor, ZLibDecompressor +from .hdrs import ( + CONTENT_DISPOSITION, + CONTENT_ENCODING, + CONTENT_LENGTH, + CONTENT_TRANSFER_ENCODING, + CONTENT_TYPE, +) +from .helpers import CHAR, TOKEN, parse_mimetype, reify +from .http import HeadersParser +from .payload import ( + JsonPayload, + LookupError, + Order, + Payload, + StringPayload, + get_payload, + payload_type, +) +from .streams import StreamReader + +__all__ = ( + "MultipartReader", + "MultipartWriter", + "BodyPartReader", + "BadContentDispositionHeader", + "BadContentDispositionParam", + "parse_content_disposition", + "content_disposition_filename", +) + + +if TYPE_CHECKING: + from .client_reqrep import ClientResponse + + +class BadContentDispositionHeader(RuntimeWarning): + pass + + +class BadContentDispositionParam(RuntimeWarning): + pass + + +def parse_content_disposition( + header: Optional[str], +) -> Tuple[Optional[str], Dict[str, str]]: + def is_token(string: str) -> bool: + return bool(string) and TOKEN >= set(string) + + def is_quoted(string: str) -> bool: + return string[0] == string[-1] == '"' + + def is_rfc5987(string: str) -> bool: + return is_token(string) and string.count("'") == 2 + + def is_extended_param(string: str) -> bool: + return string.endswith("*") + + def is_continuous_param(string: str) -> bool: + pos = string.find("*") + 1 + if not pos: + return False + substring = string[pos:-1] if string.endswith("*") else string[pos:] + return substring.isdigit() + + def unescape(text: str, *, chars: str = "".join(map(re.escape, CHAR))) -> str: + return re.sub(f"\\\\([{chars}])", "\\1", text) + + if not header: + return None, {} + + disptype, *parts = header.split(";") + if not is_token(disptype): + warnings.warn(BadContentDispositionHeader(header)) + return None, {} + + params: Dict[str, str] = {} + while parts: + item = parts.pop(0) + + if "=" not in item: + warnings.warn(BadContentDispositionHeader(header)) + return None, {} + + key, value = item.split("=", 1) + key = key.lower().strip() + value = value.lstrip() + + if key in params: + warnings.warn(BadContentDispositionHeader(header)) + return None, {} + + if not is_token(key): + warnings.warn(BadContentDispositionParam(item)) + continue + + elif is_continuous_param(key): + if is_quoted(value): + value = unescape(value[1:-1]) + elif not is_token(value): + warnings.warn(BadContentDispositionParam(item)) + continue + + elif is_extended_param(key): + if is_rfc5987(value): + encoding, _, value = value.split("'", 2) + encoding = encoding or "utf-8" + else: + warnings.warn(BadContentDispositionParam(item)) + continue + + try: + value = unquote(value, encoding, "strict") + except UnicodeDecodeError: # pragma: nocover + warnings.warn(BadContentDispositionParam(item)) + continue + + else: + failed = True + if is_quoted(value): + failed = False + value = unescape(value[1:-1].lstrip("\\/")) + elif is_token(value): + failed = False + elif parts: + # maybe just ; in filename, in any case this is just + # one case fix, for proper fix we need to redesign parser + _value = f"{value};{parts[0]}" + if is_quoted(_value): + parts.pop(0) + value = unescape(_value[1:-1].lstrip("\\/")) + failed = False + + if failed: + warnings.warn(BadContentDispositionHeader(header)) + return None, {} + + params[key] = value + + return disptype.lower(), params + + +def content_disposition_filename( + params: Mapping[str, str], name: str = "filename" +) -> Optional[str]: + name_suf = "%s*" % name + if not params: + return None + elif name_suf in params: + return params[name_suf] + elif name in params: + return params[name] + else: + parts = [] + fnparams = sorted( + (key, value) for key, value in params.items() if key.startswith(name_suf) + ) + for num, (key, value) in enumerate(fnparams): + _, tail = key.split("*", 1) + if tail.endswith("*"): + tail = tail[:-1] + if tail == str(num): + parts.append(value) + else: + break + if not parts: + return None + value = "".join(parts) + if "'" in value: + encoding, _, value = value.split("'", 2) + encoding = encoding or "utf-8" + return unquote(value, encoding, "strict") + return value + + +class MultipartResponseWrapper: + """Wrapper around the MultipartReader. + + It takes care about + underlying connection and close it when it needs in. + """ + + def __init__( + self, + resp: "ClientResponse", + stream: "MultipartReader", + ) -> None: + self.resp = resp + self.stream = stream + + def __aiter__(self) -> "MultipartResponseWrapper": + return self + + async def __anext__( + self, + ) -> Union["MultipartReader", "BodyPartReader"]: + part = await self.next() + if part is None: + raise StopAsyncIteration + return part + + def at_eof(self) -> bool: + """Returns True when all response data had been read.""" + return self.resp.content.at_eof() + + async def next( + self, + ) -> Optional[Union["MultipartReader", "BodyPartReader"]]: + """Emits next multipart reader object.""" + item = await self.stream.next() + if self.stream.at_eof(): + await self.release() + return item + + async def release(self) -> None: + """Release the connection gracefully. + + All remaining content is read to the void. + """ + await self.resp.release() + + +class BodyPartReader: + """Multipart reader for single body part.""" + + chunk_size = 8192 + + def __init__( + self, + boundary: bytes, + headers: "CIMultiDictProxy[str]", + content: StreamReader, + *, + subtype: str = "mixed", + default_charset: Optional[str] = None, + ) -> None: + self.headers = headers + self._boundary = boundary + self._content = content + self._default_charset = default_charset + self._at_eof = False + self._is_form_data = subtype == "form-data" + # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8 + length = None if self._is_form_data else self.headers.get(CONTENT_LENGTH, None) + self._length = int(length) if length is not None else None + self._read_bytes = 0 + self._unread: Deque[bytes] = deque() + self._prev_chunk: Optional[bytes] = None + self._content_eof = 0 + self._cache: Dict[str, Any] = {} + + def __aiter__(self) -> AsyncIterator["BodyPartReader"]: + return self # type: ignore[return-value] + + async def __anext__(self) -> bytes: + part = await self.next() + if part is None: + raise StopAsyncIteration + return part + + async def next(self) -> Optional[bytes]: + item = await self.read() + if not item: + return None + return item + + async def read(self, *, decode: bool = False) -> bytes: + """Reads body part data. + + decode: Decodes data following by encoding + method from Content-Encoding header. If it missed + data remains untouched + """ + if self._at_eof: + return b"" + data = bytearray() + while not self._at_eof: + data.extend(await self.read_chunk(self.chunk_size)) + if decode: + return self.decode(data) + return data + + async def read_chunk(self, size: int = chunk_size) -> bytes: + """Reads body part content chunk of the specified size. + + size: chunk size + """ + if self._at_eof: + return b"" + if self._length: + chunk = await self._read_chunk_from_length(size) + else: + chunk = await self._read_chunk_from_stream(size) + + self._read_bytes += len(chunk) + if self._read_bytes == self._length: + self._at_eof = True + if self._at_eof: + clrf = await self._content.readline() + assert ( + b"\r\n" == clrf + ), "reader did not read all the data or it is malformed" + return chunk + + async def _read_chunk_from_length(self, size: int) -> bytes: + # Reads body part content chunk of the specified size. + # The body part must has Content-Length header with proper value. + assert self._length is not None, "Content-Length required for chunked read" + chunk_size = min(size, self._length - self._read_bytes) + chunk = await self._content.read(chunk_size) + if self._content.at_eof(): + self._at_eof = True + return chunk + + async def _read_chunk_from_stream(self, size: int) -> bytes: + # Reads content chunk of body part with unknown length. + # The Content-Length header for body part is not necessary. + assert ( + size >= len(self._boundary) + 2 + ), "Chunk size must be greater or equal than boundary length + 2" + first_chunk = self._prev_chunk is None + if first_chunk: + self._prev_chunk = await self._content.read(size) + + chunk = await self._content.read(size) + self._content_eof += int(self._content.at_eof()) + assert self._content_eof < 3, "Reading after EOF" + assert self._prev_chunk is not None + window = self._prev_chunk + chunk + sub = b"\r\n" + self._boundary + if first_chunk: + idx = window.find(sub) + else: + idx = window.find(sub, max(0, len(self._prev_chunk) - len(sub))) + if idx >= 0: + # pushing boundary back to content + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=DeprecationWarning) + self._content.unread_data(window[idx:]) + if size > idx: + self._prev_chunk = self._prev_chunk[:idx] + chunk = window[len(self._prev_chunk) : idx] + if not chunk: + self._at_eof = True + result = self._prev_chunk + self._prev_chunk = chunk + return result + + async def readline(self) -> bytes: + """Reads body part by line by line.""" + if self._at_eof: + return b"" + + if self._unread: + line = self._unread.popleft() + else: + line = await self._content.readline() + + if line.startswith(self._boundary): + # the very last boundary may not come with \r\n, + # so set single rules for everyone + sline = line.rstrip(b"\r\n") + boundary = self._boundary + last_boundary = self._boundary + b"--" + # ensure that we read exactly the boundary, not something alike + if sline == boundary or sline == last_boundary: + self._at_eof = True + self._unread.append(line) + return b"" + else: + next_line = await self._content.readline() + if next_line.startswith(self._boundary): + line = line[:-2] # strip CRLF but only once + self._unread.append(next_line) + + return line + + async def release(self) -> None: + """Like read(), but reads all the data to the void.""" + if self._at_eof: + return + while not self._at_eof: + await self.read_chunk(self.chunk_size) + + async def text(self, *, encoding: Optional[str] = None) -> str: + """Like read(), but assumes that body part contains text data.""" + data = await self.read(decode=True) + # see https://www.w3.org/TR/html5/forms.html#multipart/form-data-encoding-algorithm + # and https://dvcs.w3.org/hg/xhr/raw-file/tip/Overview.html#dom-xmlhttprequest-send + encoding = encoding or self.get_charset(default="utf-8") + return data.decode(encoding) + + async def json(self, *, encoding: Optional[str] = None) -> Optional[Dict[str, Any]]: + """Like read(), but assumes that body parts contains JSON data.""" + data = await self.read(decode=True) + if not data: + return None + encoding = encoding or self.get_charset(default="utf-8") + return cast(Dict[str, Any], json.loads(data.decode(encoding))) + + async def form(self, *, encoding: Optional[str] = None) -> List[Tuple[str, str]]: + """Like read(), but assumes that body parts contain form urlencoded data.""" + data = await self.read(decode=True) + if not data: + return [] + if encoding is not None: + real_encoding = encoding + else: + real_encoding = self.get_charset(default="utf-8") + try: + decoded_data = data.rstrip().decode(real_encoding) + except UnicodeDecodeError: + raise ValueError("data cannot be decoded with %s encoding" % real_encoding) + + return parse_qsl( + decoded_data, + keep_blank_values=True, + encoding=real_encoding, + ) + + def at_eof(self) -> bool: + """Returns True if the boundary was reached or False otherwise.""" + return self._at_eof + + def decode(self, data: bytes) -> bytes: + """Decodes data. + + Decoding is done according the specified Content-Encoding + or Content-Transfer-Encoding headers value. + """ + if CONTENT_TRANSFER_ENCODING in self.headers: + data = self._decode_content_transfer(data) + # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8 + if not self._is_form_data and CONTENT_ENCODING in self.headers: + return self._decode_content(data) + return data + + def _decode_content(self, data: bytes) -> bytes: + encoding = self.headers.get(CONTENT_ENCODING, "").lower() + if encoding == "identity": + return data + if encoding in {"deflate", "gzip"}: + return ZLibDecompressor( + encoding=encoding, + suppress_deflate_header=True, + ).decompress_sync(data) + + raise RuntimeError(f"unknown content encoding: {encoding}") + + def _decode_content_transfer(self, data: bytes) -> bytes: + encoding = self.headers.get(CONTENT_TRANSFER_ENCODING, "").lower() + + if encoding == "base64": + return base64.b64decode(data) + elif encoding == "quoted-printable": + return binascii.a2b_qp(data) + elif encoding in ("binary", "8bit", "7bit"): + return data + else: + raise RuntimeError( + "unknown content transfer encoding: {}" "".format(encoding) + ) + + def get_charset(self, default: str) -> str: + """Returns charset parameter from Content-Type header or default.""" + ctype = self.headers.get(CONTENT_TYPE, "") + mimetype = parse_mimetype(ctype) + return mimetype.parameters.get("charset", self._default_charset or default) + + @reify + def name(self) -> Optional[str]: + """Returns name specified in Content-Disposition header. + + If the header is missing or malformed, returns None. + """ + _, params = parse_content_disposition(self.headers.get(CONTENT_DISPOSITION)) + return content_disposition_filename(params, "name") + + @reify + def filename(self) -> Optional[str]: + """Returns filename specified in Content-Disposition header. + + Returns None if the header is missing or malformed. + """ + _, params = parse_content_disposition(self.headers.get(CONTENT_DISPOSITION)) + return content_disposition_filename(params, "filename") + + +@payload_type(BodyPartReader, order=Order.try_first) +class BodyPartReaderPayload(Payload): + def __init__(self, value: BodyPartReader, *args: Any, **kwargs: Any) -> None: + super().__init__(value, *args, **kwargs) + + params: Dict[str, str] = {} + if value.name is not None: + params["name"] = value.name + if value.filename is not None: + params["filename"] = value.filename + + if params: + self.set_content_disposition("attachment", True, **params) + + async def write(self, writer: Any) -> None: + field = self._value + chunk = await field.read_chunk(size=2**16) + while chunk: + await writer.write(field.decode(chunk)) + chunk = await field.read_chunk(size=2**16) + + +class MultipartReader: + """Multipart body reader.""" + + #: Response wrapper, used when multipart readers constructs from response. + response_wrapper_cls = MultipartResponseWrapper + #: Multipart reader class, used to handle multipart/* body parts. + #: None points to type(self) + multipart_reader_cls = None + #: Body part reader class for non multipart/* content types. + part_reader_cls = BodyPartReader + + def __init__(self, headers: Mapping[str, str], content: StreamReader) -> None: + self._mimetype = parse_mimetype(headers[CONTENT_TYPE]) + assert self._mimetype.type == "multipart", "multipart/* content type expected" + if "boundary" not in self._mimetype.parameters: + raise ValueError( + "boundary missed for Content-Type: %s" % headers[CONTENT_TYPE] + ) + + self.headers = headers + self._boundary = ("--" + self._get_boundary()).encode() + self._content = content + self._default_charset: Optional[str] = None + self._last_part: Optional[Union["MultipartReader", BodyPartReader]] = None + self._at_eof = False + self._at_bof = True + self._unread: List[bytes] = [] + + def __aiter__( + self, + ) -> AsyncIterator["BodyPartReader"]: + return self # type: ignore[return-value] + + async def __anext__( + self, + ) -> Optional[Union["MultipartReader", BodyPartReader]]: + part = await self.next() + if part is None: + raise StopAsyncIteration + return part + + @classmethod + def from_response( + cls, + response: "ClientResponse", + ) -> MultipartResponseWrapper: + """Constructs reader instance from HTTP response. + + :param response: :class:`~aiohttp.client.ClientResponse` instance + """ + obj = cls.response_wrapper_cls( + response, cls(response.headers, response.content) + ) + return obj + + def at_eof(self) -> bool: + """Returns True if the final boundary was reached, false otherwise.""" + return self._at_eof + + async def next( + self, + ) -> Optional[Union["MultipartReader", BodyPartReader]]: + """Emits the next multipart body part.""" + # So, if we're at BOF, we need to skip till the boundary. + if self._at_eof: + return None + await self._maybe_release_last_part() + if self._at_bof: + await self._read_until_first_boundary() + self._at_bof = False + else: + await self._read_boundary() + if self._at_eof: # we just read the last boundary, nothing to do there + return None + + part = await self.fetch_next_part() + # https://datatracker.ietf.org/doc/html/rfc7578#section-4.6 + if ( + self._last_part is None + and self._mimetype.subtype == "form-data" + and isinstance(part, BodyPartReader) + ): + _, params = parse_content_disposition(part.headers.get(CONTENT_DISPOSITION)) + if params.get("name") == "_charset_": + # Longest encoding in https://encoding.spec.whatwg.org/encodings.json + # is 19 characters, so 32 should be more than enough for any valid encoding. + charset = await part.read_chunk(32) + if len(charset) > 31: + raise RuntimeError("Invalid default charset") + self._default_charset = charset.strip().decode() + part = await self.fetch_next_part() + self._last_part = part + return self._last_part + + async def release(self) -> None: + """Reads all the body parts to the void till the final boundary.""" + while not self._at_eof: + item = await self.next() + if item is None: + break + await item.release() + + async def fetch_next_part( + self, + ) -> Union["MultipartReader", BodyPartReader]: + """Returns the next body part reader.""" + headers = await self._read_headers() + return self._get_part_reader(headers) + + def _get_part_reader( + self, + headers: "CIMultiDictProxy[str]", + ) -> Union["MultipartReader", BodyPartReader]: + """Dispatches the response by the `Content-Type` header. + + Returns a suitable reader instance. + + :param dict headers: Response headers + """ + ctype = headers.get(CONTENT_TYPE, "") + mimetype = parse_mimetype(ctype) + + if mimetype.type == "multipart": + if self.multipart_reader_cls is None: + return type(self)(headers, self._content) + return self.multipart_reader_cls(headers, self._content) + else: + return self.part_reader_cls( + self._boundary, + headers, + self._content, + subtype=self._mimetype.subtype, + default_charset=self._default_charset, + ) + + def _get_boundary(self) -> str: + boundary = self._mimetype.parameters["boundary"] + if len(boundary) > 70: + raise ValueError("boundary %r is too long (70 chars max)" % boundary) + + return boundary + + async def _readline(self) -> bytes: + if self._unread: + return self._unread.pop() + return await self._content.readline() + + async def _read_until_first_boundary(self) -> None: + while True: + chunk = await self._readline() + if chunk == b"": + raise ValueError( + "Could not find starting boundary %r" % (self._boundary) + ) + chunk = chunk.rstrip() + if chunk == self._boundary: + return + elif chunk == self._boundary + b"--": + self._at_eof = True + return + + async def _read_boundary(self) -> None: + chunk = (await self._readline()).rstrip() + if chunk == self._boundary: + pass + elif chunk == self._boundary + b"--": + self._at_eof = True + epilogue = await self._readline() + next_line = await self._readline() + + # the epilogue is expected and then either the end of input or the + # parent multipart boundary, if the parent boundary is found then + # it should be marked as unread and handed to the parent for + # processing + if next_line[:2] == b"--": + self._unread.append(next_line) + # otherwise the request is likely missing an epilogue and both + # lines should be passed to the parent for processing + # (this handles the old behavior gracefully) + else: + self._unread.extend([next_line, epilogue]) + else: + raise ValueError(f"Invalid boundary {chunk!r}, expected {self._boundary!r}") + + async def _read_headers(self) -> "CIMultiDictProxy[str]": + lines = [b""] + while True: + chunk = await self._content.readline() + chunk = chunk.strip() + lines.append(chunk) + if not chunk: + break + parser = HeadersParser() + headers, raw_headers = parser.parse_headers(lines) + return headers + + async def _maybe_release_last_part(self) -> None: + """Ensures that the last read body part is read completely.""" + if self._last_part is not None: + if not self._last_part.at_eof(): + await self._last_part.release() + self._unread.extend(self._last_part._unread) + self._last_part = None + + +_Part = Tuple[Payload, str, str] + + +class MultipartWriter(Payload): + """Multipart body writer.""" + + def __init__(self, subtype: str = "mixed", boundary: Optional[str] = None) -> None: + boundary = boundary if boundary is not None else uuid.uuid4().hex + # The underlying Payload API demands a str (utf-8), not bytes, + # so we need to ensure we don't lose anything during conversion. + # As a result, require the boundary to be ASCII only. + # In both situations. + + try: + self._boundary = boundary.encode("ascii") + except UnicodeEncodeError: + raise ValueError("boundary should contain ASCII only chars") from None + ctype = f"multipart/{subtype}; boundary={self._boundary_value}" + + super().__init__(None, content_type=ctype) + + self._parts: List[_Part] = [] + self._is_form_data = subtype == "form-data" + + def __enter__(self) -> "MultipartWriter": + return self + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + pass + + def __iter__(self) -> Iterator[_Part]: + return iter(self._parts) + + def __len__(self) -> int: + return len(self._parts) + + def __bool__(self) -> bool: + return True + + _valid_tchar_regex = re.compile(rb"\A[!#$%&'*+\-.^_`|~\w]+\Z") + _invalid_qdtext_char_regex = re.compile(rb"[\x00-\x08\x0A-\x1F\x7F]") + + @property + def _boundary_value(self) -> str: + """Wrap boundary parameter value in quotes, if necessary. + + Reads self.boundary and returns a unicode string. + """ + # Refer to RFCs 7231, 7230, 5234. + # + # parameter = token "=" ( token / quoted-string ) + # token = 1*tchar + # quoted-string = DQUOTE *( qdtext / quoted-pair ) DQUOTE + # qdtext = HTAB / SP / %x21 / %x23-5B / %x5D-7E / obs-text + # obs-text = %x80-FF + # quoted-pair = "\" ( HTAB / SP / VCHAR / obs-text ) + # tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" + # / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~" + # / DIGIT / ALPHA + # ; any VCHAR, except delimiters + # VCHAR = %x21-7E + value = self._boundary + if re.match(self._valid_tchar_regex, value): + return value.decode("ascii") # cannot fail + + if re.search(self._invalid_qdtext_char_regex, value): + raise ValueError("boundary value contains invalid characters") + + # escape %x5C and %x22 + quoted_value_content = value.replace(b"\\", b"\\\\") + quoted_value_content = quoted_value_content.replace(b'"', b'\\"') + + return '"' + quoted_value_content.decode("ascii") + '"' + + @property + def boundary(self) -> str: + return self._boundary.decode("ascii") + + def append(self, obj: Any, headers: Optional[Mapping[str, str]] = None) -> Payload: + if headers is None: + headers = CIMultiDict() + + if isinstance(obj, Payload): + obj.headers.update(headers) + return self.append_payload(obj) + else: + try: + payload = get_payload(obj, headers=headers) + except LookupError: + raise TypeError("Cannot create payload from %r" % obj) + else: + return self.append_payload(payload) + + def append_payload(self, payload: Payload) -> Payload: + """Adds a new body part to multipart writer.""" + encoding: Optional[str] = None + te_encoding: Optional[str] = None + if self._is_form_data: + # https://datatracker.ietf.org/doc/html/rfc7578#section-4.7 + # https://datatracker.ietf.org/doc/html/rfc7578#section-4.8 + assert ( + not {CONTENT_ENCODING, CONTENT_LENGTH, CONTENT_TRANSFER_ENCODING} + & payload.headers.keys() + ) + # Set default Content-Disposition in case user doesn't create one + if CONTENT_DISPOSITION not in payload.headers: + name = f"section-{len(self._parts)}" + payload.set_content_disposition("form-data", name=name) + else: + # compression + encoding = payload.headers.get(CONTENT_ENCODING, "").lower() + if encoding and encoding not in ("deflate", "gzip", "identity"): + raise RuntimeError(f"unknown content encoding: {encoding}") + if encoding == "identity": + encoding = None + + # te encoding + te_encoding = payload.headers.get(CONTENT_TRANSFER_ENCODING, "").lower() + if te_encoding not in ("", "base64", "quoted-printable", "binary"): + raise RuntimeError(f"unknown content transfer encoding: {te_encoding}") + if te_encoding == "binary": + te_encoding = None + + # size + size = payload.size + if size is not None and not (encoding or te_encoding): + payload.headers[CONTENT_LENGTH] = str(size) + + self._parts.append((payload, encoding, te_encoding)) # type: ignore[arg-type] + return payload + + def append_json( + self, obj: Any, headers: Optional[Mapping[str, str]] = None + ) -> Payload: + """Helper to append JSON part.""" + if headers is None: + headers = CIMultiDict() + + return self.append_payload(JsonPayload(obj, headers=headers)) + + def append_form( + self, + obj: Union[Sequence[Tuple[str, str]], Mapping[str, str]], + headers: Optional[Mapping[str, str]] = None, + ) -> Payload: + """Helper to append form urlencoded part.""" + assert isinstance(obj, (Sequence, Mapping)) + + if headers is None: + headers = CIMultiDict() + + if isinstance(obj, Mapping): + obj = list(obj.items()) + data = urlencode(obj, doseq=True) + + return self.append_payload( + StringPayload( + data, headers=headers, content_type="application/x-www-form-urlencoded" + ) + ) + + @property + def size(self) -> Optional[int]: + """Size of the payload.""" + total = 0 + for part, encoding, te_encoding in self._parts: + if encoding or te_encoding or part.size is None: + return None + + total += int( + 2 + + len(self._boundary) + + 2 + + part.size # b'--'+self._boundary+b'\r\n' + + len(part._binary_headers) + + 2 # b'\r\n' + ) + + total += 2 + len(self._boundary) + 4 # b'--'+self._boundary+b'--\r\n' + return total + + async def write(self, writer: Any, close_boundary: bool = True) -> None: + """Write body.""" + for part, encoding, te_encoding in self._parts: + if self._is_form_data: + # https://datatracker.ietf.org/doc/html/rfc7578#section-4.2 + assert CONTENT_DISPOSITION in part.headers + assert "name=" in part.headers[CONTENT_DISPOSITION] + + await writer.write(b"--" + self._boundary + b"\r\n") + await writer.write(part._binary_headers) + + if encoding or te_encoding: + w = MultipartPayloadWriter(writer) + if encoding: + w.enable_compression(encoding) + if te_encoding: + w.enable_encoding(te_encoding) + await part.write(w) # type: ignore[arg-type] + await w.write_eof() + else: + await part.write(writer) + + await writer.write(b"\r\n") + + if close_boundary: + await writer.write(b"--" + self._boundary + b"--\r\n") + + +class MultipartPayloadWriter: + def __init__(self, writer: Any) -> None: + self._writer = writer + self._encoding: Optional[str] = None + self._compress: Optional[ZLibCompressor] = None + self._encoding_buffer: Optional[bytearray] = None + + def enable_encoding(self, encoding: str) -> None: + if encoding == "base64": + self._encoding = encoding + self._encoding_buffer = bytearray() + elif encoding == "quoted-printable": + self._encoding = "quoted-printable" + + def enable_compression( + self, encoding: str = "deflate", strategy: int = zlib.Z_DEFAULT_STRATEGY + ) -> None: + self._compress = ZLibCompressor( + encoding=encoding, + suppress_deflate_header=True, + strategy=strategy, + ) + + async def write_eof(self) -> None: + if self._compress is not None: + chunk = self._compress.flush() + if chunk: + self._compress = None + await self.write(chunk) + + if self._encoding == "base64": + if self._encoding_buffer: + await self._writer.write(base64.b64encode(self._encoding_buffer)) + + async def write(self, chunk: bytes) -> None: + if self._compress is not None: + if chunk: + chunk = await self._compress.compress(chunk) + if not chunk: + return + + if self._encoding == "base64": + buf = self._encoding_buffer + assert buf is not None + buf.extend(chunk) + + if buf: + div, mod = divmod(len(buf), 3) + enc_chunk, self._encoding_buffer = (buf[: div * 3], buf[div * 3 :]) + if enc_chunk: + b64chunk = base64.b64encode(enc_chunk) + await self._writer.write(b64chunk) + elif self._encoding == "quoted-printable": + await self._writer.write(binascii.b2a_qp(chunk)) + else: + await self._writer.write(chunk) diff --git a/llm/Lib/site-packages/aiohttp/payload.py b/llm/Lib/site-packages/aiohttp/payload.py new file mode 100644 index 0000000000000000000000000000000000000000..7d6668728805651ce61f25051cb8feb872fbe86d --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/payload.py @@ -0,0 +1,463 @@ +import asyncio +import enum +import io +import json +import mimetypes +import os +import warnings +from abc import ABC, abstractmethod +from itertools import chain +from typing import ( + IO, + TYPE_CHECKING, + Any, + ByteString, + Dict, + Final, + Iterable, + Optional, + TextIO, + Tuple, + Type, + Union, +) + +from multidict import CIMultiDict + +from . import hdrs +from .abc import AbstractStreamWriter +from .helpers import ( + _SENTINEL, + content_disposition_header, + guess_filename, + parse_mimetype, + sentinel, +) +from .streams import StreamReader +from .typedefs import JSONEncoder, _CIMultiDict + +__all__ = ( + "PAYLOAD_REGISTRY", + "get_payload", + "payload_type", + "Payload", + "BytesPayload", + "StringPayload", + "IOBasePayload", + "BytesIOPayload", + "BufferedReaderPayload", + "TextIOPayload", + "StringIOPayload", + "JsonPayload", + "AsyncIterablePayload", +) + +TOO_LARGE_BYTES_BODY: Final[int] = 2**20 # 1 MB + +if TYPE_CHECKING: + from typing import List + + +class LookupError(Exception): + pass + + +class Order(str, enum.Enum): + normal = "normal" + try_first = "try_first" + try_last = "try_last" + + +def get_payload(data: Any, *args: Any, **kwargs: Any) -> "Payload": + return PAYLOAD_REGISTRY.get(data, *args, **kwargs) + + +def register_payload( + factory: Type["Payload"], type: Any, *, order: Order = Order.normal +) -> None: + PAYLOAD_REGISTRY.register(factory, type, order=order) + + +class payload_type: + def __init__(self, type: Any, *, order: Order = Order.normal) -> None: + self.type = type + self.order = order + + def __call__(self, factory: Type["Payload"]) -> Type["Payload"]: + register_payload(factory, self.type, order=self.order) + return factory + + +PayloadType = Type["Payload"] +_PayloadRegistryItem = Tuple[PayloadType, Any] + + +class PayloadRegistry: + """Payload registry. + + note: we need zope.interface for more efficient adapter search + """ + + def __init__(self) -> None: + self._first: List[_PayloadRegistryItem] = [] + self._normal: List[_PayloadRegistryItem] = [] + self._last: List[_PayloadRegistryItem] = [] + + def get( + self, + data: Any, + *args: Any, + _CHAIN: "Type[chain[_PayloadRegistryItem]]" = chain, + **kwargs: Any, + ) -> "Payload": + if isinstance(data, Payload): + return data + for factory, type in _CHAIN(self._first, self._normal, self._last): + if isinstance(data, type): + return factory(data, *args, **kwargs) + + raise LookupError() + + def register( + self, factory: PayloadType, type: Any, *, order: Order = Order.normal + ) -> None: + if order is Order.try_first: + self._first.append((factory, type)) + elif order is Order.normal: + self._normal.append((factory, type)) + elif order is Order.try_last: + self._last.append((factory, type)) + else: + raise ValueError(f"Unsupported order {order!r}") + + +class Payload(ABC): + + _default_content_type: str = "application/octet-stream" + _size: Optional[int] = None + + def __init__( + self, + value: Any, + headers: Optional[ + Union[_CIMultiDict, Dict[str, str], Iterable[Tuple[str, str]]] + ] = None, + content_type: Union[str, None, _SENTINEL] = sentinel, + filename: Optional[str] = None, + encoding: Optional[str] = None, + **kwargs: Any, + ) -> None: + self._encoding = encoding + self._filename = filename + self._headers: _CIMultiDict = CIMultiDict() + self._value = value + if content_type is not sentinel and content_type is not None: + self._headers[hdrs.CONTENT_TYPE] = content_type + elif self._filename is not None: + content_type = mimetypes.guess_type(self._filename)[0] + if content_type is None: + content_type = self._default_content_type + self._headers[hdrs.CONTENT_TYPE] = content_type + else: + self._headers[hdrs.CONTENT_TYPE] = self._default_content_type + self._headers.update(headers or {}) + + @property + def size(self) -> Optional[int]: + """Size of the payload.""" + return self._size + + @property + def filename(self) -> Optional[str]: + """Filename of the payload.""" + return self._filename + + @property + def headers(self) -> _CIMultiDict: + """Custom item headers""" + return self._headers + + @property + def _binary_headers(self) -> bytes: + return ( + "".join([k + ": " + v + "\r\n" for k, v in self.headers.items()]).encode( + "utf-8" + ) + + b"\r\n" + ) + + @property + def encoding(self) -> Optional[str]: + """Payload encoding""" + return self._encoding + + @property + def content_type(self) -> str: + """Content type""" + return self._headers[hdrs.CONTENT_TYPE] + + def set_content_disposition( + self, + disptype: str, + quote_fields: bool = True, + _charset: str = "utf-8", + **params: Any, + ) -> None: + """Sets ``Content-Disposition`` header.""" + self._headers[hdrs.CONTENT_DISPOSITION] = content_disposition_header( + disptype, quote_fields=quote_fields, _charset=_charset, **params + ) + + @abstractmethod + async def write(self, writer: AbstractStreamWriter) -> None: + """Write payload. + + writer is an AbstractStreamWriter instance: + """ + + +class BytesPayload(Payload): + def __init__(self, value: ByteString, *args: Any, **kwargs: Any) -> None: + if not isinstance(value, (bytes, bytearray, memoryview)): + raise TypeError(f"value argument must be byte-ish, not {type(value)!r}") + + if "content_type" not in kwargs: + kwargs["content_type"] = "application/octet-stream" + + super().__init__(value, *args, **kwargs) + + if isinstance(value, memoryview): + self._size = value.nbytes + else: + self._size = len(value) + + if self._size > TOO_LARGE_BYTES_BODY: + kwargs = {"source": self} + warnings.warn( + "Sending a large body directly with raw bytes might" + " lock the event loop. You should probably pass an " + "io.BytesIO object instead", + ResourceWarning, + **kwargs, + ) + + async def write(self, writer: AbstractStreamWriter) -> None: + await writer.write(self._value) + + +class StringPayload(BytesPayload): + def __init__( + self, + value: str, + *args: Any, + encoding: Optional[str] = None, + content_type: Optional[str] = None, + **kwargs: Any, + ) -> None: + + if encoding is None: + if content_type is None: + real_encoding = "utf-8" + content_type = "text/plain; charset=utf-8" + else: + mimetype = parse_mimetype(content_type) + real_encoding = mimetype.parameters.get("charset", "utf-8") + else: + if content_type is None: + content_type = "text/plain; charset=%s" % encoding + real_encoding = encoding + + super().__init__( + value.encode(real_encoding), + encoding=real_encoding, + content_type=content_type, + *args, + **kwargs, + ) + + +class StringIOPayload(StringPayload): + def __init__(self, value: IO[str], *args: Any, **kwargs: Any) -> None: + super().__init__(value.read(), *args, **kwargs) + + +class IOBasePayload(Payload): + _value: IO[Any] + + def __init__( + self, value: IO[Any], disposition: str = "attachment", *args: Any, **kwargs: Any + ) -> None: + if "filename" not in kwargs: + kwargs["filename"] = guess_filename(value) + + super().__init__(value, *args, **kwargs) + + if self._filename is not None and disposition is not None: + if hdrs.CONTENT_DISPOSITION not in self.headers: + self.set_content_disposition(disposition, filename=self._filename) + + async def write(self, writer: AbstractStreamWriter) -> None: + loop = asyncio.get_event_loop() + try: + chunk = await loop.run_in_executor(None, self._value.read, 2**16) + while chunk: + await writer.write(chunk) + chunk = await loop.run_in_executor(None, self._value.read, 2**16) + finally: + await loop.run_in_executor(None, self._value.close) + + +class TextIOPayload(IOBasePayload): + _value: TextIO + + def __init__( + self, + value: TextIO, + *args: Any, + encoding: Optional[str] = None, + content_type: Optional[str] = None, + **kwargs: Any, + ) -> None: + + if encoding is None: + if content_type is None: + encoding = "utf-8" + content_type = "text/plain; charset=utf-8" + else: + mimetype = parse_mimetype(content_type) + encoding = mimetype.parameters.get("charset", "utf-8") + else: + if content_type is None: + content_type = "text/plain; charset=%s" % encoding + + super().__init__( + value, + content_type=content_type, + encoding=encoding, + *args, + **kwargs, + ) + + @property + def size(self) -> Optional[int]: + try: + return os.fstat(self._value.fileno()).st_size - self._value.tell() + except OSError: + return None + + async def write(self, writer: AbstractStreamWriter) -> None: + loop = asyncio.get_event_loop() + try: + chunk = await loop.run_in_executor(None, self._value.read, 2**16) + while chunk: + data = ( + chunk.encode(encoding=self._encoding) + if self._encoding + else chunk.encode() + ) + await writer.write(data) + chunk = await loop.run_in_executor(None, self._value.read, 2**16) + finally: + await loop.run_in_executor(None, self._value.close) + + +class BytesIOPayload(IOBasePayload): + @property + def size(self) -> int: + position = self._value.tell() + end = self._value.seek(0, os.SEEK_END) + self._value.seek(position) + return end - position + + +class BufferedReaderPayload(IOBasePayload): + @property + def size(self) -> Optional[int]: + try: + return os.fstat(self._value.fileno()).st_size - self._value.tell() + except OSError: + # data.fileno() is not supported, e.g. + # io.BufferedReader(io.BytesIO(b'data')) + return None + + +class JsonPayload(BytesPayload): + def __init__( + self, + value: Any, + encoding: str = "utf-8", + content_type: str = "application/json", + dumps: JSONEncoder = json.dumps, + *args: Any, + **kwargs: Any, + ) -> None: + + super().__init__( + dumps(value).encode(encoding), + content_type=content_type, + encoding=encoding, + *args, + **kwargs, + ) + + +if TYPE_CHECKING: + from typing import AsyncIterable, AsyncIterator + + _AsyncIterator = AsyncIterator[bytes] + _AsyncIterable = AsyncIterable[bytes] +else: + from collections.abc import AsyncIterable, AsyncIterator + + _AsyncIterator = AsyncIterator + _AsyncIterable = AsyncIterable + + +class AsyncIterablePayload(Payload): + + _iter: Optional[_AsyncIterator] = None + + def __init__(self, value: _AsyncIterable, *args: Any, **kwargs: Any) -> None: + if not isinstance(value, AsyncIterable): + raise TypeError( + "value argument must support " + "collections.abc.AsyncIterable interface, " + "got {!r}".format(type(value)) + ) + + if "content_type" not in kwargs: + kwargs["content_type"] = "application/octet-stream" + + super().__init__(value, *args, **kwargs) + + self._iter = value.__aiter__() + + async def write(self, writer: AbstractStreamWriter) -> None: + if self._iter: + try: + # iter is not None check prevents rare cases + # when the case iterable is used twice + while True: + chunk = await self._iter.__anext__() + await writer.write(chunk) + except StopAsyncIteration: + self._iter = None + + +class StreamReaderPayload(AsyncIterablePayload): + def __init__(self, value: StreamReader, *args: Any, **kwargs: Any) -> None: + super().__init__(value.iter_any(), *args, **kwargs) + + +PAYLOAD_REGISTRY = PayloadRegistry() +PAYLOAD_REGISTRY.register(BytesPayload, (bytes, bytearray, memoryview)) +PAYLOAD_REGISTRY.register(StringPayload, str) +PAYLOAD_REGISTRY.register(StringIOPayload, io.StringIO) +PAYLOAD_REGISTRY.register(TextIOPayload, io.TextIOBase) +PAYLOAD_REGISTRY.register(BytesIOPayload, io.BytesIO) +PAYLOAD_REGISTRY.register(BufferedReaderPayload, (io.BufferedReader, io.BufferedRandom)) +PAYLOAD_REGISTRY.register(IOBasePayload, io.IOBase) +PAYLOAD_REGISTRY.register(StreamReaderPayload, StreamReader) +# try_last for giving a chance to more specialized async interables like +# multidict.BodyPartReaderPayload override the default +PAYLOAD_REGISTRY.register(AsyncIterablePayload, AsyncIterable, order=Order.try_last) diff --git a/llm/Lib/site-packages/aiohttp/payload_streamer.py b/llm/Lib/site-packages/aiohttp/payload_streamer.py new file mode 100644 index 0000000000000000000000000000000000000000..534cda39f60aff26159df82f5d593c249d92e4f0 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/payload_streamer.py @@ -0,0 +1,75 @@ +""" +Payload implementation for coroutines as data provider. + +As a simple case, you can upload data from file:: + + @aiohttp.streamer + async def file_sender(writer, file_name=None): + with open(file_name, 'rb') as f: + chunk = f.read(2**16) + while chunk: + await writer.write(chunk) + + chunk = f.read(2**16) + +Then you can use `file_sender` like this: + + async with session.post('http://httpbin.org/post', + data=file_sender(file_name='huge_file')) as resp: + print(await resp.text()) + +..note:: Coroutine must accept `writer` as first argument + +""" + +import types +import warnings +from typing import Any, Awaitable, Callable, Dict, Tuple + +from .abc import AbstractStreamWriter +from .payload import Payload, payload_type + +__all__ = ("streamer",) + + +class _stream_wrapper: + def __init__( + self, + coro: Callable[..., Awaitable[None]], + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + ) -> None: + self.coro = types.coroutine(coro) + self.args = args + self.kwargs = kwargs + + async def __call__(self, writer: AbstractStreamWriter) -> None: + await self.coro(writer, *self.args, **self.kwargs) + + +class streamer: + def __init__(self, coro: Callable[..., Awaitable[None]]) -> None: + warnings.warn( + "@streamer is deprecated, use async generators instead", + DeprecationWarning, + stacklevel=2, + ) + self.coro = coro + + def __call__(self, *args: Any, **kwargs: Any) -> _stream_wrapper: + return _stream_wrapper(self.coro, args, kwargs) + + +@payload_type(_stream_wrapper) +class StreamWrapperPayload(Payload): + async def write(self, writer: AbstractStreamWriter) -> None: + await self._value(writer) + + +@payload_type(streamer) +class StreamPayload(StreamWrapperPayload): + def __init__(self, value: Any, *args: Any, **kwargs: Any) -> None: + super().__init__(value(), *args, **kwargs) + + async def write(self, writer: AbstractStreamWriter) -> None: + await self._value(writer) diff --git a/llm/Lib/site-packages/aiohttp/py.typed b/llm/Lib/site-packages/aiohttp/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..1d9f9a52a2f18944337009f77ba0d3cee2752144 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/py.typed @@ -0,0 +1 @@ +Marker diff --git a/llm/Lib/site-packages/aiohttp/pytest_plugin.py b/llm/Lib/site-packages/aiohttp/pytest_plugin.py new file mode 100644 index 0000000000000000000000000000000000000000..cf83f832532a782bf6b010c0a75668fac5154c89 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/pytest_plugin.py @@ -0,0 +1,381 @@ +import asyncio +import contextlib +import warnings +from typing import Any, Awaitable, Callable, Dict, Iterator, Optional, Type, Union + +import pytest + +from aiohttp.helpers import isasyncgenfunction +from aiohttp.web import Application + +from .test_utils import ( + BaseTestServer, + RawTestServer, + TestClient, + TestServer, + loop_context, + setup_test_loop, + teardown_test_loop, + unused_port as _unused_port, +) + +try: + import uvloop +except ImportError: # pragma: no cover + uvloop = None # type: ignore[assignment] + +AiohttpClient = Callable[[Union[Application, BaseTestServer]], Awaitable[TestClient]] +AiohttpRawServer = Callable[[Application], Awaitable[RawTestServer]] +AiohttpServer = Callable[[Application], Awaitable[TestServer]] + + +def pytest_addoption(parser): # type: ignore[no-untyped-def] + parser.addoption( + "--aiohttp-fast", + action="store_true", + default=False, + help="run tests faster by disabling extra checks", + ) + parser.addoption( + "--aiohttp-loop", + action="store", + default="pyloop", + help="run tests with specific loop: pyloop, uvloop or all", + ) + parser.addoption( + "--aiohttp-enable-loop-debug", + action="store_true", + default=False, + help="enable event loop debug mode", + ) + + +def pytest_fixture_setup(fixturedef): # type: ignore[no-untyped-def] + """Set up pytest fixture. + + Allow fixtures to be coroutines. Run coroutine fixtures in an event loop. + """ + func = fixturedef.func + + if isasyncgenfunction(func): + # async generator fixture + is_async_gen = True + elif asyncio.iscoroutinefunction(func): + # regular async fixture + is_async_gen = False + else: + # not an async fixture, nothing to do + return + + strip_request = False + if "request" not in fixturedef.argnames: + fixturedef.argnames += ("request",) + strip_request = True + + def wrapper(*args, **kwargs): # type: ignore[no-untyped-def] + request = kwargs["request"] + if strip_request: + del kwargs["request"] + + # if neither the fixture nor the test use the 'loop' fixture, + # 'getfixturevalue' will fail because the test is not parameterized + # (this can be removed someday if 'loop' is no longer parameterized) + if "loop" not in request.fixturenames: + raise Exception( + "Asynchronous fixtures must depend on the 'loop' fixture or " + "be used in tests depending from it." + ) + + _loop = request.getfixturevalue("loop") + + if is_async_gen: + # for async generators, we need to advance the generator once, + # then advance it again in a finalizer + gen = func(*args, **kwargs) + + def finalizer(): # type: ignore[no-untyped-def] + try: + return _loop.run_until_complete(gen.__anext__()) + except StopAsyncIteration: + pass + + request.addfinalizer(finalizer) + return _loop.run_until_complete(gen.__anext__()) + else: + return _loop.run_until_complete(func(*args, **kwargs)) + + fixturedef.func = wrapper + + +@pytest.fixture +def fast(request): # type: ignore[no-untyped-def] + """--fast config option""" + return request.config.getoption("--aiohttp-fast") + + +@pytest.fixture +def loop_debug(request): # type: ignore[no-untyped-def] + """--enable-loop-debug config option""" + return request.config.getoption("--aiohttp-enable-loop-debug") + + +@contextlib.contextmanager +def _runtime_warning_context(): # type: ignore[no-untyped-def] + """Context manager which checks for RuntimeWarnings. + + This exists specifically to + avoid "coroutine 'X' was never awaited" warnings being missed. + + If RuntimeWarnings occur in the context a RuntimeError is raised. + """ + with warnings.catch_warnings(record=True) as _warnings: + yield + rw = [ + "{w.filename}:{w.lineno}:{w.message}".format(w=w) + for w in _warnings + if w.category == RuntimeWarning + ] + if rw: + raise RuntimeError( + "{} Runtime Warning{},\n{}".format( + len(rw), "" if len(rw) == 1 else "s", "\n".join(rw) + ) + ) + + +@contextlib.contextmanager +def _passthrough_loop_context(loop, fast=False): # type: ignore[no-untyped-def] + """Passthrough loop context. + + Sets up and tears down a loop unless one is passed in via the loop + argument when it's passed straight through. + """ + if loop: + # loop already exists, pass it straight through + yield loop + else: + # this shadows loop_context's standard behavior + loop = setup_test_loop() + yield loop + teardown_test_loop(loop, fast=fast) + + +def pytest_pycollect_makeitem(collector, name, obj): # type: ignore[no-untyped-def] + """Fix pytest collecting for coroutines.""" + if collector.funcnamefilter(name) and asyncio.iscoroutinefunction(obj): + return list(collector._genfunctions(name, obj)) + + +def pytest_pyfunc_call(pyfuncitem): # type: ignore[no-untyped-def] + """Run coroutines in an event loop instead of a normal function call.""" + fast = pyfuncitem.config.getoption("--aiohttp-fast") + if asyncio.iscoroutinefunction(pyfuncitem.function): + existing_loop = pyfuncitem.funcargs.get( + "proactor_loop" + ) or pyfuncitem.funcargs.get("loop", None) + with _runtime_warning_context(): + with _passthrough_loop_context(existing_loop, fast=fast) as _loop: + testargs = { + arg: pyfuncitem.funcargs[arg] + for arg in pyfuncitem._fixtureinfo.argnames + } + _loop.run_until_complete(pyfuncitem.obj(**testargs)) + + return True + + +def pytest_generate_tests(metafunc): # type: ignore[no-untyped-def] + if "loop_factory" not in metafunc.fixturenames: + return + + loops = metafunc.config.option.aiohttp_loop + avail_factories: Dict[str, Type[asyncio.AbstractEventLoopPolicy]] + avail_factories = {"pyloop": asyncio.DefaultEventLoopPolicy} + + if uvloop is not None: # pragma: no cover + avail_factories["uvloop"] = uvloop.EventLoopPolicy + + if loops == "all": + loops = "pyloop,uvloop?" + + factories = {} # type: ignore[var-annotated] + for name in loops.split(","): + required = not name.endswith("?") + name = name.strip(" ?") + if name not in avail_factories: # pragma: no cover + if required: + raise ValueError( + "Unknown loop '%s', available loops: %s" + % (name, list(factories.keys())) + ) + else: + continue + factories[name] = avail_factories[name] + metafunc.parametrize( + "loop_factory", list(factories.values()), ids=list(factories.keys()) + ) + + +@pytest.fixture +def loop(loop_factory, fast, loop_debug): # type: ignore[no-untyped-def] + """Return an instance of the event loop.""" + policy = loop_factory() + asyncio.set_event_loop_policy(policy) + with loop_context(fast=fast) as _loop: + if loop_debug: + _loop.set_debug(True) # pragma: no cover + asyncio.set_event_loop(_loop) + yield _loop + + +@pytest.fixture +def proactor_loop(): # type: ignore[no-untyped-def] + policy = asyncio.WindowsProactorEventLoopPolicy() # type: ignore[attr-defined] + asyncio.set_event_loop_policy(policy) + + with loop_context(policy.new_event_loop) as _loop: + asyncio.set_event_loop(_loop) + yield _loop + + +@pytest.fixture +def unused_port(aiohttp_unused_port: Callable[[], int]) -> Callable[[], int]: + warnings.warn( + "Deprecated, use aiohttp_unused_port fixture instead", + DeprecationWarning, + stacklevel=2, + ) + return aiohttp_unused_port + + +@pytest.fixture +def aiohttp_unused_port() -> Callable[[], int]: + """Return a port that is unused on the current host.""" + return _unused_port + + +@pytest.fixture +def aiohttp_server(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpServer]: + """Factory to create a TestServer instance, given an app. + + aiohttp_server(app, **kwargs) + """ + servers = [] + + async def go(app, *, port=None, **kwargs): # type: ignore[no-untyped-def] + server = TestServer(app, port=port) + await server.start_server(loop=loop, **kwargs) + servers.append(server) + return server + + yield go + + async def finalize() -> None: + while servers: + await servers.pop().close() + + loop.run_until_complete(finalize()) + + +@pytest.fixture +def test_server(aiohttp_server): # type: ignore[no-untyped-def] # pragma: no cover + warnings.warn( + "Deprecated, use aiohttp_server fixture instead", + DeprecationWarning, + stacklevel=2, + ) + return aiohttp_server + + +@pytest.fixture +def aiohttp_raw_server(loop: asyncio.AbstractEventLoop) -> Iterator[AiohttpRawServer]: + """Factory to create a RawTestServer instance, given a web handler. + + aiohttp_raw_server(handler, **kwargs) + """ + servers = [] + + async def go(handler, *, port=None, **kwargs): # type: ignore[no-untyped-def] + server = RawTestServer(handler, port=port) + await server.start_server(loop=loop, **kwargs) + servers.append(server) + return server + + yield go + + async def finalize() -> None: + while servers: + await servers.pop().close() + + loop.run_until_complete(finalize()) + + +@pytest.fixture +def raw_test_server( # type: ignore[no-untyped-def] # pragma: no cover + aiohttp_raw_server, +): + warnings.warn( + "Deprecated, use aiohttp_raw_server fixture instead", + DeprecationWarning, + stacklevel=2, + ) + return aiohttp_raw_server + + +@pytest.fixture +def aiohttp_client( + loop: asyncio.AbstractEventLoop, +) -> Iterator[AiohttpClient]: + """Factory to create a TestClient instance. + + aiohttp_client(app, **kwargs) + aiohttp_client(server, **kwargs) + aiohttp_client(raw_server, **kwargs) + """ + clients = [] + + async def go( + __param: Union[Application, BaseTestServer], + *args: Any, + server_kwargs: Optional[Dict[str, Any]] = None, + **kwargs: Any + ) -> TestClient: + + if isinstance(__param, Callable) and not isinstance( # type: ignore[arg-type] + __param, (Application, BaseTestServer) + ): + __param = __param(loop, *args, **kwargs) + kwargs = {} + else: + assert not args, "args should be empty" + + if isinstance(__param, Application): + server_kwargs = server_kwargs or {} + server = TestServer(__param, loop=loop, **server_kwargs) + client = TestClient(server, loop=loop, **kwargs) + elif isinstance(__param, BaseTestServer): + client = TestClient(__param, loop=loop, **kwargs) + else: + raise ValueError("Unknown argument type: %r" % type(__param)) + + await client.start_server() + clients.append(client) + return client + + yield go + + async def finalize() -> None: + while clients: + await clients.pop().close() + + loop.run_until_complete(finalize()) + + +@pytest.fixture +def test_client(aiohttp_client): # type: ignore[no-untyped-def] # pragma: no cover + warnings.warn( + "Deprecated, use aiohttp_client fixture instead", + DeprecationWarning, + stacklevel=2, + ) + return aiohttp_client diff --git a/llm/Lib/site-packages/aiohttp/resolver.py b/llm/Lib/site-packages/aiohttp/resolver.py new file mode 100644 index 0000000000000000000000000000000000000000..a1e166134c7985c590df592f90bbd4a56171376b --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/resolver.py @@ -0,0 +1,160 @@ +import asyncio +import socket +from typing import Any, Dict, List, Optional, Type, Union + +from .abc import AbstractResolver +from .helpers import get_running_loop + +__all__ = ("ThreadedResolver", "AsyncResolver", "DefaultResolver") + +try: + import aiodns + + # aiodns_default = hasattr(aiodns.DNSResolver, 'gethostbyname') +except ImportError: # pragma: no cover + aiodns = None + +aiodns_default = False + + +class ThreadedResolver(AbstractResolver): + """Threaded resolver. + + Uses an Executor for synchronous getaddrinfo() calls. + concurrent.futures.ThreadPoolExecutor is used by default. + """ + + def __init__(self, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: + self._loop = get_running_loop(loop) + + async def resolve( + self, hostname: str, port: int = 0, family: int = socket.AF_INET + ) -> List[Dict[str, Any]]: + infos = await self._loop.getaddrinfo( + hostname, + port, + type=socket.SOCK_STREAM, + family=family, + flags=socket.AI_ADDRCONFIG, + ) + + hosts = [] + for family, _, proto, _, address in infos: + if family == socket.AF_INET6: + if len(address) < 3: + # IPv6 is not supported by Python build, + # or IPv6 is not enabled in the host + continue + if address[3]: + # This is essential for link-local IPv6 addresses. + # LL IPv6 is a VERY rare case. Strictly speaking, we should use + # getnameinfo() unconditionally, but performance makes sense. + host, _port = socket.getnameinfo( + address, socket.NI_NUMERICHOST | socket.NI_NUMERICSERV + ) + port = int(_port) + else: + host, port = address[:2] + else: # IPv4 + assert family == socket.AF_INET + host, port = address # type: ignore[misc] + hosts.append( + { + "hostname": hostname, + "host": host, + "port": port, + "family": family, + "proto": proto, + "flags": socket.AI_NUMERICHOST | socket.AI_NUMERICSERV, + } + ) + + return hosts + + async def close(self) -> None: + pass + + +class AsyncResolver(AbstractResolver): + """Use the `aiodns` package to make asynchronous DNS lookups""" + + def __init__( + self, + loop: Optional[asyncio.AbstractEventLoop] = None, + *args: Any, + **kwargs: Any + ) -> None: + if aiodns is None: + raise RuntimeError("Resolver requires aiodns library") + + self._loop = get_running_loop(loop) + self._resolver = aiodns.DNSResolver(*args, loop=loop, **kwargs) + + if not hasattr(self._resolver, "gethostbyname"): + # aiodns 1.1 is not available, fallback to DNSResolver.query + self.resolve = self._resolve_with_query # type: ignore + + async def resolve( + self, host: str, port: int = 0, family: int = socket.AF_INET + ) -> List[Dict[str, Any]]: + try: + resp = await self._resolver.gethostbyname(host, family) + except aiodns.error.DNSError as exc: + msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed" + raise OSError(msg) from exc + hosts = [] + for address in resp.addresses: + hosts.append( + { + "hostname": host, + "host": address, + "port": port, + "family": family, + "proto": 0, + "flags": socket.AI_NUMERICHOST | socket.AI_NUMERICSERV, + } + ) + + if not hosts: + raise OSError("DNS lookup failed") + + return hosts + + async def _resolve_with_query( + self, host: str, port: int = 0, family: int = socket.AF_INET + ) -> List[Dict[str, Any]]: + if family == socket.AF_INET6: + qtype = "AAAA" + else: + qtype = "A" + + try: + resp = await self._resolver.query(host, qtype) + except aiodns.error.DNSError as exc: + msg = exc.args[1] if len(exc.args) >= 1 else "DNS lookup failed" + raise OSError(msg) from exc + + hosts = [] + for rr in resp: + hosts.append( + { + "hostname": host, + "host": rr.host, + "port": port, + "family": family, + "proto": 0, + "flags": socket.AI_NUMERICHOST, + } + ) + + if not hosts: + raise OSError("DNS lookup failed") + + return hosts + + async def close(self) -> None: + self._resolver.cancel() + + +_DefaultType = Type[Union[AsyncResolver, ThreadedResolver]] +DefaultResolver: _DefaultType = AsyncResolver if aiodns_default else ThreadedResolver diff --git a/llm/Lib/site-packages/aiohttp/streams.py b/llm/Lib/site-packages/aiohttp/streams.py new file mode 100644 index 0000000000000000000000000000000000000000..df9e94405f37af31f67e822d6a11f35b23464568 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/streams.py @@ -0,0 +1,684 @@ +import asyncio +import collections +import warnings +from typing import ( + Awaitable, + Callable, + Deque, + Final, + Generic, + List, + Optional, + Tuple, + TypeVar, +) + +from .base_protocol import BaseProtocol +from .helpers import ( + _EXC_SENTINEL, + BaseTimerContext, + TimerNoop, + set_exception, + set_result, +) +from .log import internal_logger + +__all__ = ( + "EMPTY_PAYLOAD", + "EofStream", + "StreamReader", + "DataQueue", + "FlowControlDataQueue", +) + +_T = TypeVar("_T") + + +class EofStream(Exception): + """eof stream indication.""" + + +class AsyncStreamIterator(Generic[_T]): + def __init__(self, read_func: Callable[[], Awaitable[_T]]) -> None: + self.read_func = read_func + + def __aiter__(self) -> "AsyncStreamIterator[_T]": + return self + + async def __anext__(self) -> _T: + try: + rv = await self.read_func() + except EofStream: + raise StopAsyncIteration + if rv == b"": + raise StopAsyncIteration + return rv + + +class ChunkTupleAsyncStreamIterator: + def __init__(self, stream: "StreamReader") -> None: + self._stream = stream + + def __aiter__(self) -> "ChunkTupleAsyncStreamIterator": + return self + + async def __anext__(self) -> Tuple[bytes, bool]: + rv = await self._stream.readchunk() + if rv == (b"", False): + raise StopAsyncIteration + return rv + + +class AsyncStreamReaderMixin: + def __aiter__(self) -> AsyncStreamIterator[bytes]: + return AsyncStreamIterator(self.readline) # type: ignore[attr-defined] + + def iter_chunked(self, n: int) -> AsyncStreamIterator[bytes]: + """Returns an asynchronous iterator that yields chunks of size n.""" + return AsyncStreamIterator(lambda: self.read(n)) # type: ignore[attr-defined] + + def iter_any(self) -> AsyncStreamIterator[bytes]: + """Yield all available data as soon as it is received.""" + return AsyncStreamIterator(self.readany) # type: ignore[attr-defined] + + def iter_chunks(self) -> ChunkTupleAsyncStreamIterator: + """Yield chunks of data as they are received by the server. + + The yielded objects are tuples + of (bytes, bool) as returned by the StreamReader.readchunk method. + """ + return ChunkTupleAsyncStreamIterator(self) # type: ignore[arg-type] + + +class StreamReader(AsyncStreamReaderMixin): + """An enhancement of asyncio.StreamReader. + + Supports asynchronous iteration by line, chunk or as available:: + + async for line in reader: + ... + async for chunk in reader.iter_chunked(1024): + ... + async for slice in reader.iter_any(): + ... + + """ + + total_bytes = 0 + + def __init__( + self, + protocol: BaseProtocol, + limit: int, + *, + timer: Optional[BaseTimerContext] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, + ) -> None: + self._protocol = protocol + self._low_water = limit + self._high_water = limit * 2 + if loop is None: + loop = asyncio.get_event_loop() + self._loop = loop + self._size = 0 + self._cursor = 0 + self._http_chunk_splits: Optional[List[int]] = None + self._buffer: Deque[bytes] = collections.deque() + self._buffer_offset = 0 + self._eof = False + self._waiter: Optional[asyncio.Future[None]] = None + self._eof_waiter: Optional[asyncio.Future[None]] = None + self._exception: Optional[BaseException] = None + self._timer = TimerNoop() if timer is None else timer + self._eof_callbacks: List[Callable[[], None]] = [] + + def __repr__(self) -> str: + info = [self.__class__.__name__] + if self._size: + info.append("%d bytes" % self._size) + if self._eof: + info.append("eof") + if self._low_water != 2**16: # default limit + info.append("low=%d high=%d" % (self._low_water, self._high_water)) + if self._waiter: + info.append("w=%r" % self._waiter) + if self._exception: + info.append("e=%r" % self._exception) + return "<%s>" % " ".join(info) + + def get_read_buffer_limits(self) -> Tuple[int, int]: + return (self._low_water, self._high_water) + + def exception(self) -> Optional[BaseException]: + return self._exception + + def set_exception( + self, + exc: BaseException, + exc_cause: BaseException = _EXC_SENTINEL, + ) -> None: + self._exception = exc + self._eof_callbacks.clear() + + waiter = self._waiter + if waiter is not None: + self._waiter = None + set_exception(waiter, exc, exc_cause) + + waiter = self._eof_waiter + if waiter is not None: + self._eof_waiter = None + set_exception(waiter, exc, exc_cause) + + def on_eof(self, callback: Callable[[], None]) -> None: + if self._eof: + try: + callback() + except Exception: + internal_logger.exception("Exception in eof callback") + else: + self._eof_callbacks.append(callback) + + def feed_eof(self) -> None: + self._eof = True + + waiter = self._waiter + if waiter is not None: + self._waiter = None + set_result(waiter, None) + + waiter = self._eof_waiter + if waiter is not None: + self._eof_waiter = None + set_result(waiter, None) + + for cb in self._eof_callbacks: + try: + cb() + except Exception: + internal_logger.exception("Exception in eof callback") + + self._eof_callbacks.clear() + + def is_eof(self) -> bool: + """Return True if 'feed_eof' was called.""" + return self._eof + + def at_eof(self) -> bool: + """Return True if the buffer is empty and 'feed_eof' was called.""" + return self._eof and not self._buffer + + async def wait_eof(self) -> None: + if self._eof: + return + + assert self._eof_waiter is None + self._eof_waiter = self._loop.create_future() + try: + await self._eof_waiter + finally: + self._eof_waiter = None + + def unread_data(self, data: bytes) -> None: + """rollback reading some data from stream, inserting it to buffer head.""" + warnings.warn( + "unread_data() is deprecated " + "and will be removed in future releases (#3260)", + DeprecationWarning, + stacklevel=2, + ) + if not data: + return + + if self._buffer_offset: + self._buffer[0] = self._buffer[0][self._buffer_offset :] + self._buffer_offset = 0 + self._size += len(data) + self._cursor -= len(data) + self._buffer.appendleft(data) + self._eof_counter = 0 + + # TODO: size is ignored, remove the param later + def feed_data(self, data: bytes, size: int = 0) -> None: + assert not self._eof, "feed_data after feed_eof" + + if not data: + return + + self._size += len(data) + self._buffer.append(data) + self.total_bytes += len(data) + + waiter = self._waiter + if waiter is not None: + self._waiter = None + set_result(waiter, None) + + if self._size > self._high_water and not self._protocol._reading_paused: + self._protocol.pause_reading() + + def begin_http_chunk_receiving(self) -> None: + if self._http_chunk_splits is None: + if self.total_bytes: + raise RuntimeError( + "Called begin_http_chunk_receiving when" "some data was already fed" + ) + self._http_chunk_splits = [] + + def end_http_chunk_receiving(self) -> None: + if self._http_chunk_splits is None: + raise RuntimeError( + "Called end_chunk_receiving without calling " + "begin_chunk_receiving first" + ) + + # self._http_chunk_splits contains logical byte offsets from start of + # the body transfer. Each offset is the offset of the end of a chunk. + # "Logical" means bytes, accessible for a user. + # If no chunks containing logical data were received, current position + # is difinitely zero. + pos = self._http_chunk_splits[-1] if self._http_chunk_splits else 0 + + if self.total_bytes == pos: + # We should not add empty chunks here. So we check for that. + # Note, when chunked + gzip is used, we can receive a chunk + # of compressed data, but that data may not be enough for gzip FSM + # to yield any uncompressed data. That's why current position may + # not change after receiving a chunk. + return + + self._http_chunk_splits.append(self.total_bytes) + + # wake up readchunk when end of http chunk received + waiter = self._waiter + if waiter is not None: + self._waiter = None + set_result(waiter, None) + + async def _wait(self, func_name: str) -> None: + # StreamReader uses a future to link the protocol feed_data() method + # to a read coroutine. Running two read coroutines at the same time + # would have an unexpected behaviour. It would not possible to know + # which coroutine would get the next data. + if self._waiter is not None: + raise RuntimeError( + "%s() called while another coroutine is " + "already waiting for incoming data" % func_name + ) + + waiter = self._waiter = self._loop.create_future() + try: + with self._timer: + await waiter + finally: + self._waiter = None + + async def readline(self) -> bytes: + return await self.readuntil() + + async def readuntil(self, separator: bytes = b"\n") -> bytes: + seplen = len(separator) + if seplen == 0: + raise ValueError("Separator should be at least one-byte string") + + if self._exception is not None: + raise self._exception + + chunk = b"" + chunk_size = 0 + not_enough = True + + while not_enough: + while self._buffer and not_enough: + offset = self._buffer_offset + ichar = self._buffer[0].find(separator, offset) + 1 + # Read from current offset to found separator or to the end. + data = self._read_nowait_chunk( + ichar - offset + seplen - 1 if ichar else -1 + ) + chunk += data + chunk_size += len(data) + if ichar: + not_enough = False + + if chunk_size > self._high_water: + raise ValueError("Chunk too big") + + if self._eof: + break + + if not_enough: + await self._wait("readuntil") + + return chunk + + async def read(self, n: int = -1) -> bytes: + if self._exception is not None: + raise self._exception + + # migration problem; with DataQueue you have to catch + # EofStream exception, so common way is to run payload.read() inside + # infinite loop. what can cause real infinite loop with StreamReader + # lets keep this code one major release. + if __debug__: + if self._eof and not self._buffer: + self._eof_counter = getattr(self, "_eof_counter", 0) + 1 + if self._eof_counter > 5: + internal_logger.warning( + "Multiple access to StreamReader in eof state, " + "might be infinite loop.", + stack_info=True, + ) + + if not n: + return b"" + + if n < 0: + # This used to just loop creating a new waiter hoping to + # collect everything in self._buffer, but that would + # deadlock if the subprocess sends more than self.limit + # bytes. So just call self.readany() until EOF. + blocks = [] + while True: + block = await self.readany() + if not block: + break + blocks.append(block) + return b"".join(blocks) + + # TODO: should be `if` instead of `while` + # because waiter maybe triggered on chunk end, + # without feeding any data + while not self._buffer and not self._eof: + await self._wait("read") + + return self._read_nowait(n) + + async def readany(self) -> bytes: + if self._exception is not None: + raise self._exception + + # TODO: should be `if` instead of `while` + # because waiter maybe triggered on chunk end, + # without feeding any data + while not self._buffer and not self._eof: + await self._wait("readany") + + return self._read_nowait(-1) + + async def readchunk(self) -> Tuple[bytes, bool]: + """Returns a tuple of (data, end_of_http_chunk). + + When chunked transfer + encoding is used, end_of_http_chunk is a boolean indicating if the end + of the data corresponds to the end of a HTTP chunk , otherwise it is + always False. + """ + while True: + if self._exception is not None: + raise self._exception + + while self._http_chunk_splits: + pos = self._http_chunk_splits.pop(0) + if pos == self._cursor: + return (b"", True) + if pos > self._cursor: + return (self._read_nowait(pos - self._cursor), True) + internal_logger.warning( + "Skipping HTTP chunk end due to data " + "consumption beyond chunk boundary" + ) + + if self._buffer: + return (self._read_nowait_chunk(-1), False) + # return (self._read_nowait(-1), False) + + if self._eof: + # Special case for signifying EOF. + # (b'', True) is not a final return value actually. + return (b"", False) + + await self._wait("readchunk") + + async def readexactly(self, n: int) -> bytes: + if self._exception is not None: + raise self._exception + + blocks: List[bytes] = [] + while n > 0: + block = await self.read(n) + if not block: + partial = b"".join(blocks) + raise asyncio.IncompleteReadError(partial, len(partial) + n) + blocks.append(block) + n -= len(block) + + return b"".join(blocks) + + def read_nowait(self, n: int = -1) -> bytes: + # default was changed to be consistent with .read(-1) + # + # I believe the most users don't know about the method and + # they are not affected. + if self._exception is not None: + raise self._exception + + if self._waiter and not self._waiter.done(): + raise RuntimeError( + "Called while some coroutine is waiting for incoming data." + ) + + return self._read_nowait(n) + + def _read_nowait_chunk(self, n: int) -> bytes: + first_buffer = self._buffer[0] + offset = self._buffer_offset + if n != -1 and len(first_buffer) - offset > n: + data = first_buffer[offset : offset + n] + self._buffer_offset += n + + elif offset: + self._buffer.popleft() + data = first_buffer[offset:] + self._buffer_offset = 0 + + else: + data = self._buffer.popleft() + + self._size -= len(data) + self._cursor += len(data) + + chunk_splits = self._http_chunk_splits + # Prevent memory leak: drop useless chunk splits + while chunk_splits and chunk_splits[0] < self._cursor: + chunk_splits.pop(0) + + if self._size < self._low_water and self._protocol._reading_paused: + self._protocol.resume_reading() + return data + + def _read_nowait(self, n: int) -> bytes: + """Read not more than n bytes, or whole buffer if n == -1""" + self._timer.assert_timeout() + + chunks = [] + while self._buffer: + chunk = self._read_nowait_chunk(n) + chunks.append(chunk) + if n != -1: + n -= len(chunk) + if n == 0: + break + + return b"".join(chunks) if chunks else b"" + + +class EmptyStreamReader(StreamReader): # lgtm [py/missing-call-to-init] + def __init__(self) -> None: + self._read_eof_chunk = False + + def __repr__(self) -> str: + return "<%s>" % self.__class__.__name__ + + def exception(self) -> Optional[BaseException]: + return None + + def set_exception( + self, + exc: BaseException, + exc_cause: BaseException = _EXC_SENTINEL, + ) -> None: + pass + + def on_eof(self, callback: Callable[[], None]) -> None: + try: + callback() + except Exception: + internal_logger.exception("Exception in eof callback") + + def feed_eof(self) -> None: + pass + + def is_eof(self) -> bool: + return True + + def at_eof(self) -> bool: + return True + + async def wait_eof(self) -> None: + return + + def feed_data(self, data: bytes, n: int = 0) -> None: + pass + + async def readline(self) -> bytes: + return b"" + + async def read(self, n: int = -1) -> bytes: + return b"" + + # TODO add async def readuntil + + async def readany(self) -> bytes: + return b"" + + async def readchunk(self) -> Tuple[bytes, bool]: + if not self._read_eof_chunk: + self._read_eof_chunk = True + return (b"", False) + + return (b"", True) + + async def readexactly(self, n: int) -> bytes: + raise asyncio.IncompleteReadError(b"", n) + + def read_nowait(self, n: int = -1) -> bytes: + return b"" + + +EMPTY_PAYLOAD: Final[StreamReader] = EmptyStreamReader() + + +class DataQueue(Generic[_T]): + """DataQueue is a general-purpose blocking queue with one reader.""" + + def __init__(self, loop: asyncio.AbstractEventLoop) -> None: + self._loop = loop + self._eof = False + self._waiter: Optional[asyncio.Future[None]] = None + self._exception: Optional[BaseException] = None + self._size = 0 + self._buffer: Deque[Tuple[_T, int]] = collections.deque() + + def __len__(self) -> int: + return len(self._buffer) + + def is_eof(self) -> bool: + return self._eof + + def at_eof(self) -> bool: + return self._eof and not self._buffer + + def exception(self) -> Optional[BaseException]: + return self._exception + + def set_exception( + self, + exc: BaseException, + exc_cause: BaseException = _EXC_SENTINEL, + ) -> None: + self._eof = True + self._exception = exc + + waiter = self._waiter + if waiter is not None: + self._waiter = None + set_exception(waiter, exc, exc_cause) + + def feed_data(self, data: _T, size: int = 0) -> None: + self._size += size + self._buffer.append((data, size)) + + waiter = self._waiter + if waiter is not None: + self._waiter = None + set_result(waiter, None) + + def feed_eof(self) -> None: + self._eof = True + + waiter = self._waiter + if waiter is not None: + self._waiter = None + set_result(waiter, None) + + async def read(self) -> _T: + if not self._buffer and not self._eof: + assert not self._waiter + self._waiter = self._loop.create_future() + try: + await self._waiter + except (asyncio.CancelledError, asyncio.TimeoutError): + self._waiter = None + raise + + if self._buffer: + data, size = self._buffer.popleft() + self._size -= size + return data + else: + if self._exception is not None: + raise self._exception + else: + raise EofStream + + def __aiter__(self) -> AsyncStreamIterator[_T]: + return AsyncStreamIterator(self.read) + + +class FlowControlDataQueue(DataQueue[_T]): + """FlowControlDataQueue resumes and pauses an underlying stream. + + It is a destination for parsed data. + """ + + def __init__( + self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop + ) -> None: + super().__init__(loop=loop) + + self._protocol = protocol + self._limit = limit * 2 + + def feed_data(self, data: _T, size: int = 0) -> None: + super().feed_data(data, size) + + if self._size > self._limit and not self._protocol._reading_paused: + self._protocol.pause_reading() + + async def read(self) -> _T: + try: + return await super().read() + finally: + if self._size < self._limit and self._protocol._reading_paused: + self._protocol.resume_reading() diff --git a/llm/Lib/site-packages/aiohttp/tcp_helpers.py b/llm/Lib/site-packages/aiohttp/tcp_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..8d39ba9e3c818281d45f8525d0cdf112871e2017 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/tcp_helpers.py @@ -0,0 +1,37 @@ +"""Helper methods to tune a TCP connection""" + +import asyncio +import socket +from contextlib import suppress +from typing import Optional # noqa + +__all__ = ("tcp_keepalive", "tcp_nodelay") + + +if hasattr(socket, "SO_KEEPALIVE"): + + def tcp_keepalive(transport: asyncio.Transport) -> None: + sock = transport.get_extra_info("socket") + if sock is not None: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + +else: + + def tcp_keepalive(transport: asyncio.Transport) -> None: # pragma: no cover + pass + + +def tcp_nodelay(transport: asyncio.Transport, value: bool) -> None: + sock = transport.get_extra_info("socket") + + if sock is None: + return + + if sock.family not in (socket.AF_INET, socket.AF_INET6): + return + + value = bool(value) + + # socket may be closed already, on windows OSError get raised + with suppress(OSError): + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, value) diff --git a/llm/Lib/site-packages/aiohttp/test_utils.py b/llm/Lib/site-packages/aiohttp/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1992253a6a233d4d87971527bceaae593484a3b5 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/test_utils.py @@ -0,0 +1,682 @@ +"""Utilities shared by tests.""" + +import asyncio +import contextlib +import gc +import inspect +import ipaddress +import os +import socket +import sys +import warnings +from abc import ABC, abstractmethod +from types import TracebackType +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterator, + List, + Optional, + Type, + Union, + cast, +) +from unittest import IsolatedAsyncioTestCase, mock + +from aiosignal import Signal +from multidict import CIMultiDict, CIMultiDictProxy +from yarl import URL + +import aiohttp +from aiohttp.client import _RequestContextManager, _WSRequestContextManager + +from . import ClientSession, hdrs +from .abc import AbstractCookieJar +from .client_reqrep import ClientResponse +from .client_ws import ClientWebSocketResponse +from .helpers import sentinel +from .http import HttpVersion, RawRequestMessage +from .typedefs import StrOrURL +from .web import ( + Application, + AppRunner, + BaseRunner, + Request, + Server, + ServerRunner, + SockSite, + UrlMappingMatchInfo, +) +from .web_protocol import _RequestHandler + +if TYPE_CHECKING: + from ssl import SSLContext +else: + SSLContext = None + +REUSE_ADDRESS = os.name == "posix" and sys.platform != "cygwin" + + +def get_unused_port_socket( + host: str, family: socket.AddressFamily = socket.AF_INET +) -> socket.socket: + return get_port_socket(host, 0, family) + + +def get_port_socket( + host: str, port: int, family: socket.AddressFamily +) -> socket.socket: + s = socket.socket(family, socket.SOCK_STREAM) + if REUSE_ADDRESS: + # Windows has different semantics for SO_REUSEADDR, + # so don't set it. Ref: + # https://docs.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind((host, port)) + return s + + +def unused_port() -> int: + """Return a port that is unused on the current host.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return cast(int, s.getsockname()[1]) + + +class BaseTestServer(ABC): + __test__ = False + + def __init__( + self, + *, + scheme: Union[str, object] = sentinel, + loop: Optional[asyncio.AbstractEventLoop] = None, + host: str = "127.0.0.1", + port: Optional[int] = None, + skip_url_asserts: bool = False, + socket_factory: Callable[ + [str, int, socket.AddressFamily], socket.socket + ] = get_port_socket, + **kwargs: Any, + ) -> None: + self._loop = loop + self.runner: Optional[BaseRunner] = None + self._root: Optional[URL] = None + self.host = host + self.port = port + self._closed = False + self.scheme = scheme + self.skip_url_asserts = skip_url_asserts + self.socket_factory = socket_factory + + async def start_server( + self, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any + ) -> None: + if self.runner: + return + self._loop = loop + self._ssl = kwargs.pop("ssl", None) + self.runner = await self._make_runner(handler_cancellation=True, **kwargs) + await self.runner.setup() + if not self.port: + self.port = 0 + try: + version = ipaddress.ip_address(self.host).version + except ValueError: + version = 4 + family = socket.AF_INET6 if version == 6 else socket.AF_INET + _sock = self.socket_factory(self.host, self.port, family) + self.host, self.port = _sock.getsockname()[:2] + site = SockSite(self.runner, sock=_sock, ssl_context=self._ssl) + await site.start() + server = site._server + assert server is not None + sockets = server.sockets # type: ignore[attr-defined] + assert sockets is not None + self.port = sockets[0].getsockname()[1] + if self.scheme is sentinel: + if self._ssl: + scheme = "https" + else: + scheme = "http" + self.scheme = scheme + self._root = URL(f"{self.scheme}://{self.host}:{self.port}") + + @abstractmethod # pragma: no cover + async def _make_runner(self, **kwargs: Any) -> BaseRunner: + pass + + def make_url(self, path: StrOrURL) -> URL: + assert self._root is not None + url = URL(path) + if not self.skip_url_asserts: + assert not url.is_absolute() + return self._root.join(url) + else: + return URL(str(self._root) + str(path)) + + @property + def started(self) -> bool: + return self.runner is not None + + @property + def closed(self) -> bool: + return self._closed + + @property + def handler(self) -> Server: + # for backward compatibility + # web.Server instance + runner = self.runner + assert runner is not None + assert runner.server is not None + return runner.server + + async def close(self) -> None: + """Close all fixtures created by the test client. + + After that point, the TestClient is no longer usable. + + This is an idempotent function: running close multiple times + will not have any additional effects. + + close is also run when the object is garbage collected, and on + exit when used as a context manager. + + """ + if self.started and not self.closed: + assert self.runner is not None + await self.runner.cleanup() + self._root = None + self.port = None + self._closed = True + + def __enter__(self) -> None: + raise TypeError("Use async with instead") + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + # __exit__ should exist in pair with __enter__ but never executed + pass # pragma: no cover + + async def __aenter__(self) -> "BaseTestServer": + await self.start_server(loop=self._loop) + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + await self.close() + + +class TestServer(BaseTestServer): + def __init__( + self, + app: Application, + *, + scheme: Union[str, object] = sentinel, + host: str = "127.0.0.1", + port: Optional[int] = None, + **kwargs: Any, + ): + self.app = app + super().__init__(scheme=scheme, host=host, port=port, **kwargs) + + async def _make_runner(self, **kwargs: Any) -> BaseRunner: + return AppRunner(self.app, **kwargs) + + +class RawTestServer(BaseTestServer): + def __init__( + self, + handler: _RequestHandler, + *, + scheme: Union[str, object] = sentinel, + host: str = "127.0.0.1", + port: Optional[int] = None, + **kwargs: Any, + ) -> None: + self._handler = handler + super().__init__(scheme=scheme, host=host, port=port, **kwargs) + + async def _make_runner(self, debug: bool = True, **kwargs: Any) -> ServerRunner: + srv = Server(self._handler, loop=self._loop, debug=debug, **kwargs) + return ServerRunner(srv, debug=debug, **kwargs) + + +class TestClient: + """ + A test client implementation. + + To write functional tests for aiohttp based servers. + + """ + + __test__ = False + + def __init__( + self, + server: BaseTestServer, + *, + cookie_jar: Optional[AbstractCookieJar] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, + **kwargs: Any, + ) -> None: + if not isinstance(server, BaseTestServer): + raise TypeError( + "server must be TestServer " "instance, found type: %r" % type(server) + ) + self._server = server + self._loop = loop + if cookie_jar is None: + cookie_jar = aiohttp.CookieJar(unsafe=True, loop=loop) + self._session = ClientSession(loop=loop, cookie_jar=cookie_jar, **kwargs) + self._closed = False + self._responses: List[ClientResponse] = [] + self._websockets: List[ClientWebSocketResponse] = [] + + async def start_server(self) -> None: + await self._server.start_server(loop=self._loop) + + @property + def host(self) -> str: + return self._server.host + + @property + def port(self) -> Optional[int]: + return self._server.port + + @property + def server(self) -> BaseTestServer: + return self._server + + @property + def app(self) -> Optional[Application]: + return cast(Optional[Application], getattr(self._server, "app", None)) + + @property + def session(self) -> ClientSession: + """An internal aiohttp.ClientSession. + + Unlike the methods on the TestClient, client session requests + do not automatically include the host in the url queried, and + will require an absolute path to the resource. + + """ + return self._session + + def make_url(self, path: StrOrURL) -> URL: + return self._server.make_url(path) + + async def _request( + self, method: str, path: StrOrURL, **kwargs: Any + ) -> ClientResponse: + resp = await self._session.request(method, self.make_url(path), **kwargs) + # save it to close later + self._responses.append(resp) + return resp + + def request( + self, method: str, path: StrOrURL, **kwargs: Any + ) -> _RequestContextManager: + """Routes a request to tested http server. + + The interface is identical to aiohttp.ClientSession.request, + except the loop kwarg is overridden by the instance used by the + test server. + + """ + return _RequestContextManager(self._request(method, path, **kwargs)) + + def get(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: + """Perform an HTTP GET request.""" + return _RequestContextManager(self._request(hdrs.METH_GET, path, **kwargs)) + + def post(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: + """Perform an HTTP POST request.""" + return _RequestContextManager(self._request(hdrs.METH_POST, path, **kwargs)) + + def options(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: + """Perform an HTTP OPTIONS request.""" + return _RequestContextManager(self._request(hdrs.METH_OPTIONS, path, **kwargs)) + + def head(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: + """Perform an HTTP HEAD request.""" + return _RequestContextManager(self._request(hdrs.METH_HEAD, path, **kwargs)) + + def put(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: + """Perform an HTTP PUT request.""" + return _RequestContextManager(self._request(hdrs.METH_PUT, path, **kwargs)) + + def patch(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: + """Perform an HTTP PATCH request.""" + return _RequestContextManager(self._request(hdrs.METH_PATCH, path, **kwargs)) + + def delete(self, path: StrOrURL, **kwargs: Any) -> _RequestContextManager: + """Perform an HTTP PATCH request.""" + return _RequestContextManager(self._request(hdrs.METH_DELETE, path, **kwargs)) + + def ws_connect(self, path: StrOrURL, **kwargs: Any) -> _WSRequestContextManager: + """Initiate websocket connection. + + The api corresponds to aiohttp.ClientSession.ws_connect. + + """ + return _WSRequestContextManager(self._ws_connect(path, **kwargs)) + + async def _ws_connect( + self, path: StrOrURL, **kwargs: Any + ) -> ClientWebSocketResponse: + ws = await self._session.ws_connect(self.make_url(path), **kwargs) + self._websockets.append(ws) + return ws + + async def close(self) -> None: + """Close all fixtures created by the test client. + + After that point, the TestClient is no longer usable. + + This is an idempotent function: running close multiple times + will not have any additional effects. + + close is also run on exit when used as a(n) (asynchronous) + context manager. + + """ + if not self._closed: + for resp in self._responses: + resp.close() + for ws in self._websockets: + await ws.close() + await self._session.close() + await self._server.close() + self._closed = True + + def __enter__(self) -> None: + raise TypeError("Use async with instead") + + def __exit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: + # __exit__ should exist in pair with __enter__ but never executed + pass # pragma: no cover + + async def __aenter__(self) -> "TestClient": + await self.start_server() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: + await self.close() + + +class AioHTTPTestCase(IsolatedAsyncioTestCase): + """A base class to allow for unittest web applications using aiohttp. + + Provides the following: + + * self.client (aiohttp.test_utils.TestClient): an aiohttp test client. + * self.loop (asyncio.BaseEventLoop): the event loop in which the + application and server are running. + * self.app (aiohttp.web.Application): the application returned by + self.get_application() + + Note that the TestClient's methods are asynchronous: you have to + execute function on the test client using asynchronous methods. + """ + + async def get_application(self) -> Application: + """Get application. + + This method should be overridden + to return the aiohttp.web.Application + object to test. + """ + return self.get_app() + + def get_app(self) -> Application: + """Obsolete method used to constructing web application. + + Use .get_application() coroutine instead. + """ + raise RuntimeError("Did you forget to define get_application()?") + + async def asyncSetUp(self) -> None: + self.loop = asyncio.get_running_loop() + return await self.setUpAsync() + + async def setUpAsync(self) -> None: + self.app = await self.get_application() + self.server = await self.get_server(self.app) + self.client = await self.get_client(self.server) + + await self.client.start_server() + + async def asyncTearDown(self) -> None: + return await self.tearDownAsync() + + async def tearDownAsync(self) -> None: + await self.client.close() + + async def get_server(self, app: Application) -> TestServer: + """Return a TestServer instance.""" + return TestServer(app, loop=self.loop) + + async def get_client(self, server: TestServer) -> TestClient: + """Return a TestClient instance.""" + return TestClient(server, loop=self.loop) + + +def unittest_run_loop(func: Any, *args: Any, **kwargs: Any) -> Any: + """ + A decorator dedicated to use with asynchronous AioHTTPTestCase test methods. + + In 3.8+, this does nothing. + """ + warnings.warn( + "Decorator `@unittest_run_loop` is no longer needed in aiohttp 3.8+", + DeprecationWarning, + stacklevel=2, + ) + return func + + +_LOOP_FACTORY = Callable[[], asyncio.AbstractEventLoop] + + +@contextlib.contextmanager +def loop_context( + loop_factory: _LOOP_FACTORY = asyncio.new_event_loop, fast: bool = False +) -> Iterator[asyncio.AbstractEventLoop]: + """A contextmanager that creates an event_loop, for test purposes. + + Handles the creation and cleanup of a test loop. + """ + loop = setup_test_loop(loop_factory) + yield loop + teardown_test_loop(loop, fast=fast) + + +def setup_test_loop( + loop_factory: _LOOP_FACTORY = asyncio.new_event_loop, +) -> asyncio.AbstractEventLoop: + """Create and return an asyncio.BaseEventLoop instance. + + The caller should also call teardown_test_loop, + once they are done with the loop. + """ + loop = loop_factory() + asyncio.set_event_loop(loop) + return loop + + +def teardown_test_loop(loop: asyncio.AbstractEventLoop, fast: bool = False) -> None: + """Teardown and cleanup an event_loop created by setup_test_loop.""" + closed = loop.is_closed() + if not closed: + loop.call_soon(loop.stop) + loop.run_forever() + loop.close() + + if not fast: + gc.collect() + + asyncio.set_event_loop(None) + + +def _create_app_mock() -> mock.MagicMock: + def get_dict(app: Any, key: str) -> Any: + return app.__app_dict[key] + + def set_dict(app: Any, key: str, value: Any) -> None: + app.__app_dict[key] = value + + app = mock.MagicMock(spec=Application) + app.__app_dict = {} + app.__getitem__ = get_dict + app.__setitem__ = set_dict + + app._debug = False + app.on_response_prepare = Signal(app) + app.on_response_prepare.freeze() + return app + + +def _create_transport(sslcontext: Optional[SSLContext] = None) -> mock.Mock: + transport = mock.Mock() + + def get_extra_info(key: str) -> Optional[SSLContext]: + if key == "sslcontext": + return sslcontext + else: + return None + + transport.get_extra_info.side_effect = get_extra_info + return transport + + +def make_mocked_request( + method: str, + path: str, + headers: Any = None, + *, + match_info: Any = sentinel, + version: HttpVersion = HttpVersion(1, 1), + closing: bool = False, + app: Any = None, + writer: Any = sentinel, + protocol: Any = sentinel, + transport: Any = sentinel, + payload: Any = sentinel, + sslcontext: Optional[SSLContext] = None, + client_max_size: int = 1024**2, + loop: Any = ..., +) -> Request: + """Creates mocked web.Request testing purposes. + + Useful in unit tests, when spinning full web server is overkill or + specific conditions and errors are hard to trigger. + """ + task = mock.Mock() + if loop is ...: + # no loop passed, try to get the current one if + # its is running as we need a real loop to create + # executor jobs to be able to do testing + # with a real executor + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = mock.Mock() + loop.create_future.return_value = () + + if version < HttpVersion(1, 1): + closing = True + + if headers: + headers = CIMultiDictProxy(CIMultiDict(headers)) + raw_hdrs = tuple( + (k.encode("utf-8"), v.encode("utf-8")) for k, v in headers.items() + ) + else: + headers = CIMultiDictProxy(CIMultiDict()) + raw_hdrs = () + + chunked = "chunked" in headers.get(hdrs.TRANSFER_ENCODING, "").lower() + + message = RawRequestMessage( + method, + path, + version, + headers, + raw_hdrs, + closing, + None, + False, + chunked, + URL(path), + ) + if app is None: + app = _create_app_mock() + + if transport is sentinel: + transport = _create_transport(sslcontext) + + if protocol is sentinel: + protocol = mock.Mock() + protocol.transport = transport + + if writer is sentinel: + writer = mock.Mock() + writer.write_headers = make_mocked_coro(None) + writer.write = make_mocked_coro(None) + writer.write_eof = make_mocked_coro(None) + writer.drain = make_mocked_coro(None) + writer.transport = transport + + protocol.transport = transport + protocol.writer = writer + + if payload is sentinel: + payload = mock.Mock() + + req = Request( + message, payload, protocol, writer, task, loop, client_max_size=client_max_size + ) + + match_info = UrlMappingMatchInfo( + {} if match_info is sentinel else match_info, mock.Mock() + ) + match_info.add_app(app) + req._match_info = match_info + + return req + + +def make_mocked_coro( + return_value: Any = sentinel, raise_exception: Any = sentinel +) -> Any: + """Creates a coroutine mock.""" + + async def mock_coro(*args: Any, **kwargs: Any) -> Any: + if raise_exception is not sentinel: + raise raise_exception + if not inspect.isawaitable(return_value): + return return_value + await return_value + + return mock.Mock(wraps=mock_coro) diff --git a/llm/Lib/site-packages/aiohttp/tracing.py b/llm/Lib/site-packages/aiohttp/tracing.py new file mode 100644 index 0000000000000000000000000000000000000000..6770a0dc361007846c123a1472f21e688cddd233 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/tracing.py @@ -0,0 +1,471 @@ +from types import SimpleNamespace +from typing import TYPE_CHECKING, Awaitable, Optional, Protocol, Type, TypeVar + +import attr +from aiosignal import Signal +from multidict import CIMultiDict +from yarl import URL + +from .client_reqrep import ClientResponse + +if TYPE_CHECKING: + from .client import ClientSession + + _ParamT_contra = TypeVar("_ParamT_contra", contravariant=True) + + class _SignalCallback(Protocol[_ParamT_contra]): + def __call__( + self, + __client_session: ClientSession, + __trace_config_ctx: SimpleNamespace, + __params: _ParamT_contra, + ) -> Awaitable[None]: + ... + + +__all__ = ( + "TraceConfig", + "TraceRequestStartParams", + "TraceRequestEndParams", + "TraceRequestExceptionParams", + "TraceConnectionQueuedStartParams", + "TraceConnectionQueuedEndParams", + "TraceConnectionCreateStartParams", + "TraceConnectionCreateEndParams", + "TraceConnectionReuseconnParams", + "TraceDnsResolveHostStartParams", + "TraceDnsResolveHostEndParams", + "TraceDnsCacheHitParams", + "TraceDnsCacheMissParams", + "TraceRequestRedirectParams", + "TraceRequestChunkSentParams", + "TraceResponseChunkReceivedParams", + "TraceRequestHeadersSentParams", +) + + +class TraceConfig: + """First-class used to trace requests launched via ClientSession objects.""" + + def __init__( + self, trace_config_ctx_factory: Type[SimpleNamespace] = SimpleNamespace + ) -> None: + self._on_request_start: Signal[ + _SignalCallback[TraceRequestStartParams] + ] = Signal(self) + self._on_request_chunk_sent: Signal[ + _SignalCallback[TraceRequestChunkSentParams] + ] = Signal(self) + self._on_response_chunk_received: Signal[ + _SignalCallback[TraceResponseChunkReceivedParams] + ] = Signal(self) + self._on_request_end: Signal[_SignalCallback[TraceRequestEndParams]] = Signal( + self + ) + self._on_request_exception: Signal[ + _SignalCallback[TraceRequestExceptionParams] + ] = Signal(self) + self._on_request_redirect: Signal[ + _SignalCallback[TraceRequestRedirectParams] + ] = Signal(self) + self._on_connection_queued_start: Signal[ + _SignalCallback[TraceConnectionQueuedStartParams] + ] = Signal(self) + self._on_connection_queued_end: Signal[ + _SignalCallback[TraceConnectionQueuedEndParams] + ] = Signal(self) + self._on_connection_create_start: Signal[ + _SignalCallback[TraceConnectionCreateStartParams] + ] = Signal(self) + self._on_connection_create_end: Signal[ + _SignalCallback[TraceConnectionCreateEndParams] + ] = Signal(self) + self._on_connection_reuseconn: Signal[ + _SignalCallback[TraceConnectionReuseconnParams] + ] = Signal(self) + self._on_dns_resolvehost_start: Signal[ + _SignalCallback[TraceDnsResolveHostStartParams] + ] = Signal(self) + self._on_dns_resolvehost_end: Signal[ + _SignalCallback[TraceDnsResolveHostEndParams] + ] = Signal(self) + self._on_dns_cache_hit: Signal[ + _SignalCallback[TraceDnsCacheHitParams] + ] = Signal(self) + self._on_dns_cache_miss: Signal[ + _SignalCallback[TraceDnsCacheMissParams] + ] = Signal(self) + self._on_request_headers_sent: Signal[ + _SignalCallback[TraceRequestHeadersSentParams] + ] = Signal(self) + + self._trace_config_ctx_factory = trace_config_ctx_factory + + def trace_config_ctx( + self, trace_request_ctx: Optional[SimpleNamespace] = None + ) -> SimpleNamespace: + """Return a new trace_config_ctx instance""" + return self._trace_config_ctx_factory(trace_request_ctx=trace_request_ctx) + + def freeze(self) -> None: + self._on_request_start.freeze() + self._on_request_chunk_sent.freeze() + self._on_response_chunk_received.freeze() + self._on_request_end.freeze() + self._on_request_exception.freeze() + self._on_request_redirect.freeze() + self._on_connection_queued_start.freeze() + self._on_connection_queued_end.freeze() + self._on_connection_create_start.freeze() + self._on_connection_create_end.freeze() + self._on_connection_reuseconn.freeze() + self._on_dns_resolvehost_start.freeze() + self._on_dns_resolvehost_end.freeze() + self._on_dns_cache_hit.freeze() + self._on_dns_cache_miss.freeze() + self._on_request_headers_sent.freeze() + + @property + def on_request_start(self) -> "Signal[_SignalCallback[TraceRequestStartParams]]": + return self._on_request_start + + @property + def on_request_chunk_sent( + self, + ) -> "Signal[_SignalCallback[TraceRequestChunkSentParams]]": + return self._on_request_chunk_sent + + @property + def on_response_chunk_received( + self, + ) -> "Signal[_SignalCallback[TraceResponseChunkReceivedParams]]": + return self._on_response_chunk_received + + @property + def on_request_end(self) -> "Signal[_SignalCallback[TraceRequestEndParams]]": + return self._on_request_end + + @property + def on_request_exception( + self, + ) -> "Signal[_SignalCallback[TraceRequestExceptionParams]]": + return self._on_request_exception + + @property + def on_request_redirect( + self, + ) -> "Signal[_SignalCallback[TraceRequestRedirectParams]]": + return self._on_request_redirect + + @property + def on_connection_queued_start( + self, + ) -> "Signal[_SignalCallback[TraceConnectionQueuedStartParams]]": + return self._on_connection_queued_start + + @property + def on_connection_queued_end( + self, + ) -> "Signal[_SignalCallback[TraceConnectionQueuedEndParams]]": + return self._on_connection_queued_end + + @property + def on_connection_create_start( + self, + ) -> "Signal[_SignalCallback[TraceConnectionCreateStartParams]]": + return self._on_connection_create_start + + @property + def on_connection_create_end( + self, + ) -> "Signal[_SignalCallback[TraceConnectionCreateEndParams]]": + return self._on_connection_create_end + + @property + def on_connection_reuseconn( + self, + ) -> "Signal[_SignalCallback[TraceConnectionReuseconnParams]]": + return self._on_connection_reuseconn + + @property + def on_dns_resolvehost_start( + self, + ) -> "Signal[_SignalCallback[TraceDnsResolveHostStartParams]]": + return self._on_dns_resolvehost_start + + @property + def on_dns_resolvehost_end( + self, + ) -> "Signal[_SignalCallback[TraceDnsResolveHostEndParams]]": + return self._on_dns_resolvehost_end + + @property + def on_dns_cache_hit(self) -> "Signal[_SignalCallback[TraceDnsCacheHitParams]]": + return self._on_dns_cache_hit + + @property + def on_dns_cache_miss(self) -> "Signal[_SignalCallback[TraceDnsCacheMissParams]]": + return self._on_dns_cache_miss + + @property + def on_request_headers_sent( + self, + ) -> "Signal[_SignalCallback[TraceRequestHeadersSentParams]]": + return self._on_request_headers_sent + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class TraceRequestStartParams: + """Parameters sent by the `on_request_start` signal""" + + method: str + url: URL + headers: "CIMultiDict[str]" + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class TraceRequestChunkSentParams: + """Parameters sent by the `on_request_chunk_sent` signal""" + + method: str + url: URL + chunk: bytes + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class TraceResponseChunkReceivedParams: + """Parameters sent by the `on_response_chunk_received` signal""" + + method: str + url: URL + chunk: bytes + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class TraceRequestEndParams: + """Parameters sent by the `on_request_end` signal""" + + method: str + url: URL + headers: "CIMultiDict[str]" + response: ClientResponse + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class TraceRequestExceptionParams: + """Parameters sent by the `on_request_exception` signal""" + + method: str + url: URL + headers: "CIMultiDict[str]" + exception: BaseException + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class TraceRequestRedirectParams: + """Parameters sent by the `on_request_redirect` signal""" + + method: str + url: URL + headers: "CIMultiDict[str]" + response: ClientResponse + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class TraceConnectionQueuedStartParams: + """Parameters sent by the `on_connection_queued_start` signal""" + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class TraceConnectionQueuedEndParams: + """Parameters sent by the `on_connection_queued_end` signal""" + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class TraceConnectionCreateStartParams: + """Parameters sent by the `on_connection_create_start` signal""" + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class TraceConnectionCreateEndParams: + """Parameters sent by the `on_connection_create_end` signal""" + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class TraceConnectionReuseconnParams: + """Parameters sent by the `on_connection_reuseconn` signal""" + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class TraceDnsResolveHostStartParams: + """Parameters sent by the `on_dns_resolvehost_start` signal""" + + host: str + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class TraceDnsResolveHostEndParams: + """Parameters sent by the `on_dns_resolvehost_end` signal""" + + host: str + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class TraceDnsCacheHitParams: + """Parameters sent by the `on_dns_cache_hit` signal""" + + host: str + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class TraceDnsCacheMissParams: + """Parameters sent by the `on_dns_cache_miss` signal""" + + host: str + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class TraceRequestHeadersSentParams: + """Parameters sent by the `on_request_headers_sent` signal""" + + method: str + url: URL + headers: "CIMultiDict[str]" + + +class Trace: + """Internal dependency holder class. + + Used to keep together the main dependencies used + at the moment of send a signal. + """ + + def __init__( + self, + session: "ClientSession", + trace_config: TraceConfig, + trace_config_ctx: SimpleNamespace, + ) -> None: + self._trace_config = trace_config + self._trace_config_ctx = trace_config_ctx + self._session = session + + async def send_request_start( + self, method: str, url: URL, headers: "CIMultiDict[str]" + ) -> None: + return await self._trace_config.on_request_start.send( + self._session, + self._trace_config_ctx, + TraceRequestStartParams(method, url, headers), + ) + + async def send_request_chunk_sent( + self, method: str, url: URL, chunk: bytes + ) -> None: + return await self._trace_config.on_request_chunk_sent.send( + self._session, + self._trace_config_ctx, + TraceRequestChunkSentParams(method, url, chunk), + ) + + async def send_response_chunk_received( + self, method: str, url: URL, chunk: bytes + ) -> None: + return await self._trace_config.on_response_chunk_received.send( + self._session, + self._trace_config_ctx, + TraceResponseChunkReceivedParams(method, url, chunk), + ) + + async def send_request_end( + self, + method: str, + url: URL, + headers: "CIMultiDict[str]", + response: ClientResponse, + ) -> None: + return await self._trace_config.on_request_end.send( + self._session, + self._trace_config_ctx, + TraceRequestEndParams(method, url, headers, response), + ) + + async def send_request_exception( + self, + method: str, + url: URL, + headers: "CIMultiDict[str]", + exception: BaseException, + ) -> None: + return await self._trace_config.on_request_exception.send( + self._session, + self._trace_config_ctx, + TraceRequestExceptionParams(method, url, headers, exception), + ) + + async def send_request_redirect( + self, + method: str, + url: URL, + headers: "CIMultiDict[str]", + response: ClientResponse, + ) -> None: + return await self._trace_config._on_request_redirect.send( + self._session, + self._trace_config_ctx, + TraceRequestRedirectParams(method, url, headers, response), + ) + + async def send_connection_queued_start(self) -> None: + return await self._trace_config.on_connection_queued_start.send( + self._session, self._trace_config_ctx, TraceConnectionQueuedStartParams() + ) + + async def send_connection_queued_end(self) -> None: + return await self._trace_config.on_connection_queued_end.send( + self._session, self._trace_config_ctx, TraceConnectionQueuedEndParams() + ) + + async def send_connection_create_start(self) -> None: + return await self._trace_config.on_connection_create_start.send( + self._session, self._trace_config_ctx, TraceConnectionCreateStartParams() + ) + + async def send_connection_create_end(self) -> None: + return await self._trace_config.on_connection_create_end.send( + self._session, self._trace_config_ctx, TraceConnectionCreateEndParams() + ) + + async def send_connection_reuseconn(self) -> None: + return await self._trace_config.on_connection_reuseconn.send( + self._session, self._trace_config_ctx, TraceConnectionReuseconnParams() + ) + + async def send_dns_resolvehost_start(self, host: str) -> None: + return await self._trace_config.on_dns_resolvehost_start.send( + self._session, self._trace_config_ctx, TraceDnsResolveHostStartParams(host) + ) + + async def send_dns_resolvehost_end(self, host: str) -> None: + return await self._trace_config.on_dns_resolvehost_end.send( + self._session, self._trace_config_ctx, TraceDnsResolveHostEndParams(host) + ) + + async def send_dns_cache_hit(self, host: str) -> None: + return await self._trace_config.on_dns_cache_hit.send( + self._session, self._trace_config_ctx, TraceDnsCacheHitParams(host) + ) + + async def send_dns_cache_miss(self, host: str) -> None: + return await self._trace_config.on_dns_cache_miss.send( + self._session, self._trace_config_ctx, TraceDnsCacheMissParams(host) + ) + + async def send_request_headers( + self, method: str, url: URL, headers: "CIMultiDict[str]" + ) -> None: + return await self._trace_config._on_request_headers_sent.send( + self._session, + self._trace_config_ctx, + TraceRequestHeadersSentParams(method, url, headers), + ) diff --git a/llm/Lib/site-packages/aiohttp/typedefs.py b/llm/Lib/site-packages/aiohttp/typedefs.py new file mode 100644 index 0000000000000000000000000000000000000000..b80976a25053ae612bddfff523b66b1b248501ba --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/typedefs.py @@ -0,0 +1,54 @@ +import json +import os +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Iterable, + Mapping, + Tuple, + Union, +) + +from multidict import CIMultiDict, CIMultiDictProxy, MultiDict, MultiDictProxy, istr +from yarl import URL + +DEFAULT_JSON_ENCODER = json.dumps +DEFAULT_JSON_DECODER = json.loads + +if TYPE_CHECKING: + _CIMultiDict = CIMultiDict[str] + _CIMultiDictProxy = CIMultiDictProxy[str] + _MultiDict = MultiDict[str] + _MultiDictProxy = MultiDictProxy[str] + from http.cookies import BaseCookie, Morsel + + from .web import Request, StreamResponse +else: + _CIMultiDict = CIMultiDict + _CIMultiDictProxy = CIMultiDictProxy + _MultiDict = MultiDict + _MultiDictProxy = MultiDictProxy + +Byteish = Union[bytes, bytearray, memoryview] +JSONEncoder = Callable[[Any], str] +JSONDecoder = Callable[[str], Any] +LooseHeaders = Union[Mapping[Union[str, istr], str], _CIMultiDict, _CIMultiDictProxy] +RawHeaders = Tuple[Tuple[bytes, bytes], ...] +StrOrURL = Union[str, URL] + +LooseCookiesMappings = Mapping[str, Union[str, "BaseCookie[str]", "Morsel[Any]"]] +LooseCookiesIterables = Iterable[ + Tuple[str, Union[str, "BaseCookie[str]", "Morsel[Any]"]] +] +LooseCookies = Union[ + LooseCookiesMappings, + LooseCookiesIterables, + "BaseCookie[str]", +] + +Handler = Callable[["Request"], Awaitable["StreamResponse"]] +Middleware = Callable[["Request", Handler], Awaitable["StreamResponse"]] + +PathLike = Union[str, "os.PathLike[str]"] diff --git a/llm/Lib/site-packages/aiohttp/web.py b/llm/Lib/site-packages/aiohttp/web.py new file mode 100644 index 0000000000000000000000000000000000000000..95efc5a81b87d2a1324a110a4e73cad3c1226287 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/web.py @@ -0,0 +1,616 @@ +import asyncio +import logging +import os +import socket +import sys +import warnings +from argparse import ArgumentParser +from collections.abc import Iterable +from contextlib import suppress +from functools import partial +from importlib import import_module +from typing import ( + Any, + Awaitable, + Callable, + Iterable as TypingIterable, + List, + Optional, + Set, + Type, + Union, + cast, +) +from weakref import WeakSet + +from .abc import AbstractAccessLogger +from .helpers import AppKey as AppKey +from .log import access_logger +from .typedefs import PathLike +from .web_app import Application as Application, CleanupError as CleanupError +from .web_exceptions import ( + HTTPAccepted as HTTPAccepted, + HTTPBadGateway as HTTPBadGateway, + HTTPBadRequest as HTTPBadRequest, + HTTPClientError as HTTPClientError, + HTTPConflict as HTTPConflict, + HTTPCreated as HTTPCreated, + HTTPError as HTTPError, + HTTPException as HTTPException, + HTTPExpectationFailed as HTTPExpectationFailed, + HTTPFailedDependency as HTTPFailedDependency, + HTTPForbidden as HTTPForbidden, + HTTPFound as HTTPFound, + HTTPGatewayTimeout as HTTPGatewayTimeout, + HTTPGone as HTTPGone, + HTTPInsufficientStorage as HTTPInsufficientStorage, + HTTPInternalServerError as HTTPInternalServerError, + HTTPLengthRequired as HTTPLengthRequired, + HTTPMethodNotAllowed as HTTPMethodNotAllowed, + HTTPMisdirectedRequest as HTTPMisdirectedRequest, + HTTPMove as HTTPMove, + HTTPMovedPermanently as HTTPMovedPermanently, + HTTPMultipleChoices as HTTPMultipleChoices, + HTTPNetworkAuthenticationRequired as HTTPNetworkAuthenticationRequired, + HTTPNoContent as HTTPNoContent, + HTTPNonAuthoritativeInformation as HTTPNonAuthoritativeInformation, + HTTPNotAcceptable as HTTPNotAcceptable, + HTTPNotExtended as HTTPNotExtended, + HTTPNotFound as HTTPNotFound, + HTTPNotImplemented as HTTPNotImplemented, + HTTPNotModified as HTTPNotModified, + HTTPOk as HTTPOk, + HTTPPartialContent as HTTPPartialContent, + HTTPPaymentRequired as HTTPPaymentRequired, + HTTPPermanentRedirect as HTTPPermanentRedirect, + HTTPPreconditionFailed as HTTPPreconditionFailed, + HTTPPreconditionRequired as HTTPPreconditionRequired, + HTTPProxyAuthenticationRequired as HTTPProxyAuthenticationRequired, + HTTPRedirection as HTTPRedirection, + HTTPRequestEntityTooLarge as HTTPRequestEntityTooLarge, + HTTPRequestHeaderFieldsTooLarge as HTTPRequestHeaderFieldsTooLarge, + HTTPRequestRangeNotSatisfiable as HTTPRequestRangeNotSatisfiable, + HTTPRequestTimeout as HTTPRequestTimeout, + HTTPRequestURITooLong as HTTPRequestURITooLong, + HTTPResetContent as HTTPResetContent, + HTTPSeeOther as HTTPSeeOther, + HTTPServerError as HTTPServerError, + HTTPServiceUnavailable as HTTPServiceUnavailable, + HTTPSuccessful as HTTPSuccessful, + HTTPTemporaryRedirect as HTTPTemporaryRedirect, + HTTPTooManyRequests as HTTPTooManyRequests, + HTTPUnauthorized as HTTPUnauthorized, + HTTPUnavailableForLegalReasons as HTTPUnavailableForLegalReasons, + HTTPUnprocessableEntity as HTTPUnprocessableEntity, + HTTPUnsupportedMediaType as HTTPUnsupportedMediaType, + HTTPUpgradeRequired as HTTPUpgradeRequired, + HTTPUseProxy as HTTPUseProxy, + HTTPVariantAlsoNegotiates as HTTPVariantAlsoNegotiates, + HTTPVersionNotSupported as HTTPVersionNotSupported, + NotAppKeyWarning as NotAppKeyWarning, +) +from .web_fileresponse import FileResponse as FileResponse +from .web_log import AccessLogger +from .web_middlewares import ( + middleware as middleware, + normalize_path_middleware as normalize_path_middleware, +) +from .web_protocol import ( + PayloadAccessError as PayloadAccessError, + RequestHandler as RequestHandler, + RequestPayloadError as RequestPayloadError, +) +from .web_request import ( + BaseRequest as BaseRequest, + FileField as FileField, + Request as Request, +) +from .web_response import ( + ContentCoding as ContentCoding, + Response as Response, + StreamResponse as StreamResponse, + json_response as json_response, +) +from .web_routedef import ( + AbstractRouteDef as AbstractRouteDef, + RouteDef as RouteDef, + RouteTableDef as RouteTableDef, + StaticDef as StaticDef, + delete as delete, + get as get, + head as head, + options as options, + patch as patch, + post as post, + put as put, + route as route, + static as static, + view as view, +) +from .web_runner import ( + AppRunner as AppRunner, + BaseRunner as BaseRunner, + BaseSite as BaseSite, + GracefulExit as GracefulExit, + NamedPipeSite as NamedPipeSite, + ServerRunner as ServerRunner, + SockSite as SockSite, + TCPSite as TCPSite, + UnixSite as UnixSite, +) +from .web_server import Server as Server +from .web_urldispatcher import ( + AbstractResource as AbstractResource, + AbstractRoute as AbstractRoute, + DynamicResource as DynamicResource, + PlainResource as PlainResource, + PrefixedSubAppResource as PrefixedSubAppResource, + Resource as Resource, + ResourceRoute as ResourceRoute, + StaticResource as StaticResource, + UrlDispatcher as UrlDispatcher, + UrlMappingMatchInfo as UrlMappingMatchInfo, + View as View, +) +from .web_ws import ( + WebSocketReady as WebSocketReady, + WebSocketResponse as WebSocketResponse, + WSMsgType as WSMsgType, +) + +__all__ = ( + # web_app + "AppKey", + "Application", + "CleanupError", + # web_exceptions + "NotAppKeyWarning", + "HTTPAccepted", + "HTTPBadGateway", + "HTTPBadRequest", + "HTTPClientError", + "HTTPConflict", + "HTTPCreated", + "HTTPError", + "HTTPException", + "HTTPExpectationFailed", + "HTTPFailedDependency", + "HTTPForbidden", + "HTTPFound", + "HTTPGatewayTimeout", + "HTTPGone", + "HTTPInsufficientStorage", + "HTTPInternalServerError", + "HTTPLengthRequired", + "HTTPMethodNotAllowed", + "HTTPMisdirectedRequest", + "HTTPMove", + "HTTPMovedPermanently", + "HTTPMultipleChoices", + "HTTPNetworkAuthenticationRequired", + "HTTPNoContent", + "HTTPNonAuthoritativeInformation", + "HTTPNotAcceptable", + "HTTPNotExtended", + "HTTPNotFound", + "HTTPNotImplemented", + "HTTPNotModified", + "HTTPOk", + "HTTPPartialContent", + "HTTPPaymentRequired", + "HTTPPermanentRedirect", + "HTTPPreconditionFailed", + "HTTPPreconditionRequired", + "HTTPProxyAuthenticationRequired", + "HTTPRedirection", + "HTTPRequestEntityTooLarge", + "HTTPRequestHeaderFieldsTooLarge", + "HTTPRequestRangeNotSatisfiable", + "HTTPRequestTimeout", + "HTTPRequestURITooLong", + "HTTPResetContent", + "HTTPSeeOther", + "HTTPServerError", + "HTTPServiceUnavailable", + "HTTPSuccessful", + "HTTPTemporaryRedirect", + "HTTPTooManyRequests", + "HTTPUnauthorized", + "HTTPUnavailableForLegalReasons", + "HTTPUnprocessableEntity", + "HTTPUnsupportedMediaType", + "HTTPUpgradeRequired", + "HTTPUseProxy", + "HTTPVariantAlsoNegotiates", + "HTTPVersionNotSupported", + # web_fileresponse + "FileResponse", + # web_middlewares + "middleware", + "normalize_path_middleware", + # web_protocol + "PayloadAccessError", + "RequestHandler", + "RequestPayloadError", + # web_request + "BaseRequest", + "FileField", + "Request", + # web_response + "ContentCoding", + "Response", + "StreamResponse", + "json_response", + # web_routedef + "AbstractRouteDef", + "RouteDef", + "RouteTableDef", + "StaticDef", + "delete", + "get", + "head", + "options", + "patch", + "post", + "put", + "route", + "static", + "view", + # web_runner + "AppRunner", + "BaseRunner", + "BaseSite", + "GracefulExit", + "ServerRunner", + "SockSite", + "TCPSite", + "UnixSite", + "NamedPipeSite", + # web_server + "Server", + # web_urldispatcher + "AbstractResource", + "AbstractRoute", + "DynamicResource", + "PlainResource", + "PrefixedSubAppResource", + "Resource", + "ResourceRoute", + "StaticResource", + "UrlDispatcher", + "UrlMappingMatchInfo", + "View", + # web_ws + "WebSocketReady", + "WebSocketResponse", + "WSMsgType", + # web + "run_app", +) + + +try: + from ssl import SSLContext +except ImportError: # pragma: no cover + SSLContext = Any # type: ignore[misc,assignment] + +# Only display warning when using -Wdefault, -We, -X dev or similar. +warnings.filterwarnings("ignore", category=NotAppKeyWarning, append=True) + +HostSequence = TypingIterable[str] + + +async def _run_app( + app: Union[Application, Awaitable[Application]], + *, + host: Optional[Union[str, HostSequence]] = None, + port: Optional[int] = None, + path: Union[PathLike, TypingIterable[PathLike], None] = None, + sock: Optional[Union[socket.socket, TypingIterable[socket.socket]]] = None, + shutdown_timeout: float = 60.0, + keepalive_timeout: float = 75.0, + ssl_context: Optional[SSLContext] = None, + print: Optional[Callable[..., None]] = print, + backlog: int = 128, + access_log_class: Type[AbstractAccessLogger] = AccessLogger, + access_log_format: str = AccessLogger.LOG_FORMAT, + access_log: Optional[logging.Logger] = access_logger, + handle_signals: bool = True, + reuse_address: Optional[bool] = None, + reuse_port: Optional[bool] = None, + handler_cancellation: bool = False, +) -> None: + async def wait( + starting_tasks: "WeakSet[asyncio.Task[object]]", shutdown_timeout: float + ) -> None: + # Wait for pending tasks for a given time limit. + t = asyncio.current_task() + assert t is not None + starting_tasks.add(t) + with suppress(asyncio.TimeoutError): + await asyncio.wait_for(_wait(starting_tasks), timeout=shutdown_timeout) + + async def _wait(exclude: "WeakSet[asyncio.Task[object]]") -> None: + t = asyncio.current_task() + assert t is not None + exclude.add(t) + while tasks := asyncio.all_tasks().difference(exclude): + await asyncio.wait(tasks) + + # An internal function to actually do all dirty job for application running + if asyncio.iscoroutine(app): + app = await app + + app = cast(Application, app) + + runner = AppRunner( + app, + handle_signals=handle_signals, + access_log_class=access_log_class, + access_log_format=access_log_format, + access_log=access_log, + keepalive_timeout=keepalive_timeout, + shutdown_timeout=shutdown_timeout, + handler_cancellation=handler_cancellation, + ) + + await runner.setup() + # On shutdown we want to avoid waiting on tasks which run forever. + # It's very likely that all tasks which run forever will have been created by + # the time we have completed the application startup (in runner.setup()), + # so we just record all running tasks here and exclude them later. + starting_tasks: "WeakSet[asyncio.Task[object]]" = WeakSet(asyncio.all_tasks()) + runner.shutdown_callback = partial(wait, starting_tasks, shutdown_timeout) + + sites: List[BaseSite] = [] + + try: + if host is not None: + if isinstance(host, (str, bytes, bytearray, memoryview)): + sites.append( + TCPSite( + runner, + host, + port, + ssl_context=ssl_context, + backlog=backlog, + reuse_address=reuse_address, + reuse_port=reuse_port, + ) + ) + else: + for h in host: + sites.append( + TCPSite( + runner, + h, + port, + ssl_context=ssl_context, + backlog=backlog, + reuse_address=reuse_address, + reuse_port=reuse_port, + ) + ) + elif path is None and sock is None or port is not None: + sites.append( + TCPSite( + runner, + port=port, + ssl_context=ssl_context, + backlog=backlog, + reuse_address=reuse_address, + reuse_port=reuse_port, + ) + ) + + if path is not None: + if isinstance(path, (str, os.PathLike)): + sites.append( + UnixSite( + runner, + path, + ssl_context=ssl_context, + backlog=backlog, + ) + ) + else: + for p in path: + sites.append( + UnixSite( + runner, + p, + ssl_context=ssl_context, + backlog=backlog, + ) + ) + + if sock is not None: + if not isinstance(sock, Iterable): + sites.append( + SockSite( + runner, + sock, + ssl_context=ssl_context, + backlog=backlog, + ) + ) + else: + for s in sock: + sites.append( + SockSite( + runner, + s, + ssl_context=ssl_context, + backlog=backlog, + ) + ) + for site in sites: + await site.start() + + if print: # pragma: no branch + names = sorted(str(s.name) for s in runner.sites) + print( + "======== Running on {} ========\n" + "(Press CTRL+C to quit)".format(", ".join(names)) + ) + + # sleep forever by 1 hour intervals, + while True: + await asyncio.sleep(3600) + finally: + await runner.cleanup() + + +def _cancel_tasks( + to_cancel: Set["asyncio.Task[Any]"], loop: asyncio.AbstractEventLoop +) -> None: + if not to_cancel: + return + + for task in to_cancel: + task.cancel() + + loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True)) + + for task in to_cancel: + if task.cancelled(): + continue + if task.exception() is not None: + loop.call_exception_handler( + { + "message": "unhandled exception during asyncio.run() shutdown", + "exception": task.exception(), + "task": task, + } + ) + + +def run_app( + app: Union[Application, Awaitable[Application]], + *, + host: Optional[Union[str, HostSequence]] = None, + port: Optional[int] = None, + path: Union[PathLike, TypingIterable[PathLike], None] = None, + sock: Optional[Union[socket.socket, TypingIterable[socket.socket]]] = None, + shutdown_timeout: float = 60.0, + keepalive_timeout: float = 75.0, + ssl_context: Optional[SSLContext] = None, + print: Optional[Callable[..., None]] = print, + backlog: int = 128, + access_log_class: Type[AbstractAccessLogger] = AccessLogger, + access_log_format: str = AccessLogger.LOG_FORMAT, + access_log: Optional[logging.Logger] = access_logger, + handle_signals: bool = True, + reuse_address: Optional[bool] = None, + reuse_port: Optional[bool] = None, + handler_cancellation: bool = False, + loop: Optional[asyncio.AbstractEventLoop] = None, +) -> None: + """Run an app locally""" + if loop is None: + loop = asyncio.new_event_loop() + + # Configure if and only if in debugging mode and using the default logger + if loop.get_debug() and access_log and access_log.name == "aiohttp.access": + if access_log.level == logging.NOTSET: + access_log.setLevel(logging.DEBUG) + if not access_log.hasHandlers(): + access_log.addHandler(logging.StreamHandler()) + + main_task = loop.create_task( + _run_app( + app, + host=host, + port=port, + path=path, + sock=sock, + shutdown_timeout=shutdown_timeout, + keepalive_timeout=keepalive_timeout, + ssl_context=ssl_context, + print=print, + backlog=backlog, + access_log_class=access_log_class, + access_log_format=access_log_format, + access_log=access_log, + handle_signals=handle_signals, + reuse_address=reuse_address, + reuse_port=reuse_port, + handler_cancellation=handler_cancellation, + ) + ) + + try: + asyncio.set_event_loop(loop) + loop.run_until_complete(main_task) + except (GracefulExit, KeyboardInterrupt): # pragma: no cover + pass + finally: + _cancel_tasks({main_task}, loop) + _cancel_tasks(asyncio.all_tasks(loop), loop) + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.close() + + +def main(argv: List[str]) -> None: + arg_parser = ArgumentParser( + description="aiohttp.web Application server", prog="aiohttp.web" + ) + arg_parser.add_argument( + "entry_func", + help=( + "Callable returning the `aiohttp.web.Application` instance to " + "run. Should be specified in the 'module:function' syntax." + ), + metavar="entry-func", + ) + arg_parser.add_argument( + "-H", + "--hostname", + help="TCP/IP hostname to serve on (default: %(default)r)", + default="localhost", + ) + arg_parser.add_argument( + "-P", + "--port", + help="TCP/IP port to serve on (default: %(default)r)", + type=int, + default="8080", + ) + arg_parser.add_argument( + "-U", + "--path", + help="Unix file system path to serve on. Specifying a path will cause " + "hostname and port arguments to be ignored.", + ) + args, extra_argv = arg_parser.parse_known_args(argv) + + # Import logic + mod_str, _, func_str = args.entry_func.partition(":") + if not func_str or not mod_str: + arg_parser.error("'entry-func' not in 'module:function' syntax") + if mod_str.startswith("."): + arg_parser.error("relative module names not supported") + try: + module = import_module(mod_str) + except ImportError as ex: + arg_parser.error(f"unable to import {mod_str}: {ex}") + try: + func = getattr(module, func_str) + except AttributeError: + arg_parser.error(f"module {mod_str!r} has no attribute {func_str!r}") + + # Compatibility logic + if args.path is not None and not hasattr(socket, "AF_UNIX"): + arg_parser.error( + "file system paths not supported by your operating" " environment" + ) + + logging.basicConfig(level=logging.DEBUG) + + app = func(extra_argv) + run_app(app, host=args.hostname, port=args.port, path=args.path) + arg_parser.exit(message="Stopped\n") + + +if __name__ == "__main__": # pragma: no branch + main(sys.argv[1:]) # pragma: no cover diff --git a/llm/Lib/site-packages/aiohttp/web_app.py b/llm/Lib/site-packages/aiohttp/web_app.py new file mode 100644 index 0000000000000000000000000000000000000000..0da4021d38a67f79f4f725b7f7e4dc964d6766ec --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/web_app.py @@ -0,0 +1,596 @@ +import asyncio +import logging +import warnings +from functools import partial, update_wrapper +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Awaitable, + Callable, + Dict, + Iterable, + Iterator, + List, + Mapping, + MutableMapping, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, + cast, + overload, +) + +from aiosignal import Signal +from frozenlist import FrozenList + +from . import hdrs +from .abc import ( + AbstractAccessLogger, + AbstractMatchInfo, + AbstractRouter, + AbstractStreamWriter, +) +from .helpers import DEBUG, AppKey +from .http_parser import RawRequestMessage +from .log import web_logger +from .streams import StreamReader +from .typedefs import Middleware +from .web_exceptions import NotAppKeyWarning +from .web_log import AccessLogger +from .web_middlewares import _fix_request_current_app +from .web_protocol import RequestHandler +from .web_request import Request +from .web_response import StreamResponse +from .web_routedef import AbstractRouteDef +from .web_server import Server +from .web_urldispatcher import ( + AbstractResource, + AbstractRoute, + Domain, + MaskDomain, + MatchedSubAppResource, + PrefixedSubAppResource, + UrlDispatcher, +) + +__all__ = ("Application", "CleanupError") + + +if TYPE_CHECKING: + _AppSignal = Signal[Callable[["Application"], Awaitable[None]]] + _RespPrepareSignal = Signal[Callable[[Request, StreamResponse], Awaitable[None]]] + _Middlewares = FrozenList[Middleware] + _MiddlewaresHandlers = Optional[Sequence[Tuple[Middleware, bool]]] + _Subapps = List["Application"] +else: + # No type checker mode, skip types + _AppSignal = Signal + _RespPrepareSignal = Signal + _Middlewares = FrozenList + _MiddlewaresHandlers = Optional[Sequence] + _Subapps = List + +_T = TypeVar("_T") +_U = TypeVar("_U") + + +class Application(MutableMapping[Union[str, AppKey[Any]], Any]): + ATTRS = frozenset( + [ + "logger", + "_debug", + "_router", + "_loop", + "_handler_args", + "_middlewares", + "_middlewares_handlers", + "_run_middlewares", + "_state", + "_frozen", + "_pre_frozen", + "_subapps", + "_on_response_prepare", + "_on_startup", + "_on_shutdown", + "_on_cleanup", + "_client_max_size", + "_cleanup_ctx", + ] + ) + + def __init__( + self, + *, + logger: logging.Logger = web_logger, + router: Optional[UrlDispatcher] = None, + middlewares: Iterable[Middleware] = (), + handler_args: Optional[Mapping[str, Any]] = None, + client_max_size: int = 1024**2, + loop: Optional[asyncio.AbstractEventLoop] = None, + debug: Any = ..., # mypy doesn't support ellipsis + ) -> None: + if router is None: + router = UrlDispatcher() + else: + warnings.warn( + "router argument is deprecated", DeprecationWarning, stacklevel=2 + ) + assert isinstance(router, AbstractRouter), router + + if loop is not None: + warnings.warn( + "loop argument is deprecated", DeprecationWarning, stacklevel=2 + ) + + if debug is not ...: + warnings.warn( + "debug argument is deprecated", DeprecationWarning, stacklevel=2 + ) + self._debug = debug + self._router: UrlDispatcher = router + self._loop = loop + self._handler_args = handler_args + self.logger = logger + + self._middlewares: _Middlewares = FrozenList(middlewares) + + # initialized on freezing + self._middlewares_handlers: _MiddlewaresHandlers = None + # initialized on freezing + self._run_middlewares: Optional[bool] = None + + self._state: Dict[Union[AppKey[Any], str], object] = {} + self._frozen = False + self._pre_frozen = False + self._subapps: _Subapps = [] + + self._on_response_prepare: _RespPrepareSignal = Signal(self) + self._on_startup: _AppSignal = Signal(self) + self._on_shutdown: _AppSignal = Signal(self) + self._on_cleanup: _AppSignal = Signal(self) + self._cleanup_ctx = CleanupContext() + self._on_startup.append(self._cleanup_ctx._on_startup) + self._on_cleanup.append(self._cleanup_ctx._on_cleanup) + self._client_max_size = client_max_size + + def __init_subclass__(cls: Type["Application"]) -> None: + warnings.warn( + "Inheritance class {} from web.Application " + "is discouraged".format(cls.__name__), + DeprecationWarning, + stacklevel=3, + ) + + if DEBUG: # pragma: no cover + + def __setattr__(self, name: str, val: Any) -> None: + if name not in self.ATTRS: + warnings.warn( + "Setting custom web.Application.{} attribute " + "is discouraged".format(name), + DeprecationWarning, + stacklevel=2, + ) + super().__setattr__(name, val) + + # MutableMapping API + + def __eq__(self, other: object) -> bool: + return self is other + + @overload # type: ignore[override] + def __getitem__(self, key: AppKey[_T]) -> _T: + ... + + @overload + def __getitem__(self, key: str) -> Any: + ... + + def __getitem__(self, key: Union[str, AppKey[_T]]) -> Any: + return self._state[key] + + def _check_frozen(self) -> None: + if self._frozen: + warnings.warn( + "Changing state of started or joined " "application is deprecated", + DeprecationWarning, + stacklevel=3, + ) + + @overload # type: ignore[override] + def __setitem__(self, key: AppKey[_T], value: _T) -> None: + ... + + @overload + def __setitem__(self, key: str, value: Any) -> None: + ... + + def __setitem__(self, key: Union[str, AppKey[_T]], value: Any) -> None: + self._check_frozen() + if not isinstance(key, AppKey): + warnings.warn( + "It is recommended to use web.AppKey instances for keys.\n" + + "https://docs.aiohttp.org/en/stable/web_advanced.html" + + "#application-s-config", + category=NotAppKeyWarning, + stacklevel=2, + ) + self._state[key] = value + + def __delitem__(self, key: Union[str, AppKey[_T]]) -> None: + self._check_frozen() + del self._state[key] + + def __len__(self) -> int: + return len(self._state) + + def __iter__(self) -> Iterator[Union[str, AppKey[Any]]]: + return iter(self._state) + + @overload # type: ignore[override] + def get(self, key: AppKey[_T], default: None = ...) -> Optional[_T]: + ... + + @overload + def get(self, key: AppKey[_T], default: _U) -> Union[_T, _U]: + ... + + @overload + def get(self, key: str, default: Any = ...) -> Any: + ... + + def get(self, key: Union[str, AppKey[_T]], default: Any = None) -> Any: + return self._state.get(key, default) + + ######## + @property + def loop(self) -> asyncio.AbstractEventLoop: + # Technically the loop can be None + # but we mask it by explicit type cast + # to provide more convenient type annotation + warnings.warn("loop property is deprecated", DeprecationWarning, stacklevel=2) + return cast(asyncio.AbstractEventLoop, self._loop) + + def _set_loop(self, loop: Optional[asyncio.AbstractEventLoop]) -> None: + if loop is None: + loop = asyncio.get_event_loop() + if self._loop is not None and self._loop is not loop: + raise RuntimeError( + "web.Application instance initialized with different loop" + ) + + self._loop = loop + + # set loop debug + if self._debug is ...: + self._debug = loop.get_debug() + + # set loop to sub applications + for subapp in self._subapps: + subapp._set_loop(loop) + + @property + def pre_frozen(self) -> bool: + return self._pre_frozen + + def pre_freeze(self) -> None: + if self._pre_frozen: + return + + self._pre_frozen = True + self._middlewares.freeze() + self._router.freeze() + self._on_response_prepare.freeze() + self._cleanup_ctx.freeze() + self._on_startup.freeze() + self._on_shutdown.freeze() + self._on_cleanup.freeze() + self._middlewares_handlers = tuple(self._prepare_middleware()) + + # If current app and any subapp do not have middlewares avoid run all + # of the code footprint that it implies, which have a middleware + # hardcoded per app that sets up the current_app attribute. If no + # middlewares are configured the handler will receive the proper + # current_app without needing all of this code. + self._run_middlewares = True if self.middlewares else False + + for subapp in self._subapps: + subapp.pre_freeze() + self._run_middlewares = self._run_middlewares or subapp._run_middlewares + + @property + def frozen(self) -> bool: + return self._frozen + + def freeze(self) -> None: + if self._frozen: + return + + self.pre_freeze() + self._frozen = True + for subapp in self._subapps: + subapp.freeze() + + @property + def debug(self) -> bool: + warnings.warn("debug property is deprecated", DeprecationWarning, stacklevel=2) + return self._debug # type: ignore[no-any-return] + + def _reg_subapp_signals(self, subapp: "Application") -> None: + def reg_handler(signame: str) -> None: + subsig = getattr(subapp, signame) + + async def handler(app: "Application") -> None: + await subsig.send(subapp) + + appsig = getattr(self, signame) + appsig.append(handler) + + reg_handler("on_startup") + reg_handler("on_shutdown") + reg_handler("on_cleanup") + + def add_subapp(self, prefix: str, subapp: "Application") -> AbstractResource: + if not isinstance(prefix, str): + raise TypeError("Prefix must be str") + prefix = prefix.rstrip("/") + if not prefix: + raise ValueError("Prefix cannot be empty") + factory = partial(PrefixedSubAppResource, prefix, subapp) + return self._add_subapp(factory, subapp) + + def _add_subapp( + self, resource_factory: Callable[[], AbstractResource], subapp: "Application" + ) -> AbstractResource: + if self.frozen: + raise RuntimeError("Cannot add sub application to frozen application") + if subapp.frozen: + raise RuntimeError("Cannot add frozen application") + resource = resource_factory() + self.router.register_resource(resource) + self._reg_subapp_signals(subapp) + self._subapps.append(subapp) + subapp.pre_freeze() + if self._loop is not None: + subapp._set_loop(self._loop) + return resource + + def add_domain(self, domain: str, subapp: "Application") -> AbstractResource: + if not isinstance(domain, str): + raise TypeError("Domain must be str") + elif "*" in domain: + rule: Domain = MaskDomain(domain) + else: + rule = Domain(domain) + factory = partial(MatchedSubAppResource, rule, subapp) + return self._add_subapp(factory, subapp) + + def add_routes(self, routes: Iterable[AbstractRouteDef]) -> List[AbstractRoute]: + return self.router.add_routes(routes) + + @property + def on_response_prepare(self) -> _RespPrepareSignal: + return self._on_response_prepare + + @property + def on_startup(self) -> _AppSignal: + return self._on_startup + + @property + def on_shutdown(self) -> _AppSignal: + return self._on_shutdown + + @property + def on_cleanup(self) -> _AppSignal: + return self._on_cleanup + + @property + def cleanup_ctx(self) -> "CleanupContext": + return self._cleanup_ctx + + @property + def router(self) -> UrlDispatcher: + return self._router + + @property + def middlewares(self) -> _Middlewares: + return self._middlewares + + def _make_handler( + self, + *, + loop: Optional[asyncio.AbstractEventLoop] = None, + access_log_class: Type[AbstractAccessLogger] = AccessLogger, + **kwargs: Any, + ) -> Server: + + if not issubclass(access_log_class, AbstractAccessLogger): + raise TypeError( + "access_log_class must be subclass of " + "aiohttp.abc.AbstractAccessLogger, got {}".format(access_log_class) + ) + + self._set_loop(loop) + self.freeze() + + kwargs["debug"] = self._debug + kwargs["access_log_class"] = access_log_class + if self._handler_args: + for k, v in self._handler_args.items(): + kwargs[k] = v + + return Server( + self._handle, # type: ignore[arg-type] + request_factory=self._make_request, + loop=self._loop, + **kwargs, + ) + + def make_handler( + self, + *, + loop: Optional[asyncio.AbstractEventLoop] = None, + access_log_class: Type[AbstractAccessLogger] = AccessLogger, + **kwargs: Any, + ) -> Server: + + warnings.warn( + "Application.make_handler(...) is deprecated, " "use AppRunner API instead", + DeprecationWarning, + stacklevel=2, + ) + + return self._make_handler( + loop=loop, access_log_class=access_log_class, **kwargs + ) + + async def startup(self) -> None: + """Causes on_startup signal + + Should be called in the event loop along with the request handler. + """ + await self.on_startup.send(self) + + async def shutdown(self) -> None: + """Causes on_shutdown signal + + Should be called before cleanup() + """ + await self.on_shutdown.send(self) + + async def cleanup(self) -> None: + """Causes on_cleanup signal + + Should be called after shutdown() + """ + if self.on_cleanup.frozen: + await self.on_cleanup.send(self) + else: + # If an exception occurs in startup, ensure cleanup contexts are completed. + await self._cleanup_ctx._on_cleanup(self) + + def _make_request( + self, + message: RawRequestMessage, + payload: StreamReader, + protocol: RequestHandler, + writer: AbstractStreamWriter, + task: "asyncio.Task[None]", + _cls: Type[Request] = Request, + ) -> Request: + return _cls( + message, + payload, + protocol, + writer, + task, + self._loop, + client_max_size=self._client_max_size, + ) + + def _prepare_middleware(self) -> Iterator[Tuple[Middleware, bool]]: + for m in reversed(self._middlewares): + if getattr(m, "__middleware_version__", None) == 1: + yield m, True + else: + warnings.warn( + 'old-style middleware "{!r}" deprecated, ' "see #2252".format(m), + DeprecationWarning, + stacklevel=2, + ) + yield m, False + + yield _fix_request_current_app(self), True + + async def _handle(self, request: Request) -> StreamResponse: + loop = asyncio.get_event_loop() + debug = loop.get_debug() + match_info = await self._router.resolve(request) + if debug: # pragma: no cover + if not isinstance(match_info, AbstractMatchInfo): + raise TypeError( + "match_info should be AbstractMatchInfo " + "instance, not {!r}".format(match_info) + ) + match_info.add_app(self) + + match_info.freeze() + + resp = None + request._match_info = match_info + expect = request.headers.get(hdrs.EXPECT) + if expect: + resp = await match_info.expect_handler(request) + await request.writer.drain() + + if resp is None: + handler = match_info.handler + + if self._run_middlewares: + for app in match_info.apps[::-1]: + for m, new_style in app._middlewares_handlers: # type: ignore[union-attr] + if new_style: + handler = update_wrapper( + partial(m, handler=handler), handler + ) + else: + handler = await m(app, handler) # type: ignore[arg-type,assignment] + + resp = await handler(request) + + return resp + + def __call__(self) -> "Application": + """gunicorn compatibility""" + return self + + def __repr__(self) -> str: + return f"" + + def __bool__(self) -> bool: + return True + + +class CleanupError(RuntimeError): + @property + def exceptions(self) -> List[BaseException]: + return cast(List[BaseException], self.args[1]) + + +if TYPE_CHECKING: + _CleanupContextBase = FrozenList[Callable[[Application], AsyncIterator[None]]] +else: + _CleanupContextBase = FrozenList + + +class CleanupContext(_CleanupContextBase): + def __init__(self) -> None: + super().__init__() + self._exits: List[AsyncIterator[None]] = [] + + async def _on_startup(self, app: Application) -> None: + for cb in self: + it = cb(app).__aiter__() + await it.__anext__() + self._exits.append(it) + + async def _on_cleanup(self, app: Application) -> None: + errors = [] + for it in reversed(self._exits): + try: + await it.__anext__() + except StopAsyncIteration: + pass + except Exception as exc: + errors.append(exc) + else: + errors.append(RuntimeError(f"{it!r} has more than one 'yield'")) + if errors: + if len(errors) == 1: + raise errors[0] + else: + raise CleanupError("Multiple errors on cleanup stage", errors) diff --git a/llm/Lib/site-packages/aiohttp/web_exceptions.py b/llm/Lib/site-packages/aiohttp/web_exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..49184109378542202a441366d324c50c0190705e --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/web_exceptions.py @@ -0,0 +1,452 @@ +import warnings +from typing import Any, Dict, Iterable, List, Optional, Set # noqa + +from yarl import URL + +from .typedefs import LooseHeaders, StrOrURL +from .web_response import Response + +__all__ = ( + "HTTPException", + "HTTPError", + "HTTPRedirection", + "HTTPSuccessful", + "HTTPOk", + "HTTPCreated", + "HTTPAccepted", + "HTTPNonAuthoritativeInformation", + "HTTPNoContent", + "HTTPResetContent", + "HTTPPartialContent", + "HTTPMove", + "HTTPMultipleChoices", + "HTTPMovedPermanently", + "HTTPFound", + "HTTPSeeOther", + "HTTPNotModified", + "HTTPUseProxy", + "HTTPTemporaryRedirect", + "HTTPPermanentRedirect", + "HTTPClientError", + "HTTPBadRequest", + "HTTPUnauthorized", + "HTTPPaymentRequired", + "HTTPForbidden", + "HTTPNotFound", + "HTTPMethodNotAllowed", + "HTTPNotAcceptable", + "HTTPProxyAuthenticationRequired", + "HTTPRequestTimeout", + "HTTPConflict", + "HTTPGone", + "HTTPLengthRequired", + "HTTPPreconditionFailed", + "HTTPRequestEntityTooLarge", + "HTTPRequestURITooLong", + "HTTPUnsupportedMediaType", + "HTTPRequestRangeNotSatisfiable", + "HTTPExpectationFailed", + "HTTPMisdirectedRequest", + "HTTPUnprocessableEntity", + "HTTPFailedDependency", + "HTTPUpgradeRequired", + "HTTPPreconditionRequired", + "HTTPTooManyRequests", + "HTTPRequestHeaderFieldsTooLarge", + "HTTPUnavailableForLegalReasons", + "HTTPServerError", + "HTTPInternalServerError", + "HTTPNotImplemented", + "HTTPBadGateway", + "HTTPServiceUnavailable", + "HTTPGatewayTimeout", + "HTTPVersionNotSupported", + "HTTPVariantAlsoNegotiates", + "HTTPInsufficientStorage", + "HTTPNotExtended", + "HTTPNetworkAuthenticationRequired", +) + + +class NotAppKeyWarning(UserWarning): + """Warning when not using AppKey in Application.""" + + +############################################################ +# HTTP Exceptions +############################################################ + + +class HTTPException(Response, Exception): + + # You should set in subclasses: + # status = 200 + + status_code = -1 + empty_body = False + + __http_exception__ = True + + def __init__( + self, + *, + headers: Optional[LooseHeaders] = None, + reason: Optional[str] = None, + body: Any = None, + text: Optional[str] = None, + content_type: Optional[str] = None, + ) -> None: + if body is not None: + warnings.warn( + "body argument is deprecated for http web exceptions", + DeprecationWarning, + ) + Response.__init__( + self, + status=self.status_code, + headers=headers, + reason=reason, + body=body, + text=text, + content_type=content_type, + ) + Exception.__init__(self, self.reason) + if self.body is None and not self.empty_body: + self.text = f"{self.status}: {self.reason}" + + def __bool__(self) -> bool: + return True + + +class HTTPError(HTTPException): + """Base class for exceptions with status codes in the 400s and 500s.""" + + +class HTTPRedirection(HTTPException): + """Base class for exceptions with status codes in the 300s.""" + + +class HTTPSuccessful(HTTPException): + """Base class for exceptions with status codes in the 200s.""" + + +class HTTPOk(HTTPSuccessful): + status_code = 200 + + +class HTTPCreated(HTTPSuccessful): + status_code = 201 + + +class HTTPAccepted(HTTPSuccessful): + status_code = 202 + + +class HTTPNonAuthoritativeInformation(HTTPSuccessful): + status_code = 203 + + +class HTTPNoContent(HTTPSuccessful): + status_code = 204 + empty_body = True + + +class HTTPResetContent(HTTPSuccessful): + status_code = 205 + empty_body = True + + +class HTTPPartialContent(HTTPSuccessful): + status_code = 206 + + +############################################################ +# 3xx redirection +############################################################ + + +class HTTPMove(HTTPRedirection): + def __init__( + self, + location: StrOrURL, + *, + headers: Optional[LooseHeaders] = None, + reason: Optional[str] = None, + body: Any = None, + text: Optional[str] = None, + content_type: Optional[str] = None, + ) -> None: + if not location: + raise ValueError("HTTP redirects need a location to redirect to.") + super().__init__( + headers=headers, + reason=reason, + body=body, + text=text, + content_type=content_type, + ) + self.headers["Location"] = str(URL(location)) + self.location = location + + +class HTTPMultipleChoices(HTTPMove): + status_code = 300 + + +class HTTPMovedPermanently(HTTPMove): + status_code = 301 + + +class HTTPFound(HTTPMove): + status_code = 302 + + +# This one is safe after a POST (the redirected location will be +# retrieved with GET): +class HTTPSeeOther(HTTPMove): + status_code = 303 + + +class HTTPNotModified(HTTPRedirection): + # FIXME: this should include a date or etag header + status_code = 304 + empty_body = True + + +class HTTPUseProxy(HTTPMove): + # Not a move, but looks a little like one + status_code = 305 + + +class HTTPTemporaryRedirect(HTTPMove): + status_code = 307 + + +class HTTPPermanentRedirect(HTTPMove): + status_code = 308 + + +############################################################ +# 4xx client error +############################################################ + + +class HTTPClientError(HTTPError): + pass + + +class HTTPBadRequest(HTTPClientError): + status_code = 400 + + +class HTTPUnauthorized(HTTPClientError): + status_code = 401 + + +class HTTPPaymentRequired(HTTPClientError): + status_code = 402 + + +class HTTPForbidden(HTTPClientError): + status_code = 403 + + +class HTTPNotFound(HTTPClientError): + status_code = 404 + + +class HTTPMethodNotAllowed(HTTPClientError): + status_code = 405 + + def __init__( + self, + method: str, + allowed_methods: Iterable[str], + *, + headers: Optional[LooseHeaders] = None, + reason: Optional[str] = None, + body: Any = None, + text: Optional[str] = None, + content_type: Optional[str] = None, + ) -> None: + allow = ",".join(sorted(allowed_methods)) + super().__init__( + headers=headers, + reason=reason, + body=body, + text=text, + content_type=content_type, + ) + self.headers["Allow"] = allow + self.allowed_methods: Set[str] = set(allowed_methods) + self.method = method.upper() + + +class HTTPNotAcceptable(HTTPClientError): + status_code = 406 + + +class HTTPProxyAuthenticationRequired(HTTPClientError): + status_code = 407 + + +class HTTPRequestTimeout(HTTPClientError): + status_code = 408 + + +class HTTPConflict(HTTPClientError): + status_code = 409 + + +class HTTPGone(HTTPClientError): + status_code = 410 + + +class HTTPLengthRequired(HTTPClientError): + status_code = 411 + + +class HTTPPreconditionFailed(HTTPClientError): + status_code = 412 + + +class HTTPRequestEntityTooLarge(HTTPClientError): + status_code = 413 + + def __init__(self, max_size: float, actual_size: float, **kwargs: Any) -> None: + kwargs.setdefault( + "text", + "Maximum request body size {} exceeded, " + "actual body size {}".format(max_size, actual_size), + ) + super().__init__(**kwargs) + + +class HTTPRequestURITooLong(HTTPClientError): + status_code = 414 + + +class HTTPUnsupportedMediaType(HTTPClientError): + status_code = 415 + + +class HTTPRequestRangeNotSatisfiable(HTTPClientError): + status_code = 416 + + +class HTTPExpectationFailed(HTTPClientError): + status_code = 417 + + +class HTTPMisdirectedRequest(HTTPClientError): + status_code = 421 + + +class HTTPUnprocessableEntity(HTTPClientError): + status_code = 422 + + +class HTTPFailedDependency(HTTPClientError): + status_code = 424 + + +class HTTPUpgradeRequired(HTTPClientError): + status_code = 426 + + +class HTTPPreconditionRequired(HTTPClientError): + status_code = 428 + + +class HTTPTooManyRequests(HTTPClientError): + status_code = 429 + + +class HTTPRequestHeaderFieldsTooLarge(HTTPClientError): + status_code = 431 + + +class HTTPUnavailableForLegalReasons(HTTPClientError): + status_code = 451 + + def __init__( + self, + link: Optional[StrOrURL], + *, + headers: Optional[LooseHeaders] = None, + reason: Optional[str] = None, + body: Any = None, + text: Optional[str] = None, + content_type: Optional[str] = None, + ) -> None: + super().__init__( + headers=headers, + reason=reason, + body=body, + text=text, + content_type=content_type, + ) + self._link = None + if link: + self._link = URL(link) + self.headers["Link"] = f'<{str(self._link)}>; rel="blocked-by"' + + @property + def link(self) -> Optional[URL]: + return self._link + + +############################################################ +# 5xx Server Error +############################################################ +# Response status codes beginning with the digit "5" indicate cases in +# which the server is aware that it has erred or is incapable of +# performing the request. Except when responding to a HEAD request, the +# server SHOULD include an entity containing an explanation of the error +# situation, and whether it is a temporary or permanent condition. User +# agents SHOULD display any included entity to the user. These response +# codes are applicable to any request method. + + +class HTTPServerError(HTTPError): + pass + + +class HTTPInternalServerError(HTTPServerError): + status_code = 500 + + +class HTTPNotImplemented(HTTPServerError): + status_code = 501 + + +class HTTPBadGateway(HTTPServerError): + status_code = 502 + + +class HTTPServiceUnavailable(HTTPServerError): + status_code = 503 + + +class HTTPGatewayTimeout(HTTPServerError): + status_code = 504 + + +class HTTPVersionNotSupported(HTTPServerError): + status_code = 505 + + +class HTTPVariantAlsoNegotiates(HTTPServerError): + status_code = 506 + + +class HTTPInsufficientStorage(HTTPServerError): + status_code = 507 + + +class HTTPNotExtended(HTTPServerError): + status_code = 510 + + +class HTTPNetworkAuthenticationRequired(HTTPServerError): + status_code = 511 diff --git a/llm/Lib/site-packages/aiohttp/web_fileresponse.py b/llm/Lib/site-packages/aiohttp/web_fileresponse.py new file mode 100644 index 0000000000000000000000000000000000000000..b167eedbc4966a7087b2e2e374ee72189e59abb5 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/web_fileresponse.py @@ -0,0 +1,305 @@ +import asyncio +import mimetypes +import os +import pathlib +from typing import ( # noqa + IO, + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Final, + Iterator, + List, + Optional, + Tuple, + Union, + cast, +) + +from . import hdrs +from .abc import AbstractStreamWriter +from .helpers import ETAG_ANY, ETag, must_be_empty_body +from .typedefs import LooseHeaders, PathLike +from .web_exceptions import ( + HTTPNotModified, + HTTPPartialContent, + HTTPPreconditionFailed, + HTTPRequestRangeNotSatisfiable, +) +from .web_response import StreamResponse + +__all__ = ("FileResponse",) + +if TYPE_CHECKING: + from .web_request import BaseRequest + + +_T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]] + + +NOSENDFILE: Final[bool] = bool(os.environ.get("AIOHTTP_NOSENDFILE")) + + +class FileResponse(StreamResponse): + """A response object can be used to send files.""" + + def __init__( + self, + path: PathLike, + chunk_size: int = 256 * 1024, + status: int = 200, + reason: Optional[str] = None, + headers: Optional[LooseHeaders] = None, + ) -> None: + super().__init__(status=status, reason=reason, headers=headers) + + self._path = pathlib.Path(path) + self._chunk_size = chunk_size + + async def _sendfile_fallback( + self, writer: AbstractStreamWriter, fobj: IO[Any], offset: int, count: int + ) -> AbstractStreamWriter: + # To keep memory usage low,fobj is transferred in chunks + # controlled by the constructor's chunk_size argument. + + chunk_size = self._chunk_size + loop = asyncio.get_event_loop() + + await loop.run_in_executor(None, fobj.seek, offset) + + chunk = await loop.run_in_executor(None, fobj.read, chunk_size) + while chunk: + await writer.write(chunk) + count = count - chunk_size + if count <= 0: + break + chunk = await loop.run_in_executor(None, fobj.read, min(chunk_size, count)) + + await writer.drain() + return writer + + async def _sendfile( + self, request: "BaseRequest", fobj: IO[Any], offset: int, count: int + ) -> AbstractStreamWriter: + writer = await super().prepare(request) + assert writer is not None + + if NOSENDFILE or self.compression: + return await self._sendfile_fallback(writer, fobj, offset, count) + + loop = request._loop + transport = request.transport + assert transport is not None + + try: + await loop.sendfile(transport, fobj, offset, count) + except NotImplementedError: + return await self._sendfile_fallback(writer, fobj, offset, count) + + await super().write_eof() + return writer + + @staticmethod + def _strong_etag_match(etag_value: str, etags: Tuple[ETag, ...]) -> bool: + if len(etags) == 1 and etags[0].value == ETAG_ANY: + return True + return any(etag.value == etag_value for etag in etags if not etag.is_weak) + + async def _not_modified( + self, request: "BaseRequest", etag_value: str, last_modified: float + ) -> Optional[AbstractStreamWriter]: + self.set_status(HTTPNotModified.status_code) + self._length_check = False + self.etag = etag_value # type: ignore[assignment] + self.last_modified = last_modified # type: ignore[assignment] + # Delete any Content-Length headers provided by user. HTTP 304 + # should always have empty response body + return await super().prepare(request) + + async def _precondition_failed( + self, request: "BaseRequest" + ) -> Optional[AbstractStreamWriter]: + self.set_status(HTTPPreconditionFailed.status_code) + self.content_length = 0 + return await super().prepare(request) + + def _get_file_path_stat_and_gzip( + self, check_for_gzipped_file: bool + ) -> Tuple[pathlib.Path, os.stat_result, bool]: + """Return the file path, stat result, and gzip status. + + This method should be called from a thread executor + since it calls os.stat which may block. + """ + filepath = self._path + if check_for_gzipped_file: + gzip_path = filepath.with_name(filepath.name + ".gz") + try: + return gzip_path, gzip_path.stat(), True + except OSError: + # Fall through and try the non-gzipped file + pass + + return filepath, filepath.stat(), False + + async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter]: + loop = asyncio.get_event_loop() + # Encoding comparisons should be case-insensitive + # https://www.rfc-editor.org/rfc/rfc9110#section-8.4.1 + check_for_gzipped_file = ( + "gzip" in request.headers.get(hdrs.ACCEPT_ENCODING, "").lower() + ) + filepath, st, gzip = await loop.run_in_executor( + None, self._get_file_path_stat_and_gzip, check_for_gzipped_file + ) + + etag_value = f"{st.st_mtime_ns:x}-{st.st_size:x}" + last_modified = st.st_mtime + + # https://tools.ietf.org/html/rfc7232#section-6 + ifmatch = request.if_match + if ifmatch is not None and not self._strong_etag_match(etag_value, ifmatch): + return await self._precondition_failed(request) + + unmodsince = request.if_unmodified_since + if ( + unmodsince is not None + and ifmatch is None + and st.st_mtime > unmodsince.timestamp() + ): + return await self._precondition_failed(request) + + ifnonematch = request.if_none_match + if ifnonematch is not None and self._strong_etag_match(etag_value, ifnonematch): + return await self._not_modified(request, etag_value, last_modified) + + modsince = request.if_modified_since + if ( + modsince is not None + and ifnonematch is None + and st.st_mtime <= modsince.timestamp() + ): + return await self._not_modified(request, etag_value, last_modified) + + if hdrs.CONTENT_TYPE not in self.headers: + ct, encoding = mimetypes.guess_type(str(filepath)) + if not ct: + ct = "application/octet-stream" + should_set_ct = True + else: + encoding = "gzip" if gzip else None + should_set_ct = False + + status = self._status + file_size = st.st_size + count = file_size + + start = None + + ifrange = request.if_range + if ifrange is None or st.st_mtime <= ifrange.timestamp(): + # If-Range header check: + # condition = cached date >= last modification date + # return 206 if True else 200. + # if False: + # Range header would not be processed, return 200 + # if True but Range header missing + # return 200 + try: + rng = request.http_range + start = rng.start + end = rng.stop + except ValueError: + # https://tools.ietf.org/html/rfc7233: + # A server generating a 416 (Range Not Satisfiable) response to + # a byte-range request SHOULD send a Content-Range header field + # with an unsatisfied-range value. + # The complete-length in a 416 response indicates the current + # length of the selected representation. + # + # Will do the same below. Many servers ignore this and do not + # send a Content-Range header with HTTP 416 + self.headers[hdrs.CONTENT_RANGE] = f"bytes */{file_size}" + self.set_status(HTTPRequestRangeNotSatisfiable.status_code) + return await super().prepare(request) + + # If a range request has been made, convert start, end slice + # notation into file pointer offset and count + if start is not None or end is not None: + if start < 0 and end is None: # return tail of file + start += file_size + if start < 0: + # if Range:bytes=-1000 in request header but file size + # is only 200, there would be trouble without this + start = 0 + count = file_size - start + else: + # rfc7233:If the last-byte-pos value is + # absent, or if the value is greater than or equal to + # the current length of the representation data, + # the byte range is interpreted as the remainder + # of the representation (i.e., the server replaces the + # value of last-byte-pos with a value that is one less than + # the current length of the selected representation). + count = ( + min(end if end is not None else file_size, file_size) - start + ) + + if start >= file_size: + # HTTP 416 should be returned in this case. + # + # According to https://tools.ietf.org/html/rfc7233: + # If a valid byte-range-set includes at least one + # byte-range-spec with a first-byte-pos that is less than + # the current length of the representation, or at least one + # suffix-byte-range-spec with a non-zero suffix-length, + # then the byte-range-set is satisfiable. Otherwise, the + # byte-range-set is unsatisfiable. + self.headers[hdrs.CONTENT_RANGE] = f"bytes */{file_size}" + self.set_status(HTTPRequestRangeNotSatisfiable.status_code) + return await super().prepare(request) + + status = HTTPPartialContent.status_code + # Even though you are sending the whole file, you should still + # return a HTTP 206 for a Range request. + self.set_status(status) + + if should_set_ct: + self.content_type = ct # type: ignore[assignment] + if encoding: + self.headers[hdrs.CONTENT_ENCODING] = encoding + if gzip: + self.headers[hdrs.VARY] = hdrs.ACCEPT_ENCODING + # Disable compression if we are already sending + # a compressed file since we don't want to double + # compress. + self._compression = False + + self.etag = etag_value # type: ignore[assignment] + self.last_modified = st.st_mtime # type: ignore[assignment] + self.content_length = count + + self.headers[hdrs.ACCEPT_RANGES] = "bytes" + + real_start = cast(int, start) + + if status == HTTPPartialContent.status_code: + self.headers[hdrs.CONTENT_RANGE] = "bytes {}-{}/{}".format( + real_start, real_start + count - 1, file_size + ) + + # If we are sending 0 bytes calling sendfile() will throw a ValueError + if count == 0 or must_be_empty_body(request.method, self.status): + return await super().prepare(request) + + fobj = await loop.run_in_executor(None, filepath.open, "rb") + if start: # be aware that start could be None or int=0 here. + offset = start + else: + offset = 0 + + try: + return await self._sendfile(request, fobj, offset, count) + finally: + await asyncio.shield(loop.run_in_executor(None, fobj.close)) diff --git a/llm/Lib/site-packages/aiohttp/web_log.py b/llm/Lib/site-packages/aiohttp/web_log.py new file mode 100644 index 0000000000000000000000000000000000000000..ced0276955e511d69b26814a98f53ca2612f6795 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/web_log.py @@ -0,0 +1,213 @@ +import datetime +import functools +import logging +import os +import re +import time as time_mod +from collections import namedtuple +from typing import Any, Callable, Dict, Iterable, List, Tuple # noqa + +from .abc import AbstractAccessLogger +from .web_request import BaseRequest +from .web_response import StreamResponse + +KeyMethod = namedtuple("KeyMethod", "key method") + + +class AccessLogger(AbstractAccessLogger): + """Helper object to log access. + + Usage: + log = logging.getLogger("spam") + log_format = "%a %{User-Agent}i" + access_logger = AccessLogger(log, log_format) + access_logger.log(request, response, time) + + Format: + %% The percent sign + %a Remote IP-address (IP-address of proxy if using reverse proxy) + %t Time when the request was started to process + %P The process ID of the child that serviced the request + %r First line of request + %s Response status code + %b Size of response in bytes, including HTTP headers + %T Time taken to serve the request, in seconds + %Tf Time taken to serve the request, in seconds with floating fraction + in .06f format + %D Time taken to serve the request, in microseconds + %{FOO}i request.headers['FOO'] + %{FOO}o response.headers['FOO'] + %{FOO}e os.environ['FOO'] + + """ + + LOG_FORMAT_MAP = { + "a": "remote_address", + "t": "request_start_time", + "P": "process_id", + "r": "first_request_line", + "s": "response_status", + "b": "response_size", + "T": "request_time", + "Tf": "request_time_frac", + "D": "request_time_micro", + "i": "request_header", + "o": "response_header", + } + + LOG_FORMAT = '%a %t "%r" %s %b "%{Referer}i" "%{User-Agent}i"' + FORMAT_RE = re.compile(r"%(\{([A-Za-z0-9\-_]+)\}([ioe])|[atPrsbOD]|Tf?)") + CLEANUP_RE = re.compile(r"(%[^s])") + _FORMAT_CACHE: Dict[str, Tuple[str, List[KeyMethod]]] = {} + + def __init__(self, logger: logging.Logger, log_format: str = LOG_FORMAT) -> None: + """Initialise the logger. + + logger is a logger object to be used for logging. + log_format is a string with apache compatible log format description. + + """ + super().__init__(logger, log_format=log_format) + + _compiled_format = AccessLogger._FORMAT_CACHE.get(log_format) + if not _compiled_format: + _compiled_format = self.compile_format(log_format) + AccessLogger._FORMAT_CACHE[log_format] = _compiled_format + + self._log_format, self._methods = _compiled_format + + def compile_format(self, log_format: str) -> Tuple[str, List[KeyMethod]]: + """Translate log_format into form usable by modulo formatting + + All known atoms will be replaced with %s + Also methods for formatting of those atoms will be added to + _methods in appropriate order + + For example we have log_format = "%a %t" + This format will be translated to "%s %s" + Also contents of _methods will be + [self._format_a, self._format_t] + These method will be called and results will be passed + to translated string format. + + Each _format_* method receive 'args' which is list of arguments + given to self.log + + Exceptions are _format_e, _format_i and _format_o methods which + also receive key name (by functools.partial) + + """ + # list of (key, method) tuples, we don't use an OrderedDict as users + # can repeat the same key more than once + methods = list() + + for atom in self.FORMAT_RE.findall(log_format): + if atom[1] == "": + format_key1 = self.LOG_FORMAT_MAP[atom[0]] + m = getattr(AccessLogger, "_format_%s" % atom[0]) + key_method = KeyMethod(format_key1, m) + else: + format_key2 = (self.LOG_FORMAT_MAP[atom[2]], atom[1]) + m = getattr(AccessLogger, "_format_%s" % atom[2]) + key_method = KeyMethod(format_key2, functools.partial(m, atom[1])) + + methods.append(key_method) + + log_format = self.FORMAT_RE.sub(r"%s", log_format) + log_format = self.CLEANUP_RE.sub(r"%\1", log_format) + return log_format, methods + + @staticmethod + def _format_i( + key: str, request: BaseRequest, response: StreamResponse, time: float + ) -> str: + if request is None: + return "(no headers)" + + # suboptimal, make istr(key) once + return request.headers.get(key, "-") + + @staticmethod + def _format_o( + key: str, request: BaseRequest, response: StreamResponse, time: float + ) -> str: + # suboptimal, make istr(key) once + return response.headers.get(key, "-") + + @staticmethod + def _format_a(request: BaseRequest, response: StreamResponse, time: float) -> str: + if request is None: + return "-" + ip = request.remote + return ip if ip is not None else "-" + + @staticmethod + def _format_t(request: BaseRequest, response: StreamResponse, time: float) -> str: + tz = datetime.timezone(datetime.timedelta(seconds=-time_mod.timezone)) + now = datetime.datetime.now(tz) + start_time = now - datetime.timedelta(seconds=time) + return start_time.strftime("[%d/%b/%Y:%H:%M:%S %z]") + + @staticmethod + def _format_P(request: BaseRequest, response: StreamResponse, time: float) -> str: + return "<%s>" % os.getpid() + + @staticmethod + def _format_r(request: BaseRequest, response: StreamResponse, time: float) -> str: + if request is None: + return "-" + return "{} {} HTTP/{}.{}".format( + request.method, + request.path_qs, + request.version.major, + request.version.minor, + ) + + @staticmethod + def _format_s(request: BaseRequest, response: StreamResponse, time: float) -> int: + return response.status + + @staticmethod + def _format_b(request: BaseRequest, response: StreamResponse, time: float) -> int: + return response.body_length + + @staticmethod + def _format_T(request: BaseRequest, response: StreamResponse, time: float) -> str: + return str(round(time)) + + @staticmethod + def _format_Tf(request: BaseRequest, response: StreamResponse, time: float) -> str: + return "%06f" % time + + @staticmethod + def _format_D(request: BaseRequest, response: StreamResponse, time: float) -> str: + return str(round(time * 1000000)) + + def _format_line( + self, request: BaseRequest, response: StreamResponse, time: float + ) -> Iterable[Tuple[str, Callable[[BaseRequest, StreamResponse, float], str]]]: + return [(key, method(request, response, time)) for key, method in self._methods] + + def log(self, request: BaseRequest, response: StreamResponse, time: float) -> None: + if not self.logger.isEnabledFor(logging.INFO): + # Avoid formatting the log line if it will not be emitted. + return + try: + fmt_info = self._format_line(request, response, time) + + values = list() + extra = dict() + for key, value in fmt_info: + values.append(value) + + if key.__class__ is str: + extra[key] = value + else: + k1, k2 = key # type: ignore[misc] + dct = extra.get(k1, {}) # type: ignore[var-annotated,has-type] + dct[k2] = value # type: ignore[index,has-type] + extra[k1] = dct # type: ignore[has-type,assignment] + + self.logger.info(self._log_format % tuple(values), extra=extra) + except Exception: + self.logger.exception("Error in logging") diff --git a/llm/Lib/site-packages/aiohttp/web_middlewares.py b/llm/Lib/site-packages/aiohttp/web_middlewares.py new file mode 100644 index 0000000000000000000000000000000000000000..b2ab1e3f309c6b2f5e6abafdbb8cc654422dad33 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/web_middlewares.py @@ -0,0 +1,116 @@ +import re +from typing import TYPE_CHECKING, Tuple, Type, TypeVar + +from .typedefs import Handler, Middleware +from .web_exceptions import HTTPMove, HTTPPermanentRedirect +from .web_request import Request +from .web_response import StreamResponse +from .web_urldispatcher import SystemRoute + +__all__ = ( + "middleware", + "normalize_path_middleware", +) + +if TYPE_CHECKING: + from .web_app import Application + +_Func = TypeVar("_Func") + + +async def _check_request_resolves(request: Request, path: str) -> Tuple[bool, Request]: + alt_request = request.clone(rel_url=path) + + match_info = await request.app.router.resolve(alt_request) + alt_request._match_info = match_info + + if match_info.http_exception is None: + return True, alt_request + + return False, request + + +def middleware(f: _Func) -> _Func: + f.__middleware_version__ = 1 # type: ignore[attr-defined] + return f + + +def normalize_path_middleware( + *, + append_slash: bool = True, + remove_slash: bool = False, + merge_slashes: bool = True, + redirect_class: Type[HTTPMove] = HTTPPermanentRedirect, +) -> Middleware: + """Factory for producing a middleware that normalizes the path of a request. + + Normalizing means: + - Add or remove a trailing slash to the path. + - Double slashes are replaced by one. + + The middleware returns as soon as it finds a path that resolves + correctly. The order if both merge and append/remove are enabled is + 1) merge slashes + 2) append/remove slash + 3) both merge slashes and append/remove slash. + If the path resolves with at least one of those conditions, it will + redirect to the new path. + + Only one of `append_slash` and `remove_slash` can be enabled. If both + are `True` the factory will raise an assertion error + + If `append_slash` is `True` the middleware will append a slash when + needed. If a resource is defined with trailing slash and the request + comes without it, it will append it automatically. + + If `remove_slash` is `True`, `append_slash` must be `False`. When enabled + the middleware will remove trailing slashes and redirect if the resource + is defined + + If merge_slashes is True, merge multiple consecutive slashes in the + path into one. + """ + correct_configuration = not (append_slash and remove_slash) + assert correct_configuration, "Cannot both remove and append slash" + + @middleware + async def impl(request: Request, handler: Handler) -> StreamResponse: + if isinstance(request.match_info.route, SystemRoute): + paths_to_check = [] + if "?" in request.raw_path: + path, query = request.raw_path.split("?", 1) + query = "?" + query + else: + query = "" + path = request.raw_path + + if merge_slashes: + paths_to_check.append(re.sub("//+", "/", path)) + if append_slash and not request.path.endswith("/"): + paths_to_check.append(path + "/") + if remove_slash and request.path.endswith("/"): + paths_to_check.append(path[:-1]) + if merge_slashes and append_slash: + paths_to_check.append(re.sub("//+", "/", path + "/")) + if merge_slashes and remove_slash: + merged_slashes = re.sub("//+", "/", path) + paths_to_check.append(merged_slashes[:-1]) + + for path in paths_to_check: + path = re.sub("^//+", "/", path) # SECURITY: GHSA-v6wp-4m6f-gcjg + resolves, request = await _check_request_resolves(request, path) + if resolves: + raise redirect_class(request.raw_path + query) + + return await handler(request) + + return impl + + +def _fix_request_current_app(app: "Application") -> Middleware: + @middleware + async def impl(request: Request, handler: Handler) -> StreamResponse: + with request.match_info.set_current_app(app): + return await handler(request) + + return impl diff --git a/llm/Lib/site-packages/aiohttp/web_protocol.py b/llm/Lib/site-packages/aiohttp/web_protocol.py new file mode 100644 index 0000000000000000000000000000000000000000..9b6213ce6c6cc7eb984eadd8791c42cb004cdf8d --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/web_protocol.py @@ -0,0 +1,698 @@ +import asyncio +import asyncio.streams +import traceback +import warnings +from collections import deque +from contextlib import suppress +from html import escape as html_escape +from http import HTTPStatus +from logging import Logger +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Deque, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, +) + +import attr +import yarl + +from .abc import AbstractAccessLogger, AbstractStreamWriter +from .base_protocol import BaseProtocol +from .helpers import ceil_timeout, set_exception +from .http import ( + HttpProcessingError, + HttpRequestParser, + HttpVersion10, + RawRequestMessage, + StreamWriter, +) +from .log import access_logger, server_logger +from .streams import EMPTY_PAYLOAD, StreamReader +from .tcp_helpers import tcp_keepalive +from .web_exceptions import HTTPException +from .web_log import AccessLogger +from .web_request import BaseRequest +from .web_response import Response, StreamResponse + +__all__ = ("RequestHandler", "RequestPayloadError", "PayloadAccessError") + +if TYPE_CHECKING: + from .web_server import Server + + +_RequestFactory = Callable[ + [ + RawRequestMessage, + StreamReader, + "RequestHandler", + AbstractStreamWriter, + "asyncio.Task[None]", + ], + BaseRequest, +] + +_RequestHandler = Callable[[BaseRequest], Awaitable[StreamResponse]] + +ERROR = RawRequestMessage( + "UNKNOWN", + "/", + HttpVersion10, + {}, # type: ignore[arg-type] + {}, # type: ignore[arg-type] + True, + None, + False, + False, + yarl.URL("/"), +) + + +class RequestPayloadError(Exception): + """Payload parsing error.""" + + +class PayloadAccessError(Exception): + """Payload was accessed after response was sent.""" + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class _ErrInfo: + status: int + exc: BaseException + message: str + + +_MsgType = Tuple[Union[RawRequestMessage, _ErrInfo], StreamReader] + + +class RequestHandler(BaseProtocol): + """HTTP protocol implementation. + + RequestHandler handles incoming HTTP request. It reads request line, + request headers and request payload and calls handle_request() method. + By default it always returns with 404 response. + + RequestHandler handles errors in incoming request, like bad + status line, bad headers or incomplete payload. If any error occurs, + connection gets closed. + + keepalive_timeout -- number of seconds before closing + keep-alive connection + + tcp_keepalive -- TCP keep-alive is on, default is on + + debug -- enable debug mode + + logger -- custom logger object + + access_log_class -- custom class for access_logger + + access_log -- custom logging object + + access_log_format -- access log format string + + loop -- Optional event loop + + max_line_size -- Optional maximum header line size + + max_field_size -- Optional maximum header field size + + max_headers -- Optional maximum header size + + timeout_ceil_threshold -- Optional value to specify + threshold to ceil() timeout + values + + """ + + KEEPALIVE_RESCHEDULE_DELAY = 1 + + __slots__ = ( + "_request_count", + "_keepalive", + "_manager", + "_request_handler", + "_request_factory", + "_tcp_keepalive", + "_keepalive_time", + "_keepalive_handle", + "_keepalive_timeout", + "_lingering_time", + "_messages", + "_message_tail", + "_waiter", + "_task_handler", + "_upgrade", + "_payload_parser", + "_request_parser", + "_reading_paused", + "logger", + "debug", + "access_log", + "access_logger", + "_close", + "_force_close", + "_current_request", + "_timeout_ceil_threshold", + ) + + def __init__( + self, + manager: "Server", + *, + loop: asyncio.AbstractEventLoop, + keepalive_timeout: float = 75.0, # NGINX default is 75 secs + tcp_keepalive: bool = True, + logger: Logger = server_logger, + access_log_class: Type[AbstractAccessLogger] = AccessLogger, + access_log: Logger = access_logger, + access_log_format: str = AccessLogger.LOG_FORMAT, + debug: bool = False, + max_line_size: int = 8190, + max_headers: int = 32768, + max_field_size: int = 8190, + lingering_time: float = 10.0, + read_bufsize: int = 2**16, + auto_decompress: bool = True, + timeout_ceil_threshold: float = 5, + ): + super().__init__(loop) + + self._request_count = 0 + self._keepalive = False + self._current_request: Optional[BaseRequest] = None + self._manager: Optional[Server] = manager + self._request_handler: Optional[_RequestHandler] = manager.request_handler + self._request_factory: Optional[_RequestFactory] = manager.request_factory + + self._tcp_keepalive = tcp_keepalive + # placeholder to be replaced on keepalive timeout setup + self._keepalive_time = 0.0 + self._keepalive_handle: Optional[asyncio.Handle] = None + self._keepalive_timeout = keepalive_timeout + self._lingering_time = float(lingering_time) + + self._messages: Deque[_MsgType] = deque() + self._message_tail = b"" + + self._waiter: Optional[asyncio.Future[None]] = None + self._task_handler: Optional[asyncio.Task[None]] = None + + self._upgrade = False + self._payload_parser: Any = None + self._request_parser: Optional[HttpRequestParser] = HttpRequestParser( + self, + loop, + read_bufsize, + max_line_size=max_line_size, + max_field_size=max_field_size, + max_headers=max_headers, + payload_exception=RequestPayloadError, + auto_decompress=auto_decompress, + ) + + self._timeout_ceil_threshold: float = 5 + try: + self._timeout_ceil_threshold = float(timeout_ceil_threshold) + except (TypeError, ValueError): + pass + + self.logger = logger + self.debug = debug + self.access_log = access_log + if access_log: + self.access_logger: Optional[AbstractAccessLogger] = access_log_class( + access_log, access_log_format + ) + else: + self.access_logger = None + + self._close = False + self._force_close = False + + def __repr__(self) -> str: + return "<{} {}>".format( + self.__class__.__name__, + "connected" if self.transport is not None else "disconnected", + ) + + @property + def keepalive_timeout(self) -> float: + return self._keepalive_timeout + + async def shutdown(self, timeout: Optional[float] = 15.0) -> None: + """Do worker process exit preparations. + + We need to clean up everything and stop accepting requests. + It is especially important for keep-alive connections. + """ + self._force_close = True + + if self._keepalive_handle is not None: + self._keepalive_handle.cancel() + + if self._waiter: + self._waiter.cancel() + + # wait for handlers + with suppress(asyncio.CancelledError, asyncio.TimeoutError): + async with ceil_timeout(timeout): + if self._current_request is not None: + self._current_request._cancel(asyncio.CancelledError()) + + if self._task_handler is not None and not self._task_handler.done(): + await self._task_handler + + # force-close non-idle handler + if self._task_handler is not None: + self._task_handler.cancel() + + if self.transport is not None: + self.transport.close() + self.transport = None + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + super().connection_made(transport) + + real_transport = cast(asyncio.Transport, transport) + if self._tcp_keepalive: + tcp_keepalive(real_transport) + + self._task_handler = self._loop.create_task(self.start()) + assert self._manager is not None + self._manager.connection_made(self, real_transport) + + def connection_lost(self, exc: Optional[BaseException]) -> None: + if self._manager is None: + return + self._manager.connection_lost(self, exc) + + super().connection_lost(exc) + + # Grab value before setting _manager to None. + handler_cancellation = self._manager.handler_cancellation + + self._manager = None + self._force_close = True + self._request_factory = None + self._request_handler = None + self._request_parser = None + + if self._keepalive_handle is not None: + self._keepalive_handle.cancel() + + if self._current_request is not None: + if exc is None: + exc = ConnectionResetError("Connection lost") + self._current_request._cancel(exc) + + if self._waiter is not None: + self._waiter.cancel() + + if handler_cancellation and self._task_handler is not None: + self._task_handler.cancel() + + self._task_handler = None + + if self._payload_parser is not None: + self._payload_parser.feed_eof() + self._payload_parser = None + + def set_parser(self, parser: Any) -> None: + # Actual type is WebReader + assert self._payload_parser is None + + self._payload_parser = parser + + if self._message_tail: + self._payload_parser.feed_data(self._message_tail) + self._message_tail = b"" + + def eof_received(self) -> None: + pass + + def data_received(self, data: bytes) -> None: + if self._force_close or self._close: + return + # parse http messages + messages: Sequence[_MsgType] + if self._payload_parser is None and not self._upgrade: + assert self._request_parser is not None + try: + messages, upgraded, tail = self._request_parser.feed_data(data) + except HttpProcessingError as exc: + messages = [ + (_ErrInfo(status=400, exc=exc, message=exc.message), EMPTY_PAYLOAD) + ] + upgraded = False + tail = b"" + + for msg, payload in messages or (): + self._request_count += 1 + self._messages.append((msg, payload)) + + waiter = self._waiter + if messages and waiter is not None and not waiter.done(): + # don't set result twice + waiter.set_result(None) + + self._upgrade = upgraded + if upgraded and tail: + self._message_tail = tail + + # no parser, just store + elif self._payload_parser is None and self._upgrade and data: + self._message_tail += data + + # feed payload + elif data: + eof, tail = self._payload_parser.feed_data(data) + if eof: + self.close() + + def keep_alive(self, val: bool) -> None: + """Set keep-alive connection mode. + + :param bool val: new state. + """ + self._keepalive = val + if self._keepalive_handle: + self._keepalive_handle.cancel() + self._keepalive_handle = None + + def close(self) -> None: + """Close connection. + + Stop accepting new pipelining messages and close + connection when handlers done processing messages. + """ + self._close = True + if self._waiter: + self._waiter.cancel() + + def force_close(self) -> None: + """Forcefully close connection.""" + self._force_close = True + if self._waiter: + self._waiter.cancel() + if self.transport is not None: + self.transport.close() + self.transport = None + + def log_access( + self, request: BaseRequest, response: StreamResponse, time: float + ) -> None: + if self.access_logger is not None: + self.access_logger.log(request, response, self._loop.time() - time) + + def log_debug(self, *args: Any, **kw: Any) -> None: + if self.debug: + self.logger.debug(*args, **kw) + + def log_exception(self, *args: Any, **kw: Any) -> None: + self.logger.exception(*args, **kw) + + def _process_keepalive(self) -> None: + if self._force_close or not self._keepalive: + return + + next = self._keepalive_time + self._keepalive_timeout + + # handler in idle state + if self._waiter: + if self._loop.time() > next: + self.force_close() + return + + # not all request handlers are done, + # reschedule itself to next second + self._keepalive_handle = self._loop.call_later( + self.KEEPALIVE_RESCHEDULE_DELAY, + self._process_keepalive, + ) + + async def _handle_request( + self, + request: BaseRequest, + start_time: float, + request_handler: Callable[[BaseRequest], Awaitable[StreamResponse]], + ) -> Tuple[StreamResponse, bool]: + assert self._request_handler is not None + try: + try: + self._current_request = request + resp = await request_handler(request) + finally: + self._current_request = None + except HTTPException as exc: + resp = exc + reset = await self.finish_response(request, resp, start_time) + except asyncio.CancelledError: + raise + except asyncio.TimeoutError as exc: + self.log_debug("Request handler timed out.", exc_info=exc) + resp = self.handle_error(request, 504) + reset = await self.finish_response(request, resp, start_time) + except Exception as exc: + resp = self.handle_error(request, 500, exc) + reset = await self.finish_response(request, resp, start_time) + else: + # Deprecation warning (See #2415) + if getattr(resp, "__http_exception__", False): + warnings.warn( + "returning HTTPException object is deprecated " + "(#2415) and will be removed, " + "please raise the exception instead", + DeprecationWarning, + ) + + reset = await self.finish_response(request, resp, start_time) + + return resp, reset + + async def start(self) -> None: + """Process incoming request. + + It reads request line, request headers and request payload, then + calls handle_request() method. Subclass has to override + handle_request(). start() handles various exceptions in request + or response handling. Connection is being closed always unless + keep_alive(True) specified. + """ + loop = self._loop + handler = self._task_handler + assert handler is not None + manager = self._manager + assert manager is not None + keepalive_timeout = self._keepalive_timeout + resp = None + assert self._request_factory is not None + assert self._request_handler is not None + + while not self._force_close: + if not self._messages: + try: + # wait for next request + self._waiter = loop.create_future() + await self._waiter + except asyncio.CancelledError: + break + finally: + self._waiter = None + + message, payload = self._messages.popleft() + + start = loop.time() + + manager.requests_count += 1 + writer = StreamWriter(self, loop) + if isinstance(message, _ErrInfo): + # make request_factory work + request_handler = self._make_error_handler(message) + message = ERROR + else: + request_handler = self._request_handler + + request = self._request_factory(message, payload, self, writer, handler) + try: + # a new task is used for copy context vars (#3406) + task = self._loop.create_task( + self._handle_request(request, start, request_handler) + ) + try: + resp, reset = await task + except (asyncio.CancelledError, ConnectionError): + self.log_debug("Ignored premature client disconnection") + break + + # Drop the processed task from asyncio.Task.all_tasks() early + del task + if reset: + self.log_debug("Ignored premature client disconnection 2") + break + + # notify server about keep-alive + self._keepalive = bool(resp.keep_alive) + + # check payload + if not payload.is_eof(): + lingering_time = self._lingering_time + if not self._force_close and lingering_time: + self.log_debug( + "Start lingering close timer for %s sec.", lingering_time + ) + + now = loop.time() + end_t = now + lingering_time + + with suppress(asyncio.TimeoutError, asyncio.CancelledError): + while not payload.is_eof() and now < end_t: + async with ceil_timeout(end_t - now): + # read and ignore + await payload.readany() + now = loop.time() + + # if payload still uncompleted + if not payload.is_eof() and not self._force_close: + self.log_debug("Uncompleted request.") + self.close() + + set_exception(payload, PayloadAccessError()) + + except asyncio.CancelledError: + self.log_debug("Ignored premature client disconnection ") + break + except RuntimeError as exc: + if self.debug: + self.log_exception("Unhandled runtime exception", exc_info=exc) + self.force_close() + except Exception as exc: + self.log_exception("Unhandled exception", exc_info=exc) + self.force_close() + finally: + if self.transport is None and resp is not None: + self.log_debug("Ignored premature client disconnection.") + elif not self._force_close: + if self._keepalive and not self._close: + # start keep-alive timer + if keepalive_timeout is not None: + now = self._loop.time() + self._keepalive_time = now + if self._keepalive_handle is None: + self._keepalive_handle = loop.call_at( + now + keepalive_timeout, self._process_keepalive + ) + else: + break + + # remove handler, close transport if no handlers left + if not self._force_close: + self._task_handler = None + if self.transport is not None: + self.transport.close() + + async def finish_response( + self, request: BaseRequest, resp: StreamResponse, start_time: float + ) -> bool: + """Prepare the response and write_eof, then log access. + + This has to + be called within the context of any exception so the access logger + can get exception information. Returns True if the client disconnects + prematurely. + """ + if self._request_parser is not None: + self._request_parser.set_upgraded(False) + self._upgrade = False + if self._message_tail: + self._request_parser.feed_data(self._message_tail) + self._message_tail = b"" + try: + prepare_meth = resp.prepare + except AttributeError: + if resp is None: + raise RuntimeError("Missing return " "statement on request handler") + else: + raise RuntimeError( + "Web-handler should return " + "a response instance, " + "got {!r}".format(resp) + ) + try: + await prepare_meth(request) + await resp.write_eof() + except ConnectionError: + self.log_access(request, resp, start_time) + return True + else: + self.log_access(request, resp, start_time) + return False + + def handle_error( + self, + request: BaseRequest, + status: int = 500, + exc: Optional[BaseException] = None, + message: Optional[str] = None, + ) -> StreamResponse: + """Handle errors. + + Returns HTTP response with specific status code. Logs additional + information. It always closes current connection. + """ + self.log_exception("Error handling request", exc_info=exc) + + # some data already got sent, connection is broken + if request.writer.output_size > 0: + raise ConnectionError( + "Response is sent already, cannot send another response " + "with the error message" + ) + + ct = "text/plain" + if status == HTTPStatus.INTERNAL_SERVER_ERROR: + title = "{0.value} {0.phrase}".format(HTTPStatus.INTERNAL_SERVER_ERROR) + msg = HTTPStatus.INTERNAL_SERVER_ERROR.description + tb = None + if self.debug: + with suppress(Exception): + tb = traceback.format_exc() + + if "text/html" in request.headers.get("Accept", ""): + if tb: + tb = html_escape(tb) + msg = f"

Traceback:

\n
{tb}
" + message = ( + "" + "{title}" + "\n

{title}

" + "\n{msg}\n\n" + ).format(title=title, msg=msg) + ct = "text/html" + else: + if tb: + msg = tb + message = title + "\n\n" + msg + + resp = Response(status=status, text=message, content_type=ct) + resp.force_close() + + return resp + + def _make_error_handler( + self, err_info: _ErrInfo + ) -> Callable[[BaseRequest], Awaitable[StreamResponse]]: + async def handler(request: BaseRequest) -> StreamResponse: + return self.handle_error( + request, err_info.status, err_info.exc, err_info.message + ) + + return handler diff --git a/llm/Lib/site-packages/aiohttp/web_request.py b/llm/Lib/site-packages/aiohttp/web_request.py new file mode 100644 index 0000000000000000000000000000000000000000..68c46b03d27c838bc8ca7596e0088d7b2238ea47 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/web_request.py @@ -0,0 +1,901 @@ +import asyncio +import datetime +import io +import re +import socket +import string +import tempfile +import types +import warnings +from http.cookies import SimpleCookie +from types import MappingProxyType +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Final, + Iterator, + Mapping, + MutableMapping, + Optional, + Pattern, + Tuple, + Union, + cast, +) +from urllib.parse import parse_qsl + +import attr +from multidict import ( + CIMultiDict, + CIMultiDictProxy, + MultiDict, + MultiDictProxy, + MultiMapping, +) +from yarl import URL + +from . import hdrs +from .abc import AbstractStreamWriter +from .helpers import ( + _SENTINEL, + DEBUG, + ETAG_ANY, + LIST_QUOTED_ETAG_RE, + ChainMapProxy, + ETag, + HeadersMixin, + parse_http_date, + reify, + sentinel, + set_exception, +) +from .http_parser import RawRequestMessage +from .http_writer import HttpVersion +from .multipart import BodyPartReader, MultipartReader +from .streams import EmptyStreamReader, StreamReader +from .typedefs import ( + DEFAULT_JSON_DECODER, + JSONDecoder, + LooseHeaders, + RawHeaders, + StrOrURL, +) +from .web_exceptions import HTTPRequestEntityTooLarge +from .web_response import StreamResponse + +__all__ = ("BaseRequest", "FileField", "Request") + + +if TYPE_CHECKING: + from .web_app import Application + from .web_protocol import RequestHandler + from .web_urldispatcher import UrlMappingMatchInfo + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class FileField: + name: str + filename: str + file: io.BufferedReader + content_type: str + headers: "CIMultiDictProxy[str]" + + +_TCHAR: Final[str] = string.digits + string.ascii_letters + r"!#$%&'*+.^_`|~-" +# '-' at the end to prevent interpretation as range in a char class + +_TOKEN: Final[str] = rf"[{_TCHAR}]+" + +_QDTEXT: Final[str] = r"[{}]".format( + r"".join(chr(c) for c in (0x09, 0x20, 0x21) + tuple(range(0x23, 0x7F))) +) +# qdtext includes 0x5C to escape 0x5D ('\]') +# qdtext excludes obs-text (because obsoleted, and encoding not specified) + +_QUOTED_PAIR: Final[str] = r"\\[\t !-~]" + +_QUOTED_STRING: Final[str] = r'"(?:{quoted_pair}|{qdtext})*"'.format( + qdtext=_QDTEXT, quoted_pair=_QUOTED_PAIR +) + +_FORWARDED_PAIR: Final[ + str +] = r"({token})=({token}|{quoted_string})(:\d{{1,4}})?".format( + token=_TOKEN, quoted_string=_QUOTED_STRING +) + +_QUOTED_PAIR_REPLACE_RE: Final[Pattern[str]] = re.compile(r"\\([\t !-~])") +# same pattern as _QUOTED_PAIR but contains a capture group + +_FORWARDED_PAIR_RE: Final[Pattern[str]] = re.compile(_FORWARDED_PAIR) + +############################################################ +# HTTP Request +############################################################ + + +class BaseRequest(MutableMapping[str, Any], HeadersMixin): + + POST_METHODS = { + hdrs.METH_PATCH, + hdrs.METH_POST, + hdrs.METH_PUT, + hdrs.METH_TRACE, + hdrs.METH_DELETE, + } + + ATTRS = HeadersMixin.ATTRS | frozenset( + [ + "_message", + "_protocol", + "_payload_writer", + "_payload", + "_headers", + "_method", + "_version", + "_rel_url", + "_post", + "_read_bytes", + "_state", + "_cache", + "_task", + "_client_max_size", + "_loop", + "_transport_sslcontext", + "_transport_peername", + ] + ) + + def __init__( + self, + message: RawRequestMessage, + payload: StreamReader, + protocol: "RequestHandler", + payload_writer: AbstractStreamWriter, + task: "asyncio.Task[None]", + loop: asyncio.AbstractEventLoop, + *, + client_max_size: int = 1024**2, + state: Optional[Dict[str, Any]] = None, + scheme: Optional[str] = None, + host: Optional[str] = None, + remote: Optional[str] = None, + ) -> None: + if state is None: + state = {} + self._message = message + self._protocol = protocol + self._payload_writer = payload_writer + + self._payload = payload + self._headers = message.headers + self._method = message.method + self._version = message.version + self._cache: Dict[str, Any] = {} + url = message.url + if url.is_absolute(): + # absolute URL is given, + # override auto-calculating url, host, and scheme + # all other properties should be good + self._cache["url"] = url + self._cache["host"] = url.host + self._cache["scheme"] = url.scheme + self._rel_url = url.relative() + else: + self._rel_url = message.url + self._post: Optional[MultiDictProxy[Union[str, bytes, FileField]]] = None + self._read_bytes: Optional[bytes] = None + + self._state = state + self._task = task + self._client_max_size = client_max_size + self._loop = loop + + transport = self._protocol.transport + assert transport is not None + self._transport_sslcontext = transport.get_extra_info("sslcontext") + self._transport_peername = transport.get_extra_info("peername") + + if scheme is not None: + self._cache["scheme"] = scheme + if host is not None: + self._cache["host"] = host + if remote is not None: + self._cache["remote"] = remote + + def clone( + self, + *, + method: Union[str, _SENTINEL] = sentinel, + rel_url: Union[StrOrURL, _SENTINEL] = sentinel, + headers: Union[LooseHeaders, _SENTINEL] = sentinel, + scheme: Union[str, _SENTINEL] = sentinel, + host: Union[str, _SENTINEL] = sentinel, + remote: Union[str, _SENTINEL] = sentinel, + client_max_size: Union[int, _SENTINEL] = sentinel, + ) -> "BaseRequest": + """Clone itself with replacement some attributes. + + Creates and returns a new instance of Request object. If no parameters + are given, an exact copy is returned. If a parameter is not passed, it + will reuse the one from the current request object. + """ + if self._read_bytes: + raise RuntimeError("Cannot clone request " "after reading its content") + + dct: Dict[str, Any] = {} + if method is not sentinel: + dct["method"] = method + if rel_url is not sentinel: + new_url: URL = URL(rel_url) + dct["url"] = new_url + dct["path"] = str(new_url) + if headers is not sentinel: + # a copy semantic + dct["headers"] = CIMultiDictProxy(CIMultiDict(headers)) + dct["raw_headers"] = tuple( + (k.encode("utf-8"), v.encode("utf-8")) for k, v in headers.items() + ) + + message = self._message._replace(**dct) + + kwargs = {} + if scheme is not sentinel: + kwargs["scheme"] = scheme + if host is not sentinel: + kwargs["host"] = host + if remote is not sentinel: + kwargs["remote"] = remote + if client_max_size is sentinel: + client_max_size = self._client_max_size + + return self.__class__( + message, + self._payload, + self._protocol, + self._payload_writer, + self._task, + self._loop, + client_max_size=client_max_size, + state=self._state.copy(), + **kwargs, + ) + + @property + def task(self) -> "asyncio.Task[None]": + return self._task + + @property + def protocol(self) -> "RequestHandler": + return self._protocol + + @property + def transport(self) -> Optional[asyncio.Transport]: + if self._protocol is None: + return None + return self._protocol.transport + + @property + def writer(self) -> AbstractStreamWriter: + return self._payload_writer + + @property + def client_max_size(self) -> int: + return self._client_max_size + + @reify + def message(self) -> RawRequestMessage: + warnings.warn("Request.message is deprecated", DeprecationWarning, stacklevel=3) + return self._message + + @reify + def rel_url(self) -> URL: + return self._rel_url + + @reify + def loop(self) -> asyncio.AbstractEventLoop: + warnings.warn( + "request.loop property is deprecated", DeprecationWarning, stacklevel=2 + ) + return self._loop + + # MutableMapping API + + def __getitem__(self, key: str) -> Any: + return self._state[key] + + def __setitem__(self, key: str, value: Any) -> None: + self._state[key] = value + + def __delitem__(self, key: str) -> None: + del self._state[key] + + def __len__(self) -> int: + return len(self._state) + + def __iter__(self) -> Iterator[str]: + return iter(self._state) + + ######## + + @reify + def secure(self) -> bool: + """A bool indicating if the request is handled with SSL.""" + return self.scheme == "https" + + @reify + def forwarded(self) -> Tuple[Mapping[str, str], ...]: + """A tuple containing all parsed Forwarded header(s). + + Makes an effort to parse Forwarded headers as specified by RFC 7239: + + - It adds one (immutable) dictionary per Forwarded 'field-value', ie + per proxy. The element corresponds to the data in the Forwarded + field-value added by the first proxy encountered by the client. Each + subsequent item corresponds to those added by later proxies. + - It checks that every value has valid syntax in general as specified + in section 4: either a 'token' or a 'quoted-string'. + - It un-escapes found escape sequences. + - It does NOT validate 'by' and 'for' contents as specified in section + 6. + - It does NOT validate 'host' contents (Host ABNF). + - It does NOT validate 'proto' contents for valid URI scheme names. + + Returns a tuple containing one or more immutable dicts + """ + elems = [] + for field_value in self._message.headers.getall(hdrs.FORWARDED, ()): + length = len(field_value) + pos = 0 + need_separator = False + elem: Dict[str, str] = {} + elems.append(types.MappingProxyType(elem)) + while 0 <= pos < length: + match = _FORWARDED_PAIR_RE.match(field_value, pos) + if match is not None: # got a valid forwarded-pair + if need_separator: + # bad syntax here, skip to next comma + pos = field_value.find(",", pos) + else: + name, value, port = match.groups() + if value[0] == '"': + # quoted string: remove quotes and unescape + value = _QUOTED_PAIR_REPLACE_RE.sub(r"\1", value[1:-1]) + if port: + value += port + elem[name.lower()] = value + pos += len(match.group(0)) + need_separator = True + elif field_value[pos] == ",": # next forwarded-element + need_separator = False + elem = {} + elems.append(types.MappingProxyType(elem)) + pos += 1 + elif field_value[pos] == ";": # next forwarded-pair + need_separator = False + pos += 1 + elif field_value[pos] in " \t": + # Allow whitespace even between forwarded-pairs, though + # RFC 7239 doesn't. This simplifies code and is in line + # with Postel's law. + pos += 1 + else: + # bad syntax here, skip to next comma + pos = field_value.find(",", pos) + return tuple(elems) + + @reify + def scheme(self) -> str: + """A string representing the scheme of the request. + + Hostname is resolved in this order: + + - overridden value by .clone(scheme=new_scheme) call. + - type of connection to peer: HTTPS if socket is SSL, HTTP otherwise. + + 'http' or 'https'. + """ + if self._transport_sslcontext: + return "https" + else: + return "http" + + @reify + def method(self) -> str: + """Read only property for getting HTTP method. + + The value is upper-cased str like 'GET', 'POST', 'PUT' etc. + """ + return self._method + + @reify + def version(self) -> HttpVersion: + """Read only property for getting HTTP version of request. + + Returns aiohttp.protocol.HttpVersion instance. + """ + return self._version + + @reify + def host(self) -> str: + """Hostname of the request. + + Hostname is resolved in this order: + + - overridden value by .clone(host=new_host) call. + - HOST HTTP header + - socket.getfqdn() value + """ + host = self._message.headers.get(hdrs.HOST) + if host is not None: + return host + return socket.getfqdn() + + @reify + def remote(self) -> Optional[str]: + """Remote IP of client initiated HTTP request. + + The IP is resolved in this order: + + - overridden value by .clone(remote=new_remote) call. + - peername of opened socket + """ + if self._transport_peername is None: + return None + if isinstance(self._transport_peername, (list, tuple)): + return str(self._transport_peername[0]) + return str(self._transport_peername) + + @reify + def url(self) -> URL: + url = URL.build(scheme=self.scheme, host=self.host) + return url.join(self._rel_url) + + @reify + def path(self) -> str: + """The URL including *PATH INFO* without the host or scheme. + + E.g., ``/app/blog`` + """ + return self._rel_url.path + + @reify + def path_qs(self) -> str: + """The URL including PATH_INFO and the query string. + + E.g, /app/blog?id=10 + """ + return str(self._rel_url) + + @reify + def raw_path(self) -> str: + """The URL including raw *PATH INFO* without the host or scheme. + + Warning, the path is unquoted and may contains non valid URL characters + + E.g., ``/my%2Fpath%7Cwith%21some%25strange%24characters`` + """ + return self._message.path + + @reify + def query(self) -> "MultiMapping[str]": + """A multidict with all the variables in the query string.""" + return MultiDictProxy(self._rel_url.query) + + @reify + def query_string(self) -> str: + """The query string in the URL. + + E.g., id=10 + """ + return self._rel_url.query_string + + @reify + def headers(self) -> "MultiMapping[str]": + """A case-insensitive multidict proxy with all headers.""" + return self._headers + + @reify + def raw_headers(self) -> RawHeaders: + """A sequence of pairs for all headers.""" + return self._message.raw_headers + + @reify + def if_modified_since(self) -> Optional[datetime.datetime]: + """The value of If-Modified-Since HTTP header, or None. + + This header is represented as a `datetime` object. + """ + return parse_http_date(self.headers.get(hdrs.IF_MODIFIED_SINCE)) + + @reify + def if_unmodified_since(self) -> Optional[datetime.datetime]: + """The value of If-Unmodified-Since HTTP header, or None. + + This header is represented as a `datetime` object. + """ + return parse_http_date(self.headers.get(hdrs.IF_UNMODIFIED_SINCE)) + + @staticmethod + def _etag_values(etag_header: str) -> Iterator[ETag]: + """Extract `ETag` objects from raw header.""" + if etag_header == ETAG_ANY: + yield ETag( + is_weak=False, + value=ETAG_ANY, + ) + else: + for match in LIST_QUOTED_ETAG_RE.finditer(etag_header): + is_weak, value, garbage = match.group(2, 3, 4) + # Any symbol captured by 4th group means + # that the following sequence is invalid. + if garbage: + break + + yield ETag( + is_weak=bool(is_weak), + value=value, + ) + + @classmethod + def _if_match_or_none_impl( + cls, header_value: Optional[str] + ) -> Optional[Tuple[ETag, ...]]: + if not header_value: + return None + + return tuple(cls._etag_values(header_value)) + + @reify + def if_match(self) -> Optional[Tuple[ETag, ...]]: + """The value of If-Match HTTP header, or None. + + This header is represented as a `tuple` of `ETag` objects. + """ + return self._if_match_or_none_impl(self.headers.get(hdrs.IF_MATCH)) + + @reify + def if_none_match(self) -> Optional[Tuple[ETag, ...]]: + """The value of If-None-Match HTTP header, or None. + + This header is represented as a `tuple` of `ETag` objects. + """ + return self._if_match_or_none_impl(self.headers.get(hdrs.IF_NONE_MATCH)) + + @reify + def if_range(self) -> Optional[datetime.datetime]: + """The value of If-Range HTTP header, or None. + + This header is represented as a `datetime` object. + """ + return parse_http_date(self.headers.get(hdrs.IF_RANGE)) + + @reify + def keep_alive(self) -> bool: + """Is keepalive enabled by client?""" + return not self._message.should_close + + @reify + def cookies(self) -> Mapping[str, str]: + """Return request cookies. + + A read-only dictionary-like object. + """ + raw = self.headers.get(hdrs.COOKIE, "") + parsed = SimpleCookie(raw) + return MappingProxyType({key: val.value for key, val in parsed.items()}) + + @reify + def http_range(self) -> slice: + """The content of Range HTTP header. + + Return a slice instance. + + """ + rng = self._headers.get(hdrs.RANGE) + start, end = None, None + if rng is not None: + try: + pattern = r"^bytes=(\d*)-(\d*)$" + start, end = re.findall(pattern, rng)[0] + except IndexError: # pattern was not found in header + raise ValueError("range not in acceptable format") + + end = int(end) if end else None + start = int(start) if start else None + + if start is None and end is not None: + # end with no start is to return tail of content + start = -end + end = None + + if start is not None and end is not None: + # end is inclusive in range header, exclusive for slice + end += 1 + + if start >= end: + raise ValueError("start cannot be after end") + + if start is end is None: # No valid range supplied + raise ValueError("No start or end of range specified") + + return slice(start, end, 1) + + @reify + def content(self) -> StreamReader: + """Return raw payload stream.""" + return self._payload + + @property + def has_body(self) -> bool: + """Return True if request's HTTP BODY can be read, False otherwise.""" + warnings.warn( + "Deprecated, use .can_read_body #2005", DeprecationWarning, stacklevel=2 + ) + return not self._payload.at_eof() + + @property + def can_read_body(self) -> bool: + """Return True if request's HTTP BODY can be read, False otherwise.""" + return not self._payload.at_eof() + + @reify + def body_exists(self) -> bool: + """Return True if request has HTTP BODY, False otherwise.""" + return type(self._payload) is not EmptyStreamReader + + async def release(self) -> None: + """Release request. + + Eat unread part of HTTP BODY if present. + """ + while not self._payload.at_eof(): + await self._payload.readany() + + async def read(self) -> bytes: + """Read request body if present. + + Returns bytes object with full request content. + """ + if self._read_bytes is None: + body = bytearray() + while True: + chunk = await self._payload.readany() + body.extend(chunk) + if self._client_max_size: + body_size = len(body) + if body_size >= self._client_max_size: + raise HTTPRequestEntityTooLarge( + max_size=self._client_max_size, actual_size=body_size + ) + if not chunk: + break + self._read_bytes = bytes(body) + return self._read_bytes + + async def text(self) -> str: + """Return BODY as text using encoding from .charset.""" + bytes_body = await self.read() + encoding = self.charset or "utf-8" + return bytes_body.decode(encoding) + + async def json(self, *, loads: JSONDecoder = DEFAULT_JSON_DECODER) -> Any: + """Return BODY as JSON.""" + body = await self.text() + return loads(body) + + async def multipart(self) -> MultipartReader: + """Return async iterator to process BODY as multipart.""" + return MultipartReader(self._headers, self._payload) + + async def post(self) -> "MultiDictProxy[Union[str, bytes, FileField]]": + """Return POST parameters.""" + if self._post is not None: + return self._post + if self._method not in self.POST_METHODS: + self._post = MultiDictProxy(MultiDict()) + return self._post + + content_type = self.content_type + if content_type not in ( + "", + "application/x-www-form-urlencoded", + "multipart/form-data", + ): + self._post = MultiDictProxy(MultiDict()) + return self._post + + out: MultiDict[Union[str, bytes, FileField]] = MultiDict() + + if content_type == "multipart/form-data": + multipart = await self.multipart() + max_size = self._client_max_size + + field = await multipart.next() + while field is not None: + size = 0 + field_ct = field.headers.get(hdrs.CONTENT_TYPE) + + if isinstance(field, BodyPartReader): + assert field.name is not None + + # Note that according to RFC 7578, the Content-Type header + # is optional, even for files, so we can't assume it's + # present. + # https://tools.ietf.org/html/rfc7578#section-4.4 + if field.filename: + # store file in temp file + tmp = await self._loop.run_in_executor( + None, tempfile.TemporaryFile + ) + chunk = await field.read_chunk(size=2**16) + while chunk: + chunk = field.decode(chunk) + await self._loop.run_in_executor(None, tmp.write, chunk) + size += len(chunk) + if 0 < max_size < size: + await self._loop.run_in_executor(None, tmp.close) + raise HTTPRequestEntityTooLarge( + max_size=max_size, actual_size=size + ) + chunk = await field.read_chunk(size=2**16) + await self._loop.run_in_executor(None, tmp.seek, 0) + + if field_ct is None: + field_ct = "application/octet-stream" + + ff = FileField( + field.name, + field.filename, + cast(io.BufferedReader, tmp), + field_ct, + field.headers, + ) + out.add(field.name, ff) + else: + # deal with ordinary data + value = await field.read(decode=True) + if field_ct is None or field_ct.startswith("text/"): + charset = field.get_charset(default="utf-8") + out.add(field.name, value.decode(charset)) + else: + out.add(field.name, value) + size += len(value) + if 0 < max_size < size: + raise HTTPRequestEntityTooLarge( + max_size=max_size, actual_size=size + ) + else: + raise ValueError( + "To decode nested multipart you need " "to use custom reader", + ) + + field = await multipart.next() + else: + data = await self.read() + if data: + charset = self.charset or "utf-8" + out.extend( + parse_qsl( + data.rstrip().decode(charset), + keep_blank_values=True, + encoding=charset, + ) + ) + + self._post = MultiDictProxy(out) + return self._post + + def get_extra_info(self, name: str, default: Any = None) -> Any: + """Extra info from protocol transport""" + protocol = self._protocol + if protocol is None: + return default + + transport = protocol.transport + if transport is None: + return default + + return transport.get_extra_info(name, default) + + def __repr__(self) -> str: + ascii_encodable_path = self.path.encode("ascii", "backslashreplace").decode( + "ascii" + ) + return "<{} {} {} >".format( + self.__class__.__name__, self._method, ascii_encodable_path + ) + + def __eq__(self, other: object) -> bool: + return id(self) == id(other) + + def __bool__(self) -> bool: + return True + + async def _prepare_hook(self, response: StreamResponse) -> None: + return + + def _cancel(self, exc: BaseException) -> None: + set_exception(self._payload, exc) + + +class Request(BaseRequest): + + ATTRS = BaseRequest.ATTRS | frozenset(["_match_info"]) + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + # matchdict, route_name, handler + # or information about traversal lookup + + # initialized after route resolving + self._match_info: Optional[UrlMappingMatchInfo] = None + + if DEBUG: + + def __setattr__(self, name: str, val: Any) -> None: + if name not in self.ATTRS: + warnings.warn( + "Setting custom {}.{} attribute " + "is discouraged".format(self.__class__.__name__, name), + DeprecationWarning, + stacklevel=2, + ) + super().__setattr__(name, val) + + def clone( + self, + *, + method: Union[str, _SENTINEL] = sentinel, + rel_url: Union[StrOrURL, _SENTINEL] = sentinel, + headers: Union[LooseHeaders, _SENTINEL] = sentinel, + scheme: Union[str, _SENTINEL] = sentinel, + host: Union[str, _SENTINEL] = sentinel, + remote: Union[str, _SENTINEL] = sentinel, + client_max_size: Union[int, _SENTINEL] = sentinel, + ) -> "Request": + ret = super().clone( + method=method, + rel_url=rel_url, + headers=headers, + scheme=scheme, + host=host, + remote=remote, + client_max_size=client_max_size, + ) + new_ret = cast(Request, ret) + new_ret._match_info = self._match_info + return new_ret + + @reify + def match_info(self) -> "UrlMappingMatchInfo": + """Result of route resolving.""" + match_info = self._match_info + assert match_info is not None + return match_info + + @property + def app(self) -> "Application": + """Application instance.""" + match_info = self._match_info + assert match_info is not None + return match_info.current_app + + @property + def config_dict(self) -> ChainMapProxy: + match_info = self._match_info + assert match_info is not None + lst = match_info.apps + app = self.app + idx = lst.index(app) + sublist = list(reversed(lst[: idx + 1])) + return ChainMapProxy(sublist) + + async def _prepare_hook(self, response: StreamResponse) -> None: + match_info = self._match_info + if match_info is None: + return + for app in match_info._apps: + await app.on_response_prepare.send(self, response) diff --git a/llm/Lib/site-packages/aiohttp/web_response.py b/llm/Lib/site-packages/aiohttp/web_response.py new file mode 100644 index 0000000000000000000000000000000000000000..5f61f60b2cf227dce9df1ba32ff34fdd2d3aafce --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/web_response.py @@ -0,0 +1,819 @@ +import asyncio +import collections.abc +import datetime +import enum +import json +import math +import time +import warnings +from concurrent.futures import Executor +from http import HTTPStatus +from http.cookies import SimpleCookie +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterator, + MutableMapping, + Optional, + Union, + cast, +) + +from multidict import CIMultiDict, istr + +from . import hdrs, payload +from .abc import AbstractStreamWriter +from .compression_utils import ZLibCompressor +from .helpers import ( + ETAG_ANY, + QUOTED_ETAG_RE, + ETag, + HeadersMixin, + must_be_empty_body, + parse_http_date, + rfc822_formatted_time, + sentinel, + should_remove_content_length, + validate_etag_value, +) +from .http import SERVER_SOFTWARE, HttpVersion10, HttpVersion11 +from .payload import Payload +from .typedefs import JSONEncoder, LooseHeaders + +__all__ = ("ContentCoding", "StreamResponse", "Response", "json_response") + + +if TYPE_CHECKING: + from .web_request import BaseRequest + + BaseClass = MutableMapping[str, Any] +else: + BaseClass = collections.abc.MutableMapping + + +class ContentCoding(enum.Enum): + # The content codings that we have support for. + # + # Additional registered codings are listed at: + # https://www.iana.org/assignments/http-parameters/http-parameters.xhtml#content-coding + deflate = "deflate" + gzip = "gzip" + identity = "identity" + + +############################################################ +# HTTP Response classes +############################################################ + + +class StreamResponse(BaseClass, HeadersMixin): + + _length_check = True + + def __init__( + self, + *, + status: int = 200, + reason: Optional[str] = None, + headers: Optional[LooseHeaders] = None, + ) -> None: + self._body = None + self._keep_alive: Optional[bool] = None + self._chunked = False + self._compression = False + self._compression_force: Optional[ContentCoding] = None + self._cookies = SimpleCookie() + + self._req: Optional[BaseRequest] = None + self._payload_writer: Optional[AbstractStreamWriter] = None + self._eof_sent = False + self._must_be_empty_body: Optional[bool] = None + self._body_length = 0 + self._state: Dict[str, Any] = {} + + if headers is not None: + self._headers: CIMultiDict[str] = CIMultiDict(headers) + else: + self._headers = CIMultiDict() + + self.set_status(status, reason) + + @property + def prepared(self) -> bool: + return self._payload_writer is not None + + @property + def task(self) -> "Optional[asyncio.Task[None]]": + if self._req: + return self._req.task + else: + return None + + @property + def status(self) -> int: + return self._status + + @property + def chunked(self) -> bool: + return self._chunked + + @property + def compression(self) -> bool: + return self._compression + + @property + def reason(self) -> str: + return self._reason + + def set_status( + self, + status: int, + reason: Optional[str] = None, + ) -> None: + assert not self.prepared, ( + "Cannot change the response status code after " "the headers have been sent" + ) + self._status = int(status) + if reason is None: + try: + reason = HTTPStatus(self._status).phrase + except ValueError: + reason = "" + self._reason = reason + + @property + def keep_alive(self) -> Optional[bool]: + return self._keep_alive + + def force_close(self) -> None: + self._keep_alive = False + + @property + def body_length(self) -> int: + return self._body_length + + @property + def output_length(self) -> int: + warnings.warn("output_length is deprecated", DeprecationWarning) + assert self._payload_writer + return self._payload_writer.buffer_size + + def enable_chunked_encoding(self, chunk_size: Optional[int] = None) -> None: + """Enables automatic chunked transfer encoding.""" + self._chunked = True + + if hdrs.CONTENT_LENGTH in self._headers: + raise RuntimeError( + "You can't enable chunked encoding when " "a content length is set" + ) + if chunk_size is not None: + warnings.warn("Chunk size is deprecated #1615", DeprecationWarning) + + def enable_compression( + self, force: Optional[Union[bool, ContentCoding]] = None + ) -> None: + """Enables response compression encoding.""" + # Backwards compatibility for when force was a bool <0.17. + if type(force) == bool: + force = ContentCoding.deflate if force else ContentCoding.identity + warnings.warn( + "Using boolean for force is deprecated #3318", DeprecationWarning + ) + elif force is not None: + assert isinstance(force, ContentCoding), ( + "force should one of " "None, bool or " "ContentEncoding" + ) + + self._compression = True + self._compression_force = force + + @property + def headers(self) -> "CIMultiDict[str]": + return self._headers + + @property + def cookies(self) -> SimpleCookie: + return self._cookies + + def set_cookie( + self, + name: str, + value: str, + *, + expires: Optional[str] = None, + domain: Optional[str] = None, + max_age: Optional[Union[int, str]] = None, + path: str = "/", + secure: Optional[bool] = None, + httponly: Optional[bool] = None, + version: Optional[str] = None, + samesite: Optional[str] = None, + ) -> None: + """Set or update response cookie. + + Sets new cookie or updates existent with new value. + Also updates only those params which are not None. + """ + old = self._cookies.get(name) + if old is not None and old.coded_value == "": + # deleted cookie + self._cookies.pop(name, None) + + self._cookies[name] = value + c = self._cookies[name] + + if expires is not None: + c["expires"] = expires + elif c.get("expires") == "Thu, 01 Jan 1970 00:00:00 GMT": + del c["expires"] + + if domain is not None: + c["domain"] = domain + + if max_age is not None: + c["max-age"] = str(max_age) + elif "max-age" in c: + del c["max-age"] + + c["path"] = path + + if secure is not None: + c["secure"] = secure + if httponly is not None: + c["httponly"] = httponly + if version is not None: + c["version"] = version + if samesite is not None: + c["samesite"] = samesite + + def del_cookie( + self, name: str, *, domain: Optional[str] = None, path: str = "/" + ) -> None: + """Delete cookie. + + Creates new empty expired cookie. + """ + # TODO: do we need domain/path here? + self._cookies.pop(name, None) + self.set_cookie( + name, + "", + max_age=0, + expires="Thu, 01 Jan 1970 00:00:00 GMT", + domain=domain, + path=path, + ) + + @property + def content_length(self) -> Optional[int]: + # Just a placeholder for adding setter + return super().content_length + + @content_length.setter + def content_length(self, value: Optional[int]) -> None: + if value is not None: + value = int(value) + if self._chunked: + raise RuntimeError( + "You can't set content length when " "chunked encoding is enable" + ) + self._headers[hdrs.CONTENT_LENGTH] = str(value) + else: + self._headers.pop(hdrs.CONTENT_LENGTH, None) + + @property + def content_type(self) -> str: + # Just a placeholder for adding setter + return super().content_type + + @content_type.setter + def content_type(self, value: str) -> None: + self.content_type # read header values if needed + self._content_type = str(value) + self._generate_content_type_header() + + @property + def charset(self) -> Optional[str]: + # Just a placeholder for adding setter + return super().charset + + @charset.setter + def charset(self, value: Optional[str]) -> None: + ctype = self.content_type # read header values if needed + if ctype == "application/octet-stream": + raise RuntimeError( + "Setting charset for application/octet-stream " + "doesn't make sense, setup content_type first" + ) + assert self._content_dict is not None + if value is None: + self._content_dict.pop("charset", None) + else: + self._content_dict["charset"] = str(value).lower() + self._generate_content_type_header() + + @property + def last_modified(self) -> Optional[datetime.datetime]: + """The value of Last-Modified HTTP header, or None. + + This header is represented as a `datetime` object. + """ + return parse_http_date(self._headers.get(hdrs.LAST_MODIFIED)) + + @last_modified.setter + def last_modified( + self, value: Optional[Union[int, float, datetime.datetime, str]] + ) -> None: + if value is None: + self._headers.pop(hdrs.LAST_MODIFIED, None) + elif isinstance(value, (int, float)): + self._headers[hdrs.LAST_MODIFIED] = time.strftime( + "%a, %d %b %Y %H:%M:%S GMT", time.gmtime(math.ceil(value)) + ) + elif isinstance(value, datetime.datetime): + self._headers[hdrs.LAST_MODIFIED] = time.strftime( + "%a, %d %b %Y %H:%M:%S GMT", value.utctimetuple() + ) + elif isinstance(value, str): + self._headers[hdrs.LAST_MODIFIED] = value + + @property + def etag(self) -> Optional[ETag]: + quoted_value = self._headers.get(hdrs.ETAG) + if not quoted_value: + return None + elif quoted_value == ETAG_ANY: + return ETag(value=ETAG_ANY) + match = QUOTED_ETAG_RE.fullmatch(quoted_value) + if not match: + return None + is_weak, value = match.group(1, 2) + return ETag( + is_weak=bool(is_weak), + value=value, + ) + + @etag.setter + def etag(self, value: Optional[Union[ETag, str]]) -> None: + if value is None: + self._headers.pop(hdrs.ETAG, None) + elif (isinstance(value, str) and value == ETAG_ANY) or ( + isinstance(value, ETag) and value.value == ETAG_ANY + ): + self._headers[hdrs.ETAG] = ETAG_ANY + elif isinstance(value, str): + validate_etag_value(value) + self._headers[hdrs.ETAG] = f'"{value}"' + elif isinstance(value, ETag) and isinstance(value.value, str): + validate_etag_value(value.value) + hdr_value = f'W/"{value.value}"' if value.is_weak else f'"{value.value}"' + self._headers[hdrs.ETAG] = hdr_value + else: + raise ValueError( + f"Unsupported etag type: {type(value)}. " + f"etag must be str, ETag or None" + ) + + def _generate_content_type_header( + self, CONTENT_TYPE: istr = hdrs.CONTENT_TYPE + ) -> None: + assert self._content_dict is not None + assert self._content_type is not None + params = "; ".join(f"{k}={v}" for k, v in self._content_dict.items()) + if params: + ctype = self._content_type + "; " + params + else: + ctype = self._content_type + self._headers[CONTENT_TYPE] = ctype + + async def _do_start_compression(self, coding: ContentCoding) -> None: + if coding != ContentCoding.identity: + assert self._payload_writer is not None + self._headers[hdrs.CONTENT_ENCODING] = coding.value + self._payload_writer.enable_compression(coding.value) + # Compressed payload may have different content length, + # remove the header + self._headers.popall(hdrs.CONTENT_LENGTH, None) + + async def _start_compression(self, request: "BaseRequest") -> None: + if self._compression_force: + await self._do_start_compression(self._compression_force) + else: + # Encoding comparisons should be case-insensitive + # https://www.rfc-editor.org/rfc/rfc9110#section-8.4.1 + accept_encoding = request.headers.get(hdrs.ACCEPT_ENCODING, "").lower() + for coding in ContentCoding: + if coding.value in accept_encoding: + await self._do_start_compression(coding) + return + + async def prepare(self, request: "BaseRequest") -> Optional[AbstractStreamWriter]: + if self._eof_sent: + return None + if self._payload_writer is not None: + return self._payload_writer + self._must_be_empty_body = must_be_empty_body(request.method, self.status) + return await self._start(request) + + async def _start(self, request: "BaseRequest") -> AbstractStreamWriter: + self._req = request + writer = self._payload_writer = request._payload_writer + + await self._prepare_headers() + await request._prepare_hook(self) + await self._write_headers() + + return writer + + async def _prepare_headers(self) -> None: + request = self._req + assert request is not None + writer = self._payload_writer + assert writer is not None + keep_alive = self._keep_alive + if keep_alive is None: + keep_alive = request.keep_alive + self._keep_alive = keep_alive + + version = request.version + + headers = self._headers + for cookie in self._cookies.values(): + value = cookie.output(header="")[1:] + headers.add(hdrs.SET_COOKIE, value) + + if self._compression: + await self._start_compression(request) + + if self._chunked: + if version != HttpVersion11: + raise RuntimeError( + "Using chunked encoding is forbidden " + "for HTTP/{0.major}.{0.minor}".format(request.version) + ) + if not self._must_be_empty_body: + writer.enable_chunking() + headers[hdrs.TRANSFER_ENCODING] = "chunked" + if hdrs.CONTENT_LENGTH in headers: + del headers[hdrs.CONTENT_LENGTH] + elif self._length_check: + writer.length = self.content_length + if writer.length is None: + if version >= HttpVersion11: + if not self._must_be_empty_body: + writer.enable_chunking() + headers[hdrs.TRANSFER_ENCODING] = "chunked" + elif not self._must_be_empty_body: + keep_alive = False + + # HTTP 1.1: https://tools.ietf.org/html/rfc7230#section-3.3.2 + # HTTP 1.0: https://tools.ietf.org/html/rfc1945#section-10.4 + if self._must_be_empty_body: + if hdrs.CONTENT_LENGTH in headers and should_remove_content_length( + request.method, self.status + ): + del headers[hdrs.CONTENT_LENGTH] + # https://datatracker.ietf.org/doc/html/rfc9112#section-6.1-10 + # https://datatracker.ietf.org/doc/html/rfc9112#section-6.1-13 + if hdrs.TRANSFER_ENCODING in headers: + del headers[hdrs.TRANSFER_ENCODING] + else: + headers.setdefault(hdrs.CONTENT_TYPE, "application/octet-stream") + headers.setdefault(hdrs.DATE, rfc822_formatted_time()) + headers.setdefault(hdrs.SERVER, SERVER_SOFTWARE) + + # connection header + if hdrs.CONNECTION not in headers: + if keep_alive: + if version == HttpVersion10: + headers[hdrs.CONNECTION] = "keep-alive" + else: + if version == HttpVersion11: + headers[hdrs.CONNECTION] = "close" + + async def _write_headers(self) -> None: + request = self._req + assert request is not None + writer = self._payload_writer + assert writer is not None + # status line + version = request.version + status_line = "HTTP/{}.{} {} {}".format( + version[0], version[1], self._status, self._reason + ) + await writer.write_headers(status_line, self._headers) + + async def write(self, data: bytes) -> None: + assert isinstance( + data, (bytes, bytearray, memoryview) + ), "data argument must be byte-ish (%r)" % type(data) + + if self._eof_sent: + raise RuntimeError("Cannot call write() after write_eof()") + if self._payload_writer is None: + raise RuntimeError("Cannot call write() before prepare()") + + await self._payload_writer.write(data) + + async def drain(self) -> None: + assert not self._eof_sent, "EOF has already been sent" + assert self._payload_writer is not None, "Response has not been started" + warnings.warn( + "drain method is deprecated, use await resp.write()", + DeprecationWarning, + stacklevel=2, + ) + await self._payload_writer.drain() + + async def write_eof(self, data: bytes = b"") -> None: + assert isinstance( + data, (bytes, bytearray, memoryview) + ), "data argument must be byte-ish (%r)" % type(data) + + if self._eof_sent: + return + + assert self._payload_writer is not None, "Response has not been started" + + await self._payload_writer.write_eof(data) + self._eof_sent = True + self._req = None + self._body_length = self._payload_writer.output_size + self._payload_writer = None + + def __repr__(self) -> str: + if self._eof_sent: + info = "eof" + elif self.prepared: + assert self._req is not None + info = f"{self._req.method} {self._req.path} " + else: + info = "not prepared" + return f"<{self.__class__.__name__} {self.reason} {info}>" + + def __getitem__(self, key: str) -> Any: + return self._state[key] + + def __setitem__(self, key: str, value: Any) -> None: + self._state[key] = value + + def __delitem__(self, key: str) -> None: + del self._state[key] + + def __len__(self) -> int: + return len(self._state) + + def __iter__(self) -> Iterator[str]: + return iter(self._state) + + def __hash__(self) -> int: + return hash(id(self)) + + def __eq__(self, other: object) -> bool: + return self is other + + +class Response(StreamResponse): + def __init__( + self, + *, + body: Any = None, + status: int = 200, + reason: Optional[str] = None, + text: Optional[str] = None, + headers: Optional[LooseHeaders] = None, + content_type: Optional[str] = None, + charset: Optional[str] = None, + zlib_executor_size: Optional[int] = None, + zlib_executor: Optional[Executor] = None, + ) -> None: + if body is not None and text is not None: + raise ValueError("body and text are not allowed together") + + if headers is None: + real_headers: CIMultiDict[str] = CIMultiDict() + elif not isinstance(headers, CIMultiDict): + real_headers = CIMultiDict(headers) + else: + real_headers = headers # = cast('CIMultiDict[str]', headers) + + if content_type is not None and "charset" in content_type: + raise ValueError("charset must not be in content_type " "argument") + + if text is not None: + if hdrs.CONTENT_TYPE in real_headers: + if content_type or charset: + raise ValueError( + "passing both Content-Type header and " + "content_type or charset params " + "is forbidden" + ) + else: + # fast path for filling headers + if not isinstance(text, str): + raise TypeError("text argument must be str (%r)" % type(text)) + if content_type is None: + content_type = "text/plain" + if charset is None: + charset = "utf-8" + real_headers[hdrs.CONTENT_TYPE] = content_type + "; charset=" + charset + body = text.encode(charset) + text = None + else: + if hdrs.CONTENT_TYPE in real_headers: + if content_type is not None or charset is not None: + raise ValueError( + "passing both Content-Type header and " + "content_type or charset params " + "is forbidden" + ) + else: + if content_type is not None: + if charset is not None: + content_type += "; charset=" + charset + real_headers[hdrs.CONTENT_TYPE] = content_type + + super().__init__(status=status, reason=reason, headers=real_headers) + + if text is not None: + self.text = text + else: + self.body = body + + self._compressed_body: Optional[bytes] = None + self._zlib_executor_size = zlib_executor_size + self._zlib_executor = zlib_executor + + @property + def body(self) -> Optional[Union[bytes, Payload]]: + return self._body + + @body.setter + def body(self, body: bytes) -> None: + if body is None: + self._body: Optional[bytes] = None + self._body_payload: bool = False + elif isinstance(body, (bytes, bytearray)): + self._body = body + self._body_payload = False + else: + try: + self._body = body = payload.PAYLOAD_REGISTRY.get(body) + except payload.LookupError: + raise ValueError("Unsupported body type %r" % type(body)) + + self._body_payload = True + + headers = self._headers + + # set content-type + if hdrs.CONTENT_TYPE not in headers: + headers[hdrs.CONTENT_TYPE] = body.content_type + + # copy payload headers + if body.headers: + for (key, value) in body.headers.items(): + if key not in headers: + headers[key] = value + + self._compressed_body = None + + @property + def text(self) -> Optional[str]: + if self._body is None: + return None + return self._body.decode(self.charset or "utf-8") + + @text.setter + def text(self, text: str) -> None: + assert text is None or isinstance( + text, str + ), "text argument must be str (%r)" % type(text) + + if self.content_type == "application/octet-stream": + self.content_type = "text/plain" + if self.charset is None: + self.charset = "utf-8" + + self._body = text.encode(self.charset) + self._body_payload = False + self._compressed_body = None + + @property + def content_length(self) -> Optional[int]: + if self._chunked: + return None + + if hdrs.CONTENT_LENGTH in self._headers: + return super().content_length + + if self._compressed_body is not None: + # Return length of the compressed body + return len(self._compressed_body) + elif self._body_payload: + # A payload without content length, or a compressed payload + return None + elif self._body is not None: + return len(self._body) + else: + return 0 + + @content_length.setter + def content_length(self, value: Optional[int]) -> None: + raise RuntimeError("Content length is set automatically") + + async def write_eof(self, data: bytes = b"") -> None: + if self._eof_sent: + return + if self._compressed_body is None: + body: Optional[Union[bytes, Payload]] = self._body + else: + body = self._compressed_body + assert not data, f"data arg is not supported, got {data!r}" + assert self._req is not None + assert self._payload_writer is not None + if body is not None: + if self._must_be_empty_body: + await super().write_eof() + elif self._body_payload: + payload = cast(Payload, body) + await payload.write(self._payload_writer) + await super().write_eof() + else: + await super().write_eof(cast(bytes, body)) + else: + await super().write_eof() + + async def _start(self, request: "BaseRequest") -> AbstractStreamWriter: + if should_remove_content_length(request.method, self.status): + if hdrs.CONTENT_LENGTH in self._headers: + del self._headers[hdrs.CONTENT_LENGTH] + elif not self._chunked and hdrs.CONTENT_LENGTH not in self._headers: + if self._body_payload: + size = cast(Payload, self._body).size + if size is not None: + self._headers[hdrs.CONTENT_LENGTH] = str(size) + else: + body_len = len(self._body) if self._body else "0" + # https://www.rfc-editor.org/rfc/rfc9110.html#section-8.6-7 + if body_len != "0" or ( + self.status != 304 and request.method.upper() != hdrs.METH_HEAD + ): + self._headers[hdrs.CONTENT_LENGTH] = str(body_len) + + return await super()._start(request) + + async def _do_start_compression(self, coding: ContentCoding) -> None: + if self._body_payload or self._chunked: + return await super()._do_start_compression(coding) + + if coding != ContentCoding.identity: + # Instead of using _payload_writer.enable_compression, + # compress the whole body + compressor = ZLibCompressor( + encoding=str(coding.value), + max_sync_chunk_size=self._zlib_executor_size, + executor=self._zlib_executor, + ) + assert self._body is not None + if self._zlib_executor_size is None and len(self._body) > 1024 * 1024: + warnings.warn( + "Synchronous compression of large response bodies " + f"({len(self._body)} bytes) might block the async event loop. " + "Consider providing a custom value to zlib_executor_size/" + "zlib_executor response properties or disabling compression on it." + ) + self._compressed_body = ( + await compressor.compress(self._body) + compressor.flush() + ) + assert self._compressed_body is not None + + self._headers[hdrs.CONTENT_ENCODING] = coding.value + self._headers[hdrs.CONTENT_LENGTH] = str(len(self._compressed_body)) + + +def json_response( + data: Any = sentinel, + *, + text: Optional[str] = None, + body: Optional[bytes] = None, + status: int = 200, + reason: Optional[str] = None, + headers: Optional[LooseHeaders] = None, + content_type: str = "application/json", + dumps: JSONEncoder = json.dumps, +) -> Response: + if data is not sentinel: + if text or body: + raise ValueError("only one of data, text, or body should be specified") + else: + text = dumps(data) + return Response( + text=text, + body=body, + status=status, + reason=reason, + headers=headers, + content_type=content_type, + ) diff --git a/llm/Lib/site-packages/aiohttp/web_routedef.py b/llm/Lib/site-packages/aiohttp/web_routedef.py new file mode 100644 index 0000000000000000000000000000000000000000..39f637df49eddfa7f18f76237033c06b373d8be6 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/web_routedef.py @@ -0,0 +1,216 @@ +import abc +import os # noqa +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterator, + List, + Optional, + Sequence, + Type, + Union, + overload, +) + +import attr + +from . import hdrs +from .abc import AbstractView +from .typedefs import Handler, PathLike + +if TYPE_CHECKING: + from .web_request import Request + from .web_response import StreamResponse + from .web_urldispatcher import AbstractRoute, UrlDispatcher +else: + Request = StreamResponse = UrlDispatcher = AbstractRoute = None + + +__all__ = ( + "AbstractRouteDef", + "RouteDef", + "StaticDef", + "RouteTableDef", + "head", + "options", + "get", + "post", + "patch", + "put", + "delete", + "route", + "view", + "static", +) + + +class AbstractRouteDef(abc.ABC): + @abc.abstractmethod + def register(self, router: UrlDispatcher) -> List[AbstractRoute]: + pass # pragma: no cover + + +_HandlerType = Union[Type[AbstractView], Handler] + + +@attr.s(auto_attribs=True, frozen=True, repr=False, slots=True) +class RouteDef(AbstractRouteDef): + method: str + path: str + handler: _HandlerType + kwargs: Dict[str, Any] + + def __repr__(self) -> str: + info = [] + for name, value in sorted(self.kwargs.items()): + info.append(f", {name}={value!r}") + return " {handler.__name__!r}" "{info}>".format( + method=self.method, path=self.path, handler=self.handler, info="".join(info) + ) + + def register(self, router: UrlDispatcher) -> List[AbstractRoute]: + if self.method in hdrs.METH_ALL: + reg = getattr(router, "add_" + self.method.lower()) + return [reg(self.path, self.handler, **self.kwargs)] + else: + return [ + router.add_route(self.method, self.path, self.handler, **self.kwargs) + ] + + +@attr.s(auto_attribs=True, frozen=True, repr=False, slots=True) +class StaticDef(AbstractRouteDef): + prefix: str + path: PathLike + kwargs: Dict[str, Any] + + def __repr__(self) -> str: + info = [] + for name, value in sorted(self.kwargs.items()): + info.append(f", {name}={value!r}") + return " {path}" "{info}>".format( + prefix=self.prefix, path=self.path, info="".join(info) + ) + + def register(self, router: UrlDispatcher) -> List[AbstractRoute]: + resource = router.add_static(self.prefix, self.path, **self.kwargs) + routes = resource.get_info().get("routes", {}) + return list(routes.values()) + + +def route(method: str, path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef: + return RouteDef(method, path, handler, kwargs) + + +def head(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef: + return route(hdrs.METH_HEAD, path, handler, **kwargs) + + +def options(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef: + return route(hdrs.METH_OPTIONS, path, handler, **kwargs) + + +def get( + path: str, + handler: _HandlerType, + *, + name: Optional[str] = None, + allow_head: bool = True, + **kwargs: Any, +) -> RouteDef: + return route( + hdrs.METH_GET, path, handler, name=name, allow_head=allow_head, **kwargs + ) + + +def post(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef: + return route(hdrs.METH_POST, path, handler, **kwargs) + + +def put(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef: + return route(hdrs.METH_PUT, path, handler, **kwargs) + + +def patch(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef: + return route(hdrs.METH_PATCH, path, handler, **kwargs) + + +def delete(path: str, handler: _HandlerType, **kwargs: Any) -> RouteDef: + return route(hdrs.METH_DELETE, path, handler, **kwargs) + + +def view(path: str, handler: Type[AbstractView], **kwargs: Any) -> RouteDef: + return route(hdrs.METH_ANY, path, handler, **kwargs) + + +def static(prefix: str, path: PathLike, **kwargs: Any) -> StaticDef: + return StaticDef(prefix, path, kwargs) + + +_Deco = Callable[[_HandlerType], _HandlerType] + + +class RouteTableDef(Sequence[AbstractRouteDef]): + """Route definition table""" + + def __init__(self) -> None: + self._items: List[AbstractRouteDef] = [] + + def __repr__(self) -> str: + return f"" + + @overload + def __getitem__(self, index: int) -> AbstractRouteDef: + ... + + @overload + def __getitem__(self, index: slice) -> List[AbstractRouteDef]: + ... + + def __getitem__(self, index): # type: ignore[no-untyped-def] + return self._items[index] + + def __iter__(self) -> Iterator[AbstractRouteDef]: + return iter(self._items) + + def __len__(self) -> int: + return len(self._items) + + def __contains__(self, item: object) -> bool: + return item in self._items + + def route(self, method: str, path: str, **kwargs: Any) -> _Deco: + def inner(handler: _HandlerType) -> _HandlerType: + self._items.append(RouteDef(method, path, handler, kwargs)) + return handler + + return inner + + def head(self, path: str, **kwargs: Any) -> _Deco: + return self.route(hdrs.METH_HEAD, path, **kwargs) + + def get(self, path: str, **kwargs: Any) -> _Deco: + return self.route(hdrs.METH_GET, path, **kwargs) + + def post(self, path: str, **kwargs: Any) -> _Deco: + return self.route(hdrs.METH_POST, path, **kwargs) + + def put(self, path: str, **kwargs: Any) -> _Deco: + return self.route(hdrs.METH_PUT, path, **kwargs) + + def patch(self, path: str, **kwargs: Any) -> _Deco: + return self.route(hdrs.METH_PATCH, path, **kwargs) + + def delete(self, path: str, **kwargs: Any) -> _Deco: + return self.route(hdrs.METH_DELETE, path, **kwargs) + + def options(self, path: str, **kwargs: Any) -> _Deco: + return self.route(hdrs.METH_OPTIONS, path, **kwargs) + + def view(self, path: str, **kwargs: Any) -> _Deco: + return self.route(hdrs.METH_ANY, path, **kwargs) + + def static(self, prefix: str, path: PathLike, **kwargs: Any) -> None: + self._items.append(StaticDef(prefix, path, kwargs)) diff --git a/llm/Lib/site-packages/aiohttp/web_runner.py b/llm/Lib/site-packages/aiohttp/web_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..930e940d45753c9f9811ea84622bf833b5a1a206 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/web_runner.py @@ -0,0 +1,409 @@ +import asyncio +import signal +import socket +import warnings +from abc import ABC, abstractmethod +from typing import Any, Awaitable, Callable, List, Optional, Set + +from yarl import URL + +from .typedefs import PathLike +from .web_app import Application +from .web_server import Server + +try: + from ssl import SSLContext +except ImportError: + SSLContext = object # type: ignore[misc,assignment] + + +__all__ = ( + "BaseSite", + "TCPSite", + "UnixSite", + "NamedPipeSite", + "SockSite", + "BaseRunner", + "AppRunner", + "ServerRunner", + "GracefulExit", +) + + +class GracefulExit(SystemExit): + code = 1 + + +def _raise_graceful_exit() -> None: + raise GracefulExit() + + +class BaseSite(ABC): + __slots__ = ("_runner", "_ssl_context", "_backlog", "_server") + + def __init__( + self, + runner: "BaseRunner", + *, + shutdown_timeout: float = 60.0, + ssl_context: Optional[SSLContext] = None, + backlog: int = 128, + ) -> None: + if runner.server is None: + raise RuntimeError("Call runner.setup() before making a site") + if shutdown_timeout != 60.0: + msg = "shutdown_timeout should be set on BaseRunner" + warnings.warn(msg, DeprecationWarning, stacklevel=2) + runner._shutdown_timeout = shutdown_timeout + self._runner = runner + self._ssl_context = ssl_context + self._backlog = backlog + self._server: Optional[asyncio.AbstractServer] = None + + @property + @abstractmethod + def name(self) -> str: + pass # pragma: no cover + + @abstractmethod + async def start(self) -> None: + self._runner._reg_site(self) + + async def stop(self) -> None: + self._runner._check_site(self) + if self._server is not None: # Maybe not started yet + self._server.close() + + self._runner._unreg_site(self) + + +class TCPSite(BaseSite): + __slots__ = ("_host", "_port", "_reuse_address", "_reuse_port") + + def __init__( + self, + runner: "BaseRunner", + host: Optional[str] = None, + port: Optional[int] = None, + *, + shutdown_timeout: float = 60.0, + ssl_context: Optional[SSLContext] = None, + backlog: int = 128, + reuse_address: Optional[bool] = None, + reuse_port: Optional[bool] = None, + ) -> None: + super().__init__( + runner, + shutdown_timeout=shutdown_timeout, + ssl_context=ssl_context, + backlog=backlog, + ) + self._host = host + if port is None: + port = 8443 if self._ssl_context else 8080 + self._port = port + self._reuse_address = reuse_address + self._reuse_port = reuse_port + + @property + def name(self) -> str: + scheme = "https" if self._ssl_context else "http" + host = "0.0.0.0" if self._host is None else self._host + return str(URL.build(scheme=scheme, host=host, port=self._port)) + + async def start(self) -> None: + await super().start() + loop = asyncio.get_event_loop() + server = self._runner.server + assert server is not None + self._server = await loop.create_server( + server, + self._host, + self._port, + ssl=self._ssl_context, + backlog=self._backlog, + reuse_address=self._reuse_address, + reuse_port=self._reuse_port, + ) + + +class UnixSite(BaseSite): + __slots__ = ("_path",) + + def __init__( + self, + runner: "BaseRunner", + path: PathLike, + *, + shutdown_timeout: float = 60.0, + ssl_context: Optional[SSLContext] = None, + backlog: int = 128, + ) -> None: + super().__init__( + runner, + shutdown_timeout=shutdown_timeout, + ssl_context=ssl_context, + backlog=backlog, + ) + self._path = path + + @property + def name(self) -> str: + scheme = "https" if self._ssl_context else "http" + return f"{scheme}://unix:{self._path}:" + + async def start(self) -> None: + await super().start() + loop = asyncio.get_event_loop() + server = self._runner.server + assert server is not None + self._server = await loop.create_unix_server( + server, + self._path, + ssl=self._ssl_context, + backlog=self._backlog, + ) + + +class NamedPipeSite(BaseSite): + __slots__ = ("_path",) + + def __init__( + self, runner: "BaseRunner", path: str, *, shutdown_timeout: float = 60.0 + ) -> None: + loop = asyncio.get_event_loop() + if not isinstance( + loop, asyncio.ProactorEventLoop # type: ignore[attr-defined] + ): + raise RuntimeError( + "Named Pipes only available in proactor" "loop under windows" + ) + super().__init__(runner, shutdown_timeout=shutdown_timeout) + self._path = path + + @property + def name(self) -> str: + return self._path + + async def start(self) -> None: + await super().start() + loop = asyncio.get_event_loop() + server = self._runner.server + assert server is not None + _server = await loop.start_serving_pipe( # type: ignore[attr-defined] + server, self._path + ) + self._server = _server[0] + + +class SockSite(BaseSite): + __slots__ = ("_sock", "_name") + + def __init__( + self, + runner: "BaseRunner", + sock: socket.socket, + *, + shutdown_timeout: float = 60.0, + ssl_context: Optional[SSLContext] = None, + backlog: int = 128, + ) -> None: + super().__init__( + runner, + shutdown_timeout=shutdown_timeout, + ssl_context=ssl_context, + backlog=backlog, + ) + self._sock = sock + scheme = "https" if self._ssl_context else "http" + if hasattr(socket, "AF_UNIX") and sock.family == socket.AF_UNIX: + name = f"{scheme}://unix:{sock.getsockname()}:" + else: + host, port = sock.getsockname()[:2] + name = str(URL.build(scheme=scheme, host=host, port=port)) + self._name = name + + @property + def name(self) -> str: + return self._name + + async def start(self) -> None: + await super().start() + loop = asyncio.get_event_loop() + server = self._runner.server + assert server is not None + self._server = await loop.create_server( + server, sock=self._sock, ssl=self._ssl_context, backlog=self._backlog + ) + + +class BaseRunner(ABC): + __slots__ = ( + "shutdown_callback", + "_handle_signals", + "_kwargs", + "_server", + "_sites", + "_shutdown_timeout", + ) + + def __init__( + self, + *, + handle_signals: bool = False, + shutdown_timeout: float = 60.0, + **kwargs: Any, + ) -> None: + self.shutdown_callback: Optional[Callable[[], Awaitable[None]]] = None + self._handle_signals = handle_signals + self._kwargs = kwargs + self._server: Optional[Server] = None + self._sites: List[BaseSite] = [] + self._shutdown_timeout = shutdown_timeout + + @property + def server(self) -> Optional[Server]: + return self._server + + @property + def addresses(self) -> List[Any]: + ret: List[Any] = [] + for site in self._sites: + server = site._server + if server is not None: + sockets = server.sockets # type: ignore[attr-defined] + if sockets is not None: + for sock in sockets: + ret.append(sock.getsockname()) + return ret + + @property + def sites(self) -> Set[BaseSite]: + return set(self._sites) + + async def setup(self) -> None: + loop = asyncio.get_event_loop() + + if self._handle_signals: + try: + loop.add_signal_handler(signal.SIGINT, _raise_graceful_exit) + loop.add_signal_handler(signal.SIGTERM, _raise_graceful_exit) + except NotImplementedError: # pragma: no cover + # add_signal_handler is not implemented on Windows + pass + + self._server = await self._make_server() + + @abstractmethod + async def shutdown(self) -> None: + """Call any shutdown hooks to help server close gracefully.""" + + async def cleanup(self) -> None: + # The loop over sites is intentional, an exception on gather() + # leaves self._sites in unpredictable state. + # The loop guaranties that a site is either deleted on success or + # still present on failure + for site in list(self._sites): + await site.stop() + + if self._server: # If setup succeeded + # Yield to event loop to ensure incoming requests prior to stopping the sites + # have all started to be handled before we proceed to close idle connections. + await asyncio.sleep(0) + self._server.pre_shutdown() + await self.shutdown() + + if self.shutdown_callback: + await self.shutdown_callback() + + await self._server.shutdown(self._shutdown_timeout) + await self._cleanup_server() + + self._server = None + if self._handle_signals: + loop = asyncio.get_running_loop() + try: + loop.remove_signal_handler(signal.SIGINT) + loop.remove_signal_handler(signal.SIGTERM) + except NotImplementedError: # pragma: no cover + # remove_signal_handler is not implemented on Windows + pass + + @abstractmethod + async def _make_server(self) -> Server: + pass # pragma: no cover + + @abstractmethod + async def _cleanup_server(self) -> None: + pass # pragma: no cover + + def _reg_site(self, site: BaseSite) -> None: + if site in self._sites: + raise RuntimeError(f"Site {site} is already registered in runner {self}") + self._sites.append(site) + + def _check_site(self, site: BaseSite) -> None: + if site not in self._sites: + raise RuntimeError(f"Site {site} is not registered in runner {self}") + + def _unreg_site(self, site: BaseSite) -> None: + if site not in self._sites: + raise RuntimeError(f"Site {site} is not registered in runner {self}") + self._sites.remove(site) + + +class ServerRunner(BaseRunner): + """Low-level web server runner""" + + __slots__ = ("_web_server",) + + def __init__( + self, web_server: Server, *, handle_signals: bool = False, **kwargs: Any + ) -> None: + super().__init__(handle_signals=handle_signals, **kwargs) + self._web_server = web_server + + async def shutdown(self) -> None: + pass + + async def _make_server(self) -> Server: + return self._web_server + + async def _cleanup_server(self) -> None: + pass + + +class AppRunner(BaseRunner): + """Web Application runner""" + + __slots__ = ("_app",) + + def __init__( + self, app: Application, *, handle_signals: bool = False, **kwargs: Any + ) -> None: + super().__init__(handle_signals=handle_signals, **kwargs) + if not isinstance(app, Application): + raise TypeError( + "The first argument should be web.Application " + "instance, got {!r}".format(app) + ) + self._app = app + + @property + def app(self) -> Application: + return self._app + + async def shutdown(self) -> None: + await self._app.shutdown() + + async def _make_server(self) -> Server: + loop = asyncio.get_event_loop() + self._app._set_loop(loop) + self._app.on_startup.freeze() + await self._app.startup() + self._app.freeze() + + return self._app._make_handler(loop=loop, **self._kwargs) + + async def _cleanup_server(self) -> None: + await self._app.cleanup() diff --git a/llm/Lib/site-packages/aiohttp/web_server.py b/llm/Lib/site-packages/aiohttp/web_server.py new file mode 100644 index 0000000000000000000000000000000000000000..31e9b79d2f8ea36dc711451b47677c37c596e1cf --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/web_server.py @@ -0,0 +1,77 @@ +"""Low level HTTP server.""" +import asyncio +from typing import Any, Awaitable, Callable, Dict, List, Optional # noqa + +from .abc import AbstractStreamWriter +from .helpers import get_running_loop +from .http_parser import RawRequestMessage +from .streams import StreamReader +from .web_protocol import RequestHandler, _RequestFactory, _RequestHandler +from .web_request import BaseRequest + +__all__ = ("Server",) + + +class Server: + def __init__( + self, + handler: _RequestHandler, + *, + request_factory: Optional[_RequestFactory] = None, + handler_cancellation: bool = False, + loop: Optional[asyncio.AbstractEventLoop] = None, + **kwargs: Any + ) -> None: + self._loop = get_running_loop(loop) + self._connections: Dict[RequestHandler, asyncio.Transport] = {} + self._kwargs = kwargs + self.requests_count = 0 + self.request_handler = handler + self.request_factory = request_factory or self._make_request + self.handler_cancellation = handler_cancellation + + @property + def connections(self) -> List[RequestHandler]: + return list(self._connections.keys()) + + def connection_made( + self, handler: RequestHandler, transport: asyncio.Transport + ) -> None: + self._connections[handler] = transport + + def connection_lost( + self, handler: RequestHandler, exc: Optional[BaseException] = None + ) -> None: + if handler in self._connections: + del self._connections[handler] + + def _make_request( + self, + message: RawRequestMessage, + payload: StreamReader, + protocol: RequestHandler, + writer: AbstractStreamWriter, + task: "asyncio.Task[None]", + ) -> BaseRequest: + return BaseRequest(message, payload, protocol, writer, task, self._loop) + + def pre_shutdown(self) -> None: + for conn in self._connections: + conn.close() + + async def shutdown(self, timeout: Optional[float] = None) -> None: + coros = (conn.shutdown(timeout) for conn in self._connections) + await asyncio.gather(*coros) + self._connections.clear() + + def __call__(self) -> RequestHandler: + try: + return RequestHandler(self, loop=self._loop, **self._kwargs) + except TypeError: + # Failsafe creation: remove all custom handler_args + kwargs = { + k: v + for k, v in self._kwargs.items() + if k in ["debug", "access_log_class"] + } + return RequestHandler(self, loop=self._loop, **kwargs) diff --git a/llm/Lib/site-packages/aiohttp/web_urldispatcher.py b/llm/Lib/site-packages/aiohttp/web_urldispatcher.py new file mode 100644 index 0000000000000000000000000000000000000000..5b3fed49d1de6d0414f2f6e0e117473a0cccb43e --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/web_urldispatcher.py @@ -0,0 +1,1234 @@ +import abc +import asyncio +import base64 +import functools +import hashlib +import html +import inspect +import keyword +import os +import re +import warnings +from contextlib import contextmanager +from functools import wraps +from pathlib import Path +from types import MappingProxyType +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Container, + Dict, + Final, + Generator, + Iterable, + Iterator, + List, + Mapping, + NoReturn, + Optional, + Pattern, + Set, + Sized, + Tuple, + Type, + TypedDict, + Union, + cast, +) + +from yarl import URL, __version__ as yarl_version # type: ignore[attr-defined] + +from . import hdrs +from .abc import AbstractMatchInfo, AbstractRouter, AbstractView +from .helpers import DEBUG +from .http import HttpVersion11 +from .typedefs import Handler, PathLike +from .web_exceptions import ( + HTTPException, + HTTPExpectationFailed, + HTTPForbidden, + HTTPMethodNotAllowed, + HTTPNotFound, +) +from .web_fileresponse import FileResponse +from .web_request import Request +from .web_response import Response, StreamResponse +from .web_routedef import AbstractRouteDef + +__all__ = ( + "UrlDispatcher", + "UrlMappingMatchInfo", + "AbstractResource", + "Resource", + "PlainResource", + "DynamicResource", + "AbstractRoute", + "ResourceRoute", + "StaticResource", + "View", +) + + +if TYPE_CHECKING: + from .web_app import Application + + BaseDict = Dict[str, str] +else: + BaseDict = dict + +YARL_VERSION: Final[Tuple[int, ...]] = tuple(map(int, yarl_version.split(".")[:2])) + +HTTP_METHOD_RE: Final[Pattern[str]] = re.compile( + r"^[0-9A-Za-z!#\$%&'\*\+\-\.\^_`\|~]+$" +) +ROUTE_RE: Final[Pattern[str]] = re.compile( + r"(\{[_a-zA-Z][^{}]*(?:\{[^{}]*\}[^{}]*)*\})" +) +PATH_SEP: Final[str] = re.escape("/") + + +_ExpectHandler = Callable[[Request], Awaitable[Optional[StreamResponse]]] +_Resolve = Tuple[Optional["UrlMappingMatchInfo"], Set[str]] + +html_escape = functools.partial(html.escape, quote=True) + + +class _InfoDict(TypedDict, total=False): + path: str + + formatter: str + pattern: Pattern[str] + + directory: Path + prefix: str + routes: Mapping[str, "AbstractRoute"] + + app: "Application" + + domain: str + + rule: "AbstractRuleMatching" + + http_exception: HTTPException + + +class AbstractResource(Sized, Iterable["AbstractRoute"]): + def __init__(self, *, name: Optional[str] = None) -> None: + self._name = name + + @property + def name(self) -> Optional[str]: + return self._name + + @property + @abc.abstractmethod + def canonical(self) -> str: + """Exposes the resource's canonical path. + + For example '/foo/bar/{name}' + + """ + + @abc.abstractmethod # pragma: no branch + def url_for(self, **kwargs: str) -> URL: + """Construct url for resource with additional params.""" + + @abc.abstractmethod # pragma: no branch + async def resolve(self, request: Request) -> _Resolve: + """Resolve resource. + + Return (UrlMappingMatchInfo, allowed_methods) pair. + """ + + @abc.abstractmethod + def add_prefix(self, prefix: str) -> None: + """Add a prefix to processed URLs. + + Required for subapplications support. + """ + + @abc.abstractmethod + def get_info(self) -> _InfoDict: + """Return a dict with additional info useful for introspection""" + + def freeze(self) -> None: + pass + + @abc.abstractmethod + def raw_match(self, path: str) -> bool: + """Perform a raw match against path""" + + +class AbstractRoute(abc.ABC): + def __init__( + self, + method: str, + handler: Union[Handler, Type[AbstractView]], + *, + expect_handler: Optional[_ExpectHandler] = None, + resource: Optional[AbstractResource] = None, + ) -> None: + + if expect_handler is None: + expect_handler = _default_expect_handler + + assert asyncio.iscoroutinefunction( + expect_handler + ), f"Coroutine is expected, got {expect_handler!r}" + + method = method.upper() + if not HTTP_METHOD_RE.match(method): + raise ValueError(f"{method} is not allowed HTTP method") + + assert callable(handler), handler + if asyncio.iscoroutinefunction(handler): + pass + elif inspect.isgeneratorfunction(handler): + warnings.warn( + "Bare generators are deprecated, " "use @coroutine wrapper", + DeprecationWarning, + ) + elif isinstance(handler, type) and issubclass(handler, AbstractView): + pass + else: + warnings.warn( + "Bare functions are deprecated, " "use async ones", DeprecationWarning + ) + + @wraps(handler) + async def handler_wrapper(request: Request) -> StreamResponse: + result = old_handler(request) + if asyncio.iscoroutine(result): + result = await result + assert isinstance(result, StreamResponse) + return result + + old_handler = handler + handler = handler_wrapper + + self._method = method + self._handler = handler + self._expect_handler = expect_handler + self._resource = resource + + @property + def method(self) -> str: + return self._method + + @property + def handler(self) -> Handler: + return self._handler + + @property + @abc.abstractmethod + def name(self) -> Optional[str]: + """Optional route's name, always equals to resource's name.""" + + @property + def resource(self) -> Optional[AbstractResource]: + return self._resource + + @abc.abstractmethod + def get_info(self) -> _InfoDict: + """Return a dict with additional info useful for introspection""" + + @abc.abstractmethod # pragma: no branch + def url_for(self, *args: str, **kwargs: str) -> URL: + """Construct url for route with additional params.""" + + async def handle_expect_header(self, request: Request) -> Optional[StreamResponse]: + return await self._expect_handler(request) + + +class UrlMappingMatchInfo(BaseDict, AbstractMatchInfo): + def __init__(self, match_dict: Dict[str, str], route: AbstractRoute): + super().__init__(match_dict) + self._route = route + self._apps: List[Application] = [] + self._current_app: Optional[Application] = None + self._frozen = False + + @property + def handler(self) -> Handler: + return self._route.handler + + @property + def route(self) -> AbstractRoute: + return self._route + + @property + def expect_handler(self) -> _ExpectHandler: + return self._route.handle_expect_header + + @property + def http_exception(self) -> Optional[HTTPException]: + return None + + def get_info(self) -> _InfoDict: # type: ignore[override] + return self._route.get_info() + + @property + def apps(self) -> Tuple["Application", ...]: + return tuple(self._apps) + + def add_app(self, app: "Application") -> None: + if self._frozen: + raise RuntimeError("Cannot change apps stack after .freeze() call") + if self._current_app is None: + self._current_app = app + self._apps.insert(0, app) + + @property + def current_app(self) -> "Application": + app = self._current_app + assert app is not None + return app + + @contextmanager + def set_current_app(self, app: "Application") -> Generator[None, None, None]: + if DEBUG: # pragma: no cover + if app not in self._apps: + raise RuntimeError( + "Expected one of the following apps {!r}, got {!r}".format( + self._apps, app + ) + ) + prev = self._current_app + self._current_app = app + try: + yield + finally: + self._current_app = prev + + def freeze(self) -> None: + self._frozen = True + + def __repr__(self) -> str: + return f"" + + +class MatchInfoError(UrlMappingMatchInfo): + def __init__(self, http_exception: HTTPException) -> None: + self._exception = http_exception + super().__init__({}, SystemRoute(self._exception)) + + @property + def http_exception(self) -> HTTPException: + return self._exception + + def __repr__(self) -> str: + return "".format( + self._exception.status, self._exception.reason + ) + + +async def _default_expect_handler(request: Request) -> None: + """Default handler for Expect header. + + Just send "100 Continue" to client. + raise HTTPExpectationFailed if value of header is not "100-continue" + """ + expect = request.headers.get(hdrs.EXPECT, "") + if request.version == HttpVersion11: + if expect.lower() == "100-continue": + await request.writer.write(b"HTTP/1.1 100 Continue\r\n\r\n") + else: + raise HTTPExpectationFailed(text="Unknown Expect: %s" % expect) + + +class Resource(AbstractResource): + def __init__(self, *, name: Optional[str] = None) -> None: + super().__init__(name=name) + self._routes: List[ResourceRoute] = [] + + def add_route( + self, + method: str, + handler: Union[Type[AbstractView], Handler], + *, + expect_handler: Optional[_ExpectHandler] = None, + ) -> "ResourceRoute": + + for route_obj in self._routes: + if route_obj.method == method or route_obj.method == hdrs.METH_ANY: + raise RuntimeError( + "Added route will never be executed, " + "method {route.method} is already " + "registered".format(route=route_obj) + ) + + route_obj = ResourceRoute(method, handler, self, expect_handler=expect_handler) + self.register_route(route_obj) + return route_obj + + def register_route(self, route: "ResourceRoute") -> None: + assert isinstance( + route, ResourceRoute + ), f"Instance of Route class is required, got {route!r}" + self._routes.append(route) + + async def resolve(self, request: Request) -> _Resolve: + allowed_methods: Set[str] = set() + + match_dict = self._match(request.rel_url.raw_path) + if match_dict is None: + return None, allowed_methods + + for route_obj in self._routes: + route_method = route_obj.method + allowed_methods.add(route_method) + + if route_method == request.method or route_method == hdrs.METH_ANY: + return (UrlMappingMatchInfo(match_dict, route_obj), allowed_methods) + else: + return None, allowed_methods + + @abc.abstractmethod + def _match(self, path: str) -> Optional[Dict[str, str]]: + pass # pragma: no cover + + def __len__(self) -> int: + return len(self._routes) + + def __iter__(self) -> Iterator["ResourceRoute"]: + return iter(self._routes) + + # TODO: implement all abstract methods + + +class PlainResource(Resource): + def __init__(self, path: str, *, name: Optional[str] = None) -> None: + super().__init__(name=name) + assert not path or path.startswith("/") + self._path = path + + @property + def canonical(self) -> str: + return self._path + + def freeze(self) -> None: + if not self._path: + self._path = "/" + + def add_prefix(self, prefix: str) -> None: + assert prefix.startswith("/") + assert not prefix.endswith("/") + assert len(prefix) > 1 + self._path = prefix + self._path + + def _match(self, path: str) -> Optional[Dict[str, str]]: + # string comparison is about 10 times faster than regexp matching + if self._path == path: + return {} + else: + return None + + def raw_match(self, path: str) -> bool: + return self._path == path + + def get_info(self) -> _InfoDict: + return {"path": self._path} + + def url_for(self) -> URL: # type: ignore[override] + return URL.build(path=self._path, encoded=True) + + def __repr__(self) -> str: + name = "'" + self.name + "' " if self.name is not None else "" + return f"" + + +class DynamicResource(Resource): + + DYN = re.compile(r"\{(?P[_a-zA-Z][_a-zA-Z0-9]*)\}") + DYN_WITH_RE = re.compile(r"\{(?P[_a-zA-Z][_a-zA-Z0-9]*):(?P.+)\}") + GOOD = r"[^{}/]+" + + def __init__(self, path: str, *, name: Optional[str] = None) -> None: + super().__init__(name=name) + pattern = "" + formatter = "" + for part in ROUTE_RE.split(path): + match = self.DYN.fullmatch(part) + if match: + pattern += "(?P<{}>{})".format(match.group("var"), self.GOOD) + formatter += "{" + match.group("var") + "}" + continue + + match = self.DYN_WITH_RE.fullmatch(part) + if match: + pattern += "(?P<{var}>{re})".format(**match.groupdict()) + formatter += "{" + match.group("var") + "}" + continue + + if "{" in part or "}" in part: + raise ValueError(f"Invalid path '{path}'['{part}']") + + part = _requote_path(part) + formatter += part + pattern += re.escape(part) + + try: + compiled = re.compile(pattern) + except re.error as exc: + raise ValueError(f"Bad pattern '{pattern}': {exc}") from None + assert compiled.pattern.startswith(PATH_SEP) + assert formatter.startswith("/") + self._pattern = compiled + self._formatter = formatter + + @property + def canonical(self) -> str: + return self._formatter + + def add_prefix(self, prefix: str) -> None: + assert prefix.startswith("/") + assert not prefix.endswith("/") + assert len(prefix) > 1 + self._pattern = re.compile(re.escape(prefix) + self._pattern.pattern) + self._formatter = prefix + self._formatter + + def _match(self, path: str) -> Optional[Dict[str, str]]: + match = self._pattern.fullmatch(path) + if match is None: + return None + else: + return { + key: _unquote_path(value) for key, value in match.groupdict().items() + } + + def raw_match(self, path: str) -> bool: + return self._formatter == path + + def get_info(self) -> _InfoDict: + return {"formatter": self._formatter, "pattern": self._pattern} + + def url_for(self, **parts: str) -> URL: + url = self._formatter.format_map({k: _quote_path(v) for k, v in parts.items()}) + return URL.build(path=url, encoded=True) + + def __repr__(self) -> str: + name = "'" + self.name + "' " if self.name is not None else "" + return "".format( + name=name, formatter=self._formatter + ) + + +class PrefixResource(AbstractResource): + def __init__(self, prefix: str, *, name: Optional[str] = None) -> None: + assert not prefix or prefix.startswith("/"), prefix + assert prefix in ("", "/") or not prefix.endswith("/"), prefix + super().__init__(name=name) + self._prefix = _requote_path(prefix) + self._prefix2 = self._prefix + "/" + + @property + def canonical(self) -> str: + return self._prefix + + def add_prefix(self, prefix: str) -> None: + assert prefix.startswith("/") + assert not prefix.endswith("/") + assert len(prefix) > 1 + self._prefix = prefix + self._prefix + self._prefix2 = self._prefix + "/" + + def raw_match(self, prefix: str) -> bool: + return False + + # TODO: impl missing abstract methods + + +class StaticResource(PrefixResource): + VERSION_KEY = "v" + + def __init__( + self, + prefix: str, + directory: PathLike, + *, + name: Optional[str] = None, + expect_handler: Optional[_ExpectHandler] = None, + chunk_size: int = 256 * 1024, + show_index: bool = False, + follow_symlinks: bool = False, + append_version: bool = False, + ) -> None: + super().__init__(prefix, name=name) + try: + directory = Path(directory) + if str(directory).startswith("~"): + directory = Path(os.path.expanduser(str(directory))) + directory = directory.resolve() + if not directory.is_dir(): + raise ValueError("Not a directory") + except (FileNotFoundError, ValueError) as error: + raise ValueError(f"No directory exists at '{directory}'") from error + self._directory = directory + self._show_index = show_index + self._chunk_size = chunk_size + self._follow_symlinks = follow_symlinks + self._expect_handler = expect_handler + self._append_version = append_version + + self._routes = { + "GET": ResourceRoute( + "GET", self._handle, self, expect_handler=expect_handler + ), + "HEAD": ResourceRoute( + "HEAD", self._handle, self, expect_handler=expect_handler + ), + } + + def url_for( # type: ignore[override] + self, + *, + filename: PathLike, + append_version: Optional[bool] = None, + ) -> URL: + if append_version is None: + append_version = self._append_version + filename = str(filename).lstrip("/") + + url = URL.build(path=self._prefix, encoded=True) + # filename is not encoded + if YARL_VERSION < (1, 6): + url = url / filename.replace("%", "%25") + else: + url = url / filename + + if append_version: + unresolved_path = self._directory.joinpath(filename) + try: + if self._follow_symlinks: + normalized_path = Path(os.path.normpath(unresolved_path)) + normalized_path.relative_to(self._directory) + filepath = normalized_path.resolve() + else: + filepath = unresolved_path.resolve() + filepath.relative_to(self._directory) + except (ValueError, FileNotFoundError): + # ValueError for case when path point to symlink + # with follow_symlinks is False + return url # relatively safe + if filepath.is_file(): + # TODO cache file content + # with file watcher for cache invalidation + with filepath.open("rb") as f: + file_bytes = f.read() + h = self._get_file_hash(file_bytes) + url = url.with_query({self.VERSION_KEY: h}) + return url + return url + + @staticmethod + def _get_file_hash(byte_array: bytes) -> str: + m = hashlib.sha256() # todo sha256 can be configurable param + m.update(byte_array) + b64 = base64.urlsafe_b64encode(m.digest()) + return b64.decode("ascii") + + def get_info(self) -> _InfoDict: + return { + "directory": self._directory, + "prefix": self._prefix, + "routes": self._routes, + } + + def set_options_route(self, handler: Handler) -> None: + if "OPTIONS" in self._routes: + raise RuntimeError("OPTIONS route was set already") + self._routes["OPTIONS"] = ResourceRoute( + "OPTIONS", handler, self, expect_handler=self._expect_handler + ) + + async def resolve(self, request: Request) -> _Resolve: + path = request.rel_url.raw_path + method = request.method + allowed_methods = set(self._routes) + if not path.startswith(self._prefix2) and path != self._prefix: + return None, set() + + if method not in allowed_methods: + return None, allowed_methods + + match_dict = {"filename": _unquote_path(path[len(self._prefix) + 1 :])} + return (UrlMappingMatchInfo(match_dict, self._routes[method]), allowed_methods) + + def __len__(self) -> int: + return len(self._routes) + + def __iter__(self) -> Iterator[AbstractRoute]: + return iter(self._routes.values()) + + async def _handle(self, request: Request) -> StreamResponse: + rel_url = request.match_info["filename"] + try: + filename = Path(rel_url) + if filename.anchor: + # rel_url is an absolute name like + # /static/\\machine_name\c$ or /static/D:\path + # where the static dir is totally different + raise HTTPForbidden() + unresolved_path = self._directory.joinpath(filename) + if self._follow_symlinks: + normalized_path = Path(os.path.normpath(unresolved_path)) + normalized_path.relative_to(self._directory) + filepath = normalized_path.resolve() + else: + filepath = unresolved_path.resolve() + filepath.relative_to(self._directory) + except (ValueError, FileNotFoundError) as error: + # relatively safe + raise HTTPNotFound() from error + except HTTPForbidden: + raise + except Exception as error: + # perm error or other kind! + request.app.logger.exception(error) + raise HTTPNotFound() from error + + # on opening a dir, load its contents if allowed + if filepath.is_dir(): + if self._show_index: + try: + return Response( + text=self._directory_as_html(filepath), content_type="text/html" + ) + except PermissionError: + raise HTTPForbidden() + else: + raise HTTPForbidden() + elif filepath.is_file(): + return FileResponse(filepath, chunk_size=self._chunk_size) + else: + raise HTTPNotFound + + def _directory_as_html(self, filepath: Path) -> str: + # returns directory's index as html + + # sanity check + assert filepath.is_dir() + + relative_path_to_dir = filepath.relative_to(self._directory).as_posix() + index_of = f"Index of /{html_escape(relative_path_to_dir)}" + h1 = f"

{index_of}

" + + index_list = [] + dir_index = filepath.iterdir() + for _file in sorted(dir_index): + # show file url as relative to static path + rel_path = _file.relative_to(self._directory).as_posix() + quoted_file_url = _quote_path(f"{self._prefix}/{rel_path}") + + # if file is a directory, add '/' to the end of the name + if _file.is_dir(): + file_name = f"{_file.name}/" + else: + file_name = _file.name + + index_list.append( + f'
  • {html_escape(file_name)}
  • ' + ) + ul = "
      \n{}\n
    ".format("\n".join(index_list)) + body = f"\n{h1}\n{ul}\n" + + head_str = f"\n{index_of}\n" + html = f"\n{head_str}\n{body}\n" + + return html + + def __repr__(self) -> str: + name = "'" + self.name + "'" if self.name is not None else "" + return " {directory!r}>".format( + name=name, path=self._prefix, directory=self._directory + ) + + +class PrefixedSubAppResource(PrefixResource): + def __init__(self, prefix: str, app: "Application") -> None: + super().__init__(prefix) + self._app = app + for resource in app.router.resources(): + resource.add_prefix(prefix) + + def add_prefix(self, prefix: str) -> None: + super().add_prefix(prefix) + for resource in self._app.router.resources(): + resource.add_prefix(prefix) + + def url_for(self, *args: str, **kwargs: str) -> URL: + raise RuntimeError(".url_for() is not supported " "by sub-application root") + + def get_info(self) -> _InfoDict: + return {"app": self._app, "prefix": self._prefix} + + async def resolve(self, request: Request) -> _Resolve: + if ( + not request.url.raw_path.startswith(self._prefix2) + and request.url.raw_path != self._prefix + ): + return None, set() + match_info = await self._app.router.resolve(request) + match_info.add_app(self._app) + if isinstance(match_info.http_exception, HTTPMethodNotAllowed): + methods = match_info.http_exception.allowed_methods + else: + methods = set() + return match_info, methods + + def __len__(self) -> int: + return len(self._app.router.routes()) + + def __iter__(self) -> Iterator[AbstractRoute]: + return iter(self._app.router.routes()) + + def __repr__(self) -> str: + return " {app!r}>".format( + prefix=self._prefix, app=self._app + ) + + +class AbstractRuleMatching(abc.ABC): + @abc.abstractmethod # pragma: no branch + async def match(self, request: Request) -> bool: + """Return bool if the request satisfies the criteria""" + + @abc.abstractmethod # pragma: no branch + def get_info(self) -> _InfoDict: + """Return a dict with additional info useful for introspection""" + + @property + @abc.abstractmethod # pragma: no branch + def canonical(self) -> str: + """Return a str""" + + +class Domain(AbstractRuleMatching): + re_part = re.compile(r"(?!-)[a-z\d-]{1,63}(? None: + super().__init__() + self._domain = self.validation(domain) + + @property + def canonical(self) -> str: + return self._domain + + def validation(self, domain: str) -> str: + if not isinstance(domain, str): + raise TypeError("Domain must be str") + domain = domain.rstrip(".").lower() + if not domain: + raise ValueError("Domain cannot be empty") + elif "://" in domain: + raise ValueError("Scheme not supported") + url = URL("http://" + domain) + assert url.raw_host is not None + if not all(self.re_part.fullmatch(x) for x in url.raw_host.split(".")): + raise ValueError("Domain not valid") + if url.port == 80: + return url.raw_host + return f"{url.raw_host}:{url.port}" + + async def match(self, request: Request) -> bool: + host = request.headers.get(hdrs.HOST) + if not host: + return False + return self.match_domain(host) + + def match_domain(self, host: str) -> bool: + return host.lower() == self._domain + + def get_info(self) -> _InfoDict: + return {"domain": self._domain} + + +class MaskDomain(Domain): + re_part = re.compile(r"(?!-)[a-z\d\*-]{1,63}(? None: + super().__init__(domain) + mask = self._domain.replace(".", r"\.").replace("*", ".*") + self._mask = re.compile(mask) + + @property + def canonical(self) -> str: + return self._mask.pattern + + def match_domain(self, host: str) -> bool: + return self._mask.fullmatch(host) is not None + + +class MatchedSubAppResource(PrefixedSubAppResource): + def __init__(self, rule: AbstractRuleMatching, app: "Application") -> None: + AbstractResource.__init__(self) + self._prefix = "" + self._app = app + self._rule = rule + + @property + def canonical(self) -> str: + return self._rule.canonical + + def get_info(self) -> _InfoDict: + return {"app": self._app, "rule": self._rule} + + async def resolve(self, request: Request) -> _Resolve: + if not await self._rule.match(request): + return None, set() + match_info = await self._app.router.resolve(request) + match_info.add_app(self._app) + if isinstance(match_info.http_exception, HTTPMethodNotAllowed): + methods = match_info.http_exception.allowed_methods + else: + methods = set() + return match_info, methods + + def __repr__(self) -> str: + return " {app!r}>" "".format(app=self._app) + + +class ResourceRoute(AbstractRoute): + """A route with resource""" + + def __init__( + self, + method: str, + handler: Union[Handler, Type[AbstractView]], + resource: AbstractResource, + *, + expect_handler: Optional[_ExpectHandler] = None, + ) -> None: + super().__init__( + method, handler, expect_handler=expect_handler, resource=resource + ) + + def __repr__(self) -> str: + return " {handler!r}".format( + method=self.method, resource=self._resource, handler=self.handler + ) + + @property + def name(self) -> Optional[str]: + if self._resource is None: + return None + return self._resource.name + + def url_for(self, *args: str, **kwargs: str) -> URL: + """Construct url for route with additional params.""" + assert self._resource is not None + return self._resource.url_for(*args, **kwargs) + + def get_info(self) -> _InfoDict: + assert self._resource is not None + return self._resource.get_info() + + +class SystemRoute(AbstractRoute): + def __init__(self, http_exception: HTTPException) -> None: + super().__init__(hdrs.METH_ANY, self._handle) + self._http_exception = http_exception + + def url_for(self, *args: str, **kwargs: str) -> URL: + raise RuntimeError(".url_for() is not allowed for SystemRoute") + + @property + def name(self) -> Optional[str]: + return None + + def get_info(self) -> _InfoDict: + return {"http_exception": self._http_exception} + + async def _handle(self, request: Request) -> StreamResponse: + raise self._http_exception + + @property + def status(self) -> int: + return self._http_exception.status + + @property + def reason(self) -> str: + return self._http_exception.reason + + def __repr__(self) -> str: + return "".format(self=self) + + +class View(AbstractView): + async def _iter(self) -> StreamResponse: + if self.request.method not in hdrs.METH_ALL: + self._raise_allowed_methods() + method: Optional[Callable[[], Awaitable[StreamResponse]]] + method = getattr(self, self.request.method.lower(), None) + if method is None: + self._raise_allowed_methods() + ret = await method() + assert isinstance(ret, StreamResponse) + return ret + + def __await__(self) -> Generator[Any, None, StreamResponse]: + return self._iter().__await__() + + def _raise_allowed_methods(self) -> NoReturn: + allowed_methods = {m for m in hdrs.METH_ALL if hasattr(self, m.lower())} + raise HTTPMethodNotAllowed(self.request.method, allowed_methods) + + +class ResourcesView(Sized, Iterable[AbstractResource], Container[AbstractResource]): + def __init__(self, resources: List[AbstractResource]) -> None: + self._resources = resources + + def __len__(self) -> int: + return len(self._resources) + + def __iter__(self) -> Iterator[AbstractResource]: + yield from self._resources + + def __contains__(self, resource: object) -> bool: + return resource in self._resources + + +class RoutesView(Sized, Iterable[AbstractRoute], Container[AbstractRoute]): + def __init__(self, resources: List[AbstractResource]): + self._routes: List[AbstractRoute] = [] + for resource in resources: + for route in resource: + self._routes.append(route) + + def __len__(self) -> int: + return len(self._routes) + + def __iter__(self) -> Iterator[AbstractRoute]: + yield from self._routes + + def __contains__(self, route: object) -> bool: + return route in self._routes + + +class UrlDispatcher(AbstractRouter, Mapping[str, AbstractResource]): + + NAME_SPLIT_RE = re.compile(r"[.:-]") + + def __init__(self) -> None: + super().__init__() + self._resources: List[AbstractResource] = [] + self._named_resources: Dict[str, AbstractResource] = {} + + async def resolve(self, request: Request) -> UrlMappingMatchInfo: + method = request.method + allowed_methods: Set[str] = set() + + for resource in self._resources: + match_dict, allowed = await resource.resolve(request) + if match_dict is not None: + return match_dict + else: + allowed_methods |= allowed + + if allowed_methods: + return MatchInfoError(HTTPMethodNotAllowed(method, allowed_methods)) + else: + return MatchInfoError(HTTPNotFound()) + + def __iter__(self) -> Iterator[str]: + return iter(self._named_resources) + + def __len__(self) -> int: + return len(self._named_resources) + + def __contains__(self, resource: object) -> bool: + return resource in self._named_resources + + def __getitem__(self, name: str) -> AbstractResource: + return self._named_resources[name] + + def resources(self) -> ResourcesView: + return ResourcesView(self._resources) + + def routes(self) -> RoutesView: + return RoutesView(self._resources) + + def named_resources(self) -> Mapping[str, AbstractResource]: + return MappingProxyType(self._named_resources) + + def register_resource(self, resource: AbstractResource) -> None: + assert isinstance( + resource, AbstractResource + ), f"Instance of AbstractResource class is required, got {resource!r}" + if self.frozen: + raise RuntimeError("Cannot register a resource into frozen router.") + + name = resource.name + + if name is not None: + parts = self.NAME_SPLIT_RE.split(name) + for part in parts: + if keyword.iskeyword(part): + raise ValueError( + f"Incorrect route name {name!r}, " + "python keywords cannot be used " + "for route name" + ) + if not part.isidentifier(): + raise ValueError( + "Incorrect route name {!r}, " + "the name should be a sequence of " + "python identifiers separated " + "by dash, dot or column".format(name) + ) + if name in self._named_resources: + raise ValueError( + "Duplicate {!r}, " + "already handled by {!r}".format(name, self._named_resources[name]) + ) + self._named_resources[name] = resource + self._resources.append(resource) + + def add_resource(self, path: str, *, name: Optional[str] = None) -> Resource: + if path and not path.startswith("/"): + raise ValueError("path should be started with / or be empty") + # Reuse last added resource if path and name are the same + if self._resources: + resource = self._resources[-1] + if resource.name == name and resource.raw_match(path): + return cast(Resource, resource) + if not ("{" in path or "}" in path or ROUTE_RE.search(path)): + resource = PlainResource(_requote_path(path), name=name) + self.register_resource(resource) + return resource + resource = DynamicResource(path, name=name) + self.register_resource(resource) + return resource + + def add_route( + self, + method: str, + path: str, + handler: Union[Handler, Type[AbstractView]], + *, + name: Optional[str] = None, + expect_handler: Optional[_ExpectHandler] = None, + ) -> AbstractRoute: + resource = self.add_resource(path, name=name) + return resource.add_route(method, handler, expect_handler=expect_handler) + + def add_static( + self, + prefix: str, + path: PathLike, + *, + name: Optional[str] = None, + expect_handler: Optional[_ExpectHandler] = None, + chunk_size: int = 256 * 1024, + show_index: bool = False, + follow_symlinks: bool = False, + append_version: bool = False, + ) -> AbstractResource: + """Add static files view. + + prefix - url prefix + path - folder with files + + """ + assert prefix.startswith("/") + if prefix.endswith("/"): + prefix = prefix[:-1] + resource = StaticResource( + prefix, + path, + name=name, + expect_handler=expect_handler, + chunk_size=chunk_size, + show_index=show_index, + follow_symlinks=follow_symlinks, + append_version=append_version, + ) + self.register_resource(resource) + return resource + + def add_head(self, path: str, handler: Handler, **kwargs: Any) -> AbstractRoute: + """Shortcut for add_route with method HEAD.""" + return self.add_route(hdrs.METH_HEAD, path, handler, **kwargs) + + def add_options(self, path: str, handler: Handler, **kwargs: Any) -> AbstractRoute: + """Shortcut for add_route with method OPTIONS.""" + return self.add_route(hdrs.METH_OPTIONS, path, handler, **kwargs) + + def add_get( + self, + path: str, + handler: Handler, + *, + name: Optional[str] = None, + allow_head: bool = True, + **kwargs: Any, + ) -> AbstractRoute: + """Shortcut for add_route with method GET. + + If allow_head is true, another + route is added allowing head requests to the same endpoint. + """ + resource = self.add_resource(path, name=name) + if allow_head: + resource.add_route(hdrs.METH_HEAD, handler, **kwargs) + return resource.add_route(hdrs.METH_GET, handler, **kwargs) + + def add_post(self, path: str, handler: Handler, **kwargs: Any) -> AbstractRoute: + """Shortcut for add_route with method POST.""" + return self.add_route(hdrs.METH_POST, path, handler, **kwargs) + + def add_put(self, path: str, handler: Handler, **kwargs: Any) -> AbstractRoute: + """Shortcut for add_route with method PUT.""" + return self.add_route(hdrs.METH_PUT, path, handler, **kwargs) + + def add_patch(self, path: str, handler: Handler, **kwargs: Any) -> AbstractRoute: + """Shortcut for add_route with method PATCH.""" + return self.add_route(hdrs.METH_PATCH, path, handler, **kwargs) + + def add_delete(self, path: str, handler: Handler, **kwargs: Any) -> AbstractRoute: + """Shortcut for add_route with method DELETE.""" + return self.add_route(hdrs.METH_DELETE, path, handler, **kwargs) + + def add_view( + self, path: str, handler: Type[AbstractView], **kwargs: Any + ) -> AbstractRoute: + """Shortcut for add_route with ANY methods for a class-based view.""" + return self.add_route(hdrs.METH_ANY, path, handler, **kwargs) + + def freeze(self) -> None: + super().freeze() + for resource in self._resources: + resource.freeze() + + def add_routes(self, routes: Iterable[AbstractRouteDef]) -> List[AbstractRoute]: + """Append routes to route table. + + Parameter should be a sequence of RouteDef objects. + + Returns a list of registered AbstractRoute instances. + """ + registered_routes = [] + for route_def in routes: + registered_routes.extend(route_def.register(self)) + return registered_routes + + +def _quote_path(value: str) -> str: + if YARL_VERSION < (1, 6): + value = value.replace("%", "%25") + return URL.build(path=value, encoded=False).raw_path + + +def _unquote_path(value: str) -> str: + return URL.build(path=value, encoded=True).path + + +def _requote_path(value: str) -> str: + # Quote non-ascii characters and other characters which must be quoted, + # but preserve existing %-sequences. + result = _quote_path(value) + if "%" in value: + result = result.replace("%25", "%") + return result diff --git a/llm/Lib/site-packages/aiohttp/web_ws.py b/llm/Lib/site-packages/aiohttp/web_ws.py new file mode 100644 index 0000000000000000000000000000000000000000..48bcb35742a23bb7d38227835206672480cf9537 --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/web_ws.py @@ -0,0 +1,539 @@ +import asyncio +import base64 +import binascii +import hashlib +import json +import sys +from typing import Any, Final, Iterable, Optional, Tuple, cast + +import attr +from multidict import CIMultiDict + +from . import hdrs +from .abc import AbstractStreamWriter +from .helpers import call_later, set_exception, set_result +from .http import ( + WS_CLOSED_MESSAGE, + WS_CLOSING_MESSAGE, + WS_KEY, + WebSocketError, + WebSocketReader, + WebSocketWriter, + WSCloseCode, + WSMessage, + WSMsgType as WSMsgType, + ws_ext_gen, + ws_ext_parse, +) +from .log import ws_logger +from .streams import EofStream, FlowControlDataQueue +from .typedefs import JSONDecoder, JSONEncoder +from .web_exceptions import HTTPBadRequest, HTTPException +from .web_request import BaseRequest +from .web_response import StreamResponse + +if sys.version_info >= (3, 11): + import asyncio as async_timeout +else: + import async_timeout + +__all__ = ( + "WebSocketResponse", + "WebSocketReady", + "WSMsgType", +) + +THRESHOLD_CONNLOST_ACCESS: Final[int] = 5 + + +@attr.s(auto_attribs=True, frozen=True, slots=True) +class WebSocketReady: + ok: bool + protocol: Optional[str] + + def __bool__(self) -> bool: + return self.ok + + +class WebSocketResponse(StreamResponse): + + _length_check = False + + def __init__( + self, + *, + timeout: float = 10.0, + receive_timeout: Optional[float] = None, + autoclose: bool = True, + autoping: bool = True, + heartbeat: Optional[float] = None, + protocols: Iterable[str] = (), + compress: bool = True, + max_msg_size: int = 4 * 1024 * 1024, + ) -> None: + super().__init__(status=101) + self._protocols = protocols + self._ws_protocol: Optional[str] = None + self._writer: Optional[WebSocketWriter] = None + self._reader: Optional[FlowControlDataQueue[WSMessage]] = None + self._closed = False + self._closing = False + self._conn_lost = 0 + self._close_code: Optional[int] = None + self._loop: Optional[asyncio.AbstractEventLoop] = None + self._waiting: Optional[asyncio.Future[bool]] = None + self._exception: Optional[BaseException] = None + self._timeout = timeout + self._receive_timeout = receive_timeout + self._autoclose = autoclose + self._autoping = autoping + self._heartbeat = heartbeat + self._heartbeat_cb: Optional[asyncio.TimerHandle] = None + if heartbeat is not None: + self._pong_heartbeat = heartbeat / 2.0 + self._pong_response_cb: Optional[asyncio.TimerHandle] = None + self._compress = compress + self._max_msg_size = max_msg_size + + def _cancel_heartbeat(self) -> None: + if self._pong_response_cb is not None: + self._pong_response_cb.cancel() + self._pong_response_cb = None + + if self._heartbeat_cb is not None: + self._heartbeat_cb.cancel() + self._heartbeat_cb = None + + def _reset_heartbeat(self) -> None: + self._cancel_heartbeat() + + if self._heartbeat is not None: + assert self._loop is not None + self._heartbeat_cb = call_later( + self._send_heartbeat, + self._heartbeat, + self._loop, + timeout_ceil_threshold=self._req._protocol._timeout_ceil_threshold + if self._req is not None + else 5, + ) + + def _send_heartbeat(self) -> None: + if self._heartbeat is not None and not self._closed: + assert self._loop is not None + # fire-and-forget a task is not perfect but maybe ok for + # sending ping. Otherwise we need a long-living heartbeat + # task in the class. + self._loop.create_task(self._writer.ping()) # type: ignore[union-attr] + + if self._pong_response_cb is not None: + self._pong_response_cb.cancel() + self._pong_response_cb = call_later( + self._pong_not_received, + self._pong_heartbeat, + self._loop, + timeout_ceil_threshold=self._req._protocol._timeout_ceil_threshold + if self._req is not None + else 5, + ) + + def _pong_not_received(self) -> None: + if self._req is not None and self._req.transport is not None: + self._closed = True + self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE) + self._exception = asyncio.TimeoutError() + + async def prepare(self, request: BaseRequest) -> AbstractStreamWriter: + # make pre-check to don't hide it by do_handshake() exceptions + if self._payload_writer is not None: + return self._payload_writer + + protocol, writer = self._pre_start(request) + payload_writer = await super().prepare(request) + assert payload_writer is not None + self._post_start(request, protocol, writer) + await payload_writer.drain() + return payload_writer + + def _handshake( + self, request: BaseRequest + ) -> Tuple["CIMultiDict[str]", str, bool, bool]: + headers = request.headers + if "websocket" != headers.get(hdrs.UPGRADE, "").lower().strip(): + raise HTTPBadRequest( + text=( + "No WebSocket UPGRADE hdr: {}\n Can " + '"Upgrade" only to "WebSocket".' + ).format(headers.get(hdrs.UPGRADE)) + ) + + if "upgrade" not in headers.get(hdrs.CONNECTION, "").lower(): + raise HTTPBadRequest( + text="No CONNECTION upgrade hdr: {}".format( + headers.get(hdrs.CONNECTION) + ) + ) + + # find common sub-protocol between client and server + protocol = None + if hdrs.SEC_WEBSOCKET_PROTOCOL in headers: + req_protocols = [ + str(proto.strip()) + for proto in headers[hdrs.SEC_WEBSOCKET_PROTOCOL].split(",") + ] + + for proto in req_protocols: + if proto in self._protocols: + protocol = proto + break + else: + # No overlap found: Return no protocol as per spec + ws_logger.warning( + "Client protocols %r don’t overlap server-known ones %r", + req_protocols, + self._protocols, + ) + + # check supported version + version = headers.get(hdrs.SEC_WEBSOCKET_VERSION, "") + if version not in ("13", "8", "7"): + raise HTTPBadRequest(text=f"Unsupported version: {version}") + + # check client handshake for validity + key = headers.get(hdrs.SEC_WEBSOCKET_KEY) + try: + if not key or len(base64.b64decode(key)) != 16: + raise HTTPBadRequest(text=f"Handshake error: {key!r}") + except binascii.Error: + raise HTTPBadRequest(text=f"Handshake error: {key!r}") from None + + accept_val = base64.b64encode( + hashlib.sha1(key.encode() + WS_KEY).digest() + ).decode() + response_headers = CIMultiDict( + { + hdrs.UPGRADE: "websocket", + hdrs.CONNECTION: "upgrade", + hdrs.SEC_WEBSOCKET_ACCEPT: accept_val, + } + ) + + notakeover = False + compress = 0 + if self._compress: + extensions = headers.get(hdrs.SEC_WEBSOCKET_EXTENSIONS) + # Server side always get return with no exception. + # If something happened, just drop compress extension + compress, notakeover = ws_ext_parse(extensions, isserver=True) + if compress: + enabledext = ws_ext_gen( + compress=compress, isserver=True, server_notakeover=notakeover + ) + response_headers[hdrs.SEC_WEBSOCKET_EXTENSIONS] = enabledext + + if protocol: + response_headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = protocol + return ( + response_headers, + protocol, + compress, + notakeover, + ) # type: ignore[return-value] + + def _pre_start(self, request: BaseRequest) -> Tuple[str, WebSocketWriter]: + self._loop = request._loop + + headers, protocol, compress, notakeover = self._handshake(request) + + self.set_status(101) + self.headers.update(headers) + self.force_close() + self._compress = compress + transport = request._protocol.transport + assert transport is not None + writer = WebSocketWriter( + request._protocol, transport, compress=compress, notakeover=notakeover + ) + + return protocol, writer + + def _post_start( + self, request: BaseRequest, protocol: str, writer: WebSocketWriter + ) -> None: + self._ws_protocol = protocol + self._writer = writer + + self._reset_heartbeat() + + loop = self._loop + assert loop is not None + self._reader = FlowControlDataQueue(request._protocol, 2**16, loop=loop) + request.protocol.set_parser( + WebSocketReader(self._reader, self._max_msg_size, compress=self._compress) + ) + # disable HTTP keepalive for WebSocket + request.protocol.keep_alive(False) + + def can_prepare(self, request: BaseRequest) -> WebSocketReady: + if self._writer is not None: + raise RuntimeError("Already started") + try: + _, protocol, _, _ = self._handshake(request) + except HTTPException: + return WebSocketReady(False, None) + else: + return WebSocketReady(True, protocol) + + @property + def closed(self) -> bool: + return self._closed + + @property + def close_code(self) -> Optional[int]: + return self._close_code + + @property + def ws_protocol(self) -> Optional[str]: + return self._ws_protocol + + @property + def compress(self) -> bool: + return self._compress + + def get_extra_info(self, name: str, default: Any = None) -> Any: + """Get optional transport information. + + If no value associated with ``name`` is found, ``default`` is returned. + """ + writer = self._writer + if writer is None: + return default + transport = writer.transport + if transport is None: + return default + return transport.get_extra_info(name, default) + + def exception(self) -> Optional[BaseException]: + return self._exception + + async def ping(self, message: bytes = b"") -> None: + if self._writer is None: + raise RuntimeError("Call .prepare() first") + await self._writer.ping(message) + + async def pong(self, message: bytes = b"") -> None: + # unsolicited pong + if self._writer is None: + raise RuntimeError("Call .prepare() first") + await self._writer.pong(message) + + async def send_str(self, data: str, compress: Optional[bool] = None) -> None: + if self._writer is None: + raise RuntimeError("Call .prepare() first") + if not isinstance(data, str): + raise TypeError("data argument must be str (%r)" % type(data)) + await self._writer.send(data, binary=False, compress=compress) + + async def send_bytes(self, data: bytes, compress: Optional[bool] = None) -> None: + if self._writer is None: + raise RuntimeError("Call .prepare() first") + if not isinstance(data, (bytes, bytearray, memoryview)): + raise TypeError("data argument must be byte-ish (%r)" % type(data)) + await self._writer.send(data, binary=True, compress=compress) + + async def send_json( + self, + data: Any, + compress: Optional[bool] = None, + *, + dumps: JSONEncoder = json.dumps, + ) -> None: + await self.send_str(dumps(data), compress=compress) + + async def write_eof(self) -> None: # type: ignore[override] + if self._eof_sent: + return + if self._payload_writer is None: + raise RuntimeError("Response has not been started") + + await self.close() + self._eof_sent = True + + async def close( + self, *, code: int = WSCloseCode.OK, message: bytes = b"", drain: bool = True + ) -> bool: + """Close websocket connection.""" + if self._writer is None: + raise RuntimeError("Call .prepare() first") + + self._cancel_heartbeat() + reader = self._reader + assert reader is not None + + # we need to break `receive()` cycle first, + # `close()` may be called from different task + if self._waiting is not None and not self._closed: + reader.feed_data(WS_CLOSING_MESSAGE, 0) + await self._waiting + + if self._closed: + return False + + self._closed = True + try: + await self._writer.close(code, message) + writer = self._payload_writer + assert writer is not None + if drain: + await writer.drain() + except (asyncio.CancelledError, asyncio.TimeoutError): + self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE) + raise + except Exception as exc: + self._exception = exc + self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE) + return True + + if self._closing: + self._close_transport() + return True + + reader = self._reader + assert reader is not None + try: + async with async_timeout.timeout(self._timeout): + msg = await reader.read() + except asyncio.CancelledError: + self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE) + raise + except Exception as exc: + self._exception = exc + self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE) + return True + + if msg.type == WSMsgType.CLOSE: + self._set_code_close_transport(msg.data) + return True + + self._set_code_close_transport(WSCloseCode.ABNORMAL_CLOSURE) + self._exception = asyncio.TimeoutError() + return True + + def _set_closing(self, code: WSCloseCode) -> None: + """Set the close code and mark the connection as closing.""" + self._closing = True + self._close_code = code + + def _set_code_close_transport(self, code: WSCloseCode) -> None: + """Set the close code and close the transport.""" + self._close_code = code + self._close_transport() + + def _close_transport(self) -> None: + """Close the transport.""" + if self._req is not None and self._req.transport is not None: + self._req.transport.close() + + async def receive(self, timeout: Optional[float] = None) -> WSMessage: + if self._reader is None: + raise RuntimeError("Call .prepare() first") + + loop = self._loop + assert loop is not None + while True: + if self._waiting is not None: + raise RuntimeError("Concurrent call to receive() is not allowed") + + if self._closed: + self._conn_lost += 1 + if self._conn_lost >= THRESHOLD_CONNLOST_ACCESS: + raise RuntimeError("WebSocket connection is closed.") + return WS_CLOSED_MESSAGE + elif self._closing: + return WS_CLOSING_MESSAGE + + try: + self._waiting = loop.create_future() + try: + async with async_timeout.timeout(timeout or self._receive_timeout): + msg = await self._reader.read() + self._reset_heartbeat() + finally: + waiter = self._waiting + set_result(waiter, True) + self._waiting = None + except asyncio.TimeoutError: + raise + except EofStream: + self._close_code = WSCloseCode.OK + await self.close() + return WSMessage(WSMsgType.CLOSED, None, None) + except WebSocketError as exc: + self._close_code = exc.code + await self.close(code=exc.code) + return WSMessage(WSMsgType.ERROR, exc, None) + except Exception as exc: + self._exception = exc + self._set_closing(WSCloseCode.ABNORMAL_CLOSURE) + await self.close() + return WSMessage(WSMsgType.ERROR, exc, None) + + if msg.type == WSMsgType.CLOSE: + self._set_closing(msg.data) + # Could be closed while awaiting reader. + if not self._closed and self._autoclose: + # The client is likely going to close the + # connection out from under us so we do not + # want to drain any pending writes as it will + # likely result writing to a broken pipe. + await self.close(drain=False) + elif msg.type == WSMsgType.CLOSING: + self._set_closing(WSCloseCode.OK) + elif msg.type == WSMsgType.PING and self._autoping: + await self.pong(msg.data) + continue + elif msg.type == WSMsgType.PONG and self._autoping: + continue + + return msg + + async def receive_str(self, *, timeout: Optional[float] = None) -> str: + msg = await self.receive(timeout) + if msg.type != WSMsgType.TEXT: + raise TypeError( + "Received message {}:{!r} is not WSMsgType.TEXT".format( + msg.type, msg.data + ) + ) + return cast(str, msg.data) + + async def receive_bytes(self, *, timeout: Optional[float] = None) -> bytes: + msg = await self.receive(timeout) + if msg.type != WSMsgType.BINARY: + raise TypeError(f"Received message {msg.type}:{msg.data!r} is not bytes") + return cast(bytes, msg.data) + + async def receive_json( + self, *, loads: JSONDecoder = json.loads, timeout: Optional[float] = None + ) -> Any: + data = await self.receive_str(timeout=timeout) + return loads(data) + + async def write(self, data: bytes) -> None: + raise RuntimeError("Cannot call .write() for websocket") + + def __aiter__(self) -> "WebSocketResponse": + return self + + async def __anext__(self) -> WSMessage: + msg = await self.receive() + if msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING, WSMsgType.CLOSED): + raise StopAsyncIteration + return msg + + def _cancel(self, exc: BaseException) -> None: + # web_protocol calls this from connection_lost + # or when the server is shutting down. + self._closing = True + if self._reader is not None: + set_exception(self._reader, exc) diff --git a/llm/Lib/site-packages/aiohttp/worker.py b/llm/Lib/site-packages/aiohttp/worker.py new file mode 100644 index 0000000000000000000000000000000000000000..aa8c40472afa3f27e5349a7a42e934cf14d991ab --- /dev/null +++ b/llm/Lib/site-packages/aiohttp/worker.py @@ -0,0 +1,247 @@ +"""Async gunicorn worker for aiohttp.web""" + +import asyncio +import os +import re +import signal +import sys +from types import FrameType +from typing import Any, Awaitable, Callable, Optional, Union # noqa + +from gunicorn.config import AccessLogFormat as GunicornAccessLogFormat +from gunicorn.workers import base + +from aiohttp import web + +from .helpers import set_result +from .web_app import Application +from .web_log import AccessLogger + +try: + import ssl + + SSLContext = ssl.SSLContext +except ImportError: # pragma: no cover + ssl = None # type: ignore[assignment] + SSLContext = object # type: ignore[misc,assignment] + + +__all__ = ("GunicornWebWorker", "GunicornUVLoopWebWorker") + + +class GunicornWebWorker(base.Worker): # type: ignore[misc,no-any-unimported] + + DEFAULT_AIOHTTP_LOG_FORMAT = AccessLogger.LOG_FORMAT + DEFAULT_GUNICORN_LOG_FORMAT = GunicornAccessLogFormat.default + + def __init__(self, *args: Any, **kw: Any) -> None: # pragma: no cover + super().__init__(*args, **kw) + + self._task: Optional[asyncio.Task[None]] = None + self.exit_code = 0 + self._notify_waiter: Optional[asyncio.Future[bool]] = None + + def init_process(self) -> None: + # create new event_loop after fork + asyncio.get_event_loop().close() + + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + super().init_process() + + def run(self) -> None: + self._task = self.loop.create_task(self._run()) + + try: # ignore all finalization problems + self.loop.run_until_complete(self._task) + except Exception: + self.log.exception("Exception in gunicorn worker") + self.loop.run_until_complete(self.loop.shutdown_asyncgens()) + self.loop.close() + + sys.exit(self.exit_code) + + async def _run(self) -> None: + runner = None + if isinstance(self.wsgi, Application): + app = self.wsgi + elif asyncio.iscoroutinefunction(self.wsgi): + wsgi = await self.wsgi() + if isinstance(wsgi, web.AppRunner): + runner = wsgi + app = runner.app + else: + app = wsgi + else: + raise RuntimeError( + "wsgi app should be either Application or " + "async function returning Application, got {}".format(self.wsgi) + ) + + if runner is None: + access_log = self.log.access_log if self.cfg.accesslog else None + runner = web.AppRunner( + app, + logger=self.log, + keepalive_timeout=self.cfg.keepalive, + access_log=access_log, + access_log_format=self._get_valid_log_format( + self.cfg.access_log_format + ), + shutdown_timeout=self.cfg.graceful_timeout / 100 * 95, + ) + await runner.setup() + + ctx = self._create_ssl_context(self.cfg) if self.cfg.is_ssl else None + + runner = runner + assert runner is not None + server = runner.server + assert server is not None + for sock in self.sockets: + site = web.SockSite( + runner, + sock, + ssl_context=ctx, + ) + await site.start() + + # If our parent changed then we shut down. + pid = os.getpid() + try: + while self.alive: # type: ignore[has-type] + self.notify() + + cnt = server.requests_count + if self.max_requests and cnt > self.max_requests: + self.alive = False + self.log.info("Max requests, shutting down: %s", self) + + elif pid == os.getpid() and self.ppid != os.getppid(): + self.alive = False + self.log.info("Parent changed, shutting down: %s", self) + else: + await self._wait_next_notify() + except BaseException: + pass + + await runner.cleanup() + + def _wait_next_notify(self) -> "asyncio.Future[bool]": + self._notify_waiter_done() + + loop = self.loop + assert loop is not None + self._notify_waiter = waiter = loop.create_future() + self.loop.call_later(1.0, self._notify_waiter_done, waiter) + + return waiter + + def _notify_waiter_done( + self, waiter: Optional["asyncio.Future[bool]"] = None + ) -> None: + if waiter is None: + waiter = self._notify_waiter + if waiter is not None: + set_result(waiter, True) + + if waiter is self._notify_waiter: + self._notify_waiter = None + + def init_signals(self) -> None: + # Set up signals through the event loop API. + + self.loop.add_signal_handler( + signal.SIGQUIT, self.handle_quit, signal.SIGQUIT, None + ) + + self.loop.add_signal_handler( + signal.SIGTERM, self.handle_exit, signal.SIGTERM, None + ) + + self.loop.add_signal_handler( + signal.SIGINT, self.handle_quit, signal.SIGINT, None + ) + + self.loop.add_signal_handler( + signal.SIGWINCH, self.handle_winch, signal.SIGWINCH, None + ) + + self.loop.add_signal_handler( + signal.SIGUSR1, self.handle_usr1, signal.SIGUSR1, None + ) + + self.loop.add_signal_handler( + signal.SIGABRT, self.handle_abort, signal.SIGABRT, None + ) + + # Don't let SIGTERM and SIGUSR1 disturb active requests + # by interrupting system calls + signal.siginterrupt(signal.SIGTERM, False) + signal.siginterrupt(signal.SIGUSR1, False) + # Reset signals so Gunicorn doesn't swallow subprocess return codes + # See: https://github.com/aio-libs/aiohttp/issues/6130 + + def handle_quit(self, sig: int, frame: Optional[FrameType]) -> None: + self.alive = False + + # worker_int callback + self.cfg.worker_int(self) + + # wakeup closing process + self._notify_waiter_done() + + def handle_abort(self, sig: int, frame: Optional[FrameType]) -> None: + self.alive = False + self.exit_code = 1 + self.cfg.worker_abort(self) + sys.exit(1) + + @staticmethod + def _create_ssl_context(cfg: Any) -> "SSLContext": + """Creates SSLContext instance for usage in asyncio.create_server. + + See ssl.SSLSocket.__init__ for more details. + """ + if ssl is None: # pragma: no cover + raise RuntimeError("SSL is not supported.") + + ctx = ssl.SSLContext(cfg.ssl_version) + ctx.load_cert_chain(cfg.certfile, cfg.keyfile) + ctx.verify_mode = cfg.cert_reqs + if cfg.ca_certs: + ctx.load_verify_locations(cfg.ca_certs) + if cfg.ciphers: + ctx.set_ciphers(cfg.ciphers) + return ctx + + def _get_valid_log_format(self, source_format: str) -> str: + if source_format == self.DEFAULT_GUNICORN_LOG_FORMAT: + return self.DEFAULT_AIOHTTP_LOG_FORMAT + elif re.search(r"%\([^\)]+\)", source_format): + raise ValueError( + "Gunicorn's style options in form of `%(name)s` are not " + "supported for the log formatting. Please use aiohttp's " + "format specification to configure access log formatting: " + "http://docs.aiohttp.org/en/stable/logging.html" + "#format-specification" + ) + else: + return source_format + + +class GunicornUVLoopWebWorker(GunicornWebWorker): + def init_process(self) -> None: + import uvloop + + # Close any existing event loop before setting a + # new policy. + asyncio.get_event_loop().close() + + # Setup uvloop policy, so that every + # asyncio.get_event_loop() will create an instance + # of uvloop event loop. + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + + super().init_process() diff --git a/llm/Lib/site-packages/aiosignal-1.3.1.dist-info/INSTALLER b/llm/Lib/site-packages/aiosignal-1.3.1.dist-info/INSTALLER new file mode 100644 index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68 --- /dev/null +++ b/llm/Lib/site-packages/aiosignal-1.3.1.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/llm/Lib/site-packages/aiosignal-1.3.1.dist-info/LICENSE b/llm/Lib/site-packages/aiosignal-1.3.1.dist-info/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..7082a2d5b9047bfc09589f387053e24ea490bc54 --- /dev/null +++ b/llm/Lib/site-packages/aiosignal-1.3.1.dist-info/LICENSE @@ -0,0 +1,201 @@ +Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2013-2019 Nikolay Kim and Andrew Svetlov + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/llm/Lib/site-packages/aiosignal-1.3.1.dist-info/METADATA b/llm/Lib/site-packages/aiosignal-1.3.1.dist-info/METADATA new file mode 100644 index 0000000000000000000000000000000000000000..fc964525f05e8e34961f0398b1930b8dec64ef26 --- /dev/null +++ b/llm/Lib/site-packages/aiosignal-1.3.1.dist-info/METADATA @@ -0,0 +1,128 @@ +Metadata-Version: 2.1 +Name: aiosignal +Version: 1.3.1 +Summary: aiosignal: a list of registered asynchronous callbacks +Home-page: https://github.com/aio-libs/aiosignal +Maintainer: aiohttp team +Maintainer-email: team@aiohttp.org +License: Apache 2.0 +Project-URL: Chat: Gitter, https://gitter.im/aio-libs/Lobby +Project-URL: CI: GitHub Actions, https://github.com/aio-libs/aiosignal/actions +Project-URL: Coverage: codecov, https://codecov.io/github/aio-libs/aiosignal +Project-URL: Docs: RTD, https://docs.aiosignal.org +Project-URL: GitHub: issues, https://github.com/aio-libs/aiosignal/issues +Project-URL: GitHub: repo, https://github.com/aio-libs/aiosignal +Classifier: License :: OSI Approved :: Apache Software License +Classifier: Intended Audience :: Developers +Classifier: Programming Language :: Python +Classifier: Programming Language :: Python :: 3 +Classifier: Programming Language :: Python :: 3 :: Only +Classifier: Programming Language :: Python :: 3.7 +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Development Status :: 5 - Production/Stable +Classifier: Operating System :: POSIX +Classifier: Operating System :: MacOS :: MacOS X +Classifier: Operating System :: Microsoft :: Windows +Classifier: Framework :: AsyncIO +Requires-Python: >=3.7 +Description-Content-Type: text/x-rst +License-File: LICENSE +Requires-Dist: frozenlist (>=1.1.0) + +========= +aiosignal +========= + +.. image:: https://github.com/aio-libs/aiosignal/workflows/CI/badge.svg + :target: https://github.com/aio-libs/aiosignal/actions?query=workflow%3ACI + :alt: GitHub status for master branch + +.. image:: https://codecov.io/gh/aio-libs/aiosignal/branch/master/graph/badge.svg + :target: https://codecov.io/gh/aio-libs/aiosignal + :alt: codecov.io status for master branch + +.. image:: https://badge.fury.io/py/aiosignal.svg + :target: https://pypi.org/project/aiosignal + :alt: Latest PyPI package version + +.. image:: https://readthedocs.org/projects/aiosignal/badge/?version=latest + :target: https://aiosignal.readthedocs.io/ + :alt: Latest Read The Docs + +.. image:: https://img.shields.io/discourse/topics?server=https%3A%2F%2Faio-libs.discourse.group%2F + :target: https://aio-libs.discourse.group/ + :alt: Discourse group for io-libs + +.. image:: https://badges.gitter.im/Join%20Chat.svg + :target: https://gitter.im/aio-libs/Lobby + :alt: Chat on Gitter + +Introduction +============ + +A project to manage callbacks in `asyncio` projects. + +``Signal`` is a list of registered asynchronous callbacks. + +The signal's life-cycle has two stages: after creation its content +could be filled by using standard list operations: ``sig.append()`` +etc. + +After you call ``sig.freeze()`` the signal is *frozen*: adding, removing +and dropping callbacks is forbidden. + +The only available operation is calling the previously registered +callbacks by using ``await sig.send(data)``. + +For concrete usage examples see the `Signals + +section of the `Web Server Advanced +` chapter of the `aiohttp +documentation`_. + + +Installation +------------ + +:: + + $ pip install aiosignal + +The library requires Python 3.6 or newer. + + +Documentation +============= + +https://aiosignal.readthedocs.io/ + +Communication channels +====================== + +*gitter chat* https://gitter.im/aio-libs/Lobby + +Requirements +============ + +- Python >= 3.6 +- frozenlist >= 1.0.0 + +License +======= + +``aiosignal`` is offered under the Apache 2 license. + +Source code +=========== + +The project is hosted on GitHub_ + +Please file an issue in the `bug tracker +`_ if you have found a bug +or have some suggestions to improve the library. + +.. _GitHub: https://github.com/aio-libs/aiosignal +.. _aiohttp documentation: https://docs.aiohttp.org/ diff --git a/llm/Lib/site-packages/aiosignal-1.3.1.dist-info/RECORD b/llm/Lib/site-packages/aiosignal-1.3.1.dist-info/RECORD new file mode 100644 index 0000000000000000000000000000000000000000..63cbabf5c681369801fdd2925ab317963da469a7 --- /dev/null +++ b/llm/Lib/site-packages/aiosignal-1.3.1.dist-info/RECORD @@ -0,0 +1,10 @@ +aiosignal-1.3.1.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +aiosignal-1.3.1.dist-info/LICENSE,sha256=b9UkPpLdf5jsacesN3co50kFcJ_1J6W_mNbQJjwE9bY,11332 +aiosignal-1.3.1.dist-info/METADATA,sha256=c0HRnlYzfXKztZPTFDlPfygizTherhG5WdwXlvco0Ug,4008 +aiosignal-1.3.1.dist-info/RECORD,, +aiosignal-1.3.1.dist-info/WHEEL,sha256=ZL1lC_LiPDNRgDnOl2taCMc83aPEUZgHHv2h-LDgdiM,92 +aiosignal-1.3.1.dist-info/top_level.txt,sha256=z45aNOKGDdrI1roqZY3BGXQ22kJFPHBmVdwtLYLtXC0,10 +aiosignal/__init__.py,sha256=zQNfFYRSd84bswvpFv8ZWjEr5DeYwV3LXbMSyo2222s,867 +aiosignal/__init__.pyi,sha256=xeCddYSS8fZAkz8S4HuKSR2IDe3N7RW_LKcXDPPA1Xk,311 +aiosignal/__pycache__/__init__.cpython-311.pyc,, +aiosignal/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 diff --git a/llm/Lib/site-packages/aiosignal-1.3.1.dist-info/WHEEL b/llm/Lib/site-packages/aiosignal-1.3.1.dist-info/WHEEL new file mode 100644 index 0000000000000000000000000000000000000000..5e1f087ca1ac49327ef76b101df80489a03c2e7f --- /dev/null +++ b/llm/Lib/site-packages/aiosignal-1.3.1.dist-info/WHEEL @@ -0,0 +1,5 @@ +Wheel-Version: 1.0 +Generator: bdist_wheel (0.38.2) +Root-Is-Purelib: true +Tag: py3-none-any + diff --git a/llm/Lib/site-packages/aiosignal-1.3.1.dist-info/top_level.txt b/llm/Lib/site-packages/aiosignal-1.3.1.dist-info/top_level.txt new file mode 100644 index 0000000000000000000000000000000000000000..ac6df3afe74a5fd43afc7ab7f8393571a495fdc5 --- /dev/null +++ b/llm/Lib/site-packages/aiosignal-1.3.1.dist-info/top_level.txt @@ -0,0 +1 @@ +aiosignal diff --git a/llm/Lib/site-packages/aiosignal/__init__.py b/llm/Lib/site-packages/aiosignal/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d288e6ede67df2bb8e5660e30372e190eb23e65 --- /dev/null +++ b/llm/Lib/site-packages/aiosignal/__init__.py @@ -0,0 +1,36 @@ +from frozenlist import FrozenList + +__version__ = "1.3.1" + +__all__ = ("Signal",) + + +class Signal(FrozenList): + """Coroutine-based signal implementation. + + To connect a callback to a signal, use any list method. + + Signals are fired using the send() coroutine, which takes named + arguments. + """ + + __slots__ = ("_owner",) + + def __init__(self, owner): + super().__init__() + self._owner = owner + + def __repr__(self): + return "".format( + self._owner, self.frozen, list(self) + ) + + async def send(self, *args, **kwargs): + """ + Sends data to all registered receivers. + """ + if not self.frozen: + raise RuntimeError("Cannot send non-frozen signal.") + + for receiver in self: + await receiver(*args, **kwargs) # type: ignore diff --git a/llm/Lib/site-packages/aiosignal/__init__.pyi b/llm/Lib/site-packages/aiosignal/__init__.pyi new file mode 100644 index 0000000000000000000000000000000000000000..d4e3416d72246058259061578a82697e2bc0706e --- /dev/null +++ b/llm/Lib/site-packages/aiosignal/__init__.pyi @@ -0,0 +1,12 @@ +from typing import Any, Generic, TypeVar + +from frozenlist import FrozenList + +__all__ = ("Signal",) + +_T = TypeVar("_T") + +class Signal(FrozenList[_T], Generic[_T]): + def __init__(self, owner: Any) -> None: ... + def __repr__(self) -> str: ... + async def send(self, *args: Any, **kwargs: Any) -> None: ... diff --git a/llm/Lib/site-packages/aiosignal/__pycache__/__init__.cpython-311.pyc b/llm/Lib/site-packages/aiosignal/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b01e54668ca26b59fb6763305e988e505020e583 Binary files /dev/null and b/llm/Lib/site-packages/aiosignal/__pycache__/__init__.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/aiosignal/py.typed b/llm/Lib/site-packages/aiosignal/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/llm/Lib/site-packages/altair-5.3.0.dist-info/INSTALLER b/llm/Lib/site-packages/altair-5.3.0.dist-info/INSTALLER new file mode 100644 index 0000000000000000000000000000000000000000..a1b589e38a32041e49332e5e81c2d363dc418d68 --- /dev/null +++ b/llm/Lib/site-packages/altair-5.3.0.dist-info/INSTALLER @@ -0,0 +1 @@ +pip diff --git a/llm/Lib/site-packages/altair-5.3.0.dist-info/METADATA b/llm/Lib/site-packages/altair-5.3.0.dist-info/METADATA new file mode 100644 index 0000000000000000000000000000000000000000..c1ce9ef9e7e9fe5f64fb0e661b1f0f9a5ce18ab1 --- /dev/null +++ b/llm/Lib/site-packages/altair-5.3.0.dist-info/METADATA @@ -0,0 +1,221 @@ +Metadata-Version: 2.3 +Name: altair +Version: 5.3.0 +Summary: Vega-Altair: A declarative statistical visualization library for Python. +Project-URL: Documentation, https://altair-viz.github.io +Project-URL: Source, https://github.com/altair-viz/altair +Author: Vega-Altair Contributors +License-File: LICENSE +Keywords: declarative,interactive,json,statistics,vega-lite,visualization +Classifier: Development Status :: 5 - Production/Stable +Classifier: Environment :: Console +Classifier: Intended Audience :: Science/Research +Classifier: License :: OSI Approved :: BSD License +Classifier: Natural Language :: English +Classifier: Programming Language :: Python :: 3.8 +Classifier: Programming Language :: Python :: 3.9 +Classifier: Programming Language :: Python :: 3.10 +Classifier: Programming Language :: Python :: 3.11 +Classifier: Programming Language :: Python :: 3.12 +Classifier: Typing :: Typed +Requires-Python: >=3.8 +Requires-Dist: jinja2 +Requires-Dist: jsonschema>=3.0 +Requires-Dist: numpy +Requires-Dist: packaging +Requires-Dist: pandas>=0.25 +Requires-Dist: toolz +Requires-Dist: typing-extensions>=4.0.1; python_version < '3.11' +Provides-Extra: all +Requires-Dist: altair-tiles>=0.3.0; extra == 'all' +Requires-Dist: anywidget>=0.9.0; extra == 'all' +Requires-Dist: pyarrow>=11; extra == 'all' +Requires-Dist: vega-datasets>=0.9.0; extra == 'all' +Requires-Dist: vegafusion[embed]>=1.6.6; extra == 'all' +Requires-Dist: vl-convert-python>=1.3.0; extra == 'all' +Provides-Extra: dev +Requires-Dist: geopandas; extra == 'dev' +Requires-Dist: hatch; extra == 'dev' +Requires-Dist: ipython; extra == 'dev' +Requires-Dist: m2r; extra == 'dev' +Requires-Dist: mypy; extra == 'dev' +Requires-Dist: pandas-stubs; extra == 'dev' +Requires-Dist: pytest; extra == 'dev' +Requires-Dist: pytest-cov; extra == 'dev' +Requires-Dist: ruff>=0.3.0; extra == 'dev' +Requires-Dist: types-jsonschema; extra == 'dev' +Requires-Dist: types-setuptools; extra == 'dev' +Provides-Extra: doc +Requires-Dist: docutils; extra == 'doc' +Requires-Dist: jinja2; extra == 'doc' +Requires-Dist: myst-parser; extra == 'doc' +Requires-Dist: numpydoc; extra == 'doc' +Requires-Dist: pillow<10,>=9; extra == 'doc' +Requires-Dist: pydata-sphinx-theme>=0.14.1; extra == 'doc' +Requires-Dist: scipy; extra == 'doc' +Requires-Dist: sphinx; extra == 'doc' +Requires-Dist: sphinx-copybutton; extra == 'doc' +Requires-Dist: sphinx-design; extra == 'doc' +Requires-Dist: sphinxext-altair; extra == 'doc' +Description-Content-Type: text/markdown + +# Vega-Altair + +[![github actions](https://github.com/altair-viz/altair/workflows/build/badge.svg)](https://github.com/altair-viz/altair/actions?query=workflow%3Abuild) +[![typedlib_mypy](https://www.mypy-lang.org/static/mypy_badge.svg)](https://www.mypy-lang.org) +[![JOSS Paper](https://joss.theoj.org/papers/10.21105/joss.01057/status.svg)](https://joss.theoj.org/papers/10.21105/joss.01057) +[![PyPI - Downloads](https://img.shields.io/pypi/dm/altair)](https://pypi.org/project/altair) + +**Vega-Altair** is a declarative statistical visualization library for Python. With Vega-Altair, you can spend more time understanding your data and its meaning. Vega-Altair's +API is simple, friendly and consistent and built on top of the powerful +[Vega-Lite](https://github.com/vega/vega-lite) JSON specification. This elegant +simplicity produces beautiful and effective visualizations with a minimal amount of code. + +*Vega-Altair was originally developed by [Jake Vanderplas](https://github.com/jakevdp) and [Brian +Granger](https://github.com/ellisonbg) in close collaboration with the [UW +Interactive Data Lab](https://idl.cs.washington.edu/).* +*The Vega-Altair open source project is not affiliated with Altair Engineering, Inc.* + +## Documentation + +See [Vega-Altair's Documentation Site](https://altair-viz.github.io) as well as the [Tutorial Notebooks](https://github.com/altair-viz/altair_notebooks). You can +run the notebooks directly in your browser by clicking on one of the following badges: + +[![Binder](https://beta.mybinder.org/badge.svg)](https://beta.mybinder.org/v2/gh/altair-viz/altair_notebooks/master) +[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/altair-viz/altair_notebooks/blob/master/notebooks/Index.ipynb) + +## Example + +Here is an example using Vega-Altair to quickly visualize and display a dataset with the native Vega-Lite renderer in the JupyterLab: + +```python +import altair as alt + +# load a simple dataset as a pandas DataFrame +from vega_datasets import data +cars = data.cars() + +alt.Chart(cars).mark_point().encode( + x='Horsepower', + y='Miles_per_Gallon', + color='Origin', +) +``` + +![Vega-Altair Visualization](https://raw.githubusercontent.com/altair-viz/altair/main/images/cars.png) + +One of the unique features of Vega-Altair, inherited from Vega-Lite, is a declarative grammar of not just visualization, but _interaction_. +With a few modifications to the example above we can create a linked histogram that is filtered based on a selection of the scatter plot. + +```python +import altair as alt +from vega_datasets import data + +source = data.cars() + +brush = alt.selection_interval() + +points = alt.Chart(source).mark_point().encode( + x='Horsepower', + y='Miles_per_Gallon', + color=alt.condition(brush, 'Origin', alt.value('lightgray')) +).add_params( + brush +) + +bars = alt.Chart(source).mark_bar().encode( + y='Origin', + color='Origin', + x='count(Origin)' +).transform_filter( + brush +) + +points & bars +``` + +![Vega-Altair Visualization Gif](https://raw.githubusercontent.com/altair-viz/altair/main/images/cars_scatter_bar.gif) + +## Features + +* Carefully-designed, declarative Python API. +* Auto-generated internal Python API that guarantees visualizations are type-checked and + in full conformance with the [Vega-Lite](https://github.com/vega/vega-lite) + specification. +* Display visualizations in JupyterLab, Jupyter Notebook, Visual Studio Code, on GitHub and + [nbviewer](https://nbviewer.jupyter.org/), and many more. +* Export visualizations to various formats such as PNG/SVG images, stand-alone HTML pages and the +[Online Vega-Lite Editor](https://vega.github.io/editor/#/). +* Serialize visualizations as JSON files. + +## Installation + +Vega-Altair can be installed with: +```bash +pip install altair +``` + +If you are using the conda package manager, the equivalent is: +```bash +conda install altair -c conda-forge +``` + +For full installation instructions, please see [the documentation](https://altair-viz.github.io/getting_started/installation.html). + +## Getting Help + +If you have a question that is not addressed in the documentation, +you can post it on [StackOverflow](https://stackoverflow.com/questions/tagged/altair) using the `altair` tag. +For bugs and feature requests, please open a [Github Issue](https://github.com/altair-viz/altair/issues). + +## Development + +[![Hatch project](https://img.shields.io/badge/%F0%9F%A5%9A-Hatch-4051b5.svg)](https://github.com/pypa/hatch) +[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) +[![pytest](https://img.shields.io/badge/logo-pytest-blue?logo=pytest&labelColor=5c5c5c&label=%20)](https://github.com/pytest-dev/pytest) + +You can find the instructions on how to install the package for development in [the documentation](https://altair-viz.github.io/getting_started/installation.html). + +To run the tests and linters, use + +``` +hatch run test +``` + +For information on how to contribute your developments back to the Vega-Altair repository, see +[`CONTRIBUTING.md`](https://github.com/altair-viz/altair/blob/main/CONTRIBUTING.md) + +## Citing Vega-Altair + +[![JOSS Paper](https://joss.theoj.org/papers/10.21105/joss.01057/status.svg)](https://joss.theoj.org/papers/10.21105/joss.01057) + +If you use Vega-Altair in academic work, please consider citing https://joss.theoj.org/papers/10.21105/joss.01057 as + +```bib +@article{VanderPlas2018, + doi = {10.21105/joss.01057}, + url = {https://doi.org/10.21105/joss.01057}, + year = {2018}, + publisher = {The Open Journal}, + volume = {3}, + number = {32}, + pages = {1057}, + author = {Jacob VanderPlas and Brian Granger and Jeffrey Heer and Dominik Moritz and Kanit Wongsuphasawat and Arvind Satyanarayan and Eitan Lees and Ilia Timofeev and Ben Welsh and Scott Sievert}, + title = {Altair: Interactive Statistical Visualizations for Python}, + journal = {Journal of Open Source Software} +} +``` +Please additionally consider citing the [Vega-Lite](https://vega.github.io/vega-lite/) project, which Vega-Altair is based on: https://dl.acm.org/doi/10.1109/TVCG.2016.2599030 + +```bib +@article{Satyanarayan2017, + author={Satyanarayan, Arvind and Moritz, Dominik and Wongsuphasawat, Kanit and Heer, Jeffrey}, + title={Vega-Lite: A Grammar of Interactive Graphics}, + journal={IEEE transactions on visualization and computer graphics}, + year={2017}, + volume={23}, + number={1}, + pages={341-350}, + publisher={IEEE} +} +``` diff --git a/llm/Lib/site-packages/altair-5.3.0.dist-info/RECORD b/llm/Lib/site-packages/altair-5.3.0.dist-info/RECORD new file mode 100644 index 0000000000000000000000000000000000000000..83bef022793bec993d1c2b0443d029bb5963a5fd --- /dev/null +++ b/llm/Lib/site-packages/altair-5.3.0.dist-info/RECORD @@ -0,0 +1,96 @@ +altair-5.3.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4 +altair-5.3.0.dist-info/METADATA,sha256=l9xKPXtciPbAaiwLFlBsrk721JiyN31y5ghnqzbAxjE,9161 +altair-5.3.0.dist-info/RECORD,, +altair-5.3.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +altair-5.3.0.dist-info/WHEEL,sha256=uNdcs2TADwSd5pVaP0Z_kcjcvvTUklh2S7bxZMF8Uj0,87 +altair-5.3.0.dist-info/licenses/LICENSE,sha256=FL9avQqjkvgG-gu8iPUd42lk1TBxpodjHmIYB-VWMx4,1497 +altair/__init__.py,sha256=MdaJpEMW9rR_ak5eYdyy69IJOnTQItOtu6e8ditWci0,14638 +altair/__pycache__/__init__.cpython-311.pyc,, +altair/__pycache__/_magics.cpython-311.pyc,, +altair/_magics.py,sha256=HoCk3nSeCe25IiARz6nvTBj5tV1F3dU5z84w2L5SGPU,2996 +altair/expr/__init__.py,sha256=wHX2XE8k_mK-BSbteNDjH2e5vJI4PPAmVau-UR5Tom8,449 +altair/expr/__pycache__/__init__.cpython-311.pyc,, +altair/expr/__pycache__/consts.cpython-311.pyc,, +altair/expr/__pycache__/core.cpython-311.pyc,, +altair/expr/__pycache__/funcs.cpython-311.pyc,, +altair/expr/consts.py,sha256=tFviBb2Dg2W0dYAuHpjU2vkZaCdzNFO8wTWBCKczeq4,916 +altair/expr/core.py,sha256=iWAS8HHE9qmUuP6MsnwVxsjZV6xzKZBKBRGgj88qUgk,6848 +altair/expr/funcs.py,sha256=EIR1TowHBbcnYj4MWDmw9hjeeMa8tZSNUVFxoVeJEyc,34418 +altair/jupyter/__init__.py,sha256=6xkrWgyUEe075vkf8c6IJdNJkI83qMMy66to8egXbz4,852 +altair/jupyter/__pycache__/__init__.cpython-311.pyc,, +altair/jupyter/__pycache__/jupyter_chart.cpython-311.pyc,, +altair/jupyter/js/README.md,sha256=iwzryhAUgm-HD7KupvKV2TKY5U74S7mLyGKe0rgB4D8,172 +altair/jupyter/js/index.js,sha256=Fq5QVsTmNZ-YdONfAGRYhKba9PCjOMxBML-yevju0qc,7936 +altair/jupyter/jupyter_chart.py,sha256=B7NTCwH94ecB-_FKAWCDdCP5jAb1PEMNvUaMi-x8JSU,15289 +altair/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0 +altair/utils/__init__.py,sha256=7-6IsrndhpIInccWeneLbZC5S-gjasA1cFbTmsd0Nvs,702 +altair/utils/__pycache__/__init__.cpython-311.pyc,, +altair/utils/__pycache__/_dfi_types.cpython-311.pyc,, +altair/utils/__pycache__/_importers.cpython-311.pyc,, +altair/utils/__pycache__/_show.cpython-311.pyc,, +altair/utils/__pycache__/_transformed_data.cpython-311.pyc,, +altair/utils/__pycache__/_vegafusion_data.cpython-311.pyc,, +altair/utils/__pycache__/compiler.cpython-311.pyc,, +altair/utils/__pycache__/core.cpython-311.pyc,, +altair/utils/__pycache__/data.cpython-311.pyc,, +altair/utils/__pycache__/deprecation.cpython-311.pyc,, +altair/utils/__pycache__/display.cpython-311.pyc,, +altair/utils/__pycache__/execeval.cpython-311.pyc,, +altair/utils/__pycache__/html.cpython-311.pyc,, +altair/utils/__pycache__/mimebundle.cpython-311.pyc,, +altair/utils/__pycache__/plugin_registry.cpython-311.pyc,, +altair/utils/__pycache__/save.cpython-311.pyc,, +altair/utils/__pycache__/schemapi.cpython-311.pyc,, +altair/utils/__pycache__/selection.cpython-311.pyc,, +altair/utils/__pycache__/server.cpython-311.pyc,, +altair/utils/__pycache__/theme.cpython-311.pyc,, +altair/utils/_dfi_types.py,sha256=75RF6_BS7Go8pVAjwkCN__jD2YOrnZmNGBiq_PFlHEU,6503 +altair/utils/_importers.py,sha256=W34Z6rFdPGOcr2xMIYGFqKEGuGG_xyUcYwwFmDu8k40,3690 +altair/utils/_show.py,sha256=cTzcyXgpT2tqkYAH-1tCZm5MZIKyTkYMgsTKQMaViDQ,2261 +altair/utils/_transformed_data.py,sha256=H9V9Na0bAfypMNICazBxIg2ThHRnI9IF12Zc7UPYzL4,18128 +altair/utils/_vegafusion_data.py,sha256=44-iwwhz_BOwxZxSXjGRS3NUNZsRvmbYVhqAEENrnSQ,8205 +altair/utils/compiler.py,sha256=Gku2l4Z6xI_Ei7J1j4MOk_WEE9rhrp1OuoFCcnB9_BU,396 +altair/utils/core.py,sha256=wAHj3qgGWeMMp4Wy1oOE-tAShhKAkbWpoe54XSSd52A,27776 +altair/utils/data.py,sha256=Sm7q4SrBKuLRnqqYaX02WB5nptEtoK0Yx-1uwhzm_iA,13030 +altair/utils/deprecation.py,sha256=AJ-3kdQB7pvWbyDxl5ZUiZoM3O-UpI3X9BOcgzuXmG4,1785 +altair/utils/display.py,sha256=RKK09oNUWN91mWBa77U8RRMe5doz-JygH25KxXjZux0,8852 +altair/utils/execeval.py,sha256=cTpnWPGVPUAnM8mqs1t5a6l9Bp9bnZWPjwa8_mKFvLo,1423 +altair/utils/html.py,sha256=9RYtr0llnf39titWJgRNbaT98m5Zp_GZcqW10izItyw,9843 +altair/utils/mimebundle.py,sha256=Rbs4f_gg3kSEK_zhG98P8Jr_o29U9cwGTN7Bbu1C5Nc,12178 +altair/utils/plugin_registry.py,sha256=l8o0mghMeamhhLfebKVeerH2NQr3KCG7rJLZDncUKQM,7936 +altair/utils/save.py,sha256=CD8yttjyjORGxLb7cyjEcbcM_2tGNGjWJdO1dk0VAaU,8004 +altair/utils/schemapi.py,sha256=Ib1qyv2qN_NllogJ0Bd5qkRLIgxWbse7VFOzXXGsRtQ,52252 +altair/utils/selection.py,sha256=T46Lv1KqSIjgSvKY6FcMa9LQ8rdIS--9ZfN5ZhTvU4E,4175 +altair/utils/server.py,sha256=le_OAjvz0kh-6p_OclWUTpOMaPYjbvYCP1yRR4iPqLs,4076 +altair/utils/theme.py,sha256=b1VS--VOilgE8hrw05-DjIz_d5ND8qVqTrT-OWTqD_Y,221 +altair/vegalite/__init__.py,sha256=BRCDyXuwAljV_GeiifreqCgCl1fevqeww1qGAnGTfQ4,31 +altair/vegalite/__pycache__/__init__.cpython-311.pyc,, +altair/vegalite/__pycache__/api.cpython-311.pyc,, +altair/vegalite/__pycache__/data.cpython-311.pyc,, +altair/vegalite/__pycache__/display.cpython-311.pyc,, +altair/vegalite/__pycache__/schema.cpython-311.pyc,, +altair/vegalite/api.py,sha256=kZEoyq1nKhFNTcW6GRMyv7NYlQXqcpFuZ3m8eRApmEg,35 +altair/vegalite/data.py,sha256=zl1fVhwEpXQYWxrTAp6UlHIQe565thPpaMtMX-jG2Hs,1190 +altair/vegalite/display.py,sha256=2mpJo3ZRZYjtvzTKgJcBwq6MpZIrWdZLMCU377bpIlY,357 +altair/vegalite/schema.py,sha256=DsBPJA_bpac1CFyYhnsynqatGtQ38ZOSFhaef7DHTE8,68 +altair/vegalite/v5/__init__.py,sha256=-f54c8F5P2CjACqRWyQvqduejoZluEgBYuxd_BhCPVA,379 +altair/vegalite/v5/__pycache__/__init__.cpython-311.pyc,, +altair/vegalite/v5/__pycache__/api.cpython-311.pyc,, +altair/vegalite/v5/__pycache__/compiler.cpython-311.pyc,, +altair/vegalite/v5/__pycache__/data.cpython-311.pyc,, +altair/vegalite/v5/__pycache__/display.cpython-311.pyc,, +altair/vegalite/v5/__pycache__/theme.cpython-311.pyc,, +altair/vegalite/v5/api.py,sha256=2EsZuC0jeLBU5S0C_jvM3D9QeqIg4XkdflkBHKaV3AM,156055 +altair/vegalite/v5/compiler.py,sha256=ykOlGGV0PATiars7wg7_gnIBHWWpoTb_qUGBRmwoeWo,818 +altair/vegalite/v5/data.py,sha256=7kq6diO8u_SPzcCL7BvCAK6T961WfVgf2cuEnFL8H4c,1114 +altair/vegalite/v5/display.py,sha256=NuGhNKi--Bv5y5mUiPqoezT5mHrvmkkyCDrGAdFstYg,5659 +altair/vegalite/v5/schema/__init__.py,sha256=bSrsoSMf46OAK6V0yOzLBtKXPMTmGv3icAZHbbUiUIo,183 +altair/vegalite/v5/schema/__pycache__/__init__.cpython-311.pyc,, +altair/vegalite/v5/schema/__pycache__/channels.cpython-311.pyc,, +altair/vegalite/v5/schema/__pycache__/core.cpython-311.pyc,, +altair/vegalite/v5/schema/__pycache__/mixins.cpython-311.pyc,, +altair/vegalite/v5/schema/channels.py,sha256=oLXWKpAh9rQ0LBEG5hb65GkuSKeM0uSpdptONg8xwKc,2784087 +altair/vegalite/v5/schema/core.py,sha256=30bcwQyCewxFdTutH7oW2VREIf4ER_n_9yRAdFIb0Do,2157102 +altair/vegalite/v5/schema/mixins.py,sha256=04grHxBE8iphRmUrD73pcQ2whg9DSJ--YwYiNwdVO48,443354 +altair/vegalite/v5/schema/vega-lite-schema.json,sha256=8B1wGe1jk6DdjWmpkL1Z5alpIRCJSzPfOTsSXtA13xg,1817604 +altair/vegalite/v5/theme.py,sha256=MFElxfYfnMkpd58XY14nroBasuTZZ49fC6SsQEmBc58,1532 diff --git a/llm/Lib/site-packages/altair-5.3.0.dist-info/REQUESTED b/llm/Lib/site-packages/altair-5.3.0.dist-info/REQUESTED new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/llm/Lib/site-packages/altair-5.3.0.dist-info/WHEEL b/llm/Lib/site-packages/altair-5.3.0.dist-info/WHEEL new file mode 100644 index 0000000000000000000000000000000000000000..0309176f93ce82503b6ac764edadacb3d3127758 --- /dev/null +++ b/llm/Lib/site-packages/altair-5.3.0.dist-info/WHEEL @@ -0,0 +1,4 @@ +Wheel-Version: 1.0 +Generator: hatchling 1.22.4 +Root-Is-Purelib: true +Tag: py3-none-any diff --git a/llm/Lib/site-packages/altair-5.3.0.dist-info/licenses/LICENSE b/llm/Lib/site-packages/altair-5.3.0.dist-info/licenses/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..03635b4b1fb676f7e1ec7bc286345228cfca78ee --- /dev/null +++ b/llm/Lib/site-packages/altair-5.3.0.dist-info/licenses/LICENSE @@ -0,0 +1,28 @@ +Copyright (c) 2015-2023, Vega-Altair Developers +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of vega-altair nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + diff --git a/llm/Lib/site-packages/altair/__init__.py b/llm/Lib/site-packages/altair/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7c792d3f51a73c8e9fe67aa459fe02455e7bea0 --- /dev/null +++ b/llm/Lib/site-packages/altair/__init__.py @@ -0,0 +1,627 @@ +# ruff: noqa +__version__ = "5.3.0" + +from typing import Any + +# Necessary as mypy would see expr as the module alt.expr although due to how +# the imports are set up it is expr in the alt.expr module +expr: Any + + +# The content of __all__ is automatically written by +# tools/update_init_file.py. Do not modify directly. +__all__ = [ + "Aggregate", + "AggregateOp", + "AggregateTransform", + "AggregatedFieldDef", + "Align", + "AllSortString", + "Angle", + "AngleDatum", + "AngleValue", + "AnyMark", + "AnyMarkConfig", + "AreaConfig", + "ArgmaxDef", + "ArgminDef", + "AutoSizeParams", + "AutosizeType", + "Axis", + "AxisConfig", + "AxisOrient", + "AxisResolveMap", + "BBox", + "BarConfig", + "BaseTitleNoValueRefs", + "Baseline", + "Bin", + "BinExtent", + "BinParams", + "BinTransform", + "BindCheckbox", + "BindDirect", + "BindInput", + "BindRadioSelect", + "BindRange", + "Binding", + "BinnedTimeUnit", + "Blend", + "BoxPlot", + "BoxPlotConfig", + "BoxPlotDef", + "BrushConfig", + "CalculateTransform", + "Categorical", + "Chart", + "ChartDataType", + "Color", + "ColorDatum", + "ColorDef", + "ColorName", + "ColorScheme", + "ColorValue", + "Column", + "CompositeMark", + "CompositeMarkDef", + "CompositionConfig", + "ConcatChart", + "ConcatSpecGenericSpec", + "ConditionalAxisColor", + "ConditionalAxisLabelAlign", + "ConditionalAxisLabelBaseline", + "ConditionalAxisLabelFontStyle", + "ConditionalAxisLabelFontWeight", + "ConditionalAxisNumber", + "ConditionalAxisNumberArray", + "ConditionalAxisPropertyAlignnull", + "ConditionalAxisPropertyColornull", + "ConditionalAxisPropertyFontStylenull", + "ConditionalAxisPropertyFontWeightnull", + "ConditionalAxisPropertyTextBaselinenull", + "ConditionalAxisPropertynumberArraynull", + "ConditionalAxisPropertynumbernull", + "ConditionalAxisPropertystringnull", + "ConditionalAxisString", + "ConditionalMarkPropFieldOrDatumDef", + "ConditionalMarkPropFieldOrDatumDefTypeForShape", + "ConditionalParameterMarkPropFieldOrDatumDef", + "ConditionalParameterMarkPropFieldOrDatumDefTypeForShape", + "ConditionalParameterStringFieldDef", + "ConditionalParameterValueDefGradientstringnullExprRef", + "ConditionalParameterValueDefTextExprRef", + "ConditionalParameterValueDefnumber", + "ConditionalParameterValueDefnumberArrayExprRef", + "ConditionalParameterValueDefnumberExprRef", + "ConditionalParameterValueDefstringExprRef", + "ConditionalParameterValueDefstringnullExprRef", + "ConditionalPredicateMarkPropFieldOrDatumDef", + "ConditionalPredicateMarkPropFieldOrDatumDefTypeForShape", + "ConditionalPredicateStringFieldDef", + "ConditionalPredicateValueDefAlignnullExprRef", + "ConditionalPredicateValueDefColornullExprRef", + "ConditionalPredicateValueDefFontStylenullExprRef", + "ConditionalPredicateValueDefFontWeightnullExprRef", + "ConditionalPredicateValueDefGradientstringnullExprRef", + "ConditionalPredicateValueDefTextBaselinenullExprRef", + "ConditionalPredicateValueDefTextExprRef", + "ConditionalPredicateValueDefnumber", + "ConditionalPredicateValueDefnumberArrayExprRef", + "ConditionalPredicateValueDefnumberArraynullExprRef", + "ConditionalPredicateValueDefnumberExprRef", + "ConditionalPredicateValueDefnumbernullExprRef", + "ConditionalPredicateValueDefstringExprRef", + "ConditionalPredicateValueDefstringnullExprRef", + "ConditionalStringFieldDef", + "ConditionalValueDefGradientstringnullExprRef", + "ConditionalValueDefTextExprRef", + "ConditionalValueDefnumber", + "ConditionalValueDefnumberArrayExprRef", + "ConditionalValueDefnumberExprRef", + "ConditionalValueDefstringExprRef", + "ConditionalValueDefstringnullExprRef", + "Config", + "CsvDataFormat", + "Cursor", + "Cyclical", + "Data", + "DataFormat", + "DataFrameLike", + "DataSource", + "DataType", + "Datasets", + "DateTime", + "DatumChannelMixin", + "DatumDef", + "Day", + "DensityTransform", + "DerivedStream", + "Description", + "DescriptionValue", + "Detail", + "Dict", + "DictInlineDataset", + "DictSelectionInit", + "DictSelectionInitInterval", + "Diverging", + "DomainUnionWith", + "DsvDataFormat", + "Element", + "Encoding", + "EncodingSortField", + "ErrorBand", + "ErrorBandConfig", + "ErrorBandDef", + "ErrorBar", + "ErrorBarConfig", + "ErrorBarDef", + "ErrorBarExtent", + "EventStream", + "EventType", + "Expr", + "ExprRef", + "ExtentTransform", + "Facet", + "FacetChart", + "FacetEncodingFieldDef", + "FacetFieldDef", + "FacetMapping", + "FacetSpec", + "FacetedEncoding", + "FacetedUnitSpec", + "Feature", + "FeatureCollection", + "FeatureGeometryGeoJsonProperties", + "Field", + "FieldChannelMixin", + "FieldDefWithoutScale", + "FieldEqualPredicate", + "FieldGTEPredicate", + "FieldGTPredicate", + "FieldLTEPredicate", + "FieldLTPredicate", + "FieldName", + "FieldOneOfPredicate", + "FieldOrDatumDefWithConditionDatumDefGradientstringnull", + "FieldOrDatumDefWithConditionDatumDefnumber", + "FieldOrDatumDefWithConditionDatumDefnumberArray", + "FieldOrDatumDefWithConditionDatumDefstringnull", + "FieldOrDatumDefWithConditionMarkPropFieldDefGradientstringnull", + "FieldOrDatumDefWithConditionMarkPropFieldDefTypeForShapestringnull", + "FieldOrDatumDefWithConditionMarkPropFieldDefnumber", + "FieldOrDatumDefWithConditionMarkPropFieldDefnumberArray", + "FieldOrDatumDefWithConditionStringDatumDefText", + "FieldOrDatumDefWithConditionStringFieldDefText", + "FieldOrDatumDefWithConditionStringFieldDefstring", + "FieldRange", + "FieldRangePredicate", + "FieldValidPredicate", + "Fill", + "FillDatum", + "FillOpacity", + "FillOpacityDatum", + "FillOpacityValue", + "FillValue", + "FilterTransform", + "Fit", + "FlattenTransform", + "FoldTransform", + "FontStyle", + "FontWeight", + "FormatConfig", + "Generator", + "GenericUnitSpecEncodingAnyMark", + "GeoJsonFeature", + "GeoJsonFeatureCollection", + "GeoJsonProperties", + "Geometry", + "GeometryCollection", + "Gradient", + "GradientStop", + "GraticuleGenerator", + "GraticuleParams", + "HConcatChart", + "HConcatSpecGenericSpec", + "Header", + "HeaderConfig", + "HexColor", + "Href", + "HrefValue", + "Impute", + "ImputeMethod", + "ImputeParams", + "ImputeSequence", + "ImputeTransform", + "InlineData", + "InlineDataset", + "Interpolate", + "IntervalSelectionConfig", + "IntervalSelectionConfigWithoutType", + "JoinAggregateFieldDef", + "JoinAggregateTransform", + "JsonDataFormat", + "JupyterChart", + "Key", + "LabelOverlap", + "LatLongDef", + "LatLongFieldDef", + "Latitude", + "Latitude2", + "Latitude2Datum", + "Latitude2Value", + "LatitudeDatum", + "LayerChart", + "LayerRepeatMapping", + "LayerRepeatSpec", + "LayerSpec", + "LayoutAlign", + "Legend", + "LegendBinding", + "LegendConfig", + "LegendOrient", + "LegendResolveMap", + "LegendStreamBinding", + "LineConfig", + "LineString", + "LinearGradient", + "LocalMultiTimeUnit", + "LocalSingleTimeUnit", + "Locale", + "LoessTransform", + "LogicalAndPredicate", + "LogicalNotPredicate", + "LogicalOrPredicate", + "Longitude", + "Longitude2", + "Longitude2Datum", + "Longitude2Value", + "LongitudeDatum", + "LookupData", + "LookupSelection", + "LookupTransform", + "Mark", + "MarkConfig", + "MarkDef", + "MarkPropDefGradientstringnull", + "MarkPropDefnumber", + "MarkPropDefnumberArray", + "MarkPropDefstringnullTypeForShape", + "MarkType", + "MaxRowsError", + "MergedStream", + "Month", + "MultiLineString", + "MultiPoint", + "MultiPolygon", + "MultiTimeUnit", + "NamedData", + "NonArgAggregateOp", + "NonLayerRepeatSpec", + "NonNormalizedSpec", + "NumberLocale", + "NumericArrayMarkPropDef", + "NumericMarkPropDef", + "OffsetDef", + "Opacity", + "OpacityDatum", + "OpacityValue", + "Order", + "OrderFieldDef", + "OrderOnlyDef", + "OrderValue", + "OrderValueDef", + "Orient", + "Orientation", + "OverlayMarkDef", + "Padding", + "Parameter", + "ParameterExpression", + "ParameterExtent", + "ParameterName", + "ParameterPredicate", + "Parse", + "ParseValue", + "PivotTransform", + "Point", + "PointSelectionConfig", + "PointSelectionConfigWithoutType", + "PolarDef", + "Polygon", + "Position", + "Position2Def", + "PositionDatumDef", + "PositionDatumDefBase", + "PositionDef", + "PositionFieldDef", + "PositionFieldDefBase", + "PositionValueDef", + "Predicate", + "PredicateComposition", + "PrimitiveValue", + "Projection", + "ProjectionConfig", + "ProjectionType", + "QuantileTransform", + "RadialGradient", + "Radius", + "Radius2", + "Radius2Datum", + "Radius2Value", + "RadiusDatum", + "RadiusValue", + "RangeConfig", + "RangeEnum", + "RangeRaw", + "RangeRawArray", + "RangeScheme", + "RectConfig", + "RegressionTransform", + "RelativeBandSize", + "RepeatChart", + "RepeatMapping", + "RepeatRef", + "RepeatSpec", + "Resolve", + "ResolveMode", + "Root", + "Row", + "RowColLayoutAlign", + "RowColboolean", + "RowColnumber", + "RowColumnEncodingFieldDef", + "SCHEMA_URL", + "SCHEMA_VERSION", + "SampleTransform", + "Scale", + "ScaleBinParams", + "ScaleBins", + "ScaleConfig", + "ScaleDatumDef", + "ScaleFieldDef", + "ScaleInterpolateEnum", + "ScaleInterpolateParams", + "ScaleResolveMap", + "ScaleType", + "SchemaBase", + "SchemeParams", + "SecondaryFieldDef", + "SelectionConfig", + "SelectionExpression", + "SelectionInit", + "SelectionInitInterval", + "SelectionInitIntervalMapping", + "SelectionInitMapping", + "SelectionParameter", + "SelectionPredicateComposition", + "SelectionResolution", + "SelectionType", + "SequenceGenerator", + "SequenceParams", + "SequentialMultiHue", + "SequentialSingleHue", + "Shape", + "ShapeDatum", + "ShapeDef", + "ShapeValue", + "SharedEncoding", + "SingleDefUnitChannel", + "SingleTimeUnit", + "Size", + "SizeDatum", + "SizeValue", + "Sort", + "SortArray", + "SortByChannel", + "SortByChannelDesc", + "SortByEncoding", + "SortField", + "SortOrder", + "Spec", + "SphereGenerator", + "StackOffset", + "StackTransform", + "StandardType", + "Step", + "StepFor", + "Stream", + "StringFieldDef", + "StringFieldDefWithCondition", + "StringValueDefWithCondition", + "Stroke", + "StrokeCap", + "StrokeDash", + "StrokeDashDatum", + "StrokeDashValue", + "StrokeDatum", + "StrokeJoin", + "StrokeOpacity", + "StrokeOpacityDatum", + "StrokeOpacityValue", + "StrokeValue", + "StrokeWidth", + "StrokeWidthDatum", + "StrokeWidthValue", + "StyleConfigIndex", + "SymbolShape", + "TOPLEVEL_ONLY_KEYS", + "Text", + "TextBaseline", + "TextDatum", + "TextDef", + "TextDirection", + "TextValue", + "Theta", + "Theta2", + "Theta2Datum", + "Theta2Value", + "ThetaDatum", + "ThetaValue", + "TickConfig", + "TickCount", + "TimeInterval", + "TimeIntervalStep", + "TimeLocale", + "TimeUnit", + "TimeUnitParams", + "TimeUnitTransform", + "TimeUnitTransformParams", + "Title", + "TitleAnchor", + "TitleConfig", + "TitleFrame", + "TitleOrient", + "TitleParams", + "Tooltip", + "TooltipContent", + "TooltipValue", + "TopLevelConcatSpec", + "TopLevelFacetSpec", + "TopLevelHConcatSpec", + "TopLevelLayerSpec", + "TopLevelMixin", + "TopLevelParameter", + "TopLevelRepeatSpec", + "TopLevelSelectionParameter", + "TopLevelSpec", + "TopLevelUnitSpec", + "TopLevelVConcatSpec", + "TopoDataFormat", + "Transform", + "Type", + "TypeForShape", + "TypedFieldDef", + "URI", + "Undefined", + "UndefinedType", + "UnitSpec", + "UnitSpecWithFrame", + "Url", + "UrlData", + "UrlValue", + "UtcMultiTimeUnit", + "UtcSingleTimeUnit", + "VConcatChart", + "VConcatSpecGenericSpec", + "VEGAEMBED_VERSION", + "VEGALITE_VERSION", + "VEGA_VERSION", + "ValueChannelMixin", + "ValueDefWithConditionMarkPropFieldOrDatumDefGradientstringnull", + "ValueDefWithConditionMarkPropFieldOrDatumDefTypeForShapestringnull", + "ValueDefWithConditionMarkPropFieldOrDatumDefnumber", + "ValueDefWithConditionMarkPropFieldOrDatumDefnumberArray", + "ValueDefWithConditionMarkPropFieldOrDatumDefstringnull", + "ValueDefWithConditionStringFieldDefText", + "ValueDefnumber", + "ValueDefnumberwidthheightExprRef", + "VariableParameter", + "Vector10string", + "Vector12string", + "Vector2DateTime", + "Vector2Vector2number", + "Vector2boolean", + "Vector2number", + "Vector2string", + "Vector3number", + "Vector7string", + "VegaLite", + "VegaLiteSchema", + "ViewBackground", + "ViewConfig", + "WindowEventType", + "WindowFieldDef", + "WindowOnlyOp", + "WindowTransform", + "X", + "X2", + "X2Datum", + "X2Value", + "XDatum", + "XError", + "XError2", + "XError2Value", + "XErrorValue", + "XOffset", + "XOffsetDatum", + "XOffsetValue", + "XValue", + "Y", + "Y2", + "Y2Datum", + "Y2Value", + "YDatum", + "YError", + "YError2", + "YError2Value", + "YErrorValue", + "YOffset", + "YOffsetDatum", + "YOffsetValue", + "YValue", + "api", + "binding", + "binding_checkbox", + "binding_radio", + "binding_range", + "binding_select", + "channels", + "check_fields_and_encodings", + "compiler", + "concat", + "condition", + "core", + "curry", + "data", + "data_transformers", + "datum", + "default_data_transformer", + "display", + "expr", + "graticule", + "hconcat", + "jupyter", + "layer", + "limit_rows", + "load_ipython_extension", + "load_schema", + "mixins", + "overload", + "param", + "parse_shorthand", + "pipe", + "renderers", + "repeat", + "sample", + "schema", + "selection_interval", + "selection_point", + "sequence", + "sphere", + "theme", + "themes", + "to_csv", + "to_json", + "to_values", + "topo_feature", + "utils", + "v5", + "value", + "vconcat", + "vegalite", + "vegalite_compilers", + "with_property_setters", +] + + +def __dir__(): + return __all__ + + +from .vegalite import * +from .jupyter import JupyterChart + + +def load_ipython_extension(ipython): + from ._magics import vegalite + + ipython.register_magic_function(vegalite, "cell") diff --git a/llm/Lib/site-packages/altair/__pycache__/__init__.cpython-311.pyc b/llm/Lib/site-packages/altair/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dcc140d21ab43607361477394b92ca91c139a6af Binary files /dev/null and b/llm/Lib/site-packages/altair/__pycache__/__init__.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/altair/__pycache__/_magics.cpython-311.pyc b/llm/Lib/site-packages/altair/__pycache__/_magics.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bd1ac9480ccfcd5bc02c8cd73a10675b4d48487 Binary files /dev/null and b/llm/Lib/site-packages/altair/__pycache__/_magics.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/altair/_magics.py b/llm/Lib/site-packages/altair/_magics.py new file mode 100644 index 0000000000000000000000000000000000000000..bac190aa321b9f8bdd86a0107911b2fe718957f8 --- /dev/null +++ b/llm/Lib/site-packages/altair/_magics.py @@ -0,0 +1,109 @@ +""" +Magic functions for rendering vega-lite specifications +""" + +__all__ = ["vegalite"] + +import json +import warnings + +import IPython +from IPython.core import magic_arguments +import pandas as pd +from toolz import curried + +from altair.vegalite import v5 as vegalite_v5 + +try: + import yaml + + YAML_AVAILABLE = True +except ImportError: + YAML_AVAILABLE = False + + +RENDERERS = { + "vega-lite": { + "5": vegalite_v5.VegaLite, + }, +} + + +TRANSFORMERS = { + "vega-lite": { + "5": vegalite_v5.data_transformers, + }, +} + + +def _prepare_data(data, data_transformers): + """Convert input data to data for use within schema""" + if data is None or isinstance(data, dict): + return data + elif isinstance(data, pd.DataFrame): + return curried.pipe(data, data_transformers.get()) + elif isinstance(data, str): + return {"url": data} + else: + warnings.warn("data of type {} not recognized".format(type(data)), stacklevel=1) + return data + + +def _get_variable(name): + """Get a variable from the notebook namespace.""" + ip = IPython.get_ipython() + if ip is None: + raise ValueError( + "Magic command must be run within an IPython " + "environment, in which get_ipython() is defined." + ) + if name not in ip.user_ns: + raise NameError( + "argument '{}' does not match the name of any defined variable".format(name) + ) + return ip.user_ns[name] + + +@magic_arguments.magic_arguments() +@magic_arguments.argument( + "data", + nargs="?", + help="local variablename of a pandas DataFrame to be used as the dataset", +) +@magic_arguments.argument("-v", "--version", dest="version", default="v5") +@magic_arguments.argument("-j", "--json", dest="json", action="store_true") +def vegalite(line, cell): + """Cell magic for displaying vega-lite visualizations in CoLab. + + %%vegalite [dataframe] [--json] [--version='v5'] + + Visualize the contents of the cell using Vega-Lite, optionally + specifying a pandas DataFrame object to be used as the dataset. + + if --json is passed, then input is parsed as json rather than yaml. + """ + args = magic_arguments.parse_argstring(vegalite, line) + existing_versions = {"v5": "5"} + version = existing_versions[args.version] + assert version in RENDERERS["vega-lite"] + VegaLite = RENDERERS["vega-lite"][version] + data_transformers = TRANSFORMERS["vega-lite"][version] + + if args.json: + spec = json.loads(cell) + elif not YAML_AVAILABLE: + try: + spec = json.loads(cell) + except json.JSONDecodeError as err: + raise ValueError( + "%%vegalite: spec is not valid JSON. " + "Install pyyaml to parse spec as yaml" + ) from err + else: + spec = yaml.load(cell, Loader=yaml.SafeLoader) + + if args.data is not None: + data = _get_variable(args.data) + spec["data"] = _prepare_data(data, data_transformers) + + return VegaLite(spec) diff --git a/llm/Lib/site-packages/altair/expr/__init__.py b/llm/Lib/site-packages/altair/expr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..64916fe2bac46f86dded12948cf2174965df3fc8 --- /dev/null +++ b/llm/Lib/site-packages/altair/expr/__init__.py @@ -0,0 +1,20 @@ +"""Tools for creating transform & filter expressions with a python syntax""" + +# ruff: noqa +from typing import Any + +from .core import datum, Expression +from .funcs import * +from .consts import * +from ..vegalite.v5.schema.core import ExprRef as _ExprRef + + +class _ExprType: + def __init__(self, expr): + vars(self).update(expr) + + def __call__(self, expr, **kwargs): + return _ExprRef(expr, **kwargs) + + +expr: Any = _ExprType(globals()) diff --git a/llm/Lib/site-packages/altair/expr/__pycache__/__init__.cpython-311.pyc b/llm/Lib/site-packages/altair/expr/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d15ad478be0c645eff4634273bf6103266ef1c5 Binary files /dev/null and b/llm/Lib/site-packages/altair/expr/__pycache__/__init__.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/altair/expr/__pycache__/consts.cpython-311.pyc b/llm/Lib/site-packages/altair/expr/__pycache__/consts.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39b7a45eceeeae65193b1351e9f327d47656f79b Binary files /dev/null and b/llm/Lib/site-packages/altair/expr/__pycache__/consts.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/altair/expr/__pycache__/core.cpython-311.pyc b/llm/Lib/site-packages/altair/expr/__pycache__/core.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6a7a707c911204471a58993369d0f9df69579d2 Binary files /dev/null and b/llm/Lib/site-packages/altair/expr/__pycache__/core.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/altair/expr/__pycache__/funcs.cpython-311.pyc b/llm/Lib/site-packages/altair/expr/__pycache__/funcs.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..540ba2b266f5c90d3468547e4d7a48b461429b48 Binary files /dev/null and b/llm/Lib/site-packages/altair/expr/__pycache__/funcs.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/altair/expr/consts.py b/llm/Lib/site-packages/altair/expr/consts.py new file mode 100644 index 0000000000000000000000000000000000000000..974fb06a3c756a7e27106f4d1bb9c17b78a094fd --- /dev/null +++ b/llm/Lib/site-packages/altair/expr/consts.py @@ -0,0 +1,29 @@ +from typing import Dict + +from .core import ConstExpression + + +CONST_LISTING = { + "NaN": "not a number (same as JavaScript literal NaN)", + "LN10": "the natural log of 10 (alias to Math.LN10)", + "E": "the transcendental number e (alias to Math.E)", + "LOG10E": "the base 10 logarithm e (alias to Math.LOG10E)", + "LOG2E": "the base 2 logarithm of e (alias to Math.LOG2E)", + "SQRT1_2": "the square root of 0.5 (alias to Math.SQRT1_2)", + "LN2": "the natural log of 2 (alias to Math.LN2)", + "SQRT2": "the square root of 2 (alias to Math.SQRT1_2)", + "PI": "the transcendental number pi (alias to Math.PI)", +} + +NAME_MAP: Dict[str, str] = {} + + +def _populate_namespace(): + globals_ = globals() + for name, doc in CONST_LISTING.items(): + py_name = NAME_MAP.get(name, name) + globals_[py_name] = ConstExpression(name, doc) + yield py_name + + +__all__ = list(_populate_namespace()) diff --git a/llm/Lib/site-packages/altair/expr/core.py b/llm/Lib/site-packages/altair/expr/core.py new file mode 100644 index 0000000000000000000000000000000000000000..9cc258c8b723613453d4033c85035e335a537318 --- /dev/null +++ b/llm/Lib/site-packages/altair/expr/core.py @@ -0,0 +1,234 @@ +from ..utils import SchemaBase + + +class DatumType: + """An object to assist in building Vega-Lite Expressions""" + + def __repr__(self): + return "datum" + + def __getattr__(self, attr): + if attr.startswith("__") and attr.endswith("__"): + raise AttributeError(attr) + return GetAttrExpression("datum", attr) + + def __getitem__(self, attr): + return GetItemExpression("datum", attr) + + def __call__(self, datum, **kwargs): + """Specify a datum for use in an encoding""" + return dict(datum=datum, **kwargs) + + +datum = DatumType() + + +def _js_repr(val): + """Return a javascript-safe string representation of val""" + if val is True: + return "true" + elif val is False: + return "false" + elif val is None: + return "null" + elif isinstance(val, OperatorMixin): + return val._to_expr() + else: + return repr(val) + + +# Designed to work with Expression and VariableParameter +class OperatorMixin: + def _to_expr(self): + return repr(self) + + def _from_expr(self, expr): + return expr + + def __add__(self, other): + comp_value = BinaryExpression("+", self, other) + return self._from_expr(comp_value) + + def __radd__(self, other): + comp_value = BinaryExpression("+", other, self) + return self._from_expr(comp_value) + + def __sub__(self, other): + comp_value = BinaryExpression("-", self, other) + return self._from_expr(comp_value) + + def __rsub__(self, other): + comp_value = BinaryExpression("-", other, self) + return self._from_expr(comp_value) + + def __mul__(self, other): + comp_value = BinaryExpression("*", self, other) + return self._from_expr(comp_value) + + def __rmul__(self, other): + comp_value = BinaryExpression("*", other, self) + return self._from_expr(comp_value) + + def __truediv__(self, other): + comp_value = BinaryExpression("/", self, other) + return self._from_expr(comp_value) + + def __rtruediv__(self, other): + comp_value = BinaryExpression("/", other, self) + return self._from_expr(comp_value) + + __div__ = __truediv__ + + __rdiv__ = __rtruediv__ + + def __mod__(self, other): + comp_value = BinaryExpression("%", self, other) + return self._from_expr(comp_value) + + def __rmod__(self, other): + comp_value = BinaryExpression("%", other, self) + return self._from_expr(comp_value) + + def __pow__(self, other): + # "**" Javascript operator is not supported in all browsers + comp_value = FunctionExpression("pow", (self, other)) + return self._from_expr(comp_value) + + def __rpow__(self, other): + # "**" Javascript operator is not supported in all browsers + comp_value = FunctionExpression("pow", (other, self)) + return self._from_expr(comp_value) + + def __neg__(self): + comp_value = UnaryExpression("-", self) + return self._from_expr(comp_value) + + def __pos__(self): + comp_value = UnaryExpression("+", self) + return self._from_expr(comp_value) + + # comparison operators + + def __eq__(self, other): + comp_value = BinaryExpression("===", self, other) + return self._from_expr(comp_value) + + def __ne__(self, other): + comp_value = BinaryExpression("!==", self, other) + return self._from_expr(comp_value) + + def __gt__(self, other): + comp_value = BinaryExpression(">", self, other) + return self._from_expr(comp_value) + + def __lt__(self, other): + comp_value = BinaryExpression("<", self, other) + return self._from_expr(comp_value) + + def __ge__(self, other): + comp_value = BinaryExpression(">=", self, other) + return self._from_expr(comp_value) + + def __le__(self, other): + comp_value = BinaryExpression("<=", self, other) + return self._from_expr(comp_value) + + def __abs__(self): + comp_value = FunctionExpression("abs", (self,)) + return self._from_expr(comp_value) + + # logical operators + + def __and__(self, other): + comp_value = BinaryExpression("&&", self, other) + return self._from_expr(comp_value) + + def __rand__(self, other): + comp_value = BinaryExpression("&&", other, self) + return self._from_expr(comp_value) + + def __or__(self, other): + comp_value = BinaryExpression("||", self, other) + return self._from_expr(comp_value) + + def __ror__(self, other): + comp_value = BinaryExpression("||", other, self) + return self._from_expr(comp_value) + + def __invert__(self): + comp_value = UnaryExpression("!", self) + return self._from_expr(comp_value) + + +class Expression(OperatorMixin, SchemaBase): + """Expression + + Base object for enabling build-up of Javascript expressions using + a Python syntax. Calling ``repr(obj)`` will return a Javascript + representation of the object and the operations it encodes. + """ + + _schema = {"type": "string"} + + def to_dict(self, *args, **kwargs): + return repr(self) + + def __setattr__(self, attr, val): + # We don't need the setattr magic defined in SchemaBase + return object.__setattr__(self, attr, val) + + # item access + def __getitem__(self, val): + return GetItemExpression(self, val) + + +class UnaryExpression(Expression): + def __init__(self, op, val): + super(UnaryExpression, self).__init__(op=op, val=val) + + def __repr__(self): + return "({op}{val})".format(op=self.op, val=_js_repr(self.val)) + + +class BinaryExpression(Expression): + def __init__(self, op, lhs, rhs): + super(BinaryExpression, self).__init__(op=op, lhs=lhs, rhs=rhs) + + def __repr__(self): + return "({lhs} {op} {rhs})".format( + op=self.op, lhs=_js_repr(self.lhs), rhs=_js_repr(self.rhs) + ) + + +class FunctionExpression(Expression): + def __init__(self, name, args): + super(FunctionExpression, self).__init__(name=name, args=args) + + def __repr__(self): + args = ",".join(_js_repr(arg) for arg in self.args) + return "{name}({args})".format(name=self.name, args=args) + + +class ConstExpression(Expression): + def __init__(self, name, doc): + self.__doc__ = """{}: {}""".format(name, doc) + super(ConstExpression, self).__init__(name=name, doc=doc) + + def __repr__(self): + return str(self.name) + + +class GetAttrExpression(Expression): + def __init__(self, group, name): + super(GetAttrExpression, self).__init__(group=group, name=name) + + def __repr__(self): + return "{}.{}".format(self.group, self.name) + + +class GetItemExpression(Expression): + def __init__(self, group, name): + super(GetItemExpression, self).__init__(group=group, name=name) + + def __repr__(self): + return "{}[{!r}]".format(self.group, self.name) diff --git a/llm/Lib/site-packages/altair/expr/funcs.py b/llm/Lib/site-packages/altair/expr/funcs.py new file mode 100644 index 0000000000000000000000000000000000000000..c4a73f4c9d118f9c64163086445eb2448630daea --- /dev/null +++ b/llm/Lib/site-packages/altair/expr/funcs.py @@ -0,0 +1,192 @@ +from .core import FunctionExpression + + +FUNCTION_LISTING = { + "isArray": r"Returns true if _value_ is an array, false otherwise.", + "isBoolean": r"Returns true if _value_ is a boolean (`true` or `false`), false otherwise.", + "isDate": r"Returns true if _value_ is a Date object, false otherwise. This method will return false for timestamp numbers or date-formatted strings; it recognizes Date objects only.", + "isDefined": r"Returns true if _value_ is a defined value, false if _value_ equals `undefined`. This method will return true for `null` and `NaN` values.", + "isNumber": r"Returns true if _value_ is a number, false otherwise. `NaN` and `Infinity` are considered numbers.", + "isObject": r"Returns true if _value_ is an object (including arrays and Dates), false otherwise.", + "isRegExp": r"Returns true if _value_ is a RegExp (regular expression) object, false otherwise.", + "isString": r"Returns true if _value_ is a string, false otherwise.", + "isValid": r"Returns true if _value_ is not `null`, `undefined`, or `NaN`, false otherwise.", + "toBoolean": r"Coerces the input _value_ to a string. Null values and empty strings are mapped to `null`.", + "toDate": r"Coerces the input _value_ to a Date instance. Null values and empty strings are mapped to `null`. If an optional _parser_ function is provided, it is used to perform date parsing, otherwise `Date.parse` is used. Be aware that `Date.parse` has different implementations across browsers!", + "toNumber": r"Coerces the input _value_ to a number. Null values and empty strings are mapped to `null`.", + "toString": r"Coerces the input _value_ to a string. Null values and empty strings are mapped to `null`.", + "if": r"If _test_ is truthy, returns _thenValue_. Otherwise, returns _elseValue_. The _if_ function is equivalent to the ternary operator `a ? b : c`.", + "isNaN": r"Returns true if _value_ is not a number. Same as JavaScript's `isNaN`.", + "isFinite": r"Returns true if _value_ is a finite number. Same as JavaScript's `isFinite`.", + "abs": r"Returns the absolute value of _value_. Same as JavaScript's `Math.abs`.", + "acos": r"Trigonometric arccosine. Same as JavaScript's `Math.acos`.", + "asin": r"Trigonometric arcsine. Same as JavaScript's `Math.asin`.", + "atan": r"Trigonometric arctangent. Same as JavaScript's `Math.atan`.", + "atan2": r"Returns the arctangent of _dy / dx_. Same as JavaScript's `Math.atan2`.", + "ceil": r"Rounds _value_ to the nearest integer of equal or greater value. Same as JavaScript's `Math.ceil`.", + "clamp": r"Restricts _value_ to be between the specified _min_ and _max_.", + "cos": r"Trigonometric cosine. Same as JavaScript's `Math.cos`.", + "exp": r"Returns the value of _e_ raised to the provided _exponent_. Same as JavaScript's `Math.exp`.", + "floor": r"Rounds _value_ to the nearest integer of equal or lower value. Same as JavaScript's `Math.floor`.", + "hypot": r"Returns the square root of the sum of squares of its arguments. Same as JavaScript's `Math.hypot`.", + "log": r"Returns the natural logarithm of _value_. Same as JavaScript's `Math.log`.", + "max": r"Returns the maximum argument value. Same as JavaScript's `Math.max`.", + "min": r"Returns the minimum argument value. Same as JavaScript's `Math.min`.", + "pow": r"Returns _value_ raised to the given _exponent_. Same as JavaScript's `Math.pow`.", + "random": r"Returns a pseudo-random number in the range [0,1). Same as JavaScript's `Math.random`.", + "round": r"Rounds _value_ to the nearest integer. Same as JavaScript's `Math.round`.", + "sin": r"Trigonometric sine. Same as JavaScript's `Math.sin`.", + "sqrt": r"Square root function. Same as JavaScript's `Math.sqrt`.", + "tan": r"Trigonometric tangent. Same as JavaScript's `Math.tan`.", + "sampleNormal": r"Returns a sample from a univariate [normal (Gaussian) probability distribution](https://en.wikipedia.org/wiki/Normal_distribution) with specified _mean_ and standard deviation _stdev_. If unspecified, the mean defaults to `0` and the standard deviation defaults to `1`.", + "cumulativeNormal": r"Returns the value of the [cumulative distribution function](https://en.wikipedia.org/wiki/Cumulative_distribution_function) at the given input domain _value_ for a normal distribution with specified _mean_ and standard deviation _stdev_. If unspecified, the mean defaults to `0` and the standard deviation defaults to `1`.", + "densityNormal": r"Returns the value of the [probability density function](https://en.wikipedia.org/wiki/Probability_density_function) at the given input domain _value_, for a normal distribution with specified _mean_ and standard deviation _stdev_. If unspecified, the mean defaults to `0` and the standard deviation defaults to `1`.", + "quantileNormal": r"Returns the quantile value (the inverse of the [cumulative distribution function](https://en.wikipedia.org/wiki/Cumulative_distribution_function)) for the given input _probability_, for a normal distribution with specified _mean_ and standard deviation _stdev_. If unspecified, the mean defaults to `0` and the standard deviation defaults to `1`.", + "sampleLogNormal": r"Returns a sample from a univariate [log-normal probability distribution](https://en.wikipedia.org/wiki/Log-normal_distribution) with specified log _mean_ and log standard deviation _stdev_. If unspecified, the log mean defaults to `0` and the log standard deviation defaults to `1`.", + "cumulativeLogNormal": r"Returns the value of the [cumulative distribution function](https://en.wikipedia.org/wiki/Cumulative_distribution_function) at the given input domain _value_ for a log-normal distribution with specified log _mean_ and log standard deviation _stdev_. If unspecified, the log mean defaults to `0` and the log standard deviation defaults to `1`.", + "densityLogNormal": r"Returns the value of the [probability density function](https://en.wikipedia.org/wiki/Probability_density_function) at the given input domain _value_, for a log-normal distribution with specified log _mean_ and log standard deviation _stdev_. If unspecified, the log mean defaults to `0` and the log standard deviation defaults to `1`.", + "quantileLogNormal": r"Returns the quantile value (the inverse of the [cumulative distribution function](https://en.wikipedia.org/wiki/Cumulative_distribution_function)) for the given input _probability_, for a log-normal distribution with specified log _mean_ and log standard deviation _stdev_. If unspecified, the log mean defaults to `0` and the log standard deviation defaults to `1`.", + "sampleUniform": r"Returns a sample from a univariate [continuous uniform probability distribution](https://en.wikipedia.org/wiki/Uniform_distribution_(continuous)) over the interval [_min_, _max_). If unspecified, _min_ defaults to `0` and _max_ defaults to `1`. If only one argument is provided, it is interpreted as the _max_ value.", + "cumulativeUniform": r"Returns the value of the [cumulative distribution function](https://en.wikipedia.org/wiki/Cumulative_distribution_function) at the given input domain _value_ for a uniform distribution over the interval [_min_, _max_). If unspecified, _min_ defaults to `0` and _max_ defaults to `1`. If only one argument is provided, it is interpreted as the _max_ value.", + "densityUniform": r"Returns the value of the [probability density function](https://en.wikipedia.org/wiki/Probability_density_function) at the given input domain _value_, for a uniform distribution over the interval [_min_, _max_). If unspecified, _min_ defaults to `0` and _max_ defaults to `1`. If only one argument is provided, it is interpreted as the _max_ value.", + "quantileUniform": r"Returns the quantile value (the inverse of the [cumulative distribution function](https://en.wikipedia.org/wiki/Cumulative_distribution_function)) for the given input _probability_, for a uniform distribution over the interval [_min_, _max_). If unspecified, _min_ defaults to `0` and _max_ defaults to `1`. If only one argument is provided, it is interpreted as the _max_ value.", + "now": r"Returns the timestamp for the current time.", + "datetime": r"Returns a new `Date` instance. The _month_ is 0-based, such that `1` represents February.", + "date": r"Returns the day of the month for the given _datetime_ value, in local time.", + "day": r"Returns the day of the week for the given _datetime_ value, in local time.", + "dayofyear": r"Returns the one-based day of the year for the given _datetime_ value, in local time.", + "year": r"Returns the year for the given _datetime_ value, in local time.", + "quarter": r"Returns the quarter of the year (0-3) for the given _datetime_ value, in local time.", + "month": r"Returns the (zero-based) month for the given _datetime_ value, in local time.", + "week": r"Returns the week number of the year for the given _datetime_, in local time. This function assumes Sunday-based weeks. Days before the first Sunday of the year are considered to be in week 0, the first Sunday of the year is the start of week 1, the second Sunday week 2, _etc._.", + "hours": r"Returns the hours component for the given _datetime_ value, in local time.", + "minutes": r"Returns the minutes component for the given _datetime_ value, in local time.", + "seconds": r"Returns the seconds component for the given _datetime_ value, in local time.", + "milliseconds": r"Returns the milliseconds component for the given _datetime_ value, in local time.", + "time": r"Returns the epoch-based timestamp for the given _datetime_ value.", + "timezoneoffset": r"Returns the timezone offset from the local timezone to UTC for the given _datetime_ value.", + "timeOffset": r"Returns a new `Date` instance that offsets the given _date_ by the specified time [_unit_](../api/time/#time-units) in the local timezone. The optional _step_ argument indicates the number of time unit steps to offset by (default 1).", + "timeSequence": r"Returns an array of `Date` instances from _start_ (inclusive) to _stop_ (exclusive), with each entry separated by the given time [_unit_](../api/time/#time-units) in the local timezone. The optional _step_ argument indicates the number of time unit steps to take between each sequence entry (default 1).", + "utc": r"Returns a timestamp for the given UTC date. The _month_ is 0-based, such that `1` represents February.", + "utcdate": r"Returns the day of the month for the given _datetime_ value, in UTC time.", + "utcday": r"Returns the day of the week for the given _datetime_ value, in UTC time.", + "utcdayofyear": r"Returns the one-based day of the year for the given _datetime_ value, in UTC time.", + "utcyear": r"Returns the year for the given _datetime_ value, in UTC time.", + "utcquarter": r"Returns the quarter of the year (0-3) for the given _datetime_ value, in UTC time.", + "utcmonth": r"Returns the (zero-based) month for the given _datetime_ value, in UTC time.", + "utcweek": r"Returns the week number of the year for the given _datetime_, in UTC time. This function assumes Sunday-based weeks. Days before the first Sunday of the year are considered to be in week 0, the first Sunday of the year is the start of week 1, the second Sunday week 2, _etc._.", + "utchours": r"Returns the hours component for the given _datetime_ value, in UTC time.", + "utcminutes": r"Returns the minutes component for the given _datetime_ value, in UTC time.", + "utcseconds": r"Returns the seconds component for the given _datetime_ value, in UTC time.", + "utcmilliseconds": r"Returns the milliseconds component for the given _datetime_ value, in UTC time.", + "utcOffset": r"Returns a new `Date` instance that offsets the given _date_ by the specified time [_unit_](../api/time/#time-units) in UTC time. The optional _step_ argument indicates the number of time unit steps to offset by (default 1).", + "utcSequence": r"Returns an array of `Date` instances from _start_ (inclusive) to _stop_ (exclusive), with each entry separated by the given time [_unit_](../api/time/#time-units) in UTC time. The optional _step_ argument indicates the number of time unit steps to take between each sequence entry (default 1).", + "extent": r"Returns a new _[min, max]_ array with the minimum and maximum values of the input array, ignoring `null`, `undefined`, and `NaN` values.", + "clampRange": r"Clamps a two-element _range_ array in a span-preserving manner. If the span of the input _range_ is less than _(max - min)_ and an endpoint exceeds either the _min_ or _max_ value, the range is translated such that the span is preserved and one endpoint touches the boundary of the _[min, max]_ range. If the span exceeds _(max - min)_, the range _[min, max]_ is returned.", + "indexof": r"Returns the first index of _value_ in the input _array_, or the first index of _substring_ in the input _string_..", + "inrange": r"Tests whether _value_ lies within (or is equal to either) the first and last values of the _range_ array.", + "join": r"Returns a new string by concatenating all of the elements of the input _array_, separated by commas or a specified _separator_ string.", + "lastindexof": r"Returns the last index of _value_ in the input _array_, or the last index of _substring_ in the input _string_..", + "length": r"Returns the length of the input _array_, or the length of the input _string_.", + "lerp": r"Returns the linearly interpolated value between the first and last entries in the _array_ for the provided interpolation _fraction_ (typically between 0 and 1). For example, `lerp([0, 50], 0.5)` returns 25.", + "peek": r"Returns the last element in the input _array_. Similar to the built-in `Array.pop` method, except that it does not remove the last element. This method is a convenient shorthand for `array[array.length - 1]`.", + "pluck": r"Retrieves the value for the specified *field* from a given *array* of objects. The input *field* string may include nested properties (e.g., `foo.bar.bz`).", + "reverse": r"Returns a new array with elements in a reverse order of the input _array_. The first array element becomes the last, and the last array element becomes the first.", + "sequence": r"Returns an array containing an arithmetic sequence of numbers. If _step_ is omitted, it defaults to 1. If _start_ is omitted, it defaults to 0. The _stop_ value is exclusive; it is not included in the result. If _step_ is positive, the last element is the largest _start + i * step_ less than _stop_; if _step_ is negative, the last element is the smallest _start + i * step_ greater than _stop_. If the returned array would contain an infinite number of values, an empty range is returned. The arguments are not required to be integers.", + "slice": r"Returns a section of _array_ between the _start_ and _end_ indices. If the _end_ argument is negative, it is treated as an offset from the end of the array (_length(array) + end_).", + "span": r"Returns the span of _array_: the difference between the last and first elements, or _array[array.length-1] - array[0]_. Or if input is a string: a section of _string_ between the _start_ and _end_ indices. If the _end_ argument is negative, it is treated as an offset from the end of the string (_length(string) + end_)..", + "lower": r"Transforms _string_ to lower-case letters.", + "pad": r"Pads a _string_ value with repeated instances of a _character_ up to a specified _length_. If _character_ is not specified, a space (' ') is used. By default, padding is added to the end of a string. An optional _align_ parameter specifies if padding should be added to the `'left'` (beginning), `'center'`, or `'right'` (end) of the input string.", + "parseFloat": r"Parses the input _string_ to a floating-point value. Same as JavaScript's `parseFloat`.", + "parseInt": r"Parses the input _string_ to an integer value. Same as JavaScript's `parseInt`.", + "replace": r"Returns a new string with some or all matches of _pattern_ replaced by a _replacement_ string. The _pattern_ can be a string or a regular expression. If _pattern_ is a string, only the first instance will be replaced. Same as [JavaScript's String.replace](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/String/replace).", + "split": r"Returns an array of tokens created by splitting the input _string_ according to a provided _separator_ pattern. The result can optionally be constrained to return at most _limit_ tokens.", + "substring": r"Returns a section of _string_ between the _start_ and _end_ indices.", + "trim": r"Returns a trimmed string with preceding and trailing whitespace removed.", + "truncate": r"Truncates an input _string_ to a target _length_. The optional _align_ argument indicates what part of the string should be truncated: `'left'` (the beginning), `'center'`, or `'right'` (the end). By default, the `'right'` end of the string is truncated. The optional _ellipsis_ argument indicates the string to use to indicate truncated content; by default the ellipsis character `...` (`\\u2026`) is used.", + "upper": r"Transforms _string_ to upper-case letters.", + "merge": r"Merges the input objects _object1_, _object2_, etc into a new output object. Inputs are visited in sequential order, such that key values from later arguments can overwrite those from earlier arguments. Example: `merge({a:1, b:2}, {a:3}) -> {a:3, b:2}`.", + "dayFormat": r"Formats a (0-6) _weekday_ number as a full week day name, according to the current locale. For example: `dayFormat(0) -> \"Sunday\"`.", + "dayAbbrevFormat": r"Formats a (0-6) _weekday_ number as an abbreviated week day name, according to the current locale. For example: `dayAbbrevFormat(0) -> \"Sun\"`.", + "format": r"Formats a numeric _value_ as a string. The _specifier_ must be a valid [d3-format specifier](https://github.com/d3/d3-format/) (e.g., `format(value, ',.2f')`.", + "monthFormat": r"Formats a (zero-based) _month_ number as a full month name, according to the current locale. For example: `monthFormat(0) -> \"January\"`.", + "monthAbbrevFormat": r"Formats a (zero-based) _month_ number as an abbreviated month name, according to the current locale. For example: `monthAbbrevFormat(0) -> \"Jan\"`.", + "timeUnitSpecifier": r"Returns a time format specifier string for the given time [_units_](../api/time/#time-units). The optional _specifiers_ object provides a set of specifier sub-strings for customizing the format; for more, see the [timeUnitSpecifier API documentation](../api/time/#timeUnitSpecifier). The resulting specifier string can then be used as input to the [timeFormat](#timeFormat) or [utcFormat](#utcFormat) functions, or as the _format_ parameter of an axis or legend. For example: `timeFormat(date, timeUnitSpecifier('year'))` or `timeFormat(date, timeUnitSpecifier(['hours', 'minutes']))`.", + "timeFormat": r"Formats a datetime _value_ (either a `Date` object or timestamp) as a string, according to the local time. The _specifier_ must be a valid [d3-time-format specifier](https://github.com/d3/d3-time-format/). For example: `timeFormat(timestamp, '%A')`.", + "timeParse": r"Parses a _string_ value to a Date object, according to the local time. The _specifier_ must be a valid [d3-time-format specifier](https://github.com/d3/d3-time-format/). For example: `timeParse('June 30, 2015', '%B %d, %Y')`.", + "utcFormat": r"Formats a datetime _value_ (either a `Date` object or timestamp) as a string, according to [UTC](https://en.wikipedia.org/wiki/Coordinated_Universal_Time) time. The _specifier_ must be a valid [d3-time-format specifier](https://github.com/d3/d3-time-format/). For example: `utcFormat(timestamp, '%A')`.", + "utcParse": r"Parses a _string_ value to a Date object, according to [UTC](https://en.wikipedia.org/wiki/Coordinated_Universal_Time) time. The _specifier_ must be a valid [d3-time-format specifier](https://github.com/d3/d3-time-format/). For example: `utcParse('June 30, 2015', '%B %d, %Y')`.", + "regexp": r"Creates a regular expression instance from an input _pattern_ string and optional _flags_. Same as [JavaScript's `RegExp`](https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/RegExp).", + "test": r"Evaluates a regular expression _regexp_ against the input _string_, returning `true` if the string matches the pattern, `false` otherwise. For example: `test(/\\d{3}/, \"32-21-9483\") -> true`.", + "rgb": r"Constructs a new [RGB](https://en.wikipedia.org/wiki/RGB_color_model) color. If _r_, _g_ and _b_ are specified, these represent the channel values of the returned color; an _opacity_ may also be specified. If a CSS Color Module Level 3 _specifier_ string is specified, it is parsed and then converted to the RGB color space. Uses [d3-color's rgb function](https://github.com/d3/d3-color#rgb).", + "hsl": r"Constructs a new [HSL](https://en.wikipedia.org/wiki/HSL_and_HSV) color. If _h_, _s_ and _l_ are specified, these represent the channel values of the returned color; an _opacity_ may also be specified. If a CSS Color Module Level 3 _specifier_ string is specified, it is parsed and then converted to the HSL color space. Uses [d3-color's hsl function](https://github.com/d3/d3-color#hsl).", + "lab": r"Constructs a new [CIE LAB](https://en.wikipedia.org/wiki/Lab_color_space#CIELAB) color. If _l_, _a_ and _b_ are specified, these represent the channel values of the returned color; an _opacity_ may also be specified. If a CSS Color Module Level 3 _specifier_ string is specified, it is parsed and then converted to the LAB color space. Uses [d3-color's lab function](https://github.com/d3/d3-color#lab).", + "hcl": r"Constructs a new [HCL](https://en.wikipedia.org/wiki/Lab_color_space#CIELAB) (hue, chroma, luminance) color. If _h_, _c_ and _l_ are specified, these represent the channel values of the returned color; an _opacity_ may also be specified. If a CSS Color Module Level 3 _specifier_ string is specified, it is parsed and then converted to the HCL color space. Uses [d3-color's hcl function](https://github.com/d3/d3-color#hcl).", + "luminance": r"Returns the luminance for the given color _specifier_ (compatible with [d3-color's rgb function](https://github.com/d3/d3-color#rgb)). The luminance is calculated according to the [W3C Web Content Accessibility Guidelines](https://www.w3.org/TR/2008/REC-WCAG20-20081211/#relativeluminancedef).", + "contrast": r"Returns the contrast ratio between the input color specifiers as a float between 1 and 21. The contrast is calculated according to the [W3C Web Content Accessibility Guidelines](https://www.w3.org/TR/2008/REC-WCAG20-20081211/#contrast-ratiodef).", + "item": r"Returns the current scenegraph item that is the target of the event.", + "group": r"Returns the scenegraph group mark item in which the current event has occurred. If no arguments are provided, the immediate parent group is returned. If a group name is provided, the matching ancestor group item is returned.", + "xy": r"Returns the x- and y-coordinates for the current event as a two-element array. If no arguments are provided, the top-level coordinate space of the view is used. If a scenegraph _item_ (or string group name) is provided, the coordinate space of the group item is used.", + "x": r"Returns the x coordinate for the current event. If no arguments are provided, the top-level coordinate space of the view is used. If a scenegraph _item_ (or string group name) is provided, the coordinate space of the group item is used.", + "y": r"Returns the y coordinate for the current event. If no arguments are provided, the top-level coordinate space of the view is used. If a scenegraph _item_ (or string group name) is provided, the coordinate space of the group item is used.", + "pinchDistance": r"Returns the pixel distance between the first two touch points of a multi-touch event.", + "pinchAngle": r"Returns the angle of the line connecting the first two touch points of a multi-touch event.", + "inScope": r"Returns true if the given scenegraph _item_ is a descendant of the group mark in which the event handler was defined, false otherwise.", + "data": r"Returns the array of data objects for the Vega data set with the given _name_. If the data set is not found, returns an empty array.", + "indata": r"Tests if the data set with a given _name_ contains a datum with a _field_ value that matches the input _value_. For example: `indata('table', 'category', value)`.", + "scale": r"Applies the named scale transform (or projection) to the specified _value_. The optional _group_ argument takes a scenegraph group mark item to indicate the specific scope in which to look up the scale or projection.", + "invert": r"Inverts the named scale transform (or projection) for the specified _value_. The optional _group_ argument takes a scenegraph group mark item to indicate the specific scope in which to look up the scale or projection.", + "copy": r"Returns a copy (a new cloned instance) of the named scale transform of projection, or `undefined` if no scale or projection is found. The optional _group_ argument takes a scenegraph group mark item to indicate the specific scope in which to look up the scale or projection.", + "domain": r"Returns the scale domain array for the named scale transform, or an empty array if the scale is not found. The optional _group_ argument takes a scenegraph group mark item to indicate the specific scope in which to look up the scale.", + "range": r"Returns the scale range array for the named scale transform, or an empty array if the scale is not found. The optional _group_ argument takes a scenegraph group mark item to indicate the specific scope in which to look up the scale.", + "bandwidth": r"Returns the current band width for the named band scale transform, or zero if the scale is not found or is not a band scale. The optional _group_ argument takes a scenegraph group mark item to indicate the specific scope in which to look up the scale.", + "bandspace": r"Returns the number of steps needed within a band scale, based on the _count_ of domain elements and the inner and outer padding values. While normally calculated within the scale itself, this function can be helpful for determining the size of a chart's layout.", + "gradient": r"Returns a linear color gradient for the _scale_ (whose range must be a [continuous color scheme](../schemes)) and starting and ending points _p0_ and _p1_, each an _[x, y]_ array. The points _p0_ and _p1_ should be expressed in normalized coordinates in the domain [0, 1], relative to the bounds of the item being colored. If unspecified, _p0_ defaults to `[0, 0]` and _p1_ defaults to `[1, 0]`, for a horizontal gradient that spans the full bounds of an item. The optional _count_ argument indicates a desired target number of sample points to take from the color scale.", + "panLinear": r"Given a linear scale _domain_ array with numeric or datetime values, returns a new two-element domain array that is the result of panning the domain by a fractional _delta_. The _delta_ value represents fractional units of the scale range; for example, `0.5` indicates panning the scale domain to the right by half the scale range.", + "panLog": r"Given a log scale _domain_ array with numeric or datetime values, returns a new two-element domain array that is the result of panning the domain by a fractional _delta_. The _delta_ value represents fractional units of the scale range; for example, `0.5` indicates panning the scale domain to the right by half the scale range.", + "panPow": r"Given a power scale _domain_ array with numeric or datetime values and the given _exponent_, returns a new two-element domain array that is the result of panning the domain by a fractional _delta_. The _delta_ value represents fractional units of the scale range; for example, `0.5` indicates panning the scale domain to the right by half the scale range.", + "panSymlog": r"Given a symmetric log scale _domain_ array with numeric or datetime values parameterized by the given _constant_, returns a new two-element domain array that is the result of panning the domain by a fractional _delta_. The _delta_ value represents fractional units of the scale range; for example, `0.5` indicates panning the scale domain to the right by half the scale range.", + "zoomLinear": r"Given a linear scale _domain_ array with numeric or datetime values, returns a new two-element domain array that is the result of zooming the domain by a _scaleFactor_, centered at the provided fractional _anchor_. The _anchor_ value represents the zoom position in terms of fractional units of the scale range; for example, `0.5` indicates a zoom centered on the mid-point of the scale range.", + "zoomLog": r"Given a log scale _domain_ array with numeric or datetime values, returns a new two-element domain array that is the result of zooming the domain by a _scaleFactor_, centered at the provided fractional _anchor_. The _anchor_ value represents the zoom position in terms of fractional units of the scale range; for example, `0.5` indicates a zoom centered on the mid-point of the scale range.", + "zoomPow": r"Given a power scale _domain_ array with numeric or datetime values and the given _exponent_, returns a new two-element domain array that is the result of zooming the domain by a _scaleFactor_, centered at the provided fractional _anchor_. The _anchor_ value represents the zoom position in terms of fractional units of the scale range; for example, `0.5` indicates a zoom centered on the mid-point of the scale range.", + "zoomSymlog": r"Given a symmetric log scale _domain_ array with numeric or datetime values parameterized by the given _constant_, returns a new two-element domain array that is the result of zooming the domain by a _scaleFactor_, centered at the provided fractional _anchor_. The _anchor_ value represents the zoom position in terms of fractional units of the scale range; for example, `0.5` indicates a zoom centered on the mid-point of the scale range.", + "geoArea": r"Returns the projected planar area (typically in square pixels) of a GeoJSON _feature_ according to the named _projection_. If the _projection_ argument is `null`, computes the spherical area in steradians using unprojected longitude, latitude coordinates. The optional _group_ argument takes a scenegraph group mark item to indicate the specific scope in which to look up the projection. Uses d3-geo's [geoArea](https://github.com/d3/d3-geo#geoArea) and [path.area](https://github.com/d3/d3-geo#path_area) methods.", + "geoBounds": r"Returns the projected planar bounding box (typically in pixels) for the specified GeoJSON _feature_, according to the named _projection_. The bounding box is represented by a two-dimensional array: [[_x0_, _y0_], [_x1_, _y1_]], where _x0_ is the minimum x-coordinate, _y0_ is the minimum y-coordinate, _x1_ is the maximum x-coordinate, and _y1_ is the maximum y-coordinate. If the _projection_ argument is `null`, computes the spherical bounding box using unprojected longitude, latitude coordinates. The optional _group_ argument takes a scenegraph group mark item to indicate the specific scope in which to look up the projection. Uses d3-geo's [geoBounds](https://github.com/d3/d3-geo#geoBounds) and [path.bounds](https://github.com/d3/d3-geo#path_bounds) methods.", + "geoCentroid": r"Returns the projected planar centroid (typically in pixels) for the specified GeoJSON _feature_, according to the named _projection_. If the _projection_ argument is `null`, computes the spherical centroid using unprojected longitude, latitude coordinates. The optional _group_ argument takes a scenegraph group mark item to indicate the specific scope in which to look up the projection. Uses d3-geo's [geoCentroid](https://github.com/d3/d3-geo#geoCentroid) and [path.centroid](https://github.com/d3/d3-geo#path_centroid) methods.", + "treePath": r"For the hierarchy data set with the given _name_, returns the shortest path through from the _source_ node id to the _target_ node id. The path starts at the _source_ node, ascends to the least common ancestor of the _source_ node and the _target_ node, and then descends to the _target_ node.", + "treeAncestors": r"For the hierarchy data set with the given _name_, returns the array of ancestors nodes, starting with the input _node_, then followed by each parent up to the root.", + "containerSize": r"Returns the current CSS box size (`[el.clientWidth, el.clientHeight]`) of the parent DOM element that contains the Vega view. If there is no container element, returns `[undefined, undefined]`.", + "screen": r"Returns the [`window.screen`](https://developer.mozilla.org/en-US/docs/Web/API/Window/screen) object, or `{}` if Vega is not running in a browser environment.", + "windowSize": r"Returns the current window size (`[window.innerWidth, window.innerHeight]`) or `[undefined, undefined]` if Vega is not running in a browser environment.", + "warn": r"Logs a warning message and returns the last argument. For the message to appear in the console, the visualization view must have the appropriate logging level set.", + "info": r"Logs an informative message and returns the last argument. For the message to appear in the console, the visualization view must have the appropriate logging level set.", + "debug": r"Logs a debugging message and returns the last argument. For the message to appear in the console, the visualization view must have the appropriate logging level set.", +} + + +# This maps vega expression function names to the Python name +NAME_MAP = {"if": "if_"} + + +class ExprFunc: + def __init__(self, name, doc): + self.name = name + self.doc = doc + self.__doc__ = """{}(*args)\n {}""".format(name, doc) + + def __call__(self, *args): + return FunctionExpression(self.name, args) + + def __repr__(self): + return "".format(self.name) + + +def _populate_namespace(): + globals_ = globals() + for name, doc in FUNCTION_LISTING.items(): + py_name = NAME_MAP.get(name, name) + globals_[py_name] = ExprFunc(name, doc) + yield py_name + + +__all__ = list(_populate_namespace()) diff --git a/llm/Lib/site-packages/altair/jupyter/__init__.py b/llm/Lib/site-packages/altair/jupyter/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..651ab11e4cf8de15370bbf02efd36315c1d27e82 --- /dev/null +++ b/llm/Lib/site-packages/altair/jupyter/__init__.py @@ -0,0 +1,20 @@ +try: + import anywidget # noqa: F401 +except ImportError: + # When anywidget isn't available, create stand-in JupyterChart class + # that raises an informative import error on construction. This + # way we can make JupyterChart available in the altair namespace + # when anywidget is not installed + class JupyterChart: + def __init__(self, *args, **kwargs): + raise ImportError( + "The Altair JupyterChart requires the anywidget \n" + "Python package which may be installed using pip with\n" + " pip install anywidget\n" + "or using conda with\n" + " conda install -c conda-forge anywidget\n" + "Afterwards, you will need to restart your Python kernel." + ) + +else: + from .jupyter_chart import JupyterChart # noqa: F401 diff --git a/llm/Lib/site-packages/altair/jupyter/__pycache__/__init__.cpython-311.pyc b/llm/Lib/site-packages/altair/jupyter/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03341aa0d6ae96e854a1db8e7729ab692abaa20d Binary files /dev/null and b/llm/Lib/site-packages/altair/jupyter/__pycache__/__init__.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/altair/jupyter/__pycache__/jupyter_chart.cpython-311.pyc b/llm/Lib/site-packages/altair/jupyter/__pycache__/jupyter_chart.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a794901c42c8d06e1c26d39f1e9ae1e8c49f037 Binary files /dev/null and b/llm/Lib/site-packages/altair/jupyter/__pycache__/jupyter_chart.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/altair/jupyter/js/README.md b/llm/Lib/site-packages/altair/jupyter/js/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f1ec545894f60fea2a2096b4ac4b588c890b5192 --- /dev/null +++ b/llm/Lib/site-packages/altair/jupyter/js/README.md @@ -0,0 +1,2 @@ +# JupyterChart +This directory contains the JavaScript portion of the Altair `JupyterChart`. The `JupyterChart` is based on the [AnyWidget](https://anywidget.dev/) project. diff --git a/llm/Lib/site-packages/altair/jupyter/js/index.js b/llm/Lib/site-packages/altair/jupyter/js/index.js new file mode 100644 index 0000000000000000000000000000000000000000..ba080010563c38e910f238b2b6bb31d54a41b4e0 --- /dev/null +++ b/llm/Lib/site-packages/altair/jupyter/js/index.js @@ -0,0 +1,230 @@ +import vegaEmbed from "https://esm.sh/vega-embed@6?deps=vega@5&deps=vega-lite@5.17.0"; +import lodashDebounce from "https://esm.sh/lodash-es@4.17.21/debounce"; + +// Note: For offline support, the import lines above are removed and the remaining script +// is bundled using vl-convert's javascript_bundle function. See the documentation of +// the javascript_bundle function for details on the available imports and their names. +// If an additional import is required in the future, it will need to be added to vl-convert +// in order to preserve offline support. +async function render({ model, el }) { + let finalize; + + function showError(error){ + el.innerHTML = ( + '
    ' + + '

    JavaScript Error: ' + error.message + '

    ' + + "

    This usually means there's a typo in your chart specification. " + + "See the javascript console for the full traceback.

    " + + '
    ' + ); + } + + const reembed = async () => { + if (finalize != null) { + finalize(); + } + + model.set("local_tz", Intl.DateTimeFormat().resolvedOptions().timeZone); + + let spec = structuredClone(model.get("spec")); + if (spec == null) { + // Remove any existing chart and return + while (el.firstChild) { + el.removeChild(el.lastChild); + } + model.save_changes(); + return; + } + let embedOptions = structuredClone(model.get("embed_options")) ?? undefined; + + let api; + try { + api = await vegaEmbed(el, spec, embedOptions); + } catch (error) { + showError(error) + return; + } + + finalize = api.finalize; + + // Debounce config + const wait = model.get("debounce_wait") ?? 10; + const debounceOpts = {leading: false, trailing: true}; + if (model.get("max_wait") ?? true) { + debounceOpts["maxWait"] = wait; + } + + const initialSelections = {}; + for (const selectionName of Object.keys(model.get("_vl_selections"))) { + const storeName = `${selectionName}_store`; + const selectionHandler = (_, value) => { + const newSelections = cleanJson(model.get("_vl_selections") ?? {}); + const store = cleanJson(api.view.data(storeName) ?? []); + + newSelections[selectionName] = {value, store}; + model.set("_vl_selections", newSelections); + model.save_changes(); + }; + api.view.addSignalListener(selectionName, lodashDebounce(selectionHandler, wait, debounceOpts)); + + initialSelections[selectionName] = { + value: cleanJson(api.view.signal(selectionName) ?? {}), + store: cleanJson(api.view.data(storeName) ?? []) + } + } + model.set("_vl_selections", initialSelections); + + const initialParams = {}; + for (const paramName of Object.keys(model.get("_params"))) { + const paramHandler = (_, value) => { + const newParams = JSON.parse(JSON.stringify(model.get("_params"))) || {}; + newParams[paramName] = value; + model.set("_params", newParams); + model.save_changes(); + }; + api.view.addSignalListener(paramName, lodashDebounce(paramHandler, wait, debounceOpts)); + + initialParams[paramName] = api.view.signal(paramName) ?? null + } + model.set("_params", initialParams); + model.save_changes(); + + // Param change callback + model.on('change:_params', async (new_params) => { + for (const [param, value] of Object.entries(new_params.changed._params)) { + api.view.signal(param, value); + } + await api.view.runAsync(); + }); + + // Add signal/data listeners + for (const watch of model.get("_js_watch_plan") ?? []) { + if (watch.namespace === "data") { + const dataHandler = (_, value) => { + model.set("_js_to_py_updates", [{ + namespace: "data", + name: watch.name, + scope: watch.scope, + value: cleanJson(value) + }]); + model.save_changes(); + }; + addDataListener(api.view, watch.name, watch.scope, lodashDebounce(dataHandler, wait, debounceOpts)) + + } else if (watch.namespace === "signal") { + const signalHandler = (_, value) => { + model.set("_js_to_py_updates", [{ + namespace: "signal", + name: watch.name, + scope: watch.scope, + value: cleanJson(value) + }]); + model.save_changes(); + }; + + addSignalListener(api.view, watch.name, watch.scope, lodashDebounce(signalHandler, wait, debounceOpts)) + } + } + + // Add signal/data updaters + model.on('change:_py_to_js_updates', async (updates) => { + for (const update of updates.changed._py_to_js_updates ?? []) { + if (update.namespace === "signal") { + setSignalValue(api.view, update.name, update.scope, update.value); + } else if (update.namespace === "data") { + setDataValue(api.view, update.name, update.scope, update.value); + } + } + await api.view.runAsync(); + }); + } + + model.on('change:spec', reembed); + model.on('change:embed_options', reembed); + model.on('change:debounce_wait', reembed); + model.on('change:max_wait', reembed); + await reembed(); +} + +function cleanJson(data) { + return JSON.parse(JSON.stringify(data)) +} + +function getNestedRuntime(view, scope) { + var runtime = view._runtime; + for (const index of scope) { + runtime = runtime.subcontext[index]; + } + return runtime +} + +function lookupSignalOp(view, name, scope) { + let parent_runtime = getNestedRuntime(view, scope); + return parent_runtime.signals[name] ?? null; +} + +function dataRef(view, name, scope) { + let parent_runtime = getNestedRuntime(view, scope); + return parent_runtime.data[name]; +} + +export function setSignalValue(view, name, scope, value) { + let signal_op = lookupSignalOp(view, name, scope); + view.update(signal_op, value); +} + +export function setDataValue(view, name, scope, value) { + let dataset = dataRef(view, name, scope); + let changeset = view.changeset().remove(() => true).insert(value) + dataset.modified = true; + view.pulse(dataset.input, changeset); +} + +export function addSignalListener(view, name, scope, handler) { + let signal_op = lookupSignalOp(view, name, scope); + return addOperatorListener( + view, + name, + signal_op, + handler, + ); +} + +export function addDataListener(view, name, scope, handler) { + let dataset = dataRef(view, name, scope).values; + return addOperatorListener( + view, + name, + dataset, + handler, + ); +} + +// Private helpers from Vega for dealing with nested signals/data +function findOperatorHandler(op, handler) { + const h = (op._targets || []) + .filter(op => op._update && op._update.handler === handler); + return h.length ? h[0] : null; +} + +function addOperatorListener(view, name, op, handler) { + let h = findOperatorHandler(op, handler); + if (!h) { + h = trap(view, () => handler(name, op.value)); + h.handler = handler; + view.on(op, null, h); + } + return view; +} + +function trap(view, fn) { + return !fn ? null : function() { + try { + fn.apply(this, arguments); + } catch (error) { + view.error(error); + } + }; +} + +export default { render } diff --git a/llm/Lib/site-packages/altair/jupyter/jupyter_chart.py b/llm/Lib/site-packages/altair/jupyter/jupyter_chart.py new file mode 100644 index 0000000000000000000000000000000000000000..0331c9820cd9943e8ced1abaec1d0e99573b595a --- /dev/null +++ b/llm/Lib/site-packages/altair/jupyter/jupyter_chart.py @@ -0,0 +1,413 @@ +import json +import anywidget +import traitlets +import pathlib +from typing import Any, Set, Optional + +import altair as alt +from altair.utils._vegafusion_data import ( + using_vegafusion, + compile_to_vegafusion_chart_state, +) +from altair import TopLevelSpec +from altair.utils.selection import IndexSelection, PointSelection, IntervalSelection + +_here = pathlib.Path(__file__).parent + + +class Params(traitlets.HasTraits): + """ + Traitlet class storing a JupyterChart's params + """ + + def __init__(self, trait_values): + super().__init__() + + for key, value in trait_values.items(): + if isinstance(value, (int, float)): + traitlet_type = traitlets.Float() + elif isinstance(value, str): + traitlet_type = traitlets.Unicode() + elif isinstance(value, list): + traitlet_type = traitlets.List() + elif isinstance(value, dict): + traitlet_type = traitlets.Dict() + else: + traitlet_type = traitlets.Any() + + # Add the new trait. + self.add_traits(**{key: traitlet_type}) + + # Set the trait's value. + setattr(self, key, value) + + def __repr__(self): + return f"Params({self.trait_values()})" + + +class Selections(traitlets.HasTraits): + """ + Traitlet class storing a JupyterChart's selections + """ + + def __init__(self, trait_values): + super().__init__() + + for key, value in trait_values.items(): + if isinstance(value, IndexSelection): + traitlet_type = traitlets.Instance(IndexSelection) + elif isinstance(value, PointSelection): + traitlet_type = traitlets.Instance(PointSelection) + elif isinstance(value, IntervalSelection): + traitlet_type = traitlets.Instance(IntervalSelection) + else: + raise ValueError(f"Unexpected selection type: {type(value)}") + + # Add the new trait. + self.add_traits(**{key: traitlet_type}) + + # Set the trait's value. + setattr(self, key, value) + + # Make read-only + self.observe(self._make_read_only, names=key) + + def __repr__(self): + return f"Selections({self.trait_values()})" + + def _make_read_only(self, change): + """ + Work around to make traits read-only, but still allow us to change + them internally + """ + if change["name"] in self.traits() and change["old"] != change["new"]: + self._set_value(change["name"], change["old"]) + raise ValueError( + "Selections may not be set from Python.\n" + f"Attempted to set select: {change['name']}" + ) + + def _set_value(self, key, value): + self.unobserve(self._make_read_only, names=key) + setattr(self, key, value) + self.observe(self._make_read_only, names=key) + + +def load_js_src() -> str: + return (_here / "js" / "index.js").read_text() + + +class JupyterChart(anywidget.AnyWidget): + _esm = load_js_src() + _css = r""" + .vega-embed { + /* Make sure action menu isn't cut off */ + overflow: visible; + } + """ + + # Public traitlets + chart = traitlets.Instance(TopLevelSpec, allow_none=True) + spec = traitlets.Dict(allow_none=True).tag(sync=True) + debounce_wait = traitlets.Float(default_value=10).tag(sync=True) + max_wait = traitlets.Bool(default_value=True).tag(sync=True) + local_tz = traitlets.Unicode(default_value=None, allow_none=True).tag(sync=True) + debug = traitlets.Bool(default_value=False) + embed_options = traitlets.Dict(default_value=None, allow_none=True).tag(sync=True) + + # Internal selection traitlets + _selection_types = traitlets.Dict() + _vl_selections = traitlets.Dict().tag(sync=True) + + # Internal param traitlets + _params = traitlets.Dict().tag(sync=True) + + # Internal comm traitlets for VegaFusion support + _chart_state = traitlets.Any(allow_none=True) + _js_watch_plan = traitlets.Any(allow_none=True).tag(sync=True) + _js_to_py_updates = traitlets.Any(allow_none=True).tag(sync=True) + _py_to_js_updates = traitlets.Any(allow_none=True).tag(sync=True) + + # Track whether charts are configured for offline use + _is_offline = False + + @classmethod + def enable_offline(cls, offline: bool = True): + """ + Configure JupyterChart's offline behavior + + Parameters + ---------- + offline: bool + If True, configure JupyterChart to operate in offline mode where JavaScript + dependencies are loaded from vl-convert. + If False, configure it to operate in online mode where JavaScript dependencies + are loaded from CDN dynamically. This is the default behavior. + """ + from altair.utils._importers import import_vl_convert, vl_version_for_vl_convert + + if offline: + if cls._is_offline: + # Already offline + return + + vlc = import_vl_convert() + + src_lines = load_js_src().split("\n") + + # Remove leading lines with only whitespace, comments, or imports + while src_lines and ( + len(src_lines[0].strip()) == 0 + or src_lines[0].startswith("import") + or src_lines[0].startswith("//") + ): + src_lines.pop(0) + + src = "\n".join(src_lines) + + # vl-convert's javascript_bundle function creates a self-contained JavaScript bundle + # for JavaScript snippets that import from a small set of dependencies that + # vl-convert includes. To see the available imports and their imported names, run + # import vl_convert as vlc + # help(vlc.javascript_bundle) + bundled_src = vlc.javascript_bundle( + src, vl_version=vl_version_for_vl_convert() + ) + cls._esm = bundled_src + cls._is_offline = True + else: + cls._esm = load_js_src() + cls._is_offline = False + + def __init__( + self, + chart: TopLevelSpec, + debounce_wait: int = 10, + max_wait: bool = True, + debug: bool = False, + embed_options: Optional[dict] = None, + **kwargs: Any, + ): + """ + Jupyter Widget for displaying and updating Altair Charts, and + retrieving selection and parameter values + + Parameters + ---------- + chart: Chart + Altair Chart instance + debounce_wait: int + Debouncing wait time in milliseconds. Updates will be sent from the client to the kernel + after debounce_wait milliseconds of no chart interactions. + max_wait: bool + If True (default), updates will be sent from the client to the kernel every debounce_wait + milliseconds even if there are ongoing chart interactions. If False, updates will not be + sent until chart interactions have completed. + debug: bool + If True, debug messages will be printed + embed_options: dict + Options to pass to vega-embed. + See https://github.com/vega/vega-embed?tab=readme-ov-file#options + """ + self.params = Params({}) + self.selections = Selections({}) + super().__init__( + chart=chart, + debounce_wait=debounce_wait, + max_wait=max_wait, + debug=debug, + embed_options=embed_options, + **kwargs, + ) + + @traitlets.observe("chart") + def _on_change_chart(self, change): + """ + Internal callback function that updates the JupyterChart's internal + state when the wrapped Chart instance changes + """ + new_chart = change.new + selection_watches = [] + selection_types = {} + initial_params = {} + initial_vl_selections = {} + empty_selections = {} + + if new_chart is None: + with self.hold_sync(): + self.spec = None + self._selection_types = selection_types + self._vl_selections = initial_vl_selections + self._params = initial_params + return + + params = getattr(new_chart, "params", []) + + if params is not alt.Undefined: + for param in new_chart.params: + if isinstance(param.name, alt.ParameterName): + clean_name = param.name.to_json().strip('"') + else: + clean_name = param.name + + select = getattr(param, "select", alt.Undefined) + + if select != alt.Undefined: + if not isinstance(select, dict): + select = select.to_dict() + + select_type = select["type"] + if select_type == "point": + if not ( + select.get("fields", None) or select.get("encodings", None) + ): + # Point selection with no associated fields or encodings specified. + # This is an index-based selection + selection_types[clean_name] = "index" + empty_selections[clean_name] = IndexSelection( + name=clean_name, value=[], store=[] + ) + else: + selection_types[clean_name] = "point" + empty_selections[clean_name] = PointSelection( + name=clean_name, value=[], store=[] + ) + elif select_type == "interval": + selection_types[clean_name] = "interval" + empty_selections[clean_name] = IntervalSelection( + name=clean_name, value={}, store=[] + ) + else: + raise ValueError(f"Unexpected selection type {select.type}") + selection_watches.append(clean_name) + initial_vl_selections[clean_name] = {"value": None, "store": []} + else: + clean_value = param.value if param.value != alt.Undefined else None + initial_params[clean_name] = clean_value + + # Handle the params generated by transforms + for param_name in collect_transform_params(new_chart): + initial_params[param_name] = None + + # Setup params + self.params = Params(initial_params) + + def on_param_traitlet_changed(param_change): + new_params = dict(self._params) + new_params[param_change["name"]] = param_change["new"] + self._params = new_params + + self.params.observe(on_param_traitlet_changed) + + # Setup selections + self.selections = Selections(empty_selections) + + # Update properties all together + with self.hold_sync(): + if using_vegafusion(): + if self.local_tz is None: + self.spec = None + + def on_local_tz_change(change): + self._init_with_vegafusion(change["new"]) + + self.observe(on_local_tz_change, ["local_tz"]) + else: + self._init_with_vegafusion(self.local_tz) + else: + self.spec = new_chart.to_dict() + self._selection_types = selection_types + self._vl_selections = initial_vl_selections + self._params = initial_params + + def _init_with_vegafusion(self, local_tz: str): + if self.chart is not None: + vegalite_spec = self.chart.to_dict(context={"pre_transform": False}) + with self.hold_sync(): + self._chart_state = compile_to_vegafusion_chart_state( + vegalite_spec, local_tz + ) + self._js_watch_plan = self._chart_state.get_watch_plan()[ + "client_to_server" + ] + self.spec = self._chart_state.get_transformed_spec() + + # Callback to update chart state and send updates back to client + def on_js_to_py_updates(change): + if self.debug: + updates_str = json.dumps(change["new"], indent=2) + print( + f"JavaScript to Python VegaFusion updates:\n {updates_str}" + ) + updates = self._chart_state.update(change["new"]) + if self.debug: + updates_str = json.dumps(updates, indent=2) + print( + f"Python to JavaScript VegaFusion updates:\n {updates_str}" + ) + self._py_to_js_updates = updates + + self.observe(on_js_to_py_updates, ["_js_to_py_updates"]) + + @traitlets.observe("_params") + def _on_change_params(self, change): + for param_name, value in change.new.items(): + setattr(self.params, param_name, value) + + @traitlets.observe("_vl_selections") + def _on_change_selections(self, change): + """ + Internal callback function that updates the JupyterChart's public + selections traitlet in response to changes that the JavaScript logic + makes to the internal _selections traitlet. + """ + for selection_name, selection_dict in change.new.items(): + value = selection_dict["value"] + store = selection_dict["store"] + selection_type = self._selection_types[selection_name] + if selection_type == "index": + self.selections._set_value( + selection_name, + IndexSelection.from_vega(selection_name, signal=value, store=store), + ) + elif selection_type == "point": + self.selections._set_value( + selection_name, + PointSelection.from_vega(selection_name, signal=value, store=store), + ) + elif selection_type == "interval": + self.selections._set_value( + selection_name, + IntervalSelection.from_vega( + selection_name, signal=value, store=store + ), + ) + + +def collect_transform_params(chart: TopLevelSpec) -> Set[str]: + """ + Collect the names of params that are defined by transforms + + Parameters + ---------- + chart: Chart from which to extract transform params + + Returns + ------- + set of param names + """ + transform_params = set() + + # Handle recursive case + for prop in ("layer", "concat", "hconcat", "vconcat"): + for child in getattr(chart, prop, []): + transform_params.update(collect_transform_params(child)) + + # Handle chart's own transforms + transforms = getattr(chart, "transform", []) + transforms = transforms if transforms != alt.Undefined else [] + for tx in transforms: + if hasattr(tx, "param"): + transform_params.add(tx.param) + + return transform_params diff --git a/llm/Lib/site-packages/altair/py.typed b/llm/Lib/site-packages/altair/py.typed new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/llm/Lib/site-packages/altair/utils/__init__.py b/llm/Lib/site-packages/altair/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dba1e1f81c174e445f5e396bf9d0fa421618fe89 --- /dev/null +++ b/llm/Lib/site-packages/altair/utils/__init__.py @@ -0,0 +1,32 @@ +from .core import ( + infer_vegalite_type, + infer_encoding_types, + sanitize_dataframe, + sanitize_arrow_table, + parse_shorthand, + use_signature, + update_nested, + display_traceback, + SchemaBase, +) +from .html import spec_to_html +from .plugin_registry import PluginRegistry +from .deprecation import AltairDeprecationWarning +from .schemapi import Undefined + + +__all__ = ( + "infer_vegalite_type", + "infer_encoding_types", + "sanitize_dataframe", + "sanitize_arrow_table", + "spec_to_html", + "parse_shorthand", + "use_signature", + "update_nested", + "display_traceback", + "AltairDeprecationWarning", + "SchemaBase", + "Undefined", + "PluginRegistry", +) diff --git a/llm/Lib/site-packages/altair/utils/__pycache__/__init__.cpython-311.pyc b/llm/Lib/site-packages/altair/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7c3bea7e2b10710a37b0734f731d9a46fc7add0 Binary files /dev/null and b/llm/Lib/site-packages/altair/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/altair/utils/__pycache__/_dfi_types.cpython-311.pyc b/llm/Lib/site-packages/altair/utils/__pycache__/_dfi_types.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea86f83ce963f2207971908537763218e1ce8664 Binary files /dev/null and b/llm/Lib/site-packages/altair/utils/__pycache__/_dfi_types.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/altair/utils/__pycache__/_importers.cpython-311.pyc b/llm/Lib/site-packages/altair/utils/__pycache__/_importers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a478c85a62fc9b54e8d335951c3c4979829eb7e Binary files /dev/null and b/llm/Lib/site-packages/altair/utils/__pycache__/_importers.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/altair/utils/__pycache__/_show.cpython-311.pyc b/llm/Lib/site-packages/altair/utils/__pycache__/_show.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a58667c240bb6068bed5bcc0a6162aa87359af1 Binary files /dev/null and b/llm/Lib/site-packages/altair/utils/__pycache__/_show.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/altair/utils/__pycache__/_transformed_data.cpython-311.pyc b/llm/Lib/site-packages/altair/utils/__pycache__/_transformed_data.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f0b3988fba837e115556a128f1e03a23810b03f Binary files /dev/null and b/llm/Lib/site-packages/altair/utils/__pycache__/_transformed_data.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/altair/utils/__pycache__/_vegafusion_data.cpython-311.pyc b/llm/Lib/site-packages/altair/utils/__pycache__/_vegafusion_data.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65b2f39a571f01302ed6e17b2ee2c737fd5235cb Binary files /dev/null and b/llm/Lib/site-packages/altair/utils/__pycache__/_vegafusion_data.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/altair/utils/__pycache__/compiler.cpython-311.pyc b/llm/Lib/site-packages/altair/utils/__pycache__/compiler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2b88fa78c835e0828af21776cba7f255120e2e9 Binary files /dev/null and b/llm/Lib/site-packages/altair/utils/__pycache__/compiler.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/altair/utils/__pycache__/core.cpython-311.pyc b/llm/Lib/site-packages/altair/utils/__pycache__/core.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d611e94ceaa8e1b4f81e7aa4f213e20dde96c60d Binary files /dev/null and b/llm/Lib/site-packages/altair/utils/__pycache__/core.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/altair/utils/__pycache__/data.cpython-311.pyc b/llm/Lib/site-packages/altair/utils/__pycache__/data.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe1b1e7a851cf93bc12339aff6f34730ad23c9fb Binary files /dev/null and b/llm/Lib/site-packages/altair/utils/__pycache__/data.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/altair/utils/__pycache__/deprecation.cpython-311.pyc b/llm/Lib/site-packages/altair/utils/__pycache__/deprecation.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3b05af4e789131a87bc612f7fb07072cf9eed3f Binary files /dev/null and b/llm/Lib/site-packages/altair/utils/__pycache__/deprecation.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/altair/utils/__pycache__/display.cpython-311.pyc b/llm/Lib/site-packages/altair/utils/__pycache__/display.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7392df5c3af130e1272c4175c5d068be1f8ba2f Binary files /dev/null and b/llm/Lib/site-packages/altair/utils/__pycache__/display.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/altair/utils/__pycache__/execeval.cpython-311.pyc b/llm/Lib/site-packages/altair/utils/__pycache__/execeval.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..991f4cad2adfb909bf95abde02e8f5b8fb8e7acc Binary files /dev/null and b/llm/Lib/site-packages/altair/utils/__pycache__/execeval.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/altair/utils/__pycache__/html.cpython-311.pyc b/llm/Lib/site-packages/altair/utils/__pycache__/html.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fc6b12bb4353b2c2d1e9df604aab1a6698fd308 Binary files /dev/null and b/llm/Lib/site-packages/altair/utils/__pycache__/html.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/altair/utils/__pycache__/mimebundle.cpython-311.pyc b/llm/Lib/site-packages/altair/utils/__pycache__/mimebundle.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45f7bf10d19435daf7c8b4423e59a92c845ae640 Binary files /dev/null and b/llm/Lib/site-packages/altair/utils/__pycache__/mimebundle.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/altair/utils/__pycache__/plugin_registry.cpython-311.pyc b/llm/Lib/site-packages/altair/utils/__pycache__/plugin_registry.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d4c4697c0dd6cb322eb3962f9bc6eeaef0d868d Binary files /dev/null and b/llm/Lib/site-packages/altair/utils/__pycache__/plugin_registry.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/altair/utils/__pycache__/save.cpython-311.pyc b/llm/Lib/site-packages/altair/utils/__pycache__/save.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17c89669c9a2d163743d7f575c9f0eb0d7b5ac08 Binary files /dev/null and b/llm/Lib/site-packages/altair/utils/__pycache__/save.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/altair/utils/__pycache__/schemapi.cpython-311.pyc b/llm/Lib/site-packages/altair/utils/__pycache__/schemapi.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..431e15a7607384097af5a20d375c16d80d1a9aad Binary files /dev/null and b/llm/Lib/site-packages/altair/utils/__pycache__/schemapi.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/altair/utils/__pycache__/selection.cpython-311.pyc b/llm/Lib/site-packages/altair/utils/__pycache__/selection.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14a51c516bb74f4b668e2cc79a15ddec77207d02 Binary files /dev/null and b/llm/Lib/site-packages/altair/utils/__pycache__/selection.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/altair/utils/__pycache__/server.cpython-311.pyc b/llm/Lib/site-packages/altair/utils/__pycache__/server.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d5d78890483d5e5c8ae5750efe263f007dc9eb6 Binary files /dev/null and b/llm/Lib/site-packages/altair/utils/__pycache__/server.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/altair/utils/__pycache__/theme.cpython-311.pyc b/llm/Lib/site-packages/altair/utils/__pycache__/theme.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b7c97c2b174950a218c820c7392aedaa7677829e Binary files /dev/null and b/llm/Lib/site-packages/altair/utils/__pycache__/theme.cpython-311.pyc differ diff --git a/llm/Lib/site-packages/altair/utils/_dfi_types.py b/llm/Lib/site-packages/altair/utils/_dfi_types.py new file mode 100644 index 0000000000000000000000000000000000000000..a76435e7fdf9f771fd69d7bba37dc3c4a0f93fd1 --- /dev/null +++ b/llm/Lib/site-packages/altair/utils/_dfi_types.py @@ -0,0 +1,165 @@ +# DataFrame Interchange Protocol Types +# Copied from https://data-apis.org/dataframe-protocol/latest/API.html, +# changed ABCs to Protocols, and subset the type hints to only those that are +# relevant for Altair. +# +# These classes are only for use in type signatures +import enum +from typing import Any, Iterable, Optional, Tuple, Protocol + + +class DtypeKind(enum.IntEnum): + """ + Integer enum for data types. + + Attributes + ---------- + INT : int + Matches to signed integer data type. + UINT : int + Matches to unsigned integer data type. + FLOAT : int + Matches to floating point data type. + BOOL : int + Matches to boolean data type. + STRING : int + Matches to string data type (UTF-8 encoded). + DATETIME : int + Matches to datetime data type. + CATEGORICAL : int + Matches to categorical data type. + """ + + INT = 0 + UINT = 1 + FLOAT = 2 + BOOL = 20 + STRING = 21 # UTF-8 + DATETIME = 22 + CATEGORICAL = 23 + + +# Type hint of first element would actually be DtypeKind but can't use that +# as other libraries won't use an instance of our own Enum in this module but have +# their own. Type checkers will raise an error on that even though the enums +# are identical. +Dtype = Tuple[Any, int, str, str] # see Column.dtype + + +class Column(Protocol): + @property + def dtype(self) -> Dtype: + """ + Dtype description as a tuple ``(kind, bit-width, format string, endianness)``. + + Bit-width : the number of bits as an integer + Format string : data type description format string in Apache Arrow C + Data Interface format. + Endianness : current only native endianness (``=``) is supported + + Notes: + - Kind specifiers are aligned with DLPack where possible (hence the + jump to 20, leave enough room for future extension) + - Masks must be specified as boolean with either bit width 1 (for bit + masks) or 8 (for byte masks). + - Dtype width in bits was preferred over bytes + - Endianness isn't too useful, but included now in case in the future + we need to support non-native endianness + - Went with Apache Arrow format strings over NumPy format strings + because they're more complete from a dataframe perspective + - Format strings are mostly useful for datetime specification, and + for categoricals. + - For categoricals, the format string describes the type of the + categorical in the data buffer. In case of a separate encoding of + the categorical (e.g. an integer to string mapping), this can + be derived from ``self.describe_categorical``. + - Data types not included: complex, Arrow-style null, binary, decimal, + and nested (list, struct, map, union) dtypes. + """ + pass + + # Have to use a generic Any return type as not all libraries who implement + # the dataframe interchange protocol implement the TypedDict that is usually + # returned here in the same way. As TypedDicts are invariant, even a slight change + # will lead to an error by a type checker. See PR in which this code was added + # for details. + @property + def describe_categorical(self) -> Any: + """ + If the dtype is categorical, there are two options: + - There are only values in the data buffer. + - There is a separate non-categorical Column encoding categorical values. + + Raises TypeError if the dtype is not categorical + + Returns the dictionary with description on how to interpret the data buffer: + - "is_ordered" : bool, whether the ordering of dictionary indices is + semantically meaningful. + - "is_dictionary" : bool, whether a mapping of + categorical values to other objects exists + - "categories" : Column representing the (implicit) mapping of indices to + category values (e.g. an array of cat1, cat2, ...). + None if not a dictionary-style categorical. + + TBD: are there any other in-memory representations that are needed? + """ + pass + + +class DataFrame(Protocol): + """ + A data frame class, with only the methods required by the interchange + protocol defined. + + A "data frame" represents an ordered collection of named columns. + A column's "name" must be a unique string. + Columns may be accessed by name or by position. + + This could be a public data frame class, or an object with the methods and + attributes defined on this DataFrame class could be returned from the + ``__dataframe__`` method of a public data frame class in a library adhering + to the dataframe interchange protocol specification. + """ + + def __dataframe__( + self, nan_as_null: bool = False, allow_copy: bool = True + ) -> "DataFrame": + """ + Construct a new exchange object, potentially changing the parameters. + + ``nan_as_null`` is a keyword intended for the consumer to tell the + producer to overwrite null values in the data with ``NaN``. + It is intended for cases where the consumer does not support the bit + mask or byte mask that is the producer's native representation. + ``allow_copy`` is a keyword that defines whether or not the library is + allowed to make a copy of the data. For example, copying data would be + necessary if a library supports strided buffers, given that this protocol + specifies contiguous buffers. + """ + pass + + def column_names(self) -> Iterable[str]: + """ + Return an iterator yielding the column names. + """ + pass + + def get_column_by_name(self, name: str) -> Column: + """ + Return the column whose name is the indicated name. + """ + pass + + def get_chunks(self, n_chunks: Optional[int] = None) -> Iterable["DataFrame"]: + """ + Return an iterator yielding the chunks. + + By default (None), yields the chunks that the data is stored as by the + producer. If given, ``n_chunks`` must be a multiple of + ``self.num_chunks()``, meaning the producer must subdivide each chunk + before yielding it. + + Note that the producer must ensure that all columns are chunked the + same way. + """ + pass diff --git a/llm/Lib/site-packages/altair/utils/_importers.py b/llm/Lib/site-packages/altair/utils/_importers.py new file mode 100644 index 0000000000000000000000000000000000000000..b7fa8a9584a7b4f31bcae19369b7b1fd6d23872b --- /dev/null +++ b/llm/Lib/site-packages/altair/utils/_importers.py @@ -0,0 +1,97 @@ +from types import ModuleType +from packaging.version import Version +from importlib.metadata import version as importlib_version + + +def import_vegafusion() -> ModuleType: + min_version = "1.5.0" + try: + version = importlib_version("vegafusion") + embed_version = importlib_version("vegafusion-python-embed") + if version != embed_version or Version(version) < Version(min_version): + raise RuntimeError( + "The versions of the vegafusion and vegafusion-python-embed packages must match\n" + f"and must be version {min_version} or greater.\n" + f"Found:\n" + f" - vegafusion=={version}\n" + f" - vegafusion-python-embed=={embed_version}\n" + ) + import vegafusion as vf # type: ignore + + return vf + except ImportError as err: + raise ImportError( + 'The "vegafusion" data transformer and chart.transformed_data feature requires\n' + f"version {min_version} or greater of the 'vegafusion-python-embed' and 'vegafusion' packages.\n" + "These can be installed with pip using:\n" + f' pip install "vegafusion[embed]>={min_version}"\n' + "Or with conda using:\n" + f' conda install -c conda-forge "vegafusion-python-embed>={min_version}" ' + f'"vegafusion>={min_version}"\n\n' + f"ImportError: {err.args[0]}" + ) from err + + +def import_vl_convert() -> ModuleType: + min_version = "1.3.0" + try: + version = importlib_version("vl-convert-python") + if Version(version) < Version(min_version): + raise RuntimeError( + f"The vl-convert-python package must be version {min_version} or greater. " + f"Found version {version}" + ) + import vl_convert as vlc + + return vlc + except ImportError as err: + raise ImportError( + f"The vl-convert Vega-Lite compiler and file export feature requires\n" + f"version {min_version} or greater of the 'vl-convert-python' package. \n" + f"This can be installed with pip using:\n" + f' pip install "vl-convert-python>={min_version}"\n' + "or conda:\n" + f' conda install -c conda-forge "vl-convert-python>={min_version}"\n\n' + f"ImportError: {err.args[0]}" + ) from err + + +def vl_version_for_vl_convert() -> str: + from ..vegalite import SCHEMA_VERSION + + # Compute VlConvert's vl_version string (of the form 'v5_2') + # from SCHEMA_VERSION (of the form 'v5.2.0') + return "_".join(SCHEMA_VERSION.split(".")[:2]) + + +def import_pyarrow_interchange() -> ModuleType: + min_version = "11.0.0" + try: + version = importlib_version("pyarrow") + + if Version(version) < Version(min_version): + raise RuntimeError( + f"The pyarrow package must be version {min_version} or greater. " + f"Found version {version}" + ) + import pyarrow.interchange as pi + + return pi + except ImportError as err: + raise ImportError( + f"Usage of the DataFrame Interchange Protocol requires\n" + f"version {min_version} or greater of the pyarrow package. \n" + f"This can be installed with pip using:\n" + f' pip install "pyarrow>={min_version}"\n' + "or conda:\n" + f' conda install -c conda-forge "pyarrow>={min_version}"\n\n' + f"ImportError: {err.args[0]}" + ) from err + + +def pyarrow_available() -> bool: + try: + import_pyarrow_interchange() + return True + except (ImportError, RuntimeError): + return False diff --git a/llm/Lib/site-packages/altair/utils/_show.py b/llm/Lib/site-packages/altair/utils/_show.py new file mode 100644 index 0000000000000000000000000000000000000000..0030570acb9e66be9a6f2e45f7e212b0020c8b56 --- /dev/null +++ b/llm/Lib/site-packages/altair/utils/_show.py @@ -0,0 +1,73 @@ +import webbrowser +from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import Union, Iterable, Optional + + +def open_html_in_browser( + html: Union[str, bytes], + using: Union[str, Iterable[str], None] = None, + port: Optional[int] = None, +): + """ + Display an html document in a web browser without creating a temp file. + + Instantiates a simple http server and uses the webbrowser module to + open the server's URL + + Parameters + ---------- + html: str + HTML string to display + using: str or iterable of str + Name of the web browser to open (e.g. "chrome", "firefox", etc.). + If an iterable, choose the first browser available on the system. + If none, choose the system default browser. + port: int + Port to use. Defaults to a random port + """ + # Encode html to bytes + if isinstance(html, str): + html_bytes = html.encode("utf8") + else: + html_bytes = html + + browser = None + + if using is None: + browser = webbrowser.get(None) + else: + # normalize using to an iterable + if isinstance(using, str): + using = [using] + + for browser_key in using: + try: + browser = webbrowser.get(browser_key) + if browser is not None: + break + except webbrowser.Error: + pass + + if browser is None: + raise ValueError("Failed to locate a browser with name in " + str(using)) + + class OneShotRequestHandler(BaseHTTPRequestHandler): + def do_GET(self): + self.send_response(200) + self.send_header("Content-type", "text/html") + self.end_headers() + + bufferSize = 1024 * 1024 + for i in range(0, len(html_bytes), bufferSize): + self.wfile.write(html_bytes[i : i + bufferSize]) + + def log_message(self, format, *args): + # Silence stderr logging + pass + + # Use specified port if provided, otherwise choose a random port (port value of 0) + server = HTTPServer( + ("127.0.0.1", port if port is not None else 0), OneShotRequestHandler + ) + browser.open("http://127.0.0.1:%s" % server.server_port) + server.handle_request() diff --git a/llm/Lib/site-packages/altair/utils/_transformed_data.py b/llm/Lib/site-packages/altair/utils/_transformed_data.py new file mode 100644 index 0000000000000000000000000000000000000000..1d886eb5e9cf96818dff92a6115429b40683d457 --- /dev/null +++ b/llm/Lib/site-packages/altair/utils/_transformed_data.py @@ -0,0 +1,557 @@ +from typing import List, Optional, Tuple, Dict, Iterable, overload, Union + +from altair import ( + Chart, + FacetChart, + LayerChart, + HConcatChart, + VConcatChart, + ConcatChart, + TopLevelUnitSpec, + FacetedUnitSpec, + UnitSpec, + UnitSpecWithFrame, + NonNormalizedSpec, + TopLevelLayerSpec, + LayerSpec, + TopLevelConcatSpec, + ConcatSpecGenericSpec, + TopLevelHConcatSpec, + HConcatSpecGenericSpec, + TopLevelVConcatSpec, + VConcatSpecGenericSpec, + TopLevelFacetSpec, + FacetSpec, + data_transformers, +) +from altair.utils._vegafusion_data import get_inline_tables, import_vegafusion +from altair.utils.core import DataFrameLike +from altair.utils.schemapi import Undefined + +Scope = Tuple[int, ...] +FacetMapping = Dict[Tuple[str, Scope], Tuple[str, Scope]] + + +# For the transformed_data functionality, the chart classes in the values +# can be considered equivalent to the chart class in the key. +_chart_class_mapping = { + Chart: ( + Chart, + TopLevelUnitSpec, + FacetedUnitSpec, + UnitSpec, + UnitSpecWithFrame, + NonNormalizedSpec, + ), + LayerChart: (LayerChart, TopLevelLayerSpec, LayerSpec), + ConcatChart: (ConcatChart, TopLevelConcatSpec, ConcatSpecGenericSpec), + HConcatChart: (HConcatChart, TopLevelHConcatSpec, HConcatSpecGenericSpec), + VConcatChart: (VConcatChart, TopLevelVConcatSpec, VConcatSpecGenericSpec), + FacetChart: (FacetChart, TopLevelFacetSpec, FacetSpec), +} + + +@overload +def transformed_data( + chart: Union[Chart, FacetChart], + row_limit: Optional[int] = None, + exclude: Optional[Iterable[str]] = None, +) -> Optional[DataFrameLike]: ... + + +@overload +def transformed_data( + chart: Union[LayerChart, HConcatChart, VConcatChart, ConcatChart], + row_limit: Optional[int] = None, + exclude: Optional[Iterable[str]] = None, +) -> List[DataFrameLike]: ... + + +def transformed_data(chart, row_limit=None, exclude=None): + """Evaluate a Chart's transforms + + Evaluate the data transforms associated with a Chart and return the + transformed data as one or more DataFrames + + Parameters + ---------- + chart : Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, or ConcatChart + Altair chart to evaluate transforms on + row_limit : int (optional) + Maximum number of rows to return for each DataFrame. None (default) for unlimited + exclude : iterable of str + Set of the names of charts to exclude + + Returns + ------- + DataFrame or list of DataFrames or None + If input chart is a Chart or Facet Chart, returns a DataFrame of the + transformed data. Otherwise, returns a list of DataFrames of the + transformed data + """ + vf = import_vegafusion() + + if isinstance(chart, Chart): + # Add mark if none is specified to satisfy Vega-Lite + if chart.mark == Undefined: + chart = chart.mark_point() + + # Deep copy chart so that we can rename marks without affecting caller + chart = chart.copy(deep=True) + + # Ensure that all views are named so that we can look them up in the + # resulting Vega specification + chart_names = name_views(chart, 0, exclude=exclude) + + # Compile to Vega and extract inline DataFrames + with data_transformers.enable("vegafusion"): + vega_spec = chart.to_dict(format="vega", context={"pre_transform": False}) + inline_datasets = get_inline_tables(vega_spec) + + # Build mapping from mark names to vega datasets + facet_mapping = get_facet_mapping(vega_spec) + dataset_mapping = get_datasets_for_view_names(vega_spec, chart_names, facet_mapping) + + # Build a list of vega dataset names that corresponds to the order + # of the chart components + dataset_names = [] + for chart_name in chart_names: + if chart_name in dataset_mapping: + dataset_names.append(dataset_mapping[chart_name]) + else: + raise ValueError("Failed to locate all datasets") + + # Extract transformed datasets with VegaFusion + datasets, warnings = vf.runtime.pre_transform_datasets( + vega_spec, + dataset_names, + row_limit=row_limit, + inline_datasets=inline_datasets, + ) + + if isinstance(chart, (Chart, FacetChart)): + # Return DataFrame (or None if it was excluded) if input was a simple Chart + if not datasets: + return None + else: + return datasets[0] + else: + # Otherwise return the list of DataFrames + return datasets + + +# The equivalent classes from _chart_class_mapping should also be added +# to the type hints below for `chart` as the function would also work for them. +# However, this was not possible so far as mypy then complains about +# "Overloaded function signatures 1 and 2 overlap with incompatible return types [misc]" +# This might be due to the complex type hierarchy of the chart classes. +# See also https://github.com/python/mypy/issues/5119 +# and https://github.com/python/mypy/issues/4020 which show that mypy might not have +# a very consistent behavior for overloaded functions. +# The same error appeared when trying it with Protocols for the concat and layer charts. +# This function is only used internally and so we accept this inconsistency for now. +def name_views( + chart: Union[ + Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, ConcatChart + ], + i: int = 0, + exclude: Optional[Iterable[str]] = None, +) -> List[str]: + """Name unnamed chart views + + Name unnamed charts views so that we can look them up later in + the compiled Vega spec. + + Note: This function mutates the input chart by applying names to + unnamed views. + + Parameters + ---------- + chart : Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, or ConcatChart + Altair chart to apply names to + i : int (default 0) + Starting chart index + exclude : iterable of str + Names of charts to exclude + + Returns + ------- + list of str + List of the names of the charts and subcharts + """ + exclude = set(exclude) if exclude is not None else set() + if isinstance(chart, _chart_class_mapping[Chart]) or isinstance( + chart, _chart_class_mapping[FacetChart] + ): + if chart.name not in exclude: + if chart.name in (None, Undefined): + # Add name since none is specified + chart.name = Chart._get_name() + return [chart.name] + else: + return [] + else: + if isinstance(chart, _chart_class_mapping[LayerChart]): + subcharts = chart.layer + elif isinstance(chart, _chart_class_mapping[HConcatChart]): + subcharts = chart.hconcat + elif isinstance(chart, _chart_class_mapping[VConcatChart]): + subcharts = chart.vconcat + elif isinstance(chart, _chart_class_mapping[ConcatChart]): + subcharts = chart.concat + else: + raise ValueError( + "transformed_data accepts an instance of " + "Chart, FacetChart, LayerChart, HConcatChart, VConcatChart, or ConcatChart\n" + f"Received value of type: {type(chart)}" + ) + + chart_names: List[str] = [] + for subchart in subcharts: + for name in name_views(subchart, i=i + len(chart_names), exclude=exclude): + chart_names.append(name) + return chart_names + + +def get_group_mark_for_scope(vega_spec: dict, scope: Scope) -> Optional[dict]: + """Get the group mark at a particular scope + + Parameters + ---------- + vega_spec : dict + Top-level Vega specification dictionary + scope : tuple of int + Scope tuple. If empty, the original Vega specification is returned. + Otherwise, the nested group mark at the scope specified is returned. + + Returns + ------- + dict or None + Top-level Vega spec (if scope is empty) + or group mark (if scope is non-empty) + or None (if group mark at scope does not exist) + + Examples + -------- + >>> spec = { + ... "marks": [ + ... { + ... "type": "group", + ... "marks": [{"type": "symbol"}] + ... }, + ... { + ... "type": "group", + ... "marks": [{"type": "rect"}]} + ... ] + ... } + >>> get_group_mark_for_scope(spec, (1,)) + {'type': 'group', 'marks': [{'type': 'rect'}]} + """ + group = vega_spec + + # Find group at scope + for scope_value in scope: + group_index = 0 + child_group = None + for mark in group.get("marks", []): + if mark.get("type") == "group": + if group_index == scope_value: + child_group = mark + break + group_index += 1 + if child_group is None: + return None + group = child_group + + return group + + +def get_datasets_for_scope(vega_spec: dict, scope: Scope) -> List[str]: + """Get the names of the datasets that are defined at a given scope + + Parameters + ---------- + vega_spec : dict + Top-leve Vega specification + scope : tuple of int + Scope tuple. If empty, the names of top-level datasets are returned + Otherwise, the names of the datasets defined in the nested group mark + at the specified scope are returned. + + Returns + ------- + list of str + List of the names of the datasets defined at the specified scope + + Examples + -------- + >>> spec = { + ... "data": [ + ... {"name": "data1"} + ... ], + ... "marks": [ + ... { + ... "type": "group", + ... "data": [ + ... {"name": "data2"} + ... ], + ... "marks": [{"type": "symbol"}] + ... }, + ... { + ... "type": "group", + ... "data": [ + ... {"name": "data3"}, + ... {"name": "data4"}, + ... ], + ... "marks": [{"type": "rect"}] + ... } + ... ] + ... } + + >>> get_datasets_for_scope(spec, ()) + ['data1'] + + >>> get_datasets_for_scope(spec, (0,)) + ['data2'] + + >>> get_datasets_for_scope(spec, (1,)) + ['data3', 'data4'] + + Returns empty when no group mark exists at scope + >>> get_datasets_for_scope(spec, (1, 3)) + [] + """ + group = get_group_mark_for_scope(vega_spec, scope) or {} + + # get datasets from group + datasets = [] + for dataset in group.get("data", []): + datasets.append(dataset["name"]) + + # Add facet dataset + facet_dataset = group.get("from", {}).get("facet", {}).get("name", None) + if facet_dataset: + datasets.append(facet_dataset) + return datasets + + +def get_definition_scope_for_data_reference( + vega_spec: dict, data_name: str, usage_scope: Scope +) -> Optional[Scope]: + """Return the scope that a dataset is defined at, for a given usage scope + + Parameters + ---------- + vega_spec: dict + Top-level Vega specification + data_name: str + The name of a dataset reference + usage_scope: tuple of int + The scope that the dataset is referenced in + + Returns + ------- + tuple of int + The scope where the referenced dataset is defined, + or None if no such dataset is found + + Examples + -------- + >>> spec = { + ... "data": [ + ... {"name": "data1"} + ... ], + ... "marks": [ + ... { + ... "type": "group", + ... "data": [ + ... {"name": "data2"} + ... ], + ... "marks": [{ + ... "type": "symbol", + ... "encode": { + ... "update": { + ... "x": {"field": "x", "data": "data1"}, + ... "y": {"field": "y", "data": "data2"}, + ... } + ... } + ... }] + ... } + ... ] + ... } + + data1 is referenced at scope [0] and defined at scope [] + >>> get_definition_scope_for_data_reference(spec, "data1", (0,)) + () + + data2 is referenced at scope [0] and defined at scope [0] + >>> get_definition_scope_for_data_reference(spec, "data2", (0,)) + (0,) + + If data2 is not visible at scope [] (the top level), + because it's defined in scope [0] + >>> repr(get_definition_scope_for_data_reference(spec, "data2", ())) + 'None' + """ + for i in reversed(range(len(usage_scope) + 1)): + scope = usage_scope[:i] + datasets = get_datasets_for_scope(vega_spec, scope) + if data_name in datasets: + return scope + return None + + +def get_facet_mapping(group: dict, scope: Scope = ()) -> FacetMapping: + """Create mapping from facet definitions to source datasets + + Parameters + ---------- + group : dict + Top-level Vega spec or nested group mark + scope : tuple of int + Scope of the group dictionary within a top-level Vega spec + + Returns + ------- + dict + Dictionary from (facet_name, facet_scope) to (dataset_name, dataset_scope) + + Examples + -------- + >>> spec = { + ... "data": [ + ... {"name": "data1"} + ... ], + ... "marks": [ + ... { + ... "type": "group", + ... "from": { + ... "facet": { + ... "name": "facet1", + ... "data": "data1", + ... "groupby": ["colA"] + ... } + ... } + ... } + ... ] + ... } + >>> get_facet_mapping(spec) + {('facet1', (0,)): ('data1', ())} + """ + facet_mapping = {} + group_index = 0 + mark_group = get_group_mark_for_scope(group, scope) or {} + for mark in mark_group.get("marks", []): + if mark.get("type", None) == "group": + # Get facet for this group + group_scope = scope + (group_index,) + facet = mark.get("from", {}).get("facet", None) + if facet is not None: + facet_name = facet.get("name", None) + facet_data = facet.get("data", None) + if facet_name is not None and facet_data is not None: + definition_scope = get_definition_scope_for_data_reference( + group, facet_data, scope + ) + if definition_scope is not None: + facet_mapping[(facet_name, group_scope)] = ( + facet_data, + definition_scope, + ) + + # Handle children recursively + child_mapping = get_facet_mapping(group, scope=group_scope) + facet_mapping.update(child_mapping) + group_index += 1 + + return facet_mapping + + +def get_from_facet_mapping( + scoped_dataset: Tuple[str, Scope], facet_mapping: FacetMapping +) -> Tuple[str, Scope]: + """Apply facet mapping to a scoped dataset + + Parameters + ---------- + scoped_dataset : (str, tuple of int) + A dataset name and scope tuple + facet_mapping : dict from (str, tuple of int) to (str, tuple of int) + The facet mapping produced by get_facet_mapping + + Returns + ------- + (str, tuple of int) + Dataset name and scope tuple that has been mapped as many times as possible + + Examples + -------- + Facet mapping as produced by get_facet_mapping + >>> facet_mapping = {("facet1", (0,)): ("data1", ()), ("facet2", (0, 1)): ("facet1", (0,))} + >>> get_from_facet_mapping(("facet2", (0, 1)), facet_mapping) + ('data1', ()) + """ + while scoped_dataset in facet_mapping: + scoped_dataset = facet_mapping[scoped_dataset] + return scoped_dataset + + +def get_datasets_for_view_names( + group: dict, + vl_chart_names: List[str], + facet_mapping: FacetMapping, + scope: Scope = (), +) -> Dict[str, Tuple[str, Scope]]: + """Get the Vega datasets that correspond to the provided Altair view names + + Parameters + ---------- + group : dict + Top-level Vega spec or nested group mark + vl_chart_names : list of str + List of the Vega-Lite + facet_mapping : dict from (str, tuple of int) to (str, tuple of int) + The facet mapping produced by get_facet_mapping + scope : tuple of int + Scope of the group dictionary within a top-level Vega spec + + Returns + ------- + dict from str to (str, tuple of int) + Dict from Altair view names to scoped datasets + """ + datasets = {} + group_index = 0 + mark_group = get_group_mark_for_scope(group, scope) or {} + for mark in mark_group.get("marks", []): + for vl_chart_name in vl_chart_names: + if mark.get("name", "") == f"{vl_chart_name}_cell": + data_name = mark.get("from", {}).get("facet", None).get("data", None) + scoped_data_name = (data_name, scope) + datasets[vl_chart_name] = get_from_facet_mapping( + scoped_data_name, facet_mapping + ) + break + + name = mark.get("name", "") + if mark.get("type", "") == "group": + group_data_names = get_datasets_for_view_names( + group, vl_chart_names, facet_mapping, scope=scope + (group_index,) + ) + for k, v in group_data_names.items(): + datasets.setdefault(k, v) + group_index += 1 + else: + for vl_chart_name in vl_chart_names: + if name.startswith(vl_chart_name) and name.endswith("_marks"): + data_name = mark.get("from", {}).get("data", None) + scoped_data = get_definition_scope_for_data_reference( + group, data_name, scope + ) + if scoped_data is not None: + datasets[vl_chart_name] = get_from_facet_mapping( + (data_name, scoped_data), facet_mapping + ) + break + + return datasets diff --git a/llm/Lib/site-packages/altair/utils/_vegafusion_data.py b/llm/Lib/site-packages/altair/utils/_vegafusion_data.py new file mode 100644 index 0000000000000000000000000000000000000000..ce30e8d6d9bab4ebb72d2dcdb3e51d664d30c1bc --- /dev/null +++ b/llm/Lib/site-packages/altair/utils/_vegafusion_data.py @@ -0,0 +1,256 @@ +from toolz import curried +import uuid +from weakref import WeakValueDictionary + +from typing import ( + Union, + Dict, + Set, + MutableMapping, + TypedDict, + Final, + TYPE_CHECKING, +) + +from altair.utils._importers import import_vegafusion +from altair.utils.core import DataFrameLike +from altair.utils.data import DataType, ToValuesReturnType, MaxRowsError +from altair.vegalite.data import default_data_transformer + +if TYPE_CHECKING: + from vegafusion.runtime import ChartState # type: ignore + +# Temporary storage for dataframes that have been extracted +# from charts by the vegafusion data transformer. Use a WeakValueDictionary +# rather than a dict so that the Python interpreter is free to garbage +# collect the stored DataFrames. +extracted_inline_tables: MutableMapping[str, DataFrameLike] = WeakValueDictionary() + +# Special URL prefix that VegaFusion uses to denote that a +# dataset in a Vega spec corresponds to an entry in the `inline_datasets` +# kwarg of vf.runtime.pre_transform_spec(). +VEGAFUSION_PREFIX: Final = "vegafusion+dataset://" + + +class _ToVegaFusionReturnUrlDict(TypedDict): + url: str + + +@curried.curry +def vegafusion_data_transformer( + data: DataType, max_rows: int = 100000 +) -> Union[_ToVegaFusionReturnUrlDict, ToValuesReturnType]: + """VegaFusion Data Transformer""" + if hasattr(data, "__geo_interface__"): + # Use default transformer for geo interface objects + # # (e.g. a geopandas GeoDataFrame) + return default_data_transformer(data) + elif isinstance(data, DataFrameLike): + table_name = f"table_{uuid.uuid4()}".replace("-", "_") + extracted_inline_tables[table_name] = data + return {"url": VEGAFUSION_PREFIX + table_name} + else: + # Use default transformer if we don't recognize data type + return default_data_transformer(data) + + +def get_inline_table_names(vega_spec: dict) -> Set[str]: + """Get a set of the inline datasets names in the provided Vega spec + + Inline datasets are encoded as URLs that start with the table:// + prefix. + + Parameters + ---------- + vega_spec: dict + A Vega specification dict + + Returns + ------- + set of str + Set of the names of the inline datasets that are referenced + in the specification. + + Examples + -------- + >>> spec = { + ... "data": [ + ... { + ... "name": "foo", + ... "url": "https://path/to/file.csv" + ... }, + ... { + ... "name": "bar", + ... "url": "vegafusion+dataset://inline_dataset_123" + ... } + ... ] + ... } + >>> get_inline_table_names(spec) + {'inline_dataset_123'} + """ + table_names = set() + + # Process datasets + for data in vega_spec.get("data", []): + url = data.get("url", "") + if url.startswith(VEGAFUSION_PREFIX): + name = url[len(VEGAFUSION_PREFIX) :] + table_names.add(name) + + # Recursively process child marks, which may have their own datasets + for mark in vega_spec.get("marks", []): + table_names.update(get_inline_table_names(mark)) + + return table_names + + +def get_inline_tables(vega_spec: dict) -> Dict[str, DataFrameLike]: + """Get the inline tables referenced by a Vega specification + + Note: This function should only be called on a Vega spec that corresponds + to a chart that was processed by the vegafusion_data_transformer. + Furthermore, this function may only be called once per spec because + the returned dataframes are deleted from internal storage. + + Parameters + ---------- + vega_spec: dict + A Vega specification dict + + Returns + ------- + dict from str to dataframe + dict from inline dataset name to dataframe object + """ + table_names = get_inline_table_names(vega_spec) + tables = {} + for table_name in table_names: + try: + tables[table_name] = extracted_inline_tables.pop(table_name) + except KeyError: + # named dataset that was provided by the user + pass + return tables + + +def compile_to_vegafusion_chart_state( + vegalite_spec: dict, local_tz: str +) -> "ChartState": + """Compile a Vega-Lite spec to a VegaFusion ChartState + + Note: This function should only be called on a Vega-Lite spec + that was generated with the "vegafusion" data transformer enabled. + In particular, this spec may contain references to extract datasets + using table:// prefixed URLs. + + Parameters + ---------- + vegalite_spec: dict + A Vega-Lite spec that was generated from an Altair chart with + the "vegafusion" data transformer enabled + local_tz: str + Local timezone name (e.g. 'America/New_York') + + Returns + ------- + ChartState + A VegaFusion ChartState object + """ + # Local import to avoid circular ImportError + from altair import vegalite_compilers, data_transformers + + vf = import_vegafusion() + + # Compile Vega-Lite spec to Vega + compiler = vegalite_compilers.get() + if compiler is None: + raise ValueError("No active vega-lite compiler plugin found") + + vega_spec = compiler(vegalite_spec) + + # Retrieve dict of inline tables referenced by the spec + inline_tables = get_inline_tables(vega_spec) + + # Pre-evaluate transforms in vega spec with vegafusion + row_limit = data_transformers.options.get("max_rows", None) + + chart_state = vf.runtime.new_chart_state( + vega_spec, + local_tz=local_tz, + inline_datasets=inline_tables, + row_limit=row_limit, + ) + + # Check from row limit warning and convert to MaxRowsError + handle_row_limit_exceeded(row_limit, chart_state.get_warnings()) + + return chart_state + + +def compile_with_vegafusion(vegalite_spec: dict) -> dict: + """Compile a Vega-Lite spec to Vega and pre-transform with VegaFusion + + Note: This function should only be called on a Vega-Lite spec + that was generated with the "vegafusion" data transformer enabled. + In particular, this spec may contain references to extract datasets + using table:// prefixed URLs. + + Parameters + ---------- + vegalite_spec: dict + A Vega-Lite spec that was generated from an Altair chart with + the "vegafusion" data transformer enabled + + Returns + ------- + dict + A Vega spec that has been pre-transformed by VegaFusion + """ + # Local import to avoid circular ImportError + from altair import vegalite_compilers, data_transformers + + vf = import_vegafusion() + + # Compile Vega-Lite spec to Vega + compiler = vegalite_compilers.get() + if compiler is None: + raise ValueError("No active vega-lite compiler plugin found") + + vega_spec = compiler(vegalite_spec) + + # Retrieve dict of inline tables referenced by the spec + inline_tables = get_inline_tables(vega_spec) + + # Pre-evaluate transforms in vega spec with vegafusion + row_limit = data_transformers.options.get("max_rows", None) + transformed_vega_spec, warnings = vf.runtime.pre_transform_spec( + vega_spec, + vf.get_local_tz(), + inline_datasets=inline_tables, + row_limit=row_limit, + ) + + # Check from row limit warning and convert to MaxRowsError + handle_row_limit_exceeded(row_limit, warnings) + + return transformed_vega_spec + + +def handle_row_limit_exceeded(row_limit: int, warnings: list): + for warning in warnings: + if warning.get("type") == "RowLimitExceeded": + raise MaxRowsError( + "The number of dataset rows after filtering and aggregation exceeds\n" + f"the current limit of {row_limit}. Try adding an aggregation to reduce\n" + "the size of the dataset that must be loaded into the browser. Or, disable\n" + "the limit by calling alt.data_transformers.disable_max_rows(). Note that\n" + "disabling this limit may cause the browser to freeze or crash." + ) + + +def using_vegafusion() -> bool: + """Check whether the vegafusion data transformer is enabled""" + # Local import to avoid circular ImportError + from altair import data_transformers + + return data_transformers.active == "vegafusion" diff --git a/llm/Lib/site-packages/altair/utils/compiler.py b/llm/Lib/site-packages/altair/utils/compiler.py new file mode 100644 index 0000000000000000000000000000000000000000..0944c92fddb4a620d11287a1df1583464b9b1c92 --- /dev/null +++ b/llm/Lib/site-packages/altair/utils/compiler.py @@ -0,0 +1,11 @@ +from typing import Callable +from altair.utils import PluginRegistry + +# ============================================================================== +# Vega-Lite to Vega compiler registry +# ============================================================================== +VegaLiteCompilerType = Callable[[dict], dict] + + +class VegaLiteCompilerRegistry(PluginRegistry[VegaLiteCompilerType]): + pass diff --git a/llm/Lib/site-packages/altair/utils/core.py b/llm/Lib/site-packages/altair/utils/core.py new file mode 100644 index 0000000000000000000000000000000000000000..baf1013f7d774c6caa79a9b77534bfcb2c577d79 --- /dev/null +++ b/llm/Lib/site-packages/altair/utils/core.py @@ -0,0 +1,852 @@ +""" +Utility routines +""" + +from collections.abc import Mapping, MutableMapping +from copy import deepcopy +import json +import itertools +import re +import sys +import traceback +import warnings +from typing import ( + Callable, + TypeVar, + Any, + Union, + Dict, + Optional, + Tuple, + Sequence, + Type, + cast, +) +from types import ModuleType + +import jsonschema +import pandas as pd +import numpy as np +from pandas.api.types import infer_dtype + +from altair.utils.schemapi import SchemaBase +from altair.utils._dfi_types import Column, DtypeKind, DataFrame as DfiDataFrame + +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec + +from typing import Literal, Protocol, TYPE_CHECKING, runtime_checkable + +if TYPE_CHECKING: + from pandas.core.interchange.dataframe_protocol import Column as PandasColumn + +V = TypeVar("V") +P = ParamSpec("P") + + +@runtime_checkable +class DataFrameLike(Protocol): + def __dataframe__( + self, nan_as_null: bool = False, allow_copy: bool = True + ) -> DfiDataFrame: ... + + +TYPECODE_MAP = { + "ordinal": "O", + "nominal": "N", + "quantitative": "Q", + "temporal": "T", + "geojson": "G", +} + +INV_TYPECODE_MAP = {v: k for k, v in TYPECODE_MAP.items()} + + +# aggregates from vega-lite version 4.6.0 +AGGREGATES = [ + "argmax", + "argmin", + "average", + "count", + "distinct", + "max", + "mean", + "median", + "min", + "missing", + "product", + "q1", + "q3", + "ci0", + "ci1", + "stderr", + "stdev", + "stdevp", + "sum", + "valid", + "values", + "variance", + "variancep", + "exponential", + "exponentialb", +] + +# window aggregates from vega-lite version 4.6.0 +WINDOW_AGGREGATES = [ + "row_number", + "rank", + "dense_rank", + "percent_rank", + "cume_dist", + "ntile", + "lag", + "lead", + "first_value", + "last_value", + "nth_value", +] + +# timeUnits from vega-lite version 4.17.0 +TIMEUNITS = [ + "year", + "quarter", + "month", + "week", + "day", + "dayofyear", + "date", + "hours", + "minutes", + "seconds", + "milliseconds", + "yearquarter", + "yearquartermonth", + "yearmonth", + "yearmonthdate", + "yearmonthdatehours", + "yearmonthdatehoursminutes", + "yearmonthdatehoursminutesseconds", + "yearweek", + "yearweekday", + "yearweekdayhours", + "yearweekdayhoursminutes", + "yearweekdayhoursminutesseconds", + "yeardayofyear", + "quartermonth", + "monthdate", + "monthdatehours", + "monthdatehoursminutes", + "monthdatehoursminutesseconds", + "weekday", + "weeksdayhours", + "weekdayhours", + "weekdayhoursminutes", + "weekdayhoursminutesseconds", + "dayhours", + "dayhoursminutes", + "dayhoursminutesseconds", + "hoursminutes", + "hoursminutesseconds", + "minutesseconds", + "secondsmilliseconds", + "utcyear", + "utcquarter", + "utcmonth", + "utcweek", + "utcday", + "utcdayofyear", + "utcdate", + "utchours", + "utcminutes", + "utcseconds", + "utcmilliseconds", + "utcyearquarter", + "utcyearquartermonth", + "utcyearmonth", + "utcyearmonthdate", + "utcyearmonthdatehours", + "utcyearmonthdatehoursminutes", + "utcyearmonthdatehoursminutesseconds", + "utcyearweek", + "utcyearweekday", + "utcyearweekdayhours", + "utcyearweekdayhoursminutes", + "utcyearweekdayhoursminutesseconds", + "utcyeardayofyear", + "utcquartermonth", + "utcmonthdate", + "utcmonthdatehours", + "utcmonthdatehoursminutes", + "utcmonthdatehoursminutesseconds", + "utcweekday", + "utcweeksdayhours", + "utcweekdayhoursminutes", + "utcweekdayhoursminutesseconds", + "utcdayhours", + "utcdayhoursminutes", + "utcdayhoursminutesseconds", + "utchoursminutes", + "utchoursminutesseconds", + "utcminutesseconds", + "utcsecondsmilliseconds", +] + + +InferredVegaLiteType = Literal["ordinal", "nominal", "quantitative", "temporal"] + + +def infer_vegalite_type( + data: object, +) -> Union[InferredVegaLiteType, Tuple[InferredVegaLiteType, list]]: + """ + From an array-like input, infer the correct vega typecode + ('ordinal', 'nominal', 'quantitative', or 'temporal') + + Parameters + ---------- + data: object + """ + typ = infer_dtype(data, skipna=False) + + if typ in [ + "floating", + "mixed-integer-float", + "integer", + "mixed-integer", + "complex", + ]: + return "quantitative" + elif typ == "categorical" and hasattr(data, "cat") and data.cat.ordered: + return ("ordinal", data.cat.categories.tolist()) + elif typ in ["string", "bytes", "categorical", "boolean", "mixed", "unicode"]: + return "nominal" + elif typ in [ + "datetime", + "datetime64", + "timedelta", + "timedelta64", + "date", + "time", + "period", + ]: + return "temporal" + else: + warnings.warn( + "I don't know how to infer vegalite type from '{}'. " + "Defaulting to nominal.".format(typ), + stacklevel=1, + ) + return "nominal" + + +def merge_props_geom(feat: dict) -> dict: + """ + Merge properties with geometry + * Overwrites 'type' and 'geometry' entries if existing + """ + + geom = {k: feat[k] for k in ("type", "geometry")} + try: + feat["properties"].update(geom) + props_geom = feat["properties"] + except (AttributeError, KeyError): + # AttributeError when 'properties' equals None + # KeyError when 'properties' is non-existing + props_geom = geom + + return props_geom + + +def sanitize_geo_interface(geo: MutableMapping) -> dict: + """Santize a geo_interface to prepare it for serialization. + + * Make a copy + * Convert type array or _Array to list + * Convert tuples to lists (using json.loads/dumps) + * Merge properties with geometry + """ + + geo = deepcopy(geo) + + # convert type _Array or array to list + for key in geo.keys(): + if str(type(geo[key]).__name__).startswith(("_Array", "array")): + geo[key] = geo[key].tolist() + + # convert (nested) tuples to lists + geo_dct: dict = json.loads(json.dumps(geo)) + + # sanitize features + if geo_dct["type"] == "FeatureCollection": + geo_dct = geo_dct["features"] + if len(geo_dct) > 0: + for idx, feat in enumerate(geo_dct): + geo_dct[idx] = merge_props_geom(feat) + elif geo_dct["type"] == "Feature": + geo_dct = merge_props_geom(geo_dct) + else: + geo_dct = {"type": "Feature", "geometry": geo_dct} + + return geo_dct + + +def numpy_is_subtype(dtype: Any, subtype: Any) -> bool: + try: + return np.issubdtype(dtype, subtype) + except (NotImplementedError, TypeError): + return False + + +def sanitize_dataframe(df: pd.DataFrame) -> pd.DataFrame: # noqa: C901 + """Sanitize a DataFrame to prepare it for serialization. + + * Make a copy + * Convert RangeIndex columns to strings + * Raise ValueError if column names are not strings + * Raise ValueError if it has a hierarchical index. + * Convert categoricals to strings. + * Convert np.bool_ dtypes to Python bool objects + * Convert np.int dtypes to Python int objects + * Convert floats to objects and replace NaNs/infs with None. + * Convert DateTime dtypes into appropriate string representations + * Convert Nullable integers to objects and replace NaN with None + * Convert Nullable boolean to objects and replace NaN with None + * convert dedicated string column to objects and replace NaN with None + * Raise a ValueError for TimeDelta dtypes + """ + df = df.copy() + + if isinstance(df.columns, pd.RangeIndex): + df.columns = df.columns.astype(str) + + for col_name in df.columns: + if not isinstance(col_name, str): + raise ValueError( + "Dataframe contains invalid column name: {0!r}. " + "Column names must be strings".format(col_name) + ) + + if isinstance(df.index, pd.MultiIndex): + raise ValueError("Hierarchical indices not supported") + if isinstance(df.columns, pd.MultiIndex): + raise ValueError("Hierarchical indices not supported") + + def to_list_if_array(val): + if isinstance(val, np.ndarray): + return val.tolist() + else: + return val + + for dtype_item in df.dtypes.items(): + # We know that the column names are strings from the isinstance check + # further above but mypy thinks it is of type Hashable and therefore does not + # let us assign it to the col_name variable which is already of type str. + col_name = cast(str, dtype_item[0]) + dtype = dtype_item[1] + dtype_name = str(dtype) + if dtype_name == "category": + # Work around bug in to_json for categorical types in older versions + # of pandas as they do not properly convert NaN values to null in to_json. + # We can probably remove this part once we require pandas >= 1.0 + col = df[col_name].astype(object) + df[col_name] = col.where(col.notnull(), None) + elif dtype_name == "string": + # dedicated string datatype (since 1.0) + # https://pandas.pydata.org/pandas-docs/version/1.0.0/whatsnew/v1.0.0.html#dedicated-string-data-type + col = df[col_name].astype(object) + df[col_name] = col.where(col.notnull(), None) + elif dtype_name == "bool": + # convert numpy bools to objects; np.bool is not JSON serializable + df[col_name] = df[col_name].astype(object) + elif dtype_name == "boolean": + # dedicated boolean datatype (since 1.0) + # https://pandas.io/docs/user_guide/boolean.html + col = df[col_name].astype(object) + df[col_name] = col.where(col.notnull(), None) + elif dtype_name.startswith("datetime") or dtype_name.startswith("timestamp"): + # Convert datetimes to strings. This needs to be a full ISO string + # with time, which is why we cannot use ``col.astype(str)``. + # This is because Javascript parses date-only times in UTC, but + # parses full ISO-8601 dates as local time, and dates in Vega and + # Vega-Lite are displayed in local time by default. + # (see https://github.com/altair-viz/altair/issues/1027) + df[col_name] = ( + df[col_name].apply(lambda x: x.isoformat()).replace("NaT", "") + ) + elif dtype_name.startswith("timedelta"): + raise ValueError( + 'Field "{col_name}" has type "{dtype}" which is ' + "not supported by Altair. Please convert to " + "either a timestamp or a numerical value." + "".format(col_name=col_name, dtype=dtype) + ) + elif dtype_name.startswith("geometry"): + # geopandas >=0.6.1 uses the dtype geometry. Continue here + # otherwise it will give an error on np.issubdtype(dtype, np.integer) + continue + elif ( + dtype_name + in { + "Int8", + "Int16", + "Int32", + "Int64", + "UInt8", + "UInt16", + "UInt32", + "UInt64", + "Float32", + "Float64", + } + ): # nullable integer datatypes (since 24.0) and nullable float datatypes (since 1.2.0) + # https://pandas.pydata.org/pandas-docs/version/0.25/whatsnew/v0.24.0.html#optional-integer-na-support + col = df[col_name].astype(object) + df[col_name] = col.where(col.notnull(), None) + elif numpy_is_subtype(dtype, np.integer): + # convert integers to objects; np.int is not JSON serializable + df[col_name] = df[col_name].astype(object) + elif numpy_is_subtype(dtype, np.floating): + # For floats, convert to Python float: np.float is not JSON serializable + # Also convert NaN/inf values to null, as they are not JSON serializable + col = df[col_name] + bad_values = col.isnull() | np.isinf(col) + df[col_name] = col.astype(object).where(~bad_values, None) + elif dtype == object: + # Convert numpy arrays saved as objects to lists + # Arrays are not JSON serializable + col = df[col_name].astype(object).apply(to_list_if_array) + df[col_name] = col.where(col.notnull(), None) + return df + + +def sanitize_arrow_table(pa_table): + """Sanitize arrow table for JSON serialization""" + import pyarrow as pa + import pyarrow.compute as pc + + arrays = [] + schema = pa_table.schema + for name in schema.names: + array = pa_table[name] + dtype_name = str(schema.field(name).type) + if dtype_name.startswith("timestamp") or dtype_name.startswith("date"): + arrays.append(pc.strftime(array)) + elif dtype_name.startswith("duration"): + raise ValueError( + 'Field "{col_name}" has type "{dtype}" which is ' + "not supported by Altair. Please convert to " + "either a timestamp or a numerical value." + "".format(col_name=name, dtype=dtype_name) + ) + else: + arrays.append(array) + + return pa.Table.from_arrays(arrays, names=schema.names) + + +def parse_shorthand( + shorthand: Union[Dict[str, Any], str], + data: Optional[Union[pd.DataFrame, DataFrameLike]] = None, + parse_aggregates: bool = True, + parse_window_ops: bool = False, + parse_timeunits: bool = True, + parse_types: bool = True, +) -> Dict[str, Any]: + """General tool to parse shorthand values + + These are of the form: + + - "col_name" + - "col_name:O" + - "average(col_name)" + - "average(col_name):O" + + Optionally, a dataframe may be supplied, from which the type + will be inferred if not specified in the shorthand. + + Parameters + ---------- + shorthand : dict or string + The shorthand representation to be parsed + data : DataFrame, optional + If specified and of type DataFrame, then use these values to infer the + column type if not provided by the shorthand. + parse_aggregates : boolean + If True (default), then parse aggregate functions within the shorthand. + parse_window_ops : boolean + If True then parse window operations within the shorthand (default:False) + parse_timeunits : boolean + If True (default), then parse timeUnits from within the shorthand + parse_types : boolean + If True (default), then parse typecodes within the shorthand + + Returns + ------- + attrs : dict + a dictionary of attributes extracted from the shorthand + + Examples + -------- + >>> data = pd.DataFrame({'foo': ['A', 'B', 'A', 'B'], + ... 'bar': [1, 2, 3, 4]}) + + >>> parse_shorthand('name') == {'field': 'name'} + True + + >>> parse_shorthand('name:Q') == {'field': 'name', 'type': 'quantitative'} + True + + >>> parse_shorthand('average(col)') == {'aggregate': 'average', 'field': 'col'} + True + + >>> parse_shorthand('foo:O') == {'field': 'foo', 'type': 'ordinal'} + True + + >>> parse_shorthand('min(foo):Q') == {'aggregate': 'min', 'field': 'foo', 'type': 'quantitative'} + True + + >>> parse_shorthand('month(col)') == {'field': 'col', 'timeUnit': 'month', 'type': 'temporal'} + True + + >>> parse_shorthand('year(col):O') == {'field': 'col', 'timeUnit': 'year', 'type': 'ordinal'} + True + + >>> parse_shorthand('foo', data) == {'field': 'foo', 'type': 'nominal'} + True + + >>> parse_shorthand('bar', data) == {'field': 'bar', 'type': 'quantitative'} + True + + >>> parse_shorthand('bar:O', data) == {'field': 'bar', 'type': 'ordinal'} + True + + >>> parse_shorthand('sum(bar)', data) == {'aggregate': 'sum', 'field': 'bar', 'type': 'quantitative'} + True + + >>> parse_shorthand('count()', data) == {'aggregate': 'count', 'type': 'quantitative'} + True + """ + from altair.utils._importers import pyarrow_available + + if not shorthand: + return {} + + valid_typecodes = list(TYPECODE_MAP) + list(INV_TYPECODE_MAP) + + units = { + "field": "(?P.*)", + "type": "(?P{})".format("|".join(valid_typecodes)), + "agg_count": "(?Pcount)", + "op_count": "(?Pcount)", + "aggregate": "(?P{})".format("|".join(AGGREGATES)), + "window_op": "(?P{})".format("|".join(AGGREGATES + WINDOW_AGGREGATES)), + "timeUnit": "(?P{})".format("|".join(TIMEUNITS)), + } + + patterns = [] + + if parse_aggregates: + patterns.extend([r"{agg_count}\(\)"]) + patterns.extend([r"{aggregate}\({field}\)"]) + if parse_window_ops: + patterns.extend([r"{op_count}\(\)"]) + patterns.extend([r"{window_op}\({field}\)"]) + if parse_timeunits: + patterns.extend([r"{timeUnit}\({field}\)"]) + + patterns.extend([r"{field}"]) + + if parse_types: + patterns = list(itertools.chain(*((p + ":{type}", p) for p in patterns))) + + regexps = ( + re.compile(r"\A" + p.format(**units) + r"\Z", re.DOTALL) for p in patterns + ) + + # find matches depending on valid fields passed + if isinstance(shorthand, dict): + attrs = shorthand + else: + attrs = next( + exp.match(shorthand).groupdict() # type: ignore[union-attr] + for exp in regexps + if exp.match(shorthand) is not None + ) + + # Handle short form of the type expression + if "type" in attrs: + attrs["type"] = INV_TYPECODE_MAP.get(attrs["type"], attrs["type"]) + + # counts are quantitative by default + if attrs == {"aggregate": "count"}: + attrs["type"] = "quantitative" + + # times are temporal by default + if "timeUnit" in attrs and "type" not in attrs: + attrs["type"] = "temporal" + + # if data is specified and type is not, infer type from data + if "type" not in attrs: + if pyarrow_available() and data is not None and isinstance(data, DataFrameLike): + dfi = data.__dataframe__() + if "field" in attrs: + unescaped_field = attrs["field"].replace("\\", "") + if unescaped_field in dfi.column_names(): + column = dfi.get_column_by_name(unescaped_field) + try: + attrs["type"] = infer_vegalite_type_for_dfi_column(column) + except (NotImplementedError, AttributeError, ValueError): + # Fall back to pandas-based inference. + # Note: The AttributeError catch is a workaround for + # https://github.com/pandas-dev/pandas/issues/55332 + if isinstance(data, pd.DataFrame): + attrs["type"] = infer_vegalite_type(data[unescaped_field]) + else: + raise + + if isinstance(attrs["type"], tuple): + attrs["sort"] = attrs["type"][1] + attrs["type"] = attrs["type"][0] + elif isinstance(data, pd.DataFrame): + # Fallback if pyarrow is not installed or if pandas is older than 1.5 + # + # Remove escape sequences so that types can be inferred for columns with special characters + if "field" in attrs and attrs["field"].replace("\\", "") in data.columns: + attrs["type"] = infer_vegalite_type( + data[attrs["field"].replace("\\", "")] + ) + # ordered categorical dataframe columns return the type and sort order as a tuple + if isinstance(attrs["type"], tuple): + attrs["sort"] = attrs["type"][1] + attrs["type"] = attrs["type"][0] + + # If an unescaped colon is still present, it's often due to an incorrect data type specification + # but could also be due to using a column name with ":" in it. + if ( + "field" in attrs + and ":" in attrs["field"] + and attrs["field"][attrs["field"].rfind(":") - 1] != "\\" + ): + raise ValueError( + '"{}" '.format(attrs["field"].split(":")[-1]) + + "is not one of the valid encoding data types: {}.".format( + ", ".join(TYPECODE_MAP.values()) + ) + + "\nFor more details, see https://altair-viz.github.io/user_guide/encodings/index.html#encoding-data-types. " + + "If you are trying to use a column name that contains a colon, " + + 'prefix it with a backslash; for example "column\\:name" instead of "column:name".' + ) + return attrs + + +def infer_vegalite_type_for_dfi_column( + column: Union[Column, "PandasColumn"], +) -> Union[InferredVegaLiteType, Tuple[InferredVegaLiteType, list]]: + from pyarrow.interchange.from_dataframe import column_to_array + + try: + kind = column.dtype[0] + except NotImplementedError as e: + # Edge case hack: + # dtype access fails for pandas column with datetime64[ns, UTC] type, + # but all we need to know is that its temporal, so check the + # error message for the presence of datetime64. + # + # See https://github.com/pandas-dev/pandas/issues/54239 + if "datetime64" in e.args[0] or "timestamp" in e.args[0]: + return "temporal" + raise e + + if ( + kind == DtypeKind.CATEGORICAL + and column.describe_categorical["is_ordered"] + and column.describe_categorical["categories"] is not None + ): + # Treat ordered categorical column as Vega-Lite ordinal + categories_column = column.describe_categorical["categories"] + categories_array = column_to_array(categories_column) + return "ordinal", categories_array.to_pylist() + if kind in (DtypeKind.STRING, DtypeKind.CATEGORICAL, DtypeKind.BOOL): + return "nominal" + elif kind in (DtypeKind.INT, DtypeKind.UINT, DtypeKind.FLOAT): + return "quantitative" + elif kind == DtypeKind.DATETIME: + return "temporal" + else: + raise ValueError(f"Unexpected DtypeKind: {kind}") + + +def use_signature(Obj: Callable[P, Any]): + """Apply call signature and documentation of Obj to the decorated method""" + + def decorate(f: Callable[..., V]) -> Callable[P, V]: + # call-signature of f is exposed via __wrapped__. + # we want it to mimic Obj.__init__ + f.__wrapped__ = Obj.__init__ # type: ignore + f._uses_signature = Obj # type: ignore + + # Supplement the docstring of f with information from Obj + if Obj.__doc__: + # Patch in a reference to the class this docstring is copied from, + # to generate a hyperlink. + doclines = Obj.__doc__.splitlines() + doclines[0] = f"Refer to :class:`{Obj.__name__}`" + + if f.__doc__: + doc = f.__doc__ + "\n".join(doclines[1:]) + else: + doc = "\n".join(doclines) + try: + f.__doc__ = doc + except AttributeError: + # __doc__ is not modifiable for classes in Python < 3.3 + pass + + return f + + return decorate + + +def update_nested( + original: MutableMapping, update: Mapping, copy: bool = False +) -> MutableMapping: + """Update nested dictionaries + + Parameters + ---------- + original : MutableMapping + the original (nested) dictionary, which will be updated in-place + update : Mapping + the nested dictionary of updates + copy : bool, default False + if True, then copy the original dictionary rather than modifying it + + Returns + ------- + original : MutableMapping + a reference to the (modified) original dict + + Examples + -------- + >>> original = {'x': {'b': 2, 'c': 4}} + >>> update = {'x': {'b': 5, 'd': 6}, 'y': 40} + >>> update_nested(original, update) # doctest: +SKIP + {'x': {'b': 5, 'c': 4, 'd': 6}, 'y': 40} + >>> original # doctest: +SKIP + {'x': {'b': 5, 'c': 4, 'd': 6}, 'y': 40} + """ + if copy: + original = deepcopy(original) + for key, val in update.items(): + if isinstance(val, Mapping): + orig_val = original.get(key, {}) + if isinstance(orig_val, MutableMapping): + original[key] = update_nested(orig_val, val) + else: + original[key] = val + else: + original[key] = val + return original + + +def display_traceback(in_ipython: bool = True): + exc_info = sys.exc_info() + + if in_ipython: + from IPython.core.getipython import get_ipython + + ip = get_ipython() + else: + ip = None + + if ip is not None: + ip.showtraceback(exc_info) + else: + traceback.print_exception(*exc_info) + + +def infer_encoding_types(args: Sequence, kwargs: MutableMapping, channels: ModuleType): + """Infer typed keyword arguments for args and kwargs + + Parameters + ---------- + args : Sequence + Sequence of function args + kwargs : MutableMapping + Dict of function kwargs + channels : ModuleType + The module containing all altair encoding channel classes. + + Returns + ------- + kwargs : dict + All args and kwargs in a single dict, with keys and types + based on the channels mapping. + """ + # Construct a dictionary of channel type to encoding name + # TODO: cache this somehow? + channel_objs = (getattr(channels, name) for name in dir(channels)) + channel_objs = ( + c for c in channel_objs if isinstance(c, type) and issubclass(c, SchemaBase) + ) + channel_to_name: Dict[Type[SchemaBase], str] = { + c: c._encoding_name for c in channel_objs + } + name_to_channel: Dict[str, Dict[str, Type[SchemaBase]]] = {} + for chan, name in channel_to_name.items(): + chans = name_to_channel.setdefault(name, {}) + if chan.__name__.endswith("Datum"): + key = "datum" + elif chan.__name__.endswith("Value"): + key = "value" + else: + key = "field" + chans[key] = chan + + # First use the mapping to convert args to kwargs based on their types. + for arg in args: + if isinstance(arg, (list, tuple)) and len(arg) > 0: + type_ = type(arg[0]) + else: + type_ = type(arg) + + encoding = channel_to_name.get(type_, None) + if encoding is None: + raise NotImplementedError("positional of type {}" "".format(type_)) + if encoding in kwargs: + raise ValueError("encoding {} specified twice.".format(encoding)) + kwargs[encoding] = arg + + def _wrap_in_channel_class(obj, encoding): + if isinstance(obj, SchemaBase): + return obj + + if isinstance(obj, str): + obj = {"shorthand": obj} + + if isinstance(obj, (list, tuple)): + return [_wrap_in_channel_class(subobj, encoding) for subobj in obj] + + if encoding not in name_to_channel: + warnings.warn( + "Unrecognized encoding channel '{}'".format(encoding), stacklevel=1 + ) + return obj + + classes = name_to_channel[encoding] + cls = classes["value"] if "value" in obj else classes["field"] + + try: + # Don't force validation here; some objects won't be valid until + # they're created in the context of a chart. + return cls.from_dict(obj, validate=False) + except jsonschema.ValidationError: + # our attempts at finding the correct class have failed + return obj + + return { + encoding: _wrap_in_channel_class(obj, encoding) + for encoding, obj in kwargs.items() + } diff --git a/llm/Lib/site-packages/altair/utils/data.py b/llm/Lib/site-packages/altair/utils/data.py new file mode 100644 index 0000000000000000000000000000000000000000..871b43092a257e270a5a790e64fdb56ebeca2fbf --- /dev/null +++ b/llm/Lib/site-packages/altair/utils/data.py @@ -0,0 +1,360 @@ +import json +import os +import random +import hashlib +import warnings +from typing import Union, MutableMapping, Optional, Dict, Sequence, TYPE_CHECKING, List + +import pandas as pd +from toolz import curried +from typing import TypeVar + +from ._importers import import_pyarrow_interchange +from .core import sanitize_dataframe, sanitize_arrow_table, DataFrameLike +from .core import sanitize_geo_interface +from .deprecation import AltairDeprecationWarning +from .plugin_registry import PluginRegistry + + +from typing import Protocol, TypedDict, Literal + + +if TYPE_CHECKING: + import pyarrow.lib + + +class SupportsGeoInterface(Protocol): + __geo_interface__: MutableMapping + + +DataType = Union[dict, pd.DataFrame, SupportsGeoInterface, DataFrameLike] +TDataType = TypeVar("TDataType", bound=DataType) + +VegaLiteDataDict = Dict[str, Union[str, dict, List[dict]]] +ToValuesReturnType = Dict[str, Union[dict, List[dict]]] + + +# ============================================================================== +# Data transformer registry +# +# A data transformer is a callable that takes a supported data type and returns +# a transformed dictionary version of it which is compatible with the VegaLite schema. +# The dict objects will be the Data portion of the VegaLite schema. +# +# Renderers only deal with the dict form of a +# VegaLite spec, after the Data model has been put into a schema compliant +# form. +# ============================================================================== +class DataTransformerType(Protocol): + def __call__(self, data: DataType, **kwargs) -> VegaLiteDataDict: + pass + + +class DataTransformerRegistry(PluginRegistry[DataTransformerType]): + _global_settings = {"consolidate_datasets": True} + + @property + def consolidate_datasets(self) -> bool: + return self._global_settings["consolidate_datasets"] + + @consolidate_datasets.setter + def consolidate_datasets(self, value: bool) -> None: + self._global_settings["consolidate_datasets"] = value + + +# ============================================================================== +class MaxRowsError(Exception): + """Raised when a data model has too many rows.""" + + pass + + +@curried.curry +def limit_rows(data: TDataType, max_rows: Optional[int] = 5000) -> TDataType: + """Raise MaxRowsError if the data model has more than max_rows. + + If max_rows is None, then do not perform any check. + """ + check_data_type(data) + + def raise_max_rows_error(): + raise MaxRowsError( + "The number of rows in your dataset is greater " + f"than the maximum allowed ({max_rows}).\n\n" + "Try enabling the VegaFusion data transformer which " + "raises this limit by pre-evaluating data\n" + "transformations in Python.\n" + " >> import altair as alt\n" + ' >> alt.data_transformers.enable("vegafusion")\n\n' + "Or, see https://altair-viz.github.io/user_guide/large_datasets.html " + "for additional information\n" + "on how to plot large datasets." + ) + + if hasattr(data, "__geo_interface__"): + if data.__geo_interface__["type"] == "FeatureCollection": + values = data.__geo_interface__["features"] + else: + values = data.__geo_interface__ + elif isinstance(data, pd.DataFrame): + values = data + elif isinstance(data, dict): + if "values" in data: + values = data["values"] + else: + # mypy gets confused as it doesn't see Dict[Any, Any] + # as equivalent to TDataType + return data # type: ignore[return-value] + elif isinstance(data, DataFrameLike): + pa_table = arrow_table_from_dfi_dataframe(data) + if max_rows is not None and pa_table.num_rows > max_rows: + raise_max_rows_error() + # Return pyarrow Table instead of input since the + # `arrow_table_from_dfi_dataframe` call above may be expensive + return pa_table + + if max_rows is not None and len(values) > max_rows: + raise_max_rows_error() + + return data + + +@curried.curry +def sample( + data: DataType, n: Optional[int] = None, frac: Optional[float] = None +) -> Optional[Union[pd.DataFrame, Dict[str, Sequence], "pyarrow.lib.Table"]]: + """Reduce the size of the data model by sampling without replacement.""" + check_data_type(data) + if isinstance(data, pd.DataFrame): + return data.sample(n=n, frac=frac) + elif isinstance(data, dict): + if "values" in data: + values = data["values"] + if not n: + if frac is None: + raise ValueError( + "frac cannot be None if n is None and data is a dictionary" + ) + n = int(frac * len(values)) + values = random.sample(values, n) + return {"values": values} + else: + # Maybe this should raise an error or return something useful? + return None + elif isinstance(data, DataFrameLike): + pa_table = arrow_table_from_dfi_dataframe(data) + if not n: + if frac is None: + raise ValueError( + "frac cannot be None if n is None with this data input type" + ) + n = int(frac * len(pa_table)) + indices = random.sample(range(len(pa_table)), n) + return pa_table.take(indices) + else: + # Maybe this should raise an error or return something useful? Currently, + # if data is of type SupportsGeoInterface it lands here + return None + + +class _JsonFormatDict(TypedDict): + type: Literal["json"] + + +class _CsvFormatDict(TypedDict): + type: Literal["csv"] + + +class _ToJsonReturnUrlDict(TypedDict): + url: str + format: _JsonFormatDict + + +class _ToCsvReturnUrlDict(TypedDict): + url: str + format: _CsvFormatDict + + +@curried.curry +def to_json( + data: DataType, + prefix: str = "altair-data", + extension: str = "json", + filename: str = "{prefix}-{hash}.{extension}", + urlpath: str = "", +) -> _ToJsonReturnUrlDict: + """ + Write the data model to a .json file and return a url based data model. + """ + data_json = _data_to_json_string(data) + data_hash = _compute_data_hash(data_json) + filename = filename.format(prefix=prefix, hash=data_hash, extension=extension) + with open(filename, "w") as f: + f.write(data_json) + return {"url": os.path.join(urlpath, filename), "format": {"type": "json"}} + + +@curried.curry +def to_csv( + data: Union[dict, pd.DataFrame, DataFrameLike], + prefix: str = "altair-data", + extension: str = "csv", + filename: str = "{prefix}-{hash}.{extension}", + urlpath: str = "", +) -> _ToCsvReturnUrlDict: + """Write the data model to a .csv file and return a url based data model.""" + data_csv = _data_to_csv_string(data) + data_hash = _compute_data_hash(data_csv) + filename = filename.format(prefix=prefix, hash=data_hash, extension=extension) + with open(filename, "w") as f: + f.write(data_csv) + return {"url": os.path.join(urlpath, filename), "format": {"type": "csv"}} + + +@curried.curry +def to_values(data: DataType) -> ToValuesReturnType: + """Replace a DataFrame by a data model with values.""" + check_data_type(data) + if hasattr(data, "__geo_interface__"): + if isinstance(data, pd.DataFrame): + data = sanitize_dataframe(data) + # Maybe the type could be further clarified here that it is + # SupportGeoInterface and then the ignore statement is not needed? + data_sanitized = sanitize_geo_interface(data.__geo_interface__) # type: ignore[arg-type] + return {"values": data_sanitized} + elif isinstance(data, pd.DataFrame): + data = sanitize_dataframe(data) + return {"values": data.to_dict(orient="records")} + elif isinstance(data, dict): + if "values" not in data: + raise KeyError("values expected in data dict, but not present.") + return data + elif isinstance(data, DataFrameLike): + pa_table = sanitize_arrow_table(arrow_table_from_dfi_dataframe(data)) + return {"values": pa_table.to_pylist()} + else: + # Should never reach this state as tested by check_data_type + raise ValueError("Unrecognized data type: {}".format(type(data))) + + +def check_data_type(data: DataType) -> None: + if not isinstance(data, (dict, pd.DataFrame, DataFrameLike)) and not any( + hasattr(data, attr) for attr in ["__geo_interface__"] + ): + raise TypeError( + "Expected dict, DataFrame or a __geo_interface__ attribute, got: {}".format( + type(data) + ) + ) + + +# ============================================================================== +# Private utilities +# ============================================================================== +def _compute_data_hash(data_str: str) -> str: + return hashlib.sha256(data_str.encode()).hexdigest()[:32] + + +def _data_to_json_string(data: DataType) -> str: + """Return a JSON string representation of the input data""" + check_data_type(data) + if hasattr(data, "__geo_interface__"): + if isinstance(data, pd.DataFrame): + data = sanitize_dataframe(data) + # Maybe the type could be further clarified here that it is + # SupportGeoInterface and then the ignore statement is not needed? + data = sanitize_geo_interface(data.__geo_interface__) # type: ignore[arg-type] + return json.dumps(data) + elif isinstance(data, pd.DataFrame): + data = sanitize_dataframe(data) + return data.to_json(orient="records", double_precision=15) + elif isinstance(data, dict): + if "values" not in data: + raise KeyError("values expected in data dict, but not present.") + return json.dumps(data["values"], sort_keys=True) + elif isinstance(data, DataFrameLike): + pa_table = arrow_table_from_dfi_dataframe(data) + return json.dumps(pa_table.to_pylist()) + else: + raise NotImplementedError( + "to_json only works with data expressed as " "a DataFrame or as a dict" + ) + + +def _data_to_csv_string(data: Union[dict, pd.DataFrame, DataFrameLike]) -> str: + """return a CSV string representation of the input data""" + check_data_type(data) + if hasattr(data, "__geo_interface__"): + raise NotImplementedError( + "to_csv does not work with data that " + "contains the __geo_interface__ attribute" + ) + elif isinstance(data, pd.DataFrame): + data = sanitize_dataframe(data) + return data.to_csv(index=False) + elif isinstance(data, dict): + if "values" not in data: + raise KeyError("values expected in data dict, but not present") + return pd.DataFrame.from_dict(data["values"]).to_csv(index=False) + elif isinstance(data, DataFrameLike): + # experimental interchange dataframe support + import pyarrow as pa + import pyarrow.csv as pa_csv + + pa_table = arrow_table_from_dfi_dataframe(data) + csv_buffer = pa.BufferOutputStream() + pa_csv.write_csv(pa_table, csv_buffer) + return csv_buffer.getvalue().to_pybytes().decode() + else: + raise NotImplementedError( + "to_csv only works with data expressed as " "a DataFrame or as a dict" + ) + + +def pipe(data, *funcs): + """ + Pipe a value through a sequence of functions + + Deprecated: use toolz.curried.pipe() instead. + """ + warnings.warn( + "alt.pipe() is deprecated, and will be removed in a future release. " + "Use toolz.curried.pipe() instead.", + AltairDeprecationWarning, + stacklevel=1, + ) + return curried.pipe(data, *funcs) + + +def curry(*args, **kwargs): + """Curry a callable function + + Deprecated: use toolz.curried.curry() instead. + """ + warnings.warn( + "alt.curry() is deprecated, and will be removed in a future release. " + "Use toolz.curried.curry() instead.", + AltairDeprecationWarning, + stacklevel=1, + ) + return curried.curry(*args, **kwargs) + + +def arrow_table_from_dfi_dataframe(dfi_df: DataFrameLike) -> "pyarrow.lib.Table": + """Convert a DataFrame Interchange Protocol compatible object to an Arrow Table""" + import pyarrow as pa + + # First check if the dataframe object has a method to convert to arrow. + # Give this preference over the pyarrow from_dataframe function since the object + # has more control over the conversion, and may have broader compatibility. + # This is the case for Polars, which supports Date32 columns in direct conversion + # while pyarrow does not yet support this type in from_dataframe + for convert_method_name in ("arrow", "to_arrow", "to_arrow_table"): + convert_method = getattr(dfi_df, convert_method_name, None) + if callable(convert_method): + result = convert_method() + if isinstance(result, pa.Table): + return result + + pi = import_pyarrow_interchange() + return pi.from_dataframe(dfi_df) diff --git a/llm/Lib/site-packages/altair/utils/deprecation.py b/llm/Lib/site-packages/altair/utils/deprecation.py new file mode 100644 index 0000000000000000000000000000000000000000..f0ed26ae98f9a71512f85ca589fd4e160033f97b --- /dev/null +++ b/llm/Lib/site-packages/altair/utils/deprecation.py @@ -0,0 +1,71 @@ +import warnings +import functools + + +class AltairDeprecationWarning(UserWarning): + pass + + +def deprecated(message=None): + """Decorator to deprecate a function or class. + + Parameters + ---------- + message : string (optional) + The deprecation message + """ + + def wrapper(obj): + return _deprecate(obj, message=message) + + return wrapper + + +def _deprecate(obj, name=None, message=None): + """Return a version of a class or function that raises a deprecation warning. + + Parameters + ---------- + obj : class or function + The object to create a deprecated version of. + name : string (optional) + The name of the deprecated object + message : string (optional) + The deprecation message + + Returns + ------- + deprecated_obj : + The deprecated version of obj + + Examples + -------- + >>> class Foo: pass + >>> OldFoo = _deprecate(Foo, "OldFoo") + >>> f = OldFoo() # doctest: +SKIP + AltairDeprecationWarning: alt.OldFoo is deprecated. Use alt.Foo instead. + """ + if message is None: + message = "alt.{} is deprecated. Use alt.{} instead." "".format( + name, obj.__name__ + ) + if isinstance(obj, type): + return type( + name, + (obj,), + { + "__doc__": obj.__doc__, + "__init__": _deprecate(obj.__init__, "__init__", message), + }, + ) + elif callable(obj): + + @functools.wraps(obj) + def new_obj(*args, **kwargs): + warnings.warn(message, AltairDeprecationWarning, stacklevel=1) + return obj(*args, **kwargs) + + new_obj._deprecated = True + return new_obj + else: + raise ValueError("Cannot deprecate object of type {}".format(type(obj))) diff --git a/llm/Lib/site-packages/altair/utils/display.py b/llm/Lib/site-packages/altair/utils/display.py new file mode 100644 index 0000000000000000000000000000000000000000..c1585eda2e0489f7386dc66c08f2d0121125cd28 --- /dev/null +++ b/llm/Lib/site-packages/altair/utils/display.py @@ -0,0 +1,229 @@ +import json +import pkgutil +import textwrap +from typing import Callable, Dict, Optional, Tuple, Any, Union +import uuid + +from ._vegafusion_data import compile_with_vegafusion, using_vegafusion +from .plugin_registry import PluginRegistry, PluginEnabler +from .mimebundle import spec_to_mimebundle +from .schemapi import validate_jsonschema + + +# ============================================================================== +# Renderer registry +# ============================================================================== +# MimeBundleType needs to be the same as what are acceptable return values +# for _repr_mimebundle_, +# see https://ipython.readthedocs.io/en/stable/config/integrating.html#MyObject._repr_mimebundle_ +MimeBundleDataType = Dict[str, Any] +MimeBundleMetaDataType = Dict[str, Any] +MimeBundleType = Union[ + MimeBundleDataType, Tuple[MimeBundleDataType, MimeBundleMetaDataType] +] +RendererType = Callable[..., MimeBundleType] +# Subtype of MimeBundleType as more specific in the values of the dictionaries +DefaultRendererReturnType = Tuple[ + Dict[str, Union[str, dict]], Dict[str, Dict[str, Any]] +] + + +class RendererRegistry(PluginRegistry[RendererType]): + entrypoint_err_messages = { + "notebook": textwrap.dedent( + """ + To use the 'notebook' renderer, you must install the vega package + and the associated Jupyter extension. + See https://altair-viz.github.io/getting_started/installation.html + for more information. + """ + ), + "altair_viewer": textwrap.dedent( + """ + To use the 'altair_viewer' renderer, you must install the altair_viewer + package; see http://github.com/altair-viz/altair_viewer/ + for more information. + """ + ), + } + + def set_embed_options( + self, + defaultStyle: Optional[Union[bool, str]] = None, + renderer: Optional[str] = None, + width: Optional[int] = None, + height: Optional[int] = None, + padding: Optional[int] = None, + scaleFactor: Optional[float] = None, + actions: Optional[Union[bool, Dict[str, bool]]] = None, + format_locale: Optional[Union[str, dict]] = None, + time_format_locale: Optional[Union[str, dict]] = None, + **kwargs, + ) -> PluginEnabler: + """Set options for embeddings of Vega & Vega-Lite charts. + + Options are fully documented at https://github.com/vega/vega-embed. + Similar to the `enable()` method, this can be used as either + a persistent global switch, or as a temporary local setting using + a context manager (i.e. a `with` statement). + + Parameters + ---------- + defaultStyle : bool or string + Specify a default stylesheet for embed actions. + renderer : string + The renderer to use for the view. One of "canvas" (default) or "svg" + width : integer + The view width in pixels + height : integer + The view height in pixels + padding : integer + The view padding in pixels + scaleFactor : number + The number by which to multiply the width and height (default 1) + of an exported PNG or SVG image. + actions : bool or dict + Determines if action links ("Export as PNG/SVG", "View Source", + "View Vega" (only for Vega-Lite), "Open in Vega Editor") are + included with the embedded view. If the value is true, all action + links will be shown and none if the value is false. This property + can take a key-value mapping object that maps keys (export, source, + compiled, editor) to boolean values for determining if + each action link should be shown. + format_locale : str or dict + d3-format locale name or dictionary. Defaults to "en-US" for United States English. + See https://github.com/d3/d3-format/tree/main/locale for available names and example + definitions. + time_format_locale : str or dict + d3-time-format locale name or dictionary. Defaults to "en-US" for United States English. + See https://github.com/d3/d3-time-format/tree/main/locale for available names and example + definitions. + **kwargs : + Additional options are passed directly to embed options. + """ + options: Dict[str, Optional[Union[bool, str, float, Dict[str, bool]]]] = { + "defaultStyle": defaultStyle, + "renderer": renderer, + "width": width, + "height": height, + "padding": padding, + "scaleFactor": scaleFactor, + "actions": actions, + "formatLocale": format_locale, + "timeFormatLocale": time_format_locale, + } + kwargs.update({key: val for key, val in options.items() if val is not None}) + return self.enable(None, embed_options=kwargs) + + +# ============================================================================== +# VegaLite v1/v2 renderer logic +# ============================================================================== + + +class Displayable: + """A base display class for VegaLite v1/v2. + + This class takes a VegaLite v1/v2 spec and does the following: + + 1. Optionally validates the spec against a schema. + 2. Uses the RendererPlugin to grab a renderer and call it when the + IPython/Jupyter display method (_repr_mimebundle_) is called. + + The spec passed to this class must be fully schema compliant and already + have the data portion of the spec fully processed and ready to serialize. + In practice, this means, the data portion of the spec should have been passed + through appropriate data model transformers. + """ + + renderers: Optional[RendererRegistry] = None + schema_path = ("altair", "") + + def __init__(self, spec: dict, validate: bool = False) -> None: + self.spec = spec + self.validate = validate + self._validate() + + def _validate(self) -> None: + """Validate the spec against the schema.""" + data = pkgutil.get_data(*self.schema_path) + assert data is not None + schema_dict: dict = json.loads(data.decode("utf-8")) + validate_jsonschema( + self.spec, + schema_dict, + ) + + def _repr_mimebundle_( + self, include: Any = None, exclude: Any = None + ) -> MimeBundleType: + """Return a MIME bundle for display in Jupyter frontends.""" + if self.renderers is not None: + renderer_func = self.renderers.get() + assert renderer_func is not None + return renderer_func(self.spec) + else: + return {} + + +def default_renderer_base( + spec: dict, mime_type: str, str_repr: str, **options +) -> DefaultRendererReturnType: + """A default renderer for Vega or VegaLite that works for modern frontends. + + This renderer works with modern frontends (JupyterLab, nteract) that know + how to render the custom VegaLite MIME type listed above. + """ + # Local import to avoid circular ImportError + from altair.vegalite.v5.display import VEGA_MIME_TYPE, VEGALITE_MIME_TYPE + + assert isinstance(spec, dict) + bundle: Dict[str, Union[str, dict]] = {} + metadata: Dict[str, Dict[str, Any]] = {} + + if using_vegafusion(): + spec = compile_with_vegafusion(spec) + + # Swap mimetype from Vega-Lite to Vega. + # If mimetype was JSON, leave it alone + if mime_type == VEGALITE_MIME_TYPE: + mime_type = VEGA_MIME_TYPE + + bundle[mime_type] = spec + bundle["text/plain"] = str_repr + if options: + metadata[mime_type] = options + return bundle, metadata + + +def json_renderer_base( + spec: dict, str_repr: str, **options +) -> DefaultRendererReturnType: + """A renderer that returns a MIME type of application/json. + + In JupyterLab/nteract this is rendered as a nice JSON tree. + """ + return default_renderer_base( + spec, mime_type="application/json", str_repr=str_repr, **options + ) + + +class HTMLRenderer: + """Object to render charts as HTML, with a unique output div each time""" + + def __init__(self, output_div: str = "altair-viz-{}", **kwargs) -> None: + self._output_div = output_div + self.kwargs = kwargs + + @property + def output_div(self) -> str: + return self._output_div.format(uuid.uuid4().hex) + + def __call__(self, spec: dict, **metadata) -> Dict[str, str]: + kwargs = self.kwargs.copy() + kwargs.update(metadata) + # To get proper return value type, would need to write complex + # overload signatures for spec_to_mimebundle based on `format` + return spec_to_mimebundle( # type: ignore[return-value] + spec, format="html", output_div=self.output_div, **kwargs + ) diff --git a/llm/Lib/site-packages/altair/utils/execeval.py b/llm/Lib/site-packages/altair/utils/execeval.py new file mode 100644 index 0000000000000000000000000000000000000000..7e98fd378bfee470858b3a79e4ab3473086f2eef --- /dev/null +++ b/llm/Lib/site-packages/altair/utils/execeval.py @@ -0,0 +1,53 @@ +import ast +import sys + + +class _CatchDisplay: + """Class to temporarily catch sys.displayhook""" + + def __init__(self): + self.output = None + + def __enter__(self): + self.old_hook = sys.displayhook + sys.displayhook = self + return self + + def __exit__(self, type, value, traceback): + sys.displayhook = self.old_hook + # Returning False will cause exceptions to propagate + return False + + def __call__(self, output): + self.output = output + + +def eval_block(code, namespace=None, filename=""): + """ + Execute a multi-line block of code in the given namespace + + If the final statement in the code is an expression, return + the result of the expression. + """ + tree = ast.parse(code, filename="", mode="exec") + if namespace is None: + namespace = {} + catch_display = _CatchDisplay() + + if isinstance(tree.body[-1], ast.Expr): + to_exec, to_eval = tree.body[:-1], tree.body[-1:] + else: + to_exec, to_eval = tree.body, [] + + for node in to_exec: + compiled = compile(ast.Module([node], []), filename=filename, mode="exec") + exec(compiled, namespace) + + with catch_display: + for node in to_eval: + compiled = compile( + ast.Interactive([node]), filename=filename, mode="single" + ) + exec(compiled, namespace) + + return catch_display.output diff --git a/llm/Lib/site-packages/altair/utils/html.py b/llm/Lib/site-packages/altair/utils/html.py new file mode 100644 index 0000000000000000000000000000000000000000..6cd89b2ddef982aad26e77970fbc020131f250fe --- /dev/null +++ b/llm/Lib/site-packages/altair/utils/html.py @@ -0,0 +1,303 @@ +import json +from typing import Optional, Dict + +import jinja2 + +from altair.utils._importers import import_vl_convert, vl_version_for_vl_convert + +HTML_TEMPLATE = jinja2.Template( + """ +{%- if fullhtml -%} + + + +{%- endif %} + +{%- if not requirejs %} + + {%- if mode == 'vega-lite' %} + + {%- endif %} + +{%- endif %} +{%- if fullhtml %} +{%- if requirejs %} + + +{%- endif %} + + +{%- endif %} +
    + +{%- if fullhtml %} + + +{%- endif %} +""" +) + + +HTML_TEMPLATE_UNIVERSAL = jinja2.Template( + """ + +
    + +""" +) + + +# This is like the HTML_TEMPLATE template, but includes vega javascript inline +# so that the resulting file is not dependent on external resources. This was +# ported over from altair_saver. +# +# implies requirejs=False and full_html=True +INLINE_HTML_TEMPLATE = jinja2.Template( + """\ + + + + + + + +
    + + + +""" +) + + +TEMPLATES: Dict[str, jinja2.Template] = { + "standard": HTML_TEMPLATE, + "universal": HTML_TEMPLATE_UNIVERSAL, + "inline": INLINE_HTML_TEMPLATE, +} + + +def spec_to_html( + spec: dict, + mode: str, + vega_version: Optional[str], + vegaembed_version: Optional[str], + vegalite_version: Optional[str] = None, + base_url: str = "https://cdn.jsdelivr.net/npm", + output_div: str = "vis", + embed_options: Optional[dict] = None, + json_kwds: Optional[dict] = None, + fullhtml: bool = True, + requirejs: bool = False, + template: str = "standard", +) -> str: + """Embed a Vega/Vega-Lite spec into an HTML page + + Parameters + ---------- + spec : dict + a dictionary representing a vega-lite plot spec. + mode : string {'vega' | 'vega-lite'} + The rendering mode. This value is overridden by embed_options['mode'], + if it is present. + vega_version : string + For html output, the version of vega.js to use. + vegalite_version : string + For html output, the version of vegalite.js to use. + vegaembed_version : string + For html output, the version of vegaembed.js to use. + base_url : string (optional) + The base url from which to load the javascript libraries. + output_div : string (optional) + The id of the div element where the plot will be shown. + embed_options : dict (optional) + Dictionary of options to pass to the vega-embed script. Default + entry is {'mode': mode}. + json_kwds : dict (optional) + Dictionary of keywords to pass to json.dumps(). + fullhtml : boolean (optional) + If True (default) then return a full html page. If False, then return + an HTML snippet that can be embedded into an HTML page. + requirejs : boolean (optional) + If False (default) then load libraries from base_url using