reach-vb HF staff commited on
Commit
f14e74e
1 Parent(s): 42a2b88

ce304fafe19161978ad512b385c65426bad519e5a0b8fb3f0659eace3d2ea3cc

Browse files
Files changed (50) hide show
  1. lib/python3.11/site-packages/mlx-0.0.7.dist-info/INSTALLER +1 -0
  2. lib/python3.11/site-packages/mlx-0.0.7.dist-info/LICENSE +21 -0
  3. lib/python3.11/site-packages/mlx-0.0.7.dist-info/METADATA +122 -0
  4. lib/python3.11/site-packages/mlx-0.0.7.dist-info/RECORD +199 -0
  5. lib/python3.11/site-packages/mlx-0.0.7.dist-info/REQUESTED +0 -0
  6. lib/python3.11/site-packages/mlx-0.0.7.dist-info/WHEEL +5 -0
  7. lib/python3.11/site-packages/mlx-0.0.7.dist-info/top_level.txt +1 -0
  8. lib/python3.11/site-packages/mlx/nn/layers/base.py +532 -0
  9. lib/python3.11/site-packages/mlx/nn/layers/containers.py +24 -0
  10. lib/python3.11/site-packages/mlx/nn/layers/convolution.py +126 -0
  11. lib/python3.11/site-packages/mlx/nn/layers/dropout.py +137 -0
  12. lib/python3.11/site-packages/mlx/nn/layers/embedding.py +30 -0
  13. lib/python3.11/site-packages/mlx/nn/layers/linear.py +141 -0
  14. lib/python3.11/site-packages/mlx/nn/layers/normalization.py +368 -0
  15. lib/python3.11/site-packages/mlx/nn/layers/positional_encoding.py +199 -0
  16. lib/python3.11/site-packages/mlx/nn/layers/quantized.py +125 -0
  17. lib/python3.11/site-packages/mlx/nn/layers/transformer.py +354 -0
  18. lib/python3.11/site-packages/mlx/nn/losses.py +374 -0
  19. lib/python3.11/site-packages/mlx/nn/utils.py +33 -0
  20. lib/python3.11/site-packages/mlx/optimizers.py +500 -0
  21. lib/python3.11/site-packages/mlx/share/cmake/MLX/MLXConfig.cmake +57 -0
  22. lib/python3.11/site-packages/mlx/share/cmake/MLX/MLXConfigVersion.cmake +65 -0
  23. lib/python3.11/site-packages/mlx/share/cmake/MLX/MLXTargets-release.cmake +19 -0
  24. lib/python3.11/site-packages/mlx/share/cmake/MLX/MLXTargets.cmake +107 -0
  25. lib/python3.11/site-packages/mlx/share/cmake/MLX/extension.cmake +56 -0
  26. lib/python3.11/site-packages/mlx/utils.py +145 -0
  27. lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/INSTALLER +1 -0
  28. lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/LICENSE +19 -0
  29. lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/METADATA +253 -0
  30. lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/RECORD +16 -0
  31. lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/REQUESTED +0 -0
  32. lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/WHEEL +4 -0
  33. lib/python3.11/site-packages/more_itertools/__init__.py +6 -0
  34. lib/python3.11/site-packages/more_itertools/__init__.pyi +2 -0
  35. lib/python3.11/site-packages/more_itertools/__pycache__/__init__.cpython-311.pyc +0 -0
  36. lib/python3.11/site-packages/more_itertools/__pycache__/more.cpython-311.pyc +0 -0
  37. lib/python3.11/site-packages/more_itertools/__pycache__/recipes.cpython-311.pyc +0 -0
  38. lib/python3.11/site-packages/more_itertools/more.py +0 -0
  39. lib/python3.11/site-packages/more_itertools/more.pyi +684 -0
  40. lib/python3.11/site-packages/more_itertools/py.typed +0 -0
  41. lib/python3.11/site-packages/more_itertools/recipes.py +977 -0
  42. lib/python3.11/site-packages/more_itertools/recipes.pyi +122 -0
  43. lib/python3.11/site-packages/mpmath/__init__.py +468 -0
  44. lib/python3.11/site-packages/mpmath/__pycache__/__init__.cpython-311.pyc +0 -0
  45. lib/python3.11/site-packages/mpmath/__pycache__/ctx_base.cpython-311.pyc +0 -0
  46. lib/python3.11/site-packages/mpmath/__pycache__/ctx_fp.cpython-311.pyc +0 -0
  47. lib/python3.11/site-packages/mpmath/__pycache__/ctx_iv.cpython-311.pyc +0 -0
  48. lib/python3.11/site-packages/mpmath/__pycache__/ctx_mp.cpython-311.pyc +0 -0
  49. lib/python3.11/site-packages/mpmath/__pycache__/ctx_mp_python.cpython-311.pyc +0 -0
  50. lib/python3.11/site-packages/mpmath/__pycache__/function_docs.cpython-311.pyc +0 -0
