ce304fafe19161978ad512b385c65426bad519e5a0b8fb3f0659eace3d2ea3cc
Browse files- lib/python3.11/site-packages/mlx-0.0.7.dist-info/INSTALLER +1 -0
- lib/python3.11/site-packages/mlx-0.0.7.dist-info/LICENSE +21 -0
- lib/python3.11/site-packages/mlx-0.0.7.dist-info/METADATA +122 -0
- lib/python3.11/site-packages/mlx-0.0.7.dist-info/RECORD +199 -0
- lib/python3.11/site-packages/mlx-0.0.7.dist-info/REQUESTED +0 -0
- lib/python3.11/site-packages/mlx-0.0.7.dist-info/WHEEL +5 -0
- lib/python3.11/site-packages/mlx-0.0.7.dist-info/top_level.txt +1 -0
- lib/python3.11/site-packages/mlx/nn/layers/base.py +532 -0
- lib/python3.11/site-packages/mlx/nn/layers/containers.py +24 -0
- lib/python3.11/site-packages/mlx/nn/layers/convolution.py +126 -0
- lib/python3.11/site-packages/mlx/nn/layers/dropout.py +137 -0
- lib/python3.11/site-packages/mlx/nn/layers/embedding.py +30 -0
- lib/python3.11/site-packages/mlx/nn/layers/linear.py +141 -0
- lib/python3.11/site-packages/mlx/nn/layers/normalization.py +368 -0
- lib/python3.11/site-packages/mlx/nn/layers/positional_encoding.py +199 -0
- lib/python3.11/site-packages/mlx/nn/layers/quantized.py +125 -0
- lib/python3.11/site-packages/mlx/nn/layers/transformer.py +354 -0
- lib/python3.11/site-packages/mlx/nn/losses.py +374 -0
- lib/python3.11/site-packages/mlx/nn/utils.py +33 -0
- lib/python3.11/site-packages/mlx/optimizers.py +500 -0
- lib/python3.11/site-packages/mlx/share/cmake/MLX/MLXConfig.cmake +57 -0
- lib/python3.11/site-packages/mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- lib/python3.11/site-packages/mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- lib/python3.11/site-packages/mlx/share/cmake/MLX/MLXTargets.cmake +107 -0
- lib/python3.11/site-packages/mlx/share/cmake/MLX/extension.cmake +56 -0
- lib/python3.11/site-packages/mlx/utils.py +145 -0
- lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/INSTALLER +1 -0
- lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/LICENSE +19 -0
- lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/METADATA +253 -0
- lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/RECORD +16 -0
- lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/REQUESTED +0 -0
- lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/WHEEL +4 -0
- lib/python3.11/site-packages/more_itertools/__init__.py +6 -0
- lib/python3.11/site-packages/more_itertools/__init__.pyi +2 -0
- lib/python3.11/site-packages/more_itertools/__pycache__/__init__.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/more_itertools/__pycache__/more.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/more_itertools/__pycache__/recipes.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/more_itertools/more.py +0 -0
- lib/python3.11/site-packages/more_itertools/more.pyi +684 -0
- lib/python3.11/site-packages/more_itertools/py.typed +0 -0
- lib/python3.11/site-packages/more_itertools/recipes.py +977 -0
- lib/python3.11/site-packages/more_itertools/recipes.pyi +122 -0
- lib/python3.11/site-packages/mpmath/__init__.py +468 -0
- lib/python3.11/site-packages/mpmath/__pycache__/__init__.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mpmath/__pycache__/ctx_base.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mpmath/__pycache__/ctx_fp.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mpmath/__pycache__/ctx_iv.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mpmath/__pycache__/ctx_mp.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mpmath/__pycache__/ctx_mp_python.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mpmath/__pycache__/function_docs.cpython-311.pyc +0 -0
    	
        lib/python3.11/site-packages/mlx-0.0.7.dist-info/INSTALLER
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            pip
         | 
    	
        lib/python3.11/site-packages/mlx-0.0.7.dist-info/LICENSE
    ADDED
    
    | @@ -0,0 +1,21 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            MIT License
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            Copyright © 2023 Apple Inc.
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            Permission is hereby granted, free of charge, to any person obtaining a copy
         | 
| 6 | 
            +
            of this software and associated documentation files (the "Software"), to deal
         | 
| 7 | 
            +
            in the Software without restriction, including without limitation the rights
         | 
| 8 | 
            +
            to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
         | 
| 9 | 
            +
            copies of the Software, and to permit persons to whom the Software is
         | 
| 10 | 
            +
            furnished to do so, subject to the following conditions:
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            The above copyright notice and this permission notice shall be included in all
         | 
| 13 | 
            +
            copies or substantial portions of the Software.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         | 
| 16 | 
            +
            IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         | 
| 17 | 
            +
            FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         | 
| 18 | 
            +
            AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         | 
| 19 | 
            +
            LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         | 
| 20 | 
            +
            OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         | 
| 21 | 
            +
            SOFTWARE.
         | 
    	
        lib/python3.11/site-packages/mlx-0.0.7.dist-info/METADATA
    ADDED
    
    | @@ -0,0 +1,122 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            Metadata-Version: 2.1
         | 
| 2 | 
            +
            Name: mlx
         | 
| 3 | 
            +
            Version: 0.0.7
         | 
| 4 | 
            +
            Summary: A framework for machine learning on Apple silicon.
         | 
| 5 | 
            +
            Author: MLX Contributors
         | 
| 6 | 
            +
            Author-email: [email protected]
         | 
| 7 | 
            +
            Requires-Python: >=3.8
         | 
| 8 | 
            +
            Description-Content-Type: text/markdown
         | 
| 9 | 
            +
            License-File: LICENSE
         | 
| 10 | 
            +
            Provides-Extra: dev
         | 
| 11 | 
            +
            Requires-Dist: pre-commit ; extra == 'dev'
         | 
| 12 | 
            +
            Requires-Dist: pybind11-stubgen ; extra == 'dev'
         | 
| 13 | 
            +
            Provides-Extra: testing
         | 
| 14 | 
            +
            Requires-Dist: numpy ; extra == 'testing'
         | 
| 15 | 
            +
            Requires-Dist: torch ; extra == 'testing'
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            # MLX
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            [**Quickstart**](#quickstart) | [**Installation**](#installation) |
         | 
| 20 | 
            +
            [**Documentation**](https://ml-explore.github.io/mlx/build/html/index.html) |
         | 
| 21 | 
            +
            [**Examples**](#examples) 
         | 
| 22 | 
            +
             | 
| 23 | 
            +
            [](https://circleci.com/gh/ml-explore/mlx)
         | 
| 24 | 
            +
             | 
| 25 | 
            +
            MLX is an array framework for machine learning on Apple silicon, brought to you
         | 
| 26 | 
            +
            by Apple machine learning research.
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            Some key features of MLX include:
         | 
| 29 | 
            +
             | 
| 30 | 
            +
             - **Familiar APIs**: MLX has a Python API that closely follows NumPy.
         | 
| 31 | 
            +
               MLX also has a fully featured C++ API, which closely mirrors the Python API. 
         | 
| 32 | 
            +
               MLX has higher-level packages like `mlx.nn` and `mlx.optimizers` with APIs
         | 
| 33 | 
            +
               that closely follow PyTorch to simplify building more complex models.
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             - **Composable function transformations**: MLX supports composable function
         | 
| 36 | 
            +
               transformations for automatic differentiation, automatic vectorization,
         | 
| 37 | 
            +
               and computation graph optimization.
         | 
| 38 | 
            +
             | 
| 39 | 
            +
             - **Lazy computation**: Computations in MLX are lazy. Arrays are only
         | 
| 40 | 
            +
               materialized when needed.
         | 
| 41 | 
            +
             | 
| 42 | 
            +
             - **Dynamic graph construction**: Computation graphs in MLX are constructed
         | 
| 43 | 
            +
               dynamically. Changing the shapes of function arguments does not trigger
         | 
| 44 | 
            +
               slow compilations, and debugging is simple and intuitive.
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             - **Multi-device**: Operations can run on any of the supported devices
         | 
| 47 | 
            +
               (currently the CPU and the GPU).
         | 
| 48 | 
            +
             | 
| 49 | 
            +
             - **Unified memory**: A notable difference from MLX and other frameworks
         | 
| 50 | 
            +
               is the *unified memory model*. Arrays in MLX live in shared memory.
         | 
| 51 | 
            +
               Operations on MLX arrays can be performed on any of the supported
         | 
| 52 | 
            +
               device types without transferring data.
         | 
| 53 | 
            +
             | 
| 54 | 
            +
            MLX is designed by machine learning researchers for machine learning
         | 
| 55 | 
            +
            researchers. The framework is intended to be user-friendly, but still efficient
         | 
| 56 | 
            +
            to train and deploy models. The design of the framework itself is also
         | 
| 57 | 
            +
            conceptually simple. We intend to make it easy for researchers to extend and
         | 
| 58 | 
            +
            improve MLX with the goal of quickly exploring new ideas. 
         | 
| 59 | 
            +
             | 
| 60 | 
            +
            The design of MLX is inspired by frameworks like
         | 
| 61 | 
            +
            [NumPy](https://numpy.org/doc/stable/index.html),
         | 
| 62 | 
            +
            [PyTorch](https://pytorch.org/), [Jax](https://github.com/google/jax), and
         | 
| 63 | 
            +
            [ArrayFire](https://arrayfire.org/).
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            ## Examples
         | 
| 66 | 
            +
             | 
| 67 | 
            +
            The [MLX examples repo](https://github.com/ml-explore/mlx-examples) has a
         | 
| 68 | 
            +
            variety of examples, including:
         | 
| 69 | 
            +
             | 
| 70 | 
            +
            - [Transformer language model](https://github.com/ml-explore/mlx-examples/tree/main/transformer_lm) training.
         | 
| 71 | 
            +
            - Large-scale text generation with
         | 
| 72 | 
            +
              [LLaMA](https://github.com/ml-explore/mlx-examples/tree/main/llms/llama) and
         | 
| 73 | 
            +
              finetuning with [LoRA](https://github.com/ml-explore/mlx-examples/tree/main/lora).
         | 
| 74 | 
            +
            - Generating images with [Stable Diffusion](https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion).
         | 
| 75 | 
            +
            - Speech recognition with [OpenAI's Whisper](https://github.com/ml-explore/mlx-examples/tree/main/whisper).
         | 
| 76 | 
            +
             | 
| 77 | 
            +
            ## Quickstart
         | 
| 78 | 
            +
             | 
| 79 | 
            +
            See the [quick start
         | 
| 80 | 
            +
            guide](https://ml-explore.github.io/mlx/build/html/quick_start.html)
         | 
| 81 | 
            +
            in the documentation.
         | 
| 82 | 
            +
             | 
| 83 | 
            +
            ## Installation
         | 
| 84 | 
            +
             | 
| 85 | 
            +
            MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
         | 
| 86 | 
            +
             | 
| 87 | 
            +
            ```
         | 
| 88 | 
            +
            pip install mlx
         | 
| 89 | 
            +
            ```
         | 
| 90 | 
            +
             | 
| 91 | 
            +
            Checkout the
         | 
| 92 | 
            +
            [documentation](https://ml-explore.github.io/mlx/build/html/install.html#)
         | 
| 93 | 
            +
            for more information on building the C++ and Python APIs from source.
         | 
| 94 | 
            +
             | 
| 95 | 
            +
            ## Contributing 
         | 
| 96 | 
            +
             | 
| 97 | 
            +
            Check out the [contribution guidelines](CONTRIBUTING.md) for more information
         | 
| 98 | 
            +
            on contributing to MLX. See the
         | 
| 99 | 
            +
            [docs](https://ml-explore.github.io/mlx/build/html/install.html) for more
         | 
| 100 | 
            +
            information on building from source, and running tests.
         | 
| 101 | 
            +
             | 
| 102 | 
            +
            We are grateful for all of [our
         | 
| 103 | 
            +
            contributors](ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
         | 
| 104 | 
            +
            to MLX and wish to be acknowledged, please add your name to the list in your
         | 
| 105 | 
            +
            pull request.
         | 
| 106 | 
            +
             | 
| 107 | 
            +
            ## Citing MLX
         | 
| 108 | 
            +
             | 
| 109 | 
            +
            The MLX software suite was initially developed with equal contribution by Awni
         | 
| 110 | 
            +
            Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find
         | 
| 111 | 
            +
            MLX useful in your research and wish to cite it, please use the following
         | 
| 112 | 
            +
            BibTex entry:
         | 
| 113 | 
            +
             | 
| 114 | 
            +
            ```
         | 
| 115 | 
            +
            @software{mlx2023,
         | 
| 116 | 
            +
              author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
         | 
| 117 | 
            +
              title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
         | 
| 118 | 
            +
              url = {https://github.com/ml-explore},
         | 
| 119 | 
            +
              version = {0.0},
         | 
| 120 | 
            +
              year = {2023},
         | 
| 121 | 
            +
            }
         | 
| 122 | 
            +
            ```
         | 
    	
        lib/python3.11/site-packages/mlx-0.0.7.dist-info/RECORD
    ADDED
    
    | @@ -0,0 +1,199 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            mlx-0.0.7.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
         | 
| 2 | 
            +
            mlx-0.0.7.dist-info/LICENSE,sha256=zPq3zLLqMG9xUxyMp3u1VQdgbNkHaLHjK4tSq1tIzwE,1066
         | 
| 3 | 
            +
            mlx-0.0.7.dist-info/METADATA,sha256=CK9nqz-nrSgrFZpRYpc86IYysL_W7zObqiiczNHD5OA,4731
         | 
| 4 | 
            +
            mlx-0.0.7.dist-info/RECORD,,
         | 
| 5 | 
            +
            mlx-0.0.7.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 6 | 
            +
            mlx-0.0.7.dist-info/WHEEL,sha256=9CmlYCCxe9UEGaJDl4r_tPKbyfF_OdzBb5_jMO3h1CA,110
         | 
| 7 | 
            +
            mlx-0.0.7.dist-info/top_level.txt,sha256=NCVZzyms-obPlDq2Sa--m3oyfnnmLma38KCnE6gmcFE,4
         | 
| 8 | 
            +
            mlx/__pycache__/_reprlib_fix.cpython-311.pyc,,
         | 
| 9 | 
            +
            mlx/__pycache__/extension.cpython-311.pyc,,
         | 
| 10 | 
            +
            mlx/__pycache__/optimizers.cpython-311.pyc,,
         | 
| 11 | 
            +
            mlx/__pycache__/utils.cpython-311.pyc,,
         | 
| 12 | 
            +
            mlx/_reprlib_fix.py,sha256=YoRQib4PSzl1FIMvs4BwkvipAwOrfufqtV89vNF4Yhc,518
         | 
| 13 | 
            +
            mlx/core.cpython-311-darwin.so,sha256=4GMiWnLS2sZGsrIwB9mAryue5Lztf-SttAn_AzadmCc,1098840
         | 
| 14 | 
            +
            mlx/extension.py,sha256=tqKrGeP1ElK9UfBMHUeOq56qWnMMxZrt7DVVM8QZ7yU,3769
         | 
| 15 | 
            +
            mlx/include/metal_cpp/Foundation/Foundation.hpp,sha256=pYbEkzTrb2O0jJPzI0m9cBNvVSDTh_t39Dv7h0ZDma0,1808
         | 
| 16 | 
            +
            mlx/include/metal_cpp/Foundation/NSArray.hpp,sha256=8nhS0gF060n7taK9aIynkuMPxlbKgqgijsj61yvhrPQ,4818
         | 
| 17 | 
            +
            mlx/include/metal_cpp/Foundation/NSAutoreleasePool.hpp,sha256=K8_qNgYgyeP49WUSsz90Lo41qTb-x5Xjs23T9UNSkkU,3302
         | 
| 18 | 
            +
            mlx/include/metal_cpp/Foundation/NSBundle.hpp,sha256=y0Le2v43bCAdIdrSClvLpRtCL8c4qdB3WK59Gz_AYYs,16845
         | 
| 19 | 
            +
            mlx/include/metal_cpp/Foundation/NSData.hpp,sha256=rAPEZQn1W05O_sZ6LYexLhvv1CmllqyFWddGIrxGJ8s,2198
         | 
| 20 | 
            +
            mlx/include/metal_cpp/Foundation/NSDate.hpp,sha256=353UgslEWyV-D1_T7_Sr1fCP4nVjixwJK2QGnRZQKgg,2085
         | 
| 21 | 
            +
            mlx/include/metal_cpp/Foundation/NSDefines.hpp,sha256=c8BfIDwW9yCKiofyatyeLJ9YvnHO87PsBqSQtciN3tY,2203
         | 
| 22 | 
            +
            mlx/include/metal_cpp/Foundation/NSDictionary.hpp,sha256=Lt7EBR6ppPFhhnoZq9UqMETHf_L1k9A0gs637ZTEQSU,5738
         | 
| 23 | 
            +
            mlx/include/metal_cpp/Foundation/NSEnumerator.hpp,sha256=P4mkVmpWN0kdXnqqTMkyIovckr4_IZAoFzyg2IZDRs4,3141
         | 
| 24 | 
            +
            mlx/include/metal_cpp/Foundation/NSError.hpp,sha256=rFqCSIXVsWOsH7Ikegv2irMF3aoMn7fM6teB9VipMgc,7802
         | 
| 25 | 
            +
            mlx/include/metal_cpp/Foundation/NSLock.hpp,sha256=OjnoMx18m4E89eDPliPMw3cG8LtAqNsoBSxSO0l9DoY,4350
         | 
| 26 | 
            +
            mlx/include/metal_cpp/Foundation/NSNotification.hpp,sha256=rpAgI4upmIG2inWZ9ACYy-Q8za7muZ2WmUtA5l8WGHw,4696
         | 
| 27 | 
            +
            mlx/include/metal_cpp/Foundation/NSNumber.hpp,sha256=4a2fEU1Fm_wrWOs4zULhYQne0WymnRGsDeSGq1bDbuQ,22098
         | 
| 28 | 
            +
            mlx/include/metal_cpp/Foundation/NSObjCRuntime.hpp,sha256=SHYqTLWQ2F8T4dfXJ98OlATOexXQ3NnXzmFyVpb4wFM,1665
         | 
| 29 | 
            +
            mlx/include/metal_cpp/Foundation/NSObject.hpp,sha256=T_64vByLI4n3HIE05MVJZteUIAdN-y0oogsczYT-JTE,11105
         | 
| 30 | 
            +
            mlx/include/metal_cpp/Foundation/NSPrivate.hpp,sha256=JYewPlM0BKHLnmQiItPlYuc1HNAO-akJjhXnEKUZO00,20217
         | 
| 31 | 
            +
            mlx/include/metal_cpp/Foundation/NSProcessInfo.hpp,sha256=1NLI7-3mR8cseE_jE6c-D2_pRJKXA9fRl4V2uRZtiDI,16642
         | 
| 32 | 
            +
            mlx/include/metal_cpp/Foundation/NSRange.hpp,sha256=7VFC565QzvzfF1CI91osTozl-d1mM9UToCcGyXmf5cc,3147
         | 
| 33 | 
            +
            mlx/include/metal_cpp/Foundation/NSSet.hpp,sha256=0rORF6rpK2-R1AQCgT4CMNnGr18xZ8P5x82LEQLQ2sc,3368
         | 
| 34 | 
            +
            mlx/include/metal_cpp/Foundation/NSSharedPtr.hpp,sha256=208e_0hJddLllsT16pJa4rQAyFqA1uYuez5gWyRS6EU,8496
         | 
| 35 | 
            +
            mlx/include/metal_cpp/Foundation/NSString.hpp,sha256=AqRvSkbHPncbC36TJuR-4beEVT7ZPwnrVjmyoassaL8,11001
         | 
| 36 | 
            +
            mlx/include/metal_cpp/Foundation/NSTypes.hpp,sha256=9IaLleaqYbdjlMvT_mqFlYDlzg4vF_8MQQTO2Sdp6vo,1893
         | 
| 37 | 
            +
            mlx/include/metal_cpp/Foundation/NSURL.hpp,sha256=b1oGmKqL5F1WrWSaCK3Iqs0vxL8utVbOgtYW4RZUe_I,3668
         | 
| 38 | 
            +
            mlx/include/metal_cpp/LICENSE.txt,sha256=mgnj5ca95epXBMNZliB1O16xkRb5_WScN7W8RJeiXOo,11344
         | 
| 39 | 
            +
            mlx/include/metal_cpp/Metal/MTLAccelerationStructure.hpp,sha256=jqK65WOzpPUNE6H-XYb1PC3myk6vBIQr4fRrClCDZYs,85372
         | 
| 40 | 
            +
            mlx/include/metal_cpp/Metal/MTLAccelerationStructureCommandEncoder.hpp,sha256=EC1flOCenKN-CNLIbHcpli7kP1PQGxkGUU8ZdMYQ4Qs,16289
         | 
| 41 | 
            +
            mlx/include/metal_cpp/Metal/MTLAccelerationStructureTypes.hpp,sha256=I29afo7zt4fiCY2ip3vPN72lhb6Whgjpx0lcHabfib4,5800
         | 
| 42 | 
            +
            mlx/include/metal_cpp/Metal/MTLArgument.hpp,sha256=_gQiklV3kRM2ukcYg0JCz3rUsK-zUOiZMYrp_thHjXU,23515
         | 
| 43 | 
            +
            mlx/include/metal_cpp/Metal/MTLArgumentEncoder.hpp,sha256=PaV9gzYeC1JLqDdcAJK9H6DaPVV2rclNWvNPK70KUWs,10739
         | 
| 44 | 
            +
            mlx/include/metal_cpp/Metal/MTLBinaryArchive.hpp,sha256=I8a_7I2WI-ftil5s4bT8N5rnUeRUt2HtsM2O9LW_t2M,5132
         | 
| 45 | 
            +
            mlx/include/metal_cpp/Metal/MTLBlitCommandEncoder.hpp,sha256=6QF2GR7c4XQXfA5COwCb2Q6xRXVMkYnO1gmUlrNw05o,16482
         | 
| 46 | 
            +
            mlx/include/metal_cpp/Metal/MTLBlitPass.hpp,sha256=zPt4836TDUuDPyRCyF7tJBEtEMpoWziJLlSRgWleURs,6982
         | 
| 47 | 
            +
            mlx/include/metal_cpp/Metal/MTLBuffer.hpp,sha256=IDVEI2qXwq6rU_tQsPkOpma-IV7othYmJIT2QCOBsic,3542
         | 
| 48 | 
            +
            mlx/include/metal_cpp/Metal/MTLCaptureManager.hpp,sha256=ANj6r5G9Gc1l9WikJQmZZlOWvfViJclzyPwKF59UvUg,7524
         | 
| 49 | 
            +
            mlx/include/metal_cpp/Metal/MTLCaptureScope.hpp,sha256=DU5_HrnsaHBttsqHfXCk-oPr8ifhMdP-DiYpS1G-qC0,3733
         | 
| 50 | 
            +
            mlx/include/metal_cpp/Metal/MTLCommandBuffer.hpp,sha256=29LVLOZlzB8G0bVfnHm867TXc0FPqbM7eXbNrZ2zRzE,17642
         | 
| 51 | 
            +
            mlx/include/metal_cpp/Metal/MTLCommandEncoder.hpp,sha256=IQGAhZMikxwU1IINtvw3MSTLwlbQKRWuTNZnZAw_0m8,2946
         | 
| 52 | 
            +
            mlx/include/metal_cpp/Metal/MTLCommandQueue.hpp,sha256=Uz74PTJAhqMi8XgLrj1RZOio-FYxLnTO9Msdy1dE7RA,2960
         | 
| 53 | 
            +
            mlx/include/metal_cpp/Metal/MTLComputeCommandEncoder.hpp,sha256=Z2E_Ny8OJj4gvC_YZtku5sni2t4o9Mbm7xvJk2tpgHc,17157
         | 
| 54 | 
            +
            mlx/include/metal_cpp/Metal/MTLComputePass.hpp,sha256=VOyFJlwaB89gmrtLdmnAz7gl97a5HbTYo-M6icWbN5M,7792
         | 
| 55 | 
            +
            mlx/include/metal_cpp/Metal/MTLComputePipeline.hpp,sha256=WTBUdVq1sxvHocezHO5wpVmaa94Xnthq27kSUCedcMY,15017
         | 
| 56 | 
            +
            mlx/include/metal_cpp/Metal/MTLCounters.hpp,sha256=RmlZCNFZIVPC4ue6FoIgYVq1Weq2Nu1GQD-G0yIwimg,9056
         | 
| 57 | 
            +
            mlx/include/metal_cpp/Metal/MTLDefines.hpp,sha256=1MG6A6zwLk99Yknu7PAvBvIMefVhumr1kFoKxQ-6xKk,1900
         | 
| 58 | 
            +
            mlx/include/metal_cpp/Metal/MTLDepthStencil.hpp,sha256=nP5OtW6QGUK7tzrXtJWJlch9bXkDYPYgp5BjMAptH6E,9631
         | 
| 59 | 
            +
            mlx/include/metal_cpp/Metal/MTLDevice.hpp,sha256=HyHg83T53RuLU1y_ZthBvEIzpOi1Lu0nXptaJfdrhwc,63880
         | 
| 60 | 
            +
            mlx/include/metal_cpp/Metal/MTLDrawable.hpp,sha256=07X3GSg7uJEWoOMWxI77cIWdr2NaBfJEXOtPJZr4M44,3193
         | 
| 61 | 
            +
            mlx/include/metal_cpp/Metal/MTLDynamicLibrary.hpp,sha256=yBUYfobLLHd00UOGALdsEF0FBiF1hUDr7G3FdKy6uMY,2606
         | 
| 62 | 
            +
            mlx/include/metal_cpp/Metal/MTLEvent.hpp,sha256=W8QBQWIUn-DcuBYzg0NB0L1W_Zs6G6QuQCeKUKM7tIo,4976
         | 
| 63 | 
            +
            mlx/include/metal_cpp/Metal/MTLFence.hpp,sha256=DuxPaIkQN9fWi1E6RHg8YrKNeB2nxGOZMuHXd8TiE9c,1723
         | 
| 64 | 
            +
            mlx/include/metal_cpp/Metal/MTLFunctionConstantValues.hpp,sha256=dz9s7_TYqjTJbthwEJTDoqdXsZ34JE2k4BERCYzWqWQ,3088
         | 
| 65 | 
            +
            mlx/include/metal_cpp/Metal/MTLFunctionDescriptor.hpp,sha256=JkzBrxhlg4_ezt2NLomw-MsGuBqdV9embOekJow-mks,5410
         | 
| 66 | 
            +
            mlx/include/metal_cpp/Metal/MTLFunctionHandle.hpp,sha256=HykCHezEwao-pmH6eL57qOyNt_G5RYsrsofxFDyHjiQ,1842
         | 
| 67 | 
            +
            mlx/include/metal_cpp/Metal/MTLFunctionLog.hpp,sha256=ZiIPh0Ef5oeiYIWsFZJGnOXZj2bejYFDe_3NVoRT2nY,3262
         | 
| 68 | 
            +
            mlx/include/metal_cpp/Metal/MTLFunctionStitching.hpp,sha256=fEipcQvZ2d8GeyPyn0AiyMup4doS9Yf-uzZRFpjkfyI,11407
         | 
| 69 | 
            +
            mlx/include/metal_cpp/Metal/MTLHeaderBridge.hpp,sha256=qYNlYMWNAx4z4iNnMyrqdQr_NJ5md2n1HMSxmP5ykRM,105217
         | 
| 70 | 
            +
            mlx/include/metal_cpp/Metal/MTLHeap.hpp,sha256=O9zMAFEKDGKAt710IDSwdsW3mYSyM7Mt2IVvJNJnAq0,11624
         | 
| 71 | 
            +
            mlx/include/metal_cpp/Metal/MTLIOCommandBuffer.hpp,sha256=OweZPLAmFoJFypxGIWisfb6Bzlvgv5qIYB2MVu6jBLw,7417
         | 
| 72 | 
            +
            mlx/include/metal_cpp/Metal/MTLIOCommandQueue.hpp,sha256=Al9BpGDKaSu5zQblMphnr9EXvrVee2ZaZWMba6L7QSs,7465
         | 
| 73 | 
            +
            mlx/include/metal_cpp/Metal/MTLIOCompressor.hpp,sha256=s7UjSrJLyfQt69RzWBY692G1audMRzjIVR-7GjkZFOQ,3043
         | 
| 74 | 
            +
            mlx/include/metal_cpp/Metal/MTLIndirectCommandBuffer.hpp,sha256=J3H4HH0p8qq_DYU8pQC_w4qDmYtNqj6FzoOo_AVGfxA,9715
         | 
| 75 | 
            +
            mlx/include/metal_cpp/Metal/MTLIndirectCommandEncoder.hpp,sha256=AW1Qzyggw9k9EqS71CvlFUHu4lag6nvvzK-LnV4oG0k,11239
         | 
| 76 | 
            +
            mlx/include/metal_cpp/Metal/MTLIntersectionFunctionTable.hpp,sha256=NUP7rHPWulH7M2M1wvzlyjjF7S0BJvG-_XFn4nWehcg,8181
         | 
| 77 | 
            +
            mlx/include/metal_cpp/Metal/MTLLibrary.hpp,sha256=le2P__N8XA4RCN62L2ZB-F9uo2p8BjQ-A6j0kEOqzws,23904
         | 
| 78 | 
            +
            mlx/include/metal_cpp/Metal/MTLLinkedFunctions.hpp,sha256=FmUoNjOGSegryBa0PDa1qcxEF0RUESixbd7OPyU-qDM,3867
         | 
| 79 | 
            +
            mlx/include/metal_cpp/Metal/MTLParallelRenderCommandEncoder.hpp,sha256=uHyC4bf8BjFUz6xEkFg79Lqzz9CV-0NNnNoQ-U3xfsA,3899
         | 
| 80 | 
            +
            mlx/include/metal_cpp/Metal/MTLPipeline.hpp,sha256=1EaeWaT8YNfvPaX0gqktNMOLtojBF6cVMlZiFT7vjyg,3775
         | 
| 81 | 
            +
            mlx/include/metal_cpp/Metal/MTLPixelFormat.hpp,sha256=DYqej_xuztbD0lhsioWTYWO2Sg7jRShyJ7izoVlex5Y,5910
         | 
| 82 | 
            +
            mlx/include/metal_cpp/Metal/MTLPrivate.hpp,sha256=vPTld95A83b1nIVT4qwsdcrttngMK5L5Ow-0HTr-zm4,6415
         | 
| 83 | 
            +
            mlx/include/metal_cpp/Metal/MTLRasterizationRate.hpp,sha256=LtcYTnjLSTaxTw98WdUg_9R9BaAozvXOTHcsnmSqbTo,15447
         | 
| 84 | 
            +
            mlx/include/metal_cpp/Metal/MTLRenderCommandEncoder.hpp,sha256=JNisXL9uSMMm_5B2vkVfuI8dgrMMMIrC8-E7P0pjdeo,60813
         | 
| 85 | 
            +
            mlx/include/metal_cpp/Metal/MTLRenderPass.hpp,sha256=YQCYdLOnOWsUgIbQimBDFo4l_FmbZfeHIJFADL3NnRQ,32752
         | 
| 86 | 
            +
            mlx/include/metal_cpp/Metal/MTLRenderPipeline.hpp,sha256=TuBYsiWg4hpumgaKccAn8TVAeLw26l_FdVGa6JsaMLA,72929
         | 
| 87 | 
            +
            mlx/include/metal_cpp/Metal/MTLResource.hpp,sha256=Be28y-6-9GvFuugCUhnykiTUqbMqYgVovDjUdpIrFDA,5244
         | 
| 88 | 
            +
            mlx/include/metal_cpp/Metal/MTLResourceStateCommandEncoder.hpp,sha256=eOjzuG-vDtQLJ8QOzhiqxXxiMEShXyn9cn-Uhmtigr8,5330
         | 
| 89 | 
            +
            mlx/include/metal_cpp/Metal/MTLResourceStatePass.hpp,sha256=N5sMbFCjfEdf10_3VZJfQvDl3nLuydqrKnUgiJzynRU,7585
         | 
| 90 | 
            +
            mlx/include/metal_cpp/Metal/MTLSampler.hpp,sha256=0rifEvFw4V8_kMmusAU4NXiZeZbCSX_L4x0K6QyEkog,10838
         | 
| 91 | 
            +
            mlx/include/metal_cpp/Metal/MTLStageInputOutputDescriptor.hpp,sha256=LfDoa6Dj5WQFS2Qg_FQRCkVgBpbbfjBdFtn1SIzsCyI,13120
         | 
| 92 | 
            +
            mlx/include/metal_cpp/Metal/MTLTexture.hpp,sha256=jM-mSxqfGBNXE6xwzx1M7CC8GOiyungRuop-6uTuf-M,24807
         | 
| 93 | 
            +
            mlx/include/metal_cpp/Metal/MTLTypes.hpp,sha256=9_le5gsR93J2awZRAk_Rep_yAWKShAQOgI6faZKD_tg,4389
         | 
| 94 | 
            +
            mlx/include/metal_cpp/Metal/MTLVersion.hpp,sha256=iaG4OsKd3JlcsoLEvyAx_5TdCRup1ykH0FUYqbu6Lxo,1509
         | 
| 95 | 
            +
            mlx/include/metal_cpp/Metal/MTLVertexDescriptor.hpp,sha256=4oRbxIiJV2eafs-SqTXTVgPttIC22KtTZ0t4AXnLlgc,12017
         | 
| 96 | 
            +
            mlx/include/metal_cpp/Metal/MTLVisibleFunctionTable.hpp,sha256=P5Cv_bIIfJ4ktplCeZ3fAhHtLEXngF5cAlQgAr55c-0,3833
         | 
| 97 | 
            +
            mlx/include/metal_cpp/Metal/Metal.hpp,sha256=Bm5ldz-FxywsGmAdjZ8U2X5hlYjwfzHUjKCoW38T-w8,3190
         | 
| 98 | 
            +
            mlx/include/metal_cpp/MetalFX/MTLFXDefines.hpp,sha256=Ms80cxZbVVAFYDogYahvmzyVZqkDezrwXKXB_79m3KE,2104
         | 
| 99 | 
            +
            mlx/include/metal_cpp/MetalFX/MTLFXPrivate.hpp,sha256=LcubGmMx98T8OduyEawEx8swCEo_n5Z3ZfwuxwvAtz0,15147
         | 
| 100 | 
            +
            mlx/include/metal_cpp/MetalFX/MTLFXSpatialScaler.hpp,sha256=8QCySJ9V5JfuTNAHujkvgKQJ_hdM-yV8cqdw5v8l1Ns,18385
         | 
| 101 | 
            +
            mlx/include/metal_cpp/MetalFX/MTLFXTemporalScaler.hpp,sha256=57Qfhq3YqYOd0qwLfOBvz_viu55-Pmor0_IF99UWMAg,32887
         | 
| 102 | 
            +
            mlx/include/metal_cpp/MetalFX/MetalFX.hpp,sha256=Vm0W_ycCRb0UOOvLIPYm5gIEHW1nikG5vTvoKfTgAtE,1350
         | 
| 103 | 
            +
            mlx/include/metal_cpp/QuartzCore/CADefines.hpp,sha256=q6k-jUxe5uulNnsoqkM61OuyEVSJO4W8gIAVM6mSbwE,1895
         | 
| 104 | 
            +
            mlx/include/metal_cpp/QuartzCore/CAMetalDrawable.hpp,sha256=RE4_Cp9FacJMh7AlqZw-IhelLMXvUv_YpXVLQ2aNVrU,2363
         | 
| 105 | 
            +
            mlx/include/metal_cpp/QuartzCore/CAMetalLayer.hpp,sha256=i4WjUgWUsQVRPj8dxk4aBC06TKaXsKxRdgNNIgNe0pk,5502
         | 
| 106 | 
            +
            mlx/include/metal_cpp/QuartzCore/CAPrivate.hpp,sha256=aFEVW34Ih_Gs1aGv5Rml0pNaFxDwWLP6qBQWw3N6EzA,5010
         | 
| 107 | 
            +
            mlx/include/metal_cpp/QuartzCore/QuartzCore.hpp,sha256=-fQqOpndpszueKWJ4U_jovMID_s0QxCOJHpQoUj8JBE,1346
         | 
| 108 | 
            +
            mlx/include/metal_cpp/README.md,sha256=ZVbXv3dwSy-n_EmPpYlB2cJXkUCZXdE0-jJsY9pHsd0,14584
         | 
| 109 | 
            +
            mlx/include/metal_cpp/SingleHeader/MakeSingleHeader.py,sha256=JqlbeaangDoU2dRPyOfKgU5jkbmUghucuumqPEbZOWc,8984
         | 
| 110 | 
            +
            mlx/include/metal_cpp/SingleHeader/__pycache__/MakeSingleHeader.cpython-311.pyc,,
         | 
| 111 | 
            +
            mlx/include/mlx/3rdparty/pocketfft.h,sha256=Dq6iEwS_MkY5_xECLuY_r0io-rZK_lw8LR7---HX36Y,110508
         | 
| 112 | 
            +
            mlx/include/mlx/allocator.h,sha256=r0AgIuBbvq67OeO8eUD8cb4a72HDF6-PISAnOMfC-kM,1519
         | 
| 113 | 
            +
            mlx/include/mlx/array.h,sha256=D_dkEyq5zbBSc8kK5KNgQbAikvihW9Xn_1HlwI9uaps,11348
         | 
| 114 | 
            +
            mlx/include/mlx/backend/accelerate/utils.h,sha256=ryQ9t4EI6UkMBCE48MVcGUSstbRDlwM9YI__vs4dTzY,752
         | 
| 115 | 
            +
            mlx/include/mlx/backend/common/arange.h,sha256=J44RZZx_DB7DSELCtDFHBThHfLopu_JWj_5ew_q0ROw,1823
         | 
| 116 | 
            +
            mlx/include/mlx/backend/common/binary.h,sha256=uvS_NdciYF-6VHhKrDNFL1Sd2Ag8hn_kClpYmcu6E9E,15169
         | 
| 117 | 
            +
            mlx/include/mlx/backend/common/copy.h,sha256=5-b2CRTRoyCJ4WL5RkPAQXaLKaEn8pz-XojHGs4u1Rc,690
         | 
| 118 | 
            +
            mlx/include/mlx/backend/common/erf.h,sha256=Wm6jXR40EgAmyh0zbGOvRbZh58LVhe3_P1-hwN9GPuo,274
         | 
| 119 | 
            +
            mlx/include/mlx/backend/common/reduce.h,sha256=iR0xKDMZSEDrf2W0TLLXuqxVesf-nOioGu0MDGalT-0,11138
         | 
| 120 | 
            +
            mlx/include/mlx/backend/common/threefry.h,sha256=sXTdUFRAgEXf789u7IhWIj7n-Gtzwo4HiQXDA6BwFyU,572
         | 
| 121 | 
            +
            mlx/include/mlx/backend/common/unary.h,sha256=5vAiPfmNjXxMohP6RqZiOngCXcXa76RtrLDEyTeQobs,3356
         | 
| 122 | 
            +
            mlx/include/mlx/backend/common/utils.h,sha256=4pS1LCC-Tux3x4Y0A3LeTyTLGDSxbjapNkNFWUNjrc8,610
         | 
| 123 | 
            +
            mlx/include/mlx/backend/metal/allocator.h,sha256=lluCLBSXYbp6WGtl9zKsikpMHrptRjeZJ0Cf70zFkoc,1491
         | 
| 124 | 
            +
            mlx/include/mlx/backend/metal/copy.h,sha256=jXW0_jDwX97NmnwigeZwovS_MZ1PeZix2iuZf7f9Xts,400
         | 
| 125 | 
            +
            mlx/include/mlx/backend/metal/device.h,sha256=1Q0xSLn21dTzD5rWanoBBg1eMqOPf2958Z0rJL9GJag,2196
         | 
| 126 | 
            +
            mlx/include/mlx/backend/metal/kernels/atomic.h,sha256=0Iph_jNZoTyTrR-iEXXDzIwFoRWOkhqndX6IHJxFhKg,9217
         | 
| 127 | 
            +
            mlx/include/mlx/backend/metal/kernels/bf16.h,sha256=YdmOaznu1O1UNTyGlnpnEkvi1r44EouYNBAZ4SAX2yA,11904
         | 
| 128 | 
            +
            mlx/include/mlx/backend/metal/kernels/bf16_math.h,sha256=VRM_iZpBcY3Sg3yxkNPzjU30x32zcmd0diIagA1D5l4,26337
         | 
| 129 | 
            +
            mlx/include/mlx/backend/metal/kernels/complex.h,sha256=XDYKDiBdeM5KGpfpmwXmelZ00LmFCdxiGOIGoUMEeyI,3526
         | 
| 130 | 
            +
            mlx/include/mlx/backend/metal/kernels/conv_params.h,sha256=pzJn2uZ8TqY1WXRyGRx-MzBYBOfMSyhiDPgu8LMSElk,592
         | 
| 131 | 
            +
            mlx/include/mlx/backend/metal/kernels/defines.h,sha256=Bkplnhk_KYkqUzpB9TSlYUuXyucJQZYNyW9AR55g8fM,477
         | 
| 132 | 
            +
            mlx/include/mlx/backend/metal/kernels/erf.h,sha256=k1CfVsb-rEQaXGYJcAbgVR90VwEY5w_88BhJCSsj97k,2736
         | 
| 133 | 
            +
            mlx/include/mlx/backend/metal/kernels/gemm/conv.h,sha256=ilmFKKdzy8qwn1Zt0cu43f_mKsMLhIJy6DgIdx5ZU5I,14236
         | 
| 134 | 
            +
            mlx/include/mlx/backend/metal/kernels/gemm/gemm.h,sha256=x80JGVUliIDkFyzoGGd5iC0veHZVPx_BFkeVdGzgVOk,16184
         | 
| 135 | 
            +
            mlx/include/mlx/backend/metal/kernels/reduce.h,sha256=RgBsV24CtUU_dXCNzDP8sojJdjRz2bdF3PqB8UNJrlA,3554
         | 
| 136 | 
            +
            mlx/include/mlx/backend/metal/kernels/utils.h,sha256=mRGC9B2cjzCZT2kxY__EM0QmRQQBL8GSHC0rg7_4cJo,7202
         | 
| 137 | 
            +
            mlx/include/mlx/backend/metal/matmul.h,sha256=n32FxTg1zL-_oIaPRrs9QcI5TywhMemcm2KJ4zF-n0Q,592
         | 
| 138 | 
            +
            mlx/include/mlx/backend/metal/metal.h,sha256=9kgNkbx0Pqy-W3o7stP0cjvSAcoU0I20dZEsKmsSAS4,552
         | 
| 139 | 
            +
            mlx/include/mlx/backend/metal/mps/gemm.h,sha256=501lddlggxogceZI3wkWK98yIZurE0qgVxx0ICHyBxs,11307
         | 
| 140 | 
            +
            mlx/include/mlx/backend/metal/utils.h,sha256=ape52Rl1et96MA4H_a0cbVveLvhASJlszqX6mxl-nis,4205
         | 
| 141 | 
            +
            mlx/include/mlx/device.h,sha256=KtG7G03J5T0Ad2zn931OyRGfCa1XKn3ZyEUoIBY0No4,563
         | 
| 142 | 
            +
            mlx/include/mlx/dtype.h,sha256=jrOMl7xEhOZ7Dn0udDfi-Mq8QbT3atbd7CL0SZ4Yxjk,2575
         | 
| 143 | 
            +
            mlx/include/mlx/fft.h,sha256=MpW4UMjetc81i8pt3NJnTCvtXuaP92CX3w2iatYZHTw,4179
         | 
| 144 | 
            +
            mlx/include/mlx/graph_utils.h,sha256=23pv6xyLYTl-l2o9Z3DqlcvGVGZEA-uIV29jamcK-1Q,593
         | 
| 145 | 
            +
            mlx/include/mlx/io/load.h,sha256=vl7Z0pSJTjgG7KkSLXbe5oLjPhnVkWR9XdvC-ZCf_fk,2508
         | 
| 146 | 
            +
            mlx/include/mlx/io/safetensor.h,sha256=ro-grE_iXiGbCgygGFSeqsvxD_DDiwKckCsyb424e-g,624
         | 
| 147 | 
            +
            mlx/include/mlx/linalg.h,sha256=-CJ8NhHmycVsT29ObUmVZbxd2gwCpj8qj2RvOv-Glto,1771
         | 
| 148 | 
            +
            mlx/include/mlx/mlx.h,sha256=raZfU4bjGTSRAZa8P1XkoxHHeiooTZVI9C63ydGwm8c,296
         | 
| 149 | 
            +
            mlx/include/mlx/ops.h,sha256=322gdgz9l7pz-M0Pgpo0T6LbAe_w0RMV-A1-jRN_Le0,32224
         | 
| 150 | 
            +
            mlx/include/mlx/primitives.h,sha256=uJhmcwsXyseQlfeJBDjLCNrhPM91vursiw8_R79yUGg,44227
         | 
| 151 | 
            +
            mlx/include/mlx/random.h,sha256=aNoyWf-zAg9WKgqfe1yBpOvxedygIUZMQBgsd3kNnzQ,4887
         | 
| 152 | 
            +
            mlx/include/mlx/scheduler.h,sha256=sZX0BU1J15e_IOzBsv95KwAQmU4wrayd3b6ClvxGA4M,3864
         | 
| 153 | 
            +
            mlx/include/mlx/stream.h,sha256=7BvNtylObXbyIHZ0gCrbMb1bN3DUq7PfzEa-bIvqEXE,686
         | 
| 154 | 
            +
            mlx/include/mlx/transforms.h,sha256=Qt8pjvxAPYMoa8Dz2WUcJBdeKEFKNdczV0O3UTE_XmY,6527
         | 
| 155 | 
            +
            mlx/include/mlx/transforms_impl.h,sha256=RttUce4KHg3PLUFc9aRVjxenCIztlBSTlQq5LnVxKVg,542
         | 
| 156 | 
            +
            mlx/include/mlx/types/bf16.h,sha256=fLnTB0fbL2vIs_CCvWhrr7SfDnFUEpDPmg7RNmrRnzE,6753
         | 
| 157 | 
            +
            mlx/include/mlx/types/complex.h,sha256=oaTl_rCth3TO33sT2Db6SX--n9lIfE_dw7rjE06Qqmc,2878
         | 
| 158 | 
            +
            mlx/include/mlx/types/fp16.h,sha256=j0E7-PKkVkCT-bOMRdBEYdvxUdo9rf6w80dNSNCYNaQ,8322
         | 
| 159 | 
            +
            mlx/include/mlx/types/half_types.h,sha256=-D-ilDAh5shp1xp4lx0J87y9iFfBs43gLNi-XqzmLWA,1367
         | 
| 160 | 
            +
            mlx/include/mlx/utils.h,sha256=cchUDEk5qbWNIPHf_cRwbsxnczz2_moUr_Dipz-cfhM,1524
         | 
| 161 | 
            +
            mlx/lib/libmlx.dylib,sha256=ir7-RqHznJKyhGSBTwWnMPqYmbF3V3A0A8bvNi8GrJM,12420704
         | 
| 162 | 
            +
            mlx/lib/mlx.metallib,sha256=Lu30EADtJwKD2hHYibsQGqTIjG-PDsaP5rBApb5CRQE,59495531
         | 
| 163 | 
            +
            mlx/nn/__init__.py,sha256=b-1hqkAnzpthKJNKgRGVHsltNEqOU7URFMoY2Z46Buw,126
         | 
| 164 | 
            +
            mlx/nn/__pycache__/__init__.cpython-311.pyc,,
         | 
| 165 | 
            +
            mlx/nn/__pycache__/losses.cpython-311.pyc,,
         | 
| 166 | 
            +
            mlx/nn/__pycache__/utils.cpython-311.pyc,,
         | 
| 167 | 
            +
            mlx/nn/layers/__init__.py,sha256=2Ge23mO7W5VOMkogmcY2dqTKRRD3lY0g_WhVDP-9Hi0,1249
         | 
| 168 | 
            +
            mlx/nn/layers/__pycache__/__init__.cpython-311.pyc,,
         | 
| 169 | 
            +
            mlx/nn/layers/__pycache__/activations.cpython-311.pyc,,
         | 
| 170 | 
            +
            mlx/nn/layers/__pycache__/base.cpython-311.pyc,,
         | 
| 171 | 
            +
            mlx/nn/layers/__pycache__/containers.cpython-311.pyc,,
         | 
| 172 | 
            +
            mlx/nn/layers/__pycache__/convolution.cpython-311.pyc,,
         | 
| 173 | 
            +
            mlx/nn/layers/__pycache__/dropout.cpython-311.pyc,,
         | 
| 174 | 
            +
            mlx/nn/layers/__pycache__/embedding.cpython-311.pyc,,
         | 
| 175 | 
            +
            mlx/nn/layers/__pycache__/linear.cpython-311.pyc,,
         | 
| 176 | 
            +
            mlx/nn/layers/__pycache__/normalization.cpython-311.pyc,,
         | 
| 177 | 
            +
            mlx/nn/layers/__pycache__/positional_encoding.cpython-311.pyc,,
         | 
| 178 | 
            +
            mlx/nn/layers/__pycache__/quantized.cpython-311.pyc,,
         | 
| 179 | 
            +
            mlx/nn/layers/__pycache__/transformer.cpython-311.pyc,,
         | 
| 180 | 
            +
            mlx/nn/layers/activations.py,sha256=ukopw8ZFpSIQijXscov80Gv_0j0jKh0Y4Z3qrUFzpkk,11826
         | 
| 181 | 
            +
            mlx/nn/layers/base.py,sha256=iFp3PVGUhT9pAaWNnWtNY-yZ-uCQxvH-O1CYbKoGCug,19485
         | 
| 182 | 
            +
            mlx/nn/layers/containers.py,sha256=Ke8gPBPZvrtrjK1qrbRNGtSV4rvBkYTKVBW2pQXrNvY,618
         | 
| 183 | 
            +
            mlx/nn/layers/convolution.py,sha256=ZP43OBZTKF3ui_2Xrtx7VvAlLic_jcn2oL-MPLU7Rys,4060
         | 
| 184 | 
            +
            mlx/nn/layers/dropout.py,sha256=1GR5uWU82eFd296WuPB7CqNcNj3acEIlS3_4ZAxnx-g,4529
         | 
| 185 | 
            +
            mlx/nn/layers/embedding.py,sha256=5G7TvVyEaHTlQCASNVfBGea_1ywDVNqDkxHP2U-JEcs,878
         | 
| 186 | 
            +
            mlx/nn/layers/linear.py,sha256=ekgkElIUlfdWZNRDZkoyQFTotzyjNpD_0v7Ac8pnako,4137
         | 
| 187 | 
            +
            mlx/nn/layers/normalization.py,sha256=m92TwFfNQ4NbQlxtlTfJsWZlBJtO0_tR0M8H2d9wZEs,12021
         | 
| 188 | 
            +
            mlx/nn/layers/positional_encoding.py,sha256=_4zbYLMGYI0EVq4b6eitmB2hgBXblcrJJYsq0KCSafs,6611
         | 
| 189 | 
            +
            mlx/nn/layers/quantized.py,sha256=FA_NzLBheO2i5AClMWCkaFrNXJGScxlwvumqaJTIxAA,4136
         | 
| 190 | 
            +
            mlx/nn/layers/transformer.py,sha256=IUTNGvVoaKLS7Gwes_yK3SqgX6voPgMqk7nmdE4nyI0,12235
         | 
| 191 | 
            +
            mlx/nn/losses.py,sha256=1yWCxt-QSv4lGV9U6KJyeTzybQWW4LPZhipkLFx7hWk,12158
         | 
| 192 | 
            +
            mlx/nn/utils.py,sha256=tQJ62bu2eAAhzYqV7QriQtNGrYSz85XGxHolEPkGiGs,999
         | 
| 193 | 
            +
            mlx/optimizers.py,sha256=krPDwdaMy1JTlaVBgRHlg7PNqyxiVRkfNEbEf_wV-5g,17090
         | 
| 194 | 
            +
            mlx/share/cmake/MLX/MLXConfig.cmake,sha256=02CzwjTAHLTyLzFMtE7LLzyX_CSHreASjFyiYEpfIp4,1884
         | 
| 195 | 
            +
            mlx/share/cmake/MLX/MLXConfigVersion.cmake,sha256=p6TWi4EVJlFiJx_BxHZqpPBhkKaze00W3S-JRcRHD78,2762
         | 
| 196 | 
            +
            mlx/share/cmake/MLX/MLXTargets-release.cmake,sha256=bUjlsca3AGGUHl1X9lfm3CiQiSFeYTkm2c7GLzITeJs,788
         | 
| 197 | 
            +
            mlx/share/cmake/MLX/MLXTargets.cmake,sha256=vS5sI59URF-Pz8SPKMDVbyTPq1ENZ3zoR_VPzWToXkw,4705
         | 
| 198 | 
            +
            mlx/share/cmake/MLX/extension.cmake,sha256=q09-X3ckTrFCZOc5UNyoqjUC94OQQSXXWW7Bmg7sxok,1637
         | 
| 199 | 
            +
            mlx/utils.py,sha256=F-bmWn6a4VFRUvnSeZPvIFwT8Z6IHEtgEAfcLiGBrz8,4667
         | 
    	
        lib/python3.11/site-packages/mlx-0.0.7.dist-info/REQUESTED
    ADDED
    
    | 
            File without changes
         | 
    	
        lib/python3.11/site-packages/mlx-0.0.7.dist-info/WHEEL
    ADDED
    
    | @@ -0,0 +1,5 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            Wheel-Version: 1.0
         | 
| 2 | 
            +
            Generator: bdist_wheel (0.41.2)
         | 
| 3 | 
            +
            Root-Is-Purelib: false
         | 
| 4 | 
            +
            Tag: cp311-cp311-macosx_14_0_arm64
         | 
| 5 | 
            +
             | 
    	
        lib/python3.11/site-packages/mlx-0.0.7.dist-info/top_level.txt
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            mlx
         | 
    	
        lib/python3.11/site-packages/mlx/nn/layers/base.py
    ADDED
    
    | @@ -0,0 +1,532 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright © 2023 Apple Inc.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import textwrap
         | 
| 4 | 
            +
            from typing import Any, Callable, List, Optional, Tuple, Union
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import mlx.core as mx
         | 
| 7 | 
            +
            from mlx.utils import tree_flatten, tree_unflatten
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            class Module(dict):
         | 
| 11 | 
            +
                """Base class for building neural networks with MLX.
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                All the layers provided in :mod:`mlx.nn.layers` subclass this class and
         | 
| 14 | 
            +
                your models should do the same.
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                A ``Module`` can contain other ``Module`` instances or :class:`mlx.core.array`
         | 
| 17 | 
            +
                instances in arbitrary nesting of python lists or dicts. The ``Module``
         | 
| 18 | 
            +
                then allows recursively extracting all the :class:`mlx.core.array` instances
         | 
| 19 | 
            +
                using :meth:`mlx.nn.Module.parameters`.
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                In addition, the ``Module`` has the concept of trainable and non trainable
         | 
| 22 | 
            +
                parameters (called "frozen"). When using :func:`mlx.nn.value_and_grad`
         | 
| 23 | 
            +
                the gradients are returned only with respect to the trainable parameters.
         | 
| 24 | 
            +
                All arrays in a module are trainable unless they are added in the "frozen"
         | 
| 25 | 
            +
                set by calling :meth:`freeze`.
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                .. code-block:: python
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                    import mlx.core as mx
         | 
| 30 | 
            +
                    import mlx.nn as nn
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    class MyMLP(nn.Module):
         | 
| 33 | 
            +
                        def __init__(self, in_dims: int, out_dims: int, hidden_dims: int = 16):
         | 
| 34 | 
            +
                            super().__init__()
         | 
| 35 | 
            +
             | 
| 36 | 
            +
                            self.in_proj = nn.Linear(in_dims, hidden_dims)
         | 
| 37 | 
            +
                            self.out_proj = nn.Linear(hidden_dims, out_dims)
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                        def __call__(self, x):
         | 
| 40 | 
            +
                            x = self.in_proj(x)
         | 
| 41 | 
            +
                            x = mx.maximum(x, 0)
         | 
| 42 | 
            +
                            return self.out_proj(x)
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                    model = MyMLP(2, 1)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                    # All the model parameters are created but since MLX is lazy by
         | 
| 47 | 
            +
                    # default, they are not evaluated yet. Calling `mx.eval` actually
         | 
| 48 | 
            +
                    # allocates memory and initializes the parameters.
         | 
| 49 | 
            +
                    mx.eval(model.parameters())
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    # Setting a parameter to a new value is as simply as accessing that
         | 
| 52 | 
            +
                    # parameter and assigning a new array to it.
         | 
| 53 | 
            +
                    model.in_proj.weight = model.in_proj.weight * 2
         | 
| 54 | 
            +
                    mx.eval(model.parameters())
         | 
| 55 | 
            +
                """
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                def __init__(self):
         | 
| 58 | 
            +
                    """Should be called by the subclasses of ``Module``."""
         | 
| 59 | 
            +
                    self._no_grad = set()
         | 
| 60 | 
            +
                    self._training = True
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                @property
         | 
| 63 | 
            +
                def training(self):
         | 
| 64 | 
            +
                    """Boolean indicating if the model is in training mode."""
         | 
| 65 | 
            +
                    return self._training
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                def _extra_repr(self):
         | 
| 68 | 
            +
                    return ""
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                def __repr__(self):
         | 
| 71 | 
            +
                    children = tree_flatten(self.children(), is_leaf=self.is_module)
         | 
| 72 | 
            +
                    value = f"{type(self).__name__}({self._extra_repr()}"
         | 
| 73 | 
            +
                    for k, v in children:
         | 
| 74 | 
            +
                        value += "\n"
         | 
| 75 | 
            +
                        value += textwrap.indent(f"({k}): {repr(v)}", prefix="  ")
         | 
| 76 | 
            +
                    if children:
         | 
| 77 | 
            +
                        value += "\n"
         | 
| 78 | 
            +
                    value += ")"
         | 
| 79 | 
            +
             | 
| 80 | 
            +
                    return value
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                def __getattr__(self, key: str):
         | 
| 83 | 
            +
                    if key in self:
         | 
| 84 | 
            +
                        return self[key]
         | 
| 85 | 
            +
                    else:
         | 
| 86 | 
            +
                        raise AttributeError(f"{type(self)!r} has no attribute {key!r}")
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                def __setattr__(self, key: str, val: Any):
         | 
| 89 | 
            +
                    self[key] = val
         | 
| 90 | 
            +
             | 
| 91 | 
            +
                def load_weights(
         | 
| 92 | 
            +
                    self,
         | 
| 93 | 
            +
                    file_or_weights: Union[str, List[Tuple[str, mx.array]]],
         | 
| 94 | 
            +
                    strict: bool = True,
         | 
| 95 | 
            +
                ):
         | 
| 96 | 
            +
                    """
         | 
| 97 | 
            +
                    Update the model's weights from a ``.npz`` or a list.
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    Args:
         | 
| 100 | 
            +
                        file_or_weights (str or list(tuple(str, mx.array))): The path to
         | 
| 101 | 
            +
                            the weights ``.npz`` file or a list of pairs of parameter names
         | 
| 102 | 
            +
                            and arrays.
         | 
| 103 | 
            +
                        strict (bool, optional): If ``True`` then checks that the provided
         | 
| 104 | 
            +
                          weights exactly match the parameters of the model. Otherwise,
         | 
| 105 | 
            +
                          only the weights actually contained in the model are loaded and
         | 
| 106 | 
            +
                          shapes are not checked. Default: ``True``.
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    Example:
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                        .. code-block:: python
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                            import mlx.core as mx
         | 
| 113 | 
            +
                            import mlx.nn as nn
         | 
| 114 | 
            +
                            model = nn.Linear(10, 10)
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                            # Load from file
         | 
| 117 | 
            +
                            model.load_weights("weights.npz")
         | 
| 118 | 
            +
             | 
| 119 | 
            +
                            # Load from list
         | 
| 120 | 
            +
                            weights = [
         | 
| 121 | 
            +
                                ("weight", mx.random.uniform(shape=(10, 10))),
         | 
| 122 | 
            +
                                ("bias",  mx.zeros((10,))),
         | 
| 123 | 
            +
                            ]
         | 
| 124 | 
            +
                            model.load_weights(weights)
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                            # Missing weight
         | 
| 127 | 
            +
                            weights = [
         | 
| 128 | 
            +
                                ("weight", mx.random.uniform(shape=(10, 10))),
         | 
| 129 | 
            +
                            ]
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                            # Raises a ValueError exception
         | 
| 132 | 
            +
                            model.load_weights(weights)
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                            # Ok, only updates the weight but not the bias
         | 
| 135 | 
            +
                            model.load_weights(weights, strict=False)
         | 
| 136 | 
            +
                    """
         | 
| 137 | 
            +
                    weights = file_or_weights
         | 
| 138 | 
            +
                    if isinstance(weights, str):
         | 
| 139 | 
            +
                        weights = list(mx.load(weights).items())
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    if strict:
         | 
| 142 | 
            +
                        new_weights = dict(weights)
         | 
| 143 | 
            +
                        curr_weights = dict(tree_flatten(self.parameters()))
         | 
| 144 | 
            +
                        if extras := (new_weights.keys() - curr_weights.keys()):
         | 
| 145 | 
            +
                            extras = " ".join(extras)
         | 
| 146 | 
            +
                            raise ValueError(f"Received parameters not in model: {extras}.")
         | 
| 147 | 
            +
                        if missing := (curr_weights.keys() - new_weights.keys()):
         | 
| 148 | 
            +
                            missing = " ".join(missing)
         | 
| 149 | 
            +
                            raise ValueError(f"Missing parameters: {missing}.")
         | 
| 150 | 
            +
                        for k, v in curr_weights.items():
         | 
| 151 | 
            +
                            v_new = new_weights[k]
         | 
| 152 | 
            +
                            if not isinstance(v_new, mx.array):
         | 
| 153 | 
            +
                                raise ValueError(
         | 
| 154 | 
            +
                                    "Expected mx.array but received "
         | 
| 155 | 
            +
                                    f"{type(v_new)} for parameter {k}"
         | 
| 156 | 
            +
                                )
         | 
| 157 | 
            +
                            if v_new.shape != v.shape:
         | 
| 158 | 
            +
                                raise ValueError(
         | 
| 159 | 
            +
                                    f"Expected shape {v.shape} but received "
         | 
| 160 | 
            +
                                    f" shape {v_new.shape} for parameter {k}"
         | 
| 161 | 
            +
                                )
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                    self.update(tree_unflatten(weights))
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                def save_weights(self, file: str):
         | 
| 166 | 
            +
                    """
         | 
| 167 | 
            +
                    Save the model's weights to a ``.npz`` file.
         | 
| 168 | 
            +
                    """
         | 
| 169 | 
            +
                    mx.savez(file, **dict(tree_flatten(self.parameters())))
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                @staticmethod
         | 
| 172 | 
            +
                def is_module(value):
         | 
| 173 | 
            +
                    return isinstance(value, Module)
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                @staticmethod
         | 
| 176 | 
            +
                def valid_child_filter(module, key, value):
         | 
| 177 | 
            +
                    return isinstance(value, (dict, list))
         | 
| 178 | 
            +
             | 
| 179 | 
            +
                @staticmethod
         | 
| 180 | 
            +
                def valid_parameter_filter(module, key, value):
         | 
| 181 | 
            +
                    return isinstance(value, (dict, list, mx.array)) and not key.startswith("_")
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                @staticmethod
         | 
| 184 | 
            +
                def trainable_parameter_filter(module, key, value):
         | 
| 185 | 
            +
                    return (
         | 
| 186 | 
            +
                        Module.valid_parameter_filter(module, key, value)
         | 
| 187 | 
            +
                        and key not in module._no_grad
         | 
| 188 | 
            +
                    )
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                def filter_and_map(
         | 
| 191 | 
            +
                    self,
         | 
| 192 | 
            +
                    filter_fn: Callable[["mlx.nn.Module", str, Any], bool],
         | 
| 193 | 
            +
                    map_fn: Optional[Callable] = None,
         | 
| 194 | 
            +
                    is_leaf_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
         | 
| 195 | 
            +
                ):
         | 
| 196 | 
            +
                    """Recursively filter the contents of the module using ``filter_fn``,
         | 
| 197 | 
            +
                    namely only select keys and values where ``filter_fn`` returns true.
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                    This is used to implement :meth:`parameters` and :meth:`trainable_parameters`
         | 
| 200 | 
            +
                    but it can also be used to extract any subset of the module's parameters.
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    Args:
         | 
| 203 | 
            +
                        filter_fn (Callable): Given a value, the key in which it is found
         | 
| 204 | 
            +
                            and the containing module, decide whether to keep the value or
         | 
| 205 | 
            +
                            drop it.
         | 
| 206 | 
            +
                        map_fn (Callable, optional): Optionally transform the value before
         | 
| 207 | 
            +
                            returning it.
         | 
| 208 | 
            +
                        is_leaf_fn (Callable, optional): Given a value, the key in which it
         | 
| 209 | 
            +
                            is found and the containing module decide if it is a leaf.
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                    Returns:
         | 
| 212 | 
            +
                        A dictionary containing the contents of the module recursively filtered
         | 
| 213 | 
            +
                    """
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    map_fn = map_fn or (lambda x: x)
         | 
| 216 | 
            +
                    is_leaf_fn = is_leaf_fn or (
         | 
| 217 | 
            +
                        lambda m, k, v: not isinstance(v, (Module, dict, list))
         | 
| 218 | 
            +
                    )
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                    def unwrap(vk, v):
         | 
| 221 | 
            +
                        if is_leaf_fn(self, vk, v):
         | 
| 222 | 
            +
                            return map_fn(v)
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                        if isinstance(v, Module):
         | 
| 225 | 
            +
                            return v.filter_and_map(filter_fn, map_fn, is_leaf_fn)
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                        if isinstance(v, dict):
         | 
| 228 | 
            +
                            nd = {}
         | 
| 229 | 
            +
                            for k, v in v.items():
         | 
| 230 | 
            +
                                tk = f"{vk}.{k}"
         | 
| 231 | 
            +
                                nd[k] = unwrap(tk, v) if filter_fn(self, tk, v) else {}
         | 
| 232 | 
            +
                            return nd
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                        if isinstance(v, list):
         | 
| 235 | 
            +
                            nl = []
         | 
| 236 | 
            +
                            for i, vi in enumerate(v):
         | 
| 237 | 
            +
                                tk = f"{vk}.{i}"
         | 
| 238 | 
            +
                                nl.append(unwrap(tk, vi) if filter_fn(self, tk, vi) else {})
         | 
| 239 | 
            +
                            return nl
         | 
| 240 | 
            +
             | 
| 241 | 
            +
                        raise RuntimeError("Unexpected leaf found while traversing the module")
         | 
| 242 | 
            +
             | 
| 243 | 
            +
                    return {k: unwrap(k, v) for k, v in self.items() if filter_fn(self, k, v)}
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                def parameters(self):
         | 
| 246 | 
            +
                    """Recursively return all the :class:`mlx.core.array` members of this Module
         | 
| 247 | 
            +
                    as a dict of dicts and lists."""
         | 
| 248 | 
            +
                    return self.filter_and_map(self.valid_parameter_filter)
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                def trainable_parameters(self):
         | 
| 251 | 
            +
                    """Recursively return all the non frozen :class:`mlx.core.array` members of
         | 
| 252 | 
            +
                    this Module as a dict of dicts and lists."""
         | 
| 253 | 
            +
                    return self.filter_and_map(self.trainable_parameter_filter)
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                def children(self):
         | 
| 256 | 
            +
                    """Return the direct descendants of this Module instance."""
         | 
| 257 | 
            +
                    return self.filter_and_map(
         | 
| 258 | 
            +
                        self.valid_child_filter, is_leaf_fn=lambda m, k, v: isinstance(v, Module)
         | 
| 259 | 
            +
                    )
         | 
| 260 | 
            +
             | 
| 261 | 
            +
                def leaf_modules(self):
         | 
| 262 | 
            +
                    """Return the submodules that do not contain other modules."""
         | 
| 263 | 
            +
             | 
| 264 | 
            +
                    def _is_leaf_module(m, k, v):
         | 
| 265 | 
            +
                        return isinstance(v, Module) and len(tree_flatten(v.children())) == 0
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                    return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
         | 
| 268 | 
            +
             | 
| 269 | 
            +
                def update(self, parameters: dict):
         | 
| 270 | 
            +
                    """Replace the parameters of this Module with the provided ones in the
         | 
| 271 | 
            +
                    dict of dicts and lists.
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                    Commonly used by the optimizer to change the model to the updated
         | 
| 274 | 
            +
                    (optimized) parameters. Also used by the :meth:`mlx.nn.value_and_grad` to set the
         | 
| 275 | 
            +
                    tracers in the model in order to compute gradients.
         | 
| 276 | 
            +
             | 
| 277 | 
            +
                    The passed in parameters dictionary need not be a full dictionary
         | 
| 278 | 
            +
                    similar to :meth:`parameters`. Only the provided locations will be
         | 
| 279 | 
            +
                    updated.
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                    Args:
         | 
| 282 | 
            +
                        parameters (dict): A complete or partial dictionary of the modules
         | 
| 283 | 
            +
                                           parameters.
         | 
| 284 | 
            +
                    """
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                    def apply(dst, parameters):
         | 
| 287 | 
            +
                        if isinstance(parameters, dict):
         | 
| 288 | 
            +
                            for k in parameters:
         | 
| 289 | 
            +
                                if k in dst:
         | 
| 290 | 
            +
                                    current_value = dst[k]
         | 
| 291 | 
            +
                                    new_value = parameters[k]
         | 
| 292 | 
            +
                                    if isinstance(current_value, mx.array):
         | 
| 293 | 
            +
                                        dst[k] = new_value
         | 
| 294 | 
            +
                                    elif isinstance(current_value, Module):
         | 
| 295 | 
            +
                                        current_value.update(new_value)
         | 
| 296 | 
            +
                                    elif isinstance(current_value, (dict, list)):
         | 
| 297 | 
            +
                                        apply(current_value, new_value)
         | 
| 298 | 
            +
                        elif isinstance(parameters, list):
         | 
| 299 | 
            +
                            for i in range(len(dst)):
         | 
| 300 | 
            +
                                current_value = dst[i]
         | 
| 301 | 
            +
                                new_value = parameters[i]
         | 
| 302 | 
            +
                                if isinstance(current_value, mx.array):
         | 
| 303 | 
            +
                                    dst[i] = new_value
         | 
| 304 | 
            +
                                elif isinstance(current_value, Module):
         | 
| 305 | 
            +
                                    current_value.update(new_value)
         | 
| 306 | 
            +
                                elif isinstance(current_value, (dict, list)):
         | 
| 307 | 
            +
                                    apply(current_value, new_value)
         | 
| 308 | 
            +
             | 
| 309 | 
            +
                    apply(self, parameters)
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                def apply(
         | 
| 312 | 
            +
                    self,
         | 
| 313 | 
            +
                    map_fn: Callable[[mx.array], mx.array],
         | 
| 314 | 
            +
                    filter_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
         | 
| 315 | 
            +
                ):
         | 
| 316 | 
            +
                    """Map all the parameters using the provided ``map_fn`` and immediately
         | 
| 317 | 
            +
                    update the module with the mapped parameters.
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                    For instance running ``model.apply(lambda x: x.astype(mx.float16))``
         | 
| 320 | 
            +
                    casts all parameters to 16 bit floats.
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                    Args:
         | 
| 323 | 
            +
                        map_fn (Callable): Maps an array to another array
         | 
| 324 | 
            +
                        filter_fn (Callable, optional): Filter to select which arrays to
         | 
| 325 | 
            +
                            map (default: :meth:`Module.valid_parameter_filter`).
         | 
| 326 | 
            +
                    """
         | 
| 327 | 
            +
                    filter_fn = filter_fn or Module.valid_parameter_filter
         | 
| 328 | 
            +
                    self.update(self.filter_and_map(filter_fn, map_fn))
         | 
| 329 | 
            +
             | 
| 330 | 
            +
                def update_modules(self, modules: dict):
         | 
| 331 | 
            +
                    """Replace the child modules of this :class:`Module` instance with the
         | 
| 332 | 
            +
                    provided ones in the dict of dicts and lists.
         | 
| 333 | 
            +
             | 
| 334 | 
            +
                    It is the equivalent of :meth:`Module.update` but for modules instead
         | 
| 335 | 
            +
                    of parameters and allows us to flexibly edit complex architectures by
         | 
| 336 | 
            +
                    programmatically swapping layers.
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                    The passed in parameters dictionary need not be a full dictionary
         | 
| 339 | 
            +
                    similar to :meth:`parameters`. Only the provided locations will be
         | 
| 340 | 
            +
                    updated.
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                    Args:
         | 
| 343 | 
            +
                        modules (dict): A complete or partial dictionary of the modules
         | 
| 344 | 
            +
                            submodules.
         | 
| 345 | 
            +
                    """
         | 
| 346 | 
            +
             | 
| 347 | 
            +
                    def apply(dst, modules):
         | 
| 348 | 
            +
                        if isinstance(modules, dict):
         | 
| 349 | 
            +
                            for k in modules:
         | 
| 350 | 
            +
                                if k in dst:
         | 
| 351 | 
            +
                                    current_value = dst[k]
         | 
| 352 | 
            +
                                    new_value = modules[k]
         | 
| 353 | 
            +
                                    if self.is_module(current_value) and self.is_module(new_value):
         | 
| 354 | 
            +
                                        dst[k] = new_value
         | 
| 355 | 
            +
                                    elif isinstance(current_value, (dict, list)):
         | 
| 356 | 
            +
                                        apply(current_value, new_value)
         | 
| 357 | 
            +
                        elif isinstance(modules, list):
         | 
| 358 | 
            +
                            for i in range(len(dst)):
         | 
| 359 | 
            +
                                current_value = dst[i]
         | 
| 360 | 
            +
                                new_value = modules[i]
         | 
| 361 | 
            +
                                if self.is_module(current_value) and self.is_module(new_value):
         | 
| 362 | 
            +
                                    dst[i] = new_value
         | 
| 363 | 
            +
                                elif isinstance(current_value, (dict, list)):
         | 
| 364 | 
            +
                                    apply(current_value, new_value)
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                    apply(self, modules)
         | 
| 367 | 
            +
             | 
| 368 | 
            +
                def apply_to_modules(self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]):
         | 
| 369 | 
            +
                    """Apply a function to all the modules in this instance (including this
         | 
| 370 | 
            +
                    instance).
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                    Args:
         | 
| 373 | 
            +
                        apply_fn (Callable): The function to apply to the modules.
         | 
| 374 | 
            +
                    """
         | 
| 375 | 
            +
                    module_stack = [("", self)]
         | 
| 376 | 
            +
                    while module_stack:
         | 
| 377 | 
            +
                        prefix, mod = module_stack.pop()
         | 
| 378 | 
            +
                        apply_fn(prefix, mod)
         | 
| 379 | 
            +
                        prefix = "." + prefix if prefix else ""
         | 
| 380 | 
            +
                        module_stack.extend(
         | 
| 381 | 
            +
                            tree_flatten(mod.children(), prefix=prefix, is_leaf=self.is_module)
         | 
| 382 | 
            +
                        )
         | 
| 383 | 
            +
             | 
| 384 | 
            +
                def modules(self):
         | 
| 385 | 
            +
                    """Return a list with all the modules in this instance.
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                    Returns:
         | 
| 388 | 
            +
                        A list of :class:`mlx.nn.Module` instances.
         | 
| 389 | 
            +
                    """
         | 
| 390 | 
            +
                    modulelist = []
         | 
| 391 | 
            +
                    self.apply_to_modules(lambda k, m: modulelist.append(m))
         | 
| 392 | 
            +
                    return modulelist
         | 
| 393 | 
            +
             | 
| 394 | 
            +
                def named_modules(self):
         | 
| 395 | 
            +
                    """Return a list with all the modules in this instance and their name
         | 
| 396 | 
            +
                    with dot notation.
         | 
| 397 | 
            +
             | 
| 398 | 
            +
                    Returns:
         | 
| 399 | 
            +
                        A list of tuples (str, :class:`mlx.nn.Module`).
         | 
| 400 | 
            +
                    """
         | 
| 401 | 
            +
                    modulelist = []
         | 
| 402 | 
            +
                    self.apply_to_modules(lambda k, m: modulelist.append((k, m)))
         | 
| 403 | 
            +
                    return modulelist
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                def _validate_keys(self, keys, strict):
         | 
| 406 | 
            +
                    keys = keys if isinstance(keys, list) else [keys]
         | 
| 407 | 
            +
                    if strict:
         | 
| 408 | 
            +
                        for k in keys:
         | 
| 409 | 
            +
                            if k not in self:
         | 
| 410 | 
            +
                                raise KeyError(f"Module doesn't contain member {k}.")
         | 
| 411 | 
            +
                    return keys
         | 
| 412 | 
            +
             | 
| 413 | 
            +
                def freeze(
         | 
| 414 | 
            +
                    self,
         | 
| 415 | 
            +
                    *,
         | 
| 416 | 
            +
                    recurse: bool = True,
         | 
| 417 | 
            +
                    keys: Optional[Union[str, List[str]]] = None,
         | 
| 418 | 
            +
                    strict: bool = False,
         | 
| 419 | 
            +
                ):
         | 
| 420 | 
            +
                    """Freeze the Module's parameters or some of them. Freezing a parameter means not
         | 
| 421 | 
            +
                    computing gradients for it.
         | 
| 422 | 
            +
             | 
| 423 | 
            +
                    This function is idempotent i.e. freezing a frozen model is a no-op.
         | 
| 424 | 
            +
             | 
| 425 | 
            +
                    Example:
         | 
| 426 | 
            +
                        For instance to only train the attention parameters from a Transformer:
         | 
| 427 | 
            +
             | 
| 428 | 
            +
                        .. code-block:: python
         | 
| 429 | 
            +
             | 
| 430 | 
            +
                            model = nn.Transformer()
         | 
| 431 | 
            +
                            model.freeze()
         | 
| 432 | 
            +
                            model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None)
         | 
| 433 | 
            +
             | 
| 434 | 
            +
                    Args:
         | 
| 435 | 
            +
                        recurse (bool, optional): If True then freeze the parameters of the
         | 
| 436 | 
            +
                            submodules as well. Default: ``True``.
         | 
| 437 | 
            +
                        keys (str or list[str], optional): If provided then only these
         | 
| 438 | 
            +
                            parameters will be frozen otherwise all the parameters of a
         | 
| 439 | 
            +
                            module. For instance freeze all biases by calling
         | 
| 440 | 
            +
                            ``module.freeze(keys="bias")``.
         | 
| 441 | 
            +
                        strict (bool, optional): If set to ``True`` validate that the passed keys exist.
         | 
| 442 | 
            +
                            Default: ``False``.
         | 
| 443 | 
            +
                    """
         | 
| 444 | 
            +
             | 
| 445 | 
            +
                    def _freeze_impl(_, m):
         | 
| 446 | 
            +
                        local_keys = keys
         | 
| 447 | 
            +
                        if local_keys is None:
         | 
| 448 | 
            +
                            local_keys = tree_flatten(
         | 
| 449 | 
            +
                                m.filter_and_map(
         | 
| 450 | 
            +
                                    lambda m, k, v: (not isinstance(v, Module))
         | 
| 451 | 
            +
                                    and m.valid_parameter_filter(m, k, v)
         | 
| 452 | 
            +
                                )
         | 
| 453 | 
            +
                            )
         | 
| 454 | 
            +
                            local_keys = [k for (k, v) in local_keys]
         | 
| 455 | 
            +
             | 
| 456 | 
            +
                        local_keys = m._validate_keys(local_keys, strict)
         | 
| 457 | 
            +
                        m._no_grad.update(local_keys)
         | 
| 458 | 
            +
             | 
| 459 | 
            +
                    if recurse:
         | 
| 460 | 
            +
                        self.apply_to_modules(_freeze_impl)
         | 
| 461 | 
            +
                    else:
         | 
| 462 | 
            +
                        _freeze_impl("", self)
         | 
| 463 | 
            +
             | 
| 464 | 
            +
                def unfreeze(
         | 
| 465 | 
            +
                    self,
         | 
| 466 | 
            +
                    *,
         | 
| 467 | 
            +
                    recurse: bool = True,
         | 
| 468 | 
            +
                    keys: Optional[Union[str, List[str]]] = None,
         | 
| 469 | 
            +
                    strict: bool = False,
         | 
| 470 | 
            +
                ):
         | 
| 471 | 
            +
                    """Unfreeze the Module's parameters or some of them.
         | 
| 472 | 
            +
             | 
| 473 | 
            +
                    This function is idempotent ie unfreezing a model that is not frozen is
         | 
| 474 | 
            +
                    a noop.
         | 
| 475 | 
            +
             | 
| 476 | 
            +
                    Example:
         | 
| 477 | 
            +
             | 
| 478 | 
            +
                        For instance to only train the biases of a Transformer one can do:
         | 
| 479 | 
            +
             | 
| 480 | 
            +
                        .. code-block:: python
         | 
| 481 | 
            +
             | 
| 482 | 
            +
                            model = nn.Transformer()
         | 
| 483 | 
            +
                            model.freeze()
         | 
| 484 | 
            +
                            model.unfreeze(keys="bias")
         | 
| 485 | 
            +
             | 
| 486 | 
            +
                    Args:
         | 
| 487 | 
            +
                        recurse (bool, optional): If True then unfreeze the parameters of the
         | 
| 488 | 
            +
                            submodules as well. Default: ``True``.
         | 
| 489 | 
            +
                        keys (str or list[str], optional): If provided then only these
         | 
| 490 | 
            +
                            parameters will be unfrozen otherwise all the parameters of a
         | 
| 491 | 
            +
                            module. For instance unfreeze all biases by calling
         | 
| 492 | 
            +
                            ``module.unfreeze(keys="bias")``.
         | 
| 493 | 
            +
                        strict (bool, optional): If set to ``True`` validate that the passed keys exist.
         | 
| 494 | 
            +
                            Default: ``False``.
         | 
| 495 | 
            +
                    """
         | 
| 496 | 
            +
             | 
| 497 | 
            +
                    def _unfreeze_impl(_, m):
         | 
| 498 | 
            +
                        if keys is None:
         | 
| 499 | 
            +
                            m._no_grad.clear()
         | 
| 500 | 
            +
             | 
| 501 | 
            +
                        else:
         | 
| 502 | 
            +
                            local_keys = m._validate_keys(keys, strict)
         | 
| 503 | 
            +
                            m._no_grad.difference_update(local_keys)
         | 
| 504 | 
            +
             | 
| 505 | 
            +
                    if recurse:
         | 
| 506 | 
            +
                        self.apply_to_modules(_unfreeze_impl)
         | 
| 507 | 
            +
                    else:
         | 
| 508 | 
            +
                        _unfreeze_impl("", self)
         | 
| 509 | 
            +
             | 
| 510 | 
            +
                def train(self, mode: bool = True):
         | 
| 511 | 
            +
                    """Set the model in or out of training mode.
         | 
| 512 | 
            +
             | 
| 513 | 
            +
                    Training mode only applies to certain layers. For example
         | 
| 514 | 
            +
                    :obj:`Dropout` applies a random mask in training mode, but is the
         | 
| 515 | 
            +
                    identity in evaluation mode.
         | 
| 516 | 
            +
             | 
| 517 | 
            +
                    Args:
         | 
| 518 | 
            +
                        mode (bool): Indicate if the model should be in training or
         | 
| 519 | 
            +
                            evaluation mode. Default: ``True``.
         | 
| 520 | 
            +
                    """
         | 
| 521 | 
            +
             | 
| 522 | 
            +
                    def _set_train(_, m):
         | 
| 523 | 
            +
                        m._training = mode
         | 
| 524 | 
            +
             | 
| 525 | 
            +
                    self.apply_to_modules(_set_train)
         | 
| 526 | 
            +
             | 
| 527 | 
            +
                def eval(self):
         | 
| 528 | 
            +
                    """Set the model to evaluation mode.
         | 
| 529 | 
            +
             | 
| 530 | 
            +
                    See :func:`train`.
         | 
| 531 | 
            +
                    """
         | 
| 532 | 
            +
                    self.train(False)
         | 
    	
        lib/python3.11/site-packages/mlx/nn/layers/containers.py
    ADDED
    
    | @@ -0,0 +1,24 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright © 2023 Apple Inc.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from mlx.nn.layers.base import Module
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            class Sequential(Module):
         | 
| 7 | 
            +
                """A layer that calls the passed callables in order.
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                We can pass either modules or plain callables to the Sequential module. If
         | 
| 10 | 
            +
                our functions have learnable parameters they should be implemented as
         | 
| 11 | 
            +
                ``nn.Module`` instances.
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                Args:
         | 
| 14 | 
            +
                    modules (tuple of Callables): The modules to call in order
         | 
| 15 | 
            +
                """
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                def __init__(self, *modules):
         | 
| 18 | 
            +
                    super().__init__()
         | 
| 19 | 
            +
                    self.layers = list(modules)
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def __call__(self, x):
         | 
| 22 | 
            +
                    for m in self.layers:
         | 
| 23 | 
            +
                        x = m(x)
         | 
| 24 | 
            +
                    return x
         | 
    	
        lib/python3.11/site-packages/mlx/nn/layers/convolution.py
    ADDED
    
    | @@ -0,0 +1,126 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright © 2023 Apple Inc.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import math
         | 
| 4 | 
            +
            from typing import Union
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import mlx.core as mx
         | 
| 7 | 
            +
            from mlx.nn.layers.base import Module
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            class Conv1d(Module):
         | 
| 11 | 
            +
                """Applies a 1-dimensional convolution over the multi-channel input sequence.
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                The channels are expected to be last i.e. the input shape should be ``NLC`` where:
         | 
| 14 | 
            +
                    - ``N`` is the batch dimension
         | 
| 15 | 
            +
                    - ``L`` is the sequence length
         | 
| 16 | 
            +
                    - ``C`` is the number of input channels
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                Args:
         | 
| 19 | 
            +
                    in_channels (int): The number of input channels
         | 
| 20 | 
            +
                    out_channels (int): The number of output channels
         | 
| 21 | 
            +
                    kernel_size (int): The size of the convolution filters
         | 
| 22 | 
            +
                    stride (int, optional): The stride when applying the filter.
         | 
| 23 | 
            +
                        Default: 1.
         | 
| 24 | 
            +
                    padding (int, optional): How many positions to 0-pad the input with.
         | 
| 25 | 
            +
                        Default: 0.
         | 
| 26 | 
            +
                    bias (bool, optional): If ``True`` add a learnable bias to the output.
         | 
| 27 | 
            +
                        Default: ``True``
         | 
| 28 | 
            +
                """
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                def __init__(
         | 
| 31 | 
            +
                    self,
         | 
| 32 | 
            +
                    in_channels: int,
         | 
| 33 | 
            +
                    out_channels: int,
         | 
| 34 | 
            +
                    kernel_size: int,
         | 
| 35 | 
            +
                    stride: int = 1,
         | 
| 36 | 
            +
                    padding: int = 0,
         | 
| 37 | 
            +
                    bias: bool = True,
         | 
| 38 | 
            +
                ):
         | 
| 39 | 
            +
                    super().__init__()
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    scale = math.sqrt(1 / (in_channels * kernel_size))
         | 
| 42 | 
            +
                    self.weight = mx.random.uniform(
         | 
| 43 | 
            +
                        low=-scale,
         | 
| 44 | 
            +
                        high=scale,
         | 
| 45 | 
            +
                        shape=(out_channels, kernel_size, in_channels),
         | 
| 46 | 
            +
                    )
         | 
| 47 | 
            +
                    if bias:
         | 
| 48 | 
            +
                        self.bias = mx.zeros((out_channels,))
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                    self.padding = padding
         | 
| 51 | 
            +
                    self.stride = stride
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                def _extra_repr(self):
         | 
| 54 | 
            +
                    return (
         | 
| 55 | 
            +
                        f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
         | 
| 56 | 
            +
                        f"kernel_size={self.weight.shape[1]}, stride={self.stride}, "
         | 
| 57 | 
            +
                        f"padding={self.padding}, bias={'bias' in self}"
         | 
| 58 | 
            +
                    )
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                def __call__(self, x):
         | 
| 61 | 
            +
                    y = mx.conv1d(x, self.weight, self.stride, self.padding)
         | 
| 62 | 
            +
                    if "bias" in self:
         | 
| 63 | 
            +
                        y = y + self.bias
         | 
| 64 | 
            +
                    return y
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            class Conv2d(Module):
         | 
| 68 | 
            +
                """Applies a 2-dimensional convolution over the multi-channel input image.
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                The channels are expected to be last i.e. the input shape should be ``NHWC`` where:
         | 
| 71 | 
            +
                    - ``N`` is the batch dimension
         | 
| 72 | 
            +
                    - ``H`` is the input image height
         | 
| 73 | 
            +
                    - ``W`` is the input image width
         | 
| 74 | 
            +
                    - ``C`` is the number of input channels
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                Args:
         | 
| 77 | 
            +
                    in_channels (int): The number of input channels.
         | 
| 78 | 
            +
                    out_channels (int): The number of output channels.
         | 
| 79 | 
            +
                    kernel_size (int or tuple): The size of the convolution filters.
         | 
| 80 | 
            +
                    stride (int or tuple, optional): The size of the stride when
         | 
| 81 | 
            +
                        applying the filter. Default: 1.
         | 
| 82 | 
            +
                    padding (int or tuple, optional): How many positions to 0-pad
         | 
| 83 | 
            +
                        the input with. Default: 0.
         | 
| 84 | 
            +
                    bias (bool, optional): If ``True`` add a learnable bias to the
         | 
| 85 | 
            +
                        output. Default: ``True``
         | 
| 86 | 
            +
                """
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                def __init__(
         | 
| 89 | 
            +
                    self,
         | 
| 90 | 
            +
                    in_channels: int,
         | 
| 91 | 
            +
                    out_channels: int,
         | 
| 92 | 
            +
                    kernel_size: Union[int, tuple],
         | 
| 93 | 
            +
                    stride: Union[int, tuple] = 1,
         | 
| 94 | 
            +
                    padding: Union[int, tuple] = 0,
         | 
| 95 | 
            +
                    bias: bool = True,
         | 
| 96 | 
            +
                ):
         | 
| 97 | 
            +
                    super().__init__()
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                    kernel_size, stride, padding = map(
         | 
| 100 | 
            +
                        lambda x: (x, x) if isinstance(x, int) else x,
         | 
| 101 | 
            +
                        (kernel_size, stride, padding),
         | 
| 102 | 
            +
                    )
         | 
| 103 | 
            +
                    scale = math.sqrt(1 / (in_channels * kernel_size[0] * kernel_size[1]))
         | 
| 104 | 
            +
                    self.weight = mx.random.uniform(
         | 
| 105 | 
            +
                        low=-scale,
         | 
| 106 | 
            +
                        high=scale,
         | 
| 107 | 
            +
                        shape=(out_channels, *kernel_size, in_channels),
         | 
| 108 | 
            +
                    )
         | 
| 109 | 
            +
                    if bias:
         | 
| 110 | 
            +
                        self.bias = mx.zeros((out_channels,))
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                    self.padding = padding
         | 
| 113 | 
            +
                    self.stride = stride
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                def _extra_repr(self):
         | 
| 116 | 
            +
                    return (
         | 
| 117 | 
            +
                        f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
         | 
| 118 | 
            +
                        f"kernel_size={self.weight.shape[1:2]}, stride={self.stride}, "
         | 
| 119 | 
            +
                        f"padding={self.padding}, bias={'bias' in self}"
         | 
| 120 | 
            +
                    )
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                def __call__(self, x):
         | 
| 123 | 
            +
                    y = mx.conv2d(x, self.weight, self.stride, self.padding)
         | 
| 124 | 
            +
                    if "bias" in self:
         | 
| 125 | 
            +
                        y = y + self.bias
         | 
| 126 | 
            +
                    return y
         | 
    	
        lib/python3.11/site-packages/mlx/nn/layers/dropout.py
    ADDED
    
    | @@ -0,0 +1,137 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright © 2023 Apple Inc.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import mlx.core as mx
         | 
| 4 | 
            +
            from mlx.nn.layers.base import Module
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            class Dropout(Module):
         | 
| 8 | 
            +
                r"""Randomly zero a portion of the elements during training.
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                The remaining elements are multiplied with :math:`\frac{1}{1-p}` where
         | 
| 11 | 
            +
                :math:`p` is the probability of zeroing an element. This is done so the
         | 
| 12 | 
            +
                expected value of a given element will remain the same.
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                Args:
         | 
| 15 | 
            +
                    p (float): The probability to zero an element
         | 
| 16 | 
            +
                """
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                def __init__(self, p: float = 0.5):
         | 
| 19 | 
            +
                    super().__init__()
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                    if p < 0 or p >= 1:
         | 
| 22 | 
            +
                        raise ValueError(f"The dropout probability {p} is not in [0, 1)")
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                    self._p_1 = 1 - p
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                def _extra_repr(self):
         | 
| 27 | 
            +
                    return f"p={1-self._p_1}"
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                def __call__(self, x):
         | 
| 30 | 
            +
                    if self._p_1 == 1 or not self.training:
         | 
| 31 | 
            +
                        return x
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    mask = mx.random.bernoulli(self._p_1, x.shape)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                    return (1 / self._p_1) * mask * x
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            class Dropout2d(Module):
         | 
| 39 | 
            +
                r"""Apply 2D channel-wise dropout during training.
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                Randomly zero out entire channels independently with probability :math:`p`.
         | 
| 42 | 
            +
                This layer expects the channels to be last, i.e. the input shape should be
         | 
| 43 | 
            +
                ``NWHC`` or ``WHC`` where:``N`` is the batch dimension,``H`` is the input
         | 
| 44 | 
            +
                image height,``W`` is the input image width, and``C`` is the number of
         | 
| 45 | 
            +
                input channels
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                The remaining channels are scaled by :math:`\frac{1}{1-p}` to
         | 
| 48 | 
            +
                maintain the expected value of each element. Unlike traditional dropout,
         | 
| 49 | 
            +
                which zeros individual entries, this layer zeros entire channels. This is
         | 
| 50 | 
            +
                beneficial for early convolution layers where adjacent pixels are
         | 
| 51 | 
            +
                correlated. In such case, traditional dropout may not effectively
         | 
| 52 | 
            +
                regularize activations. For more details, see [1].
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                [1]: Thompson, J., Goroshin, R., Jain, A., LeCun, Y. and Bregler C., 2015.
         | 
| 55 | 
            +
                Efficient Object Localization Using Convolutional Networks. CVPR 2015.
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                Args:
         | 
| 58 | 
            +
                    p (float): Probability of zeroing a channel during training.
         | 
| 59 | 
            +
                """
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                def __init__(self, p: float = 0.5):
         | 
| 62 | 
            +
                    super().__init__()
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    if p < 0 or p >= 1:
         | 
| 65 | 
            +
                        raise ValueError(f"The dropout probability {p} is not in [0, 1)")
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                    self._p_1 = 1 - p
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                def _extra_repr(self):
         | 
| 70 | 
            +
                    return f"p={1-self._p_1}"
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                def __call__(self, x):
         | 
| 73 | 
            +
                    if x.ndim not in (3, 4):
         | 
| 74 | 
            +
                        raise ValueError(
         | 
| 75 | 
            +
                            f"Received input with {x.ndim} dimensions. Expected 3 or 4 dimensions."
         | 
| 76 | 
            +
                        )
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    if self._p_1 == 1 or not self.training:
         | 
| 79 | 
            +
                        return x
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    # Dropout is applied on the whole channel
         | 
| 82 | 
            +
                    # 3D input: (1, 1, C)
         | 
| 83 | 
            +
                    # 4D input: (B, 1, 1, C)
         | 
| 84 | 
            +
                    mask_shape = x.shape
         | 
| 85 | 
            +
                    mask_shape[-2] = mask_shape[-3] = 1
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
         | 
| 88 | 
            +
                    return (1 / self._p_1) * mask * x
         | 
| 89 | 
            +
             | 
| 90 | 
            +
             | 
| 91 | 
            +
            class Dropout3d(Module):
         | 
| 92 | 
            +
                r"""Apply 3D channel-wise dropout during training.
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                Randomly zero out entire channels independently with probability :math:`p`.
         | 
| 95 | 
            +
                This layer expects the channels to be last, i.e., the input shape should be
         | 
| 96 | 
            +
                `NDHWC` or `DHWC` where: `N` is the batch dimension, `D` is the depth,
         | 
| 97 | 
            +
                `H` is the input image height, `W` is the input image width, and `C` is
         | 
| 98 | 
            +
                the number of input channels.
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                The remaining channels are scaled by :math:`\frac{1}{1-p}` to
         | 
| 101 | 
            +
                maintain the expected value of each element. Unlike traditional dropout,
         | 
| 102 | 
            +
                which zeros individual entries, this layer zeros entire channels. This is
         | 
| 103 | 
            +
                often beneficial for convolutional layers processing 3D data, like in
         | 
| 104 | 
            +
                medical imaging or video processing.
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                Args:
         | 
| 107 | 
            +
                    p (float): Probability of zeroing a channel during training.
         | 
| 108 | 
            +
                """
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                def __init__(self, p: float = 0.5):
         | 
| 111 | 
            +
                    super().__init__()
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                    if p < 0 or p >= 1:
         | 
| 114 | 
            +
                        raise ValueError(f"The dropout probability {p} is not in [0, 1)")
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                    self._p_1 = 1 - p
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                def _extra_repr(self):
         | 
| 119 | 
            +
                    return f"p={1-self._p_1}"
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                def __call__(self, x):
         | 
| 122 | 
            +
                    if x.ndim not in (4, 5):
         | 
| 123 | 
            +
                        raise ValueError(
         | 
| 124 | 
            +
                            f"Received input with {x.ndim} dimensions. Expected 4 or 5 dimensions."
         | 
| 125 | 
            +
                        )
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    if self._p_1 == 1 or not self.training:
         | 
| 128 | 
            +
                        return x
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    # Dropout is applied on the whole channel
         | 
| 131 | 
            +
                    # 4D input: (1, 1, 1, C)
         | 
| 132 | 
            +
                    # 5D input: (B, 1, 1, 1, C)
         | 
| 133 | 
            +
                    mask_shape = list(x.shape)
         | 
| 134 | 
            +
                    mask_shape[-2] = mask_shape[-3] = mask_shape[-4] = 1
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
         | 
| 137 | 
            +
                    return (1 / self._p_1) * mask * x
         | 
    	
        lib/python3.11/site-packages/mlx/nn/layers/embedding.py
    ADDED
    
    | @@ -0,0 +1,30 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright © 2023 Apple Inc.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import math
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import mlx.core as mx
         | 
| 6 | 
            +
            from mlx.nn.layers.base import Module
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            class Embedding(Module):
         | 
| 10 | 
            +
                """Implements a simple lookup table that maps each input integer to a
         | 
| 11 | 
            +
                high-dimensional vector.
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                Typically used to embed discrete tokens for processing by neural networks.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                Args:
         | 
| 16 | 
            +
                    num_embeddings (int): How many possible discrete tokens can we embed.
         | 
| 17 | 
            +
                                          Usually called the vocabulary size.
         | 
| 18 | 
            +
                    dims (int): The dimensionality of the embeddings.
         | 
| 19 | 
            +
                """
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def __init__(self, num_embeddings: int, dims: int):
         | 
| 22 | 
            +
                    super().__init__()
         | 
| 23 | 
            +
                    scale = math.sqrt(1 / dims)
         | 
| 24 | 
            +
                    self.weight = mx.random.normal((num_embeddings, dims)) * scale
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                def _extra_repr(self):
         | 
| 27 | 
            +
                    return f"{self.weight.shape[0]}, {self.weight.shape[1]}"
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                def __call__(self, x):
         | 
| 30 | 
            +
                    return self.weight[x]
         | 
    	
        lib/python3.11/site-packages/mlx/nn/layers/linear.py
    ADDED
    
    | @@ -0,0 +1,141 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright © 2023 Apple Inc.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import math
         | 
| 4 | 
            +
            from typing import Any
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import mlx.core as mx
         | 
| 7 | 
            +
            from mlx.nn.layers.base import Module
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            class Identity(Module):
         | 
| 11 | 
            +
                r"""A placeholder identity operator that is argument-insensitive.
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                Args:
         | 
| 14 | 
            +
                    args: any argument (unused)
         | 
| 15 | 
            +
                    kwargs: any keyword argument (unused)
         | 
| 16 | 
            +
                """
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                def __init__(self, *args: Any, **kwargs: Any) -> None:
         | 
| 19 | 
            +
                    super().__init__()
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def __call__(self, x: mx.array) -> mx.array:
         | 
| 22 | 
            +
                    return x
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            class Linear(Module):
         | 
| 26 | 
            +
                r"""Applies an affine transformation to the input.
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                Concretely:
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                .. math::
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                    y = x W^\top + b
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                where:
         | 
| 35 | 
            +
                where :math:`W` has shape ``[output_dims, input_dims]`` and :math:`b` has shape ``[output_dims]``.
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                The values are initialized from the uniform distribution :math:`\mathcal{U}(-{k}, {k})`,
         | 
| 38 | 
            +
                where :math:`k = \frac{1}{\sqrt{D_i}}` and :math:`D_i` is equal to ``input_dims``.
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                Args:
         | 
| 41 | 
            +
                    input_dims (int): The dimensionality of the input features
         | 
| 42 | 
            +
                    output_dims (int): The dimensionality of the output features
         | 
| 43 | 
            +
                    bias (bool, optional): If set to ``False`` then the layer will
         | 
| 44 | 
            +
                      not use a bias. Default is ``True``.
         | 
| 45 | 
            +
                """
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                def __init__(self, input_dims: int, output_dims: int, bias: bool = True) -> None:
         | 
| 48 | 
            +
                    super().__init__()
         | 
| 49 | 
            +
                    scale = math.sqrt(1.0 / input_dims)
         | 
| 50 | 
            +
                    self.weight = mx.random.uniform(
         | 
| 51 | 
            +
                        low=-scale,
         | 
| 52 | 
            +
                        high=scale,
         | 
| 53 | 
            +
                        shape=(output_dims, input_dims),
         | 
| 54 | 
            +
                    )
         | 
| 55 | 
            +
                    if bias:
         | 
| 56 | 
            +
                        self.bias = mx.random.uniform(
         | 
| 57 | 
            +
                            low=-scale,
         | 
| 58 | 
            +
                            high=scale,
         | 
| 59 | 
            +
                            shape=(output_dims,),
         | 
| 60 | 
            +
                        )
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                def _extra_repr(self) -> str:
         | 
| 63 | 
            +
                    return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}"
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                def __call__(self, x: mx.array) -> mx.array:
         | 
| 66 | 
            +
                    x = x @ self.weight.T
         | 
| 67 | 
            +
                    if "bias" in self:
         | 
| 68 | 
            +
                        x = x + self.bias
         | 
| 69 | 
            +
                    return x
         | 
| 70 | 
            +
             | 
| 71 | 
            +
             | 
| 72 | 
            +
            class Bilinear(Module):
         | 
| 73 | 
            +
                r"""Applies a bilinear transformation to the inputs.
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                Concretely:
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                .. math::
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    y_i = x_1^\top W_i x_2 + b_i
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                where:
         | 
| 82 | 
            +
                :math:`W` has shape ``[output_dims, input1_dims, input2_dims]``, :math:`b` has shape ``[output_dims ]``,
         | 
| 83 | 
            +
                and :math:`i` indexes the output dimension.
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                The values are initialized from the uniform distribution :math:`\mathcal{U}(-{k}, {k})`,
         | 
| 86 | 
            +
                where :math:`k = \frac{1}{\sqrt{D_1}}` and :math:`D_1` is ``input1_dims``.
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                Args:
         | 
| 89 | 
            +
                    input1_dims (int): The dimensionality of the input1 features
         | 
| 90 | 
            +
                    input2_dims (int): The dimensionality of the input2 features
         | 
| 91 | 
            +
                    output_dims (int): The dimensionality of the output features
         | 
| 92 | 
            +
                    bias (bool, optional): If set to ``False`` then the layer will
         | 
| 93 | 
            +
                      not use a bias. Default is ``True``.
         | 
| 94 | 
            +
                """
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                def __init__(
         | 
| 97 | 
            +
                    self, input1_dims: int, input2_dims: int, output_dims: int, bias: bool = True
         | 
| 98 | 
            +
                ) -> None:
         | 
| 99 | 
            +
                    super().__init__()
         | 
| 100 | 
            +
                    scale = math.sqrt(1.0 / input1_dims)
         | 
| 101 | 
            +
                    self.weight = mx.random.uniform(
         | 
| 102 | 
            +
                        low=-scale,
         | 
| 103 | 
            +
                        high=scale,
         | 
| 104 | 
            +
                        shape=(output_dims, input2_dims, input1_dims),
         | 
| 105 | 
            +
                    )
         | 
| 106 | 
            +
                    if bias:
         | 
| 107 | 
            +
                        self.bias = mx.random.uniform(
         | 
| 108 | 
            +
                            low=-scale,
         | 
| 109 | 
            +
                            high=scale,
         | 
| 110 | 
            +
                            shape=(output_dims,),
         | 
| 111 | 
            +
                        )
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                def _extra_repr(self) -> str:
         | 
| 114 | 
            +
                    out, in2, in1 = self.weight.shape
         | 
| 115 | 
            +
                    return (
         | 
| 116 | 
            +
                        f"input1_dims={in1}, input2_dims={in2}, output_dims={out}, "
         | 
| 117 | 
            +
                        f"bias={'bias' in self}"
         | 
| 118 | 
            +
                    )
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                def __call__(self, x1: mx.array, x2: mx.array) -> mx.array:
         | 
| 121 | 
            +
                    # Normalize shapes
         | 
| 122 | 
            +
                    out, in2, in1 = self.weight.shape
         | 
| 123 | 
            +
                    xshape = x1.shape[:-1]
         | 
| 124 | 
            +
                    x1 = x1.reshape(-1, in1)
         | 
| 125 | 
            +
                    x2 = x2.reshape(-1, 1, in2)
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    # Perform the bilinear transformation
         | 
| 128 | 
            +
                    w = self.weight.reshape(out * in2, in1)
         | 
| 129 | 
            +
                    y = x1 @ w.T
         | 
| 130 | 
            +
                    y = y.reshape(-1, out, in2).swapaxes(-2, -1)
         | 
| 131 | 
            +
                    y = x2 @ y
         | 
| 132 | 
            +
                    y = y.squeeze(1)
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                    # Reset the shape
         | 
| 135 | 
            +
                    y = y.reshape(*xshape, out)
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    # Apply the bias
         | 
| 138 | 
            +
                    if "bias" in self:
         | 
| 139 | 
            +
                        y = y + self.bias
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                    return y
         | 
    	
        lib/python3.11/site-packages/mlx/nn/layers/normalization.py
    ADDED
    
    | @@ -0,0 +1,368 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright © 2023 Apple Inc.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from typing import Tuple
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import mlx.core as mx
         | 
| 6 | 
            +
            from mlx.nn.layers.base import Module
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            class InstanceNorm(Module):
         | 
| 10 | 
            +
                r"""Applies instance normalization [1] on the inputs.
         | 
| 11 | 
            +
             | 
| 12 | 
            +
                Computes
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                .. math::
         | 
| 15 | 
            +
             | 
| 16 | 
            +
                    y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta,
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                where :math:`\gamma` and :math:`\beta` are learned per feature dimension
         | 
| 19 | 
            +
                parameters initialized at 1 and 0 respectively. Both are of size :attr:`dims`,
         | 
| 20 | 
            +
                if :attr:`affine` is ``True``.
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                Args:
         | 
| 23 | 
            +
                    dims (int): The number of features of the input.
         | 
| 24 | 
            +
                    eps (float): A value added to the denominator for numerical stability. Default: ``1e-5``.
         | 
| 25 | 
            +
                    affine (bool): Default: ``False``.
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                Shape:
         | 
| 28 | 
            +
                  - Input: :math:`(..., C)` where :math:`C` is equal to :attr:`dims`.
         | 
| 29 | 
            +
                  - Output: Same shape as the input.
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                Examples:
         | 
| 32 | 
            +
                    >>> import mlx.core as mx
         | 
| 33 | 
            +
                    >>> import mlx.nn as nn
         | 
| 34 | 
            +
                    >>> x = mx.random.normal((8, 4, 4, 16))
         | 
| 35 | 
            +
                    >>> inorm = nn.InstanceNorm(dims=16)
         | 
| 36 | 
            +
                    >>> output = inorm(x)
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                References:
         | 
| 39 | 
            +
                    [1]: https://arxiv.org/abs/1607.08022
         | 
| 40 | 
            +
                """
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                def __init__(
         | 
| 43 | 
            +
                    self,
         | 
| 44 | 
            +
                    dims: int,
         | 
| 45 | 
            +
                    eps: float = 1e-5,
         | 
| 46 | 
            +
                    affine: bool = False,
         | 
| 47 | 
            +
                ):
         | 
| 48 | 
            +
                    super().__init__()
         | 
| 49 | 
            +
                    if affine:
         | 
| 50 | 
            +
                        self.weight = mx.ones((dims,))
         | 
| 51 | 
            +
                        self.bias = mx.zeros((dims,))
         | 
| 52 | 
            +
                    self.dims = dims
         | 
| 53 | 
            +
                    self.eps = eps
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                def _extra_repr(self):
         | 
| 56 | 
            +
                    return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                def __call__(self, x: mx.array) -> mx.array:
         | 
| 59 | 
            +
                    reduction_axes = tuple(range(1, x.ndim - 1))
         | 
| 60 | 
            +
                    # Compute stats
         | 
| 61 | 
            +
                    mean = mx.mean(x, axis=reduction_axes, keepdims=True)
         | 
| 62 | 
            +
                    var = mx.var(x, axis=reduction_axes, keepdims=True)
         | 
| 63 | 
            +
                    # Normalize
         | 
| 64 | 
            +
                    x = (x - mean) * mx.rsqrt(var + self.eps)
         | 
| 65 | 
            +
                    # Scale and shift if necessary
         | 
| 66 | 
            +
                    return (self.weight * x + self.bias) if "weight" in self else x
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            class LayerNorm(Module):
         | 
| 70 | 
            +
                r"""Applies layer normalization [1] on the inputs.
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                Computes
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                .. math::
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                where :math:`\gamma` and :math:`\beta` are learned per feature dimension
         | 
| 79 | 
            +
                parameters initialized at 1 and 0 respectively.
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                [1]: https://arxiv.org/abs/1607.06450
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                Args:
         | 
| 84 | 
            +
                    dims (int): The feature dimension of the input to normalize over
         | 
| 85 | 
            +
                    eps (float): A small additive constant for numerical stability
         | 
| 86 | 
            +
                    affine (bool): If True learn an affine transform to apply after the
         | 
| 87 | 
            +
                        normalization
         | 
| 88 | 
            +
                """
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True):
         | 
| 91 | 
            +
                    super().__init__()
         | 
| 92 | 
            +
                    if affine:
         | 
| 93 | 
            +
                        self.bias = mx.zeros((dims,))
         | 
| 94 | 
            +
                        self.weight = mx.ones((dims,))
         | 
| 95 | 
            +
                    self.eps = eps
         | 
| 96 | 
            +
                    self.dims = dims
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                def _extra_repr(self):
         | 
| 99 | 
            +
                    return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                def __call__(self, x):
         | 
| 102 | 
            +
                    means = mx.mean(x, axis=-1, keepdims=True)
         | 
| 103 | 
            +
                    var = mx.var(x, axis=-1, keepdims=True)
         | 
| 104 | 
            +
                    x = (x - means) * mx.rsqrt(var + self.eps)
         | 
| 105 | 
            +
                    return (self.weight * x + self.bias) if "weight" in self else x
         | 
| 106 | 
            +
             | 
| 107 | 
            +
             | 
| 108 | 
            +
            class RMSNorm(Module):
         | 
| 109 | 
            +
                r"""Applies Root Mean Square normalization [1] to the inputs.
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                Computes
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                ..  math::
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    y = \frac{x}{\sqrt{E[x^2] + \epsilon}} \gamma
         | 
| 116 | 
            +
             | 
| 117 | 
            +
                where :math:`\gamma` is a learned per feature dimension parameter initialized at
         | 
| 118 | 
            +
                1.
         | 
| 119 | 
            +
             | 
| 120 | 
            +
                [1]: https://arxiv.org/abs/1910.07467
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                Args:
         | 
| 123 | 
            +
                    dims (int): The feature dimension of the input to normalize over
         | 
| 124 | 
            +
                    eps (float): A small additive constant for numerical stability
         | 
| 125 | 
            +
                """
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                def __init__(self, dims: int, eps: float = 1e-5):
         | 
| 128 | 
            +
                    super().__init__()
         | 
| 129 | 
            +
                    self.weight = mx.ones((dims,))
         | 
| 130 | 
            +
                    self.eps = eps
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                def _extra_repr(self):
         | 
| 133 | 
            +
                    return f"{self.weight.shape[0]}, eps={self.eps}"
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                def __call__(self, x):
         | 
| 136 | 
            +
                    # S is 1/sqrt(N) where N is the size of the features of x and is used
         | 
| 137 | 
            +
                    # to compute a numerically more stable RMS of x by multiplying with S
         | 
| 138 | 
            +
                    # first and summing.
         | 
| 139 | 
            +
                    #
         | 
| 140 | 
            +
                    # This way we prefer underflow over overflow which is controlled with
         | 
| 141 | 
            +
                    # the parameter epsilon anyway.
         | 
| 142 | 
            +
                    S = 1 / x.shape[-1] ** 0.5
         | 
| 143 | 
            +
             | 
| 144 | 
            +
                    n = (x * S).square().sum(axis=-1, keepdims=True)
         | 
| 145 | 
            +
                    n = mx.rsqrt(n + self.eps)
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                    return self.weight * x * n
         | 
| 148 | 
            +
             | 
| 149 | 
            +
             | 
| 150 | 
            +
            class GroupNorm(Module):
         | 
| 151 | 
            +
                r"""Applies Group Normalization [1] to the inputs.
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                Computes the same normalization as layer norm, namely
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                .. math::
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                where :math:`\gamma` and :math:`\beta` are learned per feature dimension
         | 
| 160 | 
            +
                parameters initialized at 1 and 0 respectively. However, the mean and
         | 
| 161 | 
            +
                variance are computed over the spatial dimensions and each group of
         | 
| 162 | 
            +
                features. In particular, the input is split into num_groups across the
         | 
| 163 | 
            +
                feature dimension.
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                The feature dimension is assumed to be the last dimension and the dimensions
         | 
| 166 | 
            +
                that precede it (except the first) are considered the spatial dimensions.
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                [1]: https://arxiv.org/abs/1803.08494
         | 
| 169 | 
            +
             | 
| 170 | 
            +
                Args:
         | 
| 171 | 
            +
                    num_groups (int): Number of groups to separate the features into
         | 
| 172 | 
            +
                    dims (int): The feature dimensions of the input to normalize over
         | 
| 173 | 
            +
                    eps (float): A small additive constant for numerical stability
         | 
| 174 | 
            +
                    affine (bool): If True learn an affine transform to apply after the
         | 
| 175 | 
            +
                        normalization.
         | 
| 176 | 
            +
                    pytorch_compatible (bool): If True perform the group normalization in
         | 
| 177 | 
            +
                        the same order/grouping as PyTorch.
         | 
| 178 | 
            +
                """
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                def __init__(
         | 
| 181 | 
            +
                    self,
         | 
| 182 | 
            +
                    num_groups: int,
         | 
| 183 | 
            +
                    dims: int,
         | 
| 184 | 
            +
                    eps: float = 1e-5,
         | 
| 185 | 
            +
                    affine: bool = True,
         | 
| 186 | 
            +
                    pytorch_compatible: bool = False,
         | 
| 187 | 
            +
                ):
         | 
| 188 | 
            +
                    super().__init__()
         | 
| 189 | 
            +
                    if affine:
         | 
| 190 | 
            +
                        self.bias = mx.zeros((dims,))
         | 
| 191 | 
            +
                        self.weight = mx.ones((dims,))
         | 
| 192 | 
            +
                    self.num_groups = num_groups
         | 
| 193 | 
            +
                    self.dims = dims
         | 
| 194 | 
            +
                    self.eps = eps
         | 
| 195 | 
            +
                    self.pytorch_compatible = pytorch_compatible
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                def _extra_repr(self):
         | 
| 198 | 
            +
                    return (
         | 
| 199 | 
            +
                        f"{self.num_groups}, {self.dims}, eps={self.eps}, "
         | 
| 200 | 
            +
                        f"affine={'weight' in self}, pytorch_compatible={self.pytorch_compatible}"
         | 
| 201 | 
            +
                    )
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                def _pytorch_compatible_group_norm(self, x):
         | 
| 204 | 
            +
                    num_groups = self.num_groups
         | 
| 205 | 
            +
                    batch, *rest, dims = x.shape
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                    # Split into groups
         | 
| 208 | 
            +
                    x = x.reshape(batch, -1, num_groups, dims // num_groups)
         | 
| 209 | 
            +
                    x = x.transpose(0, 1, 3, 2).reshape(batch, -1, num_groups)
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                    # Normalize
         | 
| 212 | 
            +
                    means = mx.mean(x, axis=1, keepdims=True)
         | 
| 213 | 
            +
                    var = mx.var(x, axis=1, keepdims=True)
         | 
| 214 | 
            +
                    x = (x - means) * mx.rsqrt(var + self.eps)
         | 
| 215 | 
            +
                    x = x.reshape(batch, -1, dims // num_groups, num_groups)
         | 
| 216 | 
            +
                    x = x.transpose(0, 1, 3, 2).reshape(batch, *rest, dims)
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                    return x
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                def _group_norm(self, x):
         | 
| 221 | 
            +
                    num_groups = self.num_groups
         | 
| 222 | 
            +
                    batch, *rest, dims = x.shape
         | 
| 223 | 
            +
             | 
| 224 | 
            +
                    # Split into groups
         | 
| 225 | 
            +
                    x = x.reshape(batch, -1, num_groups)
         | 
| 226 | 
            +
             | 
| 227 | 
            +
                    # Normalize
         | 
| 228 | 
            +
                    means = mx.mean(x, axis=1, keepdims=True)
         | 
| 229 | 
            +
                    var = mx.var(x, axis=1, keepdims=True)
         | 
| 230 | 
            +
                    x = (x - means) * mx.rsqrt(var + self.eps)
         | 
| 231 | 
            +
                    x = x.reshape(batch, *rest, dims)
         | 
| 232 | 
            +
             | 
| 233 | 
            +
                    return x
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                def __call__(self, x):
         | 
| 236 | 
            +
                    group_norm = (
         | 
| 237 | 
            +
                        self._pytorch_compatible_group_norm
         | 
| 238 | 
            +
                        if self.pytorch_compatible
         | 
| 239 | 
            +
                        else self._group_norm
         | 
| 240 | 
            +
                    )
         | 
| 241 | 
            +
                    x = group_norm(x)
         | 
| 242 | 
            +
                    return (self.weight * x + self.bias) if "weight" in self else x
         | 
| 243 | 
            +
             | 
| 244 | 
            +
             | 
| 245 | 
            +
            class BatchNorm(Module):
         | 
| 246 | 
            +
                r"""Applies Batch Normalization over a 2D or 3D input.
         | 
| 247 | 
            +
             | 
| 248 | 
            +
                Computes
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                .. math::
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                    y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,
         | 
| 253 | 
            +
             | 
| 254 | 
            +
                where :math:`\gamma` and :math:`\beta` are learned per feature dimension
         | 
| 255 | 
            +
                parameters initialized at 1 and 0 respectively.
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                The input shape is specified as ``NC`` or ``NLC``, where ``N`` is the
         | 
| 258 | 
            +
                batch, ``C`` is the number of features or channels, and ``L`` is the
         | 
| 259 | 
            +
                sequence length. The output has the same shape as the input. For
         | 
| 260 | 
            +
                four-dimensional arrays, the shape is ``NHWC``, where ``H`` and ``W`` are
         | 
| 261 | 
            +
                the height and width respectively.
         | 
| 262 | 
            +
             | 
| 263 | 
            +
                For more information on Batch Normalization, see the original paper `Batch
         | 
| 264 | 
            +
                Normalization: Accelerating Deep Network Training by Reducing Internal
         | 
| 265 | 
            +
                Covariate Shift <https://arxiv.org/abs/1502.03167>`_.
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                Args:
         | 
| 268 | 
            +
                    num_features (int): The feature dimension to normalize over.
         | 
| 269 | 
            +
                    eps (float, optional): A small additive constant for numerical
         | 
| 270 | 
            +
                        stability. Default: ``1e-5``.
         | 
| 271 | 
            +
                    momentum (float, optional): The momentum for updating the running
         | 
| 272 | 
            +
                        mean and variance. Default: ``0.1``.
         | 
| 273 | 
            +
                    affine (bool, optional): If ``True``, apply a learned affine
         | 
| 274 | 
            +
                        transformation after the normalization. Default: ``True``.
         | 
| 275 | 
            +
                    track_running_stats (bool, optional): If ``True``, track the
         | 
| 276 | 
            +
                        running mean and variance. Default: ``True``.
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                Examples:
         | 
| 279 | 
            +
                    >>> import mlx.core as mx
         | 
| 280 | 
            +
                    >>> import mlx.nn as nn
         | 
| 281 | 
            +
                    >>> x = mx.random.normal((5, 4))
         | 
| 282 | 
            +
                    >>> bn = nn.BatchNorm(num_features=4, affine=True)
         | 
| 283 | 
            +
                    >>> output = bn(x)
         | 
| 284 | 
            +
                """
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                def __init__(
         | 
| 287 | 
            +
                    self,
         | 
| 288 | 
            +
                    num_features: int,
         | 
| 289 | 
            +
                    eps: float = 1e-5,
         | 
| 290 | 
            +
                    momentum: float = 0.1,
         | 
| 291 | 
            +
                    affine: bool = True,
         | 
| 292 | 
            +
                    track_running_stats: bool = True,
         | 
| 293 | 
            +
                ):
         | 
| 294 | 
            +
                    super().__init__()
         | 
| 295 | 
            +
             | 
| 296 | 
            +
                    self.num_features = num_features
         | 
| 297 | 
            +
                    self.eps = eps
         | 
| 298 | 
            +
                    self.momentum = momentum
         | 
| 299 | 
            +
                    self.track_running_stats = track_running_stats
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                    if affine:
         | 
| 302 | 
            +
                        self.weight = mx.ones((num_features,))
         | 
| 303 | 
            +
                        self.bias = mx.zeros((num_features,))
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                    if self.track_running_stats:
         | 
| 306 | 
            +
                        self.running_mean = mx.zeros((num_features,))
         | 
| 307 | 
            +
                        self.running_var = mx.ones((num_features,))
         | 
| 308 | 
            +
                        self.freeze(keys=["running_mean", "running_var"], recurse=False)
         | 
| 309 | 
            +
             | 
| 310 | 
            +
                def unfreeze(self, *args, **kwargs):
         | 
| 311 | 
            +
                    """Wrap unfreeze to make sure that running_mean and var are always
         | 
| 312 | 
            +
                    frozen parameters."""
         | 
| 313 | 
            +
                    super().unfreeze(*args, **kwargs)
         | 
| 314 | 
            +
                    self.freeze(keys=["running_mean", "running_var"], recurse=False)
         | 
| 315 | 
            +
             | 
| 316 | 
            +
                def _extra_repr(self):
         | 
| 317 | 
            +
                    return (
         | 
| 318 | 
            +
                        f"{self.num_features}, eps={self.eps}, "
         | 
| 319 | 
            +
                        f"momentum={self.momentum}, affine={'weight' in self}, "
         | 
| 320 | 
            +
                        f"track_running_stats={self.track_running_stats}"
         | 
| 321 | 
            +
                    )
         | 
| 322 | 
            +
             | 
| 323 | 
            +
                def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
         | 
| 324 | 
            +
                    """
         | 
| 325 | 
            +
                    Calculate the mean and variance of the input tensor across the batch
         | 
| 326 | 
            +
                    and spatial dimensions.
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                    Args:
         | 
| 329 | 
            +
                        x (array): Input tensor.
         | 
| 330 | 
            +
             | 
| 331 | 
            +
                    Returns:
         | 
| 332 | 
            +
                        tuple: Tuple containing mean and variance.
         | 
| 333 | 
            +
                    """
         | 
| 334 | 
            +
                    reduction_axes = tuple(range(0, x.ndim - 1))
         | 
| 335 | 
            +
             | 
| 336 | 
            +
                    mean = mx.mean(x, axis=reduction_axes, keepdims=True)
         | 
| 337 | 
            +
                    var = mx.var(x, axis=reduction_axes, keepdims=True)
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                    return mean, var
         | 
| 340 | 
            +
             | 
| 341 | 
            +
                def __call__(self, x: mx.array) -> mx.array:
         | 
| 342 | 
            +
                    """
         | 
| 343 | 
            +
                    Forward pass of BatchNorm.
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                    Args:
         | 
| 346 | 
            +
                        x (array): Input tensor.
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                    Returns:
         | 
| 349 | 
            +
                        array: Normalized output tensor.
         | 
| 350 | 
            +
                    """
         | 
| 351 | 
            +
                    if x.ndim < 2 or x.ndim > 4:
         | 
| 352 | 
            +
                        raise ValueError(
         | 
| 353 | 
            +
                            f"Expected input tensor to have 2, 3 or 4 dimensions, but got {x.ndim}"
         | 
| 354 | 
            +
                        )
         | 
| 355 | 
            +
             | 
| 356 | 
            +
                    # Calculate the mean and variance used to normalize the input x. If we
         | 
| 357 | 
            +
                    # are in training mode update the running stats if needed.
         | 
| 358 | 
            +
                    mean, var = self._calc_stats(x)
         | 
| 359 | 
            +
                    if self.training and self.track_running_stats:
         | 
| 360 | 
            +
                        mu = self.momentum
         | 
| 361 | 
            +
                        self.running_mean = (1 - mu) * self.running_mean + mu * mean
         | 
| 362 | 
            +
                        self.running_var = (1 - mu) * self.running_var + mu * var
         | 
| 363 | 
            +
                    elif self.track_running_stats:
         | 
| 364 | 
            +
                        mean = self.running_mean
         | 
| 365 | 
            +
                        var = self.running_var
         | 
| 366 | 
            +
             | 
| 367 | 
            +
                    x = (x - mean) * mx.rsqrt(var + self.eps)
         | 
| 368 | 
            +
                    return (self.weight * x + self.bias) if "weight" in self else x
         | 
    	
        lib/python3.11/site-packages/mlx/nn/layers/positional_encoding.py
    ADDED
    
    | @@ -0,0 +1,199 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright © 2023 Apple Inc.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import math
         | 
| 4 | 
            +
            from typing import Optional
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import mlx.core as mx
         | 
| 7 | 
            +
            from mlx.nn.layers.base import Module
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            class RoPE(Module):
         | 
| 11 | 
            +
                """Implements the rotary positional encoding.
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                The traditional implementation rotates consecutive pairs of elements in the
         | 
| 14 | 
            +
                feature dimension while the default implementation rotates pairs with
         | 
| 15 | 
            +
                stride half the feature dimensions for efficiency.
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                For more details see `RoFormer: Enhanced Transformer with Rotary Position
         | 
| 18 | 
            +
                Embedding <https://arxiv.org/abs/2104.09864>`_.
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                Args:
         | 
| 21 | 
            +
                    dims (int): The feature dimensions to be rotated. If the input feature
         | 
| 22 | 
            +
                        is larger than dims then the rest is left unchanged.
         | 
| 23 | 
            +
                    traditional (bool, optional): If set to True choose the traditional
         | 
| 24 | 
            +
                        implementation which is slightly less efficient. Default: ``False``.
         | 
| 25 | 
            +
                    base (float, optional): The base used to compute angular frequency for
         | 
| 26 | 
            +
                        each dimension in the positional encodings. Default: ``10000``.
         | 
| 27 | 
            +
                    scale (float, optional): The scale used to scale the positions. Default: ``1.0``.
         | 
| 28 | 
            +
                """
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                def __init__(
         | 
| 31 | 
            +
                    self,
         | 
| 32 | 
            +
                    dims: int,
         | 
| 33 | 
            +
                    traditional: bool = False,
         | 
| 34 | 
            +
                    base: float = 10000,
         | 
| 35 | 
            +
                    scale: float = 1.0,
         | 
| 36 | 
            +
                ):
         | 
| 37 | 
            +
                    super().__init__()
         | 
| 38 | 
            +
                    self.dims = dims
         | 
| 39 | 
            +
                    self.traditional = traditional
         | 
| 40 | 
            +
                    self.base = base
         | 
| 41 | 
            +
                    self.scale = scale
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                def _extra_repr(self):
         | 
| 44 | 
            +
                    return f"{self.dims}, traditional={self.traditional}"
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                def _compute_rope(self, costheta, sintheta, x):
         | 
| 47 | 
            +
                    x1 = x[..., : self.dims // 2]
         | 
| 48 | 
            +
                    x2 = x[..., self.dims // 2 : self.dims]
         | 
| 49 | 
            +
                    rx1 = x1 * costheta - x2 * sintheta
         | 
| 50 | 
            +
                    rx2 = x1 * sintheta + x2 * costheta
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    if self.dims < x.shape[-1]:
         | 
| 53 | 
            +
                        rx = mx.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1)
         | 
| 54 | 
            +
                    else:
         | 
| 55 | 
            +
                        rx = mx.concatenate([rx1, rx2], axis=-1)
         | 
| 56 | 
            +
             | 
| 57 | 
            +
                    return rx
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                def _compute_traditional_rope(self, costheta, sintheta, x):
         | 
| 60 | 
            +
                    x1 = x[..., ::2]
         | 
| 61 | 
            +
                    x2 = x[..., 1::2]
         | 
| 62 | 
            +
                    rx1 = x1 * costheta - x2 * sintheta
         | 
| 63 | 
            +
                    rx2 = x1 * sintheta + x2 * costheta
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    if self.dims < x.shape[-1]:
         | 
| 66 | 
            +
                        raise NotImplementedError(
         | 
| 67 | 
            +
                            "RoPE doesn't implement partial traditional application"
         | 
| 68 | 
            +
                        )
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    return rx
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                def __call__(self, x, offset: int = 0):
         | 
| 75 | 
            +
                    shape = x.shape
         | 
| 76 | 
            +
                    x = mx.reshape(x, (-1, shape[-2], shape[-1]))
         | 
| 77 | 
            +
                    N = x.shape[1] + offset
         | 
| 78 | 
            +
                    costheta, sintheta = RoPE.create_cos_sin_theta(
         | 
| 79 | 
            +
                        N, self.dims, offset=offset, base=self.base, scale=self.scale, dtype=x.dtype
         | 
| 80 | 
            +
                    )
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                    rope = (
         | 
| 83 | 
            +
                        self._compute_traditional_rope if self.traditional else self._compute_rope
         | 
| 84 | 
            +
                    )
         | 
| 85 | 
            +
                    rx = rope(costheta, sintheta, x)
         | 
| 86 | 
            +
             | 
| 87 | 
            +
                    return mx.reshape(rx, shape)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                @staticmethod
         | 
| 90 | 
            +
                def create_cos_sin_theta(
         | 
| 91 | 
            +
                    N: int,
         | 
| 92 | 
            +
                    D: int,
         | 
| 93 | 
            +
                    offset: int = 0,
         | 
| 94 | 
            +
                    base: float = 10000,
         | 
| 95 | 
            +
                    scale: float = 1.0,
         | 
| 96 | 
            +
                    dtype=mx.float32,
         | 
| 97 | 
            +
                ):
         | 
| 98 | 
            +
                    D = D // 2
         | 
| 99 | 
            +
                    positions = mx.arange(offset, N, dtype=dtype) * scale
         | 
| 100 | 
            +
                    freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D))
         | 
| 101 | 
            +
                    theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
         | 
| 102 | 
            +
                    return mx.cos(theta), mx.sin(theta)
         | 
| 103 | 
            +
             | 
| 104 | 
            +
             | 
| 105 | 
            +
            class SinusoidalPositionalEncoding(Module):
         | 
| 106 | 
            +
                r"""Implements sinusoidal positional encoding.
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                For more details see the paper `Attention Is All You Need
         | 
| 109 | 
            +
                <https://arxiv.org/abs/1706.03762>`_.
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                Args:
         | 
| 112 | 
            +
                    dims (int): The dimensionality of the resulting positional embeddings.
         | 
| 113 | 
            +
                    min_freq (float, optional): The minimum frequency expected. Default:
         | 
| 114 | 
            +
                        ``0.0001``.
         | 
| 115 | 
            +
                    max_freq (float, optional): The maximum frequency expected. Default:
         | 
| 116 | 
            +
                        ``1``.
         | 
| 117 | 
            +
                    scale (float, optional): A multiplicative scale for the embeddings.
         | 
| 118 | 
            +
                        Default: ``sqrt(dims//2)``.
         | 
| 119 | 
            +
                    cos_first (bool, optional): If ``True`` embed using ``[cos(x); sin(x)]``
         | 
| 120 | 
            +
                        instead of the reverse. Default: ``False``.
         | 
| 121 | 
            +
                    full_turns (bool, optional): If ``True`` multiply the frequencies with
         | 
| 122 | 
            +
                        :math:`2\pi`. Default: ``False``.
         | 
| 123 | 
            +
                """
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                def __init__(
         | 
| 126 | 
            +
                    self,
         | 
| 127 | 
            +
                    dims: int,
         | 
| 128 | 
            +
                    min_freq: float = 0.0001,
         | 
| 129 | 
            +
                    max_freq: float = 1,
         | 
| 130 | 
            +
                    scale: Optional[float] = None,
         | 
| 131 | 
            +
                    cos_first: bool = False,
         | 
| 132 | 
            +
                    full_turns: bool = False,
         | 
| 133 | 
            +
                ):
         | 
| 134 | 
            +
                    super().__init__()
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    one_zero = 1 - mx.arange(0, dims // 2) / (dims // 2 - 1)
         | 
| 137 | 
            +
                    min_freq = math.log(min_freq)
         | 
| 138 | 
            +
                    max_freq = math.log(max_freq)
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                    # Start with underscore so it is not included in the parameters
         | 
| 141 | 
            +
                    self._sigmas = mx.exp(one_zero * (max_freq - min_freq) + min_freq)
         | 
| 142 | 
            +
                    if full_turns:
         | 
| 143 | 
            +
                        self._sigmas = self._sigmas * (2 * math.pi)
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    # Save some constants that define the implementation
         | 
| 146 | 
            +
                    self.scale = scale or (2 / dims) ** 0.5
         | 
| 147 | 
            +
                    self.cos_first = cos_first
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                def __call__(self, x):
         | 
| 150 | 
            +
                    y = x[..., None] * self._sigmas
         | 
| 151 | 
            +
                    cosy = mx.cos(y)
         | 
| 152 | 
            +
                    siny = mx.sin(y)
         | 
| 153 | 
            +
             | 
| 154 | 
            +
                    if self.cos_first:
         | 
| 155 | 
            +
                        y = mx.concatenate([cosy, siny], axis=-1)
         | 
| 156 | 
            +
                    else:
         | 
| 157 | 
            +
                        y = mx.concatenate([siny, cosy], axis=-1)
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    if self.scale != 1:
         | 
| 160 | 
            +
                        y = y * self.scale
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                    return y
         | 
| 163 | 
            +
             | 
| 164 | 
            +
             | 
| 165 | 
            +
            class ALiBi(Module):
         | 
| 166 | 
            +
                @staticmethod
         | 
| 167 | 
            +
                def create_alibi_matrix(
         | 
| 168 | 
            +
                    q_sequence_length: int,
         | 
| 169 | 
            +
                    k_sequence_length: int,
         | 
| 170 | 
            +
                    num_heads: int,
         | 
| 171 | 
            +
                    offset: int,
         | 
| 172 | 
            +
                    dtype=mx.float32,
         | 
| 173 | 
            +
                ):
         | 
| 174 | 
            +
                    x1 = mx.arange(offset, q_sequence_length)
         | 
| 175 | 
            +
                    x2 = mx.arange(0, k_sequence_length)
         | 
| 176 | 
            +
                    distance_matrix = -mx.abs(
         | 
| 177 | 
            +
                        mx.expand_dims(x1[:, None] - x2[None, :], axis=(0, 1))
         | 
| 178 | 
            +
                    )
         | 
| 179 | 
            +
                    alibi_slope = ALiBi.create_alibi_slope(num_heads=num_heads)
         | 
| 180 | 
            +
                    alibi_mask = (distance_matrix * alibi_slope).astype(dtype)
         | 
| 181 | 
            +
                    return alibi_mask
         | 
| 182 | 
            +
             | 
| 183 | 
            +
                @staticmethod
         | 
| 184 | 
            +
                def create_alibi_slope(num_heads):
         | 
| 185 | 
            +
                    x = (2**8) ** (1 / num_heads)
         | 
| 186 | 
            +
                    out = mx.power(x, -mx.arange(1, num_heads + 1))
         | 
| 187 | 
            +
                    return mx.expand_dims(out, axis=(-1, -2))
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                def __call__(self, attention_scores, offset=0, mask=None):
         | 
| 190 | 
            +
                    alibi_mask = ALiBi.create_alibi_matrix(
         | 
| 191 | 
            +
                        q_sequence_length=attention_scores.shape[-2] + offset,
         | 
| 192 | 
            +
                        k_sequence_length=attention_scores.shape[-1],
         | 
| 193 | 
            +
                        num_heads=attention_scores.shape[1],
         | 
| 194 | 
            +
                        offset=offset,
         | 
| 195 | 
            +
                        dtype=attention_scores.dtype,
         | 
| 196 | 
            +
                    )
         | 
| 197 | 
            +
                    if mask is not None:
         | 
| 198 | 
            +
                        alibi_mask = alibi_mask + mask
         | 
| 199 | 
            +
                    return attention_scores + alibi_mask
         | 
    	
        lib/python3.11/site-packages/mlx/nn/layers/quantized.py
    ADDED
    
    | @@ -0,0 +1,125 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright © 2023 Apple Inc.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import math
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import mlx.core as mx
         | 
| 6 | 
            +
            from mlx.nn.layers.base import Module
         | 
| 7 | 
            +
            from mlx.nn.layers.linear import Linear
         | 
| 8 | 
            +
            from mlx.utils import tree_flatten, tree_map
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            class QuantizedLinear(Module):
         | 
| 12 | 
            +
                """Applies an affine transformation to the input using a quantized weight matrix.
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                It is the quantized equivalent of :class:`mlx.nn.Linear`. For now its
         | 
| 15 | 
            +
                parameters are frozen and will not be included in any gradient computation
         | 
| 16 | 
            +
                but this will probably change in the future.
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                QuantizedLinear also provides two useful classmethods to convert linear
         | 
| 19 | 
            +
                layers to QuantizedLinear layers.
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                - :meth:`from_linear` returns a QuantizedLinear layer that applies the same
         | 
| 22 | 
            +
                  linear transformation up to the quantization error.
         | 
| 23 | 
            +
                - :meth:`quantize_module` swaps all the linear layers of the passed module
         | 
| 24 | 
            +
                  with QuantizedLinear ones.
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                Args:
         | 
| 27 | 
            +
                    input_dims (int): The dimensionality of the input features
         | 
| 28 | 
            +
                    output_dims (int): The dimensionality of the output features
         | 
| 29 | 
            +
                    bias (bool, optional): If set to ``False`` then the layer will not use
         | 
| 30 | 
            +
                        a bias. (default: True).
         | 
| 31 | 
            +
                    group_size (int, optional): The group size to use for the quantized
         | 
| 32 | 
            +
                        weight. See :func:`~mlx.core.quantize`. (default: 64)
         | 
| 33 | 
            +
                    bits (int, optional): The bit width to use for the quantized weight.
         | 
| 34 | 
            +
                        See :func:`~mlx.core.quantize`. (default: 4)
         | 
| 35 | 
            +
                """
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                def __init__(
         | 
| 38 | 
            +
                    self,
         | 
| 39 | 
            +
                    input_dims: int,
         | 
| 40 | 
            +
                    output_dims: int,
         | 
| 41 | 
            +
                    bias: bool = True,
         | 
| 42 | 
            +
                    group_size: int = 64,
         | 
| 43 | 
            +
                    bits: int = 4,
         | 
| 44 | 
            +
                ):
         | 
| 45 | 
            +
                    super().__init__()
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    # Quantization config
         | 
| 48 | 
            +
                    self.group_size = group_size
         | 
| 49 | 
            +
                    self.bits = bits
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                    # Initialize the quantized weight
         | 
| 52 | 
            +
                    scale = math.sqrt(1 / input_dims)
         | 
| 53 | 
            +
                    weight = mx.random.uniform(
         | 
| 54 | 
            +
                        low=-scale,
         | 
| 55 | 
            +
                        high=scale,
         | 
| 56 | 
            +
                        shape=(output_dims, input_dims),
         | 
| 57 | 
            +
                    )
         | 
| 58 | 
            +
                    self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    # And bias if needed
         | 
| 61 | 
            +
                    if bias:
         | 
| 62 | 
            +
                        self.bias = mx.zeros((output_dims,))
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    # Freeze this model's parameters
         | 
| 65 | 
            +
                    self.freeze()
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                def unfreeze(self, *args, **kwargs):
         | 
| 68 | 
            +
                    """Wrap unfreeze so that we unfreeze any layers we might contain but
         | 
| 69 | 
            +
                    our parameters will remain frozen."""
         | 
| 70 | 
            +
                    super().unfreeze(*args, **kwargs)
         | 
| 71 | 
            +
                    self.freeze(recurse=False)
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                def _extra_repr(self):
         | 
| 74 | 
            +
                    out_dims, in_dims = self.weight.shape
         | 
| 75 | 
            +
                    in_dims *= 32 // self.bits
         | 
| 76 | 
            +
                    return (
         | 
| 77 | 
            +
                        f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self},"
         | 
| 78 | 
            +
                        f"group_size={self.group_size}, bits={self.bits}"
         | 
| 79 | 
            +
                    )
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                def __call__(self, x):
         | 
| 82 | 
            +
                    x = mx.quantized_matmul(
         | 
| 83 | 
            +
                        x,
         | 
| 84 | 
            +
                        self.weight,
         | 
| 85 | 
            +
                        scales=self.scales,
         | 
| 86 | 
            +
                        biases=self.biases,
         | 
| 87 | 
            +
                        transpose=True,
         | 
| 88 | 
            +
                        group_size=self.group_size,
         | 
| 89 | 
            +
                        bits=self.bits,
         | 
| 90 | 
            +
                    )
         | 
| 91 | 
            +
                    if "bias" in self:
         | 
| 92 | 
            +
                        x = x + self.bias
         | 
| 93 | 
            +
                    return x
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                @classmethod
         | 
| 96 | 
            +
                def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
         | 
| 97 | 
            +
                    """Create a QuantizedLinear layer from the parameters of a provided
         | 
| 98 | 
            +
                    linear layer."""
         | 
| 99 | 
            +
                    output_dims, input_dims = linear_layer.weight.shape
         | 
| 100 | 
            +
                    ql = cls(input_dims, output_dims, False, group_size, bits)
         | 
| 101 | 
            +
                    ql.weight, ql.scales, ql.biases = mx.quantize(
         | 
| 102 | 
            +
                        linear_layer.weight, group_size, bits
         | 
| 103 | 
            +
                    )
         | 
| 104 | 
            +
                    if "bias" in linear_layer:
         | 
| 105 | 
            +
                        ql.bias = linear_layer.bias
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                    return ql
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                @classmethod
         | 
| 110 | 
            +
                def quantize_module(
         | 
| 111 | 
            +
                    cls,
         | 
| 112 | 
            +
                    model: Module,
         | 
| 113 | 
            +
                    group_size: int = 64,
         | 
| 114 | 
            +
                    bits: int = 4,
         | 
| 115 | 
            +
                    linear_class_predicate=lambda m: isinstance(m, Linear),
         | 
| 116 | 
            +
                ):
         | 
| 117 | 
            +
                    def _quantize_if_linear(m):
         | 
| 118 | 
            +
                        if linear_class_predicate(m):
         | 
| 119 | 
            +
                            return cls.from_linear(m, group_size, bits)
         | 
| 120 | 
            +
                        else:
         | 
| 121 | 
            +
                            return m
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    leaves = model.leaf_modules()
         | 
| 124 | 
            +
                    leaves = tree_map(_quantize_if_linear, leaves, is_leaf=Module.is_module)
         | 
| 125 | 
            +
                    model.update_modules(leaves)
         | 
    	
        lib/python3.11/site-packages/mlx/nn/layers/transformer.py
    ADDED
    
    | @@ -0,0 +1,354 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright © 2023 Apple Inc.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import math
         | 
| 4 | 
            +
            from typing import Any, Callable, Optional
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import mlx.core as mx
         | 
| 7 | 
            +
            from mlx.nn.layers.activations import relu
         | 
| 8 | 
            +
            from mlx.nn.layers.base import Module
         | 
| 9 | 
            +
            from mlx.nn.layers.dropout import Dropout
         | 
| 10 | 
            +
            from mlx.nn.layers.linear import Linear
         | 
| 11 | 
            +
            from mlx.nn.layers.normalization import LayerNorm
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            class MultiHeadAttention(Module):
         | 
| 15 | 
            +
                """Implements the scaled dot product attention with multiple heads.
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                Given inputs for queries, keys and values the ``MultiHeadAttention``
         | 
| 18 | 
            +
                produces new values by aggregating information from the input values
         | 
| 19 | 
            +
                according to the similarities of the input queries and keys.
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                All inputs as well as the output are linearly projected without biases by
         | 
| 22 | 
            +
                default.
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                ``MultiHeadAttention`` also takes an optional additive attention mask that
         | 
| 25 | 
            +
                should be broadcastable with ``(batch, num_heads, # queries, # keys)``. The
         | 
| 26 | 
            +
                mask should have ``-inf`` or very large negative numbers at the positions
         | 
| 27 | 
            +
                that should *not* be attended to.
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                Args:
         | 
| 30 | 
            +
                    dims (int): The model dimensions. This is also the default
         | 
| 31 | 
            +
                        value for the queries, keys, values, and the output.
         | 
| 32 | 
            +
                    num_heads (int): The number of attention heads to use.
         | 
| 33 | 
            +
                    query_input_dims (int, optional): The input dimensions of the queries.
         | 
| 34 | 
            +
                        Default: ``dims``.
         | 
| 35 | 
            +
                    key_input_dims (int, optional): The input dimensions of the keys.
         | 
| 36 | 
            +
                        Default: ``dims``.
         | 
| 37 | 
            +
                    value_input_dims (int, optional): The input dimensions of the values.
         | 
| 38 | 
            +
                        Default: ``key_input_dims``.
         | 
| 39 | 
            +
                    value_dims (int, optional): The dimensions of the values after the
         | 
| 40 | 
            +
                        projection. Default: ``dims``.
         | 
| 41 | 
            +
                    value_output_dims (int, optional): The dimensions the new values will
         | 
| 42 | 
            +
                        be projected to. Default: ``dims``.
         | 
| 43 | 
            +
                    bias (bool, optional): Whether or not to use a bias in the projections.
         | 
| 44 | 
            +
                        Default: ``False``.
         | 
| 45 | 
            +
                """
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                def __init__(
         | 
| 48 | 
            +
                    self,
         | 
| 49 | 
            +
                    dims: int,
         | 
| 50 | 
            +
                    num_heads: int,
         | 
| 51 | 
            +
                    query_input_dims: Optional[int] = None,
         | 
| 52 | 
            +
                    key_input_dims: Optional[int] = None,
         | 
| 53 | 
            +
                    value_input_dims: Optional[int] = None,
         | 
| 54 | 
            +
                    value_dims: Optional[int] = None,
         | 
| 55 | 
            +
                    value_output_dims: Optional[int] = None,
         | 
| 56 | 
            +
                    bias: bool = False,
         | 
| 57 | 
            +
                ):
         | 
| 58 | 
            +
                    super().__init__()
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    if (dims % num_heads) != 0:
         | 
| 61 | 
            +
                        raise ValueError(
         | 
| 62 | 
            +
                            "The input feature dimensions should be divisible by the "
         | 
| 63 | 
            +
                            f"number of heads ({dims} % {num_heads}) != 0"
         | 
| 64 | 
            +
                        )
         | 
| 65 | 
            +
             | 
| 66 | 
            +
                    query_input_dims = query_input_dims or dims
         | 
| 67 | 
            +
                    key_input_dims = key_input_dims or dims
         | 
| 68 | 
            +
                    value_input_dims = value_input_dims or key_input_dims
         | 
| 69 | 
            +
                    value_dims = value_dims or dims
         | 
| 70 | 
            +
                    value_output_dims = value_output_dims or dims
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    self.num_heads = num_heads
         | 
| 73 | 
            +
                    self.query_proj = Linear(query_input_dims, dims, bias=bias)
         | 
| 74 | 
            +
                    self.key_proj = Linear(key_input_dims, dims, bias=bias)
         | 
| 75 | 
            +
                    self.value_proj = Linear(value_input_dims, value_dims, bias=bias)
         | 
| 76 | 
            +
                    self.out_proj = Linear(value_dims, value_output_dims, bias=bias)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                def __call__(self, queries, keys, values, mask=None):
         | 
| 79 | 
            +
                    queries = self.query_proj(queries)
         | 
| 80 | 
            +
                    keys = self.key_proj(keys)
         | 
| 81 | 
            +
                    values = self.value_proj(values)
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                    num_heads = self.num_heads
         | 
| 84 | 
            +
                    B, L, D = queries.shape
         | 
| 85 | 
            +
                    _, S, _ = keys.shape
         | 
| 86 | 
            +
                    queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
         | 
| 87 | 
            +
                    keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
         | 
| 88 | 
            +
                    values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    # Dimensions are [batch x num heads x sequence x hidden dim]
         | 
| 91 | 
            +
                    scale = math.sqrt(1 / queries.shape[-1])
         | 
| 92 | 
            +
                    scores = (queries * scale) @ keys
         | 
| 93 | 
            +
                    if mask is not None:
         | 
| 94 | 
            +
                        scores = scores + mask.astype(scores.dtype)
         | 
| 95 | 
            +
                    scores = mx.softmax(scores, axis=-1)
         | 
| 96 | 
            +
                    values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    return self.out_proj(values_hat)
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                @staticmethod
         | 
| 101 | 
            +
                def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):
         | 
| 102 | 
            +
                    indices = mx.arange(N)
         | 
| 103 | 
            +
                    mask = indices[:, None] < indices[None]
         | 
| 104 | 
            +
                    # usually inf but 1e9 is as good and softmax(full(1e9)) != nan
         | 
| 105 | 
            +
                    # TODO: Should replace this with finfo(dtype).min
         | 
| 106 | 
            +
                    mask = mask.astype(dtype) * -1e9
         | 
| 107 | 
            +
                    return mask
         | 
| 108 | 
            +
             | 
| 109 | 
            +
             | 
| 110 | 
            +
            class TransformerEncoderLayer(Module):
         | 
| 111 | 
            +
                def __init__(
         | 
| 112 | 
            +
                    self,
         | 
| 113 | 
            +
                    dims: int,
         | 
| 114 | 
            +
                    num_heads: int,
         | 
| 115 | 
            +
                    mlp_dims: Optional[int] = None,
         | 
| 116 | 
            +
                    dropout: float = 0.0,
         | 
| 117 | 
            +
                    activation: Callable[[Any], Any] = relu,
         | 
| 118 | 
            +
                    norm_first: bool = False,
         | 
| 119 | 
            +
                ):
         | 
| 120 | 
            +
                    super().__init__()
         | 
| 121 | 
            +
                    mlp_dims = mlp_dims or dims * 4
         | 
| 122 | 
            +
                    self.attention = MultiHeadAttention(dims, num_heads)
         | 
| 123 | 
            +
                    self.ln1 = LayerNorm(dims)
         | 
| 124 | 
            +
                    self.ln2 = LayerNorm(dims)
         | 
| 125 | 
            +
                    self.linear1 = Linear(dims, mlp_dims)
         | 
| 126 | 
            +
                    self.linear2 = Linear(mlp_dims, dims)
         | 
| 127 | 
            +
                    self.dropout1 = Dropout(dropout)
         | 
| 128 | 
            +
                    self.dropout2 = Dropout(dropout)
         | 
| 129 | 
            +
                    self.activation = activation
         | 
| 130 | 
            +
                    self.norm_first = norm_first
         | 
| 131 | 
            +
             | 
| 132 | 
            +
                def __call__(self, x, mask):
         | 
| 133 | 
            +
                    if self.norm_first:
         | 
| 134 | 
            +
                        y = self.ln1(x)
         | 
| 135 | 
            +
                        y = self.attention(y, y, y, mask)
         | 
| 136 | 
            +
                        y = self.dropout1(y)
         | 
| 137 | 
            +
                        x = x + y
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                        y = self.ln2(x)
         | 
| 140 | 
            +
                        y = self.linear1(y)
         | 
| 141 | 
            +
                        y = self.activation(y)
         | 
| 142 | 
            +
                        y = self.dropout2(y)
         | 
| 143 | 
            +
                        y = self.linear2(y)
         | 
| 144 | 
            +
                        y = x + y
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    else:
         | 
| 147 | 
            +
                        y = self.attention(x, x, x, mask)
         | 
| 148 | 
            +
                        y = self.dropout1(y)
         | 
| 149 | 
            +
                        y = self.ln1(x + y)
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                        y = self.linear1(y)
         | 
| 152 | 
            +
                        y = self.activation(y)
         | 
| 153 | 
            +
                        y = self.dropout2(y)
         | 
| 154 | 
            +
                        y = self.linear2(y)
         | 
| 155 | 
            +
                        y = self.ln2(x + y)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                    return y
         | 
| 158 | 
            +
             | 
| 159 | 
            +
             | 
| 160 | 
            +
            class TransformerEncoder(Module):
         | 
| 161 | 
            +
                def __init__(
         | 
| 162 | 
            +
                    self,
         | 
| 163 | 
            +
                    num_layers: int,
         | 
| 164 | 
            +
                    dims: int,
         | 
| 165 | 
            +
                    num_heads: int,
         | 
| 166 | 
            +
                    mlp_dims: Optional[int] = None,
         | 
| 167 | 
            +
                    dropout: float = 0.0,
         | 
| 168 | 
            +
                    activation=relu,
         | 
| 169 | 
            +
                    norm_first: bool = False,
         | 
| 170 | 
            +
                ):
         | 
| 171 | 
            +
                    super().__init__()
         | 
| 172 | 
            +
                    self.layers = [
         | 
| 173 | 
            +
                        TransformerEncoderLayer(
         | 
| 174 | 
            +
                            dims, num_heads, mlp_dims, dropout, activation, norm_first
         | 
| 175 | 
            +
                        )
         | 
| 176 | 
            +
                        for i in range(num_layers)
         | 
| 177 | 
            +
                    ]
         | 
| 178 | 
            +
                    self.ln = LayerNorm(dims)
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                def __call__(self, x, mask):
         | 
| 181 | 
            +
                    for l in self.layers:
         | 
| 182 | 
            +
                        x = l(x, mask)
         | 
| 183 | 
            +
                    return self.ln(x)
         | 
| 184 | 
            +
             | 
| 185 | 
            +
             | 
| 186 | 
            +
            class TransformerDecoderLayer(Module):
         | 
| 187 | 
            +
                def __init__(
         | 
| 188 | 
            +
                    self,
         | 
| 189 | 
            +
                    dims: int,
         | 
| 190 | 
            +
                    num_heads: int,
         | 
| 191 | 
            +
                    mlp_dims: Optional[int] = None,
         | 
| 192 | 
            +
                    dropout: float = 0.0,
         | 
| 193 | 
            +
                    activation: Callable[[Any], Any] = relu,
         | 
| 194 | 
            +
                    norm_first: bool = False,
         | 
| 195 | 
            +
                ):
         | 
| 196 | 
            +
                    super().__init__()
         | 
| 197 | 
            +
                    mlp_dims = mlp_dims or dims * 4
         | 
| 198 | 
            +
                    self.self_attention = MultiHeadAttention(dims, num_heads)
         | 
| 199 | 
            +
                    self.cross_attention = MultiHeadAttention(dims, num_heads)
         | 
| 200 | 
            +
                    self.ln1 = LayerNorm(dims)
         | 
| 201 | 
            +
                    self.ln2 = LayerNorm(dims)
         | 
| 202 | 
            +
                    self.ln3 = LayerNorm(dims)
         | 
| 203 | 
            +
                    self.linear1 = Linear(dims, mlp_dims)
         | 
| 204 | 
            +
                    self.linear2 = Linear(mlp_dims, dims)
         | 
| 205 | 
            +
                    self.dropout1 = Dropout(dropout)
         | 
| 206 | 
            +
                    self.dropout2 = Dropout(dropout)
         | 
| 207 | 
            +
                    self.dropout3 = Dropout(dropout)
         | 
| 208 | 
            +
                    self.activation = activation
         | 
| 209 | 
            +
                    self.norm_first = norm_first
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                def __call__(self, x, memory, x_mask, memory_mask):
         | 
| 212 | 
            +
                    if self.norm_first:
         | 
| 213 | 
            +
                        y = self.ln1(x)
         | 
| 214 | 
            +
                        y = self.self_attention(y, y, y, x_mask)
         | 
| 215 | 
            +
                        y = self.dropout1(y)
         | 
| 216 | 
            +
                        x = x + y
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                        y = self.ln2(x)
         | 
| 219 | 
            +
                        y = self.cross_attention(y, memory, memory, memory_mask)
         | 
| 220 | 
            +
                        y = self.dropout2(y)
         | 
| 221 | 
            +
                        x = x + y
         | 
| 222 | 
            +
             | 
| 223 | 
            +
                        y = self.ln3(x)
         | 
| 224 | 
            +
                        y = self.linear1(y)
         | 
| 225 | 
            +
                        y = self.activation(y)
         | 
| 226 | 
            +
                        y = self.dropout3(y)
         | 
| 227 | 
            +
                        y = self.linear2(y)
         | 
| 228 | 
            +
                        y = x + y
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                    else:
         | 
| 231 | 
            +
                        y = self.self_attention(x, x, x, x_mask)
         | 
| 232 | 
            +
                        y = self.dropout1(y)
         | 
| 233 | 
            +
                        x = self.ln1(x + y)
         | 
| 234 | 
            +
             | 
| 235 | 
            +
                        y = self.cross_attention(y, memory, memory, memory_mask)
         | 
| 236 | 
            +
                        y = self.dropout2(y)
         | 
| 237 | 
            +
                        x = self.ln1(x + y)
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                        y = self.linear1(x)
         | 
| 240 | 
            +
                        y = self.activation(y)
         | 
| 241 | 
            +
                        y = self.dropout3(y)
         | 
| 242 | 
            +
                        y = self.linear2(y)
         | 
| 243 | 
            +
                        y = self.ln3(x + y)
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                    return y
         | 
| 246 | 
            +
             | 
| 247 | 
            +
             | 
| 248 | 
            +
            class TransformerDecoder(Module):
         | 
| 249 | 
            +
                def __init__(
         | 
| 250 | 
            +
                    self,
         | 
| 251 | 
            +
                    num_layers: int,
         | 
| 252 | 
            +
                    dims: int,
         | 
| 253 | 
            +
                    num_heads: int,
         | 
| 254 | 
            +
                    mlp_dims: Optional[int] = None,
         | 
| 255 | 
            +
                    dropout: float = 0.0,
         | 
| 256 | 
            +
                    activation=relu,
         | 
| 257 | 
            +
                    norm_first: bool = False,
         | 
| 258 | 
            +
                ):
         | 
| 259 | 
            +
                    super().__init__()
         | 
| 260 | 
            +
                    self.layers = [
         | 
| 261 | 
            +
                        TransformerDecoderLayer(
         | 
| 262 | 
            +
                            dims, num_heads, mlp_dims, dropout, activation, norm_first
         | 
| 263 | 
            +
                        )
         | 
| 264 | 
            +
                        for i in range(num_layers)
         | 
| 265 | 
            +
                    ]
         | 
| 266 | 
            +
                    self.ln = LayerNorm(dims)
         | 
| 267 | 
            +
             | 
| 268 | 
            +
                def __call__(self, x, memory, x_mask, memory_mask):
         | 
| 269 | 
            +
                    for l in self.layers:
         | 
| 270 | 
            +
                        x = l(x, memory, x_mask, memory_mask)
         | 
| 271 | 
            +
                    return self.ln(x)
         | 
| 272 | 
            +
             | 
| 273 | 
            +
             | 
| 274 | 
            +
            class Transformer(Module):
         | 
| 275 | 
            +
                """
         | 
| 276 | 
            +
                Implements a standard Transformer model.
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                The implementation is based on `Attention Is All You Need
         | 
| 279 | 
            +
                <https://arxiv.org/abs/1706.03762>`_.
         | 
| 280 | 
            +
             | 
| 281 | 
            +
                The Transformer model contains an encoder and a decoder. The encoder
         | 
| 282 | 
            +
                processes the input sequence and the decoder generates the output sequence.
         | 
| 283 | 
            +
                The interaction between encoder and decoder happens through the attention
         | 
| 284 | 
            +
                mechanism.
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                Args:
         | 
| 287 | 
            +
                    dims (int, optional): The number of expected features in the
         | 
| 288 | 
            +
                        encoder/decoder inputs. Default: ``512``.
         | 
| 289 | 
            +
                    num_heads (int, optional): The number of attention heads. Default:
         | 
| 290 | 
            +
                        ``8``.
         | 
| 291 | 
            +
                    num_encoder_layers (int, optional): The number of encoder layers in the
         | 
| 292 | 
            +
                        Transformer encoder. Default: ``6``.
         | 
| 293 | 
            +
                    num_decoder_layers (int, optional): The number of decoder layers in the
         | 
| 294 | 
            +
                        Transformer decoder. Default: ``6``.
         | 
| 295 | 
            +
                    mlp_dims (int, optional): The hidden dimension of the MLP block in each
         | 
| 296 | 
            +
                        Transformer layer. Defaults to ``4*dims`` if not provided. Default:
         | 
| 297 | 
            +
                        ``None``.
         | 
| 298 | 
            +
                    dropout (float, optional): The dropout value for the Transformer
         | 
| 299 | 
            +
                        encoder and decoder. Dropout is used after each attention layer and
         | 
| 300 | 
            +
                        the activation in the MLP layer. Default: ``0.0``.
         | 
| 301 | 
            +
                    activation (function, optional): the activation function for the MLP
         | 
| 302 | 
            +
                        hidden layer. Default: :func:`mlx.nn.relu`.
         | 
| 303 | 
            +
                    custom_encoder (nn.Module, optional): A custom encoder to replace the
         | 
| 304 | 
            +
                        standard Transformer encoder. Default: ``None``.
         | 
| 305 | 
            +
                    custom_decoder (nn.Module, optional): A custom decoder to replace the
         | 
| 306 | 
            +
                        standard Transformer decoder. Default: ``None``.
         | 
| 307 | 
            +
                    norm_first (bool, optional): if ``True``, encoder and decoder layers
         | 
| 308 | 
            +
                        will perform layer normalization before attention and MLP
         | 
| 309 | 
            +
                        operations, otherwise after. Default: ``False``.
         | 
| 310 | 
            +
                """
         | 
| 311 | 
            +
             | 
| 312 | 
            +
                def __init__(
         | 
| 313 | 
            +
                    self,
         | 
| 314 | 
            +
                    dims: int = 512,
         | 
| 315 | 
            +
                    num_heads: int = 8,
         | 
| 316 | 
            +
                    num_encoder_layers: int = 6,
         | 
| 317 | 
            +
                    num_decoder_layers: int = 6,
         | 
| 318 | 
            +
                    mlp_dims: Optional[int] = None,
         | 
| 319 | 
            +
                    dropout: float = 0.0,
         | 
| 320 | 
            +
                    activation: Callable[[Any], Any] = relu,
         | 
| 321 | 
            +
                    custom_encoder: Optional[Any] = None,
         | 
| 322 | 
            +
                    custom_decoder: Optional[Any] = None,
         | 
| 323 | 
            +
                    norm_first: bool = False,
         | 
| 324 | 
            +
                ):
         | 
| 325 | 
            +
                    super().__init__()
         | 
| 326 | 
            +
                    if custom_encoder is not None:
         | 
| 327 | 
            +
                        self.encoder = custom_encoder
         | 
| 328 | 
            +
                    else:
         | 
| 329 | 
            +
                        self.encoder = TransformerEncoder(
         | 
| 330 | 
            +
                            num_encoder_layers,
         | 
| 331 | 
            +
                            dims,
         | 
| 332 | 
            +
                            num_heads,
         | 
| 333 | 
            +
                            mlp_dims,
         | 
| 334 | 
            +
                            dropout,
         | 
| 335 | 
            +
                            activation,
         | 
| 336 | 
            +
                            norm_first,
         | 
| 337 | 
            +
                        )
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                    if custom_decoder is not None:
         | 
| 340 | 
            +
                        self.decoder = custom_decoder
         | 
| 341 | 
            +
                    else:
         | 
| 342 | 
            +
                        self.decoder = TransformerDecoder(
         | 
| 343 | 
            +
                            num_decoder_layers,
         | 
| 344 | 
            +
                            dims,
         | 
| 345 | 
            +
                            num_heads,
         | 
| 346 | 
            +
                            mlp_dims,
         | 
| 347 | 
            +
                            dropout,
         | 
| 348 | 
            +
                            activation,
         | 
| 349 | 
            +
                            norm_first,
         | 
| 350 | 
            +
                        )
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):
         | 
| 353 | 
            +
                    memory = self.encoder(src, src_mask)
         | 
| 354 | 
            +
                    return self.decoder(tgt, memory, tgt_mask, memory_mask)
         | 
    	
        lib/python3.11/site-packages/mlx/nn/losses.py
    ADDED
    
    | @@ -0,0 +1,374 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright © 2023 Apple Inc.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import math
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import mlx.core as mx
         | 
| 6 | 
            +
            from mlx.nn.layers.base import Module
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            def cross_entropy(
         | 
| 10 | 
            +
                logits: mx.array,
         | 
| 11 | 
            +
                targets: mx.array,
         | 
| 12 | 
            +
                weights: mx.array = None,
         | 
| 13 | 
            +
                axis: int = -1,
         | 
| 14 | 
            +
                label_smoothing: float = 0.0,
         | 
| 15 | 
            +
                reduction: str = "none",
         | 
| 16 | 
            +
            ) -> mx.array:
         | 
| 17 | 
            +
                """
         | 
| 18 | 
            +
                Computes the cross entropy loss.
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                Args:
         | 
| 21 | 
            +
                    logits (array): The unnormalized predicted logits.
         | 
| 22 | 
            +
                    targets (array): The target values, as class indices.
         | 
| 23 | 
            +
                    weights (array, optional): Weights for each target. Default: ``None``.
         | 
| 24 | 
            +
                    axis (int, optional): The axis over which to compute softmax. Default: ``-1``.
         | 
| 25 | 
            +
                    label_smoothing (float, optional): Label smoothing factor. Default: ``0``.
         | 
| 26 | 
            +
                    reduction (str, optional): Specifies the reduction to apply to the output:
         | 
| 27 | 
            +
                        ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                Returns:
         | 
| 30 | 
            +
                    array: The computed cross entropy loss.
         | 
| 31 | 
            +
                """
         | 
| 32 | 
            +
                if label_smoothing < 0 or label_smoothing >= 1:
         | 
| 33 | 
            +
                    raise ValueError(f"Label smoothing must in [0, 1), got {label_smoothing}.")
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                score = mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1)
         | 
| 36 | 
            +
                logsumexp_logits = mx.logsumexp(logits, axis=axis)
         | 
| 37 | 
            +
                if label_smoothing > 0:
         | 
| 38 | 
            +
                    # Adjust the true class score with label smoothing
         | 
| 39 | 
            +
                    adjusted_score = (1 - label_smoothing) * score
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    # Calculate the mean logit across the classes for smoothed loss
         | 
| 42 | 
            +
                    mean_logits = logits.mean(axis=axis)
         | 
| 43 | 
            +
                    smoothed_loss = -mean_logits * label_smoothing
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                    # Combine the adjusted score and smoothed loss with the logsumexp logits
         | 
| 46 | 
            +
                    loss = logsumexp_logits - adjusted_score + smoothed_loss
         | 
| 47 | 
            +
                else:
         | 
| 48 | 
            +
                    loss = logsumexp_logits - score
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                # Apply weights if provided
         | 
| 51 | 
            +
                if weights is not None:
         | 
| 52 | 
            +
                    if weights.shape != targets.shape:
         | 
| 53 | 
            +
                        raise ValueError(
         | 
| 54 | 
            +
                            f"Weights with shape {weights.shape} is not the same as "
         | 
| 55 | 
            +
                            f"targets with shape {targets.shape}."
         | 
| 56 | 
            +
                        )
         | 
| 57 | 
            +
                    loss *= weights
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                # Apply reduction
         | 
| 60 | 
            +
                return _reduce(loss, reduction)
         | 
| 61 | 
            +
             | 
| 62 | 
            +
             | 
| 63 | 
            +
            def binary_cross_entropy(
         | 
| 64 | 
            +
                logits: mx.array, targets: mx.array, reduction: str = "none"
         | 
| 65 | 
            +
            ) -> mx.array:
         | 
| 66 | 
            +
                """
         | 
| 67 | 
            +
                Computes the binary cross entropy loss.
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                Args:
         | 
| 70 | 
            +
                    logits (array): The unnormalized (pre-sigmoid) predicted logits.
         | 
| 71 | 
            +
                    targets (array): The binary target values in {0, 1}.
         | 
| 72 | 
            +
                    reduction (str, optional): Specifies the reduction to apply to the output:
         | 
| 73 | 
            +
                      ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                Returns:
         | 
| 76 | 
            +
                    array: The computed binary cross entropy loss.
         | 
| 77 | 
            +
                Examples:
         | 
| 78 | 
            +
                    >>> import mlx.core as mx
         | 
| 79 | 
            +
                    >>> import mlx.nn as nn
         | 
| 80 | 
            +
                    >>> inputs = mx.array([0.105361, 0.223144, 1.20397, 0.916291])
         | 
| 81 | 
            +
                    >>> targets = mx.array([0, 0, 1, 1])
         | 
| 82 | 
            +
                    >>> loss = nn.losses.binary_cross_entropy(inputs, targets, "mean")
         | 
| 83 | 
            +
                    >>> loss
         | 
| 84 | 
            +
                    array([0.612192], dtype=float32)
         | 
| 85 | 
            +
                """
         | 
| 86 | 
            +
                loss = mx.logaddexp(0.0, logits) - targets * logits
         | 
| 87 | 
            +
                return _reduce(loss, reduction)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
             | 
| 90 | 
            +
            def l1_loss(
         | 
| 91 | 
            +
                predictions: mx.array, targets: mx.array, reduction: str = "mean"
         | 
| 92 | 
            +
            ) -> mx.array:
         | 
| 93 | 
            +
                """
         | 
| 94 | 
            +
                Computes the L1 loss.
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                Args:
         | 
| 97 | 
            +
                    predictions (array): The predicted values.
         | 
| 98 | 
            +
                    targets (array): The target values.
         | 
| 99 | 
            +
                    reduction (str, optional): Specifies the reduction to apply to the output:
         | 
| 100 | 
            +
                      ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``.
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                Returns:
         | 
| 103 | 
            +
                    array: The computed L1 loss.
         | 
| 104 | 
            +
                """
         | 
| 105 | 
            +
                if predictions.shape != targets.shape:
         | 
| 106 | 
            +
                    raise ValueError(
         | 
| 107 | 
            +
                        f"Predictions shape {predictions.shape} does not match "
         | 
| 108 | 
            +
                        f"targets shape {targets.shape}."
         | 
| 109 | 
            +
                    )
         | 
| 110 | 
            +
                loss = mx.abs(predictions - targets)
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                return _reduce(loss, reduction)
         | 
| 113 | 
            +
             | 
| 114 | 
            +
             | 
| 115 | 
            +
            def mse_loss(
         | 
| 116 | 
            +
                predictions: mx.array, targets: mx.array, reduction: str = "mean"
         | 
| 117 | 
            +
            ) -> mx.array:
         | 
| 118 | 
            +
                """
         | 
| 119 | 
            +
                Computes the mean squared error loss.
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                Args:
         | 
| 122 | 
            +
                    predictions (array): The predicted values.
         | 
| 123 | 
            +
                    targets (array): The target values.
         | 
| 124 | 
            +
                    reduction (str, optional): Specifies the reduction to apply to the output:
         | 
| 125 | 
            +
                      ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``.
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                Returns:
         | 
| 128 | 
            +
                    array: The computed mean squared error loss.
         | 
| 129 | 
            +
                """
         | 
| 130 | 
            +
                if predictions.shape != targets.shape:
         | 
| 131 | 
            +
                    raise ValueError(
         | 
| 132 | 
            +
                        f"Predictions shape {predictions.shape} does not match "
         | 
| 133 | 
            +
                        f"targets shape {targets.shape}."
         | 
| 134 | 
            +
                    )
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                loss = mx.square(predictions - targets)
         | 
| 137 | 
            +
                return _reduce(loss, reduction)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
             | 
| 140 | 
            +
            def nll_loss(
         | 
| 141 | 
            +
                inputs: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none"
         | 
| 142 | 
            +
            ) -> mx.array:
         | 
| 143 | 
            +
                """
         | 
| 144 | 
            +
                Computes the negative log likelihood loss.
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                Args:
         | 
| 147 | 
            +
                    inputs (array): The predicted distribution in log space.
         | 
| 148 | 
            +
                    targets (array): The target values.
         | 
| 149 | 
            +
                    axis (int, optional): The distribution axis. Default: ``-1``.
         | 
| 150 | 
            +
                    reduction (str, optional): Specifies the reduction to apply to the output:
         | 
| 151 | 
            +
                      ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                Returns:
         | 
| 154 | 
            +
                    array: The computed NLL loss.
         | 
| 155 | 
            +
                """
         | 
| 156 | 
            +
                loss = -mx.take_along_axis(inputs, targets[..., None], axis).squeeze(-1)
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                return _reduce(loss, reduction)
         | 
| 159 | 
            +
             | 
| 160 | 
            +
             | 
| 161 | 
            +
            def kl_div_loss(
         | 
| 162 | 
            +
                inputs: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none"
         | 
| 163 | 
            +
            ) -> mx.array:
         | 
| 164 | 
            +
                """
         | 
| 165 | 
            +
                Computes the Kullback-Leibler divergence loss.
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                Computes the following when ``reduction == 'none'``:
         | 
| 168 | 
            +
             | 
| 169 | 
            +
                .. code-block:: python
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                    mx.exp(targets) * (targets - inputs).sum(axis)
         | 
| 172 | 
            +
             | 
| 173 | 
            +
                Args:
         | 
| 174 | 
            +
                    inputs (array): Log probabilities for the predicted distribution.
         | 
| 175 | 
            +
                    targets (array): Log probabilities for the target distribution.
         | 
| 176 | 
            +
                    axis (int, optional): The distribution axis. Default: ``-1``.
         | 
| 177 | 
            +
                    reduction (str, optional): Specifies the reduction to apply to the output:
         | 
| 178 | 
            +
                      ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
         | 
| 179 | 
            +
             | 
| 180 | 
            +
                Returns:
         | 
| 181 | 
            +
                    array: The computed Kullback-Leibler divergence loss.
         | 
| 182 | 
            +
                """
         | 
| 183 | 
            +
                loss = mx.sum(mx.exp(targets) * (targets - inputs), axis)
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                return _reduce(loss, reduction)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
             | 
| 188 | 
            +
            def smooth_l1_loss(
         | 
| 189 | 
            +
                predictions: mx.array, targets: mx.array, beta: float = 1.0, reduction: str = "mean"
         | 
| 190 | 
            +
            ) -> mx.array:
         | 
| 191 | 
            +
                r"""
         | 
| 192 | 
            +
                Computes the smooth L1 loss.
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                The smooth L1 loss is a variant of the L1 loss which replaces the absolute
         | 
| 195 | 
            +
                difference with a squared difference when the absolute difference is less
         | 
| 196 | 
            +
                than ``beta``.
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                The formula for the smooth L1 Loss is:
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                .. math::
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                   l =
         | 
| 203 | 
            +
                      \begin{cases}
         | 
| 204 | 
            +
                        0.5 (x - y)^2, & \text{ if } & (x - y) < \beta \\
         | 
| 205 | 
            +
                        |x - y| - 0.5 \beta, &  & \text{otherwise}
         | 
| 206 | 
            +
                      \end{cases}
         | 
| 207 | 
            +
             | 
| 208 | 
            +
                Args:
         | 
| 209 | 
            +
                    predictions (array): Predicted values.
         | 
| 210 | 
            +
                    targets (array): Ground truth values.
         | 
| 211 | 
            +
                    beta (float, optional): The threshold after which the loss changes
         | 
| 212 | 
            +
                      from the squared to the absolute difference. Default: ``1.0``.
         | 
| 213 | 
            +
                    reduction (str, optional): Specifies the reduction to apply to the output:
         | 
| 214 | 
            +
                      ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``.
         | 
| 215 | 
            +
             | 
| 216 | 
            +
                Returns:
         | 
| 217 | 
            +
                    array: The computed smooth L1 loss.
         | 
| 218 | 
            +
                """
         | 
| 219 | 
            +
                if predictions.shape != targets.shape:
         | 
| 220 | 
            +
                    raise ValueError(
         | 
| 221 | 
            +
                        f"Predictions shape {predictions.shape} does not match "
         | 
| 222 | 
            +
                        f"targets shape {targets.shape}."
         | 
| 223 | 
            +
                    )
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                diff = predictions - targets
         | 
| 226 | 
            +
                loss = mx.where(
         | 
| 227 | 
            +
                    diff < beta, 0.5 * mx.square(diff) / beta, mx.abs(diff) - 0.5 * beta
         | 
| 228 | 
            +
                )
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                return _reduce(loss, reduction)
         | 
| 231 | 
            +
             | 
| 232 | 
            +
             | 
| 233 | 
            +
            def triplet_loss(
         | 
| 234 | 
            +
                anchors: mx.array,
         | 
| 235 | 
            +
                positives: mx.array,
         | 
| 236 | 
            +
                negatives: mx.array,
         | 
| 237 | 
            +
                axis: int = -1,
         | 
| 238 | 
            +
                p: int = 2,
         | 
| 239 | 
            +
                margin: float = 1.0,
         | 
| 240 | 
            +
                eps: float = 1e-6,
         | 
| 241 | 
            +
                reduction: str = "none",
         | 
| 242 | 
            +
            ) -> mx.array:
         | 
| 243 | 
            +
                r"""
         | 
| 244 | 
            +
                Computes the triplet loss for a set of anchor, positive, and negative samples.
         | 
| 245 | 
            +
                Margin is represented with alpha in the math section.
         | 
| 246 | 
            +
             | 
| 247 | 
            +
                .. math::
         | 
| 248 | 
            +
             | 
| 249 | 
            +
                   L_{\text{triplet}} = \max\left(\|A - P\|_p - \|A - N\|_p + \alpha, 0\right)
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                Args:
         | 
| 252 | 
            +
                    anchors (array): The anchor samples.
         | 
| 253 | 
            +
                    positives (array): The positive samples.
         | 
| 254 | 
            +
                    negatives (array): The negative samples.
         | 
| 255 | 
            +
                    axis (int, optional): The distribution axis. Default: ``-1``.
         | 
| 256 | 
            +
                    p (int, optional): The norm degree for pairwise distance. Default: ``2``.
         | 
| 257 | 
            +
                    margin (float, optional): Margin for the triplet loss. Defaults to ``1.0``.
         | 
| 258 | 
            +
                    eps (float, optional): Small positive constant to prevent numerical instability. Defaults to ``1e-6``.
         | 
| 259 | 
            +
                    reduction (str, optional): Specifies the reduction to apply to the output:
         | 
| 260 | 
            +
                      ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                Returns:
         | 
| 263 | 
            +
                    array: Computed triplet loss. If reduction is "none", returns a tensor of the same shape as input;
         | 
| 264 | 
            +
                              if reduction is "mean" or "sum", returns a scalar tensor.
         | 
| 265 | 
            +
                """
         | 
| 266 | 
            +
                loss = mx.maximum(
         | 
| 267 | 
            +
                    mx.sqrt(mx.power(anchors - positives, p).sum(axis) + eps)
         | 
| 268 | 
            +
                    - mx.sqrt(mx.power(anchors - negatives, p).sum(axis) + eps)
         | 
| 269 | 
            +
                    + margin,
         | 
| 270 | 
            +
                    0,
         | 
| 271 | 
            +
                )
         | 
| 272 | 
            +
                return _reduce(loss, reduction)
         | 
| 273 | 
            +
             | 
| 274 | 
            +
             | 
| 275 | 
            +
            def _reduce(loss: mx.array, reduction: str = "none"):
         | 
| 276 | 
            +
                if reduction == "mean":
         | 
| 277 | 
            +
                    return mx.mean(loss)
         | 
| 278 | 
            +
                elif reduction == "sum":
         | 
| 279 | 
            +
                    return mx.sum(loss)
         | 
| 280 | 
            +
                elif reduction == "none":
         | 
| 281 | 
            +
                    return loss
         | 
| 282 | 
            +
                else:
         | 
| 283 | 
            +
                    raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
         | 
| 284 | 
            +
             | 
| 285 | 
            +
             | 
| 286 | 
            +
            def hinge_loss(
         | 
| 287 | 
            +
                inputs: mx.array, targets: mx.array, reduction: str = "none"
         | 
| 288 | 
            +
            ) -> mx.array:
         | 
| 289 | 
            +
                r"""
         | 
| 290 | 
            +
                Computes the hinge loss between inputs and targets.
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                .. math::
         | 
| 293 | 
            +
             | 
| 294 | 
            +
                   \text{hinge}(y, y_{\text{pred}}) = \max(0, 1 - y \cdot y_{\text{pred}})
         | 
| 295 | 
            +
             | 
| 296 | 
            +
             | 
| 297 | 
            +
                Args:
         | 
| 298 | 
            +
                    inputs (array): The predicted values.
         | 
| 299 | 
            +
                    targets (array): The target values. They should be -1 or 1.
         | 
| 300 | 
            +
                    reduction (str, optional): Specifies the reduction to apply to the output:
         | 
| 301 | 
            +
                      ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                Returns:
         | 
| 304 | 
            +
                    array: The computed hinge loss.
         | 
| 305 | 
            +
                """
         | 
| 306 | 
            +
                loss = mx.maximum(1 - inputs * targets, 0)
         | 
| 307 | 
            +
             | 
| 308 | 
            +
                return _reduce(loss, reduction)
         | 
| 309 | 
            +
             | 
| 310 | 
            +
             | 
| 311 | 
            +
            def huber_loss(
         | 
| 312 | 
            +
                inputs: mx.array, targets: mx.array, delta: float = 1.0, reduction: str = "none"
         | 
| 313 | 
            +
            ) -> mx.array:
         | 
| 314 | 
            +
                r"""
         | 
| 315 | 
            +
                Computes the Huber loss between inputs and targets.
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                .. math::
         | 
| 318 | 
            +
             | 
| 319 | 
            +
                    L_{\delta}(a) =
         | 
| 320 | 
            +
                    \left\{ \begin{array}{ll}
         | 
| 321 | 
            +
                        \frac{1}{2} a^2 & \text{for } |a| \leq \delta, \\
         | 
| 322 | 
            +
                        \delta \left( |a| - \frac{1}{2} \delta \right) & \text{otherwise.}
         | 
| 323 | 
            +
                    \end{array} \right.
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                Args:
         | 
| 326 | 
            +
                    inputs (array): The predicted values.
         | 
| 327 | 
            +
                    targets (array): The target values.
         | 
| 328 | 
            +
                    delta (float, optional): The threshold at which to change between L1 and L2 loss.
         | 
| 329 | 
            +
                      Default: ``1.0``.
         | 
| 330 | 
            +
                    reduction (str, optional): Specifies the reduction to apply to the output:
         | 
| 331 | 
            +
                      ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                Returns:
         | 
| 334 | 
            +
                    array: The computed Huber loss.
         | 
| 335 | 
            +
                """
         | 
| 336 | 
            +
                errors = inputs - targets
         | 
| 337 | 
            +
                abs_errors = mx.abs(errors)
         | 
| 338 | 
            +
                quadratic = mx.minimum(abs_errors, delta)
         | 
| 339 | 
            +
                linear = abs_errors - quadratic
         | 
| 340 | 
            +
                loss = 0.5 * quadratic**2 + delta * linear
         | 
| 341 | 
            +
             | 
| 342 | 
            +
                return _reduce(loss, reduction)
         | 
| 343 | 
            +
             | 
| 344 | 
            +
             | 
| 345 | 
            +
            def log_cosh_loss(
         | 
| 346 | 
            +
                inputs: mx.array, targets: mx.array, reduction: str = "none"
         | 
| 347 | 
            +
            ) -> mx.array:
         | 
| 348 | 
            +
                r"""
         | 
| 349 | 
            +
                Computes the log cosh loss between inputs and targets.
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                Logcosh acts like L2 loss for small errors, ensuring stable gradients,
         | 
| 352 | 
            +
                and like the L1 loss for large errors, reducing sensitivity to outliers. This
         | 
| 353 | 
            +
                dual behavior offers a balanced, robust approach for regression tasks.
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                .. math::
         | 
| 356 | 
            +
             | 
| 357 | 
            +
                   \text{logcosh}(y_{\text{true}}, y_{\text{pred}}) =
         | 
| 358 | 
            +
                        \frac{1}{n} \sum_{i=1}^{n}
         | 
| 359 | 
            +
                        \log(\cosh(y_{\text{pred}}^{(i)} - y_{\text{true}}^{(i)}))
         | 
| 360 | 
            +
             | 
| 361 | 
            +
             | 
| 362 | 
            +
                Args:
         | 
| 363 | 
            +
                    inputs (array): The predicted values.
         | 
| 364 | 
            +
                    targets (array): The target values.
         | 
| 365 | 
            +
                    reduction (str, optional): Specifies the reduction to apply to the output:
         | 
| 366 | 
            +
                      ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
         | 
| 367 | 
            +
             | 
| 368 | 
            +
                Returns:
         | 
| 369 | 
            +
                    array: The computed log cosh loss.
         | 
| 370 | 
            +
                """
         | 
| 371 | 
            +
                errors = inputs - targets
         | 
| 372 | 
            +
                loss = mx.logaddexp(errors, -errors) - math.log(2)
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                return _reduce(loss, reduction)
         | 
    	
        lib/python3.11/site-packages/mlx/nn/utils.py
    ADDED
    
    | @@ -0,0 +1,33 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright © 2023 Apple Inc.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from typing import Callable
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import mlx.core as mx
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            def value_and_grad(model: "mlx.nn.Module", fn: Callable):
         | 
| 9 | 
            +
                """Transform the passed function ``fn`` to a function that computes the
         | 
| 10 | 
            +
                gradients of ``fn`` wrt the model's trainable parameters and also its
         | 
| 11 | 
            +
                value.
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                Args:
         | 
| 14 | 
            +
                    model (mlx.nn.Module): The model whose trainable parameters to compute
         | 
| 15 | 
            +
                                           gradients for
         | 
| 16 | 
            +
                    fn (Callable): The scalar function to compute gradients for
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                Returns:
         | 
| 19 | 
            +
                    A callable that returns the value of ``fn`` and the gradients wrt the
         | 
| 20 | 
            +
                    trainable parameters of ``model``
         | 
| 21 | 
            +
                """
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                def inner_fn(params, *args, **kwargs):
         | 
| 24 | 
            +
                    model.update(params)
         | 
| 25 | 
            +
                    return fn(*args, **kwargs)
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                value_grad_fn = mx.value_and_grad(inner_fn)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                def wrapped_value_grad_fn(*args, **kwargs):
         | 
| 30 | 
            +
                    value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs)
         | 
| 31 | 
            +
                    return value, grad
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                return wrapped_value_grad_fn
         | 
    	
        lib/python3.11/site-packages/mlx/optimizers.py
    ADDED
    
    | @@ -0,0 +1,500 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright © 2023 Apple Inc.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import math
         | 
| 4 | 
            +
            from typing import List
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import mlx.core as mx
         | 
| 7 | 
            +
            from mlx.utils import tree_map
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            class OptimizerState(dict):
         | 
| 11 | 
            +
                """The optimizer state implements a recursively defined
         | 
| 12 | 
            +
                :class:`collections.defaultdict`, namely a missing key in an optimizer
         | 
| 13 | 
            +
                state is an :class:`OptimizerState`.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                .. note::
         | 
| 16 | 
            +
                   :meth:`OptimizerState.get` in contrast to a normal dictionary also sets
         | 
| 17 | 
            +
                   the key to the ``default`` value if the ``key`` was not present in the
         | 
| 18 | 
            +
                   dictionary.
         | 
| 19 | 
            +
                """
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                def __getitem__(self, key):
         | 
| 22 | 
            +
                    if key not in self:
         | 
| 23 | 
            +
                        self[key] = OptimizerState()
         | 
| 24 | 
            +
                    return super().__getitem__(key)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                def get(self, key, default):
         | 
| 27 | 
            +
                    """If ``key`` doesn't exist set its value to ``default`` and then return it."""
         | 
| 28 | 
            +
                    if key not in self:
         | 
| 29 | 
            +
                        self[key] = default
         | 
| 30 | 
            +
                    return super().__getitem__(key)
         | 
| 31 | 
            +
             | 
| 32 | 
            +
             | 
| 33 | 
            +
            class Optimizer:
         | 
| 34 | 
            +
                """The base class for all optimizers. It allows us to implement an
         | 
| 35 | 
            +
                optimizer on a per-parameter basis and apply it to a parameter tree.
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                Attributes:
         | 
| 38 | 
            +
                    state (OptimizerState): It holds the optimizer's state dictionary.
         | 
| 39 | 
            +
                """
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                def __init__(self):
         | 
| 42 | 
            +
                    self.state = OptimizerState()
         | 
| 43 | 
            +
             | 
| 44 | 
            +
                def update(self, model: "mlx.nn.Module", gradients: dict):
         | 
| 45 | 
            +
                    """Apply the gradients to the parameters of the model and update the
         | 
| 46 | 
            +
                    model with the new parameters.
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    Args:
         | 
| 49 | 
            +
                        model (mlx.nn.Module): An mlx module to be updated.
         | 
| 50 | 
            +
                        gradients (dict): A Python tree of gradients, most likely computed
         | 
| 51 | 
            +
                                          via :func:`mlx.nn.value_and_grad`.
         | 
| 52 | 
            +
                    """
         | 
| 53 | 
            +
                    model.update(self.apply_gradients(gradients, model))
         | 
| 54 | 
            +
             | 
| 55 | 
            +
                def apply_gradients(self, gradients: dict, model: dict):
         | 
| 56 | 
            +
                    """Apply the gradients to the parameters and return the updated parameters.
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    Can be used to update a model via
         | 
| 59 | 
            +
                    ``model.update(opt.apply_gradients(grads, model))`` which is precisely
         | 
| 60 | 
            +
                    how :meth:`Optimizer.update` is implemented.
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                    Args:
         | 
| 63 | 
            +
                        gradients (dict): A Python tree of gradients.
         | 
| 64 | 
            +
                        model (dict): A Python tree of parameters. It can be a superset of
         | 
| 65 | 
            +
                                      the gradients. In that case the returned python tree
         | 
| 66 | 
            +
                                      will be of the same structure as the gradients.
         | 
| 67 | 
            +
                    """
         | 
| 68 | 
            +
                    return tree_map(self.apply_single, gradients, model, self.state)
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                def apply_single(
         | 
| 71 | 
            +
                    self, gradient: mx.array, parameter: mx.array, state: OptimizerState
         | 
| 72 | 
            +
                ):
         | 
| 73 | 
            +
                    """To be extended by the children classes to implement each optimizer's
         | 
| 74 | 
            +
                    update."""
         | 
| 75 | 
            +
                    raise NotImplementedError()
         | 
| 76 | 
            +
             | 
| 77 | 
            +
             | 
| 78 | 
            +
            class SGD(Optimizer):
         | 
| 79 | 
            +
                r"""Stochastic gradient descent optimizer.
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                Updates a parameter :math:`w` with a gradient :math:`g` as follows
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                .. math::
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                    v_{t+1} &= \mu v_t + (1 - \tau) g_t \\
         | 
| 86 | 
            +
                    w_{t+1} &= w_t - \lambda v_{t+1}
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                Args:
         | 
| 89 | 
            +
                    learning_rate (float): The learning rate :math:`\lambda`.
         | 
| 90 | 
            +
                    momentum (float, optional): The momentum strength :math:`\mu`. Default: ``0``
         | 
| 91 | 
            +
                    weight_decay (float, optional): The weight decay (L2 penalty). Default: ``0``
         | 
| 92 | 
            +
                    dampening (float, optional): Dampening for momentum :math:`\tau`. Default: ``0``
         | 
| 93 | 
            +
                    nesterov (bool, optional): Enables Nesterov momentum. Default: ``False``
         | 
| 94 | 
            +
                """
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                def __init__(
         | 
| 97 | 
            +
                    self,
         | 
| 98 | 
            +
                    learning_rate: float,
         | 
| 99 | 
            +
                    momentum: float = 0.0,
         | 
| 100 | 
            +
                    weight_decay: float = 0.0,
         | 
| 101 | 
            +
                    dampening: float = 0.0,
         | 
| 102 | 
            +
                    nesterov: bool = False,
         | 
| 103 | 
            +
                ):
         | 
| 104 | 
            +
                    if nesterov and (momentum <= 0 or dampening != 0):
         | 
| 105 | 
            +
                        raise ValueError(
         | 
| 106 | 
            +
                            "Nesterov momentum requires a momentum and zero dampening."
         | 
| 107 | 
            +
                        )
         | 
| 108 | 
            +
                    super().__init__()
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    self.learning_rate = learning_rate
         | 
| 111 | 
            +
                    self.momentum = momentum
         | 
| 112 | 
            +
                    self.weight_decay = weight_decay
         | 
| 113 | 
            +
                    self.dampening = dampening
         | 
| 114 | 
            +
                    self.nesterov = nesterov
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                def apply_single(
         | 
| 117 | 
            +
                    self, gradient: mx.array, parameter: mx.array, state: OptimizerState
         | 
| 118 | 
            +
                ):
         | 
| 119 | 
            +
                    """Performs the SGD parameter update and stores :math:`v` in the
         | 
| 120 | 
            +
                    optimizer state."""
         | 
| 121 | 
            +
                    if self.momentum <= 0:
         | 
| 122 | 
            +
                        return parameter - self.learning_rate * gradient
         | 
| 123 | 
            +
             | 
| 124 | 
            +
                    v = state.get("v", mx.zeros_like(gradient))
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                    if self.weight_decay != 0:
         | 
| 127 | 
            +
                        gradient += self.weight_decay * parameter
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    v = self.momentum * v
         | 
| 130 | 
            +
                    if self.dampening > 0:
         | 
| 131 | 
            +
                        v += (1 - self.dampening) * gradient
         | 
| 132 | 
            +
                    else:
         | 
| 133 | 
            +
                        v += gradient
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                    if self.nesterov:
         | 
| 136 | 
            +
                        update = gradient + self.momentum * v
         | 
| 137 | 
            +
                    else:
         | 
| 138 | 
            +
                        update = v
         | 
| 139 | 
            +
                    state["v"] = v
         | 
| 140 | 
            +
                    return parameter - self.learning_rate * update
         | 
| 141 | 
            +
             | 
| 142 | 
            +
             | 
| 143 | 
            +
            class RMSprop(Optimizer):
         | 
| 144 | 
            +
                r"""Implementation of the RMSprop optimizer [1].
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                [1]: Tieleman, T. and Hinton, G. 2012. Lecture 6.5-rmsprop, coursera: Neural networks for machine learning
         | 
| 147 | 
            +
             | 
| 148 | 
            +
                .. math::
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    v_{t+1} &= \alpha v_t + (1 - \alpha) g_t^2 \\
         | 
| 151 | 
            +
                    w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                Args:
         | 
| 154 | 
            +
                    learning_rate (float): The learning rate :math:`\lambda`.
         | 
| 155 | 
            +
                    alpha (float, optional): The smoothing constant :math:`\alpha`.
         | 
| 156 | 
            +
                      Default: ``0.99``
         | 
| 157 | 
            +
                    eps (float, optional): The term :math:`\epsilon` added to the denominator
         | 
| 158 | 
            +
                      to improve numerical stability. Default: ``1e-8``
         | 
| 159 | 
            +
                """
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                def __init__(self, learning_rate: float, alpha: float = 0.99, eps: float = 1e-8):
         | 
| 162 | 
            +
                    super().__init__()
         | 
| 163 | 
            +
             | 
| 164 | 
            +
                    self.learning_rate = learning_rate
         | 
| 165 | 
            +
                    self.alpha = alpha
         | 
| 166 | 
            +
                    self.eps = eps
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    if self.alpha < 0.0:
         | 
| 169 | 
            +
                        raise ValueError(
         | 
| 170 | 
            +
                            f"RMSprop alpha should be >=0, {self.alpha} was provided instead"
         | 
| 171 | 
            +
                        )
         | 
| 172 | 
            +
                    if self.eps < 0.0:
         | 
| 173 | 
            +
                        raise ValueError(
         | 
| 174 | 
            +
                            f"RMSprop epsilon should be >0, {self.eps} was provided instead"
         | 
| 175 | 
            +
                        )
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                def apply_single(
         | 
| 178 | 
            +
                    self, gradient: mx.array, parameter: mx.array, state: OptimizerState
         | 
| 179 | 
            +
                ):
         | 
| 180 | 
            +
                    """Performs the RMSprop parameter update and stores :math:`v` in the optimizer state."""
         | 
| 181 | 
            +
                    lr = self.learning_rate
         | 
| 182 | 
            +
                    alpha = self.alpha
         | 
| 183 | 
            +
                    eps = self.eps
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                    v = state.get("v", mx.zeros_like(gradient))
         | 
| 186 | 
            +
                    v = alpha * v + (1 - alpha) * mx.square(gradient)
         | 
| 187 | 
            +
                    state["v"] = v
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                    return parameter - lr * gradient / (mx.sqrt(v) + eps)
         | 
| 190 | 
            +
             | 
| 191 | 
            +
             | 
| 192 | 
            +
            class Adagrad(Optimizer):
         | 
| 193 | 
            +
                r"""Implementation of the Adagrad optimizer [1].
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                Our Adagrad implementation follows the original paper. In detail,
         | 
| 196 | 
            +
             | 
| 197 | 
            +
                [1]: Duchi, J., Hazan, E. and Singer, Y., 2011. Adaptive subgradient methods
         | 
| 198 | 
            +
                for online learning and stochastic optimization. JMLR 2011.
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                .. math::
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    v_{t+1} &= v_t + g_t^2 \\
         | 
| 203 | 
            +
                    w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                Args:
         | 
| 206 | 
            +
                    learning_rate (float): The learning rate :math:`\lambda`.
         | 
| 207 | 
            +
                    eps (float, optional): The term :math:`\epsilon` added to the
         | 
| 208 | 
            +
                      denominator to improve numerical stability. Default: ``1e-8``
         | 
| 209 | 
            +
                """
         | 
| 210 | 
            +
             | 
| 211 | 
            +
                def __init__(self, learning_rate: float, eps: float = 1e-8):
         | 
| 212 | 
            +
                    super().__init__()
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                    self.learning_rate = learning_rate
         | 
| 215 | 
            +
                    self.eps = eps
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                    if self.eps < 0.0:
         | 
| 218 | 
            +
                        raise ValueError(
         | 
| 219 | 
            +
                            f"Adagrad epsilon should be >0, {self.eps} was provided instead"
         | 
| 220 | 
            +
                        )
         | 
| 221 | 
            +
             | 
| 222 | 
            +
                def apply_single(
         | 
| 223 | 
            +
                    self, gradient: mx.array, parameter: mx.array, state: OptimizerState
         | 
| 224 | 
            +
                ):
         | 
| 225 | 
            +
                    """Performs the Adagrad parameter update and stores :math:`v` in the
         | 
| 226 | 
            +
                    optimizer state."""
         | 
| 227 | 
            +
                    lr = self.learning_rate
         | 
| 228 | 
            +
                    eps = self.eps
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                    v = state.get("v", mx.zeros_like(gradient))
         | 
| 231 | 
            +
                    v = v + mx.square(gradient)
         | 
| 232 | 
            +
                    state["v"] = v
         | 
| 233 | 
            +
             | 
| 234 | 
            +
                    return parameter - lr * gradient / (mx.sqrt(v) + eps)
         | 
| 235 | 
            +
             | 
| 236 | 
            +
             | 
| 237 | 
            +
            class AdaDelta(Optimizer):
         | 
| 238 | 
            +
                r"""Implementation of the AdaDelta optimizer with learning rate[1].
         | 
| 239 | 
            +
             | 
| 240 | 
            +
                Our AdaDelta implementation follows the original paper. In detail,
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                [1]: Zeiler, M.D., 2012. ADADELTA: an adaptive learning rate method. arXiv preprint arXiv:1212.5701.
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                .. math::
         | 
| 245 | 
            +
             | 
| 246 | 
            +
                    v_{t+1} &= \rho v_t + (1 - \rho) g_t^2 \\
         | 
| 247 | 
            +
                    \Delta w_{t+1} &= \frac{\sqrt{u_t + \epsilon}}{\sqrt{v_{t+1} + \epsilon}} g_t \\
         | 
| 248 | 
            +
                    u_{t+1} &= \rho u_t + (1 - \rho) \Delta w_{t+1}^2 \\
         | 
| 249 | 
            +
                    w_{t+1} &= w_t - \lambda \Delta w_{t+1}
         | 
| 250 | 
            +
             | 
| 251 | 
            +
                Args:
         | 
| 252 | 
            +
                    learning_rate (float): The learning rate :math:`\lambda`.
         | 
| 253 | 
            +
                    rho (float, optional): The coefficient :math:`\rho` used for computing a
         | 
| 254 | 
            +
                        running average of squared gradients. Default: ``0.9``
         | 
| 255 | 
            +
                    eps (float, optional): The term :math:`\epsilon` added to the denominator to improve
         | 
| 256 | 
            +
                      numerical stability. Default: `1e-8`
         | 
| 257 | 
            +
                """
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                def __init__(self, learning_rate: float, rho: float = 0.9, eps: float = 1e-6):
         | 
| 260 | 
            +
                    super().__init__()
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                    self.learning_rate = learning_rate
         | 
| 263 | 
            +
                    self.rho = rho
         | 
| 264 | 
            +
                    self.eps = eps
         | 
| 265 | 
            +
                    if self.rho < 0.0:
         | 
| 266 | 
            +
                        raise ValueError(
         | 
| 267 | 
            +
                            f"AdaDelta rho should be >=0, {self.rho} was provided instead"
         | 
| 268 | 
            +
                        )
         | 
| 269 | 
            +
                    if self.eps < 0.0:
         | 
| 270 | 
            +
                        raise ValueError(
         | 
| 271 | 
            +
                            f"AdaDelta epsilon should be >0, {self.eps} was provided instead"
         | 
| 272 | 
            +
                        )
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                def apply_single(
         | 
| 275 | 
            +
                    self, gradient: mx.array, parameter: mx.array, state: OptimizerState
         | 
| 276 | 
            +
                ):
         | 
| 277 | 
            +
                    """Performs the AdaDelta parameter update and stores :math:`v` and
         | 
| 278 | 
            +
                    :math:`u` in the optimizer state."""
         | 
| 279 | 
            +
                    lr = self.learning_rate
         | 
| 280 | 
            +
                    rho = self.rho
         | 
| 281 | 
            +
                    eps = self.eps
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                    v = state.get("v", mx.zeros_like(gradient))
         | 
| 284 | 
            +
                    u = state.get("s", mx.zeros_like(gradient))
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                    v = rho * v + (1 - rho) * mx.square(gradient)
         | 
| 287 | 
            +
                    d = mx.sqrt(u + eps) / mx.sqrt(v + eps) * gradient
         | 
| 288 | 
            +
                    u = rho * u + (1 - rho) * mx.square(d)
         | 
| 289 | 
            +
             | 
| 290 | 
            +
                    state["v"] = v
         | 
| 291 | 
            +
                    state["u"] = u
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                    return parameter - lr * d
         | 
| 294 | 
            +
             | 
| 295 | 
            +
             | 
| 296 | 
            +
            class Adam(Optimizer):
         | 
| 297 | 
            +
                r"""Implementation of the Adam optimizer [1].
         | 
| 298 | 
            +
             | 
| 299 | 
            +
                Our Adam implementation follows the original paper and omits the bias
         | 
| 300 | 
            +
                correction in the first and second moment estimates. In detail,
         | 
| 301 | 
            +
             | 
| 302 | 
            +
                [1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic
         | 
| 303 | 
            +
                optimization. ICLR 2015.
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                .. math::
         | 
| 306 | 
            +
             | 
| 307 | 
            +
                    m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
         | 
| 308 | 
            +
                    v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
         | 
| 309 | 
            +
                    w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}}
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                Args:
         | 
| 312 | 
            +
                    learning_rate (float): The learning rate :math:`\lambda`.
         | 
| 313 | 
            +
                    betas (Tuple[float, float], optional): The coefficients
         | 
| 314 | 
            +
                      :math:`(\beta_1, \beta_2)` used for computing running averages of the
         | 
| 315 | 
            +
                      gradient and its square. Default: ``(0.9, 0.999)``
         | 
| 316 | 
            +
                    eps (float, optional): The term :math:`\epsilon` added to the
         | 
| 317 | 
            +
                      denominator to improve numerical stability. Default: ``1e-8``
         | 
| 318 | 
            +
                """
         | 
| 319 | 
            +
             | 
| 320 | 
            +
                def __init__(
         | 
| 321 | 
            +
                    self, learning_rate: float, betas: List[float] = [0.9, 0.999], eps: float = 1e-8
         | 
| 322 | 
            +
                ):
         | 
| 323 | 
            +
                    super().__init__()
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                    self.learning_rate = learning_rate
         | 
| 326 | 
            +
                    self.betas = betas
         | 
| 327 | 
            +
                    self.eps = eps
         | 
| 328 | 
            +
             | 
| 329 | 
            +
                def apply_single(
         | 
| 330 | 
            +
                    self, gradient: mx.array, parameter: mx.array, state: OptimizerState
         | 
| 331 | 
            +
                ):
         | 
| 332 | 
            +
                    """Performs the Adam parameter update and stores :math:`v` and
         | 
| 333 | 
            +
                    :math:`m` in the optimizer state."""
         | 
| 334 | 
            +
                    lr = self.learning_rate
         | 
| 335 | 
            +
                    b1, b2 = self.betas
         | 
| 336 | 
            +
                    eps = self.eps
         | 
| 337 | 
            +
             | 
| 338 | 
            +
                    m = state.get("m", gradient)
         | 
| 339 | 
            +
                    v = state.get("v", mx.square(gradient))
         | 
| 340 | 
            +
                    m = b1 * m + (1 - b1) * gradient
         | 
| 341 | 
            +
                    v = b2 * v + (1 - b2) * mx.square(gradient)
         | 
| 342 | 
            +
                    state["m"] = m
         | 
| 343 | 
            +
                    state["v"] = v
         | 
| 344 | 
            +
             | 
| 345 | 
            +
                    return parameter - lr * m / (mx.sqrt(v) + eps)
         | 
| 346 | 
            +
             | 
| 347 | 
            +
             | 
| 348 | 
            +
            class AdamW(Adam):
         | 
| 349 | 
            +
                r"""Implementation of the AdamW optimizer [1].
         | 
| 350 | 
            +
             | 
| 351 | 
            +
                Following the above convention, in contrast with [1], we do not use bias
         | 
| 352 | 
            +
                correction in the first and second moments for AdamW. We update the weights
         | 
| 353 | 
            +
                with a weight_decay (:math:`\lambda`) value:
         | 
| 354 | 
            +
             | 
| 355 | 
            +
                [1]: Loshchilov, I. and Hutter, F., 2019. Decoupled weight decay
         | 
| 356 | 
            +
                regularization. ICLR 2019.
         | 
| 357 | 
            +
             | 
| 358 | 
            +
                .. math::
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                    m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
         | 
| 361 | 
            +
                    v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
         | 
| 362 | 
            +
                    w_{t+1} &= w_t - \alpha (\frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}} + \lambda w_t)
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                Args:
         | 
| 365 | 
            +
                    learning_rate (float): The learning rate :math:`\alpha`.
         | 
| 366 | 
            +
                    betas (Tuple[float, float], optional): The coefficients
         | 
| 367 | 
            +
                      :math:`(\beta_1, \beta_2)` used for computing running averages of the
         | 
| 368 | 
            +
                      gradient and its square. Default: ``(0.9, 0.999)``
         | 
| 369 | 
            +
                    eps (float, optional): The term :math:`\epsilon` added to the
         | 
| 370 | 
            +
                      denominator to improve numerical stability. Default: ``1e-8``
         | 
| 371 | 
            +
                    weight_decay (float, optional): The weight decay :math:`\lambda`.
         | 
| 372 | 
            +
                      Default: ``0``.
         | 
| 373 | 
            +
                """
         | 
| 374 | 
            +
             | 
| 375 | 
            +
                def __init__(
         | 
| 376 | 
            +
                    self,
         | 
| 377 | 
            +
                    learning_rate: float,
         | 
| 378 | 
            +
                    betas: List[float] = [0.9, 0.999],
         | 
| 379 | 
            +
                    eps: float = 1e-8,
         | 
| 380 | 
            +
                    weight_decay: float = 0.01,
         | 
| 381 | 
            +
                ):
         | 
| 382 | 
            +
                    super().__init__(learning_rate=learning_rate, betas=betas, eps=eps)
         | 
| 383 | 
            +
                    self.weight_decay = weight_decay
         | 
| 384 | 
            +
             | 
| 385 | 
            +
                def apply_single(
         | 
| 386 | 
            +
                    self, gradient: mx.array, parameter: mx.array, state: OptimizerState
         | 
| 387 | 
            +
                ):
         | 
| 388 | 
            +
                    """Performs the AdamW parameter update by modifying the parameters
         | 
| 389 | 
            +
                    passed into Adam.
         | 
| 390 | 
            +
                    """
         | 
| 391 | 
            +
             | 
| 392 | 
            +
                    return super().apply_single(
         | 
| 393 | 
            +
                        gradient, parameter * (1 - self.learning_rate * self.weight_decay), state
         | 
| 394 | 
            +
                    )
         | 
| 395 | 
            +
             | 
| 396 | 
            +
             | 
| 397 | 
            +
            class Adamax(Adam):
         | 
| 398 | 
            +
                r"""Implementation of the Adamax optimizer. It is a variant of Adam based
         | 
| 399 | 
            +
                on the infinity norm [1].
         | 
| 400 | 
            +
             | 
| 401 | 
            +
                Our Adam implementation follows the original paper and omits the bias
         | 
| 402 | 
            +
                correction in the first and second moment estimates. In detail,
         | 
| 403 | 
            +
             | 
| 404 | 
            +
                [1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic
         | 
| 405 | 
            +
                optimization. ICLR 2015.
         | 
| 406 | 
            +
             | 
| 407 | 
            +
                .. math::
         | 
| 408 | 
            +
             | 
| 409 | 
            +
                    m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
         | 
| 410 | 
            +
                    v_{t+1} &= \max(\beta_2 v_t, |g_t|) \\
         | 
| 411 | 
            +
                    w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{v_{t+1} + \epsilon}
         | 
| 412 | 
            +
             | 
| 413 | 
            +
                Args:
         | 
| 414 | 
            +
                    learning_rate (float): The learning rate :math:`\lambda`.
         | 
| 415 | 
            +
                    betas (Tuple[float, float], optional): The coefficients
         | 
| 416 | 
            +
                      :math:`(\beta_1, \beta_2)` used for computing running averages of the
         | 
| 417 | 
            +
                      gradient and its square. Default: ``(0.9, 0.999)``
         | 
| 418 | 
            +
                    eps (float, optional): The term :math:`\epsilon` added to the
         | 
| 419 | 
            +
                      denominator to improve numerical stability. Default: ``1e-8``
         | 
| 420 | 
            +
                """
         | 
| 421 | 
            +
             | 
| 422 | 
            +
                def __init__(
         | 
| 423 | 
            +
                    self, learning_rate: float, betas: List[float] = [0.9, 0.999], eps: float = 1e-8
         | 
| 424 | 
            +
                ):
         | 
| 425 | 
            +
                    super().__init__(learning_rate, betas, eps)
         | 
| 426 | 
            +
             | 
| 427 | 
            +
                def apply_single(
         | 
| 428 | 
            +
                    self, gradient: mx.array, parameter: mx.array, state: OptimizerState
         | 
| 429 | 
            +
                ):
         | 
| 430 | 
            +
                    """Performs the Adamax parameter update and stores :math:`v` and
         | 
| 431 | 
            +
                    :math:`m` in the optimizer state."""
         | 
| 432 | 
            +
                    lr = self.learning_rate
         | 
| 433 | 
            +
                    b1, b2 = self.betas
         | 
| 434 | 
            +
                    eps = self.eps
         | 
| 435 | 
            +
             | 
| 436 | 
            +
                    m = state.get("m", mx.zeros_like(gradient))
         | 
| 437 | 
            +
                    v = state.get("v", mx.zeros_like(gradient))
         | 
| 438 | 
            +
             | 
| 439 | 
            +
                    m = b1 * m + (1 - b1) * gradient
         | 
| 440 | 
            +
                    v = mx.maximum(b2 * v, mx.abs(gradient))
         | 
| 441 | 
            +
                    state["m"] = m
         | 
| 442 | 
            +
                    state["v"] = v
         | 
| 443 | 
            +
             | 
| 444 | 
            +
                    return parameter - lr * m / (v + eps)
         | 
| 445 | 
            +
             | 
| 446 | 
            +
             | 
| 447 | 
            +
            class Lion(Optimizer):
         | 
| 448 | 
            +
                r"""Implementation of the Lion optimizer [1].
         | 
| 449 | 
            +
             | 
| 450 | 
            +
                Since updates are computed through the sign operation, they tend to
         | 
| 451 | 
            +
                have larger norm than for other optimizers such as SGD and Adam.
         | 
| 452 | 
            +
                We recommend a learning rate that is 3-10x smaller than AdamW and a
         | 
| 453 | 
            +
                weight decay 3-10x larger than AdamW to maintain the strength
         | 
| 454 | 
            +
                (lr * wd). Our Lion implementation follows the original paper. In
         | 
| 455 | 
            +
                detail,
         | 
| 456 | 
            +
             | 
| 457 | 
            +
                [1]: Chen, X. Symbolic Discovery of Optimization Algorithms. arXiv
         | 
| 458 | 
            +
                preprint arXiv:2302.06675.
         | 
| 459 | 
            +
             | 
| 460 | 
            +
                .. math::
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                    c_{t + 1} &= \beta_1 m_t + (1 - \beta_1) g_t
         | 
| 463 | 
            +
                    m_{t + 1} &= \beta_2 m_t + (1 - \beta_2) g_t
         | 
| 464 | 
            +
                    w_{t + 1} &= w_t - \eta (\text{sign}(c_t) + \lambda w_t)
         | 
| 465 | 
            +
             | 
| 466 | 
            +
                Args:
         | 
| 467 | 
            +
                    learning_rate (float): The learning rate :math:`\eta`.
         | 
| 468 | 
            +
                    betas (Tuple[float, float], optional): The coefficients
         | 
| 469 | 
            +
                      :math:`(\beta_1, \beta_2)` used for computing the gradient
         | 
| 470 | 
            +
                      momentum and update direction. Default: ``(0.9, 0.99)``
         | 
| 471 | 
            +
                    weight_decay (float, optional): The weight decay :math:`\lambda`. Default: ``0.0``
         | 
| 472 | 
            +
                """
         | 
| 473 | 
            +
             | 
| 474 | 
            +
                def __init__(
         | 
| 475 | 
            +
                    self,
         | 
| 476 | 
            +
                    learning_rate: float,
         | 
| 477 | 
            +
                    betas: List[float] = [0.9, 0.99],
         | 
| 478 | 
            +
                    weight_decay: float = 0.0,
         | 
| 479 | 
            +
                ):
         | 
| 480 | 
            +
                    super().__init__()
         | 
| 481 | 
            +
             | 
| 482 | 
            +
                    self.learning_rate = learning_rate
         | 
| 483 | 
            +
                    self.betas = betas
         | 
| 484 | 
            +
                    self.weight_decay = weight_decay
         | 
| 485 | 
            +
             | 
| 486 | 
            +
                def apply_single(
         | 
| 487 | 
            +
                    self, gradient: mx.array, parameter: mx.array, state: OptimizerState
         | 
| 488 | 
            +
                ):
         | 
| 489 | 
            +
                    """Performs the Lion parameter update and stores :math:`m`
         | 
| 490 | 
            +
                    in the optimizer state."""
         | 
| 491 | 
            +
                    lr = self.learning_rate
         | 
| 492 | 
            +
                    b1, b2 = self.betas
         | 
| 493 | 
            +
                    weight_decay = self.weight_decay
         | 
| 494 | 
            +
             | 
| 495 | 
            +
                    m = state.get("m", gradient)
         | 
| 496 | 
            +
                    c = b1 * m + (1 - b1) * gradient
         | 
| 497 | 
            +
                    state["m"] = b2 * m + (1 - b2) * gradient
         | 
| 498 | 
            +
                    if weight_decay > 0:
         | 
| 499 | 
            +
                        parameter = (1 - lr * weight_decay) * parameter
         | 
| 500 | 
            +
                    return parameter - lr * mx.sign(c)
         | 
    	
        lib/python3.11/site-packages/mlx/share/cmake/MLX/MLXConfig.cmake
    ADDED
    
    | @@ -0,0 +1,57 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Find MLX
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Defines the following variables:
         | 
| 4 | 
            +
            #
         | 
| 5 | 
            +
            #   MLX_FOUND            : True if MLX is found
         | 
| 6 | 
            +
            #   MLX_INCLUDE_DIRS     : Include directory
         | 
| 7 | 
            +
            #   MLX_LIBRARIES        : Libraries to link against
         | 
| 8 | 
            +
            #   MLX_CXX_FLAGS        : Additional compiler flags
         | 
| 9 | 
            +
            #   MLX_BUILD_ACCELERATE : True if MLX was built with accelerate 
         | 
| 10 | 
            +
            #   MLX_BUILD_METAL      : True if MLX was built with metal 
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            ####### Expanded from @PACKAGE_INIT@ by configure_package_config_file() #######
         | 
| 14 | 
            +
            ####### Any changes to this file will be overwritten by the next CMake run ####
         | 
| 15 | 
            +
            ####### The input file was mlx.pc.in                            ########
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            get_filename_component(PACKAGE_PREFIX_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../" ABSOLUTE)
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            macro(set_and_check _var _file)
         | 
| 20 | 
            +
              set(${_var} "${_file}")
         | 
| 21 | 
            +
              if(NOT EXISTS "${_file}")
         | 
| 22 | 
            +
                message(FATAL_ERROR "File or directory ${_file} referenced by variable ${_var} does not exist !")
         | 
| 23 | 
            +
              endif()
         | 
| 24 | 
            +
            endmacro()
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            ####################################################################################
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            include(${PACKAGE_PREFIX_DIR}/share/cmake/MLX/MLXTargets.cmake)
         | 
| 29 | 
            +
            include(${PACKAGE_PREFIX_DIR}/share/cmake/MLX/extension.cmake)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            set_and_check(MLX_LIBRARY_DIRS ${PACKAGE_PREFIX_DIR}/lib)
         | 
| 32 | 
            +
            set_and_check(MLX_INCLUDE_DIRS ${PACKAGE_PREFIX_DIR}/include)
         | 
| 33 | 
            +
            set(MLX_LIBRARIES mlx)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            find_library(MLX_LIBRARY mlx PATHS ${MLX_LIBRARY_DIRS})
         | 
| 36 | 
            +
             | 
| 37 | 
            +
            if (ON)
         | 
| 38 | 
            +
                set(MLX_BUILD_ACCELERATE ON)
         | 
| 39 | 
            +
                set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -DACCELERATE_NEW_LAPACK)
         | 
| 40 | 
            +
            endif()
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            if (ON)
         | 
| 43 | 
            +
                set(MLX_BUILD_METAL ON)
         | 
| 44 | 
            +
                set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_)
         | 
| 45 | 
            +
                set_and_check(MLX_INCLUDE_DIRS 
         | 
| 46 | 
            +
                    ${MLX_INCLUDE_DIRS} 
         | 
| 47 | 
            +
                    ${PACKAGE_PREFIX_DIR}/include/metal_cpp
         | 
| 48 | 
            +
                )
         | 
| 49 | 
            +
            endif()
         | 
| 50 | 
            +
             | 
| 51 | 
            +
            set_target_properties(mlx PROPERTIES
         | 
| 52 | 
            +
                CXX_STANDARD 17
         | 
| 53 | 
            +
                INTERFACE_COMPILE_OPTIONS "${MLX_CXX_FLAGS}"
         | 
| 54 | 
            +
            )
         | 
| 55 | 
            +
             | 
| 56 | 
            +
            include(FindPackageHandleStandardArgs)
         | 
| 57 | 
            +
            find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS)
         | 
    	
        lib/python3.11/site-packages/mlx/share/cmake/MLX/MLXConfigVersion.cmake
    ADDED
    
    | @@ -0,0 +1,65 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # This is a basic version file for the Config-mode of find_package().
         | 
| 2 | 
            +
            # It is used by write_basic_package_version_file() as input file for configure_file()
         | 
| 3 | 
            +
            # to create a version-file which can be installed along a config.cmake file.
         | 
| 4 | 
            +
            #
         | 
| 5 | 
            +
            # The created file sets PACKAGE_VERSION_EXACT if the current version string and
         | 
| 6 | 
            +
            # the requested version string are exactly the same and it sets
         | 
| 7 | 
            +
            # PACKAGE_VERSION_COMPATIBLE if the current version is >= requested version,
         | 
| 8 | 
            +
            # but only if the requested major version is the same as the current one.
         | 
| 9 | 
            +
            # The variable CVF_VERSION must be set before calling configure_file().
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            set(PACKAGE_VERSION "0.0.7")
         | 
| 13 | 
            +
             | 
| 14 | 
            +
            if(PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION)
         | 
| 15 | 
            +
              set(PACKAGE_VERSION_COMPATIBLE FALSE)
         | 
| 16 | 
            +
            else()
         | 
| 17 | 
            +
             | 
| 18 | 
            +
              if("0.0.7" MATCHES "^([0-9]+)\\.")
         | 
| 19 | 
            +
                set(CVF_VERSION_MAJOR "${CMAKE_MATCH_1}")
         | 
| 20 | 
            +
                if(NOT CVF_VERSION_MAJOR VERSION_EQUAL 0)
         | 
| 21 | 
            +
                  string(REGEX REPLACE "^0+" "" CVF_VERSION_MAJOR "${CVF_VERSION_MAJOR}")
         | 
| 22 | 
            +
                endif()
         | 
| 23 | 
            +
              else()
         | 
| 24 | 
            +
                set(CVF_VERSION_MAJOR "0.0.7")
         | 
| 25 | 
            +
              endif()
         | 
| 26 | 
            +
             | 
| 27 | 
            +
              if(PACKAGE_FIND_VERSION_RANGE)
         | 
| 28 | 
            +
                # both endpoints of the range must have the expected major version
         | 
| 29 | 
            +
                math (EXPR CVF_VERSION_MAJOR_NEXT "${CVF_VERSION_MAJOR} + 1")
         | 
| 30 | 
            +
                if (NOT PACKAGE_FIND_VERSION_MIN_MAJOR STREQUAL CVF_VERSION_MAJOR
         | 
| 31 | 
            +
                    OR ((PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "INCLUDE" AND NOT PACKAGE_FIND_VERSION_MAX_MAJOR STREQUAL CVF_VERSION_MAJOR)
         | 
| 32 | 
            +
                      OR (PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "EXCLUDE" AND NOT PACKAGE_FIND_VERSION_MAX VERSION_LESS_EQUAL CVF_VERSION_MAJOR_NEXT)))
         | 
| 33 | 
            +
                  set(PACKAGE_VERSION_COMPATIBLE FALSE)
         | 
| 34 | 
            +
                elseif(PACKAGE_FIND_VERSION_MIN_MAJOR STREQUAL CVF_VERSION_MAJOR
         | 
| 35 | 
            +
                    AND ((PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "INCLUDE" AND PACKAGE_VERSION VERSION_LESS_EQUAL PACKAGE_FIND_VERSION_MAX)
         | 
| 36 | 
            +
                    OR (PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "EXCLUDE" AND PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION_MAX)))
         | 
| 37 | 
            +
                  set(PACKAGE_VERSION_COMPATIBLE TRUE)
         | 
| 38 | 
            +
                else()
         | 
| 39 | 
            +
                  set(PACKAGE_VERSION_COMPATIBLE FALSE)
         | 
| 40 | 
            +
                endif()
         | 
| 41 | 
            +
              else()
         | 
| 42 | 
            +
                if(PACKAGE_FIND_VERSION_MAJOR STREQUAL CVF_VERSION_MAJOR)
         | 
| 43 | 
            +
                  set(PACKAGE_VERSION_COMPATIBLE TRUE)
         | 
| 44 | 
            +
                else()
         | 
| 45 | 
            +
                  set(PACKAGE_VERSION_COMPATIBLE FALSE)
         | 
| 46 | 
            +
                endif()
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                if(PACKAGE_FIND_VERSION STREQUAL PACKAGE_VERSION)
         | 
| 49 | 
            +
                  set(PACKAGE_VERSION_EXACT TRUE)
         | 
| 50 | 
            +
                endif()
         | 
| 51 | 
            +
              endif()
         | 
| 52 | 
            +
            endif()
         | 
| 53 | 
            +
             | 
| 54 | 
            +
             | 
| 55 | 
            +
            # if the installed or the using project don't have CMAKE_SIZEOF_VOID_P set, ignore it:
         | 
| 56 | 
            +
            if("${CMAKE_SIZEOF_VOID_P}" STREQUAL "" OR "8" STREQUAL "")
         | 
| 57 | 
            +
              return()
         | 
| 58 | 
            +
            endif()
         | 
| 59 | 
            +
             | 
| 60 | 
            +
            # check that the installed version has the same 32/64bit-ness as the one which is currently searching:
         | 
| 61 | 
            +
            if(NOT CMAKE_SIZEOF_VOID_P STREQUAL "8")
         | 
| 62 | 
            +
              math(EXPR installedBits "8 * 8")
         | 
| 63 | 
            +
              set(PACKAGE_VERSION "${PACKAGE_VERSION} (${installedBits}bit)")
         | 
| 64 | 
            +
              set(PACKAGE_VERSION_UNSUITABLE TRUE)
         | 
| 65 | 
            +
            endif()
         | 
    	
        lib/python3.11/site-packages/mlx/share/cmake/MLX/MLXTargets-release.cmake
    ADDED
    
    | @@ -0,0 +1,19 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #----------------------------------------------------------------
         | 
| 2 | 
            +
            # Generated CMake target import file for configuration "Release".
         | 
| 3 | 
            +
            #----------------------------------------------------------------
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            # Commands may need to know the format version.
         | 
| 6 | 
            +
            set(CMAKE_IMPORT_FILE_VERSION 1)
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            # Import target "mlx" for configuration "Release"
         | 
| 9 | 
            +
            set_property(TARGET mlx APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE)
         | 
| 10 | 
            +
            set_target_properties(mlx PROPERTIES
         | 
| 11 | 
            +
              IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib/libmlx.dylib"
         | 
| 12 | 
            +
              IMPORTED_SONAME_RELEASE "@rpath/libmlx.dylib"
         | 
| 13 | 
            +
              )
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            list(APPEND _cmake_import_check_targets mlx )
         | 
| 16 | 
            +
            list(APPEND _cmake_import_check_files_for_mlx "${_IMPORT_PREFIX}/lib/libmlx.dylib" )
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            # Commands beyond this point should not need to know the version.
         | 
| 19 | 
            +
            set(CMAKE_IMPORT_FILE_VERSION)
         | 
    	
        lib/python3.11/site-packages/mlx/share/cmake/MLX/MLXTargets.cmake
    ADDED
    
    | @@ -0,0 +1,107 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Generated by CMake
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            if("${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION}" LESS 2.8)
         | 
| 4 | 
            +
               message(FATAL_ERROR "CMake >= 2.8.0 required")
         | 
| 5 | 
            +
            endif()
         | 
| 6 | 
            +
            if(CMAKE_VERSION VERSION_LESS "2.8.3")
         | 
| 7 | 
            +
               message(FATAL_ERROR "CMake >= 2.8.3 required")
         | 
| 8 | 
            +
            endif()
         | 
| 9 | 
            +
            cmake_policy(PUSH)
         | 
| 10 | 
            +
            cmake_policy(VERSION 2.8.3...3.24)
         | 
| 11 | 
            +
            #----------------------------------------------------------------
         | 
| 12 | 
            +
            # Generated CMake target import file.
         | 
| 13 | 
            +
            #----------------------------------------------------------------
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            # Commands may need to know the format version.
         | 
| 16 | 
            +
            set(CMAKE_IMPORT_FILE_VERSION 1)
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            # Protect against multiple inclusion, which would fail when already imported targets are added once more.
         | 
| 19 | 
            +
            set(_cmake_targets_defined "")
         | 
| 20 | 
            +
            set(_cmake_targets_not_defined "")
         | 
| 21 | 
            +
            set(_cmake_expected_targets "")
         | 
| 22 | 
            +
            foreach(_cmake_expected_target IN ITEMS mlx)
         | 
| 23 | 
            +
              list(APPEND _cmake_expected_targets "${_cmake_expected_target}")
         | 
| 24 | 
            +
              if(TARGET "${_cmake_expected_target}")
         | 
| 25 | 
            +
                list(APPEND _cmake_targets_defined "${_cmake_expected_target}")
         | 
| 26 | 
            +
              else()
         | 
| 27 | 
            +
                list(APPEND _cmake_targets_not_defined "${_cmake_expected_target}")
         | 
| 28 | 
            +
              endif()
         | 
| 29 | 
            +
            endforeach()
         | 
| 30 | 
            +
            unset(_cmake_expected_target)
         | 
| 31 | 
            +
            if(_cmake_targets_defined STREQUAL _cmake_expected_targets)
         | 
| 32 | 
            +
              unset(_cmake_targets_defined)
         | 
| 33 | 
            +
              unset(_cmake_targets_not_defined)
         | 
| 34 | 
            +
              unset(_cmake_expected_targets)
         | 
| 35 | 
            +
              unset(CMAKE_IMPORT_FILE_VERSION)
         | 
| 36 | 
            +
              cmake_policy(POP)
         | 
| 37 | 
            +
              return()
         | 
| 38 | 
            +
            endif()
         | 
| 39 | 
            +
            if(NOT _cmake_targets_defined STREQUAL "")
         | 
| 40 | 
            +
              string(REPLACE ";" ", " _cmake_targets_defined_text "${_cmake_targets_defined}")
         | 
| 41 | 
            +
              string(REPLACE ";" ", " _cmake_targets_not_defined_text "${_cmake_targets_not_defined}")
         | 
| 42 | 
            +
              message(FATAL_ERROR "Some (but not all) targets in this export set were already defined.\nTargets Defined: ${_cmake_targets_defined_text}\nTargets not yet defined: ${_cmake_targets_not_defined_text}\n")
         | 
| 43 | 
            +
            endif()
         | 
| 44 | 
            +
            unset(_cmake_targets_defined)
         | 
| 45 | 
            +
            unset(_cmake_targets_not_defined)
         | 
| 46 | 
            +
            unset(_cmake_expected_targets)
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            # Compute the installation prefix relative to this file.
         | 
| 50 | 
            +
            get_filename_component(_IMPORT_PREFIX "${CMAKE_CURRENT_LIST_FILE}" PATH)
         | 
| 51 | 
            +
            get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
         | 
| 52 | 
            +
            get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
         | 
| 53 | 
            +
            get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
         | 
| 54 | 
            +
            if(_IMPORT_PREFIX STREQUAL "/")
         | 
| 55 | 
            +
              set(_IMPORT_PREFIX "")
         | 
| 56 | 
            +
            endif()
         | 
| 57 | 
            +
             | 
| 58 | 
            +
            # Create imported target mlx
         | 
| 59 | 
            +
            add_library(mlx SHARED IMPORTED)
         | 
| 60 | 
            +
             | 
| 61 | 
            +
            set_target_properties(mlx PROPERTIES
         | 
| 62 | 
            +
              INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include/metal_cpp;${_IMPORT_PREFIX}/include/json;${_IMPORT_PREFIX}/include;${_IMPORT_PREFIX}/include"
         | 
| 63 | 
            +
              INTERFACE_LINK_LIBRARIES "/Applications/Xcode-beta.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX14.0.sdk/System/Library/Frameworks/Metal.framework;/Applications/Xcode-beta.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX14.0.sdk/System/Library/Frameworks/Foundation.framework;/Applications/Xcode-beta.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX14.0.sdk/System/Library/Frameworks/QuartzCore.framework;/Applications/Xcode-beta.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX14.0.sdk/System/Library/Frameworks/Accelerate.framework"
         | 
| 64 | 
            +
            )
         | 
| 65 | 
            +
             | 
| 66 | 
            +
            if(CMAKE_VERSION VERSION_LESS 2.8.12)
         | 
| 67 | 
            +
              message(FATAL_ERROR "This file relies on consumers using CMake 2.8.12 or greater.")
         | 
| 68 | 
            +
            endif()
         | 
| 69 | 
            +
             | 
| 70 | 
            +
            # Load information for each installed configuration.
         | 
| 71 | 
            +
            file(GLOB _cmake_config_files "${CMAKE_CURRENT_LIST_DIR}/MLXTargets-*.cmake")
         | 
| 72 | 
            +
            foreach(_cmake_config_file IN LISTS _cmake_config_files)
         | 
| 73 | 
            +
              include("${_cmake_config_file}")
         | 
| 74 | 
            +
            endforeach()
         | 
| 75 | 
            +
            unset(_cmake_config_file)
         | 
| 76 | 
            +
            unset(_cmake_config_files)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
            # Cleanup temporary variables.
         | 
| 79 | 
            +
            set(_IMPORT_PREFIX)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
            # Loop over all imported files and verify that they actually exist
         | 
| 82 | 
            +
            foreach(_cmake_target IN LISTS _cmake_import_check_targets)
         | 
| 83 | 
            +
              foreach(_cmake_file IN LISTS "_cmake_import_check_files_for_${_cmake_target}")
         | 
| 84 | 
            +
                if(NOT EXISTS "${_cmake_file}")
         | 
| 85 | 
            +
                  message(FATAL_ERROR "The imported target \"${_cmake_target}\" references the file
         | 
| 86 | 
            +
               \"${_cmake_file}\"
         | 
| 87 | 
            +
            but this file does not exist.  Possible reasons include:
         | 
| 88 | 
            +
            * The file was deleted, renamed, or moved to another location.
         | 
| 89 | 
            +
            * An install or uninstall procedure did not complete successfully.
         | 
| 90 | 
            +
            * The installation package was faulty and contained
         | 
| 91 | 
            +
               \"${CMAKE_CURRENT_LIST_FILE}\"
         | 
| 92 | 
            +
            but not all the files it references.
         | 
| 93 | 
            +
            ")
         | 
| 94 | 
            +
                endif()
         | 
| 95 | 
            +
              endforeach()
         | 
| 96 | 
            +
              unset(_cmake_file)
         | 
| 97 | 
            +
              unset("_cmake_import_check_files_for_${_cmake_target}")
         | 
| 98 | 
            +
            endforeach()
         | 
| 99 | 
            +
            unset(_cmake_target)
         | 
| 100 | 
            +
            unset(_cmake_import_check_targets)
         | 
| 101 | 
            +
             | 
| 102 | 
            +
            # This file does not depend on other imported targets which have
         | 
| 103 | 
            +
            # been exported from the same project but in a separate export set.
         | 
| 104 | 
            +
             | 
| 105 | 
            +
            # Commands beyond this point should not need to know the version.
         | 
| 106 | 
            +
            set(CMAKE_IMPORT_FILE_VERSION)
         | 
| 107 | 
            +
            cmake_policy(POP)
         | 
    	
        lib/python3.11/site-packages/mlx/share/cmake/MLX/extension.cmake
    ADDED
    
    | @@ -0,0 +1,56 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            include(CMakeParseArguments)
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            ###############################################################################
         | 
| 4 | 
            +
            # Build metal library
         | 
| 5 | 
            +
            #
         | 
| 6 | 
            +
            # Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib
         | 
| 7 | 
            +
            # from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS}
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Args:
         | 
| 10 | 
            +
            #     TARGET: Custom target to be added for the metal library 
         | 
| 11 | 
            +
            #     TITLE: Name of the .metallib
         | 
| 12 | 
            +
            #     OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib
         | 
| 13 | 
            +
            #     SOURCES: List of source files
         | 
| 14 | 
            +
            #     INCLUDE_DIRS: List of include dirs
         | 
| 15 | 
            +
            #     DEPS: List of dependency files (like headers)
         | 
| 16 | 
            +
            #
         | 
| 17 | 
            +
            macro(mlx_build_metallib)
         | 
| 18 | 
            +
              # Parse args
         | 
| 19 | 
            +
              set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
         | 
| 20 | 
            +
              set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
         | 
| 21 | 
            +
              cmake_parse_arguments(
         | 
| 22 | 
            +
                  MTLLIB 
         | 
| 23 | 
            +
                  ""
         | 
| 24 | 
            +
                  "${oneValueArgs}"
         | 
| 25 | 
            +
                  "${multiValueArgs}" 
         | 
| 26 | 
            +
                  ${ARGN}
         | 
| 27 | 
            +
              )
         | 
| 28 | 
            +
             | 
| 29 | 
            +
              # Set output
         | 
| 30 | 
            +
              set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
         | 
| 31 | 
            +
             | 
| 32 | 
            +
              # Collect compile options 
         | 
| 33 | 
            +
              set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
         | 
| 34 | 
            +
             | 
| 35 | 
            +
              # Prepare metallib build command
         | 
| 36 | 
            +
              add_custom_command(
         | 
| 37 | 
            +
                OUTPUT ${MTLLIB_BUILD_TARGET}
         | 
| 38 | 
            +
                COMMAND xcrun -sdk macosx metal 
         | 
| 39 | 
            +
                              "$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
         | 
| 40 | 
            +
                              ${MTLLIB_COMPILE_OPTIONS}
         | 
| 41 | 
            +
                              ${MTLLIB_SOURCES}
         | 
| 42 | 
            +
                              -o ${MTLLIB_BUILD_TARGET}
         | 
| 43 | 
            +
                DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
         | 
| 44 | 
            +
                COMMAND_EXPAND_LISTS
         | 
| 45 | 
            +
                COMMENT "Building ${MTLLIB_TITLE}.metallib"
         | 
| 46 | 
            +
                VERBATIM
         | 
| 47 | 
            +
              )
         | 
| 48 | 
            +
             | 
| 49 | 
            +
              # Add metallib custom target
         | 
| 50 | 
            +
              add_custom_target(
         | 
| 51 | 
            +
                ${MTLLIB_TARGET}
         | 
| 52 | 
            +
                DEPENDS
         | 
| 53 | 
            +
                ${MTLLIB_BUILD_TARGET}
         | 
| 54 | 
            +
              )
         | 
| 55 | 
            +
             | 
| 56 | 
            +
            endmacro(mlx_build_metallib)
         | 
    	
        lib/python3.11/site-packages/mlx/utils.py
    ADDED
    
    | @@ -0,0 +1,145 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright © 2023 Apple Inc.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from collections import defaultdict
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            def tree_map(fn, tree, *rest, is_leaf=None):
         | 
| 7 | 
            +
                """Applies ``fn`` to the leaves of the python tree ``tree`` and
         | 
| 8 | 
            +
                returns a new collection with the results.
         | 
| 9 | 
            +
             | 
| 10 | 
            +
                If ``rest`` is provided, every item is assumed to be a superset of ``tree``
         | 
| 11 | 
            +
                and the corresponding leaves are provided as extra positional arguments to
         | 
| 12 | 
            +
                ``fn``. In that respect, :meth:`tree_map` is closer to :func:`itertools.starmap`
         | 
| 13 | 
            +
                than to :func:`map`.
         | 
| 14 | 
            +
             | 
| 15 | 
            +
                The keyword argument ``is_leaf`` decides what constitutes a leaf from
         | 
| 16 | 
            +
                ``tree`` similar to :func:`tree_flatten`.
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                .. code-block:: python
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                    import mlx.nn as nn
         | 
| 21 | 
            +
                    from mlx.utils import tree_map
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                    model = nn.Linear(10, 10)
         | 
| 24 | 
            +
                    print(model.parameters().keys())
         | 
| 25 | 
            +
                    # dict_keys(['weight', 'bias'])
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                    # square the parameters
         | 
| 28 | 
            +
                    model.update(tree_map(lambda x: x*x, model.parameters()))
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                Args:
         | 
| 31 | 
            +
                    fn (Callable): The function that processes the leaves of the tree
         | 
| 32 | 
            +
                    tree (Any): The main python tree that will be iterated upon
         | 
| 33 | 
            +
                    rest (Tuple[Any]): Extra trees to be iterated together with tree
         | 
| 34 | 
            +
                    is_leaf (Optional[Callable]): An optional callable that returns True if
         | 
| 35 | 
            +
                        the passed object is considered a leaf or False otherwise.
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                Returns:
         | 
| 38 | 
            +
                    A python tree with the new values returned by ``fn``.
         | 
| 39 | 
            +
                """
         | 
| 40 | 
            +
                if is_leaf is not None and is_leaf(tree):
         | 
| 41 | 
            +
                    return fn(tree, *rest)
         | 
| 42 | 
            +
                elif isinstance(tree, (list, tuple)):
         | 
| 43 | 
            +
                    TreeType = type(tree)
         | 
| 44 | 
            +
                    return TreeType(
         | 
| 45 | 
            +
                        tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf)
         | 
| 46 | 
            +
                        for i, child in enumerate(tree)
         | 
| 47 | 
            +
                    )
         | 
| 48 | 
            +
                elif isinstance(tree, dict):
         | 
| 49 | 
            +
                    return {
         | 
| 50 | 
            +
                        k: tree_map(fn, child, *(r[k] for r in rest), is_leaf=is_leaf)
         | 
| 51 | 
            +
                        for k, child in tree.items()
         | 
| 52 | 
            +
                    }
         | 
| 53 | 
            +
                else:
         | 
| 54 | 
            +
                    return fn(tree, *rest)
         | 
| 55 | 
            +
             | 
| 56 | 
            +
             | 
| 57 | 
            +
            def tree_flatten(tree, prefix="", is_leaf=None):
         | 
| 58 | 
            +
                """Flattens a python tree to a list of key, value tuples.
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                The keys are using the dot notation to define trees of arbitrary depth and
         | 
| 61 | 
            +
                complexity.
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                .. code-block:: python
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    from mlx.utils import tree_flatten
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                    print(tree_flatten([[[0]]]))
         | 
| 68 | 
            +
                    # [("0.0.0", 0)]
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    print(tree_flatten([[[0]]], ".hello"))
         | 
| 71 | 
            +
                    # [("hello.0.0.0", 0)]
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                .. note::
         | 
| 74 | 
            +
                   Dictionaries should have keys that are valid python identifiers.
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                Args:
         | 
| 77 | 
            +
                    tree (Any): The python tree to be flattened.
         | 
| 78 | 
            +
                    prefix (str): A prefix to use for the keys. The first character is
         | 
| 79 | 
            +
                        always discarded.
         | 
| 80 | 
            +
                    is_leaf (Callable): An optional callable that returns True if the
         | 
| 81 | 
            +
                        passed object is considered a leaf or False otherwise.
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                Returns:
         | 
| 84 | 
            +
                    List[Tuple[str, Any]]: The flat representation of the python tree.
         | 
| 85 | 
            +
                """
         | 
| 86 | 
            +
                flat_tree = []
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                if is_leaf is None or not is_leaf(tree):
         | 
| 89 | 
            +
                    if isinstance(tree, (list, tuple)):
         | 
| 90 | 
            +
                        for i, t in enumerate(tree):
         | 
| 91 | 
            +
                            flat_tree.extend(tree_flatten(t, f"{prefix}.{i}", is_leaf))
         | 
| 92 | 
            +
                        return flat_tree
         | 
| 93 | 
            +
                    if isinstance(tree, dict):
         | 
| 94 | 
            +
                        for k, t in tree.items():
         | 
| 95 | 
            +
                            flat_tree.extend(tree_flatten(t, f"{prefix}.{k}", is_leaf))
         | 
| 96 | 
            +
                        return flat_tree
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                return [(prefix[1:], tree)]
         | 
| 99 | 
            +
             | 
| 100 | 
            +
             | 
| 101 | 
            +
            def tree_unflatten(tree):
         | 
| 102 | 
            +
                """Recreate a python tree from its flat representation.
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                .. code-block:: python
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                    from mlx.utils import tree_unflatten
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                    d = tree_unflatten([("hello.world", 42)])
         | 
| 109 | 
            +
                    print(d)
         | 
| 110 | 
            +
                    # {"hello": {"world": 42}}
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                Args:
         | 
| 113 | 
            +
                    tree (List[Tuple[str, Any]]): The flat representation of a python tree.
         | 
| 114 | 
            +
                                                  For instance as returned by :meth:`tree_flatten`.
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                Returns:
         | 
| 117 | 
            +
                    A python tree.
         | 
| 118 | 
            +
                """
         | 
| 119 | 
            +
                if len(tree) == 1 and tree[0][0] == "":
         | 
| 120 | 
            +
                    return tree[0][1]
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                try:
         | 
| 123 | 
            +
                    int(tree[0][0].split(".", maxsplit=1)[0])
         | 
| 124 | 
            +
                    is_list = True
         | 
| 125 | 
            +
                except ValueError:
         | 
| 126 | 
            +
                    is_list = False
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                # collect children
         | 
| 129 | 
            +
                children = defaultdict(list)
         | 
| 130 | 
            +
                for key, value in tree:
         | 
| 131 | 
            +
                    current_idx, *next_idx = key.split(".", maxsplit=1)
         | 
| 132 | 
            +
                    next_idx = "" if not next_idx else next_idx[0]
         | 
| 133 | 
            +
                    children[current_idx].append((next_idx, value))
         | 
| 134 | 
            +
             | 
| 135 | 
            +
                # recursively map them to the original container
         | 
| 136 | 
            +
                if is_list:
         | 
| 137 | 
            +
                    keys = sorted((int(idx), idx) for idx in children.keys())
         | 
| 138 | 
            +
                    l = []
         | 
| 139 | 
            +
                    for i, k in keys:
         | 
| 140 | 
            +
                        # if i <= len(l), no {} will be appended.
         | 
| 141 | 
            +
                        l.extend([{} for _ in range(i - len(l))])
         | 
| 142 | 
            +
                        l.append(tree_unflatten(children[k]))
         | 
| 143 | 
            +
                    return l
         | 
| 144 | 
            +
                else:
         | 
| 145 | 
            +
                    return {k: tree_unflatten(v) for k, v in children.items()}
         | 
    	
        lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/INSTALLER
    ADDED
    
    | @@ -0,0 +1 @@ | |
|  | 
|  | |
| 1 | 
            +
            pip
         | 
    	
        lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/LICENSE
    ADDED
    
    | @@ -0,0 +1,19 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            Copyright (c) 2012 Erik Rose
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            Permission is hereby granted, free of charge, to any person obtaining a copy of
         | 
| 4 | 
            +
            this software and associated documentation files (the "Software"), to deal in
         | 
| 5 | 
            +
            the Software without restriction, including without limitation the rights to
         | 
| 6 | 
            +
            use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
         | 
| 7 | 
            +
            of the Software, and to permit persons to whom the Software is furnished to do
         | 
| 8 | 
            +
            so, subject to the following conditions:
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            The above copyright notice and this permission notice shall be included in all
         | 
| 11 | 
            +
            copies or substantial portions of the Software.
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
         | 
| 14 | 
            +
            IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
         | 
| 15 | 
            +
            FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
         | 
| 16 | 
            +
            AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
         | 
| 17 | 
            +
            LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
         | 
| 18 | 
            +
            OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
         | 
| 19 | 
            +
            SOFTWARE.
         | 
    	
        lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/METADATA
    ADDED
    
    | @@ -0,0 +1,253 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            Metadata-Version: 2.1
         | 
| 2 | 
            +
            Name: more-itertools
         | 
| 3 | 
            +
            Version: 10.1.0
         | 
| 4 | 
            +
            Summary: More routines for operating on iterables, beyond itertools
         | 
| 5 | 
            +
            Keywords: itertools,iterator,iteration,filter,peek,peekable,chunk,chunked
         | 
| 6 | 
            +
            Author-email: Erik Rose <[email protected]>
         | 
| 7 | 
            +
            Requires-Python: >=3.8
         | 
| 8 | 
            +
            Description-Content-Type: text/x-rst
         | 
| 9 | 
            +
            Classifier: Development Status :: 5 - Production/Stable
         | 
| 10 | 
            +
            Classifier: Intended Audience :: Developers
         | 
| 11 | 
            +
            Classifier: Natural Language :: English
         | 
| 12 | 
            +
            Classifier: License :: OSI Approved :: MIT License
         | 
| 13 | 
            +
            Classifier: Programming Language :: Python :: 3
         | 
| 14 | 
            +
            Classifier: Programming Language :: Python :: 3.8
         | 
| 15 | 
            +
            Classifier: Programming Language :: Python :: 3.9
         | 
| 16 | 
            +
            Classifier: Programming Language :: Python :: 3.10
         | 
| 17 | 
            +
            Classifier: Programming Language :: Python :: 3.11
         | 
| 18 | 
            +
            Classifier: Programming Language :: Python :: 3 :: Only
         | 
| 19 | 
            +
            Classifier: Programming Language :: Python :: Implementation :: CPython
         | 
| 20 | 
            +
            Classifier: Programming Language :: Python :: Implementation :: PyPy
         | 
| 21 | 
            +
            Classifier: Topic :: Software Development :: Libraries
         | 
| 22 | 
            +
            Project-URL: Homepage, https://github.com/more-itertools/more-itertools
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            ==============
         | 
| 25 | 
            +
            More Itertools
         | 
| 26 | 
            +
            ==============
         | 
| 27 | 
            +
             | 
| 28 | 
            +
            .. image:: https://readthedocs.org/projects/more-itertools/badge/?version=latest
         | 
| 29 | 
            +
              :target: https://more-itertools.readthedocs.io/en/stable/
         | 
| 30 | 
            +
             | 
| 31 | 
            +
            Python's ``itertools`` library is a gem - you can compose elegant solutions
         | 
| 32 | 
            +
            for a variety of problems with the functions it provides. In ``more-itertools``
         | 
| 33 | 
            +
            we collect additional building blocks, recipes, and routines for working with
         | 
| 34 | 
            +
            Python iterables.
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            +------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
         | 
| 37 | 
            +
            | Grouping               | `chunked <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.chunked>`_,                                                                               |
         | 
| 38 | 
            +
            |                        | `ichunked <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.ichunked>`_,                                                                             |
         | 
| 39 | 
            +
            |                        | `chunked_even <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.chunked_even>`_,                                                                     |
         | 
| 40 | 
            +
            |                        | `sliced <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.sliced>`_,                                                                                 |
         | 
| 41 | 
            +
            |                        | `constrained_batches <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.constrained_batches>`_,                                                       |
         | 
| 42 | 
            +
            |                        | `distribute <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distribute>`_,                                                                         |
         | 
| 43 | 
            +
            |                        | `divide <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.divide>`_,                                                                                 |
         | 
| 44 | 
            +
            |                        | `split_at <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.split_at>`_,                                                                             |
         | 
| 45 | 
            +
            |                        | `split_before <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.split_before>`_,                                                                     |
         | 
| 46 | 
            +
            |                        | `split_after <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.split_after>`_,                                                                       |
         | 
| 47 | 
            +
            |                        | `split_into <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.split_into>`_,                                                                         |
         | 
| 48 | 
            +
            |                        | `split_when <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.split_when>`_,                                                                         |
         | 
| 49 | 
            +
            |                        | `bucket <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.bucket>`_,                                                                                 |
         | 
| 50 | 
            +
            |                        | `unzip <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.unzip>`_,                                                                                   |
         | 
| 51 | 
            +
            |                        | `batched <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.batched>`_,                                                                               |
         | 
| 52 | 
            +
            |                        | `grouper <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.grouper>`_,                                                                               |
         | 
| 53 | 
            +
            |                        | `partition <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.partition>`_,                                                                           |
         | 
| 54 | 
            +
            |                        | `transpose <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.transpose>`_                                                                            |
         | 
| 55 | 
            +
            +------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
         | 
| 56 | 
            +
            | Lookahead and lookback | `spy <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.spy>`_,                                                                                       |
         | 
| 57 | 
            +
            |                        | `peekable <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.peekable>`_,                                                                             |
         | 
| 58 | 
            +
            |                        | `seekable <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.seekable>`_                                                                              |
         | 
| 59 | 
            +
            +------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
         | 
| 60 | 
            +
            | Windowing              | `windowed <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.windowed>`_,                                                                             |
         | 
| 61 | 
            +
            |                        | `substrings <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.substrings>`_,                                                                         |
         | 
| 62 | 
            +
            |                        | `substrings_indexes <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.substrings_indexes>`_,                                                         |
         | 
| 63 | 
            +
            |                        | `stagger <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.stagger>`_,                                                                               |
         | 
| 64 | 
            +
            |                        | `windowed_complete <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.windowed_complete>`_,                                                           |
         | 
| 65 | 
            +
            |                        | `pairwise <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.pairwise>`_,                                                                             |
         | 
| 66 | 
            +
            |                        | `triplewise <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.triplewise>`_,                                                                         |
         | 
| 67 | 
            +
            |                        | `sliding_window <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.sliding_window>`_,                                                                 |
         | 
| 68 | 
            +
            |                        | `subslices <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.subslices>`_                                                                            |
         | 
| 69 | 
            +
            +------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
         | 
| 70 | 
            +
            | Augmenting             | `count_cycle <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.count_cycle>`_,                                                                       |
         | 
| 71 | 
            +
            |                        | `intersperse <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.intersperse>`_,                                                                       |
         | 
| 72 | 
            +
            |                        | `padded <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.padded>`_,                                                                                 |
         | 
| 73 | 
            +
            |                        | `repeat_each <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.repeat_each>`_,                                                                       |
         | 
| 74 | 
            +
            |                        | `mark_ends <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.mark_ends>`_,                                                                           |
         | 
| 75 | 
            +
            |                        | `repeat_last <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.repeat_last>`_,                                                                       |
         | 
| 76 | 
            +
            |                        | `adjacent <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.adjacent>`_,                                                                             |
         | 
| 77 | 
            +
            |                        | `groupby_transform <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.groupby_transform>`_,                                                           |
         | 
| 78 | 
            +
            |                        | `pad_none <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.pad_none>`_,                                                                             |
         | 
| 79 | 
            +
            |                        | `ncycles <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.ncycles>`_                                                                                |
         | 
| 80 | 
            +
            +------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
         | 
| 81 | 
            +
            | Combining              | `collapse <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.collapse>`_,                                                                             |
         | 
| 82 | 
            +
            |                        | `sort_together <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.sort_together>`_,                                                                   |
         | 
| 83 | 
            +
            |                        | `interleave <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.interleave>`_,                                                                         |
         | 
| 84 | 
            +
            |                        | `interleave_longest <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.interleave_longest>`_,                                                         |
         | 
| 85 | 
            +
            |                        | `interleave_evenly <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.interleave_evenly>`_,                                                           |
         | 
| 86 | 
            +
            |                        | `zip_offset <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.zip_offset>`_,                                                                         |
         | 
| 87 | 
            +
            |                        | `zip_equal <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.zip_equal>`_,                                                                           |
         | 
| 88 | 
            +
            |                        | `zip_broadcast <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.zip_broadcast>`_,                                                                   |
         | 
| 89 | 
            +
            |                        | `dotproduct <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.dotproduct>`_,                                                                         |
         | 
| 90 | 
            +
            |                        | `convolve <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.convolve>`_,                                                                             |
         | 
| 91 | 
            +
            |                        | `flatten <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.flatten>`_,                                                                               |
         | 
| 92 | 
            +
            |                        | `roundrobin <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.roundrobin>`_,                                                                         |
         | 
| 93 | 
            +
            |                        | `prepend <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.prepend>`_,                                                                               |
         | 
| 94 | 
            +
            |                        | `value_chain <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.value_chain>`_,                                                                       |
         | 
| 95 | 
            +
            |                        | `partial_product <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.partial_product>`_                                                                |
         | 
| 96 | 
            +
            +------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
         | 
| 97 | 
            +
            | Summarizing            | `ilen <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.ilen>`_,                                                                                     |
         | 
| 98 | 
            +
            |                        | `unique_to_each <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.unique_to_each>`_,                                                                 |
         | 
| 99 | 
            +
            |                        | `sample <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.sample>`_,                                                                                 |
         | 
| 100 | 
            +
            |                        | `consecutive_groups <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.consecutive_groups>`_,                                                         |
         | 
| 101 | 
            +
            |                        | `run_length <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.run_length>`_,                                                                         |
         | 
| 102 | 
            +
            |                        | `map_reduce <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.map_reduce>`_,                                                                         |
         | 
| 103 | 
            +
            |                        | `exactly_n <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.exactly_n>`_,                                                                           |
         | 
| 104 | 
            +
            |                        | `is_sorted <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.is_sorted>`_,                                                                           |
         | 
| 105 | 
            +
            |                        | `all_equal <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.all_equal>`_,                                                                           |
         | 
| 106 | 
            +
            |                        | `all_unique <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.all_unique>`_,                                                                         |
         | 
| 107 | 
            +
            |                        | `minmax <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.minmax>`_,                                                                                 |
         | 
| 108 | 
            +
            |                        | `first_true <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.first_true>`_,                                                                         |
         | 
| 109 | 
            +
            |                        | `quantify <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.quantify>`_,                                                                             |
         | 
| 110 | 
            +
            |                        | `iequals <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.iequals>`_                                                                                |
         | 
| 111 | 
            +
            +------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
         | 
| 112 | 
            +
            | Selecting              | `islice_extended <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.islice_extended>`_,                                                               |
         | 
| 113 | 
            +
            |                        | `first <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.first>`_,                                                                                   |
         | 
| 114 | 
            +
            |                        | `last <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.last>`_,                                                                                     |
         | 
| 115 | 
            +
            |                        | `one <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.one>`_,                                                                                       |
         | 
| 116 | 
            +
            |                        | `only <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.only>`_,                                                                                     |
         | 
| 117 | 
            +
            |                        | `strictly_n <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.strictly_n>`_,                                                                         |
         | 
| 118 | 
            +
            |                        | `strip <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.strip>`_,                                                                                   |
         | 
| 119 | 
            +
            |                        | `lstrip <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.lstrip>`_,                                                                                 |
         | 
| 120 | 
            +
            |                        | `rstrip <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.rstrip>`_,                                                                                 |
         | 
| 121 | 
            +
            |                        | `filter_except <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.filter_except>`_,                                                                   |
         | 
| 122 | 
            +
            |                        | `map_except <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.map_except>`_,                                                                         |
         | 
| 123 | 
            +
            |                        | `nth_or_last <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.nth_or_last>`_,                                                                       |
         | 
| 124 | 
            +
            |                        | `unique_in_window <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.unique_in_window>`_,                                                             |
         | 
| 125 | 
            +
            |                        | `before_and_after <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.before_and_after>`_,                                                             |
         | 
| 126 | 
            +
            |                        | `nth <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.nth>`_,                                                                                       |
         | 
| 127 | 
            +
            |                        | `take <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.take>`_,                                                                                     |
         | 
| 128 | 
            +
            |                        | `tail <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.tail>`_,                                                                                     |
         | 
| 129 | 
            +
            |                        | `unique_everseen <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.unique_everseen>`_,                                                               |
         | 
| 130 | 
            +
            |                        | `unique_justseen <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.unique_justseen>`_,                                                               |
         | 
| 131 | 
            +
            |                        | `duplicates_everseen <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.duplicates_everseen>`_,                                                       |
         | 
| 132 | 
            +
            |                        | `duplicates_justseen <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.duplicates_justseen>`_,                                                       |
         | 
| 133 | 
            +
            |                        | `longest_common_prefix <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.longest_common_prefix>`_,                                                   |
         | 
| 134 | 
            +
            |                        | `takewhile_inclusive <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.takewhile_inclusive>`_                                                        |
         | 
| 135 | 
            +
            +------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
         | 
| 136 | 
            +
            | Combinatorics          | `distinct_permutations <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distinct_permutations>`_,                                                   |
         | 
| 137 | 
            +
            |                        | `distinct_combinations <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distinct_combinations>`_,                                                   |
         | 
| 138 | 
            +
            |                        | `circular_shifts <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.circular_shifts>`_,                                                               |
         | 
| 139 | 
            +
            |                        | `partitions <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.partitions>`_,                                                                         |
         | 
| 140 | 
            +
            |                        | `set_partitions <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.set_partitions>`_,                                                                 |
         | 
| 141 | 
            +
            |                        | `product_index <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.product_index>`_,                                                                   |
         | 
| 142 | 
            +
            |                        | `combination_index <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.combination_index>`_,                                                           |
         | 
| 143 | 
            +
            |                        | `permutation_index <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.permutation_index>`_,                                                           |
         | 
| 144 | 
            +
            |                        | `combination_with_replacement_index <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.combination_with_replacement_index>`_,                         |
         | 
| 145 | 
            +
            |                        | `gray_product  <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.gray_product>`_,                                                                    |
         | 
| 146 | 
            +
            |                        | `outer_product  <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.outer_product>`_,                                                                  |
         | 
| 147 | 
            +
            |                        | `powerset <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.powerset>`_,                                                                             |
         | 
| 148 | 
            +
            |                        | `random_product <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.random_product>`_,                                                                 |
         | 
| 149 | 
            +
            |                        | `random_permutation <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.random_permutation>`_,                                                         |
         | 
| 150 | 
            +
            |                        | `random_combination <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.random_combination>`_,                                                         |
         | 
| 151 | 
            +
            |                        | `random_combination_with_replacement <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.random_combination_with_replacement>`_,                       |
         | 
| 152 | 
            +
            |                        | `nth_product <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.nth_product>`_,                                                                       |
         | 
| 153 | 
            +
            |                        | `nth_permutation <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.nth_permutation>`_,                                                               |
         | 
| 154 | 
            +
            |                        | `nth_combination <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.nth_combination>`_,                                                               |
         | 
| 155 | 
            +
            |                        | `nth_combination_with_replacement <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.nth_combination_with_replacement>`_                              |
         | 
| 156 | 
            +
            +------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
         | 
| 157 | 
            +
            | Wrapping               | `always_iterable <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.always_iterable>`_,                                                               |
         | 
| 158 | 
            +
            |                        | `always_reversible <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.always_reversible>`_,                                                           |
         | 
| 159 | 
            +
            |                        | `countable <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.countable>`_,                                                                           |
         | 
| 160 | 
            +
            |                        | `consumer <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.consumer>`_,                                                                             |
         | 
| 161 | 
            +
            |                        | `with_iter <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.with_iter>`_,                                                                           |
         | 
| 162 | 
            +
            |                        | `iter_except <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.iter_except>`_                                                                        |
         | 
| 163 | 
            +
            +------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
         | 
| 164 | 
            +
            | Others                 | `locate <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.locate>`_,                                                                                 |
         | 
| 165 | 
            +
            |                        | `rlocate <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.rlocate>`_,                                                                               |
         | 
| 166 | 
            +
            |                        | `replace <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.replace>`_,                                                                               |
         | 
| 167 | 
            +
            |                        | `numeric_range <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.numeric_range>`_,                                                                   |
         | 
| 168 | 
            +
            |                        | `side_effect <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.side_effect>`_,                                                                       |
         | 
| 169 | 
            +
            |                        | `iterate <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.iterate>`_,                                                                               |
         | 
| 170 | 
            +
            |                        | `difference <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.difference>`_,                                                                         |
         | 
| 171 | 
            +
            |                        | `make_decorator <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.make_decorator>`_,                                                                 |
         | 
| 172 | 
            +
            |                        | `SequenceView <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.SequenceView>`_,                                                                     |
         | 
| 173 | 
            +
            |                        | `time_limited <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.time_limited>`_,                                                                     |
         | 
| 174 | 
            +
            |                        | `map_if <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.map_if>`_,                                                                                 |
         | 
| 175 | 
            +
            |                        | `iter_index <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.iter_index>`_,                                                                         |
         | 
| 176 | 
            +
            |                        | `consume <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.consume>`_,                                                                               |
         | 
| 177 | 
            +
            |                        | `tabulate <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.tabulate>`_,                                                                             |
         | 
| 178 | 
            +
            |                        | `repeatfunc <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.repeatfunc>`_,                                                                         |
         | 
| 179 | 
            +
            |                        | `polynomial_from_roots <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.polynomial_from_roots>`_,                                                   |
         | 
| 180 | 
            +
            |                        | `polynomial_eval <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.polynomial_eval>`_,                                                               |
         | 
| 181 | 
            +
            |                        | `polynomial_derivative <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.polynomial_derivative>`_,                                                   |
         | 
| 182 | 
            +
            |                        | `sieve <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.sieve>`_,                                                                                   |
         | 
| 183 | 
            +
            |                        | `factor <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.factor>`_,                                                                                 |
         | 
| 184 | 
            +
            |                        | `matmul <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.matmul>`_,                                                                                 |
         | 
| 185 | 
            +
            |                        | `sum_of_squares <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.sum_of_squares>`_                                                                  |
         | 
| 186 | 
            +
            +------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
         | 
| 187 | 
            +
             | 
| 188 | 
            +
             | 
| 189 | 
            +
            Getting started
         | 
| 190 | 
            +
            ===============
         | 
| 191 | 
            +
             | 
| 192 | 
            +
            To get started, install the library with `pip <https://pip.pypa.io/en/stable/>`_:
         | 
| 193 | 
            +
             | 
| 194 | 
            +
            .. code-block:: shell
         | 
| 195 | 
            +
             | 
| 196 | 
            +
                pip install more-itertools
         | 
| 197 | 
            +
             | 
| 198 | 
            +
            The recipes from the `itertools docs <https://docs.python.org/3/library/itertools.html#itertools-recipes>`_
         | 
| 199 | 
            +
            are included in the top-level package:
         | 
| 200 | 
            +
             | 
| 201 | 
            +
            .. code-block:: python
         | 
| 202 | 
            +
             | 
| 203 | 
            +
                >>> from more_itertools import flatten
         | 
| 204 | 
            +
                >>> iterable = [(0, 1), (2, 3)]
         | 
| 205 | 
            +
                >>> list(flatten(iterable))
         | 
| 206 | 
            +
                [0, 1, 2, 3]
         | 
| 207 | 
            +
             | 
| 208 | 
            +
            Several new recipes are available as well:
         | 
| 209 | 
            +
             | 
| 210 | 
            +
            .. code-block:: python
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                >>> from more_itertools import chunked
         | 
| 213 | 
            +
                >>> iterable = [0, 1, 2, 3, 4, 5, 6, 7, 8]
         | 
| 214 | 
            +
                >>> list(chunked(iterable, 3))
         | 
| 215 | 
            +
                [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                >>> from more_itertools import spy
         | 
| 218 | 
            +
                >>> iterable = (x * x for x in range(1, 6))
         | 
| 219 | 
            +
                >>> head, iterable = spy(iterable, n=3)
         | 
| 220 | 
            +
                >>> list(head)
         | 
| 221 | 
            +
                [1, 4, 9]
         | 
| 222 | 
            +
                >>> list(iterable)
         | 
| 223 | 
            +
                [1, 4, 9, 16, 25]
         | 
| 224 | 
            +
             | 
| 225 | 
            +
             | 
| 226 | 
            +
             | 
| 227 | 
            +
            For the full listing of functions, see the `API documentation <https://more-itertools.readthedocs.io/en/stable/api.html>`_.
         | 
| 228 | 
            +
             | 
| 229 | 
            +
             | 
| 230 | 
            +
            Links elsewhere
         | 
| 231 | 
            +
            ===============
         | 
| 232 | 
            +
             | 
| 233 | 
            +
            Blog posts about ``more-itertools``:
         | 
| 234 | 
            +
             | 
| 235 | 
            +
            * `Yo, I heard you like decorators <https://www.bbayles.com/index/decorator_factory>`__
         | 
| 236 | 
            +
            * `Tour of Python Itertools <https://martinheinz.dev/blog/16>`__ (`Alternate <https://dev.to/martinheinz/tour-of-python-itertools-4122>`__)
         | 
| 237 | 
            +
            * `Real-World Python More Itertools <https://www.gidware.com/real-world-more-itertools/>`_
         | 
| 238 | 
            +
             | 
| 239 | 
            +
             | 
| 240 | 
            +
            Development
         | 
| 241 | 
            +
            ===========
         | 
| 242 | 
            +
             | 
| 243 | 
            +
            ``more-itertools`` is maintained by `@erikrose <https://github.com/erikrose>`_
         | 
| 244 | 
            +
            and `@bbayles <https://github.com/bbayles>`_, with help from `many others <https://github.com/more-itertools/more-itertools/graphs/contributors>`_.
         | 
| 245 | 
            +
            If you have a problem or suggestion, please file a bug or pull request in this
         | 
| 246 | 
            +
            repository. Thanks for contributing!
         | 
| 247 | 
            +
             | 
| 248 | 
            +
             | 
| 249 | 
            +
            Version History
         | 
| 250 | 
            +
            ===============
         | 
| 251 | 
            +
             | 
| 252 | 
            +
            The version history can be found in `documentation <https://more-itertools.readthedocs.io/en/stable/versions.html>`_.
         | 
| 253 | 
            +
             | 
    	
        lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/RECORD
    ADDED
    
    | @@ -0,0 +1,16 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            more_itertools-10.1.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
         | 
| 2 | 
            +
            more_itertools-10.1.0.dist-info/LICENSE,sha256=CfHIyelBrz5YTVlkHqm4fYPAyw_QB-te85Gn4mQ8GkY,1053
         | 
| 3 | 
            +
            more_itertools-10.1.0.dist-info/METADATA,sha256=s6T6Rg5Sq9hJE6KhX38MrPxWh8X0d9Fsdld4qdbrQVQ,33830
         | 
| 4 | 
            +
            more_itertools-10.1.0.dist-info/RECORD,,
         | 
| 5 | 
            +
            more_itertools-10.1.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 6 | 
            +
            more_itertools-10.1.0.dist-info/WHEEL,sha256=rSgq_JpHF9fHR1lx53qwg_1-2LypZE_qmcuXbVUq948,81
         | 
| 7 | 
            +
            more_itertools/__init__.py,sha256=weQyJxnCVH2EuAaoxC8ZEX-ViZmpHm5kLHO05O8LfRM,149
         | 
| 8 | 
            +
            more_itertools/__init__.pyi,sha256=5B3eTzON1BBuOLob1vCflyEb2lSd6usXQQ-Cv-hXkeA,43
         | 
| 9 | 
            +
            more_itertools/__pycache__/__init__.cpython-311.pyc,,
         | 
| 10 | 
            +
            more_itertools/__pycache__/more.cpython-311.pyc,,
         | 
| 11 | 
            +
            more_itertools/__pycache__/recipes.cpython-311.pyc,,
         | 
| 12 | 
            +
            more_itertools/more.py,sha256=6XrPO3vvd3miZb6ygQtY5SoJeLk2PAwfEQOlZHwZgXw,140481
         | 
| 13 | 
            +
            more_itertools/more.pyi,sha256=nbYRtg-0YR0XMnFOGMzRzbgEOtNC_cfkUaOMOmNVWMA,20726
         | 
| 14 | 
            +
            more_itertools/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
         | 
| 15 | 
            +
            more_itertools/recipes.py,sha256=GNzJluz0ZYBZhF_H3bC7YkNCm9fZzOoBiQvh1WNNDOU,26354
         | 
| 16 | 
            +
            more_itertools/recipes.pyi,sha256=77IGo2ZqPtp9IkRu5MzQBgatLBl52vO3MRrFvPDcHzM,4237
         | 
    	
        lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/REQUESTED
    ADDED
    
    | 
            File without changes
         | 
    	
        lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/WHEEL
    ADDED
    
    | @@ -0,0 +1,4 @@ | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            Wheel-Version: 1.0
         | 
| 2 | 
            +
            Generator: flit 3.8.0
         | 
| 3 | 
            +
            Root-Is-Purelib: true
         | 
| 4 | 
            +
            Tag: py3-none-any
         | 
    	
        lib/python3.11/site-packages/more_itertools/__init__.py
    ADDED
    
    | @@ -0,0 +1,6 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """More routines for operating on iterables, beyond itertools"""
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from .more import *  # noqa
         | 
| 4 | 
            +
            from .recipes import *  # noqa
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            __version__ = '10.1.0'
         | 
    	
        lib/python3.11/site-packages/more_itertools/__init__.pyi
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .more import *
         | 
| 2 | 
            +
            from .recipes import *
         | 
    	
        lib/python3.11/site-packages/more_itertools/__pycache__/__init__.cpython-311.pyc
    ADDED
    
    | Binary file (394 Bytes). View file | 
|  | 
    	
        lib/python3.11/site-packages/more_itertools/__pycache__/more.cpython-311.pyc
    ADDED
    
    | Binary file (179 kB). View file | 
|  | 
    	
        lib/python3.11/site-packages/more_itertools/__pycache__/recipes.cpython-311.pyc
    ADDED
    
    | Binary file (37.2 kB). View file | 
|  | 
    	
        lib/python3.11/site-packages/more_itertools/more.py
    ADDED
    
    | The diff for this file is too large to render. 
		See raw diff | 
|  | 
    	
        lib/python3.11/site-packages/more_itertools/more.pyi
    ADDED
    
    | @@ -0,0 +1,684 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """Stubs for more_itertools.more"""
         | 
| 2 | 
            +
            from __future__ import annotations
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from types import TracebackType
         | 
| 5 | 
            +
            from typing import (
         | 
| 6 | 
            +
                Any,
         | 
| 7 | 
            +
                Callable,
         | 
| 8 | 
            +
                Container,
         | 
| 9 | 
            +
                ContextManager,
         | 
| 10 | 
            +
                Generic,
         | 
| 11 | 
            +
                Hashable,
         | 
| 12 | 
            +
                Iterable,
         | 
| 13 | 
            +
                Iterator,
         | 
| 14 | 
            +
                overload,
         | 
| 15 | 
            +
                Reversible,
         | 
| 16 | 
            +
                Sequence,
         | 
| 17 | 
            +
                Sized,
         | 
| 18 | 
            +
                Type,
         | 
| 19 | 
            +
                TypeVar,
         | 
| 20 | 
            +
                type_check_only,
         | 
| 21 | 
            +
            )
         | 
| 22 | 
            +
            from typing_extensions import Protocol
         | 
| 23 | 
            +
             | 
| 24 | 
            +
            # Type and type variable definitions
         | 
| 25 | 
            +
            _T = TypeVar('_T')
         | 
| 26 | 
            +
            _T1 = TypeVar('_T1')
         | 
| 27 | 
            +
            _T2 = TypeVar('_T2')
         | 
| 28 | 
            +
            _U = TypeVar('_U')
         | 
| 29 | 
            +
            _V = TypeVar('_V')
         | 
| 30 | 
            +
            _W = TypeVar('_W')
         | 
| 31 | 
            +
            _T_co = TypeVar('_T_co', covariant=True)
         | 
| 32 | 
            +
            _GenFn = TypeVar('_GenFn', bound=Callable[..., Iterator[object]])
         | 
| 33 | 
            +
            _Raisable = BaseException | Type[BaseException]
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            @type_check_only
         | 
| 36 | 
            +
            class _SizedIterable(Protocol[_T_co], Sized, Iterable[_T_co]): ...
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            @type_check_only
         | 
| 39 | 
            +
            class _SizedReversible(Protocol[_T_co], Sized, Reversible[_T_co]): ...
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            @type_check_only
         | 
| 42 | 
            +
            class _SupportsSlicing(Protocol[_T_co]):
         | 
| 43 | 
            +
                def __getitem__(self, __k: slice) -> _T_co: ...
         | 
| 44 | 
            +
             | 
| 45 | 
            +
            def chunked(
         | 
| 46 | 
            +
                iterable: Iterable[_T], n: int | None, strict: bool = ...
         | 
| 47 | 
            +
            ) -> Iterator[list[_T]]: ...
         | 
| 48 | 
            +
            @overload
         | 
| 49 | 
            +
            def first(iterable: Iterable[_T]) -> _T: ...
         | 
| 50 | 
            +
            @overload
         | 
| 51 | 
            +
            def first(iterable: Iterable[_T], default: _U) -> _T | _U: ...
         | 
| 52 | 
            +
            @overload
         | 
| 53 | 
            +
            def last(iterable: Iterable[_T]) -> _T: ...
         | 
| 54 | 
            +
            @overload
         | 
| 55 | 
            +
            def last(iterable: Iterable[_T], default: _U) -> _T | _U: ...
         | 
| 56 | 
            +
            @overload
         | 
| 57 | 
            +
            def nth_or_last(iterable: Iterable[_T], n: int) -> _T: ...
         | 
| 58 | 
            +
            @overload
         | 
| 59 | 
            +
            def nth_or_last(iterable: Iterable[_T], n: int, default: _U) -> _T | _U: ...
         | 
| 60 | 
            +
             | 
| 61 | 
            +
            class peekable(Generic[_T], Iterator[_T]):
         | 
| 62 | 
            +
                def __init__(self, iterable: Iterable[_T]) -> None: ...
         | 
| 63 | 
            +
                def __iter__(self) -> peekable[_T]: ...
         | 
| 64 | 
            +
                def __bool__(self) -> bool: ...
         | 
| 65 | 
            +
                @overload
         | 
| 66 | 
            +
                def peek(self) -> _T: ...
         | 
| 67 | 
            +
                @overload
         | 
| 68 | 
            +
                def peek(self, default: _U) -> _T | _U: ...
         | 
| 69 | 
            +
                def prepend(self, *items: _T) -> None: ...
         | 
| 70 | 
            +
                def __next__(self) -> _T: ...
         | 
| 71 | 
            +
                @overload
         | 
| 72 | 
            +
                def __getitem__(self, index: int) -> _T: ...
         | 
| 73 | 
            +
                @overload
         | 
| 74 | 
            +
                def __getitem__(self, index: slice) -> list[_T]: ...
         | 
| 75 | 
            +
             | 
| 76 | 
            +
            def consumer(func: _GenFn) -> _GenFn: ...
         | 
| 77 | 
            +
            def ilen(iterable: Iterable[object]) -> int: ...
         | 
| 78 | 
            +
            def iterate(func: Callable[[_T], _T], start: _T) -> Iterator[_T]: ...
         | 
| 79 | 
            +
            def with_iter(
         | 
| 80 | 
            +
                context_manager: ContextManager[Iterable[_T]],
         | 
| 81 | 
            +
            ) -> Iterator[_T]: ...
         | 
| 82 | 
            +
            def one(
         | 
| 83 | 
            +
                iterable: Iterable[_T],
         | 
| 84 | 
            +
                too_short: _Raisable | None = ...,
         | 
| 85 | 
            +
                too_long: _Raisable | None = ...,
         | 
| 86 | 
            +
            ) -> _T: ...
         | 
| 87 | 
            +
            def raise_(exception: _Raisable, *args: Any) -> None: ...
         | 
| 88 | 
            +
            def strictly_n(
         | 
| 89 | 
            +
                iterable: Iterable[_T],
         | 
| 90 | 
            +
                n: int,
         | 
| 91 | 
            +
                too_short: _GenFn | None = ...,
         | 
| 92 | 
            +
                too_long: _GenFn | None = ...,
         | 
| 93 | 
            +
            ) -> list[_T]: ...
         | 
| 94 | 
            +
            def distinct_permutations(
         | 
| 95 | 
            +
                iterable: Iterable[_T], r: int | None = ...
         | 
| 96 | 
            +
            ) -> Iterator[tuple[_T, ...]]: ...
         | 
| 97 | 
            +
            def intersperse(
         | 
| 98 | 
            +
                e: _U, iterable: Iterable[_T], n: int = ...
         | 
| 99 | 
            +
            ) -> Iterator[_T | _U]: ...
         | 
| 100 | 
            +
            def unique_to_each(*iterables: Iterable[_T]) -> list[list[_T]]: ...
         | 
| 101 | 
            +
            @overload
         | 
| 102 | 
            +
            def windowed(
         | 
| 103 | 
            +
                seq: Iterable[_T], n: int, *, step: int = ...
         | 
| 104 | 
            +
            ) -> Iterator[tuple[_T | None, ...]]: ...
         | 
| 105 | 
            +
            @overload
         | 
| 106 | 
            +
            def windowed(
         | 
| 107 | 
            +
                seq: Iterable[_T], n: int, fillvalue: _U, step: int = ...
         | 
| 108 | 
            +
            ) -> Iterator[tuple[_T | _U, ...]]: ...
         | 
| 109 | 
            +
            def substrings(iterable: Iterable[_T]) -> Iterator[tuple[_T, ...]]: ...
         | 
| 110 | 
            +
            def substrings_indexes(
         | 
| 111 | 
            +
                seq: Sequence[_T], reverse: bool = ...
         | 
| 112 | 
            +
            ) -> Iterator[tuple[Sequence[_T], int, int]]: ...
         | 
| 113 | 
            +
             | 
| 114 | 
            +
            class bucket(Generic[_T, _U], Container[_U]):
         | 
| 115 | 
            +
                def __init__(
         | 
| 116 | 
            +
                    self,
         | 
| 117 | 
            +
                    iterable: Iterable[_T],
         | 
| 118 | 
            +
                    key: Callable[[_T], _U],
         | 
| 119 | 
            +
                    validator: Callable[[object], object] | None = ...,
         | 
| 120 | 
            +
                ) -> None: ...
         | 
| 121 | 
            +
                def __contains__(self, value: object) -> bool: ...
         | 
| 122 | 
            +
                def __iter__(self) -> Iterator[_U]: ...
         | 
| 123 | 
            +
                def __getitem__(self, value: object) -> Iterator[_T]: ...
         | 
| 124 | 
            +
             | 
| 125 | 
            +
            def spy(
         | 
| 126 | 
            +
                iterable: Iterable[_T], n: int = ...
         | 
| 127 | 
            +
            ) -> tuple[list[_T], Iterator[_T]]: ...
         | 
| 128 | 
            +
            def interleave(*iterables: Iterable[_T]) -> Iterator[_T]: ...
         | 
| 129 | 
            +
            def interleave_longest(*iterables: Iterable[_T]) -> Iterator[_T]: ...
         | 
| 130 | 
            +
            def interleave_evenly(
         | 
| 131 | 
            +
                iterables: list[Iterable[_T]], lengths: list[int] | None = ...
         | 
| 132 | 
            +
            ) -> Iterator[_T]: ...
         | 
| 133 | 
            +
            def collapse(
         | 
| 134 | 
            +
                iterable: Iterable[Any],
         | 
| 135 | 
            +
                base_type: type | None = ...,
         | 
| 136 | 
            +
                levels: int | None = ...,
         | 
| 137 | 
            +
            ) -> Iterator[Any]: ...
         | 
| 138 | 
            +
            @overload
         | 
| 139 | 
            +
            def side_effect(
         | 
| 140 | 
            +
                func: Callable[[_T], object],
         | 
| 141 | 
            +
                iterable: Iterable[_T],
         | 
| 142 | 
            +
                chunk_size: None = ...,
         | 
| 143 | 
            +
                before: Callable[[], object] | None = ...,
         | 
| 144 | 
            +
                after: Callable[[], object] | None = ...,
         | 
| 145 | 
            +
            ) -> Iterator[_T]: ...
         | 
| 146 | 
            +
            @overload
         | 
| 147 | 
            +
            def side_effect(
         | 
| 148 | 
            +
                func: Callable[[list[_T]], object],
         | 
| 149 | 
            +
                iterable: Iterable[_T],
         | 
| 150 | 
            +
                chunk_size: int,
         | 
| 151 | 
            +
                before: Callable[[], object] | None = ...,
         | 
| 152 | 
            +
                after: Callable[[], object] | None = ...,
         | 
| 153 | 
            +
            ) -> Iterator[_T]: ...
         | 
| 154 | 
            +
            def sliced(
         | 
| 155 | 
            +
                seq: _SupportsSlicing[_T], n: int, strict: bool = ...
         | 
| 156 | 
            +
            ) -> Iterator[_T]: ...
         | 
| 157 | 
            +
            def split_at(
         | 
| 158 | 
            +
                iterable: Iterable[_T],
         | 
| 159 | 
            +
                pred: Callable[[_T], object],
         | 
| 160 | 
            +
                maxsplit: int = ...,
         | 
| 161 | 
            +
                keep_separator: bool = ...,
         | 
| 162 | 
            +
            ) -> Iterator[list[_T]]: ...
         | 
| 163 | 
            +
            def split_before(
         | 
| 164 | 
            +
                iterable: Iterable[_T], pred: Callable[[_T], object], maxsplit: int = ...
         | 
| 165 | 
            +
            ) -> Iterator[list[_T]]: ...
         | 
| 166 | 
            +
            def split_after(
         | 
| 167 | 
            +
                iterable: Iterable[_T], pred: Callable[[_T], object], maxsplit: int = ...
         | 
| 168 | 
            +
            ) -> Iterator[list[_T]]: ...
         | 
| 169 | 
            +
            def split_when(
         | 
| 170 | 
            +
                iterable: Iterable[_T],
         | 
| 171 | 
            +
                pred: Callable[[_T, _T], object],
         | 
| 172 | 
            +
                maxsplit: int = ...,
         | 
| 173 | 
            +
            ) -> Iterator[list[_T]]: ...
         | 
| 174 | 
            +
            def split_into(
         | 
| 175 | 
            +
                iterable: Iterable[_T], sizes: Iterable[int | None]
         | 
| 176 | 
            +
            ) -> Iterator[list[_T]]: ...
         | 
| 177 | 
            +
            @overload
         | 
| 178 | 
            +
            def padded(
         | 
| 179 | 
            +
                iterable: Iterable[_T],
         | 
| 180 | 
            +
                *,
         | 
| 181 | 
            +
                n: int | None = ...,
         | 
| 182 | 
            +
                next_multiple: bool = ...,
         | 
| 183 | 
            +
            ) -> Iterator[_T | None]: ...
         | 
| 184 | 
            +
            @overload
         | 
| 185 | 
            +
            def padded(
         | 
| 186 | 
            +
                iterable: Iterable[_T],
         | 
| 187 | 
            +
                fillvalue: _U,
         | 
| 188 | 
            +
                n: int | None = ...,
         | 
| 189 | 
            +
                next_multiple: bool = ...,
         | 
| 190 | 
            +
            ) -> Iterator[_T | _U]: ...
         | 
| 191 | 
            +
            @overload
         | 
| 192 | 
            +
            def repeat_last(iterable: Iterable[_T]) -> Iterator[_T]: ...
         | 
| 193 | 
            +
            @overload
         | 
| 194 | 
            +
            def repeat_last(iterable: Iterable[_T], default: _U) -> Iterator[_T | _U]: ...
         | 
| 195 | 
            +
            def distribute(n: int, iterable: Iterable[_T]) -> list[Iterator[_T]]: ...
         | 
| 196 | 
            +
            @overload
         | 
| 197 | 
            +
            def stagger(
         | 
| 198 | 
            +
                iterable: Iterable[_T],
         | 
| 199 | 
            +
                offsets: _SizedIterable[int] = ...,
         | 
| 200 | 
            +
                longest: bool = ...,
         | 
| 201 | 
            +
            ) -> Iterator[tuple[_T | None, ...]]: ...
         | 
| 202 | 
            +
            @overload
         | 
| 203 | 
            +
            def stagger(
         | 
| 204 | 
            +
                iterable: Iterable[_T],
         | 
| 205 | 
            +
                offsets: _SizedIterable[int] = ...,
         | 
| 206 | 
            +
                longest: bool = ...,
         | 
| 207 | 
            +
                fillvalue: _U = ...,
         | 
| 208 | 
            +
            ) -> Iterator[tuple[_T | _U, ...]]: ...
         | 
| 209 | 
            +
             | 
| 210 | 
            +
            class UnequalIterablesError(ValueError):
         | 
| 211 | 
            +
                def __init__(self, details: tuple[int, int, int] | None = ...) -> None: ...
         | 
| 212 | 
            +
             | 
| 213 | 
            +
            @overload
         | 
| 214 | 
            +
            def zip_equal(__iter1: Iterable[_T1]) -> Iterator[tuple[_T1]]: ...
         | 
| 215 | 
            +
            @overload
         | 
| 216 | 
            +
            def zip_equal(
         | 
| 217 | 
            +
                __iter1: Iterable[_T1], __iter2: Iterable[_T2]
         | 
| 218 | 
            +
            ) -> Iterator[tuple[_T1, _T2]]: ...
         | 
| 219 | 
            +
            @overload
         | 
| 220 | 
            +
            def zip_equal(
         | 
| 221 | 
            +
                __iter1: Iterable[_T],
         | 
| 222 | 
            +
                __iter2: Iterable[_T],
         | 
| 223 | 
            +
                __iter3: Iterable[_T],
         | 
| 224 | 
            +
                *iterables: Iterable[_T],
         | 
| 225 | 
            +
            ) -> Iterator[tuple[_T, ...]]: ...
         | 
| 226 | 
            +
            @overload
         | 
| 227 | 
            +
            def zip_offset(
         | 
| 228 | 
            +
                __iter1: Iterable[_T1],
         | 
| 229 | 
            +
                *,
         | 
| 230 | 
            +
                offsets: _SizedIterable[int],
         | 
| 231 | 
            +
                longest: bool = ...,
         | 
| 232 | 
            +
                fillvalue: None = None,
         | 
| 233 | 
            +
            ) -> Iterator[tuple[_T1 | None]]: ...
         | 
| 234 | 
            +
            @overload
         | 
| 235 | 
            +
            def zip_offset(
         | 
| 236 | 
            +
                __iter1: Iterable[_T1],
         | 
| 237 | 
            +
                __iter2: Iterable[_T2],
         | 
| 238 | 
            +
                *,
         | 
| 239 | 
            +
                offsets: _SizedIterable[int],
         | 
| 240 | 
            +
                longest: bool = ...,
         | 
| 241 | 
            +
                fillvalue: None = None,
         | 
| 242 | 
            +
            ) -> Iterator[tuple[_T1 | None, _T2 | None]]: ...
         | 
| 243 | 
            +
            @overload
         | 
| 244 | 
            +
            def zip_offset(
         | 
| 245 | 
            +
                __iter1: Iterable[_T],
         | 
| 246 | 
            +
                __iter2: Iterable[_T],
         | 
| 247 | 
            +
                __iter3: Iterable[_T],
         | 
| 248 | 
            +
                *iterables: Iterable[_T],
         | 
| 249 | 
            +
                offsets: _SizedIterable[int],
         | 
| 250 | 
            +
                longest: bool = ...,
         | 
| 251 | 
            +
                fillvalue: None = None,
         | 
| 252 | 
            +
            ) -> Iterator[tuple[_T | None, ...]]: ...
         | 
| 253 | 
            +
            @overload
         | 
| 254 | 
            +
            def zip_offset(
         | 
| 255 | 
            +
                __iter1: Iterable[_T1],
         | 
| 256 | 
            +
                *,
         | 
| 257 | 
            +
                offsets: _SizedIterable[int],
         | 
| 258 | 
            +
                longest: bool = ...,
         | 
| 259 | 
            +
                fillvalue: _U,
         | 
| 260 | 
            +
            ) -> Iterator[tuple[_T1 | _U]]: ...
         | 
| 261 | 
            +
            @overload
         | 
| 262 | 
            +
            def zip_offset(
         | 
| 263 | 
            +
                __iter1: Iterable[_T1],
         | 
| 264 | 
            +
                __iter2: Iterable[_T2],
         | 
| 265 | 
            +
                *,
         | 
| 266 | 
            +
                offsets: _SizedIterable[int],
         | 
| 267 | 
            +
                longest: bool = ...,
         | 
| 268 | 
            +
                fillvalue: _U,
         | 
| 269 | 
            +
            ) -> Iterator[tuple[_T1 | _U, _T2 | _U]]: ...
         | 
| 270 | 
            +
            @overload
         | 
| 271 | 
            +
            def zip_offset(
         | 
| 272 | 
            +
                __iter1: Iterable[_T],
         | 
| 273 | 
            +
                __iter2: Iterable[_T],
         | 
| 274 | 
            +
                __iter3: Iterable[_T],
         | 
| 275 | 
            +
                *iterables: Iterable[_T],
         | 
| 276 | 
            +
                offsets: _SizedIterable[int],
         | 
| 277 | 
            +
                longest: bool = ...,
         | 
| 278 | 
            +
                fillvalue: _U,
         | 
| 279 | 
            +
            ) -> Iterator[tuple[_T | _U, ...]]: ...
         | 
| 280 | 
            +
            def sort_together(
         | 
| 281 | 
            +
                iterables: Iterable[Iterable[_T]],
         | 
| 282 | 
            +
                key_list: Iterable[int] = ...,
         | 
| 283 | 
            +
                key: Callable[..., Any] | None = ...,
         | 
| 284 | 
            +
                reverse: bool = ...,
         | 
| 285 | 
            +
            ) -> list[tuple[_T, ...]]: ...
         | 
| 286 | 
            +
            def unzip(iterable: Iterable[Sequence[_T]]) -> tuple[Iterator[_T], ...]: ...
         | 
| 287 | 
            +
            def divide(n: int, iterable: Iterable[_T]) -> list[Iterator[_T]]: ...
         | 
| 288 | 
            +
            def always_iterable(
         | 
| 289 | 
            +
                obj: object,
         | 
| 290 | 
            +
                base_type: type | tuple[type | tuple[Any, ...], ...] | None = ...,
         | 
| 291 | 
            +
            ) -> Iterator[Any]: ...
         | 
| 292 | 
            +
            def adjacent(
         | 
| 293 | 
            +
                predicate: Callable[[_T], bool],
         | 
| 294 | 
            +
                iterable: Iterable[_T],
         | 
| 295 | 
            +
                distance: int = ...,
         | 
| 296 | 
            +
            ) -> Iterator[tuple[bool, _T]]: ...
         | 
| 297 | 
            +
            @overload
         | 
| 298 | 
            +
            def groupby_transform(
         | 
| 299 | 
            +
                iterable: Iterable[_T],
         | 
| 300 | 
            +
                keyfunc: None = None,
         | 
| 301 | 
            +
                valuefunc: None = None,
         | 
| 302 | 
            +
                reducefunc: None = None,
         | 
| 303 | 
            +
            ) -> Iterator[tuple[_T, Iterator[_T]]]: ...
         | 
| 304 | 
            +
            @overload
         | 
| 305 | 
            +
            def groupby_transform(
         | 
| 306 | 
            +
                iterable: Iterable[_T],
         | 
| 307 | 
            +
                keyfunc: Callable[[_T], _U],
         | 
| 308 | 
            +
                valuefunc: None,
         | 
| 309 | 
            +
                reducefunc: None,
         | 
| 310 | 
            +
            ) -> Iterator[tuple[_U, Iterator[_T]]]: ...
         | 
| 311 | 
            +
            @overload
         | 
| 312 | 
            +
            def groupby_transform(
         | 
| 313 | 
            +
                iterable: Iterable[_T],
         | 
| 314 | 
            +
                keyfunc: None,
         | 
| 315 | 
            +
                valuefunc: Callable[[_T], _V],
         | 
| 316 | 
            +
                reducefunc: None,
         | 
| 317 | 
            +
            ) -> Iterable[tuple[_T, Iterable[_V]]]: ...
         | 
| 318 | 
            +
            @overload
         | 
| 319 | 
            +
            def groupby_transform(
         | 
| 320 | 
            +
                iterable: Iterable[_T],
         | 
| 321 | 
            +
                keyfunc: Callable[[_T], _U],
         | 
| 322 | 
            +
                valuefunc: Callable[[_T], _V],
         | 
| 323 | 
            +
                reducefunc: None,
         | 
| 324 | 
            +
            ) -> Iterable[tuple[_U, Iterator[_V]]]: ...
         | 
| 325 | 
            +
            @overload
         | 
| 326 | 
            +
            def groupby_transform(
         | 
| 327 | 
            +
                iterable: Iterable[_T],
         | 
| 328 | 
            +
                keyfunc: None,
         | 
| 329 | 
            +
                valuefunc: None,
         | 
| 330 | 
            +
                reducefunc: Callable[[Iterator[_T]], _W],
         | 
| 331 | 
            +
            ) -> Iterable[tuple[_T, _W]]: ...
         | 
| 332 | 
            +
            @overload
         | 
| 333 | 
            +
            def groupby_transform(
         | 
| 334 | 
            +
                iterable: Iterable[_T],
         | 
| 335 | 
            +
                keyfunc: Callable[[_T], _U],
         | 
| 336 | 
            +
                valuefunc: None,
         | 
| 337 | 
            +
                reducefunc: Callable[[Iterator[_T]], _W],
         | 
| 338 | 
            +
            ) -> Iterable[tuple[_U, _W]]: ...
         | 
| 339 | 
            +
            @overload
         | 
| 340 | 
            +
            def groupby_transform(
         | 
| 341 | 
            +
                iterable: Iterable[_T],
         | 
| 342 | 
            +
                keyfunc: None,
         | 
| 343 | 
            +
                valuefunc: Callable[[_T], _V],
         | 
| 344 | 
            +
                reducefunc: Callable[[Iterable[_V]], _W],
         | 
| 345 | 
            +
            ) -> Iterable[tuple[_T, _W]]: ...
         | 
| 346 | 
            +
            @overload
         | 
| 347 | 
            +
            def groupby_transform(
         | 
| 348 | 
            +
                iterable: Iterable[_T],
         | 
| 349 | 
            +
                keyfunc: Callable[[_T], _U],
         | 
| 350 | 
            +
                valuefunc: Callable[[_T], _V],
         | 
| 351 | 
            +
                reducefunc: Callable[[Iterable[_V]], _W],
         | 
| 352 | 
            +
            ) -> Iterable[tuple[_U, _W]]: ...
         | 
| 353 | 
            +
             | 
| 354 | 
            +
            class numeric_range(Generic[_T, _U], Sequence[_T], Hashable, Reversible[_T]):
         | 
| 355 | 
            +
                @overload
         | 
| 356 | 
            +
                def __init__(self, __stop: _T) -> None: ...
         | 
| 357 | 
            +
                @overload
         | 
| 358 | 
            +
                def __init__(self, __start: _T, __stop: _T) -> None: ...
         | 
| 359 | 
            +
                @overload
         | 
| 360 | 
            +
                def __init__(self, __start: _T, __stop: _T, __step: _U) -> None: ...
         | 
| 361 | 
            +
                def __bool__(self) -> bool: ...
         | 
| 362 | 
            +
                def __contains__(self, elem: object) -> bool: ...
         | 
| 363 | 
            +
                def __eq__(self, other: object) -> bool: ...
         | 
| 364 | 
            +
                @overload
         | 
| 365 | 
            +
                def __getitem__(self, key: int) -> _T: ...
         | 
| 366 | 
            +
                @overload
         | 
| 367 | 
            +
                def __getitem__(self, key: slice) -> numeric_range[_T, _U]: ...
         | 
| 368 | 
            +
                def __hash__(self) -> int: ...
         | 
| 369 | 
            +
                def __iter__(self) -> Iterator[_T]: ...
         | 
| 370 | 
            +
                def __len__(self) -> int: ...
         | 
| 371 | 
            +
                def __reduce__(
         | 
| 372 | 
            +
                    self,
         | 
| 373 | 
            +
                ) -> tuple[Type[numeric_range[_T, _U]], tuple[_T, _T, _U]]: ...
         | 
| 374 | 
            +
                def __repr__(self) -> str: ...
         | 
| 375 | 
            +
                def __reversed__(self) -> Iterator[_T]: ...
         | 
| 376 | 
            +
                def count(self, value: _T) -> int: ...
         | 
| 377 | 
            +
                def index(self, value: _T) -> int: ...  # type: ignore
         | 
| 378 | 
            +
             | 
| 379 | 
            +
            def count_cycle(
         | 
| 380 | 
            +
                iterable: Iterable[_T], n: int | None = ...
         | 
| 381 | 
            +
            ) -> Iterable[tuple[int, _T]]: ...
         | 
| 382 | 
            +
            def mark_ends(
         | 
| 383 | 
            +
                iterable: Iterable[_T],
         | 
| 384 | 
            +
            ) -> Iterable[tuple[bool, bool, _T]]: ...
         | 
| 385 | 
            +
            def locate(
         | 
| 386 | 
            +
                iterable: Iterable[object],
         | 
| 387 | 
            +
                pred: Callable[..., Any] = ...,
         | 
| 388 | 
            +
                window_size: int | None = ...,
         | 
| 389 | 
            +
            ) -> Iterator[int]: ...
         | 
| 390 | 
            +
            def lstrip(
         | 
| 391 | 
            +
                iterable: Iterable[_T], pred: Callable[[_T], object]
         | 
| 392 | 
            +
            ) -> Iterator[_T]: ...
         | 
| 393 | 
            +
            def rstrip(
         | 
| 394 | 
            +
                iterable: Iterable[_T], pred: Callable[[_T], object]
         | 
| 395 | 
            +
            ) -> Iterator[_T]: ...
         | 
| 396 | 
            +
            def strip(
         | 
| 397 | 
            +
                iterable: Iterable[_T], pred: Callable[[_T], object]
         | 
| 398 | 
            +
            ) -> Iterator[_T]: ...
         | 
| 399 | 
            +
             | 
| 400 | 
            +
            class islice_extended(Generic[_T], Iterator[_T]):
         | 
| 401 | 
            +
                def __init__(self, iterable: Iterable[_T], *args: int | None) -> None: ...
         | 
| 402 | 
            +
                def __iter__(self) -> islice_extended[_T]: ...
         | 
| 403 | 
            +
                def __next__(self) -> _T: ...
         | 
| 404 | 
            +
                def __getitem__(self, index: slice) -> islice_extended[_T]: ...
         | 
| 405 | 
            +
             | 
| 406 | 
            +
            def always_reversible(iterable: Iterable[_T]) -> Iterator[_T]: ...
         | 
| 407 | 
            +
            def consecutive_groups(
         | 
| 408 | 
            +
                iterable: Iterable[_T], ordering: Callable[[_T], int] = ...
         | 
| 409 | 
            +
            ) -> Iterator[Iterator[_T]]: ...
         | 
| 410 | 
            +
            @overload
         | 
| 411 | 
            +
            def difference(
         | 
| 412 | 
            +
                iterable: Iterable[_T],
         | 
| 413 | 
            +
                func: Callable[[_T, _T], _U] = ...,
         | 
| 414 | 
            +
                *,
         | 
| 415 | 
            +
                initial: None = ...,
         | 
| 416 | 
            +
            ) -> Iterator[_T | _U]: ...
         | 
| 417 | 
            +
            @overload
         | 
| 418 | 
            +
            def difference(
         | 
| 419 | 
            +
                iterable: Iterable[_T], func: Callable[[_T, _T], _U] = ..., *, initial: _U
         | 
| 420 | 
            +
            ) -> Iterator[_U]: ...
         | 
| 421 | 
            +
             | 
| 422 | 
            +
            class SequenceView(Generic[_T], Sequence[_T]):
         | 
| 423 | 
            +
                def __init__(self, target: Sequence[_T]) -> None: ...
         | 
| 424 | 
            +
                @overload
         | 
| 425 | 
            +
                def __getitem__(self, index: int) -> _T: ...
         | 
| 426 | 
            +
                @overload
         | 
| 427 | 
            +
                def __getitem__(self, index: slice) -> Sequence[_T]: ...
         | 
| 428 | 
            +
                def __len__(self) -> int: ...
         | 
| 429 | 
            +
             | 
| 430 | 
            +
            class seekable(Generic[_T], Iterator[_T]):
         | 
| 431 | 
            +
                def __init__(
         | 
| 432 | 
            +
                    self, iterable: Iterable[_T], maxlen: int | None = ...
         | 
| 433 | 
            +
                ) -> None: ...
         | 
| 434 | 
            +
                def __iter__(self) -> seekable[_T]: ...
         | 
| 435 | 
            +
                def __next__(self) -> _T: ...
         | 
| 436 | 
            +
                def __bool__(self) -> bool: ...
         | 
| 437 | 
            +
                @overload
         | 
| 438 | 
            +
                def peek(self) -> _T: ...
         | 
| 439 | 
            +
                @overload
         | 
| 440 | 
            +
                def peek(self, default: _U) -> _T | _U: ...
         | 
| 441 | 
            +
                def elements(self) -> SequenceView[_T]: ...
         | 
| 442 | 
            +
                def seek(self, index: int) -> None: ...
         | 
| 443 | 
            +
                def relative_seek(self, count: int) -> None: ...
         | 
| 444 | 
            +
             | 
| 445 | 
            +
            class run_length:
         | 
| 446 | 
            +
                @staticmethod
         | 
| 447 | 
            +
                def encode(iterable: Iterable[_T]) -> Iterator[tuple[_T, int]]: ...
         | 
| 448 | 
            +
                @staticmethod
         | 
| 449 | 
            +
                def decode(iterable: Iterable[tuple[_T, int]]) -> Iterator[_T]: ...
         | 
| 450 | 
            +
             | 
| 451 | 
            +
            def exactly_n(
         | 
| 452 | 
            +
                iterable: Iterable[_T], n: int, predicate: Callable[[_T], object] = ...
         | 
| 453 | 
            +
            ) -> bool: ...
         | 
| 454 | 
            +
            def circular_shifts(iterable: Iterable[_T]) -> list[tuple[_T, ...]]: ...
         | 
| 455 | 
            +
            def make_decorator(
         | 
| 456 | 
            +
                wrapping_func: Callable[..., _U], result_index: int = ...
         | 
| 457 | 
            +
            ) -> Callable[..., Callable[[Callable[..., Any]], Callable[..., _U]]]: ...
         | 
| 458 | 
            +
            @overload
         | 
| 459 | 
            +
            def map_reduce(
         | 
| 460 | 
            +
                iterable: Iterable[_T],
         | 
| 461 | 
            +
                keyfunc: Callable[[_T], _U],
         | 
| 462 | 
            +
                valuefunc: None = ...,
         | 
| 463 | 
            +
                reducefunc: None = ...,
         | 
| 464 | 
            +
            ) -> dict[_U, list[_T]]: ...
         | 
| 465 | 
            +
            @overload
         | 
| 466 | 
            +
            def map_reduce(
         | 
| 467 | 
            +
                iterable: Iterable[_T],
         | 
| 468 | 
            +
                keyfunc: Callable[[_T], _U],
         | 
| 469 | 
            +
                valuefunc: Callable[[_T], _V],
         | 
| 470 | 
            +
                reducefunc: None = ...,
         | 
| 471 | 
            +
            ) -> dict[_U, list[_V]]: ...
         | 
| 472 | 
            +
            @overload
         | 
| 473 | 
            +
            def map_reduce(
         | 
| 474 | 
            +
                iterable: Iterable[_T],
         | 
| 475 | 
            +
                keyfunc: Callable[[_T], _U],
         | 
| 476 | 
            +
                valuefunc: None = ...,
         | 
| 477 | 
            +
                reducefunc: Callable[[list[_T]], _W] = ...,
         | 
| 478 | 
            +
            ) -> dict[_U, _W]: ...
         | 
| 479 | 
            +
            @overload
         | 
| 480 | 
            +
            def map_reduce(
         | 
| 481 | 
            +
                iterable: Iterable[_T],
         | 
| 482 | 
            +
                keyfunc: Callable[[_T], _U],
         | 
| 483 | 
            +
                valuefunc: Callable[[_T], _V],
         | 
| 484 | 
            +
                reducefunc: Callable[[list[_V]], _W],
         | 
| 485 | 
            +
            ) -> dict[_U, _W]: ...
         | 
| 486 | 
            +
            def rlocate(
         | 
| 487 | 
            +
                iterable: Iterable[_T],
         | 
| 488 | 
            +
                pred: Callable[..., object] = ...,
         | 
| 489 | 
            +
                window_size: int | None = ...,
         | 
| 490 | 
            +
            ) -> Iterator[int]: ...
         | 
| 491 | 
            +
            def replace(
         | 
| 492 | 
            +
                iterable: Iterable[_T],
         | 
| 493 | 
            +
                pred: Callable[..., object],
         | 
| 494 | 
            +
                substitutes: Iterable[_U],
         | 
| 495 | 
            +
                count: int | None = ...,
         | 
| 496 | 
            +
                window_size: int = ...,
         | 
| 497 | 
            +
            ) -> Iterator[_T | _U]: ...
         | 
| 498 | 
            +
            def partitions(iterable: Iterable[_T]) -> Iterator[list[list[_T]]]: ...
         | 
| 499 | 
            +
            def set_partitions(
         | 
| 500 | 
            +
                iterable: Iterable[_T], k: int | None = ...
         | 
| 501 | 
            +
            ) -> Iterator[list[list[_T]]]: ...
         | 
| 502 | 
            +
             | 
| 503 | 
            +
            class time_limited(Generic[_T], Iterator[_T]):
         | 
| 504 | 
            +
                def __init__(
         | 
| 505 | 
            +
                    self, limit_seconds: float, iterable: Iterable[_T]
         | 
| 506 | 
            +
                ) -> None: ...
         | 
| 507 | 
            +
                def __iter__(self) -> islice_extended[_T]: ...
         | 
| 508 | 
            +
                def __next__(self) -> _T: ...
         | 
| 509 | 
            +
             | 
| 510 | 
            +
            @overload
         | 
| 511 | 
            +
            def only(
         | 
| 512 | 
            +
                iterable: Iterable[_T], *, too_long: _Raisable | None = ...
         | 
| 513 | 
            +
            ) -> _T | None: ...
         | 
| 514 | 
            +
            @overload
         | 
| 515 | 
            +
            def only(
         | 
| 516 | 
            +
                iterable: Iterable[_T], default: _U, too_long: _Raisable | None = ...
         | 
| 517 | 
            +
            ) -> _T | _U: ...
         | 
| 518 | 
            +
            def ichunked(iterable: Iterable[_T], n: int) -> Iterator[Iterator[_T]]: ...
         | 
| 519 | 
            +
            def distinct_combinations(
         | 
| 520 | 
            +
                iterable: Iterable[_T], r: int
         | 
| 521 | 
            +
            ) -> Iterator[tuple[_T, ...]]: ...
         | 
| 522 | 
            +
            def filter_except(
         | 
| 523 | 
            +
                validator: Callable[[Any], object],
         | 
| 524 | 
            +
                iterable: Iterable[_T],
         | 
| 525 | 
            +
                *exceptions: Type[BaseException],
         | 
| 526 | 
            +
            ) -> Iterator[_T]: ...
         | 
| 527 | 
            +
            def map_except(
         | 
| 528 | 
            +
                function: Callable[[Any], _U],
         | 
| 529 | 
            +
                iterable: Iterable[_T],
         | 
| 530 | 
            +
                *exceptions: Type[BaseException],
         | 
| 531 | 
            +
            ) -> Iterator[_U]: ...
         | 
| 532 | 
            +
            def map_if(
         | 
| 533 | 
            +
                iterable: Iterable[Any],
         | 
| 534 | 
            +
                pred: Callable[[Any], bool],
         | 
| 535 | 
            +
                func: Callable[[Any], Any],
         | 
| 536 | 
            +
                func_else: Callable[[Any], Any] | None = ...,
         | 
| 537 | 
            +
            ) -> Iterator[Any]: ...
         | 
| 538 | 
            +
            def sample(
         | 
| 539 | 
            +
                iterable: Iterable[_T],
         | 
| 540 | 
            +
                k: int,
         | 
| 541 | 
            +
                weights: Iterable[float] | None = ...,
         | 
| 542 | 
            +
            ) -> list[_T]: ...
         | 
| 543 | 
            +
            def is_sorted(
         | 
| 544 | 
            +
                iterable: Iterable[_T],
         | 
| 545 | 
            +
                key: Callable[[_T], _U] | None = ...,
         | 
| 546 | 
            +
                reverse: bool = False,
         | 
| 547 | 
            +
                strict: bool = False,
         | 
| 548 | 
            +
            ) -> bool: ...
         | 
| 549 | 
            +
             | 
| 550 | 
            +
            class AbortThread(BaseException):
         | 
| 551 | 
            +
                pass
         | 
| 552 | 
            +
             | 
| 553 | 
            +
            class callback_iter(Generic[_T], Iterator[_T]):
         | 
| 554 | 
            +
                def __init__(
         | 
| 555 | 
            +
                    self,
         | 
| 556 | 
            +
                    func: Callable[..., Any],
         | 
| 557 | 
            +
                    callback_kwd: str = ...,
         | 
| 558 | 
            +
                    wait_seconds: float = ...,
         | 
| 559 | 
            +
                ) -> None: ...
         | 
| 560 | 
            +
                def __enter__(self) -> callback_iter[_T]: ...
         | 
| 561 | 
            +
                def __exit__(
         | 
| 562 | 
            +
                    self,
         | 
| 563 | 
            +
                    exc_type: Type[BaseException] | None,
         | 
| 564 | 
            +
                    exc_value: BaseException | None,
         | 
| 565 | 
            +
                    traceback: TracebackType | None,
         | 
| 566 | 
            +
                ) -> bool | None: ...
         | 
| 567 | 
            +
                def __iter__(self) -> callback_iter[_T]: ...
         | 
| 568 | 
            +
                def __next__(self) -> _T: ...
         | 
| 569 | 
            +
                def _reader(self) -> Iterator[_T]: ...
         | 
| 570 | 
            +
                @property
         | 
| 571 | 
            +
                def done(self) -> bool: ...
         | 
| 572 | 
            +
                @property
         | 
| 573 | 
            +
                def result(self) -> Any: ...
         | 
| 574 | 
            +
             | 
| 575 | 
            +
            def windowed_complete(
         | 
| 576 | 
            +
                iterable: Iterable[_T], n: int
         | 
| 577 | 
            +
            ) -> Iterator[tuple[_T, ...]]: ...
         | 
| 578 | 
            +
            def all_unique(
         | 
| 579 | 
            +
                iterable: Iterable[_T], key: Callable[[_T], _U] | None = ...
         | 
| 580 | 
            +
            ) -> bool: ...
         | 
| 581 | 
            +
            def nth_product(index: int, *args: Iterable[_T]) -> tuple[_T, ...]: ...
         | 
| 582 | 
            +
            def nth_combination_with_replacement(
         | 
| 583 | 
            +
                iterable: Iterable[_T], r: int, index: int
         | 
| 584 | 
            +
            ) -> tuple[_T, ...]: ...
         | 
| 585 | 
            +
            def nth_permutation(
         | 
| 586 | 
            +
                iterable: Iterable[_T], r: int, index: int
         | 
| 587 | 
            +
            ) -> tuple[_T, ...]: ...
         | 
| 588 | 
            +
            def value_chain(*args: _T | Iterable[_T]) -> Iterable[_T]: ...
         | 
| 589 | 
            +
            def product_index(element: Iterable[_T], *args: Iterable[_T]) -> int: ...
         | 
| 590 | 
            +
            def combination_index(
         | 
| 591 | 
            +
                element: Iterable[_T], iterable: Iterable[_T]
         | 
| 592 | 
            +
            ) -> int: ...
         | 
| 593 | 
            +
            def combination_with_replacement_index(
         | 
| 594 | 
            +
                element: Iterable[_T], iterable: Iterable[_T]
         | 
| 595 | 
            +
            ) -> int: ...
         | 
| 596 | 
            +
            def permutation_index(
         | 
| 597 | 
            +
                element: Iterable[_T], iterable: Iterable[_T]
         | 
| 598 | 
            +
            ) -> int: ...
         | 
| 599 | 
            +
            def repeat_each(iterable: Iterable[_T], n: int = ...) -> Iterator[_T]: ...
         | 
| 600 | 
            +
             | 
| 601 | 
            +
            class countable(Generic[_T], Iterator[_T]):
         | 
| 602 | 
            +
                def __init__(self, iterable: Iterable[_T]) -> None: ...
         | 
| 603 | 
            +
                def __iter__(self) -> countable[_T]: ...
         | 
| 604 | 
            +
                def __next__(self) -> _T: ...
         | 
| 605 | 
            +
             | 
| 606 | 
            +
            def chunked_even(iterable: Iterable[_T], n: int) -> Iterator[list[_T]]: ...
         | 
| 607 | 
            +
            def zip_broadcast(
         | 
| 608 | 
            +
                *objects: _T | Iterable[_T],
         | 
| 609 | 
            +
                scalar_types: type | tuple[type | tuple[Any, ...], ...] | None = ...,
         | 
| 610 | 
            +
                strict: bool = ...,
         | 
| 611 | 
            +
            ) -> Iterable[tuple[_T, ...]]: ...
         | 
| 612 | 
            +
            def unique_in_window(
         | 
| 613 | 
            +
                iterable: Iterable[_T], n: int, key: Callable[[_T], _U] | None = ...
         | 
| 614 | 
            +
            ) -> Iterator[_T]: ...
         | 
| 615 | 
            +
            def duplicates_everseen(
         | 
| 616 | 
            +
                iterable: Iterable[_T], key: Callable[[_T], _U] | None = ...
         | 
| 617 | 
            +
            ) -> Iterator[_T]: ...
         | 
| 618 | 
            +
            def duplicates_justseen(
         | 
| 619 | 
            +
                iterable: Iterable[_T], key: Callable[[_T], _U] | None = ...
         | 
| 620 | 
            +
            ) -> Iterator[_T]: ...
         | 
| 621 | 
            +
             | 
| 622 | 
            +
            class _SupportsLessThan(Protocol):
         | 
| 623 | 
            +
                def __lt__(self, __other: Any) -> bool: ...
         | 
| 624 | 
            +
             | 
| 625 | 
            +
            _SupportsLessThanT = TypeVar("_SupportsLessThanT", bound=_SupportsLessThan)
         | 
| 626 | 
            +
             | 
| 627 | 
            +
            @overload
         | 
| 628 | 
            +
            def minmax(
         | 
| 629 | 
            +
                iterable_or_value: Iterable[_SupportsLessThanT], *, key: None = None
         | 
| 630 | 
            +
            ) -> tuple[_SupportsLessThanT, _SupportsLessThanT]: ...
         | 
| 631 | 
            +
            @overload
         | 
| 632 | 
            +
            def minmax(
         | 
| 633 | 
            +
                iterable_or_value: Iterable[_T], *, key: Callable[[_T], _SupportsLessThan]
         | 
| 634 | 
            +
            ) -> tuple[_T, _T]: ...
         | 
| 635 | 
            +
            @overload
         | 
| 636 | 
            +
            def minmax(
         | 
| 637 | 
            +
                iterable_or_value: Iterable[_SupportsLessThanT],
         | 
| 638 | 
            +
                *,
         | 
| 639 | 
            +
                key: None = None,
         | 
| 640 | 
            +
                default: _U,
         | 
| 641 | 
            +
            ) -> _U | tuple[_SupportsLessThanT, _SupportsLessThanT]: ...
         | 
| 642 | 
            +
            @overload
         | 
| 643 | 
            +
            def minmax(
         | 
| 644 | 
            +
                iterable_or_value: Iterable[_T],
         | 
| 645 | 
            +
                *,
         | 
| 646 | 
            +
                key: Callable[[_T], _SupportsLessThan],
         | 
| 647 | 
            +
                default: _U,
         | 
| 648 | 
            +
            ) -> _U | tuple[_T, _T]: ...
         | 
| 649 | 
            +
            @overload
         | 
| 650 | 
            +
            def minmax(
         | 
| 651 | 
            +
                iterable_or_value: _SupportsLessThanT,
         | 
| 652 | 
            +
                __other: _SupportsLessThanT,
         | 
| 653 | 
            +
                *others: _SupportsLessThanT,
         | 
| 654 | 
            +
            ) -> tuple[_SupportsLessThanT, _SupportsLessThanT]: ...
         | 
| 655 | 
            +
            @overload
         | 
| 656 | 
            +
            def minmax(
         | 
| 657 | 
            +
                iterable_or_value: _T,
         | 
| 658 | 
            +
                __other: _T,
         | 
| 659 | 
            +
                *others: _T,
         | 
| 660 | 
            +
                key: Callable[[_T], _SupportsLessThan],
         | 
| 661 | 
            +
            ) -> tuple[_T, _T]: ...
         | 
| 662 | 
            +
            def longest_common_prefix(
         | 
| 663 | 
            +
                iterables: Iterable[Iterable[_T]],
         | 
| 664 | 
            +
            ) -> Iterator[_T]: ...
         | 
| 665 | 
            +
            def iequals(*iterables: Iterable[object]) -> bool: ...
         | 
| 666 | 
            +
            def constrained_batches(
         | 
| 667 | 
            +
                iterable: Iterable[object],
         | 
| 668 | 
            +
                max_size: int,
         | 
| 669 | 
            +
                max_count: int | None = ...,
         | 
| 670 | 
            +
                get_len: Callable[[_T], object] = ...,
         | 
| 671 | 
            +
                strict: bool = ...,
         | 
| 672 | 
            +
            ) -> Iterator[tuple[_T]]: ...
         | 
| 673 | 
            +
            def gray_product(*iterables: Iterable[_T]) -> Iterator[tuple[_T, ...]]: ...
         | 
| 674 | 
            +
            def partial_product(*iterables: Iterable[_T]) -> Iterator[tuple[_T, ...]]: ...
         | 
| 675 | 
            +
            def takewhile_inclusive(
         | 
| 676 | 
            +
                predicate: Callable[[_T], bool], iterable: Iterable[_T]
         | 
| 677 | 
            +
            ) -> Iterator[_T]: ...
         | 
| 678 | 
            +
            def outer_product(
         | 
| 679 | 
            +
                func: Callable[[_T, _U], _V],
         | 
| 680 | 
            +
                xs: Iterable[_T],
         | 
| 681 | 
            +
                ys: Iterable[_U],
         | 
| 682 | 
            +
                *args: Any,
         | 
| 683 | 
            +
                **kwargs: Any,
         | 
| 684 | 
            +
            ) -> Iterator[tuple[_V, ...]]: ...
         | 
    	
        lib/python3.11/site-packages/more_itertools/py.typed
    ADDED
    
    | 
            File without changes
         | 
    	
        lib/python3.11/site-packages/more_itertools/recipes.py
    ADDED
    
    | @@ -0,0 +1,977 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """Imported from the recipes section of the itertools documentation.
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            All functions taken from the recipes section of the itertools library docs
         | 
| 4 | 
            +
            [1]_.
         | 
| 5 | 
            +
            Some backward-compatible usability improvements have been made.
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            .. [1] http://docs.python.org/library/itertools.html#recipes
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            """
         | 
| 10 | 
            +
            import math
         | 
| 11 | 
            +
            import operator
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from collections import deque
         | 
| 14 | 
            +
            from collections.abc import Sized
         | 
| 15 | 
            +
            from functools import partial, reduce
         | 
| 16 | 
            +
            from itertools import (
         | 
| 17 | 
            +
                chain,
         | 
| 18 | 
            +
                combinations,
         | 
| 19 | 
            +
                compress,
         | 
| 20 | 
            +
                count,
         | 
| 21 | 
            +
                cycle,
         | 
| 22 | 
            +
                groupby,
         | 
| 23 | 
            +
                islice,
         | 
| 24 | 
            +
                product,
         | 
| 25 | 
            +
                repeat,
         | 
| 26 | 
            +
                starmap,
         | 
| 27 | 
            +
                tee,
         | 
| 28 | 
            +
                zip_longest,
         | 
| 29 | 
            +
            )
         | 
| 30 | 
            +
            from random import randrange, sample, choice
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            __all__ = [
         | 
| 33 | 
            +
                'all_equal',
         | 
| 34 | 
            +
                'batched',
         | 
| 35 | 
            +
                'before_and_after',
         | 
| 36 | 
            +
                'consume',
         | 
| 37 | 
            +
                'convolve',
         | 
| 38 | 
            +
                'dotproduct',
         | 
| 39 | 
            +
                'first_true',
         | 
| 40 | 
            +
                'factor',
         | 
| 41 | 
            +
                'flatten',
         | 
| 42 | 
            +
                'grouper',
         | 
| 43 | 
            +
                'iter_except',
         | 
| 44 | 
            +
                'iter_index',
         | 
| 45 | 
            +
                'matmul',
         | 
| 46 | 
            +
                'ncycles',
         | 
| 47 | 
            +
                'nth',
         | 
| 48 | 
            +
                'nth_combination',
         | 
| 49 | 
            +
                'padnone',
         | 
| 50 | 
            +
                'pad_none',
         | 
| 51 | 
            +
                'pairwise',
         | 
| 52 | 
            +
                'partition',
         | 
| 53 | 
            +
                'polynomial_eval',
         | 
| 54 | 
            +
                'polynomial_from_roots',
         | 
| 55 | 
            +
                'polynomial_derivative',
         | 
| 56 | 
            +
                'powerset',
         | 
| 57 | 
            +
                'prepend',
         | 
| 58 | 
            +
                'quantify',
         | 
| 59 | 
            +
                'random_combination_with_replacement',
         | 
| 60 | 
            +
                'random_combination',
         | 
| 61 | 
            +
                'random_permutation',
         | 
| 62 | 
            +
                'random_product',
         | 
| 63 | 
            +
                'repeatfunc',
         | 
| 64 | 
            +
                'roundrobin',
         | 
| 65 | 
            +
                'sieve',
         | 
| 66 | 
            +
                'sliding_window',
         | 
| 67 | 
            +
                'subslices',
         | 
| 68 | 
            +
                'sum_of_squares',
         | 
| 69 | 
            +
                'tabulate',
         | 
| 70 | 
            +
                'tail',
         | 
| 71 | 
            +
                'take',
         | 
| 72 | 
            +
                'transpose',
         | 
| 73 | 
            +
                'triplewise',
         | 
| 74 | 
            +
                'unique_everseen',
         | 
| 75 | 
            +
                'unique_justseen',
         | 
| 76 | 
            +
            ]
         | 
| 77 | 
            +
             | 
| 78 | 
            +
            _marker = object()
         | 
| 79 | 
            +
             | 
| 80 | 
            +
             | 
| 81 | 
            +
            # zip with strict is available for Python 3.10+
         | 
| 82 | 
            +
            try:
         | 
| 83 | 
            +
                zip(strict=True)
         | 
| 84 | 
            +
            except TypeError:
         | 
| 85 | 
            +
                _zip_strict = zip
         | 
| 86 | 
            +
            else:
         | 
| 87 | 
            +
                _zip_strict = partial(zip, strict=True)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
            # math.sumprod is available for Python 3.12+
         | 
| 90 | 
            +
            _sumprod = getattr(math, 'sumprod', lambda x, y: dotproduct(x, y))
         | 
| 91 | 
            +
             | 
| 92 | 
            +
             | 
| 93 | 
            +
            def take(n, iterable):
         | 
| 94 | 
            +
                """Return first *n* items of the iterable as a list.
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                    >>> take(3, range(10))
         | 
| 97 | 
            +
                    [0, 1, 2]
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                If there are fewer than *n* items in the iterable, all of them are
         | 
| 100 | 
            +
                returned.
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    >>> take(10, range(3))
         | 
| 103 | 
            +
                    [0, 1, 2]
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                """
         | 
| 106 | 
            +
                return list(islice(iterable, n))
         | 
| 107 | 
            +
             | 
| 108 | 
            +
             | 
| 109 | 
            +
            def tabulate(function, start=0):
         | 
| 110 | 
            +
                """Return an iterator over the results of ``func(start)``,
         | 
| 111 | 
            +
                ``func(start + 1)``, ``func(start + 2)``...
         | 
| 112 | 
            +
             | 
| 113 | 
            +
                *func* should be a function that accepts one integer argument.
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                If *start* is not specified it defaults to 0. It will be incremented each
         | 
| 116 | 
            +
                time the iterator is advanced.
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    >>> square = lambda x: x ** 2
         | 
| 119 | 
            +
                    >>> iterator = tabulate(square, -3)
         | 
| 120 | 
            +
                    >>> take(4, iterator)
         | 
| 121 | 
            +
                    [9, 4, 1, 0]
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                """
         | 
| 124 | 
            +
                return map(function, count(start))
         | 
| 125 | 
            +
             | 
| 126 | 
            +
             | 
| 127 | 
            +
            def tail(n, iterable):
         | 
| 128 | 
            +
                """Return an iterator over the last *n* items of *iterable*.
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                >>> t = tail(3, 'ABCDEFG')
         | 
| 131 | 
            +
                >>> list(t)
         | 
| 132 | 
            +
                ['E', 'F', 'G']
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                """
         | 
| 135 | 
            +
                # If the given iterable has a length, then we can use islice to get its
         | 
| 136 | 
            +
                # final elements. Note that if the iterable is not actually Iterable,
         | 
| 137 | 
            +
                # either islice or deque will throw a TypeError. This is why we don't
         | 
| 138 | 
            +
                # check if it is Iterable.
         | 
| 139 | 
            +
                if isinstance(iterable, Sized):
         | 
| 140 | 
            +
                    yield from islice(iterable, max(0, len(iterable) - n), None)
         | 
| 141 | 
            +
                else:
         | 
| 142 | 
            +
                    yield from iter(deque(iterable, maxlen=n))
         | 
| 143 | 
            +
             | 
| 144 | 
            +
             | 
| 145 | 
            +
            def consume(iterator, n=None):
         | 
| 146 | 
            +
                """Advance *iterable* by *n* steps. If *n* is ``None``, consume it
         | 
| 147 | 
            +
                entirely.
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                Efficiently exhausts an iterator without returning values. Defaults to
         | 
| 150 | 
            +
                consuming the whole iterator, but an optional second argument may be
         | 
| 151 | 
            +
                provided to limit consumption.
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                    >>> i = (x for x in range(10))
         | 
| 154 | 
            +
                    >>> next(i)
         | 
| 155 | 
            +
                    0
         | 
| 156 | 
            +
                    >>> consume(i, 3)
         | 
| 157 | 
            +
                    >>> next(i)
         | 
| 158 | 
            +
                    4
         | 
| 159 | 
            +
                    >>> consume(i)
         | 
| 160 | 
            +
                    >>> next(i)
         | 
| 161 | 
            +
                    Traceback (most recent call last):
         | 
| 162 | 
            +
                      File "<stdin>", line 1, in <module>
         | 
| 163 | 
            +
                    StopIteration
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                If the iterator has fewer items remaining than the provided limit, the
         | 
| 166 | 
            +
                whole iterator will be consumed.
         | 
| 167 | 
            +
             | 
| 168 | 
            +
                    >>> i = (x for x in range(3))
         | 
| 169 | 
            +
                    >>> consume(i, 5)
         | 
| 170 | 
            +
                    >>> next(i)
         | 
| 171 | 
            +
                    Traceback (most recent call last):
         | 
| 172 | 
            +
                      File "<stdin>", line 1, in <module>
         | 
| 173 | 
            +
                    StopIteration
         | 
| 174 | 
            +
             | 
| 175 | 
            +
                """
         | 
| 176 | 
            +
                # Use functions that consume iterators at C speed.
         | 
| 177 | 
            +
                if n is None:
         | 
| 178 | 
            +
                    # feed the entire iterator into a zero-length deque
         | 
| 179 | 
            +
                    deque(iterator, maxlen=0)
         | 
| 180 | 
            +
                else:
         | 
| 181 | 
            +
                    # advance to the empty slice starting at position n
         | 
| 182 | 
            +
                    next(islice(iterator, n, n), None)
         | 
| 183 | 
            +
             | 
| 184 | 
            +
             | 
| 185 | 
            +
            def nth(iterable, n, default=None):
         | 
| 186 | 
            +
                """Returns the nth item or a default value.
         | 
| 187 | 
            +
             | 
| 188 | 
            +
                >>> l = range(10)
         | 
| 189 | 
            +
                >>> nth(l, 3)
         | 
| 190 | 
            +
                3
         | 
| 191 | 
            +
                >>> nth(l, 20, "zebra")
         | 
| 192 | 
            +
                'zebra'
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                """
         | 
| 195 | 
            +
                return next(islice(iterable, n, None), default)
         | 
| 196 | 
            +
             | 
| 197 | 
            +
             | 
| 198 | 
            +
            def all_equal(iterable):
         | 
| 199 | 
            +
                """
         | 
| 200 | 
            +
                Returns ``True`` if all the elements are equal to each other.
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    >>> all_equal('aaaa')
         | 
| 203 | 
            +
                    True
         | 
| 204 | 
            +
                    >>> all_equal('aaab')
         | 
| 205 | 
            +
                    False
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                """
         | 
| 208 | 
            +
                g = groupby(iterable)
         | 
| 209 | 
            +
                return next(g, True) and not next(g, False)
         | 
| 210 | 
            +
             | 
| 211 | 
            +
             | 
| 212 | 
            +
            def quantify(iterable, pred=bool):
         | 
| 213 | 
            +
                """Return the how many times the predicate is true.
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                >>> quantify([True, False, True])
         | 
| 216 | 
            +
                2
         | 
| 217 | 
            +
             | 
| 218 | 
            +
                """
         | 
| 219 | 
            +
                return sum(map(pred, iterable))
         | 
| 220 | 
            +
             | 
| 221 | 
            +
             | 
| 222 | 
            +
            def pad_none(iterable):
         | 
| 223 | 
            +
                """Returns the sequence of elements and then returns ``None`` indefinitely.
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                    >>> take(5, pad_none(range(3)))
         | 
| 226 | 
            +
                    [0, 1, 2, None, None]
         | 
| 227 | 
            +
             | 
| 228 | 
            +
                Useful for emulating the behavior of the built-in :func:`map` function.
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                See also :func:`padded`.
         | 
| 231 | 
            +
             | 
| 232 | 
            +
                """
         | 
| 233 | 
            +
                return chain(iterable, repeat(None))
         | 
| 234 | 
            +
             | 
| 235 | 
            +
             | 
| 236 | 
            +
            padnone = pad_none
         | 
| 237 | 
            +
             | 
| 238 | 
            +
             | 
| 239 | 
            +
            def ncycles(iterable, n):
         | 
| 240 | 
            +
                """Returns the sequence elements *n* times
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                >>> list(ncycles(["a", "b"], 3))
         | 
| 243 | 
            +
                ['a', 'b', 'a', 'b', 'a', 'b']
         | 
| 244 | 
            +
             | 
| 245 | 
            +
                """
         | 
| 246 | 
            +
                return chain.from_iterable(repeat(tuple(iterable), n))
         | 
| 247 | 
            +
             | 
| 248 | 
            +
             | 
| 249 | 
            +
            def dotproduct(vec1, vec2):
         | 
| 250 | 
            +
                """Returns the dot product of the two iterables.
         | 
| 251 | 
            +
             | 
| 252 | 
            +
                >>> dotproduct([10, 10], [20, 20])
         | 
| 253 | 
            +
                400
         | 
| 254 | 
            +
             | 
| 255 | 
            +
                """
         | 
| 256 | 
            +
                return sum(map(operator.mul, vec1, vec2))
         | 
| 257 | 
            +
             | 
| 258 | 
            +
             | 
| 259 | 
            +
            def flatten(listOfLists):
         | 
| 260 | 
            +
                """Return an iterator flattening one level of nesting in a list of lists.
         | 
| 261 | 
            +
             | 
| 262 | 
            +
                    >>> list(flatten([[0, 1], [2, 3]]))
         | 
| 263 | 
            +
                    [0, 1, 2, 3]
         | 
| 264 | 
            +
             | 
| 265 | 
            +
                See also :func:`collapse`, which can flatten multiple levels of nesting.
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                """
         | 
| 268 | 
            +
                return chain.from_iterable(listOfLists)
         | 
| 269 | 
            +
             | 
| 270 | 
            +
             | 
| 271 | 
            +
            def repeatfunc(func, times=None, *args):
         | 
| 272 | 
            +
                """Call *func* with *args* repeatedly, returning an iterable over the
         | 
| 273 | 
            +
                results.
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                If *times* is specified, the iterable will terminate after that many
         | 
| 276 | 
            +
                repetitions:
         | 
| 277 | 
            +
             | 
| 278 | 
            +
                    >>> from operator import add
         | 
| 279 | 
            +
                    >>> times = 4
         | 
| 280 | 
            +
                    >>> args = 3, 5
         | 
| 281 | 
            +
                    >>> list(repeatfunc(add, times, *args))
         | 
| 282 | 
            +
                    [8, 8, 8, 8]
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                If *times* is ``None`` the iterable will not terminate:
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                    >>> from random import randrange
         | 
| 287 | 
            +
                    >>> times = None
         | 
| 288 | 
            +
                    >>> args = 1, 11
         | 
| 289 | 
            +
                    >>> take(6, repeatfunc(randrange, times, *args))  # doctest:+SKIP
         | 
| 290 | 
            +
                    [2, 4, 8, 1, 8, 4]
         | 
| 291 | 
            +
             | 
| 292 | 
            +
                """
         | 
| 293 | 
            +
                if times is None:
         | 
| 294 | 
            +
                    return starmap(func, repeat(args))
         | 
| 295 | 
            +
                return starmap(func, repeat(args, times))
         | 
| 296 | 
            +
             | 
| 297 | 
            +
             | 
| 298 | 
            +
            def _pairwise(iterable):
         | 
| 299 | 
            +
                """Returns an iterator of paired items, overlapping, from the original
         | 
| 300 | 
            +
             | 
| 301 | 
            +
                >>> take(4, pairwise(count()))
         | 
| 302 | 
            +
                [(0, 1), (1, 2), (2, 3), (3, 4)]
         | 
| 303 | 
            +
             | 
| 304 | 
            +
                On Python 3.10 and above, this is an alias for :func:`itertools.pairwise`.
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                """
         | 
| 307 | 
            +
                a, b = tee(iterable)
         | 
| 308 | 
            +
                next(b, None)
         | 
| 309 | 
            +
                return zip(a, b)
         | 
| 310 | 
            +
             | 
| 311 | 
            +
             | 
| 312 | 
            +
            try:
         | 
| 313 | 
            +
                from itertools import pairwise as itertools_pairwise
         | 
| 314 | 
            +
            except ImportError:
         | 
| 315 | 
            +
                pairwise = _pairwise
         | 
| 316 | 
            +
            else:
         | 
| 317 | 
            +
             | 
| 318 | 
            +
                def pairwise(iterable):
         | 
| 319 | 
            +
                    return itertools_pairwise(iterable)
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                pairwise.__doc__ = _pairwise.__doc__
         | 
| 322 | 
            +
             | 
| 323 | 
            +
             | 
| 324 | 
            +
            class UnequalIterablesError(ValueError):
         | 
| 325 | 
            +
                def __init__(self, details=None):
         | 
| 326 | 
            +
                    msg = 'Iterables have different lengths'
         | 
| 327 | 
            +
                    if details is not None:
         | 
| 328 | 
            +
                        msg += (': index 0 has length {}; index {} has length {}').format(
         | 
| 329 | 
            +
                            *details
         | 
| 330 | 
            +
                        )
         | 
| 331 | 
            +
             | 
| 332 | 
            +
                    super().__init__(msg)
         | 
| 333 | 
            +
             | 
| 334 | 
            +
             | 
| 335 | 
            +
            def _zip_equal_generator(iterables):
         | 
| 336 | 
            +
                for combo in zip_longest(*iterables, fillvalue=_marker):
         | 
| 337 | 
            +
                    for val in combo:
         | 
| 338 | 
            +
                        if val is _marker:
         | 
| 339 | 
            +
                            raise UnequalIterablesError()
         | 
| 340 | 
            +
                    yield combo
         | 
| 341 | 
            +
             | 
| 342 | 
            +
             | 
| 343 | 
            +
            def _zip_equal(*iterables):
         | 
| 344 | 
            +
                # Check whether the iterables are all the same size.
         | 
| 345 | 
            +
                try:
         | 
| 346 | 
            +
                    first_size = len(iterables[0])
         | 
| 347 | 
            +
                    for i, it in enumerate(iterables[1:], 1):
         | 
| 348 | 
            +
                        size = len(it)
         | 
| 349 | 
            +
                        if size != first_size:
         | 
| 350 | 
            +
                            raise UnequalIterablesError(details=(first_size, i, size))
         | 
| 351 | 
            +
                    # All sizes are equal, we can use the built-in zip.
         | 
| 352 | 
            +
                    return zip(*iterables)
         | 
| 353 | 
            +
                # If any one of the iterables didn't have a length, start reading
         | 
| 354 | 
            +
                # them until one runs out.
         | 
| 355 | 
            +
                except TypeError:
         | 
| 356 | 
            +
                    return _zip_equal_generator(iterables)
         | 
| 357 | 
            +
             | 
| 358 | 
            +
             | 
| 359 | 
            +
            def grouper(iterable, n, incomplete='fill', fillvalue=None):
         | 
| 360 | 
            +
                """Group elements from *iterable* into fixed-length groups of length *n*.
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                >>> list(grouper('ABCDEF', 3))
         | 
| 363 | 
            +
                [('A', 'B', 'C'), ('D', 'E', 'F')]
         | 
| 364 | 
            +
             | 
| 365 | 
            +
                The keyword arguments *incomplete* and *fillvalue* control what happens for
         | 
| 366 | 
            +
                iterables whose length is not a multiple of *n*.
         | 
| 367 | 
            +
             | 
| 368 | 
            +
                When *incomplete* is `'fill'`, the last group will contain instances of
         | 
| 369 | 
            +
                *fillvalue*.
         | 
| 370 | 
            +
             | 
| 371 | 
            +
                >>> list(grouper('ABCDEFG', 3, incomplete='fill', fillvalue='x'))
         | 
| 372 | 
            +
                [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')]
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                When *incomplete* is `'ignore'`, the last group will not be emitted.
         | 
| 375 | 
            +
             | 
| 376 | 
            +
                >>> list(grouper('ABCDEFG', 3, incomplete='ignore', fillvalue='x'))
         | 
| 377 | 
            +
                [('A', 'B', 'C'), ('D', 'E', 'F')]
         | 
| 378 | 
            +
             | 
| 379 | 
            +
                When *incomplete* is `'strict'`, a subclass of `ValueError` will be raised.
         | 
| 380 | 
            +
             | 
| 381 | 
            +
                >>> it = grouper('ABCDEFG', 3, incomplete='strict')
         | 
| 382 | 
            +
                >>> list(it)  # doctest: +IGNORE_EXCEPTION_DETAIL
         | 
| 383 | 
            +
                Traceback (most recent call last):
         | 
| 384 | 
            +
                ...
         | 
| 385 | 
            +
                UnequalIterablesError
         | 
| 386 | 
            +
             | 
| 387 | 
            +
                """
         | 
| 388 | 
            +
                args = [iter(iterable)] * n
         | 
| 389 | 
            +
                if incomplete == 'fill':
         | 
| 390 | 
            +
                    return zip_longest(*args, fillvalue=fillvalue)
         | 
| 391 | 
            +
                if incomplete == 'strict':
         | 
| 392 | 
            +
                    return _zip_equal(*args)
         | 
| 393 | 
            +
                if incomplete == 'ignore':
         | 
| 394 | 
            +
                    return zip(*args)
         | 
| 395 | 
            +
                else:
         | 
| 396 | 
            +
                    raise ValueError('Expected fill, strict, or ignore')
         | 
| 397 | 
            +
             | 
| 398 | 
            +
             | 
| 399 | 
            +
            def roundrobin(*iterables):
         | 
| 400 | 
            +
                """Yields an item from each iterable, alternating between them.
         | 
| 401 | 
            +
             | 
| 402 | 
            +
                    >>> list(roundrobin('ABC', 'D', 'EF'))
         | 
| 403 | 
            +
                    ['A', 'D', 'E', 'B', 'F', 'C']
         | 
| 404 | 
            +
             | 
| 405 | 
            +
                This function produces the same output as :func:`interleave_longest`, but
         | 
| 406 | 
            +
                may perform better for some inputs (in particular when the number of
         | 
| 407 | 
            +
                iterables is small).
         | 
| 408 | 
            +
             | 
| 409 | 
            +
                """
         | 
| 410 | 
            +
                # Recipe credited to George Sakkis
         | 
| 411 | 
            +
                pending = len(iterables)
         | 
| 412 | 
            +
                nexts = cycle(iter(it).__next__ for it in iterables)
         | 
| 413 | 
            +
                while pending:
         | 
| 414 | 
            +
                    try:
         | 
| 415 | 
            +
                        for next in nexts:
         | 
| 416 | 
            +
                            yield next()
         | 
| 417 | 
            +
                    except StopIteration:
         | 
| 418 | 
            +
                        pending -= 1
         | 
| 419 | 
            +
                        nexts = cycle(islice(nexts, pending))
         | 
| 420 | 
            +
             | 
| 421 | 
            +
             | 
| 422 | 
            +
            def partition(pred, iterable):
         | 
| 423 | 
            +
                """
         | 
| 424 | 
            +
                Returns a 2-tuple of iterables derived from the input iterable.
         | 
| 425 | 
            +
                The first yields the items that have ``pred(item) == False``.
         | 
| 426 | 
            +
                The second yields the items that have ``pred(item) == True``.
         | 
| 427 | 
            +
             | 
| 428 | 
            +
                    >>> is_odd = lambda x: x % 2 != 0
         | 
| 429 | 
            +
                    >>> iterable = range(10)
         | 
| 430 | 
            +
                    >>> even_items, odd_items = partition(is_odd, iterable)
         | 
| 431 | 
            +
                    >>> list(even_items), list(odd_items)
         | 
| 432 | 
            +
                    ([0, 2, 4, 6, 8], [1, 3, 5, 7, 9])
         | 
| 433 | 
            +
             | 
| 434 | 
            +
                If *pred* is None, :func:`bool` is used.
         | 
| 435 | 
            +
             | 
| 436 | 
            +
                    >>> iterable = [0, 1, False, True, '', ' ']
         | 
| 437 | 
            +
                    >>> false_items, true_items = partition(None, iterable)
         | 
| 438 | 
            +
                    >>> list(false_items), list(true_items)
         | 
| 439 | 
            +
                    ([0, False, ''], [1, True, ' '])
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                """
         | 
| 442 | 
            +
                if pred is None:
         | 
| 443 | 
            +
                    pred = bool
         | 
| 444 | 
            +
             | 
| 445 | 
            +
                t1, t2, p = tee(iterable, 3)
         | 
| 446 | 
            +
                p1, p2 = tee(map(pred, p))
         | 
| 447 | 
            +
                return (compress(t1, map(operator.not_, p1)), compress(t2, p2))
         | 
| 448 | 
            +
             | 
| 449 | 
            +
             | 
| 450 | 
            +
            def powerset(iterable):
         | 
| 451 | 
            +
                """Yields all possible subsets of the iterable.
         | 
| 452 | 
            +
             | 
| 453 | 
            +
                    >>> list(powerset([1, 2, 3]))
         | 
| 454 | 
            +
                    [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
         | 
| 455 | 
            +
             | 
| 456 | 
            +
                :func:`powerset` will operate on iterables that aren't :class:`set`
         | 
| 457 | 
            +
                instances, so repeated elements in the input will produce repeated elements
         | 
| 458 | 
            +
                in the output. Use :func:`unique_everseen` on the input to avoid generating
         | 
| 459 | 
            +
                duplicates:
         | 
| 460 | 
            +
             | 
| 461 | 
            +
                    >>> seq = [1, 1, 0]
         | 
| 462 | 
            +
                    >>> list(powerset(seq))
         | 
| 463 | 
            +
                    [(), (1,), (1,), (0,), (1, 1), (1, 0), (1, 0), (1, 1, 0)]
         | 
| 464 | 
            +
                    >>> from more_itertools import unique_everseen
         | 
| 465 | 
            +
                    >>> list(powerset(unique_everseen(seq)))
         | 
| 466 | 
            +
                    [(), (1,), (0,), (1, 0)]
         | 
| 467 | 
            +
             | 
| 468 | 
            +
                """
         | 
| 469 | 
            +
                s = list(iterable)
         | 
| 470 | 
            +
                return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
         | 
| 471 | 
            +
             | 
| 472 | 
            +
             | 
| 473 | 
            +
            def unique_everseen(iterable, key=None):
         | 
| 474 | 
            +
                """
         | 
| 475 | 
            +
                Yield unique elements, preserving order.
         | 
| 476 | 
            +
             | 
| 477 | 
            +
                    >>> list(unique_everseen('AAAABBBCCDAABBB'))
         | 
| 478 | 
            +
                    ['A', 'B', 'C', 'D']
         | 
| 479 | 
            +
                    >>> list(unique_everseen('ABBCcAD', str.lower))
         | 
| 480 | 
            +
                    ['A', 'B', 'C', 'D']
         | 
| 481 | 
            +
             | 
| 482 | 
            +
                Sequences with a mix of hashable and unhashable items can be used.
         | 
| 483 | 
            +
                The function will be slower (i.e., `O(n^2)`) for unhashable items.
         | 
| 484 | 
            +
             | 
| 485 | 
            +
                Remember that ``list`` objects are unhashable - you can use the *key*
         | 
| 486 | 
            +
                parameter to transform the list to a tuple (which is hashable) to
         | 
| 487 | 
            +
                avoid a slowdown.
         | 
| 488 | 
            +
             | 
| 489 | 
            +
                    >>> iterable = ([1, 2], [2, 3], [1, 2])
         | 
| 490 | 
            +
                    >>> list(unique_everseen(iterable))  # Slow
         | 
| 491 | 
            +
                    [[1, 2], [2, 3]]
         | 
| 492 | 
            +
                    >>> list(unique_everseen(iterable, key=tuple))  # Faster
         | 
| 493 | 
            +
                    [[1, 2], [2, 3]]
         | 
| 494 | 
            +
             | 
| 495 | 
            +
                Similary, you may want to convert unhashable ``set`` objects with
         | 
| 496 | 
            +
                ``key=frozenset``. For ``dict`` objects,
         | 
| 497 | 
            +
                ``key=lambda x: frozenset(x.items())`` can be used.
         | 
| 498 | 
            +
             | 
| 499 | 
            +
                """
         | 
| 500 | 
            +
                seenset = set()
         | 
| 501 | 
            +
                seenset_add = seenset.add
         | 
| 502 | 
            +
                seenlist = []
         | 
| 503 | 
            +
                seenlist_add = seenlist.append
         | 
| 504 | 
            +
                use_key = key is not None
         | 
| 505 | 
            +
             | 
| 506 | 
            +
                for element in iterable:
         | 
| 507 | 
            +
                    k = key(element) if use_key else element
         | 
| 508 | 
            +
                    try:
         | 
| 509 | 
            +
                        if k not in seenset:
         | 
| 510 | 
            +
                            seenset_add(k)
         | 
| 511 | 
            +
                            yield element
         | 
| 512 | 
            +
                    except TypeError:
         | 
| 513 | 
            +
                        if k not in seenlist:
         | 
| 514 | 
            +
                            seenlist_add(k)
         | 
| 515 | 
            +
                            yield element
         | 
| 516 | 
            +
             | 
| 517 | 
            +
             | 
| 518 | 
            +
            def unique_justseen(iterable, key=None):
         | 
| 519 | 
            +
                """Yields elements in order, ignoring serial duplicates
         | 
| 520 | 
            +
             | 
| 521 | 
            +
                >>> list(unique_justseen('AAAABBBCCDAABBB'))
         | 
| 522 | 
            +
                ['A', 'B', 'C', 'D', 'A', 'B']
         | 
| 523 | 
            +
                >>> list(unique_justseen('ABBCcAD', str.lower))
         | 
| 524 | 
            +
                ['A', 'B', 'C', 'A', 'D']
         | 
| 525 | 
            +
             | 
| 526 | 
            +
                """
         | 
| 527 | 
            +
                return map(next, map(operator.itemgetter(1), groupby(iterable, key)))
         | 
| 528 | 
            +
             | 
| 529 | 
            +
             | 
| 530 | 
            +
            def iter_except(func, exception, first=None):
         | 
| 531 | 
            +
                """Yields results from a function repeatedly until an exception is raised.
         | 
| 532 | 
            +
             | 
| 533 | 
            +
                Converts a call-until-exception interface to an iterator interface.
         | 
| 534 | 
            +
                Like ``iter(func, sentinel)``, but uses an exception instead of a sentinel
         | 
| 535 | 
            +
                to end the loop.
         | 
| 536 | 
            +
             | 
| 537 | 
            +
                    >>> l = [0, 1, 2]
         | 
| 538 | 
            +
                    >>> list(iter_except(l.pop, IndexError))
         | 
| 539 | 
            +
                    [2, 1, 0]
         | 
| 540 | 
            +
             | 
| 541 | 
            +
                Multiple exceptions can be specified as a stopping condition:
         | 
| 542 | 
            +
             | 
| 543 | 
            +
                    >>> l = [1, 2, 3, '...', 4, 5, 6]
         | 
| 544 | 
            +
                    >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
         | 
| 545 | 
            +
                    [7, 6, 5]
         | 
| 546 | 
            +
                    >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
         | 
| 547 | 
            +
                    [4, 3, 2]
         | 
| 548 | 
            +
                    >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
         | 
| 549 | 
            +
                    []
         | 
| 550 | 
            +
             | 
| 551 | 
            +
                """
         | 
| 552 | 
            +
                try:
         | 
| 553 | 
            +
                    if first is not None:
         | 
| 554 | 
            +
                        yield first()
         | 
| 555 | 
            +
                    while 1:
         | 
| 556 | 
            +
                        yield func()
         | 
| 557 | 
            +
                except exception:
         | 
| 558 | 
            +
                    pass
         | 
| 559 | 
            +
             | 
| 560 | 
            +
             | 
| 561 | 
            +
            def first_true(iterable, default=None, pred=None):
         | 
| 562 | 
            +
                """
         | 
| 563 | 
            +
                Returns the first true value in the iterable.
         | 
| 564 | 
            +
             | 
| 565 | 
            +
                If no true value is found, returns *default*
         | 
| 566 | 
            +
             | 
| 567 | 
            +
                If *pred* is not None, returns the first item for which
         | 
| 568 | 
            +
                ``pred(item) == True`` .
         | 
| 569 | 
            +
             | 
| 570 | 
            +
                    >>> first_true(range(10))
         | 
| 571 | 
            +
                    1
         | 
| 572 | 
            +
                    >>> first_true(range(10), pred=lambda x: x > 5)
         | 
| 573 | 
            +
                    6
         | 
| 574 | 
            +
                    >>> first_true(range(10), default='missing', pred=lambda x: x > 9)
         | 
| 575 | 
            +
                    'missing'
         | 
| 576 | 
            +
             | 
| 577 | 
            +
                """
         | 
| 578 | 
            +
                return next(filter(pred, iterable), default)
         | 
| 579 | 
            +
             | 
| 580 | 
            +
             | 
| 581 | 
            +
            def random_product(*args, repeat=1):
         | 
| 582 | 
            +
                """Draw an item at random from each of the input iterables.
         | 
| 583 | 
            +
             | 
| 584 | 
            +
                    >>> random_product('abc', range(4), 'XYZ')  # doctest:+SKIP
         | 
| 585 | 
            +
                    ('c', 3, 'Z')
         | 
| 586 | 
            +
             | 
| 587 | 
            +
                If *repeat* is provided as a keyword argument, that many items will be
         | 
| 588 | 
            +
                drawn from each iterable.
         | 
| 589 | 
            +
             | 
| 590 | 
            +
                    >>> random_product('abcd', range(4), repeat=2)  # doctest:+SKIP
         | 
| 591 | 
            +
                    ('a', 2, 'd', 3)
         | 
| 592 | 
            +
             | 
| 593 | 
            +
                This equivalent to taking a random selection from
         | 
| 594 | 
            +
                ``itertools.product(*args, **kwarg)``.
         | 
| 595 | 
            +
             | 
| 596 | 
            +
                """
         | 
| 597 | 
            +
                pools = [tuple(pool) for pool in args] * repeat
         | 
| 598 | 
            +
                return tuple(choice(pool) for pool in pools)
         | 
| 599 | 
            +
             | 
| 600 | 
            +
             | 
| 601 | 
            +
            def random_permutation(iterable, r=None):
         | 
| 602 | 
            +
                """Return a random *r* length permutation of the elements in *iterable*.
         | 
| 603 | 
            +
             | 
| 604 | 
            +
                If *r* is not specified or is ``None``, then *r* defaults to the length of
         | 
| 605 | 
            +
                *iterable*.
         | 
| 606 | 
            +
             | 
| 607 | 
            +
                    >>> random_permutation(range(5))  # doctest:+SKIP
         | 
| 608 | 
            +
                    (3, 4, 0, 1, 2)
         | 
| 609 | 
            +
             | 
| 610 | 
            +
                This equivalent to taking a random selection from
         | 
| 611 | 
            +
                ``itertools.permutations(iterable, r)``.
         | 
| 612 | 
            +
             | 
| 613 | 
            +
                """
         | 
| 614 | 
            +
                pool = tuple(iterable)
         | 
| 615 | 
            +
                r = len(pool) if r is None else r
         | 
| 616 | 
            +
                return tuple(sample(pool, r))
         | 
| 617 | 
            +
             | 
| 618 | 
            +
             | 
| 619 | 
            +
            def random_combination(iterable, r):
         | 
| 620 | 
            +
                """Return a random *r* length subsequence of the elements in *iterable*.
         | 
| 621 | 
            +
             | 
| 622 | 
            +
                    >>> random_combination(range(5), 3)  # doctest:+SKIP
         | 
| 623 | 
            +
                    (2, 3, 4)
         | 
| 624 | 
            +
             | 
| 625 | 
            +
                This equivalent to taking a random selection from
         | 
| 626 | 
            +
                ``itertools.combinations(iterable, r)``.
         | 
| 627 | 
            +
             | 
| 628 | 
            +
                """
         | 
| 629 | 
            +
                pool = tuple(iterable)
         | 
| 630 | 
            +
                n = len(pool)
         | 
| 631 | 
            +
                indices = sorted(sample(range(n), r))
         | 
| 632 | 
            +
                return tuple(pool[i] for i in indices)
         | 
| 633 | 
            +
             | 
| 634 | 
            +
             | 
| 635 | 
            +
            def random_combination_with_replacement(iterable, r):
         | 
| 636 | 
            +
                """Return a random *r* length subsequence of elements in *iterable*,
         | 
| 637 | 
            +
                allowing individual elements to be repeated.
         | 
| 638 | 
            +
             | 
| 639 | 
            +
                    >>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP
         | 
| 640 | 
            +
                    (0, 0, 1, 2, 2)
         | 
| 641 | 
            +
             | 
| 642 | 
            +
                This equivalent to taking a random selection from
         | 
| 643 | 
            +
                ``itertools.combinations_with_replacement(iterable, r)``.
         | 
| 644 | 
            +
             | 
| 645 | 
            +
                """
         | 
| 646 | 
            +
                pool = tuple(iterable)
         | 
| 647 | 
            +
                n = len(pool)
         | 
| 648 | 
            +
                indices = sorted(randrange(n) for i in range(r))
         | 
| 649 | 
            +
                return tuple(pool[i] for i in indices)
         | 
| 650 | 
            +
             | 
| 651 | 
            +
             | 
| 652 | 
            +
            def nth_combination(iterable, r, index):
         | 
| 653 | 
            +
                """Equivalent to ``list(combinations(iterable, r))[index]``.
         | 
| 654 | 
            +
             | 
| 655 | 
            +
                The subsequences of *iterable* that are of length *r* can be ordered
         | 
| 656 | 
            +
                lexicographically. :func:`nth_combination` computes the subsequence at
         | 
| 657 | 
            +
                sort position *index* directly, without computing the previous
         | 
| 658 | 
            +
                subsequences.
         | 
| 659 | 
            +
             | 
| 660 | 
            +
                    >>> nth_combination(range(5), 3, 5)
         | 
| 661 | 
            +
                    (0, 3, 4)
         | 
| 662 | 
            +
             | 
| 663 | 
            +
                ``ValueError`` will be raised If *r* is negative or greater than the length
         | 
| 664 | 
            +
                of *iterable*.
         | 
| 665 | 
            +
                ``IndexError`` will be raised if the given *index* is invalid.
         | 
| 666 | 
            +
                """
         | 
| 667 | 
            +
                pool = tuple(iterable)
         | 
| 668 | 
            +
                n = len(pool)
         | 
| 669 | 
            +
                if (r < 0) or (r > n):
         | 
| 670 | 
            +
                    raise ValueError
         | 
| 671 | 
            +
             | 
| 672 | 
            +
                c = 1
         | 
| 673 | 
            +
                k = min(r, n - r)
         | 
| 674 | 
            +
                for i in range(1, k + 1):
         | 
| 675 | 
            +
                    c = c * (n - k + i) // i
         | 
| 676 | 
            +
             | 
| 677 | 
            +
                if index < 0:
         | 
| 678 | 
            +
                    index += c
         | 
| 679 | 
            +
             | 
| 680 | 
            +
                if (index < 0) or (index >= c):
         | 
| 681 | 
            +
                    raise IndexError
         | 
| 682 | 
            +
             | 
| 683 | 
            +
                result = []
         | 
| 684 | 
            +
                while r:
         | 
| 685 | 
            +
                    c, n, r = c * r // n, n - 1, r - 1
         | 
| 686 | 
            +
                    while index >= c:
         | 
| 687 | 
            +
                        index -= c
         | 
| 688 | 
            +
                        c, n = c * (n - r) // n, n - 1
         | 
| 689 | 
            +
                    result.append(pool[-1 - n])
         | 
| 690 | 
            +
             | 
| 691 | 
            +
                return tuple(result)
         | 
| 692 | 
            +
             | 
| 693 | 
            +
             | 
| 694 | 
            +
            def prepend(value, iterator):
         | 
| 695 | 
            +
                """Yield *value*, followed by the elements in *iterator*.
         | 
| 696 | 
            +
             | 
| 697 | 
            +
                    >>> value = '0'
         | 
| 698 | 
            +
                    >>> iterator = ['1', '2', '3']
         | 
| 699 | 
            +
                    >>> list(prepend(value, iterator))
         | 
| 700 | 
            +
                    ['0', '1', '2', '3']
         | 
| 701 | 
            +
             | 
| 702 | 
            +
                To prepend multiple values, see :func:`itertools.chain`
         | 
| 703 | 
            +
                or :func:`value_chain`.
         | 
| 704 | 
            +
             | 
| 705 | 
            +
                """
         | 
| 706 | 
            +
                return chain([value], iterator)
         | 
| 707 | 
            +
             | 
| 708 | 
            +
             | 
| 709 | 
            +
            def convolve(signal, kernel):
         | 
| 710 | 
            +
                """Convolve the iterable *signal* with the iterable *kernel*.
         | 
| 711 | 
            +
             | 
| 712 | 
            +
                    >>> signal = (1, 2, 3, 4, 5)
         | 
| 713 | 
            +
                    >>> kernel = [3, 2, 1]
         | 
| 714 | 
            +
                    >>> list(convolve(signal, kernel))
         | 
| 715 | 
            +
                    [3, 8, 14, 20, 26, 14, 5]
         | 
| 716 | 
            +
             | 
| 717 | 
            +
                Note: the input arguments are not interchangeable, as the *kernel*
         | 
| 718 | 
            +
                is immediately consumed and stored.
         | 
| 719 | 
            +
             | 
| 720 | 
            +
                """
         | 
| 721 | 
            +
                # This implementation intentionally doesn't match the one in the itertools
         | 
| 722 | 
            +
                # documentation.
         | 
| 723 | 
            +
                kernel = tuple(kernel)[::-1]
         | 
| 724 | 
            +
                n = len(kernel)
         | 
| 725 | 
            +
                window = deque([0], maxlen=n) * n
         | 
| 726 | 
            +
                for x in chain(signal, repeat(0, n - 1)):
         | 
| 727 | 
            +
                    window.append(x)
         | 
| 728 | 
            +
                    yield _sumprod(kernel, window)
         | 
| 729 | 
            +
             | 
| 730 | 
            +
             | 
| 731 | 
            +
            def before_and_after(predicate, it):
         | 
| 732 | 
            +
                """A variant of :func:`takewhile` that allows complete access to the
         | 
| 733 | 
            +
                remainder of the iterator.
         | 
| 734 | 
            +
             | 
| 735 | 
            +
                     >>> it = iter('ABCdEfGhI')
         | 
| 736 | 
            +
                     >>> all_upper, remainder = before_and_after(str.isupper, it)
         | 
| 737 | 
            +
                     >>> ''.join(all_upper)
         | 
| 738 | 
            +
                     'ABC'
         | 
| 739 | 
            +
                     >>> ''.join(remainder) # takewhile() would lose the 'd'
         | 
| 740 | 
            +
                     'dEfGhI'
         | 
| 741 | 
            +
             | 
| 742 | 
            +
                Note that the first iterator must be fully consumed before the second
         | 
| 743 | 
            +
                iterator can generate valid results.
         | 
| 744 | 
            +
                """
         | 
| 745 | 
            +
                it = iter(it)
         | 
| 746 | 
            +
                transition = []
         | 
| 747 | 
            +
             | 
| 748 | 
            +
                def true_iterator():
         | 
| 749 | 
            +
                    for elem in it:
         | 
| 750 | 
            +
                        if predicate(elem):
         | 
| 751 | 
            +
                            yield elem
         | 
| 752 | 
            +
                        else:
         | 
| 753 | 
            +
                            transition.append(elem)
         | 
| 754 | 
            +
                            return
         | 
| 755 | 
            +
             | 
| 756 | 
            +
                # Note: this is different from itertools recipes to allow nesting
         | 
| 757 | 
            +
                # before_and_after remainders into before_and_after again. See tests
         | 
| 758 | 
            +
                # for an example.
         | 
| 759 | 
            +
                remainder_iterator = chain(transition, it)
         | 
| 760 | 
            +
             | 
| 761 | 
            +
                return true_iterator(), remainder_iterator
         | 
| 762 | 
            +
             | 
| 763 | 
            +
             | 
| 764 | 
            +
            def triplewise(iterable):
         | 
| 765 | 
            +
                """Return overlapping triplets from *iterable*.
         | 
| 766 | 
            +
             | 
| 767 | 
            +
                >>> list(triplewise('ABCDE'))
         | 
| 768 | 
            +
                [('A', 'B', 'C'), ('B', 'C', 'D'), ('C', 'D', 'E')]
         | 
| 769 | 
            +
             | 
| 770 | 
            +
                """
         | 
| 771 | 
            +
                for (a, _), (b, c) in pairwise(pairwise(iterable)):
         | 
| 772 | 
            +
                    yield a, b, c
         | 
| 773 | 
            +
             | 
| 774 | 
            +
             | 
| 775 | 
            +
            def sliding_window(iterable, n):
         | 
| 776 | 
            +
                """Return a sliding window of width *n* over *iterable*.
         | 
| 777 | 
            +
             | 
| 778 | 
            +
                    >>> list(sliding_window(range(6), 4))
         | 
| 779 | 
            +
                    [(0, 1, 2, 3), (1, 2, 3, 4), (2, 3, 4, 5)]
         | 
| 780 | 
            +
             | 
| 781 | 
            +
                If *iterable* has fewer than *n* items, then nothing is yielded:
         | 
| 782 | 
            +
             | 
| 783 | 
            +
                    >>> list(sliding_window(range(3), 4))
         | 
| 784 | 
            +
                    []
         | 
| 785 | 
            +
             | 
| 786 | 
            +
                For a variant with more features, see :func:`windowed`.
         | 
| 787 | 
            +
                """
         | 
| 788 | 
            +
                it = iter(iterable)
         | 
| 789 | 
            +
                window = deque(islice(it, n - 1), maxlen=n)
         | 
| 790 | 
            +
                for x in it:
         | 
| 791 | 
            +
                    window.append(x)
         | 
| 792 | 
            +
                    yield tuple(window)
         | 
| 793 | 
            +
             | 
| 794 | 
            +
             | 
| 795 | 
            +
            def subslices(iterable):
         | 
| 796 | 
            +
                """Return all contiguous non-empty subslices of *iterable*.
         | 
| 797 | 
            +
             | 
| 798 | 
            +
                    >>> list(subslices('ABC'))
         | 
| 799 | 
            +
                    [['A'], ['A', 'B'], ['A', 'B', 'C'], ['B'], ['B', 'C'], ['C']]
         | 
| 800 | 
            +
             | 
| 801 | 
            +
                This is similar to :func:`substrings`, but emits items in a different
         | 
| 802 | 
            +
                order.
         | 
| 803 | 
            +
                """
         | 
| 804 | 
            +
                seq = list(iterable)
         | 
| 805 | 
            +
                slices = starmap(slice, combinations(range(len(seq) + 1), 2))
         | 
| 806 | 
            +
                return map(operator.getitem, repeat(seq), slices)
         | 
| 807 | 
            +
             | 
| 808 | 
            +
             | 
| 809 | 
            +
            def polynomial_from_roots(roots):
         | 
| 810 | 
            +
                """Compute a polynomial's coefficients from its roots.
         | 
| 811 | 
            +
             | 
| 812 | 
            +
                >>> roots = [5, -4, 3]  # (x - 5) * (x + 4) * (x - 3)
         | 
| 813 | 
            +
                >>> polynomial_from_roots(roots)  # x^3 - 4 * x^2 - 17 * x + 60
         | 
| 814 | 
            +
                [1, -4, -17, 60]
         | 
| 815 | 
            +
                """
         | 
| 816 | 
            +
                factors = zip(repeat(1), map(operator.neg, roots))
         | 
| 817 | 
            +
                return list(reduce(convolve, factors, [1]))
         | 
| 818 | 
            +
             | 
| 819 | 
            +
             | 
| 820 | 
            +
            def iter_index(iterable, value, start=0):
         | 
| 821 | 
            +
                """Yield the index of each place in *iterable* that *value* occurs,
         | 
| 822 | 
            +
                beginning with index *start*.
         | 
| 823 | 
            +
             | 
| 824 | 
            +
                See :func:`locate` for a more general means of finding the indexes
         | 
| 825 | 
            +
                associated with particular values.
         | 
| 826 | 
            +
             | 
| 827 | 
            +
                >>> list(iter_index('AABCADEAF', 'A'))
         | 
| 828 | 
            +
                [0, 1, 4, 7]
         | 
| 829 | 
            +
                """
         | 
| 830 | 
            +
                try:
         | 
| 831 | 
            +
                    seq_index = iterable.index
         | 
| 832 | 
            +
                except AttributeError:
         | 
| 833 | 
            +
                    # Slow path for general iterables
         | 
| 834 | 
            +
                    it = islice(iterable, start, None)
         | 
| 835 | 
            +
                    i = start - 1
         | 
| 836 | 
            +
                    try:
         | 
| 837 | 
            +
                        while True:
         | 
| 838 | 
            +
                            i = i + operator.indexOf(it, value) + 1
         | 
| 839 | 
            +
                            yield i
         | 
| 840 | 
            +
                    except ValueError:
         | 
| 841 | 
            +
                        pass
         | 
| 842 | 
            +
                else:
         | 
| 843 | 
            +
                    # Fast path for sequences
         | 
| 844 | 
            +
                    i = start - 1
         | 
| 845 | 
            +
                    try:
         | 
| 846 | 
            +
                        while True:
         | 
| 847 | 
            +
                            i = seq_index(value, i + 1)
         | 
| 848 | 
            +
                            yield i
         | 
| 849 | 
            +
                    except ValueError:
         | 
| 850 | 
            +
                        pass
         | 
| 851 | 
            +
             | 
| 852 | 
            +
             | 
| 853 | 
            +
            def sieve(n):
         | 
| 854 | 
            +
                """Yield the primes less than n.
         | 
| 855 | 
            +
             | 
| 856 | 
            +
                >>> list(sieve(30))
         | 
| 857 | 
            +
                [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
         | 
| 858 | 
            +
                """
         | 
| 859 | 
            +
                data = bytearray((0, 1)) * (n // 2)
         | 
| 860 | 
            +
                data[:3] = 0, 0, 0
         | 
| 861 | 
            +
                limit = math.isqrt(n) + 1
         | 
| 862 | 
            +
                for p in compress(range(limit), data):
         | 
| 863 | 
            +
                    data[p * p : n : p + p] = bytes(len(range(p * p, n, p + p)))
         | 
| 864 | 
            +
                data[2] = 1
         | 
| 865 | 
            +
                return iter_index(data, 1) if n > 2 else iter([])
         | 
| 866 | 
            +
             | 
| 867 | 
            +
             | 
| 868 | 
            +
            def _batched(iterable, n):
         | 
| 869 | 
            +
                """Batch data into lists of length *n*. The last batch may be shorter.
         | 
| 870 | 
            +
             | 
| 871 | 
            +
                >>> list(batched('ABCDEFG', 3))
         | 
| 872 | 
            +
                [('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)]
         | 
| 873 | 
            +
             | 
| 874 | 
            +
                On Python 3.12 and above, this is an alias for :func:`itertools.batched`.
         | 
| 875 | 
            +
                """
         | 
| 876 | 
            +
                if n < 1:
         | 
| 877 | 
            +
                    raise ValueError('n must be at least one')
         | 
| 878 | 
            +
                it = iter(iterable)
         | 
| 879 | 
            +
                while True:
         | 
| 880 | 
            +
                    batch = tuple(islice(it, n))
         | 
| 881 | 
            +
                    if not batch:
         | 
| 882 | 
            +
                        break
         | 
| 883 | 
            +
                    yield batch
         | 
| 884 | 
            +
             | 
| 885 | 
            +
             | 
| 886 | 
            +
            try:
         | 
| 887 | 
            +
                from itertools import batched as itertools_batched
         | 
| 888 | 
            +
            except ImportError:
         | 
| 889 | 
            +
                batched = _batched
         | 
| 890 | 
            +
            else:
         | 
| 891 | 
            +
             | 
| 892 | 
            +
                def batched(iterable, n):
         | 
| 893 | 
            +
                    return itertools_batched(iterable, n)
         | 
| 894 | 
            +
             | 
| 895 | 
            +
                batched.__doc__ = _batched.__doc__
         | 
| 896 | 
            +
             | 
| 897 | 
            +
             | 
| 898 | 
            +
            def transpose(it):
         | 
| 899 | 
            +
                """Swap the rows and columns of the input.
         | 
| 900 | 
            +
             | 
| 901 | 
            +
                >>> list(transpose([(1, 2, 3), (11, 22, 33)]))
         | 
| 902 | 
            +
                [(1, 11), (2, 22), (3, 33)]
         | 
| 903 | 
            +
             | 
| 904 | 
            +
                The caller should ensure that the dimensions of the input are compatible.
         | 
| 905 | 
            +
                If the input is empty, no output will be produced.
         | 
| 906 | 
            +
                """
         | 
| 907 | 
            +
                return _zip_strict(*it)
         | 
| 908 | 
            +
             | 
| 909 | 
            +
             | 
| 910 | 
            +
            def matmul(m1, m2):
         | 
| 911 | 
            +
                """Multiply two matrices.
         | 
| 912 | 
            +
                >>> list(matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]))
         | 
| 913 | 
            +
                [(49, 80), (41, 60)]
         | 
| 914 | 
            +
             | 
| 915 | 
            +
                The caller should ensure that the dimensions of the input matrices are
         | 
| 916 | 
            +
                compatible with each other.
         | 
| 917 | 
            +
                """
         | 
| 918 | 
            +
                n = len(m2[0])
         | 
| 919 | 
            +
                return batched(starmap(_sumprod, product(m1, transpose(m2))), n)
         | 
| 920 | 
            +
             | 
| 921 | 
            +
             | 
| 922 | 
            +
            def factor(n):
         | 
| 923 | 
            +
                """Yield the prime factors of n.
         | 
| 924 | 
            +
                >>> list(factor(360))
         | 
| 925 | 
            +
                [2, 2, 2, 3, 3, 5]
         | 
| 926 | 
            +
                """
         | 
| 927 | 
            +
                for prime in sieve(math.isqrt(n) + 1):
         | 
| 928 | 
            +
                    while True:
         | 
| 929 | 
            +
                        if n % prime:
         | 
| 930 | 
            +
                            break
         | 
| 931 | 
            +
                        yield prime
         | 
| 932 | 
            +
                        n //= prime
         | 
| 933 | 
            +
                        if n == 1:
         | 
| 934 | 
            +
                            return
         | 
| 935 | 
            +
                if n > 1:
         | 
| 936 | 
            +
                    yield n
         | 
| 937 | 
            +
             | 
| 938 | 
            +
             | 
| 939 | 
            +
            def polynomial_eval(coefficients, x):
         | 
| 940 | 
            +
                """Evaluate a polynomial at a specific value.
         | 
| 941 | 
            +
             | 
| 942 | 
            +
                Example: evaluating x^3 - 4 * x^2 - 17 * x + 60 at x = 2.5:
         | 
| 943 | 
            +
             | 
| 944 | 
            +
                >>> coefficients = [1, -4, -17, 60]
         | 
| 945 | 
            +
                >>> x = 2.5
         | 
| 946 | 
            +
                >>> polynomial_eval(coefficients, x)
         | 
| 947 | 
            +
                8.125
         | 
| 948 | 
            +
                """
         | 
| 949 | 
            +
                n = len(coefficients)
         | 
| 950 | 
            +
                if n == 0:
         | 
| 951 | 
            +
                    return x * 0  # coerce zero to the type of x
         | 
| 952 | 
            +
                powers = map(pow, repeat(x), reversed(range(n)))
         | 
| 953 | 
            +
                return _sumprod(coefficients, powers)
         | 
| 954 | 
            +
             | 
| 955 | 
            +
             | 
| 956 | 
            +
            def sum_of_squares(it):
         | 
| 957 | 
            +
                """Return the sum of the squares of the input values.
         | 
| 958 | 
            +
             | 
| 959 | 
            +
                >>> sum_of_squares([10, 20, 30])
         | 
| 960 | 
            +
                1400
         | 
| 961 | 
            +
                """
         | 
| 962 | 
            +
                return _sumprod(*tee(it))
         | 
| 963 | 
            +
             | 
| 964 | 
            +
             | 
| 965 | 
            +
            def polynomial_derivative(coefficients):
         | 
| 966 | 
            +
                """Compute the first derivative of a polynomial.
         | 
| 967 | 
            +
             | 
| 968 | 
            +
                Example: evaluating the derivative of x^3 - 4 * x^2 - 17 * x + 60
         | 
| 969 | 
            +
             | 
| 970 | 
            +
                >>> coefficients = [1, -4, -17, 60]
         | 
| 971 | 
            +
                >>> derivative_coefficients = polynomial_derivative(coefficients)
         | 
| 972 | 
            +
                >>> derivative_coefficients
         | 
| 973 | 
            +
                [3, -8, -17]
         | 
| 974 | 
            +
                """
         | 
| 975 | 
            +
                n = len(coefficients)
         | 
| 976 | 
            +
                powers = reversed(range(1, n))
         | 
| 977 | 
            +
                return list(map(operator.mul, coefficients, powers))
         | 
    	
        lib/python3.11/site-packages/more_itertools/recipes.pyi
    ADDED
    
    | @@ -0,0 +1,122 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """Stubs for more_itertools.recipes"""
         | 
| 2 | 
            +
            from __future__ import annotations
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            from typing import (
         | 
| 5 | 
            +
                Any,
         | 
| 6 | 
            +
                Callable,
         | 
| 7 | 
            +
                Iterable,
         | 
| 8 | 
            +
                Iterator,
         | 
| 9 | 
            +
                overload,
         | 
| 10 | 
            +
                Sequence,
         | 
| 11 | 
            +
                Type,
         | 
| 12 | 
            +
                TypeVar,
         | 
| 13 | 
            +
            )
         | 
| 14 | 
            +
             | 
| 15 | 
            +
            # Type and type variable definitions
         | 
| 16 | 
            +
            _T = TypeVar('_T')
         | 
| 17 | 
            +
            _U = TypeVar('_U')
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            def take(n: int, iterable: Iterable[_T]) -> list[_T]: ...
         | 
| 20 | 
            +
            def tabulate(
         | 
| 21 | 
            +
                function: Callable[[int], _T], start: int = ...
         | 
| 22 | 
            +
            ) -> Iterator[_T]: ...
         | 
| 23 | 
            +
            def tail(n: int, iterable: Iterable[_T]) -> Iterator[_T]: ...
         | 
| 24 | 
            +
            def consume(iterator: Iterable[_T], n: int | None = ...) -> None: ...
         | 
| 25 | 
            +
            @overload
         | 
| 26 | 
            +
            def nth(iterable: Iterable[_T], n: int) -> _T | None: ...
         | 
| 27 | 
            +
            @overload
         | 
| 28 | 
            +
            def nth(iterable: Iterable[_T], n: int, default: _U) -> _T | _U: ...
         | 
| 29 | 
            +
            def all_equal(iterable: Iterable[object]) -> bool: ...
         | 
| 30 | 
            +
            def quantify(
         | 
| 31 | 
            +
                iterable: Iterable[_T], pred: Callable[[_T], bool] = ...
         | 
| 32 | 
            +
            ) -> int: ...
         | 
| 33 | 
            +
            def pad_none(iterable: Iterable[_T]) -> Iterator[_T | None]: ...
         | 
| 34 | 
            +
            def padnone(iterable: Iterable[_T]) -> Iterator[_T | None]: ...
         | 
| 35 | 
            +
            def ncycles(iterable: Iterable[_T], n: int) -> Iterator[_T]: ...
         | 
| 36 | 
            +
            def dotproduct(vec1: Iterable[object], vec2: Iterable[object]) -> object: ...
         | 
| 37 | 
            +
            def flatten(listOfLists: Iterable[Iterable[_T]]) -> Iterator[_T]: ...
         | 
| 38 | 
            +
            def repeatfunc(
         | 
| 39 | 
            +
                func: Callable[..., _U], times: int | None = ..., *args: Any
         | 
| 40 | 
            +
            ) -> Iterator[_U]: ...
         | 
| 41 | 
            +
            def pairwise(iterable: Iterable[_T]) -> Iterator[tuple[_T, _T]]: ...
         | 
| 42 | 
            +
            def grouper(
         | 
| 43 | 
            +
                iterable: Iterable[_T],
         | 
| 44 | 
            +
                n: int,
         | 
| 45 | 
            +
                incomplete: str = ...,
         | 
| 46 | 
            +
                fillvalue: _U = ...,
         | 
| 47 | 
            +
            ) -> Iterator[tuple[_T | _U, ...]]: ...
         | 
| 48 | 
            +
            def roundrobin(*iterables: Iterable[_T]) -> Iterator[_T]: ...
         | 
| 49 | 
            +
            def partition(
         | 
| 50 | 
            +
                pred: Callable[[_T], object] | None, iterable: Iterable[_T]
         | 
| 51 | 
            +
            ) -> tuple[Iterator[_T], Iterator[_T]]: ...
         | 
| 52 | 
            +
            def powerset(iterable: Iterable[_T]) -> Iterator[tuple[_T, ...]]: ...
         | 
| 53 | 
            +
            def unique_everseen(
         | 
| 54 | 
            +
                iterable: Iterable[_T], key: Callable[[_T], _U] | None = ...
         | 
| 55 | 
            +
            ) -> Iterator[_T]: ...
         | 
| 56 | 
            +
            def unique_justseen(
         | 
| 57 | 
            +
                iterable: Iterable[_T], key: Callable[[_T], object] | None = ...
         | 
| 58 | 
            +
            ) -> Iterator[_T]: ...
         | 
| 59 | 
            +
            @overload
         | 
| 60 | 
            +
            def iter_except(
         | 
| 61 | 
            +
                func: Callable[[], _T],
         | 
| 62 | 
            +
                exception: Type[BaseException] | tuple[Type[BaseException], ...],
         | 
| 63 | 
            +
                first: None = ...,
         | 
| 64 | 
            +
            ) -> Iterator[_T]: ...
         | 
| 65 | 
            +
            @overload
         | 
| 66 | 
            +
            def iter_except(
         | 
| 67 | 
            +
                func: Callable[[], _T],
         | 
| 68 | 
            +
                exception: Type[BaseException] | tuple[Type[BaseException], ...],
         | 
| 69 | 
            +
                first: Callable[[], _U],
         | 
| 70 | 
            +
            ) -> Iterator[_T | _U]: ...
         | 
| 71 | 
            +
            @overload
         | 
| 72 | 
            +
            def first_true(
         | 
| 73 | 
            +
                iterable: Iterable[_T], *, pred: Callable[[_T], object] | None = ...
         | 
| 74 | 
            +
            ) -> _T | None: ...
         | 
| 75 | 
            +
            @overload
         | 
| 76 | 
            +
            def first_true(
         | 
| 77 | 
            +
                iterable: Iterable[_T],
         | 
| 78 | 
            +
                default: _U,
         | 
| 79 | 
            +
                pred: Callable[[_T], object] | None = ...,
         | 
| 80 | 
            +
            ) -> _T | _U: ...
         | 
| 81 | 
            +
            def random_product(
         | 
| 82 | 
            +
                *args: Iterable[_T], repeat: int = ...
         | 
| 83 | 
            +
            ) -> tuple[_T, ...]: ...
         | 
| 84 | 
            +
            def random_permutation(
         | 
| 85 | 
            +
                iterable: Iterable[_T], r: int | None = ...
         | 
| 86 | 
            +
            ) -> tuple[_T, ...]: ...
         | 
| 87 | 
            +
            def random_combination(iterable: Iterable[_T], r: int) -> tuple[_T, ...]: ...
         | 
| 88 | 
            +
            def random_combination_with_replacement(
         | 
| 89 | 
            +
                iterable: Iterable[_T], r: int
         | 
| 90 | 
            +
            ) -> tuple[_T, ...]: ...
         | 
| 91 | 
            +
            def nth_combination(
         | 
| 92 | 
            +
                iterable: Iterable[_T], r: int, index: int
         | 
| 93 | 
            +
            ) -> tuple[_T, ...]: ...
         | 
| 94 | 
            +
            def prepend(value: _T, iterator: Iterable[_U]) -> Iterator[_T | _U]: ...
         | 
| 95 | 
            +
            def convolve(signal: Iterable[_T], kernel: Iterable[_T]) -> Iterator[_T]: ...
         | 
| 96 | 
            +
            def before_and_after(
         | 
| 97 | 
            +
                predicate: Callable[[_T], bool], it: Iterable[_T]
         | 
| 98 | 
            +
            ) -> tuple[Iterator[_T], Iterator[_T]]: ...
         | 
| 99 | 
            +
            def triplewise(iterable: Iterable[_T]) -> Iterator[tuple[_T, _T, _T]]: ...
         | 
| 100 | 
            +
            def sliding_window(
         | 
| 101 | 
            +
                iterable: Iterable[_T], n: int
         | 
| 102 | 
            +
            ) -> Iterator[tuple[_T, ...]]: ...
         | 
| 103 | 
            +
            def subslices(iterable: Iterable[_T]) -> Iterator[list[_T]]: ...
         | 
| 104 | 
            +
            def polynomial_from_roots(roots: Sequence[_T]) -> list[_T]: ...
         | 
| 105 | 
            +
            def iter_index(
         | 
| 106 | 
            +
                iterable: Iterable[object],
         | 
| 107 | 
            +
                value: Any,
         | 
| 108 | 
            +
                start: int | None = ...,
         | 
| 109 | 
            +
            ) -> Iterator[int]: ...
         | 
| 110 | 
            +
            def sieve(n: int) -> Iterator[int]: ...
         | 
| 111 | 
            +
            def batched(
         | 
| 112 | 
            +
                iterable: Iterable[_T],
         | 
| 113 | 
            +
                n: int,
         | 
| 114 | 
            +
            ) -> Iterator[tuple[_T]]: ...
         | 
| 115 | 
            +
            def transpose(
         | 
| 116 | 
            +
                it: Iterable[Iterable[_T]],
         | 
| 117 | 
            +
            ) -> Iterator[tuple[_T, ...]]: ...
         | 
| 118 | 
            +
            def matmul(m1: Sequence[_T], m2: Sequence[_T]) -> Iterator[tuple[_T]]: ...
         | 
| 119 | 
            +
            def factor(n: int) -> Iterator[int]: ...
         | 
| 120 | 
            +
            def polynomial_eval(coefficients: Sequence[_T], x: _U) -> _U: ...
         | 
| 121 | 
            +
            def sum_of_squares(it: Iterable[_T]) -> _T: ...
         | 
| 122 | 
            +
            def polynomial_derivative(coefficients: Sequence[_T]) -> list[_T]: ...
         | 
    	
        lib/python3.11/site-packages/mpmath/__init__.py
    ADDED
    
    | @@ -0,0 +1,468 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            __version__ = '1.3.0'
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from .usertools import monitor, timing
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from .ctx_fp import FPContext
         | 
| 6 | 
            +
            from .ctx_mp import MPContext
         | 
| 7 | 
            +
            from .ctx_iv import MPIntervalContext
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            fp = FPContext()
         | 
| 10 | 
            +
            mp = MPContext()
         | 
| 11 | 
            +
            iv = MPIntervalContext()
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            fp._mp = mp
         | 
| 14 | 
            +
            mp._mp = mp
         | 
| 15 | 
            +
            iv._mp = mp
         | 
| 16 | 
            +
            mp._fp = fp
         | 
| 17 | 
            +
            fp._fp = fp
         | 
| 18 | 
            +
            mp._iv = iv
         | 
| 19 | 
            +
            fp._iv = iv
         | 
| 20 | 
            +
            iv._iv = iv
         | 
| 21 | 
            +
             | 
| 22 | 
            +
            # XXX: extremely bad pickle hack
         | 
| 23 | 
            +
            from . import ctx_mp as _ctx_mp
         | 
| 24 | 
            +
            _ctx_mp._mpf_module.mpf = mp.mpf
         | 
| 25 | 
            +
            _ctx_mp._mpf_module.mpc = mp.mpc
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            make_mpf = mp.make_mpf
         | 
| 28 | 
            +
            make_mpc = mp.make_mpc
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            extraprec = mp.extraprec
         | 
| 31 | 
            +
            extradps = mp.extradps
         | 
| 32 | 
            +
            workprec = mp.workprec
         | 
| 33 | 
            +
            workdps = mp.workdps
         | 
| 34 | 
            +
            autoprec = mp.autoprec
         | 
| 35 | 
            +
            maxcalls = mp.maxcalls
         | 
| 36 | 
            +
            memoize = mp.memoize
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            mag = mp.mag
         | 
| 39 | 
            +
             | 
| 40 | 
            +
            bernfrac = mp.bernfrac
         | 
| 41 | 
            +
             | 
| 42 | 
            +
            qfrom = mp.qfrom
         | 
| 43 | 
            +
            mfrom = mp.mfrom
         | 
| 44 | 
            +
            kfrom = mp.kfrom
         | 
| 45 | 
            +
            taufrom = mp.taufrom
         | 
| 46 | 
            +
            qbarfrom = mp.qbarfrom
         | 
| 47 | 
            +
            ellipfun = mp.ellipfun
         | 
| 48 | 
            +
            jtheta = mp.jtheta
         | 
| 49 | 
            +
            kleinj = mp.kleinj
         | 
| 50 | 
            +
            eta = mp.eta
         | 
| 51 | 
            +
             | 
| 52 | 
            +
            qp = mp.qp
         | 
| 53 | 
            +
            qhyper = mp.qhyper
         | 
| 54 | 
            +
            qgamma = mp.qgamma
         | 
| 55 | 
            +
            qfac = mp.qfac
         | 
| 56 | 
            +
             | 
| 57 | 
            +
            nint_distance = mp.nint_distance
         | 
| 58 | 
            +
             | 
| 59 | 
            +
            plot = mp.plot
         | 
| 60 | 
            +
            cplot = mp.cplot
         | 
| 61 | 
            +
            splot = mp.splot
         | 
| 62 | 
            +
             | 
| 63 | 
            +
            odefun = mp.odefun
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            jacobian = mp.jacobian
         | 
| 66 | 
            +
            findroot = mp.findroot
         | 
| 67 | 
            +
            multiplicity = mp.multiplicity
         | 
| 68 | 
            +
             | 
| 69 | 
            +
            isinf = mp.isinf
         | 
| 70 | 
            +
            isnan = mp.isnan
         | 
| 71 | 
            +
            isnormal = mp.isnormal
         | 
| 72 | 
            +
            isint = mp.isint
         | 
| 73 | 
            +
            isfinite = mp.isfinite
         | 
| 74 | 
            +
            almosteq = mp.almosteq
         | 
| 75 | 
            +
            nan = mp.nan
         | 
| 76 | 
            +
            rand = mp.rand
         | 
| 77 | 
            +
             | 
| 78 | 
            +
            absmin = mp.absmin
         | 
| 79 | 
            +
            absmax = mp.absmax
         | 
| 80 | 
            +
             | 
| 81 | 
            +
            fraction = mp.fraction
         | 
| 82 | 
            +
             | 
| 83 | 
            +
            linspace = mp.linspace
         | 
| 84 | 
            +
            arange = mp.arange
         | 
| 85 | 
            +
             | 
| 86 | 
            +
            mpmathify = convert = mp.convert
         | 
| 87 | 
            +
            mpc = mp.mpc
         | 
| 88 | 
            +
             | 
| 89 | 
            +
            mpi = iv._mpi
         | 
| 90 | 
            +
             | 
| 91 | 
            +
            nstr = mp.nstr
         | 
| 92 | 
            +
            nprint = mp.nprint
         | 
| 93 | 
            +
            chop = mp.chop
         | 
| 94 | 
            +
             | 
| 95 | 
            +
            fneg = mp.fneg
         | 
| 96 | 
            +
            fadd = mp.fadd
         | 
| 97 | 
            +
            fsub = mp.fsub
         | 
| 98 | 
            +
            fmul = mp.fmul
         | 
| 99 | 
            +
            fdiv = mp.fdiv
         | 
| 100 | 
            +
            fprod = mp.fprod
         | 
| 101 | 
            +
             | 
| 102 | 
            +
            quad = mp.quad
         | 
| 103 | 
            +
            quadgl = mp.quadgl
         | 
| 104 | 
            +
            quadts = mp.quadts
         | 
| 105 | 
            +
            quadosc = mp.quadosc
         | 
| 106 | 
            +
            quadsubdiv = mp.quadsubdiv
         | 
| 107 | 
            +
             | 
| 108 | 
            +
            invertlaplace = mp.invertlaplace
         | 
| 109 | 
            +
            invlaptalbot = mp.invlaptalbot
         | 
| 110 | 
            +
            invlapstehfest = mp.invlapstehfest
         | 
| 111 | 
            +
            invlapdehoog = mp.invlapdehoog
         | 
| 112 | 
            +
             | 
| 113 | 
            +
            pslq = mp.pslq
         | 
| 114 | 
            +
            identify = mp.identify
         | 
| 115 | 
            +
            findpoly = mp.findpoly
         | 
| 116 | 
            +
             | 
| 117 | 
            +
            richardson = mp.richardson
         | 
| 118 | 
            +
            shanks = mp.shanks
         | 
| 119 | 
            +
            levin = mp.levin
         | 
| 120 | 
            +
            cohen_alt = mp.cohen_alt
         | 
| 121 | 
            +
            nsum = mp.nsum
         | 
| 122 | 
            +
            nprod = mp.nprod
         | 
| 123 | 
            +
            difference = mp.difference
         | 
| 124 | 
            +
            diff = mp.diff
         | 
| 125 | 
            +
            diffs = mp.diffs
         | 
| 126 | 
            +
            diffs_prod = mp.diffs_prod
         | 
| 127 | 
            +
            diffs_exp = mp.diffs_exp
         | 
| 128 | 
            +
            diffun = mp.diffun
         | 
| 129 | 
            +
            differint = mp.differint
         | 
| 130 | 
            +
            taylor = mp.taylor
         | 
| 131 | 
            +
            pade = mp.pade
         | 
| 132 | 
            +
            polyval = mp.polyval
         | 
| 133 | 
            +
            polyroots = mp.polyroots
         | 
| 134 | 
            +
            fourier = mp.fourier
         | 
| 135 | 
            +
            fourierval = mp.fourierval
         | 
| 136 | 
            +
            sumem = mp.sumem
         | 
| 137 | 
            +
            sumap = mp.sumap
         | 
| 138 | 
            +
            chebyfit = mp.chebyfit
         | 
| 139 | 
            +
            limit = mp.limit
         | 
| 140 | 
            +
             | 
| 141 | 
            +
            matrix = mp.matrix
         | 
| 142 | 
            +
            eye = mp.eye
         | 
| 143 | 
            +
            diag = mp.diag
         | 
| 144 | 
            +
            zeros = mp.zeros
         | 
| 145 | 
            +
            ones = mp.ones
         | 
| 146 | 
            +
            hilbert = mp.hilbert
         | 
| 147 | 
            +
            randmatrix = mp.randmatrix
         | 
| 148 | 
            +
            swap_row = mp.swap_row
         | 
| 149 | 
            +
            extend = mp.extend
         | 
| 150 | 
            +
            norm = mp.norm
         | 
| 151 | 
            +
            mnorm = mp.mnorm
         | 
| 152 | 
            +
             | 
| 153 | 
            +
            lu_solve = mp.lu_solve
         | 
| 154 | 
            +
            lu = mp.lu
         | 
| 155 | 
            +
            qr = mp.qr
         | 
| 156 | 
            +
            unitvector = mp.unitvector
         | 
| 157 | 
            +
            inverse = mp.inverse
         | 
| 158 | 
            +
            residual = mp.residual
         | 
| 159 | 
            +
            qr_solve = mp.qr_solve
         | 
| 160 | 
            +
            cholesky = mp.cholesky
         | 
| 161 | 
            +
            cholesky_solve = mp.cholesky_solve
         | 
| 162 | 
            +
            det = mp.det
         | 
| 163 | 
            +
            cond = mp.cond
         | 
| 164 | 
            +
            hessenberg = mp.hessenberg
         | 
| 165 | 
            +
            schur = mp.schur
         | 
| 166 | 
            +
            eig = mp.eig
         | 
| 167 | 
            +
            eig_sort = mp.eig_sort
         | 
| 168 | 
            +
            eigsy = mp.eigsy
         | 
| 169 | 
            +
            eighe = mp.eighe
         | 
| 170 | 
            +
            eigh = mp.eigh
         | 
| 171 | 
            +
            svd_r = mp.svd_r
         | 
| 172 | 
            +
            svd_c = mp.svd_c
         | 
| 173 | 
            +
            svd = mp.svd
         | 
| 174 | 
            +
            gauss_quadrature = mp.gauss_quadrature
         | 
| 175 | 
            +
             | 
| 176 | 
            +
            expm = mp.expm
         | 
| 177 | 
            +
            sqrtm = mp.sqrtm
         | 
| 178 | 
            +
            powm = mp.powm
         | 
| 179 | 
            +
            logm = mp.logm
         | 
| 180 | 
            +
            sinm = mp.sinm
         | 
| 181 | 
            +
            cosm = mp.cosm
         | 
| 182 | 
            +
             | 
| 183 | 
            +
            mpf = mp.mpf
         | 
| 184 | 
            +
            j = mp.j
         | 
| 185 | 
            +
            exp = mp.exp
         | 
| 186 | 
            +
            expj = mp.expj
         | 
| 187 | 
            +
            expjpi = mp.expjpi
         | 
| 188 | 
            +
            ln = mp.ln
         | 
| 189 | 
            +
            im = mp.im
         | 
| 190 | 
            +
            re = mp.re
         | 
| 191 | 
            +
            inf = mp.inf
         | 
| 192 | 
            +
            ninf = mp.ninf
         | 
| 193 | 
            +
            sign = mp.sign
         | 
| 194 | 
            +
             | 
| 195 | 
            +
            eps = mp.eps
         | 
| 196 | 
            +
            pi = mp.pi
         | 
| 197 | 
            +
            ln2 = mp.ln2
         | 
| 198 | 
            +
            ln10 = mp.ln10
         | 
| 199 | 
            +
            phi = mp.phi
         | 
| 200 | 
            +
            e = mp.e
         | 
| 201 | 
            +
            euler = mp.euler
         | 
| 202 | 
            +
            catalan = mp.catalan
         | 
| 203 | 
            +
            khinchin = mp.khinchin
         | 
| 204 | 
            +
            glaisher = mp.glaisher
         | 
| 205 | 
            +
            apery = mp.apery
         | 
| 206 | 
            +
            degree = mp.degree
         | 
| 207 | 
            +
            twinprime = mp.twinprime
         | 
| 208 | 
            +
            mertens = mp.mertens
         | 
| 209 | 
            +
             | 
| 210 | 
            +
            ldexp = mp.ldexp
         | 
| 211 | 
            +
            frexp = mp.frexp
         | 
| 212 | 
            +
             | 
| 213 | 
            +
            fsum = mp.fsum
         | 
| 214 | 
            +
            fdot = mp.fdot
         | 
| 215 | 
            +
             | 
| 216 | 
            +
            sqrt = mp.sqrt
         | 
| 217 | 
            +
            cbrt = mp.cbrt
         | 
| 218 | 
            +
            exp = mp.exp
         | 
| 219 | 
            +
            ln = mp.ln
         | 
| 220 | 
            +
            log = mp.log
         | 
| 221 | 
            +
            log10 = mp.log10
         | 
| 222 | 
            +
            power = mp.power
         | 
| 223 | 
            +
            cos = mp.cos
         | 
| 224 | 
            +
            sin = mp.sin
         | 
| 225 | 
            +
            tan = mp.tan
         | 
| 226 | 
            +
            cosh = mp.cosh
         | 
| 227 | 
            +
            sinh = mp.sinh
         | 
| 228 | 
            +
            tanh = mp.tanh
         | 
| 229 | 
            +
            acos = mp.acos
         | 
| 230 | 
            +
            asin = mp.asin
         | 
| 231 | 
            +
            atan = mp.atan
         | 
| 232 | 
            +
            asinh = mp.asinh
         | 
| 233 | 
            +
            acosh = mp.acosh
         | 
| 234 | 
            +
            atanh = mp.atanh
         | 
| 235 | 
            +
            sec = mp.sec
         | 
| 236 | 
            +
            csc = mp.csc
         | 
| 237 | 
            +
            cot = mp.cot
         | 
| 238 | 
            +
            sech = mp.sech
         | 
| 239 | 
            +
            csch = mp.csch
         | 
| 240 | 
            +
            coth = mp.coth
         | 
| 241 | 
            +
            asec = mp.asec
         | 
| 242 | 
            +
            acsc = mp.acsc
         | 
| 243 | 
            +
            acot = mp.acot
         | 
| 244 | 
            +
            asech = mp.asech
         | 
| 245 | 
            +
            acsch = mp.acsch
         | 
| 246 | 
            +
            acoth = mp.acoth
         | 
| 247 | 
            +
            cospi = mp.cospi
         | 
| 248 | 
            +
            sinpi = mp.sinpi
         | 
| 249 | 
            +
            sinc = mp.sinc
         | 
| 250 | 
            +
            sincpi = mp.sincpi
         | 
| 251 | 
            +
            cos_sin = mp.cos_sin
         | 
| 252 | 
            +
            cospi_sinpi = mp.cospi_sinpi
         | 
| 253 | 
            +
            fabs = mp.fabs
         | 
| 254 | 
            +
            re = mp.re
         | 
| 255 | 
            +
            im = mp.im
         | 
| 256 | 
            +
            conj = mp.conj
         | 
| 257 | 
            +
            floor = mp.floor
         | 
| 258 | 
            +
            ceil = mp.ceil
         | 
| 259 | 
            +
            nint = mp.nint
         | 
| 260 | 
            +
            frac = mp.frac
         | 
| 261 | 
            +
            root = mp.root
         | 
| 262 | 
            +
            nthroot = mp.nthroot
         | 
| 263 | 
            +
            hypot = mp.hypot
         | 
| 264 | 
            +
            fmod = mp.fmod
         | 
| 265 | 
            +
            ldexp = mp.ldexp
         | 
| 266 | 
            +
            frexp = mp.frexp
         | 
| 267 | 
            +
            sign = mp.sign
         | 
| 268 | 
            +
            arg = mp.arg
         | 
| 269 | 
            +
            phase = mp.phase
         | 
| 270 | 
            +
            polar = mp.polar
         | 
| 271 | 
            +
            rect = mp.rect
         | 
| 272 | 
            +
            degrees = mp.degrees
         | 
| 273 | 
            +
            radians = mp.radians
         | 
| 274 | 
            +
            atan2 = mp.atan2
         | 
| 275 | 
            +
            fib = mp.fib
         | 
| 276 | 
            +
            fibonacci = mp.fibonacci
         | 
| 277 | 
            +
            lambertw = mp.lambertw
         | 
| 278 | 
            +
            zeta = mp.zeta
         | 
| 279 | 
            +
            altzeta = mp.altzeta
         | 
| 280 | 
            +
            gamma = mp.gamma
         | 
| 281 | 
            +
            rgamma = mp.rgamma
         | 
| 282 | 
            +
            factorial = mp.factorial
         | 
| 283 | 
            +
            fac = mp.fac
         | 
| 284 | 
            +
            fac2 = mp.fac2
         | 
| 285 | 
            +
            beta = mp.beta
         | 
| 286 | 
            +
            betainc = mp.betainc
         | 
| 287 | 
            +
            psi = mp.psi
         | 
| 288 | 
            +
            #psi0 = mp.psi0
         | 
| 289 | 
            +
            #psi1 = mp.psi1
         | 
| 290 | 
            +
            #psi2 = mp.psi2
         | 
| 291 | 
            +
            #psi3 = mp.psi3
         | 
| 292 | 
            +
            polygamma = mp.polygamma
         | 
| 293 | 
            +
            digamma = mp.digamma
         | 
| 294 | 
            +
            #trigamma = mp.trigamma
         | 
| 295 | 
            +
            #tetragamma = mp.tetragamma
         | 
| 296 | 
            +
            #pentagamma = mp.pentagamma
         | 
| 297 | 
            +
            harmonic = mp.harmonic
         | 
| 298 | 
            +
            bernoulli = mp.bernoulli
         | 
| 299 | 
            +
            bernfrac = mp.bernfrac
         | 
| 300 | 
            +
            stieltjes = mp.stieltjes
         | 
| 301 | 
            +
            hurwitz = mp.hurwitz
         | 
| 302 | 
            +
            dirichlet = mp.dirichlet
         | 
| 303 | 
            +
            bernpoly = mp.bernpoly
         | 
| 304 | 
            +
            eulerpoly = mp.eulerpoly
         | 
| 305 | 
            +
            eulernum = mp.eulernum
         | 
| 306 | 
            +
            polylog = mp.polylog
         | 
| 307 | 
            +
            clsin = mp.clsin
         | 
| 308 | 
            +
            clcos = mp.clcos
         | 
| 309 | 
            +
            gammainc = mp.gammainc
         | 
| 310 | 
            +
            gammaprod = mp.gammaprod
         | 
| 311 | 
            +
            binomial = mp.binomial
         | 
| 312 | 
            +
            rf = mp.rf
         | 
| 313 | 
            +
            ff = mp.ff
         | 
| 314 | 
            +
            hyper = mp.hyper
         | 
| 315 | 
            +
            hyp0f1 = mp.hyp0f1
         | 
| 316 | 
            +
            hyp1f1 = mp.hyp1f1
         | 
| 317 | 
            +
            hyp1f2 = mp.hyp1f2
         | 
| 318 | 
            +
            hyp2f1 = mp.hyp2f1
         | 
| 319 | 
            +
            hyp2f2 = mp.hyp2f2
         | 
| 320 | 
            +
            hyp2f0 = mp.hyp2f0
         | 
| 321 | 
            +
            hyp2f3 = mp.hyp2f3
         | 
| 322 | 
            +
            hyp3f2 = mp.hyp3f2
         | 
| 323 | 
            +
            hyperu = mp.hyperu
         | 
| 324 | 
            +
            hypercomb = mp.hypercomb
         | 
| 325 | 
            +
            meijerg = mp.meijerg
         | 
| 326 | 
            +
            appellf1 = mp.appellf1
         | 
| 327 | 
            +
            appellf2 = mp.appellf2
         | 
| 328 | 
            +
            appellf3 = mp.appellf3
         | 
| 329 | 
            +
            appellf4 = mp.appellf4
         | 
| 330 | 
            +
            hyper2d = mp.hyper2d
         | 
| 331 | 
            +
            bihyper = mp.bihyper
         | 
| 332 | 
            +
            erf = mp.erf
         | 
| 333 | 
            +
            erfc = mp.erfc
         | 
| 334 | 
            +
            erfi = mp.erfi
         | 
| 335 | 
            +
            erfinv = mp.erfinv
         | 
| 336 | 
            +
            npdf = mp.npdf
         | 
| 337 | 
            +
            ncdf = mp.ncdf
         | 
| 338 | 
            +
            expint = mp.expint
         | 
| 339 | 
            +
            e1 = mp.e1
         | 
| 340 | 
            +
            ei = mp.ei
         | 
| 341 | 
            +
            li = mp.li
         | 
| 342 | 
            +
            ci = mp.ci
         | 
| 343 | 
            +
            si = mp.si
         | 
| 344 | 
            +
            chi = mp.chi
         | 
| 345 | 
            +
            shi = mp.shi
         | 
| 346 | 
            +
            fresnels = mp.fresnels
         | 
| 347 | 
            +
            fresnelc = mp.fresnelc
         | 
| 348 | 
            +
            airyai = mp.airyai
         | 
| 349 | 
            +
            airybi = mp.airybi
         | 
| 350 | 
            +
            airyaizero = mp.airyaizero
         | 
| 351 | 
            +
            airybizero = mp.airybizero
         | 
| 352 | 
            +
            scorergi = mp.scorergi
         | 
| 353 | 
            +
            scorerhi = mp.scorerhi
         | 
| 354 | 
            +
            ellipk = mp.ellipk
         | 
| 355 | 
            +
            ellipe = mp.ellipe
         | 
| 356 | 
            +
            ellipf = mp.ellipf
         | 
| 357 | 
            +
            ellippi = mp.ellippi
         | 
| 358 | 
            +
            elliprc = mp.elliprc
         | 
| 359 | 
            +
            elliprj = mp.elliprj
         | 
| 360 | 
            +
            elliprf = mp.elliprf
         | 
| 361 | 
            +
            elliprd = mp.elliprd
         | 
| 362 | 
            +
            elliprg = mp.elliprg
         | 
| 363 | 
            +
            agm = mp.agm
         | 
| 364 | 
            +
            jacobi = mp.jacobi
         | 
| 365 | 
            +
            chebyt = mp.chebyt
         | 
| 366 | 
            +
            chebyu = mp.chebyu
         | 
| 367 | 
            +
            legendre = mp.legendre
         | 
| 368 | 
            +
            legenp = mp.legenp
         | 
| 369 | 
            +
            legenq = mp.legenq
         | 
| 370 | 
            +
            hermite = mp.hermite
         | 
| 371 | 
            +
            pcfd = mp.pcfd
         | 
| 372 | 
            +
            pcfu = mp.pcfu
         | 
| 373 | 
            +
            pcfv = mp.pcfv
         | 
| 374 | 
            +
            pcfw = mp.pcfw
         | 
| 375 | 
            +
            gegenbauer = mp.gegenbauer
         | 
| 376 | 
            +
            laguerre = mp.laguerre
         | 
| 377 | 
            +
            spherharm = mp.spherharm
         | 
| 378 | 
            +
            besselj = mp.besselj
         | 
| 379 | 
            +
            j0 = mp.j0
         | 
| 380 | 
            +
            j1 = mp.j1
         | 
| 381 | 
            +
            besseli = mp.besseli
         | 
| 382 | 
            +
            bessely = mp.bessely
         | 
| 383 | 
            +
            besselk = mp.besselk
         | 
| 384 | 
            +
            besseljzero = mp.besseljzero
         | 
| 385 | 
            +
            besselyzero = mp.besselyzero
         | 
| 386 | 
            +
            hankel1 = mp.hankel1
         | 
| 387 | 
            +
            hankel2 = mp.hankel2
         | 
| 388 | 
            +
            struveh = mp.struveh
         | 
| 389 | 
            +
            struvel = mp.struvel
         | 
| 390 | 
            +
            angerj = mp.angerj
         | 
| 391 | 
            +
            webere = mp.webere
         | 
| 392 | 
            +
            lommels1 = mp.lommels1
         | 
| 393 | 
            +
            lommels2 = mp.lommels2
         | 
| 394 | 
            +
            whitm = mp.whitm
         | 
| 395 | 
            +
            whitw = mp.whitw
         | 
| 396 | 
            +
            ber = mp.ber
         | 
| 397 | 
            +
            bei = mp.bei
         | 
| 398 | 
            +
            ker = mp.ker
         | 
| 399 | 
            +
            kei = mp.kei
         | 
| 400 | 
            +
            coulombc = mp.coulombc
         | 
| 401 | 
            +
            coulombf = mp.coulombf
         | 
| 402 | 
            +
            coulombg = mp.coulombg
         | 
| 403 | 
            +
            barnesg = mp.barnesg
         | 
| 404 | 
            +
            superfac = mp.superfac
         | 
| 405 | 
            +
            hyperfac = mp.hyperfac
         | 
| 406 | 
            +
            loggamma = mp.loggamma
         | 
| 407 | 
            +
            siegeltheta = mp.siegeltheta
         | 
| 408 | 
            +
            siegelz = mp.siegelz
         | 
| 409 | 
            +
            grampoint = mp.grampoint
         | 
| 410 | 
            +
            zetazero = mp.zetazero
         | 
| 411 | 
            +
            riemannr = mp.riemannr
         | 
| 412 | 
            +
            primepi = mp.primepi
         | 
| 413 | 
            +
            primepi2 = mp.primepi2
         | 
| 414 | 
            +
            primezeta = mp.primezeta
         | 
| 415 | 
            +
            bell = mp.bell
         | 
| 416 | 
            +
            polyexp = mp.polyexp
         | 
| 417 | 
            +
            expm1 = mp.expm1
         | 
| 418 | 
            +
            log1p = mp.log1p
         | 
| 419 | 
            +
            powm1 = mp.powm1
         | 
| 420 | 
            +
            unitroots = mp.unitroots
         | 
| 421 | 
            +
            cyclotomic = mp.cyclotomic
         | 
| 422 | 
            +
            mangoldt = mp.mangoldt
         | 
| 423 | 
            +
            secondzeta = mp.secondzeta
         | 
| 424 | 
            +
            nzeros = mp.nzeros
         | 
| 425 | 
            +
            backlunds = mp.backlunds
         | 
| 426 | 
            +
            lerchphi = mp.lerchphi
         | 
| 427 | 
            +
            stirling1 = mp.stirling1
         | 
| 428 | 
            +
            stirling2 = mp.stirling2
         | 
| 429 | 
            +
            squarew = mp.squarew
         | 
| 430 | 
            +
            trianglew = mp.trianglew
         | 
| 431 | 
            +
            sawtoothw = mp.sawtoothw
         | 
| 432 | 
            +
            unit_triangle = mp.unit_triangle
         | 
| 433 | 
            +
            sigmoid = mp.sigmoid
         | 
| 434 | 
            +
             | 
| 435 | 
            +
            # be careful when changing this name, don't use test*!
         | 
| 436 | 
            +
            def runtests():
         | 
| 437 | 
            +
                """
         | 
| 438 | 
            +
                Run all mpmath tests and print output.
         | 
| 439 | 
            +
                """
         | 
| 440 | 
            +
                import os.path
         | 
| 441 | 
            +
                from inspect import getsourcefile
         | 
| 442 | 
            +
                from .tests import runtests as tests
         | 
| 443 | 
            +
                testdir = os.path.dirname(os.path.abspath(getsourcefile(tests)))
         | 
| 444 | 
            +
                importdir = os.path.abspath(testdir + '/../..')
         | 
| 445 | 
            +
                tests.testit(importdir, testdir)
         | 
| 446 | 
            +
             | 
| 447 | 
            +
            def doctests(filter=[]):
         | 
| 448 | 
            +
                import sys
         | 
| 449 | 
            +
                from timeit import default_timer as clock
         | 
| 450 | 
            +
                for i, arg in enumerate(sys.argv):
         | 
| 451 | 
            +
                    if '__init__.py' in arg:
         | 
| 452 | 
            +
                        filter = [sn for sn in sys.argv[i+1:] if not sn.startswith("-")]
         | 
| 453 | 
            +
                        break
         | 
| 454 | 
            +
                import doctest
         | 
| 455 | 
            +
                globs = globals().copy()
         | 
| 456 | 
            +
                for obj in globs: #sorted(globs.keys()):
         | 
| 457 | 
            +
                    if filter:
         | 
| 458 | 
            +
                        if not sum([pat in obj for pat in filter]):
         | 
| 459 | 
            +
                            continue
         | 
| 460 | 
            +
                    sys.stdout.write(str(obj) + " ")
         | 
| 461 | 
            +
                    sys.stdout.flush()
         | 
| 462 | 
            +
                    t1 = clock()
         | 
| 463 | 
            +
                    doctest.run_docstring_examples(globs[obj], {}, verbose=("-v" in sys.argv))
         | 
| 464 | 
            +
                    t2 = clock()
         | 
| 465 | 
            +
                    print(round(t2-t1, 3))
         | 
| 466 | 
            +
             | 
| 467 | 
            +
            if __name__ == '__main__':
         | 
| 468 | 
            +
                doctests()
         | 
    	
        lib/python3.11/site-packages/mpmath/__pycache__/__init__.cpython-311.pyc
    ADDED
    
    | Binary file (14.7 kB). View file | 
|  | 
    	
        lib/python3.11/site-packages/mpmath/__pycache__/ctx_base.cpython-311.pyc
    ADDED
    
    | Binary file (24.4 kB). View file | 
|  | 
    	
        lib/python3.11/site-packages/mpmath/__pycache__/ctx_fp.cpython-311.pyc
    ADDED
    
    | Binary file (13 kB). View file | 
|  | 
    	
        lib/python3.11/site-packages/mpmath/__pycache__/ctx_iv.cpython-311.pyc
    ADDED
    
    | Binary file (38.7 kB). View file | 
|  | 
    	
        lib/python3.11/site-packages/mpmath/__pycache__/ctx_mp.cpython-311.pyc
    ADDED
    
    | Binary file (71.2 kB). View file | 
|  | 
    	
        lib/python3.11/site-packages/mpmath/__pycache__/ctx_mp_python.cpython-311.pyc
    ADDED
    
    | Binary file (61.2 kB). View file | 
|  | 
    	
        lib/python3.11/site-packages/mpmath/__pycache__/function_docs.cpython-311.pyc
    ADDED
    
    | Binary file (285 kB). View file | 
|  | 

