ce304fafe19161978ad512b385c65426bad519e5a0b8fb3f0659eace3d2ea3cc
Browse files- lib/python3.11/site-packages/mlx-0.0.7.dist-info/INSTALLER +1 -0
- lib/python3.11/site-packages/mlx-0.0.7.dist-info/LICENSE +21 -0
- lib/python3.11/site-packages/mlx-0.0.7.dist-info/METADATA +122 -0
- lib/python3.11/site-packages/mlx-0.0.7.dist-info/RECORD +199 -0
- lib/python3.11/site-packages/mlx-0.0.7.dist-info/REQUESTED +0 -0
- lib/python3.11/site-packages/mlx-0.0.7.dist-info/WHEEL +5 -0
- lib/python3.11/site-packages/mlx-0.0.7.dist-info/top_level.txt +1 -0
- lib/python3.11/site-packages/mlx/nn/layers/base.py +532 -0
- lib/python3.11/site-packages/mlx/nn/layers/containers.py +24 -0
- lib/python3.11/site-packages/mlx/nn/layers/convolution.py +126 -0
- lib/python3.11/site-packages/mlx/nn/layers/dropout.py +137 -0
- lib/python3.11/site-packages/mlx/nn/layers/embedding.py +30 -0
- lib/python3.11/site-packages/mlx/nn/layers/linear.py +141 -0
- lib/python3.11/site-packages/mlx/nn/layers/normalization.py +368 -0
- lib/python3.11/site-packages/mlx/nn/layers/positional_encoding.py +199 -0
- lib/python3.11/site-packages/mlx/nn/layers/quantized.py +125 -0
- lib/python3.11/site-packages/mlx/nn/layers/transformer.py +354 -0
- lib/python3.11/site-packages/mlx/nn/losses.py +374 -0
- lib/python3.11/site-packages/mlx/nn/utils.py +33 -0
- lib/python3.11/site-packages/mlx/optimizers.py +500 -0
- lib/python3.11/site-packages/mlx/share/cmake/MLX/MLXConfig.cmake +57 -0
- lib/python3.11/site-packages/mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
- lib/python3.11/site-packages/mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
- lib/python3.11/site-packages/mlx/share/cmake/MLX/MLXTargets.cmake +107 -0
- lib/python3.11/site-packages/mlx/share/cmake/MLX/extension.cmake +56 -0
- lib/python3.11/site-packages/mlx/utils.py +145 -0
- lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/INSTALLER +1 -0
- lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/LICENSE +19 -0
- lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/METADATA +253 -0
- lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/RECORD +16 -0
- lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/REQUESTED +0 -0
- lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/WHEEL +4 -0
- lib/python3.11/site-packages/more_itertools/__init__.py +6 -0
- lib/python3.11/site-packages/more_itertools/__init__.pyi +2 -0
- lib/python3.11/site-packages/more_itertools/__pycache__/__init__.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/more_itertools/__pycache__/more.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/more_itertools/__pycache__/recipes.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/more_itertools/more.py +0 -0
- lib/python3.11/site-packages/more_itertools/more.pyi +684 -0
- lib/python3.11/site-packages/more_itertools/py.typed +0 -0
- lib/python3.11/site-packages/more_itertools/recipes.py +977 -0
- lib/python3.11/site-packages/more_itertools/recipes.pyi +122 -0
- lib/python3.11/site-packages/mpmath/__init__.py +468 -0
- lib/python3.11/site-packages/mpmath/__pycache__/__init__.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mpmath/__pycache__/ctx_base.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mpmath/__pycache__/ctx_fp.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mpmath/__pycache__/ctx_iv.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mpmath/__pycache__/ctx_mp.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mpmath/__pycache__/ctx_mp_python.cpython-311.pyc +0 -0
- lib/python3.11/site-packages/mpmath/__pycache__/function_docs.cpython-311.pyc +0 -0
lib/python3.11/site-packages/mlx-0.0.7.dist-info/INSTALLER
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
pip
|
lib/python3.11/site-packages/mlx-0.0.7.dist-info/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright © 2023 Apple Inc.
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
lib/python3.11/site-packages/mlx-0.0.7.dist-info/METADATA
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Metadata-Version: 2.1
|
2 |
+
Name: mlx
|
3 |
+
Version: 0.0.7
|
4 |
+
Summary: A framework for machine learning on Apple silicon.
|
5 |
+
Author: MLX Contributors
|
6 |
+
Author-email: [email protected]
|
7 |
+
Requires-Python: >=3.8
|
8 |
+
Description-Content-Type: text/markdown
|
9 |
+
License-File: LICENSE
|
10 |
+
Provides-Extra: dev
|
11 |
+
Requires-Dist: pre-commit ; extra == 'dev'
|
12 |
+
Requires-Dist: pybind11-stubgen ; extra == 'dev'
|
13 |
+
Provides-Extra: testing
|
14 |
+
Requires-Dist: numpy ; extra == 'testing'
|
15 |
+
Requires-Dist: torch ; extra == 'testing'
|
16 |
+
|
17 |
+
# MLX
|
18 |
+
|
19 |
+
[**Quickstart**](#quickstart) | [**Installation**](#installation) |
|
20 |
+
[**Documentation**](https://ml-explore.github.io/mlx/build/html/index.html) |
|
21 |
+
[**Examples**](#examples)
|
22 |
+
|
23 |
+
[![CircleCI](https://circleci.com/gh/ml-explore/mlx.svg?style=svg)](https://circleci.com/gh/ml-explore/mlx)
|
24 |
+
|
25 |
+
MLX is an array framework for machine learning on Apple silicon, brought to you
|
26 |
+
by Apple machine learning research.
|
27 |
+
|
28 |
+
Some key features of MLX include:
|
29 |
+
|
30 |
+
- **Familiar APIs**: MLX has a Python API that closely follows NumPy.
|
31 |
+
MLX also has a fully featured C++ API, which closely mirrors the Python API.
|
32 |
+
MLX has higher-level packages like `mlx.nn` and `mlx.optimizers` with APIs
|
33 |
+
that closely follow PyTorch to simplify building more complex models.
|
34 |
+
|
35 |
+
- **Composable function transformations**: MLX supports composable function
|
36 |
+
transformations for automatic differentiation, automatic vectorization,
|
37 |
+
and computation graph optimization.
|
38 |
+
|
39 |
+
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
|
40 |
+
materialized when needed.
|
41 |
+
|
42 |
+
- **Dynamic graph construction**: Computation graphs in MLX are constructed
|
43 |
+
dynamically. Changing the shapes of function arguments does not trigger
|
44 |
+
slow compilations, and debugging is simple and intuitive.
|
45 |
+
|
46 |
+
- **Multi-device**: Operations can run on any of the supported devices
|
47 |
+
(currently the CPU and the GPU).
|
48 |
+
|
49 |
+
- **Unified memory**: A notable difference from MLX and other frameworks
|
50 |
+
is the *unified memory model*. Arrays in MLX live in shared memory.
|
51 |
+
Operations on MLX arrays can be performed on any of the supported
|
52 |
+
device types without transferring data.
|
53 |
+
|
54 |
+
MLX is designed by machine learning researchers for machine learning
|
55 |
+
researchers. The framework is intended to be user-friendly, but still efficient
|
56 |
+
to train and deploy models. The design of the framework itself is also
|
57 |
+
conceptually simple. We intend to make it easy for researchers to extend and
|
58 |
+
improve MLX with the goal of quickly exploring new ideas.
|
59 |
+
|
60 |
+
The design of MLX is inspired by frameworks like
|
61 |
+
[NumPy](https://numpy.org/doc/stable/index.html),
|
62 |
+
[PyTorch](https://pytorch.org/), [Jax](https://github.com/google/jax), and
|
63 |
+
[ArrayFire](https://arrayfire.org/).
|
64 |
+
|
65 |
+
## Examples
|
66 |
+
|
67 |
+
The [MLX examples repo](https://github.com/ml-explore/mlx-examples) has a
|
68 |
+
variety of examples, including:
|
69 |
+
|
70 |
+
- [Transformer language model](https://github.com/ml-explore/mlx-examples/tree/main/transformer_lm) training.
|
71 |
+
- Large-scale text generation with
|
72 |
+
[LLaMA](https://github.com/ml-explore/mlx-examples/tree/main/llms/llama) and
|
73 |
+
finetuning with [LoRA](https://github.com/ml-explore/mlx-examples/tree/main/lora).
|
74 |
+
- Generating images with [Stable Diffusion](https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion).
|
75 |
+
- Speech recognition with [OpenAI's Whisper](https://github.com/ml-explore/mlx-examples/tree/main/whisper).
|
76 |
+
|
77 |
+
## Quickstart
|
78 |
+
|
79 |
+
See the [quick start
|
80 |
+
guide](https://ml-explore.github.io/mlx/build/html/quick_start.html)
|
81 |
+
in the documentation.
|
82 |
+
|
83 |
+
## Installation
|
84 |
+
|
85 |
+
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
|
86 |
+
|
87 |
+
```
|
88 |
+
pip install mlx
|
89 |
+
```
|
90 |
+
|
91 |
+
Checkout the
|
92 |
+
[documentation](https://ml-explore.github.io/mlx/build/html/install.html#)
|
93 |
+
for more information on building the C++ and Python APIs from source.
|
94 |
+
|
95 |
+
## Contributing
|
96 |
+
|
97 |
+
Check out the [contribution guidelines](CONTRIBUTING.md) for more information
|
98 |
+
on contributing to MLX. See the
|
99 |
+
[docs](https://ml-explore.github.io/mlx/build/html/install.html) for more
|
100 |
+
information on building from source, and running tests.
|
101 |
+
|
102 |
+
We are grateful for all of [our
|
103 |
+
contributors](ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
|
104 |
+
to MLX and wish to be acknowledged, please add your name to the list in your
|
105 |
+
pull request.
|
106 |
+
|
107 |
+
## Citing MLX
|
108 |
+
|
109 |
+
The MLX software suite was initially developed with equal contribution by Awni
|
110 |
+
Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find
|
111 |
+
MLX useful in your research and wish to cite it, please use the following
|
112 |
+
BibTex entry:
|
113 |
+
|
114 |
+
```
|
115 |
+
@software{mlx2023,
|
116 |
+
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
|
117 |
+
title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
|
118 |
+
url = {https://github.com/ml-explore},
|
119 |
+
version = {0.0},
|
120 |
+
year = {2023},
|
121 |
+
}
|
122 |
+
```
|
lib/python3.11/site-packages/mlx-0.0.7.dist-info/RECORD
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
mlx-0.0.7.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
2 |
+
mlx-0.0.7.dist-info/LICENSE,sha256=zPq3zLLqMG9xUxyMp3u1VQdgbNkHaLHjK4tSq1tIzwE,1066
|
3 |
+
mlx-0.0.7.dist-info/METADATA,sha256=CK9nqz-nrSgrFZpRYpc86IYysL_W7zObqiiczNHD5OA,4731
|
4 |
+
mlx-0.0.7.dist-info/RECORD,,
|
5 |
+
mlx-0.0.7.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6 |
+
mlx-0.0.7.dist-info/WHEEL,sha256=9CmlYCCxe9UEGaJDl4r_tPKbyfF_OdzBb5_jMO3h1CA,110
|
7 |
+
mlx-0.0.7.dist-info/top_level.txt,sha256=NCVZzyms-obPlDq2Sa--m3oyfnnmLma38KCnE6gmcFE,4
|
8 |
+
mlx/__pycache__/_reprlib_fix.cpython-311.pyc,,
|
9 |
+
mlx/__pycache__/extension.cpython-311.pyc,,
|
10 |
+
mlx/__pycache__/optimizers.cpython-311.pyc,,
|
11 |
+
mlx/__pycache__/utils.cpython-311.pyc,,
|
12 |
+
mlx/_reprlib_fix.py,sha256=YoRQib4PSzl1FIMvs4BwkvipAwOrfufqtV89vNF4Yhc,518
|
13 |
+
mlx/core.cpython-311-darwin.so,sha256=4GMiWnLS2sZGsrIwB9mAryue5Lztf-SttAn_AzadmCc,1098840
|
14 |
+
mlx/extension.py,sha256=tqKrGeP1ElK9UfBMHUeOq56qWnMMxZrt7DVVM8QZ7yU,3769
|
15 |
+
mlx/include/metal_cpp/Foundation/Foundation.hpp,sha256=pYbEkzTrb2O0jJPzI0m9cBNvVSDTh_t39Dv7h0ZDma0,1808
|
16 |
+
mlx/include/metal_cpp/Foundation/NSArray.hpp,sha256=8nhS0gF060n7taK9aIynkuMPxlbKgqgijsj61yvhrPQ,4818
|
17 |
+
mlx/include/metal_cpp/Foundation/NSAutoreleasePool.hpp,sha256=K8_qNgYgyeP49WUSsz90Lo41qTb-x5Xjs23T9UNSkkU,3302
|
18 |
+
mlx/include/metal_cpp/Foundation/NSBundle.hpp,sha256=y0Le2v43bCAdIdrSClvLpRtCL8c4qdB3WK59Gz_AYYs,16845
|
19 |
+
mlx/include/metal_cpp/Foundation/NSData.hpp,sha256=rAPEZQn1W05O_sZ6LYexLhvv1CmllqyFWddGIrxGJ8s,2198
|
20 |
+
mlx/include/metal_cpp/Foundation/NSDate.hpp,sha256=353UgslEWyV-D1_T7_Sr1fCP4nVjixwJK2QGnRZQKgg,2085
|
21 |
+
mlx/include/metal_cpp/Foundation/NSDefines.hpp,sha256=c8BfIDwW9yCKiofyatyeLJ9YvnHO87PsBqSQtciN3tY,2203
|
22 |
+
mlx/include/metal_cpp/Foundation/NSDictionary.hpp,sha256=Lt7EBR6ppPFhhnoZq9UqMETHf_L1k9A0gs637ZTEQSU,5738
|
23 |
+
mlx/include/metal_cpp/Foundation/NSEnumerator.hpp,sha256=P4mkVmpWN0kdXnqqTMkyIovckr4_IZAoFzyg2IZDRs4,3141
|
24 |
+
mlx/include/metal_cpp/Foundation/NSError.hpp,sha256=rFqCSIXVsWOsH7Ikegv2irMF3aoMn7fM6teB9VipMgc,7802
|
25 |
+
mlx/include/metal_cpp/Foundation/NSLock.hpp,sha256=OjnoMx18m4E89eDPliPMw3cG8LtAqNsoBSxSO0l9DoY,4350
|
26 |
+
mlx/include/metal_cpp/Foundation/NSNotification.hpp,sha256=rpAgI4upmIG2inWZ9ACYy-Q8za7muZ2WmUtA5l8WGHw,4696
|
27 |
+
mlx/include/metal_cpp/Foundation/NSNumber.hpp,sha256=4a2fEU1Fm_wrWOs4zULhYQne0WymnRGsDeSGq1bDbuQ,22098
|
28 |
+
mlx/include/metal_cpp/Foundation/NSObjCRuntime.hpp,sha256=SHYqTLWQ2F8T4dfXJ98OlATOexXQ3NnXzmFyVpb4wFM,1665
|
29 |
+
mlx/include/metal_cpp/Foundation/NSObject.hpp,sha256=T_64vByLI4n3HIE05MVJZteUIAdN-y0oogsczYT-JTE,11105
|
30 |
+
mlx/include/metal_cpp/Foundation/NSPrivate.hpp,sha256=JYewPlM0BKHLnmQiItPlYuc1HNAO-akJjhXnEKUZO00,20217
|
31 |
+
mlx/include/metal_cpp/Foundation/NSProcessInfo.hpp,sha256=1NLI7-3mR8cseE_jE6c-D2_pRJKXA9fRl4V2uRZtiDI,16642
|
32 |
+
mlx/include/metal_cpp/Foundation/NSRange.hpp,sha256=7VFC565QzvzfF1CI91osTozl-d1mM9UToCcGyXmf5cc,3147
|
33 |
+
mlx/include/metal_cpp/Foundation/NSSet.hpp,sha256=0rORF6rpK2-R1AQCgT4CMNnGr18xZ8P5x82LEQLQ2sc,3368
|
34 |
+
mlx/include/metal_cpp/Foundation/NSSharedPtr.hpp,sha256=208e_0hJddLllsT16pJa4rQAyFqA1uYuez5gWyRS6EU,8496
|
35 |
+
mlx/include/metal_cpp/Foundation/NSString.hpp,sha256=AqRvSkbHPncbC36TJuR-4beEVT7ZPwnrVjmyoassaL8,11001
|
36 |
+
mlx/include/metal_cpp/Foundation/NSTypes.hpp,sha256=9IaLleaqYbdjlMvT_mqFlYDlzg4vF_8MQQTO2Sdp6vo,1893
|
37 |
+
mlx/include/metal_cpp/Foundation/NSURL.hpp,sha256=b1oGmKqL5F1WrWSaCK3Iqs0vxL8utVbOgtYW4RZUe_I,3668
|
38 |
+
mlx/include/metal_cpp/LICENSE.txt,sha256=mgnj5ca95epXBMNZliB1O16xkRb5_WScN7W8RJeiXOo,11344
|
39 |
+
mlx/include/metal_cpp/Metal/MTLAccelerationStructure.hpp,sha256=jqK65WOzpPUNE6H-XYb1PC3myk6vBIQr4fRrClCDZYs,85372
|
40 |
+
mlx/include/metal_cpp/Metal/MTLAccelerationStructureCommandEncoder.hpp,sha256=EC1flOCenKN-CNLIbHcpli7kP1PQGxkGUU8ZdMYQ4Qs,16289
|
41 |
+
mlx/include/metal_cpp/Metal/MTLAccelerationStructureTypes.hpp,sha256=I29afo7zt4fiCY2ip3vPN72lhb6Whgjpx0lcHabfib4,5800
|
42 |
+
mlx/include/metal_cpp/Metal/MTLArgument.hpp,sha256=_gQiklV3kRM2ukcYg0JCz3rUsK-zUOiZMYrp_thHjXU,23515
|
43 |
+
mlx/include/metal_cpp/Metal/MTLArgumentEncoder.hpp,sha256=PaV9gzYeC1JLqDdcAJK9H6DaPVV2rclNWvNPK70KUWs,10739
|
44 |
+
mlx/include/metal_cpp/Metal/MTLBinaryArchive.hpp,sha256=I8a_7I2WI-ftil5s4bT8N5rnUeRUt2HtsM2O9LW_t2M,5132
|
45 |
+
mlx/include/metal_cpp/Metal/MTLBlitCommandEncoder.hpp,sha256=6QF2GR7c4XQXfA5COwCb2Q6xRXVMkYnO1gmUlrNw05o,16482
|
46 |
+
mlx/include/metal_cpp/Metal/MTLBlitPass.hpp,sha256=zPt4836TDUuDPyRCyF7tJBEtEMpoWziJLlSRgWleURs,6982
|
47 |
+
mlx/include/metal_cpp/Metal/MTLBuffer.hpp,sha256=IDVEI2qXwq6rU_tQsPkOpma-IV7othYmJIT2QCOBsic,3542
|
48 |
+
mlx/include/metal_cpp/Metal/MTLCaptureManager.hpp,sha256=ANj6r5G9Gc1l9WikJQmZZlOWvfViJclzyPwKF59UvUg,7524
|
49 |
+
mlx/include/metal_cpp/Metal/MTLCaptureScope.hpp,sha256=DU5_HrnsaHBttsqHfXCk-oPr8ifhMdP-DiYpS1G-qC0,3733
|
50 |
+
mlx/include/metal_cpp/Metal/MTLCommandBuffer.hpp,sha256=29LVLOZlzB8G0bVfnHm867TXc0FPqbM7eXbNrZ2zRzE,17642
|
51 |
+
mlx/include/metal_cpp/Metal/MTLCommandEncoder.hpp,sha256=IQGAhZMikxwU1IINtvw3MSTLwlbQKRWuTNZnZAw_0m8,2946
|
52 |
+
mlx/include/metal_cpp/Metal/MTLCommandQueue.hpp,sha256=Uz74PTJAhqMi8XgLrj1RZOio-FYxLnTO9Msdy1dE7RA,2960
|
53 |
+
mlx/include/metal_cpp/Metal/MTLComputeCommandEncoder.hpp,sha256=Z2E_Ny8OJj4gvC_YZtku5sni2t4o9Mbm7xvJk2tpgHc,17157
|
54 |
+
mlx/include/metal_cpp/Metal/MTLComputePass.hpp,sha256=VOyFJlwaB89gmrtLdmnAz7gl97a5HbTYo-M6icWbN5M,7792
|
55 |
+
mlx/include/metal_cpp/Metal/MTLComputePipeline.hpp,sha256=WTBUdVq1sxvHocezHO5wpVmaa94Xnthq27kSUCedcMY,15017
|
56 |
+
mlx/include/metal_cpp/Metal/MTLCounters.hpp,sha256=RmlZCNFZIVPC4ue6FoIgYVq1Weq2Nu1GQD-G0yIwimg,9056
|
57 |
+
mlx/include/metal_cpp/Metal/MTLDefines.hpp,sha256=1MG6A6zwLk99Yknu7PAvBvIMefVhumr1kFoKxQ-6xKk,1900
|
58 |
+
mlx/include/metal_cpp/Metal/MTLDepthStencil.hpp,sha256=nP5OtW6QGUK7tzrXtJWJlch9bXkDYPYgp5BjMAptH6E,9631
|
59 |
+
mlx/include/metal_cpp/Metal/MTLDevice.hpp,sha256=HyHg83T53RuLU1y_ZthBvEIzpOi1Lu0nXptaJfdrhwc,63880
|
60 |
+
mlx/include/metal_cpp/Metal/MTLDrawable.hpp,sha256=07X3GSg7uJEWoOMWxI77cIWdr2NaBfJEXOtPJZr4M44,3193
|
61 |
+
mlx/include/metal_cpp/Metal/MTLDynamicLibrary.hpp,sha256=yBUYfobLLHd00UOGALdsEF0FBiF1hUDr7G3FdKy6uMY,2606
|
62 |
+
mlx/include/metal_cpp/Metal/MTLEvent.hpp,sha256=W8QBQWIUn-DcuBYzg0NB0L1W_Zs6G6QuQCeKUKM7tIo,4976
|
63 |
+
mlx/include/metal_cpp/Metal/MTLFence.hpp,sha256=DuxPaIkQN9fWi1E6RHg8YrKNeB2nxGOZMuHXd8TiE9c,1723
|
64 |
+
mlx/include/metal_cpp/Metal/MTLFunctionConstantValues.hpp,sha256=dz9s7_TYqjTJbthwEJTDoqdXsZ34JE2k4BERCYzWqWQ,3088
|
65 |
+
mlx/include/metal_cpp/Metal/MTLFunctionDescriptor.hpp,sha256=JkzBrxhlg4_ezt2NLomw-MsGuBqdV9embOekJow-mks,5410
|
66 |
+
mlx/include/metal_cpp/Metal/MTLFunctionHandle.hpp,sha256=HykCHezEwao-pmH6eL57qOyNt_G5RYsrsofxFDyHjiQ,1842
|
67 |
+
mlx/include/metal_cpp/Metal/MTLFunctionLog.hpp,sha256=ZiIPh0Ef5oeiYIWsFZJGnOXZj2bejYFDe_3NVoRT2nY,3262
|
68 |
+
mlx/include/metal_cpp/Metal/MTLFunctionStitching.hpp,sha256=fEipcQvZ2d8GeyPyn0AiyMup4doS9Yf-uzZRFpjkfyI,11407
|
69 |
+
mlx/include/metal_cpp/Metal/MTLHeaderBridge.hpp,sha256=qYNlYMWNAx4z4iNnMyrqdQr_NJ5md2n1HMSxmP5ykRM,105217
|
70 |
+
mlx/include/metal_cpp/Metal/MTLHeap.hpp,sha256=O9zMAFEKDGKAt710IDSwdsW3mYSyM7Mt2IVvJNJnAq0,11624
|
71 |
+
mlx/include/metal_cpp/Metal/MTLIOCommandBuffer.hpp,sha256=OweZPLAmFoJFypxGIWisfb6Bzlvgv5qIYB2MVu6jBLw,7417
|
72 |
+
mlx/include/metal_cpp/Metal/MTLIOCommandQueue.hpp,sha256=Al9BpGDKaSu5zQblMphnr9EXvrVee2ZaZWMba6L7QSs,7465
|
73 |
+
mlx/include/metal_cpp/Metal/MTLIOCompressor.hpp,sha256=s7UjSrJLyfQt69RzWBY692G1audMRzjIVR-7GjkZFOQ,3043
|
74 |
+
mlx/include/metal_cpp/Metal/MTLIndirectCommandBuffer.hpp,sha256=J3H4HH0p8qq_DYU8pQC_w4qDmYtNqj6FzoOo_AVGfxA,9715
|
75 |
+
mlx/include/metal_cpp/Metal/MTLIndirectCommandEncoder.hpp,sha256=AW1Qzyggw9k9EqS71CvlFUHu4lag6nvvzK-LnV4oG0k,11239
|
76 |
+
mlx/include/metal_cpp/Metal/MTLIntersectionFunctionTable.hpp,sha256=NUP7rHPWulH7M2M1wvzlyjjF7S0BJvG-_XFn4nWehcg,8181
|
77 |
+
mlx/include/metal_cpp/Metal/MTLLibrary.hpp,sha256=le2P__N8XA4RCN62L2ZB-F9uo2p8BjQ-A6j0kEOqzws,23904
|
78 |
+
mlx/include/metal_cpp/Metal/MTLLinkedFunctions.hpp,sha256=FmUoNjOGSegryBa0PDa1qcxEF0RUESixbd7OPyU-qDM,3867
|
79 |
+
mlx/include/metal_cpp/Metal/MTLParallelRenderCommandEncoder.hpp,sha256=uHyC4bf8BjFUz6xEkFg79Lqzz9CV-0NNnNoQ-U3xfsA,3899
|
80 |
+
mlx/include/metal_cpp/Metal/MTLPipeline.hpp,sha256=1EaeWaT8YNfvPaX0gqktNMOLtojBF6cVMlZiFT7vjyg,3775
|
81 |
+
mlx/include/metal_cpp/Metal/MTLPixelFormat.hpp,sha256=DYqej_xuztbD0lhsioWTYWO2Sg7jRShyJ7izoVlex5Y,5910
|
82 |
+
mlx/include/metal_cpp/Metal/MTLPrivate.hpp,sha256=vPTld95A83b1nIVT4qwsdcrttngMK5L5Ow-0HTr-zm4,6415
|
83 |
+
mlx/include/metal_cpp/Metal/MTLRasterizationRate.hpp,sha256=LtcYTnjLSTaxTw98WdUg_9R9BaAozvXOTHcsnmSqbTo,15447
|
84 |
+
mlx/include/metal_cpp/Metal/MTLRenderCommandEncoder.hpp,sha256=JNisXL9uSMMm_5B2vkVfuI8dgrMMMIrC8-E7P0pjdeo,60813
|
85 |
+
mlx/include/metal_cpp/Metal/MTLRenderPass.hpp,sha256=YQCYdLOnOWsUgIbQimBDFo4l_FmbZfeHIJFADL3NnRQ,32752
|
86 |
+
mlx/include/metal_cpp/Metal/MTLRenderPipeline.hpp,sha256=TuBYsiWg4hpumgaKccAn8TVAeLw26l_FdVGa6JsaMLA,72929
|
87 |
+
mlx/include/metal_cpp/Metal/MTLResource.hpp,sha256=Be28y-6-9GvFuugCUhnykiTUqbMqYgVovDjUdpIrFDA,5244
|
88 |
+
mlx/include/metal_cpp/Metal/MTLResourceStateCommandEncoder.hpp,sha256=eOjzuG-vDtQLJ8QOzhiqxXxiMEShXyn9cn-Uhmtigr8,5330
|
89 |
+
mlx/include/metal_cpp/Metal/MTLResourceStatePass.hpp,sha256=N5sMbFCjfEdf10_3VZJfQvDl3nLuydqrKnUgiJzynRU,7585
|
90 |
+
mlx/include/metal_cpp/Metal/MTLSampler.hpp,sha256=0rifEvFw4V8_kMmusAU4NXiZeZbCSX_L4x0K6QyEkog,10838
|
91 |
+
mlx/include/metal_cpp/Metal/MTLStageInputOutputDescriptor.hpp,sha256=LfDoa6Dj5WQFS2Qg_FQRCkVgBpbbfjBdFtn1SIzsCyI,13120
|
92 |
+
mlx/include/metal_cpp/Metal/MTLTexture.hpp,sha256=jM-mSxqfGBNXE6xwzx1M7CC8GOiyungRuop-6uTuf-M,24807
|
93 |
+
mlx/include/metal_cpp/Metal/MTLTypes.hpp,sha256=9_le5gsR93J2awZRAk_Rep_yAWKShAQOgI6faZKD_tg,4389
|
94 |
+
mlx/include/metal_cpp/Metal/MTLVersion.hpp,sha256=iaG4OsKd3JlcsoLEvyAx_5TdCRup1ykH0FUYqbu6Lxo,1509
|
95 |
+
mlx/include/metal_cpp/Metal/MTLVertexDescriptor.hpp,sha256=4oRbxIiJV2eafs-SqTXTVgPttIC22KtTZ0t4AXnLlgc,12017
|
96 |
+
mlx/include/metal_cpp/Metal/MTLVisibleFunctionTable.hpp,sha256=P5Cv_bIIfJ4ktplCeZ3fAhHtLEXngF5cAlQgAr55c-0,3833
|
97 |
+
mlx/include/metal_cpp/Metal/Metal.hpp,sha256=Bm5ldz-FxywsGmAdjZ8U2X5hlYjwfzHUjKCoW38T-w8,3190
|
98 |
+
mlx/include/metal_cpp/MetalFX/MTLFXDefines.hpp,sha256=Ms80cxZbVVAFYDogYahvmzyVZqkDezrwXKXB_79m3KE,2104
|
99 |
+
mlx/include/metal_cpp/MetalFX/MTLFXPrivate.hpp,sha256=LcubGmMx98T8OduyEawEx8swCEo_n5Z3ZfwuxwvAtz0,15147
|
100 |
+
mlx/include/metal_cpp/MetalFX/MTLFXSpatialScaler.hpp,sha256=8QCySJ9V5JfuTNAHujkvgKQJ_hdM-yV8cqdw5v8l1Ns,18385
|
101 |
+
mlx/include/metal_cpp/MetalFX/MTLFXTemporalScaler.hpp,sha256=57Qfhq3YqYOd0qwLfOBvz_viu55-Pmor0_IF99UWMAg,32887
|
102 |
+
mlx/include/metal_cpp/MetalFX/MetalFX.hpp,sha256=Vm0W_ycCRb0UOOvLIPYm5gIEHW1nikG5vTvoKfTgAtE,1350
|
103 |
+
mlx/include/metal_cpp/QuartzCore/CADefines.hpp,sha256=q6k-jUxe5uulNnsoqkM61OuyEVSJO4W8gIAVM6mSbwE,1895
|
104 |
+
mlx/include/metal_cpp/QuartzCore/CAMetalDrawable.hpp,sha256=RE4_Cp9FacJMh7AlqZw-IhelLMXvUv_YpXVLQ2aNVrU,2363
|
105 |
+
mlx/include/metal_cpp/QuartzCore/CAMetalLayer.hpp,sha256=i4WjUgWUsQVRPj8dxk4aBC06TKaXsKxRdgNNIgNe0pk,5502
|
106 |
+
mlx/include/metal_cpp/QuartzCore/CAPrivate.hpp,sha256=aFEVW34Ih_Gs1aGv5Rml0pNaFxDwWLP6qBQWw3N6EzA,5010
|
107 |
+
mlx/include/metal_cpp/QuartzCore/QuartzCore.hpp,sha256=-fQqOpndpszueKWJ4U_jovMID_s0QxCOJHpQoUj8JBE,1346
|
108 |
+
mlx/include/metal_cpp/README.md,sha256=ZVbXv3dwSy-n_EmPpYlB2cJXkUCZXdE0-jJsY9pHsd0,14584
|
109 |
+
mlx/include/metal_cpp/SingleHeader/MakeSingleHeader.py,sha256=JqlbeaangDoU2dRPyOfKgU5jkbmUghucuumqPEbZOWc,8984
|
110 |
+
mlx/include/metal_cpp/SingleHeader/__pycache__/MakeSingleHeader.cpython-311.pyc,,
|
111 |
+
mlx/include/mlx/3rdparty/pocketfft.h,sha256=Dq6iEwS_MkY5_xECLuY_r0io-rZK_lw8LR7---HX36Y,110508
|
112 |
+
mlx/include/mlx/allocator.h,sha256=r0AgIuBbvq67OeO8eUD8cb4a72HDF6-PISAnOMfC-kM,1519
|
113 |
+
mlx/include/mlx/array.h,sha256=D_dkEyq5zbBSc8kK5KNgQbAikvihW9Xn_1HlwI9uaps,11348
|
114 |
+
mlx/include/mlx/backend/accelerate/utils.h,sha256=ryQ9t4EI6UkMBCE48MVcGUSstbRDlwM9YI__vs4dTzY,752
|
115 |
+
mlx/include/mlx/backend/common/arange.h,sha256=J44RZZx_DB7DSELCtDFHBThHfLopu_JWj_5ew_q0ROw,1823
|
116 |
+
mlx/include/mlx/backend/common/binary.h,sha256=uvS_NdciYF-6VHhKrDNFL1Sd2Ag8hn_kClpYmcu6E9E,15169
|
117 |
+
mlx/include/mlx/backend/common/copy.h,sha256=5-b2CRTRoyCJ4WL5RkPAQXaLKaEn8pz-XojHGs4u1Rc,690
|
118 |
+
mlx/include/mlx/backend/common/erf.h,sha256=Wm6jXR40EgAmyh0zbGOvRbZh58LVhe3_P1-hwN9GPuo,274
|
119 |
+
mlx/include/mlx/backend/common/reduce.h,sha256=iR0xKDMZSEDrf2W0TLLXuqxVesf-nOioGu0MDGalT-0,11138
|
120 |
+
mlx/include/mlx/backend/common/threefry.h,sha256=sXTdUFRAgEXf789u7IhWIj7n-Gtzwo4HiQXDA6BwFyU,572
|
121 |
+
mlx/include/mlx/backend/common/unary.h,sha256=5vAiPfmNjXxMohP6RqZiOngCXcXa76RtrLDEyTeQobs,3356
|
122 |
+
mlx/include/mlx/backend/common/utils.h,sha256=4pS1LCC-Tux3x4Y0A3LeTyTLGDSxbjapNkNFWUNjrc8,610
|
123 |
+
mlx/include/mlx/backend/metal/allocator.h,sha256=lluCLBSXYbp6WGtl9zKsikpMHrptRjeZJ0Cf70zFkoc,1491
|
124 |
+
mlx/include/mlx/backend/metal/copy.h,sha256=jXW0_jDwX97NmnwigeZwovS_MZ1PeZix2iuZf7f9Xts,400
|
125 |
+
mlx/include/mlx/backend/metal/device.h,sha256=1Q0xSLn21dTzD5rWanoBBg1eMqOPf2958Z0rJL9GJag,2196
|
126 |
+
mlx/include/mlx/backend/metal/kernels/atomic.h,sha256=0Iph_jNZoTyTrR-iEXXDzIwFoRWOkhqndX6IHJxFhKg,9217
|
127 |
+
mlx/include/mlx/backend/metal/kernels/bf16.h,sha256=YdmOaznu1O1UNTyGlnpnEkvi1r44EouYNBAZ4SAX2yA,11904
|
128 |
+
mlx/include/mlx/backend/metal/kernels/bf16_math.h,sha256=VRM_iZpBcY3Sg3yxkNPzjU30x32zcmd0diIagA1D5l4,26337
|
129 |
+
mlx/include/mlx/backend/metal/kernels/complex.h,sha256=XDYKDiBdeM5KGpfpmwXmelZ00LmFCdxiGOIGoUMEeyI,3526
|
130 |
+
mlx/include/mlx/backend/metal/kernels/conv_params.h,sha256=pzJn2uZ8TqY1WXRyGRx-MzBYBOfMSyhiDPgu8LMSElk,592
|
131 |
+
mlx/include/mlx/backend/metal/kernels/defines.h,sha256=Bkplnhk_KYkqUzpB9TSlYUuXyucJQZYNyW9AR55g8fM,477
|
132 |
+
mlx/include/mlx/backend/metal/kernels/erf.h,sha256=k1CfVsb-rEQaXGYJcAbgVR90VwEY5w_88BhJCSsj97k,2736
|
133 |
+
mlx/include/mlx/backend/metal/kernels/gemm/conv.h,sha256=ilmFKKdzy8qwn1Zt0cu43f_mKsMLhIJy6DgIdx5ZU5I,14236
|
134 |
+
mlx/include/mlx/backend/metal/kernels/gemm/gemm.h,sha256=x80JGVUliIDkFyzoGGd5iC0veHZVPx_BFkeVdGzgVOk,16184
|
135 |
+
mlx/include/mlx/backend/metal/kernels/reduce.h,sha256=RgBsV24CtUU_dXCNzDP8sojJdjRz2bdF3PqB8UNJrlA,3554
|
136 |
+
mlx/include/mlx/backend/metal/kernels/utils.h,sha256=mRGC9B2cjzCZT2kxY__EM0QmRQQBL8GSHC0rg7_4cJo,7202
|
137 |
+
mlx/include/mlx/backend/metal/matmul.h,sha256=n32FxTg1zL-_oIaPRrs9QcI5TywhMemcm2KJ4zF-n0Q,592
|
138 |
+
mlx/include/mlx/backend/metal/metal.h,sha256=9kgNkbx0Pqy-W3o7stP0cjvSAcoU0I20dZEsKmsSAS4,552
|
139 |
+
mlx/include/mlx/backend/metal/mps/gemm.h,sha256=501lddlggxogceZI3wkWK98yIZurE0qgVxx0ICHyBxs,11307
|
140 |
+
mlx/include/mlx/backend/metal/utils.h,sha256=ape52Rl1et96MA4H_a0cbVveLvhASJlszqX6mxl-nis,4205
|
141 |
+
mlx/include/mlx/device.h,sha256=KtG7G03J5T0Ad2zn931OyRGfCa1XKn3ZyEUoIBY0No4,563
|
142 |
+
mlx/include/mlx/dtype.h,sha256=jrOMl7xEhOZ7Dn0udDfi-Mq8QbT3atbd7CL0SZ4Yxjk,2575
|
143 |
+
mlx/include/mlx/fft.h,sha256=MpW4UMjetc81i8pt3NJnTCvtXuaP92CX3w2iatYZHTw,4179
|
144 |
+
mlx/include/mlx/graph_utils.h,sha256=23pv6xyLYTl-l2o9Z3DqlcvGVGZEA-uIV29jamcK-1Q,593
|
145 |
+
mlx/include/mlx/io/load.h,sha256=vl7Z0pSJTjgG7KkSLXbe5oLjPhnVkWR9XdvC-ZCf_fk,2508
|
146 |
+
mlx/include/mlx/io/safetensor.h,sha256=ro-grE_iXiGbCgygGFSeqsvxD_DDiwKckCsyb424e-g,624
|
147 |
+
mlx/include/mlx/linalg.h,sha256=-CJ8NhHmycVsT29ObUmVZbxd2gwCpj8qj2RvOv-Glto,1771
|
148 |
+
mlx/include/mlx/mlx.h,sha256=raZfU4bjGTSRAZa8P1XkoxHHeiooTZVI9C63ydGwm8c,296
|
149 |
+
mlx/include/mlx/ops.h,sha256=322gdgz9l7pz-M0Pgpo0T6LbAe_w0RMV-A1-jRN_Le0,32224
|
150 |
+
mlx/include/mlx/primitives.h,sha256=uJhmcwsXyseQlfeJBDjLCNrhPM91vursiw8_R79yUGg,44227
|
151 |
+
mlx/include/mlx/random.h,sha256=aNoyWf-zAg9WKgqfe1yBpOvxedygIUZMQBgsd3kNnzQ,4887
|
152 |
+
mlx/include/mlx/scheduler.h,sha256=sZX0BU1J15e_IOzBsv95KwAQmU4wrayd3b6ClvxGA4M,3864
|
153 |
+
mlx/include/mlx/stream.h,sha256=7BvNtylObXbyIHZ0gCrbMb1bN3DUq7PfzEa-bIvqEXE,686
|
154 |
+
mlx/include/mlx/transforms.h,sha256=Qt8pjvxAPYMoa8Dz2WUcJBdeKEFKNdczV0O3UTE_XmY,6527
|
155 |
+
mlx/include/mlx/transforms_impl.h,sha256=RttUce4KHg3PLUFc9aRVjxenCIztlBSTlQq5LnVxKVg,542
|
156 |
+
mlx/include/mlx/types/bf16.h,sha256=fLnTB0fbL2vIs_CCvWhrr7SfDnFUEpDPmg7RNmrRnzE,6753
|
157 |
+
mlx/include/mlx/types/complex.h,sha256=oaTl_rCth3TO33sT2Db6SX--n9lIfE_dw7rjE06Qqmc,2878
|
158 |
+
mlx/include/mlx/types/fp16.h,sha256=j0E7-PKkVkCT-bOMRdBEYdvxUdo9rf6w80dNSNCYNaQ,8322
|
159 |
+
mlx/include/mlx/types/half_types.h,sha256=-D-ilDAh5shp1xp4lx0J87y9iFfBs43gLNi-XqzmLWA,1367
|
160 |
+
mlx/include/mlx/utils.h,sha256=cchUDEk5qbWNIPHf_cRwbsxnczz2_moUr_Dipz-cfhM,1524
|
161 |
+
mlx/lib/libmlx.dylib,sha256=ir7-RqHznJKyhGSBTwWnMPqYmbF3V3A0A8bvNi8GrJM,12420704
|
162 |
+
mlx/lib/mlx.metallib,sha256=Lu30EADtJwKD2hHYibsQGqTIjG-PDsaP5rBApb5CRQE,59495531
|
163 |
+
mlx/nn/__init__.py,sha256=b-1hqkAnzpthKJNKgRGVHsltNEqOU7URFMoY2Z46Buw,126
|
164 |
+
mlx/nn/__pycache__/__init__.cpython-311.pyc,,
|
165 |
+
mlx/nn/__pycache__/losses.cpython-311.pyc,,
|
166 |
+
mlx/nn/__pycache__/utils.cpython-311.pyc,,
|
167 |
+
mlx/nn/layers/__init__.py,sha256=2Ge23mO7W5VOMkogmcY2dqTKRRD3lY0g_WhVDP-9Hi0,1249
|
168 |
+
mlx/nn/layers/__pycache__/__init__.cpython-311.pyc,,
|
169 |
+
mlx/nn/layers/__pycache__/activations.cpython-311.pyc,,
|
170 |
+
mlx/nn/layers/__pycache__/base.cpython-311.pyc,,
|
171 |
+
mlx/nn/layers/__pycache__/containers.cpython-311.pyc,,
|
172 |
+
mlx/nn/layers/__pycache__/convolution.cpython-311.pyc,,
|
173 |
+
mlx/nn/layers/__pycache__/dropout.cpython-311.pyc,,
|
174 |
+
mlx/nn/layers/__pycache__/embedding.cpython-311.pyc,,
|
175 |
+
mlx/nn/layers/__pycache__/linear.cpython-311.pyc,,
|
176 |
+
mlx/nn/layers/__pycache__/normalization.cpython-311.pyc,,
|
177 |
+
mlx/nn/layers/__pycache__/positional_encoding.cpython-311.pyc,,
|
178 |
+
mlx/nn/layers/__pycache__/quantized.cpython-311.pyc,,
|
179 |
+
mlx/nn/layers/__pycache__/transformer.cpython-311.pyc,,
|
180 |
+
mlx/nn/layers/activations.py,sha256=ukopw8ZFpSIQijXscov80Gv_0j0jKh0Y4Z3qrUFzpkk,11826
|
181 |
+
mlx/nn/layers/base.py,sha256=iFp3PVGUhT9pAaWNnWtNY-yZ-uCQxvH-O1CYbKoGCug,19485
|
182 |
+
mlx/nn/layers/containers.py,sha256=Ke8gPBPZvrtrjK1qrbRNGtSV4rvBkYTKVBW2pQXrNvY,618
|
183 |
+
mlx/nn/layers/convolution.py,sha256=ZP43OBZTKF3ui_2Xrtx7VvAlLic_jcn2oL-MPLU7Rys,4060
|
184 |
+
mlx/nn/layers/dropout.py,sha256=1GR5uWU82eFd296WuPB7CqNcNj3acEIlS3_4ZAxnx-g,4529
|
185 |
+
mlx/nn/layers/embedding.py,sha256=5G7TvVyEaHTlQCASNVfBGea_1ywDVNqDkxHP2U-JEcs,878
|
186 |
+
mlx/nn/layers/linear.py,sha256=ekgkElIUlfdWZNRDZkoyQFTotzyjNpD_0v7Ac8pnako,4137
|
187 |
+
mlx/nn/layers/normalization.py,sha256=m92TwFfNQ4NbQlxtlTfJsWZlBJtO0_tR0M8H2d9wZEs,12021
|
188 |
+
mlx/nn/layers/positional_encoding.py,sha256=_4zbYLMGYI0EVq4b6eitmB2hgBXblcrJJYsq0KCSafs,6611
|
189 |
+
mlx/nn/layers/quantized.py,sha256=FA_NzLBheO2i5AClMWCkaFrNXJGScxlwvumqaJTIxAA,4136
|
190 |
+
mlx/nn/layers/transformer.py,sha256=IUTNGvVoaKLS7Gwes_yK3SqgX6voPgMqk7nmdE4nyI0,12235
|
191 |
+
mlx/nn/losses.py,sha256=1yWCxt-QSv4lGV9U6KJyeTzybQWW4LPZhipkLFx7hWk,12158
|
192 |
+
mlx/nn/utils.py,sha256=tQJ62bu2eAAhzYqV7QriQtNGrYSz85XGxHolEPkGiGs,999
|
193 |
+
mlx/optimizers.py,sha256=krPDwdaMy1JTlaVBgRHlg7PNqyxiVRkfNEbEf_wV-5g,17090
|
194 |
+
mlx/share/cmake/MLX/MLXConfig.cmake,sha256=02CzwjTAHLTyLzFMtE7LLzyX_CSHreASjFyiYEpfIp4,1884
|
195 |
+
mlx/share/cmake/MLX/MLXConfigVersion.cmake,sha256=p6TWi4EVJlFiJx_BxHZqpPBhkKaze00W3S-JRcRHD78,2762
|
196 |
+
mlx/share/cmake/MLX/MLXTargets-release.cmake,sha256=bUjlsca3AGGUHl1X9lfm3CiQiSFeYTkm2c7GLzITeJs,788
|
197 |
+
mlx/share/cmake/MLX/MLXTargets.cmake,sha256=vS5sI59URF-Pz8SPKMDVbyTPq1ENZ3zoR_VPzWToXkw,4705
|
198 |
+
mlx/share/cmake/MLX/extension.cmake,sha256=q09-X3ckTrFCZOc5UNyoqjUC94OQQSXXWW7Bmg7sxok,1637
|
199 |
+
mlx/utils.py,sha256=F-bmWn6a4VFRUvnSeZPvIFwT8Z6IHEtgEAfcLiGBrz8,4667
|
lib/python3.11/site-packages/mlx-0.0.7.dist-info/REQUESTED
ADDED
File without changes
|
lib/python3.11/site-packages/mlx-0.0.7.dist-info/WHEEL
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Wheel-Version: 1.0
|
2 |
+
Generator: bdist_wheel (0.41.2)
|
3 |
+
Root-Is-Purelib: false
|
4 |
+
Tag: cp311-cp311-macosx_14_0_arm64
|
5 |
+
|
lib/python3.11/site-packages/mlx-0.0.7.dist-info/top_level.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
mlx
|
lib/python3.11/site-packages/mlx/nn/layers/base.py
ADDED
@@ -0,0 +1,532 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
import textwrap
|
4 |
+
from typing import Any, Callable, List, Optional, Tuple, Union
|
5 |
+
|
6 |
+
import mlx.core as mx
|
7 |
+
from mlx.utils import tree_flatten, tree_unflatten
|
8 |
+
|
9 |
+
|
10 |
+
class Module(dict):
|
11 |
+
"""Base class for building neural networks with MLX.
|
12 |
+
|
13 |
+
All the layers provided in :mod:`mlx.nn.layers` subclass this class and
|
14 |
+
your models should do the same.
|
15 |
+
|
16 |
+
A ``Module`` can contain other ``Module`` instances or :class:`mlx.core.array`
|
17 |
+
instances in arbitrary nesting of python lists or dicts. The ``Module``
|
18 |
+
then allows recursively extracting all the :class:`mlx.core.array` instances
|
19 |
+
using :meth:`mlx.nn.Module.parameters`.
|
20 |
+
|
21 |
+
In addition, the ``Module`` has the concept of trainable and non trainable
|
22 |
+
parameters (called "frozen"). When using :func:`mlx.nn.value_and_grad`
|
23 |
+
the gradients are returned only with respect to the trainable parameters.
|
24 |
+
All arrays in a module are trainable unless they are added in the "frozen"
|
25 |
+
set by calling :meth:`freeze`.
|
26 |
+
|
27 |
+
.. code-block:: python
|
28 |
+
|
29 |
+
import mlx.core as mx
|
30 |
+
import mlx.nn as nn
|
31 |
+
|
32 |
+
class MyMLP(nn.Module):
|
33 |
+
def __init__(self, in_dims: int, out_dims: int, hidden_dims: int = 16):
|
34 |
+
super().__init__()
|
35 |
+
|
36 |
+
self.in_proj = nn.Linear(in_dims, hidden_dims)
|
37 |
+
self.out_proj = nn.Linear(hidden_dims, out_dims)
|
38 |
+
|
39 |
+
def __call__(self, x):
|
40 |
+
x = self.in_proj(x)
|
41 |
+
x = mx.maximum(x, 0)
|
42 |
+
return self.out_proj(x)
|
43 |
+
|
44 |
+
model = MyMLP(2, 1)
|
45 |
+
|
46 |
+
# All the model parameters are created but since MLX is lazy by
|
47 |
+
# default, they are not evaluated yet. Calling `mx.eval` actually
|
48 |
+
# allocates memory and initializes the parameters.
|
49 |
+
mx.eval(model.parameters())
|
50 |
+
|
51 |
+
# Setting a parameter to a new value is as simply as accessing that
|
52 |
+
# parameter and assigning a new array to it.
|
53 |
+
model.in_proj.weight = model.in_proj.weight * 2
|
54 |
+
mx.eval(model.parameters())
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(self):
|
58 |
+
"""Should be called by the subclasses of ``Module``."""
|
59 |
+
self._no_grad = set()
|
60 |
+
self._training = True
|
61 |
+
|
62 |
+
@property
|
63 |
+
def training(self):
|
64 |
+
"""Boolean indicating if the model is in training mode."""
|
65 |
+
return self._training
|
66 |
+
|
67 |
+
def _extra_repr(self):
|
68 |
+
return ""
|
69 |
+
|
70 |
+
def __repr__(self):
|
71 |
+
children = tree_flatten(self.children(), is_leaf=self.is_module)
|
72 |
+
value = f"{type(self).__name__}({self._extra_repr()}"
|
73 |
+
for k, v in children:
|
74 |
+
value += "\n"
|
75 |
+
value += textwrap.indent(f"({k}): {repr(v)}", prefix=" ")
|
76 |
+
if children:
|
77 |
+
value += "\n"
|
78 |
+
value += ")"
|
79 |
+
|
80 |
+
return value
|
81 |
+
|
82 |
+
def __getattr__(self, key: str):
|
83 |
+
if key in self:
|
84 |
+
return self[key]
|
85 |
+
else:
|
86 |
+
raise AttributeError(f"{type(self)!r} has no attribute {key!r}")
|
87 |
+
|
88 |
+
def __setattr__(self, key: str, val: Any):
|
89 |
+
self[key] = val
|
90 |
+
|
91 |
+
def load_weights(
|
92 |
+
self,
|
93 |
+
file_or_weights: Union[str, List[Tuple[str, mx.array]]],
|
94 |
+
strict: bool = True,
|
95 |
+
):
|
96 |
+
"""
|
97 |
+
Update the model's weights from a ``.npz`` or a list.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
file_or_weights (str or list(tuple(str, mx.array))): The path to
|
101 |
+
the weights ``.npz`` file or a list of pairs of parameter names
|
102 |
+
and arrays.
|
103 |
+
strict (bool, optional): If ``True`` then checks that the provided
|
104 |
+
weights exactly match the parameters of the model. Otherwise,
|
105 |
+
only the weights actually contained in the model are loaded and
|
106 |
+
shapes are not checked. Default: ``True``.
|
107 |
+
|
108 |
+
Example:
|
109 |
+
|
110 |
+
.. code-block:: python
|
111 |
+
|
112 |
+
import mlx.core as mx
|
113 |
+
import mlx.nn as nn
|
114 |
+
model = nn.Linear(10, 10)
|
115 |
+
|
116 |
+
# Load from file
|
117 |
+
model.load_weights("weights.npz")
|
118 |
+
|
119 |
+
# Load from list
|
120 |
+
weights = [
|
121 |
+
("weight", mx.random.uniform(shape=(10, 10))),
|
122 |
+
("bias", mx.zeros((10,))),
|
123 |
+
]
|
124 |
+
model.load_weights(weights)
|
125 |
+
|
126 |
+
# Missing weight
|
127 |
+
weights = [
|
128 |
+
("weight", mx.random.uniform(shape=(10, 10))),
|
129 |
+
]
|
130 |
+
|
131 |
+
# Raises a ValueError exception
|
132 |
+
model.load_weights(weights)
|
133 |
+
|
134 |
+
# Ok, only updates the weight but not the bias
|
135 |
+
model.load_weights(weights, strict=False)
|
136 |
+
"""
|
137 |
+
weights = file_or_weights
|
138 |
+
if isinstance(weights, str):
|
139 |
+
weights = list(mx.load(weights).items())
|
140 |
+
|
141 |
+
if strict:
|
142 |
+
new_weights = dict(weights)
|
143 |
+
curr_weights = dict(tree_flatten(self.parameters()))
|
144 |
+
if extras := (new_weights.keys() - curr_weights.keys()):
|
145 |
+
extras = " ".join(extras)
|
146 |
+
raise ValueError(f"Received parameters not in model: {extras}.")
|
147 |
+
if missing := (curr_weights.keys() - new_weights.keys()):
|
148 |
+
missing = " ".join(missing)
|
149 |
+
raise ValueError(f"Missing parameters: {missing}.")
|
150 |
+
for k, v in curr_weights.items():
|
151 |
+
v_new = new_weights[k]
|
152 |
+
if not isinstance(v_new, mx.array):
|
153 |
+
raise ValueError(
|
154 |
+
"Expected mx.array but received "
|
155 |
+
f"{type(v_new)} for parameter {k}"
|
156 |
+
)
|
157 |
+
if v_new.shape != v.shape:
|
158 |
+
raise ValueError(
|
159 |
+
f"Expected shape {v.shape} but received "
|
160 |
+
f" shape {v_new.shape} for parameter {k}"
|
161 |
+
)
|
162 |
+
|
163 |
+
self.update(tree_unflatten(weights))
|
164 |
+
|
165 |
+
def save_weights(self, file: str):
|
166 |
+
"""
|
167 |
+
Save the model's weights to a ``.npz`` file.
|
168 |
+
"""
|
169 |
+
mx.savez(file, **dict(tree_flatten(self.parameters())))
|
170 |
+
|
171 |
+
@staticmethod
|
172 |
+
def is_module(value):
|
173 |
+
return isinstance(value, Module)
|
174 |
+
|
175 |
+
@staticmethod
|
176 |
+
def valid_child_filter(module, key, value):
|
177 |
+
return isinstance(value, (dict, list))
|
178 |
+
|
179 |
+
@staticmethod
|
180 |
+
def valid_parameter_filter(module, key, value):
|
181 |
+
return isinstance(value, (dict, list, mx.array)) and not key.startswith("_")
|
182 |
+
|
183 |
+
@staticmethod
|
184 |
+
def trainable_parameter_filter(module, key, value):
|
185 |
+
return (
|
186 |
+
Module.valid_parameter_filter(module, key, value)
|
187 |
+
and key not in module._no_grad
|
188 |
+
)
|
189 |
+
|
190 |
+
def filter_and_map(
|
191 |
+
self,
|
192 |
+
filter_fn: Callable[["mlx.nn.Module", str, Any], bool],
|
193 |
+
map_fn: Optional[Callable] = None,
|
194 |
+
is_leaf_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
|
195 |
+
):
|
196 |
+
"""Recursively filter the contents of the module using ``filter_fn``,
|
197 |
+
namely only select keys and values where ``filter_fn`` returns true.
|
198 |
+
|
199 |
+
This is used to implement :meth:`parameters` and :meth:`trainable_parameters`
|
200 |
+
but it can also be used to extract any subset of the module's parameters.
|
201 |
+
|
202 |
+
Args:
|
203 |
+
filter_fn (Callable): Given a value, the key in which it is found
|
204 |
+
and the containing module, decide whether to keep the value or
|
205 |
+
drop it.
|
206 |
+
map_fn (Callable, optional): Optionally transform the value before
|
207 |
+
returning it.
|
208 |
+
is_leaf_fn (Callable, optional): Given a value, the key in which it
|
209 |
+
is found and the containing module decide if it is a leaf.
|
210 |
+
|
211 |
+
Returns:
|
212 |
+
A dictionary containing the contents of the module recursively filtered
|
213 |
+
"""
|
214 |
+
|
215 |
+
map_fn = map_fn or (lambda x: x)
|
216 |
+
is_leaf_fn = is_leaf_fn or (
|
217 |
+
lambda m, k, v: not isinstance(v, (Module, dict, list))
|
218 |
+
)
|
219 |
+
|
220 |
+
def unwrap(vk, v):
|
221 |
+
if is_leaf_fn(self, vk, v):
|
222 |
+
return map_fn(v)
|
223 |
+
|
224 |
+
if isinstance(v, Module):
|
225 |
+
return v.filter_and_map(filter_fn, map_fn, is_leaf_fn)
|
226 |
+
|
227 |
+
if isinstance(v, dict):
|
228 |
+
nd = {}
|
229 |
+
for k, v in v.items():
|
230 |
+
tk = f"{vk}.{k}"
|
231 |
+
nd[k] = unwrap(tk, v) if filter_fn(self, tk, v) else {}
|
232 |
+
return nd
|
233 |
+
|
234 |
+
if isinstance(v, list):
|
235 |
+
nl = []
|
236 |
+
for i, vi in enumerate(v):
|
237 |
+
tk = f"{vk}.{i}"
|
238 |
+
nl.append(unwrap(tk, vi) if filter_fn(self, tk, vi) else {})
|
239 |
+
return nl
|
240 |
+
|
241 |
+
raise RuntimeError("Unexpected leaf found while traversing the module")
|
242 |
+
|
243 |
+
return {k: unwrap(k, v) for k, v in self.items() if filter_fn(self, k, v)}
|
244 |
+
|
245 |
+
def parameters(self):
|
246 |
+
"""Recursively return all the :class:`mlx.core.array` members of this Module
|
247 |
+
as a dict of dicts and lists."""
|
248 |
+
return self.filter_and_map(self.valid_parameter_filter)
|
249 |
+
|
250 |
+
def trainable_parameters(self):
|
251 |
+
"""Recursively return all the non frozen :class:`mlx.core.array` members of
|
252 |
+
this Module as a dict of dicts and lists."""
|
253 |
+
return self.filter_and_map(self.trainable_parameter_filter)
|
254 |
+
|
255 |
+
def children(self):
|
256 |
+
"""Return the direct descendants of this Module instance."""
|
257 |
+
return self.filter_and_map(
|
258 |
+
self.valid_child_filter, is_leaf_fn=lambda m, k, v: isinstance(v, Module)
|
259 |
+
)
|
260 |
+
|
261 |
+
def leaf_modules(self):
|
262 |
+
"""Return the submodules that do not contain other modules."""
|
263 |
+
|
264 |
+
def _is_leaf_module(m, k, v):
|
265 |
+
return isinstance(v, Module) and len(tree_flatten(v.children())) == 0
|
266 |
+
|
267 |
+
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
|
268 |
+
|
269 |
+
def update(self, parameters: dict):
|
270 |
+
"""Replace the parameters of this Module with the provided ones in the
|
271 |
+
dict of dicts and lists.
|
272 |
+
|
273 |
+
Commonly used by the optimizer to change the model to the updated
|
274 |
+
(optimized) parameters. Also used by the :meth:`mlx.nn.value_and_grad` to set the
|
275 |
+
tracers in the model in order to compute gradients.
|
276 |
+
|
277 |
+
The passed in parameters dictionary need not be a full dictionary
|
278 |
+
similar to :meth:`parameters`. Only the provided locations will be
|
279 |
+
updated.
|
280 |
+
|
281 |
+
Args:
|
282 |
+
parameters (dict): A complete or partial dictionary of the modules
|
283 |
+
parameters.
|
284 |
+
"""
|
285 |
+
|
286 |
+
def apply(dst, parameters):
|
287 |
+
if isinstance(parameters, dict):
|
288 |
+
for k in parameters:
|
289 |
+
if k in dst:
|
290 |
+
current_value = dst[k]
|
291 |
+
new_value = parameters[k]
|
292 |
+
if isinstance(current_value, mx.array):
|
293 |
+
dst[k] = new_value
|
294 |
+
elif isinstance(current_value, Module):
|
295 |
+
current_value.update(new_value)
|
296 |
+
elif isinstance(current_value, (dict, list)):
|
297 |
+
apply(current_value, new_value)
|
298 |
+
elif isinstance(parameters, list):
|
299 |
+
for i in range(len(dst)):
|
300 |
+
current_value = dst[i]
|
301 |
+
new_value = parameters[i]
|
302 |
+
if isinstance(current_value, mx.array):
|
303 |
+
dst[i] = new_value
|
304 |
+
elif isinstance(current_value, Module):
|
305 |
+
current_value.update(new_value)
|
306 |
+
elif isinstance(current_value, (dict, list)):
|
307 |
+
apply(current_value, new_value)
|
308 |
+
|
309 |
+
apply(self, parameters)
|
310 |
+
|
311 |
+
def apply(
|
312 |
+
self,
|
313 |
+
map_fn: Callable[[mx.array], mx.array],
|
314 |
+
filter_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
|
315 |
+
):
|
316 |
+
"""Map all the parameters using the provided ``map_fn`` and immediately
|
317 |
+
update the module with the mapped parameters.
|
318 |
+
|
319 |
+
For instance running ``model.apply(lambda x: x.astype(mx.float16))``
|
320 |
+
casts all parameters to 16 bit floats.
|
321 |
+
|
322 |
+
Args:
|
323 |
+
map_fn (Callable): Maps an array to another array
|
324 |
+
filter_fn (Callable, optional): Filter to select which arrays to
|
325 |
+
map (default: :meth:`Module.valid_parameter_filter`).
|
326 |
+
"""
|
327 |
+
filter_fn = filter_fn or Module.valid_parameter_filter
|
328 |
+
self.update(self.filter_and_map(filter_fn, map_fn))
|
329 |
+
|
330 |
+
def update_modules(self, modules: dict):
|
331 |
+
"""Replace the child modules of this :class:`Module` instance with the
|
332 |
+
provided ones in the dict of dicts and lists.
|
333 |
+
|
334 |
+
It is the equivalent of :meth:`Module.update` but for modules instead
|
335 |
+
of parameters and allows us to flexibly edit complex architectures by
|
336 |
+
programmatically swapping layers.
|
337 |
+
|
338 |
+
The passed in parameters dictionary need not be a full dictionary
|
339 |
+
similar to :meth:`parameters`. Only the provided locations will be
|
340 |
+
updated.
|
341 |
+
|
342 |
+
Args:
|
343 |
+
modules (dict): A complete or partial dictionary of the modules
|
344 |
+
submodules.
|
345 |
+
"""
|
346 |
+
|
347 |
+
def apply(dst, modules):
|
348 |
+
if isinstance(modules, dict):
|
349 |
+
for k in modules:
|
350 |
+
if k in dst:
|
351 |
+
current_value = dst[k]
|
352 |
+
new_value = modules[k]
|
353 |
+
if self.is_module(current_value) and self.is_module(new_value):
|
354 |
+
dst[k] = new_value
|
355 |
+
elif isinstance(current_value, (dict, list)):
|
356 |
+
apply(current_value, new_value)
|
357 |
+
elif isinstance(modules, list):
|
358 |
+
for i in range(len(dst)):
|
359 |
+
current_value = dst[i]
|
360 |
+
new_value = modules[i]
|
361 |
+
if self.is_module(current_value) and self.is_module(new_value):
|
362 |
+
dst[i] = new_value
|
363 |
+
elif isinstance(current_value, (dict, list)):
|
364 |
+
apply(current_value, new_value)
|
365 |
+
|
366 |
+
apply(self, modules)
|
367 |
+
|
368 |
+
def apply_to_modules(self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]):
|
369 |
+
"""Apply a function to all the modules in this instance (including this
|
370 |
+
instance).
|
371 |
+
|
372 |
+
Args:
|
373 |
+
apply_fn (Callable): The function to apply to the modules.
|
374 |
+
"""
|
375 |
+
module_stack = [("", self)]
|
376 |
+
while module_stack:
|
377 |
+
prefix, mod = module_stack.pop()
|
378 |
+
apply_fn(prefix, mod)
|
379 |
+
prefix = "." + prefix if prefix else ""
|
380 |
+
module_stack.extend(
|
381 |
+
tree_flatten(mod.children(), prefix=prefix, is_leaf=self.is_module)
|
382 |
+
)
|
383 |
+
|
384 |
+
def modules(self):
|
385 |
+
"""Return a list with all the modules in this instance.
|
386 |
+
|
387 |
+
Returns:
|
388 |
+
A list of :class:`mlx.nn.Module` instances.
|
389 |
+
"""
|
390 |
+
modulelist = []
|
391 |
+
self.apply_to_modules(lambda k, m: modulelist.append(m))
|
392 |
+
return modulelist
|
393 |
+
|
394 |
+
def named_modules(self):
|
395 |
+
"""Return a list with all the modules in this instance and their name
|
396 |
+
with dot notation.
|
397 |
+
|
398 |
+
Returns:
|
399 |
+
A list of tuples (str, :class:`mlx.nn.Module`).
|
400 |
+
"""
|
401 |
+
modulelist = []
|
402 |
+
self.apply_to_modules(lambda k, m: modulelist.append((k, m)))
|
403 |
+
return modulelist
|
404 |
+
|
405 |
+
def _validate_keys(self, keys, strict):
|
406 |
+
keys = keys if isinstance(keys, list) else [keys]
|
407 |
+
if strict:
|
408 |
+
for k in keys:
|
409 |
+
if k not in self:
|
410 |
+
raise KeyError(f"Module doesn't contain member {k}.")
|
411 |
+
return keys
|
412 |
+
|
413 |
+
def freeze(
|
414 |
+
self,
|
415 |
+
*,
|
416 |
+
recurse: bool = True,
|
417 |
+
keys: Optional[Union[str, List[str]]] = None,
|
418 |
+
strict: bool = False,
|
419 |
+
):
|
420 |
+
"""Freeze the Module's parameters or some of them. Freezing a parameter means not
|
421 |
+
computing gradients for it.
|
422 |
+
|
423 |
+
This function is idempotent i.e. freezing a frozen model is a no-op.
|
424 |
+
|
425 |
+
Example:
|
426 |
+
For instance to only train the attention parameters from a Transformer:
|
427 |
+
|
428 |
+
.. code-block:: python
|
429 |
+
|
430 |
+
model = nn.Transformer()
|
431 |
+
model.freeze()
|
432 |
+
model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None)
|
433 |
+
|
434 |
+
Args:
|
435 |
+
recurse (bool, optional): If True then freeze the parameters of the
|
436 |
+
submodules as well. Default: ``True``.
|
437 |
+
keys (str or list[str], optional): If provided then only these
|
438 |
+
parameters will be frozen otherwise all the parameters of a
|
439 |
+
module. For instance freeze all biases by calling
|
440 |
+
``module.freeze(keys="bias")``.
|
441 |
+
strict (bool, optional): If set to ``True`` validate that the passed keys exist.
|
442 |
+
Default: ``False``.
|
443 |
+
"""
|
444 |
+
|
445 |
+
def _freeze_impl(_, m):
|
446 |
+
local_keys = keys
|
447 |
+
if local_keys is None:
|
448 |
+
local_keys = tree_flatten(
|
449 |
+
m.filter_and_map(
|
450 |
+
lambda m, k, v: (not isinstance(v, Module))
|
451 |
+
and m.valid_parameter_filter(m, k, v)
|
452 |
+
)
|
453 |
+
)
|
454 |
+
local_keys = [k for (k, v) in local_keys]
|
455 |
+
|
456 |
+
local_keys = m._validate_keys(local_keys, strict)
|
457 |
+
m._no_grad.update(local_keys)
|
458 |
+
|
459 |
+
if recurse:
|
460 |
+
self.apply_to_modules(_freeze_impl)
|
461 |
+
else:
|
462 |
+
_freeze_impl("", self)
|
463 |
+
|
464 |
+
def unfreeze(
|
465 |
+
self,
|
466 |
+
*,
|
467 |
+
recurse: bool = True,
|
468 |
+
keys: Optional[Union[str, List[str]]] = None,
|
469 |
+
strict: bool = False,
|
470 |
+
):
|
471 |
+
"""Unfreeze the Module's parameters or some of them.
|
472 |
+
|
473 |
+
This function is idempotent ie unfreezing a model that is not frozen is
|
474 |
+
a noop.
|
475 |
+
|
476 |
+
Example:
|
477 |
+
|
478 |
+
For instance to only train the biases of a Transformer one can do:
|
479 |
+
|
480 |
+
.. code-block:: python
|
481 |
+
|
482 |
+
model = nn.Transformer()
|
483 |
+
model.freeze()
|
484 |
+
model.unfreeze(keys="bias")
|
485 |
+
|
486 |
+
Args:
|
487 |
+
recurse (bool, optional): If True then unfreeze the parameters of the
|
488 |
+
submodules as well. Default: ``True``.
|
489 |
+
keys (str or list[str], optional): If provided then only these
|
490 |
+
parameters will be unfrozen otherwise all the parameters of a
|
491 |
+
module. For instance unfreeze all biases by calling
|
492 |
+
``module.unfreeze(keys="bias")``.
|
493 |
+
strict (bool, optional): If set to ``True`` validate that the passed keys exist.
|
494 |
+
Default: ``False``.
|
495 |
+
"""
|
496 |
+
|
497 |
+
def _unfreeze_impl(_, m):
|
498 |
+
if keys is None:
|
499 |
+
m._no_grad.clear()
|
500 |
+
|
501 |
+
else:
|
502 |
+
local_keys = m._validate_keys(keys, strict)
|
503 |
+
m._no_grad.difference_update(local_keys)
|
504 |
+
|
505 |
+
if recurse:
|
506 |
+
self.apply_to_modules(_unfreeze_impl)
|
507 |
+
else:
|
508 |
+
_unfreeze_impl("", self)
|
509 |
+
|
510 |
+
def train(self, mode: bool = True):
|
511 |
+
"""Set the model in or out of training mode.
|
512 |
+
|
513 |
+
Training mode only applies to certain layers. For example
|
514 |
+
:obj:`Dropout` applies a random mask in training mode, but is the
|
515 |
+
identity in evaluation mode.
|
516 |
+
|
517 |
+
Args:
|
518 |
+
mode (bool): Indicate if the model should be in training or
|
519 |
+
evaluation mode. Default: ``True``.
|
520 |
+
"""
|
521 |
+
|
522 |
+
def _set_train(_, m):
|
523 |
+
m._training = mode
|
524 |
+
|
525 |
+
self.apply_to_modules(_set_train)
|
526 |
+
|
527 |
+
def eval(self):
|
528 |
+
"""Set the model to evaluation mode.
|
529 |
+
|
530 |
+
See :func:`train`.
|
531 |
+
"""
|
532 |
+
self.train(False)
|
lib/python3.11/site-packages/mlx/nn/layers/containers.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
from mlx.nn.layers.base import Module
|
4 |
+
|
5 |
+
|
6 |
+
class Sequential(Module):
|
7 |
+
"""A layer that calls the passed callables in order.
|
8 |
+
|
9 |
+
We can pass either modules or plain callables to the Sequential module. If
|
10 |
+
our functions have learnable parameters they should be implemented as
|
11 |
+
``nn.Module`` instances.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
modules (tuple of Callables): The modules to call in order
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self, *modules):
|
18 |
+
super().__init__()
|
19 |
+
self.layers = list(modules)
|
20 |
+
|
21 |
+
def __call__(self, x):
|
22 |
+
for m in self.layers:
|
23 |
+
x = m(x)
|
24 |
+
return x
|
lib/python3.11/site-packages/mlx/nn/layers/convolution.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
import math
|
4 |
+
from typing import Union
|
5 |
+
|
6 |
+
import mlx.core as mx
|
7 |
+
from mlx.nn.layers.base import Module
|
8 |
+
|
9 |
+
|
10 |
+
class Conv1d(Module):
|
11 |
+
"""Applies a 1-dimensional convolution over the multi-channel input sequence.
|
12 |
+
|
13 |
+
The channels are expected to be last i.e. the input shape should be ``NLC`` where:
|
14 |
+
- ``N`` is the batch dimension
|
15 |
+
- ``L`` is the sequence length
|
16 |
+
- ``C`` is the number of input channels
|
17 |
+
|
18 |
+
Args:
|
19 |
+
in_channels (int): The number of input channels
|
20 |
+
out_channels (int): The number of output channels
|
21 |
+
kernel_size (int): The size of the convolution filters
|
22 |
+
stride (int, optional): The stride when applying the filter.
|
23 |
+
Default: 1.
|
24 |
+
padding (int, optional): How many positions to 0-pad the input with.
|
25 |
+
Default: 0.
|
26 |
+
bias (bool, optional): If ``True`` add a learnable bias to the output.
|
27 |
+
Default: ``True``
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
in_channels: int,
|
33 |
+
out_channels: int,
|
34 |
+
kernel_size: int,
|
35 |
+
stride: int = 1,
|
36 |
+
padding: int = 0,
|
37 |
+
bias: bool = True,
|
38 |
+
):
|
39 |
+
super().__init__()
|
40 |
+
|
41 |
+
scale = math.sqrt(1 / (in_channels * kernel_size))
|
42 |
+
self.weight = mx.random.uniform(
|
43 |
+
low=-scale,
|
44 |
+
high=scale,
|
45 |
+
shape=(out_channels, kernel_size, in_channels),
|
46 |
+
)
|
47 |
+
if bias:
|
48 |
+
self.bias = mx.zeros((out_channels,))
|
49 |
+
|
50 |
+
self.padding = padding
|
51 |
+
self.stride = stride
|
52 |
+
|
53 |
+
def _extra_repr(self):
|
54 |
+
return (
|
55 |
+
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
|
56 |
+
f"kernel_size={self.weight.shape[1]}, stride={self.stride}, "
|
57 |
+
f"padding={self.padding}, bias={'bias' in self}"
|
58 |
+
)
|
59 |
+
|
60 |
+
def __call__(self, x):
|
61 |
+
y = mx.conv1d(x, self.weight, self.stride, self.padding)
|
62 |
+
if "bias" in self:
|
63 |
+
y = y + self.bias
|
64 |
+
return y
|
65 |
+
|
66 |
+
|
67 |
+
class Conv2d(Module):
|
68 |
+
"""Applies a 2-dimensional convolution over the multi-channel input image.
|
69 |
+
|
70 |
+
The channels are expected to be last i.e. the input shape should be ``NHWC`` where:
|
71 |
+
- ``N`` is the batch dimension
|
72 |
+
- ``H`` is the input image height
|
73 |
+
- ``W`` is the input image width
|
74 |
+
- ``C`` is the number of input channels
|
75 |
+
|
76 |
+
Args:
|
77 |
+
in_channels (int): The number of input channels.
|
78 |
+
out_channels (int): The number of output channels.
|
79 |
+
kernel_size (int or tuple): The size of the convolution filters.
|
80 |
+
stride (int or tuple, optional): The size of the stride when
|
81 |
+
applying the filter. Default: 1.
|
82 |
+
padding (int or tuple, optional): How many positions to 0-pad
|
83 |
+
the input with. Default: 0.
|
84 |
+
bias (bool, optional): If ``True`` add a learnable bias to the
|
85 |
+
output. Default: ``True``
|
86 |
+
"""
|
87 |
+
|
88 |
+
def __init__(
|
89 |
+
self,
|
90 |
+
in_channels: int,
|
91 |
+
out_channels: int,
|
92 |
+
kernel_size: Union[int, tuple],
|
93 |
+
stride: Union[int, tuple] = 1,
|
94 |
+
padding: Union[int, tuple] = 0,
|
95 |
+
bias: bool = True,
|
96 |
+
):
|
97 |
+
super().__init__()
|
98 |
+
|
99 |
+
kernel_size, stride, padding = map(
|
100 |
+
lambda x: (x, x) if isinstance(x, int) else x,
|
101 |
+
(kernel_size, stride, padding),
|
102 |
+
)
|
103 |
+
scale = math.sqrt(1 / (in_channels * kernel_size[0] * kernel_size[1]))
|
104 |
+
self.weight = mx.random.uniform(
|
105 |
+
low=-scale,
|
106 |
+
high=scale,
|
107 |
+
shape=(out_channels, *kernel_size, in_channels),
|
108 |
+
)
|
109 |
+
if bias:
|
110 |
+
self.bias = mx.zeros((out_channels,))
|
111 |
+
|
112 |
+
self.padding = padding
|
113 |
+
self.stride = stride
|
114 |
+
|
115 |
+
def _extra_repr(self):
|
116 |
+
return (
|
117 |
+
f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
|
118 |
+
f"kernel_size={self.weight.shape[1:2]}, stride={self.stride}, "
|
119 |
+
f"padding={self.padding}, bias={'bias' in self}"
|
120 |
+
)
|
121 |
+
|
122 |
+
def __call__(self, x):
|
123 |
+
y = mx.conv2d(x, self.weight, self.stride, self.padding)
|
124 |
+
if "bias" in self:
|
125 |
+
y = y + self.bias
|
126 |
+
return y
|
lib/python3.11/site-packages/mlx/nn/layers/dropout.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
import mlx.core as mx
|
4 |
+
from mlx.nn.layers.base import Module
|
5 |
+
|
6 |
+
|
7 |
+
class Dropout(Module):
|
8 |
+
r"""Randomly zero a portion of the elements during training.
|
9 |
+
|
10 |
+
The remaining elements are multiplied with :math:`\frac{1}{1-p}` where
|
11 |
+
:math:`p` is the probability of zeroing an element. This is done so the
|
12 |
+
expected value of a given element will remain the same.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
p (float): The probability to zero an element
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, p: float = 0.5):
|
19 |
+
super().__init__()
|
20 |
+
|
21 |
+
if p < 0 or p >= 1:
|
22 |
+
raise ValueError(f"The dropout probability {p} is not in [0, 1)")
|
23 |
+
|
24 |
+
self._p_1 = 1 - p
|
25 |
+
|
26 |
+
def _extra_repr(self):
|
27 |
+
return f"p={1-self._p_1}"
|
28 |
+
|
29 |
+
def __call__(self, x):
|
30 |
+
if self._p_1 == 1 or not self.training:
|
31 |
+
return x
|
32 |
+
|
33 |
+
mask = mx.random.bernoulli(self._p_1, x.shape)
|
34 |
+
|
35 |
+
return (1 / self._p_1) * mask * x
|
36 |
+
|
37 |
+
|
38 |
+
class Dropout2d(Module):
|
39 |
+
r"""Apply 2D channel-wise dropout during training.
|
40 |
+
|
41 |
+
Randomly zero out entire channels independently with probability :math:`p`.
|
42 |
+
This layer expects the channels to be last, i.e. the input shape should be
|
43 |
+
``NWHC`` or ``WHC`` where:``N`` is the batch dimension,``H`` is the input
|
44 |
+
image height,``W`` is the input image width, and``C`` is the number of
|
45 |
+
input channels
|
46 |
+
|
47 |
+
The remaining channels are scaled by :math:`\frac{1}{1-p}` to
|
48 |
+
maintain the expected value of each element. Unlike traditional dropout,
|
49 |
+
which zeros individual entries, this layer zeros entire channels. This is
|
50 |
+
beneficial for early convolution layers where adjacent pixels are
|
51 |
+
correlated. In such case, traditional dropout may not effectively
|
52 |
+
regularize activations. For more details, see [1].
|
53 |
+
|
54 |
+
[1]: Thompson, J., Goroshin, R., Jain, A., LeCun, Y. and Bregler C., 2015.
|
55 |
+
Efficient Object Localization Using Convolutional Networks. CVPR 2015.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
p (float): Probability of zeroing a channel during training.
|
59 |
+
"""
|
60 |
+
|
61 |
+
def __init__(self, p: float = 0.5):
|
62 |
+
super().__init__()
|
63 |
+
|
64 |
+
if p < 0 or p >= 1:
|
65 |
+
raise ValueError(f"The dropout probability {p} is not in [0, 1)")
|
66 |
+
|
67 |
+
self._p_1 = 1 - p
|
68 |
+
|
69 |
+
def _extra_repr(self):
|
70 |
+
return f"p={1-self._p_1}"
|
71 |
+
|
72 |
+
def __call__(self, x):
|
73 |
+
if x.ndim not in (3, 4):
|
74 |
+
raise ValueError(
|
75 |
+
f"Received input with {x.ndim} dimensions. Expected 3 or 4 dimensions."
|
76 |
+
)
|
77 |
+
|
78 |
+
if self._p_1 == 1 or not self.training:
|
79 |
+
return x
|
80 |
+
|
81 |
+
# Dropout is applied on the whole channel
|
82 |
+
# 3D input: (1, 1, C)
|
83 |
+
# 4D input: (B, 1, 1, C)
|
84 |
+
mask_shape = x.shape
|
85 |
+
mask_shape[-2] = mask_shape[-3] = 1
|
86 |
+
|
87 |
+
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
|
88 |
+
return (1 / self._p_1) * mask * x
|
89 |
+
|
90 |
+
|
91 |
+
class Dropout3d(Module):
|
92 |
+
r"""Apply 3D channel-wise dropout during training.
|
93 |
+
|
94 |
+
Randomly zero out entire channels independently with probability :math:`p`.
|
95 |
+
This layer expects the channels to be last, i.e., the input shape should be
|
96 |
+
`NDHWC` or `DHWC` where: `N` is the batch dimension, `D` is the depth,
|
97 |
+
`H` is the input image height, `W` is the input image width, and `C` is
|
98 |
+
the number of input channels.
|
99 |
+
|
100 |
+
The remaining channels are scaled by :math:`\frac{1}{1-p}` to
|
101 |
+
maintain the expected value of each element. Unlike traditional dropout,
|
102 |
+
which zeros individual entries, this layer zeros entire channels. This is
|
103 |
+
often beneficial for convolutional layers processing 3D data, like in
|
104 |
+
medical imaging or video processing.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
p (float): Probability of zeroing a channel during training.
|
108 |
+
"""
|
109 |
+
|
110 |
+
def __init__(self, p: float = 0.5):
|
111 |
+
super().__init__()
|
112 |
+
|
113 |
+
if p < 0 or p >= 1:
|
114 |
+
raise ValueError(f"The dropout probability {p} is not in [0, 1)")
|
115 |
+
|
116 |
+
self._p_1 = 1 - p
|
117 |
+
|
118 |
+
def _extra_repr(self):
|
119 |
+
return f"p={1-self._p_1}"
|
120 |
+
|
121 |
+
def __call__(self, x):
|
122 |
+
if x.ndim not in (4, 5):
|
123 |
+
raise ValueError(
|
124 |
+
f"Received input with {x.ndim} dimensions. Expected 4 or 5 dimensions."
|
125 |
+
)
|
126 |
+
|
127 |
+
if self._p_1 == 1 or not self.training:
|
128 |
+
return x
|
129 |
+
|
130 |
+
# Dropout is applied on the whole channel
|
131 |
+
# 4D input: (1, 1, 1, C)
|
132 |
+
# 5D input: (B, 1, 1, 1, C)
|
133 |
+
mask_shape = list(x.shape)
|
134 |
+
mask_shape[-2] = mask_shape[-3] = mask_shape[-4] = 1
|
135 |
+
|
136 |
+
mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
|
137 |
+
return (1 / self._p_1) * mask * x
|
lib/python3.11/site-packages/mlx/nn/layers/embedding.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
import mlx.core as mx
|
6 |
+
from mlx.nn.layers.base import Module
|
7 |
+
|
8 |
+
|
9 |
+
class Embedding(Module):
|
10 |
+
"""Implements a simple lookup table that maps each input integer to a
|
11 |
+
high-dimensional vector.
|
12 |
+
|
13 |
+
Typically used to embed discrete tokens for processing by neural networks.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
num_embeddings (int): How many possible discrete tokens can we embed.
|
17 |
+
Usually called the vocabulary size.
|
18 |
+
dims (int): The dimensionality of the embeddings.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self, num_embeddings: int, dims: int):
|
22 |
+
super().__init__()
|
23 |
+
scale = math.sqrt(1 / dims)
|
24 |
+
self.weight = mx.random.normal((num_embeddings, dims)) * scale
|
25 |
+
|
26 |
+
def _extra_repr(self):
|
27 |
+
return f"{self.weight.shape[0]}, {self.weight.shape[1]}"
|
28 |
+
|
29 |
+
def __call__(self, x):
|
30 |
+
return self.weight[x]
|
lib/python3.11/site-packages/mlx/nn/layers/linear.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
import math
|
4 |
+
from typing import Any
|
5 |
+
|
6 |
+
import mlx.core as mx
|
7 |
+
from mlx.nn.layers.base import Module
|
8 |
+
|
9 |
+
|
10 |
+
class Identity(Module):
|
11 |
+
r"""A placeholder identity operator that is argument-insensitive.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
args: any argument (unused)
|
15 |
+
kwargs: any keyword argument (unused)
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
19 |
+
super().__init__()
|
20 |
+
|
21 |
+
def __call__(self, x: mx.array) -> mx.array:
|
22 |
+
return x
|
23 |
+
|
24 |
+
|
25 |
+
class Linear(Module):
|
26 |
+
r"""Applies an affine transformation to the input.
|
27 |
+
|
28 |
+
Concretely:
|
29 |
+
|
30 |
+
.. math::
|
31 |
+
|
32 |
+
y = x W^\top + b
|
33 |
+
|
34 |
+
where:
|
35 |
+
where :math:`W` has shape ``[output_dims, input_dims]`` and :math:`b` has shape ``[output_dims]``.
|
36 |
+
|
37 |
+
The values are initialized from the uniform distribution :math:`\mathcal{U}(-{k}, {k})`,
|
38 |
+
where :math:`k = \frac{1}{\sqrt{D_i}}` and :math:`D_i` is equal to ``input_dims``.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
input_dims (int): The dimensionality of the input features
|
42 |
+
output_dims (int): The dimensionality of the output features
|
43 |
+
bias (bool, optional): If set to ``False`` then the layer will
|
44 |
+
not use a bias. Default is ``True``.
|
45 |
+
"""
|
46 |
+
|
47 |
+
def __init__(self, input_dims: int, output_dims: int, bias: bool = True) -> None:
|
48 |
+
super().__init__()
|
49 |
+
scale = math.sqrt(1.0 / input_dims)
|
50 |
+
self.weight = mx.random.uniform(
|
51 |
+
low=-scale,
|
52 |
+
high=scale,
|
53 |
+
shape=(output_dims, input_dims),
|
54 |
+
)
|
55 |
+
if bias:
|
56 |
+
self.bias = mx.random.uniform(
|
57 |
+
low=-scale,
|
58 |
+
high=scale,
|
59 |
+
shape=(output_dims,),
|
60 |
+
)
|
61 |
+
|
62 |
+
def _extra_repr(self) -> str:
|
63 |
+
return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}"
|
64 |
+
|
65 |
+
def __call__(self, x: mx.array) -> mx.array:
|
66 |
+
x = x @ self.weight.T
|
67 |
+
if "bias" in self:
|
68 |
+
x = x + self.bias
|
69 |
+
return x
|
70 |
+
|
71 |
+
|
72 |
+
class Bilinear(Module):
|
73 |
+
r"""Applies a bilinear transformation to the inputs.
|
74 |
+
|
75 |
+
Concretely:
|
76 |
+
|
77 |
+
.. math::
|
78 |
+
|
79 |
+
y_i = x_1^\top W_i x_2 + b_i
|
80 |
+
|
81 |
+
where:
|
82 |
+
:math:`W` has shape ``[output_dims, input1_dims, input2_dims]``, :math:`b` has shape ``[output_dims ]``,
|
83 |
+
and :math:`i` indexes the output dimension.
|
84 |
+
|
85 |
+
The values are initialized from the uniform distribution :math:`\mathcal{U}(-{k}, {k})`,
|
86 |
+
where :math:`k = \frac{1}{\sqrt{D_1}}` and :math:`D_1` is ``input1_dims``.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
input1_dims (int): The dimensionality of the input1 features
|
90 |
+
input2_dims (int): The dimensionality of the input2 features
|
91 |
+
output_dims (int): The dimensionality of the output features
|
92 |
+
bias (bool, optional): If set to ``False`` then the layer will
|
93 |
+
not use a bias. Default is ``True``.
|
94 |
+
"""
|
95 |
+
|
96 |
+
def __init__(
|
97 |
+
self, input1_dims: int, input2_dims: int, output_dims: int, bias: bool = True
|
98 |
+
) -> None:
|
99 |
+
super().__init__()
|
100 |
+
scale = math.sqrt(1.0 / input1_dims)
|
101 |
+
self.weight = mx.random.uniform(
|
102 |
+
low=-scale,
|
103 |
+
high=scale,
|
104 |
+
shape=(output_dims, input2_dims, input1_dims),
|
105 |
+
)
|
106 |
+
if bias:
|
107 |
+
self.bias = mx.random.uniform(
|
108 |
+
low=-scale,
|
109 |
+
high=scale,
|
110 |
+
shape=(output_dims,),
|
111 |
+
)
|
112 |
+
|
113 |
+
def _extra_repr(self) -> str:
|
114 |
+
out, in2, in1 = self.weight.shape
|
115 |
+
return (
|
116 |
+
f"input1_dims={in1}, input2_dims={in2}, output_dims={out}, "
|
117 |
+
f"bias={'bias' in self}"
|
118 |
+
)
|
119 |
+
|
120 |
+
def __call__(self, x1: mx.array, x2: mx.array) -> mx.array:
|
121 |
+
# Normalize shapes
|
122 |
+
out, in2, in1 = self.weight.shape
|
123 |
+
xshape = x1.shape[:-1]
|
124 |
+
x1 = x1.reshape(-1, in1)
|
125 |
+
x2 = x2.reshape(-1, 1, in2)
|
126 |
+
|
127 |
+
# Perform the bilinear transformation
|
128 |
+
w = self.weight.reshape(out * in2, in1)
|
129 |
+
y = x1 @ w.T
|
130 |
+
y = y.reshape(-1, out, in2).swapaxes(-2, -1)
|
131 |
+
y = x2 @ y
|
132 |
+
y = y.squeeze(1)
|
133 |
+
|
134 |
+
# Reset the shape
|
135 |
+
y = y.reshape(*xshape, out)
|
136 |
+
|
137 |
+
# Apply the bias
|
138 |
+
if "bias" in self:
|
139 |
+
y = y + self.bias
|
140 |
+
|
141 |
+
return y
|
lib/python3.11/site-packages/mlx/nn/layers/normalization.py
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
from typing import Tuple
|
4 |
+
|
5 |
+
import mlx.core as mx
|
6 |
+
from mlx.nn.layers.base import Module
|
7 |
+
|
8 |
+
|
9 |
+
class InstanceNorm(Module):
|
10 |
+
r"""Applies instance normalization [1] on the inputs.
|
11 |
+
|
12 |
+
Computes
|
13 |
+
|
14 |
+
.. math::
|
15 |
+
|
16 |
+
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta,
|
17 |
+
|
18 |
+
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
|
19 |
+
parameters initialized at 1 and 0 respectively. Both are of size :attr:`dims`,
|
20 |
+
if :attr:`affine` is ``True``.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
dims (int): The number of features of the input.
|
24 |
+
eps (float): A value added to the denominator for numerical stability. Default: ``1e-5``.
|
25 |
+
affine (bool): Default: ``False``.
|
26 |
+
|
27 |
+
Shape:
|
28 |
+
- Input: :math:`(..., C)` where :math:`C` is equal to :attr:`dims`.
|
29 |
+
- Output: Same shape as the input.
|
30 |
+
|
31 |
+
Examples:
|
32 |
+
>>> import mlx.core as mx
|
33 |
+
>>> import mlx.nn as nn
|
34 |
+
>>> x = mx.random.normal((8, 4, 4, 16))
|
35 |
+
>>> inorm = nn.InstanceNorm(dims=16)
|
36 |
+
>>> output = inorm(x)
|
37 |
+
|
38 |
+
References:
|
39 |
+
[1]: https://arxiv.org/abs/1607.08022
|
40 |
+
"""
|
41 |
+
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
dims: int,
|
45 |
+
eps: float = 1e-5,
|
46 |
+
affine: bool = False,
|
47 |
+
):
|
48 |
+
super().__init__()
|
49 |
+
if affine:
|
50 |
+
self.weight = mx.ones((dims,))
|
51 |
+
self.bias = mx.zeros((dims,))
|
52 |
+
self.dims = dims
|
53 |
+
self.eps = eps
|
54 |
+
|
55 |
+
def _extra_repr(self):
|
56 |
+
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
|
57 |
+
|
58 |
+
def __call__(self, x: mx.array) -> mx.array:
|
59 |
+
reduction_axes = tuple(range(1, x.ndim - 1))
|
60 |
+
# Compute stats
|
61 |
+
mean = mx.mean(x, axis=reduction_axes, keepdims=True)
|
62 |
+
var = mx.var(x, axis=reduction_axes, keepdims=True)
|
63 |
+
# Normalize
|
64 |
+
x = (x - mean) * mx.rsqrt(var + self.eps)
|
65 |
+
# Scale and shift if necessary
|
66 |
+
return (self.weight * x + self.bias) if "weight" in self else x
|
67 |
+
|
68 |
+
|
69 |
+
class LayerNorm(Module):
|
70 |
+
r"""Applies layer normalization [1] on the inputs.
|
71 |
+
|
72 |
+
Computes
|
73 |
+
|
74 |
+
.. math::
|
75 |
+
|
76 |
+
y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,
|
77 |
+
|
78 |
+
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
|
79 |
+
parameters initialized at 1 and 0 respectively.
|
80 |
+
|
81 |
+
[1]: https://arxiv.org/abs/1607.06450
|
82 |
+
|
83 |
+
Args:
|
84 |
+
dims (int): The feature dimension of the input to normalize over
|
85 |
+
eps (float): A small additive constant for numerical stability
|
86 |
+
affine (bool): If True learn an affine transform to apply after the
|
87 |
+
normalization
|
88 |
+
"""
|
89 |
+
|
90 |
+
def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True):
|
91 |
+
super().__init__()
|
92 |
+
if affine:
|
93 |
+
self.bias = mx.zeros((dims,))
|
94 |
+
self.weight = mx.ones((dims,))
|
95 |
+
self.eps = eps
|
96 |
+
self.dims = dims
|
97 |
+
|
98 |
+
def _extra_repr(self):
|
99 |
+
return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
|
100 |
+
|
101 |
+
def __call__(self, x):
|
102 |
+
means = mx.mean(x, axis=-1, keepdims=True)
|
103 |
+
var = mx.var(x, axis=-1, keepdims=True)
|
104 |
+
x = (x - means) * mx.rsqrt(var + self.eps)
|
105 |
+
return (self.weight * x + self.bias) if "weight" in self else x
|
106 |
+
|
107 |
+
|
108 |
+
class RMSNorm(Module):
|
109 |
+
r"""Applies Root Mean Square normalization [1] to the inputs.
|
110 |
+
|
111 |
+
Computes
|
112 |
+
|
113 |
+
.. math::
|
114 |
+
|
115 |
+
y = \frac{x}{\sqrt{E[x^2] + \epsilon}} \gamma
|
116 |
+
|
117 |
+
where :math:`\gamma` is a learned per feature dimension parameter initialized at
|
118 |
+
1.
|
119 |
+
|
120 |
+
[1]: https://arxiv.org/abs/1910.07467
|
121 |
+
|
122 |
+
Args:
|
123 |
+
dims (int): The feature dimension of the input to normalize over
|
124 |
+
eps (float): A small additive constant for numerical stability
|
125 |
+
"""
|
126 |
+
|
127 |
+
def __init__(self, dims: int, eps: float = 1e-5):
|
128 |
+
super().__init__()
|
129 |
+
self.weight = mx.ones((dims,))
|
130 |
+
self.eps = eps
|
131 |
+
|
132 |
+
def _extra_repr(self):
|
133 |
+
return f"{self.weight.shape[0]}, eps={self.eps}"
|
134 |
+
|
135 |
+
def __call__(self, x):
|
136 |
+
# S is 1/sqrt(N) where N is the size of the features of x and is used
|
137 |
+
# to compute a numerically more stable RMS of x by multiplying with S
|
138 |
+
# first and summing.
|
139 |
+
#
|
140 |
+
# This way we prefer underflow over overflow which is controlled with
|
141 |
+
# the parameter epsilon anyway.
|
142 |
+
S = 1 / x.shape[-1] ** 0.5
|
143 |
+
|
144 |
+
n = (x * S).square().sum(axis=-1, keepdims=True)
|
145 |
+
n = mx.rsqrt(n + self.eps)
|
146 |
+
|
147 |
+
return self.weight * x * n
|
148 |
+
|
149 |
+
|
150 |
+
class GroupNorm(Module):
|
151 |
+
r"""Applies Group Normalization [1] to the inputs.
|
152 |
+
|
153 |
+
Computes the same normalization as layer norm, namely
|
154 |
+
|
155 |
+
.. math::
|
156 |
+
|
157 |
+
y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,
|
158 |
+
|
159 |
+
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
|
160 |
+
parameters initialized at 1 and 0 respectively. However, the mean and
|
161 |
+
variance are computed over the spatial dimensions and each group of
|
162 |
+
features. In particular, the input is split into num_groups across the
|
163 |
+
feature dimension.
|
164 |
+
|
165 |
+
The feature dimension is assumed to be the last dimension and the dimensions
|
166 |
+
that precede it (except the first) are considered the spatial dimensions.
|
167 |
+
|
168 |
+
[1]: https://arxiv.org/abs/1803.08494
|
169 |
+
|
170 |
+
Args:
|
171 |
+
num_groups (int): Number of groups to separate the features into
|
172 |
+
dims (int): The feature dimensions of the input to normalize over
|
173 |
+
eps (float): A small additive constant for numerical stability
|
174 |
+
affine (bool): If True learn an affine transform to apply after the
|
175 |
+
normalization.
|
176 |
+
pytorch_compatible (bool): If True perform the group normalization in
|
177 |
+
the same order/grouping as PyTorch.
|
178 |
+
"""
|
179 |
+
|
180 |
+
def __init__(
|
181 |
+
self,
|
182 |
+
num_groups: int,
|
183 |
+
dims: int,
|
184 |
+
eps: float = 1e-5,
|
185 |
+
affine: bool = True,
|
186 |
+
pytorch_compatible: bool = False,
|
187 |
+
):
|
188 |
+
super().__init__()
|
189 |
+
if affine:
|
190 |
+
self.bias = mx.zeros((dims,))
|
191 |
+
self.weight = mx.ones((dims,))
|
192 |
+
self.num_groups = num_groups
|
193 |
+
self.dims = dims
|
194 |
+
self.eps = eps
|
195 |
+
self.pytorch_compatible = pytorch_compatible
|
196 |
+
|
197 |
+
def _extra_repr(self):
|
198 |
+
return (
|
199 |
+
f"{self.num_groups}, {self.dims}, eps={self.eps}, "
|
200 |
+
f"affine={'weight' in self}, pytorch_compatible={self.pytorch_compatible}"
|
201 |
+
)
|
202 |
+
|
203 |
+
def _pytorch_compatible_group_norm(self, x):
|
204 |
+
num_groups = self.num_groups
|
205 |
+
batch, *rest, dims = x.shape
|
206 |
+
|
207 |
+
# Split into groups
|
208 |
+
x = x.reshape(batch, -1, num_groups, dims // num_groups)
|
209 |
+
x = x.transpose(0, 1, 3, 2).reshape(batch, -1, num_groups)
|
210 |
+
|
211 |
+
# Normalize
|
212 |
+
means = mx.mean(x, axis=1, keepdims=True)
|
213 |
+
var = mx.var(x, axis=1, keepdims=True)
|
214 |
+
x = (x - means) * mx.rsqrt(var + self.eps)
|
215 |
+
x = x.reshape(batch, -1, dims // num_groups, num_groups)
|
216 |
+
x = x.transpose(0, 1, 3, 2).reshape(batch, *rest, dims)
|
217 |
+
|
218 |
+
return x
|
219 |
+
|
220 |
+
def _group_norm(self, x):
|
221 |
+
num_groups = self.num_groups
|
222 |
+
batch, *rest, dims = x.shape
|
223 |
+
|
224 |
+
# Split into groups
|
225 |
+
x = x.reshape(batch, -1, num_groups)
|
226 |
+
|
227 |
+
# Normalize
|
228 |
+
means = mx.mean(x, axis=1, keepdims=True)
|
229 |
+
var = mx.var(x, axis=1, keepdims=True)
|
230 |
+
x = (x - means) * mx.rsqrt(var + self.eps)
|
231 |
+
x = x.reshape(batch, *rest, dims)
|
232 |
+
|
233 |
+
return x
|
234 |
+
|
235 |
+
def __call__(self, x):
|
236 |
+
group_norm = (
|
237 |
+
self._pytorch_compatible_group_norm
|
238 |
+
if self.pytorch_compatible
|
239 |
+
else self._group_norm
|
240 |
+
)
|
241 |
+
x = group_norm(x)
|
242 |
+
return (self.weight * x + self.bias) if "weight" in self else x
|
243 |
+
|
244 |
+
|
245 |
+
class BatchNorm(Module):
|
246 |
+
r"""Applies Batch Normalization over a 2D or 3D input.
|
247 |
+
|
248 |
+
Computes
|
249 |
+
|
250 |
+
.. math::
|
251 |
+
|
252 |
+
y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,
|
253 |
+
|
254 |
+
where :math:`\gamma` and :math:`\beta` are learned per feature dimension
|
255 |
+
parameters initialized at 1 and 0 respectively.
|
256 |
+
|
257 |
+
The input shape is specified as ``NC`` or ``NLC``, where ``N`` is the
|
258 |
+
batch, ``C`` is the number of features or channels, and ``L`` is the
|
259 |
+
sequence length. The output has the same shape as the input. For
|
260 |
+
four-dimensional arrays, the shape is ``NHWC``, where ``H`` and ``W`` are
|
261 |
+
the height and width respectively.
|
262 |
+
|
263 |
+
For more information on Batch Normalization, see the original paper `Batch
|
264 |
+
Normalization: Accelerating Deep Network Training by Reducing Internal
|
265 |
+
Covariate Shift <https://arxiv.org/abs/1502.03167>`_.
|
266 |
+
|
267 |
+
Args:
|
268 |
+
num_features (int): The feature dimension to normalize over.
|
269 |
+
eps (float, optional): A small additive constant for numerical
|
270 |
+
stability. Default: ``1e-5``.
|
271 |
+
momentum (float, optional): The momentum for updating the running
|
272 |
+
mean and variance. Default: ``0.1``.
|
273 |
+
affine (bool, optional): If ``True``, apply a learned affine
|
274 |
+
transformation after the normalization. Default: ``True``.
|
275 |
+
track_running_stats (bool, optional): If ``True``, track the
|
276 |
+
running mean and variance. Default: ``True``.
|
277 |
+
|
278 |
+
Examples:
|
279 |
+
>>> import mlx.core as mx
|
280 |
+
>>> import mlx.nn as nn
|
281 |
+
>>> x = mx.random.normal((5, 4))
|
282 |
+
>>> bn = nn.BatchNorm(num_features=4, affine=True)
|
283 |
+
>>> output = bn(x)
|
284 |
+
"""
|
285 |
+
|
286 |
+
def __init__(
|
287 |
+
self,
|
288 |
+
num_features: int,
|
289 |
+
eps: float = 1e-5,
|
290 |
+
momentum: float = 0.1,
|
291 |
+
affine: bool = True,
|
292 |
+
track_running_stats: bool = True,
|
293 |
+
):
|
294 |
+
super().__init__()
|
295 |
+
|
296 |
+
self.num_features = num_features
|
297 |
+
self.eps = eps
|
298 |
+
self.momentum = momentum
|
299 |
+
self.track_running_stats = track_running_stats
|
300 |
+
|
301 |
+
if affine:
|
302 |
+
self.weight = mx.ones((num_features,))
|
303 |
+
self.bias = mx.zeros((num_features,))
|
304 |
+
|
305 |
+
if self.track_running_stats:
|
306 |
+
self.running_mean = mx.zeros((num_features,))
|
307 |
+
self.running_var = mx.ones((num_features,))
|
308 |
+
self.freeze(keys=["running_mean", "running_var"], recurse=False)
|
309 |
+
|
310 |
+
def unfreeze(self, *args, **kwargs):
|
311 |
+
"""Wrap unfreeze to make sure that running_mean and var are always
|
312 |
+
frozen parameters."""
|
313 |
+
super().unfreeze(*args, **kwargs)
|
314 |
+
self.freeze(keys=["running_mean", "running_var"], recurse=False)
|
315 |
+
|
316 |
+
def _extra_repr(self):
|
317 |
+
return (
|
318 |
+
f"{self.num_features}, eps={self.eps}, "
|
319 |
+
f"momentum={self.momentum}, affine={'weight' in self}, "
|
320 |
+
f"track_running_stats={self.track_running_stats}"
|
321 |
+
)
|
322 |
+
|
323 |
+
def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
|
324 |
+
"""
|
325 |
+
Calculate the mean and variance of the input tensor across the batch
|
326 |
+
and spatial dimensions.
|
327 |
+
|
328 |
+
Args:
|
329 |
+
x (array): Input tensor.
|
330 |
+
|
331 |
+
Returns:
|
332 |
+
tuple: Tuple containing mean and variance.
|
333 |
+
"""
|
334 |
+
reduction_axes = tuple(range(0, x.ndim - 1))
|
335 |
+
|
336 |
+
mean = mx.mean(x, axis=reduction_axes, keepdims=True)
|
337 |
+
var = mx.var(x, axis=reduction_axes, keepdims=True)
|
338 |
+
|
339 |
+
return mean, var
|
340 |
+
|
341 |
+
def __call__(self, x: mx.array) -> mx.array:
|
342 |
+
"""
|
343 |
+
Forward pass of BatchNorm.
|
344 |
+
|
345 |
+
Args:
|
346 |
+
x (array): Input tensor.
|
347 |
+
|
348 |
+
Returns:
|
349 |
+
array: Normalized output tensor.
|
350 |
+
"""
|
351 |
+
if x.ndim < 2 or x.ndim > 4:
|
352 |
+
raise ValueError(
|
353 |
+
f"Expected input tensor to have 2, 3 or 4 dimensions, but got {x.ndim}"
|
354 |
+
)
|
355 |
+
|
356 |
+
# Calculate the mean and variance used to normalize the input x. If we
|
357 |
+
# are in training mode update the running stats if needed.
|
358 |
+
mean, var = self._calc_stats(x)
|
359 |
+
if self.training and self.track_running_stats:
|
360 |
+
mu = self.momentum
|
361 |
+
self.running_mean = (1 - mu) * self.running_mean + mu * mean
|
362 |
+
self.running_var = (1 - mu) * self.running_var + mu * var
|
363 |
+
elif self.track_running_stats:
|
364 |
+
mean = self.running_mean
|
365 |
+
var = self.running_var
|
366 |
+
|
367 |
+
x = (x - mean) * mx.rsqrt(var + self.eps)
|
368 |
+
return (self.weight * x + self.bias) if "weight" in self else x
|
lib/python3.11/site-packages/mlx/nn/layers/positional_encoding.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
import math
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
import mlx.core as mx
|
7 |
+
from mlx.nn.layers.base import Module
|
8 |
+
|
9 |
+
|
10 |
+
class RoPE(Module):
|
11 |
+
"""Implements the rotary positional encoding.
|
12 |
+
|
13 |
+
The traditional implementation rotates consecutive pairs of elements in the
|
14 |
+
feature dimension while the default implementation rotates pairs with
|
15 |
+
stride half the feature dimensions for efficiency.
|
16 |
+
|
17 |
+
For more details see `RoFormer: Enhanced Transformer with Rotary Position
|
18 |
+
Embedding <https://arxiv.org/abs/2104.09864>`_.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
dims (int): The feature dimensions to be rotated. If the input feature
|
22 |
+
is larger than dims then the rest is left unchanged.
|
23 |
+
traditional (bool, optional): If set to True choose the traditional
|
24 |
+
implementation which is slightly less efficient. Default: ``False``.
|
25 |
+
base (float, optional): The base used to compute angular frequency for
|
26 |
+
each dimension in the positional encodings. Default: ``10000``.
|
27 |
+
scale (float, optional): The scale used to scale the positions. Default: ``1.0``.
|
28 |
+
"""
|
29 |
+
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
dims: int,
|
33 |
+
traditional: bool = False,
|
34 |
+
base: float = 10000,
|
35 |
+
scale: float = 1.0,
|
36 |
+
):
|
37 |
+
super().__init__()
|
38 |
+
self.dims = dims
|
39 |
+
self.traditional = traditional
|
40 |
+
self.base = base
|
41 |
+
self.scale = scale
|
42 |
+
|
43 |
+
def _extra_repr(self):
|
44 |
+
return f"{self.dims}, traditional={self.traditional}"
|
45 |
+
|
46 |
+
def _compute_rope(self, costheta, sintheta, x):
|
47 |
+
x1 = x[..., : self.dims // 2]
|
48 |
+
x2 = x[..., self.dims // 2 : self.dims]
|
49 |
+
rx1 = x1 * costheta - x2 * sintheta
|
50 |
+
rx2 = x1 * sintheta + x2 * costheta
|
51 |
+
|
52 |
+
if self.dims < x.shape[-1]:
|
53 |
+
rx = mx.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1)
|
54 |
+
else:
|
55 |
+
rx = mx.concatenate([rx1, rx2], axis=-1)
|
56 |
+
|
57 |
+
return rx
|
58 |
+
|
59 |
+
def _compute_traditional_rope(self, costheta, sintheta, x):
|
60 |
+
x1 = x[..., ::2]
|
61 |
+
x2 = x[..., 1::2]
|
62 |
+
rx1 = x1 * costheta - x2 * sintheta
|
63 |
+
rx2 = x1 * sintheta + x2 * costheta
|
64 |
+
|
65 |
+
if self.dims < x.shape[-1]:
|
66 |
+
raise NotImplementedError(
|
67 |
+
"RoPE doesn't implement partial traditional application"
|
68 |
+
)
|
69 |
+
|
70 |
+
rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
|
71 |
+
|
72 |
+
return rx
|
73 |
+
|
74 |
+
def __call__(self, x, offset: int = 0):
|
75 |
+
shape = x.shape
|
76 |
+
x = mx.reshape(x, (-1, shape[-2], shape[-1]))
|
77 |
+
N = x.shape[1] + offset
|
78 |
+
costheta, sintheta = RoPE.create_cos_sin_theta(
|
79 |
+
N, self.dims, offset=offset, base=self.base, scale=self.scale, dtype=x.dtype
|
80 |
+
)
|
81 |
+
|
82 |
+
rope = (
|
83 |
+
self._compute_traditional_rope if self.traditional else self._compute_rope
|
84 |
+
)
|
85 |
+
rx = rope(costheta, sintheta, x)
|
86 |
+
|
87 |
+
return mx.reshape(rx, shape)
|
88 |
+
|
89 |
+
@staticmethod
|
90 |
+
def create_cos_sin_theta(
|
91 |
+
N: int,
|
92 |
+
D: int,
|
93 |
+
offset: int = 0,
|
94 |
+
base: float = 10000,
|
95 |
+
scale: float = 1.0,
|
96 |
+
dtype=mx.float32,
|
97 |
+
):
|
98 |
+
D = D // 2
|
99 |
+
positions = mx.arange(offset, N, dtype=dtype) * scale
|
100 |
+
freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D))
|
101 |
+
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
|
102 |
+
return mx.cos(theta), mx.sin(theta)
|
103 |
+
|
104 |
+
|
105 |
+
class SinusoidalPositionalEncoding(Module):
|
106 |
+
r"""Implements sinusoidal positional encoding.
|
107 |
+
|
108 |
+
For more details see the paper `Attention Is All You Need
|
109 |
+
<https://arxiv.org/abs/1706.03762>`_.
|
110 |
+
|
111 |
+
Args:
|
112 |
+
dims (int): The dimensionality of the resulting positional embeddings.
|
113 |
+
min_freq (float, optional): The minimum frequency expected. Default:
|
114 |
+
``0.0001``.
|
115 |
+
max_freq (float, optional): The maximum frequency expected. Default:
|
116 |
+
``1``.
|
117 |
+
scale (float, optional): A multiplicative scale for the embeddings.
|
118 |
+
Default: ``sqrt(dims//2)``.
|
119 |
+
cos_first (bool, optional): If ``True`` embed using ``[cos(x); sin(x)]``
|
120 |
+
instead of the reverse. Default: ``False``.
|
121 |
+
full_turns (bool, optional): If ``True`` multiply the frequencies with
|
122 |
+
:math:`2\pi`. Default: ``False``.
|
123 |
+
"""
|
124 |
+
|
125 |
+
def __init__(
|
126 |
+
self,
|
127 |
+
dims: int,
|
128 |
+
min_freq: float = 0.0001,
|
129 |
+
max_freq: float = 1,
|
130 |
+
scale: Optional[float] = None,
|
131 |
+
cos_first: bool = False,
|
132 |
+
full_turns: bool = False,
|
133 |
+
):
|
134 |
+
super().__init__()
|
135 |
+
|
136 |
+
one_zero = 1 - mx.arange(0, dims // 2) / (dims // 2 - 1)
|
137 |
+
min_freq = math.log(min_freq)
|
138 |
+
max_freq = math.log(max_freq)
|
139 |
+
|
140 |
+
# Start with underscore so it is not included in the parameters
|
141 |
+
self._sigmas = mx.exp(one_zero * (max_freq - min_freq) + min_freq)
|
142 |
+
if full_turns:
|
143 |
+
self._sigmas = self._sigmas * (2 * math.pi)
|
144 |
+
|
145 |
+
# Save some constants that define the implementation
|
146 |
+
self.scale = scale or (2 / dims) ** 0.5
|
147 |
+
self.cos_first = cos_first
|
148 |
+
|
149 |
+
def __call__(self, x):
|
150 |
+
y = x[..., None] * self._sigmas
|
151 |
+
cosy = mx.cos(y)
|
152 |
+
siny = mx.sin(y)
|
153 |
+
|
154 |
+
if self.cos_first:
|
155 |
+
y = mx.concatenate([cosy, siny], axis=-1)
|
156 |
+
else:
|
157 |
+
y = mx.concatenate([siny, cosy], axis=-1)
|
158 |
+
|
159 |
+
if self.scale != 1:
|
160 |
+
y = y * self.scale
|
161 |
+
|
162 |
+
return y
|
163 |
+
|
164 |
+
|
165 |
+
class ALiBi(Module):
|
166 |
+
@staticmethod
|
167 |
+
def create_alibi_matrix(
|
168 |
+
q_sequence_length: int,
|
169 |
+
k_sequence_length: int,
|
170 |
+
num_heads: int,
|
171 |
+
offset: int,
|
172 |
+
dtype=mx.float32,
|
173 |
+
):
|
174 |
+
x1 = mx.arange(offset, q_sequence_length)
|
175 |
+
x2 = mx.arange(0, k_sequence_length)
|
176 |
+
distance_matrix = -mx.abs(
|
177 |
+
mx.expand_dims(x1[:, None] - x2[None, :], axis=(0, 1))
|
178 |
+
)
|
179 |
+
alibi_slope = ALiBi.create_alibi_slope(num_heads=num_heads)
|
180 |
+
alibi_mask = (distance_matrix * alibi_slope).astype(dtype)
|
181 |
+
return alibi_mask
|
182 |
+
|
183 |
+
@staticmethod
|
184 |
+
def create_alibi_slope(num_heads):
|
185 |
+
x = (2**8) ** (1 / num_heads)
|
186 |
+
out = mx.power(x, -mx.arange(1, num_heads + 1))
|
187 |
+
return mx.expand_dims(out, axis=(-1, -2))
|
188 |
+
|
189 |
+
def __call__(self, attention_scores, offset=0, mask=None):
|
190 |
+
alibi_mask = ALiBi.create_alibi_matrix(
|
191 |
+
q_sequence_length=attention_scores.shape[-2] + offset,
|
192 |
+
k_sequence_length=attention_scores.shape[-1],
|
193 |
+
num_heads=attention_scores.shape[1],
|
194 |
+
offset=offset,
|
195 |
+
dtype=attention_scores.dtype,
|
196 |
+
)
|
197 |
+
if mask is not None:
|
198 |
+
alibi_mask = alibi_mask + mask
|
199 |
+
return attention_scores + alibi_mask
|
lib/python3.11/site-packages/mlx/nn/layers/quantized.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
import mlx.core as mx
|
6 |
+
from mlx.nn.layers.base import Module
|
7 |
+
from mlx.nn.layers.linear import Linear
|
8 |
+
from mlx.utils import tree_flatten, tree_map
|
9 |
+
|
10 |
+
|
11 |
+
class QuantizedLinear(Module):
|
12 |
+
"""Applies an affine transformation to the input using a quantized weight matrix.
|
13 |
+
|
14 |
+
It is the quantized equivalent of :class:`mlx.nn.Linear`. For now its
|
15 |
+
parameters are frozen and will not be included in any gradient computation
|
16 |
+
but this will probably change in the future.
|
17 |
+
|
18 |
+
QuantizedLinear also provides two useful classmethods to convert linear
|
19 |
+
layers to QuantizedLinear layers.
|
20 |
+
|
21 |
+
- :meth:`from_linear` returns a QuantizedLinear layer that applies the same
|
22 |
+
linear transformation up to the quantization error.
|
23 |
+
- :meth:`quantize_module` swaps all the linear layers of the passed module
|
24 |
+
with QuantizedLinear ones.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
input_dims (int): The dimensionality of the input features
|
28 |
+
output_dims (int): The dimensionality of the output features
|
29 |
+
bias (bool, optional): If set to ``False`` then the layer will not use
|
30 |
+
a bias. (default: True).
|
31 |
+
group_size (int, optional): The group size to use for the quantized
|
32 |
+
weight. See :func:`~mlx.core.quantize`. (default: 64)
|
33 |
+
bits (int, optional): The bit width to use for the quantized weight.
|
34 |
+
See :func:`~mlx.core.quantize`. (default: 4)
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
input_dims: int,
|
40 |
+
output_dims: int,
|
41 |
+
bias: bool = True,
|
42 |
+
group_size: int = 64,
|
43 |
+
bits: int = 4,
|
44 |
+
):
|
45 |
+
super().__init__()
|
46 |
+
|
47 |
+
# Quantization config
|
48 |
+
self.group_size = group_size
|
49 |
+
self.bits = bits
|
50 |
+
|
51 |
+
# Initialize the quantized weight
|
52 |
+
scale = math.sqrt(1 / input_dims)
|
53 |
+
weight = mx.random.uniform(
|
54 |
+
low=-scale,
|
55 |
+
high=scale,
|
56 |
+
shape=(output_dims, input_dims),
|
57 |
+
)
|
58 |
+
self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
|
59 |
+
|
60 |
+
# And bias if needed
|
61 |
+
if bias:
|
62 |
+
self.bias = mx.zeros((output_dims,))
|
63 |
+
|
64 |
+
# Freeze this model's parameters
|
65 |
+
self.freeze()
|
66 |
+
|
67 |
+
def unfreeze(self, *args, **kwargs):
|
68 |
+
"""Wrap unfreeze so that we unfreeze any layers we might contain but
|
69 |
+
our parameters will remain frozen."""
|
70 |
+
super().unfreeze(*args, **kwargs)
|
71 |
+
self.freeze(recurse=False)
|
72 |
+
|
73 |
+
def _extra_repr(self):
|
74 |
+
out_dims, in_dims = self.weight.shape
|
75 |
+
in_dims *= 32 // self.bits
|
76 |
+
return (
|
77 |
+
f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self},"
|
78 |
+
f"group_size={self.group_size}, bits={self.bits}"
|
79 |
+
)
|
80 |
+
|
81 |
+
def __call__(self, x):
|
82 |
+
x = mx.quantized_matmul(
|
83 |
+
x,
|
84 |
+
self.weight,
|
85 |
+
scales=self.scales,
|
86 |
+
biases=self.biases,
|
87 |
+
transpose=True,
|
88 |
+
group_size=self.group_size,
|
89 |
+
bits=self.bits,
|
90 |
+
)
|
91 |
+
if "bias" in self:
|
92 |
+
x = x + self.bias
|
93 |
+
return x
|
94 |
+
|
95 |
+
@classmethod
|
96 |
+
def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
|
97 |
+
"""Create a QuantizedLinear layer from the parameters of a provided
|
98 |
+
linear layer."""
|
99 |
+
output_dims, input_dims = linear_layer.weight.shape
|
100 |
+
ql = cls(input_dims, output_dims, False, group_size, bits)
|
101 |
+
ql.weight, ql.scales, ql.biases = mx.quantize(
|
102 |
+
linear_layer.weight, group_size, bits
|
103 |
+
)
|
104 |
+
if "bias" in linear_layer:
|
105 |
+
ql.bias = linear_layer.bias
|
106 |
+
|
107 |
+
return ql
|
108 |
+
|
109 |
+
@classmethod
|
110 |
+
def quantize_module(
|
111 |
+
cls,
|
112 |
+
model: Module,
|
113 |
+
group_size: int = 64,
|
114 |
+
bits: int = 4,
|
115 |
+
linear_class_predicate=lambda m: isinstance(m, Linear),
|
116 |
+
):
|
117 |
+
def _quantize_if_linear(m):
|
118 |
+
if linear_class_predicate(m):
|
119 |
+
return cls.from_linear(m, group_size, bits)
|
120 |
+
else:
|
121 |
+
return m
|
122 |
+
|
123 |
+
leaves = model.leaf_modules()
|
124 |
+
leaves = tree_map(_quantize_if_linear, leaves, is_leaf=Module.is_module)
|
125 |
+
model.update_modules(leaves)
|
lib/python3.11/site-packages/mlx/nn/layers/transformer.py
ADDED
@@ -0,0 +1,354 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
import math
|
4 |
+
from typing import Any, Callable, Optional
|
5 |
+
|
6 |
+
import mlx.core as mx
|
7 |
+
from mlx.nn.layers.activations import relu
|
8 |
+
from mlx.nn.layers.base import Module
|
9 |
+
from mlx.nn.layers.dropout import Dropout
|
10 |
+
from mlx.nn.layers.linear import Linear
|
11 |
+
from mlx.nn.layers.normalization import LayerNorm
|
12 |
+
|
13 |
+
|
14 |
+
class MultiHeadAttention(Module):
|
15 |
+
"""Implements the scaled dot product attention with multiple heads.
|
16 |
+
|
17 |
+
Given inputs for queries, keys and values the ``MultiHeadAttention``
|
18 |
+
produces new values by aggregating information from the input values
|
19 |
+
according to the similarities of the input queries and keys.
|
20 |
+
|
21 |
+
All inputs as well as the output are linearly projected without biases by
|
22 |
+
default.
|
23 |
+
|
24 |
+
``MultiHeadAttention`` also takes an optional additive attention mask that
|
25 |
+
should be broadcastable with ``(batch, num_heads, # queries, # keys)``. The
|
26 |
+
mask should have ``-inf`` or very large negative numbers at the positions
|
27 |
+
that should *not* be attended to.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
dims (int): The model dimensions. This is also the default
|
31 |
+
value for the queries, keys, values, and the output.
|
32 |
+
num_heads (int): The number of attention heads to use.
|
33 |
+
query_input_dims (int, optional): The input dimensions of the queries.
|
34 |
+
Default: ``dims``.
|
35 |
+
key_input_dims (int, optional): The input dimensions of the keys.
|
36 |
+
Default: ``dims``.
|
37 |
+
value_input_dims (int, optional): The input dimensions of the values.
|
38 |
+
Default: ``key_input_dims``.
|
39 |
+
value_dims (int, optional): The dimensions of the values after the
|
40 |
+
projection. Default: ``dims``.
|
41 |
+
value_output_dims (int, optional): The dimensions the new values will
|
42 |
+
be projected to. Default: ``dims``.
|
43 |
+
bias (bool, optional): Whether or not to use a bias in the projections.
|
44 |
+
Default: ``False``.
|
45 |
+
"""
|
46 |
+
|
47 |
+
def __init__(
|
48 |
+
self,
|
49 |
+
dims: int,
|
50 |
+
num_heads: int,
|
51 |
+
query_input_dims: Optional[int] = None,
|
52 |
+
key_input_dims: Optional[int] = None,
|
53 |
+
value_input_dims: Optional[int] = None,
|
54 |
+
value_dims: Optional[int] = None,
|
55 |
+
value_output_dims: Optional[int] = None,
|
56 |
+
bias: bool = False,
|
57 |
+
):
|
58 |
+
super().__init__()
|
59 |
+
|
60 |
+
if (dims % num_heads) != 0:
|
61 |
+
raise ValueError(
|
62 |
+
"The input feature dimensions should be divisible by the "
|
63 |
+
f"number of heads ({dims} % {num_heads}) != 0"
|
64 |
+
)
|
65 |
+
|
66 |
+
query_input_dims = query_input_dims or dims
|
67 |
+
key_input_dims = key_input_dims or dims
|
68 |
+
value_input_dims = value_input_dims or key_input_dims
|
69 |
+
value_dims = value_dims or dims
|
70 |
+
value_output_dims = value_output_dims or dims
|
71 |
+
|
72 |
+
self.num_heads = num_heads
|
73 |
+
self.query_proj = Linear(query_input_dims, dims, bias=bias)
|
74 |
+
self.key_proj = Linear(key_input_dims, dims, bias=bias)
|
75 |
+
self.value_proj = Linear(value_input_dims, value_dims, bias=bias)
|
76 |
+
self.out_proj = Linear(value_dims, value_output_dims, bias=bias)
|
77 |
+
|
78 |
+
def __call__(self, queries, keys, values, mask=None):
|
79 |
+
queries = self.query_proj(queries)
|
80 |
+
keys = self.key_proj(keys)
|
81 |
+
values = self.value_proj(values)
|
82 |
+
|
83 |
+
num_heads = self.num_heads
|
84 |
+
B, L, D = queries.shape
|
85 |
+
_, S, _ = keys.shape
|
86 |
+
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
87 |
+
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
|
88 |
+
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
89 |
+
|
90 |
+
# Dimensions are [batch x num heads x sequence x hidden dim]
|
91 |
+
scale = math.sqrt(1 / queries.shape[-1])
|
92 |
+
scores = (queries * scale) @ keys
|
93 |
+
if mask is not None:
|
94 |
+
scores = scores + mask.astype(scores.dtype)
|
95 |
+
scores = mx.softmax(scores, axis=-1)
|
96 |
+
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
|
97 |
+
|
98 |
+
return self.out_proj(values_hat)
|
99 |
+
|
100 |
+
@staticmethod
|
101 |
+
def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):
|
102 |
+
indices = mx.arange(N)
|
103 |
+
mask = indices[:, None] < indices[None]
|
104 |
+
# usually inf but 1e9 is as good and softmax(full(1e9)) != nan
|
105 |
+
# TODO: Should replace this with finfo(dtype).min
|
106 |
+
mask = mask.astype(dtype) * -1e9
|
107 |
+
return mask
|
108 |
+
|
109 |
+
|
110 |
+
class TransformerEncoderLayer(Module):
|
111 |
+
def __init__(
|
112 |
+
self,
|
113 |
+
dims: int,
|
114 |
+
num_heads: int,
|
115 |
+
mlp_dims: Optional[int] = None,
|
116 |
+
dropout: float = 0.0,
|
117 |
+
activation: Callable[[Any], Any] = relu,
|
118 |
+
norm_first: bool = False,
|
119 |
+
):
|
120 |
+
super().__init__()
|
121 |
+
mlp_dims = mlp_dims or dims * 4
|
122 |
+
self.attention = MultiHeadAttention(dims, num_heads)
|
123 |
+
self.ln1 = LayerNorm(dims)
|
124 |
+
self.ln2 = LayerNorm(dims)
|
125 |
+
self.linear1 = Linear(dims, mlp_dims)
|
126 |
+
self.linear2 = Linear(mlp_dims, dims)
|
127 |
+
self.dropout1 = Dropout(dropout)
|
128 |
+
self.dropout2 = Dropout(dropout)
|
129 |
+
self.activation = activation
|
130 |
+
self.norm_first = norm_first
|
131 |
+
|
132 |
+
def __call__(self, x, mask):
|
133 |
+
if self.norm_first:
|
134 |
+
y = self.ln1(x)
|
135 |
+
y = self.attention(y, y, y, mask)
|
136 |
+
y = self.dropout1(y)
|
137 |
+
x = x + y
|
138 |
+
|
139 |
+
y = self.ln2(x)
|
140 |
+
y = self.linear1(y)
|
141 |
+
y = self.activation(y)
|
142 |
+
y = self.dropout2(y)
|
143 |
+
y = self.linear2(y)
|
144 |
+
y = x + y
|
145 |
+
|
146 |
+
else:
|
147 |
+
y = self.attention(x, x, x, mask)
|
148 |
+
y = self.dropout1(y)
|
149 |
+
y = self.ln1(x + y)
|
150 |
+
|
151 |
+
y = self.linear1(y)
|
152 |
+
y = self.activation(y)
|
153 |
+
y = self.dropout2(y)
|
154 |
+
y = self.linear2(y)
|
155 |
+
y = self.ln2(x + y)
|
156 |
+
|
157 |
+
return y
|
158 |
+
|
159 |
+
|
160 |
+
class TransformerEncoder(Module):
|
161 |
+
def __init__(
|
162 |
+
self,
|
163 |
+
num_layers: int,
|
164 |
+
dims: int,
|
165 |
+
num_heads: int,
|
166 |
+
mlp_dims: Optional[int] = None,
|
167 |
+
dropout: float = 0.0,
|
168 |
+
activation=relu,
|
169 |
+
norm_first: bool = False,
|
170 |
+
):
|
171 |
+
super().__init__()
|
172 |
+
self.layers = [
|
173 |
+
TransformerEncoderLayer(
|
174 |
+
dims, num_heads, mlp_dims, dropout, activation, norm_first
|
175 |
+
)
|
176 |
+
for i in range(num_layers)
|
177 |
+
]
|
178 |
+
self.ln = LayerNorm(dims)
|
179 |
+
|
180 |
+
def __call__(self, x, mask):
|
181 |
+
for l in self.layers:
|
182 |
+
x = l(x, mask)
|
183 |
+
return self.ln(x)
|
184 |
+
|
185 |
+
|
186 |
+
class TransformerDecoderLayer(Module):
|
187 |
+
def __init__(
|
188 |
+
self,
|
189 |
+
dims: int,
|
190 |
+
num_heads: int,
|
191 |
+
mlp_dims: Optional[int] = None,
|
192 |
+
dropout: float = 0.0,
|
193 |
+
activation: Callable[[Any], Any] = relu,
|
194 |
+
norm_first: bool = False,
|
195 |
+
):
|
196 |
+
super().__init__()
|
197 |
+
mlp_dims = mlp_dims or dims * 4
|
198 |
+
self.self_attention = MultiHeadAttention(dims, num_heads)
|
199 |
+
self.cross_attention = MultiHeadAttention(dims, num_heads)
|
200 |
+
self.ln1 = LayerNorm(dims)
|
201 |
+
self.ln2 = LayerNorm(dims)
|
202 |
+
self.ln3 = LayerNorm(dims)
|
203 |
+
self.linear1 = Linear(dims, mlp_dims)
|
204 |
+
self.linear2 = Linear(mlp_dims, dims)
|
205 |
+
self.dropout1 = Dropout(dropout)
|
206 |
+
self.dropout2 = Dropout(dropout)
|
207 |
+
self.dropout3 = Dropout(dropout)
|
208 |
+
self.activation = activation
|
209 |
+
self.norm_first = norm_first
|
210 |
+
|
211 |
+
def __call__(self, x, memory, x_mask, memory_mask):
|
212 |
+
if self.norm_first:
|
213 |
+
y = self.ln1(x)
|
214 |
+
y = self.self_attention(y, y, y, x_mask)
|
215 |
+
y = self.dropout1(y)
|
216 |
+
x = x + y
|
217 |
+
|
218 |
+
y = self.ln2(x)
|
219 |
+
y = self.cross_attention(y, memory, memory, memory_mask)
|
220 |
+
y = self.dropout2(y)
|
221 |
+
x = x + y
|
222 |
+
|
223 |
+
y = self.ln3(x)
|
224 |
+
y = self.linear1(y)
|
225 |
+
y = self.activation(y)
|
226 |
+
y = self.dropout3(y)
|
227 |
+
y = self.linear2(y)
|
228 |
+
y = x + y
|
229 |
+
|
230 |
+
else:
|
231 |
+
y = self.self_attention(x, x, x, x_mask)
|
232 |
+
y = self.dropout1(y)
|
233 |
+
x = self.ln1(x + y)
|
234 |
+
|
235 |
+
y = self.cross_attention(y, memory, memory, memory_mask)
|
236 |
+
y = self.dropout2(y)
|
237 |
+
x = self.ln1(x + y)
|
238 |
+
|
239 |
+
y = self.linear1(x)
|
240 |
+
y = self.activation(y)
|
241 |
+
y = self.dropout3(y)
|
242 |
+
y = self.linear2(y)
|
243 |
+
y = self.ln3(x + y)
|
244 |
+
|
245 |
+
return y
|
246 |
+
|
247 |
+
|
248 |
+
class TransformerDecoder(Module):
|
249 |
+
def __init__(
|
250 |
+
self,
|
251 |
+
num_layers: int,
|
252 |
+
dims: int,
|
253 |
+
num_heads: int,
|
254 |
+
mlp_dims: Optional[int] = None,
|
255 |
+
dropout: float = 0.0,
|
256 |
+
activation=relu,
|
257 |
+
norm_first: bool = False,
|
258 |
+
):
|
259 |
+
super().__init__()
|
260 |
+
self.layers = [
|
261 |
+
TransformerDecoderLayer(
|
262 |
+
dims, num_heads, mlp_dims, dropout, activation, norm_first
|
263 |
+
)
|
264 |
+
for i in range(num_layers)
|
265 |
+
]
|
266 |
+
self.ln = LayerNorm(dims)
|
267 |
+
|
268 |
+
def __call__(self, x, memory, x_mask, memory_mask):
|
269 |
+
for l in self.layers:
|
270 |
+
x = l(x, memory, x_mask, memory_mask)
|
271 |
+
return self.ln(x)
|
272 |
+
|
273 |
+
|
274 |
+
class Transformer(Module):
|
275 |
+
"""
|
276 |
+
Implements a standard Transformer model.
|
277 |
+
|
278 |
+
The implementation is based on `Attention Is All You Need
|
279 |
+
<https://arxiv.org/abs/1706.03762>`_.
|
280 |
+
|
281 |
+
The Transformer model contains an encoder and a decoder. The encoder
|
282 |
+
processes the input sequence and the decoder generates the output sequence.
|
283 |
+
The interaction between encoder and decoder happens through the attention
|
284 |
+
mechanism.
|
285 |
+
|
286 |
+
Args:
|
287 |
+
dims (int, optional): The number of expected features in the
|
288 |
+
encoder/decoder inputs. Default: ``512``.
|
289 |
+
num_heads (int, optional): The number of attention heads. Default:
|
290 |
+
``8``.
|
291 |
+
num_encoder_layers (int, optional): The number of encoder layers in the
|
292 |
+
Transformer encoder. Default: ``6``.
|
293 |
+
num_decoder_layers (int, optional): The number of decoder layers in the
|
294 |
+
Transformer decoder. Default: ``6``.
|
295 |
+
mlp_dims (int, optional): The hidden dimension of the MLP block in each
|
296 |
+
Transformer layer. Defaults to ``4*dims`` if not provided. Default:
|
297 |
+
``None``.
|
298 |
+
dropout (float, optional): The dropout value for the Transformer
|
299 |
+
encoder and decoder. Dropout is used after each attention layer and
|
300 |
+
the activation in the MLP layer. Default: ``0.0``.
|
301 |
+
activation (function, optional): the activation function for the MLP
|
302 |
+
hidden layer. Default: :func:`mlx.nn.relu`.
|
303 |
+
custom_encoder (nn.Module, optional): A custom encoder to replace the
|
304 |
+
standard Transformer encoder. Default: ``None``.
|
305 |
+
custom_decoder (nn.Module, optional): A custom decoder to replace the
|
306 |
+
standard Transformer decoder. Default: ``None``.
|
307 |
+
norm_first (bool, optional): if ``True``, encoder and decoder layers
|
308 |
+
will perform layer normalization before attention and MLP
|
309 |
+
operations, otherwise after. Default: ``False``.
|
310 |
+
"""
|
311 |
+
|
312 |
+
def __init__(
|
313 |
+
self,
|
314 |
+
dims: int = 512,
|
315 |
+
num_heads: int = 8,
|
316 |
+
num_encoder_layers: int = 6,
|
317 |
+
num_decoder_layers: int = 6,
|
318 |
+
mlp_dims: Optional[int] = None,
|
319 |
+
dropout: float = 0.0,
|
320 |
+
activation: Callable[[Any], Any] = relu,
|
321 |
+
custom_encoder: Optional[Any] = None,
|
322 |
+
custom_decoder: Optional[Any] = None,
|
323 |
+
norm_first: bool = False,
|
324 |
+
):
|
325 |
+
super().__init__()
|
326 |
+
if custom_encoder is not None:
|
327 |
+
self.encoder = custom_encoder
|
328 |
+
else:
|
329 |
+
self.encoder = TransformerEncoder(
|
330 |
+
num_encoder_layers,
|
331 |
+
dims,
|
332 |
+
num_heads,
|
333 |
+
mlp_dims,
|
334 |
+
dropout,
|
335 |
+
activation,
|
336 |
+
norm_first,
|
337 |
+
)
|
338 |
+
|
339 |
+
if custom_decoder is not None:
|
340 |
+
self.decoder = custom_decoder
|
341 |
+
else:
|
342 |
+
self.decoder = TransformerDecoder(
|
343 |
+
num_decoder_layers,
|
344 |
+
dims,
|
345 |
+
num_heads,
|
346 |
+
mlp_dims,
|
347 |
+
dropout,
|
348 |
+
activation,
|
349 |
+
norm_first,
|
350 |
+
)
|
351 |
+
|
352 |
+
def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):
|
353 |
+
memory = self.encoder(src, src_mask)
|
354 |
+
return self.decoder(tgt, memory, tgt_mask, memory_mask)
|
lib/python3.11/site-packages/mlx/nn/losses.py
ADDED
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
import mlx.core as mx
|
6 |
+
from mlx.nn.layers.base import Module
|
7 |
+
|
8 |
+
|
9 |
+
def cross_entropy(
|
10 |
+
logits: mx.array,
|
11 |
+
targets: mx.array,
|
12 |
+
weights: mx.array = None,
|
13 |
+
axis: int = -1,
|
14 |
+
label_smoothing: float = 0.0,
|
15 |
+
reduction: str = "none",
|
16 |
+
) -> mx.array:
|
17 |
+
"""
|
18 |
+
Computes the cross entropy loss.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
logits (array): The unnormalized predicted logits.
|
22 |
+
targets (array): The target values, as class indices.
|
23 |
+
weights (array, optional): Weights for each target. Default: ``None``.
|
24 |
+
axis (int, optional): The axis over which to compute softmax. Default: ``-1``.
|
25 |
+
label_smoothing (float, optional): Label smoothing factor. Default: ``0``.
|
26 |
+
reduction (str, optional): Specifies the reduction to apply to the output:
|
27 |
+
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
array: The computed cross entropy loss.
|
31 |
+
"""
|
32 |
+
if label_smoothing < 0 or label_smoothing >= 1:
|
33 |
+
raise ValueError(f"Label smoothing must in [0, 1), got {label_smoothing}.")
|
34 |
+
|
35 |
+
score = mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1)
|
36 |
+
logsumexp_logits = mx.logsumexp(logits, axis=axis)
|
37 |
+
if label_smoothing > 0:
|
38 |
+
# Adjust the true class score with label smoothing
|
39 |
+
adjusted_score = (1 - label_smoothing) * score
|
40 |
+
|
41 |
+
# Calculate the mean logit across the classes for smoothed loss
|
42 |
+
mean_logits = logits.mean(axis=axis)
|
43 |
+
smoothed_loss = -mean_logits * label_smoothing
|
44 |
+
|
45 |
+
# Combine the adjusted score and smoothed loss with the logsumexp logits
|
46 |
+
loss = logsumexp_logits - adjusted_score + smoothed_loss
|
47 |
+
else:
|
48 |
+
loss = logsumexp_logits - score
|
49 |
+
|
50 |
+
# Apply weights if provided
|
51 |
+
if weights is not None:
|
52 |
+
if weights.shape != targets.shape:
|
53 |
+
raise ValueError(
|
54 |
+
f"Weights with shape {weights.shape} is not the same as "
|
55 |
+
f"targets with shape {targets.shape}."
|
56 |
+
)
|
57 |
+
loss *= weights
|
58 |
+
|
59 |
+
# Apply reduction
|
60 |
+
return _reduce(loss, reduction)
|
61 |
+
|
62 |
+
|
63 |
+
def binary_cross_entropy(
|
64 |
+
logits: mx.array, targets: mx.array, reduction: str = "none"
|
65 |
+
) -> mx.array:
|
66 |
+
"""
|
67 |
+
Computes the binary cross entropy loss.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
logits (array): The unnormalized (pre-sigmoid) predicted logits.
|
71 |
+
targets (array): The binary target values in {0, 1}.
|
72 |
+
reduction (str, optional): Specifies the reduction to apply to the output:
|
73 |
+
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
array: The computed binary cross entropy loss.
|
77 |
+
Examples:
|
78 |
+
>>> import mlx.core as mx
|
79 |
+
>>> import mlx.nn as nn
|
80 |
+
>>> inputs = mx.array([0.105361, 0.223144, 1.20397, 0.916291])
|
81 |
+
>>> targets = mx.array([0, 0, 1, 1])
|
82 |
+
>>> loss = nn.losses.binary_cross_entropy(inputs, targets, "mean")
|
83 |
+
>>> loss
|
84 |
+
array([0.612192], dtype=float32)
|
85 |
+
"""
|
86 |
+
loss = mx.logaddexp(0.0, logits) - targets * logits
|
87 |
+
return _reduce(loss, reduction)
|
88 |
+
|
89 |
+
|
90 |
+
def l1_loss(
|
91 |
+
predictions: mx.array, targets: mx.array, reduction: str = "mean"
|
92 |
+
) -> mx.array:
|
93 |
+
"""
|
94 |
+
Computes the L1 loss.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
predictions (array): The predicted values.
|
98 |
+
targets (array): The target values.
|
99 |
+
reduction (str, optional): Specifies the reduction to apply to the output:
|
100 |
+
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``.
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
array: The computed L1 loss.
|
104 |
+
"""
|
105 |
+
if predictions.shape != targets.shape:
|
106 |
+
raise ValueError(
|
107 |
+
f"Predictions shape {predictions.shape} does not match "
|
108 |
+
f"targets shape {targets.shape}."
|
109 |
+
)
|
110 |
+
loss = mx.abs(predictions - targets)
|
111 |
+
|
112 |
+
return _reduce(loss, reduction)
|
113 |
+
|
114 |
+
|
115 |
+
def mse_loss(
|
116 |
+
predictions: mx.array, targets: mx.array, reduction: str = "mean"
|
117 |
+
) -> mx.array:
|
118 |
+
"""
|
119 |
+
Computes the mean squared error loss.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
predictions (array): The predicted values.
|
123 |
+
targets (array): The target values.
|
124 |
+
reduction (str, optional): Specifies the reduction to apply to the output:
|
125 |
+
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``.
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
array: The computed mean squared error loss.
|
129 |
+
"""
|
130 |
+
if predictions.shape != targets.shape:
|
131 |
+
raise ValueError(
|
132 |
+
f"Predictions shape {predictions.shape} does not match "
|
133 |
+
f"targets shape {targets.shape}."
|
134 |
+
)
|
135 |
+
|
136 |
+
loss = mx.square(predictions - targets)
|
137 |
+
return _reduce(loss, reduction)
|
138 |
+
|
139 |
+
|
140 |
+
def nll_loss(
|
141 |
+
inputs: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none"
|
142 |
+
) -> mx.array:
|
143 |
+
"""
|
144 |
+
Computes the negative log likelihood loss.
|
145 |
+
|
146 |
+
Args:
|
147 |
+
inputs (array): The predicted distribution in log space.
|
148 |
+
targets (array): The target values.
|
149 |
+
axis (int, optional): The distribution axis. Default: ``-1``.
|
150 |
+
reduction (str, optional): Specifies the reduction to apply to the output:
|
151 |
+
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
array: The computed NLL loss.
|
155 |
+
"""
|
156 |
+
loss = -mx.take_along_axis(inputs, targets[..., None], axis).squeeze(-1)
|
157 |
+
|
158 |
+
return _reduce(loss, reduction)
|
159 |
+
|
160 |
+
|
161 |
+
def kl_div_loss(
|
162 |
+
inputs: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none"
|
163 |
+
) -> mx.array:
|
164 |
+
"""
|
165 |
+
Computes the Kullback-Leibler divergence loss.
|
166 |
+
|
167 |
+
Computes the following when ``reduction == 'none'``:
|
168 |
+
|
169 |
+
.. code-block:: python
|
170 |
+
|
171 |
+
mx.exp(targets) * (targets - inputs).sum(axis)
|
172 |
+
|
173 |
+
Args:
|
174 |
+
inputs (array): Log probabilities for the predicted distribution.
|
175 |
+
targets (array): Log probabilities for the target distribution.
|
176 |
+
axis (int, optional): The distribution axis. Default: ``-1``.
|
177 |
+
reduction (str, optional): Specifies the reduction to apply to the output:
|
178 |
+
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
179 |
+
|
180 |
+
Returns:
|
181 |
+
array: The computed Kullback-Leibler divergence loss.
|
182 |
+
"""
|
183 |
+
loss = mx.sum(mx.exp(targets) * (targets - inputs), axis)
|
184 |
+
|
185 |
+
return _reduce(loss, reduction)
|
186 |
+
|
187 |
+
|
188 |
+
def smooth_l1_loss(
|
189 |
+
predictions: mx.array, targets: mx.array, beta: float = 1.0, reduction: str = "mean"
|
190 |
+
) -> mx.array:
|
191 |
+
r"""
|
192 |
+
Computes the smooth L1 loss.
|
193 |
+
|
194 |
+
The smooth L1 loss is a variant of the L1 loss which replaces the absolute
|
195 |
+
difference with a squared difference when the absolute difference is less
|
196 |
+
than ``beta``.
|
197 |
+
|
198 |
+
The formula for the smooth L1 Loss is:
|
199 |
+
|
200 |
+
.. math::
|
201 |
+
|
202 |
+
l =
|
203 |
+
\begin{cases}
|
204 |
+
0.5 (x - y)^2, & \text{ if } & (x - y) < \beta \\
|
205 |
+
|x - y| - 0.5 \beta, & & \text{otherwise}
|
206 |
+
\end{cases}
|
207 |
+
|
208 |
+
Args:
|
209 |
+
predictions (array): Predicted values.
|
210 |
+
targets (array): Ground truth values.
|
211 |
+
beta (float, optional): The threshold after which the loss changes
|
212 |
+
from the squared to the absolute difference. Default: ``1.0``.
|
213 |
+
reduction (str, optional): Specifies the reduction to apply to the output:
|
214 |
+
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``.
|
215 |
+
|
216 |
+
Returns:
|
217 |
+
array: The computed smooth L1 loss.
|
218 |
+
"""
|
219 |
+
if predictions.shape != targets.shape:
|
220 |
+
raise ValueError(
|
221 |
+
f"Predictions shape {predictions.shape} does not match "
|
222 |
+
f"targets shape {targets.shape}."
|
223 |
+
)
|
224 |
+
|
225 |
+
diff = predictions - targets
|
226 |
+
loss = mx.where(
|
227 |
+
diff < beta, 0.5 * mx.square(diff) / beta, mx.abs(diff) - 0.5 * beta
|
228 |
+
)
|
229 |
+
|
230 |
+
return _reduce(loss, reduction)
|
231 |
+
|
232 |
+
|
233 |
+
def triplet_loss(
|
234 |
+
anchors: mx.array,
|
235 |
+
positives: mx.array,
|
236 |
+
negatives: mx.array,
|
237 |
+
axis: int = -1,
|
238 |
+
p: int = 2,
|
239 |
+
margin: float = 1.0,
|
240 |
+
eps: float = 1e-6,
|
241 |
+
reduction: str = "none",
|
242 |
+
) -> mx.array:
|
243 |
+
r"""
|
244 |
+
Computes the triplet loss for a set of anchor, positive, and negative samples.
|
245 |
+
Margin is represented with alpha in the math section.
|
246 |
+
|
247 |
+
.. math::
|
248 |
+
|
249 |
+
L_{\text{triplet}} = \max\left(\|A - P\|_p - \|A - N\|_p + \alpha, 0\right)
|
250 |
+
|
251 |
+
Args:
|
252 |
+
anchors (array): The anchor samples.
|
253 |
+
positives (array): The positive samples.
|
254 |
+
negatives (array): The negative samples.
|
255 |
+
axis (int, optional): The distribution axis. Default: ``-1``.
|
256 |
+
p (int, optional): The norm degree for pairwise distance. Default: ``2``.
|
257 |
+
margin (float, optional): Margin for the triplet loss. Defaults to ``1.0``.
|
258 |
+
eps (float, optional): Small positive constant to prevent numerical instability. Defaults to ``1e-6``.
|
259 |
+
reduction (str, optional): Specifies the reduction to apply to the output:
|
260 |
+
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
261 |
+
|
262 |
+
Returns:
|
263 |
+
array: Computed triplet loss. If reduction is "none", returns a tensor of the same shape as input;
|
264 |
+
if reduction is "mean" or "sum", returns a scalar tensor.
|
265 |
+
"""
|
266 |
+
loss = mx.maximum(
|
267 |
+
mx.sqrt(mx.power(anchors - positives, p).sum(axis) + eps)
|
268 |
+
- mx.sqrt(mx.power(anchors - negatives, p).sum(axis) + eps)
|
269 |
+
+ margin,
|
270 |
+
0,
|
271 |
+
)
|
272 |
+
return _reduce(loss, reduction)
|
273 |
+
|
274 |
+
|
275 |
+
def _reduce(loss: mx.array, reduction: str = "none"):
|
276 |
+
if reduction == "mean":
|
277 |
+
return mx.mean(loss)
|
278 |
+
elif reduction == "sum":
|
279 |
+
return mx.sum(loss)
|
280 |
+
elif reduction == "none":
|
281 |
+
return loss
|
282 |
+
else:
|
283 |
+
raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
|
284 |
+
|
285 |
+
|
286 |
+
def hinge_loss(
|
287 |
+
inputs: mx.array, targets: mx.array, reduction: str = "none"
|
288 |
+
) -> mx.array:
|
289 |
+
r"""
|
290 |
+
Computes the hinge loss between inputs and targets.
|
291 |
+
|
292 |
+
.. math::
|
293 |
+
|
294 |
+
\text{hinge}(y, y_{\text{pred}}) = \max(0, 1 - y \cdot y_{\text{pred}})
|
295 |
+
|
296 |
+
|
297 |
+
Args:
|
298 |
+
inputs (array): The predicted values.
|
299 |
+
targets (array): The target values. They should be -1 or 1.
|
300 |
+
reduction (str, optional): Specifies the reduction to apply to the output:
|
301 |
+
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
302 |
+
|
303 |
+
Returns:
|
304 |
+
array: The computed hinge loss.
|
305 |
+
"""
|
306 |
+
loss = mx.maximum(1 - inputs * targets, 0)
|
307 |
+
|
308 |
+
return _reduce(loss, reduction)
|
309 |
+
|
310 |
+
|
311 |
+
def huber_loss(
|
312 |
+
inputs: mx.array, targets: mx.array, delta: float = 1.0, reduction: str = "none"
|
313 |
+
) -> mx.array:
|
314 |
+
r"""
|
315 |
+
Computes the Huber loss between inputs and targets.
|
316 |
+
|
317 |
+
.. math::
|
318 |
+
|
319 |
+
L_{\delta}(a) =
|
320 |
+
\left\{ \begin{array}{ll}
|
321 |
+
\frac{1}{2} a^2 & \text{for } |a| \leq \delta, \\
|
322 |
+
\delta \left( |a| - \frac{1}{2} \delta \right) & \text{otherwise.}
|
323 |
+
\end{array} \right.
|
324 |
+
|
325 |
+
Args:
|
326 |
+
inputs (array): The predicted values.
|
327 |
+
targets (array): The target values.
|
328 |
+
delta (float, optional): The threshold at which to change between L1 and L2 loss.
|
329 |
+
Default: ``1.0``.
|
330 |
+
reduction (str, optional): Specifies the reduction to apply to the output:
|
331 |
+
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
332 |
+
|
333 |
+
Returns:
|
334 |
+
array: The computed Huber loss.
|
335 |
+
"""
|
336 |
+
errors = inputs - targets
|
337 |
+
abs_errors = mx.abs(errors)
|
338 |
+
quadratic = mx.minimum(abs_errors, delta)
|
339 |
+
linear = abs_errors - quadratic
|
340 |
+
loss = 0.5 * quadratic**2 + delta * linear
|
341 |
+
|
342 |
+
return _reduce(loss, reduction)
|
343 |
+
|
344 |
+
|
345 |
+
def log_cosh_loss(
|
346 |
+
inputs: mx.array, targets: mx.array, reduction: str = "none"
|
347 |
+
) -> mx.array:
|
348 |
+
r"""
|
349 |
+
Computes the log cosh loss between inputs and targets.
|
350 |
+
|
351 |
+
Logcosh acts like L2 loss for small errors, ensuring stable gradients,
|
352 |
+
and like the L1 loss for large errors, reducing sensitivity to outliers. This
|
353 |
+
dual behavior offers a balanced, robust approach for regression tasks.
|
354 |
+
|
355 |
+
.. math::
|
356 |
+
|
357 |
+
\text{logcosh}(y_{\text{true}}, y_{\text{pred}}) =
|
358 |
+
\frac{1}{n} \sum_{i=1}^{n}
|
359 |
+
\log(\cosh(y_{\text{pred}}^{(i)} - y_{\text{true}}^{(i)}))
|
360 |
+
|
361 |
+
|
362 |
+
Args:
|
363 |
+
inputs (array): The predicted values.
|
364 |
+
targets (array): The target values.
|
365 |
+
reduction (str, optional): Specifies the reduction to apply to the output:
|
366 |
+
``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
|
367 |
+
|
368 |
+
Returns:
|
369 |
+
array: The computed log cosh loss.
|
370 |
+
"""
|
371 |
+
errors = inputs - targets
|
372 |
+
loss = mx.logaddexp(errors, -errors) - math.log(2)
|
373 |
+
|
374 |
+
return _reduce(loss, reduction)
|
lib/python3.11/site-packages/mlx/nn/utils.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
from typing import Callable
|
4 |
+
|
5 |
+
import mlx.core as mx
|
6 |
+
|
7 |
+
|
8 |
+
def value_and_grad(model: "mlx.nn.Module", fn: Callable):
|
9 |
+
"""Transform the passed function ``fn`` to a function that computes the
|
10 |
+
gradients of ``fn`` wrt the model's trainable parameters and also its
|
11 |
+
value.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
model (mlx.nn.Module): The model whose trainable parameters to compute
|
15 |
+
gradients for
|
16 |
+
fn (Callable): The scalar function to compute gradients for
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
A callable that returns the value of ``fn`` and the gradients wrt the
|
20 |
+
trainable parameters of ``model``
|
21 |
+
"""
|
22 |
+
|
23 |
+
def inner_fn(params, *args, **kwargs):
|
24 |
+
model.update(params)
|
25 |
+
return fn(*args, **kwargs)
|
26 |
+
|
27 |
+
value_grad_fn = mx.value_and_grad(inner_fn)
|
28 |
+
|
29 |
+
def wrapped_value_grad_fn(*args, **kwargs):
|
30 |
+
value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs)
|
31 |
+
return value, grad
|
32 |
+
|
33 |
+
return wrapped_value_grad_fn
|
lib/python3.11/site-packages/mlx/optimizers.py
ADDED
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
import math
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
import mlx.core as mx
|
7 |
+
from mlx.utils import tree_map
|
8 |
+
|
9 |
+
|
10 |
+
class OptimizerState(dict):
|
11 |
+
"""The optimizer state implements a recursively defined
|
12 |
+
:class:`collections.defaultdict`, namely a missing key in an optimizer
|
13 |
+
state is an :class:`OptimizerState`.
|
14 |
+
|
15 |
+
.. note::
|
16 |
+
:meth:`OptimizerState.get` in contrast to a normal dictionary also sets
|
17 |
+
the key to the ``default`` value if the ``key`` was not present in the
|
18 |
+
dictionary.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __getitem__(self, key):
|
22 |
+
if key not in self:
|
23 |
+
self[key] = OptimizerState()
|
24 |
+
return super().__getitem__(key)
|
25 |
+
|
26 |
+
def get(self, key, default):
|
27 |
+
"""If ``key`` doesn't exist set its value to ``default`` and then return it."""
|
28 |
+
if key not in self:
|
29 |
+
self[key] = default
|
30 |
+
return super().__getitem__(key)
|
31 |
+
|
32 |
+
|
33 |
+
class Optimizer:
|
34 |
+
"""The base class for all optimizers. It allows us to implement an
|
35 |
+
optimizer on a per-parameter basis and apply it to a parameter tree.
|
36 |
+
|
37 |
+
Attributes:
|
38 |
+
state (OptimizerState): It holds the optimizer's state dictionary.
|
39 |
+
"""
|
40 |
+
|
41 |
+
def __init__(self):
|
42 |
+
self.state = OptimizerState()
|
43 |
+
|
44 |
+
def update(self, model: "mlx.nn.Module", gradients: dict):
|
45 |
+
"""Apply the gradients to the parameters of the model and update the
|
46 |
+
model with the new parameters.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
model (mlx.nn.Module): An mlx module to be updated.
|
50 |
+
gradients (dict): A Python tree of gradients, most likely computed
|
51 |
+
via :func:`mlx.nn.value_and_grad`.
|
52 |
+
"""
|
53 |
+
model.update(self.apply_gradients(gradients, model))
|
54 |
+
|
55 |
+
def apply_gradients(self, gradients: dict, model: dict):
|
56 |
+
"""Apply the gradients to the parameters and return the updated parameters.
|
57 |
+
|
58 |
+
Can be used to update a model via
|
59 |
+
``model.update(opt.apply_gradients(grads, model))`` which is precisely
|
60 |
+
how :meth:`Optimizer.update` is implemented.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
gradients (dict): A Python tree of gradients.
|
64 |
+
model (dict): A Python tree of parameters. It can be a superset of
|
65 |
+
the gradients. In that case the returned python tree
|
66 |
+
will be of the same structure as the gradients.
|
67 |
+
"""
|
68 |
+
return tree_map(self.apply_single, gradients, model, self.state)
|
69 |
+
|
70 |
+
def apply_single(
|
71 |
+
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
72 |
+
):
|
73 |
+
"""To be extended by the children classes to implement each optimizer's
|
74 |
+
update."""
|
75 |
+
raise NotImplementedError()
|
76 |
+
|
77 |
+
|
78 |
+
class SGD(Optimizer):
|
79 |
+
r"""Stochastic gradient descent optimizer.
|
80 |
+
|
81 |
+
Updates a parameter :math:`w` with a gradient :math:`g` as follows
|
82 |
+
|
83 |
+
.. math::
|
84 |
+
|
85 |
+
v_{t+1} &= \mu v_t + (1 - \tau) g_t \\
|
86 |
+
w_{t+1} &= w_t - \lambda v_{t+1}
|
87 |
+
|
88 |
+
Args:
|
89 |
+
learning_rate (float): The learning rate :math:`\lambda`.
|
90 |
+
momentum (float, optional): The momentum strength :math:`\mu`. Default: ``0``
|
91 |
+
weight_decay (float, optional): The weight decay (L2 penalty). Default: ``0``
|
92 |
+
dampening (float, optional): Dampening for momentum :math:`\tau`. Default: ``0``
|
93 |
+
nesterov (bool, optional): Enables Nesterov momentum. Default: ``False``
|
94 |
+
"""
|
95 |
+
|
96 |
+
def __init__(
|
97 |
+
self,
|
98 |
+
learning_rate: float,
|
99 |
+
momentum: float = 0.0,
|
100 |
+
weight_decay: float = 0.0,
|
101 |
+
dampening: float = 0.0,
|
102 |
+
nesterov: bool = False,
|
103 |
+
):
|
104 |
+
if nesterov and (momentum <= 0 or dampening != 0):
|
105 |
+
raise ValueError(
|
106 |
+
"Nesterov momentum requires a momentum and zero dampening."
|
107 |
+
)
|
108 |
+
super().__init__()
|
109 |
+
|
110 |
+
self.learning_rate = learning_rate
|
111 |
+
self.momentum = momentum
|
112 |
+
self.weight_decay = weight_decay
|
113 |
+
self.dampening = dampening
|
114 |
+
self.nesterov = nesterov
|
115 |
+
|
116 |
+
def apply_single(
|
117 |
+
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
118 |
+
):
|
119 |
+
"""Performs the SGD parameter update and stores :math:`v` in the
|
120 |
+
optimizer state."""
|
121 |
+
if self.momentum <= 0:
|
122 |
+
return parameter - self.learning_rate * gradient
|
123 |
+
|
124 |
+
v = state.get("v", mx.zeros_like(gradient))
|
125 |
+
|
126 |
+
if self.weight_decay != 0:
|
127 |
+
gradient += self.weight_decay * parameter
|
128 |
+
|
129 |
+
v = self.momentum * v
|
130 |
+
if self.dampening > 0:
|
131 |
+
v += (1 - self.dampening) * gradient
|
132 |
+
else:
|
133 |
+
v += gradient
|
134 |
+
|
135 |
+
if self.nesterov:
|
136 |
+
update = gradient + self.momentum * v
|
137 |
+
else:
|
138 |
+
update = v
|
139 |
+
state["v"] = v
|
140 |
+
return parameter - self.learning_rate * update
|
141 |
+
|
142 |
+
|
143 |
+
class RMSprop(Optimizer):
|
144 |
+
r"""Implementation of the RMSprop optimizer [1].
|
145 |
+
|
146 |
+
[1]: Tieleman, T. and Hinton, G. 2012. Lecture 6.5-rmsprop, coursera: Neural networks for machine learning
|
147 |
+
|
148 |
+
.. math::
|
149 |
+
|
150 |
+
v_{t+1} &= \alpha v_t + (1 - \alpha) g_t^2 \\
|
151 |
+
w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}
|
152 |
+
|
153 |
+
Args:
|
154 |
+
learning_rate (float): The learning rate :math:`\lambda`.
|
155 |
+
alpha (float, optional): The smoothing constant :math:`\alpha`.
|
156 |
+
Default: ``0.99``
|
157 |
+
eps (float, optional): The term :math:`\epsilon` added to the denominator
|
158 |
+
to improve numerical stability. Default: ``1e-8``
|
159 |
+
"""
|
160 |
+
|
161 |
+
def __init__(self, learning_rate: float, alpha: float = 0.99, eps: float = 1e-8):
|
162 |
+
super().__init__()
|
163 |
+
|
164 |
+
self.learning_rate = learning_rate
|
165 |
+
self.alpha = alpha
|
166 |
+
self.eps = eps
|
167 |
+
|
168 |
+
if self.alpha < 0.0:
|
169 |
+
raise ValueError(
|
170 |
+
f"RMSprop alpha should be >=0, {self.alpha} was provided instead"
|
171 |
+
)
|
172 |
+
if self.eps < 0.0:
|
173 |
+
raise ValueError(
|
174 |
+
f"RMSprop epsilon should be >0, {self.eps} was provided instead"
|
175 |
+
)
|
176 |
+
|
177 |
+
def apply_single(
|
178 |
+
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
179 |
+
):
|
180 |
+
"""Performs the RMSprop parameter update and stores :math:`v` in the optimizer state."""
|
181 |
+
lr = self.learning_rate
|
182 |
+
alpha = self.alpha
|
183 |
+
eps = self.eps
|
184 |
+
|
185 |
+
v = state.get("v", mx.zeros_like(gradient))
|
186 |
+
v = alpha * v + (1 - alpha) * mx.square(gradient)
|
187 |
+
state["v"] = v
|
188 |
+
|
189 |
+
return parameter - lr * gradient / (mx.sqrt(v) + eps)
|
190 |
+
|
191 |
+
|
192 |
+
class Adagrad(Optimizer):
|
193 |
+
r"""Implementation of the Adagrad optimizer [1].
|
194 |
+
|
195 |
+
Our Adagrad implementation follows the original paper. In detail,
|
196 |
+
|
197 |
+
[1]: Duchi, J., Hazan, E. and Singer, Y., 2011. Adaptive subgradient methods
|
198 |
+
for online learning and stochastic optimization. JMLR 2011.
|
199 |
+
|
200 |
+
.. math::
|
201 |
+
|
202 |
+
v_{t+1} &= v_t + g_t^2 \\
|
203 |
+
w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}
|
204 |
+
|
205 |
+
Args:
|
206 |
+
learning_rate (float): The learning rate :math:`\lambda`.
|
207 |
+
eps (float, optional): The term :math:`\epsilon` added to the
|
208 |
+
denominator to improve numerical stability. Default: ``1e-8``
|
209 |
+
"""
|
210 |
+
|
211 |
+
def __init__(self, learning_rate: float, eps: float = 1e-8):
|
212 |
+
super().__init__()
|
213 |
+
|
214 |
+
self.learning_rate = learning_rate
|
215 |
+
self.eps = eps
|
216 |
+
|
217 |
+
if self.eps < 0.0:
|
218 |
+
raise ValueError(
|
219 |
+
f"Adagrad epsilon should be >0, {self.eps} was provided instead"
|
220 |
+
)
|
221 |
+
|
222 |
+
def apply_single(
|
223 |
+
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
224 |
+
):
|
225 |
+
"""Performs the Adagrad parameter update and stores :math:`v` in the
|
226 |
+
optimizer state."""
|
227 |
+
lr = self.learning_rate
|
228 |
+
eps = self.eps
|
229 |
+
|
230 |
+
v = state.get("v", mx.zeros_like(gradient))
|
231 |
+
v = v + mx.square(gradient)
|
232 |
+
state["v"] = v
|
233 |
+
|
234 |
+
return parameter - lr * gradient / (mx.sqrt(v) + eps)
|
235 |
+
|
236 |
+
|
237 |
+
class AdaDelta(Optimizer):
|
238 |
+
r"""Implementation of the AdaDelta optimizer with learning rate[1].
|
239 |
+
|
240 |
+
Our AdaDelta implementation follows the original paper. In detail,
|
241 |
+
|
242 |
+
[1]: Zeiler, M.D., 2012. ADADELTA: an adaptive learning rate method. arXiv preprint arXiv:1212.5701.
|
243 |
+
|
244 |
+
.. math::
|
245 |
+
|
246 |
+
v_{t+1} &= \rho v_t + (1 - \rho) g_t^2 \\
|
247 |
+
\Delta w_{t+1} &= \frac{\sqrt{u_t + \epsilon}}{\sqrt{v_{t+1} + \epsilon}} g_t \\
|
248 |
+
u_{t+1} &= \rho u_t + (1 - \rho) \Delta w_{t+1}^2 \\
|
249 |
+
w_{t+1} &= w_t - \lambda \Delta w_{t+1}
|
250 |
+
|
251 |
+
Args:
|
252 |
+
learning_rate (float): The learning rate :math:`\lambda`.
|
253 |
+
rho (float, optional): The coefficient :math:`\rho` used for computing a
|
254 |
+
running average of squared gradients. Default: ``0.9``
|
255 |
+
eps (float, optional): The term :math:`\epsilon` added to the denominator to improve
|
256 |
+
numerical stability. Default: `1e-8`
|
257 |
+
"""
|
258 |
+
|
259 |
+
def __init__(self, learning_rate: float, rho: float = 0.9, eps: float = 1e-6):
|
260 |
+
super().__init__()
|
261 |
+
|
262 |
+
self.learning_rate = learning_rate
|
263 |
+
self.rho = rho
|
264 |
+
self.eps = eps
|
265 |
+
if self.rho < 0.0:
|
266 |
+
raise ValueError(
|
267 |
+
f"AdaDelta rho should be >=0, {self.rho} was provided instead"
|
268 |
+
)
|
269 |
+
if self.eps < 0.0:
|
270 |
+
raise ValueError(
|
271 |
+
f"AdaDelta epsilon should be >0, {self.eps} was provided instead"
|
272 |
+
)
|
273 |
+
|
274 |
+
def apply_single(
|
275 |
+
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
276 |
+
):
|
277 |
+
"""Performs the AdaDelta parameter update and stores :math:`v` and
|
278 |
+
:math:`u` in the optimizer state."""
|
279 |
+
lr = self.learning_rate
|
280 |
+
rho = self.rho
|
281 |
+
eps = self.eps
|
282 |
+
|
283 |
+
v = state.get("v", mx.zeros_like(gradient))
|
284 |
+
u = state.get("s", mx.zeros_like(gradient))
|
285 |
+
|
286 |
+
v = rho * v + (1 - rho) * mx.square(gradient)
|
287 |
+
d = mx.sqrt(u + eps) / mx.sqrt(v + eps) * gradient
|
288 |
+
u = rho * u + (1 - rho) * mx.square(d)
|
289 |
+
|
290 |
+
state["v"] = v
|
291 |
+
state["u"] = u
|
292 |
+
|
293 |
+
return parameter - lr * d
|
294 |
+
|
295 |
+
|
296 |
+
class Adam(Optimizer):
|
297 |
+
r"""Implementation of the Adam optimizer [1].
|
298 |
+
|
299 |
+
Our Adam implementation follows the original paper and omits the bias
|
300 |
+
correction in the first and second moment estimates. In detail,
|
301 |
+
|
302 |
+
[1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic
|
303 |
+
optimization. ICLR 2015.
|
304 |
+
|
305 |
+
.. math::
|
306 |
+
|
307 |
+
m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
|
308 |
+
v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
|
309 |
+
w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}}
|
310 |
+
|
311 |
+
Args:
|
312 |
+
learning_rate (float): The learning rate :math:`\lambda`.
|
313 |
+
betas (Tuple[float, float], optional): The coefficients
|
314 |
+
:math:`(\beta_1, \beta_2)` used for computing running averages of the
|
315 |
+
gradient and its square. Default: ``(0.9, 0.999)``
|
316 |
+
eps (float, optional): The term :math:`\epsilon` added to the
|
317 |
+
denominator to improve numerical stability. Default: ``1e-8``
|
318 |
+
"""
|
319 |
+
|
320 |
+
def __init__(
|
321 |
+
self, learning_rate: float, betas: List[float] = [0.9, 0.999], eps: float = 1e-8
|
322 |
+
):
|
323 |
+
super().__init__()
|
324 |
+
|
325 |
+
self.learning_rate = learning_rate
|
326 |
+
self.betas = betas
|
327 |
+
self.eps = eps
|
328 |
+
|
329 |
+
def apply_single(
|
330 |
+
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
331 |
+
):
|
332 |
+
"""Performs the Adam parameter update and stores :math:`v` and
|
333 |
+
:math:`m` in the optimizer state."""
|
334 |
+
lr = self.learning_rate
|
335 |
+
b1, b2 = self.betas
|
336 |
+
eps = self.eps
|
337 |
+
|
338 |
+
m = state.get("m", gradient)
|
339 |
+
v = state.get("v", mx.square(gradient))
|
340 |
+
m = b1 * m + (1 - b1) * gradient
|
341 |
+
v = b2 * v + (1 - b2) * mx.square(gradient)
|
342 |
+
state["m"] = m
|
343 |
+
state["v"] = v
|
344 |
+
|
345 |
+
return parameter - lr * m / (mx.sqrt(v) + eps)
|
346 |
+
|
347 |
+
|
348 |
+
class AdamW(Adam):
|
349 |
+
r"""Implementation of the AdamW optimizer [1].
|
350 |
+
|
351 |
+
Following the above convention, in contrast with [1], we do not use bias
|
352 |
+
correction in the first and second moments for AdamW. We update the weights
|
353 |
+
with a weight_decay (:math:`\lambda`) value:
|
354 |
+
|
355 |
+
[1]: Loshchilov, I. and Hutter, F., 2019. Decoupled weight decay
|
356 |
+
regularization. ICLR 2019.
|
357 |
+
|
358 |
+
.. math::
|
359 |
+
|
360 |
+
m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
|
361 |
+
v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
|
362 |
+
w_{t+1} &= w_t - \alpha (\frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}} + \lambda w_t)
|
363 |
+
|
364 |
+
Args:
|
365 |
+
learning_rate (float): The learning rate :math:`\alpha`.
|
366 |
+
betas (Tuple[float, float], optional): The coefficients
|
367 |
+
:math:`(\beta_1, \beta_2)` used for computing running averages of the
|
368 |
+
gradient and its square. Default: ``(0.9, 0.999)``
|
369 |
+
eps (float, optional): The term :math:`\epsilon` added to the
|
370 |
+
denominator to improve numerical stability. Default: ``1e-8``
|
371 |
+
weight_decay (float, optional): The weight decay :math:`\lambda`.
|
372 |
+
Default: ``0``.
|
373 |
+
"""
|
374 |
+
|
375 |
+
def __init__(
|
376 |
+
self,
|
377 |
+
learning_rate: float,
|
378 |
+
betas: List[float] = [0.9, 0.999],
|
379 |
+
eps: float = 1e-8,
|
380 |
+
weight_decay: float = 0.01,
|
381 |
+
):
|
382 |
+
super().__init__(learning_rate=learning_rate, betas=betas, eps=eps)
|
383 |
+
self.weight_decay = weight_decay
|
384 |
+
|
385 |
+
def apply_single(
|
386 |
+
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
387 |
+
):
|
388 |
+
"""Performs the AdamW parameter update by modifying the parameters
|
389 |
+
passed into Adam.
|
390 |
+
"""
|
391 |
+
|
392 |
+
return super().apply_single(
|
393 |
+
gradient, parameter * (1 - self.learning_rate * self.weight_decay), state
|
394 |
+
)
|
395 |
+
|
396 |
+
|
397 |
+
class Adamax(Adam):
|
398 |
+
r"""Implementation of the Adamax optimizer. It is a variant of Adam based
|
399 |
+
on the infinity norm [1].
|
400 |
+
|
401 |
+
Our Adam implementation follows the original paper and omits the bias
|
402 |
+
correction in the first and second moment estimates. In detail,
|
403 |
+
|
404 |
+
[1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic
|
405 |
+
optimization. ICLR 2015.
|
406 |
+
|
407 |
+
.. math::
|
408 |
+
|
409 |
+
m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
|
410 |
+
v_{t+1} &= \max(\beta_2 v_t, |g_t|) \\
|
411 |
+
w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{v_{t+1} + \epsilon}
|
412 |
+
|
413 |
+
Args:
|
414 |
+
learning_rate (float): The learning rate :math:`\lambda`.
|
415 |
+
betas (Tuple[float, float], optional): The coefficients
|
416 |
+
:math:`(\beta_1, \beta_2)` used for computing running averages of the
|
417 |
+
gradient and its square. Default: ``(0.9, 0.999)``
|
418 |
+
eps (float, optional): The term :math:`\epsilon` added to the
|
419 |
+
denominator to improve numerical stability. Default: ``1e-8``
|
420 |
+
"""
|
421 |
+
|
422 |
+
def __init__(
|
423 |
+
self, learning_rate: float, betas: List[float] = [0.9, 0.999], eps: float = 1e-8
|
424 |
+
):
|
425 |
+
super().__init__(learning_rate, betas, eps)
|
426 |
+
|
427 |
+
def apply_single(
|
428 |
+
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
429 |
+
):
|
430 |
+
"""Performs the Adamax parameter update and stores :math:`v` and
|
431 |
+
:math:`m` in the optimizer state."""
|
432 |
+
lr = self.learning_rate
|
433 |
+
b1, b2 = self.betas
|
434 |
+
eps = self.eps
|
435 |
+
|
436 |
+
m = state.get("m", mx.zeros_like(gradient))
|
437 |
+
v = state.get("v", mx.zeros_like(gradient))
|
438 |
+
|
439 |
+
m = b1 * m + (1 - b1) * gradient
|
440 |
+
v = mx.maximum(b2 * v, mx.abs(gradient))
|
441 |
+
state["m"] = m
|
442 |
+
state["v"] = v
|
443 |
+
|
444 |
+
return parameter - lr * m / (v + eps)
|
445 |
+
|
446 |
+
|
447 |
+
class Lion(Optimizer):
|
448 |
+
r"""Implementation of the Lion optimizer [1].
|
449 |
+
|
450 |
+
Since updates are computed through the sign operation, they tend to
|
451 |
+
have larger norm than for other optimizers such as SGD and Adam.
|
452 |
+
We recommend a learning rate that is 3-10x smaller than AdamW and a
|
453 |
+
weight decay 3-10x larger than AdamW to maintain the strength
|
454 |
+
(lr * wd). Our Lion implementation follows the original paper. In
|
455 |
+
detail,
|
456 |
+
|
457 |
+
[1]: Chen, X. Symbolic Discovery of Optimization Algorithms. arXiv
|
458 |
+
preprint arXiv:2302.06675.
|
459 |
+
|
460 |
+
.. math::
|
461 |
+
|
462 |
+
c_{t + 1} &= \beta_1 m_t + (1 - \beta_1) g_t
|
463 |
+
m_{t + 1} &= \beta_2 m_t + (1 - \beta_2) g_t
|
464 |
+
w_{t + 1} &= w_t - \eta (\text{sign}(c_t) + \lambda w_t)
|
465 |
+
|
466 |
+
Args:
|
467 |
+
learning_rate (float): The learning rate :math:`\eta`.
|
468 |
+
betas (Tuple[float, float], optional): The coefficients
|
469 |
+
:math:`(\beta_1, \beta_2)` used for computing the gradient
|
470 |
+
momentum and update direction. Default: ``(0.9, 0.99)``
|
471 |
+
weight_decay (float, optional): The weight decay :math:`\lambda`. Default: ``0.0``
|
472 |
+
"""
|
473 |
+
|
474 |
+
def __init__(
|
475 |
+
self,
|
476 |
+
learning_rate: float,
|
477 |
+
betas: List[float] = [0.9, 0.99],
|
478 |
+
weight_decay: float = 0.0,
|
479 |
+
):
|
480 |
+
super().__init__()
|
481 |
+
|
482 |
+
self.learning_rate = learning_rate
|
483 |
+
self.betas = betas
|
484 |
+
self.weight_decay = weight_decay
|
485 |
+
|
486 |
+
def apply_single(
|
487 |
+
self, gradient: mx.array, parameter: mx.array, state: OptimizerState
|
488 |
+
):
|
489 |
+
"""Performs the Lion parameter update and stores :math:`m`
|
490 |
+
in the optimizer state."""
|
491 |
+
lr = self.learning_rate
|
492 |
+
b1, b2 = self.betas
|
493 |
+
weight_decay = self.weight_decay
|
494 |
+
|
495 |
+
m = state.get("m", gradient)
|
496 |
+
c = b1 * m + (1 - b1) * gradient
|
497 |
+
state["m"] = b2 * m + (1 - b2) * gradient
|
498 |
+
if weight_decay > 0:
|
499 |
+
parameter = (1 - lr * weight_decay) * parameter
|
500 |
+
return parameter - lr * mx.sign(c)
|
lib/python3.11/site-packages/mlx/share/cmake/MLX/MLXConfig.cmake
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Find MLX
|
2 |
+
#
|
3 |
+
# Defines the following variables:
|
4 |
+
#
|
5 |
+
# MLX_FOUND : True if MLX is found
|
6 |
+
# MLX_INCLUDE_DIRS : Include directory
|
7 |
+
# MLX_LIBRARIES : Libraries to link against
|
8 |
+
# MLX_CXX_FLAGS : Additional compiler flags
|
9 |
+
# MLX_BUILD_ACCELERATE : True if MLX was built with accelerate
|
10 |
+
# MLX_BUILD_METAL : True if MLX was built with metal
|
11 |
+
|
12 |
+
|
13 |
+
####### Expanded from @PACKAGE_INIT@ by configure_package_config_file() #######
|
14 |
+
####### Any changes to this file will be overwritten by the next CMake run ####
|
15 |
+
####### The input file was mlx.pc.in ########
|
16 |
+
|
17 |
+
get_filename_component(PACKAGE_PREFIX_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../" ABSOLUTE)
|
18 |
+
|
19 |
+
macro(set_and_check _var _file)
|
20 |
+
set(${_var} "${_file}")
|
21 |
+
if(NOT EXISTS "${_file}")
|
22 |
+
message(FATAL_ERROR "File or directory ${_file} referenced by variable ${_var} does not exist !")
|
23 |
+
endif()
|
24 |
+
endmacro()
|
25 |
+
|
26 |
+
####################################################################################
|
27 |
+
|
28 |
+
include(${PACKAGE_PREFIX_DIR}/share/cmake/MLX/MLXTargets.cmake)
|
29 |
+
include(${PACKAGE_PREFIX_DIR}/share/cmake/MLX/extension.cmake)
|
30 |
+
|
31 |
+
set_and_check(MLX_LIBRARY_DIRS ${PACKAGE_PREFIX_DIR}/lib)
|
32 |
+
set_and_check(MLX_INCLUDE_DIRS ${PACKAGE_PREFIX_DIR}/include)
|
33 |
+
set(MLX_LIBRARIES mlx)
|
34 |
+
|
35 |
+
find_library(MLX_LIBRARY mlx PATHS ${MLX_LIBRARY_DIRS})
|
36 |
+
|
37 |
+
if (ON)
|
38 |
+
set(MLX_BUILD_ACCELERATE ON)
|
39 |
+
set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -DACCELERATE_NEW_LAPACK)
|
40 |
+
endif()
|
41 |
+
|
42 |
+
if (ON)
|
43 |
+
set(MLX_BUILD_METAL ON)
|
44 |
+
set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_)
|
45 |
+
set_and_check(MLX_INCLUDE_DIRS
|
46 |
+
${MLX_INCLUDE_DIRS}
|
47 |
+
${PACKAGE_PREFIX_DIR}/include/metal_cpp
|
48 |
+
)
|
49 |
+
endif()
|
50 |
+
|
51 |
+
set_target_properties(mlx PROPERTIES
|
52 |
+
CXX_STANDARD 17
|
53 |
+
INTERFACE_COMPILE_OPTIONS "${MLX_CXX_FLAGS}"
|
54 |
+
)
|
55 |
+
|
56 |
+
include(FindPackageHandleStandardArgs)
|
57 |
+
find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS)
|
lib/python3.11/site-packages/mlx/share/cmake/MLX/MLXConfigVersion.cmake
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This is a basic version file for the Config-mode of find_package().
|
2 |
+
# It is used by write_basic_package_version_file() as input file for configure_file()
|
3 |
+
# to create a version-file which can be installed along a config.cmake file.
|
4 |
+
#
|
5 |
+
# The created file sets PACKAGE_VERSION_EXACT if the current version string and
|
6 |
+
# the requested version string are exactly the same and it sets
|
7 |
+
# PACKAGE_VERSION_COMPATIBLE if the current version is >= requested version,
|
8 |
+
# but only if the requested major version is the same as the current one.
|
9 |
+
# The variable CVF_VERSION must be set before calling configure_file().
|
10 |
+
|
11 |
+
|
12 |
+
set(PACKAGE_VERSION "0.0.7")
|
13 |
+
|
14 |
+
if(PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION)
|
15 |
+
set(PACKAGE_VERSION_COMPATIBLE FALSE)
|
16 |
+
else()
|
17 |
+
|
18 |
+
if("0.0.7" MATCHES "^([0-9]+)\\.")
|
19 |
+
set(CVF_VERSION_MAJOR "${CMAKE_MATCH_1}")
|
20 |
+
if(NOT CVF_VERSION_MAJOR VERSION_EQUAL 0)
|
21 |
+
string(REGEX REPLACE "^0+" "" CVF_VERSION_MAJOR "${CVF_VERSION_MAJOR}")
|
22 |
+
endif()
|
23 |
+
else()
|
24 |
+
set(CVF_VERSION_MAJOR "0.0.7")
|
25 |
+
endif()
|
26 |
+
|
27 |
+
if(PACKAGE_FIND_VERSION_RANGE)
|
28 |
+
# both endpoints of the range must have the expected major version
|
29 |
+
math (EXPR CVF_VERSION_MAJOR_NEXT "${CVF_VERSION_MAJOR} + 1")
|
30 |
+
if (NOT PACKAGE_FIND_VERSION_MIN_MAJOR STREQUAL CVF_VERSION_MAJOR
|
31 |
+
OR ((PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "INCLUDE" AND NOT PACKAGE_FIND_VERSION_MAX_MAJOR STREQUAL CVF_VERSION_MAJOR)
|
32 |
+
OR (PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "EXCLUDE" AND NOT PACKAGE_FIND_VERSION_MAX VERSION_LESS_EQUAL CVF_VERSION_MAJOR_NEXT)))
|
33 |
+
set(PACKAGE_VERSION_COMPATIBLE FALSE)
|
34 |
+
elseif(PACKAGE_FIND_VERSION_MIN_MAJOR STREQUAL CVF_VERSION_MAJOR
|
35 |
+
AND ((PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "INCLUDE" AND PACKAGE_VERSION VERSION_LESS_EQUAL PACKAGE_FIND_VERSION_MAX)
|
36 |
+
OR (PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "EXCLUDE" AND PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION_MAX)))
|
37 |
+
set(PACKAGE_VERSION_COMPATIBLE TRUE)
|
38 |
+
else()
|
39 |
+
set(PACKAGE_VERSION_COMPATIBLE FALSE)
|
40 |
+
endif()
|
41 |
+
else()
|
42 |
+
if(PACKAGE_FIND_VERSION_MAJOR STREQUAL CVF_VERSION_MAJOR)
|
43 |
+
set(PACKAGE_VERSION_COMPATIBLE TRUE)
|
44 |
+
else()
|
45 |
+
set(PACKAGE_VERSION_COMPATIBLE FALSE)
|
46 |
+
endif()
|
47 |
+
|
48 |
+
if(PACKAGE_FIND_VERSION STREQUAL PACKAGE_VERSION)
|
49 |
+
set(PACKAGE_VERSION_EXACT TRUE)
|
50 |
+
endif()
|
51 |
+
endif()
|
52 |
+
endif()
|
53 |
+
|
54 |
+
|
55 |
+
# if the installed or the using project don't have CMAKE_SIZEOF_VOID_P set, ignore it:
|
56 |
+
if("${CMAKE_SIZEOF_VOID_P}" STREQUAL "" OR "8" STREQUAL "")
|
57 |
+
return()
|
58 |
+
endif()
|
59 |
+
|
60 |
+
# check that the installed version has the same 32/64bit-ness as the one which is currently searching:
|
61 |
+
if(NOT CMAKE_SIZEOF_VOID_P STREQUAL "8")
|
62 |
+
math(EXPR installedBits "8 * 8")
|
63 |
+
set(PACKAGE_VERSION "${PACKAGE_VERSION} (${installedBits}bit)")
|
64 |
+
set(PACKAGE_VERSION_UNSUITABLE TRUE)
|
65 |
+
endif()
|
lib/python3.11/site-packages/mlx/share/cmake/MLX/MLXTargets-release.cmake
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#----------------------------------------------------------------
|
2 |
+
# Generated CMake target import file for configuration "Release".
|
3 |
+
#----------------------------------------------------------------
|
4 |
+
|
5 |
+
# Commands may need to know the format version.
|
6 |
+
set(CMAKE_IMPORT_FILE_VERSION 1)
|
7 |
+
|
8 |
+
# Import target "mlx" for configuration "Release"
|
9 |
+
set_property(TARGET mlx APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE)
|
10 |
+
set_target_properties(mlx PROPERTIES
|
11 |
+
IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib/libmlx.dylib"
|
12 |
+
IMPORTED_SONAME_RELEASE "@rpath/libmlx.dylib"
|
13 |
+
)
|
14 |
+
|
15 |
+
list(APPEND _cmake_import_check_targets mlx )
|
16 |
+
list(APPEND _cmake_import_check_files_for_mlx "${_IMPORT_PREFIX}/lib/libmlx.dylib" )
|
17 |
+
|
18 |
+
# Commands beyond this point should not need to know the version.
|
19 |
+
set(CMAKE_IMPORT_FILE_VERSION)
|
lib/python3.11/site-packages/mlx/share/cmake/MLX/MLXTargets.cmake
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Generated by CMake
|
2 |
+
|
3 |
+
if("${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION}" LESS 2.8)
|
4 |
+
message(FATAL_ERROR "CMake >= 2.8.0 required")
|
5 |
+
endif()
|
6 |
+
if(CMAKE_VERSION VERSION_LESS "2.8.3")
|
7 |
+
message(FATAL_ERROR "CMake >= 2.8.3 required")
|
8 |
+
endif()
|
9 |
+
cmake_policy(PUSH)
|
10 |
+
cmake_policy(VERSION 2.8.3...3.24)
|
11 |
+
#----------------------------------------------------------------
|
12 |
+
# Generated CMake target import file.
|
13 |
+
#----------------------------------------------------------------
|
14 |
+
|
15 |
+
# Commands may need to know the format version.
|
16 |
+
set(CMAKE_IMPORT_FILE_VERSION 1)
|
17 |
+
|
18 |
+
# Protect against multiple inclusion, which would fail when already imported targets are added once more.
|
19 |
+
set(_cmake_targets_defined "")
|
20 |
+
set(_cmake_targets_not_defined "")
|
21 |
+
set(_cmake_expected_targets "")
|
22 |
+
foreach(_cmake_expected_target IN ITEMS mlx)
|
23 |
+
list(APPEND _cmake_expected_targets "${_cmake_expected_target}")
|
24 |
+
if(TARGET "${_cmake_expected_target}")
|
25 |
+
list(APPEND _cmake_targets_defined "${_cmake_expected_target}")
|
26 |
+
else()
|
27 |
+
list(APPEND _cmake_targets_not_defined "${_cmake_expected_target}")
|
28 |
+
endif()
|
29 |
+
endforeach()
|
30 |
+
unset(_cmake_expected_target)
|
31 |
+
if(_cmake_targets_defined STREQUAL _cmake_expected_targets)
|
32 |
+
unset(_cmake_targets_defined)
|
33 |
+
unset(_cmake_targets_not_defined)
|
34 |
+
unset(_cmake_expected_targets)
|
35 |
+
unset(CMAKE_IMPORT_FILE_VERSION)
|
36 |
+
cmake_policy(POP)
|
37 |
+
return()
|
38 |
+
endif()
|
39 |
+
if(NOT _cmake_targets_defined STREQUAL "")
|
40 |
+
string(REPLACE ";" ", " _cmake_targets_defined_text "${_cmake_targets_defined}")
|
41 |
+
string(REPLACE ";" ", " _cmake_targets_not_defined_text "${_cmake_targets_not_defined}")
|
42 |
+
message(FATAL_ERROR "Some (but not all) targets in this export set were already defined.\nTargets Defined: ${_cmake_targets_defined_text}\nTargets not yet defined: ${_cmake_targets_not_defined_text}\n")
|
43 |
+
endif()
|
44 |
+
unset(_cmake_targets_defined)
|
45 |
+
unset(_cmake_targets_not_defined)
|
46 |
+
unset(_cmake_expected_targets)
|
47 |
+
|
48 |
+
|
49 |
+
# Compute the installation prefix relative to this file.
|
50 |
+
get_filename_component(_IMPORT_PREFIX "${CMAKE_CURRENT_LIST_FILE}" PATH)
|
51 |
+
get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
|
52 |
+
get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
|
53 |
+
get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
|
54 |
+
if(_IMPORT_PREFIX STREQUAL "/")
|
55 |
+
set(_IMPORT_PREFIX "")
|
56 |
+
endif()
|
57 |
+
|
58 |
+
# Create imported target mlx
|
59 |
+
add_library(mlx SHARED IMPORTED)
|
60 |
+
|
61 |
+
set_target_properties(mlx PROPERTIES
|
62 |
+
INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include/metal_cpp;${_IMPORT_PREFIX}/include/json;${_IMPORT_PREFIX}/include;${_IMPORT_PREFIX}/include"
|
63 |
+
INTERFACE_LINK_LIBRARIES "/Applications/Xcode-beta.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX14.0.sdk/System/Library/Frameworks/Metal.framework;/Applications/Xcode-beta.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX14.0.sdk/System/Library/Frameworks/Foundation.framework;/Applications/Xcode-beta.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX14.0.sdk/System/Library/Frameworks/QuartzCore.framework;/Applications/Xcode-beta.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX14.0.sdk/System/Library/Frameworks/Accelerate.framework"
|
64 |
+
)
|
65 |
+
|
66 |
+
if(CMAKE_VERSION VERSION_LESS 2.8.12)
|
67 |
+
message(FATAL_ERROR "This file relies on consumers using CMake 2.8.12 or greater.")
|
68 |
+
endif()
|
69 |
+
|
70 |
+
# Load information for each installed configuration.
|
71 |
+
file(GLOB _cmake_config_files "${CMAKE_CURRENT_LIST_DIR}/MLXTargets-*.cmake")
|
72 |
+
foreach(_cmake_config_file IN LISTS _cmake_config_files)
|
73 |
+
include("${_cmake_config_file}")
|
74 |
+
endforeach()
|
75 |
+
unset(_cmake_config_file)
|
76 |
+
unset(_cmake_config_files)
|
77 |
+
|
78 |
+
# Cleanup temporary variables.
|
79 |
+
set(_IMPORT_PREFIX)
|
80 |
+
|
81 |
+
# Loop over all imported files and verify that they actually exist
|
82 |
+
foreach(_cmake_target IN LISTS _cmake_import_check_targets)
|
83 |
+
foreach(_cmake_file IN LISTS "_cmake_import_check_files_for_${_cmake_target}")
|
84 |
+
if(NOT EXISTS "${_cmake_file}")
|
85 |
+
message(FATAL_ERROR "The imported target \"${_cmake_target}\" references the file
|
86 |
+
\"${_cmake_file}\"
|
87 |
+
but this file does not exist. Possible reasons include:
|
88 |
+
* The file was deleted, renamed, or moved to another location.
|
89 |
+
* An install or uninstall procedure did not complete successfully.
|
90 |
+
* The installation package was faulty and contained
|
91 |
+
\"${CMAKE_CURRENT_LIST_FILE}\"
|
92 |
+
but not all the files it references.
|
93 |
+
")
|
94 |
+
endif()
|
95 |
+
endforeach()
|
96 |
+
unset(_cmake_file)
|
97 |
+
unset("_cmake_import_check_files_for_${_cmake_target}")
|
98 |
+
endforeach()
|
99 |
+
unset(_cmake_target)
|
100 |
+
unset(_cmake_import_check_targets)
|
101 |
+
|
102 |
+
# This file does not depend on other imported targets which have
|
103 |
+
# been exported from the same project but in a separate export set.
|
104 |
+
|
105 |
+
# Commands beyond this point should not need to know the version.
|
106 |
+
set(CMAKE_IMPORT_FILE_VERSION)
|
107 |
+
cmake_policy(POP)
|
lib/python3.11/site-packages/mlx/share/cmake/MLX/extension.cmake
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
include(CMakeParseArguments)
|
2 |
+
|
3 |
+
###############################################################################
|
4 |
+
# Build metal library
|
5 |
+
#
|
6 |
+
# Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib
|
7 |
+
# from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS}
|
8 |
+
#
|
9 |
+
# Args:
|
10 |
+
# TARGET: Custom target to be added for the metal library
|
11 |
+
# TITLE: Name of the .metallib
|
12 |
+
# OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib
|
13 |
+
# SOURCES: List of source files
|
14 |
+
# INCLUDE_DIRS: List of include dirs
|
15 |
+
# DEPS: List of dependency files (like headers)
|
16 |
+
#
|
17 |
+
macro(mlx_build_metallib)
|
18 |
+
# Parse args
|
19 |
+
set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
|
20 |
+
set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
|
21 |
+
cmake_parse_arguments(
|
22 |
+
MTLLIB
|
23 |
+
""
|
24 |
+
"${oneValueArgs}"
|
25 |
+
"${multiValueArgs}"
|
26 |
+
${ARGN}
|
27 |
+
)
|
28 |
+
|
29 |
+
# Set output
|
30 |
+
set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
|
31 |
+
|
32 |
+
# Collect compile options
|
33 |
+
set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
|
34 |
+
|
35 |
+
# Prepare metallib build command
|
36 |
+
add_custom_command(
|
37 |
+
OUTPUT ${MTLLIB_BUILD_TARGET}
|
38 |
+
COMMAND xcrun -sdk macosx metal
|
39 |
+
"$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
|
40 |
+
${MTLLIB_COMPILE_OPTIONS}
|
41 |
+
${MTLLIB_SOURCES}
|
42 |
+
-o ${MTLLIB_BUILD_TARGET}
|
43 |
+
DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
|
44 |
+
COMMAND_EXPAND_LISTS
|
45 |
+
COMMENT "Building ${MTLLIB_TITLE}.metallib"
|
46 |
+
VERBATIM
|
47 |
+
)
|
48 |
+
|
49 |
+
# Add metallib custom target
|
50 |
+
add_custom_target(
|
51 |
+
${MTLLIB_TARGET}
|
52 |
+
DEPENDS
|
53 |
+
${MTLLIB_BUILD_TARGET}
|
54 |
+
)
|
55 |
+
|
56 |
+
endmacro(mlx_build_metallib)
|
lib/python3.11/site-packages/mlx/utils.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright © 2023 Apple Inc.
|
2 |
+
|
3 |
+
from collections import defaultdict
|
4 |
+
|
5 |
+
|
6 |
+
def tree_map(fn, tree, *rest, is_leaf=None):
|
7 |
+
"""Applies ``fn`` to the leaves of the python tree ``tree`` and
|
8 |
+
returns a new collection with the results.
|
9 |
+
|
10 |
+
If ``rest`` is provided, every item is assumed to be a superset of ``tree``
|
11 |
+
and the corresponding leaves are provided as extra positional arguments to
|
12 |
+
``fn``. In that respect, :meth:`tree_map` is closer to :func:`itertools.starmap`
|
13 |
+
than to :func:`map`.
|
14 |
+
|
15 |
+
The keyword argument ``is_leaf`` decides what constitutes a leaf from
|
16 |
+
``tree`` similar to :func:`tree_flatten`.
|
17 |
+
|
18 |
+
.. code-block:: python
|
19 |
+
|
20 |
+
import mlx.nn as nn
|
21 |
+
from mlx.utils import tree_map
|
22 |
+
|
23 |
+
model = nn.Linear(10, 10)
|
24 |
+
print(model.parameters().keys())
|
25 |
+
# dict_keys(['weight', 'bias'])
|
26 |
+
|
27 |
+
# square the parameters
|
28 |
+
model.update(tree_map(lambda x: x*x, model.parameters()))
|
29 |
+
|
30 |
+
Args:
|
31 |
+
fn (Callable): The function that processes the leaves of the tree
|
32 |
+
tree (Any): The main python tree that will be iterated upon
|
33 |
+
rest (Tuple[Any]): Extra trees to be iterated together with tree
|
34 |
+
is_leaf (Optional[Callable]): An optional callable that returns True if
|
35 |
+
the passed object is considered a leaf or False otherwise.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
A python tree with the new values returned by ``fn``.
|
39 |
+
"""
|
40 |
+
if is_leaf is not None and is_leaf(tree):
|
41 |
+
return fn(tree, *rest)
|
42 |
+
elif isinstance(tree, (list, tuple)):
|
43 |
+
TreeType = type(tree)
|
44 |
+
return TreeType(
|
45 |
+
tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf)
|
46 |
+
for i, child in enumerate(tree)
|
47 |
+
)
|
48 |
+
elif isinstance(tree, dict):
|
49 |
+
return {
|
50 |
+
k: tree_map(fn, child, *(r[k] for r in rest), is_leaf=is_leaf)
|
51 |
+
for k, child in tree.items()
|
52 |
+
}
|
53 |
+
else:
|
54 |
+
return fn(tree, *rest)
|
55 |
+
|
56 |
+
|
57 |
+
def tree_flatten(tree, prefix="", is_leaf=None):
|
58 |
+
"""Flattens a python tree to a list of key, value tuples.
|
59 |
+
|
60 |
+
The keys are using the dot notation to define trees of arbitrary depth and
|
61 |
+
complexity.
|
62 |
+
|
63 |
+
.. code-block:: python
|
64 |
+
|
65 |
+
from mlx.utils import tree_flatten
|
66 |
+
|
67 |
+
print(tree_flatten([[[0]]]))
|
68 |
+
# [("0.0.0", 0)]
|
69 |
+
|
70 |
+
print(tree_flatten([[[0]]], ".hello"))
|
71 |
+
# [("hello.0.0.0", 0)]
|
72 |
+
|
73 |
+
.. note::
|
74 |
+
Dictionaries should have keys that are valid python identifiers.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
tree (Any): The python tree to be flattened.
|
78 |
+
prefix (str): A prefix to use for the keys. The first character is
|
79 |
+
always discarded.
|
80 |
+
is_leaf (Callable): An optional callable that returns True if the
|
81 |
+
passed object is considered a leaf or False otherwise.
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
List[Tuple[str, Any]]: The flat representation of the python tree.
|
85 |
+
"""
|
86 |
+
flat_tree = []
|
87 |
+
|
88 |
+
if is_leaf is None or not is_leaf(tree):
|
89 |
+
if isinstance(tree, (list, tuple)):
|
90 |
+
for i, t in enumerate(tree):
|
91 |
+
flat_tree.extend(tree_flatten(t, f"{prefix}.{i}", is_leaf))
|
92 |
+
return flat_tree
|
93 |
+
if isinstance(tree, dict):
|
94 |
+
for k, t in tree.items():
|
95 |
+
flat_tree.extend(tree_flatten(t, f"{prefix}.{k}", is_leaf))
|
96 |
+
return flat_tree
|
97 |
+
|
98 |
+
return [(prefix[1:], tree)]
|
99 |
+
|
100 |
+
|
101 |
+
def tree_unflatten(tree):
|
102 |
+
"""Recreate a python tree from its flat representation.
|
103 |
+
|
104 |
+
.. code-block:: python
|
105 |
+
|
106 |
+
from mlx.utils import tree_unflatten
|
107 |
+
|
108 |
+
d = tree_unflatten([("hello.world", 42)])
|
109 |
+
print(d)
|
110 |
+
# {"hello": {"world": 42}}
|
111 |
+
|
112 |
+
Args:
|
113 |
+
tree (List[Tuple[str, Any]]): The flat representation of a python tree.
|
114 |
+
For instance as returned by :meth:`tree_flatten`.
|
115 |
+
|
116 |
+
Returns:
|
117 |
+
A python tree.
|
118 |
+
"""
|
119 |
+
if len(tree) == 1 and tree[0][0] == "":
|
120 |
+
return tree[0][1]
|
121 |
+
|
122 |
+
try:
|
123 |
+
int(tree[0][0].split(".", maxsplit=1)[0])
|
124 |
+
is_list = True
|
125 |
+
except ValueError:
|
126 |
+
is_list = False
|
127 |
+
|
128 |
+
# collect children
|
129 |
+
children = defaultdict(list)
|
130 |
+
for key, value in tree:
|
131 |
+
current_idx, *next_idx = key.split(".", maxsplit=1)
|
132 |
+
next_idx = "" if not next_idx else next_idx[0]
|
133 |
+
children[current_idx].append((next_idx, value))
|
134 |
+
|
135 |
+
# recursively map them to the original container
|
136 |
+
if is_list:
|
137 |
+
keys = sorted((int(idx), idx) for idx in children.keys())
|
138 |
+
l = []
|
139 |
+
for i, k in keys:
|
140 |
+
# if i <= len(l), no {} will be appended.
|
141 |
+
l.extend([{} for _ in range(i - len(l))])
|
142 |
+
l.append(tree_unflatten(children[k]))
|
143 |
+
return l
|
144 |
+
else:
|
145 |
+
return {k: tree_unflatten(v) for k, v in children.items()}
|
lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/INSTALLER
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
pip
|
lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/LICENSE
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2012 Erik Rose
|
2 |
+
|
3 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
this software and associated documentation files (the "Software"), to deal in
|
5 |
+
the Software without restriction, including without limitation the rights to
|
6 |
+
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
|
7 |
+
of the Software, and to permit persons to whom the Software is furnished to do
|
8 |
+
so, subject to the following conditions:
|
9 |
+
|
10 |
+
The above copyright notice and this permission notice shall be included in all
|
11 |
+
copies or substantial portions of the Software.
|
12 |
+
|
13 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
16 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
18 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
19 |
+
SOFTWARE.
|
lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/METADATA
ADDED
@@ -0,0 +1,253 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Metadata-Version: 2.1
|
2 |
+
Name: more-itertools
|
3 |
+
Version: 10.1.0
|
4 |
+
Summary: More routines for operating on iterables, beyond itertools
|
5 |
+
Keywords: itertools,iterator,iteration,filter,peek,peekable,chunk,chunked
|
6 |
+
Author-email: Erik Rose <[email protected]>
|
7 |
+
Requires-Python: >=3.8
|
8 |
+
Description-Content-Type: text/x-rst
|
9 |
+
Classifier: Development Status :: 5 - Production/Stable
|
10 |
+
Classifier: Intended Audience :: Developers
|
11 |
+
Classifier: Natural Language :: English
|
12 |
+
Classifier: License :: OSI Approved :: MIT License
|
13 |
+
Classifier: Programming Language :: Python :: 3
|
14 |
+
Classifier: Programming Language :: Python :: 3.8
|
15 |
+
Classifier: Programming Language :: Python :: 3.9
|
16 |
+
Classifier: Programming Language :: Python :: 3.10
|
17 |
+
Classifier: Programming Language :: Python :: 3.11
|
18 |
+
Classifier: Programming Language :: Python :: 3 :: Only
|
19 |
+
Classifier: Programming Language :: Python :: Implementation :: CPython
|
20 |
+
Classifier: Programming Language :: Python :: Implementation :: PyPy
|
21 |
+
Classifier: Topic :: Software Development :: Libraries
|
22 |
+
Project-URL: Homepage, https://github.com/more-itertools/more-itertools
|
23 |
+
|
24 |
+
==============
|
25 |
+
More Itertools
|
26 |
+
==============
|
27 |
+
|
28 |
+
.. image:: https://readthedocs.org/projects/more-itertools/badge/?version=latest
|
29 |
+
:target: https://more-itertools.readthedocs.io/en/stable/
|
30 |
+
|
31 |
+
Python's ``itertools`` library is a gem - you can compose elegant solutions
|
32 |
+
for a variety of problems with the functions it provides. In ``more-itertools``
|
33 |
+
we collect additional building blocks, recipes, and routines for working with
|
34 |
+
Python iterables.
|
35 |
+
|
36 |
+
+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
37 |
+
| Grouping | `chunked <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.chunked>`_, |
|
38 |
+
| | `ichunked <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.ichunked>`_, |
|
39 |
+
| | `chunked_even <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.chunked_even>`_, |
|
40 |
+
| | `sliced <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.sliced>`_, |
|
41 |
+
| | `constrained_batches <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.constrained_batches>`_, |
|
42 |
+
| | `distribute <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distribute>`_, |
|
43 |
+
| | `divide <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.divide>`_, |
|
44 |
+
| | `split_at <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.split_at>`_, |
|
45 |
+
| | `split_before <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.split_before>`_, |
|
46 |
+
| | `split_after <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.split_after>`_, |
|
47 |
+
| | `split_into <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.split_into>`_, |
|
48 |
+
| | `split_when <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.split_when>`_, |
|
49 |
+
| | `bucket <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.bucket>`_, |
|
50 |
+
| | `unzip <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.unzip>`_, |
|
51 |
+
| | `batched <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.batched>`_, |
|
52 |
+
| | `grouper <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.grouper>`_, |
|
53 |
+
| | `partition <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.partition>`_, |
|
54 |
+
| | `transpose <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.transpose>`_ |
|
55 |
+
+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
56 |
+
| Lookahead and lookback | `spy <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.spy>`_, |
|
57 |
+
| | `peekable <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.peekable>`_, |
|
58 |
+
| | `seekable <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.seekable>`_ |
|
59 |
+
+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
60 |
+
| Windowing | `windowed <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.windowed>`_, |
|
61 |
+
| | `substrings <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.substrings>`_, |
|
62 |
+
| | `substrings_indexes <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.substrings_indexes>`_, |
|
63 |
+
| | `stagger <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.stagger>`_, |
|
64 |
+
| | `windowed_complete <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.windowed_complete>`_, |
|
65 |
+
| | `pairwise <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.pairwise>`_, |
|
66 |
+
| | `triplewise <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.triplewise>`_, |
|
67 |
+
| | `sliding_window <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.sliding_window>`_, |
|
68 |
+
| | `subslices <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.subslices>`_ |
|
69 |
+
+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
70 |
+
| Augmenting | `count_cycle <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.count_cycle>`_, |
|
71 |
+
| | `intersperse <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.intersperse>`_, |
|
72 |
+
| | `padded <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.padded>`_, |
|
73 |
+
| | `repeat_each <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.repeat_each>`_, |
|
74 |
+
| | `mark_ends <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.mark_ends>`_, |
|
75 |
+
| | `repeat_last <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.repeat_last>`_, |
|
76 |
+
| | `adjacent <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.adjacent>`_, |
|
77 |
+
| | `groupby_transform <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.groupby_transform>`_, |
|
78 |
+
| | `pad_none <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.pad_none>`_, |
|
79 |
+
| | `ncycles <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.ncycles>`_ |
|
80 |
+
+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
81 |
+
| Combining | `collapse <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.collapse>`_, |
|
82 |
+
| | `sort_together <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.sort_together>`_, |
|
83 |
+
| | `interleave <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.interleave>`_, |
|
84 |
+
| | `interleave_longest <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.interleave_longest>`_, |
|
85 |
+
| | `interleave_evenly <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.interleave_evenly>`_, |
|
86 |
+
| | `zip_offset <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.zip_offset>`_, |
|
87 |
+
| | `zip_equal <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.zip_equal>`_, |
|
88 |
+
| | `zip_broadcast <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.zip_broadcast>`_, |
|
89 |
+
| | `dotproduct <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.dotproduct>`_, |
|
90 |
+
| | `convolve <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.convolve>`_, |
|
91 |
+
| | `flatten <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.flatten>`_, |
|
92 |
+
| | `roundrobin <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.roundrobin>`_, |
|
93 |
+
| | `prepend <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.prepend>`_, |
|
94 |
+
| | `value_chain <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.value_chain>`_, |
|
95 |
+
| | `partial_product <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.partial_product>`_ |
|
96 |
+
+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
97 |
+
| Summarizing | `ilen <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.ilen>`_, |
|
98 |
+
| | `unique_to_each <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.unique_to_each>`_, |
|
99 |
+
| | `sample <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.sample>`_, |
|
100 |
+
| | `consecutive_groups <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.consecutive_groups>`_, |
|
101 |
+
| | `run_length <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.run_length>`_, |
|
102 |
+
| | `map_reduce <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.map_reduce>`_, |
|
103 |
+
| | `exactly_n <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.exactly_n>`_, |
|
104 |
+
| | `is_sorted <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.is_sorted>`_, |
|
105 |
+
| | `all_equal <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.all_equal>`_, |
|
106 |
+
| | `all_unique <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.all_unique>`_, |
|
107 |
+
| | `minmax <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.minmax>`_, |
|
108 |
+
| | `first_true <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.first_true>`_, |
|
109 |
+
| | `quantify <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.quantify>`_, |
|
110 |
+
| | `iequals <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.iequals>`_ |
|
111 |
+
+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
112 |
+
| Selecting | `islice_extended <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.islice_extended>`_, |
|
113 |
+
| | `first <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.first>`_, |
|
114 |
+
| | `last <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.last>`_, |
|
115 |
+
| | `one <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.one>`_, |
|
116 |
+
| | `only <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.only>`_, |
|
117 |
+
| | `strictly_n <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.strictly_n>`_, |
|
118 |
+
| | `strip <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.strip>`_, |
|
119 |
+
| | `lstrip <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.lstrip>`_, |
|
120 |
+
| | `rstrip <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.rstrip>`_, |
|
121 |
+
| | `filter_except <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.filter_except>`_, |
|
122 |
+
| | `map_except <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.map_except>`_, |
|
123 |
+
| | `nth_or_last <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.nth_or_last>`_, |
|
124 |
+
| | `unique_in_window <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.unique_in_window>`_, |
|
125 |
+
| | `before_and_after <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.before_and_after>`_, |
|
126 |
+
| | `nth <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.nth>`_, |
|
127 |
+
| | `take <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.take>`_, |
|
128 |
+
| | `tail <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.tail>`_, |
|
129 |
+
| | `unique_everseen <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.unique_everseen>`_, |
|
130 |
+
| | `unique_justseen <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.unique_justseen>`_, |
|
131 |
+
| | `duplicates_everseen <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.duplicates_everseen>`_, |
|
132 |
+
| | `duplicates_justseen <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.duplicates_justseen>`_, |
|
133 |
+
| | `longest_common_prefix <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.longest_common_prefix>`_, |
|
134 |
+
| | `takewhile_inclusive <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.takewhile_inclusive>`_ |
|
135 |
+
+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
136 |
+
| Combinatorics | `distinct_permutations <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distinct_permutations>`_, |
|
137 |
+
| | `distinct_combinations <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distinct_combinations>`_, |
|
138 |
+
| | `circular_shifts <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.circular_shifts>`_, |
|
139 |
+
| | `partitions <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.partitions>`_, |
|
140 |
+
| | `set_partitions <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.set_partitions>`_, |
|
141 |
+
| | `product_index <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.product_index>`_, |
|
142 |
+
| | `combination_index <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.combination_index>`_, |
|
143 |
+
| | `permutation_index <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.permutation_index>`_, |
|
144 |
+
| | `combination_with_replacement_index <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.combination_with_replacement_index>`_, |
|
145 |
+
| | `gray_product <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.gray_product>`_, |
|
146 |
+
| | `outer_product <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.outer_product>`_, |
|
147 |
+
| | `powerset <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.powerset>`_, |
|
148 |
+
| | `random_product <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.random_product>`_, |
|
149 |
+
| | `random_permutation <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.random_permutation>`_, |
|
150 |
+
| | `random_combination <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.random_combination>`_, |
|
151 |
+
| | `random_combination_with_replacement <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.random_combination_with_replacement>`_, |
|
152 |
+
| | `nth_product <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.nth_product>`_, |
|
153 |
+
| | `nth_permutation <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.nth_permutation>`_, |
|
154 |
+
| | `nth_combination <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.nth_combination>`_, |
|
155 |
+
| | `nth_combination_with_replacement <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.nth_combination_with_replacement>`_ |
|
156 |
+
+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
157 |
+
| Wrapping | `always_iterable <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.always_iterable>`_, |
|
158 |
+
| | `always_reversible <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.always_reversible>`_, |
|
159 |
+
| | `countable <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.countable>`_, |
|
160 |
+
| | `consumer <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.consumer>`_, |
|
161 |
+
| | `with_iter <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.with_iter>`_, |
|
162 |
+
| | `iter_except <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.iter_except>`_ |
|
163 |
+
+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
164 |
+
| Others | `locate <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.locate>`_, |
|
165 |
+
| | `rlocate <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.rlocate>`_, |
|
166 |
+
| | `replace <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.replace>`_, |
|
167 |
+
| | `numeric_range <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.numeric_range>`_, |
|
168 |
+
| | `side_effect <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.side_effect>`_, |
|
169 |
+
| | `iterate <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.iterate>`_, |
|
170 |
+
| | `difference <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.difference>`_, |
|
171 |
+
| | `make_decorator <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.make_decorator>`_, |
|
172 |
+
| | `SequenceView <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.SequenceView>`_, |
|
173 |
+
| | `time_limited <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.time_limited>`_, |
|
174 |
+
| | `map_if <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.map_if>`_, |
|
175 |
+
| | `iter_index <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.iter_index>`_, |
|
176 |
+
| | `consume <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.consume>`_, |
|
177 |
+
| | `tabulate <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.tabulate>`_, |
|
178 |
+
| | `repeatfunc <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.repeatfunc>`_, |
|
179 |
+
| | `polynomial_from_roots <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.polynomial_from_roots>`_, |
|
180 |
+
| | `polynomial_eval <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.polynomial_eval>`_, |
|
181 |
+
| | `polynomial_derivative <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.polynomial_derivative>`_, |
|
182 |
+
| | `sieve <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.sieve>`_, |
|
183 |
+
| | `factor <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.factor>`_, |
|
184 |
+
| | `matmul <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.matmul>`_, |
|
185 |
+
| | `sum_of_squares <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.sum_of_squares>`_ |
|
186 |
+
+------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|
187 |
+
|
188 |
+
|
189 |
+
Getting started
|
190 |
+
===============
|
191 |
+
|
192 |
+
To get started, install the library with `pip <https://pip.pypa.io/en/stable/>`_:
|
193 |
+
|
194 |
+
.. code-block:: shell
|
195 |
+
|
196 |
+
pip install more-itertools
|
197 |
+
|
198 |
+
The recipes from the `itertools docs <https://docs.python.org/3/library/itertools.html#itertools-recipes>`_
|
199 |
+
are included in the top-level package:
|
200 |
+
|
201 |
+
.. code-block:: python
|
202 |
+
|
203 |
+
>>> from more_itertools import flatten
|
204 |
+
>>> iterable = [(0, 1), (2, 3)]
|
205 |
+
>>> list(flatten(iterable))
|
206 |
+
[0, 1, 2, 3]
|
207 |
+
|
208 |
+
Several new recipes are available as well:
|
209 |
+
|
210 |
+
.. code-block:: python
|
211 |
+
|
212 |
+
>>> from more_itertools import chunked
|
213 |
+
>>> iterable = [0, 1, 2, 3, 4, 5, 6, 7, 8]
|
214 |
+
>>> list(chunked(iterable, 3))
|
215 |
+
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
|
216 |
+
|
217 |
+
>>> from more_itertools import spy
|
218 |
+
>>> iterable = (x * x for x in range(1, 6))
|
219 |
+
>>> head, iterable = spy(iterable, n=3)
|
220 |
+
>>> list(head)
|
221 |
+
[1, 4, 9]
|
222 |
+
>>> list(iterable)
|
223 |
+
[1, 4, 9, 16, 25]
|
224 |
+
|
225 |
+
|
226 |
+
|
227 |
+
For the full listing of functions, see the `API documentation <https://more-itertools.readthedocs.io/en/stable/api.html>`_.
|
228 |
+
|
229 |
+
|
230 |
+
Links elsewhere
|
231 |
+
===============
|
232 |
+
|
233 |
+
Blog posts about ``more-itertools``:
|
234 |
+
|
235 |
+
* `Yo, I heard you like decorators <https://www.bbayles.com/index/decorator_factory>`__
|
236 |
+
* `Tour of Python Itertools <https://martinheinz.dev/blog/16>`__ (`Alternate <https://dev.to/martinheinz/tour-of-python-itertools-4122>`__)
|
237 |
+
* `Real-World Python More Itertools <https://www.gidware.com/real-world-more-itertools/>`_
|
238 |
+
|
239 |
+
|
240 |
+
Development
|
241 |
+
===========
|
242 |
+
|
243 |
+
``more-itertools`` is maintained by `@erikrose <https://github.com/erikrose>`_
|
244 |
+
and `@bbayles <https://github.com/bbayles>`_, with help from `many others <https://github.com/more-itertools/more-itertools/graphs/contributors>`_.
|
245 |
+
If you have a problem or suggestion, please file a bug or pull request in this
|
246 |
+
repository. Thanks for contributing!
|
247 |
+
|
248 |
+
|
249 |
+
Version History
|
250 |
+
===============
|
251 |
+
|
252 |
+
The version history can be found in `documentation <https://more-itertools.readthedocs.io/en/stable/versions.html>`_.
|
253 |
+
|
lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/RECORD
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
more_itertools-10.1.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
2 |
+
more_itertools-10.1.0.dist-info/LICENSE,sha256=CfHIyelBrz5YTVlkHqm4fYPAyw_QB-te85Gn4mQ8GkY,1053
|
3 |
+
more_itertools-10.1.0.dist-info/METADATA,sha256=s6T6Rg5Sq9hJE6KhX38MrPxWh8X0d9Fsdld4qdbrQVQ,33830
|
4 |
+
more_itertools-10.1.0.dist-info/RECORD,,
|
5 |
+
more_itertools-10.1.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6 |
+
more_itertools-10.1.0.dist-info/WHEEL,sha256=rSgq_JpHF9fHR1lx53qwg_1-2LypZE_qmcuXbVUq948,81
|
7 |
+
more_itertools/__init__.py,sha256=weQyJxnCVH2EuAaoxC8ZEX-ViZmpHm5kLHO05O8LfRM,149
|
8 |
+
more_itertools/__init__.pyi,sha256=5B3eTzON1BBuOLob1vCflyEb2lSd6usXQQ-Cv-hXkeA,43
|
9 |
+
more_itertools/__pycache__/__init__.cpython-311.pyc,,
|
10 |
+
more_itertools/__pycache__/more.cpython-311.pyc,,
|
11 |
+
more_itertools/__pycache__/recipes.cpython-311.pyc,,
|
12 |
+
more_itertools/more.py,sha256=6XrPO3vvd3miZb6ygQtY5SoJeLk2PAwfEQOlZHwZgXw,140481
|
13 |
+
more_itertools/more.pyi,sha256=nbYRtg-0YR0XMnFOGMzRzbgEOtNC_cfkUaOMOmNVWMA,20726
|
14 |
+
more_itertools/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
15 |
+
more_itertools/recipes.py,sha256=GNzJluz0ZYBZhF_H3bC7YkNCm9fZzOoBiQvh1WNNDOU,26354
|
16 |
+
more_itertools/recipes.pyi,sha256=77IGo2ZqPtp9IkRu5MzQBgatLBl52vO3MRrFvPDcHzM,4237
|
lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/REQUESTED
ADDED
File without changes
|
lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/WHEEL
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Wheel-Version: 1.0
|
2 |
+
Generator: flit 3.8.0
|
3 |
+
Root-Is-Purelib: true
|
4 |
+
Tag: py3-none-any
|
lib/python3.11/site-packages/more_itertools/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""More routines for operating on iterables, beyond itertools"""
|
2 |
+
|
3 |
+
from .more import * # noqa
|
4 |
+
from .recipes import * # noqa
|
5 |
+
|
6 |
+
__version__ = '10.1.0'
|
lib/python3.11/site-packages/more_itertools/__init__.pyi
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .more import *
|
2 |
+
from .recipes import *
|
lib/python3.11/site-packages/more_itertools/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (394 Bytes). View file
|
|
lib/python3.11/site-packages/more_itertools/__pycache__/more.cpython-311.pyc
ADDED
Binary file (179 kB). View file
|
|
lib/python3.11/site-packages/more_itertools/__pycache__/recipes.cpython-311.pyc
ADDED
Binary file (37.2 kB). View file
|
|
lib/python3.11/site-packages/more_itertools/more.py
ADDED
The diff for this file is too large to render.
See raw diff
|
|
lib/python3.11/site-packages/more_itertools/more.pyi
ADDED
@@ -0,0 +1,684 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Stubs for more_itertools.more"""
|
2 |
+
from __future__ import annotations
|
3 |
+
|
4 |
+
from types import TracebackType
|
5 |
+
from typing import (
|
6 |
+
Any,
|
7 |
+
Callable,
|
8 |
+
Container,
|
9 |
+
ContextManager,
|
10 |
+
Generic,
|
11 |
+
Hashable,
|
12 |
+
Iterable,
|
13 |
+
Iterator,
|
14 |
+
overload,
|
15 |
+
Reversible,
|
16 |
+
Sequence,
|
17 |
+
Sized,
|
18 |
+
Type,
|
19 |
+
TypeVar,
|
20 |
+
type_check_only,
|
21 |
+
)
|
22 |
+
from typing_extensions import Protocol
|
23 |
+
|
24 |
+
# Type and type variable definitions
|
25 |
+
_T = TypeVar('_T')
|
26 |
+
_T1 = TypeVar('_T1')
|
27 |
+
_T2 = TypeVar('_T2')
|
28 |
+
_U = TypeVar('_U')
|
29 |
+
_V = TypeVar('_V')
|
30 |
+
_W = TypeVar('_W')
|
31 |
+
_T_co = TypeVar('_T_co', covariant=True)
|
32 |
+
_GenFn = TypeVar('_GenFn', bound=Callable[..., Iterator[object]])
|
33 |
+
_Raisable = BaseException | Type[BaseException]
|
34 |
+
|
35 |
+
@type_check_only
|
36 |
+
class _SizedIterable(Protocol[_T_co], Sized, Iterable[_T_co]): ...
|
37 |
+
|
38 |
+
@type_check_only
|
39 |
+
class _SizedReversible(Protocol[_T_co], Sized, Reversible[_T_co]): ...
|
40 |
+
|
41 |
+
@type_check_only
|
42 |
+
class _SupportsSlicing(Protocol[_T_co]):
|
43 |
+
def __getitem__(self, __k: slice) -> _T_co: ...
|
44 |
+
|
45 |
+
def chunked(
|
46 |
+
iterable: Iterable[_T], n: int | None, strict: bool = ...
|
47 |
+
) -> Iterator[list[_T]]: ...
|
48 |
+
@overload
|
49 |
+
def first(iterable: Iterable[_T]) -> _T: ...
|
50 |
+
@overload
|
51 |
+
def first(iterable: Iterable[_T], default: _U) -> _T | _U: ...
|
52 |
+
@overload
|
53 |
+
def last(iterable: Iterable[_T]) -> _T: ...
|
54 |
+
@overload
|
55 |
+
def last(iterable: Iterable[_T], default: _U) -> _T | _U: ...
|
56 |
+
@overload
|
57 |
+
def nth_or_last(iterable: Iterable[_T], n: int) -> _T: ...
|
58 |
+
@overload
|
59 |
+
def nth_or_last(iterable: Iterable[_T], n: int, default: _U) -> _T | _U: ...
|
60 |
+
|
61 |
+
class peekable(Generic[_T], Iterator[_T]):
|
62 |
+
def __init__(self, iterable: Iterable[_T]) -> None: ...
|
63 |
+
def __iter__(self) -> peekable[_T]: ...
|
64 |
+
def __bool__(self) -> bool: ...
|
65 |
+
@overload
|
66 |
+
def peek(self) -> _T: ...
|
67 |
+
@overload
|
68 |
+
def peek(self, default: _U) -> _T | _U: ...
|
69 |
+
def prepend(self, *items: _T) -> None: ...
|
70 |
+
def __next__(self) -> _T: ...
|
71 |
+
@overload
|
72 |
+
def __getitem__(self, index: int) -> _T: ...
|
73 |
+
@overload
|
74 |
+
def __getitem__(self, index: slice) -> list[_T]: ...
|
75 |
+
|
76 |
+
def consumer(func: _GenFn) -> _GenFn: ...
|
77 |
+
def ilen(iterable: Iterable[object]) -> int: ...
|
78 |
+
def iterate(func: Callable[[_T], _T], start: _T) -> Iterator[_T]: ...
|
79 |
+
def with_iter(
|
80 |
+
context_manager: ContextManager[Iterable[_T]],
|
81 |
+
) -> Iterator[_T]: ...
|
82 |
+
def one(
|
83 |
+
iterable: Iterable[_T],
|
84 |
+
too_short: _Raisable | None = ...,
|
85 |
+
too_long: _Raisable | None = ...,
|
86 |
+
) -> _T: ...
|
87 |
+
def raise_(exception: _Raisable, *args: Any) -> None: ...
|
88 |
+
def strictly_n(
|
89 |
+
iterable: Iterable[_T],
|
90 |
+
n: int,
|
91 |
+
too_short: _GenFn | None = ...,
|
92 |
+
too_long: _GenFn | None = ...,
|
93 |
+
) -> list[_T]: ...
|
94 |
+
def distinct_permutations(
|
95 |
+
iterable: Iterable[_T], r: int | None = ...
|
96 |
+
) -> Iterator[tuple[_T, ...]]: ...
|
97 |
+
def intersperse(
|
98 |
+
e: _U, iterable: Iterable[_T], n: int = ...
|
99 |
+
) -> Iterator[_T | _U]: ...
|
100 |
+
def unique_to_each(*iterables: Iterable[_T]) -> list[list[_T]]: ...
|
101 |
+
@overload
|
102 |
+
def windowed(
|
103 |
+
seq: Iterable[_T], n: int, *, step: int = ...
|
104 |
+
) -> Iterator[tuple[_T | None, ...]]: ...
|
105 |
+
@overload
|
106 |
+
def windowed(
|
107 |
+
seq: Iterable[_T], n: int, fillvalue: _U, step: int = ...
|
108 |
+
) -> Iterator[tuple[_T | _U, ...]]: ...
|
109 |
+
def substrings(iterable: Iterable[_T]) -> Iterator[tuple[_T, ...]]: ...
|
110 |
+
def substrings_indexes(
|
111 |
+
seq: Sequence[_T], reverse: bool = ...
|
112 |
+
) -> Iterator[tuple[Sequence[_T], int, int]]: ...
|
113 |
+
|
114 |
+
class bucket(Generic[_T, _U], Container[_U]):
|
115 |
+
def __init__(
|
116 |
+
self,
|
117 |
+
iterable: Iterable[_T],
|
118 |
+
key: Callable[[_T], _U],
|
119 |
+
validator: Callable[[object], object] | None = ...,
|
120 |
+
) -> None: ...
|
121 |
+
def __contains__(self, value: object) -> bool: ...
|
122 |
+
def __iter__(self) -> Iterator[_U]: ...
|
123 |
+
def __getitem__(self, value: object) -> Iterator[_T]: ...
|
124 |
+
|
125 |
+
def spy(
|
126 |
+
iterable: Iterable[_T], n: int = ...
|
127 |
+
) -> tuple[list[_T], Iterator[_T]]: ...
|
128 |
+
def interleave(*iterables: Iterable[_T]) -> Iterator[_T]: ...
|
129 |
+
def interleave_longest(*iterables: Iterable[_T]) -> Iterator[_T]: ...
|
130 |
+
def interleave_evenly(
|
131 |
+
iterables: list[Iterable[_T]], lengths: list[int] | None = ...
|
132 |
+
) -> Iterator[_T]: ...
|
133 |
+
def collapse(
|
134 |
+
iterable: Iterable[Any],
|
135 |
+
base_type: type | None = ...,
|
136 |
+
levels: int | None = ...,
|
137 |
+
) -> Iterator[Any]: ...
|
138 |
+
@overload
|
139 |
+
def side_effect(
|
140 |
+
func: Callable[[_T], object],
|
141 |
+
iterable: Iterable[_T],
|
142 |
+
chunk_size: None = ...,
|
143 |
+
before: Callable[[], object] | None = ...,
|
144 |
+
after: Callable[[], object] | None = ...,
|
145 |
+
) -> Iterator[_T]: ...
|
146 |
+
@overload
|
147 |
+
def side_effect(
|
148 |
+
func: Callable[[list[_T]], object],
|
149 |
+
iterable: Iterable[_T],
|
150 |
+
chunk_size: int,
|
151 |
+
before: Callable[[], object] | None = ...,
|
152 |
+
after: Callable[[], object] | None = ...,
|
153 |
+
) -> Iterator[_T]: ...
|
154 |
+
def sliced(
|
155 |
+
seq: _SupportsSlicing[_T], n: int, strict: bool = ...
|
156 |
+
) -> Iterator[_T]: ...
|
157 |
+
def split_at(
|
158 |
+
iterable: Iterable[_T],
|
159 |
+
pred: Callable[[_T], object],
|
160 |
+
maxsplit: int = ...,
|
161 |
+
keep_separator: bool = ...,
|
162 |
+
) -> Iterator[list[_T]]: ...
|
163 |
+
def split_before(
|
164 |
+
iterable: Iterable[_T], pred: Callable[[_T], object], maxsplit: int = ...
|
165 |
+
) -> Iterator[list[_T]]: ...
|
166 |
+
def split_after(
|
167 |
+
iterable: Iterable[_T], pred: Callable[[_T], object], maxsplit: int = ...
|
168 |
+
) -> Iterator[list[_T]]: ...
|
169 |
+
def split_when(
|
170 |
+
iterable: Iterable[_T],
|
171 |
+
pred: Callable[[_T, _T], object],
|
172 |
+
maxsplit: int = ...,
|
173 |
+
) -> Iterator[list[_T]]: ...
|
174 |
+
def split_into(
|
175 |
+
iterable: Iterable[_T], sizes: Iterable[int | None]
|
176 |
+
) -> Iterator[list[_T]]: ...
|
177 |
+
@overload
|
178 |
+
def padded(
|
179 |
+
iterable: Iterable[_T],
|
180 |
+
*,
|
181 |
+
n: int | None = ...,
|
182 |
+
next_multiple: bool = ...,
|
183 |
+
) -> Iterator[_T | None]: ...
|
184 |
+
@overload
|
185 |
+
def padded(
|
186 |
+
iterable: Iterable[_T],
|
187 |
+
fillvalue: _U,
|
188 |
+
n: int | None = ...,
|
189 |
+
next_multiple: bool = ...,
|
190 |
+
) -> Iterator[_T | _U]: ...
|
191 |
+
@overload
|
192 |
+
def repeat_last(iterable: Iterable[_T]) -> Iterator[_T]: ...
|
193 |
+
@overload
|
194 |
+
def repeat_last(iterable: Iterable[_T], default: _U) -> Iterator[_T | _U]: ...
|
195 |
+
def distribute(n: int, iterable: Iterable[_T]) -> list[Iterator[_T]]: ...
|
196 |
+
@overload
|
197 |
+
def stagger(
|
198 |
+
iterable: Iterable[_T],
|
199 |
+
offsets: _SizedIterable[int] = ...,
|
200 |
+
longest: bool = ...,
|
201 |
+
) -> Iterator[tuple[_T | None, ...]]: ...
|
202 |
+
@overload
|
203 |
+
def stagger(
|
204 |
+
iterable: Iterable[_T],
|
205 |
+
offsets: _SizedIterable[int] = ...,
|
206 |
+
longest: bool = ...,
|
207 |
+
fillvalue: _U = ...,
|
208 |
+
) -> Iterator[tuple[_T | _U, ...]]: ...
|
209 |
+
|
210 |
+
class UnequalIterablesError(ValueError):
|
211 |
+
def __init__(self, details: tuple[int, int, int] | None = ...) -> None: ...
|
212 |
+
|
213 |
+
@overload
|
214 |
+
def zip_equal(__iter1: Iterable[_T1]) -> Iterator[tuple[_T1]]: ...
|
215 |
+
@overload
|
216 |
+
def zip_equal(
|
217 |
+
__iter1: Iterable[_T1], __iter2: Iterable[_T2]
|
218 |
+
) -> Iterator[tuple[_T1, _T2]]: ...
|
219 |
+
@overload
|
220 |
+
def zip_equal(
|
221 |
+
__iter1: Iterable[_T],
|
222 |
+
__iter2: Iterable[_T],
|
223 |
+
__iter3: Iterable[_T],
|
224 |
+
*iterables: Iterable[_T],
|
225 |
+
) -> Iterator[tuple[_T, ...]]: ...
|
226 |
+
@overload
|
227 |
+
def zip_offset(
|
228 |
+
__iter1: Iterable[_T1],
|
229 |
+
*,
|
230 |
+
offsets: _SizedIterable[int],
|
231 |
+
longest: bool = ...,
|
232 |
+
fillvalue: None = None,
|
233 |
+
) -> Iterator[tuple[_T1 | None]]: ...
|
234 |
+
@overload
|
235 |
+
def zip_offset(
|
236 |
+
__iter1: Iterable[_T1],
|
237 |
+
__iter2: Iterable[_T2],
|
238 |
+
*,
|
239 |
+
offsets: _SizedIterable[int],
|
240 |
+
longest: bool = ...,
|
241 |
+
fillvalue: None = None,
|
242 |
+
) -> Iterator[tuple[_T1 | None, _T2 | None]]: ...
|
243 |
+
@overload
|
244 |
+
def zip_offset(
|
245 |
+
__iter1: Iterable[_T],
|
246 |
+
__iter2: Iterable[_T],
|
247 |
+
__iter3: Iterable[_T],
|
248 |
+
*iterables: Iterable[_T],
|
249 |
+
offsets: _SizedIterable[int],
|
250 |
+
longest: bool = ...,
|
251 |
+
fillvalue: None = None,
|
252 |
+
) -> Iterator[tuple[_T | None, ...]]: ...
|
253 |
+
@overload
|
254 |
+
def zip_offset(
|
255 |
+
__iter1: Iterable[_T1],
|
256 |
+
*,
|
257 |
+
offsets: _SizedIterable[int],
|
258 |
+
longest: bool = ...,
|
259 |
+
fillvalue: _U,
|
260 |
+
) -> Iterator[tuple[_T1 | _U]]: ...
|
261 |
+
@overload
|
262 |
+
def zip_offset(
|
263 |
+
__iter1: Iterable[_T1],
|
264 |
+
__iter2: Iterable[_T2],
|
265 |
+
*,
|
266 |
+
offsets: _SizedIterable[int],
|
267 |
+
longest: bool = ...,
|
268 |
+
fillvalue: _U,
|
269 |
+
) -> Iterator[tuple[_T1 | _U, _T2 | _U]]: ...
|
270 |
+
@overload
|
271 |
+
def zip_offset(
|
272 |
+
__iter1: Iterable[_T],
|
273 |
+
__iter2: Iterable[_T],
|
274 |
+
__iter3: Iterable[_T],
|
275 |
+
*iterables: Iterable[_T],
|
276 |
+
offsets: _SizedIterable[int],
|
277 |
+
longest: bool = ...,
|
278 |
+
fillvalue: _U,
|
279 |
+
) -> Iterator[tuple[_T | _U, ...]]: ...
|
280 |
+
def sort_together(
|
281 |
+
iterables: Iterable[Iterable[_T]],
|
282 |
+
key_list: Iterable[int] = ...,
|
283 |
+
key: Callable[..., Any] | None = ...,
|
284 |
+
reverse: bool = ...,
|
285 |
+
) -> list[tuple[_T, ...]]: ...
|
286 |
+
def unzip(iterable: Iterable[Sequence[_T]]) -> tuple[Iterator[_T], ...]: ...
|
287 |
+
def divide(n: int, iterable: Iterable[_T]) -> list[Iterator[_T]]: ...
|
288 |
+
def always_iterable(
|
289 |
+
obj: object,
|
290 |
+
base_type: type | tuple[type | tuple[Any, ...], ...] | None = ...,
|
291 |
+
) -> Iterator[Any]: ...
|
292 |
+
def adjacent(
|
293 |
+
predicate: Callable[[_T], bool],
|
294 |
+
iterable: Iterable[_T],
|
295 |
+
distance: int = ...,
|
296 |
+
) -> Iterator[tuple[bool, _T]]: ...
|
297 |
+
@overload
|
298 |
+
def groupby_transform(
|
299 |
+
iterable: Iterable[_T],
|
300 |
+
keyfunc: None = None,
|
301 |
+
valuefunc: None = None,
|
302 |
+
reducefunc: None = None,
|
303 |
+
) -> Iterator[tuple[_T, Iterator[_T]]]: ...
|
304 |
+
@overload
|
305 |
+
def groupby_transform(
|
306 |
+
iterable: Iterable[_T],
|
307 |
+
keyfunc: Callable[[_T], _U],
|
308 |
+
valuefunc: None,
|
309 |
+
reducefunc: None,
|
310 |
+
) -> Iterator[tuple[_U, Iterator[_T]]]: ...
|
311 |
+
@overload
|
312 |
+
def groupby_transform(
|
313 |
+
iterable: Iterable[_T],
|
314 |
+
keyfunc: None,
|
315 |
+
valuefunc: Callable[[_T], _V],
|
316 |
+
reducefunc: None,
|
317 |
+
) -> Iterable[tuple[_T, Iterable[_V]]]: ...
|
318 |
+
@overload
|
319 |
+
def groupby_transform(
|
320 |
+
iterable: Iterable[_T],
|
321 |
+
keyfunc: Callable[[_T], _U],
|
322 |
+
valuefunc: Callable[[_T], _V],
|
323 |
+
reducefunc: None,
|
324 |
+
) -> Iterable[tuple[_U, Iterator[_V]]]: ...
|
325 |
+
@overload
|
326 |
+
def groupby_transform(
|
327 |
+
iterable: Iterable[_T],
|
328 |
+
keyfunc: None,
|
329 |
+
valuefunc: None,
|
330 |
+
reducefunc: Callable[[Iterator[_T]], _W],
|
331 |
+
) -> Iterable[tuple[_T, _W]]: ...
|
332 |
+
@overload
|
333 |
+
def groupby_transform(
|
334 |
+
iterable: Iterable[_T],
|
335 |
+
keyfunc: Callable[[_T], _U],
|
336 |
+
valuefunc: None,
|
337 |
+
reducefunc: Callable[[Iterator[_T]], _W],
|
338 |
+
) -> Iterable[tuple[_U, _W]]: ...
|
339 |
+
@overload
|
340 |
+
def groupby_transform(
|
341 |
+
iterable: Iterable[_T],
|
342 |
+
keyfunc: None,
|
343 |
+
valuefunc: Callable[[_T], _V],
|
344 |
+
reducefunc: Callable[[Iterable[_V]], _W],
|
345 |
+
) -> Iterable[tuple[_T, _W]]: ...
|
346 |
+
@overload
|
347 |
+
def groupby_transform(
|
348 |
+
iterable: Iterable[_T],
|
349 |
+
keyfunc: Callable[[_T], _U],
|
350 |
+
valuefunc: Callable[[_T], _V],
|
351 |
+
reducefunc: Callable[[Iterable[_V]], _W],
|
352 |
+
) -> Iterable[tuple[_U, _W]]: ...
|
353 |
+
|
354 |
+
class numeric_range(Generic[_T, _U], Sequence[_T], Hashable, Reversible[_T]):
|
355 |
+
@overload
|
356 |
+
def __init__(self, __stop: _T) -> None: ...
|
357 |
+
@overload
|
358 |
+
def __init__(self, __start: _T, __stop: _T) -> None: ...
|
359 |
+
@overload
|
360 |
+
def __init__(self, __start: _T, __stop: _T, __step: _U) -> None: ...
|
361 |
+
def __bool__(self) -> bool: ...
|
362 |
+
def __contains__(self, elem: object) -> bool: ...
|
363 |
+
def __eq__(self, other: object) -> bool: ...
|
364 |
+
@overload
|
365 |
+
def __getitem__(self, key: int) -> _T: ...
|
366 |
+
@overload
|
367 |
+
def __getitem__(self, key: slice) -> numeric_range[_T, _U]: ...
|
368 |
+
def __hash__(self) -> int: ...
|
369 |
+
def __iter__(self) -> Iterator[_T]: ...
|
370 |
+
def __len__(self) -> int: ...
|
371 |
+
def __reduce__(
|
372 |
+
self,
|
373 |
+
) -> tuple[Type[numeric_range[_T, _U]], tuple[_T, _T, _U]]: ...
|
374 |
+
def __repr__(self) -> str: ...
|
375 |
+
def __reversed__(self) -> Iterator[_T]: ...
|
376 |
+
def count(self, value: _T) -> int: ...
|
377 |
+
def index(self, value: _T) -> int: ... # type: ignore
|
378 |
+
|
379 |
+
def count_cycle(
|
380 |
+
iterable: Iterable[_T], n: int | None = ...
|
381 |
+
) -> Iterable[tuple[int, _T]]: ...
|
382 |
+
def mark_ends(
|
383 |
+
iterable: Iterable[_T],
|
384 |
+
) -> Iterable[tuple[bool, bool, _T]]: ...
|
385 |
+
def locate(
|
386 |
+
iterable: Iterable[object],
|
387 |
+
pred: Callable[..., Any] = ...,
|
388 |
+
window_size: int | None = ...,
|
389 |
+
) -> Iterator[int]: ...
|
390 |
+
def lstrip(
|
391 |
+
iterable: Iterable[_T], pred: Callable[[_T], object]
|
392 |
+
) -> Iterator[_T]: ...
|
393 |
+
def rstrip(
|
394 |
+
iterable: Iterable[_T], pred: Callable[[_T], object]
|
395 |
+
) -> Iterator[_T]: ...
|
396 |
+
def strip(
|
397 |
+
iterable: Iterable[_T], pred: Callable[[_T], object]
|
398 |
+
) -> Iterator[_T]: ...
|
399 |
+
|
400 |
+
class islice_extended(Generic[_T], Iterator[_T]):
|
401 |
+
def __init__(self, iterable: Iterable[_T], *args: int | None) -> None: ...
|
402 |
+
def __iter__(self) -> islice_extended[_T]: ...
|
403 |
+
def __next__(self) -> _T: ...
|
404 |
+
def __getitem__(self, index: slice) -> islice_extended[_T]: ...
|
405 |
+
|
406 |
+
def always_reversible(iterable: Iterable[_T]) -> Iterator[_T]: ...
|
407 |
+
def consecutive_groups(
|
408 |
+
iterable: Iterable[_T], ordering: Callable[[_T], int] = ...
|
409 |
+
) -> Iterator[Iterator[_T]]: ...
|
410 |
+
@overload
|
411 |
+
def difference(
|
412 |
+
iterable: Iterable[_T],
|
413 |
+
func: Callable[[_T, _T], _U] = ...,
|
414 |
+
*,
|
415 |
+
initial: None = ...,
|
416 |
+
) -> Iterator[_T | _U]: ...
|
417 |
+
@overload
|
418 |
+
def difference(
|
419 |
+
iterable: Iterable[_T], func: Callable[[_T, _T], _U] = ..., *, initial: _U
|
420 |
+
) -> Iterator[_U]: ...
|
421 |
+
|
422 |
+
class SequenceView(Generic[_T], Sequence[_T]):
|
423 |
+
def __init__(self, target: Sequence[_T]) -> None: ...
|
424 |
+
@overload
|
425 |
+
def __getitem__(self, index: int) -> _T: ...
|
426 |
+
@overload
|
427 |
+
def __getitem__(self, index: slice) -> Sequence[_T]: ...
|
428 |
+
def __len__(self) -> int: ...
|
429 |
+
|
430 |
+
class seekable(Generic[_T], Iterator[_T]):
|
431 |
+
def __init__(
|
432 |
+
self, iterable: Iterable[_T], maxlen: int | None = ...
|
433 |
+
) -> None: ...
|
434 |
+
def __iter__(self) -> seekable[_T]: ...
|
435 |
+
def __next__(self) -> _T: ...
|
436 |
+
def __bool__(self) -> bool: ...
|
437 |
+
@overload
|
438 |
+
def peek(self) -> _T: ...
|
439 |
+
@overload
|
440 |
+
def peek(self, default: _U) -> _T | _U: ...
|
441 |
+
def elements(self) -> SequenceView[_T]: ...
|
442 |
+
def seek(self, index: int) -> None: ...
|
443 |
+
def relative_seek(self, count: int) -> None: ...
|
444 |
+
|
445 |
+
class run_length:
|
446 |
+
@staticmethod
|
447 |
+
def encode(iterable: Iterable[_T]) -> Iterator[tuple[_T, int]]: ...
|
448 |
+
@staticmethod
|
449 |
+
def decode(iterable: Iterable[tuple[_T, int]]) -> Iterator[_T]: ...
|
450 |
+
|
451 |
+
def exactly_n(
|
452 |
+
iterable: Iterable[_T], n: int, predicate: Callable[[_T], object] = ...
|
453 |
+
) -> bool: ...
|
454 |
+
def circular_shifts(iterable: Iterable[_T]) -> list[tuple[_T, ...]]: ...
|
455 |
+
def make_decorator(
|
456 |
+
wrapping_func: Callable[..., _U], result_index: int = ...
|
457 |
+
) -> Callable[..., Callable[[Callable[..., Any]], Callable[..., _U]]]: ...
|
458 |
+
@overload
|
459 |
+
def map_reduce(
|
460 |
+
iterable: Iterable[_T],
|
461 |
+
keyfunc: Callable[[_T], _U],
|
462 |
+
valuefunc: None = ...,
|
463 |
+
reducefunc: None = ...,
|
464 |
+
) -> dict[_U, list[_T]]: ...
|
465 |
+
@overload
|
466 |
+
def map_reduce(
|
467 |
+
iterable: Iterable[_T],
|
468 |
+
keyfunc: Callable[[_T], _U],
|
469 |
+
valuefunc: Callable[[_T], _V],
|
470 |
+
reducefunc: None = ...,
|
471 |
+
) -> dict[_U, list[_V]]: ...
|
472 |
+
@overload
|
473 |
+
def map_reduce(
|
474 |
+
iterable: Iterable[_T],
|
475 |
+
keyfunc: Callable[[_T], _U],
|
476 |
+
valuefunc: None = ...,
|
477 |
+
reducefunc: Callable[[list[_T]], _W] = ...,
|
478 |
+
) -> dict[_U, _W]: ...
|
479 |
+
@overload
|
480 |
+
def map_reduce(
|
481 |
+
iterable: Iterable[_T],
|
482 |
+
keyfunc: Callable[[_T], _U],
|
483 |
+
valuefunc: Callable[[_T], _V],
|
484 |
+
reducefunc: Callable[[list[_V]], _W],
|
485 |
+
) -> dict[_U, _W]: ...
|
486 |
+
def rlocate(
|
487 |
+
iterable: Iterable[_T],
|
488 |
+
pred: Callable[..., object] = ...,
|
489 |
+
window_size: int | None = ...,
|
490 |
+
) -> Iterator[int]: ...
|
491 |
+
def replace(
|
492 |
+
iterable: Iterable[_T],
|
493 |
+
pred: Callable[..., object],
|
494 |
+
substitutes: Iterable[_U],
|
495 |
+
count: int | None = ...,
|
496 |
+
window_size: int = ...,
|
497 |
+
) -> Iterator[_T | _U]: ...
|
498 |
+
def partitions(iterable: Iterable[_T]) -> Iterator[list[list[_T]]]: ...
|
499 |
+
def set_partitions(
|
500 |
+
iterable: Iterable[_T], k: int | None = ...
|
501 |
+
) -> Iterator[list[list[_T]]]: ...
|
502 |
+
|
503 |
+
class time_limited(Generic[_T], Iterator[_T]):
|
504 |
+
def __init__(
|
505 |
+
self, limit_seconds: float, iterable: Iterable[_T]
|
506 |
+
) -> None: ...
|
507 |
+
def __iter__(self) -> islice_extended[_T]: ...
|
508 |
+
def __next__(self) -> _T: ...
|
509 |
+
|
510 |
+
@overload
|
511 |
+
def only(
|
512 |
+
iterable: Iterable[_T], *, too_long: _Raisable | None = ...
|
513 |
+
) -> _T | None: ...
|
514 |
+
@overload
|
515 |
+
def only(
|
516 |
+
iterable: Iterable[_T], default: _U, too_long: _Raisable | None = ...
|
517 |
+
) -> _T | _U: ...
|
518 |
+
def ichunked(iterable: Iterable[_T], n: int) -> Iterator[Iterator[_T]]: ...
|
519 |
+
def distinct_combinations(
|
520 |
+
iterable: Iterable[_T], r: int
|
521 |
+
) -> Iterator[tuple[_T, ...]]: ...
|
522 |
+
def filter_except(
|
523 |
+
validator: Callable[[Any], object],
|
524 |
+
iterable: Iterable[_T],
|
525 |
+
*exceptions: Type[BaseException],
|
526 |
+
) -> Iterator[_T]: ...
|
527 |
+
def map_except(
|
528 |
+
function: Callable[[Any], _U],
|
529 |
+
iterable: Iterable[_T],
|
530 |
+
*exceptions: Type[BaseException],
|
531 |
+
) -> Iterator[_U]: ...
|
532 |
+
def map_if(
|
533 |
+
iterable: Iterable[Any],
|
534 |
+
pred: Callable[[Any], bool],
|
535 |
+
func: Callable[[Any], Any],
|
536 |
+
func_else: Callable[[Any], Any] | None = ...,
|
537 |
+
) -> Iterator[Any]: ...
|
538 |
+
def sample(
|
539 |
+
iterable: Iterable[_T],
|
540 |
+
k: int,
|
541 |
+
weights: Iterable[float] | None = ...,
|
542 |
+
) -> list[_T]: ...
|
543 |
+
def is_sorted(
|
544 |
+
iterable: Iterable[_T],
|
545 |
+
key: Callable[[_T], _U] | None = ...,
|
546 |
+
reverse: bool = False,
|
547 |
+
strict: bool = False,
|
548 |
+
) -> bool: ...
|
549 |
+
|
550 |
+
class AbortThread(BaseException):
|
551 |
+
pass
|
552 |
+
|
553 |
+
class callback_iter(Generic[_T], Iterator[_T]):
|
554 |
+
def __init__(
|
555 |
+
self,
|
556 |
+
func: Callable[..., Any],
|
557 |
+
callback_kwd: str = ...,
|
558 |
+
wait_seconds: float = ...,
|
559 |
+
) -> None: ...
|
560 |
+
def __enter__(self) -> callback_iter[_T]: ...
|
561 |
+
def __exit__(
|
562 |
+
self,
|
563 |
+
exc_type: Type[BaseException] | None,
|
564 |
+
exc_value: BaseException | None,
|
565 |
+
traceback: TracebackType | None,
|
566 |
+
) -> bool | None: ...
|
567 |
+
def __iter__(self) -> callback_iter[_T]: ...
|
568 |
+
def __next__(self) -> _T: ...
|
569 |
+
def _reader(self) -> Iterator[_T]: ...
|
570 |
+
@property
|
571 |
+
def done(self) -> bool: ...
|
572 |
+
@property
|
573 |
+
def result(self) -> Any: ...
|
574 |
+
|
575 |
+
def windowed_complete(
|
576 |
+
iterable: Iterable[_T], n: int
|
577 |
+
) -> Iterator[tuple[_T, ...]]: ...
|
578 |
+
def all_unique(
|
579 |
+
iterable: Iterable[_T], key: Callable[[_T], _U] | None = ...
|
580 |
+
) -> bool: ...
|
581 |
+
def nth_product(index: int, *args: Iterable[_T]) -> tuple[_T, ...]: ...
|
582 |
+
def nth_combination_with_replacement(
|
583 |
+
iterable: Iterable[_T], r: int, index: int
|
584 |
+
) -> tuple[_T, ...]: ...
|
585 |
+
def nth_permutation(
|
586 |
+
iterable: Iterable[_T], r: int, index: int
|
587 |
+
) -> tuple[_T, ...]: ...
|
588 |
+
def value_chain(*args: _T | Iterable[_T]) -> Iterable[_T]: ...
|
589 |
+
def product_index(element: Iterable[_T], *args: Iterable[_T]) -> int: ...
|
590 |
+
def combination_index(
|
591 |
+
element: Iterable[_T], iterable: Iterable[_T]
|
592 |
+
) -> int: ...
|
593 |
+
def combination_with_replacement_index(
|
594 |
+
element: Iterable[_T], iterable: Iterable[_T]
|
595 |
+
) -> int: ...
|
596 |
+
def permutation_index(
|
597 |
+
element: Iterable[_T], iterable: Iterable[_T]
|
598 |
+
) -> int: ...
|
599 |
+
def repeat_each(iterable: Iterable[_T], n: int = ...) -> Iterator[_T]: ...
|
600 |
+
|
601 |
+
class countable(Generic[_T], Iterator[_T]):
|
602 |
+
def __init__(self, iterable: Iterable[_T]) -> None: ...
|
603 |
+
def __iter__(self) -> countable[_T]: ...
|
604 |
+
def __next__(self) -> _T: ...
|
605 |
+
|
606 |
+
def chunked_even(iterable: Iterable[_T], n: int) -> Iterator[list[_T]]: ...
|
607 |
+
def zip_broadcast(
|
608 |
+
*objects: _T | Iterable[_T],
|
609 |
+
scalar_types: type | tuple[type | tuple[Any, ...], ...] | None = ...,
|
610 |
+
strict: bool = ...,
|
611 |
+
) -> Iterable[tuple[_T, ...]]: ...
|
612 |
+
def unique_in_window(
|
613 |
+
iterable: Iterable[_T], n: int, key: Callable[[_T], _U] | None = ...
|
614 |
+
) -> Iterator[_T]: ...
|
615 |
+
def duplicates_everseen(
|
616 |
+
iterable: Iterable[_T], key: Callable[[_T], _U] | None = ...
|
617 |
+
) -> Iterator[_T]: ...
|
618 |
+
def duplicates_justseen(
|
619 |
+
iterable: Iterable[_T], key: Callable[[_T], _U] | None = ...
|
620 |
+
) -> Iterator[_T]: ...
|
621 |
+
|
622 |
+
class _SupportsLessThan(Protocol):
|
623 |
+
def __lt__(self, __other: Any) -> bool: ...
|
624 |
+
|
625 |
+
_SupportsLessThanT = TypeVar("_SupportsLessThanT", bound=_SupportsLessThan)
|
626 |
+
|
627 |
+
@overload
|
628 |
+
def minmax(
|
629 |
+
iterable_or_value: Iterable[_SupportsLessThanT], *, key: None = None
|
630 |
+
) -> tuple[_SupportsLessThanT, _SupportsLessThanT]: ...
|
631 |
+
@overload
|
632 |
+
def minmax(
|
633 |
+
iterable_or_value: Iterable[_T], *, key: Callable[[_T], _SupportsLessThan]
|
634 |
+
) -> tuple[_T, _T]: ...
|
635 |
+
@overload
|
636 |
+
def minmax(
|
637 |
+
iterable_or_value: Iterable[_SupportsLessThanT],
|
638 |
+
*,
|
639 |
+
key: None = None,
|
640 |
+
default: _U,
|
641 |
+
) -> _U | tuple[_SupportsLessThanT, _SupportsLessThanT]: ...
|
642 |
+
@overload
|
643 |
+
def minmax(
|
644 |
+
iterable_or_value: Iterable[_T],
|
645 |
+
*,
|
646 |
+
key: Callable[[_T], _SupportsLessThan],
|
647 |
+
default: _U,
|
648 |
+
) -> _U | tuple[_T, _T]: ...
|
649 |
+
@overload
|
650 |
+
def minmax(
|
651 |
+
iterable_or_value: _SupportsLessThanT,
|
652 |
+
__other: _SupportsLessThanT,
|
653 |
+
*others: _SupportsLessThanT,
|
654 |
+
) -> tuple[_SupportsLessThanT, _SupportsLessThanT]: ...
|
655 |
+
@overload
|
656 |
+
def minmax(
|
657 |
+
iterable_or_value: _T,
|
658 |
+
__other: _T,
|
659 |
+
*others: _T,
|
660 |
+
key: Callable[[_T], _SupportsLessThan],
|
661 |
+
) -> tuple[_T, _T]: ...
|
662 |
+
def longest_common_prefix(
|
663 |
+
iterables: Iterable[Iterable[_T]],
|
664 |
+
) -> Iterator[_T]: ...
|
665 |
+
def iequals(*iterables: Iterable[object]) -> bool: ...
|
666 |
+
def constrained_batches(
|
667 |
+
iterable: Iterable[object],
|
668 |
+
max_size: int,
|
669 |
+
max_count: int | None = ...,
|
670 |
+
get_len: Callable[[_T], object] = ...,
|
671 |
+
strict: bool = ...,
|
672 |
+
) -> Iterator[tuple[_T]]: ...
|
673 |
+
def gray_product(*iterables: Iterable[_T]) -> Iterator[tuple[_T, ...]]: ...
|
674 |
+
def partial_product(*iterables: Iterable[_T]) -> Iterator[tuple[_T, ...]]: ...
|
675 |
+
def takewhile_inclusive(
|
676 |
+
predicate: Callable[[_T], bool], iterable: Iterable[_T]
|
677 |
+
) -> Iterator[_T]: ...
|
678 |
+
def outer_product(
|
679 |
+
func: Callable[[_T, _U], _V],
|
680 |
+
xs: Iterable[_T],
|
681 |
+
ys: Iterable[_U],
|
682 |
+
*args: Any,
|
683 |
+
**kwargs: Any,
|
684 |
+
) -> Iterator[tuple[_V, ...]]: ...
|
lib/python3.11/site-packages/more_itertools/py.typed
ADDED
File without changes
|
lib/python3.11/site-packages/more_itertools/recipes.py
ADDED
@@ -0,0 +1,977 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Imported from the recipes section of the itertools documentation.
|
2 |
+
|
3 |
+
All functions taken from the recipes section of the itertools library docs
|
4 |
+
[1]_.
|
5 |
+
Some backward-compatible usability improvements have been made.
|
6 |
+
|
7 |
+
.. [1] http://docs.python.org/library/itertools.html#recipes
|
8 |
+
|
9 |
+
"""
|
10 |
+
import math
|
11 |
+
import operator
|
12 |
+
|
13 |
+
from collections import deque
|
14 |
+
from collections.abc import Sized
|
15 |
+
from functools import partial, reduce
|
16 |
+
from itertools import (
|
17 |
+
chain,
|
18 |
+
combinations,
|
19 |
+
compress,
|
20 |
+
count,
|
21 |
+
cycle,
|
22 |
+
groupby,
|
23 |
+
islice,
|
24 |
+
product,
|
25 |
+
repeat,
|
26 |
+
starmap,
|
27 |
+
tee,
|
28 |
+
zip_longest,
|
29 |
+
)
|
30 |
+
from random import randrange, sample, choice
|
31 |
+
|
32 |
+
__all__ = [
|
33 |
+
'all_equal',
|
34 |
+
'batched',
|
35 |
+
'before_and_after',
|
36 |
+
'consume',
|
37 |
+
'convolve',
|
38 |
+
'dotproduct',
|
39 |
+
'first_true',
|
40 |
+
'factor',
|
41 |
+
'flatten',
|
42 |
+
'grouper',
|
43 |
+
'iter_except',
|
44 |
+
'iter_index',
|
45 |
+
'matmul',
|
46 |
+
'ncycles',
|
47 |
+
'nth',
|
48 |
+
'nth_combination',
|
49 |
+
'padnone',
|
50 |
+
'pad_none',
|
51 |
+
'pairwise',
|
52 |
+
'partition',
|
53 |
+
'polynomial_eval',
|
54 |
+
'polynomial_from_roots',
|
55 |
+
'polynomial_derivative',
|
56 |
+
'powerset',
|
57 |
+
'prepend',
|
58 |
+
'quantify',
|
59 |
+
'random_combination_with_replacement',
|
60 |
+
'random_combination',
|
61 |
+
'random_permutation',
|
62 |
+
'random_product',
|
63 |
+
'repeatfunc',
|
64 |
+
'roundrobin',
|
65 |
+
'sieve',
|
66 |
+
'sliding_window',
|
67 |
+
'subslices',
|
68 |
+
'sum_of_squares',
|
69 |
+
'tabulate',
|
70 |
+
'tail',
|
71 |
+
'take',
|
72 |
+
'transpose',
|
73 |
+
'triplewise',
|
74 |
+
'unique_everseen',
|
75 |
+
'unique_justseen',
|
76 |
+
]
|
77 |
+
|
78 |
+
_marker = object()
|
79 |
+
|
80 |
+
|
81 |
+
# zip with strict is available for Python 3.10+
|
82 |
+
try:
|
83 |
+
zip(strict=True)
|
84 |
+
except TypeError:
|
85 |
+
_zip_strict = zip
|
86 |
+
else:
|
87 |
+
_zip_strict = partial(zip, strict=True)
|
88 |
+
|
89 |
+
# math.sumprod is available for Python 3.12+
|
90 |
+
_sumprod = getattr(math, 'sumprod', lambda x, y: dotproduct(x, y))
|
91 |
+
|
92 |
+
|
93 |
+
def take(n, iterable):
|
94 |
+
"""Return first *n* items of the iterable as a list.
|
95 |
+
|
96 |
+
>>> take(3, range(10))
|
97 |
+
[0, 1, 2]
|
98 |
+
|
99 |
+
If there are fewer than *n* items in the iterable, all of them are
|
100 |
+
returned.
|
101 |
+
|
102 |
+
>>> take(10, range(3))
|
103 |
+
[0, 1, 2]
|
104 |
+
|
105 |
+
"""
|
106 |
+
return list(islice(iterable, n))
|
107 |
+
|
108 |
+
|
109 |
+
def tabulate(function, start=0):
|
110 |
+
"""Return an iterator over the results of ``func(start)``,
|
111 |
+
``func(start + 1)``, ``func(start + 2)``...
|
112 |
+
|
113 |
+
*func* should be a function that accepts one integer argument.
|
114 |
+
|
115 |
+
If *start* is not specified it defaults to 0. It will be incremented each
|
116 |
+
time the iterator is advanced.
|
117 |
+
|
118 |
+
>>> square = lambda x: x ** 2
|
119 |
+
>>> iterator = tabulate(square, -3)
|
120 |
+
>>> take(4, iterator)
|
121 |
+
[9, 4, 1, 0]
|
122 |
+
|
123 |
+
"""
|
124 |
+
return map(function, count(start))
|
125 |
+
|
126 |
+
|
127 |
+
def tail(n, iterable):
|
128 |
+
"""Return an iterator over the last *n* items of *iterable*.
|
129 |
+
|
130 |
+
>>> t = tail(3, 'ABCDEFG')
|
131 |
+
>>> list(t)
|
132 |
+
['E', 'F', 'G']
|
133 |
+
|
134 |
+
"""
|
135 |
+
# If the given iterable has a length, then we can use islice to get its
|
136 |
+
# final elements. Note that if the iterable is not actually Iterable,
|
137 |
+
# either islice or deque will throw a TypeError. This is why we don't
|
138 |
+
# check if it is Iterable.
|
139 |
+
if isinstance(iterable, Sized):
|
140 |
+
yield from islice(iterable, max(0, len(iterable) - n), None)
|
141 |
+
else:
|
142 |
+
yield from iter(deque(iterable, maxlen=n))
|
143 |
+
|
144 |
+
|
145 |
+
def consume(iterator, n=None):
|
146 |
+
"""Advance *iterable* by *n* steps. If *n* is ``None``, consume it
|
147 |
+
entirely.
|
148 |
+
|
149 |
+
Efficiently exhausts an iterator without returning values. Defaults to
|
150 |
+
consuming the whole iterator, but an optional second argument may be
|
151 |
+
provided to limit consumption.
|
152 |
+
|
153 |
+
>>> i = (x for x in range(10))
|
154 |
+
>>> next(i)
|
155 |
+
0
|
156 |
+
>>> consume(i, 3)
|
157 |
+
>>> next(i)
|
158 |
+
4
|
159 |
+
>>> consume(i)
|
160 |
+
>>> next(i)
|
161 |
+
Traceback (most recent call last):
|
162 |
+
File "<stdin>", line 1, in <module>
|
163 |
+
StopIteration
|
164 |
+
|
165 |
+
If the iterator has fewer items remaining than the provided limit, the
|
166 |
+
whole iterator will be consumed.
|
167 |
+
|
168 |
+
>>> i = (x for x in range(3))
|
169 |
+
>>> consume(i, 5)
|
170 |
+
>>> next(i)
|
171 |
+
Traceback (most recent call last):
|
172 |
+
File "<stdin>", line 1, in <module>
|
173 |
+
StopIteration
|
174 |
+
|
175 |
+
"""
|
176 |
+
# Use functions that consume iterators at C speed.
|
177 |
+
if n is None:
|
178 |
+
# feed the entire iterator into a zero-length deque
|
179 |
+
deque(iterator, maxlen=0)
|
180 |
+
else:
|
181 |
+
# advance to the empty slice starting at position n
|
182 |
+
next(islice(iterator, n, n), None)
|
183 |
+
|
184 |
+
|
185 |
+
def nth(iterable, n, default=None):
|
186 |
+
"""Returns the nth item or a default value.
|
187 |
+
|
188 |
+
>>> l = range(10)
|
189 |
+
>>> nth(l, 3)
|
190 |
+
3
|
191 |
+
>>> nth(l, 20, "zebra")
|
192 |
+
'zebra'
|
193 |
+
|
194 |
+
"""
|
195 |
+
return next(islice(iterable, n, None), default)
|
196 |
+
|
197 |
+
|
198 |
+
def all_equal(iterable):
|
199 |
+
"""
|
200 |
+
Returns ``True`` if all the elements are equal to each other.
|
201 |
+
|
202 |
+
>>> all_equal('aaaa')
|
203 |
+
True
|
204 |
+
>>> all_equal('aaab')
|
205 |
+
False
|
206 |
+
|
207 |
+
"""
|
208 |
+
g = groupby(iterable)
|
209 |
+
return next(g, True) and not next(g, False)
|
210 |
+
|
211 |
+
|
212 |
+
def quantify(iterable, pred=bool):
|
213 |
+
"""Return the how many times the predicate is true.
|
214 |
+
|
215 |
+
>>> quantify([True, False, True])
|
216 |
+
2
|
217 |
+
|
218 |
+
"""
|
219 |
+
return sum(map(pred, iterable))
|
220 |
+
|
221 |
+
|
222 |
+
def pad_none(iterable):
|
223 |
+
"""Returns the sequence of elements and then returns ``None`` indefinitely.
|
224 |
+
|
225 |
+
>>> take(5, pad_none(range(3)))
|
226 |
+
[0, 1, 2, None, None]
|
227 |
+
|
228 |
+
Useful for emulating the behavior of the built-in :func:`map` function.
|
229 |
+
|
230 |
+
See also :func:`padded`.
|
231 |
+
|
232 |
+
"""
|
233 |
+
return chain(iterable, repeat(None))
|
234 |
+
|
235 |
+
|
236 |
+
padnone = pad_none
|
237 |
+
|
238 |
+
|
239 |
+
def ncycles(iterable, n):
|
240 |
+
"""Returns the sequence elements *n* times
|
241 |
+
|
242 |
+
>>> list(ncycles(["a", "b"], 3))
|
243 |
+
['a', 'b', 'a', 'b', 'a', 'b']
|
244 |
+
|
245 |
+
"""
|
246 |
+
return chain.from_iterable(repeat(tuple(iterable), n))
|
247 |
+
|
248 |
+
|
249 |
+
def dotproduct(vec1, vec2):
|
250 |
+
"""Returns the dot product of the two iterables.
|
251 |
+
|
252 |
+
>>> dotproduct([10, 10], [20, 20])
|
253 |
+
400
|
254 |
+
|
255 |
+
"""
|
256 |
+
return sum(map(operator.mul, vec1, vec2))
|
257 |
+
|
258 |
+
|
259 |
+
def flatten(listOfLists):
|
260 |
+
"""Return an iterator flattening one level of nesting in a list of lists.
|
261 |
+
|
262 |
+
>>> list(flatten([[0, 1], [2, 3]]))
|
263 |
+
[0, 1, 2, 3]
|
264 |
+
|
265 |
+
See also :func:`collapse`, which can flatten multiple levels of nesting.
|
266 |
+
|
267 |
+
"""
|
268 |
+
return chain.from_iterable(listOfLists)
|
269 |
+
|
270 |
+
|
271 |
+
def repeatfunc(func, times=None, *args):
|
272 |
+
"""Call *func* with *args* repeatedly, returning an iterable over the
|
273 |
+
results.
|
274 |
+
|
275 |
+
If *times* is specified, the iterable will terminate after that many
|
276 |
+
repetitions:
|
277 |
+
|
278 |
+
>>> from operator import add
|
279 |
+
>>> times = 4
|
280 |
+
>>> args = 3, 5
|
281 |
+
>>> list(repeatfunc(add, times, *args))
|
282 |
+
[8, 8, 8, 8]
|
283 |
+
|
284 |
+
If *times* is ``None`` the iterable will not terminate:
|
285 |
+
|
286 |
+
>>> from random import randrange
|
287 |
+
>>> times = None
|
288 |
+
>>> args = 1, 11
|
289 |
+
>>> take(6, repeatfunc(randrange, times, *args)) # doctest:+SKIP
|
290 |
+
[2, 4, 8, 1, 8, 4]
|
291 |
+
|
292 |
+
"""
|
293 |
+
if times is None:
|
294 |
+
return starmap(func, repeat(args))
|
295 |
+
return starmap(func, repeat(args, times))
|
296 |
+
|
297 |
+
|
298 |
+
def _pairwise(iterable):
|
299 |
+
"""Returns an iterator of paired items, overlapping, from the original
|
300 |
+
|
301 |
+
>>> take(4, pairwise(count()))
|
302 |
+
[(0, 1), (1, 2), (2, 3), (3, 4)]
|
303 |
+
|
304 |
+
On Python 3.10 and above, this is an alias for :func:`itertools.pairwise`.
|
305 |
+
|
306 |
+
"""
|
307 |
+
a, b = tee(iterable)
|
308 |
+
next(b, None)
|
309 |
+
return zip(a, b)
|
310 |
+
|
311 |
+
|
312 |
+
try:
|
313 |
+
from itertools import pairwise as itertools_pairwise
|
314 |
+
except ImportError:
|
315 |
+
pairwise = _pairwise
|
316 |
+
else:
|
317 |
+
|
318 |
+
def pairwise(iterable):
|
319 |
+
return itertools_pairwise(iterable)
|
320 |
+
|
321 |
+
pairwise.__doc__ = _pairwise.__doc__
|
322 |
+
|
323 |
+
|
324 |
+
class UnequalIterablesError(ValueError):
|
325 |
+
def __init__(self, details=None):
|
326 |
+
msg = 'Iterables have different lengths'
|
327 |
+
if details is not None:
|
328 |
+
msg += (': index 0 has length {}; index {} has length {}').format(
|
329 |
+
*details
|
330 |
+
)
|
331 |
+
|
332 |
+
super().__init__(msg)
|
333 |
+
|
334 |
+
|
335 |
+
def _zip_equal_generator(iterables):
|
336 |
+
for combo in zip_longest(*iterables, fillvalue=_marker):
|
337 |
+
for val in combo:
|
338 |
+
if val is _marker:
|
339 |
+
raise UnequalIterablesError()
|
340 |
+
yield combo
|
341 |
+
|
342 |
+
|
343 |
+
def _zip_equal(*iterables):
|
344 |
+
# Check whether the iterables are all the same size.
|
345 |
+
try:
|
346 |
+
first_size = len(iterables[0])
|
347 |
+
for i, it in enumerate(iterables[1:], 1):
|
348 |
+
size = len(it)
|
349 |
+
if size != first_size:
|
350 |
+
raise UnequalIterablesError(details=(first_size, i, size))
|
351 |
+
# All sizes are equal, we can use the built-in zip.
|
352 |
+
return zip(*iterables)
|
353 |
+
# If any one of the iterables didn't have a length, start reading
|
354 |
+
# them until one runs out.
|
355 |
+
except TypeError:
|
356 |
+
return _zip_equal_generator(iterables)
|
357 |
+
|
358 |
+
|
359 |
+
def grouper(iterable, n, incomplete='fill', fillvalue=None):
|
360 |
+
"""Group elements from *iterable* into fixed-length groups of length *n*.
|
361 |
+
|
362 |
+
>>> list(grouper('ABCDEF', 3))
|
363 |
+
[('A', 'B', 'C'), ('D', 'E', 'F')]
|
364 |
+
|
365 |
+
The keyword arguments *incomplete* and *fillvalue* control what happens for
|
366 |
+
iterables whose length is not a multiple of *n*.
|
367 |
+
|
368 |
+
When *incomplete* is `'fill'`, the last group will contain instances of
|
369 |
+
*fillvalue*.
|
370 |
+
|
371 |
+
>>> list(grouper('ABCDEFG', 3, incomplete='fill', fillvalue='x'))
|
372 |
+
[('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')]
|
373 |
+
|
374 |
+
When *incomplete* is `'ignore'`, the last group will not be emitted.
|
375 |
+
|
376 |
+
>>> list(grouper('ABCDEFG', 3, incomplete='ignore', fillvalue='x'))
|
377 |
+
[('A', 'B', 'C'), ('D', 'E', 'F')]
|
378 |
+
|
379 |
+
When *incomplete* is `'strict'`, a subclass of `ValueError` will be raised.
|
380 |
+
|
381 |
+
>>> it = grouper('ABCDEFG', 3, incomplete='strict')
|
382 |
+
>>> list(it) # doctest: +IGNORE_EXCEPTION_DETAIL
|
383 |
+
Traceback (most recent call last):
|
384 |
+
...
|
385 |
+
UnequalIterablesError
|
386 |
+
|
387 |
+
"""
|
388 |
+
args = [iter(iterable)] * n
|
389 |
+
if incomplete == 'fill':
|
390 |
+
return zip_longest(*args, fillvalue=fillvalue)
|
391 |
+
if incomplete == 'strict':
|
392 |
+
return _zip_equal(*args)
|
393 |
+
if incomplete == 'ignore':
|
394 |
+
return zip(*args)
|
395 |
+
else:
|
396 |
+
raise ValueError('Expected fill, strict, or ignore')
|
397 |
+
|
398 |
+
|
399 |
+
def roundrobin(*iterables):
|
400 |
+
"""Yields an item from each iterable, alternating between them.
|
401 |
+
|
402 |
+
>>> list(roundrobin('ABC', 'D', 'EF'))
|
403 |
+
['A', 'D', 'E', 'B', 'F', 'C']
|
404 |
+
|
405 |
+
This function produces the same output as :func:`interleave_longest`, but
|
406 |
+
may perform better for some inputs (in particular when the number of
|
407 |
+
iterables is small).
|
408 |
+
|
409 |
+
"""
|
410 |
+
# Recipe credited to George Sakkis
|
411 |
+
pending = len(iterables)
|
412 |
+
nexts = cycle(iter(it).__next__ for it in iterables)
|
413 |
+
while pending:
|
414 |
+
try:
|
415 |
+
for next in nexts:
|
416 |
+
yield next()
|
417 |
+
except StopIteration:
|
418 |
+
pending -= 1
|
419 |
+
nexts = cycle(islice(nexts, pending))
|
420 |
+
|
421 |
+
|
422 |
+
def partition(pred, iterable):
|
423 |
+
"""
|
424 |
+
Returns a 2-tuple of iterables derived from the input iterable.
|
425 |
+
The first yields the items that have ``pred(item) == False``.
|
426 |
+
The second yields the items that have ``pred(item) == True``.
|
427 |
+
|
428 |
+
>>> is_odd = lambda x: x % 2 != 0
|
429 |
+
>>> iterable = range(10)
|
430 |
+
>>> even_items, odd_items = partition(is_odd, iterable)
|
431 |
+
>>> list(even_items), list(odd_items)
|
432 |
+
([0, 2, 4, 6, 8], [1, 3, 5, 7, 9])
|
433 |
+
|
434 |
+
If *pred* is None, :func:`bool` is used.
|
435 |
+
|
436 |
+
>>> iterable = [0, 1, False, True, '', ' ']
|
437 |
+
>>> false_items, true_items = partition(None, iterable)
|
438 |
+
>>> list(false_items), list(true_items)
|
439 |
+
([0, False, ''], [1, True, ' '])
|
440 |
+
|
441 |
+
"""
|
442 |
+
if pred is None:
|
443 |
+
pred = bool
|
444 |
+
|
445 |
+
t1, t2, p = tee(iterable, 3)
|
446 |
+
p1, p2 = tee(map(pred, p))
|
447 |
+
return (compress(t1, map(operator.not_, p1)), compress(t2, p2))
|
448 |
+
|
449 |
+
|
450 |
+
def powerset(iterable):
|
451 |
+
"""Yields all possible subsets of the iterable.
|
452 |
+
|
453 |
+
>>> list(powerset([1, 2, 3]))
|
454 |
+
[(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
|
455 |
+
|
456 |
+
:func:`powerset` will operate on iterables that aren't :class:`set`
|
457 |
+
instances, so repeated elements in the input will produce repeated elements
|
458 |
+
in the output. Use :func:`unique_everseen` on the input to avoid generating
|
459 |
+
duplicates:
|
460 |
+
|
461 |
+
>>> seq = [1, 1, 0]
|
462 |
+
>>> list(powerset(seq))
|
463 |
+
[(), (1,), (1,), (0,), (1, 1), (1, 0), (1, 0), (1, 1, 0)]
|
464 |
+
>>> from more_itertools import unique_everseen
|
465 |
+
>>> list(powerset(unique_everseen(seq)))
|
466 |
+
[(), (1,), (0,), (1, 0)]
|
467 |
+
|
468 |
+
"""
|
469 |
+
s = list(iterable)
|
470 |
+
return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
|
471 |
+
|
472 |
+
|
473 |
+
def unique_everseen(iterable, key=None):
|
474 |
+
"""
|
475 |
+
Yield unique elements, preserving order.
|
476 |
+
|
477 |
+
>>> list(unique_everseen('AAAABBBCCDAABBB'))
|
478 |
+
['A', 'B', 'C', 'D']
|
479 |
+
>>> list(unique_everseen('ABBCcAD', str.lower))
|
480 |
+
['A', 'B', 'C', 'D']
|
481 |
+
|
482 |
+
Sequences with a mix of hashable and unhashable items can be used.
|
483 |
+
The function will be slower (i.e., `O(n^2)`) for unhashable items.
|
484 |
+
|
485 |
+
Remember that ``list`` objects are unhashable - you can use the *key*
|
486 |
+
parameter to transform the list to a tuple (which is hashable) to
|
487 |
+
avoid a slowdown.
|
488 |
+
|
489 |
+
>>> iterable = ([1, 2], [2, 3], [1, 2])
|
490 |
+
>>> list(unique_everseen(iterable)) # Slow
|
491 |
+
[[1, 2], [2, 3]]
|
492 |
+
>>> list(unique_everseen(iterable, key=tuple)) # Faster
|
493 |
+
[[1, 2], [2, 3]]
|
494 |
+
|
495 |
+
Similary, you may want to convert unhashable ``set`` objects with
|
496 |
+
``key=frozenset``. For ``dict`` objects,
|
497 |
+
``key=lambda x: frozenset(x.items())`` can be used.
|
498 |
+
|
499 |
+
"""
|
500 |
+
seenset = set()
|
501 |
+
seenset_add = seenset.add
|
502 |
+
seenlist = []
|
503 |
+
seenlist_add = seenlist.append
|
504 |
+
use_key = key is not None
|
505 |
+
|
506 |
+
for element in iterable:
|
507 |
+
k = key(element) if use_key else element
|
508 |
+
try:
|
509 |
+
if k not in seenset:
|
510 |
+
seenset_add(k)
|
511 |
+
yield element
|
512 |
+
except TypeError:
|
513 |
+
if k not in seenlist:
|
514 |
+
seenlist_add(k)
|
515 |
+
yield element
|
516 |
+
|
517 |
+
|
518 |
+
def unique_justseen(iterable, key=None):
|
519 |
+
"""Yields elements in order, ignoring serial duplicates
|
520 |
+
|
521 |
+
>>> list(unique_justseen('AAAABBBCCDAABBB'))
|
522 |
+
['A', 'B', 'C', 'D', 'A', 'B']
|
523 |
+
>>> list(unique_justseen('ABBCcAD', str.lower))
|
524 |
+
['A', 'B', 'C', 'A', 'D']
|
525 |
+
|
526 |
+
"""
|
527 |
+
return map(next, map(operator.itemgetter(1), groupby(iterable, key)))
|
528 |
+
|
529 |
+
|
530 |
+
def iter_except(func, exception, first=None):
|
531 |
+
"""Yields results from a function repeatedly until an exception is raised.
|
532 |
+
|
533 |
+
Converts a call-until-exception interface to an iterator interface.
|
534 |
+
Like ``iter(func, sentinel)``, but uses an exception instead of a sentinel
|
535 |
+
to end the loop.
|
536 |
+
|
537 |
+
>>> l = [0, 1, 2]
|
538 |
+
>>> list(iter_except(l.pop, IndexError))
|
539 |
+
[2, 1, 0]
|
540 |
+
|
541 |
+
Multiple exceptions can be specified as a stopping condition:
|
542 |
+
|
543 |
+
>>> l = [1, 2, 3, '...', 4, 5, 6]
|
544 |
+
>>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
|
545 |
+
[7, 6, 5]
|
546 |
+
>>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
|
547 |
+
[4, 3, 2]
|
548 |
+
>>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
|
549 |
+
[]
|
550 |
+
|
551 |
+
"""
|
552 |
+
try:
|
553 |
+
if first is not None:
|
554 |
+
yield first()
|
555 |
+
while 1:
|
556 |
+
yield func()
|
557 |
+
except exception:
|
558 |
+
pass
|
559 |
+
|
560 |
+
|
561 |
+
def first_true(iterable, default=None, pred=None):
|
562 |
+
"""
|
563 |
+
Returns the first true value in the iterable.
|
564 |
+
|
565 |
+
If no true value is found, returns *default*
|
566 |
+
|
567 |
+
If *pred* is not None, returns the first item for which
|
568 |
+
``pred(item) == True`` .
|
569 |
+
|
570 |
+
>>> first_true(range(10))
|
571 |
+
1
|
572 |
+
>>> first_true(range(10), pred=lambda x: x > 5)
|
573 |
+
6
|
574 |
+
>>> first_true(range(10), default='missing', pred=lambda x: x > 9)
|
575 |
+
'missing'
|
576 |
+
|
577 |
+
"""
|
578 |
+
return next(filter(pred, iterable), default)
|
579 |
+
|
580 |
+
|
581 |
+
def random_product(*args, repeat=1):
|
582 |
+
"""Draw an item at random from each of the input iterables.
|
583 |
+
|
584 |
+
>>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP
|
585 |
+
('c', 3, 'Z')
|
586 |
+
|
587 |
+
If *repeat* is provided as a keyword argument, that many items will be
|
588 |
+
drawn from each iterable.
|
589 |
+
|
590 |
+
>>> random_product('abcd', range(4), repeat=2) # doctest:+SKIP
|
591 |
+
('a', 2, 'd', 3)
|
592 |
+
|
593 |
+
This equivalent to taking a random selection from
|
594 |
+
``itertools.product(*args, **kwarg)``.
|
595 |
+
|
596 |
+
"""
|
597 |
+
pools = [tuple(pool) for pool in args] * repeat
|
598 |
+
return tuple(choice(pool) for pool in pools)
|
599 |
+
|
600 |
+
|
601 |
+
def random_permutation(iterable, r=None):
|
602 |
+
"""Return a random *r* length permutation of the elements in *iterable*.
|
603 |
+
|
604 |
+
If *r* is not specified or is ``None``, then *r* defaults to the length of
|
605 |
+
*iterable*.
|
606 |
+
|
607 |
+
>>> random_permutation(range(5)) # doctest:+SKIP
|
608 |
+
(3, 4, 0, 1, 2)
|
609 |
+
|
610 |
+
This equivalent to taking a random selection from
|
611 |
+
``itertools.permutations(iterable, r)``.
|
612 |
+
|
613 |
+
"""
|
614 |
+
pool = tuple(iterable)
|
615 |
+
r = len(pool) if r is None else r
|
616 |
+
return tuple(sample(pool, r))
|
617 |
+
|
618 |
+
|
619 |
+
def random_combination(iterable, r):
|
620 |
+
"""Return a random *r* length subsequence of the elements in *iterable*.
|
621 |
+
|
622 |
+
>>> random_combination(range(5), 3) # doctest:+SKIP
|
623 |
+
(2, 3, 4)
|
624 |
+
|
625 |
+
This equivalent to taking a random selection from
|
626 |
+
``itertools.combinations(iterable, r)``.
|
627 |
+
|
628 |
+
"""
|
629 |
+
pool = tuple(iterable)
|
630 |
+
n = len(pool)
|
631 |
+
indices = sorted(sample(range(n), r))
|
632 |
+
return tuple(pool[i] for i in indices)
|
633 |
+
|
634 |
+
|
635 |
+
def random_combination_with_replacement(iterable, r):
|
636 |
+
"""Return a random *r* length subsequence of elements in *iterable*,
|
637 |
+
allowing individual elements to be repeated.
|
638 |
+
|
639 |
+
>>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP
|
640 |
+
(0, 0, 1, 2, 2)
|
641 |
+
|
642 |
+
This equivalent to taking a random selection from
|
643 |
+
``itertools.combinations_with_replacement(iterable, r)``.
|
644 |
+
|
645 |
+
"""
|
646 |
+
pool = tuple(iterable)
|
647 |
+
n = len(pool)
|
648 |
+
indices = sorted(randrange(n) for i in range(r))
|
649 |
+
return tuple(pool[i] for i in indices)
|
650 |
+
|
651 |
+
|
652 |
+
def nth_combination(iterable, r, index):
|
653 |
+
"""Equivalent to ``list(combinations(iterable, r))[index]``.
|
654 |
+
|
655 |
+
The subsequences of *iterable* that are of length *r* can be ordered
|
656 |
+
lexicographically. :func:`nth_combination` computes the subsequence at
|
657 |
+
sort position *index* directly, without computing the previous
|
658 |
+
subsequences.
|
659 |
+
|
660 |
+
>>> nth_combination(range(5), 3, 5)
|
661 |
+
(0, 3, 4)
|
662 |
+
|
663 |
+
``ValueError`` will be raised If *r* is negative or greater than the length
|
664 |
+
of *iterable*.
|
665 |
+
``IndexError`` will be raised if the given *index* is invalid.
|
666 |
+
"""
|
667 |
+
pool = tuple(iterable)
|
668 |
+
n = len(pool)
|
669 |
+
if (r < 0) or (r > n):
|
670 |
+
raise ValueError
|
671 |
+
|
672 |
+
c = 1
|
673 |
+
k = min(r, n - r)
|
674 |
+
for i in range(1, k + 1):
|
675 |
+
c = c * (n - k + i) // i
|
676 |
+
|
677 |
+
if index < 0:
|
678 |
+
index += c
|
679 |
+
|
680 |
+
if (index < 0) or (index >= c):
|
681 |
+
raise IndexError
|
682 |
+
|
683 |
+
result = []
|
684 |
+
while r:
|
685 |
+
c, n, r = c * r // n, n - 1, r - 1
|
686 |
+
while index >= c:
|
687 |
+
index -= c
|
688 |
+
c, n = c * (n - r) // n, n - 1
|
689 |
+
result.append(pool[-1 - n])
|
690 |
+
|
691 |
+
return tuple(result)
|
692 |
+
|
693 |
+
|
694 |
+
def prepend(value, iterator):
|
695 |
+
"""Yield *value*, followed by the elements in *iterator*.
|
696 |
+
|
697 |
+
>>> value = '0'
|
698 |
+
>>> iterator = ['1', '2', '3']
|
699 |
+
>>> list(prepend(value, iterator))
|
700 |
+
['0', '1', '2', '3']
|
701 |
+
|
702 |
+
To prepend multiple values, see :func:`itertools.chain`
|
703 |
+
or :func:`value_chain`.
|
704 |
+
|
705 |
+
"""
|
706 |
+
return chain([value], iterator)
|
707 |
+
|
708 |
+
|
709 |
+
def convolve(signal, kernel):
|
710 |
+
"""Convolve the iterable *signal* with the iterable *kernel*.
|
711 |
+
|
712 |
+
>>> signal = (1, 2, 3, 4, 5)
|
713 |
+
>>> kernel = [3, 2, 1]
|
714 |
+
>>> list(convolve(signal, kernel))
|
715 |
+
[3, 8, 14, 20, 26, 14, 5]
|
716 |
+
|
717 |
+
Note: the input arguments are not interchangeable, as the *kernel*
|
718 |
+
is immediately consumed and stored.
|
719 |
+
|
720 |
+
"""
|
721 |
+
# This implementation intentionally doesn't match the one in the itertools
|
722 |
+
# documentation.
|
723 |
+
kernel = tuple(kernel)[::-1]
|
724 |
+
n = len(kernel)
|
725 |
+
window = deque([0], maxlen=n) * n
|
726 |
+
for x in chain(signal, repeat(0, n - 1)):
|
727 |
+
window.append(x)
|
728 |
+
yield _sumprod(kernel, window)
|
729 |
+
|
730 |
+
|
731 |
+
def before_and_after(predicate, it):
|
732 |
+
"""A variant of :func:`takewhile` that allows complete access to the
|
733 |
+
remainder of the iterator.
|
734 |
+
|
735 |
+
>>> it = iter('ABCdEfGhI')
|
736 |
+
>>> all_upper, remainder = before_and_after(str.isupper, it)
|
737 |
+
>>> ''.join(all_upper)
|
738 |
+
'ABC'
|
739 |
+
>>> ''.join(remainder) # takewhile() would lose the 'd'
|
740 |
+
'dEfGhI'
|
741 |
+
|
742 |
+
Note that the first iterator must be fully consumed before the second
|
743 |
+
iterator can generate valid results.
|
744 |
+
"""
|
745 |
+
it = iter(it)
|
746 |
+
transition = []
|
747 |
+
|
748 |
+
def true_iterator():
|
749 |
+
for elem in it:
|
750 |
+
if predicate(elem):
|
751 |
+
yield elem
|
752 |
+
else:
|
753 |
+
transition.append(elem)
|
754 |
+
return
|
755 |
+
|
756 |
+
# Note: this is different from itertools recipes to allow nesting
|
757 |
+
# before_and_after remainders into before_and_after again. See tests
|
758 |
+
# for an example.
|
759 |
+
remainder_iterator = chain(transition, it)
|
760 |
+
|
761 |
+
return true_iterator(), remainder_iterator
|
762 |
+
|
763 |
+
|
764 |
+
def triplewise(iterable):
|
765 |
+
"""Return overlapping triplets from *iterable*.
|
766 |
+
|
767 |
+
>>> list(triplewise('ABCDE'))
|
768 |
+
[('A', 'B', 'C'), ('B', 'C', 'D'), ('C', 'D', 'E')]
|
769 |
+
|
770 |
+
"""
|
771 |
+
for (a, _), (b, c) in pairwise(pairwise(iterable)):
|
772 |
+
yield a, b, c
|
773 |
+
|
774 |
+
|
775 |
+
def sliding_window(iterable, n):
|
776 |
+
"""Return a sliding window of width *n* over *iterable*.
|
777 |
+
|
778 |
+
>>> list(sliding_window(range(6), 4))
|
779 |
+
[(0, 1, 2, 3), (1, 2, 3, 4), (2, 3, 4, 5)]
|
780 |
+
|
781 |
+
If *iterable* has fewer than *n* items, then nothing is yielded:
|
782 |
+
|
783 |
+
>>> list(sliding_window(range(3), 4))
|
784 |
+
[]
|
785 |
+
|
786 |
+
For a variant with more features, see :func:`windowed`.
|
787 |
+
"""
|
788 |
+
it = iter(iterable)
|
789 |
+
window = deque(islice(it, n - 1), maxlen=n)
|
790 |
+
for x in it:
|
791 |
+
window.append(x)
|
792 |
+
yield tuple(window)
|
793 |
+
|
794 |
+
|
795 |
+
def subslices(iterable):
|
796 |
+
"""Return all contiguous non-empty subslices of *iterable*.
|
797 |
+
|
798 |
+
>>> list(subslices('ABC'))
|
799 |
+
[['A'], ['A', 'B'], ['A', 'B', 'C'], ['B'], ['B', 'C'], ['C']]
|
800 |
+
|
801 |
+
This is similar to :func:`substrings`, but emits items in a different
|
802 |
+
order.
|
803 |
+
"""
|
804 |
+
seq = list(iterable)
|
805 |
+
slices = starmap(slice, combinations(range(len(seq) + 1), 2))
|
806 |
+
return map(operator.getitem, repeat(seq), slices)
|
807 |
+
|
808 |
+
|
809 |
+
def polynomial_from_roots(roots):
|
810 |
+
"""Compute a polynomial's coefficients from its roots.
|
811 |
+
|
812 |
+
>>> roots = [5, -4, 3] # (x - 5) * (x + 4) * (x - 3)
|
813 |
+
>>> polynomial_from_roots(roots) # x^3 - 4 * x^2 - 17 * x + 60
|
814 |
+
[1, -4, -17, 60]
|
815 |
+
"""
|
816 |
+
factors = zip(repeat(1), map(operator.neg, roots))
|
817 |
+
return list(reduce(convolve, factors, [1]))
|
818 |
+
|
819 |
+
|
820 |
+
def iter_index(iterable, value, start=0):
|
821 |
+
"""Yield the index of each place in *iterable* that *value* occurs,
|
822 |
+
beginning with index *start*.
|
823 |
+
|
824 |
+
See :func:`locate` for a more general means of finding the indexes
|
825 |
+
associated with particular values.
|
826 |
+
|
827 |
+
>>> list(iter_index('AABCADEAF', 'A'))
|
828 |
+
[0, 1, 4, 7]
|
829 |
+
"""
|
830 |
+
try:
|
831 |
+
seq_index = iterable.index
|
832 |
+
except AttributeError:
|
833 |
+
# Slow path for general iterables
|
834 |
+
it = islice(iterable, start, None)
|
835 |
+
i = start - 1
|
836 |
+
try:
|
837 |
+
while True:
|
838 |
+
i = i + operator.indexOf(it, value) + 1
|
839 |
+
yield i
|
840 |
+
except ValueError:
|
841 |
+
pass
|
842 |
+
else:
|
843 |
+
# Fast path for sequences
|
844 |
+
i = start - 1
|
845 |
+
try:
|
846 |
+
while True:
|
847 |
+
i = seq_index(value, i + 1)
|
848 |
+
yield i
|
849 |
+
except ValueError:
|
850 |
+
pass
|
851 |
+
|
852 |
+
|
853 |
+
def sieve(n):
|
854 |
+
"""Yield the primes less than n.
|
855 |
+
|
856 |
+
>>> list(sieve(30))
|
857 |
+
[2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
|
858 |
+
"""
|
859 |
+
data = bytearray((0, 1)) * (n // 2)
|
860 |
+
data[:3] = 0, 0, 0
|
861 |
+
limit = math.isqrt(n) + 1
|
862 |
+
for p in compress(range(limit), data):
|
863 |
+
data[p * p : n : p + p] = bytes(len(range(p * p, n, p + p)))
|
864 |
+
data[2] = 1
|
865 |
+
return iter_index(data, 1) if n > 2 else iter([])
|
866 |
+
|
867 |
+
|
868 |
+
def _batched(iterable, n):
|
869 |
+
"""Batch data into lists of length *n*. The last batch may be shorter.
|
870 |
+
|
871 |
+
>>> list(batched('ABCDEFG', 3))
|
872 |
+
[('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)]
|
873 |
+
|
874 |
+
On Python 3.12 and above, this is an alias for :func:`itertools.batched`.
|
875 |
+
"""
|
876 |
+
if n < 1:
|
877 |
+
raise ValueError('n must be at least one')
|
878 |
+
it = iter(iterable)
|
879 |
+
while True:
|
880 |
+
batch = tuple(islice(it, n))
|
881 |
+
if not batch:
|
882 |
+
break
|
883 |
+
yield batch
|
884 |
+
|
885 |
+
|
886 |
+
try:
|
887 |
+
from itertools import batched as itertools_batched
|
888 |
+
except ImportError:
|
889 |
+
batched = _batched
|
890 |
+
else:
|
891 |
+
|
892 |
+
def batched(iterable, n):
|
893 |
+
return itertools_batched(iterable, n)
|
894 |
+
|
895 |
+
batched.__doc__ = _batched.__doc__
|
896 |
+
|
897 |
+
|
898 |
+
def transpose(it):
|
899 |
+
"""Swap the rows and columns of the input.
|
900 |
+
|
901 |
+
>>> list(transpose([(1, 2, 3), (11, 22, 33)]))
|
902 |
+
[(1, 11), (2, 22), (3, 33)]
|
903 |
+
|
904 |
+
The caller should ensure that the dimensions of the input are compatible.
|
905 |
+
If the input is empty, no output will be produced.
|
906 |
+
"""
|
907 |
+
return _zip_strict(*it)
|
908 |
+
|
909 |
+
|
910 |
+
def matmul(m1, m2):
|
911 |
+
"""Multiply two matrices.
|
912 |
+
>>> list(matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]))
|
913 |
+
[(49, 80), (41, 60)]
|
914 |
+
|
915 |
+
The caller should ensure that the dimensions of the input matrices are
|
916 |
+
compatible with each other.
|
917 |
+
"""
|
918 |
+
n = len(m2[0])
|
919 |
+
return batched(starmap(_sumprod, product(m1, transpose(m2))), n)
|
920 |
+
|
921 |
+
|
922 |
+
def factor(n):
|
923 |
+
"""Yield the prime factors of n.
|
924 |
+
>>> list(factor(360))
|
925 |
+
[2, 2, 2, 3, 3, 5]
|
926 |
+
"""
|
927 |
+
for prime in sieve(math.isqrt(n) + 1):
|
928 |
+
while True:
|
929 |
+
if n % prime:
|
930 |
+
break
|
931 |
+
yield prime
|
932 |
+
n //= prime
|
933 |
+
if n == 1:
|
934 |
+
return
|
935 |
+
if n > 1:
|
936 |
+
yield n
|
937 |
+
|
938 |
+
|
939 |
+
def polynomial_eval(coefficients, x):
|
940 |
+
"""Evaluate a polynomial at a specific value.
|
941 |
+
|
942 |
+
Example: evaluating x^3 - 4 * x^2 - 17 * x + 60 at x = 2.5:
|
943 |
+
|
944 |
+
>>> coefficients = [1, -4, -17, 60]
|
945 |
+
>>> x = 2.5
|
946 |
+
>>> polynomial_eval(coefficients, x)
|
947 |
+
8.125
|
948 |
+
"""
|
949 |
+
n = len(coefficients)
|
950 |
+
if n == 0:
|
951 |
+
return x * 0 # coerce zero to the type of x
|
952 |
+
powers = map(pow, repeat(x), reversed(range(n)))
|
953 |
+
return _sumprod(coefficients, powers)
|
954 |
+
|
955 |
+
|
956 |
+
def sum_of_squares(it):
|
957 |
+
"""Return the sum of the squares of the input values.
|
958 |
+
|
959 |
+
>>> sum_of_squares([10, 20, 30])
|
960 |
+
1400
|
961 |
+
"""
|
962 |
+
return _sumprod(*tee(it))
|
963 |
+
|
964 |
+
|
965 |
+
def polynomial_derivative(coefficients):
|
966 |
+
"""Compute the first derivative of a polynomial.
|
967 |
+
|
968 |
+
Example: evaluating the derivative of x^3 - 4 * x^2 - 17 * x + 60
|
969 |
+
|
970 |
+
>>> coefficients = [1, -4, -17, 60]
|
971 |
+
>>> derivative_coefficients = polynomial_derivative(coefficients)
|
972 |
+
>>> derivative_coefficients
|
973 |
+
[3, -8, -17]
|
974 |
+
"""
|
975 |
+
n = len(coefficients)
|
976 |
+
powers = reversed(range(1, n))
|
977 |
+
return list(map(operator.mul, coefficients, powers))
|
lib/python3.11/site-packages/more_itertools/recipes.pyi
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Stubs for more_itertools.recipes"""
|
2 |
+
from __future__ import annotations
|
3 |
+
|
4 |
+
from typing import (
|
5 |
+
Any,
|
6 |
+
Callable,
|
7 |
+
Iterable,
|
8 |
+
Iterator,
|
9 |
+
overload,
|
10 |
+
Sequence,
|
11 |
+
Type,
|
12 |
+
TypeVar,
|
13 |
+
)
|
14 |
+
|
15 |
+
# Type and type variable definitions
|
16 |
+
_T = TypeVar('_T')
|
17 |
+
_U = TypeVar('_U')
|
18 |
+
|
19 |
+
def take(n: int, iterable: Iterable[_T]) -> list[_T]: ...
|
20 |
+
def tabulate(
|
21 |
+
function: Callable[[int], _T], start: int = ...
|
22 |
+
) -> Iterator[_T]: ...
|
23 |
+
def tail(n: int, iterable: Iterable[_T]) -> Iterator[_T]: ...
|
24 |
+
def consume(iterator: Iterable[_T], n: int | None = ...) -> None: ...
|
25 |
+
@overload
|
26 |
+
def nth(iterable: Iterable[_T], n: int) -> _T | None: ...
|
27 |
+
@overload
|
28 |
+
def nth(iterable: Iterable[_T], n: int, default: _U) -> _T | _U: ...
|
29 |
+
def all_equal(iterable: Iterable[object]) -> bool: ...
|
30 |
+
def quantify(
|
31 |
+
iterable: Iterable[_T], pred: Callable[[_T], bool] = ...
|
32 |
+
) -> int: ...
|
33 |
+
def pad_none(iterable: Iterable[_T]) -> Iterator[_T | None]: ...
|
34 |
+
def padnone(iterable: Iterable[_T]) -> Iterator[_T | None]: ...
|
35 |
+
def ncycles(iterable: Iterable[_T], n: int) -> Iterator[_T]: ...
|
36 |
+
def dotproduct(vec1: Iterable[object], vec2: Iterable[object]) -> object: ...
|
37 |
+
def flatten(listOfLists: Iterable[Iterable[_T]]) -> Iterator[_T]: ...
|
38 |
+
def repeatfunc(
|
39 |
+
func: Callable[..., _U], times: int | None = ..., *args: Any
|
40 |
+
) -> Iterator[_U]: ...
|
41 |
+
def pairwise(iterable: Iterable[_T]) -> Iterator[tuple[_T, _T]]: ...
|
42 |
+
def grouper(
|
43 |
+
iterable: Iterable[_T],
|
44 |
+
n: int,
|
45 |
+
incomplete: str = ...,
|
46 |
+
fillvalue: _U = ...,
|
47 |
+
) -> Iterator[tuple[_T | _U, ...]]: ...
|
48 |
+
def roundrobin(*iterables: Iterable[_T]) -> Iterator[_T]: ...
|
49 |
+
def partition(
|
50 |
+
pred: Callable[[_T], object] | None, iterable: Iterable[_T]
|
51 |
+
) -> tuple[Iterator[_T], Iterator[_T]]: ...
|
52 |
+
def powerset(iterable: Iterable[_T]) -> Iterator[tuple[_T, ...]]: ...
|
53 |
+
def unique_everseen(
|
54 |
+
iterable: Iterable[_T], key: Callable[[_T], _U] | None = ...
|
55 |
+
) -> Iterator[_T]: ...
|
56 |
+
def unique_justseen(
|
57 |
+
iterable: Iterable[_T], key: Callable[[_T], object] | None = ...
|
58 |
+
) -> Iterator[_T]: ...
|
59 |
+
@overload
|
60 |
+
def iter_except(
|
61 |
+
func: Callable[[], _T],
|
62 |
+
exception: Type[BaseException] | tuple[Type[BaseException], ...],
|
63 |
+
first: None = ...,
|
64 |
+
) -> Iterator[_T]: ...
|
65 |
+
@overload
|
66 |
+
def iter_except(
|
67 |
+
func: Callable[[], _T],
|
68 |
+
exception: Type[BaseException] | tuple[Type[BaseException], ...],
|
69 |
+
first: Callable[[], _U],
|
70 |
+
) -> Iterator[_T | _U]: ...
|
71 |
+
@overload
|
72 |
+
def first_true(
|
73 |
+
iterable: Iterable[_T], *, pred: Callable[[_T], object] | None = ...
|
74 |
+
) -> _T | None: ...
|
75 |
+
@overload
|
76 |
+
def first_true(
|
77 |
+
iterable: Iterable[_T],
|
78 |
+
default: _U,
|
79 |
+
pred: Callable[[_T], object] | None = ...,
|
80 |
+
) -> _T | _U: ...
|
81 |
+
def random_product(
|
82 |
+
*args: Iterable[_T], repeat: int = ...
|
83 |
+
) -> tuple[_T, ...]: ...
|
84 |
+
def random_permutation(
|
85 |
+
iterable: Iterable[_T], r: int | None = ...
|
86 |
+
) -> tuple[_T, ...]: ...
|
87 |
+
def random_combination(iterable: Iterable[_T], r: int) -> tuple[_T, ...]: ...
|
88 |
+
def random_combination_with_replacement(
|
89 |
+
iterable: Iterable[_T], r: int
|
90 |
+
) -> tuple[_T, ...]: ...
|
91 |
+
def nth_combination(
|
92 |
+
iterable: Iterable[_T], r: int, index: int
|
93 |
+
) -> tuple[_T, ...]: ...
|
94 |
+
def prepend(value: _T, iterator: Iterable[_U]) -> Iterator[_T | _U]: ...
|
95 |
+
def convolve(signal: Iterable[_T], kernel: Iterable[_T]) -> Iterator[_T]: ...
|
96 |
+
def before_and_after(
|
97 |
+
predicate: Callable[[_T], bool], it: Iterable[_T]
|
98 |
+
) -> tuple[Iterator[_T], Iterator[_T]]: ...
|
99 |
+
def triplewise(iterable: Iterable[_T]) -> Iterator[tuple[_T, _T, _T]]: ...
|
100 |
+
def sliding_window(
|
101 |
+
iterable: Iterable[_T], n: int
|
102 |
+
) -> Iterator[tuple[_T, ...]]: ...
|
103 |
+
def subslices(iterable: Iterable[_T]) -> Iterator[list[_T]]: ...
|
104 |
+
def polynomial_from_roots(roots: Sequence[_T]) -> list[_T]: ...
|
105 |
+
def iter_index(
|
106 |
+
iterable: Iterable[object],
|
107 |
+
value: Any,
|
108 |
+
start: int | None = ...,
|
109 |
+
) -> Iterator[int]: ...
|
110 |
+
def sieve(n: int) -> Iterator[int]: ...
|
111 |
+
def batched(
|
112 |
+
iterable: Iterable[_T],
|
113 |
+
n: int,
|
114 |
+
) -> Iterator[tuple[_T]]: ...
|
115 |
+
def transpose(
|
116 |
+
it: Iterable[Iterable[_T]],
|
117 |
+
) -> Iterator[tuple[_T, ...]]: ...
|
118 |
+
def matmul(m1: Sequence[_T], m2: Sequence[_T]) -> Iterator[tuple[_T]]: ...
|
119 |
+
def factor(n: int) -> Iterator[int]: ...
|
120 |
+
def polynomial_eval(coefficients: Sequence[_T], x: _U) -> _U: ...
|
121 |
+
def sum_of_squares(it: Iterable[_T]) -> _T: ...
|
122 |
+
def polynomial_derivative(coefficients: Sequence[_T]) -> list[_T]: ...
|
lib/python3.11/site-packages/mpmath/__init__.py
ADDED
@@ -0,0 +1,468 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__version__ = '1.3.0'
|
2 |
+
|
3 |
+
from .usertools import monitor, timing
|
4 |
+
|
5 |
+
from .ctx_fp import FPContext
|
6 |
+
from .ctx_mp import MPContext
|
7 |
+
from .ctx_iv import MPIntervalContext
|
8 |
+
|
9 |
+
fp = FPContext()
|
10 |
+
mp = MPContext()
|
11 |
+
iv = MPIntervalContext()
|
12 |
+
|
13 |
+
fp._mp = mp
|
14 |
+
mp._mp = mp
|
15 |
+
iv._mp = mp
|
16 |
+
mp._fp = fp
|
17 |
+
fp._fp = fp
|
18 |
+
mp._iv = iv
|
19 |
+
fp._iv = iv
|
20 |
+
iv._iv = iv
|
21 |
+
|
22 |
+
# XXX: extremely bad pickle hack
|
23 |
+
from . import ctx_mp as _ctx_mp
|
24 |
+
_ctx_mp._mpf_module.mpf = mp.mpf
|
25 |
+
_ctx_mp._mpf_module.mpc = mp.mpc
|
26 |
+
|
27 |
+
make_mpf = mp.make_mpf
|
28 |
+
make_mpc = mp.make_mpc
|
29 |
+
|
30 |
+
extraprec = mp.extraprec
|
31 |
+
extradps = mp.extradps
|
32 |
+
workprec = mp.workprec
|
33 |
+
workdps = mp.workdps
|
34 |
+
autoprec = mp.autoprec
|
35 |
+
maxcalls = mp.maxcalls
|
36 |
+
memoize = mp.memoize
|
37 |
+
|
38 |
+
mag = mp.mag
|
39 |
+
|
40 |
+
bernfrac = mp.bernfrac
|
41 |
+
|
42 |
+
qfrom = mp.qfrom
|
43 |
+
mfrom = mp.mfrom
|
44 |
+
kfrom = mp.kfrom
|
45 |
+
taufrom = mp.taufrom
|
46 |
+
qbarfrom = mp.qbarfrom
|
47 |
+
ellipfun = mp.ellipfun
|
48 |
+
jtheta = mp.jtheta
|
49 |
+
kleinj = mp.kleinj
|
50 |
+
eta = mp.eta
|
51 |
+
|
52 |
+
qp = mp.qp
|
53 |
+
qhyper = mp.qhyper
|
54 |
+
qgamma = mp.qgamma
|
55 |
+
qfac = mp.qfac
|
56 |
+
|
57 |
+
nint_distance = mp.nint_distance
|
58 |
+
|
59 |
+
plot = mp.plot
|
60 |
+
cplot = mp.cplot
|
61 |
+
splot = mp.splot
|
62 |
+
|
63 |
+
odefun = mp.odefun
|
64 |
+
|
65 |
+
jacobian = mp.jacobian
|
66 |
+
findroot = mp.findroot
|
67 |
+
multiplicity = mp.multiplicity
|
68 |
+
|
69 |
+
isinf = mp.isinf
|
70 |
+
isnan = mp.isnan
|
71 |
+
isnormal = mp.isnormal
|
72 |
+
isint = mp.isint
|
73 |
+
isfinite = mp.isfinite
|
74 |
+
almosteq = mp.almosteq
|
75 |
+
nan = mp.nan
|
76 |
+
rand = mp.rand
|
77 |
+
|
78 |
+
absmin = mp.absmin
|
79 |
+
absmax = mp.absmax
|
80 |
+
|
81 |
+
fraction = mp.fraction
|
82 |
+
|
83 |
+
linspace = mp.linspace
|
84 |
+
arange = mp.arange
|
85 |
+
|
86 |
+
mpmathify = convert = mp.convert
|
87 |
+
mpc = mp.mpc
|
88 |
+
|
89 |
+
mpi = iv._mpi
|
90 |
+
|
91 |
+
nstr = mp.nstr
|
92 |
+
nprint = mp.nprint
|
93 |
+
chop = mp.chop
|
94 |
+
|
95 |
+
fneg = mp.fneg
|
96 |
+
fadd = mp.fadd
|
97 |
+
fsub = mp.fsub
|
98 |
+
fmul = mp.fmul
|
99 |
+
fdiv = mp.fdiv
|
100 |
+
fprod = mp.fprod
|
101 |
+
|
102 |
+
quad = mp.quad
|
103 |
+
quadgl = mp.quadgl
|
104 |
+
quadts = mp.quadts
|
105 |
+
quadosc = mp.quadosc
|
106 |
+
quadsubdiv = mp.quadsubdiv
|
107 |
+
|
108 |
+
invertlaplace = mp.invertlaplace
|
109 |
+
invlaptalbot = mp.invlaptalbot
|
110 |
+
invlapstehfest = mp.invlapstehfest
|
111 |
+
invlapdehoog = mp.invlapdehoog
|
112 |
+
|
113 |
+
pslq = mp.pslq
|
114 |
+
identify = mp.identify
|
115 |
+
findpoly = mp.findpoly
|
116 |
+
|
117 |
+
richardson = mp.richardson
|
118 |
+
shanks = mp.shanks
|
119 |
+
levin = mp.levin
|
120 |
+
cohen_alt = mp.cohen_alt
|
121 |
+
nsum = mp.nsum
|
122 |
+
nprod = mp.nprod
|
123 |
+
difference = mp.difference
|
124 |
+
diff = mp.diff
|
125 |
+
diffs = mp.diffs
|
126 |
+
diffs_prod = mp.diffs_prod
|
127 |
+
diffs_exp = mp.diffs_exp
|
128 |
+
diffun = mp.diffun
|
129 |
+
differint = mp.differint
|
130 |
+
taylor = mp.taylor
|
131 |
+
pade = mp.pade
|
132 |
+
polyval = mp.polyval
|
133 |
+
polyroots = mp.polyroots
|
134 |
+
fourier = mp.fourier
|
135 |
+
fourierval = mp.fourierval
|
136 |
+
sumem = mp.sumem
|
137 |
+
sumap = mp.sumap
|
138 |
+
chebyfit = mp.chebyfit
|
139 |
+
limit = mp.limit
|
140 |
+
|
141 |
+
matrix = mp.matrix
|
142 |
+
eye = mp.eye
|
143 |
+
diag = mp.diag
|
144 |
+
zeros = mp.zeros
|
145 |
+
ones = mp.ones
|
146 |
+
hilbert = mp.hilbert
|
147 |
+
randmatrix = mp.randmatrix
|
148 |
+
swap_row = mp.swap_row
|
149 |
+
extend = mp.extend
|
150 |
+
norm = mp.norm
|
151 |
+
mnorm = mp.mnorm
|
152 |
+
|
153 |
+
lu_solve = mp.lu_solve
|
154 |
+
lu = mp.lu
|
155 |
+
qr = mp.qr
|
156 |
+
unitvector = mp.unitvector
|
157 |
+
inverse = mp.inverse
|
158 |
+
residual = mp.residual
|
159 |
+
qr_solve = mp.qr_solve
|
160 |
+
cholesky = mp.cholesky
|
161 |
+
cholesky_solve = mp.cholesky_solve
|
162 |
+
det = mp.det
|
163 |
+
cond = mp.cond
|
164 |
+
hessenberg = mp.hessenberg
|
165 |
+
schur = mp.schur
|
166 |
+
eig = mp.eig
|
167 |
+
eig_sort = mp.eig_sort
|
168 |
+
eigsy = mp.eigsy
|
169 |
+
eighe = mp.eighe
|
170 |
+
eigh = mp.eigh
|
171 |
+
svd_r = mp.svd_r
|
172 |
+
svd_c = mp.svd_c
|
173 |
+
svd = mp.svd
|
174 |
+
gauss_quadrature = mp.gauss_quadrature
|
175 |
+
|
176 |
+
expm = mp.expm
|
177 |
+
sqrtm = mp.sqrtm
|
178 |
+
powm = mp.powm
|
179 |
+
logm = mp.logm
|
180 |
+
sinm = mp.sinm
|
181 |
+
cosm = mp.cosm
|
182 |
+
|
183 |
+
mpf = mp.mpf
|
184 |
+
j = mp.j
|
185 |
+
exp = mp.exp
|
186 |
+
expj = mp.expj
|
187 |
+
expjpi = mp.expjpi
|
188 |
+
ln = mp.ln
|
189 |
+
im = mp.im
|
190 |
+
re = mp.re
|
191 |
+
inf = mp.inf
|
192 |
+
ninf = mp.ninf
|
193 |
+
sign = mp.sign
|
194 |
+
|
195 |
+
eps = mp.eps
|
196 |
+
pi = mp.pi
|
197 |
+
ln2 = mp.ln2
|
198 |
+
ln10 = mp.ln10
|
199 |
+
phi = mp.phi
|
200 |
+
e = mp.e
|
201 |
+
euler = mp.euler
|
202 |
+
catalan = mp.catalan
|
203 |
+
khinchin = mp.khinchin
|
204 |
+
glaisher = mp.glaisher
|
205 |
+
apery = mp.apery
|
206 |
+
degree = mp.degree
|
207 |
+
twinprime = mp.twinprime
|
208 |
+
mertens = mp.mertens
|
209 |
+
|
210 |
+
ldexp = mp.ldexp
|
211 |
+
frexp = mp.frexp
|
212 |
+
|
213 |
+
fsum = mp.fsum
|
214 |
+
fdot = mp.fdot
|
215 |
+
|
216 |
+
sqrt = mp.sqrt
|
217 |
+
cbrt = mp.cbrt
|
218 |
+
exp = mp.exp
|
219 |
+
ln = mp.ln
|
220 |
+
log = mp.log
|
221 |
+
log10 = mp.log10
|
222 |
+
power = mp.power
|
223 |
+
cos = mp.cos
|
224 |
+
sin = mp.sin
|
225 |
+
tan = mp.tan
|
226 |
+
cosh = mp.cosh
|
227 |
+
sinh = mp.sinh
|
228 |
+
tanh = mp.tanh
|
229 |
+
acos = mp.acos
|
230 |
+
asin = mp.asin
|
231 |
+
atan = mp.atan
|
232 |
+
asinh = mp.asinh
|
233 |
+
acosh = mp.acosh
|
234 |
+
atanh = mp.atanh
|
235 |
+
sec = mp.sec
|
236 |
+
csc = mp.csc
|
237 |
+
cot = mp.cot
|
238 |
+
sech = mp.sech
|
239 |
+
csch = mp.csch
|
240 |
+
coth = mp.coth
|
241 |
+
asec = mp.asec
|
242 |
+
acsc = mp.acsc
|
243 |
+
acot = mp.acot
|
244 |
+
asech = mp.asech
|
245 |
+
acsch = mp.acsch
|
246 |
+
acoth = mp.acoth
|
247 |
+
cospi = mp.cospi
|
248 |
+
sinpi = mp.sinpi
|
249 |
+
sinc = mp.sinc
|
250 |
+
sincpi = mp.sincpi
|
251 |
+
cos_sin = mp.cos_sin
|
252 |
+
cospi_sinpi = mp.cospi_sinpi
|
253 |
+
fabs = mp.fabs
|
254 |
+
re = mp.re
|
255 |
+
im = mp.im
|
256 |
+
conj = mp.conj
|
257 |
+
floor = mp.floor
|
258 |
+
ceil = mp.ceil
|
259 |
+
nint = mp.nint
|
260 |
+
frac = mp.frac
|
261 |
+
root = mp.root
|
262 |
+
nthroot = mp.nthroot
|
263 |
+
hypot = mp.hypot
|
264 |
+
fmod = mp.fmod
|
265 |
+
ldexp = mp.ldexp
|
266 |
+
frexp = mp.frexp
|
267 |
+
sign = mp.sign
|
268 |
+
arg = mp.arg
|
269 |
+
phase = mp.phase
|
270 |
+
polar = mp.polar
|
271 |
+
rect = mp.rect
|
272 |
+
degrees = mp.degrees
|
273 |
+
radians = mp.radians
|
274 |
+
atan2 = mp.atan2
|
275 |
+
fib = mp.fib
|
276 |
+
fibonacci = mp.fibonacci
|
277 |
+
lambertw = mp.lambertw
|
278 |
+
zeta = mp.zeta
|
279 |
+
altzeta = mp.altzeta
|
280 |
+
gamma = mp.gamma
|
281 |
+
rgamma = mp.rgamma
|
282 |
+
factorial = mp.factorial
|
283 |
+
fac = mp.fac
|
284 |
+
fac2 = mp.fac2
|
285 |
+
beta = mp.beta
|
286 |
+
betainc = mp.betainc
|
287 |
+
psi = mp.psi
|
288 |
+
#psi0 = mp.psi0
|
289 |
+
#psi1 = mp.psi1
|
290 |
+
#psi2 = mp.psi2
|
291 |
+
#psi3 = mp.psi3
|
292 |
+
polygamma = mp.polygamma
|
293 |
+
digamma = mp.digamma
|
294 |
+
#trigamma = mp.trigamma
|
295 |
+
#tetragamma = mp.tetragamma
|
296 |
+
#pentagamma = mp.pentagamma
|
297 |
+
harmonic = mp.harmonic
|
298 |
+
bernoulli = mp.bernoulli
|
299 |
+
bernfrac = mp.bernfrac
|
300 |
+
stieltjes = mp.stieltjes
|
301 |
+
hurwitz = mp.hurwitz
|
302 |
+
dirichlet = mp.dirichlet
|
303 |
+
bernpoly = mp.bernpoly
|
304 |
+
eulerpoly = mp.eulerpoly
|
305 |
+
eulernum = mp.eulernum
|
306 |
+
polylog = mp.polylog
|
307 |
+
clsin = mp.clsin
|
308 |
+
clcos = mp.clcos
|
309 |
+
gammainc = mp.gammainc
|
310 |
+
gammaprod = mp.gammaprod
|
311 |
+
binomial = mp.binomial
|
312 |
+
rf = mp.rf
|
313 |
+
ff = mp.ff
|
314 |
+
hyper = mp.hyper
|
315 |
+
hyp0f1 = mp.hyp0f1
|
316 |
+
hyp1f1 = mp.hyp1f1
|
317 |
+
hyp1f2 = mp.hyp1f2
|
318 |
+
hyp2f1 = mp.hyp2f1
|
319 |
+
hyp2f2 = mp.hyp2f2
|
320 |
+
hyp2f0 = mp.hyp2f0
|
321 |
+
hyp2f3 = mp.hyp2f3
|
322 |
+
hyp3f2 = mp.hyp3f2
|
323 |
+
hyperu = mp.hyperu
|
324 |
+
hypercomb = mp.hypercomb
|
325 |
+
meijerg = mp.meijerg
|
326 |
+
appellf1 = mp.appellf1
|
327 |
+
appellf2 = mp.appellf2
|
328 |
+
appellf3 = mp.appellf3
|
329 |
+
appellf4 = mp.appellf4
|
330 |
+
hyper2d = mp.hyper2d
|
331 |
+
bihyper = mp.bihyper
|
332 |
+
erf = mp.erf
|
333 |
+
erfc = mp.erfc
|
334 |
+
erfi = mp.erfi
|
335 |
+
erfinv = mp.erfinv
|
336 |
+
npdf = mp.npdf
|
337 |
+
ncdf = mp.ncdf
|
338 |
+
expint = mp.expint
|
339 |
+
e1 = mp.e1
|
340 |
+
ei = mp.ei
|
341 |
+
li = mp.li
|
342 |
+
ci = mp.ci
|
343 |
+
si = mp.si
|
344 |
+
chi = mp.chi
|
345 |
+
shi = mp.shi
|
346 |
+
fresnels = mp.fresnels
|
347 |
+
fresnelc = mp.fresnelc
|
348 |
+
airyai = mp.airyai
|
349 |
+
airybi = mp.airybi
|
350 |
+
airyaizero = mp.airyaizero
|
351 |
+
airybizero = mp.airybizero
|
352 |
+
scorergi = mp.scorergi
|
353 |
+
scorerhi = mp.scorerhi
|
354 |
+
ellipk = mp.ellipk
|
355 |
+
ellipe = mp.ellipe
|
356 |
+
ellipf = mp.ellipf
|
357 |
+
ellippi = mp.ellippi
|
358 |
+
elliprc = mp.elliprc
|
359 |
+
elliprj = mp.elliprj
|
360 |
+
elliprf = mp.elliprf
|
361 |
+
elliprd = mp.elliprd
|
362 |
+
elliprg = mp.elliprg
|
363 |
+
agm = mp.agm
|
364 |
+
jacobi = mp.jacobi
|
365 |
+
chebyt = mp.chebyt
|
366 |
+
chebyu = mp.chebyu
|
367 |
+
legendre = mp.legendre
|
368 |
+
legenp = mp.legenp
|
369 |
+
legenq = mp.legenq
|
370 |
+
hermite = mp.hermite
|
371 |
+
pcfd = mp.pcfd
|
372 |
+
pcfu = mp.pcfu
|
373 |
+
pcfv = mp.pcfv
|
374 |
+
pcfw = mp.pcfw
|
375 |
+
gegenbauer = mp.gegenbauer
|
376 |
+
laguerre = mp.laguerre
|
377 |
+
spherharm = mp.spherharm
|
378 |
+
besselj = mp.besselj
|
379 |
+
j0 = mp.j0
|
380 |
+
j1 = mp.j1
|
381 |
+
besseli = mp.besseli
|
382 |
+
bessely = mp.bessely
|
383 |
+
besselk = mp.besselk
|
384 |
+
besseljzero = mp.besseljzero
|
385 |
+
besselyzero = mp.besselyzero
|
386 |
+
hankel1 = mp.hankel1
|
387 |
+
hankel2 = mp.hankel2
|
388 |
+
struveh = mp.struveh
|
389 |
+
struvel = mp.struvel
|
390 |
+
angerj = mp.angerj
|
391 |
+
webere = mp.webere
|
392 |
+
lommels1 = mp.lommels1
|
393 |
+
lommels2 = mp.lommels2
|
394 |
+
whitm = mp.whitm
|
395 |
+
whitw = mp.whitw
|
396 |
+
ber = mp.ber
|
397 |
+
bei = mp.bei
|
398 |
+
ker = mp.ker
|
399 |
+
kei = mp.kei
|
400 |
+
coulombc = mp.coulombc
|
401 |
+
coulombf = mp.coulombf
|
402 |
+
coulombg = mp.coulombg
|
403 |
+
barnesg = mp.barnesg
|
404 |
+
superfac = mp.superfac
|
405 |
+
hyperfac = mp.hyperfac
|
406 |
+
loggamma = mp.loggamma
|
407 |
+
siegeltheta = mp.siegeltheta
|
408 |
+
siegelz = mp.siegelz
|
409 |
+
grampoint = mp.grampoint
|
410 |
+
zetazero = mp.zetazero
|
411 |
+
riemannr = mp.riemannr
|
412 |
+
primepi = mp.primepi
|
413 |
+
primepi2 = mp.primepi2
|
414 |
+
primezeta = mp.primezeta
|
415 |
+
bell = mp.bell
|
416 |
+
polyexp = mp.polyexp
|
417 |
+
expm1 = mp.expm1
|
418 |
+
log1p = mp.log1p
|
419 |
+
powm1 = mp.powm1
|
420 |
+
unitroots = mp.unitroots
|
421 |
+
cyclotomic = mp.cyclotomic
|
422 |
+
mangoldt = mp.mangoldt
|
423 |
+
secondzeta = mp.secondzeta
|
424 |
+
nzeros = mp.nzeros
|
425 |
+
backlunds = mp.backlunds
|
426 |
+
lerchphi = mp.lerchphi
|
427 |
+
stirling1 = mp.stirling1
|
428 |
+
stirling2 = mp.stirling2
|
429 |
+
squarew = mp.squarew
|
430 |
+
trianglew = mp.trianglew
|
431 |
+
sawtoothw = mp.sawtoothw
|
432 |
+
unit_triangle = mp.unit_triangle
|
433 |
+
sigmoid = mp.sigmoid
|
434 |
+
|
435 |
+
# be careful when changing this name, don't use test*!
|
436 |
+
def runtests():
|
437 |
+
"""
|
438 |
+
Run all mpmath tests and print output.
|
439 |
+
"""
|
440 |
+
import os.path
|
441 |
+
from inspect import getsourcefile
|
442 |
+
from .tests import runtests as tests
|
443 |
+
testdir = os.path.dirname(os.path.abspath(getsourcefile(tests)))
|
444 |
+
importdir = os.path.abspath(testdir + '/../..')
|
445 |
+
tests.testit(importdir, testdir)
|
446 |
+
|
447 |
+
def doctests(filter=[]):
|
448 |
+
import sys
|
449 |
+
from timeit import default_timer as clock
|
450 |
+
for i, arg in enumerate(sys.argv):
|
451 |
+
if '__init__.py' in arg:
|
452 |
+
filter = [sn for sn in sys.argv[i+1:] if not sn.startswith("-")]
|
453 |
+
break
|
454 |
+
import doctest
|
455 |
+
globs = globals().copy()
|
456 |
+
for obj in globs: #sorted(globs.keys()):
|
457 |
+
if filter:
|
458 |
+
if not sum([pat in obj for pat in filter]):
|
459 |
+
continue
|
460 |
+
sys.stdout.write(str(obj) + " ")
|
461 |
+
sys.stdout.flush()
|
462 |
+
t1 = clock()
|
463 |
+
doctest.run_docstring_examples(globs[obj], {}, verbose=("-v" in sys.argv))
|
464 |
+
t2 = clock()
|
465 |
+
print(round(t2-t1, 3))
|
466 |
+
|
467 |
+
if __name__ == '__main__':
|
468 |
+
doctests()
|
lib/python3.11/site-packages/mpmath/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (14.7 kB). View file
|
|
lib/python3.11/site-packages/mpmath/__pycache__/ctx_base.cpython-311.pyc
ADDED
Binary file (24.4 kB). View file
|
|
lib/python3.11/site-packages/mpmath/__pycache__/ctx_fp.cpython-311.pyc
ADDED
Binary file (13 kB). View file
|
|
lib/python3.11/site-packages/mpmath/__pycache__/ctx_iv.cpython-311.pyc
ADDED
Binary file (38.7 kB). View file
|
|
lib/python3.11/site-packages/mpmath/__pycache__/ctx_mp.cpython-311.pyc
ADDED
Binary file (71.2 kB). View file
|
|
lib/python3.11/site-packages/mpmath/__pycache__/ctx_mp_python.cpython-311.pyc
ADDED
Binary file (61.2 kB). View file
|
|
lib/python3.11/site-packages/mpmath/__pycache__/function_docs.cpython-311.pyc
ADDED
Binary file (285 kB). View file
|
|