lib/python3.11/site-packages/mlx-0.0.7.dist-info/INSTALLER ADDED
@@ -0,0 +1 @@
 
 
1
+ pip
lib/python3.11/site-packages/mlx-0.0.7.dist-info/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright © 2023 Apple Inc.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
lib/python3.11/site-packages/mlx-0.0.7.dist-info/METADATA ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: mlx
3
+ Version: 0.0.7
4
+ Summary: A framework for machine learning on Apple silicon.
5
+ Author: MLX Contributors
6
+ Author-email: [email protected]
7
+ Requires-Python: >=3.8
8
+ Description-Content-Type: text/markdown
9
+ License-File: LICENSE
10
+ Provides-Extra: dev
11
+ Requires-Dist: pre-commit ; extra == 'dev'
12
+ Requires-Dist: pybind11-stubgen ; extra == 'dev'
13
+ Provides-Extra: testing
14
+ Requires-Dist: numpy ; extra == 'testing'
15
+ Requires-Dist: torch ; extra == 'testing'
16
+
17
+ # MLX
18
+
19
+ [**Quickstart**](#quickstart) | [**Installation**](#installation) |
20
+ [**Documentation**](https://ml-explore.github.io/mlx/build/html/index.html) |
21
+ [**Examples**](#examples)
22
+
23
+ [![CircleCI](https://circleci.com/gh/ml-explore/mlx.svg?style=svg)](https://circleci.com/gh/ml-explore/mlx)
24
+
25
+ MLX is an array framework for machine learning on Apple silicon, brought to you
26
+ by Apple machine learning research.
27
+
28
+ Some key features of MLX include:
29
+
30
+ - **Familiar APIs**: MLX has a Python API that closely follows NumPy.
31
+ MLX also has a fully featured C++ API, which closely mirrors the Python API.
32
+ MLX has higher-level packages like `mlx.nn` and `mlx.optimizers` with APIs
33
+ that closely follow PyTorch to simplify building more complex models.
34
+
35
+ - **Composable function transformations**: MLX supports composable function
36
+ transformations for automatic differentiation, automatic vectorization,
37
+ and computation graph optimization.
38
+
39
+ - **Lazy computation**: Computations in MLX are lazy. Arrays are only
40
+ materialized when needed.
41
+
42
+ - **Dynamic graph construction**: Computation graphs in MLX are constructed
43
+ dynamically. Changing the shapes of function arguments does not trigger
44
+ slow compilations, and debugging is simple and intuitive.
45
+
46
+ - **Multi-device**: Operations can run on any of the supported devices
47
+ (currently the CPU and the GPU).
48
+
49
+ - **Unified memory**: A notable difference from MLX and other frameworks
50
+ is the *unified memory model*. Arrays in MLX live in shared memory.
51
+ Operations on MLX arrays can be performed on any of the supported
52
+ device types without transferring data.
53
+
54
+ MLX is designed by machine learning researchers for machine learning
55
+ researchers. The framework is intended to be user-friendly, but still efficient
56
+ to train and deploy models. The design of the framework itself is also
57
+ conceptually simple. We intend to make it easy for researchers to extend and
58
+ improve MLX with the goal of quickly exploring new ideas.
59
+
60
+ The design of MLX is inspired by frameworks like
61
+ [NumPy](https://numpy.org/doc/stable/index.html),
62
+ [PyTorch](https://pytorch.org/), [Jax](https://github.com/google/jax), and
63
+ [ArrayFire](https://arrayfire.org/).
64
+
65
+ ## Examples
66
+
67
+ The [MLX examples repo](https://github.com/ml-explore/mlx-examples) has a
68
+ variety of examples, including:
69
+
70
+ - [Transformer language model](https://github.com/ml-explore/mlx-examples/tree/main/transformer_lm) training.
71
+ - Large-scale text generation with
72
+ [LLaMA](https://github.com/ml-explore/mlx-examples/tree/main/llms/llama) and
73
+ finetuning with [LoRA](https://github.com/ml-explore/mlx-examples/tree/main/lora).
74
+ - Generating images with [Stable Diffusion](https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion).
75
+ - Speech recognition with [OpenAI's Whisper](https://github.com/ml-explore/mlx-examples/tree/main/whisper).
76
+
77
+ ## Quickstart
78
+
79
+ See the [quick start
80
+ guide](https://ml-explore.github.io/mlx/build/html/quick_start.html)
81
+ in the documentation.
82
+
83
+ ## Installation
84
+
85
+ MLX is available on [PyPI](https://pypi.org/project/mlx/). To install the Python API, run:
86
+
87
+ ```
88
+ pip install mlx
89
+ ```
90
+
91
+ Checkout the
92
+ [documentation](https://ml-explore.github.io/mlx/build/html/install.html#)
93
+ for more information on building the C++ and Python APIs from source.
94
+
95
+ ## Contributing
96
+
97
+ Check out the [contribution guidelines](CONTRIBUTING.md) for more information
98
+ on contributing to MLX. See the
99
+ [docs](https://ml-explore.github.io/mlx/build/html/install.html) for more
100
+ information on building from source, and running tests.
101
+
102
+ We are grateful for all of [our
103
+ contributors](ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
104
+ to MLX and wish to be acknowledged, please add your name to the list in your
105
+ pull request.
106
+
107
+ ## Citing MLX
108
+
109
+ The MLX software suite was initially developed with equal contribution by Awni
110
+ Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find
111
+ MLX useful in your research and wish to cite it, please use the following
112
+ BibTex entry:
113
+
114
+ ```
115
+ @software{mlx2023,
116
+ author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
117
+ title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
118
+ url = {https://github.com/ml-explore},
119
+ version = {0.0},
120
+ year = {2023},
121
+ }
122
+ ```
lib/python3.11/site-packages/mlx-0.0.7.dist-info/RECORD ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ mlx-0.0.7.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
2
+ mlx-0.0.7.dist-info/LICENSE,sha256=zPq3zLLqMG9xUxyMp3u1VQdgbNkHaLHjK4tSq1tIzwE,1066
3
+ mlx-0.0.7.dist-info/METADATA,sha256=CK9nqz-nrSgrFZpRYpc86IYysL_W7zObqiiczNHD5OA,4731
4
+ mlx-0.0.7.dist-info/RECORD,,
5
+ mlx-0.0.7.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ mlx-0.0.7.dist-info/WHEEL,sha256=9CmlYCCxe9UEGaJDl4r_tPKbyfF_OdzBb5_jMO3h1CA,110
7
+ mlx-0.0.7.dist-info/top_level.txt,sha256=NCVZzyms-obPlDq2Sa--m3oyfnnmLma38KCnE6gmcFE,4
8
+ mlx/__pycache__/_reprlib_fix.cpython-311.pyc,,
9
+ mlx/__pycache__/extension.cpython-311.pyc,,
10
+ mlx/__pycache__/optimizers.cpython-311.pyc,,
11
+ mlx/__pycache__/utils.cpython-311.pyc,,
12
+ mlx/_reprlib_fix.py,sha256=YoRQib4PSzl1FIMvs4BwkvipAwOrfufqtV89vNF4Yhc,518
13
+ mlx/core.cpython-311-darwin.so,sha256=4GMiWnLS2sZGsrIwB9mAryue5Lztf-SttAn_AzadmCc,1098840
14
+ mlx/extension.py,sha256=tqKrGeP1ElK9UfBMHUeOq56qWnMMxZrt7DVVM8QZ7yU,3769
15
+ mlx/include/metal_cpp/Foundation/Foundation.hpp,sha256=pYbEkzTrb2O0jJPzI0m9cBNvVSDTh_t39Dv7h0ZDma0,1808
16
+ mlx/include/metal_cpp/Foundation/NSArray.hpp,sha256=8nhS0gF060n7taK9aIynkuMPxlbKgqgijsj61yvhrPQ,4818
17
+ mlx/include/metal_cpp/Foundation/NSAutoreleasePool.hpp,sha256=K8_qNgYgyeP49WUSsz90Lo41qTb-x5Xjs23T9UNSkkU,3302
18
+ mlx/include/metal_cpp/Foundation/NSBundle.hpp,sha256=y0Le2v43bCAdIdrSClvLpRtCL8c4qdB3WK59Gz_AYYs,16845
19
+ mlx/include/metal_cpp/Foundation/NSData.hpp,sha256=rAPEZQn1W05O_sZ6LYexLhvv1CmllqyFWddGIrxGJ8s,2198
20
+ mlx/include/metal_cpp/Foundation/NSDate.hpp,sha256=353UgslEWyV-D1_T7_Sr1fCP4nVjixwJK2QGnRZQKgg,2085
21
+ mlx/include/metal_cpp/Foundation/NSDefines.hpp,sha256=c8BfIDwW9yCKiofyatyeLJ9YvnHO87PsBqSQtciN3tY,2203
22
+ mlx/include/metal_cpp/Foundation/NSDictionary.hpp,sha256=Lt7EBR6ppPFhhnoZq9UqMETHf_L1k9A0gs637ZTEQSU,5738
23
+ mlx/include/metal_cpp/Foundation/NSEnumerator.hpp,sha256=P4mkVmpWN0kdXnqqTMkyIovckr4_IZAoFzyg2IZDRs4,3141
24
+ mlx/include/metal_cpp/Foundation/NSError.hpp,sha256=rFqCSIXVsWOsH7Ikegv2irMF3aoMn7fM6teB9VipMgc,7802
25
+ mlx/include/metal_cpp/Foundation/NSLock.hpp,sha256=OjnoMx18m4E89eDPliPMw3cG8LtAqNsoBSxSO0l9DoY,4350
26
+ mlx/include/metal_cpp/Foundation/NSNotification.hpp,sha256=rpAgI4upmIG2inWZ9ACYy-Q8za7muZ2WmUtA5l8WGHw,4696
27
+ mlx/include/metal_cpp/Foundation/NSNumber.hpp,sha256=4a2fEU1Fm_wrWOs4zULhYQne0WymnRGsDeSGq1bDbuQ,22098
28
+ mlx/include/metal_cpp/Foundation/NSObjCRuntime.hpp,sha256=SHYqTLWQ2F8T4dfXJ98OlATOexXQ3NnXzmFyVpb4wFM,1665
29
+ mlx/include/metal_cpp/Foundation/NSObject.hpp,sha256=T_64vByLI4n3HIE05MVJZteUIAdN-y0oogsczYT-JTE,11105
30
+ mlx/include/metal_cpp/Foundation/NSPrivate.hpp,sha256=JYewPlM0BKHLnmQiItPlYuc1HNAO-akJjhXnEKUZO00,20217
31
+ mlx/include/metal_cpp/Foundation/NSProcessInfo.hpp,sha256=1NLI7-3mR8cseE_jE6c-D2_pRJKXA9fRl4V2uRZtiDI,16642
32
+ mlx/include/metal_cpp/Foundation/NSRange.hpp,sha256=7VFC565QzvzfF1CI91osTozl-d1mM9UToCcGyXmf5cc,3147
33
+ mlx/include/metal_cpp/Foundation/NSSet.hpp,sha256=0rORF6rpK2-R1AQCgT4CMNnGr18xZ8P5x82LEQLQ2sc,3368
34
+ mlx/include/metal_cpp/Foundation/NSSharedPtr.hpp,sha256=208e_0hJddLllsT16pJa4rQAyFqA1uYuez5gWyRS6EU,8496
35
+ mlx/include/metal_cpp/Foundation/NSString.hpp,sha256=AqRvSkbHPncbC36TJuR-4beEVT7ZPwnrVjmyoassaL8,11001
36
+ mlx/include/metal_cpp/Foundation/NSTypes.hpp,sha256=9IaLleaqYbdjlMvT_mqFlYDlzg4vF_8MQQTO2Sdp6vo,1893
37
+ mlx/include/metal_cpp/Foundation/NSURL.hpp,sha256=b1oGmKqL5F1WrWSaCK3Iqs0vxL8utVbOgtYW4RZUe_I,3668
38
+ mlx/include/metal_cpp/LICENSE.txt,sha256=mgnj5ca95epXBMNZliB1O16xkRb5_WScN7W8RJeiXOo,11344
39
+ mlx/include/metal_cpp/Metal/MTLAccelerationStructure.hpp,sha256=jqK65WOzpPUNE6H-XYb1PC3myk6vBIQr4fRrClCDZYs,85372
40
+ mlx/include/metal_cpp/Metal/MTLAccelerationStructureCommandEncoder.hpp,sha256=EC1flOCenKN-CNLIbHcpli7kP1PQGxkGUU8ZdMYQ4Qs,16289
41
+ mlx/include/metal_cpp/Metal/MTLAccelerationStructureTypes.hpp,sha256=I29afo7zt4fiCY2ip3vPN72lhb6Whgjpx0lcHabfib4,5800
42
+ mlx/include/metal_cpp/Metal/MTLArgument.hpp,sha256=_gQiklV3kRM2ukcYg0JCz3rUsK-zUOiZMYrp_thHjXU,23515
43
+ mlx/include/metal_cpp/Metal/MTLArgumentEncoder.hpp,sha256=PaV9gzYeC1JLqDdcAJK9H6DaPVV2rclNWvNPK70KUWs,10739
44
+ mlx/include/metal_cpp/Metal/MTLBinaryArchive.hpp,sha256=I8a_7I2WI-ftil5s4bT8N5rnUeRUt2HtsM2O9LW_t2M,5132
45
+ mlx/include/metal_cpp/Metal/MTLBlitCommandEncoder.hpp,sha256=6QF2GR7c4XQXfA5COwCb2Q6xRXVMkYnO1gmUlrNw05o,16482
46
+ mlx/include/metal_cpp/Metal/MTLBlitPass.hpp,sha256=zPt4836TDUuDPyRCyF7tJBEtEMpoWziJLlSRgWleURs,6982
47
+ mlx/include/metal_cpp/Metal/MTLBuffer.hpp,sha256=IDVEI2qXwq6rU_tQsPkOpma-IV7othYmJIT2QCOBsic,3542
48
+ mlx/include/metal_cpp/Metal/MTLCaptureManager.hpp,sha256=ANj6r5G9Gc1l9WikJQmZZlOWvfViJclzyPwKF59UvUg,7524
49
+ mlx/include/metal_cpp/Metal/MTLCaptureScope.hpp,sha256=DU5_HrnsaHBttsqHfXCk-oPr8ifhMdP-DiYpS1G-qC0,3733
50
+ mlx/include/metal_cpp/Metal/MTLCommandBuffer.hpp,sha256=29LVLOZlzB8G0bVfnHm867TXc0FPqbM7eXbNrZ2zRzE,17642
51
+ mlx/include/metal_cpp/Metal/MTLCommandEncoder.hpp,sha256=IQGAhZMikxwU1IINtvw3MSTLwlbQKRWuTNZnZAw_0m8,2946
52
+ mlx/include/metal_cpp/Metal/MTLCommandQueue.hpp,sha256=Uz74PTJAhqMi8XgLrj1RZOio-FYxLnTO9Msdy1dE7RA,2960
53
+ mlx/include/metal_cpp/Metal/MTLComputeCommandEncoder.hpp,sha256=Z2E_Ny8OJj4gvC_YZtku5sni2t4o9Mbm7xvJk2tpgHc,17157
54
+ mlx/include/metal_cpp/Metal/MTLComputePass.hpp,sha256=VOyFJlwaB89gmrtLdmnAz7gl97a5HbTYo-M6icWbN5M,7792
55
+ mlx/include/metal_cpp/Metal/MTLComputePipeline.hpp,sha256=WTBUdVq1sxvHocezHO5wpVmaa94Xnthq27kSUCedcMY,15017
56
+ mlx/include/metal_cpp/Metal/MTLCounters.hpp,sha256=RmlZCNFZIVPC4ue6FoIgYVq1Weq2Nu1GQD-G0yIwimg,9056
57
+ mlx/include/metal_cpp/Metal/MTLDefines.hpp,sha256=1MG6A6zwLk99Yknu7PAvBvIMefVhumr1kFoKxQ-6xKk,1900
58
+ mlx/include/metal_cpp/Metal/MTLDepthStencil.hpp,sha256=nP5OtW6QGUK7tzrXtJWJlch9bXkDYPYgp5BjMAptH6E,9631
59
+ mlx/include/metal_cpp/Metal/MTLDevice.hpp,sha256=HyHg83T53RuLU1y_ZthBvEIzpOi1Lu0nXptaJfdrhwc,63880
60
+ mlx/include/metal_cpp/Metal/MTLDrawable.hpp,sha256=07X3GSg7uJEWoOMWxI77cIWdr2NaBfJEXOtPJZr4M44,3193
61
+ mlx/include/metal_cpp/Metal/MTLDynamicLibrary.hpp,sha256=yBUYfobLLHd00UOGALdsEF0FBiF1hUDr7G3FdKy6uMY,2606
62
+ mlx/include/metal_cpp/Metal/MTLEvent.hpp,sha256=W8QBQWIUn-DcuBYzg0NB0L1W_Zs6G6QuQCeKUKM7tIo,4976
63
+ mlx/include/metal_cpp/Metal/MTLFence.hpp,sha256=DuxPaIkQN9fWi1E6RHg8YrKNeB2nxGOZMuHXd8TiE9c,1723
64
+ mlx/include/metal_cpp/Metal/MTLFunctionConstantValues.hpp,sha256=dz9s7_TYqjTJbthwEJTDoqdXsZ34JE2k4BERCYzWqWQ,3088
65
+ mlx/include/metal_cpp/Metal/MTLFunctionDescriptor.hpp,sha256=JkzBrxhlg4_ezt2NLomw-MsGuBqdV9embOekJow-mks,5410
66
+ mlx/include/metal_cpp/Metal/MTLFunctionHandle.hpp,sha256=HykCHezEwao-pmH6eL57qOyNt_G5RYsrsofxFDyHjiQ,1842
67
+ mlx/include/metal_cpp/Metal/MTLFunctionLog.hpp,sha256=ZiIPh0Ef5oeiYIWsFZJGnOXZj2bejYFDe_3NVoRT2nY,3262
68
+ mlx/include/metal_cpp/Metal/MTLFunctionStitching.hpp,sha256=fEipcQvZ2d8GeyPyn0AiyMup4doS9Yf-uzZRFpjkfyI,11407
69
+ mlx/include/metal_cpp/Metal/MTLHeaderBridge.hpp,sha256=qYNlYMWNAx4z4iNnMyrqdQr_NJ5md2n1HMSxmP5ykRM,105217
70
+ mlx/include/metal_cpp/Metal/MTLHeap.hpp,sha256=O9zMAFEKDGKAt710IDSwdsW3mYSyM7Mt2IVvJNJnAq0,11624
71
+ mlx/include/metal_cpp/Metal/MTLIOCommandBuffer.hpp,sha256=OweZPLAmFoJFypxGIWisfb6Bzlvgv5qIYB2MVu6jBLw,7417
72
+ mlx/include/metal_cpp/Metal/MTLIOCommandQueue.hpp,sha256=Al9BpGDKaSu5zQblMphnr9EXvrVee2ZaZWMba6L7QSs,7465
73
+ mlx/include/metal_cpp/Metal/MTLIOCompressor.hpp,sha256=s7UjSrJLyfQt69RzWBY692G1audMRzjIVR-7GjkZFOQ,3043
74
+ mlx/include/metal_cpp/Metal/MTLIndirectCommandBuffer.hpp,sha256=J3H4HH0p8qq_DYU8pQC_w4qDmYtNqj6FzoOo_AVGfxA,9715
75
+ mlx/include/metal_cpp/Metal/MTLIndirectCommandEncoder.hpp,sha256=AW1Qzyggw9k9EqS71CvlFUHu4lag6nvvzK-LnV4oG0k,11239
76
+ mlx/include/metal_cpp/Metal/MTLIntersectionFunctionTable.hpp,sha256=NUP7rHPWulH7M2M1wvzlyjjF7S0BJvG-_XFn4nWehcg,8181
77
+ mlx/include/metal_cpp/Metal/MTLLibrary.hpp,sha256=le2P__N8XA4RCN62L2ZB-F9uo2p8BjQ-A6j0kEOqzws,23904
78
+ mlx/include/metal_cpp/Metal/MTLLinkedFunctions.hpp,sha256=FmUoNjOGSegryBa0PDa1qcxEF0RUESixbd7OPyU-qDM,3867
79
+ mlx/include/metal_cpp/Metal/MTLParallelRenderCommandEncoder.hpp,sha256=uHyC4bf8BjFUz6xEkFg79Lqzz9CV-0NNnNoQ-U3xfsA,3899
80
+ mlx/include/metal_cpp/Metal/MTLPipeline.hpp,sha256=1EaeWaT8YNfvPaX0gqktNMOLtojBF6cVMlZiFT7vjyg,3775
81
+ mlx/include/metal_cpp/Metal/MTLPixelFormat.hpp,sha256=DYqej_xuztbD0lhsioWTYWO2Sg7jRShyJ7izoVlex5Y,5910
82
+ mlx/include/metal_cpp/Metal/MTLPrivate.hpp,sha256=vPTld95A83b1nIVT4qwsdcrttngMK5L5Ow-0HTr-zm4,6415
83
+ mlx/include/metal_cpp/Metal/MTLRasterizationRate.hpp,sha256=LtcYTnjLSTaxTw98WdUg_9R9BaAozvXOTHcsnmSqbTo,15447
84
+ mlx/include/metal_cpp/Metal/MTLRenderCommandEncoder.hpp,sha256=JNisXL9uSMMm_5B2vkVfuI8dgrMMMIrC8-E7P0pjdeo,60813
85
+ mlx/include/metal_cpp/Metal/MTLRenderPass.hpp,sha256=YQCYdLOnOWsUgIbQimBDFo4l_FmbZfeHIJFADL3NnRQ,32752
86
+ mlx/include/metal_cpp/Metal/MTLRenderPipeline.hpp,sha256=TuBYsiWg4hpumgaKccAn8TVAeLw26l_FdVGa6JsaMLA,72929
87
+ mlx/include/metal_cpp/Metal/MTLResource.hpp,sha256=Be28y-6-9GvFuugCUhnykiTUqbMqYgVovDjUdpIrFDA,5244
88
+ mlx/include/metal_cpp/Metal/MTLResourceStateCommandEncoder.hpp,sha256=eOjzuG-vDtQLJ8QOzhiqxXxiMEShXyn9cn-Uhmtigr8,5330
89
+ mlx/include/metal_cpp/Metal/MTLResourceStatePass.hpp,sha256=N5sMbFCjfEdf10_3VZJfQvDl3nLuydqrKnUgiJzynRU,7585
90
+ mlx/include/metal_cpp/Metal/MTLSampler.hpp,sha256=0rifEvFw4V8_kMmusAU4NXiZeZbCSX_L4x0K6QyEkog,10838
91
+ mlx/include/metal_cpp/Metal/MTLStageInputOutputDescriptor.hpp,sha256=LfDoa6Dj5WQFS2Qg_FQRCkVgBpbbfjBdFtn1SIzsCyI,13120
92
+ mlx/include/metal_cpp/Metal/MTLTexture.hpp,sha256=jM-mSxqfGBNXE6xwzx1M7CC8GOiyungRuop-6uTuf-M,24807
93
+ mlx/include/metal_cpp/Metal/MTLTypes.hpp,sha256=9_le5gsR93J2awZRAk_Rep_yAWKShAQOgI6faZKD_tg,4389
94
+ mlx/include/metal_cpp/Metal/MTLVersion.hpp,sha256=iaG4OsKd3JlcsoLEvyAx_5TdCRup1ykH0FUYqbu6Lxo,1509
95
+ mlx/include/metal_cpp/Metal/MTLVertexDescriptor.hpp,sha256=4oRbxIiJV2eafs-SqTXTVgPttIC22KtTZ0t4AXnLlgc,12017
96
+ mlx/include/metal_cpp/Metal/MTLVisibleFunctionTable.hpp,sha256=P5Cv_bIIfJ4ktplCeZ3fAhHtLEXngF5cAlQgAr55c-0,3833
97
+ mlx/include/metal_cpp/Metal/Metal.hpp,sha256=Bm5ldz-FxywsGmAdjZ8U2X5hlYjwfzHUjKCoW38T-w8,3190
98
+ mlx/include/metal_cpp/MetalFX/MTLFXDefines.hpp,sha256=Ms80cxZbVVAFYDogYahvmzyVZqkDezrwXKXB_79m3KE,2104
99
+ mlx/include/metal_cpp/MetalFX/MTLFXPrivate.hpp,sha256=LcubGmMx98T8OduyEawEx8swCEo_n5Z3ZfwuxwvAtz0,15147
100
+ mlx/include/metal_cpp/MetalFX/MTLFXSpatialScaler.hpp,sha256=8QCySJ9V5JfuTNAHujkvgKQJ_hdM-yV8cqdw5v8l1Ns,18385
101
+ mlx/include/metal_cpp/MetalFX/MTLFXTemporalScaler.hpp,sha256=57Qfhq3YqYOd0qwLfOBvz_viu55-Pmor0_IF99UWMAg,32887
102
+ mlx/include/metal_cpp/MetalFX/MetalFX.hpp,sha256=Vm0W_ycCRb0UOOvLIPYm5gIEHW1nikG5vTvoKfTgAtE,1350
103
+ mlx/include/metal_cpp/QuartzCore/CADefines.hpp,sha256=q6k-jUxe5uulNnsoqkM61OuyEVSJO4W8gIAVM6mSbwE,1895
104
+ mlx/include/metal_cpp/QuartzCore/CAMetalDrawable.hpp,sha256=RE4_Cp9FacJMh7AlqZw-IhelLMXvUv_YpXVLQ2aNVrU,2363
105
+ mlx/include/metal_cpp/QuartzCore/CAMetalLayer.hpp,sha256=i4WjUgWUsQVRPj8dxk4aBC06TKaXsKxRdgNNIgNe0pk,5502
106
+ mlx/include/metal_cpp/QuartzCore/CAPrivate.hpp,sha256=aFEVW34Ih_Gs1aGv5Rml0pNaFxDwWLP6qBQWw3N6EzA,5010
107
+ mlx/include/metal_cpp/QuartzCore/QuartzCore.hpp,sha256=-fQqOpndpszueKWJ4U_jovMID_s0QxCOJHpQoUj8JBE,1346
108
+ mlx/include/metal_cpp/README.md,sha256=ZVbXv3dwSy-n_EmPpYlB2cJXkUCZXdE0-jJsY9pHsd0,14584
109
+ mlx/include/metal_cpp/SingleHeader/MakeSingleHeader.py,sha256=JqlbeaangDoU2dRPyOfKgU5jkbmUghucuumqPEbZOWc,8984
110
+ mlx/include/metal_cpp/SingleHeader/__pycache__/MakeSingleHeader.cpython-311.pyc,,
111
+ mlx/include/mlx/3rdparty/pocketfft.h,sha256=Dq6iEwS_MkY5_xECLuY_r0io-rZK_lw8LR7---HX36Y,110508
112
+ mlx/include/mlx/allocator.h,sha256=r0AgIuBbvq67OeO8eUD8cb4a72HDF6-PISAnOMfC-kM,1519
113
+ mlx/include/mlx/array.h,sha256=D_dkEyq5zbBSc8kK5KNgQbAikvihW9Xn_1HlwI9uaps,11348
114
+ mlx/include/mlx/backend/accelerate/utils.h,sha256=ryQ9t4EI6UkMBCE48MVcGUSstbRDlwM9YI__vs4dTzY,752
115
+ mlx/include/mlx/backend/common/arange.h,sha256=J44RZZx_DB7DSELCtDFHBThHfLopu_JWj_5ew_q0ROw,1823
116
+ mlx/include/mlx/backend/common/binary.h,sha256=uvS_NdciYF-6VHhKrDNFL1Sd2Ag8hn_kClpYmcu6E9E,15169
117
+ mlx/include/mlx/backend/common/copy.h,sha256=5-b2CRTRoyCJ4WL5RkPAQXaLKaEn8pz-XojHGs4u1Rc,690
118
+ mlx/include/mlx/backend/common/erf.h,sha256=Wm6jXR40EgAmyh0zbGOvRbZh58LVhe3_P1-hwN9GPuo,274
119
+ mlx/include/mlx/backend/common/reduce.h,sha256=iR0xKDMZSEDrf2W0TLLXuqxVesf-nOioGu0MDGalT-0,11138
120
+ mlx/include/mlx/backend/common/threefry.h,sha256=sXTdUFRAgEXf789u7IhWIj7n-Gtzwo4HiQXDA6BwFyU,572
121
+ mlx/include/mlx/backend/common/unary.h,sha256=5vAiPfmNjXxMohP6RqZiOngCXcXa76RtrLDEyTeQobs,3356
122
+ mlx/include/mlx/backend/common/utils.h,sha256=4pS1LCC-Tux3x4Y0A3LeTyTLGDSxbjapNkNFWUNjrc8,610
123
+ mlx/include/mlx/backend/metal/allocator.h,sha256=lluCLBSXYbp6WGtl9zKsikpMHrptRjeZJ0Cf70zFkoc,1491
124
+ mlx/include/mlx/backend/metal/copy.h,sha256=jXW0_jDwX97NmnwigeZwovS_MZ1PeZix2iuZf7f9Xts,400
125
+ mlx/include/mlx/backend/metal/device.h,sha256=1Q0xSLn21dTzD5rWanoBBg1eMqOPf2958Z0rJL9GJag,2196
126
+ mlx/include/mlx/backend/metal/kernels/atomic.h,sha256=0Iph_jNZoTyTrR-iEXXDzIwFoRWOkhqndX6IHJxFhKg,9217
127
+ mlx/include/mlx/backend/metal/kernels/bf16.h,sha256=YdmOaznu1O1UNTyGlnpnEkvi1r44EouYNBAZ4SAX2yA,11904
128
+ mlx/include/mlx/backend/metal/kernels/bf16_math.h,sha256=VRM_iZpBcY3Sg3yxkNPzjU30x32zcmd0diIagA1D5l4,26337
129
+ mlx/include/mlx/backend/metal/kernels/complex.h,sha256=XDYKDiBdeM5KGpfpmwXmelZ00LmFCdxiGOIGoUMEeyI,3526
130
+ mlx/include/mlx/backend/metal/kernels/conv_params.h,sha256=pzJn2uZ8TqY1WXRyGRx-MzBYBOfMSyhiDPgu8LMSElk,592
131
+ mlx/include/mlx/backend/metal/kernels/defines.h,sha256=Bkplnhk_KYkqUzpB9TSlYUuXyucJQZYNyW9AR55g8fM,477
132
+ mlx/include/mlx/backend/metal/kernels/erf.h,sha256=k1CfVsb-rEQaXGYJcAbgVR90VwEY5w_88BhJCSsj97k,2736
133
+ mlx/include/mlx/backend/metal/kernels/gemm/conv.h,sha256=ilmFKKdzy8qwn1Zt0cu43f_mKsMLhIJy6DgIdx5ZU5I,14236
134
+ mlx/include/mlx/backend/metal/kernels/gemm/gemm.h,sha256=x80JGVUliIDkFyzoGGd5iC0veHZVPx_BFkeVdGzgVOk,16184
135
+ mlx/include/mlx/backend/metal/kernels/reduce.h,sha256=RgBsV24CtUU_dXCNzDP8sojJdjRz2bdF3PqB8UNJrlA,3554
136
+ mlx/include/mlx/backend/metal/kernels/utils.h,sha256=mRGC9B2cjzCZT2kxY__EM0QmRQQBL8GSHC0rg7_4cJo,7202
137
+ mlx/include/mlx/backend/metal/matmul.h,sha256=n32FxTg1zL-_oIaPRrs9QcI5TywhMemcm2KJ4zF-n0Q,592
138
+ mlx/include/mlx/backend/metal/metal.h,sha256=9kgNkbx0Pqy-W3o7stP0cjvSAcoU0I20dZEsKmsSAS4,552
139
+ mlx/include/mlx/backend/metal/mps/gemm.h,sha256=501lddlggxogceZI3wkWK98yIZurE0qgVxx0ICHyBxs,11307
140
+ mlx/include/mlx/backend/metal/utils.h,sha256=ape52Rl1et96MA4H_a0cbVveLvhASJlszqX6mxl-nis,4205
141
+ mlx/include/mlx/device.h,sha256=KtG7G03J5T0Ad2zn931OyRGfCa1XKn3ZyEUoIBY0No4,563
142
+ mlx/include/mlx/dtype.h,sha256=jrOMl7xEhOZ7Dn0udDfi-Mq8QbT3atbd7CL0SZ4Yxjk,2575
143
+ mlx/include/mlx/fft.h,sha256=MpW4UMjetc81i8pt3NJnTCvtXuaP92CX3w2iatYZHTw,4179
144
+ mlx/include/mlx/graph_utils.h,sha256=23pv6xyLYTl-l2o9Z3DqlcvGVGZEA-uIV29jamcK-1Q,593
145
+ mlx/include/mlx/io/load.h,sha256=vl7Z0pSJTjgG7KkSLXbe5oLjPhnVkWR9XdvC-ZCf_fk,2508
146
+ mlx/include/mlx/io/safetensor.h,sha256=ro-grE_iXiGbCgygGFSeqsvxD_DDiwKckCsyb424e-g,624
147
+ mlx/include/mlx/linalg.h,sha256=-CJ8NhHmycVsT29ObUmVZbxd2gwCpj8qj2RvOv-Glto,1771
148
+ mlx/include/mlx/mlx.h,sha256=raZfU4bjGTSRAZa8P1XkoxHHeiooTZVI9C63ydGwm8c,296
149
+ mlx/include/mlx/ops.h,sha256=322gdgz9l7pz-M0Pgpo0T6LbAe_w0RMV-A1-jRN_Le0,32224
150
+ mlx/include/mlx/primitives.h,sha256=uJhmcwsXyseQlfeJBDjLCNrhPM91vursiw8_R79yUGg,44227
151
+ mlx/include/mlx/random.h,sha256=aNoyWf-zAg9WKgqfe1yBpOvxedygIUZMQBgsd3kNnzQ,4887
152
+ mlx/include/mlx/scheduler.h,sha256=sZX0BU1J15e_IOzBsv95KwAQmU4wrayd3b6ClvxGA4M,3864
153
+ mlx/include/mlx/stream.h,sha256=7BvNtylObXbyIHZ0gCrbMb1bN3DUq7PfzEa-bIvqEXE,686
154
+ mlx/include/mlx/transforms.h,sha256=Qt8pjvxAPYMoa8Dz2WUcJBdeKEFKNdczV0O3UTE_XmY,6527
155
+ mlx/include/mlx/transforms_impl.h,sha256=RttUce4KHg3PLUFc9aRVjxenCIztlBSTlQq5LnVxKVg,542
156
+ mlx/include/mlx/types/bf16.h,sha256=fLnTB0fbL2vIs_CCvWhrr7SfDnFUEpDPmg7RNmrRnzE,6753
157
+ mlx/include/mlx/types/complex.h,sha256=oaTl_rCth3TO33sT2Db6SX--n9lIfE_dw7rjE06Qqmc,2878
158
+ mlx/include/mlx/types/fp16.h,sha256=j0E7-PKkVkCT-bOMRdBEYdvxUdo9rf6w80dNSNCYNaQ,8322
159
+ mlx/include/mlx/types/half_types.h,sha256=-D-ilDAh5shp1xp4lx0J87y9iFfBs43gLNi-XqzmLWA,1367
160
+ mlx/include/mlx/utils.h,sha256=cchUDEk5qbWNIPHf_cRwbsxnczz2_moUr_Dipz-cfhM,1524
161
+ mlx/lib/libmlx.dylib,sha256=ir7-RqHznJKyhGSBTwWnMPqYmbF3V3A0A8bvNi8GrJM,12420704
162
+ mlx/lib/mlx.metallib,sha256=Lu30EADtJwKD2hHYibsQGqTIjG-PDsaP5rBApb5CRQE,59495531
163
+ mlx/nn/__init__.py,sha256=b-1hqkAnzpthKJNKgRGVHsltNEqOU7URFMoY2Z46Buw,126
164
+ mlx/nn/__pycache__/__init__.cpython-311.pyc,,
165
+ mlx/nn/__pycache__/losses.cpython-311.pyc,,
166
+ mlx/nn/__pycache__/utils.cpython-311.pyc,,
167
+ mlx/nn/layers/__init__.py,sha256=2Ge23mO7W5VOMkogmcY2dqTKRRD3lY0g_WhVDP-9Hi0,1249
168
+ mlx/nn/layers/__pycache__/__init__.cpython-311.pyc,,
169
+ mlx/nn/layers/__pycache__/activations.cpython-311.pyc,,
170
+ mlx/nn/layers/__pycache__/base.cpython-311.pyc,,
171
+ mlx/nn/layers/__pycache__/containers.cpython-311.pyc,,
172
+ mlx/nn/layers/__pycache__/convolution.cpython-311.pyc,,
173
+ mlx/nn/layers/__pycache__/dropout.cpython-311.pyc,,
174
+ mlx/nn/layers/__pycache__/embedding.cpython-311.pyc,,
175
+ mlx/nn/layers/__pycache__/linear.cpython-311.pyc,,
176
+ mlx/nn/layers/__pycache__/normalization.cpython-311.pyc,,
177
+ mlx/nn/layers/__pycache__/positional_encoding.cpython-311.pyc,,
178
+ mlx/nn/layers/__pycache__/quantized.cpython-311.pyc,,
179
+ mlx/nn/layers/__pycache__/transformer.cpython-311.pyc,,
180
+ mlx/nn/layers/activations.py,sha256=ukopw8ZFpSIQijXscov80Gv_0j0jKh0Y4Z3qrUFzpkk,11826
181
+ mlx/nn/layers/base.py,sha256=iFp3PVGUhT9pAaWNnWtNY-yZ-uCQxvH-O1CYbKoGCug,19485
182
+ mlx/nn/layers/containers.py,sha256=Ke8gPBPZvrtrjK1qrbRNGtSV4rvBkYTKVBW2pQXrNvY,618
183
+ mlx/nn/layers/convolution.py,sha256=ZP43OBZTKF3ui_2Xrtx7VvAlLic_jcn2oL-MPLU7Rys,4060
184
+ mlx/nn/layers/dropout.py,sha256=1GR5uWU82eFd296WuPB7CqNcNj3acEIlS3_4ZAxnx-g,4529
185
+ mlx/nn/layers/embedding.py,sha256=5G7TvVyEaHTlQCASNVfBGea_1ywDVNqDkxHP2U-JEcs,878
186
+ mlx/nn/layers/linear.py,sha256=ekgkElIUlfdWZNRDZkoyQFTotzyjNpD_0v7Ac8pnako,4137
187
+ mlx/nn/layers/normalization.py,sha256=m92TwFfNQ4NbQlxtlTfJsWZlBJtO0_tR0M8H2d9wZEs,12021
188
+ mlx/nn/layers/positional_encoding.py,sha256=_4zbYLMGYI0EVq4b6eitmB2hgBXblcrJJYsq0KCSafs,6611
189
+ mlx/nn/layers/quantized.py,sha256=FA_NzLBheO2i5AClMWCkaFrNXJGScxlwvumqaJTIxAA,4136
190
+ mlx/nn/layers/transformer.py,sha256=IUTNGvVoaKLS7Gwes_yK3SqgX6voPgMqk7nmdE4nyI0,12235
191
+ mlx/nn/losses.py,sha256=1yWCxt-QSv4lGV9U6KJyeTzybQWW4LPZhipkLFx7hWk,12158
192
+ mlx/nn/utils.py,sha256=tQJ62bu2eAAhzYqV7QriQtNGrYSz85XGxHolEPkGiGs,999
193
+ mlx/optimizers.py,sha256=krPDwdaMy1JTlaVBgRHlg7PNqyxiVRkfNEbEf_wV-5g,17090
194
+ mlx/share/cmake/MLX/MLXConfig.cmake,sha256=02CzwjTAHLTyLzFMtE7LLzyX_CSHreASjFyiYEpfIp4,1884
195
+ mlx/share/cmake/MLX/MLXConfigVersion.cmake,sha256=p6TWi4EVJlFiJx_BxHZqpPBhkKaze00W3S-JRcRHD78,2762
196
+ mlx/share/cmake/MLX/MLXTargets-release.cmake,sha256=bUjlsca3AGGUHl1X9lfm3CiQiSFeYTkm2c7GLzITeJs,788
197
+ mlx/share/cmake/MLX/MLXTargets.cmake,sha256=vS5sI59URF-Pz8SPKMDVbyTPq1ENZ3zoR_VPzWToXkw,4705
198
+ mlx/share/cmake/MLX/extension.cmake,sha256=q09-X3ckTrFCZOc5UNyoqjUC94OQQSXXWW7Bmg7sxok,1637
199
+ mlx/utils.py,sha256=F-bmWn6a4VFRUvnSeZPvIFwT8Z6IHEtgEAfcLiGBrz8,4667
lib/python3.11/site-packages/mlx-0.0.7.dist-info/REQUESTED ADDED
File without changes
lib/python3.11/site-packages/mlx-0.0.7.dist-info/WHEEL ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Wheel-Version: 1.0
2
+ Generator: bdist_wheel (0.41.2)
3
+ Root-Is-Purelib: false
4
+ Tag: cp311-cp311-macosx_14_0_arm64
5
+
lib/python3.11/site-packages/mlx-0.0.7.dist-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ mlx
lib/python3.11/site-packages/mlx/nn/layers/base.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2023 Apple Inc.
2
+
3
+ import textwrap
4
+ from typing import Any, Callable, List, Optional, Tuple, Union
5
+
6
+ import mlx.core as mx
7
+ from mlx.utils import tree_flatten, tree_unflatten
8
+
9
+
10
+ class Module(dict):
11
+ """Base class for building neural networks with MLX.
12
+
13
+ All the layers provided in :mod:`mlx.nn.layers` subclass this class and
14
+ your models should do the same.
15
+
16
+ A ``Module`` can contain other ``Module`` instances or :class:`mlx.core.array`
17
+ instances in arbitrary nesting of python lists or dicts. The ``Module``
18
+ then allows recursively extracting all the :class:`mlx.core.array` instances
19
+ using :meth:`mlx.nn.Module.parameters`.
20
+
21
+ In addition, the ``Module`` has the concept of trainable and non trainable
22
+ parameters (called "frozen"). When using :func:`mlx.nn.value_and_grad`
23
+ the gradients are returned only with respect to the trainable parameters.
24
+ All arrays in a module are trainable unless they are added in the "frozen"
25
+ set by calling :meth:`freeze`.
26
+
27
+ .. code-block:: python
28
+
29
+ import mlx.core as mx
30
+ import mlx.nn as nn
31
+
32
+ class MyMLP(nn.Module):
33
+ def __init__(self, in_dims: int, out_dims: int, hidden_dims: int = 16):
34
+ super().__init__()
35
+
36
+ self.in_proj = nn.Linear(in_dims, hidden_dims)
37
+ self.out_proj = nn.Linear(hidden_dims, out_dims)
38
+
39
+ def __call__(self, x):
40
+ x = self.in_proj(x)
41
+ x = mx.maximum(x, 0)
42
+ return self.out_proj(x)
43
+
44
+ model = MyMLP(2, 1)
45
+
46
+ # All the model parameters are created but since MLX is lazy by
47
+ # default, they are not evaluated yet. Calling `mx.eval` actually
48
+ # allocates memory and initializes the parameters.
49
+ mx.eval(model.parameters())
50
+
51
+ # Setting a parameter to a new value is as simply as accessing that
52
+ # parameter and assigning a new array to it.
53
+ model.in_proj.weight = model.in_proj.weight * 2
54
+ mx.eval(model.parameters())
55
+ """
56
+
57
+ def __init__(self):
58
+ """Should be called by the subclasses of ``Module``."""
59
+ self._no_grad = set()
60
+ self._training = True
61
+
62
+ @property
63
+ def training(self):
64
+ """Boolean indicating if the model is in training mode."""
65
+ return self._training
66
+
67
+ def _extra_repr(self):
68
+ return ""
69
+
70
+ def __repr__(self):
71
+ children = tree_flatten(self.children(), is_leaf=self.is_module)
72
+ value = f"{type(self).__name__}({self._extra_repr()}"
73
+ for k, v in children:
74
+ value += "\n"
75
+ value += textwrap.indent(f"({k}): {repr(v)}", prefix=" ")
76
+ if children:
77
+ value += "\n"
78
+ value += ")"
79
+
80
+ return value
81
+
82
+ def __getattr__(self, key: str):
83
+ if key in self:
84
+ return self[key]
85
+ else:
86
+ raise AttributeError(f"{type(self)!r} has no attribute {key!r}")
87
+
88
+ def __setattr__(self, key: str, val: Any):
89
+ self[key] = val
90
+
91
+ def load_weights(
92
+ self,
93
+ file_or_weights: Union[str, List[Tuple[str, mx.array]]],
94
+ strict: bool = True,
95
+ ):
96
+ """
97
+ Update the model's weights from a ``.npz`` or a list.
98
+
99
+ Args:
100
+ file_or_weights (str or list(tuple(str, mx.array))): The path to
101
+ the weights ``.npz`` file or a list of pairs of parameter names
102
+ and arrays.
103
+ strict (bool, optional): If ``True`` then checks that the provided
104
+ weights exactly match the parameters of the model. Otherwise,
105
+ only the weights actually contained in the model are loaded and
106
+ shapes are not checked. Default: ``True``.
107
+
108
+ Example:
109
+
110
+ .. code-block:: python
111
+
112
+ import mlx.core as mx
113
+ import mlx.nn as nn
114
+ model = nn.Linear(10, 10)
115
+
116
+ # Load from file
117
+ model.load_weights("weights.npz")
118
+
119
+ # Load from list
120
+ weights = [
121
+ ("weight", mx.random.uniform(shape=(10, 10))),
122
+ ("bias", mx.zeros((10,))),
123
+ ]
124
+ model.load_weights(weights)
125
+
126
+ # Missing weight
127
+ weights = [
128
+ ("weight", mx.random.uniform(shape=(10, 10))),
129
+ ]
130
+
131
+ # Raises a ValueError exception
132
+ model.load_weights(weights)
133
+
134
+ # Ok, only updates the weight but not the bias
135
+ model.load_weights(weights, strict=False)
136
+ """
137
+ weights = file_or_weights
138
+ if isinstance(weights, str):
139
+ weights = list(mx.load(weights).items())
140
+
141
+ if strict:
142
+ new_weights = dict(weights)
143
+ curr_weights = dict(tree_flatten(self.parameters()))
144
+ if extras := (new_weights.keys() - curr_weights.keys()):
145
+ extras = " ".join(extras)
146
+ raise ValueError(f"Received parameters not in model: {extras}.")
147
+ if missing := (curr_weights.keys() - new_weights.keys()):
148
+ missing = " ".join(missing)
149
+ raise ValueError(f"Missing parameters: {missing}.")
150
+ for k, v in curr_weights.items():
151
+ v_new = new_weights[k]
152
+ if not isinstance(v_new, mx.array):
153
+ raise ValueError(
154
+ "Expected mx.array but received "
155
+ f"{type(v_new)} for parameter {k}"
156
+ )
157
+ if v_new.shape != v.shape:
158
+ raise ValueError(
159
+ f"Expected shape {v.shape} but received "
160
+ f" shape {v_new.shape} for parameter {k}"
161
+ )
162
+
163
+ self.update(tree_unflatten(weights))
164
+
165
+ def save_weights(self, file: str):
166
+ """
167
+ Save the model's weights to a ``.npz`` file.
168
+ """
169
+ mx.savez(file, **dict(tree_flatten(self.parameters())))
170
+
171
+ @staticmethod
172
+ def is_module(value):
173
+ return isinstance(value, Module)
174
+
175
+ @staticmethod
176
+ def valid_child_filter(module, key, value):
177
+ return isinstance(value, (dict, list))
178
+
179
+ @staticmethod
180
+ def valid_parameter_filter(module, key, value):
181
+ return isinstance(value, (dict, list, mx.array)) and not key.startswith("_")
182
+
183
+ @staticmethod
184
+ def trainable_parameter_filter(module, key, value):
185
+ return (
186
+ Module.valid_parameter_filter(module, key, value)
187
+ and key not in module._no_grad
188
+ )
189
+
190
+ def filter_and_map(
191
+ self,
192
+ filter_fn: Callable[["mlx.nn.Module", str, Any], bool],
193
+ map_fn: Optional[Callable] = None,
194
+ is_leaf_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
195
+ ):
196
+ """Recursively filter the contents of the module using ``filter_fn``,
197
+ namely only select keys and values where ``filter_fn`` returns true.
198
+
199
+ This is used to implement :meth:`parameters` and :meth:`trainable_parameters`
200
+ but it can also be used to extract any subset of the module's parameters.
201
+
202
+ Args:
203
+ filter_fn (Callable): Given a value, the key in which it is found
204
+ and the containing module, decide whether to keep the value or
205
+ drop it.
206
+ map_fn (Callable, optional): Optionally transform the value before
207
+ returning it.
208
+ is_leaf_fn (Callable, optional): Given a value, the key in which it
209
+ is found and the containing module decide if it is a leaf.
210
+
211
+ Returns:
212
+ A dictionary containing the contents of the module recursively filtered
213
+ """
214
+
215
+ map_fn = map_fn or (lambda x: x)
216
+ is_leaf_fn = is_leaf_fn or (
217
+ lambda m, k, v: not isinstance(v, (Module, dict, list))
218
+ )
219
+
220
+ def unwrap(vk, v):
221
+ if is_leaf_fn(self, vk, v):
222
+ return map_fn(v)
223
+
224
+ if isinstance(v, Module):
225
+ return v.filter_and_map(filter_fn, map_fn, is_leaf_fn)
226
+
227
+ if isinstance(v, dict):
228
+ nd = {}
229
+ for k, v in v.items():
230
+ tk = f"{vk}.{k}"
231
+ nd[k] = unwrap(tk, v) if filter_fn(self, tk, v) else {}
232
+ return nd
233
+
234
+ if isinstance(v, list):
235
+ nl = []
236
+ for i, vi in enumerate(v):
237
+ tk = f"{vk}.{i}"
238
+ nl.append(unwrap(tk, vi) if filter_fn(self, tk, vi) else {})
239
+ return nl
240
+
241
+ raise RuntimeError("Unexpected leaf found while traversing the module")
242
+
243
+ return {k: unwrap(k, v) for k, v in self.items() if filter_fn(self, k, v)}
244
+
245
+ def parameters(self):
246
+ """Recursively return all the :class:`mlx.core.array` members of this Module
247
+ as a dict of dicts and lists."""
248
+ return self.filter_and_map(self.valid_parameter_filter)
249
+
250
+ def trainable_parameters(self):
251
+ """Recursively return all the non frozen :class:`mlx.core.array` members of
252
+ this Module as a dict of dicts and lists."""
253
+ return self.filter_and_map(self.trainable_parameter_filter)
254
+
255
+ def children(self):
256
+ """Return the direct descendants of this Module instance."""
257
+ return self.filter_and_map(
258
+ self.valid_child_filter, is_leaf_fn=lambda m, k, v: isinstance(v, Module)
259
+ )
260
+
261
+ def leaf_modules(self):
262
+ """Return the submodules that do not contain other modules."""
263
+
264
+ def _is_leaf_module(m, k, v):
265
+ return isinstance(v, Module) and len(tree_flatten(v.children())) == 0
266
+
267
+ return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
268
+
269
+ def update(self, parameters: dict):
270
+ """Replace the parameters of this Module with the provided ones in the
271
+ dict of dicts and lists.
272
+
273
+ Commonly used by the optimizer to change the model to the updated
274
+ (optimized) parameters. Also used by the :meth:`mlx.nn.value_and_grad` to set the
275
+ tracers in the model in order to compute gradients.
276
+
277
+ The passed in parameters dictionary need not be a full dictionary
278
+ similar to :meth:`parameters`. Only the provided locations will be
279
+ updated.
280
+
281
+ Args:
282
+ parameters (dict): A complete or partial dictionary of the modules
283
+ parameters.
284
+ """
285
+
286
+ def apply(dst, parameters):
287
+ if isinstance(parameters, dict):
288
+ for k in parameters:
289
+ if k in dst:
290
+ current_value = dst[k]
291
+ new_value = parameters[k]
292
+ if isinstance(current_value, mx.array):
293
+ dst[k] = new_value
294
+ elif isinstance(current_value, Module):
295
+ current_value.update(new_value)
296
+ elif isinstance(current_value, (dict, list)):
297
+ apply(current_value, new_value)
298
+ elif isinstance(parameters, list):
299
+ for i in range(len(dst)):
300
+ current_value = dst[i]
301
+ new_value = parameters[i]
302
+ if isinstance(current_value, mx.array):
303
+ dst[i] = new_value
304
+ elif isinstance(current_value, Module):
305
+ current_value.update(new_value)
306
+ elif isinstance(current_value, (dict, list)):
307
+ apply(current_value, new_value)
308
+
309
+ apply(self, parameters)
310
+
311
+ def apply(
312
+ self,
313
+ map_fn: Callable[[mx.array], mx.array],
314
+ filter_fn: Optional[Callable[["mlx.nn.Module", str, Any], bool]] = None,
315
+ ):
316
+ """Map all the parameters using the provided ``map_fn`` and immediately
317
+ update the module with the mapped parameters.
318
+
319
+ For instance running ``model.apply(lambda x: x.astype(mx.float16))``
320
+ casts all parameters to 16 bit floats.
321
+
322
+ Args:
323
+ map_fn (Callable): Maps an array to another array
324
+ filter_fn (Callable, optional): Filter to select which arrays to
325
+ map (default: :meth:`Module.valid_parameter_filter`).
326
+ """
327
+ filter_fn = filter_fn or Module.valid_parameter_filter
328
+ self.update(self.filter_and_map(filter_fn, map_fn))
329
+
330
+ def update_modules(self, modules: dict):
331
+ """Replace the child modules of this :class:`Module` instance with the
332
+ provided ones in the dict of dicts and lists.
333
+
334
+ It is the equivalent of :meth:`Module.update` but for modules instead
335
+ of parameters and allows us to flexibly edit complex architectures by
336
+ programmatically swapping layers.
337
+
338
+ The passed in parameters dictionary need not be a full dictionary
339
+ similar to :meth:`parameters`. Only the provided locations will be
340
+ updated.
341
+
342
+ Args:
343
+ modules (dict): A complete or partial dictionary of the modules
344
+ submodules.
345
+ """
346
+
347
+ def apply(dst, modules):
348
+ if isinstance(modules, dict):
349
+ for k in modules:
350
+ if k in dst:
351
+ current_value = dst[k]
352
+ new_value = modules[k]
353
+ if self.is_module(current_value) and self.is_module(new_value):
354
+ dst[k] = new_value
355
+ elif isinstance(current_value, (dict, list)):
356
+ apply(current_value, new_value)
357
+ elif isinstance(modules, list):
358
+ for i in range(len(dst)):
359
+ current_value = dst[i]
360
+ new_value = modules[i]
361
+ if self.is_module(current_value) and self.is_module(new_value):
362
+ dst[i] = new_value
363
+ elif isinstance(current_value, (dict, list)):
364
+ apply(current_value, new_value)
365
+
366
+ apply(self, modules)
367
+
368
+ def apply_to_modules(self, apply_fn: Callable[[str, "mlx.nn.Module"], Any]):
369
+ """Apply a function to all the modules in this instance (including this
370
+ instance).
371
+
372
+ Args:
373
+ apply_fn (Callable): The function to apply to the modules.
374
+ """
375
+ module_stack = [("", self)]
376
+ while module_stack:
377
+ prefix, mod = module_stack.pop()
378
+ apply_fn(prefix, mod)
379
+ prefix = "." + prefix if prefix else ""
380
+ module_stack.extend(
381
+ tree_flatten(mod.children(), prefix=prefix, is_leaf=self.is_module)
382
+ )
383
+
384
+ def modules(self):
385
+ """Return a list with all the modules in this instance.
386
+
387
+ Returns:
388
+ A list of :class:`mlx.nn.Module` instances.
389
+ """
390
+ modulelist = []
391
+ self.apply_to_modules(lambda k, m: modulelist.append(m))
392
+ return modulelist
393
+
394
+ def named_modules(self):
395
+ """Return a list with all the modules in this instance and their name
396
+ with dot notation.
397
+
398
+ Returns:
399
+ A list of tuples (str, :class:`mlx.nn.Module`).
400
+ """
401
+ modulelist = []
402
+ self.apply_to_modules(lambda k, m: modulelist.append((k, m)))
403
+ return modulelist
404
+
405
+ def _validate_keys(self, keys, strict):
406
+ keys = keys if isinstance(keys, list) else [keys]
407
+ if strict:
408
+ for k in keys:
409
+ if k not in self:
410
+ raise KeyError(f"Module doesn't contain member {k}.")
411
+ return keys
412
+
413
+ def freeze(
414
+ self,
415
+ *,
416
+ recurse: bool = True,
417
+ keys: Optional[Union[str, List[str]]] = None,
418
+ strict: bool = False,
419
+ ):
420
+ """Freeze the Module's parameters or some of them. Freezing a parameter means not
421
+ computing gradients for it.
422
+
423
+ This function is idempotent i.e. freezing a frozen model is a no-op.
424
+
425
+ Example:
426
+ For instance to only train the attention parameters from a Transformer:
427
+
428
+ .. code-block:: python
429
+
430
+ model = nn.Transformer()
431
+ model.freeze()
432
+ model.apply_to_modules(lambda k, v: v.unfreeze() if k.endswith("attention") else None)
433
+
434
+ Args:
435
+ recurse (bool, optional): If True then freeze the parameters of the
436
+ submodules as well. Default: ``True``.
437
+ keys (str or list[str], optional): If provided then only these
438
+ parameters will be frozen otherwise all the parameters of a
439
+ module. For instance freeze all biases by calling
440
+ ``module.freeze(keys="bias")``.
441
+ strict (bool, optional): If set to ``True`` validate that the passed keys exist.
442
+ Default: ``False``.
443
+ """
444
+
445
+ def _freeze_impl(_, m):
446
+ local_keys = keys
447
+ if local_keys is None:
448
+ local_keys = tree_flatten(
449
+ m.filter_and_map(
450
+ lambda m, k, v: (not isinstance(v, Module))
451
+ and m.valid_parameter_filter(m, k, v)
452
+ )
453
+ )
454
+ local_keys = [k for (k, v) in local_keys]
455
+
456
+ local_keys = m._validate_keys(local_keys, strict)
457
+ m._no_grad.update(local_keys)
458
+
459
+ if recurse:
460
+ self.apply_to_modules(_freeze_impl)
461
+ else:
462
+ _freeze_impl("", self)
463
+
464
+ def unfreeze(
465
+ self,
466
+ *,
467
+ recurse: bool = True,
468
+ keys: Optional[Union[str, List[str]]] = None,
469
+ strict: bool = False,
470
+ ):
471
+ """Unfreeze the Module's parameters or some of them.
472
+
473
+ This function is idempotent ie unfreezing a model that is not frozen is
474
+ a noop.
475
+
476
+ Example:
477
+
478
+ For instance to only train the biases of a Transformer one can do:
479
+
480
+ .. code-block:: python
481
+
482
+ model = nn.Transformer()
483
+ model.freeze()
484
+ model.unfreeze(keys="bias")
485
+
486
+ Args:
487
+ recurse (bool, optional): If True then unfreeze the parameters of the
488
+ submodules as well. Default: ``True``.
489
+ keys (str or list[str], optional): If provided then only these
490
+ parameters will be unfrozen otherwise all the parameters of a
491
+ module. For instance unfreeze all biases by calling
492
+ ``module.unfreeze(keys="bias")``.
493
+ strict (bool, optional): If set to ``True`` validate that the passed keys exist.
494
+ Default: ``False``.
495
+ """
496
+
497
+ def _unfreeze_impl(_, m):
498
+ if keys is None:
499
+ m._no_grad.clear()
500
+
501
+ else:
502
+ local_keys = m._validate_keys(keys, strict)
503
+ m._no_grad.difference_update(local_keys)
504
+
505
+ if recurse:
506
+ self.apply_to_modules(_unfreeze_impl)
507
+ else:
508
+ _unfreeze_impl("", self)
509
+
510
+ def train(self, mode: bool = True):
511
+ """Set the model in or out of training mode.
512
+
513
+ Training mode only applies to certain layers. For example
514
+ :obj:`Dropout` applies a random mask in training mode, but is the
515
+ identity in evaluation mode.
516
+
517
+ Args:
518
+ mode (bool): Indicate if the model should be in training or
519
+ evaluation mode. Default: ``True``.
520
+ """
521
+
522
+ def _set_train(_, m):
523
+ m._training = mode
524
+
525
+ self.apply_to_modules(_set_train)
526
+
527
+ def eval(self):
528
+ """Set the model to evaluation mode.
529
+
530
+ See :func:`train`.
531
+ """
532
+ self.train(False)
lib/python3.11/site-packages/mlx/nn/layers/containers.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2023 Apple Inc.
2
+
3
+ from mlx.nn.layers.base import Module
4
+
5
+
6
+ class Sequential(Module):
7
+ """A layer that calls the passed callables in order.
8
+
9
+ We can pass either modules or plain callables to the Sequential module. If
10
+ our functions have learnable parameters they should be implemented as
11
+ ``nn.Module`` instances.
12
+
13
+ Args:
14
+ modules (tuple of Callables): The modules to call in order
15
+ """
16
+
17
+ def __init__(self, *modules):
18
+ super().__init__()
19
+ self.layers = list(modules)
20
+
21
+ def __call__(self, x):
22
+ for m in self.layers:
23
+ x = m(x)
24
+ return x
lib/python3.11/site-packages/mlx/nn/layers/convolution.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2023 Apple Inc.
2
+
3
+ import math
4
+ from typing import Union
5
+
6
+ import mlx.core as mx
7
+ from mlx.nn.layers.base import Module
8
+
9
+
10
+ class Conv1d(Module):
11
+ """Applies a 1-dimensional convolution over the multi-channel input sequence.
12
+
13
+ The channels are expected to be last i.e. the input shape should be ``NLC`` where:
14
+ - ``N`` is the batch dimension
15
+ - ``L`` is the sequence length
16
+ - ``C`` is the number of input channels
17
+
18
+ Args:
19
+ in_channels (int): The number of input channels
20
+ out_channels (int): The number of output channels
21
+ kernel_size (int): The size of the convolution filters
22
+ stride (int, optional): The stride when applying the filter.
23
+ Default: 1.
24
+ padding (int, optional): How many positions to 0-pad the input with.
25
+ Default: 0.
26
+ bias (bool, optional): If ``True`` add a learnable bias to the output.
27
+ Default: ``True``
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ in_channels: int,
33
+ out_channels: int,
34
+ kernel_size: int,
35
+ stride: int = 1,
36
+ padding: int = 0,
37
+ bias: bool = True,
38
+ ):
39
+ super().__init__()
40
+
41
+ scale = math.sqrt(1 / (in_channels * kernel_size))
42
+ self.weight = mx.random.uniform(
43
+ low=-scale,
44
+ high=scale,
45
+ shape=(out_channels, kernel_size, in_channels),
46
+ )
47
+ if bias:
48
+ self.bias = mx.zeros((out_channels,))
49
+
50
+ self.padding = padding
51
+ self.stride = stride
52
+
53
+ def _extra_repr(self):
54
+ return (
55
+ f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
56
+ f"kernel_size={self.weight.shape[1]}, stride={self.stride}, "
57
+ f"padding={self.padding}, bias={'bias' in self}"
58
+ )
59
+
60
+ def __call__(self, x):
61
+ y = mx.conv1d(x, self.weight, self.stride, self.padding)
62
+ if "bias" in self:
63
+ y = y + self.bias
64
+ return y
65
+
66
+
67
+ class Conv2d(Module):
68
+ """Applies a 2-dimensional convolution over the multi-channel input image.
69
+
70
+ The channels are expected to be last i.e. the input shape should be ``NHWC`` where:
71
+ - ``N`` is the batch dimension
72
+ - ``H`` is the input image height
73
+ - ``W`` is the input image width
74
+ - ``C`` is the number of input channels
75
+
76
+ Args:
77
+ in_channels (int): The number of input channels.
78
+ out_channels (int): The number of output channels.
79
+ kernel_size (int or tuple): The size of the convolution filters.
80
+ stride (int or tuple, optional): The size of the stride when
81
+ applying the filter. Default: 1.
82
+ padding (int or tuple, optional): How many positions to 0-pad
83
+ the input with. Default: 0.
84
+ bias (bool, optional): If ``True`` add a learnable bias to the
85
+ output. Default: ``True``
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ in_channels: int,
91
+ out_channels: int,
92
+ kernel_size: Union[int, tuple],
93
+ stride: Union[int, tuple] = 1,
94
+ padding: Union[int, tuple] = 0,
95
+ bias: bool = True,
96
+ ):
97
+ super().__init__()
98
+
99
+ kernel_size, stride, padding = map(
100
+ lambda x: (x, x) if isinstance(x, int) else x,
101
+ (kernel_size, stride, padding),
102
+ )
103
+ scale = math.sqrt(1 / (in_channels * kernel_size[0] * kernel_size[1]))
104
+ self.weight = mx.random.uniform(
105
+ low=-scale,
106
+ high=scale,
107
+ shape=(out_channels, *kernel_size, in_channels),
108
+ )
109
+ if bias:
110
+ self.bias = mx.zeros((out_channels,))
111
+
112
+ self.padding = padding
113
+ self.stride = stride
114
+
115
+ def _extra_repr(self):
116
+ return (
117
+ f"{self.weight.shape[-1]}, {self.weight.shape[0]}, "
118
+ f"kernel_size={self.weight.shape[1:2]}, stride={self.stride}, "
119
+ f"padding={self.padding}, bias={'bias' in self}"
120
+ )
121
+
122
+ def __call__(self, x):
123
+ y = mx.conv2d(x, self.weight, self.stride, self.padding)
124
+ if "bias" in self:
125
+ y = y + self.bias
126
+ return y
lib/python3.11/site-packages/mlx/nn/layers/dropout.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2023 Apple Inc.
2
+
3
+ import mlx.core as mx
4
+ from mlx.nn.layers.base import Module
5
+
6
+
7
+ class Dropout(Module):
8
+ r"""Randomly zero a portion of the elements during training.
9
+
10
+ The remaining elements are multiplied with :math:`\frac{1}{1-p}` where
11
+ :math:`p` is the probability of zeroing an element. This is done so the
12
+ expected value of a given element will remain the same.
13
+
14
+ Args:
15
+ p (float): The probability to zero an element
16
+ """
17
+
18
+ def __init__(self, p: float = 0.5):
19
+ super().__init__()
20
+
21
+ if p < 0 or p >= 1:
22
+ raise ValueError(f"The dropout probability {p} is not in [0, 1)")
23
+
24
+ self._p_1 = 1 - p
25
+
26
+ def _extra_repr(self):
27
+ return f"p={1-self._p_1}"
28
+
29
+ def __call__(self, x):
30
+ if self._p_1 == 1 or not self.training:
31
+ return x
32
+
33
+ mask = mx.random.bernoulli(self._p_1, x.shape)
34
+
35
+ return (1 / self._p_1) * mask * x
36
+
37
+
38
+ class Dropout2d(Module):
39
+ r"""Apply 2D channel-wise dropout during training.
40
+
41
+ Randomly zero out entire channels independently with probability :math:`p`.
42
+ This layer expects the channels to be last, i.e. the input shape should be
43
+ ``NWHC`` or ``WHC`` where:``N`` is the batch dimension,``H`` is the input
44
+ image height,``W`` is the input image width, and``C`` is the number of
45
+ input channels
46
+
47
+ The remaining channels are scaled by :math:`\frac{1}{1-p}` to
48
+ maintain the expected value of each element. Unlike traditional dropout,
49
+ which zeros individual entries, this layer zeros entire channels. This is
50
+ beneficial for early convolution layers where adjacent pixels are
51
+ correlated. In such case, traditional dropout may not effectively
52
+ regularize activations. For more details, see [1].
53
+
54
+ [1]: Thompson, J., Goroshin, R., Jain, A., LeCun, Y. and Bregler C., 2015.
55
+ Efficient Object Localization Using Convolutional Networks. CVPR 2015.
56
+
57
+ Args:
58
+ p (float): Probability of zeroing a channel during training.
59
+ """
60
+
61
+ def __init__(self, p: float = 0.5):
62
+ super().__init__()
63
+
64
+ if p < 0 or p >= 1:
65
+ raise ValueError(f"The dropout probability {p} is not in [0, 1)")
66
+
67
+ self._p_1 = 1 - p
68
+
69
+ def _extra_repr(self):
70
+ return f"p={1-self._p_1}"
71
+
72
+ def __call__(self, x):
73
+ if x.ndim not in (3, 4):
74
+ raise ValueError(
75
+ f"Received input with {x.ndim} dimensions. Expected 3 or 4 dimensions."
76
+ )
77
+
78
+ if self._p_1 == 1 or not self.training:
79
+ return x
80
+
81
+ # Dropout is applied on the whole channel
82
+ # 3D input: (1, 1, C)
83
+ # 4D input: (B, 1, 1, C)
84
+ mask_shape = x.shape
85
+ mask_shape[-2] = mask_shape[-3] = 1
86
+
87
+ mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
88
+ return (1 / self._p_1) * mask * x
89
+
90
+
91
+ class Dropout3d(Module):
92
+ r"""Apply 3D channel-wise dropout during training.
93
+
94
+ Randomly zero out entire channels independently with probability :math:`p`.
95
+ This layer expects the channels to be last, i.e., the input shape should be
96
+ `NDHWC` or `DHWC` where: `N` is the batch dimension, `D` is the depth,
97
+ `H` is the input image height, `W` is the input image width, and `C` is
98
+ the number of input channels.
99
+
100
+ The remaining channels are scaled by :math:`\frac{1}{1-p}` to
101
+ maintain the expected value of each element. Unlike traditional dropout,
102
+ which zeros individual entries, this layer zeros entire channels. This is
103
+ often beneficial for convolutional layers processing 3D data, like in
104
+ medical imaging or video processing.
105
+
106
+ Args:
107
+ p (float): Probability of zeroing a channel during training.
108
+ """
109
+
110
+ def __init__(self, p: float = 0.5):
111
+ super().__init__()
112
+
113
+ if p < 0 or p >= 1:
114
+ raise ValueError(f"The dropout probability {p} is not in [0, 1)")
115
+
116
+ self._p_1 = 1 - p
117
+
118
+ def _extra_repr(self):
119
+ return f"p={1-self._p_1}"
120
+
121
+ def __call__(self, x):
122
+ if x.ndim not in (4, 5):
123
+ raise ValueError(
124
+ f"Received input with {x.ndim} dimensions. Expected 4 or 5 dimensions."
125
+ )
126
+
127
+ if self._p_1 == 1 or not self.training:
128
+ return x
129
+
130
+ # Dropout is applied on the whole channel
131
+ # 4D input: (1, 1, 1, C)
132
+ # 5D input: (B, 1, 1, 1, C)
133
+ mask_shape = list(x.shape)
134
+ mask_shape[-2] = mask_shape[-3] = mask_shape[-4] = 1
135
+
136
+ mask = mx.random.bernoulli(p=self._p_1, shape=mask_shape)
137
+ return (1 / self._p_1) * mask * x
lib/python3.11/site-packages/mlx/nn/layers/embedding.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2023 Apple Inc.
2
+
3
+ import math
4
+
5
+ import mlx.core as mx
6
+ from mlx.nn.layers.base import Module
7
+
8
+
9
+ class Embedding(Module):
10
+ """Implements a simple lookup table that maps each input integer to a
11
+ high-dimensional vector.
12
+
13
+ Typically used to embed discrete tokens for processing by neural networks.
14
+
15
+ Args:
16
+ num_embeddings (int): How many possible discrete tokens can we embed.
17
+ Usually called the vocabulary size.
18
+ dims (int): The dimensionality of the embeddings.
19
+ """
20
+
21
+ def __init__(self, num_embeddings: int, dims: int):
22
+ super().__init__()
23
+ scale = math.sqrt(1 / dims)
24
+ self.weight = mx.random.normal((num_embeddings, dims)) * scale
25
+
26
+ def _extra_repr(self):
27
+ return f"{self.weight.shape[0]}, {self.weight.shape[1]}"
28
+
29
+ def __call__(self, x):
30
+ return self.weight[x]
lib/python3.11/site-packages/mlx/nn/layers/linear.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2023 Apple Inc.
2
+
3
+ import math
4
+ from typing import Any
5
+
6
+ import mlx.core as mx
7
+ from mlx.nn.layers.base import Module
8
+
9
+
10
+ class Identity(Module):
11
+ r"""A placeholder identity operator that is argument-insensitive.
12
+
13
+ Args:
14
+ args: any argument (unused)
15
+ kwargs: any keyword argument (unused)
16
+ """
17
+
18
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
19
+ super().__init__()
20
+
21
+ def __call__(self, x: mx.array) -> mx.array:
22
+ return x
23
+
24
+
25
+ class Linear(Module):
26
+ r"""Applies an affine transformation to the input.
27
+
28
+ Concretely:
29
+
30
+ .. math::
31
+
32
+ y = x W^\top + b
33
+
34
+ where:
35
+ where :math:`W` has shape ``[output_dims, input_dims]`` and :math:`b` has shape ``[output_dims]``.
36
+
37
+ The values are initialized from the uniform distribution :math:`\mathcal{U}(-{k}, {k})`,
38
+ where :math:`k = \frac{1}{\sqrt{D_i}}` and :math:`D_i` is equal to ``input_dims``.
39
+
40
+ Args:
41
+ input_dims (int): The dimensionality of the input features
42
+ output_dims (int): The dimensionality of the output features
43
+ bias (bool, optional): If set to ``False`` then the layer will
44
+ not use a bias. Default is ``True``.
45
+ """
46
+
47
+ def __init__(self, input_dims: int, output_dims: int, bias: bool = True) -> None:
48
+ super().__init__()
49
+ scale = math.sqrt(1.0 / input_dims)
50
+ self.weight = mx.random.uniform(
51
+ low=-scale,
52
+ high=scale,
53
+ shape=(output_dims, input_dims),
54
+ )
55
+ if bias:
56
+ self.bias = mx.random.uniform(
57
+ low=-scale,
58
+ high=scale,
59
+ shape=(output_dims,),
60
+ )
61
+
62
+ def _extra_repr(self) -> str:
63
+ return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}"
64
+
65
+ def __call__(self, x: mx.array) -> mx.array:
66
+ x = x @ self.weight.T
67
+ if "bias" in self:
68
+ x = x + self.bias
69
+ return x
70
+
71
+
72
+ class Bilinear(Module):
73
+ r"""Applies a bilinear transformation to the inputs.
74
+
75
+ Concretely:
76
+
77
+ .. math::
78
+
79
+ y_i = x_1^\top W_i x_2 + b_i
80
+
81
+ where:
82
+ :math:`W` has shape ``[output_dims, input1_dims, input2_dims]``, :math:`b` has shape ``[output_dims ]``,
83
+ and :math:`i` indexes the output dimension.
84
+
85
+ The values are initialized from the uniform distribution :math:`\mathcal{U}(-{k}, {k})`,
86
+ where :math:`k = \frac{1}{\sqrt{D_1}}` and :math:`D_1` is ``input1_dims``.
87
+
88
+ Args:
89
+ input1_dims (int): The dimensionality of the input1 features
90
+ input2_dims (int): The dimensionality of the input2 features
91
+ output_dims (int): The dimensionality of the output features
92
+ bias (bool, optional): If set to ``False`` then the layer will
93
+ not use a bias. Default is ``True``.
94
+ """
95
+
96
+ def __init__(
97
+ self, input1_dims: int, input2_dims: int, output_dims: int, bias: bool = True
98
+ ) -> None:
99
+ super().__init__()
100
+ scale = math.sqrt(1.0 / input1_dims)
101
+ self.weight = mx.random.uniform(
102
+ low=-scale,
103
+ high=scale,
104
+ shape=(output_dims, input2_dims, input1_dims),
105
+ )
106
+ if bias:
107
+ self.bias = mx.random.uniform(
108
+ low=-scale,
109
+ high=scale,
110
+ shape=(output_dims,),
111
+ )
112
+
113
+ def _extra_repr(self) -> str:
114
+ out, in2, in1 = self.weight.shape
115
+ return (
116
+ f"input1_dims={in1}, input2_dims={in2}, output_dims={out}, "
117
+ f"bias={'bias' in self}"
118
+ )
119
+
120
+ def __call__(self, x1: mx.array, x2: mx.array) -> mx.array:
121
+ # Normalize shapes
122
+ out, in2, in1 = self.weight.shape
123
+ xshape = x1.shape[:-1]
124
+ x1 = x1.reshape(-1, in1)
125
+ x2 = x2.reshape(-1, 1, in2)
126
+
127
+ # Perform the bilinear transformation
128
+ w = self.weight.reshape(out * in2, in1)
129
+ y = x1 @ w.T
130
+ y = y.reshape(-1, out, in2).swapaxes(-2, -1)
131
+ y = x2 @ y
132
+ y = y.squeeze(1)
133
+
134
+ # Reset the shape
135
+ y = y.reshape(*xshape, out)
136
+
137
+ # Apply the bias
138
+ if "bias" in self:
139
+ y = y + self.bias
140
+
141
+ return y
lib/python3.11/site-packages/mlx/nn/layers/normalization.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2023 Apple Inc.
2
+
3
+ from typing import Tuple
4
+
5
+ import mlx.core as mx
6
+ from mlx.nn.layers.base import Module
7
+
8
+
9
+ class InstanceNorm(Module):
10
+ r"""Applies instance normalization [1] on the inputs.
11
+
12
+ Computes
13
+
14
+ .. math::
15
+
16
+ y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta,
17
+
18
+ where :math:`\gamma` and :math:`\beta` are learned per feature dimension
19
+ parameters initialized at 1 and 0 respectively. Both are of size :attr:`dims`,
20
+ if :attr:`affine` is ``True``.
21
+
22
+ Args:
23
+ dims (int): The number of features of the input.
24
+ eps (float): A value added to the denominator for numerical stability. Default: ``1e-5``.
25
+ affine (bool): Default: ``False``.
26
+
27
+ Shape:
28
+ - Input: :math:`(..., C)` where :math:`C` is equal to :attr:`dims`.
29
+ - Output: Same shape as the input.
30
+
31
+ Examples:
32
+ >>> import mlx.core as mx
33
+ >>> import mlx.nn as nn
34
+ >>> x = mx.random.normal((8, 4, 4, 16))
35
+ >>> inorm = nn.InstanceNorm(dims=16)
36
+ >>> output = inorm(x)
37
+
38
+ References:
39
+ [1]: https://arxiv.org/abs/1607.08022
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ dims: int,
45
+ eps: float = 1e-5,
46
+ affine: bool = False,
47
+ ):
48
+ super().__init__()
49
+ if affine:
50
+ self.weight = mx.ones((dims,))
51
+ self.bias = mx.zeros((dims,))
52
+ self.dims = dims
53
+ self.eps = eps
54
+
55
+ def _extra_repr(self):
56
+ return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
57
+
58
+ def __call__(self, x: mx.array) -> mx.array:
59
+ reduction_axes = tuple(range(1, x.ndim - 1))
60
+ # Compute stats
61
+ mean = mx.mean(x, axis=reduction_axes, keepdims=True)
62
+ var = mx.var(x, axis=reduction_axes, keepdims=True)
63
+ # Normalize
64
+ x = (x - mean) * mx.rsqrt(var + self.eps)
65
+ # Scale and shift if necessary
66
+ return (self.weight * x + self.bias) if "weight" in self else x
67
+
68
+
69
+ class LayerNorm(Module):
70
+ r"""Applies layer normalization [1] on the inputs.
71
+
72
+ Computes
73
+
74
+ .. math::
75
+
76
+ y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,
77
+
78
+ where :math:`\gamma` and :math:`\beta` are learned per feature dimension
79
+ parameters initialized at 1 and 0 respectively.
80
+
81
+ [1]: https://arxiv.org/abs/1607.06450
82
+
83
+ Args:
84
+ dims (int): The feature dimension of the input to normalize over
85
+ eps (float): A small additive constant for numerical stability
86
+ affine (bool): If True learn an affine transform to apply after the
87
+ normalization
88
+ """
89
+
90
+ def __init__(self, dims: int, eps: float = 1e-5, affine: bool = True):
91
+ super().__init__()
92
+ if affine:
93
+ self.bias = mx.zeros((dims,))
94
+ self.weight = mx.ones((dims,))
95
+ self.eps = eps
96
+ self.dims = dims
97
+
98
+ def _extra_repr(self):
99
+ return f"{self.dims}, eps={self.eps}, affine={'weight' in self}"
100
+
101
+ def __call__(self, x):
102
+ means = mx.mean(x, axis=-1, keepdims=True)
103
+ var = mx.var(x, axis=-1, keepdims=True)
104
+ x = (x - means) * mx.rsqrt(var + self.eps)
105
+ return (self.weight * x + self.bias) if "weight" in self else x
106
+
107
+
108
+ class RMSNorm(Module):
109
+ r"""Applies Root Mean Square normalization [1] to the inputs.
110
+
111
+ Computes
112
+
113
+ .. math::
114
+
115
+ y = \frac{x}{\sqrt{E[x^2] + \epsilon}} \gamma
116
+
117
+ where :math:`\gamma` is a learned per feature dimension parameter initialized at
118
+ 1.
119
+
120
+ [1]: https://arxiv.org/abs/1910.07467
121
+
122
+ Args:
123
+ dims (int): The feature dimension of the input to normalize over
124
+ eps (float): A small additive constant for numerical stability
125
+ """
126
+
127
+ def __init__(self, dims: int, eps: float = 1e-5):
128
+ super().__init__()
129
+ self.weight = mx.ones((dims,))
130
+ self.eps = eps
131
+
132
+ def _extra_repr(self):
133
+ return f"{self.weight.shape[0]}, eps={self.eps}"
134
+
135
+ def __call__(self, x):
136
+ # S is 1/sqrt(N) where N is the size of the features of x and is used
137
+ # to compute a numerically more stable RMS of x by multiplying with S
138
+ # first and summing.
139
+ #
140
+ # This way we prefer underflow over overflow which is controlled with
141
+ # the parameter epsilon anyway.
142
+ S = 1 / x.shape[-1] ** 0.5
143
+
144
+ n = (x * S).square().sum(axis=-1, keepdims=True)
145
+ n = mx.rsqrt(n + self.eps)
146
+
147
+ return self.weight * x * n
148
+
149
+
150
+ class GroupNorm(Module):
151
+ r"""Applies Group Normalization [1] to the inputs.
152
+
153
+ Computes the same normalization as layer norm, namely
154
+
155
+ .. math::
156
+
157
+ y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,
158
+
159
+ where :math:`\gamma` and :math:`\beta` are learned per feature dimension
160
+ parameters initialized at 1 and 0 respectively. However, the mean and
161
+ variance are computed over the spatial dimensions and each group of
162
+ features. In particular, the input is split into num_groups across the
163
+ feature dimension.
164
+
165
+ The feature dimension is assumed to be the last dimension and the dimensions
166
+ that precede it (except the first) are considered the spatial dimensions.
167
+
168
+ [1]: https://arxiv.org/abs/1803.08494
169
+
170
+ Args:
171
+ num_groups (int): Number of groups to separate the features into
172
+ dims (int): The feature dimensions of the input to normalize over
173
+ eps (float): A small additive constant for numerical stability
174
+ affine (bool): If True learn an affine transform to apply after the
175
+ normalization.
176
+ pytorch_compatible (bool): If True perform the group normalization in
177
+ the same order/grouping as PyTorch.
178
+ """
179
+
180
+ def __init__(
181
+ self,
182
+ num_groups: int,
183
+ dims: int,
184
+ eps: float = 1e-5,
185
+ affine: bool = True,
186
+ pytorch_compatible: bool = False,
187
+ ):
188
+ super().__init__()
189
+ if affine:
190
+ self.bias = mx.zeros((dims,))
191
+ self.weight = mx.ones((dims,))
192
+ self.num_groups = num_groups
193
+ self.dims = dims
194
+ self.eps = eps
195
+ self.pytorch_compatible = pytorch_compatible
196
+
197
+ def _extra_repr(self):
198
+ return (
199
+ f"{self.num_groups}, {self.dims}, eps={self.eps}, "
200
+ f"affine={'weight' in self}, pytorch_compatible={self.pytorch_compatible}"
201
+ )
202
+
203
+ def _pytorch_compatible_group_norm(self, x):
204
+ num_groups = self.num_groups
205
+ batch, *rest, dims = x.shape
206
+
207
+ # Split into groups
208
+ x = x.reshape(batch, -1, num_groups, dims // num_groups)
209
+ x = x.transpose(0, 1, 3, 2).reshape(batch, -1, num_groups)
210
+
211
+ # Normalize
212
+ means = mx.mean(x, axis=1, keepdims=True)
213
+ var = mx.var(x, axis=1, keepdims=True)
214
+ x = (x - means) * mx.rsqrt(var + self.eps)
215
+ x = x.reshape(batch, -1, dims // num_groups, num_groups)
216
+ x = x.transpose(0, 1, 3, 2).reshape(batch, *rest, dims)
217
+
218
+ return x
219
+
220
+ def _group_norm(self, x):
221
+ num_groups = self.num_groups
222
+ batch, *rest, dims = x.shape
223
+
224
+ # Split into groups
225
+ x = x.reshape(batch, -1, num_groups)
226
+
227
+ # Normalize
228
+ means = mx.mean(x, axis=1, keepdims=True)
229
+ var = mx.var(x, axis=1, keepdims=True)
230
+ x = (x - means) * mx.rsqrt(var + self.eps)
231
+ x = x.reshape(batch, *rest, dims)
232
+
233
+ return x
234
+
235
+ def __call__(self, x):
236
+ group_norm = (
237
+ self._pytorch_compatible_group_norm
238
+ if self.pytorch_compatible
239
+ else self._group_norm
240
+ )
241
+ x = group_norm(x)
242
+ return (self.weight * x + self.bias) if "weight" in self else x
243
+
244
+
245
+ class BatchNorm(Module):
246
+ r"""Applies Batch Normalization over a 2D or 3D input.
247
+
248
+ Computes
249
+
250
+ .. math::
251
+
252
+ y = \frac{x - E[x]}{\sqrt{Var[x]} + \epsilon} \gamma + \beta,
253
+
254
+ where :math:`\gamma` and :math:`\beta` are learned per feature dimension
255
+ parameters initialized at 1 and 0 respectively.
256
+
257
+ The input shape is specified as ``NC`` or ``NLC``, where ``N`` is the
258
+ batch, ``C`` is the number of features or channels, and ``L`` is the
259
+ sequence length. The output has the same shape as the input. For
260
+ four-dimensional arrays, the shape is ``NHWC``, where ``H`` and ``W`` are
261
+ the height and width respectively.
262
+
263
+ For more information on Batch Normalization, see the original paper `Batch
264
+ Normalization: Accelerating Deep Network Training by Reducing Internal
265
+ Covariate Shift <https://arxiv.org/abs/1502.03167>`_.
266
+
267
+ Args:
268
+ num_features (int): The feature dimension to normalize over.
269
+ eps (float, optional): A small additive constant for numerical
270
+ stability. Default: ``1e-5``.
271
+ momentum (float, optional): The momentum for updating the running
272
+ mean and variance. Default: ``0.1``.
273
+ affine (bool, optional): If ``True``, apply a learned affine
274
+ transformation after the normalization. Default: ``True``.
275
+ track_running_stats (bool, optional): If ``True``, track the
276
+ running mean and variance. Default: ``True``.
277
+
278
+ Examples:
279
+ >>> import mlx.core as mx
280
+ >>> import mlx.nn as nn
281
+ >>> x = mx.random.normal((5, 4))
282
+ >>> bn = nn.BatchNorm(num_features=4, affine=True)
283
+ >>> output = bn(x)
284
+ """
285
+
286
+ def __init__(
287
+ self,
288
+ num_features: int,
289
+ eps: float = 1e-5,
290
+ momentum: float = 0.1,
291
+ affine: bool = True,
292
+ track_running_stats: bool = True,
293
+ ):
294
+ super().__init__()
295
+
296
+ self.num_features = num_features
297
+ self.eps = eps
298
+ self.momentum = momentum
299
+ self.track_running_stats = track_running_stats
300
+
301
+ if affine:
302
+ self.weight = mx.ones((num_features,))
303
+ self.bias = mx.zeros((num_features,))
304
+
305
+ if self.track_running_stats:
306
+ self.running_mean = mx.zeros((num_features,))
307
+ self.running_var = mx.ones((num_features,))
308
+ self.freeze(keys=["running_mean", "running_var"], recurse=False)
309
+
310
+ def unfreeze(self, *args, **kwargs):
311
+ """Wrap unfreeze to make sure that running_mean and var are always
312
+ frozen parameters."""
313
+ super().unfreeze(*args, **kwargs)
314
+ self.freeze(keys=["running_mean", "running_var"], recurse=False)
315
+
316
+ def _extra_repr(self):
317
+ return (
318
+ f"{self.num_features}, eps={self.eps}, "
319
+ f"momentum={self.momentum}, affine={'weight' in self}, "
320
+ f"track_running_stats={self.track_running_stats}"
321
+ )
322
+
323
+ def _calc_stats(self, x: mx.array) -> Tuple[mx.array, mx.array]:
324
+ """
325
+ Calculate the mean and variance of the input tensor across the batch
326
+ and spatial dimensions.
327
+
328
+ Args:
329
+ x (array): Input tensor.
330
+
331
+ Returns:
332
+ tuple: Tuple containing mean and variance.
333
+ """
334
+ reduction_axes = tuple(range(0, x.ndim - 1))
335
+
336
+ mean = mx.mean(x, axis=reduction_axes, keepdims=True)
337
+ var = mx.var(x, axis=reduction_axes, keepdims=True)
338
+
339
+ return mean, var
340
+
341
+ def __call__(self, x: mx.array) -> mx.array:
342
+ """
343
+ Forward pass of BatchNorm.
344
+
345
+ Args:
346
+ x (array): Input tensor.
347
+
348
+ Returns:
349
+ array: Normalized output tensor.
350
+ """
351
+ if x.ndim < 2 or x.ndim > 4:
352
+ raise ValueError(
353
+ f"Expected input tensor to have 2, 3 or 4 dimensions, but got {x.ndim}"
354
+ )
355
+
356
+ # Calculate the mean and variance used to normalize the input x. If we
357
+ # are in training mode update the running stats if needed.
358
+ mean, var = self._calc_stats(x)
359
+ if self.training and self.track_running_stats:
360
+ mu = self.momentum
361
+ self.running_mean = (1 - mu) * self.running_mean + mu * mean
362
+ self.running_var = (1 - mu) * self.running_var + mu * var
363
+ elif self.track_running_stats:
364
+ mean = self.running_mean
365
+ var = self.running_var
366
+
367
+ x = (x - mean) * mx.rsqrt(var + self.eps)
368
+ return (self.weight * x + self.bias) if "weight" in self else x
lib/python3.11/site-packages/mlx/nn/layers/positional_encoding.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2023 Apple Inc.
2
+
3
+ import math
4
+ from typing import Optional
5
+
6
+ import mlx.core as mx
7
+ from mlx.nn.layers.base import Module
8
+
9
+
10
+ class RoPE(Module):
11
+ """Implements the rotary positional encoding.
12
+
13
+ The traditional implementation rotates consecutive pairs of elements in the
14
+ feature dimension while the default implementation rotates pairs with
15
+ stride half the feature dimensions for efficiency.
16
+
17
+ For more details see `RoFormer: Enhanced Transformer with Rotary Position
18
+ Embedding <https://arxiv.org/abs/2104.09864>`_.
19
+
20
+ Args:
21
+ dims (int): The feature dimensions to be rotated. If the input feature
22
+ is larger than dims then the rest is left unchanged.
23
+ traditional (bool, optional): If set to True choose the traditional
24
+ implementation which is slightly less efficient. Default: ``False``.
25
+ base (float, optional): The base used to compute angular frequency for
26
+ each dimension in the positional encodings. Default: ``10000``.
27
+ scale (float, optional): The scale used to scale the positions. Default: ``1.0``.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ dims: int,
33
+ traditional: bool = False,
34
+ base: float = 10000,
35
+ scale: float = 1.0,
36
+ ):
37
+ super().__init__()
38
+ self.dims = dims
39
+ self.traditional = traditional
40
+ self.base = base
41
+ self.scale = scale
42
+
43
+ def _extra_repr(self):
44
+ return f"{self.dims}, traditional={self.traditional}"
45
+
46
+ def _compute_rope(self, costheta, sintheta, x):
47
+ x1 = x[..., : self.dims // 2]
48
+ x2 = x[..., self.dims // 2 : self.dims]
49
+ rx1 = x1 * costheta - x2 * sintheta
50
+ rx2 = x1 * sintheta + x2 * costheta
51
+
52
+ if self.dims < x.shape[-1]:
53
+ rx = mx.concatenate([rx1, rx2, x[..., self.dims :]], axis=-1)
54
+ else:
55
+ rx = mx.concatenate([rx1, rx2], axis=-1)
56
+
57
+ return rx
58
+
59
+ def _compute_traditional_rope(self, costheta, sintheta, x):
60
+ x1 = x[..., ::2]
61
+ x2 = x[..., 1::2]
62
+ rx1 = x1 * costheta - x2 * sintheta
63
+ rx2 = x1 * sintheta + x2 * costheta
64
+
65
+ if self.dims < x.shape[-1]:
66
+ raise NotImplementedError(
67
+ "RoPE doesn't implement partial traditional application"
68
+ )
69
+
70
+ rx = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
71
+
72
+ return rx
73
+
74
+ def __call__(self, x, offset: int = 0):
75
+ shape = x.shape
76
+ x = mx.reshape(x, (-1, shape[-2], shape[-1]))
77
+ N = x.shape[1] + offset
78
+ costheta, sintheta = RoPE.create_cos_sin_theta(
79
+ N, self.dims, offset=offset, base=self.base, scale=self.scale, dtype=x.dtype
80
+ )
81
+
82
+ rope = (
83
+ self._compute_traditional_rope if self.traditional else self._compute_rope
84
+ )
85
+ rx = rope(costheta, sintheta, x)
86
+
87
+ return mx.reshape(rx, shape)
88
+
89
+ @staticmethod
90
+ def create_cos_sin_theta(
91
+ N: int,
92
+ D: int,
93
+ offset: int = 0,
94
+ base: float = 10000,
95
+ scale: float = 1.0,
96
+ dtype=mx.float32,
97
+ ):
98
+ D = D // 2
99
+ positions = mx.arange(offset, N, dtype=dtype) * scale
100
+ freqs = mx.exp(-mx.arange(0.0, D, dtype=dtype) * (math.log(base) / D))
101
+ theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
102
+ return mx.cos(theta), mx.sin(theta)
103
+
104
+
105
+ class SinusoidalPositionalEncoding(Module):
106
+ r"""Implements sinusoidal positional encoding.
107
+
108
+ For more details see the paper `Attention Is All You Need
109
+ <https://arxiv.org/abs/1706.03762>`_.
110
+
111
+ Args:
112
+ dims (int): The dimensionality of the resulting positional embeddings.
113
+ min_freq (float, optional): The minimum frequency expected. Default:
114
+ ``0.0001``.
115
+ max_freq (float, optional): The maximum frequency expected. Default:
116
+ ``1``.
117
+ scale (float, optional): A multiplicative scale for the embeddings.
118
+ Default: ``sqrt(dims//2)``.
119
+ cos_first (bool, optional): If ``True`` embed using ``[cos(x); sin(x)]``
120
+ instead of the reverse. Default: ``False``.
121
+ full_turns (bool, optional): If ``True`` multiply the frequencies with
122
+ :math:`2\pi`. Default: ``False``.
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ dims: int,
128
+ min_freq: float = 0.0001,
129
+ max_freq: float = 1,
130
+ scale: Optional[float] = None,
131
+ cos_first: bool = False,
132
+ full_turns: bool = False,
133
+ ):
134
+ super().__init__()
135
+
136
+ one_zero = 1 - mx.arange(0, dims // 2) / (dims // 2 - 1)
137
+ min_freq = math.log(min_freq)
138
+ max_freq = math.log(max_freq)
139
+
140
+ # Start with underscore so it is not included in the parameters
141
+ self._sigmas = mx.exp(one_zero * (max_freq - min_freq) + min_freq)
142
+ if full_turns:
143
+ self._sigmas = self._sigmas * (2 * math.pi)
144
+
145
+ # Save some constants that define the implementation
146
+ self.scale = scale or (2 / dims) ** 0.5
147
+ self.cos_first = cos_first
148
+
149
+ def __call__(self, x):
150
+ y = x[..., None] * self._sigmas
151
+ cosy = mx.cos(y)
152
+ siny = mx.sin(y)
153
+
154
+ if self.cos_first:
155
+ y = mx.concatenate([cosy, siny], axis=-1)
156
+ else:
157
+ y = mx.concatenate([siny, cosy], axis=-1)
158
+
159
+ if self.scale != 1:
160
+ y = y * self.scale
161
+
162
+ return y
163
+
164
+
165
+ class ALiBi(Module):
166
+ @staticmethod
167
+ def create_alibi_matrix(
168
+ q_sequence_length: int,
169
+ k_sequence_length: int,
170
+ num_heads: int,
171
+ offset: int,
172
+ dtype=mx.float32,
173
+ ):
174
+ x1 = mx.arange(offset, q_sequence_length)
175
+ x2 = mx.arange(0, k_sequence_length)
176
+ distance_matrix = -mx.abs(
177
+ mx.expand_dims(x1[:, None] - x2[None, :], axis=(0, 1))
178
+ )
179
+ alibi_slope = ALiBi.create_alibi_slope(num_heads=num_heads)
180
+ alibi_mask = (distance_matrix * alibi_slope).astype(dtype)
181
+ return alibi_mask
182
+
183
+ @staticmethod
184
+ def create_alibi_slope(num_heads):
185
+ x = (2**8) ** (1 / num_heads)
186
+ out = mx.power(x, -mx.arange(1, num_heads + 1))
187
+ return mx.expand_dims(out, axis=(-1, -2))
188
+
189
+ def __call__(self, attention_scores, offset=0, mask=None):
190
+ alibi_mask = ALiBi.create_alibi_matrix(
191
+ q_sequence_length=attention_scores.shape[-2] + offset,
192
+ k_sequence_length=attention_scores.shape[-1],
193
+ num_heads=attention_scores.shape[1],
194
+ offset=offset,
195
+ dtype=attention_scores.dtype,
196
+ )
197
+ if mask is not None:
198
+ alibi_mask = alibi_mask + mask
199
+ return attention_scores + alibi_mask
lib/python3.11/site-packages/mlx/nn/layers/quantized.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2023 Apple Inc.
2
+
3
+ import math
4
+
5
+ import mlx.core as mx
6
+ from mlx.nn.layers.base import Module
7
+ from mlx.nn.layers.linear import Linear
8
+ from mlx.utils import tree_flatten, tree_map
9
+
10
+
11
+ class QuantizedLinear(Module):
12
+ """Applies an affine transformation to the input using a quantized weight matrix.
13
+
14
+ It is the quantized equivalent of :class:`mlx.nn.Linear`. For now its
15
+ parameters are frozen and will not be included in any gradient computation
16
+ but this will probably change in the future.
17
+
18
+ QuantizedLinear also provides two useful classmethods to convert linear
19
+ layers to QuantizedLinear layers.
20
+
21
+ - :meth:`from_linear` returns a QuantizedLinear layer that applies the same
22
+ linear transformation up to the quantization error.
23
+ - :meth:`quantize_module` swaps all the linear layers of the passed module
24
+ with QuantizedLinear ones.
25
+
26
+ Args:
27
+ input_dims (int): The dimensionality of the input features
28
+ output_dims (int): The dimensionality of the output features
29
+ bias (bool, optional): If set to ``False`` then the layer will not use
30
+ a bias. (default: True).
31
+ group_size (int, optional): The group size to use for the quantized
32
+ weight. See :func:`~mlx.core.quantize`. (default: 64)
33
+ bits (int, optional): The bit width to use for the quantized weight.
34
+ See :func:`~mlx.core.quantize`. (default: 4)
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ input_dims: int,
40
+ output_dims: int,
41
+ bias: bool = True,
42
+ group_size: int = 64,
43
+ bits: int = 4,
44
+ ):
45
+ super().__init__()
46
+
47
+ # Quantization config
48
+ self.group_size = group_size
49
+ self.bits = bits
50
+
51
+ # Initialize the quantized weight
52
+ scale = math.sqrt(1 / input_dims)
53
+ weight = mx.random.uniform(
54
+ low=-scale,
55
+ high=scale,
56
+ shape=(output_dims, input_dims),
57
+ )
58
+ self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits)
59
+
60
+ # And bias if needed
61
+ if bias:
62
+ self.bias = mx.zeros((output_dims,))
63
+
64
+ # Freeze this model's parameters
65
+ self.freeze()
66
+
67
+ def unfreeze(self, *args, **kwargs):
68
+ """Wrap unfreeze so that we unfreeze any layers we might contain but
69
+ our parameters will remain frozen."""
70
+ super().unfreeze(*args, **kwargs)
71
+ self.freeze(recurse=False)
72
+
73
+ def _extra_repr(self):
74
+ out_dims, in_dims = self.weight.shape
75
+ in_dims *= 32 // self.bits
76
+ return (
77
+ f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self},"
78
+ f"group_size={self.group_size}, bits={self.bits}"
79
+ )
80
+
81
+ def __call__(self, x):
82
+ x = mx.quantized_matmul(
83
+ x,
84
+ self.weight,
85
+ scales=self.scales,
86
+ biases=self.biases,
87
+ transpose=True,
88
+ group_size=self.group_size,
89
+ bits=self.bits,
90
+ )
91
+ if "bias" in self:
92
+ x = x + self.bias
93
+ return x
94
+
95
+ @classmethod
96
+ def from_linear(cls, linear_layer: Module, group_size: int = 64, bits: int = 4):
97
+ """Create a QuantizedLinear layer from the parameters of a provided
98
+ linear layer."""
99
+ output_dims, input_dims = linear_layer.weight.shape
100
+ ql = cls(input_dims, output_dims, False, group_size, bits)
101
+ ql.weight, ql.scales, ql.biases = mx.quantize(
102
+ linear_layer.weight, group_size, bits
103
+ )
104
+ if "bias" in linear_layer:
105
+ ql.bias = linear_layer.bias
106
+
107
+ return ql
108
+
109
+ @classmethod
110
+ def quantize_module(
111
+ cls,
112
+ model: Module,
113
+ group_size: int = 64,
114
+ bits: int = 4,
115
+ linear_class_predicate=lambda m: isinstance(m, Linear),
116
+ ):
117
+ def _quantize_if_linear(m):
118
+ if linear_class_predicate(m):
119
+ return cls.from_linear(m, group_size, bits)
120
+ else:
121
+ return m
122
+
123
+ leaves = model.leaf_modules()
124
+ leaves = tree_map(_quantize_if_linear, leaves, is_leaf=Module.is_module)
125
+ model.update_modules(leaves)
lib/python3.11/site-packages/mlx/nn/layers/transformer.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2023 Apple Inc.
2
+
3
+ import math
4
+ from typing import Any, Callable, Optional
5
+
6
+ import mlx.core as mx
7
+ from mlx.nn.layers.activations import relu
8
+ from mlx.nn.layers.base import Module
9
+ from mlx.nn.layers.dropout import Dropout
10
+ from mlx.nn.layers.linear import Linear
11
+ from mlx.nn.layers.normalization import LayerNorm
12
+
13
+
14
+ class MultiHeadAttention(Module):
15
+ """Implements the scaled dot product attention with multiple heads.
16
+
17
+ Given inputs for queries, keys and values the ``MultiHeadAttention``
18
+ produces new values by aggregating information from the input values
19
+ according to the similarities of the input queries and keys.
20
+
21
+ All inputs as well as the output are linearly projected without biases by
22
+ default.
23
+
24
+ ``MultiHeadAttention`` also takes an optional additive attention mask that
25
+ should be broadcastable with ``(batch, num_heads, # queries, # keys)``. The
26
+ mask should have ``-inf`` or very large negative numbers at the positions
27
+ that should *not* be attended to.
28
+
29
+ Args:
30
+ dims (int): The model dimensions. This is also the default
31
+ value for the queries, keys, values, and the output.
32
+ num_heads (int): The number of attention heads to use.
33
+ query_input_dims (int, optional): The input dimensions of the queries.
34
+ Default: ``dims``.
35
+ key_input_dims (int, optional): The input dimensions of the keys.
36
+ Default: ``dims``.
37
+ value_input_dims (int, optional): The input dimensions of the values.
38
+ Default: ``key_input_dims``.
39
+ value_dims (int, optional): The dimensions of the values after the
40
+ projection. Default: ``dims``.
41
+ value_output_dims (int, optional): The dimensions the new values will
42
+ be projected to. Default: ``dims``.
43
+ bias (bool, optional): Whether or not to use a bias in the projections.
44
+ Default: ``False``.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ dims: int,
50
+ num_heads: int,
51
+ query_input_dims: Optional[int] = None,
52
+ key_input_dims: Optional[int] = None,
53
+ value_input_dims: Optional[int] = None,
54
+ value_dims: Optional[int] = None,
55
+ value_output_dims: Optional[int] = None,
56
+ bias: bool = False,
57
+ ):
58
+ super().__init__()
59
+
60
+ if (dims % num_heads) != 0:
61
+ raise ValueError(
62
+ "The input feature dimensions should be divisible by the "
63
+ f"number of heads ({dims} % {num_heads}) != 0"
64
+ )
65
+
66
+ query_input_dims = query_input_dims or dims
67
+ key_input_dims = key_input_dims or dims
68
+ value_input_dims = value_input_dims or key_input_dims
69
+ value_dims = value_dims or dims
70
+ value_output_dims = value_output_dims or dims
71
+
72
+ self.num_heads = num_heads
73
+ self.query_proj = Linear(query_input_dims, dims, bias=bias)
74
+ self.key_proj = Linear(key_input_dims, dims, bias=bias)
75
+ self.value_proj = Linear(value_input_dims, value_dims, bias=bias)
76
+ self.out_proj = Linear(value_dims, value_output_dims, bias=bias)
77
+
78
+ def __call__(self, queries, keys, values, mask=None):
79
+ queries = self.query_proj(queries)
80
+ keys = self.key_proj(keys)
81
+ values = self.value_proj(values)
82
+
83
+ num_heads = self.num_heads
84
+ B, L, D = queries.shape
85
+ _, S, _ = keys.shape
86
+ queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
87
+ keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
88
+ values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
89
+
90
+ # Dimensions are [batch x num heads x sequence x hidden dim]
91
+ scale = math.sqrt(1 / queries.shape[-1])
92
+ scores = (queries * scale) @ keys
93
+ if mask is not None:
94
+ scores = scores + mask.astype(scores.dtype)
95
+ scores = mx.softmax(scores, axis=-1)
96
+ values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
97
+
98
+ return self.out_proj(values_hat)
99
+
100
+ @staticmethod
101
+ def create_additive_causal_mask(N: int, dtype: mx.Dtype = mx.float32):
102
+ indices = mx.arange(N)
103
+ mask = indices[:, None] < indices[None]
104
+ # usually inf but 1e9 is as good and softmax(full(1e9)) != nan
105
+ # TODO: Should replace this with finfo(dtype).min
106
+ mask = mask.astype(dtype) * -1e9
107
+ return mask
108
+
109
+
110
+ class TransformerEncoderLayer(Module):
111
+ def __init__(
112
+ self,
113
+ dims: int,
114
+ num_heads: int,
115
+ mlp_dims: Optional[int] = None,
116
+ dropout: float = 0.0,
117
+ activation: Callable[[Any], Any] = relu,
118
+ norm_first: bool = False,
119
+ ):
120
+ super().__init__()
121
+ mlp_dims = mlp_dims or dims * 4
122
+ self.attention = MultiHeadAttention(dims, num_heads)
123
+ self.ln1 = LayerNorm(dims)
124
+ self.ln2 = LayerNorm(dims)
125
+ self.linear1 = Linear(dims, mlp_dims)
126
+ self.linear2 = Linear(mlp_dims, dims)
127
+ self.dropout1 = Dropout(dropout)
128
+ self.dropout2 = Dropout(dropout)
129
+ self.activation = activation
130
+ self.norm_first = norm_first
131
+
132
+ def __call__(self, x, mask):
133
+ if self.norm_first:
134
+ y = self.ln1(x)
135
+ y = self.attention(y, y, y, mask)
136
+ y = self.dropout1(y)
137
+ x = x + y
138
+
139
+ y = self.ln2(x)
140
+ y = self.linear1(y)
141
+ y = self.activation(y)
142
+ y = self.dropout2(y)
143
+ y = self.linear2(y)
144
+ y = x + y
145
+
146
+ else:
147
+ y = self.attention(x, x, x, mask)
148
+ y = self.dropout1(y)
149
+ y = self.ln1(x + y)
150
+
151
+ y = self.linear1(y)
152
+ y = self.activation(y)
153
+ y = self.dropout2(y)
154
+ y = self.linear2(y)
155
+ y = self.ln2(x + y)
156
+
157
+ return y
158
+
159
+
160
+ class TransformerEncoder(Module):
161
+ def __init__(
162
+ self,
163
+ num_layers: int,
164
+ dims: int,
165
+ num_heads: int,
166
+ mlp_dims: Optional[int] = None,
167
+ dropout: float = 0.0,
168
+ activation=relu,
169
+ norm_first: bool = False,
170
+ ):
171
+ super().__init__()
172
+ self.layers = [
173
+ TransformerEncoderLayer(
174
+ dims, num_heads, mlp_dims, dropout, activation, norm_first
175
+ )
176
+ for i in range(num_layers)
177
+ ]
178
+ self.ln = LayerNorm(dims)
179
+
180
+ def __call__(self, x, mask):
181
+ for l in self.layers:
182
+ x = l(x, mask)
183
+ return self.ln(x)
184
+
185
+
186
+ class TransformerDecoderLayer(Module):
187
+ def __init__(
188
+ self,
189
+ dims: int,
190
+ num_heads: int,
191
+ mlp_dims: Optional[int] = None,
192
+ dropout: float = 0.0,
193
+ activation: Callable[[Any], Any] = relu,
194
+ norm_first: bool = False,
195
+ ):
196
+ super().__init__()
197
+ mlp_dims = mlp_dims or dims * 4
198
+ self.self_attention = MultiHeadAttention(dims, num_heads)
199
+ self.cross_attention = MultiHeadAttention(dims, num_heads)
200
+ self.ln1 = LayerNorm(dims)
201
+ self.ln2 = LayerNorm(dims)
202
+ self.ln3 = LayerNorm(dims)
203
+ self.linear1 = Linear(dims, mlp_dims)
204
+ self.linear2 = Linear(mlp_dims, dims)
205
+ self.dropout1 = Dropout(dropout)
206
+ self.dropout2 = Dropout(dropout)
207
+ self.dropout3 = Dropout(dropout)
208
+ self.activation = activation
209
+ self.norm_first = norm_first
210
+
211
+ def __call__(self, x, memory, x_mask, memory_mask):
212
+ if self.norm_first:
213
+ y = self.ln1(x)
214
+ y = self.self_attention(y, y, y, x_mask)
215
+ y = self.dropout1(y)
216
+ x = x + y
217
+
218
+ y = self.ln2(x)
219
+ y = self.cross_attention(y, memory, memory, memory_mask)
220
+ y = self.dropout2(y)
221
+ x = x + y
222
+
223
+ y = self.ln3(x)
224
+ y = self.linear1(y)
225
+ y = self.activation(y)
226
+ y = self.dropout3(y)
227
+ y = self.linear2(y)
228
+ y = x + y
229
+
230
+ else:
231
+ y = self.self_attention(x, x, x, x_mask)
232
+ y = self.dropout1(y)
233
+ x = self.ln1(x + y)
234
+
235
+ y = self.cross_attention(y, memory, memory, memory_mask)
236
+ y = self.dropout2(y)
237
+ x = self.ln1(x + y)
238
+
239
+ y = self.linear1(x)
240
+ y = self.activation(y)
241
+ y = self.dropout3(y)
242
+ y = self.linear2(y)
243
+ y = self.ln3(x + y)
244
+
245
+ return y
246
+
247
+
248
+ class TransformerDecoder(Module):
249
+ def __init__(
250
+ self,
251
+ num_layers: int,
252
+ dims: int,
253
+ num_heads: int,
254
+ mlp_dims: Optional[int] = None,
255
+ dropout: float = 0.0,
256
+ activation=relu,
257
+ norm_first: bool = False,
258
+ ):
259
+ super().__init__()
260
+ self.layers = [
261
+ TransformerDecoderLayer(
262
+ dims, num_heads, mlp_dims, dropout, activation, norm_first
263
+ )
264
+ for i in range(num_layers)
265
+ ]
266
+ self.ln = LayerNorm(dims)
267
+
268
+ def __call__(self, x, memory, x_mask, memory_mask):
269
+ for l in self.layers:
270
+ x = l(x, memory, x_mask, memory_mask)
271
+ return self.ln(x)
272
+
273
+
274
+ class Transformer(Module):
275
+ """
276
+ Implements a standard Transformer model.
277
+
278
+ The implementation is based on `Attention Is All You Need
279
+ <https://arxiv.org/abs/1706.03762>`_.
280
+
281
+ The Transformer model contains an encoder and a decoder. The encoder
282
+ processes the input sequence and the decoder generates the output sequence.
283
+ The interaction between encoder and decoder happens through the attention
284
+ mechanism.
285
+
286
+ Args:
287
+ dims (int, optional): The number of expected features in the
288
+ encoder/decoder inputs. Default: ``512``.
289
+ num_heads (int, optional): The number of attention heads. Default:
290
+ ``8``.
291
+ num_encoder_layers (int, optional): The number of encoder layers in the
292
+ Transformer encoder. Default: ``6``.
293
+ num_decoder_layers (int, optional): The number of decoder layers in the
294
+ Transformer decoder. Default: ``6``.
295
+ mlp_dims (int, optional): The hidden dimension of the MLP block in each
296
+ Transformer layer. Defaults to ``4*dims`` if not provided. Default:
297
+ ``None``.
298
+ dropout (float, optional): The dropout value for the Transformer
299
+ encoder and decoder. Dropout is used after each attention layer and
300
+ the activation in the MLP layer. Default: ``0.0``.
301
+ activation (function, optional): the activation function for the MLP
302
+ hidden layer. Default: :func:`mlx.nn.relu`.
303
+ custom_encoder (nn.Module, optional): A custom encoder to replace the
304
+ standard Transformer encoder. Default: ``None``.
305
+ custom_decoder (nn.Module, optional): A custom decoder to replace the
306
+ standard Transformer decoder. Default: ``None``.
307
+ norm_first (bool, optional): if ``True``, encoder and decoder layers
308
+ will perform layer normalization before attention and MLP
309
+ operations, otherwise after. Default: ``False``.
310
+ """
311
+
312
+ def __init__(
313
+ self,
314
+ dims: int = 512,
315
+ num_heads: int = 8,
316
+ num_encoder_layers: int = 6,
317
+ num_decoder_layers: int = 6,
318
+ mlp_dims: Optional[int] = None,
319
+ dropout: float = 0.0,
320
+ activation: Callable[[Any], Any] = relu,
321
+ custom_encoder: Optional[Any] = None,
322
+ custom_decoder: Optional[Any] = None,
323
+ norm_first: bool = False,
324
+ ):
325
+ super().__init__()
326
+ if custom_encoder is not None:
327
+ self.encoder = custom_encoder
328
+ else:
329
+ self.encoder = TransformerEncoder(
330
+ num_encoder_layers,
331
+ dims,
332
+ num_heads,
333
+ mlp_dims,
334
+ dropout,
335
+ activation,
336
+ norm_first,
337
+ )
338
+
339
+ if custom_decoder is not None:
340
+ self.decoder = custom_decoder
341
+ else:
342
+ self.decoder = TransformerDecoder(
343
+ num_decoder_layers,
344
+ dims,
345
+ num_heads,
346
+ mlp_dims,
347
+ dropout,
348
+ activation,
349
+ norm_first,
350
+ )
351
+
352
+ def __call__(self, src, tgt, src_mask, tgt_mask, memory_mask):
353
+ memory = self.encoder(src, src_mask)
354
+ return self.decoder(tgt, memory, tgt_mask, memory_mask)
lib/python3.11/site-packages/mlx/nn/losses.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2023 Apple Inc.
2
+
3
+ import math
4
+
5
+ import mlx.core as mx
6
+ from mlx.nn.layers.base import Module
7
+
8
+
9
+ def cross_entropy(
10
+ logits: mx.array,
11
+ targets: mx.array,
12
+ weights: mx.array = None,
13
+ axis: int = -1,
14
+ label_smoothing: float = 0.0,
15
+ reduction: str = "none",
16
+ ) -> mx.array:
17
+ """
18
+ Computes the cross entropy loss.
19
+
20
+ Args:
21
+ logits (array): The unnormalized predicted logits.
22
+ targets (array): The target values, as class indices.
23
+ weights (array, optional): Weights for each target. Default: ``None``.
24
+ axis (int, optional): The axis over which to compute softmax. Default: ``-1``.
25
+ label_smoothing (float, optional): Label smoothing factor. Default: ``0``.
26
+ reduction (str, optional): Specifies the reduction to apply to the output:
27
+ ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
28
+
29
+ Returns:
30
+ array: The computed cross entropy loss.
31
+ """
32
+ if label_smoothing < 0 or label_smoothing >= 1:
33
+ raise ValueError(f"Label smoothing must in [0, 1), got {label_smoothing}.")
34
+
35
+ score = mx.take_along_axis(logits, targets[..., None], axis).squeeze(-1)
36
+ logsumexp_logits = mx.logsumexp(logits, axis=axis)
37
+ if label_smoothing > 0:
38
+ # Adjust the true class score with label smoothing
39
+ adjusted_score = (1 - label_smoothing) * score
40
+
41
+ # Calculate the mean logit across the classes for smoothed loss
42
+ mean_logits = logits.mean(axis=axis)
43
+ smoothed_loss = -mean_logits * label_smoothing
44
+
45
+ # Combine the adjusted score and smoothed loss with the logsumexp logits
46
+ loss = logsumexp_logits - adjusted_score + smoothed_loss
47
+ else:
48
+ loss = logsumexp_logits - score
49
+
50
+ # Apply weights if provided
51
+ if weights is not None:
52
+ if weights.shape != targets.shape:
53
+ raise ValueError(
54
+ f"Weights with shape {weights.shape} is not the same as "
55
+ f"targets with shape {targets.shape}."
56
+ )
57
+ loss *= weights
58
+
59
+ # Apply reduction
60
+ return _reduce(loss, reduction)
61
+
62
+
63
+ def binary_cross_entropy(
64
+ logits: mx.array, targets: mx.array, reduction: str = "none"
65
+ ) -> mx.array:
66
+ """
67
+ Computes the binary cross entropy loss.
68
+
69
+ Args:
70
+ logits (array): The unnormalized (pre-sigmoid) predicted logits.
71
+ targets (array): The binary target values in {0, 1}.
72
+ reduction (str, optional): Specifies the reduction to apply to the output:
73
+ ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
74
+
75
+ Returns:
76
+ array: The computed binary cross entropy loss.
77
+ Examples:
78
+ >>> import mlx.core as mx
79
+ >>> import mlx.nn as nn
80
+ >>> inputs = mx.array([0.105361, 0.223144, 1.20397, 0.916291])
81
+ >>> targets = mx.array([0, 0, 1, 1])
82
+ >>> loss = nn.losses.binary_cross_entropy(inputs, targets, "mean")
83
+ >>> loss
84
+ array([0.612192], dtype=float32)
85
+ """
86
+ loss = mx.logaddexp(0.0, logits) - targets * logits
87
+ return _reduce(loss, reduction)
88
+
89
+
90
+ def l1_loss(
91
+ predictions: mx.array, targets: mx.array, reduction: str = "mean"
92
+ ) -> mx.array:
93
+ """
94
+ Computes the L1 loss.
95
+
96
+ Args:
97
+ predictions (array): The predicted values.
98
+ targets (array): The target values.
99
+ reduction (str, optional): Specifies the reduction to apply to the output:
100
+ ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``.
101
+
102
+ Returns:
103
+ array: The computed L1 loss.
104
+ """
105
+ if predictions.shape != targets.shape:
106
+ raise ValueError(
107
+ f"Predictions shape {predictions.shape} does not match "
108
+ f"targets shape {targets.shape}."
109
+ )
110
+ loss = mx.abs(predictions - targets)
111
+
112
+ return _reduce(loss, reduction)
113
+
114
+
115
+ def mse_loss(
116
+ predictions: mx.array, targets: mx.array, reduction: str = "mean"
117
+ ) -> mx.array:
118
+ """
119
+ Computes the mean squared error loss.
120
+
121
+ Args:
122
+ predictions (array): The predicted values.
123
+ targets (array): The target values.
124
+ reduction (str, optional): Specifies the reduction to apply to the output:
125
+ ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``.
126
+
127
+ Returns:
128
+ array: The computed mean squared error loss.
129
+ """
130
+ if predictions.shape != targets.shape:
131
+ raise ValueError(
132
+ f"Predictions shape {predictions.shape} does not match "
133
+ f"targets shape {targets.shape}."
134
+ )
135
+
136
+ loss = mx.square(predictions - targets)
137
+ return _reduce(loss, reduction)
138
+
139
+
140
+ def nll_loss(
141
+ inputs: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none"
142
+ ) -> mx.array:
143
+ """
144
+ Computes the negative log likelihood loss.
145
+
146
+ Args:
147
+ inputs (array): The predicted distribution in log space.
148
+ targets (array): The target values.
149
+ axis (int, optional): The distribution axis. Default: ``-1``.
150
+ reduction (str, optional): Specifies the reduction to apply to the output:
151
+ ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
152
+
153
+ Returns:
154
+ array: The computed NLL loss.
155
+ """
156
+ loss = -mx.take_along_axis(inputs, targets[..., None], axis).squeeze(-1)
157
+
158
+ return _reduce(loss, reduction)
159
+
160
+
161
+ def kl_div_loss(
162
+ inputs: mx.array, targets: mx.array, axis: int = -1, reduction: str = "none"
163
+ ) -> mx.array:
164
+ """
165
+ Computes the Kullback-Leibler divergence loss.
166
+
167
+ Computes the following when ``reduction == 'none'``:
168
+
169
+ .. code-block:: python
170
+
171
+ mx.exp(targets) * (targets - inputs).sum(axis)
172
+
173
+ Args:
174
+ inputs (array): Log probabilities for the predicted distribution.
175
+ targets (array): Log probabilities for the target distribution.
176
+ axis (int, optional): The distribution axis. Default: ``-1``.
177
+ reduction (str, optional): Specifies the reduction to apply to the output:
178
+ ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
179
+
180
+ Returns:
181
+ array: The computed Kullback-Leibler divergence loss.
182
+ """
183
+ loss = mx.sum(mx.exp(targets) * (targets - inputs), axis)
184
+
185
+ return _reduce(loss, reduction)
186
+
187
+
188
+ def smooth_l1_loss(
189
+ predictions: mx.array, targets: mx.array, beta: float = 1.0, reduction: str = "mean"
190
+ ) -> mx.array:
191
+ r"""
192
+ Computes the smooth L1 loss.
193
+
194
+ The smooth L1 loss is a variant of the L1 loss which replaces the absolute
195
+ difference with a squared difference when the absolute difference is less
196
+ than ``beta``.
197
+
198
+ The formula for the smooth L1 Loss is:
199
+
200
+ .. math::
201
+
202
+ l =
203
+ \begin{cases}
204
+ 0.5 (x - y)^2, & \text{ if } & (x - y) < \beta \\
205
+ |x - y| - 0.5 \beta, & & \text{otherwise}
206
+ \end{cases}
207
+
208
+ Args:
209
+ predictions (array): Predicted values.
210
+ targets (array): Ground truth values.
211
+ beta (float, optional): The threshold after which the loss changes
212
+ from the squared to the absolute difference. Default: ``1.0``.
213
+ reduction (str, optional): Specifies the reduction to apply to the output:
214
+ ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``.
215
+
216
+ Returns:
217
+ array: The computed smooth L1 loss.
218
+ """
219
+ if predictions.shape != targets.shape:
220
+ raise ValueError(
221
+ f"Predictions shape {predictions.shape} does not match "
222
+ f"targets shape {targets.shape}."
223
+ )
224
+
225
+ diff = predictions - targets
226
+ loss = mx.where(
227
+ diff < beta, 0.5 * mx.square(diff) / beta, mx.abs(diff) - 0.5 * beta
228
+ )
229
+
230
+ return _reduce(loss, reduction)
231
+
232
+
233
+ def triplet_loss(
234
+ anchors: mx.array,
235
+ positives: mx.array,
236
+ negatives: mx.array,
237
+ axis: int = -1,
238
+ p: int = 2,
239
+ margin: float = 1.0,
240
+ eps: float = 1e-6,
241
+ reduction: str = "none",
242
+ ) -> mx.array:
243
+ r"""
244
+ Computes the triplet loss for a set of anchor, positive, and negative samples.
245
+ Margin is represented with alpha in the math section.
246
+
247
+ .. math::
248
+
249
+ L_{\text{triplet}} = \max\left(\|A - P\|_p - \|A - N\|_p + \alpha, 0\right)
250
+
251
+ Args:
252
+ anchors (array): The anchor samples.
253
+ positives (array): The positive samples.
254
+ negatives (array): The negative samples.
255
+ axis (int, optional): The distribution axis. Default: ``-1``.
256
+ p (int, optional): The norm degree for pairwise distance. Default: ``2``.
257
+ margin (float, optional): Margin for the triplet loss. Defaults to ``1.0``.
258
+ eps (float, optional): Small positive constant to prevent numerical instability. Defaults to ``1e-6``.
259
+ reduction (str, optional): Specifies the reduction to apply to the output:
260
+ ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
261
+
262
+ Returns:
263
+ array: Computed triplet loss. If reduction is "none", returns a tensor of the same shape as input;
264
+ if reduction is "mean" or "sum", returns a scalar tensor.
265
+ """
266
+ loss = mx.maximum(
267
+ mx.sqrt(mx.power(anchors - positives, p).sum(axis) + eps)
268
+ - mx.sqrt(mx.power(anchors - negatives, p).sum(axis) + eps)
269
+ + margin,
270
+ 0,
271
+ )
272
+ return _reduce(loss, reduction)
273
+
274
+
275
+ def _reduce(loss: mx.array, reduction: str = "none"):
276
+ if reduction == "mean":
277
+ return mx.mean(loss)
278
+ elif reduction == "sum":
279
+ return mx.sum(loss)
280
+ elif reduction == "none":
281
+ return loss
282
+ else:
283
+ raise ValueError("Invalid reduction. Must be 'none', 'mean', or 'sum'.")
284
+
285
+
286
+ def hinge_loss(
287
+ inputs: mx.array, targets: mx.array, reduction: str = "none"
288
+ ) -> mx.array:
289
+ r"""
290
+ Computes the hinge loss between inputs and targets.
291
+
292
+ .. math::
293
+
294
+ \text{hinge}(y, y_{\text{pred}}) = \max(0, 1 - y \cdot y_{\text{pred}})
295
+
296
+
297
+ Args:
298
+ inputs (array): The predicted values.
299
+ targets (array): The target values. They should be -1 or 1.
300
+ reduction (str, optional): Specifies the reduction to apply to the output:
301
+ ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
302
+
303
+ Returns:
304
+ array: The computed hinge loss.
305
+ """
306
+ loss = mx.maximum(1 - inputs * targets, 0)
307
+
308
+ return _reduce(loss, reduction)
309
+
310
+
311
+ def huber_loss(
312
+ inputs: mx.array, targets: mx.array, delta: float = 1.0, reduction: str = "none"
313
+ ) -> mx.array:
314
+ r"""
315
+ Computes the Huber loss between inputs and targets.
316
+
317
+ .. math::
318
+
319
+ L_{\delta}(a) =
320
+ \left\{ \begin{array}{ll}
321
+ \frac{1}{2} a^2 & \text{for } |a| \leq \delta, \\
322
+ \delta \left( |a| - \frac{1}{2} \delta \right) & \text{otherwise.}
323
+ \end{array} \right.
324
+
325
+ Args:
326
+ inputs (array): The predicted values.
327
+ targets (array): The target values.
328
+ delta (float, optional): The threshold at which to change between L1 and L2 loss.
329
+ Default: ``1.0``.
330
+ reduction (str, optional): Specifies the reduction to apply to the output:
331
+ ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
332
+
333
+ Returns:
334
+ array: The computed Huber loss.
335
+ """
336
+ errors = inputs - targets
337
+ abs_errors = mx.abs(errors)
338
+ quadratic = mx.minimum(abs_errors, delta)
339
+ linear = abs_errors - quadratic
340
+ loss = 0.5 * quadratic**2 + delta * linear
341
+
342
+ return _reduce(loss, reduction)
343
+
344
+
345
+ def log_cosh_loss(
346
+ inputs: mx.array, targets: mx.array, reduction: str = "none"
347
+ ) -> mx.array:
348
+ r"""
349
+ Computes the log cosh loss between inputs and targets.
350
+
351
+ Logcosh acts like L2 loss for small errors, ensuring stable gradients,
352
+ and like the L1 loss for large errors, reducing sensitivity to outliers. This
353
+ dual behavior offers a balanced, robust approach for regression tasks.
354
+
355
+ .. math::
356
+
357
+ \text{logcosh}(y_{\text{true}}, y_{\text{pred}}) =
358
+ \frac{1}{n} \sum_{i=1}^{n}
359
+ \log(\cosh(y_{\text{pred}}^{(i)} - y_{\text{true}}^{(i)}))
360
+
361
+
362
+ Args:
363
+ inputs (array): The predicted values.
364
+ targets (array): The target values.
365
+ reduction (str, optional): Specifies the reduction to apply to the output:
366
+ ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'none'``.
367
+
368
+ Returns:
369
+ array: The computed log cosh loss.
370
+ """
371
+ errors = inputs - targets
372
+ loss = mx.logaddexp(errors, -errors) - math.log(2)
373
+
374
+ return _reduce(loss, reduction)
lib/python3.11/site-packages/mlx/nn/utils.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2023 Apple Inc.
2
+
3
+ from typing import Callable
4
+
5
+ import mlx.core as mx
6
+
7
+
8
+ def value_and_grad(model: "mlx.nn.Module", fn: Callable):
9
+ """Transform the passed function ``fn`` to a function that computes the
10
+ gradients of ``fn`` wrt the model's trainable parameters and also its
11
+ value.
12
+
13
+ Args:
14
+ model (mlx.nn.Module): The model whose trainable parameters to compute
15
+ gradients for
16
+ fn (Callable): The scalar function to compute gradients for
17
+
18
+ Returns:
19
+ A callable that returns the value of ``fn`` and the gradients wrt the
20
+ trainable parameters of ``model``
21
+ """
22
+
23
+ def inner_fn(params, *args, **kwargs):
24
+ model.update(params)
25
+ return fn(*args, **kwargs)
26
+
27
+ value_grad_fn = mx.value_and_grad(inner_fn)
28
+
29
+ def wrapped_value_grad_fn(*args, **kwargs):
30
+ value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs)
31
+ return value, grad
32
+
33
+ return wrapped_value_grad_fn
lib/python3.11/site-packages/mlx/optimizers.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2023 Apple Inc.
2
+
3
+ import math
4
+ from typing import List
5
+
6
+ import mlx.core as mx
7
+ from mlx.utils import tree_map
8
+
9
+
10
+ class OptimizerState(dict):
11
+ """The optimizer state implements a recursively defined
12
+ :class:`collections.defaultdict`, namely a missing key in an optimizer
13
+ state is an :class:`OptimizerState`.
14
+
15
+ .. note::
16
+ :meth:`OptimizerState.get` in contrast to a normal dictionary also sets
17
+ the key to the ``default`` value if the ``key`` was not present in the
18
+ dictionary.
19
+ """
20
+
21
+ def __getitem__(self, key):
22
+ if key not in self:
23
+ self[key] = OptimizerState()
24
+ return super().__getitem__(key)
25
+
26
+ def get(self, key, default):
27
+ """If ``key`` doesn't exist set its value to ``default`` and then return it."""
28
+ if key not in self:
29
+ self[key] = default
30
+ return super().__getitem__(key)
31
+
32
+
33
+ class Optimizer:
34
+ """The base class for all optimizers. It allows us to implement an
35
+ optimizer on a per-parameter basis and apply it to a parameter tree.
36
+
37
+ Attributes:
38
+ state (OptimizerState): It holds the optimizer's state dictionary.
39
+ """
40
+
41
+ def __init__(self):
42
+ self.state = OptimizerState()
43
+
44
+ def update(self, model: "mlx.nn.Module", gradients: dict):
45
+ """Apply the gradients to the parameters of the model and update the
46
+ model with the new parameters.
47
+
48
+ Args:
49
+ model (mlx.nn.Module): An mlx module to be updated.
50
+ gradients (dict): A Python tree of gradients, most likely computed
51
+ via :func:`mlx.nn.value_and_grad`.
52
+ """
53
+ model.update(self.apply_gradients(gradients, model))
54
+
55
+ def apply_gradients(self, gradients: dict, model: dict):
56
+ """Apply the gradients to the parameters and return the updated parameters.
57
+
58
+ Can be used to update a model via
59
+ ``model.update(opt.apply_gradients(grads, model))`` which is precisely
60
+ how :meth:`Optimizer.update` is implemented.
61
+
62
+ Args:
63
+ gradients (dict): A Python tree of gradients.
64
+ model (dict): A Python tree of parameters. It can be a superset of
65
+ the gradients. In that case the returned python tree
66
+ will be of the same structure as the gradients.
67
+ """
68
+ return tree_map(self.apply_single, gradients, model, self.state)
69
+
70
+ def apply_single(
71
+ self, gradient: mx.array, parameter: mx.array, state: OptimizerState
72
+ ):
73
+ """To be extended by the children classes to implement each optimizer's
74
+ update."""
75
+ raise NotImplementedError()
76
+
77
+
78
+ class SGD(Optimizer):
79
+ r"""Stochastic gradient descent optimizer.
80
+
81
+ Updates a parameter :math:`w` with a gradient :math:`g` as follows
82
+
83
+ .. math::
84
+
85
+ v_{t+1} &= \mu v_t + (1 - \tau) g_t \\
86
+ w_{t+1} &= w_t - \lambda v_{t+1}
87
+
88
+ Args:
89
+ learning_rate (float): The learning rate :math:`\lambda`.
90
+ momentum (float, optional): The momentum strength :math:`\mu`. Default: ``0``
91
+ weight_decay (float, optional): The weight decay (L2 penalty). Default: ``0``
92
+ dampening (float, optional): Dampening for momentum :math:`\tau`. Default: ``0``
93
+ nesterov (bool, optional): Enables Nesterov momentum. Default: ``False``
94
+ """
95
+
96
+ def __init__(
97
+ self,
98
+ learning_rate: float,
99
+ momentum: float = 0.0,
100
+ weight_decay: float = 0.0,
101
+ dampening: float = 0.0,
102
+ nesterov: bool = False,
103
+ ):
104
+ if nesterov and (momentum <= 0 or dampening != 0):
105
+ raise ValueError(
106
+ "Nesterov momentum requires a momentum and zero dampening."
107
+ )
108
+ super().__init__()
109
+
110
+ self.learning_rate = learning_rate
111
+ self.momentum = momentum
112
+ self.weight_decay = weight_decay
113
+ self.dampening = dampening
114
+ self.nesterov = nesterov
115
+
116
+ def apply_single(
117
+ self, gradient: mx.array, parameter: mx.array, state: OptimizerState
118
+ ):
119
+ """Performs the SGD parameter update and stores :math:`v` in the
120
+ optimizer state."""
121
+ if self.momentum <= 0:
122
+ return parameter - self.learning_rate * gradient
123
+
124
+ v = state.get("v", mx.zeros_like(gradient))
125
+
126
+ if self.weight_decay != 0:
127
+ gradient += self.weight_decay * parameter
128
+
129
+ v = self.momentum * v
130
+ if self.dampening > 0:
131
+ v += (1 - self.dampening) * gradient
132
+ else:
133
+ v += gradient
134
+
135
+ if self.nesterov:
136
+ update = gradient + self.momentum * v
137
+ else:
138
+ update = v
139
+ state["v"] = v
140
+ return parameter - self.learning_rate * update
141
+
142
+
143
+ class RMSprop(Optimizer):
144
+ r"""Implementation of the RMSprop optimizer [1].
145
+
146
+ [1]: Tieleman, T. and Hinton, G. 2012. Lecture 6.5-rmsprop, coursera: Neural networks for machine learning
147
+
148
+ .. math::
149
+
150
+ v_{t+1} &= \alpha v_t + (1 - \alpha) g_t^2 \\
151
+ w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}
152
+
153
+ Args:
154
+ learning_rate (float): The learning rate :math:`\lambda`.
155
+ alpha (float, optional): The smoothing constant :math:`\alpha`.
156
+ Default: ``0.99``
157
+ eps (float, optional): The term :math:`\epsilon` added to the denominator
158
+ to improve numerical stability. Default: ``1e-8``
159
+ """
160
+
161
+ def __init__(self, learning_rate: float, alpha: float = 0.99, eps: float = 1e-8):
162
+ super().__init__()
163
+
164
+ self.learning_rate = learning_rate
165
+ self.alpha = alpha
166
+ self.eps = eps
167
+
168
+ if self.alpha < 0.0:
169
+ raise ValueError(
170
+ f"RMSprop alpha should be >=0, {self.alpha} was provided instead"
171
+ )
172
+ if self.eps < 0.0:
173
+ raise ValueError(
174
+ f"RMSprop epsilon should be >0, {self.eps} was provided instead"
175
+ )
176
+
177
+ def apply_single(
178
+ self, gradient: mx.array, parameter: mx.array, state: OptimizerState
179
+ ):
180
+ """Performs the RMSprop parameter update and stores :math:`v` in the optimizer state."""
181
+ lr = self.learning_rate
182
+ alpha = self.alpha
183
+ eps = self.eps
184
+
185
+ v = state.get("v", mx.zeros_like(gradient))
186
+ v = alpha * v + (1 - alpha) * mx.square(gradient)
187
+ state["v"] = v
188
+
189
+ return parameter - lr * gradient / (mx.sqrt(v) + eps)
190
+
191
+
192
+ class Adagrad(Optimizer):
193
+ r"""Implementation of the Adagrad optimizer [1].
194
+
195
+ Our Adagrad implementation follows the original paper. In detail,
196
+
197
+ [1]: Duchi, J., Hazan, E. and Singer, Y., 2011. Adaptive subgradient methods
198
+ for online learning and stochastic optimization. JMLR 2011.
199
+
200
+ .. math::
201
+
202
+ v_{t+1} &= v_t + g_t^2 \\
203
+ w_{t+1} &= w_t - \lambda \frac{g_t}{\sqrt{v_{t+1}} + \epsilon}
204
+
205
+ Args:
206
+ learning_rate (float): The learning rate :math:`\lambda`.
207
+ eps (float, optional): The term :math:`\epsilon` added to the
208
+ denominator to improve numerical stability. Default: ``1e-8``
209
+ """
210
+
211
+ def __init__(self, learning_rate: float, eps: float = 1e-8):
212
+ super().__init__()
213
+
214
+ self.learning_rate = learning_rate
215
+ self.eps = eps
216
+
217
+ if self.eps < 0.0:
218
+ raise ValueError(
219
+ f"Adagrad epsilon should be >0, {self.eps} was provided instead"
220
+ )
221
+
222
+ def apply_single(
223
+ self, gradient: mx.array, parameter: mx.array, state: OptimizerState
224
+ ):
225
+ """Performs the Adagrad parameter update and stores :math:`v` in the
226
+ optimizer state."""
227
+ lr = self.learning_rate
228
+ eps = self.eps
229
+
230
+ v = state.get("v", mx.zeros_like(gradient))
231
+ v = v + mx.square(gradient)
232
+ state["v"] = v
233
+
234
+ return parameter - lr * gradient / (mx.sqrt(v) + eps)
235
+
236
+
237
+ class AdaDelta(Optimizer):
238
+ r"""Implementation of the AdaDelta optimizer with learning rate[1].
239
+
240
+ Our AdaDelta implementation follows the original paper. In detail,
241
+
242
+ [1]: Zeiler, M.D., 2012. ADADELTA: an adaptive learning rate method. arXiv preprint arXiv:1212.5701.
243
+
244
+ .. math::
245
+
246
+ v_{t+1} &= \rho v_t + (1 - \rho) g_t^2 \\
247
+ \Delta w_{t+1} &= \frac{\sqrt{u_t + \epsilon}}{\sqrt{v_{t+1} + \epsilon}} g_t \\
248
+ u_{t+1} &= \rho u_t + (1 - \rho) \Delta w_{t+1}^2 \\
249
+ w_{t+1} &= w_t - \lambda \Delta w_{t+1}
250
+
251
+ Args:
252
+ learning_rate (float): The learning rate :math:`\lambda`.
253
+ rho (float, optional): The coefficient :math:`\rho` used for computing a
254
+ running average of squared gradients. Default: ``0.9``
255
+ eps (float, optional): The term :math:`\epsilon` added to the denominator to improve
256
+ numerical stability. Default: `1e-8`
257
+ """
258
+
259
+ def __init__(self, learning_rate: float, rho: float = 0.9, eps: float = 1e-6):
260
+ super().__init__()
261
+
262
+ self.learning_rate = learning_rate
263
+ self.rho = rho
264
+ self.eps = eps
265
+ if self.rho < 0.0:
266
+ raise ValueError(
267
+ f"AdaDelta rho should be >=0, {self.rho} was provided instead"
268
+ )
269
+ if self.eps < 0.0:
270
+ raise ValueError(
271
+ f"AdaDelta epsilon should be >0, {self.eps} was provided instead"
272
+ )
273
+
274
+ def apply_single(
275
+ self, gradient: mx.array, parameter: mx.array, state: OptimizerState
276
+ ):
277
+ """Performs the AdaDelta parameter update and stores :math:`v` and
278
+ :math:`u` in the optimizer state."""
279
+ lr = self.learning_rate
280
+ rho = self.rho
281
+ eps = self.eps
282
+
283
+ v = state.get("v", mx.zeros_like(gradient))
284
+ u = state.get("s", mx.zeros_like(gradient))
285
+
286
+ v = rho * v + (1 - rho) * mx.square(gradient)
287
+ d = mx.sqrt(u + eps) / mx.sqrt(v + eps) * gradient
288
+ u = rho * u + (1 - rho) * mx.square(d)
289
+
290
+ state["v"] = v
291
+ state["u"] = u
292
+
293
+ return parameter - lr * d
294
+
295
+
296
+ class Adam(Optimizer):
297
+ r"""Implementation of the Adam optimizer [1].
298
+
299
+ Our Adam implementation follows the original paper and omits the bias
300
+ correction in the first and second moment estimates. In detail,
301
+
302
+ [1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic
303
+ optimization. ICLR 2015.
304
+
305
+ .. math::
306
+
307
+ m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
308
+ v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
309
+ w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}}
310
+
311
+ Args:
312
+ learning_rate (float): The learning rate :math:`\lambda`.
313
+ betas (Tuple[float, float], optional): The coefficients
314
+ :math:`(\beta_1, \beta_2)` used for computing running averages of the
315
+ gradient and its square. Default: ``(0.9, 0.999)``
316
+ eps (float, optional): The term :math:`\epsilon` added to the
317
+ denominator to improve numerical stability. Default: ``1e-8``
318
+ """
319
+
320
+ def __init__(
321
+ self, learning_rate: float, betas: List[float] = [0.9, 0.999], eps: float = 1e-8
322
+ ):
323
+ super().__init__()
324
+
325
+ self.learning_rate = learning_rate
326
+ self.betas = betas
327
+ self.eps = eps
328
+
329
+ def apply_single(
330
+ self, gradient: mx.array, parameter: mx.array, state: OptimizerState
331
+ ):
332
+ """Performs the Adam parameter update and stores :math:`v` and
333
+ :math:`m` in the optimizer state."""
334
+ lr = self.learning_rate
335
+ b1, b2 = self.betas
336
+ eps = self.eps
337
+
338
+ m = state.get("m", gradient)
339
+ v = state.get("v", mx.square(gradient))
340
+ m = b1 * m + (1 - b1) * gradient
341
+ v = b2 * v + (1 - b2) * mx.square(gradient)
342
+ state["m"] = m
343
+ state["v"] = v
344
+
345
+ return parameter - lr * m / (mx.sqrt(v) + eps)
346
+
347
+
348
+ class AdamW(Adam):
349
+ r"""Implementation of the AdamW optimizer [1].
350
+
351
+ Following the above convention, in contrast with [1], we do not use bias
352
+ correction in the first and second moments for AdamW. We update the weights
353
+ with a weight_decay (:math:`\lambda`) value:
354
+
355
+ [1]: Loshchilov, I. and Hutter, F., 2019. Decoupled weight decay
356
+ regularization. ICLR 2019.
357
+
358
+ .. math::
359
+
360
+ m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
361
+ v_{t+1} &= \beta_2 v_t + (1 - \beta_2) g_t^2 \\
362
+ w_{t+1} &= w_t - \alpha (\frac{m_{t+1}}{\sqrt{v_{t+1} + \epsilon}} + \lambda w_t)
363
+
364
+ Args:
365
+ learning_rate (float): The learning rate :math:`\alpha`.
366
+ betas (Tuple[float, float], optional): The coefficients
367
+ :math:`(\beta_1, \beta_2)` used for computing running averages of the
368
+ gradient and its square. Default: ``(0.9, 0.999)``
369
+ eps (float, optional): The term :math:`\epsilon` added to the
370
+ denominator to improve numerical stability. Default: ``1e-8``
371
+ weight_decay (float, optional): The weight decay :math:`\lambda`.
372
+ Default: ``0``.
373
+ """
374
+
375
+ def __init__(
376
+ self,
377
+ learning_rate: float,
378
+ betas: List[float] = [0.9, 0.999],
379
+ eps: float = 1e-8,
380
+ weight_decay: float = 0.01,
381
+ ):
382
+ super().__init__(learning_rate=learning_rate, betas=betas, eps=eps)
383
+ self.weight_decay = weight_decay
384
+
385
+ def apply_single(
386
+ self, gradient: mx.array, parameter: mx.array, state: OptimizerState
387
+ ):
388
+ """Performs the AdamW parameter update by modifying the parameters
389
+ passed into Adam.
390
+ """
391
+
392
+ return super().apply_single(
393
+ gradient, parameter * (1 - self.learning_rate * self.weight_decay), state
394
+ )
395
+
396
+
397
+ class Adamax(Adam):
398
+ r"""Implementation of the Adamax optimizer. It is a variant of Adam based
399
+ on the infinity norm [1].
400
+
401
+ Our Adam implementation follows the original paper and omits the bias
402
+ correction in the first and second moment estimates. In detail,
403
+
404
+ [1]: Kingma, D.P. and Ba, J., 2015. Adam: A method for stochastic
405
+ optimization. ICLR 2015.
406
+
407
+ .. math::
408
+
409
+ m_{t+1} &= \beta_1 m_t + (1 - \beta_1) g_t \\
410
+ v_{t+1} &= \max(\beta_2 v_t, |g_t|) \\
411
+ w_{t+1} &= w_t - \lambda \frac{m_{t+1}}{v_{t+1} + \epsilon}
412
+
413
+ Args:
414
+ learning_rate (float): The learning rate :math:`\lambda`.
415
+ betas (Tuple[float, float], optional): The coefficients
416
+ :math:`(\beta_1, \beta_2)` used for computing running averages of the
417
+ gradient and its square. Default: ``(0.9, 0.999)``
418
+ eps (float, optional): The term :math:`\epsilon` added to the
419
+ denominator to improve numerical stability. Default: ``1e-8``
420
+ """
421
+
422
+ def __init__(
423
+ self, learning_rate: float, betas: List[float] = [0.9, 0.999], eps: float = 1e-8
424
+ ):
425
+ super().__init__(learning_rate, betas, eps)
426
+
427
+ def apply_single(
428
+ self, gradient: mx.array, parameter: mx.array, state: OptimizerState
429
+ ):
430
+ """Performs the Adamax parameter update and stores :math:`v` and
431
+ :math:`m` in the optimizer state."""
432
+ lr = self.learning_rate
433
+ b1, b2 = self.betas
434
+ eps = self.eps
435
+
436
+ m = state.get("m", mx.zeros_like(gradient))
437
+ v = state.get("v", mx.zeros_like(gradient))
438
+
439
+ m = b1 * m + (1 - b1) * gradient
440
+ v = mx.maximum(b2 * v, mx.abs(gradient))
441
+ state["m"] = m
442
+ state["v"] = v
443
+
444
+ return parameter - lr * m / (v + eps)
445
+
446
+
447
+ class Lion(Optimizer):
448
+ r"""Implementation of the Lion optimizer [1].
449
+
450
+ Since updates are computed through the sign operation, they tend to
451
+ have larger norm than for other optimizers such as SGD and Adam.
452
+ We recommend a learning rate that is 3-10x smaller than AdamW and a
453
+ weight decay 3-10x larger than AdamW to maintain the strength
454
+ (lr * wd). Our Lion implementation follows the original paper. In
455
+ detail,
456
+
457
+ [1]: Chen, X. Symbolic Discovery of Optimization Algorithms. arXiv
458
+ preprint arXiv:2302.06675.
459
+
460
+ .. math::
461
+
462
+ c_{t + 1} &= \beta_1 m_t + (1 - \beta_1) g_t
463
+ m_{t + 1} &= \beta_2 m_t + (1 - \beta_2) g_t
464
+ w_{t + 1} &= w_t - \eta (\text{sign}(c_t) + \lambda w_t)
465
+
466
+ Args:
467
+ learning_rate (float): The learning rate :math:`\eta`.
468
+ betas (Tuple[float, float], optional): The coefficients
469
+ :math:`(\beta_1, \beta_2)` used for computing the gradient
470
+ momentum and update direction. Default: ``(0.9, 0.99)``
471
+ weight_decay (float, optional): The weight decay :math:`\lambda`. Default: ``0.0``
472
+ """
473
+
474
+ def __init__(
475
+ self,
476
+ learning_rate: float,
477
+ betas: List[float] = [0.9, 0.99],
478
+ weight_decay: float = 0.0,
479
+ ):
480
+ super().__init__()
481
+
482
+ self.learning_rate = learning_rate
483
+ self.betas = betas
484
+ self.weight_decay = weight_decay
485
+
486
+ def apply_single(
487
+ self, gradient: mx.array, parameter: mx.array, state: OptimizerState
488
+ ):
489
+ """Performs the Lion parameter update and stores :math:`m`
490
+ in the optimizer state."""
491
+ lr = self.learning_rate
492
+ b1, b2 = self.betas
493
+ weight_decay = self.weight_decay
494
+
495
+ m = state.get("m", gradient)
496
+ c = b1 * m + (1 - b1) * gradient
497
+ state["m"] = b2 * m + (1 - b2) * gradient
498
+ if weight_decay > 0:
499
+ parameter = (1 - lr * weight_decay) * parameter
500
+ return parameter - lr * mx.sign(c)
lib/python3.11/site-packages/mlx/share/cmake/MLX/MLXConfig.cmake ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Find MLX
2
+ #
3
+ # Defines the following variables:
4
+ #
5
+ # MLX_FOUND : True if MLX is found
6
+ # MLX_INCLUDE_DIRS : Include directory
7
+ # MLX_LIBRARIES : Libraries to link against
8
+ # MLX_CXX_FLAGS : Additional compiler flags
9
+ # MLX_BUILD_ACCELERATE : True if MLX was built with accelerate
10
+ # MLX_BUILD_METAL : True if MLX was built with metal
11
+
12
+
13
+ ####### Expanded from @PACKAGE_INIT@ by configure_package_config_file() #######
14
+ ####### Any changes to this file will be overwritten by the next CMake run ####
15
+ ####### The input file was mlx.pc.in ########
16
+
17
+ get_filename_component(PACKAGE_PREFIX_DIR "${CMAKE_CURRENT_LIST_DIR}/../../../" ABSOLUTE)
18
+
19
+ macro(set_and_check _var _file)
20
+ set(${_var} "${_file}")
21
+ if(NOT EXISTS "${_file}")
22
+ message(FATAL_ERROR "File or directory ${_file} referenced by variable ${_var} does not exist !")
23
+ endif()
24
+ endmacro()
25
+
26
+ ####################################################################################
27
+
28
+ include(${PACKAGE_PREFIX_DIR}/share/cmake/MLX/MLXTargets.cmake)
29
+ include(${PACKAGE_PREFIX_DIR}/share/cmake/MLX/extension.cmake)
30
+
31
+ set_and_check(MLX_LIBRARY_DIRS ${PACKAGE_PREFIX_DIR}/lib)
32
+ set_and_check(MLX_INCLUDE_DIRS ${PACKAGE_PREFIX_DIR}/include)
33
+ set(MLX_LIBRARIES mlx)
34
+
35
+ find_library(MLX_LIBRARY mlx PATHS ${MLX_LIBRARY_DIRS})
36
+
37
+ if (ON)
38
+ set(MLX_BUILD_ACCELERATE ON)
39
+ set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -DACCELERATE_NEW_LAPACK)
40
+ endif()
41
+
42
+ if (ON)
43
+ set(MLX_BUILD_METAL ON)
44
+ set(MLX_CXX_FLAGS ${MLX_CXX_FLAGS} -D_METAL_)
45
+ set_and_check(MLX_INCLUDE_DIRS
46
+ ${MLX_INCLUDE_DIRS}
47
+ ${PACKAGE_PREFIX_DIR}/include/metal_cpp
48
+ )
49
+ endif()
50
+
51
+ set_target_properties(mlx PROPERTIES
52
+ CXX_STANDARD 17
53
+ INTERFACE_COMPILE_OPTIONS "${MLX_CXX_FLAGS}"
54
+ )
55
+
56
+ include(FindPackageHandleStandardArgs)
57
+ find_package_handle_standard_args(MLX DEFAULT_MSG MLX_LIBRARY MLX_INCLUDE_DIRS)
lib/python3.11/site-packages/mlx/share/cmake/MLX/MLXConfigVersion.cmake ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is a basic version file for the Config-mode of find_package().
2
+ # It is used by write_basic_package_version_file() as input file for configure_file()
3
+ # to create a version-file which can be installed along a config.cmake file.
4
+ #
5
+ # The created file sets PACKAGE_VERSION_EXACT if the current version string and
6
+ # the requested version string are exactly the same and it sets
7
+ # PACKAGE_VERSION_COMPATIBLE if the current version is >= requested version,
8
+ # but only if the requested major version is the same as the current one.
9
+ # The variable CVF_VERSION must be set before calling configure_file().
10
+
11
+
12
+ set(PACKAGE_VERSION "0.0.7")
13
+
14
+ if(PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION)
15
+ set(PACKAGE_VERSION_COMPATIBLE FALSE)
16
+ else()
17
+
18
+ if("0.0.7" MATCHES "^([0-9]+)\\.")
19
+ set(CVF_VERSION_MAJOR "${CMAKE_MATCH_1}")
20
+ if(NOT CVF_VERSION_MAJOR VERSION_EQUAL 0)
21
+ string(REGEX REPLACE "^0+" "" CVF_VERSION_MAJOR "${CVF_VERSION_MAJOR}")
22
+ endif()
23
+ else()
24
+ set(CVF_VERSION_MAJOR "0.0.7")
25
+ endif()
26
+
27
+ if(PACKAGE_FIND_VERSION_RANGE)
28
+ # both endpoints of the range must have the expected major version
29
+ math (EXPR CVF_VERSION_MAJOR_NEXT "${CVF_VERSION_MAJOR} + 1")
30
+ if (NOT PACKAGE_FIND_VERSION_MIN_MAJOR STREQUAL CVF_VERSION_MAJOR
31
+ OR ((PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "INCLUDE" AND NOT PACKAGE_FIND_VERSION_MAX_MAJOR STREQUAL CVF_VERSION_MAJOR)
32
+ OR (PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "EXCLUDE" AND NOT PACKAGE_FIND_VERSION_MAX VERSION_LESS_EQUAL CVF_VERSION_MAJOR_NEXT)))
33
+ set(PACKAGE_VERSION_COMPATIBLE FALSE)
34
+ elseif(PACKAGE_FIND_VERSION_MIN_MAJOR STREQUAL CVF_VERSION_MAJOR
35
+ AND ((PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "INCLUDE" AND PACKAGE_VERSION VERSION_LESS_EQUAL PACKAGE_FIND_VERSION_MAX)
36
+ OR (PACKAGE_FIND_VERSION_RANGE_MAX STREQUAL "EXCLUDE" AND PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION_MAX)))
37
+ set(PACKAGE_VERSION_COMPATIBLE TRUE)
38
+ else()
39
+ set(PACKAGE_VERSION_COMPATIBLE FALSE)
40
+ endif()
41
+ else()
42
+ if(PACKAGE_FIND_VERSION_MAJOR STREQUAL CVF_VERSION_MAJOR)
43
+ set(PACKAGE_VERSION_COMPATIBLE TRUE)
44
+ else()
45
+ set(PACKAGE_VERSION_COMPATIBLE FALSE)
46
+ endif()
47
+
48
+ if(PACKAGE_FIND_VERSION STREQUAL PACKAGE_VERSION)
49
+ set(PACKAGE_VERSION_EXACT TRUE)
50
+ endif()
51
+ endif()
52
+ endif()
53
+
54
+
55
+ # if the installed or the using project don't have CMAKE_SIZEOF_VOID_P set, ignore it:
56
+ if("${CMAKE_SIZEOF_VOID_P}" STREQUAL "" OR "8" STREQUAL "")
57
+ return()
58
+ endif()
59
+
60
+ # check that the installed version has the same 32/64bit-ness as the one which is currently searching:
61
+ if(NOT CMAKE_SIZEOF_VOID_P STREQUAL "8")
62
+ math(EXPR installedBits "8 * 8")
63
+ set(PACKAGE_VERSION "${PACKAGE_VERSION} (${installedBits}bit)")
64
+ set(PACKAGE_VERSION_UNSUITABLE TRUE)
65
+ endif()
lib/python3.11/site-packages/mlx/share/cmake/MLX/MLXTargets-release.cmake ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #----------------------------------------------------------------
2
+ # Generated CMake target import file for configuration "Release".
3
+ #----------------------------------------------------------------
4
+
5
+ # Commands may need to know the format version.
6
+ set(CMAKE_IMPORT_FILE_VERSION 1)
7
+
8
+ # Import target "mlx" for configuration "Release"
9
+ set_property(TARGET mlx APPEND PROPERTY IMPORTED_CONFIGURATIONS RELEASE)
10
+ set_target_properties(mlx PROPERTIES
11
+ IMPORTED_LOCATION_RELEASE "${_IMPORT_PREFIX}/lib/libmlx.dylib"
12
+ IMPORTED_SONAME_RELEASE "@rpath/libmlx.dylib"
13
+ )
14
+
15
+ list(APPEND _cmake_import_check_targets mlx )
16
+ list(APPEND _cmake_import_check_files_for_mlx "${_IMPORT_PREFIX}/lib/libmlx.dylib" )
17
+
18
+ # Commands beyond this point should not need to know the version.
19
+ set(CMAKE_IMPORT_FILE_VERSION)
lib/python3.11/site-packages/mlx/share/cmake/MLX/MLXTargets.cmake ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated by CMake
2
+
3
+ if("${CMAKE_MAJOR_VERSION}.${CMAKE_MINOR_VERSION}" LESS 2.8)
4
+ message(FATAL_ERROR "CMake >= 2.8.0 required")
5
+ endif()
6
+ if(CMAKE_VERSION VERSION_LESS "2.8.3")
7
+ message(FATAL_ERROR "CMake >= 2.8.3 required")
8
+ endif()
9
+ cmake_policy(PUSH)
10
+ cmake_policy(VERSION 2.8.3...3.24)
11
+ #----------------------------------------------------------------
12
+ # Generated CMake target import file.
13
+ #----------------------------------------------------------------
14
+
15
+ # Commands may need to know the format version.
16
+ set(CMAKE_IMPORT_FILE_VERSION 1)
17
+
18
+ # Protect against multiple inclusion, which would fail when already imported targets are added once more.
19
+ set(_cmake_targets_defined "")
20
+ set(_cmake_targets_not_defined "")
21
+ set(_cmake_expected_targets "")
22
+ foreach(_cmake_expected_target IN ITEMS mlx)
23
+ list(APPEND _cmake_expected_targets "${_cmake_expected_target}")
24
+ if(TARGET "${_cmake_expected_target}")
25
+ list(APPEND _cmake_targets_defined "${_cmake_expected_target}")
26
+ else()
27
+ list(APPEND _cmake_targets_not_defined "${_cmake_expected_target}")
28
+ endif()
29
+ endforeach()
30
+ unset(_cmake_expected_target)
31
+ if(_cmake_targets_defined STREQUAL _cmake_expected_targets)
32
+ unset(_cmake_targets_defined)
33
+ unset(_cmake_targets_not_defined)
34
+ unset(_cmake_expected_targets)
35
+ unset(CMAKE_IMPORT_FILE_VERSION)
36
+ cmake_policy(POP)
37
+ return()
38
+ endif()
39
+ if(NOT _cmake_targets_defined STREQUAL "")
40
+ string(REPLACE ";" ", " _cmake_targets_defined_text "${_cmake_targets_defined}")
41
+ string(REPLACE ";" ", " _cmake_targets_not_defined_text "${_cmake_targets_not_defined}")
42
+ message(FATAL_ERROR "Some (but not all) targets in this export set were already defined.\nTargets Defined: ${_cmake_targets_defined_text}\nTargets not yet defined: ${_cmake_targets_not_defined_text}\n")
43
+ endif()
44
+ unset(_cmake_targets_defined)
45
+ unset(_cmake_targets_not_defined)
46
+ unset(_cmake_expected_targets)
47
+
48
+
49
+ # Compute the installation prefix relative to this file.
50
+ get_filename_component(_IMPORT_PREFIX "${CMAKE_CURRENT_LIST_FILE}" PATH)
51
+ get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
52
+ get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
53
+ get_filename_component(_IMPORT_PREFIX "${_IMPORT_PREFIX}" PATH)
54
+ if(_IMPORT_PREFIX STREQUAL "/")
55
+ set(_IMPORT_PREFIX "")
56
+ endif()
57
+
58
+ # Create imported target mlx
59
+ add_library(mlx SHARED IMPORTED)
60
+
61
+ set_target_properties(mlx PROPERTIES
62
+ INTERFACE_INCLUDE_DIRECTORIES "${_IMPORT_PREFIX}/include/metal_cpp;${_IMPORT_PREFIX}/include/json;${_IMPORT_PREFIX}/include;${_IMPORT_PREFIX}/include"
63
+ INTERFACE_LINK_LIBRARIES "/Applications/Xcode-beta.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX14.0.sdk/System/Library/Frameworks/Metal.framework;/Applications/Xcode-beta.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX14.0.sdk/System/Library/Frameworks/Foundation.framework;/Applications/Xcode-beta.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX14.0.sdk/System/Library/Frameworks/QuartzCore.framework;/Applications/Xcode-beta.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX14.0.sdk/System/Library/Frameworks/Accelerate.framework"
64
+ )
65
+
66
+ if(CMAKE_VERSION VERSION_LESS 2.8.12)
67
+ message(FATAL_ERROR "This file relies on consumers using CMake 2.8.12 or greater.")
68
+ endif()
69
+
70
+ # Load information for each installed configuration.
71
+ file(GLOB _cmake_config_files "${CMAKE_CURRENT_LIST_DIR}/MLXTargets-*.cmake")
72
+ foreach(_cmake_config_file IN LISTS _cmake_config_files)
73
+ include("${_cmake_config_file}")
74
+ endforeach()
75
+ unset(_cmake_config_file)
76
+ unset(_cmake_config_files)
77
+
78
+ # Cleanup temporary variables.
79
+ set(_IMPORT_PREFIX)
80
+
81
+ # Loop over all imported files and verify that they actually exist
82
+ foreach(_cmake_target IN LISTS _cmake_import_check_targets)
83
+ foreach(_cmake_file IN LISTS "_cmake_import_check_files_for_${_cmake_target}")
84
+ if(NOT EXISTS "${_cmake_file}")
85
+ message(FATAL_ERROR "The imported target \"${_cmake_target}\" references the file
86
+ \"${_cmake_file}\"
87
+ but this file does not exist. Possible reasons include:
88
+ * The file was deleted, renamed, or moved to another location.
89
+ * An install or uninstall procedure did not complete successfully.
90
+ * The installation package was faulty and contained
91
+ \"${CMAKE_CURRENT_LIST_FILE}\"
92
+ but not all the files it references.
93
+ ")
94
+ endif()
95
+ endforeach()
96
+ unset(_cmake_file)
97
+ unset("_cmake_import_check_files_for_${_cmake_target}")
98
+ endforeach()
99
+ unset(_cmake_target)
100
+ unset(_cmake_import_check_targets)
101
+
102
+ # This file does not depend on other imported targets which have
103
+ # been exported from the same project but in a separate export set.
104
+
105
+ # Commands beyond this point should not need to know the version.
106
+ set(CMAKE_IMPORT_FILE_VERSION)
107
+ cmake_policy(POP)
lib/python3.11/site-packages/mlx/share/cmake/MLX/extension.cmake ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ include(CMakeParseArguments)
2
+
3
+ ###############################################################################
4
+ # Build metal library
5
+ #
6
+ # Adds a custom target ${TARGET} to build ${OUTPUT_DIRECTORY}/{TITLE}.metallib
7
+ # from list ${SOURCES}, including list ${INCLUDE_DIRS}, depends on list ${DEPS}
8
+ #
9
+ # Args:
10
+ # TARGET: Custom target to be added for the metal library
11
+ # TITLE: Name of the .metallib
12
+ # OUTPUT_DIRECTORY: Where to place ${TITLE}.metallib
13
+ # SOURCES: List of source files
14
+ # INCLUDE_DIRS: List of include dirs
15
+ # DEPS: List of dependency files (like headers)
16
+ #
17
+ macro(mlx_build_metallib)
18
+ # Parse args
19
+ set(oneValueArgs TARGET TITLE OUTPUT_DIRECTORY)
20
+ set(multiValueArgs SOURCES INCLUDE_DIRS DEPS)
21
+ cmake_parse_arguments(
22
+ MTLLIB
23
+ ""
24
+ "${oneValueArgs}"
25
+ "${multiValueArgs}"
26
+ ${ARGN}
27
+ )
28
+
29
+ # Set output
30
+ set(MTLLIB_BUILD_TARGET "${MTLLIB_OUTPUT_DIRECTORY}/${MTLLIB_TITLE}.metallib")
31
+
32
+ # Collect compile options
33
+ set(MTLLIB_COMPILE_OPTIONS -Wall -Wextra -fno-fast-math)
34
+
35
+ # Prepare metallib build command
36
+ add_custom_command(
37
+ OUTPUT ${MTLLIB_BUILD_TARGET}
38
+ COMMAND xcrun -sdk macosx metal
39
+ "$<LIST:TRANSFORM,${MTLLIB_INCLUDE_DIRS},PREPEND,-I>"
40
+ ${MTLLIB_COMPILE_OPTIONS}
41
+ ${MTLLIB_SOURCES}
42
+ -o ${MTLLIB_BUILD_TARGET}
43
+ DEPENDS ${MTLLIB_DEPS} ${MTLLIB_SOURCES}
44
+ COMMAND_EXPAND_LISTS
45
+ COMMENT "Building ${MTLLIB_TITLE}.metallib"
46
+ VERBATIM
47
+ )
48
+
49
+ # Add metallib custom target
50
+ add_custom_target(
51
+ ${MTLLIB_TARGET}
52
+ DEPENDS
53
+ ${MTLLIB_BUILD_TARGET}
54
+ )
55
+
56
+ endmacro(mlx_build_metallib)
lib/python3.11/site-packages/mlx/utils.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2023 Apple Inc.
2
+
3
+ from collections import defaultdict
4
+
5
+
6
+ def tree_map(fn, tree, *rest, is_leaf=None):
7
+ """Applies ``fn`` to the leaves of the python tree ``tree`` and
8
+ returns a new collection with the results.
9
+
10
+ If ``rest`` is provided, every item is assumed to be a superset of ``tree``
11
+ and the corresponding leaves are provided as extra positional arguments to
12
+ ``fn``. In that respect, :meth:`tree_map` is closer to :func:`itertools.starmap`
13
+ than to :func:`map`.
14
+
15
+ The keyword argument ``is_leaf`` decides what constitutes a leaf from
16
+ ``tree`` similar to :func:`tree_flatten`.
17
+
18
+ .. code-block:: python
19
+
20
+ import mlx.nn as nn
21
+ from mlx.utils import tree_map
22
+
23
+ model = nn.Linear(10, 10)
24
+ print(model.parameters().keys())
25
+ # dict_keys(['weight', 'bias'])
26
+
27
+ # square the parameters
28
+ model.update(tree_map(lambda x: x*x, model.parameters()))
29
+
30
+ Args:
31
+ fn (Callable): The function that processes the leaves of the tree
32
+ tree (Any): The main python tree that will be iterated upon
33
+ rest (Tuple[Any]): Extra trees to be iterated together with tree
34
+ is_leaf (Optional[Callable]): An optional callable that returns True if
35
+ the passed object is considered a leaf or False otherwise.
36
+
37
+ Returns:
38
+ A python tree with the new values returned by ``fn``.
39
+ """
40
+ if is_leaf is not None and is_leaf(tree):
41
+ return fn(tree, *rest)
42
+ elif isinstance(tree, (list, tuple)):
43
+ TreeType = type(tree)
44
+ return TreeType(
45
+ tree_map(fn, child, *(r[i] for r in rest), is_leaf=is_leaf)
46
+ for i, child in enumerate(tree)
47
+ )
48
+ elif isinstance(tree, dict):
49
+ return {
50
+ k: tree_map(fn, child, *(r[k] for r in rest), is_leaf=is_leaf)
51
+ for k, child in tree.items()
52
+ }
53
+ else:
54
+ return fn(tree, *rest)
55
+
56
+
57
+ def tree_flatten(tree, prefix="", is_leaf=None):
58
+ """Flattens a python tree to a list of key, value tuples.
59
+
60
+ The keys are using the dot notation to define trees of arbitrary depth and
61
+ complexity.
62
+
63
+ .. code-block:: python
64
+
65
+ from mlx.utils import tree_flatten
66
+
67
+ print(tree_flatten([[[0]]]))
68
+ # [("0.0.0", 0)]
69
+
70
+ print(tree_flatten([[[0]]], ".hello"))
71
+ # [("hello.0.0.0", 0)]
72
+
73
+ .. note::
74
+ Dictionaries should have keys that are valid python identifiers.
75
+
76
+ Args:
77
+ tree (Any): The python tree to be flattened.
78
+ prefix (str): A prefix to use for the keys. The first character is
79
+ always discarded.
80
+ is_leaf (Callable): An optional callable that returns True if the
81
+ passed object is considered a leaf or False otherwise.
82
+
83
+ Returns:
84
+ List[Tuple[str, Any]]: The flat representation of the python tree.
85
+ """
86
+ flat_tree = []
87
+
88
+ if is_leaf is None or not is_leaf(tree):
89
+ if isinstance(tree, (list, tuple)):
90
+ for i, t in enumerate(tree):
91
+ flat_tree.extend(tree_flatten(t, f"{prefix}.{i}", is_leaf))
92
+ return flat_tree
93
+ if isinstance(tree, dict):
94
+ for k, t in tree.items():
95
+ flat_tree.extend(tree_flatten(t, f"{prefix}.{k}", is_leaf))
96
+ return flat_tree
97
+
98
+ return [(prefix[1:], tree)]
99
+
100
+
101
+ def tree_unflatten(tree):
102
+ """Recreate a python tree from its flat representation.
103
+
104
+ .. code-block:: python
105
+
106
+ from mlx.utils import tree_unflatten
107
+
108
+ d = tree_unflatten([("hello.world", 42)])
109
+ print(d)
110
+ # {"hello": {"world": 42}}
111
+
112
+ Args:
113
+ tree (List[Tuple[str, Any]]): The flat representation of a python tree.
114
+ For instance as returned by :meth:`tree_flatten`.
115
+
116
+ Returns:
117
+ A python tree.
118
+ """
119
+ if len(tree) == 1 and tree[0][0] == "":
120
+ return tree[0][1]
121
+
122
+ try:
123
+ int(tree[0][0].split(".", maxsplit=1)[0])
124
+ is_list = True
125
+ except ValueError:
126
+ is_list = False
127
+
128
+ # collect children
129
+ children = defaultdict(list)
130
+ for key, value in tree:
131
+ current_idx, *next_idx = key.split(".", maxsplit=1)
132
+ next_idx = "" if not next_idx else next_idx[0]
133
+ children[current_idx].append((next_idx, value))
134
+
135
+ # recursively map them to the original container
136
+ if is_list:
137
+ keys = sorted((int(idx), idx) for idx in children.keys())
138
+ l = []
139
+ for i, k in keys:
140
+ # if i <= len(l), no {} will be appended.
141
+ l.extend([{} for _ in range(i - len(l))])
142
+ l.append(tree_unflatten(children[k]))
143
+ return l
144
+ else:
145
+ return {k: tree_unflatten(v) for k, v in children.items()}
lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/INSTALLER ADDED
@@ -0,0 +1 @@
 
 
1
+ pip
lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/LICENSE ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2012 Erik Rose
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ this software and associated documentation files (the "Software"), to deal in
5
+ the Software without restriction, including without limitation the rights to
6
+ use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
7
+ of the Software, and to permit persons to whom the Software is furnished to do
8
+ so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in all
11
+ copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19
+ SOFTWARE.
lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/METADATA ADDED
@@ -0,0 +1,253 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: more-itertools
3
+ Version: 10.1.0
4
+ Summary: More routines for operating on iterables, beyond itertools
5
+ Keywords: itertools,iterator,iteration,filter,peek,peekable,chunk,chunked
6
+ Author-email: Erik Rose <[email protected]>
7
+ Requires-Python: >=3.8
8
+ Description-Content-Type: text/x-rst
9
+ Classifier: Development Status :: 5 - Production/Stable
10
+ Classifier: Intended Audience :: Developers
11
+ Classifier: Natural Language :: English
12
+ Classifier: License :: OSI Approved :: MIT License
13
+ Classifier: Programming Language :: Python :: 3
14
+ Classifier: Programming Language :: Python :: 3.8
15
+ Classifier: Programming Language :: Python :: 3.9
16
+ Classifier: Programming Language :: Python :: 3.10
17
+ Classifier: Programming Language :: Python :: 3.11
18
+ Classifier: Programming Language :: Python :: 3 :: Only
19
+ Classifier: Programming Language :: Python :: Implementation :: CPython
20
+ Classifier: Programming Language :: Python :: Implementation :: PyPy
21
+ Classifier: Topic :: Software Development :: Libraries
22
+ Project-URL: Homepage, https://github.com/more-itertools/more-itertools
23
+
24
+ ==============
25
+ More Itertools
26
+ ==============
27
+
28
+ .. image:: https://readthedocs.org/projects/more-itertools/badge/?version=latest
29
+ :target: https://more-itertools.readthedocs.io/en/stable/
30
+
31
+ Python's ``itertools`` library is a gem - you can compose elegant solutions
32
+ for a variety of problems with the functions it provides. In ``more-itertools``
33
+ we collect additional building blocks, recipes, and routines for working with
34
+ Python iterables.
35
+
36
+ +------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
37
+ | Grouping | `chunked <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.chunked>`_, |
38
+ | | `ichunked <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.ichunked>`_, |
39
+ | | `chunked_even <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.chunked_even>`_, |
40
+ | | `sliced <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.sliced>`_, |
41
+ | | `constrained_batches <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.constrained_batches>`_, |
42
+ | | `distribute <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distribute>`_, |
43
+ | | `divide <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.divide>`_, |
44
+ | | `split_at <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.split_at>`_, |
45
+ | | `split_before <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.split_before>`_, |
46
+ | | `split_after <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.split_after>`_, |
47
+ | | `split_into <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.split_into>`_, |
48
+ | | `split_when <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.split_when>`_, |
49
+ | | `bucket <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.bucket>`_, |
50
+ | | `unzip <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.unzip>`_, |
51
+ | | `batched <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.batched>`_, |
52
+ | | `grouper <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.grouper>`_, |
53
+ | | `partition <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.partition>`_, |
54
+ | | `transpose <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.transpose>`_ |
55
+ +------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
56
+ | Lookahead and lookback | `spy <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.spy>`_, |
57
+ | | `peekable <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.peekable>`_, |
58
+ | | `seekable <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.seekable>`_ |
59
+ +------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
60
+ | Windowing | `windowed <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.windowed>`_, |
61
+ | | `substrings <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.substrings>`_, |
62
+ | | `substrings_indexes <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.substrings_indexes>`_, |
63
+ | | `stagger <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.stagger>`_, |
64
+ | | `windowed_complete <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.windowed_complete>`_, |
65
+ | | `pairwise <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.pairwise>`_, |
66
+ | | `triplewise <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.triplewise>`_, |
67
+ | | `sliding_window <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.sliding_window>`_, |
68
+ | | `subslices <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.subslices>`_ |
69
+ +------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
70
+ | Augmenting | `count_cycle <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.count_cycle>`_, |
71
+ | | `intersperse <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.intersperse>`_, |
72
+ | | `padded <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.padded>`_, |
73
+ | | `repeat_each <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.repeat_each>`_, |
74
+ | | `mark_ends <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.mark_ends>`_, |
75
+ | | `repeat_last <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.repeat_last>`_, |
76
+ | | `adjacent <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.adjacent>`_, |
77
+ | | `groupby_transform <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.groupby_transform>`_, |
78
+ | | `pad_none <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.pad_none>`_, |
79
+ | | `ncycles <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.ncycles>`_ |
80
+ +------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
81
+ | Combining | `collapse <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.collapse>`_, |
82
+ | | `sort_together <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.sort_together>`_, |
83
+ | | `interleave <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.interleave>`_, |
84
+ | | `interleave_longest <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.interleave_longest>`_, |
85
+ | | `interleave_evenly <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.interleave_evenly>`_, |
86
+ | | `zip_offset <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.zip_offset>`_, |
87
+ | | `zip_equal <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.zip_equal>`_, |
88
+ | | `zip_broadcast <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.zip_broadcast>`_, |
89
+ | | `dotproduct <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.dotproduct>`_, |
90
+ | | `convolve <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.convolve>`_, |
91
+ | | `flatten <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.flatten>`_, |
92
+ | | `roundrobin <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.roundrobin>`_, |
93
+ | | `prepend <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.prepend>`_, |
94
+ | | `value_chain <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.value_chain>`_, |
95
+ | | `partial_product <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.partial_product>`_ |
96
+ +------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
97
+ | Summarizing | `ilen <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.ilen>`_, |
98
+ | | `unique_to_each <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.unique_to_each>`_, |
99
+ | | `sample <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.sample>`_, |
100
+ | | `consecutive_groups <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.consecutive_groups>`_, |
101
+ | | `run_length <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.run_length>`_, |
102
+ | | `map_reduce <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.map_reduce>`_, |
103
+ | | `exactly_n <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.exactly_n>`_, |
104
+ | | `is_sorted <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.is_sorted>`_, |
105
+ | | `all_equal <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.all_equal>`_, |
106
+ | | `all_unique <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.all_unique>`_, |
107
+ | | `minmax <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.minmax>`_, |
108
+ | | `first_true <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.first_true>`_, |
109
+ | | `quantify <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.quantify>`_, |
110
+ | | `iequals <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.iequals>`_ |
111
+ +------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
112
+ | Selecting | `islice_extended <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.islice_extended>`_, |
113
+ | | `first <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.first>`_, |
114
+ | | `last <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.last>`_, |
115
+ | | `one <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.one>`_, |
116
+ | | `only <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.only>`_, |
117
+ | | `strictly_n <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.strictly_n>`_, |
118
+ | | `strip <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.strip>`_, |
119
+ | | `lstrip <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.lstrip>`_, |
120
+ | | `rstrip <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.rstrip>`_, |
121
+ | | `filter_except <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.filter_except>`_, |
122
+ | | `map_except <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.map_except>`_, |
123
+ | | `nth_or_last <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.nth_or_last>`_, |
124
+ | | `unique_in_window <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.unique_in_window>`_, |
125
+ | | `before_and_after <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.before_and_after>`_, |
126
+ | | `nth <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.nth>`_, |
127
+ | | `take <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.take>`_, |
128
+ | | `tail <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.tail>`_, |
129
+ | | `unique_everseen <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.unique_everseen>`_, |
130
+ | | `unique_justseen <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.unique_justseen>`_, |
131
+ | | `duplicates_everseen <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.duplicates_everseen>`_, |
132
+ | | `duplicates_justseen <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.duplicates_justseen>`_, |
133
+ | | `longest_common_prefix <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.longest_common_prefix>`_, |
134
+ | | `takewhile_inclusive <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.takewhile_inclusive>`_ |
135
+ +------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
136
+ | Combinatorics | `distinct_permutations <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distinct_permutations>`_, |
137
+ | | `distinct_combinations <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.distinct_combinations>`_, |
138
+ | | `circular_shifts <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.circular_shifts>`_, |
139
+ | | `partitions <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.partitions>`_, |
140
+ | | `set_partitions <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.set_partitions>`_, |
141
+ | | `product_index <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.product_index>`_, |
142
+ | | `combination_index <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.combination_index>`_, |
143
+ | | `permutation_index <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.permutation_index>`_, |
144
+ | | `combination_with_replacement_index <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.combination_with_replacement_index>`_, |
145
+ | | `gray_product <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.gray_product>`_, |
146
+ | | `outer_product <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.outer_product>`_, |
147
+ | | `powerset <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.powerset>`_, |
148
+ | | `random_product <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.random_product>`_, |
149
+ | | `random_permutation <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.random_permutation>`_, |
150
+ | | `random_combination <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.random_combination>`_, |
151
+ | | `random_combination_with_replacement <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.random_combination_with_replacement>`_, |
152
+ | | `nth_product <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.nth_product>`_, |
153
+ | | `nth_permutation <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.nth_permutation>`_, |
154
+ | | `nth_combination <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.nth_combination>`_, |
155
+ | | `nth_combination_with_replacement <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.nth_combination_with_replacement>`_ |
156
+ +------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
157
+ | Wrapping | `always_iterable <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.always_iterable>`_, |
158
+ | | `always_reversible <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.always_reversible>`_, |
159
+ | | `countable <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.countable>`_, |
160
+ | | `consumer <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.consumer>`_, |
161
+ | | `with_iter <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.with_iter>`_, |
162
+ | | `iter_except <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.iter_except>`_ |
163
+ +------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
164
+ | Others | `locate <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.locate>`_, |
165
+ | | `rlocate <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.rlocate>`_, |
166
+ | | `replace <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.replace>`_, |
167
+ | | `numeric_range <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.numeric_range>`_, |
168
+ | | `side_effect <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.side_effect>`_, |
169
+ | | `iterate <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.iterate>`_, |
170
+ | | `difference <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.difference>`_, |
171
+ | | `make_decorator <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.make_decorator>`_, |
172
+ | | `SequenceView <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.SequenceView>`_, |
173
+ | | `time_limited <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.time_limited>`_, |
174
+ | | `map_if <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.map_if>`_, |
175
+ | | `iter_index <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.iter_index>`_, |
176
+ | | `consume <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.consume>`_, |
177
+ | | `tabulate <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.tabulate>`_, |
178
+ | | `repeatfunc <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.repeatfunc>`_, |
179
+ | | `polynomial_from_roots <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.polynomial_from_roots>`_, |
180
+ | | `polynomial_eval <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.polynomial_eval>`_, |
181
+ | | `polynomial_derivative <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.polynomial_derivative>`_, |
182
+ | | `sieve <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.sieve>`_, |
183
+ | | `factor <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.factor>`_, |
184
+ | | `matmul <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.matmul>`_, |
185
+ | | `sum_of_squares <https://more-itertools.readthedocs.io/en/stable/api.html#more_itertools.sum_of_squares>`_ |
186
+ +------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
187
+
188
+
189
+ Getting started
190
+ ===============
191
+
192
+ To get started, install the library with `pip <https://pip.pypa.io/en/stable/>`_:
193
+
194
+ .. code-block:: shell
195
+
196
+ pip install more-itertools
197
+
198
+ The recipes from the `itertools docs <https://docs.python.org/3/library/itertools.html#itertools-recipes>`_
199
+ are included in the top-level package:
200
+
201
+ .. code-block:: python
202
+
203
+ >>> from more_itertools import flatten
204
+ >>> iterable = [(0, 1), (2, 3)]
205
+ >>> list(flatten(iterable))
206
+ [0, 1, 2, 3]
207
+
208
+ Several new recipes are available as well:
209
+
210
+ .. code-block:: python
211
+
212
+ >>> from more_itertools import chunked
213
+ >>> iterable = [0, 1, 2, 3, 4, 5, 6, 7, 8]
214
+ >>> list(chunked(iterable, 3))
215
+ [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
216
+
217
+ >>> from more_itertools import spy
218
+ >>> iterable = (x * x for x in range(1, 6))
219
+ >>> head, iterable = spy(iterable, n=3)
220
+ >>> list(head)
221
+ [1, 4, 9]
222
+ >>> list(iterable)
223
+ [1, 4, 9, 16, 25]
224
+
225
+
226
+
227
+ For the full listing of functions, see the `API documentation <https://more-itertools.readthedocs.io/en/stable/api.html>`_.
228
+
229
+
230
+ Links elsewhere
231
+ ===============
232
+
233
+ Blog posts about ``more-itertools``:
234
+
235
+ * `Yo, I heard you like decorators <https://www.bbayles.com/index/decorator_factory>`__
236
+ * `Tour of Python Itertools <https://martinheinz.dev/blog/16>`__ (`Alternate <https://dev.to/martinheinz/tour-of-python-itertools-4122>`__)
237
+ * `Real-World Python More Itertools <https://www.gidware.com/real-world-more-itertools/>`_
238
+
239
+
240
+ Development
241
+ ===========
242
+
243
+ ``more-itertools`` is maintained by `@erikrose <https://github.com/erikrose>`_
244
+ and `@bbayles <https://github.com/bbayles>`_, with help from `many others <https://github.com/more-itertools/more-itertools/graphs/contributors>`_.
245
+ If you have a problem or suggestion, please file a bug or pull request in this
246
+ repository. Thanks for contributing!
247
+
248
+
249
+ Version History
250
+ ===============
251
+
252
+ The version history can be found in `documentation <https://more-itertools.readthedocs.io/en/stable/versions.html>`_.
253
+
lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/RECORD ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ more_itertools-10.1.0.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
2
+ more_itertools-10.1.0.dist-info/LICENSE,sha256=CfHIyelBrz5YTVlkHqm4fYPAyw_QB-te85Gn4mQ8GkY,1053
3
+ more_itertools-10.1.0.dist-info/METADATA,sha256=s6T6Rg5Sq9hJE6KhX38MrPxWh8X0d9Fsdld4qdbrQVQ,33830
4
+ more_itertools-10.1.0.dist-info/RECORD,,
5
+ more_itertools-10.1.0.dist-info/REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ more_itertools-10.1.0.dist-info/WHEEL,sha256=rSgq_JpHF9fHR1lx53qwg_1-2LypZE_qmcuXbVUq948,81
7
+ more_itertools/__init__.py,sha256=weQyJxnCVH2EuAaoxC8ZEX-ViZmpHm5kLHO05O8LfRM,149
8
+ more_itertools/__init__.pyi,sha256=5B3eTzON1BBuOLob1vCflyEb2lSd6usXQQ-Cv-hXkeA,43
9
+ more_itertools/__pycache__/__init__.cpython-311.pyc,,
10
+ more_itertools/__pycache__/more.cpython-311.pyc,,
11
+ more_itertools/__pycache__/recipes.cpython-311.pyc,,
12
+ more_itertools/more.py,sha256=6XrPO3vvd3miZb6ygQtY5SoJeLk2PAwfEQOlZHwZgXw,140481
13
+ more_itertools/more.pyi,sha256=nbYRtg-0YR0XMnFOGMzRzbgEOtNC_cfkUaOMOmNVWMA,20726
14
+ more_itertools/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
+ more_itertools/recipes.py,sha256=GNzJluz0ZYBZhF_H3bC7YkNCm9fZzOoBiQvh1WNNDOU,26354
16
+ more_itertools/recipes.pyi,sha256=77IGo2ZqPtp9IkRu5MzQBgatLBl52vO3MRrFvPDcHzM,4237
lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/REQUESTED ADDED
File without changes
lib/python3.11/site-packages/more_itertools-10.1.0.dist-info/WHEEL ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ Wheel-Version: 1.0
2
+ Generator: flit 3.8.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
lib/python3.11/site-packages/more_itertools/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """More routines for operating on iterables, beyond itertools"""
2
+
3
+ from .more import * # noqa
4
+ from .recipes import * # noqa
5
+
6
+ __version__ = '10.1.0'
lib/python3.11/site-packages/more_itertools/__init__.pyi ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .more import *
2
+ from .recipes import *
lib/python3.11/site-packages/more_itertools/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (394 Bytes). View file
 
lib/python3.11/site-packages/more_itertools/__pycache__/more.cpython-311.pyc ADDED
Binary file (179 kB). View file
 
lib/python3.11/site-packages/more_itertools/__pycache__/recipes.cpython-311.pyc ADDED
Binary file (37.2 kB). View file
 
lib/python3.11/site-packages/more_itertools/more.py ADDED
The diff for this file is too large to render. See raw diff
 
lib/python3.11/site-packages/more_itertools/more.pyi ADDED
@@ -0,0 +1,684 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stubs for more_itertools.more"""
2
+ from __future__ import annotations
3
+
4
+ from types import TracebackType
5
+ from typing import (
6
+ Any,
7
+ Callable,
8
+ Container,
9
+ ContextManager,
10
+ Generic,
11
+ Hashable,
12
+ Iterable,
13
+ Iterator,
14
+ overload,
15
+ Reversible,
16
+ Sequence,
17
+ Sized,
18
+ Type,
19
+ TypeVar,
20
+ type_check_only,
21
+ )
22
+ from typing_extensions import Protocol
23
+
24
+ # Type and type variable definitions
25
+ _T = TypeVar('_T')
26
+ _T1 = TypeVar('_T1')
27
+ _T2 = TypeVar('_T2')
28
+ _U = TypeVar('_U')
29
+ _V = TypeVar('_V')
30
+ _W = TypeVar('_W')
31
+ _T_co = TypeVar('_T_co', covariant=True)
32
+ _GenFn = TypeVar('_GenFn', bound=Callable[..., Iterator[object]])
33
+ _Raisable = BaseException | Type[BaseException]
34
+
35
+ @type_check_only
36
+ class _SizedIterable(Protocol[_T_co], Sized, Iterable[_T_co]): ...
37
+
38
+ @type_check_only
39
+ class _SizedReversible(Protocol[_T_co], Sized, Reversible[_T_co]): ...
40
+
41
+ @type_check_only
42
+ class _SupportsSlicing(Protocol[_T_co]):
43
+ def __getitem__(self, __k: slice) -> _T_co: ...
44
+
45
+ def chunked(
46
+ iterable: Iterable[_T], n: int | None, strict: bool = ...
47
+ ) -> Iterator[list[_T]]: ...
48
+ @overload
49
+ def first(iterable: Iterable[_T]) -> _T: ...
50
+ @overload
51
+ def first(iterable: Iterable[_T], default: _U) -> _T | _U: ...
52
+ @overload
53
+ def last(iterable: Iterable[_T]) -> _T: ...
54
+ @overload
55
+ def last(iterable: Iterable[_T], default: _U) -> _T | _U: ...
56
+ @overload
57
+ def nth_or_last(iterable: Iterable[_T], n: int) -> _T: ...
58
+ @overload
59
+ def nth_or_last(iterable: Iterable[_T], n: int, default: _U) -> _T | _U: ...
60
+
61
+ class peekable(Generic[_T], Iterator[_T]):
62
+ def __init__(self, iterable: Iterable[_T]) -> None: ...
63
+ def __iter__(self) -> peekable[_T]: ...
64
+ def __bool__(self) -> bool: ...
65
+ @overload
66
+ def peek(self) -> _T: ...
67
+ @overload
68
+ def peek(self, default: _U) -> _T | _U: ...
69
+ def prepend(self, *items: _T) -> None: ...
70
+ def __next__(self) -> _T: ...
71
+ @overload
72
+ def __getitem__(self, index: int) -> _T: ...
73
+ @overload
74
+ def __getitem__(self, index: slice) -> list[_T]: ...
75
+
76
+ def consumer(func: _GenFn) -> _GenFn: ...
77
+ def ilen(iterable: Iterable[object]) -> int: ...
78
+ def iterate(func: Callable[[_T], _T], start: _T) -> Iterator[_T]: ...
79
+ def with_iter(
80
+ context_manager: ContextManager[Iterable[_T]],
81
+ ) -> Iterator[_T]: ...
82
+ def one(
83
+ iterable: Iterable[_T],
84
+ too_short: _Raisable | None = ...,
85
+ too_long: _Raisable | None = ...,
86
+ ) -> _T: ...
87
+ def raise_(exception: _Raisable, *args: Any) -> None: ...
88
+ def strictly_n(
89
+ iterable: Iterable[_T],
90
+ n: int,
91
+ too_short: _GenFn | None = ...,
92
+ too_long: _GenFn | None = ...,
93
+ ) -> list[_T]: ...
94
+ def distinct_permutations(
95
+ iterable: Iterable[_T], r: int | None = ...
96
+ ) -> Iterator[tuple[_T, ...]]: ...
97
+ def intersperse(
98
+ e: _U, iterable: Iterable[_T], n: int = ...
99
+ ) -> Iterator[_T | _U]: ...
100
+ def unique_to_each(*iterables: Iterable[_T]) -> list[list[_T]]: ...
101
+ @overload
102
+ def windowed(
103
+ seq: Iterable[_T], n: int, *, step: int = ...
104
+ ) -> Iterator[tuple[_T | None, ...]]: ...
105
+ @overload
106
+ def windowed(
107
+ seq: Iterable[_T], n: int, fillvalue: _U, step: int = ...
108
+ ) -> Iterator[tuple[_T | _U, ...]]: ...
109
+ def substrings(iterable: Iterable[_T]) -> Iterator[tuple[_T, ...]]: ...
110
+ def substrings_indexes(
111
+ seq: Sequence[_T], reverse: bool = ...
112
+ ) -> Iterator[tuple[Sequence[_T], int, int]]: ...
113
+
114
+ class bucket(Generic[_T, _U], Container[_U]):
115
+ def __init__(
116
+ self,
117
+ iterable: Iterable[_T],
118
+ key: Callable[[_T], _U],
119
+ validator: Callable[[object], object] | None = ...,
120
+ ) -> None: ...
121
+ def __contains__(self, value: object) -> bool: ...
122
+ def __iter__(self) -> Iterator[_U]: ...
123
+ def __getitem__(self, value: object) -> Iterator[_T]: ...
124
+
125
+ def spy(
126
+ iterable: Iterable[_T], n: int = ...
127
+ ) -> tuple[list[_T], Iterator[_T]]: ...
128
+ def interleave(*iterables: Iterable[_T]) -> Iterator[_T]: ...
129
+ def interleave_longest(*iterables: Iterable[_T]) -> Iterator[_T]: ...
130
+ def interleave_evenly(
131
+ iterables: list[Iterable[_T]], lengths: list[int] | None = ...
132
+ ) -> Iterator[_T]: ...
133
+ def collapse(
134
+ iterable: Iterable[Any],
135
+ base_type: type | None = ...,
136
+ levels: int | None = ...,
137
+ ) -> Iterator[Any]: ...
138
+ @overload
139
+ def side_effect(
140
+ func: Callable[[_T], object],
141
+ iterable: Iterable[_T],
142
+ chunk_size: None = ...,
143
+ before: Callable[[], object] | None = ...,
144
+ after: Callable[[], object] | None = ...,
145
+ ) -> Iterator[_T]: ...
146
+ @overload
147
+ def side_effect(
148
+ func: Callable[[list[_T]], object],
149
+ iterable: Iterable[_T],
150
+ chunk_size: int,
151
+ before: Callable[[], object] | None = ...,
152
+ after: Callable[[], object] | None = ...,
153
+ ) -> Iterator[_T]: ...
154
+ def sliced(
155
+ seq: _SupportsSlicing[_T], n: int, strict: bool = ...
156
+ ) -> Iterator[_T]: ...
157
+ def split_at(
158
+ iterable: Iterable[_T],
159
+ pred: Callable[[_T], object],
160
+ maxsplit: int = ...,
161
+ keep_separator: bool = ...,
162
+ ) -> Iterator[list[_T]]: ...
163
+ def split_before(
164
+ iterable: Iterable[_T], pred: Callable[[_T], object], maxsplit: int = ...
165
+ ) -> Iterator[list[_T]]: ...
166
+ def split_after(
167
+ iterable: Iterable[_T], pred: Callable[[_T], object], maxsplit: int = ...
168
+ ) -> Iterator[list[_T]]: ...
169
+ def split_when(
170
+ iterable: Iterable[_T],
171
+ pred: Callable[[_T, _T], object],
172
+ maxsplit: int = ...,
173
+ ) -> Iterator[list[_T]]: ...
174
+ def split_into(
175
+ iterable: Iterable[_T], sizes: Iterable[int | None]
176
+ ) -> Iterator[list[_T]]: ...
177
+ @overload
178
+ def padded(
179
+ iterable: Iterable[_T],
180
+ *,
181
+ n: int | None = ...,
182
+ next_multiple: bool = ...,
183
+ ) -> Iterator[_T | None]: ...
184
+ @overload
185
+ def padded(
186
+ iterable: Iterable[_T],
187
+ fillvalue: _U,
188
+ n: int | None = ...,
189
+ next_multiple: bool = ...,
190
+ ) -> Iterator[_T | _U]: ...
191
+ @overload
192
+ def repeat_last(iterable: Iterable[_T]) -> Iterator[_T]: ...
193
+ @overload
194
+ def repeat_last(iterable: Iterable[_T], default: _U) -> Iterator[_T | _U]: ...
195
+ def distribute(n: int, iterable: Iterable[_T]) -> list[Iterator[_T]]: ...
196
+ @overload
197
+ def stagger(
198
+ iterable: Iterable[_T],
199
+ offsets: _SizedIterable[int] = ...,
200
+ longest: bool = ...,
201
+ ) -> Iterator[tuple[_T | None, ...]]: ...
202
+ @overload
203
+ def stagger(
204
+ iterable: Iterable[_T],
205
+ offsets: _SizedIterable[int] = ...,
206
+ longest: bool = ...,
207
+ fillvalue: _U = ...,
208
+ ) -> Iterator[tuple[_T | _U, ...]]: ...
209
+
210
+ class UnequalIterablesError(ValueError):
211
+ def __init__(self, details: tuple[int, int, int] | None = ...) -> None: ...
212
+
213
+ @overload
214
+ def zip_equal(__iter1: Iterable[_T1]) -> Iterator[tuple[_T1]]: ...
215
+ @overload
216
+ def zip_equal(
217
+ __iter1: Iterable[_T1], __iter2: Iterable[_T2]
218
+ ) -> Iterator[tuple[_T1, _T2]]: ...
219
+ @overload
220
+ def zip_equal(
221
+ __iter1: Iterable[_T],
222
+ __iter2: Iterable[_T],
223
+ __iter3: Iterable[_T],
224
+ *iterables: Iterable[_T],
225
+ ) -> Iterator[tuple[_T, ...]]: ...
226
+ @overload
227
+ def zip_offset(
228
+ __iter1: Iterable[_T1],
229
+ *,
230
+ offsets: _SizedIterable[int],
231
+ longest: bool = ...,
232
+ fillvalue: None = None,
233
+ ) -> Iterator[tuple[_T1 | None]]: ...
234
+ @overload
235
+ def zip_offset(
236
+ __iter1: Iterable[_T1],
237
+ __iter2: Iterable[_T2],
238
+ *,
239
+ offsets: _SizedIterable[int],
240
+ longest: bool = ...,
241
+ fillvalue: None = None,
242
+ ) -> Iterator[tuple[_T1 | None, _T2 | None]]: ...
243
+ @overload
244
+ def zip_offset(
245
+ __iter1: Iterable[_T],
246
+ __iter2: Iterable[_T],
247
+ __iter3: Iterable[_T],
248
+ *iterables: Iterable[_T],
249
+ offsets: _SizedIterable[int],
250
+ longest: bool = ...,
251
+ fillvalue: None = None,
252
+ ) -> Iterator[tuple[_T | None, ...]]: ...
253
+ @overload
254
+ def zip_offset(
255
+ __iter1: Iterable[_T1],
256
+ *,
257
+ offsets: _SizedIterable[int],
258
+ longest: bool = ...,
259
+ fillvalue: _U,
260
+ ) -> Iterator[tuple[_T1 | _U]]: ...
261
+ @overload
262
+ def zip_offset(
263
+ __iter1: Iterable[_T1],
264
+ __iter2: Iterable[_T2],
265
+ *,
266
+ offsets: _SizedIterable[int],
267
+ longest: bool = ...,
268
+ fillvalue: _U,
269
+ ) -> Iterator[tuple[_T1 | _U, _T2 | _U]]: ...
270
+ @overload
271
+ def zip_offset(
272
+ __iter1: Iterable[_T],
273
+ __iter2: Iterable[_T],
274
+ __iter3: Iterable[_T],
275
+ *iterables: Iterable[_T],
276
+ offsets: _SizedIterable[int],
277
+ longest: bool = ...,
278
+ fillvalue: _U,
279
+ ) -> Iterator[tuple[_T | _U, ...]]: ...
280
+ def sort_together(
281
+ iterables: Iterable[Iterable[_T]],
282
+ key_list: Iterable[int] = ...,
283
+ key: Callable[..., Any] | None = ...,
284
+ reverse: bool = ...,
285
+ ) -> list[tuple[_T, ...]]: ...
286
+ def unzip(iterable: Iterable[Sequence[_T]]) -> tuple[Iterator[_T], ...]: ...
287
+ def divide(n: int, iterable: Iterable[_T]) -> list[Iterator[_T]]: ...
288
+ def always_iterable(
289
+ obj: object,
290
+ base_type: type | tuple[type | tuple[Any, ...], ...] | None = ...,
291
+ ) -> Iterator[Any]: ...
292
+ def adjacent(
293
+ predicate: Callable[[_T], bool],
294
+ iterable: Iterable[_T],
295
+ distance: int = ...,
296
+ ) -> Iterator[tuple[bool, _T]]: ...
297
+ @overload
298
+ def groupby_transform(
299
+ iterable: Iterable[_T],
300
+ keyfunc: None = None,
301
+ valuefunc: None = None,
302
+ reducefunc: None = None,
303
+ ) -> Iterator[tuple[_T, Iterator[_T]]]: ...
304
+ @overload
305
+ def groupby_transform(
306
+ iterable: Iterable[_T],
307
+ keyfunc: Callable[[_T], _U],
308
+ valuefunc: None,
309
+ reducefunc: None,
310
+ ) -> Iterator[tuple[_U, Iterator[_T]]]: ...
311
+ @overload
312
+ def groupby_transform(
313
+ iterable: Iterable[_T],
314
+ keyfunc: None,
315
+ valuefunc: Callable[[_T], _V],
316
+ reducefunc: None,
317
+ ) -> Iterable[tuple[_T, Iterable[_V]]]: ...
318
+ @overload
319
+ def groupby_transform(
320
+ iterable: Iterable[_T],
321
+ keyfunc: Callable[[_T], _U],
322
+ valuefunc: Callable[[_T], _V],
323
+ reducefunc: None,
324
+ ) -> Iterable[tuple[_U, Iterator[_V]]]: ...
325
+ @overload
326
+ def groupby_transform(
327
+ iterable: Iterable[_T],
328
+ keyfunc: None,
329
+ valuefunc: None,
330
+ reducefunc: Callable[[Iterator[_T]], _W],
331
+ ) -> Iterable[tuple[_T, _W]]: ...
332
+ @overload
333
+ def groupby_transform(
334
+ iterable: Iterable[_T],
335
+ keyfunc: Callable[[_T], _U],
336
+ valuefunc: None,
337
+ reducefunc: Callable[[Iterator[_T]], _W],
338
+ ) -> Iterable[tuple[_U, _W]]: ...
339
+ @overload
340
+ def groupby_transform(
341
+ iterable: Iterable[_T],
342
+ keyfunc: None,
343
+ valuefunc: Callable[[_T], _V],
344
+ reducefunc: Callable[[Iterable[_V]], _W],
345
+ ) -> Iterable[tuple[_T, _W]]: ...
346
+ @overload
347
+ def groupby_transform(
348
+ iterable: Iterable[_T],
349
+ keyfunc: Callable[[_T], _U],
350
+ valuefunc: Callable[[_T], _V],
351
+ reducefunc: Callable[[Iterable[_V]], _W],
352
+ ) -> Iterable[tuple[_U, _W]]: ...
353
+
354
+ class numeric_range(Generic[_T, _U], Sequence[_T], Hashable, Reversible[_T]):
355
+ @overload
356
+ def __init__(self, __stop: _T) -> None: ...
357
+ @overload
358
+ def __init__(self, __start: _T, __stop: _T) -> None: ...
359
+ @overload
360
+ def __init__(self, __start: _T, __stop: _T, __step: _U) -> None: ...
361
+ def __bool__(self) -> bool: ...
362
+ def __contains__(self, elem: object) -> bool: ...
363
+ def __eq__(self, other: object) -> bool: ...
364
+ @overload
365
+ def __getitem__(self, key: int) -> _T: ...
366
+ @overload
367
+ def __getitem__(self, key: slice) -> numeric_range[_T, _U]: ...
368
+ def __hash__(self) -> int: ...
369
+ def __iter__(self) -> Iterator[_T]: ...
370
+ def __len__(self) -> int: ...
371
+ def __reduce__(
372
+ self,
373
+ ) -> tuple[Type[numeric_range[_T, _U]], tuple[_T, _T, _U]]: ...
374
+ def __repr__(self) -> str: ...
375
+ def __reversed__(self) -> Iterator[_T]: ...
376
+ def count(self, value: _T) -> int: ...
377
+ def index(self, value: _T) -> int: ... # type: ignore
378
+
379
+ def count_cycle(
380
+ iterable: Iterable[_T], n: int | None = ...
381
+ ) -> Iterable[tuple[int, _T]]: ...
382
+ def mark_ends(
383
+ iterable: Iterable[_T],
384
+ ) -> Iterable[tuple[bool, bool, _T]]: ...
385
+ def locate(
386
+ iterable: Iterable[object],
387
+ pred: Callable[..., Any] = ...,
388
+ window_size: int | None = ...,
389
+ ) -> Iterator[int]: ...
390
+ def lstrip(
391
+ iterable: Iterable[_T], pred: Callable[[_T], object]
392
+ ) -> Iterator[_T]: ...
393
+ def rstrip(
394
+ iterable: Iterable[_T], pred: Callable[[_T], object]
395
+ ) -> Iterator[_T]: ...
396
+ def strip(
397
+ iterable: Iterable[_T], pred: Callable[[_T], object]
398
+ ) -> Iterator[_T]: ...
399
+
400
+ class islice_extended(Generic[_T], Iterator[_T]):
401
+ def __init__(self, iterable: Iterable[_T], *args: int | None) -> None: ...
402
+ def __iter__(self) -> islice_extended[_T]: ...
403
+ def __next__(self) -> _T: ...
404
+ def __getitem__(self, index: slice) -> islice_extended[_T]: ...
405
+
406
+ def always_reversible(iterable: Iterable[_T]) -> Iterator[_T]: ...
407
+ def consecutive_groups(
408
+ iterable: Iterable[_T], ordering: Callable[[_T], int] = ...
409
+ ) -> Iterator[Iterator[_T]]: ...
410
+ @overload
411
+ def difference(
412
+ iterable: Iterable[_T],
413
+ func: Callable[[_T, _T], _U] = ...,
414
+ *,
415
+ initial: None = ...,
416
+ ) -> Iterator[_T | _U]: ...
417
+ @overload
418
+ def difference(
419
+ iterable: Iterable[_T], func: Callable[[_T, _T], _U] = ..., *, initial: _U
420
+ ) -> Iterator[_U]: ...
421
+
422
+ class SequenceView(Generic[_T], Sequence[_T]):
423
+ def __init__(self, target: Sequence[_T]) -> None: ...
424
+ @overload
425
+ def __getitem__(self, index: int) -> _T: ...
426
+ @overload
427
+ def __getitem__(self, index: slice) -> Sequence[_T]: ...
428
+ def __len__(self) -> int: ...
429
+
430
+ class seekable(Generic[_T], Iterator[_T]):
431
+ def __init__(
432
+ self, iterable: Iterable[_T], maxlen: int | None = ...
433
+ ) -> None: ...
434
+ def __iter__(self) -> seekable[_T]: ...
435
+ def __next__(self) -> _T: ...
436
+ def __bool__(self) -> bool: ...
437
+ @overload
438
+ def peek(self) -> _T: ...
439
+ @overload
440
+ def peek(self, default: _U) -> _T | _U: ...
441
+ def elements(self) -> SequenceView[_T]: ...
442
+ def seek(self, index: int) -> None: ...
443
+ def relative_seek(self, count: int) -> None: ...
444
+
445
+ class run_length:
446
+ @staticmethod
447
+ def encode(iterable: Iterable[_T]) -> Iterator[tuple[_T, int]]: ...
448
+ @staticmethod
449
+ def decode(iterable: Iterable[tuple[_T, int]]) -> Iterator[_T]: ...
450
+
451
+ def exactly_n(
452
+ iterable: Iterable[_T], n: int, predicate: Callable[[_T], object] = ...
453
+ ) -> bool: ...
454
+ def circular_shifts(iterable: Iterable[_T]) -> list[tuple[_T, ...]]: ...
455
+ def make_decorator(
456
+ wrapping_func: Callable[..., _U], result_index: int = ...
457
+ ) -> Callable[..., Callable[[Callable[..., Any]], Callable[..., _U]]]: ...
458
+ @overload
459
+ def map_reduce(
460
+ iterable: Iterable[_T],
461
+ keyfunc: Callable[[_T], _U],
462
+ valuefunc: None = ...,
463
+ reducefunc: None = ...,
464
+ ) -> dict[_U, list[_T]]: ...
465
+ @overload
466
+ def map_reduce(
467
+ iterable: Iterable[_T],
468
+ keyfunc: Callable[[_T], _U],
469
+ valuefunc: Callable[[_T], _V],
470
+ reducefunc: None = ...,
471
+ ) -> dict[_U, list[_V]]: ...
472
+ @overload
473
+ def map_reduce(
474
+ iterable: Iterable[_T],
475
+ keyfunc: Callable[[_T], _U],
476
+ valuefunc: None = ...,
477
+ reducefunc: Callable[[list[_T]], _W] = ...,
478
+ ) -> dict[_U, _W]: ...
479
+ @overload
480
+ def map_reduce(
481
+ iterable: Iterable[_T],
482
+ keyfunc: Callable[[_T], _U],
483
+ valuefunc: Callable[[_T], _V],
484
+ reducefunc: Callable[[list[_V]], _W],
485
+ ) -> dict[_U, _W]: ...
486
+ def rlocate(
487
+ iterable: Iterable[_T],
488
+ pred: Callable[..., object] = ...,
489
+ window_size: int | None = ...,
490
+ ) -> Iterator[int]: ...
491
+ def replace(
492
+ iterable: Iterable[_T],
493
+ pred: Callable[..., object],
494
+ substitutes: Iterable[_U],
495
+ count: int | None = ...,
496
+ window_size: int = ...,
497
+ ) -> Iterator[_T | _U]: ...
498
+ def partitions(iterable: Iterable[_T]) -> Iterator[list[list[_T]]]: ...
499
+ def set_partitions(
500
+ iterable: Iterable[_T], k: int | None = ...
501
+ ) -> Iterator[list[list[_T]]]: ...
502
+
503
+ class time_limited(Generic[_T], Iterator[_T]):
504
+ def __init__(
505
+ self, limit_seconds: float, iterable: Iterable[_T]
506
+ ) -> None: ...
507
+ def __iter__(self) -> islice_extended[_T]: ...
508
+ def __next__(self) -> _T: ...
509
+
510
+ @overload
511
+ def only(
512
+ iterable: Iterable[_T], *, too_long: _Raisable | None = ...
513
+ ) -> _T | None: ...
514
+ @overload
515
+ def only(
516
+ iterable: Iterable[_T], default: _U, too_long: _Raisable | None = ...
517
+ ) -> _T | _U: ...
518
+ def ichunked(iterable: Iterable[_T], n: int) -> Iterator[Iterator[_T]]: ...
519
+ def distinct_combinations(
520
+ iterable: Iterable[_T], r: int
521
+ ) -> Iterator[tuple[_T, ...]]: ...
522
+ def filter_except(
523
+ validator: Callable[[Any], object],
524
+ iterable: Iterable[_T],
525
+ *exceptions: Type[BaseException],
526
+ ) -> Iterator[_T]: ...
527
+ def map_except(
528
+ function: Callable[[Any], _U],
529
+ iterable: Iterable[_T],
530
+ *exceptions: Type[BaseException],
531
+ ) -> Iterator[_U]: ...
532
+ def map_if(
533
+ iterable: Iterable[Any],
534
+ pred: Callable[[Any], bool],
535
+ func: Callable[[Any], Any],
536
+ func_else: Callable[[Any], Any] | None = ...,
537
+ ) -> Iterator[Any]: ...
538
+ def sample(
539
+ iterable: Iterable[_T],
540
+ k: int,
541
+ weights: Iterable[float] | None = ...,
542
+ ) -> list[_T]: ...
543
+ def is_sorted(
544
+ iterable: Iterable[_T],
545
+ key: Callable[[_T], _U] | None = ...,
546
+ reverse: bool = False,
547
+ strict: bool = False,
548
+ ) -> bool: ...
549
+
550
+ class AbortThread(BaseException):
551
+ pass
552
+
553
+ class callback_iter(Generic[_T], Iterator[_T]):
554
+ def __init__(
555
+ self,
556
+ func: Callable[..., Any],
557
+ callback_kwd: str = ...,
558
+ wait_seconds: float = ...,
559
+ ) -> None: ...
560
+ def __enter__(self) -> callback_iter[_T]: ...
561
+ def __exit__(
562
+ self,
563
+ exc_type: Type[BaseException] | None,
564
+ exc_value: BaseException | None,
565
+ traceback: TracebackType | None,
566
+ ) -> bool | None: ...
567
+ def __iter__(self) -> callback_iter[_T]: ...
568
+ def __next__(self) -> _T: ...
569
+ def _reader(self) -> Iterator[_T]: ...
570
+ @property
571
+ def done(self) -> bool: ...
572
+ @property
573
+ def result(self) -> Any: ...
574
+
575
+ def windowed_complete(
576
+ iterable: Iterable[_T], n: int
577
+ ) -> Iterator[tuple[_T, ...]]: ...
578
+ def all_unique(
579
+ iterable: Iterable[_T], key: Callable[[_T], _U] | None = ...
580
+ ) -> bool: ...
581
+ def nth_product(index: int, *args: Iterable[_T]) -> tuple[_T, ...]: ...
582
+ def nth_combination_with_replacement(
583
+ iterable: Iterable[_T], r: int, index: int
584
+ ) -> tuple[_T, ...]: ...
585
+ def nth_permutation(
586
+ iterable: Iterable[_T], r: int, index: int
587
+ ) -> tuple[_T, ...]: ...
588
+ def value_chain(*args: _T | Iterable[_T]) -> Iterable[_T]: ...
589
+ def product_index(element: Iterable[_T], *args: Iterable[_T]) -> int: ...
590
+ def combination_index(
591
+ element: Iterable[_T], iterable: Iterable[_T]
592
+ ) -> int: ...
593
+ def combination_with_replacement_index(
594
+ element: Iterable[_T], iterable: Iterable[_T]
595
+ ) -> int: ...
596
+ def permutation_index(
597
+ element: Iterable[_T], iterable: Iterable[_T]
598
+ ) -> int: ...
599
+ def repeat_each(iterable: Iterable[_T], n: int = ...) -> Iterator[_T]: ...
600
+
601
+ class countable(Generic[_T], Iterator[_T]):
602
+ def __init__(self, iterable: Iterable[_T]) -> None: ...
603
+ def __iter__(self) -> countable[_T]: ...
604
+ def __next__(self) -> _T: ...
605
+
606
+ def chunked_even(iterable: Iterable[_T], n: int) -> Iterator[list[_T]]: ...
607
+ def zip_broadcast(
608
+ *objects: _T | Iterable[_T],
609
+ scalar_types: type | tuple[type | tuple[Any, ...], ...] | None = ...,
610
+ strict: bool = ...,
611
+ ) -> Iterable[tuple[_T, ...]]: ...
612
+ def unique_in_window(
613
+ iterable: Iterable[_T], n: int, key: Callable[[_T], _U] | None = ...
614
+ ) -> Iterator[_T]: ...
615
+ def duplicates_everseen(
616
+ iterable: Iterable[_T], key: Callable[[_T], _U] | None = ...
617
+ ) -> Iterator[_T]: ...
618
+ def duplicates_justseen(
619
+ iterable: Iterable[_T], key: Callable[[_T], _U] | None = ...
620
+ ) -> Iterator[_T]: ...
621
+
622
+ class _SupportsLessThan(Protocol):
623
+ def __lt__(self, __other: Any) -> bool: ...
624
+
625
+ _SupportsLessThanT = TypeVar("_SupportsLessThanT", bound=_SupportsLessThan)
626
+
627
+ @overload
628
+ def minmax(
629
+ iterable_or_value: Iterable[_SupportsLessThanT], *, key: None = None
630
+ ) -> tuple[_SupportsLessThanT, _SupportsLessThanT]: ...
631
+ @overload
632
+ def minmax(
633
+ iterable_or_value: Iterable[_T], *, key: Callable[[_T], _SupportsLessThan]
634
+ ) -> tuple[_T, _T]: ...
635
+ @overload
636
+ def minmax(
637
+ iterable_or_value: Iterable[_SupportsLessThanT],
638
+ *,
639
+ key: None = None,
640
+ default: _U,
641
+ ) -> _U | tuple[_SupportsLessThanT, _SupportsLessThanT]: ...
642
+ @overload
643
+ def minmax(
644
+ iterable_or_value: Iterable[_T],
645
+ *,
646
+ key: Callable[[_T], _SupportsLessThan],
647
+ default: _U,
648
+ ) -> _U | tuple[_T, _T]: ...
649
+ @overload
650
+ def minmax(
651
+ iterable_or_value: _SupportsLessThanT,
652
+ __other: _SupportsLessThanT,
653
+ *others: _SupportsLessThanT,
654
+ ) -> tuple[_SupportsLessThanT, _SupportsLessThanT]: ...
655
+ @overload
656
+ def minmax(
657
+ iterable_or_value: _T,
658
+ __other: _T,
659
+ *others: _T,
660
+ key: Callable[[_T], _SupportsLessThan],
661
+ ) -> tuple[_T, _T]: ...
662
+ def longest_common_prefix(
663
+ iterables: Iterable[Iterable[_T]],
664
+ ) -> Iterator[_T]: ...
665
+ def iequals(*iterables: Iterable[object]) -> bool: ...
666
+ def constrained_batches(
667
+ iterable: Iterable[object],
668
+ max_size: int,
669
+ max_count: int | None = ...,
670
+ get_len: Callable[[_T], object] = ...,
671
+ strict: bool = ...,
672
+ ) -> Iterator[tuple[_T]]: ...
673
+ def gray_product(*iterables: Iterable[_T]) -> Iterator[tuple[_T, ...]]: ...
674
+ def partial_product(*iterables: Iterable[_T]) -> Iterator[tuple[_T, ...]]: ...
675
+ def takewhile_inclusive(
676
+ predicate: Callable[[_T], bool], iterable: Iterable[_T]
677
+ ) -> Iterator[_T]: ...
678
+ def outer_product(
679
+ func: Callable[[_T, _U], _V],
680
+ xs: Iterable[_T],
681
+ ys: Iterable[_U],
682
+ *args: Any,
683
+ **kwargs: Any,
684
+ ) -> Iterator[tuple[_V, ...]]: ...
lib/python3.11/site-packages/more_itertools/py.typed ADDED
File without changes
lib/python3.11/site-packages/more_itertools/recipes.py ADDED
@@ -0,0 +1,977 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Imported from the recipes section of the itertools documentation.
2
+
3
+ All functions taken from the recipes section of the itertools library docs
4
+ [1]_.
5
+ Some backward-compatible usability improvements have been made.
6
+
7
+ .. [1] http://docs.python.org/library/itertools.html#recipes
8
+
9
+ """
10
+ import math
11
+ import operator
12
+
13
+ from collections import deque
14
+ from collections.abc import Sized
15
+ from functools import partial, reduce
16
+ from itertools import (
17
+ chain,
18
+ combinations,
19
+ compress,
20
+ count,
21
+ cycle,
22
+ groupby,
23
+ islice,
24
+ product,
25
+ repeat,
26
+ starmap,
27
+ tee,
28
+ zip_longest,
29
+ )
30
+ from random import randrange, sample, choice
31
+
32
+ __all__ = [
33
+ 'all_equal',
34
+ 'batched',
35
+ 'before_and_after',
36
+ 'consume',
37
+ 'convolve',
38
+ 'dotproduct',
39
+ 'first_true',
40
+ 'factor',
41
+ 'flatten',
42
+ 'grouper',
43
+ 'iter_except',
44
+ 'iter_index',
45
+ 'matmul',
46
+ 'ncycles',
47
+ 'nth',
48
+ 'nth_combination',
49
+ 'padnone',
50
+ 'pad_none',
51
+ 'pairwise',
52
+ 'partition',
53
+ 'polynomial_eval',
54
+ 'polynomial_from_roots',
55
+ 'polynomial_derivative',
56
+ 'powerset',
57
+ 'prepend',
58
+ 'quantify',
59
+ 'random_combination_with_replacement',
60
+ 'random_combination',
61
+ 'random_permutation',
62
+ 'random_product',
63
+ 'repeatfunc',
64
+ 'roundrobin',
65
+ 'sieve',
66
+ 'sliding_window',
67
+ 'subslices',
68
+ 'sum_of_squares',
69
+ 'tabulate',
70
+ 'tail',
71
+ 'take',
72
+ 'transpose',
73
+ 'triplewise',
74
+ 'unique_everseen',
75
+ 'unique_justseen',
76
+ ]
77
+
78
+ _marker = object()
79
+
80
+
81
+ # zip with strict is available for Python 3.10+
82
+ try:
83
+ zip(strict=True)
84
+ except TypeError:
85
+ _zip_strict = zip
86
+ else:
87
+ _zip_strict = partial(zip, strict=True)
88
+
89
+ # math.sumprod is available for Python 3.12+
90
+ _sumprod = getattr(math, 'sumprod', lambda x, y: dotproduct(x, y))
91
+
92
+
93
+ def take(n, iterable):
94
+ """Return first *n* items of the iterable as a list.
95
+
96
+ >>> take(3, range(10))
97
+ [0, 1, 2]
98
+
99
+ If there are fewer than *n* items in the iterable, all of them are
100
+ returned.
101
+
102
+ >>> take(10, range(3))
103
+ [0, 1, 2]
104
+
105
+ """
106
+ return list(islice(iterable, n))
107
+
108
+
109
+ def tabulate(function, start=0):
110
+ """Return an iterator over the results of ``func(start)``,
111
+ ``func(start + 1)``, ``func(start + 2)``...
112
+
113
+ *func* should be a function that accepts one integer argument.
114
+
115
+ If *start* is not specified it defaults to 0. It will be incremented each
116
+ time the iterator is advanced.
117
+
118
+ >>> square = lambda x: x ** 2
119
+ >>> iterator = tabulate(square, -3)
120
+ >>> take(4, iterator)
121
+ [9, 4, 1, 0]
122
+
123
+ """
124
+ return map(function, count(start))
125
+
126
+
127
+ def tail(n, iterable):
128
+ """Return an iterator over the last *n* items of *iterable*.
129
+
130
+ >>> t = tail(3, 'ABCDEFG')
131
+ >>> list(t)
132
+ ['E', 'F', 'G']
133
+
134
+ """
135
+ # If the given iterable has a length, then we can use islice to get its
136
+ # final elements. Note that if the iterable is not actually Iterable,
137
+ # either islice or deque will throw a TypeError. This is why we don't
138
+ # check if it is Iterable.
139
+ if isinstance(iterable, Sized):
140
+ yield from islice(iterable, max(0, len(iterable) - n), None)
141
+ else:
142
+ yield from iter(deque(iterable, maxlen=n))
143
+
144
+
145
+ def consume(iterator, n=None):
146
+ """Advance *iterable* by *n* steps. If *n* is ``None``, consume it
147
+ entirely.
148
+
149
+ Efficiently exhausts an iterator without returning values. Defaults to
150
+ consuming the whole iterator, but an optional second argument may be
151
+ provided to limit consumption.
152
+
153
+ >>> i = (x for x in range(10))
154
+ >>> next(i)
155
+ 0
156
+ >>> consume(i, 3)
157
+ >>> next(i)
158
+ 4
159
+ >>> consume(i)
160
+ >>> next(i)
161
+ Traceback (most recent call last):
162
+ File "<stdin>", line 1, in <module>
163
+ StopIteration
164
+
165
+ If the iterator has fewer items remaining than the provided limit, the
166
+ whole iterator will be consumed.
167
+
168
+ >>> i = (x for x in range(3))
169
+ >>> consume(i, 5)
170
+ >>> next(i)
171
+ Traceback (most recent call last):
172
+ File "<stdin>", line 1, in <module>
173
+ StopIteration
174
+
175
+ """
176
+ # Use functions that consume iterators at C speed.
177
+ if n is None:
178
+ # feed the entire iterator into a zero-length deque
179
+ deque(iterator, maxlen=0)
180
+ else:
181
+ # advance to the empty slice starting at position n
182
+ next(islice(iterator, n, n), None)
183
+
184
+
185
+ def nth(iterable, n, default=None):
186
+ """Returns the nth item or a default value.
187
+
188
+ >>> l = range(10)
189
+ >>> nth(l, 3)
190
+ 3
191
+ >>> nth(l, 20, "zebra")
192
+ 'zebra'
193
+
194
+ """
195
+ return next(islice(iterable, n, None), default)
196
+
197
+
198
+ def all_equal(iterable):
199
+ """
200
+ Returns ``True`` if all the elements are equal to each other.
201
+
202
+ >>> all_equal('aaaa')
203
+ True
204
+ >>> all_equal('aaab')
205
+ False
206
+
207
+ """
208
+ g = groupby(iterable)
209
+ return next(g, True) and not next(g, False)
210
+
211
+
212
+ def quantify(iterable, pred=bool):
213
+ """Return the how many times the predicate is true.
214
+
215
+ >>> quantify([True, False, True])
216
+ 2
217
+
218
+ """
219
+ return sum(map(pred, iterable))
220
+
221
+
222
+ def pad_none(iterable):
223
+ """Returns the sequence of elements and then returns ``None`` indefinitely.
224
+
225
+ >>> take(5, pad_none(range(3)))
226
+ [0, 1, 2, None, None]
227
+
228
+ Useful for emulating the behavior of the built-in :func:`map` function.
229
+
230
+ See also :func:`padded`.
231
+
232
+ """
233
+ return chain(iterable, repeat(None))
234
+
235
+
236
+ padnone = pad_none
237
+
238
+
239
+ def ncycles(iterable, n):
240
+ """Returns the sequence elements *n* times
241
+
242
+ >>> list(ncycles(["a", "b"], 3))
243
+ ['a', 'b', 'a', 'b', 'a', 'b']
244
+
245
+ """
246
+ return chain.from_iterable(repeat(tuple(iterable), n))
247
+
248
+
249
+ def dotproduct(vec1, vec2):
250
+ """Returns the dot product of the two iterables.
251
+
252
+ >>> dotproduct([10, 10], [20, 20])
253
+ 400
254
+
255
+ """
256
+ return sum(map(operator.mul, vec1, vec2))
257
+
258
+
259
+ def flatten(listOfLists):
260
+ """Return an iterator flattening one level of nesting in a list of lists.
261
+
262
+ >>> list(flatten([[0, 1], [2, 3]]))
263
+ [0, 1, 2, 3]
264
+
265
+ See also :func:`collapse`, which can flatten multiple levels of nesting.
266
+
267
+ """
268
+ return chain.from_iterable(listOfLists)
269
+
270
+
271
+ def repeatfunc(func, times=None, *args):
272
+ """Call *func* with *args* repeatedly, returning an iterable over the
273
+ results.
274
+
275
+ If *times* is specified, the iterable will terminate after that many
276
+ repetitions:
277
+
278
+ >>> from operator import add
279
+ >>> times = 4
280
+ >>> args = 3, 5
281
+ >>> list(repeatfunc(add, times, *args))
282
+ [8, 8, 8, 8]
283
+
284
+ If *times* is ``None`` the iterable will not terminate:
285
+
286
+ >>> from random import randrange
287
+ >>> times = None
288
+ >>> args = 1, 11
289
+ >>> take(6, repeatfunc(randrange, times, *args)) # doctest:+SKIP
290
+ [2, 4, 8, 1, 8, 4]
291
+
292
+ """
293
+ if times is None:
294
+ return starmap(func, repeat(args))
295
+ return starmap(func, repeat(args, times))
296
+
297
+
298
+ def _pairwise(iterable):
299
+ """Returns an iterator of paired items, overlapping, from the original
300
+
301
+ >>> take(4, pairwise(count()))
302
+ [(0, 1), (1, 2), (2, 3), (3, 4)]
303
+
304
+ On Python 3.10 and above, this is an alias for :func:`itertools.pairwise`.
305
+
306
+ """
307
+ a, b = tee(iterable)
308
+ next(b, None)
309
+ return zip(a, b)
310
+
311
+
312
+ try:
313
+ from itertools import pairwise as itertools_pairwise
314
+ except ImportError:
315
+ pairwise = _pairwise
316
+ else:
317
+
318
+ def pairwise(iterable):
319
+ return itertools_pairwise(iterable)
320
+
321
+ pairwise.__doc__ = _pairwise.__doc__
322
+
323
+
324
+ class UnequalIterablesError(ValueError):
325
+ def __init__(self, details=None):
326
+ msg = 'Iterables have different lengths'
327
+ if details is not None:
328
+ msg += (': index 0 has length {}; index {} has length {}').format(
329
+ *details
330
+ )
331
+
332
+ super().__init__(msg)
333
+
334
+
335
+ def _zip_equal_generator(iterables):
336
+ for combo in zip_longest(*iterables, fillvalue=_marker):
337
+ for val in combo:
338
+ if val is _marker:
339
+ raise UnequalIterablesError()
340
+ yield combo
341
+
342
+
343
+ def _zip_equal(*iterables):
344
+ # Check whether the iterables are all the same size.
345
+ try:
346
+ first_size = len(iterables[0])
347
+ for i, it in enumerate(iterables[1:], 1):
348
+ size = len(it)
349
+ if size != first_size:
350
+ raise UnequalIterablesError(details=(first_size, i, size))
351
+ # All sizes are equal, we can use the built-in zip.
352
+ return zip(*iterables)
353
+ # If any one of the iterables didn't have a length, start reading
354
+ # them until one runs out.
355
+ except TypeError:
356
+ return _zip_equal_generator(iterables)
357
+
358
+
359
+ def grouper(iterable, n, incomplete='fill', fillvalue=None):
360
+ """Group elements from *iterable* into fixed-length groups of length *n*.
361
+
362
+ >>> list(grouper('ABCDEF', 3))
363
+ [('A', 'B', 'C'), ('D', 'E', 'F')]
364
+
365
+ The keyword arguments *incomplete* and *fillvalue* control what happens for
366
+ iterables whose length is not a multiple of *n*.
367
+
368
+ When *incomplete* is `'fill'`, the last group will contain instances of
369
+ *fillvalue*.
370
+
371
+ >>> list(grouper('ABCDEFG', 3, incomplete='fill', fillvalue='x'))
372
+ [('A', 'B', 'C'), ('D', 'E', 'F'), ('G', 'x', 'x')]
373
+
374
+ When *incomplete* is `'ignore'`, the last group will not be emitted.
375
+
376
+ >>> list(grouper('ABCDEFG', 3, incomplete='ignore', fillvalue='x'))
377
+ [('A', 'B', 'C'), ('D', 'E', 'F')]
378
+
379
+ When *incomplete* is `'strict'`, a subclass of `ValueError` will be raised.
380
+
381
+ >>> it = grouper('ABCDEFG', 3, incomplete='strict')
382
+ >>> list(it) # doctest: +IGNORE_EXCEPTION_DETAIL
383
+ Traceback (most recent call last):
384
+ ...
385
+ UnequalIterablesError
386
+
387
+ """
388
+ args = [iter(iterable)] * n
389
+ if incomplete == 'fill':
390
+ return zip_longest(*args, fillvalue=fillvalue)
391
+ if incomplete == 'strict':
392
+ return _zip_equal(*args)
393
+ if incomplete == 'ignore':
394
+ return zip(*args)
395
+ else:
396
+ raise ValueError('Expected fill, strict, or ignore')
397
+
398
+
399
+ def roundrobin(*iterables):
400
+ """Yields an item from each iterable, alternating between them.
401
+
402
+ >>> list(roundrobin('ABC', 'D', 'EF'))
403
+ ['A', 'D', 'E', 'B', 'F', 'C']
404
+
405
+ This function produces the same output as :func:`interleave_longest`, but
406
+ may perform better for some inputs (in particular when the number of
407
+ iterables is small).
408
+
409
+ """
410
+ # Recipe credited to George Sakkis
411
+ pending = len(iterables)
412
+ nexts = cycle(iter(it).__next__ for it in iterables)
413
+ while pending:
414
+ try:
415
+ for next in nexts:
416
+ yield next()
417
+ except StopIteration:
418
+ pending -= 1
419
+ nexts = cycle(islice(nexts, pending))
420
+
421
+
422
+ def partition(pred, iterable):
423
+ """
424
+ Returns a 2-tuple of iterables derived from the input iterable.
425
+ The first yields the items that have ``pred(item) == False``.
426
+ The second yields the items that have ``pred(item) == True``.
427
+
428
+ >>> is_odd = lambda x: x % 2 != 0
429
+ >>> iterable = range(10)
430
+ >>> even_items, odd_items = partition(is_odd, iterable)
431
+ >>> list(even_items), list(odd_items)
432
+ ([0, 2, 4, 6, 8], [1, 3, 5, 7, 9])
433
+
434
+ If *pred* is None, :func:`bool` is used.
435
+
436
+ >>> iterable = [0, 1, False, True, '', ' ']
437
+ >>> false_items, true_items = partition(None, iterable)
438
+ >>> list(false_items), list(true_items)
439
+ ([0, False, ''], [1, True, ' '])
440
+
441
+ """
442
+ if pred is None:
443
+ pred = bool
444
+
445
+ t1, t2, p = tee(iterable, 3)
446
+ p1, p2 = tee(map(pred, p))
447
+ return (compress(t1, map(operator.not_, p1)), compress(t2, p2))
448
+
449
+
450
+ def powerset(iterable):
451
+ """Yields all possible subsets of the iterable.
452
+
453
+ >>> list(powerset([1, 2, 3]))
454
+ [(), (1,), (2,), (3,), (1, 2), (1, 3), (2, 3), (1, 2, 3)]
455
+
456
+ :func:`powerset` will operate on iterables that aren't :class:`set`
457
+ instances, so repeated elements in the input will produce repeated elements
458
+ in the output. Use :func:`unique_everseen` on the input to avoid generating
459
+ duplicates:
460
+
461
+ >>> seq = [1, 1, 0]
462
+ >>> list(powerset(seq))
463
+ [(), (1,), (1,), (0,), (1, 1), (1, 0), (1, 0), (1, 1, 0)]
464
+ >>> from more_itertools import unique_everseen
465
+ >>> list(powerset(unique_everseen(seq)))
466
+ [(), (1,), (0,), (1, 0)]
467
+
468
+ """
469
+ s = list(iterable)
470
+ return chain.from_iterable(combinations(s, r) for r in range(len(s) + 1))
471
+
472
+
473
+ def unique_everseen(iterable, key=None):
474
+ """
475
+ Yield unique elements, preserving order.
476
+
477
+ >>> list(unique_everseen('AAAABBBCCDAABBB'))
478
+ ['A', 'B', 'C', 'D']
479
+ >>> list(unique_everseen('ABBCcAD', str.lower))
480
+ ['A', 'B', 'C', 'D']
481
+
482
+ Sequences with a mix of hashable and unhashable items can be used.
483
+ The function will be slower (i.e., `O(n^2)`) for unhashable items.
484
+
485
+ Remember that ``list`` objects are unhashable - you can use the *key*
486
+ parameter to transform the list to a tuple (which is hashable) to
487
+ avoid a slowdown.
488
+
489
+ >>> iterable = ([1, 2], [2, 3], [1, 2])
490
+ >>> list(unique_everseen(iterable)) # Slow
491
+ [[1, 2], [2, 3]]
492
+ >>> list(unique_everseen(iterable, key=tuple)) # Faster
493
+ [[1, 2], [2, 3]]
494
+
495
+ Similary, you may want to convert unhashable ``set`` objects with
496
+ ``key=frozenset``. For ``dict`` objects,
497
+ ``key=lambda x: frozenset(x.items())`` can be used.
498
+
499
+ """
500
+ seenset = set()
501
+ seenset_add = seenset.add
502
+ seenlist = []
503
+ seenlist_add = seenlist.append
504
+ use_key = key is not None
505
+
506
+ for element in iterable:
507
+ k = key(element) if use_key else element
508
+ try:
509
+ if k not in seenset:
510
+ seenset_add(k)
511
+ yield element
512
+ except TypeError:
513
+ if k not in seenlist:
514
+ seenlist_add(k)
515
+ yield element
516
+
517
+
518
+ def unique_justseen(iterable, key=None):
519
+ """Yields elements in order, ignoring serial duplicates
520
+
521
+ >>> list(unique_justseen('AAAABBBCCDAABBB'))
522
+ ['A', 'B', 'C', 'D', 'A', 'B']
523
+ >>> list(unique_justseen('ABBCcAD', str.lower))
524
+ ['A', 'B', 'C', 'A', 'D']
525
+
526
+ """
527
+ return map(next, map(operator.itemgetter(1), groupby(iterable, key)))
528
+
529
+
530
+ def iter_except(func, exception, first=None):
531
+ """Yields results from a function repeatedly until an exception is raised.
532
+
533
+ Converts a call-until-exception interface to an iterator interface.
534
+ Like ``iter(func, sentinel)``, but uses an exception instead of a sentinel
535
+ to end the loop.
536
+
537
+ >>> l = [0, 1, 2]
538
+ >>> list(iter_except(l.pop, IndexError))
539
+ [2, 1, 0]
540
+
541
+ Multiple exceptions can be specified as a stopping condition:
542
+
543
+ >>> l = [1, 2, 3, '...', 4, 5, 6]
544
+ >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
545
+ [7, 6, 5]
546
+ >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
547
+ [4, 3, 2]
548
+ >>> list(iter_except(lambda: 1 + l.pop(), (IndexError, TypeError)))
549
+ []
550
+
551
+ """
552
+ try:
553
+ if first is not None:
554
+ yield first()
555
+ while 1:
556
+ yield func()
557
+ except exception:
558
+ pass
559
+
560
+
561
+ def first_true(iterable, default=None, pred=None):
562
+ """
563
+ Returns the first true value in the iterable.
564
+
565
+ If no true value is found, returns *default*
566
+
567
+ If *pred* is not None, returns the first item for which
568
+ ``pred(item) == True`` .
569
+
570
+ >>> first_true(range(10))
571
+ 1
572
+ >>> first_true(range(10), pred=lambda x: x > 5)
573
+ 6
574
+ >>> first_true(range(10), default='missing', pred=lambda x: x > 9)
575
+ 'missing'
576
+
577
+ """
578
+ return next(filter(pred, iterable), default)
579
+
580
+
581
+ def random_product(*args, repeat=1):
582
+ """Draw an item at random from each of the input iterables.
583
+
584
+ >>> random_product('abc', range(4), 'XYZ') # doctest:+SKIP
585
+ ('c', 3, 'Z')
586
+
587
+ If *repeat* is provided as a keyword argument, that many items will be
588
+ drawn from each iterable.
589
+
590
+ >>> random_product('abcd', range(4), repeat=2) # doctest:+SKIP
591
+ ('a', 2, 'd', 3)
592
+
593
+ This equivalent to taking a random selection from
594
+ ``itertools.product(*args, **kwarg)``.
595
+
596
+ """
597
+ pools = [tuple(pool) for pool in args] * repeat
598
+ return tuple(choice(pool) for pool in pools)
599
+
600
+
601
+ def random_permutation(iterable, r=None):
602
+ """Return a random *r* length permutation of the elements in *iterable*.
603
+
604
+ If *r* is not specified or is ``None``, then *r* defaults to the length of
605
+ *iterable*.
606
+
607
+ >>> random_permutation(range(5)) # doctest:+SKIP
608
+ (3, 4, 0, 1, 2)
609
+
610
+ This equivalent to taking a random selection from
611
+ ``itertools.permutations(iterable, r)``.
612
+
613
+ """
614
+ pool = tuple(iterable)
615
+ r = len(pool) if r is None else r
616
+ return tuple(sample(pool, r))
617
+
618
+
619
+ def random_combination(iterable, r):
620
+ """Return a random *r* length subsequence of the elements in *iterable*.
621
+
622
+ >>> random_combination(range(5), 3) # doctest:+SKIP
623
+ (2, 3, 4)
624
+
625
+ This equivalent to taking a random selection from
626
+ ``itertools.combinations(iterable, r)``.
627
+
628
+ """
629
+ pool = tuple(iterable)
630
+ n = len(pool)
631
+ indices = sorted(sample(range(n), r))
632
+ return tuple(pool[i] for i in indices)
633
+
634
+
635
+ def random_combination_with_replacement(iterable, r):
636
+ """Return a random *r* length subsequence of elements in *iterable*,
637
+ allowing individual elements to be repeated.
638
+
639
+ >>> random_combination_with_replacement(range(3), 5) # doctest:+SKIP
640
+ (0, 0, 1, 2, 2)
641
+
642
+ This equivalent to taking a random selection from
643
+ ``itertools.combinations_with_replacement(iterable, r)``.
644
+
645
+ """
646
+ pool = tuple(iterable)
647
+ n = len(pool)
648
+ indices = sorted(randrange(n) for i in range(r))
649
+ return tuple(pool[i] for i in indices)
650
+
651
+
652
+ def nth_combination(iterable, r, index):
653
+ """Equivalent to ``list(combinations(iterable, r))[index]``.
654
+
655
+ The subsequences of *iterable* that are of length *r* can be ordered
656
+ lexicographically. :func:`nth_combination` computes the subsequence at
657
+ sort position *index* directly, without computing the previous
658
+ subsequences.
659
+
660
+ >>> nth_combination(range(5), 3, 5)
661
+ (0, 3, 4)
662
+
663
+ ``ValueError`` will be raised If *r* is negative or greater than the length
664
+ of *iterable*.
665
+ ``IndexError`` will be raised if the given *index* is invalid.
666
+ """
667
+ pool = tuple(iterable)
668
+ n = len(pool)
669
+ if (r < 0) or (r > n):
670
+ raise ValueError
671
+
672
+ c = 1
673
+ k = min(r, n - r)
674
+ for i in range(1, k + 1):
675
+ c = c * (n - k + i) // i
676
+
677
+ if index < 0:
678
+ index += c
679
+
680
+ if (index < 0) or (index >= c):
681
+ raise IndexError
682
+
683
+ result = []
684
+ while r:
685
+ c, n, r = c * r // n, n - 1, r - 1
686
+ while index >= c:
687
+ index -= c
688
+ c, n = c * (n - r) // n, n - 1
689
+ result.append(pool[-1 - n])
690
+
691
+ return tuple(result)
692
+
693
+
694
+ def prepend(value, iterator):
695
+ """Yield *value*, followed by the elements in *iterator*.
696
+
697
+ >>> value = '0'
698
+ >>> iterator = ['1', '2', '3']
699
+ >>> list(prepend(value, iterator))
700
+ ['0', '1', '2', '3']
701
+
702
+ To prepend multiple values, see :func:`itertools.chain`
703
+ or :func:`value_chain`.
704
+
705
+ """
706
+ return chain([value], iterator)
707
+
708
+
709
+ def convolve(signal, kernel):
710
+ """Convolve the iterable *signal* with the iterable *kernel*.
711
+
712
+ >>> signal = (1, 2, 3, 4, 5)
713
+ >>> kernel = [3, 2, 1]
714
+ >>> list(convolve(signal, kernel))
715
+ [3, 8, 14, 20, 26, 14, 5]
716
+
717
+ Note: the input arguments are not interchangeable, as the *kernel*
718
+ is immediately consumed and stored.
719
+
720
+ """
721
+ # This implementation intentionally doesn't match the one in the itertools
722
+ # documentation.
723
+ kernel = tuple(kernel)[::-1]
724
+ n = len(kernel)
725
+ window = deque([0], maxlen=n) * n
726
+ for x in chain(signal, repeat(0, n - 1)):
727
+ window.append(x)
728
+ yield _sumprod(kernel, window)
729
+
730
+
731
+ def before_and_after(predicate, it):
732
+ """A variant of :func:`takewhile` that allows complete access to the
733
+ remainder of the iterator.
734
+
735
+ >>> it = iter('ABCdEfGhI')
736
+ >>> all_upper, remainder = before_and_after(str.isupper, it)
737
+ >>> ''.join(all_upper)
738
+ 'ABC'
739
+ >>> ''.join(remainder) # takewhile() would lose the 'd'
740
+ 'dEfGhI'
741
+
742
+ Note that the first iterator must be fully consumed before the second
743
+ iterator can generate valid results.
744
+ """
745
+ it = iter(it)
746
+ transition = []
747
+
748
+ def true_iterator():
749
+ for elem in it:
750
+ if predicate(elem):
751
+ yield elem
752
+ else:
753
+ transition.append(elem)
754
+ return
755
+
756
+ # Note: this is different from itertools recipes to allow nesting
757
+ # before_and_after remainders into before_and_after again. See tests
758
+ # for an example.
759
+ remainder_iterator = chain(transition, it)
760
+
761
+ return true_iterator(), remainder_iterator
762
+
763
+
764
+ def triplewise(iterable):
765
+ """Return overlapping triplets from *iterable*.
766
+
767
+ >>> list(triplewise('ABCDE'))
768
+ [('A', 'B', 'C'), ('B', 'C', 'D'), ('C', 'D', 'E')]
769
+
770
+ """
771
+ for (a, _), (b, c) in pairwise(pairwise(iterable)):
772
+ yield a, b, c
773
+
774
+
775
+ def sliding_window(iterable, n):
776
+ """Return a sliding window of width *n* over *iterable*.
777
+
778
+ >>> list(sliding_window(range(6), 4))
779
+ [(0, 1, 2, 3), (1, 2, 3, 4), (2, 3, 4, 5)]
780
+
781
+ If *iterable* has fewer than *n* items, then nothing is yielded:
782
+
783
+ >>> list(sliding_window(range(3), 4))
784
+ []
785
+
786
+ For a variant with more features, see :func:`windowed`.
787
+ """
788
+ it = iter(iterable)
789
+ window = deque(islice(it, n - 1), maxlen=n)
790
+ for x in it:
791
+ window.append(x)
792
+ yield tuple(window)
793
+
794
+
795
+ def subslices(iterable):
796
+ """Return all contiguous non-empty subslices of *iterable*.
797
+
798
+ >>> list(subslices('ABC'))
799
+ [['A'], ['A', 'B'], ['A', 'B', 'C'], ['B'], ['B', 'C'], ['C']]
800
+
801
+ This is similar to :func:`substrings`, but emits items in a different
802
+ order.
803
+ """
804
+ seq = list(iterable)
805
+ slices = starmap(slice, combinations(range(len(seq) + 1), 2))
806
+ return map(operator.getitem, repeat(seq), slices)
807
+
808
+
809
+ def polynomial_from_roots(roots):
810
+ """Compute a polynomial's coefficients from its roots.
811
+
812
+ >>> roots = [5, -4, 3] # (x - 5) * (x + 4) * (x - 3)
813
+ >>> polynomial_from_roots(roots) # x^3 - 4 * x^2 - 17 * x + 60
814
+ [1, -4, -17, 60]
815
+ """
816
+ factors = zip(repeat(1), map(operator.neg, roots))
817
+ return list(reduce(convolve, factors, [1]))
818
+
819
+
820
+ def iter_index(iterable, value, start=0):
821
+ """Yield the index of each place in *iterable* that *value* occurs,
822
+ beginning with index *start*.
823
+
824
+ See :func:`locate` for a more general means of finding the indexes
825
+ associated with particular values.
826
+
827
+ >>> list(iter_index('AABCADEAF', 'A'))
828
+ [0, 1, 4, 7]
829
+ """
830
+ try:
831
+ seq_index = iterable.index
832
+ except AttributeError:
833
+ # Slow path for general iterables
834
+ it = islice(iterable, start, None)
835
+ i = start - 1
836
+ try:
837
+ while True:
838
+ i = i + operator.indexOf(it, value) + 1
839
+ yield i
840
+ except ValueError:
841
+ pass
842
+ else:
843
+ # Fast path for sequences
844
+ i = start - 1
845
+ try:
846
+ while True:
847
+ i = seq_index(value, i + 1)
848
+ yield i
849
+ except ValueError:
850
+ pass
851
+
852
+
853
+ def sieve(n):
854
+ """Yield the primes less than n.
855
+
856
+ >>> list(sieve(30))
857
+ [2, 3, 5, 7, 11, 13, 17, 19, 23, 29]
858
+ """
859
+ data = bytearray((0, 1)) * (n // 2)
860
+ data[:3] = 0, 0, 0
861
+ limit = math.isqrt(n) + 1
862
+ for p in compress(range(limit), data):
863
+ data[p * p : n : p + p] = bytes(len(range(p * p, n, p + p)))
864
+ data[2] = 1
865
+ return iter_index(data, 1) if n > 2 else iter([])
866
+
867
+
868
+ def _batched(iterable, n):
869
+ """Batch data into lists of length *n*. The last batch may be shorter.
870
+
871
+ >>> list(batched('ABCDEFG', 3))
872
+ [('A', 'B', 'C'), ('D', 'E', 'F'), ('G',)]
873
+
874
+ On Python 3.12 and above, this is an alias for :func:`itertools.batched`.
875
+ """
876
+ if n < 1:
877
+ raise ValueError('n must be at least one')
878
+ it = iter(iterable)
879
+ while True:
880
+ batch = tuple(islice(it, n))
881
+ if not batch:
882
+ break
883
+ yield batch
884
+
885
+
886
+ try:
887
+ from itertools import batched as itertools_batched
888
+ except ImportError:
889
+ batched = _batched
890
+ else:
891
+
892
+ def batched(iterable, n):
893
+ return itertools_batched(iterable, n)
894
+
895
+ batched.__doc__ = _batched.__doc__
896
+
897
+
898
+ def transpose(it):
899
+ """Swap the rows and columns of the input.
900
+
901
+ >>> list(transpose([(1, 2, 3), (11, 22, 33)]))
902
+ [(1, 11), (2, 22), (3, 33)]
903
+
904
+ The caller should ensure that the dimensions of the input are compatible.
905
+ If the input is empty, no output will be produced.
906
+ """
907
+ return _zip_strict(*it)
908
+
909
+
910
+ def matmul(m1, m2):
911
+ """Multiply two matrices.
912
+ >>> list(matmul([(7, 5), (3, 5)], [(2, 5), (7, 9)]))
913
+ [(49, 80), (41, 60)]
914
+
915
+ The caller should ensure that the dimensions of the input matrices are
916
+ compatible with each other.
917
+ """
918
+ n = len(m2[0])
919
+ return batched(starmap(_sumprod, product(m1, transpose(m2))), n)
920
+
921
+
922
+ def factor(n):
923
+ """Yield the prime factors of n.
924
+ >>> list(factor(360))
925
+ [2, 2, 2, 3, 3, 5]
926
+ """
927
+ for prime in sieve(math.isqrt(n) + 1):
928
+ while True:
929
+ if n % prime:
930
+ break
931
+ yield prime
932
+ n //= prime
933
+ if n == 1:
934
+ return
935
+ if n > 1:
936
+ yield n
937
+
938
+
939
+ def polynomial_eval(coefficients, x):
940
+ """Evaluate a polynomial at a specific value.
941
+
942
+ Example: evaluating x^3 - 4 * x^2 - 17 * x + 60 at x = 2.5:
943
+
944
+ >>> coefficients = [1, -4, -17, 60]
945
+ >>> x = 2.5
946
+ >>> polynomial_eval(coefficients, x)
947
+ 8.125
948
+ """
949
+ n = len(coefficients)
950
+ if n == 0:
951
+ return x * 0 # coerce zero to the type of x
952
+ powers = map(pow, repeat(x), reversed(range(n)))
953
+ return _sumprod(coefficients, powers)
954
+
955
+
956
+ def sum_of_squares(it):
957
+ """Return the sum of the squares of the input values.
958
+
959
+ >>> sum_of_squares([10, 20, 30])
960
+ 1400
961
+ """
962
+ return _sumprod(*tee(it))
963
+
964
+
965
+ def polynomial_derivative(coefficients):
966
+ """Compute the first derivative of a polynomial.
967
+
968
+ Example: evaluating the derivative of x^3 - 4 * x^2 - 17 * x + 60
969
+
970
+ >>> coefficients = [1, -4, -17, 60]
971
+ >>> derivative_coefficients = polynomial_derivative(coefficients)
972
+ >>> derivative_coefficients
973
+ [3, -8, -17]
974
+ """
975
+ n = len(coefficients)
976
+ powers = reversed(range(1, n))
977
+ return list(map(operator.mul, coefficients, powers))
lib/python3.11/site-packages/more_itertools/recipes.pyi ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stubs for more_itertools.recipes"""
2
+ from __future__ import annotations
3
+
4
+ from typing import (
5
+ Any,
6
+ Callable,
7
+ Iterable,
8
+ Iterator,
9
+ overload,
10
+ Sequence,
11
+ Type,
12
+ TypeVar,
13
+ )
14
+
15
+ # Type and type variable definitions
16
+ _T = TypeVar('_T')
17
+ _U = TypeVar('_U')
18
+
19
+ def take(n: int, iterable: Iterable[_T]) -> list[_T]: ...
20
+ def tabulate(
21
+ function: Callable[[int], _T], start: int = ...
22
+ ) -> Iterator[_T]: ...
23
+ def tail(n: int, iterable: Iterable[_T]) -> Iterator[_T]: ...
24
+ def consume(iterator: Iterable[_T], n: int | None = ...) -> None: ...
25
+ @overload
26
+ def nth(iterable: Iterable[_T], n: int) -> _T | None: ...
27
+ @overload
28
+ def nth(iterable: Iterable[_T], n: int, default: _U) -> _T | _U: ...
29
+ def all_equal(iterable: Iterable[object]) -> bool: ...
30
+ def quantify(
31
+ iterable: Iterable[_T], pred: Callable[[_T], bool] = ...
32
+ ) -> int: ...
33
+ def pad_none(iterable: Iterable[_T]) -> Iterator[_T | None]: ...
34
+ def padnone(iterable: Iterable[_T]) -> Iterator[_T | None]: ...
35
+ def ncycles(iterable: Iterable[_T], n: int) -> Iterator[_T]: ...
36
+ def dotproduct(vec1: Iterable[object], vec2: Iterable[object]) -> object: ...
37
+ def flatten(listOfLists: Iterable[Iterable[_T]]) -> Iterator[_T]: ...
38
+ def repeatfunc(
39
+ func: Callable[..., _U], times: int | None = ..., *args: Any
40
+ ) -> Iterator[_U]: ...
41
+ def pairwise(iterable: Iterable[_T]) -> Iterator[tuple[_T, _T]]: ...
42
+ def grouper(
43
+ iterable: Iterable[_T],
44
+ n: int,
45
+ incomplete: str = ...,
46
+ fillvalue: _U = ...,
47
+ ) -> Iterator[tuple[_T | _U, ...]]: ...
48
+ def roundrobin(*iterables: Iterable[_T]) -> Iterator[_T]: ...
49
+ def partition(
50
+ pred: Callable[[_T], object] | None, iterable: Iterable[_T]
51
+ ) -> tuple[Iterator[_T], Iterator[_T]]: ...
52
+ def powerset(iterable: Iterable[_T]) -> Iterator[tuple[_T, ...]]: ...
53
+ def unique_everseen(
54
+ iterable: Iterable[_T], key: Callable[[_T], _U] | None = ...
55
+ ) -> Iterator[_T]: ...
56
+ def unique_justseen(
57
+ iterable: Iterable[_T], key: Callable[[_T], object] | None = ...
58
+ ) -> Iterator[_T]: ...
59
+ @overload
60
+ def iter_except(
61
+ func: Callable[[], _T],
62
+ exception: Type[BaseException] | tuple[Type[BaseException], ...],
63
+ first: None = ...,
64
+ ) -> Iterator[_T]: ...
65
+ @overload
66
+ def iter_except(
67
+ func: Callable[[], _T],
68
+ exception: Type[BaseException] | tuple[Type[BaseException], ...],
69
+ first: Callable[[], _U],
70
+ ) -> Iterator[_T | _U]: ...
71
+ @overload
72
+ def first_true(
73
+ iterable: Iterable[_T], *, pred: Callable[[_T], object] | None = ...
74
+ ) -> _T | None: ...
75
+ @overload
76
+ def first_true(
77
+ iterable: Iterable[_T],
78
+ default: _U,
79
+ pred: Callable[[_T], object] | None = ...,
80
+ ) -> _T | _U: ...
81
+ def random_product(
82
+ *args: Iterable[_T], repeat: int = ...
83
+ ) -> tuple[_T, ...]: ...
84
+ def random_permutation(
85
+ iterable: Iterable[_T], r: int | None = ...
86
+ ) -> tuple[_T, ...]: ...
87
+ def random_combination(iterable: Iterable[_T], r: int) -> tuple[_T, ...]: ...
88
+ def random_combination_with_replacement(
89
+ iterable: Iterable[_T], r: int
90
+ ) -> tuple[_T, ...]: ...
91
+ def nth_combination(
92
+ iterable: Iterable[_T], r: int, index: int
93
+ ) -> tuple[_T, ...]: ...
94
+ def prepend(value: _T, iterator: Iterable[_U]) -> Iterator[_T | _U]: ...
95
+ def convolve(signal: Iterable[_T], kernel: Iterable[_T]) -> Iterator[_T]: ...
96
+ def before_and_after(
97
+ predicate: Callable[[_T], bool], it: Iterable[_T]
98
+ ) -> tuple[Iterator[_T], Iterator[_T]]: ...
99
+ def triplewise(iterable: Iterable[_T]) -> Iterator[tuple[_T, _T, _T]]: ...
100
+ def sliding_window(
101
+ iterable: Iterable[_T], n: int
102
+ ) -> Iterator[tuple[_T, ...]]: ...
103
+ def subslices(iterable: Iterable[_T]) -> Iterator[list[_T]]: ...
104
+ def polynomial_from_roots(roots: Sequence[_T]) -> list[_T]: ...
105
+ def iter_index(
106
+ iterable: Iterable[object],
107
+ value: Any,
108
+ start: int | None = ...,
109
+ ) -> Iterator[int]: ...
110
+ def sieve(n: int) -> Iterator[int]: ...
111
+ def batched(
112
+ iterable: Iterable[_T],
113
+ n: int,
114
+ ) -> Iterator[tuple[_T]]: ...
115
+ def transpose(
116
+ it: Iterable[Iterable[_T]],
117
+ ) -> Iterator[tuple[_T, ...]]: ...
118
+ def matmul(m1: Sequence[_T], m2: Sequence[_T]) -> Iterator[tuple[_T]]: ...
119
+ def factor(n: int) -> Iterator[int]: ...
120
+ def polynomial_eval(coefficients: Sequence[_T], x: _U) -> _U: ...
121
+ def sum_of_squares(it: Iterable[_T]) -> _T: ...
122
+ def polynomial_derivative(coefficients: Sequence[_T]) -> list[_T]: ...
lib/python3.11/site-packages/mpmath/__init__.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = '1.3.0'
2
+
3
+ from .usertools import monitor, timing
4
+
5
+ from .ctx_fp import FPContext
6
+ from .ctx_mp import MPContext
7
+ from .ctx_iv import MPIntervalContext
8
+
9
+ fp = FPContext()
10
+ mp = MPContext()
11
+ iv = MPIntervalContext()
12
+
13
+ fp._mp = mp
14
+ mp._mp = mp
15
+ iv._mp = mp
16
+ mp._fp = fp
17
+ fp._fp = fp
18
+ mp._iv = iv
19
+ fp._iv = iv
20
+ iv._iv = iv
21
+
22
+ # XXX: extremely bad pickle hack
23
+ from . import ctx_mp as _ctx_mp
24
+ _ctx_mp._mpf_module.mpf = mp.mpf
25
+ _ctx_mp._mpf_module.mpc = mp.mpc
26
+
27
+ make_mpf = mp.make_mpf
28
+ make_mpc = mp.make_mpc
29
+
30
+ extraprec = mp.extraprec
31
+ extradps = mp.extradps
32
+ workprec = mp.workprec
33
+ workdps = mp.workdps
34
+ autoprec = mp.autoprec
35
+ maxcalls = mp.maxcalls
36
+ memoize = mp.memoize
37
+
38
+ mag = mp.mag
39
+
40
+ bernfrac = mp.bernfrac
41
+
42
+ qfrom = mp.qfrom
43
+ mfrom = mp.mfrom
44
+ kfrom = mp.kfrom
45
+ taufrom = mp.taufrom
46
+ qbarfrom = mp.qbarfrom
47
+ ellipfun = mp.ellipfun
48
+ jtheta = mp.jtheta
49
+ kleinj = mp.kleinj
50
+ eta = mp.eta
51
+
52
+ qp = mp.qp
53
+ qhyper = mp.qhyper
54
+ qgamma = mp.qgamma
55
+ qfac = mp.qfac
56
+
57
+ nint_distance = mp.nint_distance
58
+
59
+ plot = mp.plot
60
+ cplot = mp.cplot
61
+ splot = mp.splot
62
+
63
+ odefun = mp.odefun
64
+
65
+ jacobian = mp.jacobian
66
+ findroot = mp.findroot
67
+ multiplicity = mp.multiplicity
68
+
69
+ isinf = mp.isinf
70
+ isnan = mp.isnan
71
+ isnormal = mp.isnormal
72
+ isint = mp.isint
73
+ isfinite = mp.isfinite
74
+ almosteq = mp.almosteq
75
+ nan = mp.nan
76
+ rand = mp.rand
77
+
78
+ absmin = mp.absmin
79
+ absmax = mp.absmax
80
+
81
+ fraction = mp.fraction
82
+
83
+ linspace = mp.linspace
84
+ arange = mp.arange
85
+
86
+ mpmathify = convert = mp.convert
87
+ mpc = mp.mpc
88
+
89
+ mpi = iv._mpi
90
+
91
+ nstr = mp.nstr
92
+ nprint = mp.nprint
93
+ chop = mp.chop
94
+
95
+ fneg = mp.fneg
96
+ fadd = mp.fadd
97
+ fsub = mp.fsub
98
+ fmul = mp.fmul
99
+ fdiv = mp.fdiv
100
+ fprod = mp.fprod
101
+
102
+ quad = mp.quad
103
+ quadgl = mp.quadgl
104
+ quadts = mp.quadts
105
+ quadosc = mp.quadosc
106
+ quadsubdiv = mp.quadsubdiv
107
+
108
+ invertlaplace = mp.invertlaplace
109
+ invlaptalbot = mp.invlaptalbot
110
+ invlapstehfest = mp.invlapstehfest
111
+ invlapdehoog = mp.invlapdehoog
112
+
113
+ pslq = mp.pslq
114
+ identify = mp.identify
115
+ findpoly = mp.findpoly
116
+
117
+ richardson = mp.richardson
118
+ shanks = mp.shanks
119
+ levin = mp.levin
120
+ cohen_alt = mp.cohen_alt
121
+ nsum = mp.nsum
122
+ nprod = mp.nprod
123
+ difference = mp.difference
124
+ diff = mp.diff
125
+ diffs = mp.diffs
126
+ diffs_prod = mp.diffs_prod
127
+ diffs_exp = mp.diffs_exp
128
+ diffun = mp.diffun
129
+ differint = mp.differint
130
+ taylor = mp.taylor
131
+ pade = mp.pade
132
+ polyval = mp.polyval
133
+ polyroots = mp.polyroots
134
+ fourier = mp.fourier
135
+ fourierval = mp.fourierval
136
+ sumem = mp.sumem
137
+ sumap = mp.sumap
138
+ chebyfit = mp.chebyfit
139
+ limit = mp.limit
140
+
141
+ matrix = mp.matrix
142
+ eye = mp.eye
143
+ diag = mp.diag
144
+ zeros = mp.zeros
145
+ ones = mp.ones
146
+ hilbert = mp.hilbert
147
+ randmatrix = mp.randmatrix
148
+ swap_row = mp.swap_row
149
+ extend = mp.extend
150
+ norm = mp.norm
151
+ mnorm = mp.mnorm
152
+
153
+ lu_solve = mp.lu_solve
154
+ lu = mp.lu
155
+ qr = mp.qr
156
+ unitvector = mp.unitvector
157
+ inverse = mp.inverse
158
+ residual = mp.residual
159
+ qr_solve = mp.qr_solve
160
+ cholesky = mp.cholesky
161
+ cholesky_solve = mp.cholesky_solve
162
+ det = mp.det
163
+ cond = mp.cond
164
+ hessenberg = mp.hessenberg
165
+ schur = mp.schur
166
+ eig = mp.eig
167
+ eig_sort = mp.eig_sort
168
+ eigsy = mp.eigsy
169
+ eighe = mp.eighe
170
+ eigh = mp.eigh
171
+ svd_r = mp.svd_r
172
+ svd_c = mp.svd_c
173
+ svd = mp.svd
174
+ gauss_quadrature = mp.gauss_quadrature
175
+
176
+ expm = mp.expm
177
+ sqrtm = mp.sqrtm
178
+ powm = mp.powm
179
+ logm = mp.logm
180
+ sinm = mp.sinm
181
+ cosm = mp.cosm
182
+
183
+ mpf = mp.mpf
184
+ j = mp.j
185
+ exp = mp.exp
186
+ expj = mp.expj
187
+ expjpi = mp.expjpi
188
+ ln = mp.ln
189
+ im = mp.im
190
+ re = mp.re
191
+ inf = mp.inf
192
+ ninf = mp.ninf
193
+ sign = mp.sign
194
+
195
+ eps = mp.eps
196
+ pi = mp.pi
197
+ ln2 = mp.ln2
198
+ ln10 = mp.ln10
199
+ phi = mp.phi
200
+ e = mp.e
201
+ euler = mp.euler
202
+ catalan = mp.catalan
203
+ khinchin = mp.khinchin
204
+ glaisher = mp.glaisher
205
+ apery = mp.apery
206
+ degree = mp.degree
207
+ twinprime = mp.twinprime
208
+ mertens = mp.mertens
209
+
210
+ ldexp = mp.ldexp
211
+ frexp = mp.frexp
212
+
213
+ fsum = mp.fsum
214
+ fdot = mp.fdot
215
+
216
+ sqrt = mp.sqrt
217
+ cbrt = mp.cbrt
218
+ exp = mp.exp
219
+ ln = mp.ln
220
+ log = mp.log
221
+ log10 = mp.log10
222
+ power = mp.power
223
+ cos = mp.cos
224
+ sin = mp.sin
225
+ tan = mp.tan
226
+ cosh = mp.cosh
227
+ sinh = mp.sinh
228
+ tanh = mp.tanh
229
+ acos = mp.acos
230
+ asin = mp.asin
231
+ atan = mp.atan
232
+ asinh = mp.asinh
233
+ acosh = mp.acosh
234
+ atanh = mp.atanh
235
+ sec = mp.sec
236
+ csc = mp.csc
237
+ cot = mp.cot
238
+ sech = mp.sech
239
+ csch = mp.csch
240
+ coth = mp.coth
241
+ asec = mp.asec
242
+ acsc = mp.acsc
243
+ acot = mp.acot
244
+ asech = mp.asech
245
+ acsch = mp.acsch
246
+ acoth = mp.acoth
247
+ cospi = mp.cospi
248
+ sinpi = mp.sinpi
249
+ sinc = mp.sinc
250
+ sincpi = mp.sincpi
251
+ cos_sin = mp.cos_sin
252
+ cospi_sinpi = mp.cospi_sinpi
253
+ fabs = mp.fabs
254
+ re = mp.re
255
+ im = mp.im
256
+ conj = mp.conj
257
+ floor = mp.floor
258
+ ceil = mp.ceil
259
+ nint = mp.nint
260
+ frac = mp.frac
261
+ root = mp.root
262
+ nthroot = mp.nthroot
263
+ hypot = mp.hypot
264
+ fmod = mp.fmod
265
+ ldexp = mp.ldexp
266
+ frexp = mp.frexp
267
+ sign = mp.sign
268
+ arg = mp.arg
269
+ phase = mp.phase
270
+ polar = mp.polar
271
+ rect = mp.rect
272
+ degrees = mp.degrees
273
+ radians = mp.radians
274
+ atan2 = mp.atan2
275
+ fib = mp.fib
276
+ fibonacci = mp.fibonacci
277
+ lambertw = mp.lambertw
278
+ zeta = mp.zeta
279
+ altzeta = mp.altzeta
280
+ gamma = mp.gamma
281
+ rgamma = mp.rgamma
282
+ factorial = mp.factorial
283
+ fac = mp.fac
284
+ fac2 = mp.fac2
285
+ beta = mp.beta
286
+ betainc = mp.betainc
287
+ psi = mp.psi
288
+ #psi0 = mp.psi0
289
+ #psi1 = mp.psi1
290
+ #psi2 = mp.psi2
291
+ #psi3 = mp.psi3
292
+ polygamma = mp.polygamma
293
+ digamma = mp.digamma
294
+ #trigamma = mp.trigamma
295
+ #tetragamma = mp.tetragamma
296
+ #pentagamma = mp.pentagamma
297
+ harmonic = mp.harmonic
298
+ bernoulli = mp.bernoulli
299
+ bernfrac = mp.bernfrac
300
+ stieltjes = mp.stieltjes
301
+ hurwitz = mp.hurwitz
302
+ dirichlet = mp.dirichlet
303
+ bernpoly = mp.bernpoly
304
+ eulerpoly = mp.eulerpoly
305
+ eulernum = mp.eulernum
306
+ polylog = mp.polylog
307
+ clsin = mp.clsin
308
+ clcos = mp.clcos
309
+ gammainc = mp.gammainc
310
+ gammaprod = mp.gammaprod
311
+ binomial = mp.binomial
312
+ rf = mp.rf
313
+ ff = mp.ff
314
+ hyper = mp.hyper
315
+ hyp0f1 = mp.hyp0f1
316
+ hyp1f1 = mp.hyp1f1
317
+ hyp1f2 = mp.hyp1f2
318
+ hyp2f1 = mp.hyp2f1
319
+ hyp2f2 = mp.hyp2f2
320
+ hyp2f0 = mp.hyp2f0
321
+ hyp2f3 = mp.hyp2f3
322
+ hyp3f2 = mp.hyp3f2
323
+ hyperu = mp.hyperu
324
+ hypercomb = mp.hypercomb
325
+ meijerg = mp.meijerg
326
+ appellf1 = mp.appellf1
327
+ appellf2 = mp.appellf2
328
+ appellf3 = mp.appellf3
329
+ appellf4 = mp.appellf4
330
+ hyper2d = mp.hyper2d
331
+ bihyper = mp.bihyper
332
+ erf = mp.erf
333
+ erfc = mp.erfc
334
+ erfi = mp.erfi
335
+ erfinv = mp.erfinv
336
+ npdf = mp.npdf
337
+ ncdf = mp.ncdf
338
+ expint = mp.expint
339
+ e1 = mp.e1
340
+ ei = mp.ei
341
+ li = mp.li
342
+ ci = mp.ci
343
+ si = mp.si
344
+ chi = mp.chi
345
+ shi = mp.shi
346
+ fresnels = mp.fresnels
347
+ fresnelc = mp.fresnelc
348
+ airyai = mp.airyai
349
+ airybi = mp.airybi
350
+ airyaizero = mp.airyaizero
351
+ airybizero = mp.airybizero
352
+ scorergi = mp.scorergi
353
+ scorerhi = mp.scorerhi
354
+ ellipk = mp.ellipk
355
+ ellipe = mp.ellipe
356
+ ellipf = mp.ellipf
357
+ ellippi = mp.ellippi
358
+ elliprc = mp.elliprc
359
+ elliprj = mp.elliprj
360
+ elliprf = mp.elliprf
361
+ elliprd = mp.elliprd
362
+ elliprg = mp.elliprg
363
+ agm = mp.agm
364
+ jacobi = mp.jacobi
365
+ chebyt = mp.chebyt
366
+ chebyu = mp.chebyu
367
+ legendre = mp.legendre
368
+ legenp = mp.legenp
369
+ legenq = mp.legenq
370
+ hermite = mp.hermite
371
+ pcfd = mp.pcfd
372
+ pcfu = mp.pcfu
373
+ pcfv = mp.pcfv
374
+ pcfw = mp.pcfw
375
+ gegenbauer = mp.gegenbauer
376
+ laguerre = mp.laguerre
377
+ spherharm = mp.spherharm
378
+ besselj = mp.besselj
379
+ j0 = mp.j0
380
+ j1 = mp.j1
381
+ besseli = mp.besseli
382
+ bessely = mp.bessely
383
+ besselk = mp.besselk
384
+ besseljzero = mp.besseljzero
385
+ besselyzero = mp.besselyzero
386
+ hankel1 = mp.hankel1
387
+ hankel2 = mp.hankel2
388
+ struveh = mp.struveh
389
+ struvel = mp.struvel
390
+ angerj = mp.angerj
391
+ webere = mp.webere
392
+ lommels1 = mp.lommels1
393
+ lommels2 = mp.lommels2
394
+ whitm = mp.whitm
395
+ whitw = mp.whitw
396
+ ber = mp.ber
397
+ bei = mp.bei
398
+ ker = mp.ker
399
+ kei = mp.kei
400
+ coulombc = mp.coulombc
401
+ coulombf = mp.coulombf
402
+ coulombg = mp.coulombg
403
+ barnesg = mp.barnesg
404
+ superfac = mp.superfac
405
+ hyperfac = mp.hyperfac
406
+ loggamma = mp.loggamma
407
+ siegeltheta = mp.siegeltheta
408
+ siegelz = mp.siegelz
409
+ grampoint = mp.grampoint
410
+ zetazero = mp.zetazero
411
+ riemannr = mp.riemannr
412
+ primepi = mp.primepi
413
+ primepi2 = mp.primepi2
414
+ primezeta = mp.primezeta
415
+ bell = mp.bell
416
+ polyexp = mp.polyexp
417
+ expm1 = mp.expm1
418
+ log1p = mp.log1p
419
+ powm1 = mp.powm1
420
+ unitroots = mp.unitroots
421
+ cyclotomic = mp.cyclotomic
422
+ mangoldt = mp.mangoldt
423
+ secondzeta = mp.secondzeta
424
+ nzeros = mp.nzeros
425
+ backlunds = mp.backlunds
426
+ lerchphi = mp.lerchphi
427
+ stirling1 = mp.stirling1
428
+ stirling2 = mp.stirling2
429
+ squarew = mp.squarew
430
+ trianglew = mp.trianglew
431
+ sawtoothw = mp.sawtoothw
432
+ unit_triangle = mp.unit_triangle
433
+ sigmoid = mp.sigmoid
434
+
435
+ # be careful when changing this name, don't use test*!
436
+ def runtests():
437
+ """
438
+ Run all mpmath tests and print output.
439
+ """
440
+ import os.path
441
+ from inspect import getsourcefile
442
+ from .tests import runtests as tests
443
+ testdir = os.path.dirname(os.path.abspath(getsourcefile(tests)))
444
+ importdir = os.path.abspath(testdir + '/../..')
445
+ tests.testit(importdir, testdir)
446
+
447
+ def doctests(filter=[]):
448
+ import sys
449
+ from timeit import default_timer as clock
450
+ for i, arg in enumerate(sys.argv):
451
+ if '__init__.py' in arg:
452
+ filter = [sn for sn in sys.argv[i+1:] if not sn.startswith("-")]
453
+ break
454
+ import doctest
455
+ globs = globals().copy()
456
+ for obj in globs: #sorted(globs.keys()):
457
+ if filter:
458
+ if not sum([pat in obj for pat in filter]):
459
+ continue
460
+ sys.stdout.write(str(obj) + " ")
461
+ sys.stdout.flush()
462
+ t1 = clock()
463
+ doctest.run_docstring_examples(globs[obj], {}, verbose=("-v" in sys.argv))
464
+ t2 = clock()
465
+ print(round(t2-t1, 3))
466
+
467
+ if __name__ == '__main__':
468
+ doctests()
lib/python3.11/site-packages/mpmath/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (14.7 kB). View file
 
lib/python3.11/site-packages/mpmath/__pycache__/ctx_base.cpython-311.pyc ADDED
Binary file (24.4 kB). View file
 
lib/python3.11/site-packages/mpmath/__pycache__/ctx_fp.cpython-311.pyc ADDED
Binary file (13 kB). View file
 
lib/python3.11/site-packages/mpmath/__pycache__/ctx_iv.cpython-311.pyc ADDED
Binary file (38.7 kB). View file
 
lib/python3.11/site-packages/mpmath/__pycache__/ctx_mp.cpython-311.pyc ADDED
Binary file (71.2 kB). View file
 
lib/python3.11/site-packages/mpmath/__pycache__/ctx_mp_python.cpython-311.pyc ADDED
Binary file (61.2 kB). View file
 
lib/python3.11/site-packages/mpmath/__pycache__/function_docs.cpython-311.pyc ADDED
Binary file (285 kB). View file