Commit
·
d1017b4
0
Parent(s):
Duplicate from jbochi/madlad400-8b-lm
Browse filesCo-authored-by: J Bochi <[email protected]>
- .gitattributes +36 -0
- README.md +462 -0
- added_tokens.json +2 -0
- config.json +39 -0
- decoderonlyt5_config.py +11 -0
- decoderonlyt5_modeling.py +840 -0
- model-00000-of-00007.safetensors +3 -0
- model-00001-of-00007.safetensors +3 -0
- model-00002-of-00007.safetensors +3 -0
- model-00003-of-00007.safetensors +3 -0
- model-00004-of-00007.safetensors +3 -0
- model-00005-of-00007.safetensors +3 -0
- model-00006-of-00007.safetensors +3 -0
- model-00007-of-00007.safetensors +3 -0
- model.safetensors.index.json +262 -0
- special_tokens_map.json +23 -0
- spiece.model +3 -0
- tokenizer.json +3 -0
- tokenizer_config.json +38 -0
.gitattributes
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
@@ -0,0 +1,462 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: apache-2.0
|
3 |
+
language:
|
4 |
+
- en
|
5 |
+
- ru
|
6 |
+
- es
|
7 |
+
- fr
|
8 |
+
- de
|
9 |
+
- it
|
10 |
+
- pt
|
11 |
+
- pl
|
12 |
+
- nl
|
13 |
+
- vi
|
14 |
+
- tr
|
15 |
+
- sv
|
16 |
+
- id
|
17 |
+
- ro
|
18 |
+
- cs
|
19 |
+
- zh
|
20 |
+
- hu
|
21 |
+
- ja
|
22 |
+
- th
|
23 |
+
- fi
|
24 |
+
- fa
|
25 |
+
- uk
|
26 |
+
- da
|
27 |
+
- el
|
28 |
+
- "no"
|
29 |
+
- bg
|
30 |
+
- sk
|
31 |
+
- ko
|
32 |
+
- ar
|
33 |
+
- lt
|
34 |
+
- ca
|
35 |
+
- sl
|
36 |
+
- he
|
37 |
+
- et
|
38 |
+
- lv
|
39 |
+
- hi
|
40 |
+
- sq
|
41 |
+
- ms
|
42 |
+
- az
|
43 |
+
- sr
|
44 |
+
- ta
|
45 |
+
- hr
|
46 |
+
- kk
|
47 |
+
- is
|
48 |
+
- ml
|
49 |
+
- mr
|
50 |
+
- te
|
51 |
+
- af
|
52 |
+
- gl
|
53 |
+
- fil
|
54 |
+
- be
|
55 |
+
- mk
|
56 |
+
- eu
|
57 |
+
- bn
|
58 |
+
- ka
|
59 |
+
- mn
|
60 |
+
- bs
|
61 |
+
- uz
|
62 |
+
- ur
|
63 |
+
- sw
|
64 |
+
- yue
|
65 |
+
- ne
|
66 |
+
- kn
|
67 |
+
- kaa
|
68 |
+
- gu
|
69 |
+
- si
|
70 |
+
- cy
|
71 |
+
- eo
|
72 |
+
- la
|
73 |
+
- hy
|
74 |
+
- ky
|
75 |
+
- tg
|
76 |
+
- ga
|
77 |
+
- mt
|
78 |
+
- my
|
79 |
+
- km
|
80 |
+
- tt
|
81 |
+
- so
|
82 |
+
- ku
|
83 |
+
- ps
|
84 |
+
- pa
|
85 |
+
- rw
|
86 |
+
- lo
|
87 |
+
- ha
|
88 |
+
- dv
|
89 |
+
- fy
|
90 |
+
- lb
|
91 |
+
- ckb
|
92 |
+
- mg
|
93 |
+
- gd
|
94 |
+
- am
|
95 |
+
- ug
|
96 |
+
- ht
|
97 |
+
- grc
|
98 |
+
- hmn
|
99 |
+
- sd
|
100 |
+
- jv
|
101 |
+
- mi
|
102 |
+
- tk
|
103 |
+
- ceb
|
104 |
+
- yi
|
105 |
+
- ba
|
106 |
+
- fo
|
107 |
+
- or
|
108 |
+
- xh
|
109 |
+
- su
|
110 |
+
- kl
|
111 |
+
- ny
|
112 |
+
- sm
|
113 |
+
- sn
|
114 |
+
- co
|
115 |
+
- zu
|
116 |
+
- ig
|
117 |
+
- yo
|
118 |
+
- pap
|
119 |
+
- st
|
120 |
+
- haw
|
121 |
+
- as
|
122 |
+
- oc
|
123 |
+
- cv
|
124 |
+
- lus
|
125 |
+
- tet
|
126 |
+
- gsw
|
127 |
+
- sah
|
128 |
+
- br
|
129 |
+
- rm
|
130 |
+
- sa
|
131 |
+
- bo
|
132 |
+
- om
|
133 |
+
- se
|
134 |
+
- ce
|
135 |
+
- cnh
|
136 |
+
- ilo
|
137 |
+
- hil
|
138 |
+
- udm
|
139 |
+
- os
|
140 |
+
- lg
|
141 |
+
- ti
|
142 |
+
- vec
|
143 |
+
- ts
|
144 |
+
- tyv
|
145 |
+
- kbd
|
146 |
+
- ee
|
147 |
+
- iba
|
148 |
+
- av
|
149 |
+
- kha
|
150 |
+
- to
|
151 |
+
- tn
|
152 |
+
- nso
|
153 |
+
- fj
|
154 |
+
- zza
|
155 |
+
- ak
|
156 |
+
- ada
|
157 |
+
- otq
|
158 |
+
- dz
|
159 |
+
- bua
|
160 |
+
- cfm
|
161 |
+
- ln
|
162 |
+
- chm
|
163 |
+
- gn
|
164 |
+
- krc
|
165 |
+
- wa
|
166 |
+
- hif
|
167 |
+
- yua
|
168 |
+
- srn
|
169 |
+
- war
|
170 |
+
- rom
|
171 |
+
- bik
|
172 |
+
- pam
|
173 |
+
- sg
|
174 |
+
- lu
|
175 |
+
- ady
|
176 |
+
- kbp
|
177 |
+
- syr
|
178 |
+
- ltg
|
179 |
+
- myv
|
180 |
+
- iso
|
181 |
+
- kac
|
182 |
+
- bho
|
183 |
+
- ay
|
184 |
+
- kum
|
185 |
+
- qu
|
186 |
+
- za
|
187 |
+
- pag
|
188 |
+
- ngu
|
189 |
+
- ve
|
190 |
+
- pck
|
191 |
+
- zap
|
192 |
+
- tyz
|
193 |
+
- hui
|
194 |
+
- bbc
|
195 |
+
- tzo
|
196 |
+
- tiv
|
197 |
+
- ksd
|
198 |
+
- gom
|
199 |
+
- min
|
200 |
+
- ang
|
201 |
+
- nhe
|
202 |
+
- bgp
|
203 |
+
- nzi
|
204 |
+
- nnb
|
205 |
+
- nv
|
206 |
+
- zxx
|
207 |
+
- bci
|
208 |
+
- kv
|
209 |
+
- new
|
210 |
+
- mps
|
211 |
+
- alt
|
212 |
+
- meu
|
213 |
+
- bew
|
214 |
+
- fon
|
215 |
+
- iu
|
216 |
+
- abt
|
217 |
+
- mgh
|
218 |
+
- mnw
|
219 |
+
- tvl
|
220 |
+
- dov
|
221 |
+
- tlh
|
222 |
+
- ho
|
223 |
+
- kw
|
224 |
+
- mrj
|
225 |
+
- meo
|
226 |
+
- crh
|
227 |
+
- mbt
|
228 |
+
- emp
|
229 |
+
- ace
|
230 |
+
- ium
|
231 |
+
- mam
|
232 |
+
- gym
|
233 |
+
- mai
|
234 |
+
- crs
|
235 |
+
- pon
|
236 |
+
- ubu
|
237 |
+
- fip
|
238 |
+
- quc
|
239 |
+
- gv
|
240 |
+
- kj
|
241 |
+
- btx
|
242 |
+
- ape
|
243 |
+
- chk
|
244 |
+
- rcf
|
245 |
+
- shn
|
246 |
+
- tzh
|
247 |
+
- mdf
|
248 |
+
- ppk
|
249 |
+
- ss
|
250 |
+
- gag
|
251 |
+
- cab
|
252 |
+
- kri
|
253 |
+
- seh
|
254 |
+
- ibb
|
255 |
+
- tbz
|
256 |
+
- bru
|
257 |
+
- enq
|
258 |
+
- ach
|
259 |
+
- cuk
|
260 |
+
- kmb
|
261 |
+
- wo
|
262 |
+
- kek
|
263 |
+
- qub
|
264 |
+
- tab
|
265 |
+
- bts
|
266 |
+
- kos
|
267 |
+
- rwo
|
268 |
+
- cak
|
269 |
+
- tuc
|
270 |
+
- bum
|
271 |
+
- cjk
|
272 |
+
- gil
|
273 |
+
- stq
|
274 |
+
- tsg
|
275 |
+
- quh
|
276 |
+
- mak
|
277 |
+
- arn
|
278 |
+
- ban
|
279 |
+
- jiv
|
280 |
+
- sja
|
281 |
+
- yap
|
282 |
+
- tcy
|
283 |
+
- toj
|
284 |
+
- twu
|
285 |
+
- xal
|
286 |
+
- amu
|
287 |
+
- rmc
|
288 |
+
- hus
|
289 |
+
- nia
|
290 |
+
- kjh
|
291 |
+
- bm
|
292 |
+
- guh
|
293 |
+
- mas
|
294 |
+
- acf
|
295 |
+
- dtp
|
296 |
+
- ksw
|
297 |
+
- bzj
|
298 |
+
- din
|
299 |
+
- zne
|
300 |
+
- mad
|
301 |
+
- msi
|
302 |
+
- mag
|
303 |
+
- mkn
|
304 |
+
- kg
|
305 |
+
- lhu
|
306 |
+
- ch
|
307 |
+
- qvi
|
308 |
+
- mh
|
309 |
+
- djk
|
310 |
+
- sus
|
311 |
+
- mfe
|
312 |
+
- srm
|
313 |
+
- dyu
|
314 |
+
- ctu
|
315 |
+
- gui
|
316 |
+
- pau
|
317 |
+
- inb
|
318 |
+
- bi
|
319 |
+
- mni
|
320 |
+
- guc
|
321 |
+
- jam
|
322 |
+
- wal
|
323 |
+
- jac
|
324 |
+
- bas
|
325 |
+
- gor
|
326 |
+
- skr
|
327 |
+
- nyu
|
328 |
+
- noa
|
329 |
+
- sda
|
330 |
+
- gub
|
331 |
+
- nog
|
332 |
+
- cni
|
333 |
+
- teo
|
334 |
+
- tdx
|
335 |
+
- sxn
|
336 |
+
- rki
|
337 |
+
- nr
|
338 |
+
- frp
|
339 |
+
- alz
|
340 |
+
- taj
|
341 |
+
- lrc
|
342 |
+
- cce
|
343 |
+
- rn
|
344 |
+
- jvn
|
345 |
+
- hvn
|
346 |
+
- nij
|
347 |
+
- dwr
|
348 |
+
- izz
|
349 |
+
- msm
|
350 |
+
- bus
|
351 |
+
- ktu
|
352 |
+
- chr
|
353 |
+
- maz
|
354 |
+
- tzj
|
355 |
+
- suz
|
356 |
+
- knj
|
357 |
+
- bim
|
358 |
+
- gvl
|
359 |
+
- bqc
|
360 |
+
- tca
|
361 |
+
- pis
|
362 |
+
- prk
|
363 |
+
- laj
|
364 |
+
- mel
|
365 |
+
- qxr
|
366 |
+
- niq
|
367 |
+
- ahk
|
368 |
+
- shp
|
369 |
+
- hne
|
370 |
+
- spp
|
371 |
+
- koi
|
372 |
+
- krj
|
373 |
+
- quf
|
374 |
+
- luz
|
375 |
+
- agr
|
376 |
+
- tsc
|
377 |
+
- mqy
|
378 |
+
- gof
|
379 |
+
- gbm
|
380 |
+
- miq
|
381 |
+
- dje
|
382 |
+
- awa
|
383 |
+
- bjj
|
384 |
+
- qvz
|
385 |
+
- sjp
|
386 |
+
- tll
|
387 |
+
- raj
|
388 |
+
- kjg
|
389 |
+
- bgz
|
390 |
+
- quy
|
391 |
+
- cbk
|
392 |
+
- akb
|
393 |
+
- oj
|
394 |
+
- ify
|
395 |
+
- mey
|
396 |
+
- ks
|
397 |
+
- cac
|
398 |
+
- brx
|
399 |
+
- qup
|
400 |
+
- syl
|
401 |
+
- jax
|
402 |
+
- ff
|
403 |
+
- ber
|
404 |
+
- tks
|
405 |
+
- trp
|
406 |
+
- mrw
|
407 |
+
- adh
|
408 |
+
- smt
|
409 |
+
- srr
|
410 |
+
- ffm
|
411 |
+
- qvc
|
412 |
+
- mtr
|
413 |
+
- ann
|
414 |
+
- kaa
|
415 |
+
- aa
|
416 |
+
- noe
|
417 |
+
- nut
|
418 |
+
- gyn
|
419 |
+
- kwi
|
420 |
+
- xmm
|
421 |
+
- msb
|
422 |
+
library_name: transformers
|
423 |
+
tags:
|
424 |
+
- text-generation-inference
|
425 |
+
datasets:
|
426 |
+
- allenai/MADLAD-400
|
427 |
+
---
|
428 |
+
|
429 |
+
This model has the safetensors weights for the [Madlad-400](https://github.com/google-research/google-research/tree/master/madlad_400) 8B param **language model**.
|
430 |
+
|
431 |
+
The HF transformers code to run inference is not ready yet. The [original implementation](https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/architectures/t5/t5_architecture.py#L1484) is in JAX/Flaxformer.
|
432 |
+
|
433 |
+
The model architecture is the same as [Palm 8B](https://arxiv.org/pdf/2204.02311.pdf).
|
434 |
+
|
435 |
+
It's a decoder-only T5 with 32 layers, 16 query heads, 1 KV head, and 4096 embedding size.
|
436 |
+
|
437 |
+
These are the main differences relative to the original T5 architecture:
|
438 |
+
|
439 |
+
- SwiGLU Activation
|
440 |
+
- Parallel Layers
|
441 |
+
- Multi-Query Attention
|
442 |
+
- RoPE Embeddings
|
443 |
+
- Shared Input-Output Embeddings
|
444 |
+
- No biases
|
445 |
+
- Bidirectional attention
|
446 |
+
- Layer Norm with `center_scale_at_zero` and final layer with `use_scale=False`
|
447 |
+
|
448 |
+
If you are looking for the language models models, here are the available versions:
|
449 |
+
- [3B](https://huggingface.co/jbochi/madlad400-3b-mt)
|
450 |
+
- [7B](https://huggingface.co/jbochi/madlad400-7b-mt)
|
451 |
+
- [7B-BT](https://huggingface.co/jbochi/madlad400-7b-mt-bt)
|
452 |
+
- [10B](https://huggingface.co/jbochi/madlad400-10b-mt)
|
453 |
+
|
454 |
+
|
455 |
+
Article: [MADLAD-400: A Multilingual And Document-Level Large Audited Dataset](https://arxiv.org/abs/2309.04662)
|
456 |
+
|
457 |
+
Abstract:
|
458 |
+
|
459 |
+
> We introduce MADLAD-400, a manually audited, general domain 3T token monolingual dataset based on CommonCrawl, spanning 419 languages. We discuss the limitations revealed by self-auditing MADLAD-400, and the role data auditing had in the dataset creation process. We then train and release a 10.7B-parameter multilingual machine translation model on 250 billion tokens covering over 450 languages using publicly available data, and find that it is competitive with models that are significantly larger, and report the results on different domains. In addition, we train a 8B-parameter language model, and assess the results on few-shot translation. We make the baseline models available to the research community.
|
460 |
+
|
461 |
+
|
462 |
+
|
added_tokens.json
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
}
|
config.json
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"DecoderOnlyT5Model"
|
4 |
+
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "decoderonlyt5_config.DecoderOnlyT5Config",
|
7 |
+
"AutoModelForCausalLM": "decoderonlyt5_modeling.DecoderOnlyT5Model"
|
8 |
+
},
|
9 |
+
"d_ff": 16384,
|
10 |
+
"d_kv": 256,
|
11 |
+
"d_model": 4096,
|
12 |
+
"dropout_rate": 0.0,
|
13 |
+
"decoder_start_token_id": 0,
|
14 |
+
"pad_token_id": 1,
|
15 |
+
"eos_token_id": 3,
|
16 |
+
"feed_forward_proj": "gated-swish",
|
17 |
+
"initializer_factor": 1.0,
|
18 |
+
"is_encoder_decoder": false,
|
19 |
+
"is_decoder_only": true,
|
20 |
+
"layer_norm_epsilon": 1e-06,
|
21 |
+
"model_type": "t5",
|
22 |
+
"n_positions": 512,
|
23 |
+
"num_layers": 0,
|
24 |
+
"num_decoder_layers": 32,
|
25 |
+
"num_heads": 16,
|
26 |
+
"output_past": true,
|
27 |
+
"relative_attention_max_distance": 128,
|
28 |
+
"relative_attention_num_buckets": 32,
|
29 |
+
"task_specific_params": {},
|
30 |
+
"tie_word_embeddings": true,
|
31 |
+
"transformers_version": "4.23.1",
|
32 |
+
"use_cache": true,
|
33 |
+
"vocab_size": 256512,
|
34 |
+
"parallel_layers": true,
|
35 |
+
"has_relative_attention_bias": false,
|
36 |
+
"multi_query_attention": true,
|
37 |
+
"use_rotary_embedding": true,
|
38 |
+
"rotary_embedding_max_timescale": 1000
|
39 |
+
}
|
decoderonlyt5_config.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers.models.t5.configuration_t5 import T5Config
|
2 |
+
|
3 |
+
|
4 |
+
class DecoderOnlyT5Config(T5Config):
|
5 |
+
is_decoder_only = True
|
6 |
+
# whether to call attention and mlp in parallel.
|
7 |
+
# https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/architectures/t5/t5_architecture.py#L384
|
8 |
+
parallel_layers = True
|
9 |
+
has_relative_attention_bias = False
|
10 |
+
# https://arxiv.org/abs/1911.02150
|
11 |
+
multi_query_attention = True
|
decoderonlyt5_modeling.py
ADDED
@@ -0,0 +1,840 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import math
|
3 |
+
from typing import Optional, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
from torch.nn import CrossEntropyLoss
|
8 |
+
from transformers.models.t5 import modeling_t5
|
9 |
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
10 |
+
from transformers.utils import (
|
11 |
+
add_start_docstrings_to_model_forward,
|
12 |
+
logging,
|
13 |
+
replace_return_docstrings,
|
14 |
+
)
|
15 |
+
|
16 |
+
from .decoderonlyt5_config import DecoderOnlyT5Config
|
17 |
+
|
18 |
+
|
19 |
+
logger = logging.get_logger(__name__)
|
20 |
+
_CONFIG_FOR_DOC = "DecoderOnlyT5Config"
|
21 |
+
|
22 |
+
|
23 |
+
class DecoderOnlyT5LayerNorm(nn.Module):
|
24 |
+
def __init__(self, hidden_size, eps=1e-6, use_scale=True, center_scale_at_zero=False):
|
25 |
+
"""
|
26 |
+
Construct a layernorm module in the T5 style No bias and no subtraction of mean.
|
27 |
+
"""
|
28 |
+
super().__init__()
|
29 |
+
if use_scale:
|
30 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
31 |
+
else:
|
32 |
+
assert not center_scale_at_zero
|
33 |
+
self.weight = None
|
34 |
+
self.center_scale_at_zero = center_scale_at_zero
|
35 |
+
self.variance_epsilon = eps
|
36 |
+
|
37 |
+
def forward(self, hidden_states):
|
38 |
+
# https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/components/layer_norm.py#L30
|
39 |
+
|
40 |
+
# layer norm should always be calculated in float32
|
41 |
+
mean2 = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
42 |
+
hidden_states = hidden_states * torch.rsqrt(mean2 + self.variance_epsilon)
|
43 |
+
|
44 |
+
# convert into float16 if necessary
|
45 |
+
if self.weight is None:
|
46 |
+
return hidden_states
|
47 |
+
if self.weight.dtype == torch.float16:
|
48 |
+
hidden_states = hidden_states.to(torch.float16)
|
49 |
+
if self.center_scale_at_zero:
|
50 |
+
return (self.weight + 1.0) * hidden_states
|
51 |
+
else:
|
52 |
+
return self.weight * hidden_states
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
class DecoderOnlyT5LayerFF(modeling_t5.T5LayerFF):
|
57 |
+
def __init__(self, config: DecoderOnlyT5Config):
|
58 |
+
super(modeling_t5.T5LayerFF, self).__init__()
|
59 |
+
if config.is_gated_act:
|
60 |
+
self.DenseReluDense = modeling_t5.T5DenseGatedActDense(config)
|
61 |
+
else:
|
62 |
+
self.DenseReluDense = modeling_t5.T5DenseActDense(config)
|
63 |
+
|
64 |
+
if not config.parallel_layers:
|
65 |
+
self.layer_norm = modeling_t5.DecoderOnlyT5LayerNorm(
|
66 |
+
config.d_model, eps=config.layer_norm_epsilon
|
67 |
+
)
|
68 |
+
else:
|
69 |
+
self.layer_norm = nn.Identity()
|
70 |
+
self.dropout = nn.Dropout(config.dropout_rate)
|
71 |
+
|
72 |
+
|
73 |
+
# LlamaRotaryEmbedding
|
74 |
+
class DecoderOnlyT5RotaryEmbedding(nn.Module):
|
75 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
76 |
+
super().__init__()
|
77 |
+
|
78 |
+
self.dim = dim
|
79 |
+
self.max_position_embeddings = max_position_embeddings
|
80 |
+
self.base = base
|
81 |
+
inv_freq = 1.0 / (
|
82 |
+
self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)
|
83 |
+
)
|
84 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
85 |
+
|
86 |
+
# Build here to make `torch.jit.trace` work.
|
87 |
+
self._set_cos_sin_cache(
|
88 |
+
seq_len=max_position_embeddings,
|
89 |
+
device=self.inv_freq.device,
|
90 |
+
dtype=torch.get_default_dtype(),
|
91 |
+
)
|
92 |
+
|
93 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
94 |
+
self.max_seq_len_cached = seq_len
|
95 |
+
t = torch.arange(
|
96 |
+
self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype
|
97 |
+
)
|
98 |
+
|
99 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
100 |
+
# Different from paper, but it uses a different permutation in order to obtain the same calculation
|
101 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
102 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
103 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
104 |
+
|
105 |
+
def forward(self, x, seq_len=None):
|
106 |
+
# x: [bs, num_attention_heads, seq_len, head_size]
|
107 |
+
if seq_len > self.max_seq_len_cached:
|
108 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
109 |
+
|
110 |
+
return (
|
111 |
+
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
112 |
+
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
113 |
+
)
|
114 |
+
|
115 |
+
|
116 |
+
def rotate_half(x):
|
117 |
+
"""Rotates half the hidden dims of the input."""
|
118 |
+
x1 = x[..., : x.shape[-1] // 2]
|
119 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
120 |
+
return torch.cat((-x2, x1), dim=-1)
|
121 |
+
|
122 |
+
|
123 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
|
124 |
+
"""Applies Rotary Position Embedding to the query and key tensors.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
q (`torch.Tensor`): The query tensor.
|
128 |
+
k (`torch.Tensor`): The key tensor.
|
129 |
+
cos (`torch.Tensor`): The cosine part of the rotary embedding.
|
130 |
+
sin (`torch.Tensor`): The sine part of the rotary embedding.
|
131 |
+
position_ids (`torch.Tensor`):
|
132 |
+
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
|
133 |
+
used to pass offsetted position ids when working with a KV-cache.
|
134 |
+
unsqueeze_dim (`int`, *optional*, defaults to 1):
|
135 |
+
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
|
136 |
+
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
|
137 |
+
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
|
138 |
+
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
|
139 |
+
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
|
140 |
+
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
|
141 |
+
Returns:
|
142 |
+
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
|
143 |
+
"""
|
144 |
+
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
|
145 |
+
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
|
146 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
147 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
148 |
+
return q_embed, k_embed
|
149 |
+
|
150 |
+
|
151 |
+
# https://github.com/huggingface/transformers/blob/7ee995fd9c692761c4601ddbffa2ac2ec9f27b0b/src/transformers/models/llama/modeling_llama.py#L263
|
152 |
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
153 |
+
"""
|
154 |
+
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
155 |
+
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
156 |
+
"""
|
157 |
+
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
158 |
+
if n_rep == 1:
|
159 |
+
return hidden_states
|
160 |
+
hidden_states = hidden_states[:, :, None, :, :].expand(
|
161 |
+
batch, num_key_value_heads, n_rep, slen, head_dim
|
162 |
+
)
|
163 |
+
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
164 |
+
|
165 |
+
|
166 |
+
class DecoderOnlyT5Attention(modeling_t5.T5Attention):
|
167 |
+
"""
|
168 |
+
Supports both multi-head and multi-query attention.
|
169 |
+
https://arxiv.org/abs/1911.02150
|
170 |
+
https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/components/attention/dense_attention.py#L292
|
171 |
+
"""
|
172 |
+
|
173 |
+
def __init__(self, config: DecoderOnlyT5Config, has_relative_attention_bias=False):
|
174 |
+
super(modeling_t5.T5Attention, self).__init__()
|
175 |
+
self.is_decoder = config.is_decoder
|
176 |
+
assert not has_relative_attention_bias
|
177 |
+
assert config.use_rotary_embedding
|
178 |
+
self.d_model = config.d_model
|
179 |
+
self.head_dim = config.d_kv
|
180 |
+
self.num_heads = config.num_heads
|
181 |
+
self.num_key_value_heads = 1 if config.multi_query_attention else self.n_heads
|
182 |
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
183 |
+
self.attention_dropout = config.dropout_rate
|
184 |
+
self.inner_dim = self.num_heads * self.head_dim
|
185 |
+
self.kv_inner_dim = self.num_key_value_heads * self.head_dim
|
186 |
+
self.rotary_emb = DecoderOnlyT5RotaryEmbedding(
|
187 |
+
self.head_dim,
|
188 |
+
max_position_embeddings=config.relative_attention_max_distance,
|
189 |
+
base=config.rotary_embedding_max_timescale,
|
190 |
+
)
|
191 |
+
|
192 |
+
# Mesh TensorFlow initialization to avoid scaling before softmax
|
193 |
+
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
|
194 |
+
self.k = nn.Linear(self.d_model, self.kv_inner_dim, bias=False)
|
195 |
+
self.v = nn.Linear(self.d_model, self.kv_inner_dim, bias=False)
|
196 |
+
self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
|
197 |
+
|
198 |
+
self.pruned_heads = set()
|
199 |
+
self.gradient_checkpointing = False
|
200 |
+
|
201 |
+
def forward(
|
202 |
+
self,
|
203 |
+
hidden_states: torch.Tensor,
|
204 |
+
key_value_states=None,
|
205 |
+
position_bias=None,
|
206 |
+
mask: Optional[torch.Tensor] = None,
|
207 |
+
layer_head_mask=None,
|
208 |
+
position_ids: Optional[torch.LongTensor] = None,
|
209 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
210 |
+
output_attentions: bool = False,
|
211 |
+
use_cache: bool = False,
|
212 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
213 |
+
assert key_value_states is None
|
214 |
+
assert position_bias is None
|
215 |
+
assert layer_head_mask is None
|
216 |
+
|
217 |
+
bsz, q_len, _ = hidden_states.size()
|
218 |
+
|
219 |
+
query_states = self.q(hidden_states)
|
220 |
+
key_states = self.k(hidden_states)
|
221 |
+
value_states = self.v(hidden_states)
|
222 |
+
|
223 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
224 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
225 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
226 |
+
|
227 |
+
kv_seq_len = key_states.shape[-2]
|
228 |
+
if past_key_value is not None:
|
229 |
+
kv_seq_len += past_key_value[0].shape[-2]
|
230 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
231 |
+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
232 |
+
|
233 |
+
if past_key_value is not None:
|
234 |
+
# reuse k, v, self_attention
|
235 |
+
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
236 |
+
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
237 |
+
|
238 |
+
past_key_value = (key_states, value_states) if use_cache else None
|
239 |
+
|
240 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
241 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
242 |
+
|
243 |
+
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
|
244 |
+
|
245 |
+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
246 |
+
raise ValueError(
|
247 |
+
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
248 |
+
f" {attn_weights.size()}"
|
249 |
+
)
|
250 |
+
|
251 |
+
if mask is not None:
|
252 |
+
if mask.size() != (bsz, 1, q_len, kv_seq_len):
|
253 |
+
raise ValueError(
|
254 |
+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {mask.size()}"
|
255 |
+
)
|
256 |
+
attn_weights = attn_weights + mask
|
257 |
+
|
258 |
+
# upcast attention to fp32
|
259 |
+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
|
260 |
+
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout)
|
261 |
+
attn_output = torch.matmul(attn_weights, value_states)
|
262 |
+
|
263 |
+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
264 |
+
raise ValueError(
|
265 |
+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
266 |
+
f" {attn_output.size()}"
|
267 |
+
)
|
268 |
+
|
269 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
270 |
+
attn_output = attn_output.reshape(bsz, q_len, self.inner_dim)
|
271 |
+
attn_output = self.o(attn_output)
|
272 |
+
|
273 |
+
present_key_value_state = (
|
274 |
+
(key_states, value_states) if (self.is_decoder and use_cache) else None
|
275 |
+
)
|
276 |
+
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
|
277 |
+
|
278 |
+
if output_attentions:
|
279 |
+
outputs = outputs + (attn_weights,)
|
280 |
+
return outputs
|
281 |
+
|
282 |
+
|
283 |
+
class DecoderOnlyT5LayerSelfAttention(modeling_t5.T5LayerSelfAttention):
|
284 |
+
def __init__(self, config, has_relative_attention_bias=False):
|
285 |
+
super(modeling_t5.T5LayerSelfAttention, self).__init__()
|
286 |
+
self.SelfAttention = DecoderOnlyT5Attention(
|
287 |
+
config, has_relative_attention_bias=has_relative_attention_bias
|
288 |
+
)
|
289 |
+
self.layer_norm = DecoderOnlyT5LayerNorm(
|
290 |
+
config.d_model,
|
291 |
+
eps=config.layer_norm_epsilon,
|
292 |
+
use_scale=True,
|
293 |
+
center_scale_at_zero=True,
|
294 |
+
)
|
295 |
+
self.dropout = nn.Dropout(config.dropout_rate)
|
296 |
+
self.parallel_layers = config.parallel_layers
|
297 |
+
|
298 |
+
def forward(
|
299 |
+
self,
|
300 |
+
hidden_states,
|
301 |
+
attention_mask=None,
|
302 |
+
position_bias=None,
|
303 |
+
position_ids=None,
|
304 |
+
layer_head_mask=None,
|
305 |
+
past_key_value=None,
|
306 |
+
use_cache=False,
|
307 |
+
output_attentions=False,
|
308 |
+
):
|
309 |
+
if not self.parallel_layers:
|
310 |
+
x = self.layer_norm(hidden_states)
|
311 |
+
else:
|
312 |
+
x = hidden_states
|
313 |
+
attention_output = self.SelfAttention(
|
314 |
+
x,
|
315 |
+
mask=attention_mask,
|
316 |
+
position_bias=position_bias,
|
317 |
+
position_ids=position_ids,
|
318 |
+
layer_head_mask=layer_head_mask,
|
319 |
+
past_key_value=past_key_value,
|
320 |
+
use_cache=use_cache,
|
321 |
+
output_attentions=output_attentions,
|
322 |
+
)
|
323 |
+
if not self.parallel_layers:
|
324 |
+
# When parallel_layers is True, the residual connection is applied
|
325 |
+
# in the decoder block instead of here.
|
326 |
+
hidden_states = hidden_states + self.dropout(attention_output[0])
|
327 |
+
else:
|
328 |
+
hidden_states = attention_output[0]
|
329 |
+
outputs = (hidden_states,) + attention_output[
|
330 |
+
1:
|
331 |
+
] # add attentions if we output them
|
332 |
+
return outputs
|
333 |
+
|
334 |
+
|
335 |
+
class DecoderOnlyT5Block(modeling_t5.T5Block):
|
336 |
+
def __init__(self, config, has_relative_attention_bias=False):
|
337 |
+
super(modeling_t5.T5Block, self).__init__()
|
338 |
+
self.is_decoder = config.is_decoder
|
339 |
+
self.is_decoder_only = config.is_decoder_only
|
340 |
+
self.layer = nn.ModuleList()
|
341 |
+
self.layer.append(
|
342 |
+
DecoderOnlyT5LayerSelfAttention(
|
343 |
+
config, has_relative_attention_bias=has_relative_attention_bias
|
344 |
+
)
|
345 |
+
)
|
346 |
+
if self.is_decoder:
|
347 |
+
if config.is_decoder_only:
|
348 |
+
self.layer.append(nn.Identity())
|
349 |
+
else:
|
350 |
+
self.layer.append(modeling_t5.T5LayerCrossAttention(config))
|
351 |
+
self.parallel_layers = config.parallel_layers
|
352 |
+
self.layer.append(DecoderOnlyT5LayerFF(config))
|
353 |
+
|
354 |
+
def forward(
|
355 |
+
self,
|
356 |
+
hidden_states,
|
357 |
+
attention_mask=None,
|
358 |
+
position_bias=None,
|
359 |
+
position_ids=None,
|
360 |
+
encoder_hidden_states=None,
|
361 |
+
layer_head_mask=None,
|
362 |
+
past_key_value=None,
|
363 |
+
use_cache=False,
|
364 |
+
output_attentions=False,
|
365 |
+
encoder_attention_mask=None,
|
366 |
+
encoder_decoder_position_bias=None,
|
367 |
+
cross_attn_layer_head_mask=None,
|
368 |
+
return_dict=True,
|
369 |
+
):
|
370 |
+
assert encoder_attention_mask is None
|
371 |
+
assert encoder_decoder_position_bias is None
|
372 |
+
assert cross_attn_layer_head_mask is None
|
373 |
+
if past_key_value is not None:
|
374 |
+
expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
|
375 |
+
|
376 |
+
if len(past_key_value) != expected_num_past_key_values:
|
377 |
+
raise ValueError(
|
378 |
+
f"There should be {expected_num_past_key_values} past states. "
|
379 |
+
f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
|
380 |
+
f"Got {len(past_key_value)} past key / value states"
|
381 |
+
)
|
382 |
+
self_attn_past_key_value = past_key_value[:2]
|
383 |
+
else:
|
384 |
+
self_attn_past_key_value = None
|
385 |
+
|
386 |
+
ff_layer = self.layer[-1]
|
387 |
+
if self.parallel_layers:
|
388 |
+
# https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/architectures/t5/t5_architecture.py#L563-L568
|
389 |
+
x = self.layer[0].layer_norm(hidden_states)
|
390 |
+
ff_output = ff_layer(x)
|
391 |
+
else:
|
392 |
+
x = hidden_states
|
393 |
+
|
394 |
+
self_attention_outputs = self.layer[0](
|
395 |
+
x,
|
396 |
+
attention_mask=attention_mask,
|
397 |
+
position_bias=position_bias,
|
398 |
+
position_ids=position_ids,
|
399 |
+
layer_head_mask=layer_head_mask,
|
400 |
+
past_key_value=self_attn_past_key_value,
|
401 |
+
use_cache=use_cache,
|
402 |
+
output_attentions=output_attentions,
|
403 |
+
)
|
404 |
+
x, present_key_value_state = self_attention_outputs[:2]
|
405 |
+
attention_outputs = self_attention_outputs[
|
406 |
+
2:
|
407 |
+
] # Keep self-attention outputs and relative position weights
|
408 |
+
|
409 |
+
# clamp inf values to enable fp16 training
|
410 |
+
if x.dtype == torch.float16:
|
411 |
+
clamp_value = torch.where(
|
412 |
+
torch.isinf(x).any(),
|
413 |
+
torch.finfo(x.dtype).max - 1000,
|
414 |
+
torch.finfo(x.dtype).max,
|
415 |
+
)
|
416 |
+
x = torch.clamp(x, min=-clamp_value, max=clamp_value)
|
417 |
+
|
418 |
+
do_cross_attention = (
|
419 |
+
self.is_decoder
|
420 |
+
and not self.is_decoder_only
|
421 |
+
and encoder_hidden_states is not None
|
422 |
+
)
|
423 |
+
assert not do_cross_attention
|
424 |
+
|
425 |
+
if self.parallel_layers:
|
426 |
+
# https://github.com/google/flaxformer/blob/ea17eb012a1d340ddff017b7a534c2162aaec34c/flaxformer/architectures/t5/t5_architecture.py#L534-L578
|
427 |
+
x = x + ff_output
|
428 |
+
x *= 2**-0.5
|
429 |
+
hidden_states = hidden_states + self.layer[0].dropout(x)
|
430 |
+
else:
|
431 |
+
hidden_states = ff_layer(x)
|
432 |
+
|
433 |
+
# clamp inf values to enable fp16 training
|
434 |
+
if hidden_states.dtype == torch.float16:
|
435 |
+
clamp_value = torch.where(
|
436 |
+
torch.isinf(hidden_states).any(),
|
437 |
+
torch.finfo(hidden_states.dtype).max - 1000,
|
438 |
+
torch.finfo(hidden_states.dtype).max,
|
439 |
+
)
|
440 |
+
hidden_states = torch.clamp(
|
441 |
+
hidden_states, min=-clamp_value, max=clamp_value
|
442 |
+
)
|
443 |
+
|
444 |
+
outputs = (hidden_states,)
|
445 |
+
|
446 |
+
if use_cache:
|
447 |
+
outputs = outputs + (present_key_value_state,) + attention_outputs
|
448 |
+
else:
|
449 |
+
outputs = outputs + attention_outputs
|
450 |
+
|
451 |
+
return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
|
452 |
+
|
453 |
+
|
454 |
+
class DecoderOnlyT5Stack(modeling_t5.T5Stack):
|
455 |
+
def __init__(self, config, embed_tokens=None):
|
456 |
+
super(modeling_t5.T5Stack, self).__init__(config)
|
457 |
+
|
458 |
+
self.embed_tokens = embed_tokens
|
459 |
+
self.is_decoder = config.is_decoder
|
460 |
+
|
461 |
+
self.block = nn.ModuleList(
|
462 |
+
[
|
463 |
+
DecoderOnlyT5Block(
|
464 |
+
config,
|
465 |
+
has_relative_attention_bias=(
|
466 |
+
config.has_relative_attention_bias and bool(i == 0)
|
467 |
+
),
|
468 |
+
)
|
469 |
+
for i in range(config.num_layers)
|
470 |
+
]
|
471 |
+
)
|
472 |
+
self.final_layer_norm = DecoderOnlyT5LayerNorm(
|
473 |
+
config.d_model,
|
474 |
+
eps=config.layer_norm_epsilon,
|
475 |
+
use_scale=False,
|
476 |
+
center_scale_at_zero=False,
|
477 |
+
)
|
478 |
+
self.dropout = nn.Dropout(config.dropout_rate)
|
479 |
+
|
480 |
+
# Initialize weights and apply final processing
|
481 |
+
self.post_init()
|
482 |
+
# Model parallel
|
483 |
+
self.model_parallel = False
|
484 |
+
self.device_map = None
|
485 |
+
self.gradient_checkpointing = False
|
486 |
+
|
487 |
+
def forward(
|
488 |
+
self,
|
489 |
+
input_ids=None,
|
490 |
+
position_ids=None,
|
491 |
+
attention_mask=None,
|
492 |
+
encoder_hidden_states=None,
|
493 |
+
encoder_attention_mask=None,
|
494 |
+
inputs_embeds=None,
|
495 |
+
head_mask=None,
|
496 |
+
cross_attn_head_mask=None,
|
497 |
+
past_key_values=None,
|
498 |
+
use_cache=None,
|
499 |
+
output_attentions=None,
|
500 |
+
output_hidden_states=None,
|
501 |
+
return_dict=None,
|
502 |
+
):
|
503 |
+
# Model parallel
|
504 |
+
if self.model_parallel:
|
505 |
+
torch.cuda.set_device(self.first_device)
|
506 |
+
self.embed_tokens = self.embed_tokens.to(self.first_device)
|
507 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
508 |
+
output_attentions = (
|
509 |
+
output_attentions
|
510 |
+
if output_attentions is not None
|
511 |
+
else self.config.output_attentions
|
512 |
+
)
|
513 |
+
output_hidden_states = (
|
514 |
+
output_hidden_states
|
515 |
+
if output_hidden_states is not None
|
516 |
+
else self.config.output_hidden_states
|
517 |
+
)
|
518 |
+
return_dict = (
|
519 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
520 |
+
)
|
521 |
+
|
522 |
+
if input_ids is not None and inputs_embeds is not None:
|
523 |
+
err_msg_prefix = "decoder_" if self.is_decoder else ""
|
524 |
+
raise ValueError(
|
525 |
+
f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
|
526 |
+
)
|
527 |
+
elif input_ids is not None:
|
528 |
+
input_shape = input_ids.size()
|
529 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
530 |
+
elif inputs_embeds is not None:
|
531 |
+
input_shape = inputs_embeds.size()[:-1]
|
532 |
+
else:
|
533 |
+
err_msg_prefix = "decoder_" if self.is_decoder else ""
|
534 |
+
raise ValueError(
|
535 |
+
f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds"
|
536 |
+
)
|
537 |
+
|
538 |
+
if position_ids is None:
|
539 |
+
seq_length = input_ids.shape[1]
|
540 |
+
past_key_values_length = (
|
541 |
+
0 if past_key_values is None else past_key_values[0][0].shape[2]
|
542 |
+
)
|
543 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
544 |
+
position_ids = torch.arange(
|
545 |
+
past_key_values_length,
|
546 |
+
seq_length + past_key_values_length,
|
547 |
+
dtype=torch.long,
|
548 |
+
device=device,
|
549 |
+
).unsqueeze(0)
|
550 |
+
|
551 |
+
if inputs_embeds is None:
|
552 |
+
if self.embed_tokens is None:
|
553 |
+
raise ValueError(
|
554 |
+
"You have to initialize the model with valid token embeddings"
|
555 |
+
)
|
556 |
+
inputs_embeds = self.embed_tokens(input_ids)
|
557 |
+
|
558 |
+
batch_size, seq_length = input_shape
|
559 |
+
|
560 |
+
# required mask seq length can be calculated via length of past
|
561 |
+
mask_seq_length = (
|
562 |
+
past_key_values[0][0].shape[2] + seq_length
|
563 |
+
if past_key_values is not None
|
564 |
+
else seq_length
|
565 |
+
)
|
566 |
+
|
567 |
+
if use_cache is True:
|
568 |
+
if not self.is_decoder:
|
569 |
+
raise ValueError(
|
570 |
+
f"`use_cache` can only be set to `True` if {self} is used as a decoder"
|
571 |
+
)
|
572 |
+
|
573 |
+
if attention_mask is None:
|
574 |
+
attention_mask = torch.ones(
|
575 |
+
batch_size, mask_seq_length, device=inputs_embeds.device
|
576 |
+
)
|
577 |
+
|
578 |
+
# initialize past_key_values with `None` if past does not exist
|
579 |
+
if past_key_values is None:
|
580 |
+
past_key_values = [None] * len(self.block)
|
581 |
+
|
582 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
583 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
584 |
+
extended_attention_mask = self.get_extended_attention_mask(
|
585 |
+
attention_mask, input_shape
|
586 |
+
)
|
587 |
+
|
588 |
+
if self.gradient_checkpointing and self.training:
|
589 |
+
if use_cache:
|
590 |
+
logger.warning_once(
|
591 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
592 |
+
)
|
593 |
+
use_cache = False
|
594 |
+
|
595 |
+
# Prepare head mask if needed
|
596 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_layers)
|
597 |
+
cross_attn_head_mask = self.get_head_mask(
|
598 |
+
cross_attn_head_mask, self.config.num_layers
|
599 |
+
)
|
600 |
+
present_key_value_states = () if use_cache else None
|
601 |
+
all_hidden_states = () if output_hidden_states else None
|
602 |
+
all_attentions = () if output_attentions else None
|
603 |
+
all_cross_attentions = () if (output_attentions and self.is_decoder) else None
|
604 |
+
position_bias = None
|
605 |
+
|
606 |
+
hidden_states = self.dropout(inputs_embeds)
|
607 |
+
|
608 |
+
for i, (layer_module, past_key_value) in enumerate(
|
609 |
+
zip(self.block, past_key_values)
|
610 |
+
):
|
611 |
+
layer_head_mask = head_mask[i]
|
612 |
+
cross_attn_layer_head_mask = cross_attn_head_mask[i]
|
613 |
+
# Model parallel
|
614 |
+
if self.model_parallel:
|
615 |
+
torch.cuda.set_device(hidden_states.device)
|
616 |
+
# Ensure that attention_mask is always on the same device as hidden_states
|
617 |
+
if attention_mask is not None:
|
618 |
+
attention_mask = attention_mask.to(hidden_states.device)
|
619 |
+
if position_bias is not None:
|
620 |
+
position_bias = position_bias.to(hidden_states.device)
|
621 |
+
if layer_head_mask is not None:
|
622 |
+
layer_head_mask = layer_head_mask.to(hidden_states.device)
|
623 |
+
|
624 |
+
if output_hidden_states:
|
625 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
626 |
+
|
627 |
+
if self.gradient_checkpointing and self.training:
|
628 |
+
layer_outputs = self._gradient_checkpointing_func(
|
629 |
+
layer_module.forward,
|
630 |
+
hidden_states,
|
631 |
+
extended_attention_mask,
|
632 |
+
position_bias,
|
633 |
+
None,
|
634 |
+
None,
|
635 |
+
None,
|
636 |
+
layer_head_mask,
|
637 |
+
cross_attn_layer_head_mask,
|
638 |
+
None, # past_key_value is always None with gradient checkpointing
|
639 |
+
use_cache,
|
640 |
+
output_attentions,
|
641 |
+
)
|
642 |
+
else:
|
643 |
+
layer_outputs = layer_module(
|
644 |
+
hidden_states,
|
645 |
+
attention_mask=extended_attention_mask,
|
646 |
+
position_bias=position_bias,
|
647 |
+
position_ids=position_ids,
|
648 |
+
encoder_hidden_states=None,
|
649 |
+
encoder_attention_mask=None,
|
650 |
+
encoder_decoder_position_bias=None,
|
651 |
+
layer_head_mask=layer_head_mask,
|
652 |
+
cross_attn_layer_head_mask=cross_attn_layer_head_mask,
|
653 |
+
past_key_value=past_key_value,
|
654 |
+
use_cache=use_cache,
|
655 |
+
output_attentions=output_attentions,
|
656 |
+
)
|
657 |
+
|
658 |
+
# layer_outputs is a tuple with:
|
659 |
+
# hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
|
660 |
+
if use_cache is False:
|
661 |
+
layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
|
662 |
+
|
663 |
+
hidden_states, present_key_value_state = layer_outputs[:2]
|
664 |
+
|
665 |
+
# We share the position biases between the layers - the first layer store them
|
666 |
+
# layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
|
667 |
+
# (cross-attention position bias), (cross-attention weights)
|
668 |
+
position_bias = layer_outputs[2]
|
669 |
+
# append next layer key value states
|
670 |
+
if use_cache:
|
671 |
+
present_key_value_states = present_key_value_states + (
|
672 |
+
present_key_value_state,
|
673 |
+
)
|
674 |
+
|
675 |
+
if output_attentions:
|
676 |
+
all_attentions = all_attentions + (layer_outputs[3],)
|
677 |
+
if self.is_decoder:
|
678 |
+
all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
|
679 |
+
|
680 |
+
# Model Parallel: If it's the last layer for that device, put things on the next device
|
681 |
+
if self.model_parallel:
|
682 |
+
for k, v in self.device_map.items():
|
683 |
+
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
684 |
+
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
685 |
+
|
686 |
+
hidden_states = self.final_layer_norm(hidden_states)
|
687 |
+
hidden_states = self.dropout(hidden_states)
|
688 |
+
|
689 |
+
# Add last layer
|
690 |
+
if output_hidden_states:
|
691 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
692 |
+
|
693 |
+
if not return_dict:
|
694 |
+
return tuple(
|
695 |
+
v
|
696 |
+
for v in [
|
697 |
+
hidden_states,
|
698 |
+
present_key_value_states,
|
699 |
+
all_hidden_states,
|
700 |
+
all_attentions,
|
701 |
+
all_cross_attentions,
|
702 |
+
]
|
703 |
+
if v is not None
|
704 |
+
)
|
705 |
+
return modeling_t5.BaseModelOutputWithPastAndCrossAttentions(
|
706 |
+
last_hidden_state=hidden_states,
|
707 |
+
past_key_values=present_key_value_states,
|
708 |
+
hidden_states=all_hidden_states,
|
709 |
+
attentions=all_attentions,
|
710 |
+
cross_attentions=all_cross_attentions,
|
711 |
+
)
|
712 |
+
|
713 |
+
|
714 |
+
class DecoderOnlyT5Model(modeling_t5.T5ForConditionalGeneration):
|
715 |
+
def __init__(self, config: DecoderOnlyT5Config):
|
716 |
+
super(modeling_t5.T5ForConditionalGeneration, self).__init__(config)
|
717 |
+
self.model_dim = config.d_model
|
718 |
+
|
719 |
+
self.shared = nn.Embedding(config.vocab_size, config.d_model)
|
720 |
+
assert (
|
721 |
+
self.config.num_layers == 0
|
722 |
+
), "Decoder only model cannot have encoder layers"
|
723 |
+
self.encoder = None
|
724 |
+
|
725 |
+
decoder_config = copy.deepcopy(config)
|
726 |
+
decoder_config.is_decoder = True
|
727 |
+
decoder_config.is_encoder_decoder = False
|
728 |
+
decoder_config.num_layers = config.num_decoder_layers
|
729 |
+
self.decoder = DecoderOnlyT5Stack(decoder_config, self.shared)
|
730 |
+
|
731 |
+
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
732 |
+
|
733 |
+
# Initialize weights and apply final processing
|
734 |
+
self.post_init()
|
735 |
+
|
736 |
+
# Model parallel
|
737 |
+
self.model_parallel = False
|
738 |
+
self.device_map = None
|
739 |
+
|
740 |
+
def _tie_weights(self):
|
741 |
+
if not self.config.tie_word_embeddings:
|
742 |
+
return
|
743 |
+
if self.decoder:
|
744 |
+
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
|
745 |
+
|
746 |
+
@add_start_docstrings_to_model_forward(modeling_t5.T5_INPUTS_DOCSTRING)
|
747 |
+
@replace_return_docstrings(
|
748 |
+
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
|
749 |
+
)
|
750 |
+
def forward(
|
751 |
+
self,
|
752 |
+
input_ids: Optional[torch.LongTensor] = None,
|
753 |
+
position_ids: Optional[torch.LongTensor] = None,
|
754 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
755 |
+
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
756 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
757 |
+
labels: Optional[torch.LongTensor] = None,
|
758 |
+
use_cache: Optional[bool] = None,
|
759 |
+
output_attentions: Optional[bool] = None,
|
760 |
+
output_hidden_states: Optional[bool] = None,
|
761 |
+
return_dict: Optional[bool] = None,
|
762 |
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
763 |
+
r"""
|
764 |
+
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
765 |
+
Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
|
766 |
+
config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
|
767 |
+
labels in `[0, ..., config.vocab_size]`
|
768 |
+
|
769 |
+
Returns:
|
770 |
+
|
771 |
+
Examples:
|
772 |
+
|
773 |
+
```"""
|
774 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
775 |
+
return_dict = (
|
776 |
+
return_dict if return_dict is not None else self.config.use_return_dict
|
777 |
+
)
|
778 |
+
|
779 |
+
if self.model_parallel:
|
780 |
+
torch.cuda.set_device(self.decoder.first_device)
|
781 |
+
|
782 |
+
# Set device for model parallelism
|
783 |
+
if self.model_parallel:
|
784 |
+
torch.cuda.set_device(self.decoder.first_device)
|
785 |
+
if input_ids is not None:
|
786 |
+
input_ids = input_ids.to(self.decoder.first_device)
|
787 |
+
if attention_mask is not None:
|
788 |
+
attention_mask = attention_mask.to(self.decoder.first_device)
|
789 |
+
|
790 |
+
# Decode
|
791 |
+
outputs = self.decoder(
|
792 |
+
input_ids=input_ids,
|
793 |
+
position_ids=position_ids,
|
794 |
+
attention_mask=attention_mask,
|
795 |
+
inputs_embeds=inputs_embeds,
|
796 |
+
past_key_values=past_key_values,
|
797 |
+
encoder_hidden_states=None,
|
798 |
+
encoder_attention_mask=None,
|
799 |
+
head_mask=None,
|
800 |
+
cross_attn_head_mask=None,
|
801 |
+
use_cache=use_cache,
|
802 |
+
output_attentions=output_attentions,
|
803 |
+
output_hidden_states=output_hidden_states,
|
804 |
+
return_dict=return_dict,
|
805 |
+
)
|
806 |
+
|
807 |
+
sequence_output = outputs[0]
|
808 |
+
|
809 |
+
# Set device for model parallelism
|
810 |
+
if self.model_parallel:
|
811 |
+
torch.cuda.set_device(self.decoder.first_device)
|
812 |
+
self.lm_head = self.lm_head.to(self.decoder.first_device)
|
813 |
+
sequence_output = sequence_output.to(self.lm_head.weight.device)
|
814 |
+
|
815 |
+
if self.config.tie_word_embeddings:
|
816 |
+
# Rescale output before projecting on vocab
|
817 |
+
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
|
818 |
+
sequence_output = sequence_output * (self.model_dim**-0.5)
|
819 |
+
|
820 |
+
lm_logits = self.lm_head(sequence_output)
|
821 |
+
|
822 |
+
loss = None
|
823 |
+
if labels is not None:
|
824 |
+
loss_fct = CrossEntropyLoss(ignore_index=-100)
|
825 |
+
# move labels to correct device to enable PP
|
826 |
+
labels = labels.to(lm_logits.device)
|
827 |
+
loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
|
828 |
+
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
|
829 |
+
|
830 |
+
if not return_dict:
|
831 |
+
output = (lm_logits,) + outputs[1:]
|
832 |
+
return ((loss,) + output) if loss is not None else output
|
833 |
+
|
834 |
+
return CausalLMOutputWithPast(
|
835 |
+
loss=loss,
|
836 |
+
logits=lm_logits,
|
837 |
+
past_key_values=outputs.past_key_values,
|
838 |
+
hidden_states=outputs.hidden_states,
|
839 |
+
attentions=outputs.attentions,
|
840 |
+
)
|
model-00000-of-00007.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:eee3b0c4eef668152f9f1106f18bf0a892bd04ba8b26017d7d5865f49dec5f3c
|
3 |
+
size 5150622792
|
model-00001-of-00007.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:29b44a655d9261522e705d963d1fa7ca1717c3e6bcfb9402fa69cb8ee6156c6f
|
3 |
+
size 4739650416
|
model-00002-of-00007.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1dbd16405fa07722d953ba5c99aeb8ae05c1068cb4018a9622e5828336c1b9c8
|
3 |
+
size 4739650424
|
model-00003-of-00007.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7444d7670f145f12706970a191b260b609a82626c5e417f08f1039c05fbdda75
|
3 |
+
size 4739650456
|
model-00004-of-00007.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:471e481023c3f4622d2ff0b1031a34c55469fb10fe1252f5c5c62e8f95418b4b
|
3 |
+
size 4739650456
|
model-00005-of-00007.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e934d8fcf5bc54649e82b29c1322b7ce3d13e8915d2cac205e65660e0a4cdbbb
|
3 |
+
size 4739650456
|
model-00006-of-00007.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f3f82de95453b49cb5a78ec517fac3219556f83edd9075a20a7ac95a577b5e93
|
3 |
+
size 4739650456
|
model-00007-of-00007.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f816e875186339eaf9afa8964d34952c80e697ade369154512f9900ea0a33553
|
3 |
+
size 947930104
|
model.safetensors.index.json
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"metadata": {},
|
3 |
+
"weight_map": {
|
4 |
+
"shared.weight": "model-00000-of-00007.safetensors",
|
5 |
+
"decoder.block.0.layer.0.layer_norm.weight": "model-00000-of-00007.safetensors",
|
6 |
+
"decoder.block.0.layer.0.SelfAttention.k.weight": "model-00000-of-00007.safetensors",
|
7 |
+
"decoder.block.0.layer.0.SelfAttention.o.weight": "model-00000-of-00007.safetensors",
|
8 |
+
"decoder.block.0.layer.0.SelfAttention.q.weight": "model-00000-of-00007.safetensors",
|
9 |
+
"decoder.block.0.layer.0.SelfAttention.v.weight": "model-00000-of-00007.safetensors",
|
10 |
+
"decoder.block.0.layer.2.DenseReluDense.wi_0.weight": "model-00000-of-00007.safetensors",
|
11 |
+
"decoder.block.0.layer.2.DenseReluDense.wi_1.weight": "model-00000-of-00007.safetensors",
|
12 |
+
"decoder.block.0.layer.2.DenseReluDense.wo.weight": "model-00000-of-00007.safetensors",
|
13 |
+
"decoder.block.1.layer.0.layer_norm.weight": "model-00001-of-00007.safetensors",
|
14 |
+
"decoder.block.1.layer.0.SelfAttention.k.weight": "model-00001-of-00007.safetensors",
|
15 |
+
"decoder.block.1.layer.0.SelfAttention.o.weight": "model-00001-of-00007.safetensors",
|
16 |
+
"decoder.block.1.layer.0.SelfAttention.q.weight": "model-00001-of-00007.safetensors",
|
17 |
+
"decoder.block.1.layer.0.SelfAttention.v.weight": "model-00001-of-00007.safetensors",
|
18 |
+
"decoder.block.1.layer.2.DenseReluDense.wi_0.weight": "model-00001-of-00007.safetensors",
|
19 |
+
"decoder.block.1.layer.2.DenseReluDense.wi_1.weight": "model-00001-of-00007.safetensors",
|
20 |
+
"decoder.block.1.layer.2.DenseReluDense.wo.weight": "model-00001-of-00007.safetensors",
|
21 |
+
"decoder.block.2.layer.0.layer_norm.weight": "model-00001-of-00007.safetensors",
|
22 |
+
"decoder.block.2.layer.0.SelfAttention.k.weight": "model-00001-of-00007.safetensors",
|
23 |
+
"decoder.block.2.layer.0.SelfAttention.o.weight": "model-00001-of-00007.safetensors",
|
24 |
+
"decoder.block.2.layer.0.SelfAttention.q.weight": "model-00001-of-00007.safetensors",
|
25 |
+
"decoder.block.2.layer.0.SelfAttention.v.weight": "model-00001-of-00007.safetensors",
|
26 |
+
"decoder.block.2.layer.2.DenseReluDense.wi_0.weight": "model-00001-of-00007.safetensors",
|
27 |
+
"decoder.block.2.layer.2.DenseReluDense.wi_1.weight": "model-00001-of-00007.safetensors",
|
28 |
+
"decoder.block.2.layer.2.DenseReluDense.wo.weight": "model-00001-of-00007.safetensors",
|
29 |
+
"decoder.block.3.layer.0.layer_norm.weight": "model-00001-of-00007.safetensors",
|
30 |
+
"decoder.block.3.layer.0.SelfAttention.k.weight": "model-00001-of-00007.safetensors",
|
31 |
+
"decoder.block.3.layer.0.SelfAttention.o.weight": "model-00001-of-00007.safetensors",
|
32 |
+
"decoder.block.3.layer.0.SelfAttention.q.weight": "model-00001-of-00007.safetensors",
|
33 |
+
"decoder.block.3.layer.0.SelfAttention.v.weight": "model-00001-of-00007.safetensors",
|
34 |
+
"decoder.block.3.layer.2.DenseReluDense.wi_0.weight": "model-00001-of-00007.safetensors",
|
35 |
+
"decoder.block.3.layer.2.DenseReluDense.wi_1.weight": "model-00001-of-00007.safetensors",
|
36 |
+
"decoder.block.3.layer.2.DenseReluDense.wo.weight": "model-00001-of-00007.safetensors",
|
37 |
+
"decoder.block.4.layer.0.layer_norm.weight": "model-00001-of-00007.safetensors",
|
38 |
+
"decoder.block.4.layer.0.SelfAttention.k.weight": "model-00001-of-00007.safetensors",
|
39 |
+
"decoder.block.4.layer.0.SelfAttention.o.weight": "model-00001-of-00007.safetensors",
|
40 |
+
"decoder.block.4.layer.0.SelfAttention.q.weight": "model-00001-of-00007.safetensors",
|
41 |
+
"decoder.block.4.layer.0.SelfAttention.v.weight": "model-00001-of-00007.safetensors",
|
42 |
+
"decoder.block.4.layer.2.DenseReluDense.wi_0.weight": "model-00001-of-00007.safetensors",
|
43 |
+
"decoder.block.4.layer.2.DenseReluDense.wi_1.weight": "model-00001-of-00007.safetensors",
|
44 |
+
"decoder.block.4.layer.2.DenseReluDense.wo.weight": "model-00001-of-00007.safetensors",
|
45 |
+
"decoder.block.5.layer.0.layer_norm.weight": "model-00001-of-00007.safetensors",
|
46 |
+
"decoder.block.5.layer.0.SelfAttention.k.weight": "model-00001-of-00007.safetensors",
|
47 |
+
"decoder.block.5.layer.0.SelfAttention.o.weight": "model-00001-of-00007.safetensors",
|
48 |
+
"decoder.block.5.layer.0.SelfAttention.q.weight": "model-00001-of-00007.safetensors",
|
49 |
+
"decoder.block.5.layer.0.SelfAttention.v.weight": "model-00001-of-00007.safetensors",
|
50 |
+
"decoder.block.5.layer.2.DenseReluDense.wi_0.weight": "model-00001-of-00007.safetensors",
|
51 |
+
"decoder.block.5.layer.2.DenseReluDense.wi_1.weight": "model-00001-of-00007.safetensors",
|
52 |
+
"decoder.block.5.layer.2.DenseReluDense.wo.weight": "model-00001-of-00007.safetensors",
|
53 |
+
"decoder.block.6.layer.0.layer_norm.weight": "model-00002-of-00007.safetensors",
|
54 |
+
"decoder.block.6.layer.0.SelfAttention.k.weight": "model-00002-of-00007.safetensors",
|
55 |
+
"decoder.block.6.layer.0.SelfAttention.o.weight": "model-00002-of-00007.safetensors",
|
56 |
+
"decoder.block.6.layer.0.SelfAttention.q.weight": "model-00002-of-00007.safetensors",
|
57 |
+
"decoder.block.6.layer.0.SelfAttention.v.weight": "model-00002-of-00007.safetensors",
|
58 |
+
"decoder.block.6.layer.2.DenseReluDense.wi_0.weight": "model-00002-of-00007.safetensors",
|
59 |
+
"decoder.block.6.layer.2.DenseReluDense.wi_1.weight": "model-00002-of-00007.safetensors",
|
60 |
+
"decoder.block.6.layer.2.DenseReluDense.wo.weight": "model-00002-of-00007.safetensors",
|
61 |
+
"decoder.block.7.layer.0.layer_norm.weight": "model-00002-of-00007.safetensors",
|
62 |
+
"decoder.block.7.layer.0.SelfAttention.k.weight": "model-00002-of-00007.safetensors",
|
63 |
+
"decoder.block.7.layer.0.SelfAttention.o.weight": "model-00002-of-00007.safetensors",
|
64 |
+
"decoder.block.7.layer.0.SelfAttention.q.weight": "model-00002-of-00007.safetensors",
|
65 |
+
"decoder.block.7.layer.0.SelfAttention.v.weight": "model-00002-of-00007.safetensors",
|
66 |
+
"decoder.block.7.layer.2.DenseReluDense.wi_0.weight": "model-00002-of-00007.safetensors",
|
67 |
+
"decoder.block.7.layer.2.DenseReluDense.wi_1.weight": "model-00002-of-00007.safetensors",
|
68 |
+
"decoder.block.7.layer.2.DenseReluDense.wo.weight": "model-00002-of-00007.safetensors",
|
69 |
+
"decoder.block.8.layer.0.layer_norm.weight": "model-00002-of-00007.safetensors",
|
70 |
+
"decoder.block.8.layer.0.SelfAttention.k.weight": "model-00002-of-00007.safetensors",
|
71 |
+
"decoder.block.8.layer.0.SelfAttention.o.weight": "model-00002-of-00007.safetensors",
|
72 |
+
"decoder.block.8.layer.0.SelfAttention.q.weight": "model-00002-of-00007.safetensors",
|
73 |
+
"decoder.block.8.layer.0.SelfAttention.v.weight": "model-00002-of-00007.safetensors",
|
74 |
+
"decoder.block.8.layer.2.DenseReluDense.wi_0.weight": "model-00002-of-00007.safetensors",
|
75 |
+
"decoder.block.8.layer.2.DenseReluDense.wi_1.weight": "model-00002-of-00007.safetensors",
|
76 |
+
"decoder.block.8.layer.2.DenseReluDense.wo.weight": "model-00002-of-00007.safetensors",
|
77 |
+
"decoder.block.9.layer.0.layer_norm.weight": "model-00002-of-00007.safetensors",
|
78 |
+
"decoder.block.9.layer.0.SelfAttention.k.weight": "model-00002-of-00007.safetensors",
|
79 |
+
"decoder.block.9.layer.0.SelfAttention.o.weight": "model-00002-of-00007.safetensors",
|
80 |
+
"decoder.block.9.layer.0.SelfAttention.q.weight": "model-00002-of-00007.safetensors",
|
81 |
+
"decoder.block.9.layer.0.SelfAttention.v.weight": "model-00002-of-00007.safetensors",
|
82 |
+
"decoder.block.9.layer.2.DenseReluDense.wi_0.weight": "model-00002-of-00007.safetensors",
|
83 |
+
"decoder.block.9.layer.2.DenseReluDense.wi_1.weight": "model-00002-of-00007.safetensors",
|
84 |
+
"decoder.block.9.layer.2.DenseReluDense.wo.weight": "model-00002-of-00007.safetensors",
|
85 |
+
"decoder.block.10.layer.0.layer_norm.weight": "model-00002-of-00007.safetensors",
|
86 |
+
"decoder.block.10.layer.0.SelfAttention.k.weight": "model-00002-of-00007.safetensors",
|
87 |
+
"decoder.block.10.layer.0.SelfAttention.o.weight": "model-00002-of-00007.safetensors",
|
88 |
+
"decoder.block.10.layer.0.SelfAttention.q.weight": "model-00002-of-00007.safetensors",
|
89 |
+
"decoder.block.10.layer.0.SelfAttention.v.weight": "model-00002-of-00007.safetensors",
|
90 |
+
"decoder.block.10.layer.2.DenseReluDense.wi_0.weight": "model-00002-of-00007.safetensors",
|
91 |
+
"decoder.block.10.layer.2.DenseReluDense.wi_1.weight": "model-00002-of-00007.safetensors",
|
92 |
+
"decoder.block.10.layer.2.DenseReluDense.wo.weight": "model-00002-of-00007.safetensors",
|
93 |
+
"decoder.block.11.layer.0.layer_norm.weight": "model-00003-of-00007.safetensors",
|
94 |
+
"decoder.block.11.layer.0.SelfAttention.k.weight": "model-00003-of-00007.safetensors",
|
95 |
+
"decoder.block.11.layer.0.SelfAttention.o.weight": "model-00003-of-00007.safetensors",
|
96 |
+
"decoder.block.11.layer.0.SelfAttention.q.weight": "model-00003-of-00007.safetensors",
|
97 |
+
"decoder.block.11.layer.0.SelfAttention.v.weight": "model-00003-of-00007.safetensors",
|
98 |
+
"decoder.block.11.layer.2.DenseReluDense.wi_0.weight": "model-00003-of-00007.safetensors",
|
99 |
+
"decoder.block.11.layer.2.DenseReluDense.wi_1.weight": "model-00003-of-00007.safetensors",
|
100 |
+
"decoder.block.11.layer.2.DenseReluDense.wo.weight": "model-00003-of-00007.safetensors",
|
101 |
+
"decoder.block.12.layer.0.layer_norm.weight": "model-00003-of-00007.safetensors",
|
102 |
+
"decoder.block.12.layer.0.SelfAttention.k.weight": "model-00003-of-00007.safetensors",
|
103 |
+
"decoder.block.12.layer.0.SelfAttention.o.weight": "model-00003-of-00007.safetensors",
|
104 |
+
"decoder.block.12.layer.0.SelfAttention.q.weight": "model-00003-of-00007.safetensors",
|
105 |
+
"decoder.block.12.layer.0.SelfAttention.v.weight": "model-00003-of-00007.safetensors",
|
106 |
+
"decoder.block.12.layer.2.DenseReluDense.wi_0.weight": "model-00003-of-00007.safetensors",
|
107 |
+
"decoder.block.12.layer.2.DenseReluDense.wi_1.weight": "model-00003-of-00007.safetensors",
|
108 |
+
"decoder.block.12.layer.2.DenseReluDense.wo.weight": "model-00003-of-00007.safetensors",
|
109 |
+
"decoder.block.13.layer.0.layer_norm.weight": "model-00003-of-00007.safetensors",
|
110 |
+
"decoder.block.13.layer.0.SelfAttention.k.weight": "model-00003-of-00007.safetensors",
|
111 |
+
"decoder.block.13.layer.0.SelfAttention.o.weight": "model-00003-of-00007.safetensors",
|
112 |
+
"decoder.block.13.layer.0.SelfAttention.q.weight": "model-00003-of-00007.safetensors",
|
113 |
+
"decoder.block.13.layer.0.SelfAttention.v.weight": "model-00003-of-00007.safetensors",
|
114 |
+
"decoder.block.13.layer.2.DenseReluDense.wi_0.weight": "model-00003-of-00007.safetensors",
|
115 |
+
"decoder.block.13.layer.2.DenseReluDense.wi_1.weight": "model-00003-of-00007.safetensors",
|
116 |
+
"decoder.block.13.layer.2.DenseReluDense.wo.weight": "model-00003-of-00007.safetensors",
|
117 |
+
"decoder.block.14.layer.0.layer_norm.weight": "model-00003-of-00007.safetensors",
|
118 |
+
"decoder.block.14.layer.0.SelfAttention.k.weight": "model-00003-of-00007.safetensors",
|
119 |
+
"decoder.block.14.layer.0.SelfAttention.o.weight": "model-00003-of-00007.safetensors",
|
120 |
+
"decoder.block.14.layer.0.SelfAttention.q.weight": "model-00003-of-00007.safetensors",
|
121 |
+
"decoder.block.14.layer.0.SelfAttention.v.weight": "model-00003-of-00007.safetensors",
|
122 |
+
"decoder.block.14.layer.2.DenseReluDense.wi_0.weight": "model-00003-of-00007.safetensors",
|
123 |
+
"decoder.block.14.layer.2.DenseReluDense.wi_1.weight": "model-00003-of-00007.safetensors",
|
124 |
+
"decoder.block.14.layer.2.DenseReluDense.wo.weight": "model-00003-of-00007.safetensors",
|
125 |
+
"decoder.block.15.layer.0.layer_norm.weight": "model-00003-of-00007.safetensors",
|
126 |
+
"decoder.block.15.layer.0.SelfAttention.k.weight": "model-00003-of-00007.safetensors",
|
127 |
+
"decoder.block.15.layer.0.SelfAttention.o.weight": "model-00003-of-00007.safetensors",
|
128 |
+
"decoder.block.15.layer.0.SelfAttention.q.weight": "model-00003-of-00007.safetensors",
|
129 |
+
"decoder.block.15.layer.0.SelfAttention.v.weight": "model-00003-of-00007.safetensors",
|
130 |
+
"decoder.block.15.layer.2.DenseReluDense.wi_0.weight": "model-00003-of-00007.safetensors",
|
131 |
+
"decoder.block.15.layer.2.DenseReluDense.wi_1.weight": "model-00003-of-00007.safetensors",
|
132 |
+
"decoder.block.15.layer.2.DenseReluDense.wo.weight": "model-00003-of-00007.safetensors",
|
133 |
+
"decoder.block.16.layer.0.layer_norm.weight": "model-00004-of-00007.safetensors",
|
134 |
+
"decoder.block.16.layer.0.SelfAttention.k.weight": "model-00004-of-00007.safetensors",
|
135 |
+
"decoder.block.16.layer.0.SelfAttention.o.weight": "model-00004-of-00007.safetensors",
|
136 |
+
"decoder.block.16.layer.0.SelfAttention.q.weight": "model-00004-of-00007.safetensors",
|
137 |
+
"decoder.block.16.layer.0.SelfAttention.v.weight": "model-00004-of-00007.safetensors",
|
138 |
+
"decoder.block.16.layer.2.DenseReluDense.wi_0.weight": "model-00004-of-00007.safetensors",
|
139 |
+
"decoder.block.16.layer.2.DenseReluDense.wi_1.weight": "model-00004-of-00007.safetensors",
|
140 |
+
"decoder.block.16.layer.2.DenseReluDense.wo.weight": "model-00004-of-00007.safetensors",
|
141 |
+
"decoder.block.17.layer.0.layer_norm.weight": "model-00004-of-00007.safetensors",
|
142 |
+
"decoder.block.17.layer.0.SelfAttention.k.weight": "model-00004-of-00007.safetensors",
|
143 |
+
"decoder.block.17.layer.0.SelfAttention.o.weight": "model-00004-of-00007.safetensors",
|
144 |
+
"decoder.block.17.layer.0.SelfAttention.q.weight": "model-00004-of-00007.safetensors",
|
145 |
+
"decoder.block.17.layer.0.SelfAttention.v.weight": "model-00004-of-00007.safetensors",
|
146 |
+
"decoder.block.17.layer.2.DenseReluDense.wi_0.weight": "model-00004-of-00007.safetensors",
|
147 |
+
"decoder.block.17.layer.2.DenseReluDense.wi_1.weight": "model-00004-of-00007.safetensors",
|
148 |
+
"decoder.block.17.layer.2.DenseReluDense.wo.weight": "model-00004-of-00007.safetensors",
|
149 |
+
"decoder.block.18.layer.0.layer_norm.weight": "model-00004-of-00007.safetensors",
|
150 |
+
"decoder.block.18.layer.0.SelfAttention.k.weight": "model-00004-of-00007.safetensors",
|
151 |
+
"decoder.block.18.layer.0.SelfAttention.o.weight": "model-00004-of-00007.safetensors",
|
152 |
+
"decoder.block.18.layer.0.SelfAttention.q.weight": "model-00004-of-00007.safetensors",
|
153 |
+
"decoder.block.18.layer.0.SelfAttention.v.weight": "model-00004-of-00007.safetensors",
|
154 |
+
"decoder.block.18.layer.2.DenseReluDense.wi_0.weight": "model-00004-of-00007.safetensors",
|
155 |
+
"decoder.block.18.layer.2.DenseReluDense.wi_1.weight": "model-00004-of-00007.safetensors",
|
156 |
+
"decoder.block.18.layer.2.DenseReluDense.wo.weight": "model-00004-of-00007.safetensors",
|
157 |
+
"decoder.block.19.layer.0.layer_norm.weight": "model-00004-of-00007.safetensors",
|
158 |
+
"decoder.block.19.layer.0.SelfAttention.k.weight": "model-00004-of-00007.safetensors",
|
159 |
+
"decoder.block.19.layer.0.SelfAttention.o.weight": "model-00004-of-00007.safetensors",
|
160 |
+
"decoder.block.19.layer.0.SelfAttention.q.weight": "model-00004-of-00007.safetensors",
|
161 |
+
"decoder.block.19.layer.0.SelfAttention.v.weight": "model-00004-of-00007.safetensors",
|
162 |
+
"decoder.block.19.layer.2.DenseReluDense.wi_0.weight": "model-00004-of-00007.safetensors",
|
163 |
+
"decoder.block.19.layer.2.DenseReluDense.wi_1.weight": "model-00004-of-00007.safetensors",
|
164 |
+
"decoder.block.19.layer.2.DenseReluDense.wo.weight": "model-00004-of-00007.safetensors",
|
165 |
+
"decoder.block.20.layer.0.layer_norm.weight": "model-00004-of-00007.safetensors",
|
166 |
+
"decoder.block.20.layer.0.SelfAttention.k.weight": "model-00004-of-00007.safetensors",
|
167 |
+
"decoder.block.20.layer.0.SelfAttention.o.weight": "model-00004-of-00007.safetensors",
|
168 |
+
"decoder.block.20.layer.0.SelfAttention.q.weight": "model-00004-of-00007.safetensors",
|
169 |
+
"decoder.block.20.layer.0.SelfAttention.v.weight": "model-00004-of-00007.safetensors",
|
170 |
+
"decoder.block.20.layer.2.DenseReluDense.wi_0.weight": "model-00004-of-00007.safetensors",
|
171 |
+
"decoder.block.20.layer.2.DenseReluDense.wi_1.weight": "model-00004-of-00007.safetensors",
|
172 |
+
"decoder.block.20.layer.2.DenseReluDense.wo.weight": "model-00004-of-00007.safetensors",
|
173 |
+
"decoder.block.21.layer.0.layer_norm.weight": "model-00005-of-00007.safetensors",
|
174 |
+
"decoder.block.21.layer.0.SelfAttention.k.weight": "model-00005-of-00007.safetensors",
|
175 |
+
"decoder.block.21.layer.0.SelfAttention.o.weight": "model-00005-of-00007.safetensors",
|
176 |
+
"decoder.block.21.layer.0.SelfAttention.q.weight": "model-00005-of-00007.safetensors",
|
177 |
+
"decoder.block.21.layer.0.SelfAttention.v.weight": "model-00005-of-00007.safetensors",
|
178 |
+
"decoder.block.21.layer.2.DenseReluDense.wi_0.weight": "model-00005-of-00007.safetensors",
|
179 |
+
"decoder.block.21.layer.2.DenseReluDense.wi_1.weight": "model-00005-of-00007.safetensors",
|
180 |
+
"decoder.block.21.layer.2.DenseReluDense.wo.weight": "model-00005-of-00007.safetensors",
|
181 |
+
"decoder.block.22.layer.0.layer_norm.weight": "model-00005-of-00007.safetensors",
|
182 |
+
"decoder.block.22.layer.0.SelfAttention.k.weight": "model-00005-of-00007.safetensors",
|
183 |
+
"decoder.block.22.layer.0.SelfAttention.o.weight": "model-00005-of-00007.safetensors",
|
184 |
+
"decoder.block.22.layer.0.SelfAttention.q.weight": "model-00005-of-00007.safetensors",
|
185 |
+
"decoder.block.22.layer.0.SelfAttention.v.weight": "model-00005-of-00007.safetensors",
|
186 |
+
"decoder.block.22.layer.2.DenseReluDense.wi_0.weight": "model-00005-of-00007.safetensors",
|
187 |
+
"decoder.block.22.layer.2.DenseReluDense.wi_1.weight": "model-00005-of-00007.safetensors",
|
188 |
+
"decoder.block.22.layer.2.DenseReluDense.wo.weight": "model-00005-of-00007.safetensors",
|
189 |
+
"decoder.block.23.layer.0.layer_norm.weight": "model-00005-of-00007.safetensors",
|
190 |
+
"decoder.block.23.layer.0.SelfAttention.k.weight": "model-00005-of-00007.safetensors",
|
191 |
+
"decoder.block.23.layer.0.SelfAttention.o.weight": "model-00005-of-00007.safetensors",
|
192 |
+
"decoder.block.23.layer.0.SelfAttention.q.weight": "model-00005-of-00007.safetensors",
|
193 |
+
"decoder.block.23.layer.0.SelfAttention.v.weight": "model-00005-of-00007.safetensors",
|
194 |
+
"decoder.block.23.layer.2.DenseReluDense.wi_0.weight": "model-00005-of-00007.safetensors",
|
195 |
+
"decoder.block.23.layer.2.DenseReluDense.wi_1.weight": "model-00005-of-00007.safetensors",
|
196 |
+
"decoder.block.23.layer.2.DenseReluDense.wo.weight": "model-00005-of-00007.safetensors",
|
197 |
+
"decoder.block.24.layer.0.layer_norm.weight": "model-00005-of-00007.safetensors",
|
198 |
+
"decoder.block.24.layer.0.SelfAttention.k.weight": "model-00005-of-00007.safetensors",
|
199 |
+
"decoder.block.24.layer.0.SelfAttention.o.weight": "model-00005-of-00007.safetensors",
|
200 |
+
"decoder.block.24.layer.0.SelfAttention.q.weight": "model-00005-of-00007.safetensors",
|
201 |
+
"decoder.block.24.layer.0.SelfAttention.v.weight": "model-00005-of-00007.safetensors",
|
202 |
+
"decoder.block.24.layer.2.DenseReluDense.wi_0.weight": "model-00005-of-00007.safetensors",
|
203 |
+
"decoder.block.24.layer.2.DenseReluDense.wi_1.weight": "model-00005-of-00007.safetensors",
|
204 |
+
"decoder.block.24.layer.2.DenseReluDense.wo.weight": "model-00005-of-00007.safetensors",
|
205 |
+
"decoder.block.25.layer.0.layer_norm.weight": "model-00005-of-00007.safetensors",
|
206 |
+
"decoder.block.25.layer.0.SelfAttention.k.weight": "model-00005-of-00007.safetensors",
|
207 |
+
"decoder.block.25.layer.0.SelfAttention.o.weight": "model-00005-of-00007.safetensors",
|
208 |
+
"decoder.block.25.layer.0.SelfAttention.q.weight": "model-00005-of-00007.safetensors",
|
209 |
+
"decoder.block.25.layer.0.SelfAttention.v.weight": "model-00005-of-00007.safetensors",
|
210 |
+
"decoder.block.25.layer.2.DenseReluDense.wi_0.weight": "model-00005-of-00007.safetensors",
|
211 |
+
"decoder.block.25.layer.2.DenseReluDense.wi_1.weight": "model-00005-of-00007.safetensors",
|
212 |
+
"decoder.block.25.layer.2.DenseReluDense.wo.weight": "model-00005-of-00007.safetensors",
|
213 |
+
"decoder.block.26.layer.0.layer_norm.weight": "model-00006-of-00007.safetensors",
|
214 |
+
"decoder.block.26.layer.0.SelfAttention.k.weight": "model-00006-of-00007.safetensors",
|
215 |
+
"decoder.block.26.layer.0.SelfAttention.o.weight": "model-00006-of-00007.safetensors",
|
216 |
+
"decoder.block.26.layer.0.SelfAttention.q.weight": "model-00006-of-00007.safetensors",
|
217 |
+
"decoder.block.26.layer.0.SelfAttention.v.weight": "model-00006-of-00007.safetensors",
|
218 |
+
"decoder.block.26.layer.2.DenseReluDense.wi_0.weight": "model-00006-of-00007.safetensors",
|
219 |
+
"decoder.block.26.layer.2.DenseReluDense.wi_1.weight": "model-00006-of-00007.safetensors",
|
220 |
+
"decoder.block.26.layer.2.DenseReluDense.wo.weight": "model-00006-of-00007.safetensors",
|
221 |
+
"decoder.block.27.layer.0.layer_norm.weight": "model-00006-of-00007.safetensors",
|
222 |
+
"decoder.block.27.layer.0.SelfAttention.k.weight": "model-00006-of-00007.safetensors",
|
223 |
+
"decoder.block.27.layer.0.SelfAttention.o.weight": "model-00006-of-00007.safetensors",
|
224 |
+
"decoder.block.27.layer.0.SelfAttention.q.weight": "model-00006-of-00007.safetensors",
|
225 |
+
"decoder.block.27.layer.0.SelfAttention.v.weight": "model-00006-of-00007.safetensors",
|
226 |
+
"decoder.block.27.layer.2.DenseReluDense.wi_0.weight": "model-00006-of-00007.safetensors",
|
227 |
+
"decoder.block.27.layer.2.DenseReluDense.wi_1.weight": "model-00006-of-00007.safetensors",
|
228 |
+
"decoder.block.27.layer.2.DenseReluDense.wo.weight": "model-00006-of-00007.safetensors",
|
229 |
+
"decoder.block.28.layer.0.layer_norm.weight": "model-00006-of-00007.safetensors",
|
230 |
+
"decoder.block.28.layer.0.SelfAttention.k.weight": "model-00006-of-00007.safetensors",
|
231 |
+
"decoder.block.28.layer.0.SelfAttention.o.weight": "model-00006-of-00007.safetensors",
|
232 |
+
"decoder.block.28.layer.0.SelfAttention.q.weight": "model-00006-of-00007.safetensors",
|
233 |
+
"decoder.block.28.layer.0.SelfAttention.v.weight": "model-00006-of-00007.safetensors",
|
234 |
+
"decoder.block.28.layer.2.DenseReluDense.wi_0.weight": "model-00006-of-00007.safetensors",
|
235 |
+
"decoder.block.28.layer.2.DenseReluDense.wi_1.weight": "model-00006-of-00007.safetensors",
|
236 |
+
"decoder.block.28.layer.2.DenseReluDense.wo.weight": "model-00006-of-00007.safetensors",
|
237 |
+
"decoder.block.29.layer.0.layer_norm.weight": "model-00006-of-00007.safetensors",
|
238 |
+
"decoder.block.29.layer.0.SelfAttention.k.weight": "model-00006-of-00007.safetensors",
|
239 |
+
"decoder.block.29.layer.0.SelfAttention.o.weight": "model-00006-of-00007.safetensors",
|
240 |
+
"decoder.block.29.layer.0.SelfAttention.q.weight": "model-00006-of-00007.safetensors",
|
241 |
+
"decoder.block.29.layer.0.SelfAttention.v.weight": "model-00006-of-00007.safetensors",
|
242 |
+
"decoder.block.29.layer.2.DenseReluDense.wi_0.weight": "model-00006-of-00007.safetensors",
|
243 |
+
"decoder.block.29.layer.2.DenseReluDense.wi_1.weight": "model-00006-of-00007.safetensors",
|
244 |
+
"decoder.block.29.layer.2.DenseReluDense.wo.weight": "model-00006-of-00007.safetensors",
|
245 |
+
"decoder.block.30.layer.0.layer_norm.weight": "model-00006-of-00007.safetensors",
|
246 |
+
"decoder.block.30.layer.0.SelfAttention.k.weight": "model-00006-of-00007.safetensors",
|
247 |
+
"decoder.block.30.layer.0.SelfAttention.o.weight": "model-00006-of-00007.safetensors",
|
248 |
+
"decoder.block.30.layer.0.SelfAttention.q.weight": "model-00006-of-00007.safetensors",
|
249 |
+
"decoder.block.30.layer.0.SelfAttention.v.weight": "model-00006-of-00007.safetensors",
|
250 |
+
"decoder.block.30.layer.2.DenseReluDense.wi_0.weight": "model-00006-of-00007.safetensors",
|
251 |
+
"decoder.block.30.layer.2.DenseReluDense.wi_1.weight": "model-00006-of-00007.safetensors",
|
252 |
+
"decoder.block.30.layer.2.DenseReluDense.wo.weight": "model-00006-of-00007.safetensors",
|
253 |
+
"decoder.block.31.layer.0.layer_norm.weight": "model-00007-of-00007.safetensors",
|
254 |
+
"decoder.block.31.layer.0.SelfAttention.k.weight": "model-00007-of-00007.safetensors",
|
255 |
+
"decoder.block.31.layer.0.SelfAttention.o.weight": "model-00007-of-00007.safetensors",
|
256 |
+
"decoder.block.31.layer.0.SelfAttention.q.weight": "model-00007-of-00007.safetensors",
|
257 |
+
"decoder.block.31.layer.0.SelfAttention.v.weight": "model-00007-of-00007.safetensors",
|
258 |
+
"decoder.block.31.layer.2.DenseReluDense.wi_0.weight": "model-00007-of-00007.safetensors",
|
259 |
+
"decoder.block.31.layer.2.DenseReluDense.wi_1.weight": "model-00007-of-00007.safetensors",
|
260 |
+
"decoder.block.31.layer.2.DenseReluDense.wo.weight": "model-00007-of-00007.safetensors"
|
261 |
+
}
|
262 |
+
}
|
special_tokens_map.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"eos_token": {
|
3 |
+
"content": "</s>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": false,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"pad_token": {
|
10 |
+
"content": "<s>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": false,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"unk_token": {
|
17 |
+
"content": "<unk>",
|
18 |
+
"lstrip": false,
|
19 |
+
"normalized": false,
|
20 |
+
"rstrip": false,
|
21 |
+
"single_word": false
|
22 |
+
}
|
23 |
+
}
|
spiece.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ef11ac9a22c7503492f56d48dce53be20e339b63605983e9f27d2cd0e0f3922c
|
3 |
+
size 4427844
|
tokenizer.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a2799ccc696b752ba00c34f58726bfe253a04921ceb6cfc620400f560474790b
|
3 |
+
size 16629031
|
tokenizer_config.json
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"added_tokens_decoder": {
|
3 |
+
"0": {
|
4 |
+
"content": "<unk>",
|
5 |
+
"lstrip": false,
|
6 |
+
"normalized": false,
|
7 |
+
"rstrip": false,
|
8 |
+
"single_word": false,
|
9 |
+
"special": true
|
10 |
+
},
|
11 |
+
"1": {
|
12 |
+
"content": "<s>",
|
13 |
+
"lstrip": false,
|
14 |
+
"normalized": false,
|
15 |
+
"rstrip": false,
|
16 |
+
"single_word": false,
|
17 |
+
"special": true
|
18 |
+
},
|
19 |
+
"2": {
|
20 |
+
"content": "</s>",
|
21 |
+
"lstrip": false,
|
22 |
+
"normalized": false,
|
23 |
+
"rstrip": false,
|
24 |
+
"single_word": false,
|
25 |
+
"special": true
|
26 |
+
}
|
27 |
+
},
|
28 |
+
"additional_special_tokens": [],
|
29 |
+
"clean_up_tokenization_spaces": true,
|
30 |
+
"eos_token": "</s>",
|
31 |
+
"extra_ids": 0,
|
32 |
+
"legacy": false,
|
33 |
+
"model_max_length": 1000000000000000019884624838656,
|
34 |
+
"pad_token": "<s>",
|
35 |
+
"sp_model_kwargs": {},
|
36 |
+
"tokenizer_class": "T5Tokenizer",
|
37 |
+
"unk_token": "<unk>"
|
38 |
+
}
|