Initial model upload
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- LICENSE +201 -0
- README.md +497 -3
- __pycache__/train.cpython-310.pyc +0 -0
- example/7B.yaml +37 -0
- finetune/__init__.py +0 -0
- finetune/__pycache__/__init__.cpython-310.pyc +0 -0
- finetune/__pycache__/__init__.cpython-38.pyc +0 -0
- finetune/__pycache__/args.cpython-310.pyc +0 -0
- finetune/__pycache__/args.cpython-38.pyc +0 -0
- finetune/__pycache__/checkpointing.cpython-310.pyc +0 -0
- finetune/__pycache__/checkpointing.cpython-38.pyc +0 -0
- finetune/__pycache__/distributed.cpython-310.pyc +0 -0
- finetune/__pycache__/distributed.cpython-38.pyc +0 -0
- finetune/__pycache__/eval.cpython-310.pyc +0 -0
- finetune/__pycache__/loss.cpython-310.pyc +0 -0
- finetune/__pycache__/mixed_precision.cpython-310.pyc +0 -0
- finetune/__pycache__/utils.cpython-310.pyc +0 -0
- finetune/__pycache__/wrapped_model.cpython-310.pyc +0 -0
- finetune/args.py +116 -0
- finetune/checkpointing.py +246 -0
- finetune/data/__init__.py +0 -0
- finetune/data/__pycache__/__init__.cpython-310.pyc +0 -0
- finetune/data/__pycache__/__init__.cpython-38.pyc +0 -0
- finetune/data/__pycache__/args.cpython-310.pyc +0 -0
- finetune/data/__pycache__/args.cpython-38.pyc +0 -0
- finetune/data/__pycache__/data_loader.cpython-310.pyc +0 -0
- finetune/data/__pycache__/dataset.cpython-310.pyc +0 -0
- finetune/data/__pycache__/dataset.cpython-38.pyc +0 -0
- finetune/data/__pycache__/exceptions.cpython-310.pyc +0 -0
- finetune/data/__pycache__/exceptions.cpython-38.pyc +0 -0
- finetune/data/__pycache__/tokenize.cpython-310.pyc +0 -0
- finetune/data/__pycache__/tokenize.cpython-38.pyc +0 -0
- finetune/data/args.py +61 -0
- finetune/data/data_loader.py +126 -0
- finetune/data/dataset.py +475 -0
- finetune/data/exceptions.py +56 -0
- finetune/data/tokenize.py +355 -0
- finetune/distributed.py +59 -0
- finetune/eval.py +77 -0
- finetune/loss.py +16 -0
- finetune/mixed_precision.py +47 -0
- finetune/monitoring/__init__.py +0 -0
- finetune/monitoring/__pycache__/__init__.cpython-310.pyc +0 -0
- finetune/monitoring/__pycache__/metrics_logger.cpython-310.pyc +0 -0
- finetune/monitoring/__pycache__/utils.cpython-310.pyc +0 -0
- finetune/monitoring/metrics_logger.py +226 -0
- finetune/monitoring/utils.py +34 -0
- finetune/utils.py +83 -0
- finetune/wrapped_model.py +227 -0
- huggingface.ipynb +40 -0
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,3 +1,497 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Mistral-finetune
|
2 |
+
|
3 |
+
<a target="_blank" href="https://colab.research.google.com/github/mistralai/mistral-finetune/blob/main/tutorials/mistral_finetune_7b.ipynb">
|
4 |
+
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
|
5 |
+
</a>
|
6 |
+
|
7 |
+
|
8 |
+
`mistral-finetune` is a light-weight codebase that enables memory-efficient and performant finetuning of Mistral's models.
|
9 |
+
It is based on [LoRA](https://arxiv.org/abs/2106.09685), a training paradigm where most weights are frozen and only 1-2% additional weights in the form of low-rank matrix perturbations are trained.
|
10 |
+
|
11 |
+
For maximum efficiency it is recommended to use a A100 or H100 GPU. The codebase is optimized
|
12 |
+
for multi-GPU-single-node training setups, but for smaller models, such as the 7B a single GPU suffices.
|
13 |
+
|
14 |
+
> **Note**
|
15 |
+
>
|
16 |
+
> - The goal of this repository is to provide a simple, guided entrypoint to finetune Mistral models.
|
17 |
+
> As such, it is fairly opinionated (especially around data formatting) and does not aim at being exhaustive
|
18 |
+
> across multiple model architecture or hardware types.
|
19 |
+
> For more generic approaches, you can check out some other great projects like
|
20 |
+
> [torchtune](https://pytorch.org/torchtune/stable/overview.html).
|
21 |
+
|
22 |
+
## Installation
|
23 |
+
|
24 |
+
To get started with Mistral LoRA fine-tuning, follow these steps:
|
25 |
+
|
26 |
+
1. Clone this repository:
|
27 |
+
```
|
28 |
+
cd $HOME && git clone https://github.com/mistralai/mistral-finetune.git
|
29 |
+
```
|
30 |
+
|
31 |
+
2. Install all required dependencies:
|
32 |
+
```
|
33 |
+
cd mistral-finetune
|
34 |
+
pip install -r requirements.txt
|
35 |
+
```
|
36 |
+
|
37 |
+
## Model download
|
38 |
+
|
39 |
+
We recommend fine-tuning one of the official Mistral models which you can download here:
|
40 |
+
|
41 |
+
| Model | Link | Checksum |
|
42 |
+
|----------------|---------------------------------------------------------------------------------------------------------|-----------------------------------|
|
43 |
+
| 7B Base V3 | [7B Base](https://models.mistralcdn.com/mistral-7b-v0-3/mistral-7B-v0.3.tar) | `0663b293810d7571dad25dae2f2a5806`|
|
44 |
+
| 7B Instruct v3 | [7B Instruct v3](https://models.mistralcdn.com/mistral-7b-v0-3/mistral-7B-Instruct-v0.3.tar) | `80b71fcb6416085bcb4efad86dfb4d52`|
|
45 |
+
| 8x7B Base V1 | [8x7B Base](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1) | (HF link) |
|
46 |
+
| 8x7B Instruct V1 | [8x7B Instruct](https://models.mistralcdn.com/mixtral-8x7b-v0-1/Mixtral-8x7B-v0.1-Instruct.tar) | `8e2d3930145dc43d3084396f49d38a3f` |
|
47 |
+
| 8x22 Instruct V3 | [8x22 Instruct](https://models.mistralcdn.com/mixtral-8x22b-v0-3/mixtral-8x22B-Instruct-v0.3.tar) | `471a02a6902706a2f1e44a693813855b`|
|
48 |
+
| 8x22B Base V3 | [8x22B Base](https://models.mistralcdn.com/mixtral-8x22b-v0-3/mixtral-8x22B-v0.3.tar) | `a2fa75117174f87d1197e3a4eb50371a`|
|
49 |
+
|
50 |
+
**Important Notice**: For 8x7B Base V1 and 8x7B Instruct V1, it is necessary to use our v3 tokenizer and extend the vocabulary size to 32768 prior to fine-tuning. For detailed instructions on this process, please refer to the ["Model extension"](https://github.com/mistralai/mistral-finetune?tab=readme-ov-file#model-extension) section.
|
51 |
+
|
52 |
+
E.g., to download the 7B-base model you can run the following command:
|
53 |
+
```sh
|
54 |
+
mkdir -p ~/${HOME}/mistral_models
|
55 |
+
cd ${HOME} && wget https://models.mistralcdn.com/mistral-7b-v0-3/mistral-7B-v0.3.tar
|
56 |
+
tar -xf mistral-7B-v0.3.tar -C mistral_models
|
57 |
+
```
|
58 |
+
|
59 |
+
Make sure to modify your training script and add the path to the downloaded
|
60 |
+
folder as `model_id_or_path`.
|
61 |
+
|
62 |
+
E.g., modify [example/7B.yaml](https://github.com/mistralai/mistral-finetune/blob/main/example/7B.yaml) to include the absolute path to `$HOME/mistral_models/7B`:
|
63 |
+
|
64 |
+
```
|
65 |
+
model_id_or_path: "/Users/johndoe/mistral_models/7B"
|
66 |
+
```
|
67 |
+
|
68 |
+
## Prepare dataset
|
69 |
+
|
70 |
+
To ensure effective training, `mistral-finetune` has strict
|
71 |
+
requirements for how the training data has to be formatted.
|
72 |
+
|
73 |
+
All data files must be stored in jsonl format files.
|
74 |
+
|
75 |
+
You can build two types of data files:
|
76 |
+
|
77 |
+
### _Pretrain_:
|
78 |
+
|
79 |
+
Pretrain data corresponds to plain text data stored in the `"text"` key. E.g:
|
80 |
+
|
81 |
+
```jsonl
|
82 |
+
{"text": "Text contained in document n°1"}
|
83 |
+
{"text": "Text contained in document n°2"}
|
84 |
+
```
|
85 |
+
|
86 |
+
### _Instruct_:
|
87 |
+
|
88 |
+
Currently two different types of instruction following data are supported:
|
89 |
+
|
90 |
+
- _Instruct_: conversational data stored in the `"messages"` key in the form of a list. Each list item is a dictionary containing the `"content"` and `"role"` keys. `"role"` is a string being one of "user", "assistant" or "system_prompt". The loss will only be computed if "role" == "assistant". E.g.:
|
91 |
+
|
92 |
+
```jsonl
|
93 |
+
{
|
94 |
+
"messages": [
|
95 |
+
{
|
96 |
+
"role": "user",
|
97 |
+
"content": "User interaction n°1 contained in document n°1"
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"role": "assistant",
|
101 |
+
"content": "Bot interaction n°1 contained in document n°1"
|
102 |
+
},
|
103 |
+
{
|
104 |
+
"role": "user",
|
105 |
+
"content": "User interaction n°2 contained in document n°1"
|
106 |
+
},
|
107 |
+
{
|
108 |
+
"role": "assistant",
|
109 |
+
"content": "Bot interaction n°2 contained in document n°1"
|
110 |
+
}
|
111 |
+
]
|
112 |
+
}
|
113 |
+
{
|
114 |
+
"messages": [
|
115 |
+
{
|
116 |
+
"role": "user",
|
117 |
+
"content": "User interaction n°1 contained in document n°2"
|
118 |
+
},
|
119 |
+
{
|
120 |
+
"role": "assistant",
|
121 |
+
"content": "Bot interaction n°1 contained in document n°2"
|
122 |
+
},
|
123 |
+
{
|
124 |
+
"role": "user",
|
125 |
+
"content": "User interaction n°2 contained in document n°2"
|
126 |
+
},
|
127 |
+
{
|
128 |
+
"role": "assistant",
|
129 |
+
"content": "Bot interaction n°2 contained in document n°2",
|
130 |
+
"weight": 0, # don't train on n°2
|
131 |
+
},
|
132 |
+
{
|
133 |
+
"role": "user",
|
134 |
+
"content": "User interaction n°3 contained in document n°2"
|
135 |
+
},
|
136 |
+
{
|
137 |
+
"role": "assistant",
|
138 |
+
"content": "Bot interaction n°3 contained in document n°2"
|
139 |
+
}
|
140 |
+
]
|
141 |
+
}
|
142 |
+
```
|
143 |
+
|
144 |
+
- _Function calling_: conversational data stored in the `"messages"` key in the form of a list. Each list item is a dictionary containing the `"role"` and `"content"` or `"tool_calls"` keys. `"role"` is a string being one of "user", "assistant", "system_prompt", or "tool". The loss will only be computed if "role" == "assistant".
|
145 |
+
|
146 |
+
**Note**: In function calling the `"id"` of `"tool_calls"` and the `"tool_call_id"` are randomly generated strings of exactly 9 chars. We recommend to generate this automatically
|
147 |
+
in a data preparation script as is done [here](https://github.com/mistralai/mistral-finetune/blob/208b25c0f7299bb78d06cea25b82adee03834319/utils/reformat_data_glaive.py#L74).
|
148 |
+
|
149 |
+
E.g.:
|
150 |
+
|
151 |
+
```jsonl
|
152 |
+
{
|
153 |
+
"messages": [
|
154 |
+
{
|
155 |
+
"role": "system",
|
156 |
+
"content": "You are an helpful assistant who has access to the following functions to help the user, you can use the functions if needed"
|
157 |
+
},
|
158 |
+
{
|
159 |
+
"role": "user",
|
160 |
+
"content": "Can you help me generate an anagram of the word \"listen\"?"
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"role": "assistant",
|
164 |
+
"tool_calls": [
|
165 |
+
{
|
166 |
+
"id": "TX92Jm8Zi",
|
167 |
+
"type": "function",
|
168 |
+
"function": {
|
169 |
+
"name": "generate_anagram",
|
170 |
+
"arguments": "{\"word\": \"listen\"}"
|
171 |
+
}
|
172 |
+
}
|
173 |
+
]
|
174 |
+
},
|
175 |
+
{
|
176 |
+
"role": "tool",
|
177 |
+
"content": "{\"anagram\": \"silent\"}",
|
178 |
+
"tool_call_id": "TX92Jm8Zi"
|
179 |
+
},
|
180 |
+
{
|
181 |
+
"role": "assistant",
|
182 |
+
"content": "The anagram of the word \"listen\" is \"silent\"."
|
183 |
+
},
|
184 |
+
{
|
185 |
+
"role": "user",
|
186 |
+
"content": "That's amazing! Can you generate an anagram for the word \"race\"?"
|
187 |
+
},
|
188 |
+
{
|
189 |
+
"role": "assistant",
|
190 |
+
"tool_calls": [
|
191 |
+
{
|
192 |
+
"id": "3XhQnxLsT",
|
193 |
+
"type": "function",
|
194 |
+
"function": {
|
195 |
+
"name": "generate_anagram",
|
196 |
+
"arguments": "{\"word\": \"race\"}"
|
197 |
+
}
|
198 |
+
}
|
199 |
+
]
|
200 |
+
}
|
201 |
+
],
|
202 |
+
"tools": [
|
203 |
+
{
|
204 |
+
"type": "function",
|
205 |
+
"function": {
|
206 |
+
"name": "generate_anagram",
|
207 |
+
"description": "Generate an anagram of a given word",
|
208 |
+
"parameters": {
|
209 |
+
"type": "object",
|
210 |
+
"properties": {
|
211 |
+
"word": {
|
212 |
+
"type": "string",
|
213 |
+
"description": "The word to generate an anagram of"
|
214 |
+
}
|
215 |
+
},
|
216 |
+
"required": [
|
217 |
+
"word"
|
218 |
+
]
|
219 |
+
}
|
220 |
+
}
|
221 |
+
}
|
222 |
+
]
|
223 |
+
}
|
224 |
+
```
|
225 |
+
|
226 |
+
## Verify dataset
|
227 |
+
|
228 |
+
Before starting a training run you should verify that your dataset is correctly formatted and get an
|
229 |
+
estimation of the training time. You can do so by using the [./utils/validate_data](https://github.com/mistralai/mistral-finetune/blob/main/utils/validate_data.py) script.
|
230 |
+
|
231 |
+
Note that this step is crucial to ensure that the data is correctly formatted.
|
232 |
+
|
233 |
+
### Instruction following
|
234 |
+
|
235 |
+
Let's go over a simple example to train a model in instruction following:
|
236 |
+
|
237 |
+
- 1. **Load a chunk of [Ultachat_200k](https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k)**
|
238 |
+
|
239 |
+
Create the data folder and navigate to the folder.
|
240 |
+
```sh
|
241 |
+
cd $HOME && mkdir -p data && cd $HOME/data
|
242 |
+
```
|
243 |
+
|
244 |
+
Load the data into a Pandas Dataframe.
|
245 |
+
|
246 |
+
**Note**: Make sure to have pandas and pyarrow installed (`pip install pandas pyarrow`).
|
247 |
+
|
248 |
+
```py
|
249 |
+
import pandas as pd
|
250 |
+
|
251 |
+
df = pd.read_parquet('https://huggingface.co/datasets/HuggingFaceH4/ultrachat_200k/resolve/main/data/test_gen-00000-of-00001-3d4cd8309148a71f.parquet')
|
252 |
+
```
|
253 |
+
- 2. Split into train and eval
|
254 |
+
|
255 |
+
```py
|
256 |
+
df_train=df.sample(frac=0.95,random_state=200)
|
257 |
+
df_eval=df.drop(df_train.index)
|
258 |
+
```
|
259 |
+
|
260 |
+
- 3. Save data to jsonl
|
261 |
+
|
262 |
+
```py
|
263 |
+
df_train.to_json("ultrachat_chunk_train.jsonl", orient="records", lines=True)
|
264 |
+
df_eval.to_json("ultrachat_chunk_eval.jsonl", orient="records", lines=True)
|
265 |
+
```
|
266 |
+
|
267 |
+
- 4. Modify your training yaml to include the ultrachat dataset and verify the yaml
|
268 |
+
|
269 |
+
Modify [example/7B.yaml](https://github.com/mistralai/mistral-finetune/blob/main/example/7B.yaml) to include the absolute path to `$HOME/data/ultrachat_chunk_train.jsonl` as well as a dataset mixing weight for training and `$HOME/data/ultrachat_chunk_eval.jsonl` for eval, *e.g.*
|
270 |
+
|
271 |
+
```
|
272 |
+
data:
|
273 |
+
instruct_data: "/Users/johndoe/data/ultrachat_chunk_train.jsonl"
|
274 |
+
eval_instruct_data: "/Users/johndoe/data/ultrachat_chunk_eval.jsonl"
|
275 |
+
```
|
276 |
+
|
277 |
+
Now you can verify your training yaml to make sure the data is correctly formatted and to get an estimate of your training time.
|
278 |
+
|
279 |
+
```
|
280 |
+
cd $HOME/mistral-finetune
|
281 |
+
python -m utils.validate_data --train_yaml example/7B.yaml
|
282 |
+
```
|
283 |
+
|
284 |
+
Upon completion you should see an error report with many of the following errors:
|
285 |
+
|
286 |
+
```
|
287 |
+
The data in line 1412 of dataset /Users/johndoe/data/ultrachat_chunk_eval.jsonl is incorrectly formated.Expected last role to be one of: [assistant] but got user
|
288 |
+
The data in line 1413 of dataset /Users/johndoe/data/ultrachat_chunk_eval.jsonl is incorrectly formated.Expected last role to be one of: [assistant] but got user
|
289 |
+
The data in line 1414 of dataset /Users/johndoe/data/ultrachat_chunk_eval.jsonl is incorrectly formated.Expected last role to be one of: [assistant] but got user
|
290 |
+
The data in line 1415 of dataset /Users/johndoe/data/ultrachat_chunk_eval.jsonl is incorrectly formated.Expected last role to be one of: [assistant] but got user
|
291 |
+
```
|
292 |
+
|
293 |
+
Many conversations seem to end with the 'user' role which is unnecessary as we only train on 'assistant' messages and thus would unnecessarily process data.
|
294 |
+
|
295 |
+
You can make use of [./utils/reformat_data.py](https://github.com/mistralai/mistral-finetune/blob/main/utils/reformat_data.py) to correct the data:
|
296 |
+
|
297 |
+
```
|
298 |
+
cd $HOME/mistral-finetune
|
299 |
+
python -m utils.reformat_data $HOME/data/ultrachat_chunk_train.jsonl
|
300 |
+
python -m utils.reformat_data $HOME/data/ultrachat_chunk_eval.jsonl
|
301 |
+
```
|
302 |
+
|
303 |
+
You should see that a couple of samples will be skipped.
|
304 |
+
|
305 |
+
- 5. Potentially change number of training steps
|
306 |
+
|
307 |
+
Upon correction of the dataset, run the script again
|
308 |
+
|
309 |
+
```
|
310 |
+
cd $HOME/mistral-finetune
|
311 |
+
python -m utils.validate_data --train_yaml example/7B.yaml
|
312 |
+
```
|
313 |
+
|
314 |
+
You should get a summary of the data input and training parameters:
|
315 |
+
|
316 |
+
```
|
317 |
+
Train States
|
318 |
+
--------------------
|
319 |
+
{
|
320 |
+
"expected": {
|
321 |
+
"eta": "00:52:44",
|
322 |
+
"data_tokens": 25169147,
|
323 |
+
"train_tokens": 131072000,
|
324 |
+
"epochs": "5.21",
|
325 |
+
"max_steps": 500,
|
326 |
+
"data_tokens_per_dataset": {
|
327 |
+
"/Users/johndoe/data/ultrachat_chunk_train.jsonl": "25169147.0"
|
328 |
+
},
|
329 |
+
"train_tokens_per_dataset": {
|
330 |
+
"/Users/johndoe/data/ultrachat_chunk_train.jsonl": "131072000.0"
|
331 |
+
},
|
332 |
+
"epochs_per_dataset": {
|
333 |
+
"/Users/johndoe/data/ultrachat_chunk_train.jsonl": "5.2"
|
334 |
+
}
|
335 |
+
},
|
336 |
+
}
|
337 |
+
```
|
338 |
+
|
339 |
+
Having `max_steps` set to 500 would lead to iterating through the dataset roughly 5 times which is reasonable, but might
|
340 |
+
be a bit too much. A recommended setting is shown below which would only take 30min on a 8xH100 cluster.
|
341 |
+
|
342 |
+
### Function calling
|
343 |
+
|
344 |
+
Next let's go over a more advanced use case to fine-tune a model on function calling.
|
345 |
+
Function calling requires the data to be in the format as [explained above](#instruct). Let's go over an example.
|
346 |
+
|
347 |
+
- 1. **Load a chat-formatted version of the [Glaive function calling dataset](https://huggingface.co/datasets/Locutusque/function-calling-chatml)**
|
348 |
+
|
349 |
+
Create the data folder and navigate to the folder.
|
350 |
+
```sh
|
351 |
+
cd $HOME && mkdir -p data && cd $HOME/data
|
352 |
+
```
|
353 |
+
|
354 |
+
Load the data into a Pandas Dataframe.
|
355 |
+
|
356 |
+
**Note**: Make sure to have pandas and pyarrow installed (`pip install pandas pyarrow`).
|
357 |
+
|
358 |
+
```py
|
359 |
+
import pandas as pd
|
360 |
+
|
361 |
+
df = pd.read_parquet('https://huggingface.co/datasets/Locutusque/function-calling-chatml/resolve/main/data/train-00000-of-00001-f0b56c6983b4a78f.parquet')
|
362 |
+
```
|
363 |
+
- 2. Split into train and eval
|
364 |
+
|
365 |
+
```py
|
366 |
+
df_train=df.sample(frac=0.95,random_state=200)
|
367 |
+
df_eval=df.drop(df_train.index)
|
368 |
+
```
|
369 |
+
|
370 |
+
- 3. Save data to jsonl
|
371 |
+
|
372 |
+
```py
|
373 |
+
df_train.to_json("glaive_train.jsonl", orient="records", lines=True)
|
374 |
+
df_eval.to_json("glaive_eval.jsonl", orient="records", lines=True)
|
375 |
+
```
|
376 |
+
|
377 |
+
- 4. Reformat dataset
|
378 |
+
|
379 |
+
As one can see the dataset does not follow the required function calling format, so it will need to be reformatted. Among other things `"from"` should be renamed to `"user"` and superfluous `"\n"` characters should be removed.
|
380 |
+
For this dataset you can make use of [`./utils/reformat_data_glaive.py`](https://github.com/mistralai/mistral-finetune/blob/main/utils/reformat_data_glaive.py):
|
381 |
+
|
382 |
+
```
|
383 |
+
cd $HOME/mistral-finetune
|
384 |
+
python -m utils.reformat_data_glaive $HOME/data/glaive_train.jsonl
|
385 |
+
python -m utils.reformat_data_glaive $HOME/data/glaive_eval.jsonl
|
386 |
+
```
|
387 |
+
|
388 |
+
Running this command will make sure that most samples are in the correct format.
|
389 |
+
|
390 |
+
**Note**: It is impossible to write reformatting scripts that work for all kinds of datasets.
|
391 |
+
If you have datasets that don't yet follow the required format above, you will most probably have to
|
392 |
+
create a reformatting script yourself (mistral-chat or chat-gpt is your best friend here!).
|
393 |
+
|
394 |
+
- 5. Validate dataset
|
395 |
+
|
396 |
+
You can now validate the dataset by setting `data.instruct_data` and `data.eval_instruct_data` to
|
397 |
+
`$HOME/data/glaive_train.jsonl` and `$HOME/data/glaive_eval.jsonl` in `example/7B.yaml` respectively.
|
398 |
+
|
399 |
+
The reformatted datasets still has some errors which can be removed with `--create_corrected`. For this, make sure to add
|
400 |
+
`--create_corrected` as follows:
|
401 |
+
|
402 |
+
```
|
403 |
+
cd $HOME/mistral-finetune
|
404 |
+
python -m utils.validate_data --train_yaml example/7B.yaml --create_corrected
|
405 |
+
```
|
406 |
+
|
407 |
+
Running this command will show a couple of errors and save two new datasets `$HOME/data/glaive_train.jsonl.corrected` and `$HOME/data/glaive_eval.jsonl.corrected`. Make sure to use these two dataset in `example/7B.yaml` and run the command again. Now the dataset should be correctly formatted!
|
408 |
+
|
409 |
+
|
410 |
+
## Start training
|
411 |
+
|
412 |
+
Having followed the [dataset verification section](#verify-dataset), we can now start training.
|
413 |
+
For faster training, we recommend setting max_steps to only 300. Make sure to define `run_dir` to your experiment folder and optionally set `wandb_project` to a Weights & Biases project for logging`, *e.g.*:
|
414 |
+
```
|
415 |
+
max_steps: 300
|
416 |
+
run_dir: "/Users/johndoe/ultra_chat_test"
|
417 |
+
wandb.project: ultra_chat
|
418 |
+
```
|
419 |
+
|
420 |
+
Optionally you can also set `wandb`
|
421 |
+
|
422 |
+
Save the training configuration and start training! Make sure to set `--nproc-per-node` to the number of available GPUs.
|
423 |
+
|
424 |
+
```
|
425 |
+
cd $HOME/mistral-finetune
|
426 |
+
torchrun --nproc-per-node 8 --master_port $RANDOM -m train example/7B.yaml
|
427 |
+
```
|
428 |
+
|
429 |
+
Training on ultra-chat should take around 30min on a 8xH100 node and the resulting weights should give an MT Bench score around 6.3.
|
430 |
+
|
431 |
+
Training on glaive should take around 1h on a 8xH100 node and the resulting weights should work nicely for function calling.
|
432 |
+
|
433 |
+
## Customizing training configuration
|
434 |
+
|
435 |
+
The example `mistral-finetune/examples/7B` defines reasonable parameters for learning rate, weight decay, etc... but you are advised to
|
436 |
+
customize these settings for your use case.
|
437 |
+
|
438 |
+
Generally, a training configuration should fill the following parameters:
|
439 |
+
|
440 |
+
- `model_id_or_path` defines the model to start training from. This can be a path to a pre-trained model or a local model directory.
|
441 |
+
- `run_dir` defines the directory where training checkpoints and metrics are stored.
|
442 |
+
- `seq_len` defines the sequence length for training. This is the maximum length of input sequences the model will process. Samples are packed to reach a length of `seq_len` for maximum training efficiency.
|
443 |
+
- `batch_size` defines the number of training examples used per GPU. **Note**: The overall effective batch_size (in tokens) across all GPUs equals `num_gpus` x `batch_size` x `seq_len`.
|
444 |
+
- `max_steps` defines the maximum number of training steps. This is the total number of iterations the training process will run. It can be adjusted based on the specific needs of your training scenario. Total number of tokens seen during training is `max_steps` x `num_gpus` x `batch_size` x `seq_len`.
|
445 |
+
- `optim.lr` defines the learning rate. This is the initial learning rate for the optimizer.
|
446 |
+
- `optim.weight_decay` defines weight decay. Weight decay is a regularization technique used to prevent overfitting by penalizing large weights. We recommend leaving it at 0.1.
|
447 |
+
- `optim.pct_start` defines the percentage of the total training steps used for the learning rate warm-up phase before it starts to decrease. It corresponds to pct_start of PyTorch's OneCycleLR.
|
448 |
+
- `lora.rank` defines the size of the LoRA (Low-Rank Adaptation) adapters. We recommend 64 or less, which adjusts the rank of the low-rank decomposition used in LoRA.
|
449 |
+
- `seed` defines the random seed for initialization and data shuffling/sampling. Setting a seed ensures reproducibility of results.
|
450 |
+
- `log_freq` defines the logging frequency. This specifies how often (in steps) to log training metrics.
|
451 |
+
- `data.instruct_data` is the path to the instruction data used for training. This field has to be filled with one or multiple data sources in the format as explained above. Each data source should either be a path to jsonl file of a path to a directory containing jsonl files followed by a weighting to define the importance of this dataset: `<path/to/data_source>:<weight>`. E.g.: `data.instruct_data: "/path/to/data1.jsonl:5.,/path/to/data2.jsonl:1.,/path/to/dir_of_jsonls:1."`
|
452 |
+
- `data.data` is an optional path to additional pretraining data in the format as explained above. Note that this field can be left blank.
|
453 |
+
- `data.eval_instruct_data` is an optional path to evaluation instruction data to run cross-validation at every `eval_freq` steps. Cross-validation metrics are displayed as `loss` and `perplexity`.
|
454 |
+
- `eval_freq` defines how often (in steps) to evaluate the model. This specifies the interval at which the model is evaluated on the validation set.
|
455 |
+
- `no_eval` is a flag to enable or disable intermediate evaluation. Setting it to False enables periodic evaluation during training.
|
456 |
+
- `ckpt_freq` defines how often (in steps) to save checkpoints. This specifies the interval at which the model's state is saved.
|
457 |
+
- `ckpt_only_lora` defines whether to only save the trained LoRA checkpoints or whether the trained LoRA should directly be merged into the base model and saved. **Note**: When setting `ckpt_only_lora=False` make sure that you have enough CPU and GPU memory to save the full model on a single process (this is usually only possible for the 7B model).
|
458 |
+
- `wandb.key` is used to pass your Weights & Biases (wandb) API key for logging. This allows you to log training metrics to the wandb dashboard.
|
459 |
+
- `wandb.project` defines the wandb project name. This is where the training run will be logged in the wandb interface.
|
460 |
+
|
461 |
+
## Inference
|
462 |
+
|
463 |
+
Once your model is trained, you should try it out in inference. We recommend using [mistral-inference](https://github.com/mistralai/mistral-inference).
|
464 |
+
|
465 |
+
Make sure to have `mistral_inference` correctly installed:
|
466 |
+
```
|
467 |
+
pip install mistral_inference
|
468 |
+
```
|
469 |
+
|
470 |
+
Assuming your `lora.safetensors` is saved under `$HOME/ultra_chat_test/checkpoints/checkpoint_000300/consolidated/lora.safetensors`, you can chat with the model using `mistral_inference`, *e.g.*:
|
471 |
+
|
472 |
+
```sh
|
473 |
+
mistral-chat /mnt/slow/runs/patrick/mistral-finetune/7B/ --max_tokens 256 --temperature 1.0 --instruct --lora_path $HOME/ultra_chat_test/checkpoints/checkpoint_000300/consolidated/lora.safetensors
|
474 |
+
```
|
475 |
+
|
476 |
+
## Model extension
|
477 |
+
|
478 |
+
**Important**: Note that one can only fine-tune mistral models that are compatible with the v3 tokenizer which entails that the models have a vocabulary size of 32768 - not 32000. One can however easily extend older version of vocabulary size 32000 to have a vocabulary size of 32768 by using:
|
479 |
+
```
|
480 |
+
python -m utils.extend_model_vocab --original_model_ckpt /folder/to/old/model --extended_model_ckpt /folder/to/extended/model
|
481 |
+
```
|
482 |
+
|
483 |
+
Once the extension has worked, one can fine-tune using the newly created model checkpoint in `/folder/to/extended/model`.
|
484 |
+
|
485 |
+
## FAQ:
|
486 |
+
|
487 |
+
> - What's the best practice of fine-tuning MoEs?
|
488 |
+
|
489 |
+
We see a higher degree of performance variance in when fine-tuning MoE models. It's not unusual to find that fine-tuning MoEs models with different seeds can lead to a high variance in performance. We did not observe such a high variance with dense models. Therefore, we suggest running multiple instances of the same fine-tuning process on MoEs models and selecting the one that performs best.
|
490 |
+
|
491 |
+
> - How can I determine the number of tokens used during the model training process?
|
492 |
+
|
493 |
+
You can use the following script to find out: https://github.com/mistralai/mistral-finetune/blob/main/utils/validate_data.py. This script accepts a .yaml training file as input and returns the number of tokens the model is being trained on.
|
494 |
+
|
495 |
+
> - What should I do if I encounter a CUDA out-of-memory error?
|
496 |
+
|
497 |
+
One possible solution is to reduce the batch size per GPU. The batch size is equal to `seq_len` x `batch_size`. Try setting `batch_size` to 1 and reduce `seq_len`. You can define the `batch_size` and `seq_len` in the .yaml file.
|
__pycache__/train.cpython-310.pyc
ADDED
Binary file (6.27 kB). View file
|
|
example/7B.yaml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# data
|
2 |
+
data:
|
3 |
+
instruct_data: "/root/data/mol_instructions_train.jsonl" # Fill this with the path to your training data
|
4 |
+
data: "" # Optionally fill with pretraining data
|
5 |
+
eval_instruct_data: "" # Optionally fill with evaluation data
|
6 |
+
|
7 |
+
# model
|
8 |
+
model_id_or_path: "/root/mistral_models/7B-v0.3" # Path to downloaded model
|
9 |
+
lora:
|
10 |
+
rank: 64
|
11 |
+
|
12 |
+
# optim
|
13 |
+
seq_len: 32768
|
14 |
+
batch_size: 2
|
15 |
+
#TODO try other values
|
16 |
+
max_steps: 500
|
17 |
+
optim:
|
18 |
+
lr: 5.e-5
|
19 |
+
weight_decay: 0.05
|
20 |
+
pct_start: 0.05
|
21 |
+
|
22 |
+
# other
|
23 |
+
seed: 99
|
24 |
+
log_freq: 1
|
25 |
+
eval_freq: 100
|
26 |
+
no_eval: True
|
27 |
+
ckpt_freq: 100
|
28 |
+
|
29 |
+
ckpt_only_lora: False # Save only trained LoRA adapters. Set to `False` to merge LoRA adapter into the base model and save full fine-tuned model
|
30 |
+
|
31 |
+
run_dir: "/root/mistral-finetune/runseed99"
|
32 |
+
|
33 |
+
wandb:
|
34 |
+
project: "CHEMISTral7b-ft"
|
35 |
+
offline: False # Set to True if you want to use wandb in offline mode
|
36 |
+
key: "aaf77f83a4e316f6a8b47fa975ab6b5e73c7c8df" # Optionally set your WandB API key
|
37 |
+
run_name: "runseed99" # Optionally name your WandB run
|
finetune/__init__.py
ADDED
File without changes
|
finetune/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (136 Bytes). View file
|
|
finetune/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (134 Bytes). View file
|
|
finetune/__pycache__/args.cpython-310.pyc
ADDED
Binary file (3.81 kB). View file
|
|
finetune/__pycache__/args.cpython-38.pyc
ADDED
Binary file (3.79 kB). View file
|
|
finetune/__pycache__/checkpointing.cpython-310.pyc
ADDED
Binary file (8.73 kB). View file
|
|
finetune/__pycache__/checkpointing.cpython-38.pyc
ADDED
Binary file (8.67 kB). View file
|
|
finetune/__pycache__/distributed.cpython-310.pyc
ADDED
Binary file (2.02 kB). View file
|
|
finetune/__pycache__/distributed.cpython-38.pyc
ADDED
Binary file (2.05 kB). View file
|
|
finetune/__pycache__/eval.cpython-310.pyc
ADDED
Binary file (2.24 kB). View file
|
|
finetune/__pycache__/loss.cpython-310.pyc
ADDED
Binary file (569 Bytes). View file
|
|
finetune/__pycache__/mixed_precision.cpython-310.pyc
ADDED
Binary file (1.6 kB). View file
|
|
finetune/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (2.94 kB). View file
|
|
finetune/__pycache__/wrapped_model.cpython-310.pyc
ADDED
Binary file (7.49 kB). View file
|
|
finetune/args.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Optional
|
6 |
+
|
7 |
+
from simple_parsing.helpers import Serializable
|
8 |
+
|
9 |
+
from model.args import LoraArgs
|
10 |
+
|
11 |
+
from .data.args import DataArgs
|
12 |
+
|
13 |
+
|
14 |
+
@dataclass
|
15 |
+
class OptimArgs(Serializable):
|
16 |
+
lr: float = 3e-4
|
17 |
+
weight_decay: float = 0.1
|
18 |
+
pct_start: float = 0.3
|
19 |
+
|
20 |
+
|
21 |
+
@dataclass
|
22 |
+
class WandbArgs(Serializable):
|
23 |
+
project: Optional[str] = None # Fill this argument to use wandb.
|
24 |
+
offline: bool = False
|
25 |
+
key: Optional[str] = None
|
26 |
+
run_name: Optional[str] = None
|
27 |
+
|
28 |
+
def __post_init__(self) -> None:
|
29 |
+
if self.project is not None:
|
30 |
+
try:
|
31 |
+
import wandb # noqa: F401
|
32 |
+
except ImportError:
|
33 |
+
raise ImportError("`wandb` not installed. Either make sure `wandb` is installed or set `wandb:project` to None.")
|
34 |
+
|
35 |
+
if len(self.project) == 0:
|
36 |
+
raise ValueError("`wandb.project` must not be an empty string.")
|
37 |
+
|
38 |
+
@dataclass
|
39 |
+
class MLFlowArgs(Serializable):
|
40 |
+
tracking_uri: Optional[str] = None
|
41 |
+
experiment_name: Optional[str] = None
|
42 |
+
|
43 |
+
def __post_init__(self) -> None:
|
44 |
+
if self.tracking_uri is not None:
|
45 |
+
try:
|
46 |
+
import mlflow # noqa: F401
|
47 |
+
except ImportError:
|
48 |
+
raise ImportError("`mlflow` not installed. Either make sure `mlflow` is installed or set `mlflow.tracking_uri` to None.")
|
49 |
+
|
50 |
+
if self.experiment_name is None:
|
51 |
+
raise ValueError("If `mlflow.tracking_uri` is set, `mlflow.experiment_name` must be set as well.")
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
@dataclass
|
56 |
+
class TrainArgs(Serializable):
|
57 |
+
data: DataArgs
|
58 |
+
|
59 |
+
# if specified, instruct_tokenizer and model will be loaded
|
60 |
+
model_id_or_path: str # Path to the directory containing the initial model or model id: "mistral-small"
|
61 |
+
|
62 |
+
run_dir: str # Path to the directory where everything will be saved. It needs to be empty.
|
63 |
+
# Name of the wandb run, if None it will be set to the name of the run_dir.
|
64 |
+
|
65 |
+
optim: OptimArgs = field(default_factory=OptimArgs)
|
66 |
+
seed: int = 0
|
67 |
+
# Number of steps to accumulate gradients before calling doing an optimizer step.
|
68 |
+
num_microbatches: int = 1
|
69 |
+
|
70 |
+
seq_len: int = 2048 # Number of tokens per batch per device.
|
71 |
+
batch_size: int = 1
|
72 |
+
max_norm: float = 1.0 # Gradient clipping.
|
73 |
+
max_steps: int = 100 # Number of training steps.
|
74 |
+
log_freq: int = 1 # Number of steps between each logging.
|
75 |
+
|
76 |
+
# Number of steps between each checkpoint saving. If inferior to 1, only the last checkpoint will be saved.
|
77 |
+
ckpt_freq: int = 0
|
78 |
+
ckpt_only_lora: bool = True
|
79 |
+
# If True, no checkpoint will be saved. This is useful for development.
|
80 |
+
no_ckpt: bool = False
|
81 |
+
num_ckpt_keep: Optional[int] = 3
|
82 |
+
eval_freq: int = 0
|
83 |
+
no_eval: bool = True
|
84 |
+
|
85 |
+
# Efficiency
|
86 |
+
# Determines whether gradient checkpointing should be utilized or not during the training process. Gradient checkpointing can be beneficial in reducing memory usage at the cost of slightly longer training times.
|
87 |
+
checkpoint: bool = True
|
88 |
+
|
89 |
+
world_size: Optional[int] = field(init=False, default=None)
|
90 |
+
|
91 |
+
# logging
|
92 |
+
wandb: WandbArgs = field(default_factory=WandbArgs)
|
93 |
+
mlflow: MLFlowArgs = field(default_factory=MLFlowArgs)
|
94 |
+
|
95 |
+
# LoRA
|
96 |
+
lora: Optional[LoraArgs] = field(default_factory=LoraArgs)
|
97 |
+
|
98 |
+
def __post_init__(self) -> None:
|
99 |
+
assert getattr(self, "world_size", None) is None
|
100 |
+
self.world_size = int(os.environ.get("WORLD_SIZE", -1))
|
101 |
+
|
102 |
+
if self.wandb.offline:
|
103 |
+
command = f"cd {self.run_dir}; wandb sync --sync-all"
|
104 |
+
logging.info(f"to sync wandb offline, run: {command}")
|
105 |
+
|
106 |
+
assert self.num_microbatches >= 1
|
107 |
+
|
108 |
+
assert self.num_ckpt_keep is None or self.num_ckpt_keep >= 1
|
109 |
+
|
110 |
+
if self.model_id_or_path is not None:
|
111 |
+
Path(self.model_id_or_path).exists()
|
112 |
+
|
113 |
+
if not self.ckpt_only_lora:
|
114 |
+
logging.warning(
|
115 |
+
"You are have disabled `ckpt_only_lora` and are thus merging the trained LoRA checkpoint into the base model upon checkpointing. This might lead to OOM erros - make sure you have enough CPU and GPU memory."
|
116 |
+
)
|
finetune/checkpointing.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import shutil
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Dict, List, Optional, Union
|
6 |
+
|
7 |
+
import safetensors.torch
|
8 |
+
import torch
|
9 |
+
from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerBase
|
10 |
+
from torch.distributed import barrier
|
11 |
+
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
|
12 |
+
|
13 |
+
from model.transformer import LoRALinear
|
14 |
+
|
15 |
+
from .distributed import get_rank, get_world_size
|
16 |
+
from .utils import TrainState
|
17 |
+
|
18 |
+
logger = logging.getLogger("checkpointing")
|
19 |
+
|
20 |
+
|
21 |
+
def main_logger_info(message: str) -> None:
|
22 |
+
if get_rank() == 0:
|
23 |
+
logger.info(message)
|
24 |
+
|
25 |
+
|
26 |
+
class Checkpointer:
|
27 |
+
"""A class to save PyTorch model and optimizer states"""
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
model: FullyShardedDataParallel,
|
32 |
+
state: TrainState,
|
33 |
+
run_dir: Union[Path, str],
|
34 |
+
optimizer: Optional[torch.optim.Optimizer] = None,
|
35 |
+
num_ckpt_keep: Optional[int] = None,
|
36 |
+
):
|
37 |
+
self.model = model
|
38 |
+
self.optimizer = optimizer
|
39 |
+
self.state = state
|
40 |
+
self.run_dir = Path(run_dir)
|
41 |
+
self.rank = get_rank()
|
42 |
+
self.num_ckpt_keep = num_ckpt_keep
|
43 |
+
|
44 |
+
@property
|
45 |
+
def ckpt_dir(self) -> Path:
|
46 |
+
return self.run_dir / "checkpoints"
|
47 |
+
|
48 |
+
@property
|
49 |
+
def dst_dir(self) -> Path:
|
50 |
+
return self.ckpt_dir / f"checkpoint_{self.state.step:06d}" / "consolidated"
|
51 |
+
|
52 |
+
@staticmethod
|
53 |
+
def consolidated_path(
|
54 |
+
ckpt_dir: Path, use_safetensors: bool, save_only_lora: Optional[bool] = False
|
55 |
+
) -> Path:
|
56 |
+
suffix = "safetensors" if use_safetensors else "00.pth"
|
57 |
+
prefix = "lora" if save_only_lora else "consolidated"
|
58 |
+
|
59 |
+
return ckpt_dir / f"{prefix}.{suffix}"
|
60 |
+
|
61 |
+
@staticmethod
|
62 |
+
def _tmp(ckpt_dir: Path) -> Path:
|
63 |
+
return ckpt_dir.with_name(f"tmp.{ckpt_dir.name}")
|
64 |
+
|
65 |
+
def write_params_info(self, tmp_dst: Path):
|
66 |
+
params_path = tmp_dst / "params.json"
|
67 |
+
with open(params_path, "w") as f:
|
68 |
+
model_args = self.model.args.to_dict()
|
69 |
+
|
70 |
+
f.write(json.dumps(model_args, indent=4))
|
71 |
+
|
72 |
+
def delete_old_ckpts(self) -> List[Path]:
|
73 |
+
all_saved_ckpts = [d for d in self.ckpt_dir.iterdir() if d.is_dir()]
|
74 |
+
|
75 |
+
# Sort directories by creation time (oldest to newest)
|
76 |
+
all_saved_ckpts.sort(key=lambda x: x.stat().st_ctime, reverse=True)
|
77 |
+
|
78 |
+
ckpts_to_delete = all_saved_ckpts[self.num_ckpt_keep :]
|
79 |
+
|
80 |
+
for ckpt_to_delete in ckpts_to_delete:
|
81 |
+
try:
|
82 |
+
shutil.rmtree(ckpt_to_delete)
|
83 |
+
main_logger_info(f"Deleted ckpt: {ckpt_to_delete}")
|
84 |
+
except OSError as e:
|
85 |
+
main_logger_info(f"Error deleting directory {ckpt_to_delete}: {e}")
|
86 |
+
|
87 |
+
return ckpts_to_delete
|
88 |
+
|
89 |
+
@staticmethod
|
90 |
+
def get_lora_states(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
91 |
+
return {k: v for k, v in state_dict.items() if "lora" in k}
|
92 |
+
|
93 |
+
@staticmethod
|
94 |
+
def get_non_lora_states(
|
95 |
+
state_dict: Dict[str, torch.Tensor]
|
96 |
+
) -> Dict[str, torch.Tensor]:
|
97 |
+
return {
|
98 |
+
k: v
|
99 |
+
for k, v in state_dict.items()
|
100 |
+
if not any(l_key in k for l_key in ["lora", "frozen"])
|
101 |
+
}
|
102 |
+
|
103 |
+
@torch.no_grad()
|
104 |
+
def retrieve_save_states(
|
105 |
+
self, save_only_lora: bool, save_dtype: torch.dtype
|
106 |
+
) -> Dict[str, torch.Tensor]:
|
107 |
+
if save_only_lora:
|
108 |
+
assert (
|
109 |
+
self.model.args.lora.enable
|
110 |
+
), "Cannot save LoRA checkpoint as LoRA training is not enabled."
|
111 |
+
|
112 |
+
# remove all potential hooks
|
113 |
+
for module in self.model.modules():
|
114 |
+
if isinstance(module, LoRALinear) and hasattr(module, "_merge_lora_handle"):
|
115 |
+
module._merge_lora_handle.remove() # type: ignore
|
116 |
+
|
117 |
+
# merge weights if we don't just save LoRA
|
118 |
+
if not save_only_lora:
|
119 |
+
|
120 |
+
def merge_lora(
|
121 |
+
m: torch.nn.Module,
|
122 |
+
destination: Dict[str, torch.Tensor],
|
123 |
+
prefix: str,
|
124 |
+
*args,
|
125 |
+
):
|
126 |
+
weight = m.merge_weight() # type: ignore
|
127 |
+
destination[prefix + "weight"] = weight
|
128 |
+
|
129 |
+
for module in self.model.modules():
|
130 |
+
if isinstance(module, LoRALinear):
|
131 |
+
module._merge_lora_handle = module._register_state_dict_hook(
|
132 |
+
merge_lora
|
133 |
+
)
|
134 |
+
|
135 |
+
offload_to_cpu = get_world_size() > 1
|
136 |
+
if save_only_lora:
|
137 |
+
|
138 |
+
def is_trainable_fsdp(
|
139 |
+
module: Union[torch.nn.Module, FullyShardedDataParallel]
|
140 |
+
):
|
141 |
+
is_fsdp = isinstance(module, FullyShardedDataParallel)
|
142 |
+
all_params_have_grads = is_fsdp and all(
|
143 |
+
p.requires_grad is True for p in module.parameters()
|
144 |
+
)
|
145 |
+
|
146 |
+
# need to make sure only lowest fsdp wrap is used
|
147 |
+
is_leaf_node = is_fsdp and len(list(module.module.children())) == 0 # type: ignore
|
148 |
+
|
149 |
+
return is_fsdp and all_params_have_grads and is_leaf_node
|
150 |
+
|
151 |
+
# extract all modules with only trainable weights
|
152 |
+
modules = {
|
153 |
+
k: m for k, m in self.model.named_modules() if is_trainable_fsdp(m)
|
154 |
+
}
|
155 |
+
|
156 |
+
states = {}
|
157 |
+
for key, module in modules.items():
|
158 |
+
assert isinstance(
|
159 |
+
module, FullyShardedDataParallel
|
160 |
+
), "`module` should be an instance of `FullyShardedDataParallel`"
|
161 |
+
parent_prefix = key.replace("_fsdp_wrapped_module.", "").replace(
|
162 |
+
"_checkpoint_wrapped_module.", ""
|
163 |
+
)
|
164 |
+
with module.summon_full_params(
|
165 |
+
module, writeback=True, offload_to_cpu=offload_to_cpu
|
166 |
+
):
|
167 |
+
states.update(
|
168 |
+
{
|
169 |
+
f"{parent_prefix}.{k}": v.to(dtype=save_dtype)
|
170 |
+
for k, v in module.state_dict().items()
|
171 |
+
}
|
172 |
+
)
|
173 |
+
else:
|
174 |
+
# make sure you have enough CPU RAM available to save the full model
|
175 |
+
assert isinstance(
|
176 |
+
self.model, FullyShardedDataParallel
|
177 |
+
), "`self.model` should be an instance of `FullyShardedDataParallel`"
|
178 |
+
with self.model.summon_full_params(
|
179 |
+
self.model, writeback=True, offload_to_cpu=offload_to_cpu
|
180 |
+
):
|
181 |
+
states = self.get_non_lora_states(self.model.state_dict())
|
182 |
+
states = {k: v.to(dtype=save_dtype) for k, v in states.items()}
|
183 |
+
|
184 |
+
states = dict(sorted(states.items()))
|
185 |
+
return states
|
186 |
+
|
187 |
+
@staticmethod
|
188 |
+
def save_tokenizer(instruct_tokenizer: InstructTokenizerBase, tmp_dst: Path):
|
189 |
+
serialized_spm = instruct_tokenizer.tokenizer._model.serialized_model_proto() # type: ignore
|
190 |
+
|
191 |
+
tokenizer_path = tmp_dst / "tokenizer.model.v3"
|
192 |
+
|
193 |
+
with open(tokenizer_path, "wb") as f:
|
194 |
+
f.write(serialized_spm)
|
195 |
+
|
196 |
+
@torch.no_grad()
|
197 |
+
def save_checkpoint(
|
198 |
+
self,
|
199 |
+
save_only_lora: bool,
|
200 |
+
dtype: torch.dtype = torch.float16,
|
201 |
+
instruct_tokenizer: Optional[InstructTokenizerBase] = None,
|
202 |
+
):
|
203 |
+
tmp_dst = self._tmp(self.dst_dir)
|
204 |
+
main_logger_info(
|
205 |
+
f"Dumping checkpoint in {self.dst_dir} using tmp name: {tmp_dst.name}"
|
206 |
+
)
|
207 |
+
|
208 |
+
assert not self.dst_dir.exists(), f"dst exists {self.dst_dir}"
|
209 |
+
tmp_dst.mkdir(parents=True, exist_ok=True)
|
210 |
+
|
211 |
+
states: Dict[str, torch.Tensor] = self.retrieve_save_states(
|
212 |
+
save_only_lora, dtype
|
213 |
+
)
|
214 |
+
|
215 |
+
barrier()
|
216 |
+
|
217 |
+
if self.rank == 0:
|
218 |
+
# save checkpoint in tmp path
|
219 |
+
safetensors.torch.save_file(
|
220 |
+
states,
|
221 |
+
self.consolidated_path(
|
222 |
+
tmp_dst, use_safetensors=True, save_only_lora=save_only_lora
|
223 |
+
), # always use safetensors for checkpointing
|
224 |
+
)
|
225 |
+
|
226 |
+
self.write_params_info(tmp_dst)
|
227 |
+
|
228 |
+
# save tokenizer
|
229 |
+
if instruct_tokenizer is not None:
|
230 |
+
self.save_tokenizer(instruct_tokenizer, tmp_dst)
|
231 |
+
|
232 |
+
assert not self.dst_dir.exists(), f"should not happen! {self.dst_dir}"
|
233 |
+
tmp_dst.rename(self.dst_dir)
|
234 |
+
|
235 |
+
logger.info(
|
236 |
+
f"Done dumping checkpoint in {self.dst_dir} for step: {self.state.step}"
|
237 |
+
)
|
238 |
+
|
239 |
+
# delete last n checkpoints
|
240 |
+
if self.num_ckpt_keep is not None:
|
241 |
+
ckpts_to_delete = self.delete_old_ckpts()
|
242 |
+
logger.info(
|
243 |
+
f"Done deleting checkpoints {', '.join([str(c) for c in ckpts_to_delete])}"
|
244 |
+
)
|
245 |
+
|
246 |
+
main_logger_info("Done!")
|
finetune/data/__init__.py
ADDED
File without changes
|
finetune/data/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (141 Bytes). View file
|
|
finetune/data/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (139 Bytes). View file
|
|
finetune/data/__pycache__/args.cpython-310.pyc
ADDED
Binary file (1.34 kB). View file
|
|
finetune/data/__pycache__/args.cpython-38.pyc
ADDED
Binary file (1.33 kB). View file
|
|
finetune/data/__pycache__/data_loader.cpython-310.pyc
ADDED
Binary file (4.26 kB). View file
|
|
finetune/data/__pycache__/dataset.cpython-310.pyc
ADDED
Binary file (11.1 kB). View file
|
|
finetune/data/__pycache__/dataset.cpython-38.pyc
ADDED
Binary file (11.1 kB). View file
|
|
finetune/data/__pycache__/exceptions.cpython-310.pyc
ADDED
Binary file (2.57 kB). View file
|
|
finetune/data/__pycache__/exceptions.cpython-38.pyc
ADDED
Binary file (2.91 kB). View file
|
|
finetune/data/__pycache__/tokenize.cpython-310.pyc
ADDED
Binary file (10.2 kB). View file
|
|
finetune/data/__pycache__/tokenize.cpython-38.pyc
ADDED
Binary file (10.3 kB). View file
|
|
finetune/data/args.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from dataclasses import dataclass, field
|
3 |
+
|
4 |
+
from simple_parsing.helpers import Serializable
|
5 |
+
|
6 |
+
logger = logging.getLogger("data")
|
7 |
+
|
8 |
+
|
9 |
+
@dataclass()
|
10 |
+
class InstructArgs(Serializable):
|
11 |
+
shuffle: bool = True
|
12 |
+
|
13 |
+
# For function calling training examples only the last tool call
|
14 |
+
# of the assistant message can be used for training. Therefore,
|
15 |
+
# we chunk longer function calling conversations into multiple
|
16 |
+
# training samples to not loose any data. E.g.:
|
17 |
+
# [[
|
18 |
+
# UserMessage_1, AssisantToolCallMessage_1, ToolMessage_1, AssisantMessage_1
|
19 |
+
# UserMessage_2, AssisantToolCallMessage_2, ToolMessage_2, AssisantMessage_2
|
20 |
+
# ]]
|
21 |
+
# => is chunked into two training samples:
|
22 |
+
# [[
|
23 |
+
# UserMessage_1, AssisantToolCallMessage_1, ToolMessage_1, AssisantMessage_1
|
24 |
+
# ],
|
25 |
+
# [
|
26 |
+
# UserMessage_1, AssisantToolCallMessage_1, ToolMessage_1, AssisantMessage_1
|
27 |
+
# UserMessage_2, AssisantToolCallMessage_2, ToolMessage_2, AssisantMessage_2
|
28 |
+
# ]]
|
29 |
+
# NOTE: Only if your data is already pre-chunked should this argument be set to False
|
30 |
+
dynamic_chunk_fn_call: bool = True
|
31 |
+
|
32 |
+
|
33 |
+
@dataclass()
|
34 |
+
class DataArgs(Serializable):
|
35 |
+
# The data arguments `data` and `instruct_data` are a string in the format
|
36 |
+
# "data_source_dir_1:weight_1,data_source_dir_2:weight_2,...". The weight
|
37 |
+
# will be used to sample the data sources. If the sum of the weights is
|
38 |
+
# not 1 when concatenating the two arguments `data` and `instruct_data`,
|
39 |
+
# it will be normalized. The data sources folders must contain jsonl files.
|
40 |
+
# If the value is an empty string, no data will be used for the corresponding
|
41 |
+
# data type.
|
42 |
+
data: str = (
|
43 |
+
"" # Each line in the jsonl files inside the data source directories must be a dictionary with a "text" key. See Readme for more details. Can be left empty.
|
44 |
+
)
|
45 |
+
shuffle: bool = False
|
46 |
+
instruct_data: str = (
|
47 |
+
"" # Each line in the jsonl files inside the data source directories must be a dictionary with a "interactions" key. See Readme for more details. Can be left empty.
|
48 |
+
)
|
49 |
+
eval_instruct_data: str = (
|
50 |
+
"" # Each line in the jsonl files inside the data source directories must be a dictionary with a "interactions" key. See Readme for more details. Can be left empty.
|
51 |
+
)
|
52 |
+
instruct: InstructArgs = field(default_factory=InstructArgs)
|
53 |
+
|
54 |
+
def __post_init__(self) -> None:
|
55 |
+
if (
|
56 |
+
self.instruct.shuffle is False
|
57 |
+
and self.instruct.dynamic_chunk_fn_call is True
|
58 |
+
):
|
59 |
+
raise ValueError(
|
60 |
+
"Make sure to either enable `data.instruct.shuffle=True` or `data.instruct.dynamic_chunk_fn_call=False`. Dynamic chunking is only possible if data is loaded and shuffled before training."
|
61 |
+
)
|
finetune/data/data_loader.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from typing import Any, Iterator, List, Optional
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerBase
|
6 |
+
|
7 |
+
from .args import DataArgs
|
8 |
+
from .dataset import build_dataset
|
9 |
+
|
10 |
+
|
11 |
+
@dataclasses.dataclass
|
12 |
+
class Batch:
|
13 |
+
x: np.ndarray
|
14 |
+
y: np.ndarray
|
15 |
+
sizes: List[int]
|
16 |
+
y_mask: Optional[np.ndarray] = None
|
17 |
+
is_pad_only: bool = False
|
18 |
+
|
19 |
+
def __post_init__(self):
|
20 |
+
assert self.x.ndim == 1
|
21 |
+
assert self.x.shape == self.y.shape
|
22 |
+
assert self.x.dtype == np.int64
|
23 |
+
assert self.y.dtype == np.int64
|
24 |
+
assert isinstance(self.sizes, list)
|
25 |
+
assert sum(self.sizes) == self.x.size == self.y.size
|
26 |
+
|
27 |
+
if self.y_mask is not None:
|
28 |
+
assert self.y_mask.size == self.y.size, (self.y_mask.shape, self.y.shape)
|
29 |
+
assert self.y_mask.dtype == bool
|
30 |
+
assert sum(self.sizes) == self.y_mask.size
|
31 |
+
assert not self.y_mask.all()
|
32 |
+
assert self.y_mask.any()
|
33 |
+
|
34 |
+
if self.is_pad_only:
|
35 |
+
assert np.sum(np.abs(self.y)) == 0
|
36 |
+
assert np.sum(np.abs(self.x)) == 0
|
37 |
+
assert self.y_mask is None
|
38 |
+
# create all 0's mask for pad samples
|
39 |
+
self.y_mask = np.zeros_like(self.x)
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
@dataclasses.dataclass
|
45 |
+
class BatchList:
|
46 |
+
x: List[List[int]] = dataclasses.field(default_factory=list)
|
47 |
+
y: List[List[int]] = dataclasses.field(default_factory=list)
|
48 |
+
sizes: List[List[int]] = dataclasses.field(default_factory=list)
|
49 |
+
y_mask: List[Optional[List[int]]] = dataclasses.field(default_factory=list)
|
50 |
+
|
51 |
+
def __post_init__(self):
|
52 |
+
assert self.x == [], "`BatchList` has to be empty at init."
|
53 |
+
assert self.y == [], "`BatchList` has to be empty at init."
|
54 |
+
assert self.sizes == [], "`BatchList` has to be empty at init."
|
55 |
+
assert self.y_mask == [], "`BatchList` has to be empty at init."
|
56 |
+
|
57 |
+
def __len__(self) -> int:
|
58 |
+
return len(self.x)
|
59 |
+
|
60 |
+
def add(self, x: List[int], y: List[int], sizes: List[int], y_mask: Optional[List[int]] = None):
|
61 |
+
self.x.append(x)
|
62 |
+
self.y.append(y)
|
63 |
+
self.sizes.append(sizes)
|
64 |
+
self.y_mask.append(y_mask)
|
65 |
+
|
66 |
+
def empty(self):
|
67 |
+
self.x = []
|
68 |
+
self.y = []
|
69 |
+
self.sizes = []
|
70 |
+
self.y_mask = []
|
71 |
+
|
72 |
+
@staticmethod
|
73 |
+
def flatten_to_numpy(list_of_lists: List[List[Any]], dtype: np.dtype) -> np.array:
|
74 |
+
return np.array([el for sublist in list_of_lists for el in sublist], dtype=dtype)
|
75 |
+
|
76 |
+
def create_batch(self) -> Batch:
|
77 |
+
x_np: np.array = self.flatten_to_numpy(self.x, dtype=np.int64)
|
78 |
+
y_np: np.array = self.flatten_to_numpy(self.y, dtype=np.int64)
|
79 |
+
sizes = sum(self.sizes, []) # noqa
|
80 |
+
|
81 |
+
y_mask_np: Optional[np.array] = self.flatten_to_numpy(self.y_mask, dtype=bool)
|
82 |
+
y_mask_np = None if y_mask_np.all() else y_mask_np
|
83 |
+
|
84 |
+
return Batch(x_np, y_np, sizes, y_mask_np)
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
def build_data_loader(
|
90 |
+
instruct_tokenizer: InstructTokenizerBase,
|
91 |
+
args: DataArgs,
|
92 |
+
batch_size: int,
|
93 |
+
seq_len: int,
|
94 |
+
seed: Optional[int],
|
95 |
+
rank: int,
|
96 |
+
world_size: int,
|
97 |
+
is_eval: bool,
|
98 |
+
) -> Iterator[Batch]:
|
99 |
+
pretrain_data = args.data if not is_eval else ""
|
100 |
+
instruct_data = args.instruct_data if not is_eval else args.eval_instruct_data
|
101 |
+
|
102 |
+
dataset = build_dataset(
|
103 |
+
pretrain_data=pretrain_data,
|
104 |
+
instruct_data=instruct_data,
|
105 |
+
instruct_args=args.instruct,
|
106 |
+
instruct_tokenizer=instruct_tokenizer,
|
107 |
+
seq_len=seq_len,
|
108 |
+
seed=seed,
|
109 |
+
rank=rank,
|
110 |
+
world_size=world_size,
|
111 |
+
is_eval=is_eval,
|
112 |
+
shuffle_pretrain=args.shuffle,
|
113 |
+
)
|
114 |
+
|
115 |
+
batch_list = BatchList()
|
116 |
+
for sample in dataset:
|
117 |
+
assert all(s >= 0 for s in sample.sizes)
|
118 |
+
|
119 |
+
batch_list.add(sample.x, sample.y, sample.sizes, sample.mask)
|
120 |
+
|
121 |
+
if len(batch_list) == batch_size:
|
122 |
+
batch: Batch = batch_list.create_batch()
|
123 |
+
yield batch
|
124 |
+
|
125 |
+
batch_list.empty()
|
126 |
+
|
finetune/data/dataset.py
ADDED
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
import itertools
|
3 |
+
import json
|
4 |
+
import logging
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch.distributed as dist
|
11 |
+
from mistral_common.protocol.instruct.messages import (
|
12 |
+
FinetuningAssistantMessage,
|
13 |
+
SystemMessage,
|
14 |
+
)
|
15 |
+
from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerBase
|
16 |
+
|
17 |
+
from finetune.distributed import get_rank
|
18 |
+
|
19 |
+
from .args import InstructArgs
|
20 |
+
from .tokenize import (
|
21 |
+
Mask,
|
22 |
+
SampleType,
|
23 |
+
TokenSample,
|
24 |
+
TrainingInstructSample,
|
25 |
+
build_instruct_sample,
|
26 |
+
encode,
|
27 |
+
)
|
28 |
+
|
29 |
+
logger = logging.getLogger("dataset")
|
30 |
+
|
31 |
+
|
32 |
+
_LOADED_DATASETS: Dict[Path, List[str]] = {}
|
33 |
+
|
34 |
+
|
35 |
+
def main_logger_info(message: str) -> None:
|
36 |
+
if dist.is_initialized() and get_rank() == 0:
|
37 |
+
logger.info(message)
|
38 |
+
|
39 |
+
|
40 |
+
def load_file(path: Path, world_size: int, rank: int) -> List[str]:
|
41 |
+
lines = []
|
42 |
+
with path.open() as f:
|
43 |
+
for idx, line in enumerate(f):
|
44 |
+
if not idx % world_size == rank:
|
45 |
+
continue
|
46 |
+
lines.append(line)
|
47 |
+
return lines
|
48 |
+
|
49 |
+
|
50 |
+
def maybe_load_local_dataset(
|
51 |
+
path: Path, chunk: bool, rank: int, world_size: int, instruct_tokenizer: InstructTokenizerBase, sample_type: SampleType
|
52 |
+
) -> List[TokenSample]:
|
53 |
+
global _LOADED_DATASETS
|
54 |
+
|
55 |
+
if path in _LOADED_DATASETS:
|
56 |
+
return _LOADED_DATASETS[path]
|
57 |
+
|
58 |
+
main_logger_info(f"Loading {path} ...")
|
59 |
+
lines: List[str] = load_file(path, rank=rank, world_size=world_size)
|
60 |
+
|
61 |
+
if chunk:
|
62 |
+
lines += maybe_chunk_lines(lines)
|
63 |
+
|
64 |
+
tokens_list: List[TokenSample] = []
|
65 |
+
for line in lines:
|
66 |
+
data = json.loads(line)
|
67 |
+
|
68 |
+
token_sample: TokenSample = encode(
|
69 |
+
data,
|
70 |
+
instruct_tokenizer=instruct_tokenizer,
|
71 |
+
as_type=sample_type,
|
72 |
+
)
|
73 |
+
tokens_list.append(token_sample)
|
74 |
+
|
75 |
+
main_logger_info(f"{path} loaded and tokenized.")
|
76 |
+
_LOADED_DATASETS[path] = tokens_list
|
77 |
+
|
78 |
+
return _LOADED_DATASETS[path]
|
79 |
+
|
80 |
+
|
81 |
+
@dataclass
|
82 |
+
class DataDir:
|
83 |
+
path: Path
|
84 |
+
sample_type: SampleType
|
85 |
+
|
86 |
+
@property
|
87 |
+
def jsonl_files(self):
|
88 |
+
assert self.path.exists(), f"Make sure that {self.path} exists"
|
89 |
+
jsonl_files = list(self.path.rglob("*jsonl"))
|
90 |
+
assert (
|
91 |
+
len(jsonl_files) > 0
|
92 |
+
), f"{self.path} does not seem to have any files ending with '.jsonl'"
|
93 |
+
return jsonl_files
|
94 |
+
|
95 |
+
|
96 |
+
@dataclass
|
97 |
+
class DataFile:
|
98 |
+
path: Path
|
99 |
+
sample_type: SampleType
|
100 |
+
|
101 |
+
@property
|
102 |
+
def jsonl_files(self):
|
103 |
+
assert self.path.exists(), f"Make sure that {self.path} exists"
|
104 |
+
return [self.path]
|
105 |
+
|
106 |
+
|
107 |
+
def parse_data_sources(
|
108 |
+
pretrain_data: str,
|
109 |
+
instruct_data: str,
|
110 |
+
) -> Tuple[List[Union[DataDir, DataFile]], List[float]]:
|
111 |
+
seen: Set[str] = set()
|
112 |
+
sources: List[Union[DataDir, DataFile]] = []
|
113 |
+
weights: List[float] = []
|
114 |
+
for sample_sources, sample_type in [
|
115 |
+
(pretrain_data, SampleType.PRETRAIN),
|
116 |
+
(instruct_data, SampleType.INSTRUCT),
|
117 |
+
]:
|
118 |
+
for source in sample_sources.strip().split(","):
|
119 |
+
if not source:
|
120 |
+
continue
|
121 |
+
|
122 |
+
source_items = source.strip().split(":")
|
123 |
+
if len(source_items) == 1:
|
124 |
+
path_ = source_items[0]
|
125 |
+
weight = 1.0
|
126 |
+
elif len(source_items) == 2:
|
127 |
+
path_, weight_ = source_items
|
128 |
+
weight = float(weight_)
|
129 |
+
else:
|
130 |
+
raise ValueError(
|
131 |
+
f"{source} is not correctly formatted. Make sure to format each data source as <path/to/data>:<weight> or just <path/to/data>"
|
132 |
+
)
|
133 |
+
|
134 |
+
assert (
|
135 |
+
path_ not in seen
|
136 |
+
), f"{path_} seems to be duplicated. Make sure to only add it once."
|
137 |
+
assert (
|
138 |
+
weight > 0
|
139 |
+
), f"Make sure to define strictly positive data sampling weights, not {weight}"
|
140 |
+
|
141 |
+
data: Union[DataDir, DataFile]
|
142 |
+
if Path(path_).is_dir():
|
143 |
+
data = DataDir(path=Path(path_), sample_type=sample_type)
|
144 |
+
elif Path(path_).is_file():
|
145 |
+
data = DataFile(path=Path(path_), sample_type=sample_type)
|
146 |
+
else:
|
147 |
+
raise FileNotFoundError(
|
148 |
+
f"The path {path_} does not exist. Make sure {path_} is either a file or directory that contains training data."
|
149 |
+
)
|
150 |
+
|
151 |
+
sources.append(data)
|
152 |
+
weights.append(weight)
|
153 |
+
|
154 |
+
seen.add(path_)
|
155 |
+
|
156 |
+
sum_weights = sum(weights)
|
157 |
+
n_weights = [weight / sum_weights for weight in weights]
|
158 |
+
|
159 |
+
assert min(n_weights) > 0
|
160 |
+
assert (
|
161 |
+
abs(1 - sum(n_weights)) < 1e-8
|
162 |
+
), f"Defined data sampling weights {weights} must sum to 1."
|
163 |
+
return sources, n_weights
|
164 |
+
|
165 |
+
|
166 |
+
@dataclasses.dataclass()
|
167 |
+
class SequenceMaskAndSizes:
|
168 |
+
"""
|
169 |
+
Concatenation of samples to reach a given size
|
170 |
+
"""
|
171 |
+
|
172 |
+
x: List[int]
|
173 |
+
y: List[int]
|
174 |
+
mask: Mask
|
175 |
+
sizes: List[int]
|
176 |
+
|
177 |
+
def __post_init__(self):
|
178 |
+
assert sum(self.sizes) == len(self.x) == len(self.y) == len(self.mask)
|
179 |
+
|
180 |
+
|
181 |
+
def sequence_iterator(
|
182 |
+
ds_it: Iterator[TokenSample],
|
183 |
+
seq_len: int,
|
184 |
+
is_finite: bool,
|
185 |
+
) -> Iterator[SequenceMaskAndSizes]:
|
186 |
+
"""
|
187 |
+
Creates sequences of length `seq_len` from the dataset iterator by concatenating samples.
|
188 |
+
"""
|
189 |
+
x_buffer: List[int] = []
|
190 |
+
y_buffer: List[int] = []
|
191 |
+
mask_buffer: Mask = []
|
192 |
+
|
193 |
+
sizes: List[int] = []
|
194 |
+
n_missing = seq_len
|
195 |
+
for sample in ds_it:
|
196 |
+
assert 0 <= len(x_buffer) < seq_len, len(x_buffer)
|
197 |
+
assert n_missing == seq_len - len(
|
198 |
+
x_buffer
|
199 |
+
), f"n_missing: {n_missing} | seq_len - len(x_buffer) {seq_len - len(x_buffer)}"
|
200 |
+
|
201 |
+
tokens, mask = sample.tokens, sample.masks[1:]
|
202 |
+
x, y = tokens[:-1], tokens[1:]
|
203 |
+
cur_pos = 0
|
204 |
+
|
205 |
+
while cur_pos < len(x):
|
206 |
+
size = len(x[cur_pos : cur_pos + n_missing])
|
207 |
+
|
208 |
+
curr_mask = mask[cur_pos : cur_pos + n_missing]
|
209 |
+
if not any(curr_mask):
|
210 |
+
cur_pos += size
|
211 |
+
# we have a sequence with a mask filled with False
|
212 |
+
continue
|
213 |
+
|
214 |
+
x_buffer.extend(x[cur_pos : cur_pos + n_missing])
|
215 |
+
y_buffer.extend(y[cur_pos : cur_pos + n_missing])
|
216 |
+
mask_buffer.extend(curr_mask)
|
217 |
+
n_missing -= size
|
218 |
+
sizes.append(size)
|
219 |
+
|
220 |
+
cur_pos += size
|
221 |
+
|
222 |
+
if n_missing == 0:
|
223 |
+
assert len(mask_buffer) == len(x_buffer) == seq_len == len(y_buffer)
|
224 |
+
assert sum(sizes) == seq_len
|
225 |
+
# we don't want to yield sequences with a mask filled with False
|
226 |
+
if any(mask_buffer):
|
227 |
+
yield SequenceMaskAndSizes(
|
228 |
+
x=x_buffer,
|
229 |
+
y=y_buffer,
|
230 |
+
mask=mask_buffer,
|
231 |
+
sizes=sizes,
|
232 |
+
)
|
233 |
+
x_buffer, y_buffer = [], []
|
234 |
+
mask_buffer = []
|
235 |
+
sizes = []
|
236 |
+
n_missing = seq_len
|
237 |
+
|
238 |
+
if is_finite:
|
239 |
+
# if dataloader is in eval, pad to seq length
|
240 |
+
if any(mask_buffer):
|
241 |
+
mask_buffer.extend(n_missing * [False])
|
242 |
+
x_buffer.extend(n_missing * [0])
|
243 |
+
y_buffer.extend(n_missing * [0])
|
244 |
+
sizes.append(n_missing)
|
245 |
+
|
246 |
+
yield SequenceMaskAndSizes(
|
247 |
+
x=x_buffer,
|
248 |
+
y=y_buffer,
|
249 |
+
mask=mask_buffer,
|
250 |
+
sizes=sizes,
|
251 |
+
)
|
252 |
+
|
253 |
+
|
254 |
+
def build_dataset(
|
255 |
+
pretrain_data: str,
|
256 |
+
instruct_data: str,
|
257 |
+
instruct_args: InstructArgs,
|
258 |
+
instruct_tokenizer: InstructTokenizerBase,
|
259 |
+
seq_len: int,
|
260 |
+
seed: Optional[int],
|
261 |
+
rank: int,
|
262 |
+
world_size: int,
|
263 |
+
is_eval: bool,
|
264 |
+
shuffle_pretrain: bool = False,
|
265 |
+
) -> Iterator[SequenceMaskAndSizes]:
|
266 |
+
sources, probabilities = parse_data_sources(
|
267 |
+
pretrain_data=pretrain_data, instruct_data=instruct_data
|
268 |
+
)
|
269 |
+
|
270 |
+
def do_shuffle(source: Union[DataDir, DataFile]) -> bool:
|
271 |
+
shuffle = {
|
272 |
+
SampleType.PRETRAIN: shuffle_pretrain,
|
273 |
+
SampleType.INSTRUCT: instruct_args.shuffle,
|
274 |
+
}[source.sample_type]
|
275 |
+
|
276 |
+
return not is_eval and shuffle
|
277 |
+
|
278 |
+
dataset_iterators = [
|
279 |
+
get_dataset_iterator(
|
280 |
+
source,
|
281 |
+
instruct_args=instruct_args,
|
282 |
+
instruct_tokenizer=instruct_tokenizer,
|
283 |
+
rank=rank,
|
284 |
+
world_size=world_size,
|
285 |
+
is_finite=is_eval,
|
286 |
+
seed=seed,
|
287 |
+
shuffle_at_epoch=do_shuffle(source),
|
288 |
+
)
|
289 |
+
for source in sources
|
290 |
+
]
|
291 |
+
|
292 |
+
sequence_iterators = [
|
293 |
+
sequence_iterator(
|
294 |
+
ds_it=it,
|
295 |
+
seq_len=seq_len,
|
296 |
+
is_finite=is_eval,
|
297 |
+
)
|
298 |
+
for it in dataset_iterators
|
299 |
+
]
|
300 |
+
|
301 |
+
if is_eval:
|
302 |
+
combined_iterator = itertools.chain.from_iterable(sequence_iterators)
|
303 |
+
else:
|
304 |
+
# make sure random_seed is different per rank and original seed
|
305 |
+
random_seed = np.array((seed, rank))
|
306 |
+
rng = np.random.RandomState(seed=random_seed)
|
307 |
+
combined_iterator = interleave_iterators(
|
308 |
+
sequence_iterators, probabilities=probabilities, rng=rng
|
309 |
+
)
|
310 |
+
|
311 |
+
return combined_iterator
|
312 |
+
|
313 |
+
|
314 |
+
def get_rng(seed: int, rank: int) -> np.random.RandomState:
|
315 |
+
random_seed = np.array((seed, rank))
|
316 |
+
rng = np.random.RandomState(seed=random_seed)
|
317 |
+
return rng
|
318 |
+
|
319 |
+
|
320 |
+
def get_dataset_iterator(
|
321 |
+
source: Union[DataDir, DataFile],
|
322 |
+
instruct_args: InstructArgs,
|
323 |
+
instruct_tokenizer: InstructTokenizerBase,
|
324 |
+
rank: int,
|
325 |
+
world_size: int,
|
326 |
+
is_finite: bool,
|
327 |
+
seed: Optional[int],
|
328 |
+
shuffle_at_epoch: bool,
|
329 |
+
) -> Iterator[TokenSample]:
|
330 |
+
jsonl_files = source.jsonl_files
|
331 |
+
rng: Optional[np.random.RandomState] = (
|
332 |
+
get_rng(seed, rank) if seed is not None else None
|
333 |
+
)
|
334 |
+
|
335 |
+
chunk_dataset = (
|
336 |
+
instruct_args.dynamic_chunk_fn_call
|
337 |
+
and source.sample_type == SampleType.INSTRUCT
|
338 |
+
)
|
339 |
+
|
340 |
+
if not is_finite:
|
341 |
+
# train mode
|
342 |
+
while True:
|
343 |
+
for jsonl_file in jsonl_files:
|
344 |
+
if shuffle_at_epoch:
|
345 |
+
assert rng is not None, "`seed` has to be passed when shuffling"
|
346 |
+
# will preload all data into RAM, shuffle and yield
|
347 |
+
yield from preload_and_yield(
|
348 |
+
jsonl_file,
|
349 |
+
chunk_dataset=chunk_dataset,
|
350 |
+
rank=rank,
|
351 |
+
world_size=world_size,
|
352 |
+
rng=rng,
|
353 |
+
instruct_tokenizer=instruct_tokenizer,
|
354 |
+
sample_type=source.sample_type,
|
355 |
+
)
|
356 |
+
else:
|
357 |
+
# will read data on-the-fly and yield
|
358 |
+
main_logger_info(f"Lazily loading {jsonl_file} ...")
|
359 |
+
yield from lazy_load_and_yield(
|
360 |
+
jsonl_file,
|
361 |
+
rank=rank,
|
362 |
+
world_size=world_size,
|
363 |
+
instruct_tokenizer=instruct_tokenizer,
|
364 |
+
sample_type=source.sample_type,
|
365 |
+
)
|
366 |
+
else:
|
367 |
+
# eval mode
|
368 |
+
for jsonl_file in jsonl_files:
|
369 |
+
# No need to shuffle for eval
|
370 |
+
yield from lazy_load_and_yield(
|
371 |
+
jsonl_file,
|
372 |
+
rank=rank,
|
373 |
+
world_size=world_size,
|
374 |
+
instruct_tokenizer=instruct_tokenizer,
|
375 |
+
sample_type=source.sample_type,
|
376 |
+
)
|
377 |
+
|
378 |
+
|
379 |
+
def preload_and_yield(
|
380 |
+
jsonl_file: Path,
|
381 |
+
chunk_dataset: bool,
|
382 |
+
rank: int,
|
383 |
+
world_size: int,
|
384 |
+
rng: np.random.RandomState,
|
385 |
+
instruct_tokenizer: InstructTokenizerBase,
|
386 |
+
sample_type: SampleType,
|
387 |
+
) -> Iterator[TokenSample]:
|
388 |
+
# only instruct data has to be chunked
|
389 |
+
# load dataset if not already loaded. Make sure to only load 1/world_size dataset
|
390 |
+
tokens_list = maybe_load_local_dataset(
|
391 |
+
jsonl_file, chunk=chunk_dataset, rank=rank, world_size=world_size, instruct_tokenizer=instruct_tokenizer, sample_type=sample_type
|
392 |
+
)
|
393 |
+
|
394 |
+
if sample_type == SampleType.PRETRAIN:
|
395 |
+
assert chunk_dataset is False, "Pretrain data should not have chunking enabled."
|
396 |
+
|
397 |
+
main_logger_info(f"Shuffling {jsonl_file} ...")
|
398 |
+
rng.shuffle(tokens_list)
|
399 |
+
|
400 |
+
for token_sample in tokens_list:
|
401 |
+
yield token_sample
|
402 |
+
|
403 |
+
def lazy_load_and_yield(
|
404 |
+
jsonl_file: Path,
|
405 |
+
rank: int,
|
406 |
+
world_size: int,
|
407 |
+
instruct_tokenizer: InstructTokenizerBase,
|
408 |
+
sample_type: SampleType,
|
409 |
+
):
|
410 |
+
with jsonl_file.open() as file_handle:
|
411 |
+
for idx, line in enumerate(file_handle):
|
412 |
+
if not idx % world_size == rank:
|
413 |
+
continue
|
414 |
+
|
415 |
+
data = json.loads(line)
|
416 |
+
yield encode(
|
417 |
+
data,
|
418 |
+
instruct_tokenizer=instruct_tokenizer,
|
419 |
+
as_type=sample_type,
|
420 |
+
)
|
421 |
+
|
422 |
+
|
423 |
+
def maybe_chunk_lines(lines: List[str]) -> List[str]:
|
424 |
+
extra_lines: List[str] = []
|
425 |
+
for line in lines:
|
426 |
+
data = json.loads(line)
|
427 |
+
# mult-turn fn call data will be chunked and shorder conversations are added additionally
|
428 |
+
maybe_chunked_lines = maybe_chunk_data(data)
|
429 |
+
extra_lines.extend([json.dumps(line) for line in maybe_chunked_lines])
|
430 |
+
|
431 |
+
return extra_lines
|
432 |
+
|
433 |
+
|
434 |
+
def maybe_chunk_data(data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
435 |
+
# think about always allowing both open-ai and non-open-ai data
|
436 |
+
sample = build_instruct_sample(data)
|
437 |
+
|
438 |
+
def num_assistant_messages(sample: TrainingInstructSample) -> int:
|
439 |
+
return len(
|
440 |
+
[m for m in sample.messages if isinstance(m, FinetuningAssistantMessage)]
|
441 |
+
)
|
442 |
+
|
443 |
+
chunk_data = []
|
444 |
+
while sample.only_last is True and num_assistant_messages(sample) > 1:
|
445 |
+
assert sample == build_instruct_sample(sample.dict())
|
446 |
+
last_message = sample.messages.pop()
|
447 |
+
|
448 |
+
# 1. First pop until and including last assistant message
|
449 |
+
system_message = None
|
450 |
+
while not isinstance(last_message, FinetuningAssistantMessage):
|
451 |
+
last_message = sample.messages.pop()
|
452 |
+
if isinstance(last_message, SystemMessage):
|
453 |
+
system_message = last_message
|
454 |
+
|
455 |
+
# 2. Second pop until and excluding last assistant message
|
456 |
+
prev_last_message = sample.messages[-1]
|
457 |
+
while not isinstance(prev_last_message, FinetuningAssistantMessage):
|
458 |
+
last_message = sample.messages.pop()
|
459 |
+
if isinstance(last_message, SystemMessage):
|
460 |
+
system_message = last_message
|
461 |
+
|
462 |
+
prev_last_message = sample.messages[-1]
|
463 |
+
|
464 |
+
# if system_message is not None, append again
|
465 |
+
if system_message is not None:
|
466 |
+
sample.messages.append(system_message)
|
467 |
+
chunk_data.append(sample.dict())
|
468 |
+
|
469 |
+
return chunk_data
|
470 |
+
|
471 |
+
|
472 |
+
def interleave_iterators(iterators: List[Iterator], probabilities, rng):
|
473 |
+
while True:
|
474 |
+
it_id = rng.choice(range(len(iterators)), p=probabilities)
|
475 |
+
yield next(iterators[it_id])
|
finetune/data/exceptions.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class MessageFormatError(Exception):
|
2 |
+
def __init__(self, message, data):
|
3 |
+
self._message = message
|
4 |
+
self._begin_data = data[:20]
|
5 |
+
super().__init__()
|
6 |
+
|
7 |
+
def __str__(self):
|
8 |
+
return f"A message starting with {self._begin_data} is incorrectly formated." + self._message
|
9 |
+
|
10 |
+
|
11 |
+
class ToolCallFormatError(Exception):
|
12 |
+
def __init__(self, message, data):
|
13 |
+
self._message = message
|
14 |
+
self._begin_data = data[:20]
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
def __str__(self):
|
18 |
+
return f"A tool call assistant message starting with {self._begin_data} of the conversation is incorrectly formated. " + self._message
|
19 |
+
|
20 |
+
|
21 |
+
class FunctionFormatError(Exception):
|
22 |
+
def __init__(self, message, data):
|
23 |
+
self._message = message
|
24 |
+
self._begin_data = data[:20]
|
25 |
+
super().__init__()
|
26 |
+
|
27 |
+
def __str__(self):
|
28 |
+
return (
|
29 |
+
f"A function of the conversation starting with {self._begin_data} is incorrectly formated. "
|
30 |
+
+ self._message
|
31 |
+
)
|
32 |
+
|
33 |
+
|
34 |
+
class ConversationFormatError(Exception):
|
35 |
+
def __init__(self, message, data):
|
36 |
+
self._message = message
|
37 |
+
self._begin_data = data[:20]
|
38 |
+
super().__init__()
|
39 |
+
|
40 |
+
def __str__(self):
|
41 |
+
return (
|
42 |
+
f"A conversation starting with {self._begin_data} is incorrectly formated. " + self._message
|
43 |
+
)
|
44 |
+
|
45 |
+
|
46 |
+
class UnrecognizedRoleError(Exception):
|
47 |
+
def __init__(self, role, allowed_roles):
|
48 |
+
self._role = role
|
49 |
+
self._allowed_roles = allowed_roles
|
50 |
+
super().__init__()
|
51 |
+
|
52 |
+
def __str__(self):
|
53 |
+
return (
|
54 |
+
f"The following role: {self._role} is not recognized in line: {self.line} of the dataset {self.dataset}. Make sure that each role is one of {self._allowed_roles}"
|
55 |
+
+ self._message
|
56 |
+
)
|
finetune/data/tokenize.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from enum import Enum
|
4 |
+
from typing import Any, Dict, List, Optional, Union
|
5 |
+
|
6 |
+
from mistral_common.protocol.instruct.messages import (
|
7 |
+
FinetuningAssistantMessage,
|
8 |
+
Roles,
|
9 |
+
SystemMessage,
|
10 |
+
ToolMessage,
|
11 |
+
UserMessage,
|
12 |
+
)
|
13 |
+
from mistral_common.protocol.instruct.tool_calls import (
|
14 |
+
Function,
|
15 |
+
FunctionCall,
|
16 |
+
Tool,
|
17 |
+
ToolCall,
|
18 |
+
)
|
19 |
+
from mistral_common.protocol.instruct.validator import (
|
20 |
+
MistralRequestValidatorV3,
|
21 |
+
ValidationMode,
|
22 |
+
)
|
23 |
+
from mistral_common.tokens.instruct.request import InstructRequest
|
24 |
+
from mistral_common.tokens.tokenizers.base import Tokenizer
|
25 |
+
from mistral_common.tokens.tokenizers.sentencepiece import InstructTokenizerBase
|
26 |
+
|
27 |
+
from .exceptions import (
|
28 |
+
ConversationFormatError,
|
29 |
+
FunctionFormatError,
|
30 |
+
MessageFormatError,
|
31 |
+
ToolCallFormatError,
|
32 |
+
UnrecognizedRoleError,
|
33 |
+
)
|
34 |
+
|
35 |
+
logger = logging.getLogger("tokenize")
|
36 |
+
|
37 |
+
Sequence = List[int]
|
38 |
+
Mask = List[bool]
|
39 |
+
|
40 |
+
|
41 |
+
class TrainingInstructSample(InstructRequest):
|
42 |
+
available_tools: Optional[List[Tool]] = None
|
43 |
+
only_last: bool = False
|
44 |
+
|
45 |
+
|
46 |
+
@dataclass()
|
47 |
+
class TokenSample:
|
48 |
+
tokens: Sequence
|
49 |
+
masks: Mask
|
50 |
+
|
51 |
+
|
52 |
+
class SampleType(str, Enum):
|
53 |
+
PRETRAIN = "pretrain"
|
54 |
+
INSTRUCT = "instruct"
|
55 |
+
|
56 |
+
|
57 |
+
def encode(
|
58 |
+
data: Dict[str, Any],
|
59 |
+
instruct_tokenizer: InstructTokenizerBase,
|
60 |
+
as_type: SampleType,
|
61 |
+
) -> TokenSample:
|
62 |
+
sample: Union[str, TrainingInstructSample]
|
63 |
+
if as_type == SampleType.PRETRAIN:
|
64 |
+
sample = get_pretrain_sample(data)
|
65 |
+
elif as_type == SampleType.INSTRUCT:
|
66 |
+
sample = build_instruct_sample(data)
|
67 |
+
|
68 |
+
return tokenize(sample=sample, instruct_tokenizer=instruct_tokenizer)
|
69 |
+
|
70 |
+
|
71 |
+
def get_pretrain_sample(data: Dict[str, Any]) -> str:
|
72 |
+
content_keys = ["text", "content"]
|
73 |
+
assert not all(
|
74 |
+
k in data for k in content_keys
|
75 |
+
), "Make sure to have either 'text' or 'content' in your data. Not both."
|
76 |
+
assert any(
|
77 |
+
data.get(k) is not None for k in content_keys
|
78 |
+
), f"Must have one of 'text' or 'content' in your data. Only have {data.keys()}"
|
79 |
+
|
80 |
+
# get first non-None value
|
81 |
+
sample = None
|
82 |
+
for key in content_keys:
|
83 |
+
sample = data[key] if key in data else sample
|
84 |
+
|
85 |
+
assert isinstance(sample, str), sample
|
86 |
+
|
87 |
+
return sample
|
88 |
+
|
89 |
+
|
90 |
+
def build_instruct_sample(data: Dict[str, Any]) -> TrainingInstructSample:
|
91 |
+
messages: List[
|
92 |
+
SystemMessage | UserMessage | FinetuningAssistantMessage | ToolMessage
|
93 |
+
] = []
|
94 |
+
# optional data fields that might be set
|
95 |
+
available_tools: Optional[List[Tool]] = data.get("available_tools")
|
96 |
+
system_prompt = data.get("system_prompt")
|
97 |
+
|
98 |
+
messages_keys = ["messages", "interactions"]
|
99 |
+
content_keys = ["content", "text"] # both are accepted
|
100 |
+
allowed_roles = [role.value for role in Roles]
|
101 |
+
|
102 |
+
if not any(messages_key in data for messages_key in messages_keys):
|
103 |
+
err = f"The conversation does not contain one of '{', '.join(messages_keys)}' key, but only {', '.join(data.keys())}. Make sure that the conversation includes one of '{', '.join(messages_keys)}'."
|
104 |
+
raise ConversationFormatError(err, str(data))
|
105 |
+
|
106 |
+
if all(messages_key in data for messages_key in messages_keys):
|
107 |
+
err = f"The conversation cannot contain both of '{', '.join(messages_keys)}' key, but only one of the two."
|
108 |
+
raise ConversationFormatError(err, str(data))
|
109 |
+
|
110 |
+
# get first non-None value
|
111 |
+
data_messages: Optional[List[Dict[str, Any]]] = None
|
112 |
+
for key in messages_keys:
|
113 |
+
data_messages = data[key] if key in data else data_messages
|
114 |
+
|
115 |
+
assert data_messages is not None, "data_messages can't be None"
|
116 |
+
|
117 |
+
if "available_tools" in data and "tools" in data:
|
118 |
+
err = "The conversation contains both an `available_tools` and `tools` key. You can only have one."
|
119 |
+
raise ConversationFormatError(err, str(data))
|
120 |
+
|
121 |
+
if data.get("tools", None) is not None and len(data["tools"]) > 0:
|
122 |
+
available_tools = _parse_available_tools(data["tools"])
|
123 |
+
elif (
|
124 |
+
data.get("available_tools", None) is not None
|
125 |
+
and len(data["available_tools"]) > 0
|
126 |
+
):
|
127 |
+
available_tools = _parse_available_tools(data["available_tools"])
|
128 |
+
|
129 |
+
for data_message in data_messages:
|
130 |
+
is_tool_call = data_message.get("tool_calls") is not None
|
131 |
+
|
132 |
+
if "role" not in data_message:
|
133 |
+
err = f"A message does not contain a 'role' key, but only {', '.join(data_message.keys())}. Make sure that the message includes the key 'role'."
|
134 |
+
raise MessageFormatError(err, str(data))
|
135 |
+
|
136 |
+
role = data_message["role"]
|
137 |
+
|
138 |
+
if all(key in data_message for key in content_keys):
|
139 |
+
err = f"A {role} message contains both a 'text' and 'content' key. Make sure that there is only one of the two."
|
140 |
+
raise MessageFormatError(err, str(data))
|
141 |
+
|
142 |
+
content: Optional[str] = None
|
143 |
+
for key in content_keys:
|
144 |
+
content = content if content is not None else data_message.get(key)
|
145 |
+
|
146 |
+
# non-function call message must have content
|
147 |
+
if not is_tool_call and content is None:
|
148 |
+
err = f"A {role} message does not contain one of '{content_keys}' key, but only {', '.join(data_message.keys())}. Make sure that the message includes one of '{content_keys}' keys."
|
149 |
+
raise MessageFormatError(err, str(data))
|
150 |
+
|
151 |
+
if role not in allowed_roles:
|
152 |
+
raise UnrecognizedRoleError(role, allowed_roles)
|
153 |
+
|
154 |
+
if data_message["role"] == "user":
|
155 |
+
assert content is not None
|
156 |
+
messages.append(UserMessage(content=content))
|
157 |
+
elif data_message["role"] == "assistant":
|
158 |
+
tool_calls: Optional[List[ToolCall]] = None
|
159 |
+
|
160 |
+
if is_tool_call:
|
161 |
+
tool_calls = _parse_tool_calls(data_message["tool_calls"])
|
162 |
+
|
163 |
+
weight = data_message.get("weight")
|
164 |
+
messages.append(
|
165 |
+
FinetuningAssistantMessage(
|
166 |
+
content=content, tool_calls=tool_calls, weight=weight
|
167 |
+
)
|
168 |
+
)
|
169 |
+
elif data_message["role"] == "system":
|
170 |
+
if system_prompt is not None:
|
171 |
+
err = "Multiple messages with role 'system' encountered. Only one is allowed."
|
172 |
+
raise MessageFormatError(err, str(data))
|
173 |
+
|
174 |
+
system_prompt = content
|
175 |
+
elif data_message["role"] == "tool":
|
176 |
+
assert content is not None
|
177 |
+
tool_message = _parse_tool_message(content, data_message)
|
178 |
+
messages.append(tool_message)
|
179 |
+
|
180 |
+
# validate created messages
|
181 |
+
validator = MistralRequestValidatorV3(ValidationMode.finetuning)
|
182 |
+
validator.validate_messages(messages)
|
183 |
+
validator._validate_tools(available_tools or [])
|
184 |
+
|
185 |
+
# whether to train only on last assistant message
|
186 |
+
only_last = data.get("only_last", False) or available_tools is not None
|
187 |
+
|
188 |
+
return TrainingInstructSample(
|
189 |
+
messages=messages,
|
190 |
+
system_prompt=system_prompt,
|
191 |
+
available_tools=available_tools,
|
192 |
+
only_last=only_last,
|
193 |
+
)
|
194 |
+
|
195 |
+
|
196 |
+
def _parse_available_tools(tools: List[Dict[str, Any]]) -> List[Tool]:
|
197 |
+
available_tools = []
|
198 |
+
for tool in tools:
|
199 |
+
if "function" not in tool:
|
200 |
+
raise FunctionFormatError(
|
201 |
+
"A tool dict does not have a 'function' key.", str(tool)
|
202 |
+
)
|
203 |
+
|
204 |
+
func_data = tool["function"]
|
205 |
+
|
206 |
+
for key in ["name", "description", "parameters"]:
|
207 |
+
if key not in func_data:
|
208 |
+
raise FunctionFormatError(
|
209 |
+
f"A function dict does not have a {key} key.", str(func_data)
|
210 |
+
)
|
211 |
+
|
212 |
+
if not isinstance(func_data["parameters"], dict):
|
213 |
+
raise FunctionFormatError(
|
214 |
+
f"A function 'parameters' key has to be of type dict, but is {type(func_data['parameters'])}. If the function has no parameters pass an empyt dict ", str(func_data)
|
215 |
+
)
|
216 |
+
|
217 |
+
description = func_data["description"]
|
218 |
+
function = Function(
|
219 |
+
name=func_data["name"],
|
220 |
+
description=description,
|
221 |
+
parameters=func_data["parameters"],
|
222 |
+
)
|
223 |
+
|
224 |
+
available_tools.append(Tool(function=function))
|
225 |
+
return available_tools
|
226 |
+
|
227 |
+
|
228 |
+
def _parse_tool_calls(calls: List[Dict[str, Any]]) -> List[ToolCall]:
|
229 |
+
for key in ["id", "function"]:
|
230 |
+
if not all(key in call for call in calls):
|
231 |
+
err = f"A tool call of an assistant message does not have a {key} key"
|
232 |
+
raise ToolCallFormatError(err, str(calls))
|
233 |
+
|
234 |
+
for key in ["name", "arguments"]:
|
235 |
+
if not all(key in call["function"] for call in calls):
|
236 |
+
err = (
|
237 |
+
f"A tool call function of an assistant message does not have a {key} key"
|
238 |
+
)
|
239 |
+
raise ToolCallFormatError(err, str(calls))
|
240 |
+
|
241 |
+
if not all(isinstance(call["function"]["arguments"], str) for call in calls):
|
242 |
+
err = "A tool call function of an assistant message does not have a 'arguments' key of type str"
|
243 |
+
raise ToolCallFormatError(err, str(calls))
|
244 |
+
|
245 |
+
tool_calls = [
|
246 |
+
ToolCall(
|
247 |
+
id=call["id"],
|
248 |
+
function=FunctionCall(
|
249 |
+
name=call["function"]["name"],
|
250 |
+
arguments=call["function"]["arguments"],
|
251 |
+
),
|
252 |
+
)
|
253 |
+
for call in calls
|
254 |
+
]
|
255 |
+
return tool_calls
|
256 |
+
|
257 |
+
|
258 |
+
def _parse_tool_message(content: str, data_message: Dict[str, Any]) -> ToolMessage:
|
259 |
+
if "tool_call_id" not in data_message:
|
260 |
+
err = f"A tool message does not contain a 'tool_call_id' key, but only {', '.join(data_message.keys())}. Make sure that the message includes the key 'tool_call_id'."
|
261 |
+
raise MessageFormatError(err, str(data_message))
|
262 |
+
|
263 |
+
call_id = data_message["tool_call_id"]
|
264 |
+
# name is deprecated in v3, but we'll add it nevertheless for now
|
265 |
+
name = data_message.get("name")
|
266 |
+
|
267 |
+
return ToolMessage(content=content, tool_call_id=call_id, name=name)
|
268 |
+
|
269 |
+
|
270 |
+
def tokenize(
|
271 |
+
sample: Union[str, TrainingInstructSample],
|
272 |
+
instruct_tokenizer: InstructTokenizerBase,
|
273 |
+
) -> TokenSample:
|
274 |
+
if isinstance(sample, str):
|
275 |
+
tokenizer: Tokenizer = instruct_tokenizer.tokenizer
|
276 |
+
return tokenize_pretrain(sample, tokenizer)
|
277 |
+
elif isinstance(sample, TrainingInstructSample):
|
278 |
+
return tokenize_instruct(sample, instruct_tokenizer)
|
279 |
+
|
280 |
+
raise ValueError(
|
281 |
+
f"`sample` has to be either of type `str` or `TrainingInstructSample`, not {type(sample)}."
|
282 |
+
)
|
283 |
+
|
284 |
+
|
285 |
+
def tokenize_pretrain(sample: str, tokenizer: Tokenizer) -> TokenSample:
|
286 |
+
tokens = tokenizer.encode(sample, bos=True, eos=True)
|
287 |
+
masks = [True] * len(tokens)
|
288 |
+
return TokenSample(tokens, masks)
|
289 |
+
|
290 |
+
|
291 |
+
def tokenize_instruct(
|
292 |
+
sample: TrainingInstructSample,
|
293 |
+
instruct_tokenizer: InstructTokenizerBase,
|
294 |
+
) -> TokenSample:
|
295 |
+
tokens: List[int] = instruct_tokenizer.start()
|
296 |
+
masks: List[bool] = [False]
|
297 |
+
|
298 |
+
mask_all_but_last = sample.only_last
|
299 |
+
|
300 |
+
# find first and last user message
|
301 |
+
user_messages = [
|
302 |
+
i for i, msg in enumerate(sample.messages) if isinstance(msg, UserMessage)
|
303 |
+
]
|
304 |
+
first_user_idx = user_messages[0] if user_messages else -1
|
305 |
+
last_user_idx = user_messages[-1] if user_messages else -1
|
306 |
+
|
307 |
+
for msg_idx, message in enumerate(sample.messages):
|
308 |
+
if isinstance(message, UserMessage):
|
309 |
+
curr_tokens = instruct_tokenizer.encode_user_message(
|
310 |
+
message,
|
311 |
+
available_tools=sample.available_tools,
|
312 |
+
is_last=msg_idx == last_user_idx,
|
313 |
+
is_first=msg_idx == first_user_idx,
|
314 |
+
system_prompt=sample.system_prompt,
|
315 |
+
)
|
316 |
+
curr_masks = [False] * len(curr_tokens) # only predict bot answers
|
317 |
+
elif isinstance(message, ToolMessage):
|
318 |
+
curr_tokens = instruct_tokenizer.encode_tool_message(
|
319 |
+
message, is_before_last_user_message=msg_idx < last_user_idx
|
320 |
+
)
|
321 |
+
curr_masks = [False] * len(curr_tokens) # only predict bot answers
|
322 |
+
elif isinstance(message, FinetuningAssistantMessage):
|
323 |
+
is_last_message = msg_idx == (len(sample.messages) - 1)
|
324 |
+
|
325 |
+
# we don't want to predict a random call id
|
326 |
+
message = maybe_remove_call_id(message, is_last_message=is_last_message)
|
327 |
+
|
328 |
+
curr_tokens = instruct_tokenizer.encode_assistant_message(
|
329 |
+
message, is_before_last_user_message=False
|
330 |
+
)
|
331 |
+
|
332 |
+
is_weighted = message.weight is None or message.weight == 1
|
333 |
+
is_relevant = (not mask_all_but_last) or is_last_message
|
334 |
+
if is_weighted and is_relevant:
|
335 |
+
curr_masks = [True] * len(curr_tokens) # only predict bot answers
|
336 |
+
else:
|
337 |
+
# in function calling we only backprop through last message
|
338 |
+
curr_masks = [False] * len(curr_tokens)
|
339 |
+
|
340 |
+
tokens.extend(curr_tokens)
|
341 |
+
masks.extend(curr_masks)
|
342 |
+
|
343 |
+
return TokenSample(tokens, masks)
|
344 |
+
|
345 |
+
|
346 |
+
def maybe_remove_call_id(message: FinetuningAssistantMessage, is_last_message: bool):
|
347 |
+
if message.tool_calls is None or not is_last_message:
|
348 |
+
return message
|
349 |
+
|
350 |
+
# remove call id
|
351 |
+
message.tool_calls = [
|
352 |
+
ToolCall(function=call.function) for call in message.tool_calls
|
353 |
+
]
|
354 |
+
|
355 |
+
return message
|
finetune/distributed.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from functools import lru_cache
|
4 |
+
from typing import List, Union
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.distributed as dist
|
8 |
+
|
9 |
+
logger = logging.getLogger("distributed")
|
10 |
+
|
11 |
+
BACKEND = "nccl"
|
12 |
+
|
13 |
+
|
14 |
+
@lru_cache()
|
15 |
+
def get_rank() -> int:
|
16 |
+
return dist.get_rank()
|
17 |
+
|
18 |
+
|
19 |
+
@lru_cache()
|
20 |
+
def get_world_size() -> int:
|
21 |
+
return dist.get_world_size()
|
22 |
+
|
23 |
+
|
24 |
+
def visible_devices() -> List[int]:
|
25 |
+
return [int(d) for d in os.environ["CUDA_VISIBLE_DEVICES"].split(",")]
|
26 |
+
|
27 |
+
|
28 |
+
def set_device():
|
29 |
+
logger.info(f"torch.cuda.device_count: {torch.cuda.device_count()}")
|
30 |
+
logger.info(f"CUDA_VISIBLE_DEVICES: {os.environ['CUDA_VISIBLE_DEVICES']}")
|
31 |
+
logger.info(f"local rank: {int(os.environ['LOCAL_RANK'])}")
|
32 |
+
|
33 |
+
assert torch.cuda.is_available()
|
34 |
+
|
35 |
+
assert len(visible_devices()) == torch.cuda.device_count()
|
36 |
+
|
37 |
+
if torch.cuda.device_count() == 1:
|
38 |
+
# gpus-per-task set to 1
|
39 |
+
torch.cuda.set_device(0)
|
40 |
+
return
|
41 |
+
|
42 |
+
local_rank = int(os.environ["LOCAL_RANK"])
|
43 |
+
logger.info(f"Set cuda device to {local_rank}")
|
44 |
+
|
45 |
+
assert 0 <= local_rank < torch.cuda.device_count(), (
|
46 |
+
local_rank,
|
47 |
+
torch.cuda.device_count(),
|
48 |
+
)
|
49 |
+
torch.cuda.set_device(local_rank)
|
50 |
+
|
51 |
+
|
52 |
+
def avg_aggregate(metric: Union[float, int]) -> Union[float, int]:
|
53 |
+
buffer = torch.tensor([metric], dtype=torch.float32, device="cuda")
|
54 |
+
dist.all_reduce(buffer, op=dist.ReduceOp.SUM)
|
55 |
+
return buffer[0].item() / get_world_size()
|
56 |
+
|
57 |
+
|
58 |
+
def is_torchrun() -> bool:
|
59 |
+
return "TORCHELASTIC_RESTART_COUNT" in os.environ
|
finetune/eval.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch.cuda
|
6 |
+
import torch.distributed as dist
|
7 |
+
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
|
8 |
+
|
9 |
+
from .data.data_loader import Batch
|
10 |
+
from .distributed import get_rank, get_world_size
|
11 |
+
from .loss import compute_loss_with_mask
|
12 |
+
from .utils import TrainState
|
13 |
+
|
14 |
+
logger = logging.getLogger("eval")
|
15 |
+
|
16 |
+
|
17 |
+
def main_logger_info(message: str) -> None:
|
18 |
+
if get_rank() == 0:
|
19 |
+
logger.info(message)
|
20 |
+
|
21 |
+
|
22 |
+
def evaluate(
|
23 |
+
model: FullyShardedDataParallel,
|
24 |
+
batches: List[Batch],
|
25 |
+
state: TrainState,
|
26 |
+
):
|
27 |
+
# Create fake samples to make FSDP happy for unbalanced data
|
28 |
+
num_samples = torch.tensor([len(batches)], device="cuda", dtype=torch.long)
|
29 |
+
all_num_samples = [torch.zeros_like(num_samples) for _ in range(get_world_size())]
|
30 |
+
|
31 |
+
torch.distributed.all_gather(all_num_samples, num_samples)
|
32 |
+
|
33 |
+
total_num_samples = int(torch.tensor(all_num_samples).sum().item())
|
34 |
+
max_num_samples = int(torch.tensor(all_num_samples).max().item())
|
35 |
+
|
36 |
+
for _ in range(max_num_samples - int(num_samples.item())):
|
37 |
+
pad_x = np.zeros_like(batches[-1].x)
|
38 |
+
pad_y = np.zeros_like(batches[-1].y)
|
39 |
+
pad_sizes = batches[-1].sizes.copy()
|
40 |
+
|
41 |
+
pad_batch = Batch(pad_x, pad_y, pad_sizes, is_pad_only=True)
|
42 |
+
batches.append(pad_batch)
|
43 |
+
|
44 |
+
# eval mode!
|
45 |
+
model.eval()
|
46 |
+
|
47 |
+
eval_loss = torch.tensor(0.0).cuda()
|
48 |
+
main_logger_info("Start eval...")
|
49 |
+
for batch in batches:
|
50 |
+
x = torch.from_numpy(batch.x).cuda()
|
51 |
+
y = torch.from_numpy(batch.y).cuda()
|
52 |
+
y_mask = (
|
53 |
+
torch.from_numpy(batch.y_mask).cuda() if batch.y_mask is not None else None
|
54 |
+
)
|
55 |
+
|
56 |
+
with torch.no_grad():
|
57 |
+
output = model(
|
58 |
+
input_ids=x,
|
59 |
+
seqlens=batch.sizes,
|
60 |
+
)
|
61 |
+
|
62 |
+
if y_mask.sum() > 0:
|
63 |
+
eval_loss += compute_loss_with_mask(output, y, y_mask)
|
64 |
+
|
65 |
+
assert batch.is_pad_only or y.abs().sum() != 0, "Pad sample is used to compute loss."
|
66 |
+
|
67 |
+
# sum loss
|
68 |
+
main_logger_info("Eval finished!")
|
69 |
+
|
70 |
+
dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
|
71 |
+
eval_loss /= total_num_samples
|
72 |
+
|
73 |
+
state.this_eval_loss = eval_loss.item()
|
74 |
+
state.this_eval_perplexity = (2**eval_loss).item()
|
75 |
+
|
76 |
+
# train mode!
|
77 |
+
model.train()
|
finetune/loss.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
|
7 |
+
def compute_loss_with_mask(
|
8 |
+
logits: torch.Tensor, target: torch.Tensor, target_mask: Optional[torch.Tensor]
|
9 |
+
):
|
10 |
+
if target_mask is None:
|
11 |
+
return F.cross_entropy(logits, target, reduction="mean")
|
12 |
+
|
13 |
+
mb_loss = F.cross_entropy(logits, target, reduction="none")
|
14 |
+
mb_loss = torch.sum(mb_loss * target_mask) / torch.sum(target_mask)
|
15 |
+
|
16 |
+
return mb_loss
|
finetune/mixed_precision.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Iterable
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def prepare_mixed_precision(
|
7 |
+
params: Iterable[torch.nn.Parameter],
|
8 |
+
param_dtype: torch.dtype,
|
9 |
+
optim_dtype: torch.dtype,
|
10 |
+
):
|
11 |
+
"""Appends a freshly allocated fp32 tensor copy of all params to parameters that can be updated."""
|
12 |
+
with torch.no_grad():
|
13 |
+
for p in params:
|
14 |
+
if p.requires_grad:
|
15 |
+
# Mixed precision: let's save a fp32 param tensor to each params that require a grad
|
16 |
+
p._mp_param = torch.empty_like(p, dtype=optim_dtype) # type: ignore
|
17 |
+
p._mp_param.copy_(p.to(optim_dtype)) # type: ignore
|
18 |
+
|
19 |
+
p.data = p.data.to(param_dtype)
|
20 |
+
|
21 |
+
|
22 |
+
def upcast_mixed_precision(
|
23 |
+
params: Iterable[torch.nn.Parameter], optim_dtype: torch.dtype
|
24 |
+
):
|
25 |
+
"""Make sure to run this function BEFORE optimizer.step() so that all weights and optimizer states are updated in fp32 in .step()"""
|
26 |
+
with torch.no_grad():
|
27 |
+
for p in params:
|
28 |
+
if p.requires_grad and p.grad is not None:
|
29 |
+
# store original tensor in p._temp
|
30 |
+
p._temp = p.data # type: ignore
|
31 |
+
# upcast data for the optimizer step
|
32 |
+
p.data = p._mp_param # type: ignore
|
33 |
+
p.grad = p.grad.to(optim_dtype)
|
34 |
+
|
35 |
+
|
36 |
+
def downcast_mixed_precision(
|
37 |
+
params: Iterable[torch.nn.Parameter], param_dtype: torch.dtype
|
38 |
+
):
|
39 |
+
"""Make sure to run this function AFTER optimizer.step() as optimizer.step() will update data underlying p.data and p._mp_param pointers"""
|
40 |
+
with torch.no_grad():
|
41 |
+
for p in params:
|
42 |
+
if p.requires_grad and p.grad is not None:
|
43 |
+
# copy fp32 weights into bfloat16 tensor
|
44 |
+
p._temp.copy_(p.data) # type: ignore
|
45 |
+
# set _temp again to the data tensor
|
46 |
+
p.data = p._temp # type: ignore
|
47 |
+
p.grad = p.grad.to(param_dtype)
|
finetune/monitoring/__init__.py
ADDED
File without changes
|
finetune/monitoring/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (147 Bytes). View file
|
|
finetune/monitoring/__pycache__/metrics_logger.cpython-310.pyc
ADDED
Binary file (5.46 kB). View file
|
|
finetune/monitoring/__pycache__/utils.cpython-310.pyc
ADDED
Binary file (1.27 kB). View file
|
|
finetune/monitoring/metrics_logger.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
from datetime import datetime, timedelta
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Any, Dict, Optional, Union
|
7 |
+
|
8 |
+
from torch.utils.tensorboard import SummaryWriter
|
9 |
+
|
10 |
+
from finetune.args import MLFlowArgs, TrainArgs, WandbArgs
|
11 |
+
from finetune.utils import TrainState
|
12 |
+
|
13 |
+
logger = logging.getLogger("metrics_logger")
|
14 |
+
|
15 |
+
GB = 1024**3
|
16 |
+
|
17 |
+
|
18 |
+
def get_train_logs(
|
19 |
+
state: TrainState,
|
20 |
+
loss: float,
|
21 |
+
lr: float,
|
22 |
+
peak_allocated_mem: float,
|
23 |
+
allocated_mem: float,
|
24 |
+
train_args: TrainArgs,
|
25 |
+
) -> Dict[str, Union[float, int]]:
|
26 |
+
metrics = {
|
27 |
+
"lr": lr,
|
28 |
+
"step": state.step,
|
29 |
+
"loss": loss,
|
30 |
+
"percent_done": 100 * state.step / train_args.max_steps,
|
31 |
+
"peak_allocated_mem": peak_allocated_mem / GB,
|
32 |
+
"allocated_mem": allocated_mem / GB,
|
33 |
+
"wps": state.wps,
|
34 |
+
"avg_wps": state.avg_wps,
|
35 |
+
"eta_in_seconds": state.eta,
|
36 |
+
}
|
37 |
+
|
38 |
+
return metrics
|
39 |
+
|
40 |
+
|
41 |
+
def get_eval_logs(
|
42 |
+
step: int,
|
43 |
+
train_loss: float,
|
44 |
+
perplexity: Optional[float],
|
45 |
+
eval_loss: Optional[float],
|
46 |
+
) -> Dict[str, Union[float, int]]:
|
47 |
+
eval_dict = {"step": step, "train_loss": train_loss}
|
48 |
+
|
49 |
+
if perplexity is not None:
|
50 |
+
eval_dict["perplexity"] = perplexity
|
51 |
+
|
52 |
+
if eval_loss is not None:
|
53 |
+
eval_dict["eval_loss"] = eval_loss
|
54 |
+
return eval_dict
|
55 |
+
|
56 |
+
|
57 |
+
def train_log_msg(
|
58 |
+
state: TrainState, logs: Dict[str, Union[float, int]], loss: float
|
59 |
+
) -> str:
|
60 |
+
metrics: Dict[str, Union[float, int, datetime]] = dict(logs) # shallow copy
|
61 |
+
metrics.pop("eta_in_seconds")
|
62 |
+
|
63 |
+
metrics["eta"] = datetime.now() + timedelta(seconds=state.eta)
|
64 |
+
metrics["step"] = state.step
|
65 |
+
metrics["loss"] = loss
|
66 |
+
|
67 |
+
parts = []
|
68 |
+
for key, fmt, new_name in [
|
69 |
+
("step", "06", None),
|
70 |
+
("percent_done", "03.1f", "done (%)"),
|
71 |
+
("loss", ".3f", None),
|
72 |
+
("lr", ".1e", None),
|
73 |
+
("peak_allocated_mem", ".1f", "peak_alloc_mem (GB)"),
|
74 |
+
("allocated_mem", ".1f", "alloc_mem (GB)"),
|
75 |
+
("wps", ".1f", "words_per_second"),
|
76 |
+
("avg_wps", ".1f", "avg_words_per_second"),
|
77 |
+
("eta", "%Y-%m-%d %H:%M:%S", "ETA"),
|
78 |
+
]:
|
79 |
+
name = key if new_name is None else new_name
|
80 |
+
try:
|
81 |
+
parts.append(f"{name}: {metrics[key]:>{fmt}}")
|
82 |
+
except KeyError:
|
83 |
+
logger.error(f"{key} not found in {sorted(metrics.keys())}")
|
84 |
+
raise
|
85 |
+
|
86 |
+
return " - ".join(parts)
|
87 |
+
|
88 |
+
|
89 |
+
def eval_log_msg(logs: Dict[str, Union[float, int]]) -> str:
|
90 |
+
parts = []
|
91 |
+
for key, fmt, new_name in [
|
92 |
+
("step", "06", None),
|
93 |
+
("perplexity", ".3f", "eval_perplexity"),
|
94 |
+
("eval_loss", ".3f", None),
|
95 |
+
("train_loss", ".3f", None),
|
96 |
+
]:
|
97 |
+
name = key if new_name is None else new_name
|
98 |
+
if key in logs:
|
99 |
+
parts.append(f"{name}: {logs[key]:>{fmt}}")
|
100 |
+
|
101 |
+
return " - ".join(parts)
|
102 |
+
|
103 |
+
|
104 |
+
class MetricsLogger:
|
105 |
+
def __init__(
|
106 |
+
self,
|
107 |
+
dst_dir: Path,
|
108 |
+
tag: str,
|
109 |
+
is_master: bool,
|
110 |
+
wandb_args: WandbArgs,
|
111 |
+
mlflow_args: MLFlowArgs,
|
112 |
+
config: Optional[Dict[str, Any]] = None,
|
113 |
+
):
|
114 |
+
self.dst_dir = dst_dir
|
115 |
+
self.tag = tag
|
116 |
+
self.is_master = is_master
|
117 |
+
self.jsonl_path = dst_dir / f"metrics.{tag}.jsonl"
|
118 |
+
self.tb_dir = dst_dir / "tb"
|
119 |
+
self.summary_writer: Optional[SummaryWriter] = None
|
120 |
+
|
121 |
+
if not self.is_master:
|
122 |
+
return
|
123 |
+
|
124 |
+
filename_suffix = f".{tag}"
|
125 |
+
self.tb_dir.mkdir(exist_ok=True)
|
126 |
+
self.summary_writer = SummaryWriter(
|
127 |
+
log_dir=str(self.tb_dir),
|
128 |
+
max_queue=1000,
|
129 |
+
filename_suffix=filename_suffix,
|
130 |
+
)
|
131 |
+
self.is_wandb = wandb_args.project is not None
|
132 |
+
self.is_mlflow = mlflow_args.tracking_uri is not None
|
133 |
+
|
134 |
+
if self.is_wandb:
|
135 |
+
import wandb
|
136 |
+
|
137 |
+
if wandb_args.key is not None:
|
138 |
+
wandb.login(key=wandb_args.key) # LLM
|
139 |
+
if wandb_args.offline:
|
140 |
+
os.environ["WANDB_MODE"] = "offline"
|
141 |
+
if wandb.run is None:
|
142 |
+
logger.info("initializing wandb")
|
143 |
+
wandb.init(
|
144 |
+
config=config,
|
145 |
+
dir=dst_dir,
|
146 |
+
project=wandb_args.project,
|
147 |
+
job_type="training",
|
148 |
+
name=wandb_args.run_name or dst_dir.name,
|
149 |
+
resume=False,
|
150 |
+
)
|
151 |
+
|
152 |
+
self.wandb_log = wandb.log
|
153 |
+
|
154 |
+
if self.is_mlflow:
|
155 |
+
import mlflow
|
156 |
+
|
157 |
+
mlflow.set_tracking_uri(mlflow_args.tracking_uri)
|
158 |
+
mlflow.set_experiment(mlflow_args.experiment_name or dst_dir.name)
|
159 |
+
|
160 |
+
if tag == "train":
|
161 |
+
mlflow.start_run()
|
162 |
+
|
163 |
+
self.mlflow_log = mlflow.log_metric
|
164 |
+
|
165 |
+
def log(self, metrics: Dict[str, Union[float, int]], step: int):
|
166 |
+
if not self.is_master:
|
167 |
+
return
|
168 |
+
|
169 |
+
metrics_to_ignore = {"step"}
|
170 |
+
assert self.summary_writer is not None
|
171 |
+
for key, value in metrics.items():
|
172 |
+
if key in metrics_to_ignore:
|
173 |
+
continue
|
174 |
+
assert isinstance(value, (int, float)), (key, value)
|
175 |
+
self.summary_writer.add_scalar(
|
176 |
+
tag=f"{self.tag}.{key}", scalar_value=value, global_step=step
|
177 |
+
)
|
178 |
+
|
179 |
+
if self.is_mlflow:
|
180 |
+
self.mlflow_log(f"{self.tag}.{key}", value, step=step)
|
181 |
+
|
182 |
+
if self.is_wandb:
|
183 |
+
# grouping in wandb is done with /
|
184 |
+
self.wandb_log(
|
185 |
+
{
|
186 |
+
f"{self.tag}/{key}": value
|
187 |
+
for key, value in metrics.items()
|
188 |
+
if key not in metrics_to_ignore
|
189 |
+
},
|
190 |
+
step=step,
|
191 |
+
)
|
192 |
+
|
193 |
+
metrics_: Dict[str, Any] = dict(metrics) # shallow copy
|
194 |
+
if "step" in metrics_:
|
195 |
+
assert step == metrics_["step"]
|
196 |
+
else:
|
197 |
+
metrics_["step"] = step
|
198 |
+
metrics_["at"] = datetime.utcnow().isoformat()
|
199 |
+
with self.jsonl_path.open("a") as fp:
|
200 |
+
fp.write(f"{json.dumps(metrics_)}\n")
|
201 |
+
|
202 |
+
def close(self):
|
203 |
+
if not self.is_master:
|
204 |
+
return
|
205 |
+
|
206 |
+
if self.summary_writer is not None:
|
207 |
+
self.summary_writer.close()
|
208 |
+
self.summary_writer = None
|
209 |
+
|
210 |
+
if self.is_wandb:
|
211 |
+
import wandb
|
212 |
+
|
213 |
+
# to be sure we are not hanging while finishing
|
214 |
+
wandb.finish()
|
215 |
+
|
216 |
+
if self.is_mlflow:
|
217 |
+
import mlflow
|
218 |
+
|
219 |
+
mlflow.end_run()
|
220 |
+
|
221 |
+
def __del__(self):
|
222 |
+
if self.summary_writer is not None:
|
223 |
+
raise RuntimeError(
|
224 |
+
"MetricsLogger not closed properly! You should "
|
225 |
+
"make sure the close() method is called!"
|
226 |
+
)
|
finetune/monitoring/utils.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import logging
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
|
6 |
+
|
7 |
+
class DeltaTimeFormatter(logging.Formatter):
|
8 |
+
def format(self, record):
|
9 |
+
delta = datetime.timedelta(
|
10 |
+
seconds=int(record.relativeCreated / 1000)
|
11 |
+
) # no milliseconds
|
12 |
+
record.delta = delta
|
13 |
+
return super().format(record)
|
14 |
+
|
15 |
+
|
16 |
+
def set_logger(level: int = logging.INFO):
|
17 |
+
root = logging.getLogger()
|
18 |
+
root.handlers.clear()
|
19 |
+
root.setLevel(level)
|
20 |
+
tz, *_ = time.tzname
|
21 |
+
|
22 |
+
LOGFORMAT = "%(asctime)s - %(delta)s - %(name)s - %(levelname)s - %(message)s"
|
23 |
+
TIMEFORMAT = f"%Y-%m-%d %H:%M:%S ({tz})"
|
24 |
+
formatter = DeltaTimeFormatter(LOGFORMAT, TIMEFORMAT)
|
25 |
+
|
26 |
+
handler = logging.StreamHandler(sys.stdout)
|
27 |
+
handler.setLevel(level)
|
28 |
+
handler.setFormatter(formatter)
|
29 |
+
root.addHandler(handler)
|
30 |
+
|
31 |
+
handler = logging.StreamHandler(sys.stderr)
|
32 |
+
handler.setLevel(logging.WARNING)
|
33 |
+
handler.setFormatter(formatter)
|
34 |
+
root.addHandler(handler)
|
finetune/utils.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import dataclasses
|
3 |
+
import datetime
|
4 |
+
import logging
|
5 |
+
import time
|
6 |
+
from typing import Optional, Protocol
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
logger = logging.getLogger("utils")
|
11 |
+
|
12 |
+
|
13 |
+
@dataclasses.dataclass
|
14 |
+
class TrainState:
|
15 |
+
max_steps: int
|
16 |
+
step: int = 0
|
17 |
+
elapsed_time: float = 0.0
|
18 |
+
n_seen_tokens: int = 0
|
19 |
+
this_step_time: float = 0.0
|
20 |
+
begin_step_time: float = 0.0
|
21 |
+
this_eval_perplexity: Optional[float] = None
|
22 |
+
this_eval_loss: Optional[float] = None
|
23 |
+
|
24 |
+
def start_step(self):
|
25 |
+
self.step += 1
|
26 |
+
self.begin_step_time = time.time()
|
27 |
+
|
28 |
+
def end_step(self, n_batch_tokens: int):
|
29 |
+
self.this_step_time = time.time() - self.begin_step_time
|
30 |
+
self.this_step_tokens = n_batch_tokens
|
31 |
+
|
32 |
+
self.elapsed_time += self.this_step_time
|
33 |
+
self.n_seen_tokens += self.this_step_tokens
|
34 |
+
|
35 |
+
self.begin_step_time = time.time()
|
36 |
+
|
37 |
+
@property
|
38 |
+
def wps(self):
|
39 |
+
return self.this_step_tokens / self.this_step_time
|
40 |
+
|
41 |
+
@property
|
42 |
+
def avg_wps(self):
|
43 |
+
return self.n_seen_tokens / self.elapsed_time
|
44 |
+
|
45 |
+
@property
|
46 |
+
def eta(self):
|
47 |
+
steps_left = self.max_steps - self.step
|
48 |
+
avg_time_per_step = self.elapsed_time / self.step
|
49 |
+
|
50 |
+
return steps_left * avg_time_per_step
|
51 |
+
|
52 |
+
|
53 |
+
def set_random_seed(seed: int) -> None:
|
54 |
+
"""Set random seed for reproducibility."""
|
55 |
+
torch.manual_seed(seed)
|
56 |
+
torch.cuda.manual_seed(seed)
|
57 |
+
|
58 |
+
|
59 |
+
class Closable(Protocol):
|
60 |
+
def close(self):
|
61 |
+
pass
|
62 |
+
|
63 |
+
|
64 |
+
@contextlib.contextmanager
|
65 |
+
def logged_closing(thing: Closable, name: str):
|
66 |
+
"""
|
67 |
+
Logging the closing to be sure something is not hanging at exit time
|
68 |
+
"""
|
69 |
+
try:
|
70 |
+
setattr(thing, "wrapped_by_closing", True)
|
71 |
+
yield
|
72 |
+
finally:
|
73 |
+
logger.info(f"Closing: {name}")
|
74 |
+
try:
|
75 |
+
thing.close()
|
76 |
+
except Exception:
|
77 |
+
logger.error(f"Error while closing {name}!")
|
78 |
+
raise
|
79 |
+
logger.info(f"Closed: {name}")
|
80 |
+
|
81 |
+
|
82 |
+
def now_as_str() -> str:
|
83 |
+
return datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
finetune/wrapped_model.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import math
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Callable, Union
|
7 |
+
|
8 |
+
import safetensors
|
9 |
+
import torch
|
10 |
+
import torch.distributed.fsdp.wrap as torch_wrap
|
11 |
+
from torch.distributed.fsdp import BackwardPrefetch
|
12 |
+
from torch.distributed.fsdp.api import ShardingStrategy
|
13 |
+
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
|
14 |
+
|
15 |
+
from model.args import ModelArgs, MoeArgs
|
16 |
+
from model.transformer import Transformer, TransformerBlock
|
17 |
+
|
18 |
+
from .args import LoraArgs
|
19 |
+
from .checkpointing import Checkpointer
|
20 |
+
from .distributed import (
|
21 |
+
get_rank,
|
22 |
+
get_world_size,
|
23 |
+
)
|
24 |
+
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
|
28 |
+
def main_logger_info(message: str) -> None:
|
29 |
+
if get_rank() == 0:
|
30 |
+
logger.info(message)
|
31 |
+
|
32 |
+
|
33 |
+
def get_fsdp_policy(is_lora: bool) -> Callable[[torch.nn.Module], bool]:
|
34 |
+
"""
|
35 |
+
This function instantiates the FSDP wrap policy.
|
36 |
+
- Each Transformers block becomes it's own FSDP group so that only a single Transformer block is sharded at a time
|
37 |
+
- If LoRA is enabled, we additionally create seperate FSDP sub-groups for every trainable and non-trainable parameter group
|
38 |
+
since this is a requirement for mixed requires_grad=True/False training. See: https://pytorch.org/docs/stable/fsdp.html
|
39 |
+
"""
|
40 |
+
|
41 |
+
# Each transformer block becomes a FSDP group, each being sharded seperately
|
42 |
+
transformer_block_wrap_policy = functools.partial(
|
43 |
+
torch_wrap.transformer_auto_wrap_policy,
|
44 |
+
transformer_layer_cls=(TransformerBlock,),
|
45 |
+
)
|
46 |
+
|
47 |
+
if not is_lora:
|
48 |
+
return transformer_block_wrap_policy
|
49 |
+
|
50 |
+
def fsdp_lora_policy_fn(module):
|
51 |
+
return all(p.requires_grad for p in module.parameters())
|
52 |
+
|
53 |
+
# For LoRA training, trainable and non-trainable parameters need to be put into
|
54 |
+
# different FSDP groups
|
55 |
+
fsdp_lora_policy = functools.partial(
|
56 |
+
torch_wrap.lambda_auto_wrap_policy, lambda_fn=fsdp_lora_policy_fn
|
57 |
+
)
|
58 |
+
|
59 |
+
policies = [fsdp_lora_policy, transformer_block_wrap_policy]
|
60 |
+
|
61 |
+
return functools.partial(torch_wrap._or_policy, policies=policies)
|
62 |
+
|
63 |
+
|
64 |
+
def log_train_params(model: Union[torch.nn.Module, FullyShardedDataParallel]):
|
65 |
+
world_size = get_world_size()
|
66 |
+
|
67 |
+
num_params = world_size * sum(p.numel() for p in model.parameters())
|
68 |
+
num_train_params = world_size * sum(
|
69 |
+
p.numel() for p in model.parameters() if p.requires_grad
|
70 |
+
)
|
71 |
+
|
72 |
+
main_logger_info(
|
73 |
+
f"{num_train_params:,.0f} out of {num_params:,.0f} parameter are finetuned ({num_train_params / num_params * 100:.2f}%)."
|
74 |
+
)
|
75 |
+
|
76 |
+
|
77 |
+
def initialize_lora_parameters(model: torch.nn.Module, param_dtype: torch.dtype):
|
78 |
+
"""
|
79 |
+
Initialize LoRA layers with Kaiming uniform and zeros.
|
80 |
+
See original paper for more info: https://arxiv.org/abs/2106.09685 and
|
81 |
+
original github repo: https://github.com/microsoft/LoRA/blob/a0a92e0f26c067cf94747bdbf1ce73793fa44d19/loralib/layers.py#L122
|
82 |
+
"""
|
83 |
+
for m_name, module in model.named_modules():
|
84 |
+
if all(p.is_meta for p in module.parameters()):
|
85 |
+
for p_name, param in module.named_parameters():
|
86 |
+
module._parameters[p_name] = torch.nn.Parameter(
|
87 |
+
torch.empty_like(param, device="cpu", dtype=param_dtype)
|
88 |
+
)
|
89 |
+
param = module._parameters[p_name]
|
90 |
+
|
91 |
+
if m_name.split(".")[-1] == "lora_A":
|
92 |
+
torch.nn.init.kaiming_uniform_(param, a=math.sqrt(5))
|
93 |
+
elif m_name.split(".")[-1] == "lora_B":
|
94 |
+
torch.nn.init.zeros_(param)
|
95 |
+
else:
|
96 |
+
raise ValueError(
|
97 |
+
"Only Lora layers should be randomely initialized."
|
98 |
+
)
|
99 |
+
|
100 |
+
|
101 |
+
def load_model(
|
102 |
+
folder: Path,
|
103 |
+
lora: LoraArgs,
|
104 |
+
checkpoint: bool,
|
105 |
+
param_dtype: torch.dtype,
|
106 |
+
) -> FullyShardedDataParallel:
|
107 |
+
with open(folder / "params.json", "r") as f:
|
108 |
+
args = json.loads(f.read())
|
109 |
+
|
110 |
+
model_args = ModelArgs(
|
111 |
+
lora=lora,
|
112 |
+
dim=args["dim"],
|
113 |
+
n_layers=args["n_layers"],
|
114 |
+
head_dim=args["head_dim"],
|
115 |
+
hidden_dim=args["hidden_dim"],
|
116 |
+
n_heads=args["n_heads"],
|
117 |
+
n_kv_heads=args["n_kv_heads"],
|
118 |
+
norm_eps=args["norm_eps"],
|
119 |
+
vocab_size=args["vocab_size"],
|
120 |
+
)
|
121 |
+
|
122 |
+
if model_args.vocab_size == 32000:
|
123 |
+
raise ValueError(
|
124 |
+
f"Fine-tuning is not supported for older model versions with vocab_size 32000. Make sure to extend your model to vocab_size=32768 using `python -m utils.extend_model_vocab --original_model_ckpt {folder} --extended_model_ckpt {folder}_extended`."
|
125 |
+
)
|
126 |
+
|
127 |
+
assert (
|
128 |
+
model_args.vocab_size >= 32768
|
129 |
+
), "Make sure to use a model with a vocab size of at least 32768"
|
130 |
+
|
131 |
+
if args.get("rope_theta") is not None:
|
132 |
+
model_args.rope_theta = args["rope_theta"]
|
133 |
+
|
134 |
+
if args.get("moe") is not None:
|
135 |
+
model_args.moe = MoeArgs(**args["moe"])
|
136 |
+
|
137 |
+
with torch.device("meta"):
|
138 |
+
model = Transformer(args=model_args, checkpoint=checkpoint)
|
139 |
+
|
140 |
+
if get_rank() == 0:
|
141 |
+
state_dict = load_state_dict(folder, dtype=param_dtype)
|
142 |
+
|
143 |
+
model.load_state_dict(state_dict, assign=True) # type: ignore
|
144 |
+
logger.info("Loaded model on cpu!")
|
145 |
+
|
146 |
+
if lora.enable:
|
147 |
+
logger.info("Initializing lora layers ...")
|
148 |
+
# initialize LoRA layers
|
149 |
+
initialize_lora_parameters(model, param_dtype)
|
150 |
+
|
151 |
+
assert not any(
|
152 |
+
p.is_meta for p in model.parameters()
|
153 |
+
), "All parameters should be intialized by now"
|
154 |
+
assert all(
|
155 |
+
p.dtype == param_dtype for p in model.parameters()
|
156 |
+
), f"All parameters should be on {param_dtype}"
|
157 |
+
|
158 |
+
logger.info("Finished initialization!")
|
159 |
+
param_init_fn = None
|
160 |
+
else:
|
161 |
+
|
162 |
+
def param_init_fn(m):
|
163 |
+
m.to_empty(device=torch.cuda.current_device(), recurse=False)
|
164 |
+
m.to(param_dtype)
|
165 |
+
|
166 |
+
assert all(
|
167 |
+
p.is_meta for p in model.parameters()
|
168 |
+
), "All parameters should be on meta"
|
169 |
+
|
170 |
+
torch.distributed.barrier()
|
171 |
+
|
172 |
+
# only finetune LoRA parameters and freeze before wrapping
|
173 |
+
if lora.enable:
|
174 |
+
for name, param in model.named_parameters():
|
175 |
+
if "lora" in name:
|
176 |
+
param.requires_grad = True
|
177 |
+
else:
|
178 |
+
param.requires_grad = False
|
179 |
+
|
180 |
+
auto_wrap_policy = get_fsdp_policy(model_args.lora.enable)
|
181 |
+
|
182 |
+
main_logger_info(f"Sharding model over {get_world_size()} GPUs ...")
|
183 |
+
|
184 |
+
wrapped_model = FullyShardedDataParallel(
|
185 |
+
model,
|
186 |
+
sharding_strategy=ShardingStrategy.FULL_SHARD,
|
187 |
+
auto_wrap_policy=auto_wrap_policy,
|
188 |
+
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
|
189 |
+
limit_all_gathers=True,
|
190 |
+
device_id=torch.cuda.current_device(),
|
191 |
+
sync_module_states=True,
|
192 |
+
param_init_fn=param_init_fn,
|
193 |
+
)
|
194 |
+
main_logger_info("Model sharded!")
|
195 |
+
|
196 |
+
log_train_params(wrapped_model)
|
197 |
+
|
198 |
+
return wrapped_model
|
199 |
+
|
200 |
+
|
201 |
+
@torch.no_grad()
|
202 |
+
def load_state_dict(path: Path, dtype: torch.dtype):
|
203 |
+
assert path.is_dir(), path
|
204 |
+
|
205 |
+
this_safetensors_path = Checkpointer.consolidated_path(path, use_safetensors=True)
|
206 |
+
this_torch_path = Checkpointer.consolidated_path(path, use_safetensors=False)
|
207 |
+
|
208 |
+
assert (
|
209 |
+
this_safetensors_path.exists() or this_torch_path.exists()
|
210 |
+
), f"Either {this_safetensors_path} or {this_torch_path} must exist."
|
211 |
+
assert not (
|
212 |
+
this_safetensors_path.exists() and this_torch_path.exists()
|
213 |
+
), f"Only one of {this_safetensors_path} or {this_torch_path} should exist."
|
214 |
+
|
215 |
+
if this_safetensors_path.exists():
|
216 |
+
logger.info(f"Reloading model from {this_safetensors_path} ...")
|
217 |
+
model_state_dict = safetensors.torch.load_file(this_safetensors_path)
|
218 |
+
else:
|
219 |
+
logger.info(f"Reloading model from {this_torch_path} ...")
|
220 |
+
model_state_dict = torch.load(this_torch_path)
|
221 |
+
|
222 |
+
logger.info(f"Converting model to dtype {dtype} ...")
|
223 |
+
|
224 |
+
for k, v in model_state_dict.items():
|
225 |
+
model_state_dict[k] = v.to(dtype)
|
226 |
+
|
227 |
+
return model_state_dict
|
huggingface.ipynb
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [],
|
8 |
+
"source": [
|
9 |
+
"from huggingface_hub import HfApi, HfFolder, Repository\n",
|
10 |
+
"\n",
|
11 |
+
"repo_id = \"your_username/your_model_name\"\n",
|
12 |
+
"repo_local_path = \"./path_to_your_model\"\n",
|
13 |
+
"\n",
|
14 |
+
"# Create the repository object and clone the repo\n",
|
15 |
+
"repo = Repository(local_dir=repo_local_path, clone_from=repo_id)\n",
|
16 |
+
"\n",
|
17 |
+
"# Copy your model files to the repository\n",
|
18 |
+
"model_files = [\"config.json\", \"pytorch_model.bin\", \"tokenizer_config.json\", \"vocab.json\"]\n",
|
19 |
+
"for file in model_files:\n",
|
20 |
+
" shutil.copy(file, repo_local_path)\n",
|
21 |
+
"\n",
|
22 |
+
"# Push the model files to the repository\n",
|
23 |
+
"repo.push_to_hub(commit_message=\"Initial model upload\")\n"
|
24 |
+
]
|
25 |
+
}
|
26 |
+
],
|
27 |
+
"metadata": {
|
28 |
+
"kernelspec": {
|
29 |
+
"display_name": "chemistralpy310",
|
30 |
+
"language": "python",
|
31 |
+
"name": "python3"
|
32 |
+
},
|
33 |
+
"language_info": {
|
34 |
+
"name": "python",
|
35 |
+
"version": "3.10.14"
|
36 |
+
}
|
37 |
+
},
|
38 |
+
"nbformat": 4,
|
39 |
+
"nbformat_minor": 2
|
40 |
+
}
|