jpc commited on
Commit
5190182
β€’
1 Parent(s): a68da21

Removed the local copy of the whisperspeech library

Browse files
whisperspeech/__init__.py DELETED
@@ -1 +0,0 @@
1
- __version__ = "0.5.6"
 
 
whisperspeech/_modidx.py DELETED
@@ -1,615 +0,0 @@
1
- # Autogenerated by nbdev
2
-
3
- d = { 'settings': { 'branch': 'master',
4
- 'doc_baseurl': '/WhisperSpeech',
5
- 'doc_host': 'https://collabora.github.io',
6
- 'git_url': 'https://github.com/collabora/WhisperSpeech',
7
- 'lib_path': 'whisperspeech'},
8
- 'syms': { 'whisperspeech.a2wav': { 'whisperspeech.a2wav.Vocoder': ('6. quality-boosting vocoder.html#vocoder', 'whisperspeech/a2wav.py'),
9
- 'whisperspeech.a2wav.Vocoder.__init__': ( '6. quality-boosting vocoder.html#vocoder.__init__',
10
- 'whisperspeech/a2wav.py'),
11
- 'whisperspeech.a2wav.Vocoder.decode': ( '6. quality-boosting vocoder.html#vocoder.decode',
12
- 'whisperspeech/a2wav.py'),
13
- 'whisperspeech.a2wav.Vocoder.decode_to_file': ( '6. quality-boosting '
14
- 'vocoder.html#vocoder.decode_to_file',
15
- 'whisperspeech/a2wav.py'),
16
- 'whisperspeech.a2wav.Vocoder.decode_to_notebook': ( '6. quality-boosting '
17
- 'vocoder.html#vocoder.decode_to_notebook',
18
- 'whisperspeech/a2wav.py')},
19
- 'whisperspeech.extract_acoustic': { 'whisperspeech.extract_acoustic.extract_Atoks': ( '1. acoustic token '
20
- 'extraction.html#extract_atoks',
21
- 'whisperspeech/extract_acoustic.py'),
22
- 'whisperspeech.extract_acoustic.extract_acoustic': ( '1. acoustic token '
23
- 'extraction.html#extract_acoustic',
24
- 'whisperspeech/extract_acoustic.py'),
25
- 'whisperspeech.extract_acoustic.load': ( '1. acoustic token extraction.html#load',
26
- 'whisperspeech/extract_acoustic.py'),
27
- 'whisperspeech.extract_acoustic.load_model': ( '1. acoustic token '
28
- 'extraction.html#load_model',
29
- 'whisperspeech/extract_acoustic.py')},
30
- 'whisperspeech.extract_semb': { 'whisperspeech.extract_semb.encode_semantic': ( '2c. whisper semantic embedding '
31
- 'extraction.html#encode_semantic',
32
- 'whisperspeech/extract_semb.py'),
33
- 'whisperspeech.extract_semb.extract_semantic': ( '2c. whisper semantic embedding '
34
- 'extraction.html#extract_semantic',
35
- 'whisperspeech/extract_semb.py'),
36
- 'whisperspeech.extract_semb.load_model': ( '2c. whisper semantic embedding '
37
- 'extraction.html#load_model',
38
- 'whisperspeech/extract_semb.py')},
39
- 'whisperspeech.fetch_models': { 'whisperspeech.fetch_models.main': ( '0. download models.html#main',
40
- 'whisperspeech/fetch_models.py')},
41
- 'whisperspeech.modules': { 'whisperspeech.modules.Decoder': ('a. neural modules.html#decoder', 'whisperspeech/modules.py'),
42
- 'whisperspeech.modules.Decoder.__init__': ( 'a. neural modules.html#decoder.__init__',
43
- 'whisperspeech/modules.py'),
44
- 'whisperspeech.modules.Decoder.forward': ( 'a. neural modules.html#decoder.forward',
45
- 'whisperspeech/modules.py'),
46
- 'whisperspeech.modules.Encoder': ('a. neural modules.html#encoder', 'whisperspeech/modules.py'),
47
- 'whisperspeech.modules.Encoder.__init__': ( 'a. neural modules.html#encoder.__init__',
48
- 'whisperspeech/modules.py'),
49
- 'whisperspeech.modules.Encoder.forward': ( 'a. neural modules.html#encoder.forward',
50
- 'whisperspeech/modules.py'),
51
- 'whisperspeech.modules.LayerNorm': ('a. neural modules.html#layernorm', 'whisperspeech/modules.py'),
52
- 'whisperspeech.modules.LayerNorm.forward': ( 'a. neural modules.html#layernorm.forward',
53
- 'whisperspeech/modules.py'),
54
- 'whisperspeech.modules.LinearHead': ( 'a. neural modules.html#linearhead',
55
- 'whisperspeech/modules.py'),
56
- 'whisperspeech.modules.MultiHeadAttention': ( 'a. neural modules.html#multiheadattention',
57
- 'whisperspeech/modules.py'),
58
- 'whisperspeech.modules.MultiHeadAttention.__init__': ( 'a. neural '
59
- 'modules.html#multiheadattention.__init__',
60
- 'whisperspeech/modules.py'),
61
- 'whisperspeech.modules.MultiHeadAttention.forward': ( 'a. neural '
62
- 'modules.html#multiheadattention.forward',
63
- 'whisperspeech/modules.py'),
64
- 'whisperspeech.modules.MultiHeadAttention.qkv_attention_pth20': ( 'a. neural '
65
- 'modules.html#multiheadattention.qkv_attention_pth20',
66
- 'whisperspeech/modules.py'),
67
- 'whisperspeech.modules.MultiHeadAttention.qkv_attention_vanilla': ( 'a. neural '
68
- 'modules.html#multiheadattention.qkv_attention_vanilla',
69
- 'whisperspeech/modules.py'),
70
- 'whisperspeech.modules.MultiHeadAttention.qkv_attention_xformers': ( 'a. neural '
71
- 'modules.html#multiheadattention.qkv_attention_xformers',
72
- 'whisperspeech/modules.py'),
73
- 'whisperspeech.modules.QueryHead': ('a. neural modules.html#queryhead', 'whisperspeech/modules.py'),
74
- 'whisperspeech.modules.ResidualAttentionBlock': ( 'a. neural modules.html#residualattentionblock',
75
- 'whisperspeech/modules.py'),
76
- 'whisperspeech.modules.ResidualAttentionBlock.__init__': ( 'a. neural '
77
- 'modules.html#residualattentionblock.__init__',
78
- 'whisperspeech/modules.py'),
79
- 'whisperspeech.modules.ResidualAttentionBlock.forward': ( 'a. neural '
80
- 'modules.html#residualattentionblock.forward',
81
- 'whisperspeech/modules.py'),
82
- 'whisperspeech.modules.Rotary': ('a. neural modules.html#rotary', 'whisperspeech/modules.py'),
83
- 'whisperspeech.modules.Rotary.__init__': ( 'a. neural modules.html#rotary.__init__',
84
- 'whisperspeech/modules.py'),
85
- 'whisperspeech.modules.Rotary.forward': ( 'a. neural modules.html#rotary.forward',
86
- 'whisperspeech/modules.py'),
87
- 'whisperspeech.modules.SumDecoder': ( 'a. neural modules.html#sumdecoder',
88
- 'whisperspeech/modules.py'),
89
- 'whisperspeech.modules.SumDecoder.__init__': ( 'a. neural modules.html#sumdecoder.__init__',
90
- 'whisperspeech/modules.py'),
91
- 'whisperspeech.modules.SumDecoder.forward': ( 'a. neural modules.html#sumdecoder.forward',
92
- 'whisperspeech/modules.py'),
93
- 'whisperspeech.modules.apply_rotary_pos_emb': ( 'a. neural modules.html#apply_rotary_pos_emb',
94
- 'whisperspeech/modules.py'),
95
- 'whisperspeech.modules.init_transformer': ( 'a. neural modules.html#init_transformer',
96
- 'whisperspeech/modules.py'),
97
- 'whisperspeech.modules.rotate_half': ( 'a. neural modules.html#rotate_half',
98
- 'whisperspeech/modules.py'),
99
- 'whisperspeech.modules.sinusoids': ('a. neural modules.html#sinusoids', 'whisperspeech/modules.py')},
100
- 'whisperspeech.pipeline': { 'whisperspeech.pipeline.Pipeline': ('7. pipeline.html#pipeline', 'whisperspeech/pipeline.py'),
101
- 'whisperspeech.pipeline.Pipeline.__init__': ( '7. pipeline.html#pipeline.__init__',
102
- 'whisperspeech/pipeline.py'),
103
- 'whisperspeech.pipeline.Pipeline.generate': ( '7. pipeline.html#pipeline.generate',
104
- 'whisperspeech/pipeline.py'),
105
- 'whisperspeech.pipeline.Pipeline.generate_atoks': ( '7. pipeline.html#pipeline.generate_atoks',
106
- 'whisperspeech/pipeline.py'),
107
- 'whisperspeech.pipeline.Pipeline.generate_to_file': ( '7. pipeline.html#pipeline.generate_to_file',
108
- 'whisperspeech/pipeline.py'),
109
- 'whisperspeech.pipeline.Pipeline.generate_to_notebook': ( '7. '
110
- 'pipeline.html#pipeline.generate_to_notebook',
111
- 'whisperspeech/pipeline.py')},
112
- 'whisperspeech.prepare_s2a_dataset': { 'whisperspeech.prepare_s2a_dataset.flac_to_s2a_name': ( '4a. s2a dataset '
113
- 'preparation.html#flac_to_s2a_name',
114
- 'whisperspeech/prepare_s2a_dataset.py'),
115
- 'whisperspeech.prepare_s2a_dataset.prepare_s2a': ( '4a. s2a dataset '
116
- 'preparation.html#prepare_s2a',
117
- 'whisperspeech/prepare_s2a_dataset.py'),
118
- 'whisperspeech.prepare_s2a_dataset.resampler': ( '4a. s2a dataset '
119
- 'preparation.html#resampler',
120
- 'whisperspeech/prepare_s2a_dataset.py')},
121
- 'whisperspeech.prepare_t2s_dataset': { 'whisperspeech.prepare_t2s_dataset.Transcriber': ( '5a. t2s dataset '
122
- 'preparation.html#transcriber',
123
- 'whisperspeech/prepare_t2s_dataset.py'),
124
- 'whisperspeech.prepare_t2s_dataset.Transcriber.__init__': ( '5a. t2s dataset '
125
- 'preparation.html#transcriber.__init__',
126
- 'whisperspeech/prepare_t2s_dataset.py'),
127
- 'whisperspeech.prepare_t2s_dataset.Transcriber.transcribe': ( '5a. t2s dataset '
128
- 'preparation.html#transcriber.transcribe',
129
- 'whisperspeech/prepare_t2s_dataset.py'),
130
- 'whisperspeech.prepare_t2s_dataset.flac_to_t2s_name': ( '5a. t2s dataset '
131
- 'preparation.html#flac_to_t2s_name',
132
- 'whisperspeech/prepare_t2s_dataset.py'),
133
- 'whisperspeech.prepare_t2s_dataset.prepare_t2s': ( '5a. t2s dataset '
134
- 'preparation.html#prepare_t2s',
135
- 'whisperspeech/prepare_t2s_dataset.py')},
136
- 'whisperspeech.s2a_delar_mup_wds': { 'whisperspeech.s2a_delar_mup_wds.CMLMVisual': ( '4b. semantic to acoustic token '
137
- 'modeling.html#cmlmvisual',
138
- 'whisperspeech/s2a_delar_mup_wds.py'),
139
- 'whisperspeech.s2a_delar_mup_wds.CMLMVisual.__init__': ( '4b. semantic to acoustic token '
140
- 'modeling.html#cmlmvisual.__init__',
141
- 'whisperspeech/s2a_delar_mup_wds.py'),
142
- 'whisperspeech.s2a_delar_mup_wds.CMLMVisual.add_data': ( '4b. semantic to acoustic token '
143
- 'modeling.html#cmlmvisual.add_data',
144
- 'whisperspeech/s2a_delar_mup_wds.py'),
145
- 'whisperspeech.s2a_delar_mup_wds.CMLMVisual.add_table_row': ( '4b. semantic to acoustic '
146
- 'token '
147
- 'modeling.html#cmlmvisual.add_table_row',
148
- 'whisperspeech/s2a_delar_mup_wds.py'),
149
- 'whisperspeech.s2a_delar_mup_wds.CMLMVisual.hide': ( '4b. semantic to acoustic token '
150
- 'modeling.html#cmlmvisual.hide',
151
- 'whisperspeech/s2a_delar_mup_wds.py'),
152
- 'whisperspeech.s2a_delar_mup_wds.CMLMVisual.on_iter': ( '4b. semantic to acoustic token '
153
- 'modeling.html#cmlmvisual.on_iter',
154
- 'whisperspeech/s2a_delar_mup_wds.py'),
155
- 'whisperspeech.s2a_delar_mup_wds.CMLMVisual.plot': ( '4b. semantic to acoustic token '
156
- 'modeling.html#cmlmvisual.plot',
157
- 'whisperspeech/s2a_delar_mup_wds.py'),
158
- 'whisperspeech.s2a_delar_mup_wds.CMLMVisual.show': ( '4b. semantic to acoustic token '
159
- 'modeling.html#cmlmvisual.show',
160
- 'whisperspeech/s2a_delar_mup_wds.py'),
161
- 'whisperspeech.s2a_delar_mup_wds.DelSumDecoder': ( '4b. semantic to acoustic token '
162
- 'modeling.html#delsumdecoder',
163
- 'whisperspeech/s2a_delar_mup_wds.py'),
164
- 'whisperspeech.s2a_delar_mup_wds.DelSumDecoder.__init__': ( '4b. semantic to acoustic '
165
- 'token '
166
- 'modeling.html#delsumdecoder.__init__',
167
- 'whisperspeech/s2a_delar_mup_wds.py'),
168
- 'whisperspeech.s2a_delar_mup_wds.DelSumDecoder.forward': ( '4b. semantic to acoustic '
169
- 'token '
170
- 'modeling.html#delsumdecoder.forward',
171
- 'whisperspeech/s2a_delar_mup_wds.py'),
172
- 'whisperspeech.s2a_delar_mup_wds.EmbeddingProjector': ( '4b. semantic to acoustic token '
173
- 'modeling.html#embeddingprojector',
174
- 'whisperspeech/s2a_delar_mup_wds.py'),
175
- 'whisperspeech.s2a_delar_mup_wds.MultiHeadAttention': ( '4b. semantic to acoustic token '
176
- 'modeling.html#multiheadattention',
177
- 'whisperspeech/s2a_delar_mup_wds.py'),
178
- 'whisperspeech.s2a_delar_mup_wds.MultiHeadAttention.__init__': ( '4b. semantic to '
179
- 'acoustic token '
180
- 'modeling.html#multiheadattention.__init__',
181
- 'whisperspeech/s2a_delar_mup_wds.py'),
182
- 'whisperspeech.s2a_delar_mup_wds.MultiHeadAttention.forward': ( '4b. semantic to acoustic '
183
- 'token '
184
- 'modeling.html#multiheadattention.forward',
185
- 'whisperspeech/s2a_delar_mup_wds.py'),
186
- 'whisperspeech.s2a_delar_mup_wds.MultiHeadAttention.qkv_attention_pth20': ( '4b. semantic '
187
- 'to acoustic '
188
- 'token '
189
- 'modeling.html#multiheadattention.qkv_attention_pth20',
190
- 'whisperspeech/s2a_delar_mup_wds.py'),
191
- 'whisperspeech.s2a_delar_mup_wds.MultiHeadAttention.qkv_attention_xformers': ( '4b. '
192
- 'semantic '
193
- 'to '
194
- 'acoustic '
195
- 'token '
196
- 'modeling.html#multiheadattention.qkv_attention_xformers',
197
- 'whisperspeech/s2a_delar_mup_wds.py'),
198
- 'whisperspeech.s2a_delar_mup_wds.ResidualAttentionBlock': ( '4b. semantic to acoustic '
199
- 'token '
200
- 'modeling.html#residualattentionblock',
201
- 'whisperspeech/s2a_delar_mup_wds.py'),
202
- 'whisperspeech.s2a_delar_mup_wds.ResidualAttentionBlock.__init__': ( '4b. semantic to '
203
- 'acoustic token '
204
- 'modeling.html#residualattentionblock.__init__',
205
- 'whisperspeech/s2a_delar_mup_wds.py'),
206
- 'whisperspeech.s2a_delar_mup_wds.ResidualAttentionBlock.forward': ( '4b. semantic to '
207
- 'acoustic token '
208
- 'modeling.html#residualattentionblock.forward',
209
- 'whisperspeech/s2a_delar_mup_wds.py'),
210
- 'whisperspeech.s2a_delar_mup_wds.Rotary': ( '4b. semantic to acoustic token '
211
- 'modeling.html#rotary',
212
- 'whisperspeech/s2a_delar_mup_wds.py'),
213
- 'whisperspeech.s2a_delar_mup_wds.Rotary.__init__': ( '4b. semantic to acoustic token '
214
- 'modeling.html#rotary.__init__',
215
- 'whisperspeech/s2a_delar_mup_wds.py'),
216
- 'whisperspeech.s2a_delar_mup_wds.Rotary.forward': ( '4b. semantic to acoustic token '
217
- 'modeling.html#rotary.forward',
218
- 'whisperspeech/s2a_delar_mup_wds.py'),
219
- 'whisperspeech.s2a_delar_mup_wds.SADelARTransformer': ( '4b. semantic to acoustic token '
220
- 'modeling.html#sadelartransformer',
221
- 'whisperspeech/s2a_delar_mup_wds.py'),
222
- 'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.__init__': ( '4b. semantic to '
223
- 'acoustic token '
224
- 'modeling.html#sadelartransformer.__init__',
225
- 'whisperspeech/s2a_delar_mup_wds.py'),
226
- 'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.device': ( '4b. semantic to acoustic '
227
- 'token '
228
- 'modeling.html#sadelartransformer.device',
229
- 'whisperspeech/s2a_delar_mup_wds.py'),
230
- 'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.embed_stoks': ( '4b. semantic to '
231
- 'acoustic token '
232
- 'modeling.html#sadelartransformer.embed_stoks',
233
- 'whisperspeech/s2a_delar_mup_wds.py'),
234
- 'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.forward': ( '4b. semantic to acoustic '
235
- 'token '
236
- 'modeling.html#sadelartransformer.forward',
237
- 'whisperspeech/s2a_delar_mup_wds.py'),
238
- 'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.generate': ( '4b. semantic to '
239
- 'acoustic token '
240
- 'modeling.html#sadelartransformer.generate',
241
- 'whisperspeech/s2a_delar_mup_wds.py'),
242
- 'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.get_extra_state': ( '4b. semantic to '
243
- 'acoustic token '
244
- 'modeling.html#sadelartransformer.get_extra_state',
245
- 'whisperspeech/s2a_delar_mup_wds.py'),
246
- 'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.get_metrics': ( '4b. semantic to '
247
- 'acoustic token '
248
- 'modeling.html#sadelartransformer.get_metrics',
249
- 'whisperspeech/s2a_delar_mup_wds.py'),
250
- 'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.init_transformer': ( '4b. semantic to '
251
- 'acoustic token '
252
- 'modeling.html#sadelartransformer.init_transformer',
253
- 'whisperspeech/s2a_delar_mup_wds.py'),
254
- 'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.load_checkpoint': ( '4b. semantic to '
255
- 'acoustic token '
256
- 'modeling.html#sadelartransformer.load_checkpoint',
257
- 'whisperspeech/s2a_delar_mup_wds.py'),
258
- 'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.load_frozen_semantic_embeddings': ( '4b. '
259
- 'semantic '
260
- 'to '
261
- 'acoustic '
262
- 'token '
263
- 'modeling.html#sadelartransformer.load_frozen_semantic_embeddings',
264
- 'whisperspeech/s2a_delar_mup_wds.py'),
265
- 'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.load_model': ( '4b. semantic to '
266
- 'acoustic token '
267
- 'modeling.html#sadelartransformer.load_model',
268
- 'whisperspeech/s2a_delar_mup_wds.py'),
269
- 'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.save_model': ( '4b. semantic to '
270
- 'acoustic token '
271
- 'modeling.html#sadelartransformer.save_model',
272
- 'whisperspeech/s2a_delar_mup_wds.py'),
273
- 'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.set_extra_state': ( '4b. semantic to '
274
- 'acoustic token '
275
- 'modeling.html#sadelartransformer.set_extra_state',
276
- 'whisperspeech/s2a_delar_mup_wds.py'),
277
- 'whisperspeech.s2a_delar_mup_wds.SADelARTransformer.setup': ( '4b. semantic to acoustic '
278
- 'token '
279
- 'modeling.html#sadelartransformer.setup',
280
- 'whisperspeech/s2a_delar_mup_wds.py'),
281
- 'whisperspeech.s2a_delar_mup_wds.Tunables': ( '4b. semantic to acoustic token '
282
- 'modeling.html#tunables',
283
- 'whisperspeech/s2a_delar_mup_wds.py'),
284
- 'whisperspeech.s2a_delar_mup_wds.Tunables.__post_init__': ( '4b. semantic to acoustic '
285
- 'token '
286
- 'modeling.html#tunables.__post_init__',
287
- 'whisperspeech/s2a_delar_mup_wds.py'),
288
- 'whisperspeech.s2a_delar_mup_wds.Tunables.upgrade': ( '4b. semantic to acoustic token '
289
- 'modeling.html#tunables.upgrade',
290
- 'whisperspeech/s2a_delar_mup_wds.py'),
291
- 'whisperspeech.s2a_delar_mup_wds._make_model': ( '4b. semantic to acoustic token '
292
- 'modeling.html#_make_model',
293
- 'whisperspeech/s2a_delar_mup_wds.py'),
294
- 'whisperspeech.s2a_delar_mup_wds.apply_rotary_pos_emb': ( '4b. semantic to acoustic token '
295
- 'modeling.html#apply_rotary_pos_emb',
296
- 'whisperspeech/s2a_delar_mup_wds.py'),
297
- 'whisperspeech.s2a_delar_mup_wds.load_datasets': ( '4b. semantic to acoustic token '
298
- 'modeling.html#load_datasets',
299
- 'whisperspeech/s2a_delar_mup_wds.py'),
300
- 'whisperspeech.s2a_delar_mup_wds.make_model': ( '4b. semantic to acoustic token '
301
- 'modeling.html#make_model',
302
- 'whisperspeech/s2a_delar_mup_wds.py'),
303
- 'whisperspeech.s2a_delar_mup_wds.pad_samples': ( '4b. semantic to acoustic token '
304
- 'modeling.html#pad_samples',
305
- 'whisperspeech/s2a_delar_mup_wds.py'),
306
- 'whisperspeech.s2a_delar_mup_wds.rand': ( '4b. semantic to acoustic token '
307
- 'modeling.html#rand',
308
- 'whisperspeech/s2a_delar_mup_wds.py'),
309
- 'whisperspeech.s2a_delar_mup_wds.random_trunc': ( '4b. semantic to acoustic token '
310
- 'modeling.html#random_trunc',
311
- 'whisperspeech/s2a_delar_mup_wds.py'),
312
- 'whisperspeech.s2a_delar_mup_wds.rotate_half': ( '4b. semantic to acoustic token '
313
- 'modeling.html#rotate_half',
314
- 'whisperspeech/s2a_delar_mup_wds.py'),
315
- 'whisperspeech.s2a_delar_mup_wds.speaker_id_extractor': ( '4b. semantic to acoustic token '
316
- 'modeling.html#speaker_id_extractor',
317
- 'whisperspeech/s2a_delar_mup_wds.py')},
318
- 'whisperspeech.t2s_up_wds': { 'whisperspeech.t2s_up_wds.CharTokenizer': ( '5b. text to semantic token '
319
- 'modeling.html#chartokenizer',
320
- 'whisperspeech/t2s_up_wds.py'),
321
- 'whisperspeech.t2s_up_wds.CharTokenizer.decode': ( '5b. text to semantic token '
322
- 'modeling.html#chartokenizer.decode',
323
- 'whisperspeech/t2s_up_wds.py'),
324
- 'whisperspeech.t2s_up_wds.CharTokenizer.encode': ( '5b. text to semantic token '
325
- 'modeling.html#chartokenizer.encode',
326
- 'whisperspeech/t2s_up_wds.py'),
327
- 'whisperspeech.t2s_up_wds.Decoder': ( '5b. text to semantic token modeling.html#decoder',
328
- 'whisperspeech/t2s_up_wds.py'),
329
- 'whisperspeech.t2s_up_wds.Decoder.__init__': ( '5b. text to semantic token '
330
- 'modeling.html#decoder.__init__',
331
- 'whisperspeech/t2s_up_wds.py'),
332
- 'whisperspeech.t2s_up_wds.Decoder.forward': ( '5b. text to semantic token '
333
- 'modeling.html#decoder.forward',
334
- 'whisperspeech/t2s_up_wds.py'),
335
- 'whisperspeech.t2s_up_wds.EmbeddingProjector': ( '5b. text to semantic token '
336
- 'modeling.html#embeddingprojector',
337
- 'whisperspeech/t2s_up_wds.py'),
338
- 'whisperspeech.t2s_up_wds.Encoder': ( '5b. text to semantic token modeling.html#encoder',
339
- 'whisperspeech/t2s_up_wds.py'),
340
- 'whisperspeech.t2s_up_wds.Encoder.__init__': ( '5b. text to semantic token '
341
- 'modeling.html#encoder.__init__',
342
- 'whisperspeech/t2s_up_wds.py'),
343
- 'whisperspeech.t2s_up_wds.Encoder.forward': ( '5b. text to semantic token '
344
- 'modeling.html#encoder.forward',
345
- 'whisperspeech/t2s_up_wds.py'),
346
- 'whisperspeech.t2s_up_wds.TSARTransformer': ( '5b. text to semantic token '
347
- 'modeling.html#tsartransformer',
348
- 'whisperspeech/t2s_up_wds.py'),
349
- 'whisperspeech.t2s_up_wds.TSARTransformer.__init__': ( '5b. text to semantic token '
350
- 'modeling.html#tsartransformer.__init__',
351
- 'whisperspeech/t2s_up_wds.py'),
352
- 'whisperspeech.t2s_up_wds.TSARTransformer.device': ( '5b. text to semantic token '
353
- 'modeling.html#tsartransformer.device',
354
- 'whisperspeech/t2s_up_wds.py'),
355
- 'whisperspeech.t2s_up_wds.TSARTransformer.ensure_tokenizer': ( '5b. text to semantic token '
356
- 'modeling.html#tsartransformer.ensure_tokenizer',
357
- 'whisperspeech/t2s_up_wds.py'),
358
- 'whisperspeech.t2s_up_wds.TSARTransformer.forward': ( '5b. text to semantic token '
359
- 'modeling.html#tsartransformer.forward',
360
- 'whisperspeech/t2s_up_wds.py'),
361
- 'whisperspeech.t2s_up_wds.TSARTransformer.generate': ( '5b. text to semantic token '
362
- 'modeling.html#tsartransformer.generate',
363
- 'whisperspeech/t2s_up_wds.py'),
364
- 'whisperspeech.t2s_up_wds.TSARTransformer.generate_batch': ( '5b. text to semantic token '
365
- 'modeling.html#tsartransformer.generate_batch',
366
- 'whisperspeech/t2s_up_wds.py'),
367
- 'whisperspeech.t2s_up_wds.TSARTransformer.init_transformer': ( '5b. text to semantic token '
368
- 'modeling.html#tsartransformer.init_transformer',
369
- 'whisperspeech/t2s_up_wds.py'),
370
- 'whisperspeech.t2s_up_wds.TSARTransformer.load_checkpoint': ( '5b. text to semantic token '
371
- 'modeling.html#tsartransformer.load_checkpoint',
372
- 'whisperspeech/t2s_up_wds.py'),
373
- 'whisperspeech.t2s_up_wds.TSARTransformer.load_frozen_semantic_embeddings': ( '5b. text to '
374
- 'semantic token '
375
- 'modeling.html#tsartransformer.load_frozen_semantic_embeddings',
376
- 'whisperspeech/t2s_up_wds.py'),
377
- 'whisperspeech.t2s_up_wds.TSARTransformer.load_model': ( '5b. text to semantic token '
378
- 'modeling.html#tsartransformer.load_model',
379
- 'whisperspeech/t2s_up_wds.py'),
380
- 'whisperspeech.t2s_up_wds.TSARTransformer.save_model': ( '5b. text to semantic token '
381
- 'modeling.html#tsartransformer.save_model',
382
- 'whisperspeech/t2s_up_wds.py'),
383
- 'whisperspeech.t2s_up_wds.TSARTransformer.setup': ( '5b. text to semantic token '
384
- 'modeling.html#tsartransformer.setup',
385
- 'whisperspeech/t2s_up_wds.py'),
386
- 'whisperspeech.t2s_up_wds.Tunables': ( '5b. text to semantic token modeling.html#tunables',
387
- 'whisperspeech/t2s_up_wds.py'),
388
- 'whisperspeech.t2s_up_wds.Tunables.__post_init__': ( '5b. text to semantic token '
389
- 'modeling.html#tunables.__post_init__',
390
- 'whisperspeech/t2s_up_wds.py'),
391
- 'whisperspeech.t2s_up_wds._make_model': ( '5b. text to semantic token modeling.html#_make_model',
392
- 'whisperspeech/t2s_up_wds.py'),
393
- 'whisperspeech.t2s_up_wds.ar_padder': ( '5b. text to semantic token modeling.html#ar_padder',
394
- 'whisperspeech/t2s_up_wds.py'),
395
- 'whisperspeech.t2s_up_wds.build_speaker_map': ( '5b. text to semantic token '
396
- 'modeling.html#build_speaker_map',
397
- 'whisperspeech/t2s_up_wds.py'),
398
- 'whisperspeech.t2s_up_wds.char_per_seconder': ( '5b. text to semantic token '
399
- 'modeling.html#char_per_seconder',
400
- 'whisperspeech/t2s_up_wds.py'),
401
- 'whisperspeech.t2s_up_wds.load_datasets': ( '5b. text to semantic token '
402
- 'modeling.html#load_datasets',
403
- 'whisperspeech/t2s_up_wds.py'),
404
- 'whisperspeech.t2s_up_wds.make_model': ( '5b. text to semantic token modeling.html#make_model',
405
- 'whisperspeech/t2s_up_wds.py'),
406
- 'whisperspeech.t2s_up_wds.rand': ( '5b. text to semantic token modeling.html#rand',
407
- 'whisperspeech/t2s_up_wds.py'),
408
- 'whisperspeech.t2s_up_wds.speaker_id_extractor': ( '5b. text to semantic token '
409
- 'modeling.html#speaker_id_extractor',
410
- 'whisperspeech/t2s_up_wds.py'),
411
- 'whisperspeech.t2s_up_wds.tokenizer': ( '5b. text to semantic token modeling.html#tokenizer',
412
- 'whisperspeech/t2s_up_wds.py')},
413
- 'whisperspeech.train': { 'whisperspeech.train.SimpleVisual': ('b1. training.html#simplevisual', 'whisperspeech/train.py'),
414
- 'whisperspeech.train.SimpleVisual.__init__': ( 'b1. training.html#simplevisual.__init__',
415
- 'whisperspeech/train.py'),
416
- 'whisperspeech.train.SimpleVisual.add_data': ( 'b1. training.html#simplevisual.add_data',
417
- 'whisperspeech/train.py'),
418
- 'whisperspeech.train.SimpleVisual.add_table_row': ( 'b1. training.html#simplevisual.add_table_row',
419
- 'whisperspeech/train.py'),
420
- 'whisperspeech.train.SimpleVisual.hide': ( 'b1. training.html#simplevisual.hide',
421
- 'whisperspeech/train.py'),
422
- 'whisperspeech.train.SimpleVisual.on_iter': ( 'b1. training.html#simplevisual.on_iter',
423
- 'whisperspeech/train.py'),
424
- 'whisperspeech.train.SimpleVisual.plot': ( 'b1. training.html#simplevisual.plot',
425
- 'whisperspeech/train.py'),
426
- 'whisperspeech.train.SimpleVisual.show': ( 'b1. training.html#simplevisual.show',
427
- 'whisperspeech/train.py'),
428
- 'whisperspeech.train.train': ('b1. training.html#train', 'whisperspeech/train.py'),
429
- 'whisperspeech.train.validate': ('b1. training.html#validate', 'whisperspeech/train.py')},
430
- 'whisperspeech.train_multi': { 'whisperspeech.train_multi.TrainingTask': ( 'b2. training (lightning).html#trainingtask',
431
- 'whisperspeech/train_multi.py'),
432
- 'whisperspeech.train_multi.TrainingTask.__init__': ( 'b2. training '
433
- '(lightning).html#trainingtask.__init__',
434
- 'whisperspeech/train_multi.py'),
435
- 'whisperspeech.train_multi.TrainingTask.configure_optimizers': ( 'b2. training '
436
- '(lightning).html#trainingtask.configure_optimizers',
437
- 'whisperspeech/train_multi.py'),
438
- 'whisperspeech.train_multi.TrainingTask.on_fit_start': ( 'b2. training '
439
- '(lightning).html#trainingtask.on_fit_start',
440
- 'whisperspeech/train_multi.py'),
441
- 'whisperspeech.train_multi.TrainingTask.on_validation_epoch_end': ( 'b2. training '
442
- '(lightning).html#trainingtask.on_validation_epoch_end',
443
- 'whisperspeech/train_multi.py'),
444
- 'whisperspeech.train_multi.TrainingTask.test_step': ( 'b2. training '
445
- '(lightning).html#trainingtask.test_step',
446
- 'whisperspeech/train_multi.py'),
447
- 'whisperspeech.train_multi.TrainingTask.training_step': ( 'b2. training '
448
- '(lightning).html#trainingtask.training_step',
449
- 'whisperspeech/train_multi.py'),
450
- 'whisperspeech.train_multi.TrainingTask.validation_step': ( 'b2. training '
451
- '(lightning).html#trainingtask.validation_step',
452
- 'whisperspeech/train_multi.py'),
453
- 'whisperspeech.train_multi.parse_and_call': ( 'b2. training (lightning).html#parse_and_call',
454
- 'whisperspeech/train_multi.py')},
455
- 'whisperspeech.vad': { 'whisperspeech.vad.extract_segments': ( '1b. voice activity detection.html#extract_segments',
456
- 'whisperspeech/vad.py'),
457
- 'whisperspeech.vad.fix_dots_in_names': ( '1b. voice activity detection.html#fix_dots_in_names',
458
- 'whisperspeech/vad.py'),
459
- 'whisperspeech.vad.flac_to_vad_name': ( '1b. voice activity detection.html#flac_to_vad_name',
460
- 'whisperspeech/vad.py'),
461
- 'whisperspeech.vad.load_dataset': ( '1b. voice activity detection.html#load_dataset',
462
- 'whisperspeech/vad.py'),
463
- 'whisperspeech.vad.process_shard': ( '1b. voice activity detection.html#process_shard',
464
- 'whisperspeech/vad.py'),
465
- 'whisperspeech.vad.segment_audio': ( '1b. voice activity detection.html#segment_audio',
466
- 'whisperspeech/vad.py')},
467
- 'whisperspeech.verify_wds': { 'whisperspeech.verify_wds.process_shard': ( '0. verify webdataset archives.html#process_shard',
468
- 'whisperspeech/verify_wds.py')},
469
- 'whisperspeech.vq_stoks': { 'whisperspeech.vq_stoks.RQBottleneckTransformer': ( '2b. whisper quantization (semantic token) '
470
- 'model.html#rqbottlenecktransformer',
471
- 'whisperspeech/vq_stoks.py'),
472
- 'whisperspeech.vq_stoks.RQBottleneckTransformer.__init__': ( '2b. whisper quantization (semantic '
473
- 'token) '
474
- 'model.html#rqbottlenecktransformer.__init__',
475
- 'whisperspeech/vq_stoks.py'),
476
- 'whisperspeech.vq_stoks.RQBottleneckTransformer.decode_text': ( '2b. whisper quantization '
477
- '(semantic token) '
478
- 'model.html#rqbottlenecktransformer.decode_text',
479
- 'whisperspeech/vq_stoks.py'),
480
- 'whisperspeech.vq_stoks.RQBottleneckTransformer.dequantize': ( '2b. whisper quantization (semantic '
481
- 'token) '
482
- 'model.html#rqbottlenecktransformer.dequantize',
483
- 'whisperspeech/vq_stoks.py'),
484
- 'whisperspeech.vq_stoks.RQBottleneckTransformer.device': ( '2b. whisper quantization (semantic '
485
- 'token) '
486
- 'model.html#rqbottlenecktransformer.device',
487
- 'whisperspeech/vq_stoks.py'),
488
- 'whisperspeech.vq_stoks.RQBottleneckTransformer.downsample_embeddings': ( '2b. whisper '
489
- 'quantization (semantic '
490
- 'token) '
491
- 'model.html#rqbottlenecktransformer.downsample_embeddings',
492
- 'whisperspeech/vq_stoks.py'),
493
- 'whisperspeech.vq_stoks.RQBottleneckTransformer.encode_audio': ( '2b. whisper quantization '
494
- '(semantic token) '
495
- 'model.html#rqbottlenecktransformer.encode_audio',
496
- 'whisperspeech/vq_stoks.py'),
497
- 'whisperspeech.vq_stoks.RQBottleneckTransformer.encode_mel': ( '2b. whisper quantization (semantic '
498
- 'token) '
499
- 'model.html#rqbottlenecktransformer.encode_mel',
500
- 'whisperspeech/vq_stoks.py'),
501
- 'whisperspeech.vq_stoks.RQBottleneckTransformer.ensure_whisper': ( '2b. whisper quantization '
502
- '(semantic token) '
503
- 'model.html#rqbottlenecktransformer.ensure_whisper',
504
- 'whisperspeech/vq_stoks.py'),
505
- 'whisperspeech.vq_stoks.RQBottleneckTransformer.extract_teacher': ( '2b. whisper quantization '
506
- '(semantic token) '
507
- 'model.html#rqbottlenecktransformer.extract_teacher',
508
- 'whisperspeech/vq_stoks.py'),
509
- 'whisperspeech.vq_stoks.RQBottleneckTransformer.forward': ( '2b. whisper quantization (semantic '
510
- 'token) '
511
- 'model.html#rqbottlenecktransformer.forward',
512
- 'whisperspeech/vq_stoks.py'),
513
- 'whisperspeech.vq_stoks.RQBottleneckTransformer.get_metrics': ( '2b. whisper quantization '
514
- '(semantic token) '
515
- 'model.html#rqbottlenecktransformer.get_metrics',
516
- 'whisperspeech/vq_stoks.py'),
517
- 'whisperspeech.vq_stoks.RQBottleneckTransformer.init_transformer': ( '2b. whisper quantization '
518
- '(semantic token) '
519
- 'model.html#rqbottlenecktransformer.init_transformer',
520
- 'whisperspeech/vq_stoks.py'),
521
- 'whisperspeech.vq_stoks.RQBottleneckTransformer.load_checkpoint': ( '2b. whisper quantization '
522
- '(semantic token) '
523
- 'model.html#rqbottlenecktransformer.load_checkpoint',
524
- 'whisperspeech/vq_stoks.py'),
525
- 'whisperspeech.vq_stoks.RQBottleneckTransformer.load_model': ( '2b. whisper quantization (semantic '
526
- 'token) '
527
- 'model.html#rqbottlenecktransformer.load_model',
528
- 'whisperspeech/vq_stoks.py'),
529
- 'whisperspeech.vq_stoks.RQBottleneckTransformer.quantize': ( '2b. whisper quantization (semantic '
530
- 'token) '
531
- 'model.html#rqbottlenecktransformer.quantize',
532
- 'whisperspeech/vq_stoks.py'),
533
- 'whisperspeech.vq_stoks.RQBottleneckTransformer.save_model': ( '2b. whisper quantization (semantic '
534
- 'token) '
535
- 'model.html#rqbottlenecktransformer.save_model',
536
- 'whisperspeech/vq_stoks.py'),
537
- 'whisperspeech.vq_stoks.RQBottleneckTransformer.setup': ( '2b. whisper quantization (semantic '
538
- 'token) '
539
- 'model.html#rqbottlenecktransformer.setup',
540
- 'whisperspeech/vq_stoks.py'),
541
- 'whisperspeech.vq_stoks.Tunables': ( '2b. whisper quantization (semantic token) '
542
- 'model.html#tunables',
543
- 'whisperspeech/vq_stoks.py'),
544
- 'whisperspeech.vq_stoks.Tunables.__post_init__': ( '2b. whisper quantization (semantic token) '
545
- 'model.html#tunables.__post_init__',
546
- 'whisperspeech/vq_stoks.py'),
547
- 'whisperspeech.vq_stoks.Tunables.upgrade': ( '2b. whisper quantization (semantic token) '
548
- 'model.html#tunables.upgrade',
549
- 'whisperspeech/vq_stoks.py'),
550
- 'whisperspeech.vq_stoks.add_masks': ( '2b. whisper quantization (semantic token) '
551
- 'model.html#add_masks',
552
- 'whisperspeech/vq_stoks.py'),
553
- 'whisperspeech.vq_stoks.derived_dataset': ( '2b. whisper quantization (semantic token) '
554
- 'model.html#derived_dataset',
555
- 'whisperspeech/vq_stoks.py'),
556
- 'whisperspeech.vq_stoks.load_datasets': ( '2b. whisper quantization (semantic token) '
557
- 'model.html#load_datasets',
558
- 'whisperspeech/vq_stoks.py'),
559
- 'whisperspeech.vq_stoks.logrand': ( '2b. whisper quantization (semantic token) model.html#logrand',
560
- 'whisperspeech/vq_stoks.py'),
561
- 'whisperspeech.vq_stoks.make_model': ( '2b. whisper quantization (semantic token) '
562
- 'model.html#make_model',
563
- 'whisperspeech/vq_stoks.py'),
564
- 'whisperspeech.vq_stoks.merge_in': ( '2b. whisper quantization (semantic token) '
565
- 'model.html#merge_in',
566
- 'whisperspeech/vq_stoks.py'),
567
- 'whisperspeech.vq_stoks.rand': ( '2b. whisper quantization (semantic token) model.html#rand',
568
- 'whisperspeech/vq_stoks.py'),
569
- 'whisperspeech.vq_stoks.tokenize_text': ( '2b. whisper quantization (semantic token) '
570
- 'model.html#tokenize_text',
571
- 'whisperspeech/vq_stoks.py')},
572
- 'whisperspeech.wer_metrics': { 'whisperspeech.wer_metrics.DfBuilder': ( 'c. word error rate metrics.html#dfbuilder',
573
- 'whisperspeech/wer_metrics.py'),
574
- 'whisperspeech.wer_metrics.DfBuilder.__init__': ( 'c. word error rate '
575
- 'metrics.html#dfbuilder.__init__',
576
- 'whisperspeech/wer_metrics.py'),
577
- 'whisperspeech.wer_metrics.DfBuilder.df': ( 'c. word error rate metrics.html#dfbuilder.df',
578
- 'whisperspeech/wer_metrics.py'),
579
- 'whisperspeech.wer_metrics.DfBuilder.push': ( 'c. word error rate metrics.html#dfbuilder.push',
580
- 'whisperspeech/wer_metrics.py'),
581
- 'whisperspeech.wer_metrics.WERStats': ( 'c. word error rate metrics.html#werstats',
582
- 'whisperspeech/wer_metrics.py'),
583
- 'whisperspeech.wer_metrics.WERStats.__init__': ( 'c. word error rate '
584
- 'metrics.html#werstats.__init__',
585
- 'whisperspeech/wer_metrics.py'),
586
- 'whisperspeech.wer_metrics.WERStats.push_sample': ( 'c. word error rate '
587
- 'metrics.html#werstats.push_sample',
588
- 'whisperspeech/wer_metrics.py'),
589
- 'whisperspeech.wer_metrics.librispeech_data': ( 'c. word error rate '
590
- 'metrics.html#librispeech_data',
591
- 'whisperspeech/wer_metrics.py'),
592
- 'whisperspeech.wer_metrics.whisper_normalize': ( 'c. word error rate '
593
- 'metrics.html#whisper_normalize',
594
- 'whisperspeech/wer_metrics.py')},
595
- 'whisperspeech.wh_transcribe': { 'whisperspeech.wh_transcribe.chunk_merger': ( '2a. whisper quantization dataset '
596
- 'preparation.html#chunk_merger',
597
- 'whisperspeech/wh_transcribe.py'),
598
- 'whisperspeech.wh_transcribe.flac_to_txt_name': ( '2a. whisper quantization dataset '
599
- 'preparation.html#flac_to_txt_name',
600
- 'whisperspeech/wh_transcribe.py'),
601
- 'whisperspeech.wh_transcribe.merge_in': ( '2a. whisper quantization dataset '
602
- 'preparation.html#merge_in',
603
- 'whisperspeech/wh_transcribe.py'),
604
- 'whisperspeech.wh_transcribe.process_shard': ( '2a. whisper quantization dataset '
605
- 'preparation.html#process_shard',
606
- 'whisperspeech/wh_transcribe.py'),
607
- 'whisperspeech.wh_transcribe.random_cutter': ( '2a. whisper quantization dataset '
608
- 'preparation.html#random_cutter',
609
- 'whisperspeech/wh_transcribe.py'),
610
- 'whisperspeech.wh_transcribe.split_to_chunks': ( '2a. whisper quantization dataset '
611
- 'preparation.html#split_to_chunks',
612
- 'whisperspeech/wh_transcribe.py'),
613
- 'whisperspeech.wh_transcribe.wds_compose': ( '2a. whisper quantization dataset '
614
- 'preparation.html#wds_compose',
615
- 'whisperspeech/wh_transcribe.py')}}}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
whisperspeech/a2wav.py DELETED
@@ -1,45 +0,0 @@
1
- # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/6. Quality-boosting vocoder.ipynb.
2
-
3
- # %% auto 0
4
- __all__ = ['Vocoder']
5
-
6
- # %% ../nbs/6. Quality-boosting vocoder.ipynb 1
7
- from vocos import Vocos
8
- import torch
9
- import torchaudio
10
-
11
- # %% ../nbs/6. Quality-boosting vocoder.ipynb 2
12
- class Vocoder:
13
- def __init__(self, repo_id="charactr/vocos-encodec-24khz"):
14
- self.vocos = Vocos.from_pretrained(repo_id).cuda()
15
-
16
- def is_notebook(self):
17
- try:
18
- return get_ipython().__class__.__name__ == "ZMQInteractiveShell"
19
- except:
20
- return False
21
-
22
- @torch.no_grad()
23
- def decode(self, atoks):
24
- if len(atoks.shape) == 3:
25
- b,q,t = atoks.shape
26
- atoks = atoks.permute(1,0,2)
27
- else:
28
- q,t = atoks.shape
29
-
30
- features = self.vocos.codes_to_features(atoks)
31
- bandwidth_id = torch.tensor({2:0,4:1,8:2}[q]).cuda()
32
- return self.vocos.decode(features, bandwidth_id=bandwidth_id)
33
-
34
- def decode_to_file(self, fname, atoks):
35
- audio = self.decode(atoks)
36
- torchaudio.save(fname, audio.cpu(), 24000)
37
- if self.is_notebook():
38
- from IPython.display import display, HTML, Audio
39
- display(HTML(f'<a href="{fname}" target="_blank">Listen to {fname}</a>'))
40
-
41
- def decode_to_notebook(self, atoks):
42
- from IPython.display import display, HTML, Audio
43
-
44
- audio = self.decode(atoks)
45
- display(Audio(audio.cpu().numpy(), rate=24000))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
whisperspeech/extract_acoustic.py DELETED
@@ -1,56 +0,0 @@
1
- # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/1. Acoustic token extraction.ipynb.
2
-
3
- # %% auto 0
4
- __all__ = ['load', 'load_model', 'extract_Atoks', 'extract_acoustic']
5
-
6
- # %% ../nbs/1. Acoustic token extraction.ipynb 2
7
- import torch
8
- import torchaudio
9
- import gc
10
-
11
- from pathlib import Path
12
- from fastcore.script import *
13
- from fastprogress import progress_bar, master_bar
14
-
15
- # %% ../nbs/1. Acoustic token extraction.ipynb 5
16
- def load(fname, newsr=24000):
17
- """Load an audio file to the GPU and resample to `newsr`."""
18
- x, sr = torchaudio.load(fname)
19
- _tform = torchaudio.transforms.Resample(sr, newsr)
20
- return _tform(x).cuda().unsqueeze(0)
21
-
22
- # %% ../nbs/1. Acoustic token extraction.ipynb 6
23
- def load_model():
24
- "Load the pretrained EnCodec model"
25
- from encodec.model import EncodecModel
26
- model = EncodecModel.encodec_model_24khz()
27
- model.set_target_bandwidth(1.5)
28
- model.cuda().eval();
29
- return model
30
-
31
- # %% ../nbs/1. Acoustic token extraction.ipynb 7
32
- def extract_Atoks(model, audio):
33
- """Extract EnCodec tokens for the given `audio` tensor (or file path)
34
- using the given `model` (see `load_model`)."""
35
- if isinstance(audio, (Path, str)):
36
- audio = load(audio)
37
- with torch.no_grad():
38
- frames = torch.cat([model.encode(segment)[0][0]
39
- for segment in torch.split(audio, 320*20000, dim=-1)], dim=-1)
40
- return frames
41
-
42
- # %% ../nbs/1. Acoustic token extraction.ipynb 8
43
- @call_parse
44
- def extract_acoustic(
45
- srcdir:Path, # source dir, should contain *.flac files
46
- outdir:Path, # output dir, will get the *.encodec files
47
- ):
48
- "Convert audio files to .encodec files with tensors of tokens"
49
- model = load_model()
50
- outdir.mkdir(exist_ok=True, parents=True)
51
- for name in progress_bar(list(srcdir.rglob('*.flac'))):
52
- outname = outdir/name.with_suffix('.encodec').name
53
- tokens = extract_Atoks(model, name)
54
- torch.save(tokens, outname)
55
- del tokens
56
- gc.collect()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
whisperspeech/fetch_models.py DELETED
@@ -1,17 +0,0 @@
1
- # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/0. Download models.ipynb.
2
-
3
- # %% auto 0
4
- __all__ = []
5
-
6
- # %% ../nbs/0. Download models.ipynb 1
7
- from fastcore.script import call_parse
8
- import whisperx
9
- import whisper
10
-
11
- # %% ../nbs/0. Download models.ipynb 3
12
- @call_parse
13
- def main():
14
- whisper.load_model('base.en')
15
- whisper.load_model('small.en')
16
- whisperx.vad.load_vad_model('cpu')
17
- whisperx.asr.load_model('medium.en', "cpu", compute_type="float16", language='en')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
whisperspeech/languages.py DELETED
@@ -1,131 +0,0 @@
1
- # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/B. Languages.ipynb.
2
-
3
- # %% auto 0
4
- __all__ = ['to_id']
5
-
6
- # %% ../nbs/B. Languages.ipynb 3
7
- LANGUAGES = {
8
- "en": "english",
9
- "zh": "chinese",
10
- "de": "german",
11
- "es": "spanish",
12
- "ru": "russian",
13
- "ko": "korean",
14
- "fr": "french",
15
- "ja": "japanese",
16
- "pt": "portuguese",
17
- "tr": "turkish",
18
- "pl": "polish",
19
- "ca": "catalan",
20
- "nl": "dutch",
21
- "ar": "arabic",
22
- "sv": "swedish",
23
- "it": "italian",
24
- "id": "indonesian",
25
- "hi": "hindi",
26
- "fi": "finnish",
27
- "vi": "vietnamese",
28
- "he": "hebrew",
29
- "uk": "ukrainian",
30
- "el": "greek",
31
- "ms": "malay",
32
- "cs": "czech",
33
- "ro": "romanian",
34
- "da": "danish",
35
- "hu": "hungarian",
36
- "ta": "tamil",
37
- "no": "norwegian",
38
- "th": "thai",
39
- "ur": "urdu",
40
- "hr": "croatian",
41
- "bg": "bulgarian",
42
- "lt": "lithuanian",
43
- "la": "latin",
44
- "mi": "maori",
45
- "ml": "malayalam",
46
- "cy": "welsh",
47
- "sk": "slovak",
48
- "te": "telugu",
49
- "fa": "persian",
50
- "lv": "latvian",
51
- "bn": "bengali",
52
- "sr": "serbian",
53
- "az": "azerbaijani",
54
- "sl": "slovenian",
55
- "kn": "kannada",
56
- "et": "estonian",
57
- "mk": "macedonian",
58
- "br": "breton",
59
- "eu": "basque",
60
- "is": "icelandic",
61
- "hy": "armenian",
62
- "ne": "nepali",
63
- "mn": "mongolian",
64
- "bs": "bosnian",
65
- "kk": "kazakh",
66
- "sq": "albanian",
67
- "sw": "swahili",
68
- "gl": "galician",
69
- "mr": "marathi",
70
- "pa": "punjabi",
71
- "si": "sinhala",
72
- "km": "khmer",
73
- "sn": "shona",
74
- "yo": "yoruba",
75
- "so": "somali",
76
- "af": "afrikaans",
77
- "oc": "occitan",
78
- "ka": "georgian",
79
- "be": "belarusian",
80
- "tg": "tajik",
81
- "sd": "sindhi",
82
- "gu": "gujarati",
83
- "am": "amharic",
84
- "yi": "yiddish",
85
- "lo": "lao",
86
- "uz": "uzbek",
87
- "fo": "faroese",
88
- "ht": "haitian creole",
89
- "ps": "pashto",
90
- "tk": "turkmen",
91
- "nn": "nynorsk",
92
- "mt": "maltese",
93
- "sa": "sanskrit",
94
- "lb": "luxembourgish",
95
- "my": "myanmar",
96
- "bo": "tibetan",
97
- "tl": "tagalog",
98
- "mg": "malagasy",
99
- "as": "assamese",
100
- "tt": "tatar",
101
- "haw": "hawaiian",
102
- "ln": "lingala",
103
- "ha": "hausa",
104
- "ba": "bashkir",
105
- "jw": "javanese",
106
- "su": "sundanese",
107
- }
108
-
109
- # %% ../nbs/B. Languages.ipynb 4
110
- # language code lookup by name, with a few language aliases
111
- TO_LANGUAGE_CODE = {
112
- **{language: code for code, language in LANGUAGES.items()},
113
- "burmese": "my",
114
- "valencian": "ca",
115
- "flemish": "nl",
116
- "haitian": "ht",
117
- "letzeburgesch": "lb",
118
- "pushto": "ps",
119
- "panjabi": "pa",
120
- "moldavian": "ro",
121
- "moldovan": "ro",
122
- "sinhalese": "si",
123
- "castilian": "es",
124
- }
125
-
126
- # %% ../nbs/B. Languages.ipynb 5
127
- languages = tuple(LANGUAGES.keys())
128
-
129
- # %% ../nbs/B. Languages.ipynb 6
130
- def to_id(lang):
131
- return languages.index(TO_LANGUAGE_CODE.get(lang, lang))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
whisperspeech/modules.py DELETED
@@ -1,331 +0,0 @@
1
- # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/A. Neural modules.ipynb.
2
-
3
- # %% auto 0
4
- __all__ = ['LayerNorm', 'LinearHead', 'QueryHead', 'init_transformer', 'sinusoids', 'MultiHeadAttention',
5
- 'ResidualAttentionBlock', 'BaseDecoder', 'EmbeddingProjector', 'FlexEmbeddings']
6
-
7
- # %% ../nbs/A. Neural modules.ipynb 2
8
- import torch
9
- import numpy as np
10
- import math
11
-
12
- from torch import Tensor, nn
13
- import torch.nn.functional as F
14
- from typing import Dict, Iterable, Optional
15
-
16
- # import xformers.ops as xops
17
-
18
- # %% ../nbs/A. Neural modules.ipynb 3
19
- # Code in this file is mostly borrowed from
20
- # https://github.com/openai/whisper/blob/main/whisper/model.py
21
- # and is under the MIT License
22
-
23
- class LayerNorm(nn.LayerNorm):
24
- def forward(self, x):
25
- return super().forward(x.float()).type(x.dtype)
26
-
27
- # Used in ΞΌP to initialize the weights and configure the optimizer
28
- # These two layers map the transformer width into a fixed dimension
29
- class LinearHead(nn.Linear):
30
- pass
31
-
32
- class QueryHead(nn.Linear):
33
- pass
34
-
35
- # based on https://github.com/karpathy/minGPT/blob/master/mingpt/model.py#L163
36
- def init_transformer(m):
37
- if isinstance(m, (nn.Linear, nn.Embedding)):
38
- torch.nn.init.trunc_normal_(m.weight, std=.02)
39
- if isinstance(m, nn.Linear) and m.bias is not None:
40
- torch.nn.init.constant_(m.bias, 0)
41
- elif isinstance(m, nn.LayerNorm):
42
- torch.nn.init.constant_(m.bias, 0)
43
- torch.nn.init.constant_(m.weight, 1.0)
44
-
45
- # %% ../nbs/A. Neural modules.ipynb 4
46
- def sinusoids(length, channels, max_timescale=10000):
47
- """Returns sinusoids for positional embedding"""
48
- assert channels % 2 == 0
49
- log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
50
- inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
51
- scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
52
- return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
53
-
54
- # %% ../nbs/A. Neural modules.ipynb 5
55
- class MultiHeadAttention(nn.Module):
56
- def __init__(self, n_state: int, n_head: int, qk_scale: float = 1, rope: bool = False, cross=False):
57
- super().__init__()
58
- self.n_state = n_state
59
- self.n_head = n_head
60
- self.sqrt_qk_scale = math.sqrt(qk_scale)
61
- self.query = QueryHead(n_state, n_state)
62
- self.key = nn.Linear(n_state, n_state, bias=False)
63
- self.value = nn.Linear(n_state, n_state)
64
- self.out = nn.Linear(n_state, n_state)
65
- self.cross = cross
66
- self.query_subsampling = 1
67
- self.key_subsampling = 1
68
-
69
- self.cached_kvx = None
70
- self.register_buffer('k_cache', None)
71
- self.register_buffer('v_cache', None)
72
-
73
- self.rotary = None
74
- if rope:
75
- self.rotary = Rotary(n_state // n_head)
76
- self.qkv = None
77
- self.kv = None
78
-
79
- def setup_kv_cache(self, max_batch_size, max_seq_len, dtype=torch.float32):
80
- cache_shape = (max_batch_size, self.n_head, max_seq_len, self.n_state//self.n_head)
81
- self.k_cache = torch.zeros(cache_shape, dtype=dtype, device=self.key.weight.device)
82
- self.v_cache = torch.zeros(cache_shape, dtype=dtype, device=self.value.weight.device)
83
-
84
- def merge_linears(self, layers, mults):
85
- bias = [x.bias for x in layers if x.bias is not None][0]
86
- din, dout = layers[0].weight.shape
87
- new = nn.Linear(din, len(layers) * dout).to(layers[0].weight.device)
88
- with torch.no_grad():
89
- new.weight[:] = torch.cat([x.weight * m for x,m in zip(layers, mults)])
90
- new.bias[:] = torch.cat([torch.zeros_like(bias) if x.bias is None else x.bias * m for x, m in zip(layers, mults)])
91
- return new
92
-
93
- def convert_for_eval(self):
94
- if self.qkv or self.kv: raise AttributeError("already converted")
95
-
96
- self.odim = self.key.weight.shape[1]
97
- if self.cross:
98
- self.q = self.merge_linears([self.query], [self.sqrt_qk_scale])
99
- self.kv = self.merge_linears([self.key, self.value],
100
- [self.sqrt_qk_scale, 1])
101
- else:
102
- self.qkv = self.merge_linears([self.query, self.key, self.value],
103
- [self.sqrt_qk_scale, self.sqrt_qk_scale, 1])
104
-
105
- def split_heads(self, x, x_positions, rope=False, subsampling=1):
106
- x = x.view(*x.shape[:2], self.n_head, -1)
107
- if rope:
108
- x = rope_rotate(x, x_positions * subsampling, *self.rotary(x))
109
- return x.permute(0, 2, 1, 3)
110
-
111
- def forward(
112
- self,
113
- qx,
114
- q_positions,
115
- kvx,
116
- kv_positions,
117
- causal = False,
118
- mask=None,
119
- ):
120
- if self.qkv:
121
- q,k,v = self.qkv(qx).split(self.odim, dim=-1)
122
- elif self.kv:
123
- q = self.q(qx)
124
- k,v = self.kv(kvx).split(self.odim, dim=-1)
125
- else:
126
- q,k,v = None,None,None
127
-
128
- if q is None: q = self.query(qx) * self.sqrt_qk_scale
129
- q = self.split_heads(q, q_positions, rope = self.rotary, subsampling = self.query_subsampling)
130
-
131
- if kvx is not self.cached_kvx:
132
- if k is None: k = self.key(kvx) * self.sqrt_qk_scale
133
- k = self.split_heads(k, kv_positions, rope = self.rotary, subsampling = self.key_subsampling)
134
- if v is None: v = self.value(kvx)
135
- v = self.split_heads(v, kv_positions)
136
- if self.k_cache is not None:
137
- self.k_cache[:,:,kv_positions] = k
138
- self.v_cache[:,:,kv_positions] = v
139
-
140
- if self.k_cache is not None:
141
- k, v = self.k_cache, self.v_cache
142
-
143
- if mask is not None:
144
- mask = mask[q_positions]
145
-
146
- wv = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0, is_causal=causal)
147
-
148
- return self.out(wv.permute(0, 2, 1, 3).flatten(start_dim=2))
149
-
150
- # %% ../nbs/A. Neural modules.ipynb 6
151
- # modified from https://blog.eleuther.ai/rotary-embeddings/
152
-
153
- import torch
154
-
155
- class Rotary(torch.nn.Module):
156
- def __init__(self, dim, base=10000):
157
- super().__init__()
158
- inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
159
- self.register_buffer("inv_freq", inv_freq)
160
- self.seq_len_cached = None
161
- self.cos_cached = None
162
- self.sin_cached = None
163
-
164
- def forward(self, x, seq_dim=1):
165
- seq_len = x.shape[seq_dim]
166
- if not self.seq_len_cached or seq_len > self.seq_len_cached:
167
- self.seq_len_cached = 2500
168
- # self.seq_len_cached = seq_len
169
-
170
- t = torch.arange(self.seq_len_cached, device=x.device).type_as(self.inv_freq)
171
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
172
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
173
- self.cos_cached = emb.cos()[None, :, None, :]
174
- self.sin_cached = emb.sin()[None, :, None, :]
175
- return self.cos_cached, self.sin_cached
176
-
177
-
178
- # rotary pos emb helpers:
179
- def rotate_half(x):
180
- x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
181
- return torch.cat(
182
- (-x2, x1), dim=len(x.shape)-1
183
- )
184
-
185
- def rope_rotate(x, positions, cos, sin):
186
- return x * cos[:,positions] + rotate_half(x) * sin[:,positions]
187
-
188
- # %% ../nbs/A. Neural modules.ipynb 7
189
- class ResidualAttentionBlock(nn.Module):
190
- def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, rope: bool = False,
191
- qk_scale: float = 1, ffn_mult: int = 4):
192
- super().__init__()
193
- self.attn = MultiHeadAttention(n_state, n_head, qk_scale=qk_scale, rope=rope)
194
- self.attn_ln = LayerNorm(n_state)
195
-
196
- self.cross_attn = (
197
- MultiHeadAttention(n_state, n_head, qk_scale=qk_scale, rope=rope, cross=True) if cross_attention else None
198
- )
199
- self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
200
-
201
- n_mlp = n_state * ffn_mult
202
- self.mlp = nn.Sequential(
203
- nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state)
204
- )
205
- self.mlp_ln = LayerNorm(n_state)
206
-
207
- def setup_kv_cache(self, max_batch_size, max_seq_len, max_cross_seq_len=None):
208
- self.attn.setup_kv_cache(max_batch_size, max_seq_len)
209
- if self.cross_attn:
210
- self.cross_attn.setup_kv_cache(max_batch_size, max_cross_seq_len)
211
-
212
- def forward(
213
- self,
214
- x: Tensor,
215
- x_positions: Tensor = None,
216
- xa: Optional[Tensor] = None,
217
- xa_positions: Optional[Tensor] = None,
218
- causal = False,
219
- mask=None,
220
- ):
221
- lnx = self.attn_ln(x)
222
- x = x + self.attn(lnx, x_positions, lnx, x_positions, causal=causal, mask=mask)
223
- if self.cross_attn:
224
- lnx = self.cross_attn_ln(x)
225
- x = x + self.cross_attn(lnx, x_positions, xa, xa_positions)
226
- x = x + self.mlp(self.mlp_ln(x))
227
- return x
228
-
229
- # %% ../nbs/A. Neural modules.ipynb 8
230
- class BaseDecoder(nn.Module):
231
- def __init__(self, depth=6, n_head=6, width=384, qk_scale=1, ffn_mult=4, length=2250, rope=False):
232
- super().__init__()
233
- self.length = length
234
- self.width = width
235
- self.layers = nn.ModuleList([
236
- ResidualAttentionBlock(
237
- self.width, n_head, qk_scale=qk_scale, ffn_mult=ffn_mult, cross_attention=True, rope=rope
238
- ) for _ in range(math.floor(depth))
239
- ])
240
-
241
- self.ln_post = LayerNorm(width)
242
-
243
- mask = torch.empty(length, length).fill_(-torch.inf).triu_(1)
244
- self.register_buffer("mask", mask, persistent=False)
245
-
246
- def forward(self, x, x_positions, xenc, xenc_positions):
247
- for i,l in enumerate(self.layers):
248
- x = l(x, x_positions, xenc, xenc_positions, causal=False, mask=self.mask)
249
-
250
- x = self.ln_post(x)
251
-
252
- return x
253
-
254
- # %% ../nbs/A. Neural modules.ipynb 9
255
- class EmbeddingProjector(nn.Linear):
256
- pass
257
-
258
- class FlexEmbeddings(nn.Module):
259
- def __init__(self, codes, width, special_codes=None, frozen_width=None, special_embedding=None, unembed=True):
260
- super().__init__()
261
- self.codes = codes
262
- self.special_codes = special_codes
263
- if frozen_width is None: frozen_width = width
264
-
265
- self.main = nn.Embedding(codes, frozen_width or width)
266
- self.emb_to_hidden = EmbeddingProjector(frozen_width, width) if frozen_width != width else None
267
- self.hidden_to_emb = EmbeddingProjector(width, frozen_width) if unembed and frozen_width != width else None
268
- if special_codes:
269
- self.special = special_embedding or nn.Embedding(special_codes, width)
270
-
271
- self.register_buffer('merged_in', None)
272
- self.register_buffer('merged_out', None)
273
- self.register_buffer('bias_out', None)
274
-
275
- def set_frozen_embeddings(self, values):
276
- with torch.no_grad():
277
- self.main.weight[:] = values
278
- self.main.lr_scale = 0
279
-
280
- @torch.no_grad()
281
- def convert_for_eval(self):
282
- if not self.special_codes: return
283
- # in
284
- main_w = self.main.weight
285
- if self.emb_to_hidden is not None: main_w = self.emb_to_hidden(main_w)
286
- weight = torch.cat([main_w, self.special.weight], dim=0)
287
- self.merged_in = nn.Embedding(*weight.shape, _weight=weight)
288
-
289
- # out
290
- weight = self.main.weight
291
- if self.hidden_to_emb: weight = weight @ self.hidden_to_emb.weight
292
- self.merged_out = torch.cat([weight.T, self.special.weight.T], dim=1).T.contiguous() # T is for F.linear
293
- if self.hidden_to_emb:
294
- self.bias_out = torch.cat([
295
- self.hidden_to_emb.bias @ self.main.weight.T,
296
- torch.zeros(self.special.weight.shape[0], device=weight.device, dtype=weight.dtype)
297
- ], dim=0)
298
- else:
299
- self.bias_out = None
300
-
301
- def forward(self, toks):
302
- if not self.training and self.merged_in is not None:
303
- return self.merged_in(toks)
304
-
305
- if self.special_codes:
306
- special_mask = toks >= self.codes
307
- embs = self.main(torch.where(special_mask, 0, toks))
308
- else:
309
- embs = self.main(toks)
310
-
311
- if self.emb_to_hidden: embs = self.emb_to_hidden(embs)
312
-
313
- if self.special_codes:
314
- embs[special_mask] = self.special(toks[special_mask] - self.codes).to(embs.dtype)
315
-
316
- return embs
317
-
318
- def unembed(self, embs):
319
- if not self.training and self.merged_out is not None:
320
- return F.linear(embs, self.merged_out, self.bias_out) # embs @ self.merged_out + self.bias_out
321
-
322
- orig_embs = embs
323
- if self.hidden_to_emb: embs = self.hidden_to_emb(embs)
324
-
325
- main_logits = (embs @ self.main.weight.to(embs.dtype).T).float()
326
-
327
- if not self.special_codes:
328
- return main_logits
329
-
330
- special_logits = (orig_embs @ self.special.weight.to(orig_embs.dtype).T).float()
331
- return torch.cat([main_logits, special_logits], dim=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
whisperspeech/pipeline.py DELETED
@@ -1,93 +0,0 @@
1
- # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/7. Pipeline.ipynb.
2
-
3
- # %% auto 0
4
- __all__ = ['Pipeline']
5
-
6
- # %% ../nbs/7. Pipeline.ipynb 1
7
- import torch
8
- from whisperspeech.t2s_up_wds_mlang_enclm import TSARTransformer
9
- from whisperspeech.s2a_delar_mup_wds_mlang import SADelARTransformer
10
- from whisperspeech.a2wav import Vocoder
11
- import traceback
12
- from pathlib import Path
13
-
14
- # %% ../nbs/7. Pipeline.ipynb 2
15
- class Pipeline:
16
- default_speaker = torch.tensor(
17
- [-0.2929, -0.4503, 0.4155, -0.1417, 0.0473, -0.1624, -0.2322, 0.7071,
18
- 0.4800, 0.5496, 0.0410, 0.6236, 0.4729, 0.0587, 0.2194, -0.0466,
19
- -0.3036, 0.0497, 0.5028, -0.1703, 0.5039, -0.6464, 0.3857, -0.7350,
20
- -0.1605, 0.4808, 0.5397, -0.4851, 0.1774, -0.8712, 0.5789, 0.1785,
21
- -0.1417, 0.3039, 0.4232, -0.0186, 0.2685, 0.6153, -0.3103, -0.5706,
22
- -0.4494, 0.3394, -0.6184, -0.3617, 1.1041, -0.1178, -0.1885, 0.1997,
23
- 0.5571, -0.2906, -0.0477, -0.4048, -0.1062, 1.4779, 0.1639, -0.3712,
24
- -0.1776, -0.0568, -0.6162, 0.0110, -0.0207, -0.1319, -0.3854, 0.7248,
25
- 0.0343, 0.5724, 0.0670, 0.0486, -0.3813, 0.1738, 0.3017, 1.0502,
26
- 0.1550, 0.5708, 0.0366, 0.5093, 0.0294, -0.7091, -0.8220, -0.1583,
27
- -0.2343, 0.1366, 0.7372, -0.0631, 0.1505, 0.4600, -0.1252, -0.5245,
28
- 0.7523, -0.0386, -0.2587, 1.0066, -0.2037, 0.1617, -0.3800, 0.2790,
29
- 0.0184, -0.5111, -0.7291, 0.1627, 0.2367, -0.0192, 0.4822, -0.4458,
30
- 0.1457, -0.5884, 0.1909, 0.2563, -0.2035, -0.0377, 0.7771, 0.2139,
31
- 0.3801, 0.6047, -0.6043, -0.2563, -0.0726, 0.3856, 0.3217, 0.0823,
32
- -0.1302, 0.3287, 0.5693, 0.2453, 0.8231, 0.0072, 1.0327, 0.6065,
33
- -0.0620, -0.5572, 0.5220, 0.2485, 0.1520, 0.0222, -0.2179, -0.7392,
34
- -0.3855, 0.1822, 0.1042, 0.7133, 0.3583, 0.0606, -0.0424, -0.9189,
35
- -0.4882, -0.5480, -0.5719, -0.1660, -0.3439, -0.5814, -0.2542, 0.0197,
36
- 0.4942, 0.0915, -0.0420, -0.0035, 0.5578, 0.1051, -0.0891, 0.2348,
37
- 0.6876, -0.6685, 0.8215, -0.3692, -0.3150, -0.0462, -0.6806, -0.2661,
38
- -0.0308, -0.0050, 0.6756, -0.1647, 1.0734, 0.0049, 0.4969, 0.0259,
39
- -0.8949, 0.0731, 0.0886, 0.3442, -0.1433, -0.6804, 0.2204, 0.1859,
40
- 0.2702, 0.1699, -0.1443, -0.9614, 0.3261, 0.1718, 0.3545, -0.0686]
41
- )
42
-
43
- def __init__(self, t2s_ref=None, s2a_ref=None, optimize=True, torch_compile=False):
44
- args = dict()
45
- try:
46
- if t2s_ref:
47
- args["ref"] = t2s_ref
48
- self.t2s = TSARTransformer.load_model(**args).cuda()
49
- if optimize: self.t2s.optimize(torch_compile=torch_compile)
50
- except:
51
- print("Failed to load the T2S model:")
52
- print(traceback.format_exc())
53
- try:
54
- if s2a_ref:
55
- args["ref"] = s2a_ref
56
- self.s2a = SADelARTransformer.load_model(**args).cuda()
57
- if optimize: self.s2a.optimize(torch_compile=torch_compile)
58
- except:
59
- print("Failed to load the S2A model:")
60
- print(traceback.format_exc())
61
- self.vocoder = Vocoder()
62
- self.encoder = None
63
-
64
- def extract_spk_emb(self, fname):
65
- """Extracts a speaker embedding from the first 30 seconds of the give audio file.
66
- """
67
- import torchaudio
68
- if self.encoder is None:
69
- from speechbrain.pretrained import EncoderClassifier
70
- self.encoder = EncoderClassifier.from_hparams("speechbrain/spkrec-ecapa-voxceleb",
71
- savedir="~/.cache/speechbrain/",
72
- run_opts={"device": "cuda"})
73
- samples, sr = torchaudio.load(fname)
74
- samples = self.encoder.audio_normalizer(samples[0,:30*sr], sr)
75
- spk_emb = self.encoder.encode_batch(samples)
76
- return spk_emb[0,0]
77
-
78
- def generate_atoks(self, text, speaker=None, lang='en', cps=15, step_callback=None):
79
- if speaker is None: speaker = self.default_speaker
80
- elif isinstance(speaker, (str, Path)): speaker = self.extract_spk_emb(speaker)
81
- text = text.replace("\n", " ")
82
- stoks = self.t2s.generate(text, cps=cps, lang=lang, step=step_callback)
83
- atoks = self.s2a.generate(stoks, speaker.unsqueeze(0), step=step_callback)
84
- return atoks
85
-
86
- def generate(self, text, speaker=None, lang='en', cps=15, step_callback=None):
87
- return self.vocoder.decode(self.generate_atoks(text, speaker, lang=lang, cps=cps, step_callback=step_callback))
88
-
89
- def generate_to_file(self, fname, text, speaker=None, lang='en', cps=15, step_callback=None):
90
- self.vocoder.decode_to_file(fname, self.generate_atoks(text, speaker, lang=lang, cps=cps, step_callback=None))
91
-
92
- def generate_to_notebook(self, text, speaker=None, lang='en', cps=15, step_callback=None):
93
- self.vocoder.decode_to_notebook(self.generate_atoks(text, speaker, lang=lang, cps=cps, step_callback=None))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
whisperspeech/prepare_s2a_dataset.py DELETED
@@ -1,112 +0,0 @@
1
- # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/4A. S2A dataset preparation.ipynb.
2
-
3
- # %% auto 0
4
- __all__ = ['flac_to_s2a_name']
5
-
6
- # %% ../nbs/4A. S2A dataset preparation.ipynb 2
7
- import sys
8
- import os
9
- import itertools
10
- from pathlib import Path
11
-
12
- import numpy as np
13
- import torch
14
- import torchaudio
15
- import torch.nn.functional as F
16
- from torch.profiler import profile, record_function, ProfilerActivity
17
-
18
- from fastprogress import progress_bar
19
- from fastcore.script import *
20
-
21
- import whisper
22
- from . import vad, wh_transcribe, vq_stoks, extract_acoustic
23
- import webdataset as wds
24
-
25
- # %% ../nbs/4A. S2A dataset preparation.ipynb 4
26
- def flac_to_s2a_name(input):
27
- if '-flac-' in input:
28
- return input.rsplit("/", 1)[1].replace('flac', 's2a') + ".gz"
29
- else:
30
- return input.rsplit("/", 1)[1].replace('raw', 's2a') + ".gz"
31
-
32
- # %% ../nbs/4A. S2A dataset preparation.ipynb 6
33
- def resampler(newsr = 24000, key = 'samples_24k'):
34
- _last_sr = None
35
- tform = None
36
-
37
- def _resample(samples):
38
- for s in samples:
39
- sr = s['sample_rate']
40
- if sr != newsr:
41
- if sr != _last_sr: tform = torchaudio.transforms.Resample(sr, newsr)
42
- s[key] = tform(s['samples'])
43
- else:
44
- s[key] = s['samples']
45
- yield s
46
-
47
- return _resample
48
-
49
- # %% ../nbs/4A. S2A dataset preparation.ipynb 9
50
- @call_parse
51
- def prepare_s2a(
52
- input:str, # FLAC webdataset file path (or - to read the names from stdin)
53
- proc_dataset_path:Path, # processed VAD files path
54
- output:str=None, # output file name
55
- vq_model:str="collabora/spear-tts-pytorch:whisper-vq-stoks.model", # the model path (use repo_id:filename to download it from hugginface)
56
- n_samples:int=None, # process a limited amount of samples
57
- batch_size:int=1, # process several segments at once
58
- fix_dots:bool=False, # fix dots in file names
59
- ):
60
- if ":" in vq_model:
61
- repo, fname = vq_model.split(":", 1)
62
- vq_model = vq_stoks.RQBottleneckTransformer.load_model(repo, fname).cuda()
63
- else:
64
- vq_model = vq_stoks.RQBottleneckTransformer.load_model(local_filename=vq_model).cuda()
65
- amodel = extract_acoustic.load_model()
66
- amodel.set_target_bandwidth(3)
67
-
68
- if input == "-":
69
- input = [f.strip() for f in sys.stdin.readlines()]
70
- assert output, "please provide the output shard name"
71
- else:
72
- if output is None: output = flac_to_s2a_name(input)
73
- input = [input]
74
-
75
- total = n_samples//batch_size if n_samples else 'noinfer'
76
-
77
- ds = wds.WebDataset(input, shardshuffle=True, rename_files=vad.fix_dots_in_names if fix_dots else None).compose(
78
- wds.decode(wds.torch_audio),
79
- wds.select(lambda x: 'wav' in x or 'flac' in x),
80
- vq_stoks.merge_in(vq_stoks.derived_dataset(proc_dataset_path, 'vad')),
81
- wds.map_dict(**{"vad.npy":wh_transcribe.chunk_merger}),
82
- lambda x: wh_transcribe.split_to_chunks(x),
83
- resampler(),
84
- resampler(16000, 'samples_16k'),
85
- wds.to_tuple('__key__', 'rpad_s', 'samples_16k', 'samples_24k'),
86
- wds.batched(64),
87
- )
88
-
89
- dl = wds.WebLoader(ds, num_workers=4, batch_size=None).unbatched().shuffle(2000).batched(batch_size)
90
-
91
- speakers = set()
92
- tmp = output+".tmp"
93
- with wds.TarWriter(tmp) as sink:
94
- for keys, rpad_ss, samples, samples24k in progress_bar(dl, total=total):
95
- with record_function('to_cuda'):
96
- samples, samples24k = samples.cuda(), samples24k.unsqueeze(1).cuda()
97
- with record_function('encodec'):
98
- atoks = amodel.encode(samples24k)[0][0]
99
- with record_function('vq_stoks'):
100
- stoks = vq_model.encode_audio(samples)
101
- with record_function('from_cuda'):
102
- atoks, stoks = atoks.cpu().numpy().astype(np.int16), stoks.cpu().numpy().astype(np.int16)
103
- for key, rpad_s, _atoks, _stoks in zip(keys, rpad_ss, atoks, stoks):
104
- speakers.add(key.split('/')[1])
105
- sink.write({
106
- "__key__": key,
107
- "atoks.npy": _atoks[:,:int(-rpad_s * 75)],
108
- "stoks.npy": _stoks[:int(-rpad_s * 25)],
109
- })
110
- with open(output+".speakers.txt", "w") as f: f.write("\n".join(speakers))
111
- if not n_samples:
112
- os.rename(tmp, output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
whisperspeech/prepare_t2s_dataset.py DELETED
@@ -1,111 +0,0 @@
1
- # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/5A. T2S dataset preparation.ipynb.
2
-
3
- # %% auto 0
4
- __all__ = []
5
-
6
- # %% ../nbs/5A. T2S dataset preparation.ipynb 2
7
- import sys
8
- import os
9
- import itertools
10
- from pathlib import Path
11
-
12
- import numpy as np
13
- import torch
14
- import torchaudio
15
- import torch.nn.functional as F
16
- from torch.profiler import profile, record_function, ProfilerActivity
17
-
18
- from fastprogress import progress_bar
19
- from fastcore.script import *
20
-
21
- import whisper, whisperx
22
- from . import vad, wh_transcribe, vq_stoks, extract_acoustic
23
- import webdataset as wds
24
-
25
- # %% ../nbs/5A. T2S dataset preparation.ipynb 4
26
- def flac_to_t2s_name(input):
27
- return input.rsplit("/", 1)[1].replace('flac', 't2s') + ".gz"
28
-
29
- # %% ../nbs/5A. T2S dataset preparation.ipynb 6
30
- class Transcriber:
31
- """
32
- A helper class to transcribe a batch of 30 second audio chunks.
33
- """
34
- def __init__(self, model_size, lang=False):
35
- self.model = whisperx.asr.load_model(model_size, "cuda", compute_type="float16", language=lang)
36
- # without calling vad_model at least once the rest segfaults for some reason...
37
- self.model.vad_model({"waveform": torch.zeros(1, 16000), "sample_rate": 16000})
38
-
39
- def transcribe(self, batch):
40
- batch = whisper.log_mel_spectrogram(batch)
41
- embs = self.model.model.encode(batch.cpu().numpy())
42
- return self.model.tokenizer.tokenizer.decode_batch([x.sequences_ids[0] for x in
43
- self.model.model.model.generate(
44
- embs,
45
- [self.model.model.get_prompt(self.model.tokenizer, [], without_timestamps=True)]*len(batch),
46
- )])
47
-
48
- # %% ../nbs/5A. T2S dataset preparation.ipynb 7
49
- @call_parse
50
- def prepare_t2s(
51
- input:str, # FLAC webdataset file path (or - to read the names from stdin)
52
- proc_dataset_path:Path, # processed VAD files path
53
- output:str=None, # output file name
54
- vq_model:str="collabora/spear-tts-pytorch:whisper-vq-stoks.model", # the model path (use repo_id:filename to download it from hugginface)
55
- n_samples:int=None, # process a limited amount of samples
56
- batch_size:int=1, # process several segments at once
57
- transcription_model:str="small.en",
58
- ):
59
- if ":" in vq_model:
60
- repo, fname = vq_model.split(":", 1)
61
- vq_model = vq_stoks.RQBottleneckTransformer.load_model(repo, fname).cuda()
62
- else:
63
- vq_model = vq_stoks.RQBottleneckTransformer.load_model(local_filename=vq_model).cuda()
64
- transcriber = Transcriber(transcription_model)
65
-
66
- if input == "-":
67
- input = [f.strip() for f in sys.stdin.readlines()]
68
- assert output, "please provide the output shard name"
69
- else:
70
- if output is None: output = flac_to_t2s_name(input)
71
- input = [input]
72
-
73
- total = n_samples//batch_size if n_samples else 'noinfer'
74
- if n_samples: print(f"Benchmarking run of {n_samples} samples ({total} batches)")
75
-
76
- ds = wds.WebDataset(input, shardshuffle=True, rename_files=vad.fix_dots_in_names).compose(
77
- wds.decode(wds.torch_audio),
78
- vq_stoks.merge_in(vq_stoks.derived_dataset(proc_dataset_path, 'vad')),
79
- wds.map_dict(**{"vad.npy": lambda s: wh_transcribe.chunk_merger(s, wh_transcribe.random_cutter)}),
80
- lambda x: wh_transcribe.split_to_chunks(x),
81
- # drop the first and last segment because they tend to be inaccurate
82
- # (the transcriptions don't have the "LibriVox" header and "end of chapter" suffix)
83
- wds.select(lambda x: x['i'] != 0 and x['i'] != x['imax']),
84
- wds.to_tuple('__key__', 'rpad', 'samples'),
85
- wds.batched(64),
86
- )
87
-
88
- dl = wds.WebLoader(ds, num_workers=4, batch_size=None).unbatched().shuffle(2000).batched(batch_size)
89
-
90
- speakers = set()
91
- tmp = output+".tmp"
92
- with wds.TarWriter(tmp) as sink:
93
- for keys, rpads, samples in progress_bar(dl, total=total):
94
- with record_function('to_cuda'):
95
- csamples = samples.cuda()
96
- with record_function('transcribe'):
97
- txts = transcriber.transcribe(csamples)
98
- with record_function('vq_stoks'):
99
- stoks = vq_model.encode_audio(csamples)
100
- with record_function('from_cuda'):
101
- stoks = stoks.cpu().numpy().astype(np.int16)
102
- for key, rpad, txt, _stoks in zip(keys, rpads, txts, stoks):
103
- speakers.add(key.split('/')[1])
104
- sink.write({
105
- "__key__": key,
106
- "txt": txt,
107
- "stoks.npy": _stoks[:int(-rpad/16000 * 25)],
108
- })
109
- with open(output+".speakers.txt", "w") as f: f.write("\n".join(speakers))
110
- if not n_samples:
111
- os.rename(tmp, output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
whisperspeech/s2a_delar_mup_wds.py DELETED
@@ -1,688 +0,0 @@
1
- # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/4B. Semantic to acoustic token modeling.ipynb.
2
-
3
- # %% auto 0
4
- __all__ = ['load_datasets', 'CMLMVisual', 'Rotary', 'rotate_half', 'apply_rotary_pos_emb', 'ResidualAttentionBlock',
5
- 'MultiHeadAttention', 'DelSumDecoder', 'EmbeddingProjector', 'rand', 'Tunables', 'SADelARTransformer']
6
-
7
- # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 1
8
- import io
9
- import time
10
- import math
11
- import random
12
- import dataclasses
13
-
14
- # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 2
15
- import torch
16
- import torch.nn as nn
17
- import torch.nn.functional as F
18
- from torch.profiler import profile, record_function, ProfilerActivity, schedule
19
- from fastcore.basics import store_attr
20
- from huggingface_hub import hf_hub_download
21
-
22
- # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 3
23
- from pathlib import Path
24
- import json
25
- from fastprogress import progress_bar, master_bar
26
- import webdataset as wds
27
-
28
- # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 4
29
- from .train import *
30
- from .modules import *
31
- from . import vq_stoks
32
-
33
- # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 8
34
- def rand(start, end):
35
- return random.random() * (end - start) + start
36
-
37
- # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 9
38
- def random_trunc(random_trunc_p, atoks_len = 2250, stoks_len = 750):
39
- atoks_per_second = atoks_len / 30
40
- def _trunc(samples):
41
- for s in samples:
42
- if random.random() < random_trunc_p:
43
- seconds = rand(0.3, 30)
44
- s['atoks.npy'] = s['atoks.npy'][:,:math.ceil(seconds * atoks_per_second)]
45
- s['stoks.npy'] = s['stoks.npy'][:math.ceil(s['atoks.npy'].shape[-1]/atoks_len*stoks_len)]
46
- yield s
47
- return _trunc
48
-
49
- def pad_samples(atoks_len = 2250, stoks_len = 750, stoks_pad_token = 4096):
50
- def _pad(samples):
51
- for s in samples:
52
- s['stoks.npy'] = F.pad(torch.tensor(s['stoks.npy']), (0, stoks_len - s['stoks.npy'].shape[-1]), value=stoks_pad_token)
53
- s['atoks.npy'] = F.pad(torch.tensor(s['atoks.npy']), (0, atoks_len - s['atoks.npy'].shape[-1]), value=-100)
54
- yield s
55
- return _pad
56
-
57
- # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 10
58
- def speaker_id_extractor(speaker_map):
59
- def _extractor(samples):
60
- for s in samples:
61
- s['speaker'] = torch.tensor(speaker_map[s['__key__'].split("/")[1]])
62
- yield s
63
- return _extractor
64
-
65
- # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 14
66
- def load_datasets(
67
- input:str, # webdataset folder
68
- samples:int, # samples per epoch
69
- subsample:float=1, # use a fraction of the files
70
- val_samples:int=512,
71
- random_trunc_p:float=0,# probability of truncating the input to less than 30 seconds
72
- stoks_pad_token=4096,
73
- ):
74
-
75
- if isinstance(input, (Path, str)):
76
- path = Path(input)
77
- if path.is_dir():
78
- glob = '*-s2a-*.tar.gz'
79
- else:
80
- glob = path.name
81
- path = path.parent
82
- input = Path(path).glob(glob)
83
- elif isinstance(input, list):
84
- pass
85
- else:
86
- raise ArgumentError("input should be either a list or a path with an optional glob specifier")
87
- shards = [str(x) for x in input]
88
-
89
- speakers = set()
90
- for shard in shards:
91
- with open(shard+'.speakers.txt') as f: speakers = speakers.union(set(x.strip() for x in f.readlines()))
92
- speakers = {id:i for i,id in enumerate(sorted(speakers))}
93
-
94
- def ds(shards, length):
95
- ds = wds.WebDataset(wds.ResampledShards(shards)).compose(
96
- wds.decode(),
97
- speaker_id_extractor(speakers),
98
- random_trunc(random_trunc_p) if random_trunc_p > 0 else lambda x: x,
99
- pad_samples(stoks_pad_token=stoks_pad_token),
100
- wds.to_tuple('stoks.npy', 'atoks.npy', 'speaker'),
101
- wds.batched(64),
102
- )
103
- ds.speakers = speakers
104
- ds.total_samples = length
105
- return ds.compose(wds.slice(length // 64)).with_epoch(length // 64).with_length(length // 64)
106
-
107
- return (
108
- ds(shards[1:], samples),
109
- ds(shards[:1], val_samples),
110
- )
111
-
112
- # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 33
113
- import pylab as plt
114
- import fastprogress
115
- import IPython
116
- import numpy as np
117
-
118
- class CMLMVisual:
119
- """Visualize training progress"""
120
- def __init__ (self, model, masterbar, total_steps):
121
- self.model = model
122
- self.masterbar = masterbar
123
- self.total_steps = total_steps
124
- self.epochs = total_steps // masterbar.main_bar.total
125
-
126
- gs = plt.GridSpec(3, 1, height_ratios=[2,2,1])
127
- graph_fig = plt.figure(figsize=(10,6))
128
- self.graph_fig = graph_fig
129
- self.loss_p = graph_fig.add_subplot(gs[0])
130
- self.acc_p = graph_fig.add_subplot(gs[1], sharex=self.loss_p)
131
- self.acc_p.tick_params('x', labelbottom=False)
132
- self.lr_p = graph_fig.add_subplot(gs[2], sharex=self.loss_p)
133
- self.lr_p.tick_params('x', labelbottom=False)
134
- self.graph_out = None
135
-
136
- self.its = []
137
- self.train_losses = []
138
- self.val_losses = []
139
- self.lr_history = []
140
- self.acc = np.nan
141
- self.acc_history = []
142
- self.pacc_history = []
143
-
144
- def show(self):
145
- self.start_t = time.time()
146
- self.masterbar.write(["samples", "train", "val", "time"], table=True)
147
- self.graph_out = display(self.graph_fig, display_id=True)
148
- self.acc_out = display(IPython.display.HTML(''), display_id=True)
149
-
150
- def hide(self):
151
- if self.graph_out is not None:
152
- self.graph_out.update(IPython.display.HTML(''))
153
-
154
- def plot(self):
155
- loss_p, acc_p, lr_p = self.loss_p, self.acc_p, self.lr_p
156
- loss_p.clear()
157
- loss_p.plot(self.its, self.train_losses)
158
- loss_p.plot(self.its, self.val_losses)
159
- loss_p.set_xlim(0, self.total_steps)
160
- loss_p.set_yscale('log')
161
- acc_p.clear()
162
- for k in self.acc_history[-1].keys():
163
- acc_p.plot(self.its, [x[k] for x in self.acc_history], ':')
164
- # acc_p.plot(self.its, np.stack(self.pacc_history), label=range(len(self.pacc_history[0])))
165
- lr_p.clear()
166
- lrs = np.array(self.lr_history)
167
- lr_p.plot(self.its, lrs)
168
- self.graph_out.update(self.graph_fig)
169
-
170
- def add_data(self, it, lr, train_loss, val_los):
171
- self.its.append(it)
172
- self.train_losses.append(train_loss)
173
- self.val_losses.append(val_los)
174
- self.lr_history.append(lr)
175
- metrics = self.model.get_metrics()
176
- self.acc_history.append(metrics)
177
- # self.acc_out.update(f"Accuracy: {self.entropy_history[-1]:.2f}")
178
- # self.pacc_history.append((self.model.pval_true / self.model.pval_total).cpu().numpy())
179
- # if self.acc_history:
180
- html = "<h5>Accuracies:</h5><table>"
181
- html += "<thead>"+(''.join([f"<td>{k}<td>" for k,x in metrics.items()]))+"</thead>"
182
- html += "<tr>"+(''.join([f"<td>{x*100:.1f}%<td>" for k,x in metrics.items()]))+"</tr>"
183
- html += "</table>"
184
- self.acc_out.update(IPython.display.HTML(html))
185
- self.plot()
186
-
187
- def add_table_row(self, it, avg_train_loss, val_loss):
188
- elapsed_t = time.time() - self.start_t
189
- self.masterbar.write([it, f"{avg_train_loss:.5f}", f"{val_loss:.5f}", fastprogress.core.format_time(elapsed_t)], table=True)
190
-
191
- def on_iter(self, bar, it, avg_train_loss, val_loss):
192
- epoch = math.ceil(it / self.total_steps * self.epochs)
193
- bar.comment = f"#{epoch}/{self.epochs} loss: {avg_train_loss:.3f} / {val_loss:.3f}"
194
-
195
- # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 34
196
- # modified from https://blog.eleuther.ai/rotary-embeddings/
197
- import torch
198
-
199
- class Rotary(torch.nn.Module):
200
- def __init__(self, dim, base=10000):
201
- super().__init__()
202
- inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
203
- self.register_buffer("inv_freq", inv_freq)
204
- self.seq_len_cached = None
205
- self.cos_cached = None
206
- self.sin_cached = None
207
-
208
- def forward(self, x, seq_dim=1):
209
- seq_len = x.shape[seq_dim]
210
- if seq_len != self.seq_len_cached:
211
- self.seq_len_cached = seq_len
212
- t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
213
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
214
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
215
- self.cos_cached = emb.cos()[None, :, None, :]
216
- self.sin_cached = emb.sin()[None, :, None, :]
217
- return self.cos_cached, self.sin_cached
218
-
219
-
220
- # rotary pos emb helpers:
221
- def rotate_half(x):
222
- x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
223
- return torch.cat(
224
- (-x2, x1), dim=-1
225
- )
226
-
227
- #@torch.jit.script
228
- def apply_rotary_pos_emb(q, k, cos, sin):
229
- return (q * cos[:,:q.shape[1]]) + (rotate_half(q) * sin[:,:q.shape[1]]), (k * cos) + (rotate_half(k) * sin)
230
-
231
- # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 35
232
- from torch import Tensor, nn
233
- import torch.nn.functional as F
234
- from typing import Dict, Iterable, Optional
235
-
236
- class ResidualAttentionBlock(nn.Module):
237
- def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, rope: bool = False,
238
- qk_scale: float = 1, ffn_mult: int = 4):
239
- super().__init__()
240
-
241
- self.attn = MultiHeadAttention(n_state, n_head, qk_scale=qk_scale, rope=rope)
242
- self.attn_ln = LayerNorm(n_state)
243
-
244
- self.cross_attn = (
245
- MultiHeadAttention(n_state, n_head, qk_scale=qk_scale, rope=rope) if cross_attention else None
246
- )
247
- self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
248
-
249
- n_mlp = n_state * ffn_mult
250
- self.mlp = nn.Sequential(
251
- nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state)
252
- )
253
- self.mlp_ln = LayerNorm(n_state)
254
-
255
- def forward(
256
- self,
257
- x: Tensor,
258
- xa: Optional[Tensor] = None,
259
- causal = False,
260
- kv_cache: Optional[dict] = None,
261
- ):
262
- x = x + self.attn(self.attn_ln(x), causal=causal, kv_cache=kv_cache)[0]
263
- if self.cross_attn:
264
- x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
265
- x = x + self.mlp(self.mlp_ln(x))
266
- return x
267
-
268
- class MultiHeadAttention(nn.Module):
269
- def __init__(self, n_state: int, n_head: int, qk_scale: float = 1, rope: bool = False):
270
- super().__init__()
271
- self.n_head = n_head
272
- self.sqrt_qk_scale = math.sqrt(qk_scale)
273
- self.query = QueryHead(n_state, n_state)
274
- self.key = nn.Linear(n_state, n_state, bias=False)
275
- self.value = nn.Linear(n_state, n_state)
276
- self.out = nn.Linear(n_state, n_state)
277
-
278
- self.rotary = None
279
- if rope:
280
- self.rotary = Rotary(n_state // n_head)
281
-
282
- def forward(
283
- self,
284
- x: Tensor,
285
- xa: Optional[Tensor] = None,
286
- causal = False,
287
- kv_cache: Optional[dict] = None,
288
- ):
289
- q = self.query(x)
290
-
291
- if kv_cache is None or xa is None or self.key not in kv_cache:
292
- # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
293
- # otherwise, perform key/value projections for self- or cross-attention as usual.
294
- k = self.key(x if xa is None else xa)
295
- v = self.value(x if xa is None else xa)
296
- else:
297
- # for cross-attention, calculate keys and values once and reuse in subsequent calls.
298
- k = kv_cache[self.key]
299
- v = kv_cache[self.value]
300
-
301
- if self.sqrt_qk_scale != 1:
302
- q *= self.sqrt_qk_scale
303
- k *= self.sqrt_qk_scale
304
-
305
- wv, qk = self.qkv_attention_pth20(q, k, v, causal)
306
- # wv, qk = self.qkv_attention_xformers(q, k, v, causal)
307
-
308
- return self.out(wv), qk
309
-
310
- def qkv_attention_pth20(
311
- self, q: Tensor, k: Tensor, v: Tensor, causal = False
312
- ):
313
- n_batch, n_ctx, n_state = q.shape
314
- q = q.view(*q.shape[:2], self.n_head, -1)
315
- k = k.view(*k.shape[:2], self.n_head, -1)
316
- v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
317
-
318
- #print('before rot:', q.shape, k.shape)
319
- if self.rotary:
320
- q, k = apply_rotary_pos_emb(q, k, *self.rotary(k))
321
- #print(' after rot:', q.shape, k.shape)
322
-
323
- k = k.permute(0, 2, 1, 3)
324
- q = q.permute(0, 2, 1, 3)
325
- # modified for better performance under PyTorch 2.0
326
- wv = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0, is_causal=causal)
327
-
328
- # previously we've returned q@k which we don't have now
329
- # since it's not actually used anywhere else, let's just keep two return values for compatibility
330
- return wv.permute(0, 2, 1, 3).flatten(start_dim=2), None
331
-
332
- def qkv_attention_xformers(
333
- self, q: Tensor, k: Tensor, v: Tensor, causal = False
334
- ):
335
- n_batch, n_ctx, n_state = q.shape
336
- q = q.view(*q.shape[:2], self.n_head, -1)
337
- k = k.view(*k.shape[:2], self.n_head, -1)
338
- v = v.view(*v.shape[:2], self.n_head, -1)
339
-
340
- if self.rotary:
341
- q, k = apply_rotary_pos_emb(q, k, *self.rotary(k))
342
-
343
- bias = xops.LowerTriangularMask() if causal else None
344
- wv = xops.memory_efficient_attention(q,k,v, attn_bias=bias)
345
-
346
- # previously we've returned q@k which we don't have now
347
- # since it's not actually used anywhere else, let's just keep two return values for compatibility
348
- return wv.flatten(start_dim=2), None
349
-
350
- # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 36
351
- class DelSumDecoder(nn.Module):
352
- def __init__(self, depth=6, n_head=6, head_width=64, qk_scale=1, ffn_mult=4, length=2250, codes=1024, quantizers=8, linear_heads=True, rope=False, pos_embs=None):
353
- super().__init__()
354
- self.length = length
355
- width = n_head * head_width
356
- self.width = width
357
- self.codes = codes
358
- self.quantizers = quantizers
359
- self.linear_heads = linear_heads
360
-
361
- self.embeddings = nn.ModuleList([nn.Embedding(codes+1, width) for _ in range(quantizers)])
362
- if pos_embs is not None:
363
- self.register_buffer("positional_embedding", pos_embs)
364
-
365
- self.layers = nn.ModuleList([
366
- ResidualAttentionBlock(width, n_head, qk_scale=qk_scale, ffn_mult=ffn_mult, cross_attention=True, rope=rope) for _ in range(math.floor(depth))
367
- ])
368
-
369
- self.ln_post = LayerNorm(width)
370
-
371
- if self.linear_heads:
372
- self.heads = LinearHead(width, (codes+1) * quantizers, bias=False)
373
- else:
374
- self.splitter = nn.Sequential(
375
- nn.Linear(width, width * quantizers),
376
- nn.GELU(),
377
- )
378
- self.heads = nn.ModuleList([
379
- LinearHead(width, codes+1, bias=True) for _ in range(quantizers)
380
- ])
381
-
382
- def forward(self, toks, xenc):
383
- b,_,n = toks.shape
384
- newn = min(n+1, self.length)
385
- embs = torch.zeros((b,newn,self.width), dtype=xenc.dtype, device=xenc.device)
386
- for i in range(self.quantizers):
387
- embs[:,:i+1] += self.embeddings[i](torch.tensor([self.codes], device=xenc.device))
388
- if i < n:
389
- embs[:,i+1:] += self.embeddings[i](toks[:,i,:newn-i-1])
390
-
391
- x = embs.to(xenc.dtype)
392
-
393
- for l in self.layers:
394
- x = l(x, xenc, causal=True)
395
- x = self.ln_post(x)
396
-
397
- if self.linear_heads:
398
- logits = self.heads(x).view(b,newn,self.quantizers,self.codes+1).permute(0,2,1,3)
399
- else:
400
- split = self.splitter(x).view(b,newn,self.quantizers,self.width)
401
- logits = torch.stack([self.heads[q](split[:,:,q]) for q in range(self.quantizers)], dim=1)
402
-
403
- return logits
404
-
405
- class EmbeddingProjector(nn.Linear):
406
- pass
407
-
408
- def rand(start, end):
409
- return random.random() * (end - start) + start
410
-
411
- @dataclasses.dataclass
412
- class Tunables:
413
- init_std :float = 9
414
- embeddings_std :float = 0.2
415
- embeddings_lr_scale: float = 10
416
- output_mult :float = 5.6
417
- # FIXME: try separate mults for self and cross attention
418
- query_mult :float = .3
419
- encoder_depth_ratio :float = 0.25
420
- linear_heads :bool = False
421
- rope :bool = True
422
-
423
- lr0 :float = 3e-3
424
- clip_gradient_norm :float = 2
425
- weight_decay :float = 1e-3
426
- warmup_steps :float = 2000
427
-
428
- random :bool = False
429
-
430
- def __post_init__(self):
431
- # randomize the hyperparams if requested
432
- if self.random:
433
- self.init_std = 2*10**rand(0,1)
434
- self.embeddings_std = 10**rand(-1.7,-0.22)
435
- self.embeddings_lr_scale = 2**rand(2,4)
436
- self.output_mult = 2**rand(1.5,3)
437
- self.query_mult = 2**rand(-3,-1.3)
438
- self.encoder_depth_ratio = random.choice([0.25,0.5])
439
- self.linear_heads = False
440
- self.rope = True
441
-
442
- self.lr0 = 3e-3
443
- self.clip_gradient_norm = 10**rand(-1,1)
444
- self.warmup_steps = 100*(10**rand(1.18,1.3))
445
-
446
- @staticmethod
447
- def upgrade(args):
448
- args = {k:v for k,v in args.items()}
449
- def old_default(name, value):
450
- if name not in args: args[name] = value
451
- old_default('rope', False)
452
- old_default('linear_heads', True)
453
- return args
454
-
455
- class SADelARTransformer(nn.Module):
456
- def __init__(self, depth=3, ctx_n=2250, stoks_len=750, stoks_codes=4097, stoks_width=None, spk_width=None, n_head=3, head_width=64, ffn_mult=4,
457
- quantizers=8, speaker_map={"1":0}, tunables=Tunables()):
458
- super().__init__()
459
- self.quantizers = quantizers
460
- width = n_head * head_width
461
- store_attr("depth,ctx_n,stoks_len,stoks_codes,stoks_width,spk_width,n_head,head_width,ffn_mult,quantizers,speaker_map")
462
- self.width = width
463
- self.base_width = 3 * head_width
464
- self.tunables = tunables
465
-
466
- if stoks_width is None: stoks_width = width
467
- if spk_width is None: spk_width = width
468
- self.emb_factor = width != stoks_width
469
- self.spk_factor = width != spk_width
470
-
471
- if tunables.rope:
472
- self.positional_embeddings = None
473
- else:
474
- self.register_buffer('positional_embeddings', sinusoids(ctx_n, width))
475
-
476
- self.speaker_embedding = nn.Embedding(len(speaker_map), width)
477
- self.semantic_embedding = nn.Embedding(stoks_codes, stoks_width)
478
- if self.emb_factor:
479
- self.emb_to_hidden = nn.Linear(stoks_width, width)
480
-
481
- if self.spk_factor:
482
- self.spk_to_hidden = EmbeddingProjector(spk_width, width)
483
-
484
- qk_scale = self.tunables.query_mult * 8 / math.sqrt(head_width)
485
-
486
- encoder_depth = int(depth * 2 * tunables.encoder_depth_ratio)
487
- decoder_depth = depth * 2 - encoder_depth
488
- self.encoder = nn.Sequential(*[
489
- ResidualAttentionBlock(width, n_head, qk_scale=qk_scale, ffn_mult=ffn_mult, rope=tunables.rope) for _ in range(encoder_depth)
490
- ])
491
- self.ln_post = LayerNorm(width)
492
-
493
- self.decoder = DelSumDecoder(pos_embs=self.positional_embeddings, qk_scale=qk_scale,
494
- length=ctx_n, n_head=n_head, head_width=head_width, ffn_mult=ffn_mult,
495
- depth=decoder_depth, quantizers=quantizers,
496
- linear_heads=tunables.linear_heads, rope=tunables.rope)
497
-
498
- self.register_buffer('val_true', torch.zeros(self.quantizers).cuda())
499
- self.register_buffer('val_total', torch.zeros(self.quantizers).cuda())
500
- self.apply(self.init_transformer)
501
-
502
- def setup(self, device):
503
- pass
504
-
505
- def load_frozen_semantic_embeddings(self, vqmodel):
506
- with torch.no_grad():
507
- self.semantic_embedding.weight[:] = vqmodel.rq.layers[0]._codebook.embed[0]
508
- self.semantic_embedding.lr_scale = 0
509
-
510
- def init_transformer(self, m):
511
- if isinstance(m, LinearHead):
512
- m.no_weight_decay = True
513
- torch.nn.init.constant_(m.weight, 0)
514
- elif isinstance(m, QueryHead):
515
- m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
516
- torch.nn.init.constant_(m.weight, 0)
517
- elif isinstance(m, nn.Embedding):
518
- m.no_weight_decay = True
519
- m.lr_scale = self.tunables.embeddings_lr_scale
520
- std = self.tunables.embeddings_std
521
- torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
522
- elif isinstance(m, EmbeddingProjector):
523
- m.lr_scale = self.tunables.embeddings_lr_scale/2
524
- std = self.tunables.init_std
525
- torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
526
- elif isinstance(m, nn.Linear):
527
- m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
528
- std = self.tunables.init_std / m.weight.shape[1]
529
- torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
530
- if m.bias is not None:
531
- torch.nn.init.trunc_normal_(m.bias, std=std, a=-3*std, b=3*std)
532
- elif isinstance(m, nn.LayerNorm):
533
- m.no_weight_decay = True
534
- torch.nn.init.constant_(m.bias, 0)
535
- torch.nn.init.constant_(m.weight, 1)
536
-
537
- def embed_stoks(self, Stoks):
538
- b,n = Stoks.shape
539
- if self.stoks_len == 1500:
540
- # converts 50 toks/s to 75 toks/s by adding padding between every two tokens
541
- x = Stoks.reshape(b,n//2,2)
542
- x = x.repeat_interleave(2, -1)[:,:,:3]
543
- x[:,:,1] = 1024
544
- x = x.reshape(b,n//2*3)
545
- else:
546
- # it's a lot easier with 25 toks/s
547
- x = Stoks.repeat_interleave(3, -1)
548
- # embed semantic tokens
549
- Sembs = self.semantic_embedding(x.to(torch.long))
550
- if self.emb_factor:
551
- Sembs = self.emb_to_hidden(Sembs)
552
- return Sembs
553
-
554
- def forward(self, Stoks, Atoks, speakers, noloss=False):
555
- Atoks = Atoks.to(torch.long)
556
- semb = self.embed_stoks(Stoks)
557
- with record_function("encoder"):
558
- if self.positional_embeddings is not None: semb = semb + self.positional_embeddings
559
- xenc = self.ln_post(self.encoder(semb))
560
- # xenc = torch.zeros_like(xenc)
561
- with record_function("decoder"):
562
- Atoks_gt = Atoks.clone()
563
- Atoks_gt[Atoks == -100] = 1024
564
- # we can randomize speaker ids during validation to measure
565
- # the importance of the speaker embedding vs. just the acoustic prompt/prefix
566
- # if not self.training: speakers = speakers[torch.randperm(speakers.nelement())]
567
- spk_embs = self.speaker_embedding(speakers)
568
- if self.spk_factor: spk_embs = self.spk_to_hidden(spk_embs)
569
- logits = self.decoder(Atoks_gt, xenc + spk_embs.unsqueeze(1))
570
- logits *= self.tunables.output_mult / (self.width / self.base_width)
571
-
572
- if noloss:
573
- return logits
574
-
575
- with record_function("loss"):
576
- N = Atoks.shape[-1]
577
- loss = 0
578
- for i in range(self.quantizers):
579
- loss += F.cross_entropy(logits[:,i,i:].reshape(-1,logits.shape[-1]), Atoks[:,i,:N-i].reshape(-1))
580
- loss /= self.quantizers
581
-
582
- if not self.training:
583
- for i in range(self.quantizers):
584
- Atoks_i = Atoks[:,i,:N-i]
585
- valid_Atoks = Atoks_i != -100
586
- self.val_true[i] += (logits[:,i,i:].argmax(-1)[valid_Atoks] == Atoks_i[valid_Atoks]).float().sum()
587
- self.val_total[i] += valid_Atoks.float().sum()
588
-
589
- return logits, loss
590
-
591
- def get_metrics(self):
592
- metrics = {
593
- f'acc_{i}':x.item() for i,x in enumerate(self.val_true / self.val_total)
594
- }
595
- self.val_true[:] = 0
596
- self.val_total[:] = 0
597
- return metrics
598
-
599
- #
600
- # inference
601
- #
602
- @classmethod
603
- def load_model(cls, repo_id="collabora/whisperspeech", filename="s2a_up_wds.model", local_filename=None):
604
- if not local_filename:
605
- local_filename = hf_hub_download(repo_id=repo_id, filename=filename)
606
- spec = torch.load(local_filename)
607
- if '_extra_state' not in spec['state_dict']: spec['state_dict']['_extra_state'] = { 'speaker_map': spec['config']['speaker_map'] }
608
- model = cls(**spec['config'], tunables=Tunables(**Tunables.upgrade(spec['tunables'])))
609
- model.load_state_dict(spec['state_dict'])
610
- model.eval()
611
- return model
612
-
613
- def get_extra_state(self):
614
- return { 'speaker_map': self.speaker_map }
615
-
616
- def set_extra_state(self, st):
617
- self.speaker_map = st['speaker_map']
618
-
619
- def load_checkpoint(self, local_filename):
620
- spec = torch.load(local_filename, map_location='cpu')
621
- assert 'pytorch-lightning_version' in spec, 'not a valid PyTorch Lightning checkpoint'
622
- state_dict = {k.replace('model.', ''):v
623
- for k,v in spec['state_dict'].items()}
624
- self.load_state_dict(state_dict)
625
- return self
626
-
627
- def save_model(self, fname):
628
- torch.save(dict(config = self.__stored_args__,
629
- tunables = dataclasses.asdict(self.tunables),
630
- state_dict = self.state_dict()), fname)
631
-
632
- @property
633
- def device(self):
634
- return next(self.parameters()).device
635
-
636
- @torch.no_grad()
637
- def generate(self, stoks, speakers, N=None, T=0.7, top_k=None, show_progress_bar=True):
638
- dev = self.device
639
- if self.stoks_len == 1500:
640
- N = N or len(stoks) * 3 // 2
641
- else:
642
- N = N or len(stoks) * 3
643
- stoks = F.pad(stoks.to(dev), (0, self.stoks_len - len(stoks)), value=self.stoks_codes-1).unsqueeze(0)
644
- speakers = torch.tensor([self.speaker_map[spk] for spk in speakers], device=dev)
645
- toks = torch.zeros((1,self.quantizers,N), dtype=torch.long, device=dev)
646
- it = range(0,N)
647
- if show_progress_bar: it = progress_bar(it)
648
- for i in it:
649
- p = self(stoks, toks[:,:,:i], speakers, noloss=True)
650
- last_p = p[0,:,-1]
651
- if top_k:
652
- last_p[last_p < torch.topk(last_p, top_k).values[:,-1,None]] = -torch.inf
653
- for j,tok in enumerate(torch.multinomial((last_p / float(T)).softmax(-1), 1)):
654
- toks[0,j,max(0,i-j)] = tok
655
- if toks[0,0,i] == 1024: return toks[0,:,:i]
656
- return toks[0]
657
-
658
- # %% ../nbs/4B. Semantic to acoustic token modeling.ipynb 37
659
- def _make_model(size:str, quantizers:int=4, tunables:Tunables=Tunables(), dataset:torch.utils.data.Dataset=None, **kwargs):
660
- assert(dataset is not None)
661
- kwargs = dict(speaker_map=dataset.speakers, quantizers=quantizers, tunables=tunables, **kwargs)
662
- if size == 'micro':
663
- return SADelARTransformer(depth=4, n_head=3, ffn_mult=2, **kwargs)
664
- if size == 'tiny-narrow':
665
- return SADelARTransformer(depth=4, n_head=6, ffn_mult=1, **kwargs)
666
- if size == 'tiny':
667
- return SADelARTransformer(depth=4, n_head=6, **kwargs)
668
- if size == 'base':
669
- return SADelARTransformer(depth=6, n_head=8, **kwargs)
670
- if size == 'base-deep':
671
- return SADelARTransformer(depth=9, n_head=8, **kwargs)
672
- if size == 'base-wide':
673
- return SADelARTransformer(depth=6, n_head=12, **kwargs)
674
- if size == 'small/2':
675
- return SADelARTransformer(depth=9, n_head=12, **kwargs)
676
- if size == 'small':
677
- return SADelARTransformer(depth=12, n_head=12, **kwargs)
678
- if size == 'medium':
679
- return SADelARTransformer(depth=24, n_head=16, **kwargs)
680
-
681
- def make_model(size:str, quantizers:int=4, frozen_embeddings_model:str=None, tunables:Tunables=Tunables(), dataset:torch.utils.data.Dataset=None):
682
- if frozen_embeddings_model:
683
- vqmodel = vq_stoks.RQBottleneckTransformer.load_model(frozen_embeddings_model)
684
- model = _make_model(size, quantizers, tunables, dataset, stoks_codes=vqmodel.vq_codes+1, stoks_width=vqmodel.rq.layers[0]._codebook.embed[0].shape[-1])
685
- model.load_frozen_semantic_embeddings(vqmodel)
686
- else:
687
- model = _make_model(size, quantizers, tunables, dataset)
688
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
whisperspeech/s2a_delar_mup_wds_mlang.py DELETED
@@ -1,564 +0,0 @@
1
- # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb.
2
-
3
- # %% auto 0
4
- __all__ = ['load_dataset', 'DelSumEmbedding', 'DelSumHead', 'rand', 'Tunables', 'SADelARTransformer']
5
-
6
- # %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 1
7
- import io
8
- import time
9
- import math
10
- import random
11
- import dataclasses
12
-
13
- # %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 2
14
- import torch
15
- import torch.nn as nn
16
- import torch.nn.functional as F
17
- import numpy as np
18
- from torch.profiler import profile, record_function, ProfilerActivity, schedule
19
- from fastcore.basics import store_attr
20
- from huggingface_hub import hf_hub_download
21
-
22
- # %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 3
23
- from pathlib import Path
24
- import json
25
- from fastprogress import progress_bar, master_bar
26
-
27
- # %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 4
28
- from .modules import *
29
-
30
- # %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 8
31
- def rand(start, end):
32
- return random.random() * (end - start) + start
33
-
34
- # %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 9
35
- def random_trunc(random_trunc_p, atoks_len = 2250, stoks_len = 750):
36
- atoks_per_second = atoks_len / 30
37
- def _trunc(samples):
38
- for s in samples:
39
- if random.random() < random_trunc_p:
40
- seconds = rand(0.3, 30)
41
- s['atoks.npy'] = s['atoks.npy'][:,:math.ceil(seconds * atoks_per_second)]
42
- s['stoks.npy'] = s['stoks.npy'][:math.ceil(s['atoks.npy'].shape[-1]/atoks_len*stoks_len)]
43
- yield s
44
- return _trunc
45
-
46
- def pad_samples(atoks_len = 2250, stoks_len = 750, stoks_pad_token = 4096):
47
- def _pad(samples):
48
- for s in samples:
49
- s['stoks.npy'] = F.pad(torch.tensor(s['stoks.npy']), (1, stoks_len - s['stoks.npy'].shape[-1]-1), value=stoks_pad_token)
50
- s['out_stoks'] = F.pad(torch.tensor(s['stoks.npy']), (0, stoks_len - s['stoks.npy'].shape[-1]), value=stoks_pad_token)
51
- s['atoks.npy'] = F.pad(torch.tensor(s['atoks.npy']), (0, atoks_len - s['atoks.npy'].shape[-1]), value=-100)
52
- yield s
53
- return _pad
54
-
55
- # %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 10
56
- def make_speaker_map(shards):
57
- speakers = set()
58
- for shard in shards:
59
- with open(shard+'.speakers.txt') as f: speakers = speakers.union(set(x.strip() for x in f.readlines()))
60
- return {id:i for i,id in enumerate(sorted(speakers))}
61
-
62
- def speaker_id_extractor(speaker_map):
63
- def _extractor(samples):
64
- for s in samples:
65
- s['speaker'] = torch.tensor(speaker_map[s['__key__'].split("/")[1]])
66
- yield s
67
- return _extractor
68
-
69
- # %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 27
70
- def load_dataset(
71
- atoks_shard_spec:str, # webdataset folder
72
- stoks_shard_dir:str, # stoks webdataset base dir
73
- samples:int, # samples per epoch
74
- random_trunc_p:float=0,# probability of truncating the input to less than 30 seconds
75
- vq_codes:int=4096,
76
- language:str='en',
77
- weight:float=1,
78
- validation:bool=False,
79
- exclude_files:str=None,
80
- randomize_speakers:bool=False,
81
- ):
82
- import webdataset as wds
83
- from whisperspeech import utils
84
-
85
- shards = utils.shard_glob(atoks_shard_spec)
86
- excludes = {x for file in exclude_files.split() for x in utils.readlines(file)} if exclude_files else set()
87
-
88
- def check_for_nan(s):
89
- if torch.tensor(s['spk_emb.npy']).isnan().any(): print("found NaN:", s['__key__'])
90
- return s
91
-
92
- def set_language(x):
93
- x['language'] = language
94
- return x
95
-
96
- same_on_all_nodes = lambda urls: urls # will only be used for validation
97
- ds = wds.WebDataset(shards, resampled=not validation, nodesplitter=same_on_all_nodes).compose(
98
- wds.decode(),
99
- utils.merge_in(utils.derived_dataset('maxvad-stoks', base='atoks-3kbps', suffix='', dir=stoks_shard_dir)),
100
- wds.map(check_for_nan),
101
- wds.select(lambda s: s['__key__'] not in excludes),
102
- wds.map_dict(**{'spk_emb.npy':np.nan_to_num}), # remove nans from the speaker embedding model
103
- random_trunc(random_trunc_p) if random_trunc_p > 0 else lambda x: x,
104
- pad_samples(stoks_pad_token=vq_codes-1),
105
- wds.map(set_language),
106
- wds.to_tuple('stoks.npy', 'atoks.npy', 'spk_emb.npy', 'language', 'out_stoks'),
107
- wds.shuffle(20000, initial=20000),
108
- wds.batched(64),
109
- )
110
- if randomize_speakers:
111
- rng = np.random.default_rng()
112
- ds = ds.compose(
113
- wds.map_tuple(None, None, lambda x: rng.permutation(x), None),
114
- )
115
- if validation:
116
- ds = ds.slice(samples // 64)
117
- ds.total_samples = samples
118
- ds.weight = weight
119
-
120
- return ds
121
-
122
- # %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 37
123
- class DelSumEmbedding(nn.Module):
124
- def __init__(self, n_head=6, head_width=64, atoks_width=None, length=2250, codes=1024, quantizers=8, pos_embs=None):
125
- super().__init__()
126
- self.length = length
127
- width = n_head * head_width
128
- if atoks_width is None: atoks_width = width
129
- self.width = width
130
- self.quantizers = quantizers
131
-
132
- emb = None
133
- embs = []
134
- for _ in range(quantizers):
135
- emb = FlexEmbeddings(codes, width, special_codes=2, frozen_width=atoks_width,
136
- special_embedding=emb and emb.special)
137
- embs.append(emb)
138
- self.embeddings = nn.ModuleList(embs)
139
- if pos_embs is not None:
140
- self.register_buffer("positional_embedding", pos_embs)
141
-
142
- def forward(self, toks, xenc):
143
- with record_function("embeddings"):
144
- b,_,n = toks.shape
145
- newn = min(n, self.length)
146
-
147
- embs = torch.zeros((b,newn,self.width), dtype=xenc.dtype, device=xenc.device)
148
- for i in range(self.quantizers):
149
- embs[:, :] += self.embeddings[i](toks[:,i,:])
150
-
151
- x = embs.to(xenc.dtype)
152
- return x
153
-
154
- # %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 38
155
- class DelSumHead(nn.Module):
156
- def __init__(self, quantizers=8, n_head=6, head_width=64):
157
- super().__init__()
158
- self.width = n_head * head_width
159
- self.quantizers = quantizers
160
- self.splitter = nn.Sequential(
161
- nn.Linear(self.width, self.width * quantizers),
162
- nn.GELU(),
163
- )
164
-
165
- def forward(self, x, embeddings=None):
166
- b, newn, _ = x.shape
167
- with record_function("splitter"):
168
- split = self.splitter(x).view(b,newn,self.quantizers,self.width)
169
- with record_function("unembed"):
170
- logits = torch.stack([embeddings[q].unembed(split[:,:,q]) for q in range(self.quantizers)], dim=1)
171
- return logits
172
-
173
- def rand(start, end):
174
- return random.random() * (end - start) + start
175
-
176
- @dataclasses.dataclass
177
- class Tunables:
178
- init_std :float = 9
179
- embeddings_std :float = 0.2
180
- embeddings_lr_scale: float = 10
181
- output_mult :float = 5.6
182
- # FIXME: try separate mults for self and cross attention
183
- query_mult :float = .3
184
- encoder_depth_ratio :float = 0.25
185
- linear_heads :bool = False
186
- rope :bool = True
187
-
188
- lr0 :float = 3e-3
189
- clip_gradient_norm :float = 2
190
- weight_decay :float = 1e-3
191
- warmup_steps :float = 2000
192
-
193
- random :bool = False
194
-
195
- def __post_init__(self):
196
- # randomize the hyperparams if requested
197
- if self.random:
198
- self.init_std = 2*10**rand(0,1)
199
- self.embeddings_std = 10**rand(-1.7,-0.22)
200
- self.embeddings_lr_scale = 2**rand(2,4)
201
- self.output_mult = 2**rand(1.5,3)
202
- self.query_mult = 2**rand(-3,-1.3)
203
- self.encoder_depth_ratio = random.choice([0.25,0.5])
204
- self.linear_heads = False
205
- self.rope = True
206
-
207
- self.lr0 = 3e-3
208
- self.clip_gradient_norm = 10**rand(-1,1)
209
- self.warmup_steps = 100*(10**rand(1.18,1.3))
210
-
211
- @staticmethod
212
- def upgrade(args):
213
- args = {k:v for k,v in args.items()}
214
- def old_default(name, value):
215
- if name not in args: args[name] = value
216
- old_default('rope', False)
217
- old_default('linear_heads', True)
218
- return args
219
-
220
- class SADelARTransformer(nn.Module):
221
- def __init__(self, depth=3, ctx_n=2250,
222
- stoks_len=750, stoks_codes=4097, stoks_width=None,
223
- spk_width=None,
224
- atoks_width=None,
225
- n_head=3, head_width=64, ffn_mult=4,
226
- quantizers=8, speaker_map={"1":0}, tunables=Tunables()):
227
- super().__init__()
228
- self.quantizers = quantizers
229
- self.codes = 1024
230
- width = n_head * head_width
231
- store_attr("depth,ctx_n,stoks_len,stoks_codes,stoks_width,spk_width,atoks_width,n_head,head_width,ffn_mult,quantizers,speaker_map")
232
- self.width = width
233
- self.base_width = 3 * head_width
234
- self.tunables = tunables
235
-
236
- if stoks_width is None: stoks_width = width
237
- if spk_width is None: spk_width = width
238
- self.emb_factor = width != stoks_width
239
- self.spk_factor = width != spk_width
240
-
241
- if tunables.rope:
242
- self.positional_embeddings = None
243
- else:
244
- self.register_buffer('positional_embeddings', sinusoids(ctx_n, width))
245
-
246
- # self.speaker_embedding = nn.Embedding(len(speaker_map), spk_width)
247
- self.semantic_embedding = nn.Embedding(stoks_codes, stoks_width)
248
- if self.emb_factor:
249
- self.emb_to_hidden = nn.Linear(stoks_width, width)
250
- self.hidden_to_emb = nn.Linear(width, stoks_width)
251
-
252
- if self.spk_factor:
253
- self.spk_to_hidden = nn.Linear(spk_width, width)
254
-
255
- qk_scale = self.tunables.query_mult * 8 / math.sqrt(head_width)
256
-
257
- encoder_depth = int(depth * 2 * tunables.encoder_depth_ratio)
258
- decoder_depth = depth * 2 - encoder_depth
259
- self.encoder = nn.Sequential(*[
260
- ResidualAttentionBlock(width, n_head, qk_scale=qk_scale, ffn_mult=ffn_mult, rope=tunables.rope) for _ in range(encoder_depth)
261
- ]) # FIXME: enclm requires causal attention here
262
- self.ln_post = LayerNorm(width)
263
-
264
- self.embds = DelSumEmbedding(
265
- pos_embs=self.positional_embeddings, length=ctx_n,
266
- n_head=n_head, head_width=head_width, atoks_width=atoks_width,
267
- quantizers=quantizers,
268
- )
269
- self.decoder = BaseDecoder(qk_scale=qk_scale, length=ctx_n,
270
- n_head=n_head, width=n_head * head_width,
271
- ffn_mult=ffn_mult, depth=decoder_depth,
272
- rope=tunables.rope)
273
- self.head = DelSumHead(n_head=n_head, head_width=head_width, quantizers=quantizers)
274
- for l in self.decoder.layers:
275
- l.cross_attn.key_subsampling = 3
276
- # for l in self.encoder:
277
- # l.attn.key_subsampling = 3
278
- # l.attn.query_subsampling = 3
279
-
280
- self.register_buffer('val_true', torch.zeros(self.quantizers).cuda())
281
- self.register_buffer('val_total', torch.zeros(self.quantizers).cuda())
282
- self.apply(self.init_transformer)
283
-
284
- def setup(self, device):
285
- pass
286
-
287
- def load_frozen_semantic_embeddings(self, vqmodel):
288
- with torch.no_grad():
289
- self.semantic_embedding.weight[:] = vqmodel.rq.layers[0]._codebook.embed[0]
290
- self.semantic_embedding.lr_scale = 0
291
-
292
- def load_frozen_acoustic_embeddings(self, amodel):
293
- for i in range(self.quantizers):
294
- self.decoder.embeddings[i].set_frozen_embeddings(amodel.quantizer.vq.layers[i].codebook)
295
-
296
- def init_transformer(self, m):
297
- if isinstance(m, LinearHead):
298
- m.no_weight_decay = True
299
- torch.nn.init.constant_(m.weight, 0)
300
- elif isinstance(m, QueryHead):
301
- m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
302
- torch.nn.init.constant_(m.weight, 0)
303
- elif isinstance(m, nn.Embedding):
304
- m.no_weight_decay = True
305
- m.lr_scale = self.tunables.embeddings_lr_scale
306
- std = self.tunables.embeddings_std
307
- torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
308
- # elif isinstance(m, EmbeddingProjector):
309
- # m.lr_scale = self.tunables.embeddings_lr_scale #1/(m.weight.shape[1] / self.base_width)
310
- # m.lr_scale = 2/(m.weight.shape[1] / self.base_width)
311
- # std = self.tunables.init_std / m.weight.shape[1]
312
- # torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
313
- elif isinstance(m, nn.Linear):
314
- m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
315
- std = self.tunables.init_std / m.weight.shape[1]
316
- torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
317
- if m.bias is not None:
318
- torch.nn.init.trunc_normal_(m.bias, std=std, a=-3*std, b=3*std)
319
- elif isinstance(m, nn.LayerNorm):
320
- m.no_weight_decay = True
321
- torch.nn.init.constant_(m.bias, 0)
322
- torch.nn.init.constant_(m.weight, 1)
323
-
324
- def embed_stoks(self, Stoks):
325
- b,n = Stoks.shape
326
- if self.stoks_len == 1500:
327
- # converts 50 toks/s to 75 toks/s by adding padding between every two tokens
328
- x = Stoks.reshape(b,n//2,2)
329
- x = x.repeat_interleave(2, -1)[:,:,:3]
330
- x[:,:,1] = 1024
331
- x = x.reshape(b,n//2*3)
332
- else:
333
- # it's a lot easier with 25 toks/s
334
- # x = Stoks.repeat_interleave(3, -1)
335
- x = Stoks
336
- # embed semantic tokens
337
- Sembs = self.semantic_embedding(x.to(torch.long))
338
- if self.emb_factor:
339
- Sembs = self.emb_to_hidden(Sembs)
340
- return Sembs
341
-
342
- def _encoder(self, semb, positions):
343
- x = semb
344
- for l in self.encoder: x = l(x, positions)
345
- return self.ln_post(x)
346
-
347
- def run_encoder(self, Stoks, speakers):
348
- semb = self.embed_stoks(Stoks)
349
- with record_function("encoder"):
350
- if self.positional_embeddings is not None: semb = semb + self.positional_embeddings
351
- positions = torch.arange(0, semb.shape[1], device=semb.device)
352
- xenc = self._encoder(semb, positions)
353
- if self.training:
354
- enc_logits = (self.hidden_to_emb(xenc) @ self.semantic_embedding.weight.to(xenc.dtype).T).float()
355
- enc_logits = enc_logits * self.tunables.output_mult / (self.width / self.base_width)
356
- else:
357
- enc_logits = None
358
- # print(xenc.shape, speakers.shape)
359
- spk_embs = F.normalize(speakers, dim=-1) # use extracted embeddings
360
- if self.spk_factor: spk_embs = self.spk_to_hidden(spk_embs)
361
- return xenc + spk_embs.unsqueeze(1), positions, enc_logits
362
-
363
- def forward(self, Stoks, Atoks, speakers, langs=None, out_stoks=None, noloss=False, xenc=None, xenc_positions=None, atoks_positions=None):
364
- if xenc is None:
365
- Atoks = Atoks.to(torch.long)
366
- out_stoks = out_stoks.to(torch.long)
367
- Atoks_gt = Atoks.clone()
368
- Atoks_gt[Atoks == -100] = 1024
369
- xenc, enc_logits = self.run_encoder(Stoks, speakers)
370
- else:
371
- Atoks_gt = Atoks
372
- with record_function("decoder"):
373
- embs = self.embds(Atoks, xenc)
374
- if atoks_positions is None: atoks_positions = torch.arange(0, embs.shape[1], device=embs.device)
375
- x = self.decoder(embs, atoks_positions, xenc, xenc_positions)
376
- logits = self.head(x, embeddings=self.embds.embeddings)
377
- logits *= self.tunables.output_mult / (self.width / self.base_width)
378
-
379
- if noloss:
380
- return logits
381
-
382
- with record_function("loss"):
383
- N = Atoks.shape[-1]
384
- loss = 0
385
- for i in range(self.quantizers):
386
- loss += F.cross_entropy(logits[:,i,i:].reshape(-1,logits.shape[-1]), Atoks[:,i,:N-i].reshape(-1))
387
- if self.training and i == 0:
388
- loss *= 5
389
- loss /= self.quantizers
390
- if self.training:
391
- loss += 0.1 * F.cross_entropy(enc_logits.transpose(-1,-2), out_stoks)
392
-
393
- if not self.training:
394
- for i in range(self.quantizers):
395
- Atoks_i = Atoks[:,i,:N-i]
396
- valid_Atoks = Atoks_i != -100
397
- self.val_true[i] += (logits[:,i,i:].argmax(-1)[valid_Atoks] == Atoks_i[valid_Atoks]).float().sum()
398
- self.val_total[i] += valid_Atoks.float().sum()
399
-
400
- return logits, loss
401
-
402
- def get_metrics(self):
403
- metrics = {
404
- f'acc_{i}':x.item() for i,x in enumerate(self.val_true / self.val_total)
405
- }
406
- self.val_true[:] = 0
407
- self.val_total[:] = 0
408
- return metrics
409
-
410
- #
411
- # inference
412
- #
413
- @classmethod
414
- def load_model(cls, ref="collabora/whisperspeech:s2a-q4-small-en+pl.model",
415
- repo_id=None, filename=None, local_filename=None):
416
- if repo_id is None and filename is None and local_filename is None:
417
- if ":" in ref:
418
- repo_id, filename = ref.split(":", 1)
419
- else:
420
- local_filename = ref
421
- if not local_filename:
422
- local_filename = hf_hub_download(repo_id=repo_id, filename=filename)
423
- spec = torch.load(local_filename)
424
- if '_extra_state' not in spec['state_dict']: spec['state_dict']['_extra_state'] = { 'speaker_map': spec['config']['speaker_map'] }
425
- model = cls(**spec['config'], tunables=Tunables(**Tunables.upgrade(spec['tunables'])))
426
- model.load_state_dict(spec['state_dict'])
427
- model.eval()
428
- return model
429
-
430
- def get_extra_state(self):
431
- return { 'speaker_map': self.speaker_map }
432
-
433
- def set_extra_state(self, st):
434
- self.speaker_map = st['speaker_map']
435
-
436
- def load_checkpoint(self, local_filename):
437
- spec = torch.load(local_filename, map_location='cpu')
438
- assert 'pytorch-lightning_version' in spec, 'not a valid PyTorch Lightning checkpoint'
439
- state_dict = {k.replace('model.', ''):v
440
- for k,v in spec['state_dict'].items()}
441
- self.load_state_dict(state_dict)
442
- return self
443
-
444
- def save_model(self, fname):
445
- torch.save(dict(config = self.__stored_args__,
446
- tunables = dataclasses.asdict(self.tunables),
447
- state_dict = self.state_dict()), fname)
448
-
449
- def switch_dtypes(self, dtype=torch.float16):
450
- self.dtype = dtype
451
- for n,m in self.named_modules():
452
- # convert every leaf layer apart from the LayerNorms
453
- if isinstance(m, (nn.Linear, nn.Embedding)):
454
- m.to(dtype)
455
- # take care of buffers ([kv]_cache, masks) that are not in the leaf layers
456
- for bn,b in m.named_buffers(recurse=False):
457
- setattr(m,bn,b.to(dtype))
458
-
459
- def optimize(self, max_batch_size=1, dtype=torch.float16, torch_compile=True):
460
- for emb in self.embds.embeddings:
461
- emb.convert_for_eval()
462
- for l in self.encoder:
463
- l.attn.convert_for_eval()
464
- for l in self.decoder.layers:
465
- l.attn.convert_for_eval()
466
- l.cross_attn.convert_for_eval()
467
- l.setup_kv_cache(max_batch_size, self.ctx_n, self.stoks_len)
468
- self.switch_dtypes(dtype)
469
- if torch_compile:
470
- self.generate_next = torch.compile(self.generate_next, mode="reduce-overhead", fullgraph=True)
471
-
472
- @property
473
- def device(self):
474
- return next(self.parameters()).device
475
-
476
- # from https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
477
- def multinomial_sample_one_no_sync(self, probs_sort): # Does multinomial sampling without a cuda synchronization
478
- q = torch.empty_like(probs_sort).exponential_(1)
479
- return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
480
-
481
- def logits_to_probs(self, logits, T=1.0, top_k=None):
482
- logits = logits / max(T, 1e-5)
483
-
484
- if top_k is not None:
485
- v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
486
- pivot = v.select(-1, -1).unsqueeze(-1)
487
- logits = torch.where(logits < pivot, -float("Inf"), logits)
488
- probs = torch.nn.functional.softmax(logits, dim=-1)
489
- return probs
490
-
491
- def sample(self, logits, T=1.0, top_k=None):
492
- probs = self.logits_to_probs(logits[0,:,-1], T, top_k)
493
- idx_next = self.multinomial_sample_one_no_sync(probs)
494
- return idx_next
495
-
496
- def generate_one(self, toks, positions, langs, xenc, xenc_positions, T, top_k):
497
- probs = self(None, toks, None, langs, noloss=True, xenc=xenc, xenc_positions=xenc_positions, atoks_positions=positions)
498
- return self.sample(probs, T, top_k)
499
-
500
- def generate_next(self, *args, **kwargs):
501
- return self.generate_one(*args, **kwargs)
502
-
503
- @torch.no_grad()
504
- def generate(self, stoks, speakers, langs=None, N=None, T=0.7, top_k=None, show_progress_bar=True, step=None, subsample_enc=False):
505
- dev = self.device
506
- N = N or len(stoks) * 3
507
- stoks = F.pad(stoks.to(dev), (1, self.stoks_len - len(stoks)-1), value=self.stoks_codes-1).unsqueeze(0)
508
- speakers = speakers.to(device=dev, dtype=self.dtype)
509
- toks = torch.full((1,self.quantizers,2250), self.codes+1, dtype=torch.long, device=dev)
510
- it = range(1,min(N,2250-1))
511
- if show_progress_bar: it = progress_bar(it)
512
- with record_function("encode"):
513
- xenc, xenc_positions, _ = self.run_encoder(stoks, speakers)
514
- toks_positions = torch.arange(N, device=dev)
515
- with record_function("prefill"):
516
- toks[0,0,1] = self.generate_one(toks[:,:,:1], toks_positions[:1], langs, xenc, xenc_positions, T, top_k)[0,0]
517
- with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
518
- for i in it:
519
- with record_function("generate_one"):
520
- toks[0,:i+1,i+1] = self.generate_next(toks[:,:,i:i+1], toks_positions[i:i+1], langs, xenc, xenc_positions, T, top_k)[:i+1,0]
521
-
522
- # for profiling, debugging or early exit
523
- if step is not None: step()
524
- # shift tokens
525
- toks = toks[:,:,1:N]
526
- for j in range(self.quantizers):
527
- toks[0, j] = torch.roll(toks[0, j], -j)
528
- return toks[0]
529
-
530
- # %% ../nbs/4B. Multi-language semantic to acoustic token modeling.ipynb 39
531
- def _make_model(size:str, quantizers:int=4, tunables:Tunables=Tunables(), **kwargs):
532
- kwargs = dict(quantizers=quantizers, tunables=tunables, **kwargs)
533
- if size == 'micro':
534
- return SADelARTransformer(depth=4, n_head=3, ffn_mult=2, **kwargs)
535
- if size == 'tiny-narrow':
536
- return SADelARTransformer(depth=4, n_head=6, ffn_mult=1, **kwargs)
537
- if size == 'tiny':
538
- return SADelARTransformer(depth=4, n_head=6, **kwargs)
539
- if size == 'base':
540
- return SADelARTransformer(depth=6, n_head=8, **kwargs)
541
- if size == 'base-deep':
542
- return SADelARTransformer(depth=9, n_head=8, **kwargs)
543
- if size == 'base-wide':
544
- return SADelARTransformer(depth=6, n_head=12, **kwargs)
545
- if size == 'small/2':
546
- return SADelARTransformer(depth=9, n_head=12, **kwargs)
547
- if size == 'small':
548
- return SADelARTransformer(depth=12, n_head=12, **kwargs)
549
- if size == 'medium':
550
- return SADelARTransformer(depth=24, n_head=16, **kwargs)
551
-
552
- def make_model(size:str, quantizers:int=4, frozen_embeddings_model:str=None, frozen_acoustic_embeddings:bool=False, spk_width:int=None, tunables:Tunables=Tunables(), dataset=None):
553
- from encodec.model import EncodecModel
554
- from whisperspeech import vq_stoks
555
-
556
- amodel = EncodecModel.encodec_model_24khz() if frozen_acoustic_embeddings else None
557
- vqmodel = vq_stoks.RQBottleneckTransformer.load_model(frozen_embeddings_model) if frozen_embeddings_model else None
558
- model = _make_model(size, quantizers, tunables,
559
- spk_width=spk_width,
560
- atoks_width=amodel and amodel.quantizer.vq.layers[0]._codebook.embed.shape[-1],
561
- stoks_codes=vqmodel.vq_codes+1, stoks_width=vqmodel.rq.layers[0]._codebook.embed[0].shape[-1])
562
- if vqmodel: model.load_frozen_semantic_embeddings(vqmodel)
563
- if amodel: model.load_frozen_acoustic_embeddings(amodel)
564
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
whisperspeech/t2s_up_wds.py DELETED
@@ -1,442 +0,0 @@
1
- # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/5B. Text to semantic token modeling.ipynb.
2
-
3
- # %% auto 0
4
- __all__ = ['load_datasets', 'rand', 'Tunables', 'Encoder', 'Decoder', 'TSARTransformer', 'make_model']
5
-
6
- # %% ../nbs/5B. Text to semantic token modeling.ipynb 1
7
- import dataclasses
8
- import random
9
- import math
10
- import torch
11
- import torch.nn as nn
12
- import torch.nn.functional as F
13
- from torch.profiler import record_function
14
-
15
- from huggingface_hub import hf_hub_download
16
- from fastcore.basics import store_attr
17
- from fastprogress import progress_bar
18
-
19
- import webdataset as wds
20
-
21
- # %% ../nbs/5B. Text to semantic token modeling.ipynb 2
22
- from pathlib import Path
23
- import pylab as plt
24
- import pandas as pd
25
- import numpy as np
26
-
27
- # %% ../nbs/5B. Text to semantic token modeling.ipynb 3
28
- import whisper
29
- from whisperspeech.train import *
30
- from whisperspeech.modules import *
31
- from whisperspeech import vq_stoks
32
-
33
- # %% ../nbs/5B. Text to semantic token modeling.ipynb 8
34
- import re
35
-
36
- class CharTokenizer:
37
- """Trivial tokenizer – just use UTF-8 bytes"""
38
- eot = 0
39
-
40
- def encode(self, txt):
41
- return list(bytes(txt.strip(), 'utf-8'))
42
-
43
- def decode(self, tokens):
44
- return bytes(tokens).decode('utf-8')
45
-
46
- def tokenizer(ikey, okey, length):
47
- """Tokenizes a transcript"""
48
- tok = CharTokenizer()
49
- def _tokenizer(samples):
50
- for s in samples:
51
- toks = torch.tensor(tok.encode(s[ikey]))
52
- s[okey] = F.pad(toks, (0, length - toks.shape[-1]), value=tok.eot)
53
- yield s
54
- return _tokenizer
55
-
56
- def ar_padder(ikey, okey, length, pad_token):
57
- """Pads the tokens for autoregresive training"""
58
- def _ar_padder(samples):
59
- for s in samples:
60
- toks = s[ikey]
61
- if isinstance(toks, (list, np.ndarray)): toks = torch.tensor(toks)
62
- toks = toks.to(torch.long)
63
- s['in_' +okey] = F.pad(toks, (1, length - toks.shape[-1] - 1), value=pad_token)
64
- s['out_'+okey] = F.pad(toks, (0, length - toks.shape[-1]), value=pad_token)
65
- yield s
66
- return _ar_padder
67
-
68
- def char_per_seconder(txt_key, stoks_key, cps_key, stoks_per_second=25):
69
- """Adds the characters per second metric to the input data"""
70
- def _char_per_seconder(samples):
71
- for s in samples:
72
- secs = s[stoks_key].shape[-1] / stoks_per_second
73
- s[cps_key] = len(s[txt_key]) / secs
74
- yield s
75
- return _char_per_seconder
76
-
77
- # %% ../nbs/5B. Text to semantic token modeling.ipynb 9
78
- def build_speaker_map(shards):
79
- speakers = set()
80
- for shard in shards:
81
- with open(shard+'.speakers.txt') as f: speakers = speakers.union(set(x.strip() for x in f.readlines()))
82
- return {id:i for i,id in enumerate(speakers)}
83
-
84
- def speaker_id_extractor(speaker_map):
85
- def _extractor(samples):
86
- for s in samples:
87
- s['speaker'] = torch.tensor(speaker_map[s['__key__'].split("/")[1]])
88
- yield s
89
- return _extractor
90
-
91
- # %% ../nbs/5B. Text to semantic token modeling.ipynb 10
92
- def load_datasets(
93
- input:str, # webdataset folder or shard list
94
- samples:int, # samples per epoch
95
- subsample:float=1, # use a fraction of the files
96
- val_samples:int=512,
97
- vq_codes:int=4096,
98
- ):
99
- if isinstance(input, (Path, str)):
100
- path = Path(input)
101
- if path.is_dir():
102
- glob = '*-t2s-*.tar.gz'
103
- else:
104
- glob = path.name
105
- path = path.parent
106
- input = Path(path).glob(glob)
107
- elif isinstance(input, list):
108
- pass
109
- else:
110
- raise ArgumentError("input should be either a list of a path with an optional glob specifier")
111
- shards = [str(x) for x in input]
112
-
113
- speaker_map = build_speaker_map(shards)
114
-
115
- def ds(shards, length):
116
- ds = wds.WebDataset(wds.ResampledShards(shards)).compose(
117
- wds.decode(),
118
- speaker_id_extractor(speaker_map),
119
- wds.select(lambda s: s['stoks.npy'].shape[-1] > 12), # select samples > .5s
120
- tokenizer('txt', 'ttoks', length=550),
121
- ar_padder('stoks.npy', 'stoks', length=750, pad_token=vq_codes-1),
122
- char_per_seconder('txt', 'stoks.npy', 'cps', stoks_per_second=25),
123
- wds.to_tuple('ttoks', 'speaker', 'cps', 'in_stoks', 'out_stoks'),
124
- wds.batched(64)
125
- )
126
- ds.speakers = speaker_map
127
- ds.total_samples = length
128
- ds.stoks_len = 750
129
- ds.stoks_codes = vq_codes
130
- ds.ttoks_len = 550
131
- return ds.compose(wds.slice(length // 64)).with_epoch(length // 64).with_length(length // 64)
132
-
133
- return (
134
- ds(shards[1:], samples),
135
- ds(shards[:1], val_samples),
136
- )
137
-
138
- # %% ../nbs/5B. Text to semantic token modeling.ipynb 14
139
- def rand(start, end):
140
- return random.random() * (end - start) + start
141
-
142
- @dataclasses.dataclass
143
- class Tunables:
144
- init_std :float = 1
145
- embeddings_std :float = .01
146
- embeddings_lr_scale: float = 5
147
- embedding_projector_lr_scale: float = 2.5
148
- output_mult :float = .35
149
- query_mult :float = 1
150
- encoder_depth_ratio :float = 0.25
151
- eot_dropout_p :float = .5
152
- cps_input: bool = True
153
- cps_bins: int = 32
154
-
155
- lr0 :float = 1.5e-3
156
- clip_gradient_norm :float = .2
157
- weight_decay :float = 1e-1
158
- warmup_steps :float = 4000
159
-
160
- random :bool = False
161
-
162
- def __post_init__(self):
163
- # randomize the hyperparams if requested
164
- if self.random:
165
- self.init_std = 10**rand(-1,1)
166
- self.embeddings_std = 10**rand(-3,-.7)
167
- self.embeddings_lr_scale = rand(2,6)
168
- self.output_mult = rand(0.25,0.65)
169
- self.query_mult = 2**rand(-2,3)
170
- self.encoder_depth_ratio = 0.25
171
-
172
- self.lr0 = rand(1,5)*1e-3
173
- self.clip_gradient_norm = 10**rand(-3,0)
174
- self.warmup_steps = 100*(10**rand(1,1.85))
175
-
176
- # %% ../nbs/5B. Text to semantic token modeling.ipynb 15
177
- class EmbeddingProjector(nn.Linear):
178
- pass
179
-
180
- # %% ../nbs/5B. Text to semantic token modeling.ipynb 16
181
- class Encoder(nn.Module):
182
- def __init__(self, depth=6, width=384, n_head=6, length=1500, codes=1024, emb_width=384, ffn_mult=4, pos_embs=None, tunables=Tunables()):
183
- super().__init__()
184
- self.emb_width = emb_width
185
-
186
- self.emb_factor = width != emb_width
187
-
188
- self.embedding = nn.Embedding(codes, emb_width)
189
- if self.emb_factor:
190
- self.emb_to_hidden = EmbeddingProjector(emb_width, width)
191
-
192
- if pos_embs is None: pos_embs = sinusoids(length, width)
193
- self.register_buffer("positional_embedding", pos_embs)
194
-
195
- self.layers = nn.Sequential(*[
196
- ResidualAttentionBlock(width, n_head,
197
- qk_scale=tunables.query_mult*8/math.sqrt(width/n_head), ffn_mult=ffn_mult) for _ in range(depth)
198
- ])
199
-
200
- self.ln_post = LayerNorm(width)
201
-
202
- def forward(self, Stoks):
203
- xin = self.embedding(Stoks)
204
- if self.emb_factor:
205
- xin = self.emb_to_hidden(xin)
206
-
207
- assert xin.shape[1:] == self.positional_embedding.shape, "incorrect semantic token shape"
208
- xin = (xin + self.positional_embedding).to(xin.dtype)
209
-
210
- return self.ln_post(self.layers(xin))
211
-
212
- # %% ../nbs/5B. Text to semantic token modeling.ipynb 17
213
- class Decoder(nn.Module):
214
- def __init__(self, depth=6, stoks_width=384, width=384, n_head=6, length=1500, codes=1024, ffn_mult=4, pos_embs=None, tunables=Tunables()):
215
- super().__init__()
216
- self.length = length
217
- self.codes = codes
218
- self.width = width
219
- self.stoks_width = stoks_width
220
-
221
- self.emb_factor = width != stoks_width
222
-
223
- # embed semantic tokens
224
- self.embedding = nn.Embedding(codes, stoks_width)
225
- if self.emb_factor:
226
- self.emb_to_hidden = EmbeddingProjector(stoks_width, width)
227
- self.hidden_to_emb = EmbeddingProjector(width, stoks_width)
228
-
229
- if pos_embs is None: pos_embs = sinusoids(length, width)
230
- self.register_buffer("positional_embedding", pos_embs)
231
-
232
- self.layers = nn.ModuleList([
233
- ResidualAttentionBlock(width, n_head, cross_attention=True,
234
- qk_scale=tunables.query_mult*8/math.sqrt(width/n_head), ffn_mult=ffn_mult) for _ in range(depth)
235
- ])
236
- self.ln_post = LayerNorm(width)
237
-
238
- def forward(self, Stoks, xenc, cps=None):
239
- Sembs = self.embedding(Stoks)
240
-
241
- if self.emb_factor:
242
- Sembs = self.emb_to_hidden(Sembs)
243
-
244
- xin = (Sembs + self.positional_embedding[:Sembs.shape[1]]).to(xenc.dtype)
245
- if cps is not None: xin = xin + cps
246
-
247
- x = xin
248
- for l in self.layers: x = l(x, xenc, causal=True)
249
-
250
- x = self.ln_post(x)
251
-
252
- if self.emb_factor:
253
- x = self.hidden_to_emb(x)
254
-
255
- logits = (x @ self.embedding.weight.to(x.dtype).T).float()
256
- return logits
257
-
258
- # %% ../nbs/5B. Text to semantic token modeling.ipynb 18
259
- class TSARTransformer(nn.Module):
260
- def __init__(self, depth=6, n_head=6, head_width=64, ffn_mult=4, language='en',
261
- ttoks_len=200, ttoks_codes=50364, ttoks_width=None,
262
- stoks_len=1500, stoks_codes=1024, stoks_width=None,
263
- tunables=Tunables()):
264
- assert language == 'en', "only english is supported right now"
265
- super().__init__()
266
- store_attr("depth,n_head,head_width,ffn_mult,stoks_width,ttoks_width,ttoks_len,stoks_len,ttoks_codes,stoks_codes,language")
267
-
268
- width = n_head * head_width
269
- self.width = width
270
- self.base_width = 3 * head_width
271
- self.tunables = tunables
272
- if self.stoks_width is None: self.stoks_width = self.width
273
- if self.ttoks_width is None: self.ttoks_width = self.width
274
-
275
- if tunables.cps_input:
276
- self.cps_embeddings = nn.Embedding(tunables.cps_bins, self.width)
277
- else:
278
- self.cps_embeddings = None
279
-
280
- encoder_depth = int(depth * 2 * tunables.encoder_depth_ratio)
281
- decoder_depth = depth * 2 - encoder_depth
282
- tformer_args = dict(width=width, n_head=n_head, ffn_mult=ffn_mult, tunables=tunables)
283
- self.encoder = Encoder(length=ttoks_len, codes=ttoks_codes, emb_width=self.ttoks_width, depth=encoder_depth, **tformer_args)
284
- self.decoder = Decoder(length=stoks_len, codes=stoks_codes, stoks_width=self.stoks_width, depth=decoder_depth, **tformer_args)
285
-
286
- self.tokenizer = None
287
-
288
- self.apply(self.init_transformer)
289
-
290
- def load_frozen_semantic_embeddings(self, vqmodel):
291
- with torch.no_grad():
292
- self.decoder.embedding.weight[:] = vqmodel.rq.layers[0]._codebook.embed[0]
293
- self.decoder.embedding.lr_scale = 0
294
-
295
- def setup(self, device):
296
- pass
297
-
298
- def init_transformer(self, m):
299
- if isinstance(m, LinearHead):
300
- m.no_weight_decay = True
301
- torch.nn.init.constant_(m.weight, 0)
302
- elif isinstance(m, QueryHead):
303
- m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
304
- torch.nn.init.constant_(m.weight, 0)
305
- elif isinstance(m, nn.Embedding):
306
- m.no_weight_decay = True
307
- m.lr_scale = self.tunables.embeddings_lr_scale
308
- std = self.tunables.embeddings_std
309
- torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
310
- elif isinstance(m, EmbeddingProjector):
311
- m.lr_scale = self.tunables.embedding_projector_lr_scale
312
- std = self.tunables.init_std
313
- torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
314
- elif isinstance(m, nn.Linear):
315
- m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
316
- std = self.tunables.init_std / m.weight.shape[1]
317
- torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
318
- if m.bias is not None:
319
- torch.nn.init.trunc_normal_(m.bias, std=std, a=-3*std, b=3*std)
320
- elif isinstance(m, nn.LayerNorm):
321
- m.no_weight_decay = True
322
- torch.nn.init.constant_(m.bias, 0)
323
- torch.nn.init.constant_(m.weight, 1)
324
-
325
- def forward(self, Ttoks, speakers, cpss, in_stoks, out_stoks=None, loss=True):
326
- with record_function("encoder"):
327
- xenc = self.encoder(Ttoks.to(torch.long))
328
- with record_function("decoder"):
329
- if self.cps_embeddings:
330
- cps_bin = (cpss / 20 * self.tunables.cps_bins).to(torch.long)
331
- cps_bin[cps_bin >= self.tunables.cps_bins] = self.tunables.cps_bins-1
332
- cps_embs = self.cps_embeddings(cps_bin).unsqueeze(1)
333
- else:
334
- cps_embs = None
335
- logits = self.decoder(in_stoks, xenc, cps=cps_embs) * self.tunables.output_mult / (self.width / self.base_width)
336
- if loss is not None:
337
- with record_function("loss"):
338
- loss = F.cross_entropy(logits.transpose(-1,-2), out_stoks)#, reduction='none')
339
- return logits, loss
340
-
341
- #
342
- # inference
343
- #
344
- @classmethod
345
- def load_model(cls, repo_id="collabora/whisperspeech", filename="t2s_up_wds.model", local_filename=None):
346
- if not local_filename:
347
- local_filename = hf_hub_download(repo_id=repo_id, filename=filename)
348
- spec = torch.load(local_filename)
349
- model = cls(**spec['config'], tunables=Tunables(**spec['tunables']))
350
- model.load_state_dict(spec['state_dict'])
351
- model.eval()
352
- return model
353
-
354
- def load_checkpoint(self, local_filename):
355
- spec = torch.load(local_filename, map_location='cpu')
356
- assert 'pytorch-lightning_version' in spec, 'not a valid PyTorch Lightning checkpoint'
357
- state_dict = {k.replace('model.', ''):v
358
- for k,v in spec['state_dict'].items()}
359
- self.load_state_dict(state_dict)
360
- return self
361
-
362
- def save_model(self, fname):
363
- torch.save(dict(config = self.__stored_args__,
364
- tunables = dataclasses.asdict(self.tunables),
365
- state_dict = self.state_dict()), fname)
366
-
367
- def ensure_tokenizer(self):
368
- assert not self.training
369
- if self.tokenizer is None: self.tokenizer = CharTokenizer()
370
- #whisper.tokenizer.get_tokenizer(multilingual=True)
371
-
372
- @property
373
- def device(self):
374
- return next(self.parameters()).device
375
-
376
- @torch.no_grad()
377
- def generate(self, txt, cps=15, N=None, T=0.7, top_k=None, show_progress_bar=True):
378
- self.ensure_tokenizer()
379
- N = N or self.stoks_len
380
- dev = self.device
381
- ttoks = torch.tensor(self.tokenizer.encode(txt), device=dev)
382
- ttoks = F.pad(ttoks, (0, self.ttoks_len - len(ttoks)), value=self.tokenizer.eot).unsqueeze(0)
383
- cpss = torch.tensor([cps], device=dev)
384
- toks = torch.zeros((1,N), dtype=torch.long, device=dev)
385
- toks[0,0] = self.stoks_codes-1
386
- it = range(1,N)
387
- if show_progress_bar: it = progress_bar(it)
388
- for i in it:
389
- p, _ = self(ttoks, None, cpss, toks[:,:i], loss=None)
390
- last_p = p[0,-1]
391
- if top_k:
392
- last_p[last_p < torch.topk(last_p, top_k).values[-1,None]] = -torch.inf
393
- tok = torch.multinomial((last_p / float(T)).softmax(-1), 1)
394
- toks[0,i] = tok
395
- if toks[0,i] == self.stoks_codes-1: return toks[0,1:i]
396
- return toks[0,1:]
397
-
398
- @torch.no_grad()
399
- def generate_batch(self, txts, N=None, T=1.1, top_k=7, show_progress_bar=True):
400
- self.ensure_tokenizer()
401
- N = self.stoks_len
402
- dev = self.device
403
- ttoks = []
404
- for txt in txts:
405
- ttoks_ = torch.tensor(self.tokenizer.encode(txt), device=dev)
406
- ttoks_ = F.pad(ttoks_, (0, self.ttoks_len - len(ttoks_)), value=self.tokenizer.eot).unsqueeze(0)
407
- ttoks.append(ttoks_)
408
- ttoks = torch.cat(ttoks, dim=0)
409
- toks = torch.zeros((len(ttoks),N), dtype=torch.long, device=dev)
410
- it = range(N)
411
- if show_progress_bar: it = progress_bar(it)
412
- for i in it:
413
- p, _ = self(ttoks, toks[:,:i], loss=None)
414
- last_p = p[:,-1]
415
- if top_k:
416
- last_p[last_p < torch.topk(last_p, top_k).values[:,-1,None]] = -torch.inf
417
- tok = torch.multinomial((last_p / float(T)).softmax(-1), 1)
418
- toks[:,i] = tok[:,0]
419
- if (toks[:,i] == self.stoks_codes-1).all(): return toks[:,:i]
420
- return toks
421
-
422
- # %% ../nbs/5B. Text to semantic token modeling.ipynb 19
423
- def _make_model(size:str, tunables:Tunables=Tunables(), dataset=None, **kwargs):
424
- kwargs = dict(stoks_len = dataset.stoks_len, ttoks_len = dataset.ttoks_len, tunables=tunables, **kwargs)
425
- if 'stoks_codes' not in kwargs: kwargs['stoks_codes'] = dataset.stoks_codes
426
- if size == 'micro':
427
- return TSARTransformer(depth=2, n_head=3, ffn_mult=1, **kwargs)
428
- if size == 'tiny':
429
- return TSARTransformer(depth=4, n_head=6, **kwargs)
430
- if size == 'base':
431
- return TSARTransformer(depth=6, n_head=8, **kwargs)
432
- if size == 'small':
433
- return TSARTransformer(depth=12, n_head=16, **kwargs)
434
-
435
- def make_model(size:str, frozen_embeddings_model:str=None, tunables:Tunables=Tunables(), dataset:torch.utils.data.Dataset=None):
436
- if frozen_embeddings_model:
437
- vqmodel = vq_stoks.RQBottleneckTransformer.load_model(frozen_embeddings_model)
438
- model = _make_model(size, tunables, dataset, stoks_codes=vqmodel.vq_codes+1, stoks_width=vqmodel.rq.layers[0]._codebook.embed[0].shape[-1])
439
- model.load_frozen_semantic_embeddings(vqmodel)
440
- else:
441
- model = _make_model(size, quantizers, tunables, dataset)
442
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
whisperspeech/t2s_up_wds_mlang_enclm.py DELETED
@@ -1,519 +0,0 @@
1
- # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/5B. Multi-lang text to semantic token modeling.ipynb.
2
-
3
- # %% auto 0
4
- __all__ = ['load_dataset', 'rand', 'Tunables', 'T2SEmbedding', 'Encoder', 'TSARTransformer', 'make_model']
5
-
6
- # %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 1
7
- import dataclasses
8
- import random
9
- import math
10
- import itertools
11
- import torch
12
- import torch.nn as nn
13
- import torch.nn.functional as F
14
- from torch.profiler import record_function
15
-
16
- from huggingface_hub import hf_hub_download
17
- from fastcore.basics import store_attr
18
- from fastprogress import progress_bar
19
-
20
- from pathlib import Path
21
-
22
- # %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 2
23
- from whisperspeech.modules import *
24
- from whisperspeech import languages
25
-
26
- # %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 6
27
- import re
28
-
29
- class CharTokenizer:
30
- """Trivial tokenizer – just use UTF-8 bytes"""
31
- eot = 0
32
-
33
- def encode(self, txt):
34
- return list(bytes(txt.strip(), 'utf-8'))
35
-
36
- def decode(self, tokens):
37
- return bytes(tokens).decode('utf-8')
38
-
39
- def tokenizer(ikey, okey, length):
40
- """Tokenizes a transcript"""
41
- tok = CharTokenizer()
42
- def _tokenizer(samples):
43
- for s in samples:
44
- toks = torch.tensor(tok.encode(s[ikey]))
45
- s[okey] = F.pad(toks, (0, length - toks.shape[-1]), value=tok.eot)
46
- yield s
47
- return _tokenizer
48
-
49
- def ar_padder(ikey, okey, length, pad_token):
50
- """Pads the tokens for autoregresive training"""
51
- import numpy as np
52
-
53
- def _ar_padder(samples):
54
- for s in samples:
55
- toks = s[ikey]
56
- if isinstance(toks, (list, np.ndarray)): toks = torch.tensor(toks)
57
- toks = toks.to(torch.long)
58
- s['in_' +okey] = F.pad(toks, (1, length - toks.shape[-1] - 1), value=pad_token)
59
- s['out_'+okey] = F.pad(toks, (0, length - toks.shape[-1]), value=pad_token)
60
- yield s
61
- return _ar_padder
62
-
63
- def char_per_seconder(txt_key, stoks_key, cps_key, stoks_per_second=25):
64
- """Adds the characters per second metric to the input data"""
65
- def _char_per_seconder(samples):
66
- for s in samples:
67
- secs = s[stoks_key].shape[-1] / stoks_per_second
68
- s[cps_key] = len(s[txt_key]) / secs
69
- yield s
70
- return _char_per_seconder
71
-
72
- # %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 7
73
- def load_dataset(
74
- txt_shard_spec:str, # transcription webdataset shards
75
- stoks_shard_dir:str, # stoks webdataset base dir
76
- samples:int, # samples per epoch
77
- txt_kind:str='small.en-txt',
78
- vq_codes:int=4096,
79
- language:str='en',
80
- weight:float=1,
81
- validation:bool=False,
82
- exclude_files:str=None,
83
- ):
84
- import webdataset as wds
85
- from whisperspeech import utils
86
-
87
- shards = utils.shard_glob(txt_shard_spec)
88
- excludes = {x for file in exclude_files.split() for x in utils.readlines(file)} if exclude_files else set()
89
-
90
- language = languages.to_id(language)
91
-
92
- def set_language(x):
93
- x['language'] = language
94
- return x
95
-
96
- same_on_all_nodes = lambda urls: urls # will only be used for validation
97
- ds = wds.WebDataset(shards, resampled=not validation, nodesplitter=same_on_all_nodes).compose(
98
- wds.decode(),
99
- utils.merge_in(utils.derived_dataset('eqvad-stoks', base=txt_kind, suffix='', dir=stoks_shard_dir)),
100
- # discard validation samples, select samples > .5s
101
- wds.select(lambda s: s['__key__'] not in excludes and s['stoks.npy'].shape[-1] > 12),
102
- tokenizer('txt', 'ttoks', length=550),
103
- ar_padder('stoks.npy', 'stoks', length=750, pad_token=vq_codes-1),
104
- ar_padder('ttoks', 'ttoks', length=550, pad_token=CharTokenizer.eot),
105
- char_per_seconder('txt', 'stoks.npy', 'cps', stoks_per_second=25),
106
- wds.map(set_language),
107
- wds.to_tuple('in_ttoks', 'out_ttoks', 'language', 'cps', 'in_stoks', 'out_stoks'),
108
- wds.shuffle(20000, initial=20000),
109
- wds.batched(64)
110
- )
111
- if validation:
112
- ds = ds.slice(samples // 64)
113
- ds.total_samples = samples
114
- ds.stoks_len = 750
115
- ds.stoks_codes = vq_codes
116
- ds.ttoks_len = 550
117
- ds.weight = weight
118
-
119
- return ds
120
-
121
- # %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 14
122
- def rand(start, end):
123
- return random.random() * (end - start) + start
124
-
125
- @dataclasses.dataclass
126
- class Tunables:
127
- init_std :float = 1
128
- embeddings_std :float = .01
129
- embeddings_lr_scale: float = 5
130
- embedding_projector_lr_scale: float = 2.5
131
- output_mult :float = .35
132
- query_mult :float = 1
133
- encoder_depth_ratio :float = 0.25
134
- eot_dropout_p :float = .5
135
- cps_input: bool = True
136
- cps_bins: int = 32
137
-
138
- lr0 :float = 1.5e-3
139
- clip_gradient_norm :float = .2
140
- weight_decay :float = 1e-1
141
- warmup_steps :float = 4000
142
-
143
- random :bool = False
144
-
145
- def __post_init__(self):
146
- # randomize the hyperparams if requested
147
- if self.random:
148
- self.init_std = 10**rand(-1,1)
149
- self.embeddings_std = 10**rand(-3,-.7)
150
- self.embeddings_lr_scale = rand(2,6)
151
- self.output_mult = rand(0.25,0.65)
152
- self.query_mult = 2**rand(-2,3)
153
- self.encoder_depth_ratio = 0.25
154
-
155
- self.lr0 = rand(1,5)*1e-3
156
- self.clip_gradient_norm = 10**rand(-3,0)
157
- self.warmup_steps = 100*(10**rand(1,1.85))
158
-
159
- # %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 15
160
- class T2SEmbedding(nn.Module):
161
- def __init__(self, length=1500, codes=1024, width=384, pos_embs=None, stoks_width=384):
162
- super().__init__()
163
- self.embedding = FlexEmbeddings(codes, width, special_codes=1, frozen_width=stoks_width)
164
- if pos_embs is None: pos_embs = sinusoids(length, width)
165
- self.register_buffer("positional_embedding", pos_embs)
166
-
167
- def forward(self, Stoks, xenc, cps=None, offset=0):
168
- Sembs = self.embedding(Stoks)
169
- xin = (Sembs + self.positional_embedding[offset : offset + Sembs.shape[1]]).to(xenc.dtype)
170
- if cps is not None: xin = xin + cps
171
- return xin, offset
172
-
173
- # %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 16
174
- class Encoder(nn.Module):
175
- def __init__(self, depth=6, width=384, n_head=6, length=1500, codes=1024, emb_width=384, ffn_mult=4, pos_embs=None, tunables=Tunables()):
176
- super().__init__()
177
- self.emb_width = emb_width
178
-
179
- self.embedding = FlexEmbeddings(codes, width, frozen_width=emb_width)
180
-
181
- if pos_embs is None: pos_embs = sinusoids(length, width)
182
- self.register_buffer("positional_embedding", pos_embs)
183
-
184
- self.layers = nn.ModuleList([
185
- ResidualAttentionBlock(width, n_head,
186
- qk_scale=tunables.query_mult*8/math.sqrt(width/n_head), ffn_mult=ffn_mult) for _ in range(depth)
187
- ])
188
-
189
- self.ln_post = LayerNorm(width)
190
-
191
- mask = torch.empty(length, length).fill_(-torch.inf).triu_(1)
192
- self.register_buffer("mask", mask, persistent=False)
193
-
194
- def forward(self, Stoks, positions, lang_emb=None):
195
- xin = self.embedding(Stoks)
196
-
197
- if lang_emb is not None: xin += lang_emb
198
-
199
- # assert xin.shape[1:] == self.positional_embedding.shape, "incorrect semantic token shape"
200
- x = (xin +
201
- self.positional_embedding[positions]).to(xin.dtype)
202
-
203
- for l in self.layers: x = l(x, positions, causal=False, mask=self.mask)
204
-
205
- return self.ln_post(x)
206
-
207
- # %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 17
208
- class TSARTransformer(nn.Module):
209
- def __init__(self, depth=6, n_head=6, head_width=64, ffn_mult=4,
210
- ttoks_len=200, ttoks_codes=256, ttoks_width=None,
211
- stoks_len=1500, stoks_codes=1024, stoks_width=None,
212
- tunables=Tunables()):
213
- super().__init__()
214
- store_attr("depth,n_head,head_width,ffn_mult,stoks_width,ttoks_width,ttoks_len,stoks_len,ttoks_codes,stoks_codes")
215
-
216
- width = n_head * head_width
217
- self.width = width
218
- self.base_width = 3 * head_width
219
- self.tunables = tunables
220
- if self.stoks_width is None: self.stoks_width = self.width
221
- if self.ttoks_width is None: self.ttoks_width = self.width
222
-
223
- self.lang_embeddings = nn.Embedding(len(languages.languages), width)
224
- if tunables.cps_input:
225
- self.cps_embeddings = nn.Embedding(tunables.cps_bins, self.width)
226
- else:
227
- self.cps_embeddings = None
228
-
229
- encoder_depth = int(depth * 2 * tunables.encoder_depth_ratio)
230
- decoder_depth = depth * 2 - encoder_depth
231
- tformer_args = dict(width=width, n_head=n_head, ffn_mult=ffn_mult, tunables=tunables)
232
- self.encoder = Encoder(length=ttoks_len, codes=ttoks_codes, emb_width=self.ttoks_width, depth=encoder_depth, **tformer_args)
233
- self.embeddings = T2SEmbedding(length=stoks_len, codes=stoks_codes, width=width, stoks_width=self.stoks_width)
234
-
235
- self.decoder = BaseDecoder(
236
- length=stoks_len,
237
- depth=decoder_depth,
238
- qk_scale=tunables.query_mult*8/math.sqrt(width/n_head),
239
- width=width, n_head=n_head, ffn_mult=ffn_mult,
240
- )
241
- self.tokenizer = None
242
-
243
- self.apply(self.init_transformer)
244
-
245
- def load_frozen_semantic_embeddings(self, vqmodel):
246
- self.embeddings.embedding.set_frozen_embeddings(vqmodel.rq.layers[0]._codebook.embed[0])
247
-
248
- def setup(self, device):
249
- pass
250
-
251
- def init_transformer(self, m):
252
- if isinstance(m, LinearHead):
253
- m.no_weight_decay = True
254
- torch.nn.init.constant_(m.weight, 0)
255
- elif isinstance(m, QueryHead):
256
- m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
257
- torch.nn.init.constant_(m.weight, 0)
258
- elif isinstance(m, nn.Embedding):
259
- m.no_weight_decay = True
260
- m.lr_scale = self.tunables.embeddings_lr_scale
261
- std = self.tunables.embeddings_std
262
- torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
263
- elif isinstance(m, EmbeddingProjector):
264
- m.lr_scale = self.tunables.embedding_projector_lr_scale
265
- std = self.tunables.init_std
266
- torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
267
- elif isinstance(m, nn.Linear):
268
- m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
269
- std = self.tunables.init_std / m.weight.shape[1]
270
- torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
271
- if m.bias is not None:
272
- torch.nn.init.trunc_normal_(m.bias, std=std, a=-3*std, b=3*std)
273
- elif isinstance(m, nn.LayerNorm):
274
- m.no_weight_decay = True
275
- torch.nn.init.constant_(m.bias, 0)
276
- torch.nn.init.constant_(m.weight, 1)
277
-
278
- def _embed_cps(self, cpss):
279
- if self.cps_embeddings is None: return None
280
-
281
- cps_bin = (cpss / 20 * self.tunables.cps_bins).to(torch.long)
282
- cps_bin[cps_bin >= self.tunables.cps_bins] = self.tunables.cps_bins-1
283
- return self.cps_embeddings(cps_bin).unsqueeze(1)
284
-
285
- def run_encoder(self, in_ttoks, languages, cpss):
286
- if len(languages.shape) != 3: lang_embs = self.lang_embeddings(languages)
287
- else: lang_embs = languages
288
- if len(lang_embs.shape) == 2: lang_embs = lang_embs.unsqueeze(1)
289
-
290
- cps_emb = self._embed_cps(cpss)
291
-
292
- with record_function("encoder"):
293
- positions = torch.arange(0, in_ttoks.shape[1], device=in_ttoks.device)
294
- xenc = self.encoder(in_ttoks.to(torch.long), positions, lang_emb=lang_embs)
295
-
296
- return xenc, positions, cps_emb
297
-
298
- def forward(self, in_ttoks, out_ttoks, languages, cpss, in_stoks, in_stoks_positions, out_stoks=None, loss=True, offset=None, xenc=None, xenc_positions=None, cps_emb=None):
299
- if xenc is None:
300
- xenc, cps_emb = self.run_encoder(in_ttoks, languages, cpss)
301
-
302
- with record_function("decoder"):
303
- x = (self.embeddings.embedding(in_stoks) +
304
- self.embeddings.positional_embedding[in_stoks_positions] +
305
- cps_emb).to(xenc[0].dtype)
306
- x = self.decoder(x, in_stoks_positions, xenc, xenc_positions)
307
- logits = self.embeddings.embedding.unembed(x)
308
- logits = logits * self.tunables.output_mult / (self.width / self.base_width)
309
-
310
- if loss is not None:
311
- enc_logits = self.encoder.embedding.unembed(xenc[0])
312
- enc_logits = enc_logits * self.tunables.output_mult / (self.width / self.base_width)
313
- with record_function("loss"):
314
- loss = F.cross_entropy(logits.transpose(-1,-2), out_stoks)
315
- if self.training:
316
- loss += 0.1 * F.cross_entropy(enc_logits.transpose(-1,-2), out_ttoks)
317
-
318
- return logits, loss
319
-
320
- #
321
- # inference
322
- #
323
- @classmethod
324
- def load_model(cls, ref="collabora/whisperspeech:t2s-small-en+pl.model",
325
- repo_id=None, filename=None, local_filename=None):
326
- if repo_id is None and filename is None and local_filename is None:
327
- if ":" in ref:
328
- repo_id, filename = ref.split(":", 1)
329
- else:
330
- local_filename = ref
331
- if not local_filename:
332
- local_filename = hf_hub_download(repo_id=repo_id, filename=filename)
333
- spec = torch.load(local_filename)
334
- model = cls(**spec['config'], tunables=Tunables(**spec['tunables']))
335
- model.load_state_dict(spec['state_dict'])
336
- model.eval()
337
- return model
338
-
339
- def load_checkpoint(self, local_filename):
340
- spec = torch.load(local_filename, map_location='cpu')
341
- assert 'pytorch-lightning_version' in spec, 'not a valid PyTorch Lightning checkpoint'
342
- state_dict = {k.replace('model.', ''):v
343
- for k,v in spec['state_dict'].items()}
344
- self.load_state_dict(state_dict)
345
- return self
346
-
347
- def save_model(self, fname):
348
- torch.save(dict(config = self.__stored_args__,
349
- tunables = dataclasses.asdict(self.tunables),
350
- state_dict = self.state_dict()), fname)
351
-
352
- def ensure_tokenizer(self):
353
- assert not self.training
354
- if self.tokenizer is None: self.tokenizer = CharTokenizer()
355
-
356
- def switch_dtypes(self, dtype=torch.float16):
357
- self.dtype = dtype
358
- for n,m in self.named_modules():
359
- # convert every leaf layer apart from the LayerNorms
360
- if isinstance(m, (nn.Linear, nn.Embedding)):
361
- m.to(dtype)
362
- # take care of buffers ([kv]_cache, masks) that are not in the leaf layers
363
- for bn,b in m.named_buffers(recurse=False):
364
- setattr(m,bn,b.to(dtype))
365
-
366
- def optimize(self, max_batch_size=1, dtype=torch.float16, torch_compile=True):
367
- for emb in [self.embeddings.embedding, self.embeddings.embedding]:
368
- emb.convert_for_eval()
369
- for l in self.encoder.layers:
370
- l.attn.convert_for_eval()
371
- for l in self.decoder.layers:
372
- l.attn.convert_for_eval()
373
- l.cross_attn.convert_for_eval()
374
- l.setup_kv_cache(max_batch_size, self.stoks_len, self.ttoks_len)
375
- self.switch_dtypes(dtype)
376
- if torch_compile:
377
- self.generate_next = torch.compile(self.generate_next, mode="reduce-overhead", fullgraph=True)
378
-
379
- @property
380
- def device(self):
381
- return next(self.parameters()).device
382
-
383
- # from https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
384
- def multinomial_sample_one_no_sync(self, probs_sort): # Does multinomial sampling without a cuda synchronization
385
- q = torch.empty_like(probs_sort).exponential_(1)
386
- return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
387
-
388
- def logits_to_probs(self, logits, T=1.0, top_k=None):
389
- logits = logits / max(T, 1e-5)
390
-
391
- logits[self.embeddings.embedding.codes:] = -torch.inf
392
- if top_k is not None:
393
- v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
394
- pivot = v.select(-1, -1).unsqueeze(-1)
395
- logits = torch.where(logits < pivot, -float("Inf"), logits)
396
-
397
- probs = torch.nn.functional.softmax(logits, dim=-1)
398
- return probs
399
-
400
- def sample(self, logits, T=1.0, top_k=None):
401
- probs = self.logits_to_probs(logits[0,-1], T, top_k)
402
- idx_next = self.multinomial_sample_one_no_sync(probs)
403
- return idx_next
404
-
405
- def generate_one(self, toks, toks_positions, cps_emb, xenc, xenc_positions, T, top_k):
406
- probs, _ = self(None, None, None, None, toks, toks_positions, loss=None, xenc=xenc, xenc_positions=xenc_positions, cps_emb=cps_emb)
407
- return self.sample(probs, T, top_k)
408
-
409
- def generate_next(self, *args, **kwargs):
410
- return self.generate_one(*args, **kwargs)
411
-
412
- @torch.no_grad()
413
- def prep(self, txt, cps=15, lang="en"):
414
- dev = self.device
415
- ttoks = torch.tensor(self.tokenizer.encode(txt), device=dev)
416
- ttoks = F.pad(ttoks, (0, self.ttoks_len - len(ttoks)), value=self.tokenizer.eot).unsqueeze(0)
417
- cpss = torch.tensor([cps], device=dev)
418
- langs = torch.tensor([languages.to_id(lang)], device=dev)
419
- return ttoks, cpss, langs
420
-
421
- @torch.no_grad()
422
- def generate(self, txt, cps=15, lang="en", N=None, T=0.7, top_k=None, step=None, show_progress_bar=True):
423
- self.ensure_tokenizer()
424
- N = N or self.stoks_len
425
- dev = self.device
426
- ttoks = []
427
- langs = []
428
- if isinstance(lang, list):
429
- lang0 = lang[0]
430
- assert isinstance(txt, list), "lang and txt have to be both lists or strings"
431
- for txt, lang in zip(txt, lang):
432
- tt = self.tokenizer.encode(txt)
433
- ttoks += tt
434
- langs += [languages.to_id(lang)] * len(tt)
435
- elif isinstance(lang, torch.Tensor):
436
- langs = lang
437
- ttoks = self.tokenizer.encode(txt)
438
- else:
439
- lang0 = lang
440
- ttoks = self.tokenizer.encode(txt)
441
- langs = torch.tensor([languages.to_id(lang)], device=dev).unsqueeze(0)
442
- ttoks = torch.tensor(ttoks, device=dev)
443
- ttoks = F.pad(ttoks, (1, self.ttoks_len - len(ttoks) - 1), value=self.tokenizer.eot).unsqueeze(0)
444
- cpss = torch.tensor([cps], device=dev)
445
- if not isinstance(langs, torch.Tensor):
446
- langs = torch.tensor(langs, device=dev)
447
- langs = F.pad(langs, (1, self.ttoks_len - len(langs) - 1), value=languages.to_id(lang0)).unsqueeze(0)
448
- it = range(0,N-1)
449
- if show_progress_bar: it = progress_bar(it)
450
-
451
- toks = torch.zeros((1,N), dtype=torch.long, device=dev)
452
- toks[:,0] = self.stoks_codes-1
453
- toks_positions = torch.arange(N, device=dev)
454
- with record_function("encode"):
455
- xenc, xenc_positions, cps_emb = self.run_encoder(ttoks, langs, cpss)
456
- toks_positions = torch.arange(N+1, device=dev)
457
- # contrary to S2A this model works without prefill and is actually a tiny bit faster
458
- # with record_function("prefill"):
459
- # toks[0,1] = self.generate_one(toks[:,:1], toks_positions[:1], cps_emb, xenc, xenc_positions, T, top_k)
460
- with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
461
- for i in it:
462
- toks[0,i+1] = self.generate_next(toks[:,i:i+1], toks_positions[i:i+1], cps_emb, xenc, xenc_positions, T, top_k)
463
- if i % 25 == 0 and toks[0,i+1] == self.stoks_codes-1: return toks[0,:i+1]
464
-
465
- # for profiling, debugging or early exit
466
- if step is not None: step()
467
- return toks[0,:]
468
-
469
- @torch.no_grad()
470
- def generate_batch(self, txts, N=None, T=1.1, top_k=7, show_progress_bar=True):
471
- self.ensure_tokenizer()
472
- N = self.stoks_len
473
- dev = self.device
474
- ttoks = []
475
- for txt in txts:
476
- ttoks_ = torch.tensor(self.tokenizer.encode(txt), device=dev)
477
- ttoks_ = F.pad(ttoks_, (0, self.ttoks_len - len(ttoks_)), value=self.tokenizer.eot).unsqueeze(0)
478
- ttoks.append(ttoks_)
479
- ttoks = torch.cat(ttoks, dim=0)
480
- toks = torch.zeros((len(ttoks),N), dtype=torch.long, device=dev)
481
- it = range(N)
482
- if show_progress_bar: it = progress_bar(it)
483
- for i in it:
484
- p, _ = self(ttoks, toks[:,:i], loss=None)
485
- last_p = p[:,-1]
486
- if top_k:
487
- last_p[last_p < torch.topk(last_p, top_k).values[:,-1,None]] = -torch.inf
488
- tok = torch.multinomial((last_p / float(T)).softmax(-1), 1)
489
- toks[:,i] = tok[:,0]
490
- if (toks[:,i] == self.stoks_codes-1).all(): return toks[:,:i]
491
- return toks
492
-
493
- # %% ../nbs/5B. Multi-lang text to semantic token modeling.ipynb 18
494
- def _make_model(size:str, tunables:Tunables=Tunables(), dataset=None, **kwargs):
495
- kwargs = dict(stoks_len = dataset.stoks_len, ttoks_len = dataset.ttoks_len, tunables=tunables, **kwargs)
496
- if 'stoks_codes' not in kwargs: kwargs['stoks_codes'] = dataset.stoks_codes
497
- if size == 'micro':
498
- return TSARTransformer(depth=2, n_head=3, ffn_mult=1, **kwargs)
499
- if size == 'tiny':
500
- return TSARTransformer(depth=4, n_head=6, **kwargs)
501
- if size == 'base':
502
- return TSARTransformer(depth=6, n_head=8, **kwargs)
503
- if size == 'small':
504
- return TSARTransformer(depth=12, n_head=12, **kwargs)
505
- if size == 'small+':
506
- return TSARTransformer(depth=12, n_head=16, **kwargs)
507
- if size == 'medium':
508
- return TSARTransformer(depth=24, n_head=16, **kwargs)
509
-
510
- def make_model(size:str, frozen_embeddings_model:str=None, tunables:Tunables=Tunables(), dataset:torch.utils.data.Dataset=None):
511
- from whisperspeech import vq_stoks
512
-
513
- if frozen_embeddings_model:
514
- vqmodel = vq_stoks.RQBottleneckTransformer.load_model(frozen_embeddings_model)
515
- model = _make_model(size, tunables, dataset, stoks_codes=vqmodel.vq_codes+1, stoks_width=vqmodel.rq.layers[0]._codebook.embed[0].shape[-1])
516
- model.load_frozen_semantic_embeddings(vqmodel)
517
- else:
518
- model = _make_model(size, tunables, dataset, mode=mode)
519
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
whisperspeech/train.py DELETED
@@ -1,271 +0,0 @@
1
- # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/B1. Training.ipynb.
2
-
3
- # %% auto 0
4
- __all__ = ['SimpleVisual', 'validate', 'train']
5
-
6
- # %% ../nbs/B1. Training.ipynb 2
7
- import io
8
- import time
9
- import random
10
- from pathlib import Path
11
-
12
- from fastprogress import progress_bar, master_bar
13
- import fastprogress
14
-
15
- import numpy as np
16
- import pylab as plt
17
- import math
18
-
19
- import IPython
20
-
21
- import torch
22
- import torch.nn as nn
23
- from torch.utils.data.dataloader import DataLoader
24
- from torch.profiler import record_function
25
-
26
- import webdataset as wds
27
-
28
- torch.backends.cudnn.benchmark = True
29
- torch.backends.cudnn.enabled = True
30
- torch.backends.cuda.matmul.allow_tf32 = True
31
- torch.set_float32_matmul_precision('medium')
32
-
33
- # %% ../nbs/B1. Training.ipynb 3
34
- class SimpleVisual:
35
- def __init__ (self, model, masterbar, total_steps):
36
- self.model = model
37
- self.masterbar = masterbar
38
- self.total_steps = total_steps
39
- self.epochs = total_steps // masterbar.main_bar.total
40
-
41
- gs = plt.GridSpec(2, 1, height_ratios=[3,1])
42
- graph_fig = plt.figure(figsize=(10,6))
43
- self.graph_fig = graph_fig
44
- self.loss_p = graph_fig.add_subplot(gs[0])
45
- self.lr_p = graph_fig.add_subplot(gs[1], sharex=self.loss_p)
46
- self.lr_p.tick_params('x', labelbottom=False)
47
- self.graph_out = None
48
-
49
- self.its = []
50
- self.train_losses = []
51
- self.val_losses = []
52
- self.lr_history = []
53
-
54
- def show(self):
55
- self.start_t = time.time()
56
- self.masterbar.write(["samples", "train", "val", "time"], table=True)
57
- self.graph_out = display(self.graph_fig, display_id=True, clear=True)
58
-
59
- def hide(self):
60
- if self.graph_out is not None:
61
- self.graph_out.update(IPython.display.HTML(''))
62
-
63
- def plot(self):
64
- loss_p, lr_p = self.loss_p, self.lr_p
65
- loss_p.clear()
66
- loss_p.plot(self.its, self.train_losses)
67
- loss_p.plot(self.its, self.val_losses)
68
- loss_p.set_xlim(0, self.total_steps)
69
- loss_p.set_yscale('log')
70
- lr_p.clear()
71
- lrs = np.array(self.lr_history)
72
- lr_p.plot(self.its, lrs)
73
- self.graph_out.update(self.graph_fig)
74
-
75
- def add_data(self, it, lr, train_loss, val_los):
76
- self.its.append(it)
77
- self.train_losses.append(train_loss)
78
- self.val_losses.append(val_los)
79
- self.lr_history.append(lr)
80
- self.plot()
81
-
82
- def add_table_row(self, it, avg_train_loss, val_loss):
83
- elapsed_t = time.time() - self.start_t
84
- self.masterbar.write([it, f"{avg_train_loss:.5f}", f"{val_loss:.5f}", fastprogress.core.format_time(elapsed_t)], table=True)
85
-
86
- def on_iter(self, bar, it, avg_train_loss, val_loss):
87
- epoch = math.ceil(it / self.total_steps * self.epochs)
88
- bar.comment = f"#{epoch}/{self.epochs} loss: {avg_train_loss:.3f} / {val_loss:.3f}"
89
-
90
- # %% ../nbs/B1. Training.ipynb 4
91
- # FIXME: we need to keep this synchronised with the validation code below...
92
- def validate(model, val, half=True, bs=16, drop_last=False, dl_workers=8, device="cuda"):
93
- if isinstance(val, torch.utils.data.IterableDataset):
94
- val_loader = wds.WebLoader(val, batch_size=None, num_workers=dl_workers, drop_last=drop_last) \
95
- .unbatched().shuffle(1024).batched(bs)
96
- else:
97
- val_loader = DataLoader(val, batch_size=bs, num_workers=dl_workers, pin_memory=True, drop_last=drop_last)
98
-
99
- with torch.no_grad():
100
- val_loss = 0
101
- val_samples = 0
102
- for args in val_loader:
103
- args = [x.to(device, non_blocking=True) for x in args]
104
- with torch.autocast(device_type=device, dtype=torch.float16 if half else torch.float32, enabled=device!='cpu'):
105
- ps, loss = model(*args)
106
- N = args[0].shape[0]
107
- val_loss += loss.mean().item() * N
108
- val_samples += N
109
- val_loss = val_loss / val_samples
110
-
111
- return val_loss
112
-
113
- # %% ../nbs/B1. Training.ipynb 5
114
- def train(checkpoint_path, model, train, val, half=True, bs=16, lr=1e-4, drop_last=False,
115
- weight_decay=0.1, warmup_steps=10000, epochs=10, clip_gradient_norm=None,
116
- dl_workers=8, visual_class = SimpleVisual, profiler=None,
117
- run_valid_every_iters=8000, table_row_every_iters=80000, chkpt_every_iters=None,
118
- device="cuda", trainable_params=None):
119
- if chkpt_every_iters is None:
120
- chkpt_every_iters = table_row_every_iters
121
-
122
- mb = master_bar(range(epochs))
123
- if isinstance(train, torch.utils.data.IterableDataset):
124
- pct_start = min(0.3, warmup_steps / (epochs * (train.total_samples//bs)))
125
- visual = visual_class(model, mb, epochs * train.total_samples)
126
- # pct_start = min(0.3, warmup_steps / (epochs * len(train)))
127
- # visual = visual_class(model, mb, epochs*len(train)*bs)
128
- else:
129
- pct_start = min(0.3, warmup_steps / (epochs * len(train) / bs))
130
- visual = visual_class(model, mb, epochs*len(train))
131
- model.visual = visual
132
-
133
- Path(checkpoint_path).mkdir(exist_ok=True)
134
-
135
- if isinstance(train, torch.utils.data.IterableDataset):
136
- # train_loader = DataLoader(train, batch_size=None, num_workers=dl_workers, pin_memory=True, drop_last=False, shuffle=False)
137
- # val_loader = DataLoader(val, batch_size=None, num_workers=dl_workers, pin_memory=True, drop_last=False)
138
- train_loader = wds.WebLoader(train, batch_size=None, num_workers=dl_workers, drop_last=drop_last) \
139
- .unbatched().shuffle(1024).batched(bs, partial=False)
140
- val_loader = wds.WebLoader(val, batch_size=None, num_workers=dl_workers, drop_last=drop_last) \
141
- .unbatched().shuffle(1024).batched(bs)
142
- else:
143
- train_loader = DataLoader(train, batch_size=bs, num_workers=dl_workers, pin_memory=True, drop_last=drop_last, shuffle=True)
144
- val_loader = DataLoader(val, batch_size=bs, num_workers=dl_workers, pin_memory=True, drop_last=drop_last)
145
-
146
- val_loss = torch.nan
147
- avg_train_loss = torch.nan
148
-
149
- if hasattr(model, 'setup'):
150
- model.setup(device)
151
-
152
- try:
153
- scheduler = None
154
-
155
- if trainable_params is None: trainable_params = model.parameters()
156
- all_params = set(trainable_params)
157
- customized_params = set()
158
- groups = []
159
- group_map = {}
160
- for name,m in model.named_modules():
161
- if hasattr(m, 'no_weight_decay') or hasattr(m, 'lr_scale'):
162
- m_trainable = [x for x in m.parameters() if x in all_params]
163
- if not m_trainable: continue
164
- customized_params |= set(m_trainable)
165
- m_wd = 0 if hasattr(m, 'no_weight_decay') else weight_decay
166
- m_lr = lr * getattr(m, 'lr_scale', 1)
167
- group = group_map.get((m_wd, m_lr), None)
168
- if not group:
169
- group = {"params": [], "names": [], "weight_decay": m_wd, "lr": m_lr}
170
- groups.append(group)
171
- group_map[(m_wd, m_lr)] = group
172
- group['params'] += m_trainable
173
- group['names'].append(name)
174
-
175
- other_params = all_params - customized_params
176
-
177
- if other_params:
178
- groups = groups + [
179
- {"names": ["other"], "params": list(other_params), "weight_decay": weight_decay },
180
- ]
181
-
182
- optimizer = torch.optim.AdamW(lr=lr, betas=(0.9, 0.95), fused=device!='cpu', params=groups)
183
- model._optimizer = optimizer
184
- scaler = torch.cuda.amp.GradScaler(enabled=half)
185
- scheduler = torch.optim.lr_scheduler.OneCycleLR(
186
- optimizer, pct_start=pct_start, steps_per_epoch=math.ceil(train.total_samples/bs), epochs=epochs,
187
- max_lr=[pg.get('lr', lr) for pg in groups],
188
- final_div_factor=25)
189
-
190
- it = 0
191
- next_val_it = it + 50
192
- next_chkpt_it = chkpt_every_iters
193
- next_table_it = table_row_every_iters
194
-
195
- visual.show()
196
-
197
- running_loss = [0]
198
-
199
- for epoch in mb:
200
- bar = progress_bar(train_loader, total=train.total_samples//bs, parent=mb)
201
- for args in bar:
202
- with record_function("forward"):
203
- args = [x.to(device, non_blocking=True) for x in args]
204
-
205
- # zero the parameter gradients
206
- optimizer.zero_grad(set_to_none=True)
207
-
208
- with torch.autocast(device_type=device, dtype=torch.float16 if half else torch.float32, enabled=device!='cpu'):
209
- ps, loss = model(*args)
210
- loss = loss.mean()
211
-
212
- with record_function("backward"):
213
- scaler.scale(loss).backward()
214
-
215
- if clip_gradient_norm:
216
- scaler.unscale_(optimizer)
217
- # Since the gradients of optimizer's assigned params are unscaled, clips as usual:
218
- torch.nn.utils.clip_grad_norm_(model.parameters(), clip_gradient_norm)
219
-
220
- scaler.step(optimizer)
221
- scaler.update()
222
-
223
- scheduler.step()
224
-
225
- if profiler is not None: profiler.step()
226
-
227
- with record_function("running_loss"):
228
- running_loss.append(loss.item())
229
- running_loss = running_loss[-5:]
230
- avg_train_loss = sum(running_loss)/len(running_loss)
231
-
232
- if it >= next_chkpt_it:
233
- with record_function("checkpoint"):
234
- next_chkpt_it += chkpt_every_iters
235
- torch.save(model.state_dict(), f'{checkpoint_path}/{it:08d}.pt')
236
-
237
- if it >= next_val_it:
238
- next_val_it += run_valid_every_iters
239
- with record_function("validation"):
240
- with record_function("model.eval"):
241
- model.eval()
242
- with torch.no_grad():
243
- val_loss = 0
244
- val_samples = 0
245
- for args in val_loader:
246
- args = [x.to(device, non_blocking=True) for x in args]
247
- with torch.autocast(device_type=device, dtype=torch.float16 if half else torch.float32, enabled=device!='cpu'):
248
- ps, loss = model(*args)
249
- N = args[0].shape[0]
250
- val_loss += loss.mean().item() * N
251
- val_samples += N
252
- val_loss = val_loss / val_samples
253
- with record_function("model.train"):
254
- model.train()
255
- with record_function("plotting"):
256
- visual.add_data(it, scheduler.get_last_lr(), avg_train_loss, val_loss)
257
-
258
- if it >= next_table_it:
259
- visual.add_table_row(it, avg_train_loss, val_loss)
260
- next_table_it += table_row_every_iters
261
-
262
- it += bs
263
- visual.on_iter(bar, it, avg_train_loss, val_loss)
264
- except KeyboardInterrupt:
265
- mb.write(f"interrupted")
266
- mb.show()
267
- pass
268
- finally:
269
- visual.add_table_row(it, avg_train_loss, val_loss)
270
- mb.show()
271
- visual.hide()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
whisperspeech/train_multi.py DELETED
@@ -1,263 +0,0 @@
1
- # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/B2. Training (Lightning).ipynb.
2
-
3
- # %% auto 0
4
- __all__ = []
5
-
6
- # %% ../nbs/B2. Training (Lightning).ipynb 2
7
- import io
8
- import time
9
- import random
10
- from pathlib import Path
11
-
12
- from fastprogress import progress_bar, master_bar
13
- import fastprogress
14
- import wandb
15
-
16
- import numpy as np
17
- import pylab as plt
18
-
19
- import torch
20
- import torch.nn as nn
21
- from torch.utils.data.dataloader import DataLoader
22
- from torch.profiler import record_function
23
-
24
- # %% ../nbs/B2. Training (Lightning).ipynb 3
25
- import lightning.pytorch as pl
26
- import math
27
-
28
- class TrainingTask(pl.LightningModule):
29
- def __init__(self, model, model_hparams=None):
30
- super().__init__()
31
- self.model = model
32
- self.model_hparams = model_hparams
33
-
34
- def on_fit_start(self):
35
- if getattr(self.model, 'setup'):
36
- self.model.setup(self.device)
37
-
38
- def configure_optimizers(self):
39
- """ Initialize AdamW optimizer"""
40
- lr = self.model_hparams['lr0']
41
- weight_decay = self.model_hparams['weight_decay']
42
-
43
- all_params = set(model.parameters())
44
- customized_params = set()
45
- groups = []
46
- group_map = {}
47
- for name,m in model.named_modules():
48
- if hasattr(m, 'no_weight_decay') or hasattr(m, 'lr_scale'):
49
- customized_params |= set(m.parameters())
50
- m_wd = 0 if hasattr(m, 'no_weight_decay') else weight_decay
51
- m_lr = lr * getattr(m, 'lr_scale', 1)
52
- group = group_map.get((m_wd, m_lr), None)
53
- if not group:
54
- group = {"params": [], "names": [], "weight_decay": m_wd, "lr": m_lr}
55
- groups.append(group)
56
- group_map[(m_wd, m_lr)] = group
57
- group['params'] += m.parameters()
58
- group['names'].append(name)
59
-
60
- other_params = all_params - customized_params
61
-
62
- param_groups = groups + [
63
- {"names": ["other"], "params": list(other_params), "weight_decay": weight_decay },
64
- ]
65
-
66
- optimizer = torch.optim.AdamW(lr=lr, betas=(0.9, 0.95), params=param_groups)
67
-
68
- # modified from https://github.com/Lightning-AI/lightning/issues/5449#issuecomment-1501597319
69
- def num_steps_per_epoch() -> int:
70
- """Get number of steps"""
71
- # Accessing _data_source is flaky and might break
72
- dataset = self.trainer.fit_loop._data_source.dataloader()
73
- dataset_size = len(dataset)
74
- # math.ceil so always overestimate (underestimating throws exceptions)
75
- num_steps = math.ceil(dataset_size / self.trainer.accumulate_grad_batches)
76
- return num_steps
77
-
78
- total_steps = self.model_hparams['epochs'] * num_steps_per_epoch()
79
- self.model_hparams['pct_start'] = min(0.3, self.model_hparams['warmup_steps'] / total_steps)
80
-
81
- print(f"{self.model_hparams['epochs']=} epochs x {num_steps_per_epoch()=} steps")
82
-
83
- lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
84
- optimizer,
85
- pct_start=self.model_hparams['pct_start'],
86
- max_lr=[pg.get('lr', lr) for pg in param_groups],
87
- steps_per_epoch=num_steps_per_epoch(),
88
- epochs=int(self.model_hparams['epochs']),
89
- final_div_factor=25
90
- )
91
-
92
- return [optimizer], [{'scheduler': lr_scheduler, 'interval': 'step'}]
93
-
94
- def training_step(self, train_batch, batch_idx):
95
- train_logits, train_loss = self.model.forward(*train_batch)
96
-
97
- self.log("train_loss", train_loss, sync_dist=True)
98
- return train_loss
99
-
100
- def validation_step(self, val_batch, batch_idx):
101
- val_logits, val_loss = self.model.forward(*val_batch)
102
-
103
- self.log("val_loss", val_loss, sync_dist=True)
104
- return val_loss
105
-
106
- def on_validation_epoch_end(self):
107
- if hasattr(self.model, 'get_metrics'):
108
- self.log_dict({'metrics/'+k:v for k,v in self.model.get_metrics().items()}, sync_dist=True)
109
-
110
- def test_step(self, val_batch, batch_idx):
111
- test_logits, test_loss = self.model.forward(*val_batch)
112
-
113
- self.log("test_loss", test_loss, sync_dist=True)
114
- return test_loss
115
-
116
- # %% ../nbs/B2. Training (Lightning).ipynb 4
117
- from fastcore.script import anno_parser
118
- import shlex
119
-
120
- # watch out: we can only pass Python values as keyword arguments (not positional)
121
- # everything else has to be a string
122
- def parse_and_call(name, fun, args, kwargs={}, log_to_wandb=True):
123
- p = anno_parser(fun)
124
- args = p.parse_args(args).__dict__
125
- args.pop('xtra'); args.pop('pdb')
126
- args.update({k:v for k, v in kwargs.items()})
127
- if log_to_wandb and type(wandb_logger.experiment.config) == wandb.sdk.wandb_config.Config:
128
- wandb_logger.experiment.config[name] = {k:v for k,v in args.items() if k not in ['dataset', 'tunables']}
129
- return fun(**args)
130
-
131
- # %% ../nbs/B2. Training (Lightning).ipynb 8
132
- import argparse
133
-
134
- parser = argparse.ArgumentParser()
135
- parser.add_argument('--task', type=str, help='Task to train')
136
- parser.add_argument('--seed', type=int, default=0, help='Global training seed')
137
- parser.add_argument('--batch-size', type=int, default=16, help='total batch size for all GPUs')
138
- parser.add_argument('--workers', type=int, default=8, help='max dataloader workers (per RANK in DDP mode)')
139
- parser.add_argument('--input-dir', type=str, default='', help='input data path') # fixed in the model for now
140
- parser.add_argument("--checkpoint-dir", type=str, default="./checkpoints/", help="directory to save the checkpoints")
141
- parser.add_argument('--epochs', type=int, default=10, help='total training epochs')
142
- parser.add_argument('--validate-every-n-steps', type=int, default=500, help='how training steps to run between validations')
143
- parser.add_argument('--weight-decay', type=float, default=1e-2, help='optimizer weight decay')
144
- parser.add_argument('--lr0', type=float, default=1e-4, help='optimizer initial learning rate')
145
- parser.add_argument('--clip-gradient-norm', type=float, default=None, help='enable gradient norm clipping')
146
- parser.add_argument('--accumulate-grad-batches', type=int, default=1, help='perform the optimizer step only after going through several batches of samples')
147
- parser.add_argument('--precision', type=str, default="16-mixed", help="floating point precision")
148
- parser.add_argument('--warmup-steps', type=int, default=10000, help='total number steps during which the learning rate rises (defaults to 10k updates)')
149
- parser.add_argument('--tunables', type=str, default="", help='tunable hyperparameters')
150
- parser.add_argument('--resume-from', type=Path, default=None, help='resume training from the given checkpoint')
151
- parser.add_argument('--strategy', type=str, default='ddp', help='distributed training strategy')
152
- parser.add_argument('--wandb-suffix', type=str, default=None, help='W&B project name suffix')
153
- parser.add_argument('--wandb-task-name', type=str, default=None, help='Task name for the W&B project name')
154
-
155
- args = parser.parse_args().__dict__
156
-
157
- task_args: list = shlex.split(args.pop("task"))
158
- task_name, task_args = task_args[0], task_args[1:]
159
- input_args: list = shlex.split(args.pop("input_dir"))
160
- checkpoint_dir: str = args.pop("checkpoint_dir")
161
- num_workers: int = args.pop("workers")
162
- batch_size: int = args.pop("batch_size")
163
- epochs: int = args.pop("epochs")
164
- tunables_args: list = shlex.split(args.pop("tunables"))
165
-
166
- hyp_params = {}
167
- hyp_params['batch_size'] = batch_size
168
- hyp_params['warmup_steps'] = args['warmup_steps']
169
- hyp_params['weight_decay'] = args['weight_decay']
170
- hyp_params['clip_gradient_norm'] = args['clip_gradient_norm']
171
- hyp_params['accumulate_grad_batches'] = args['accumulate_grad_batches']
172
- hyp_params['precision'] = args['precision']
173
- hyp_params['lr0'] = args['lr0']
174
- hyp_params['epochs'] = epochs
175
- hyp_params['strategy'] = args['strategy']
176
-
177
- # %% ../nbs/B2. Training (Lightning).ipynb 9
178
- from lightning.pytorch.loggers import WandbLogger
179
- from lightning.pytorch.callbacks import LearningRateMonitor
180
- import datetime
181
- import webdataset as wds
182
- import importlib
183
-
184
- torch.set_float32_matmul_precision('medium')
185
-
186
- project = f"WhisperSpeech-{args['wandb_task_name'] or task_name}"
187
- if args['wandb_suffix']:
188
- project += "-"+args['wandb_suffix']
189
-
190
- wandb_logger = WandbLogger(project=project)
191
-
192
- ckpt_callback = pl.callbacks.ModelCheckpoint(
193
- dirpath=f'{task_name}-{epochs}e',
194
- filename=task_name+"-{epoch}-{step}-{val_loss:.2f}",
195
- monitor="val_loss",
196
- save_top_k=4,
197
- train_time_interval=datetime.timedelta(minutes=5),
198
- )
199
-
200
- lr_monitor_callback = LearningRateMonitor(logging_interval='step')
201
-
202
- from torch.utils.data import DataLoader
203
-
204
- task = importlib.import_module("whisperspeech."+task_name)
205
-
206
- train_ds, val_ds = parse_and_call('dataset', task.load_datasets, input_args)
207
-
208
- tunables = None
209
- if hasattr(task, "Tunables"):
210
- import dataclasses
211
- tunables = parse_and_call('tunables', task.Tunables, tunables_args, log_to_wandb=False)
212
- if type(wandb_logger.experiment.config) == wandb.sdk.wandb_config.Config:
213
- wandb_logger.experiment.config['tunables'] = dataclasses.asdict(tunables)
214
-
215
- for name in ["lr0", "clip_gradient_norm", "weight_decay", "warmup_steps"]:
216
- val = getattr(tunables, name, None)
217
- if val is not None: hyp_params[name] = val
218
-
219
- if isinstance(train_ds, torch.utils.data.IterableDataset):
220
- dl_batch_size, dl_shuffle = None, False
221
- pin_memory = False
222
- else:
223
- dl_batch_size, dl_shuffle = batch_size, True
224
- pin_memory = True
225
-
226
- val_loader = wds.WebLoader(val_ds,
227
- batch_size=dl_batch_size,
228
- num_workers=num_workers,
229
- drop_last=False,
230
- pin_memory=pin_memory).unbatched().shuffle(1024).batched(batch_size).with_length(val_ds.total_samples // batch_size)
231
-
232
- train_loader = wds.WebLoader(train_ds,
233
- batch_size=dl_batch_size,
234
- num_workers=num_workers,
235
- drop_last=False,
236
- shuffle=dl_shuffle,
237
- pin_memory=pin_memory).unbatched().shuffle(1024).batched(batch_size).with_length(train_ds.total_samples // batch_size)
238
-
239
- model_kwargs = dict(dataset=train_ds)
240
- if tunables is not None: model_kwargs['tunables'] = tunables
241
- model = parse_and_call('model', task.make_model, task_args, model_kwargs)
242
-
243
- task = TrainingTask(model, model_hparams=hyp_params)
244
-
245
- trainer = pl.Trainer(strategy=hyp_params['strategy'],
246
- max_epochs=hyp_params['epochs'],
247
- accelerator="gpu",
248
- profiler="simple",
249
- precision=hyp_params['precision'],
250
- gradient_clip_val=hyp_params['clip_gradient_norm'],
251
- accumulate_grad_batches=hyp_params['accumulate_grad_batches'],
252
- val_check_interval=args.pop("validate_every_n_steps"),
253
- enable_checkpointing=True,
254
- logger=wandb_logger,
255
- callbacks=[ckpt_callback, lr_monitor_callback])
256
-
257
- if type(wandb_logger.experiment.config) == wandb.sdk.wandb_config.Config:
258
- wandb_logger.experiment.config.update(hyp_params)
259
-
260
- kwargs = {}
261
- if 'resume_from' in args:
262
- kwargs['ckpt_path'] = args['resume_from']
263
- trainer.fit(model=task, train_dataloaders=train_loader, val_dataloaders=val_loader, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
whisperspeech/utils.py DELETED
@@ -1,159 +0,0 @@
1
- # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/D. Common dataset utilities.ipynb.
2
-
3
- # %% auto 0
4
- __all__ = ['shard_glob', 'join_datasets', 'resampler', 'derived_name', 'derived_dataset', 'merge_in', 'AtomicTarWriter',
5
- 'readlines']
6
-
7
- # %% ../nbs/D. Common dataset utilities.ipynb 1
8
- import os
9
- import torch
10
- import torchaudio
11
- from pathlib import Path
12
- import webdataset as wds
13
- from contextlib import contextmanager
14
-
15
- import torch.nn.functional as F
16
-
17
- # %% ../nbs/D. Common dataset utilities.ipynb 2
18
- def shard_glob(input):
19
- if '{' in input:
20
- return wds.shardlists.expand_urls(input)
21
- if isinstance(input, (Path, str)):
22
- path = Path(input)
23
- if path.is_dir():
24
- glob = '*.tar.gz'
25
- else:
26
- glob = path.name
27
- path = path.parent
28
- input = Path(path).glob(glob)
29
- else:
30
- raise ArgumentError("input should be either a list or a path with an optional glob specifier")
31
- return [str(x) for x in input]
32
-
33
- # %% ../nbs/D. Common dataset utilities.ipynb 3
34
- class join_datasets(torch.utils.data.IterableDataset):
35
- def __init__(self, datasets):
36
- self.datasets = datasets
37
-
38
- def __iter__(self):
39
- probs = torch.tensor([getattr(ds, 'weight', 1) for ds in self.datasets], dtype=torch.float)
40
- its = [iter(ds) for ds in self.datasets]
41
- while True:
42
- try:
43
- yield next(its[torch.multinomial(probs, 1)])
44
- except StopIteration:
45
- return
46
-
47
- def __len__(self):
48
- return sum([ds.total_samples for ds in self.datasets])
49
-
50
- # %% ../nbs/D. Common dataset utilities.ipynb 5
51
- def resampler(newsr = 24000, key = 'samples_24k'):
52
- _last_sr = None
53
- tform = None
54
-
55
- def _resample(samples):
56
- for s in samples:
57
- sr = s['sample_rate']
58
- if sr != newsr:
59
- if sr != _last_sr: tform = torchaudio.transforms.Resample(sr, newsr)
60
- s[key] = tform(s['samples'])
61
- else:
62
- s[key] = s['samples']
63
- yield s
64
-
65
- return _resample
66
-
67
- # %% ../nbs/D. Common dataset utilities.ipynb 6
68
- def derived_name(input, kind, base="audio", suffix=".gz", dir=None):
69
- dir = Path(dir) if dir else Path(input).parent
70
- return str(dir/(Path(input).name.replace(f"-{base}-", f"-{kind}-") + suffix))
71
-
72
- # %% ../nbs/D. Common dataset utilities.ipynb 7
73
- def derived_dataset(kind, base='audio', suffix=".gz", decoders=[], dir=None):
74
- def deriver(url):
75
- url = str(derived_name(url, kind, base=base, suffix=suffix, dir=dir))
76
- return wds.WebDataset(
77
- wds.SimpleShardList([url])
78
- ).decode(*decoders)
79
- return deriver
80
-
81
- # %% ../nbs/D. Common dataset utilities.ipynb 8
82
- def merge_in(dataset_fun):
83
- """Merge a dataset into the current one returning samples with the union of keys. Pass in a function
84
- that takes a URL of a sample and returns a dataset for it (called everytime the URL changes).
85
-
86
- It requires (and validates) that both datasets have the same ordering of keys so you have
87
- to use it before any sample shuffling. Shard shuffling is ok.
88
- """
89
- def merge_loop(main_samples):
90
- #print("new merge loop:", dataset_fun)
91
- merged_samples = None
92
- cur_url = None
93
- i = None
94
- for s in main_samples:
95
- url = s['__url__']
96
- if url != cur_url:
97
- # this will open a new file when we get the first sample with a new __url__
98
- merged_samples = iter(dataset_fun(url))
99
- cur_url = url
100
- try:
101
- merge_s = next(merged_samples)
102
- except StopIteration:
103
- # if the original shard got repeated we won't observe a __url__ change
104
- # in this case restart the dataset from the beginning
105
- merged_samples = iter(dataset_fun(url))
106
- merge_s = next(merged_samples)
107
- assert merge_s['__key__'] == s['__key__'], f"sample keys don't match: {merge_s['__key__']}, {s['__key__']} in file {s['__url__']}"
108
- news = {}
109
- news.update(merge_s)
110
- news.update(s)
111
- yield news
112
- return merge_loop
113
-
114
- # %% ../nbs/D. Common dataset utilities.ipynb 9
115
- def split_to_chunks(stream, ikey='vad.npy', metakeys=[], pad_to_seconds=30, random_shift=False):
116
- for s in stream:
117
- audio, sr = s['audio']
118
- imax = len(s[ikey]) - 1
119
- for i,(ts,te) in enumerate(s[ikey]):
120
- samples = audio[0,int(ts*sr):int(te*sr)]
121
- if pad_to_seconds is not None:
122
- padding = pad_to_seconds*sr-samples.shape[-1]
123
- lpad = random.randint(0, padding) if random_shift else 0
124
- samples = F.pad(samples, (lpad, padding-lpad))
125
- subs = {"__key__": s['__key__'] + f"_{i:03d}",
126
- "src_key": s['__key__'],
127
- "__url__": s['__url__'],
128
- "i": i, "imax": imax,
129
- "tstart": ts, "tend": te, "total_seconds": audio.shape[-1]/sr,
130
- "lpad": lpad, "rpad": padding-lpad,
131
- "lpad_s": lpad/sr, "rpad_s": (padding-lpad)/sr,
132
- "samples": samples, "sample_rate": sr}
133
- for k in metakeys:
134
- subs[k] = s[k][i]
135
- yield subs
136
-
137
- # %% ../nbs/D. Common dataset utilities.ipynb 10
138
- def vad_dataset(shards, ikey='vad.npy', kind='vad'):
139
- return wds.WebDataset(shards).compose(
140
- wds.decode(wds.torch_audio),
141
- merge_in(derived_dataset(kind)),
142
- wds.select(lambda x: 'wav' in x or 'flac' in x or 'mp3' in x or 'ogg' in x), # skip samples without audio
143
- wds.rename(audio="flac;mp3;wav;ogg"),
144
- lambda x: split_to_chunks(x, ikey=ikey),
145
- )
146
-
147
- # %% ../nbs/D. Common dataset utilities.ipynb 11
148
- @contextmanager
149
- def AtomicTarWriter(name, throwaway=False):
150
- tmp = name+".tmp"
151
- with wds.TarWriter(tmp, compress=name.endswith('gz')) as sink:
152
- yield sink
153
- if not throwaway:
154
- os.rename(tmp, name)
155
-
156
- # %% ../nbs/D. Common dataset utilities.ipynb 12
157
- def readlines(fname):
158
- with open(fname) as file:
159
- return [line.rstrip() for line in file]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
whisperspeech/vad.py DELETED
@@ -1,71 +0,0 @@
1
- # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/1B. Voice activity detection.ipynb.
2
-
3
- # %% auto 0
4
- __all__ = []
5
-
6
- # %% ../nbs/1B. Voice activity detection.ipynb 3
7
- import os
8
- import torch
9
- import torchaudio
10
-
11
- from pathlib import Path
12
- from fastprogress import progress_bar
13
- from fastcore.script import call_parse
14
-
15
- import whisperx
16
- import random
17
- import numpy as np
18
- import webdataset as wds
19
-
20
- # %% ../nbs/1B. Voice activity detection.ipynb 5
21
- # some of the original file names have a dot in their name
22
- # webdataset does not like it so let's patch it
23
- def fix_dots_in_names(name):
24
- name, ext = name.rsplit('.', 1)
25
- return ".".join((name.replace('.', '_'), ext))
26
-
27
- def load_dataset(url, decode=True, rename_files=None):
28
- ds = wds.WebDataset(url, rename_files=rename_files)
29
- if not decode: return ds
30
- return ds.decode(wds.torch_audio)
31
-
32
- # %% ../nbs/1B. Voice activity detection.ipynb 7
33
- def extract_segments(vad_result, max_duration):
34
- binarize = whisperx.vad.Binarize(max_duration=max_duration)
35
- segments = binarize(vad_result)
36
- return [(x.start, x.end) for x in segments.get_timeline()]
37
-
38
- def segment_audio(vad_model, audio, sr=16000):
39
- vad_result = vad_model({"waveform": audio, "sample_rate": sr})
40
- return extract_segments(vad_result, 30)
41
-
42
- # %% ../nbs/1B. Voice activity detection.ipynb 13
43
- def flac_to_vad_name(input):
44
- if '-flac-' in input:
45
- return input.rsplit("/", 1)[1].replace('flac', 'vad') + ".gz"
46
- else:
47
- return input.rsplit("/", 1)[1].replace('raw', 'vad') + ".gz"
48
-
49
- @call_parse
50
- def process_shard(
51
- input:str, # input shard URL/path
52
- output:str=None, # output shard URL/path
53
- fix_dots:bool=False, # fix dots in LibriLight filenames
54
- ):
55
- if output is None: output = flac_to_vad_name(input)
56
-
57
- ds = torch.utils.data.DataLoader(load_dataset(input, rename_files=fix_dots_in_names if fix_dots else None), num_workers=2, batch_size=None)
58
- vad_model = whisperx.vad.load_vad_model('cuda')
59
-
60
- tmp = output+".tmp"
61
- with wds.TarWriter(tmp) as sink:
62
- for s in progress_bar(ds, total='noinfer'):
63
- audio, sr = s.get('flac', s.get('wav', (None, None)))
64
- if audio is None:
65
- print(f"warning: '{s['__key__']}' does not contain an audio file")
66
- continue
67
- sink.write({
68
- "__key__": s['__key__'],
69
- "vad.npy": np.array(segment_audio(vad_model, audio, sr=sr), dtype=np.float16)
70
- })
71
- os.rename(tmp, output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
whisperspeech/vq_stoks.py DELETED
@@ -1,493 +0,0 @@
1
- # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/2B. Whisper quantization (semantic token) model.ipynb.
2
-
3
- # %% auto 0
4
- __all__ = ['RQBottleneckTransformer', 'make_model']
5
-
6
- # %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 2
7
- import io
8
- import sys
9
- import time
10
- import torch
11
- import torchaudio
12
-
13
- # %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 3
14
- from pathlib import Path
15
- import json
16
- from fastprogress import progress_bar, master_bar
17
- import fastprogress
18
- import numpy as np
19
- import pylab as plt
20
- import pandas as pd
21
- import random
22
-
23
- import whisper
24
- from huggingface_hub import hf_hub_download
25
- from fastcore.basics import store_attr
26
-
27
- from torch import nn
28
- import torch.optim as optim
29
- import torch.nn.functional as F
30
- from torch.utils.data.dataloader import DataLoader
31
- import webdataset as wds
32
- from . import utils
33
-
34
- from vector_quantize_pytorch import ResidualVQ
35
-
36
- from fastcore.script import *
37
-
38
- # %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 9
39
- def merge_in(dataset_fun):
40
- """Merge a dataset into the current one returning samples with the union of keys. Pass in a function
41
- that takes a URL of a sample and returns a dataset for it (called everytime the URL changes).
42
-
43
- It requires (and validates) that both datasets have the same ordering of keys so you have
44
- to use it before any sample shuffling. Shard shuffling is ok.
45
- """
46
- def merge_loop(main_samples):
47
- #print("new merge loop:", dataset_fun)
48
- merged_samples = None
49
- cur_url = None
50
- i = None
51
- for s in main_samples:
52
- url = s['__url__']
53
- if url != cur_url:
54
- # this will open a new file when we get the first sample with a new __url__
55
- merged_samples = iter(dataset_fun(url))
56
- cur_url = url
57
- try:
58
- merge_s = next(merged_samples)
59
- except StopIteration:
60
- # if the original shard got repeated we won't observe a __url__ change
61
- # in this case restart the dataset from the beginning
62
- merged_samples = iter(dataset_fun(url))
63
- merge_s = next(merged_samples)
64
- assert merge_s['__key__'] == s['__key__'], f"sample keys don't match: {merge_s['__key__']}, {s['__key__']} in file {s['__url__']}"
65
- news = {}
66
- news.update(merge_s)
67
- news.update(s)
68
- yield news
69
- return merge_loop
70
-
71
- # %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 10
72
- def derived_dataset(kind, key='audio'):
73
- def deriver(url):
74
- url = str(Path(url).parent/(Path(url).name.replace(key, kind) + ".gz"))
75
- return wds.WebDataset(
76
- wds.SimpleShardList([url])
77
- ).decode()
78
- return deriver
79
-
80
- # %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 17
81
- def add_masks(samples):
82
- for s in samples:
83
- seconds = s['tend'] - s['tstart']
84
- # a mask (downsampled to the Whisper encoder token rate of 50/s) is used
85
- # to teach the model the concept of padding
86
- # this let's us decode shorter sequences later
87
- mask = torch.zeros(30*16000//320, dtype=torch.bool)
88
- mask[:int(seconds * 16000) // 320] = 1
89
- s['mask'] = mask
90
- yield s
91
-
92
- def tokenize_text(samples, ttoks_size=200, model="base.en", language="en"):
93
- multilingual = not model.endswith(".en")
94
- tokenizer = whisper.tokenizer.get_tokenizer(multilingual, language=language, task="transcribe")
95
- for s in samples:
96
- ttoks = tokenizer.encode(s['txt'])
97
- tokens = list(tokenizer.sot_sequence) + ttoks
98
- rpad = ttoks_size - len(tokens)
99
- s['in_ttoks'] = F.pad(torch.tensor(tokens), (0, rpad), value=tokenizer.eot)
100
- s['out_ttoks'] = F.pad(torch.tensor(tokens[1:] + [tokenizer.eot]), (0, rpad), value=-100)
101
- yield s
102
-
103
- # %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 22
104
- def load_dataset(
105
- shard_spec:str,
106
- proc_dataset_path:Path, # processed VAD and txt files
107
- samples:int, # set the per-GPU sample count
108
- txt_label:str="base.en-txt", # the label of the files containing transcriptions
109
- model:str="base.en",
110
- key:str="flac",
111
- language:str=None,
112
- validation:bool=False,
113
- ):
114
- from . import wh_transcribe
115
- shards = utils.shard_glob(shard_spec)
116
-
117
- if not language and model.endswith('en'): language = 'en'
118
- assert language, "please provide the dataset language for multilang models"
119
-
120
- same_on_all_nodes = lambda urls: urls # will only be used for validation
121
- ds = wds.WebDataset(shards, resampled=not validation, nodesplitter=same_on_all_nodes).compose(
122
- wds.decode(wds.torch_audio),
123
- wds.select(lambda x: 'wav' in x or 'flac' in x or 'mp3' in x or 'ogg' in x), # skip samples without audio
124
- wds.rename(audio="flac;mp3;wav;ogg"),
125
- merge_in(derived_dataset(proc_dataset_path, 'vad', key=key)),
126
- wds.map_dict(**{"vad.npy":wh_transcribe.chunk_merger}),
127
- wh_transcribe.split_to_chunks,
128
- utils.resampler(16000, 'samples_16k'),
129
- merge_in(derived_dataset(proc_dataset_path, txt_label, key=key)),
130
- )
131
- if 'librilight' in shards[0]:
132
- ds = ds.compose(
133
- # drop the first and last segment because they tend to be inaccurate
134
- # (the transcriptions don't have the "LibriVox" headers and "end of chapter" suffixes)
135
- wds.select(lambda x: x['i'] != 0 and x['i'] != x['imax']),
136
- )
137
- ds = ds.compose(
138
- add_masks,
139
- lambda x: tokenize_text(x, model=model, language=language),
140
- wds.to_tuple('samples_16k', 'mask', 'in_ttoks', 'out_ttoks'),
141
- wds.batched(32),
142
- )
143
- ds.total_samples = samples
144
-
145
- return ds
146
-
147
- # %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 28
148
- from whisperspeech.train import *
149
- from whisperspeech.modules import *
150
-
151
- # %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 29
152
- import dataclasses
153
-
154
- def rand(start, end):
155
- return random.random() * (end - start) + start
156
-
157
- def logrand(start, end):
158
- return 10**rand(math.log10(start), math.log10(end))
159
-
160
- @dataclasses.dataclass
161
- class Tunables:
162
- init_std :float = 1.5
163
- embeddings_std :float = 4.5e-2
164
- embeddings_lr_scale: float = 1
165
- output_mult :float = 1
166
- query_mult :float = 2
167
- rope :bool = True
168
- mask_embs :bool = True # force embeddings corresponding to the input audio padding to a constant value
169
- downsample_conv: bool = False
170
- downsample_mean: bool = True
171
-
172
- codebook_dim: int = 32
173
- codebook_decay: float = 0.9
174
-
175
- lr0 :float = .9e-3
176
- clip_gradient_norm :float = 2
177
- weight_decay :float = 1e-3
178
- warmup_steps :float = 850
179
-
180
- random :bool = False
181
-
182
- def __post_init__(self):
183
- # randomize the hyperparams if requested
184
- if self.random:
185
- self.init_std = logrand(1, 2)
186
- self.embeddings_std = logrand(3e-2,6e-2)
187
- self.embeddings_lr_scale = 2**rand(0,3)
188
- self.output_mult = 2**rand(-3,3)
189
- self.query_mult = logrand(1,8)
190
- self.codebook_dim = int(logrand(30,50))
191
- self.codebook_decay = logrand(0.86,0.95)
192
- self.rope = True
193
- self.mask_embs = True
194
- self.downsample_mean = True
195
-
196
- self.lr0 = logrand(.8e-3,1e-3)
197
- self.clip_gradient_norm = 10**rand(-1,1)
198
- self.warmup_steps = logrand(700,1000)
199
-
200
- @staticmethod
201
- def upgrade(args):
202
- args = {k:v for k,v in args.items()}
203
- def old_default(name, value):
204
- if name not in args: args[name] = value
205
- old_default('output_mult', 1)
206
- old_default('query_mult', 1)
207
- old_default('rope', False)
208
- old_default('mask_embs', False)
209
- old_default('downsample_conv', False)
210
- old_default('downsample_mean', False)
211
- if 'encoder_depth_ratio' in args: del args['encoder_depth_ratio']
212
- if 'vq_codes' in args: del args['vq_codes']
213
- return args
214
-
215
- # %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 30
216
- import math
217
-
218
- # %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 31
219
- class RQBottleneckTransformer(nn.Module):
220
- def __init__(self, vq_codes=512, q_depth=12, depth=1, n_head=2, head_width=64, ffn_mult=4,
221
- codebook_dim=2, threshold_ema_dead_code=2, use_cosine_sim = False, kl_loss_mul=1,
222
- downsample=1,
223
- whisper_model_name='tiny.en', tunables=Tunables()):
224
- super().__init__()
225
- width = n_head * head_width
226
- store_attr("codebook_dim,vq_codes,q_depth,n_head,head_width,ffn_mult,depth,use_cosine_sim,downsample,whisper_model_name")
227
- self.width = width
228
- self.base_width = 3 * head_width
229
- self.vq_codes = vq_codes
230
- self.tunables = tunables
231
- self.stoks_len = 1500//downsample
232
- self.stoks_per_sec = self.stoks_len//30
233
-
234
- qk_scale = self.tunables.query_mult * 8 / math.sqrt(head_width)
235
-
236
- self.kl_loss_mul = kl_loss_mul
237
-
238
- n_mlp = width * ffn_mult
239
- self.mlp = nn.Sequential(
240
- nn.Linear(width, n_mlp), nn.GELU(), nn.Linear(n_mlp, width)
241
- )
242
- self.mlp_ln = LayerNorm(width)
243
-
244
- if tunables.downsample_conv:
245
- self.downsample_conv = nn.Conv1d(width, width, kernel_size=3, stride=downsample, padding=1)
246
- else:
247
- self.downsample_conv = None
248
-
249
- if tunables.mask_embs: vq_codes = vq_codes + 1
250
- self.rq = ResidualVQ(
251
- dim = width,
252
- codebook_size = vq_codes, # codebook size
253
- decay = tunables.codebook_decay, # the exponential moving average decay, lower means the dictionary will change faster
254
- commitment_weight = 1., # the weight on the commitment loss
255
- threshold_ema_dead_code = threshold_ema_dead_code,
256
- use_cosine_sim = use_cosine_sim,
257
- codebook_dim = codebook_dim,
258
- num_quantizers= 1,
259
- )
260
-
261
- self.ce_lossf = nn.CrossEntropyLoss(ignore_index=-100)
262
- self.kl_lossf = nn.KLDivLoss(reduction='batchmean')
263
-
264
- self.positional_embedding = nn.Embedding(1500, width) # FIXME: should be self.stoks_len
265
-
266
- self.out_blocks = nn.Sequential(*[
267
- ResidualAttentionBlock(width, n_head, qk_scale=qk_scale, ffn_mult=ffn_mult, rope=tunables.rope) for _ in range(depth)
268
- ])
269
- self.ln_post = LayerNorm(width)
270
-
271
- self.whmodel = None
272
-
273
- self.apply(self.init_transformer)
274
- self.register_buffer('val_true', torch.zeros(1).cuda())
275
- self.register_buffer('val_total', torch.zeros(1).cuda())
276
-
277
- def setup(self, device):
278
- self.ensure_whisper(device)
279
-
280
- def init_transformer(self, m):
281
- if isinstance(m, LinearHead):
282
- m.no_weight_decay = True
283
- torch.nn.init.constant_(m.weight, 0)
284
- elif isinstance(m, QueryHead):
285
- m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
286
- torch.nn.init.constant_(m.weight, 0)
287
- elif isinstance(m, nn.Embedding):
288
- m.no_weight_decay = True
289
- m.lr_scale = self.tunables.embeddings_lr_scale
290
- std = self.tunables.embeddings_std
291
- torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
292
- elif isinstance(m, nn.Linear):
293
- m.lr_scale = 1/(m.weight.shape[1] / self.base_width)
294
- std = self.tunables.init_std / m.weight.shape[1]
295
- torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std)
296
- if m.bias is not None:
297
- torch.nn.init.trunc_normal_(m.bias, std=std, a=-3*std, b=3*std)
298
- elif isinstance(m, nn.LayerNorm):
299
- m.no_weight_decay = True
300
- torch.nn.init.constant_(m.bias, 0)
301
- torch.nn.init.constant_(m.weight, 1)
302
-
303
- @property
304
- def device(self):
305
- return next(self.parameters()).device
306
-
307
- #
308
- # training
309
- #
310
- @torch.no_grad()
311
- def extract_teacher(self, samples, input_toks, output_toks):
312
- embs = self.whmodel[0].encoder(whisper.log_mel_spectrogram(samples))
313
- teacher_logits = self.whmodel[0].decoder(input_toks, embs)
314
- # set teacher logits to 0 for padding positions so KLDivLoss ignores them
315
- teacher_logits[output_toks == -100] = 0
316
- return embs, teacher_logits
317
-
318
- def downsample_embeddings(self, x):
319
- if self.downsample_conv is not None:
320
- return x[:,::self.downsample] + self.downsample_conv(x.transpose(-1,-2)).transpose(-2,-1)
321
- elif self.tunables.downsample_mean:
322
- bs,slen,depth = x.shape
323
- return x.reshape(bs,slen//self.downsample,self.downsample,depth).mean(-2)
324
- else:
325
- return x[:,::self.downsample]
326
-
327
- def forward(self, samples, mask, input_toks, output_toks):
328
- embs, teacher_logits = self.extract_teacher(samples, input_toks, output_toks)
329
-
330
- x = self.downsample_embeddings(embs)
331
- x = x + self.mlp(self.mlp_ln(x))
332
- # VQ bottleneck
333
- quantized, self.indices, self.commit_loss = self.rq(x)
334
- self.commit_loss = self.commit_loss.mean()
335
-
336
- x = quantized.repeat_interleave(self.downsample, -2)
337
- project_out = getattr(self.rq, 'project_out', None) or self.rq.layers[0].project_out
338
- if self.tunables.mask_embs: x[~mask] = project_out(self.rq.layers[0]._codebook.embed[0,self.vq_codes])
339
- positions = torch.arange(0, x.shape[-2], dtype=torch.long, device=x.device)
340
- x = x + self.positional_embedding(positions)
341
- x = self.ln_post(self.out_blocks(x))
342
-
343
- logits = self.whmodel[0].decoder(input_toks, x)
344
- self.ce_loss = self.ce_lossf(logits.view(-1,logits.shape[-1]), output_toks.view(-1))
345
- self.kl_loss = self.kl_lossf(F.log_softmax(logits, dim=-1), F.softmax(teacher_logits, dim=-1))
346
- loss = self.ce_loss + self.kl_loss_mul * self.kl_loss + self.commit_loss
347
-
348
- if not self.training:
349
- valid_toks = output_toks != -100
350
- self.val_true += (logits.argmax(-1)[valid_toks] == output_toks[valid_toks]).float().sum()
351
- self.val_total += valid_toks.float().sum()
352
-
353
- return x, loss
354
-
355
- def get_metrics(self):
356
- metrics = {
357
- 'acc_0': (self.val_true / self.val_total).item(),
358
- }
359
- self.val_true[:] = 0
360
- self.val_total[:] = 0
361
- return metrics
362
-
363
- #
364
- # inference
365
- #
366
- @classmethod
367
- def load_model(cls, ref="collabora/spear-tts-pytorch:whisper-vq-stoks-medium-en+pl.model",
368
- repo_id=None, filename=None, local_filename=None):
369
- if repo_id is None and filename is None and local_filename is None:
370
- if ":" in ref:
371
- repo_id, filename = ref.split(":", 1)
372
- else:
373
- local_filename = ref
374
- if not local_filename:
375
- local_filename = hf_hub_download(repo_id=repo_id, filename=filename)
376
- spec = torch.load(local_filename)
377
- vqmodel = cls(**spec['config'], tunables=Tunables(**Tunables.upgrade(spec.get('tunables', {}))))
378
- vqmodel.load_state_dict(spec['state_dict'])
379
- vqmodel.eval()
380
- return vqmodel
381
-
382
- def load_checkpoint(self, local_filename):
383
- spec = torch.load(local_filename, map_location='cpu')
384
- assert 'pytorch-lightning_version' in spec, 'not a valid PyTorch Lightning checkpoint'
385
- state_dict = {k.replace('model.', ''):v
386
- for k,v in spec['state_dict'].items()}
387
- self.load_state_dict(state_dict)
388
- return self
389
-
390
- def save_model(self, fname, store_parameters=True):
391
- torch.save(dict(config = self.__stored_args__,
392
- tunables = dataclasses.asdict(self.tunables),
393
- state_dict = self.state_dict() if store_parameters else None), fname)
394
-
395
- def ensure_whisper(self, device):
396
- # the list wrapper is a hack to make sure the whole of Whisper is not sucked into self.parameters()
397
- if self.whmodel is None: self.whmodel = [whisper.load_model(self.whisper_model_name, device=device)]
398
- self.decoding_options = whisper.DecodingOptions()
399
- multilingual = not self.whisper_model_name.endswith('.en')
400
- self.tokenizer = whisper.tokenizer.get_tokenizer(multilingual)
401
-
402
- def quantize(self, embs):
403
- x = self.downsample_embeddings(embs)
404
- x = x + self.mlp(self.mlp_ln(x))
405
- _, stoks, _ = self.rq(x)
406
- if self.q_depth == 1:
407
- stoks = stoks.squeeze(-1)
408
- return stoks
409
-
410
- def dequantize(self, stoks):
411
- assert self.q_depth == 1
412
- assert len(stoks.shape) == 1, "batch processing is not supported"
413
- if isinstance(stoks, np.ndarray): stoks = torch.tensor(stoks)
414
- # remove padding
415
- padding = torch.nonzero(stoks == self.vq_codes)
416
- if padding.any(): stoks = stoks[:padding[0,0]]
417
- stoks = F.pad(stoks, (0,self.stoks_len - stoks.shape[-1]), value=self.vq_codes if self.tunables.mask_embs else 0)
418
- x = self.rq.layers[0]._codebook.embed[0,stoks.to(torch.long).view(-1)]
419
- x = x.repeat_interleave(self.downsample, -2)
420
- project_out = getattr(self.rq, 'project_out', None) or self.rq.layers[0].project_out
421
- x = project_out(x).unsqueeze(0)
422
- positions = torch.arange(0, x.shape[-2], dtype=torch.long, device=x.device)
423
- x = x + self.positional_embedding(positions)
424
- return self.ln_post(self.out_blocks(x))
425
-
426
- def encode_audio(self, audio):
427
- if isinstance(audio, str):
428
- x, sr = torchaudio.load(audio)
429
- x = torchaudio.transforms.Resample(sr, 16000)(x)[0]
430
- audio = x.unsqueeze(0)
431
- return self.encode_mel(whisper.log_mel_spectrogram(audio).to(self.device))
432
-
433
- def encode_mel(self, mel):
434
- assert len(mel.shape) == 3, "invalid mel spectrogram shape, expect (batch,chn,time)"
435
- self.ensure_whisper(self.device)
436
- n = mel.shape[-1]
437
- if n > whisper.audio.N_FRAMES:
438
- padding = 0
439
- padded = mel[:,:,:whisper.audio.N_FRAMES]
440
- else:
441
- padding = -n % whisper.audio.N_FRAMES
442
- padded = F.pad(mel, (0, padding), value=-1.5)
443
- embs = self.whmodel[0].encoder(padded)#.to(self.whmodel[0].device))#[:,:n//2]
444
- stoks = self.quantize(embs)
445
- if self.tunables.mask_embs:
446
- return stoks[:,:n//2//self.downsample]
447
- else:
448
- return stoks
449
-
450
- def decode_text(self, stoks, decoding_options=None):
451
- self.ensure_whisper(self.device)
452
- if decoding_options is None: decoding_options = self.decoding_options
453
- embs = self.dequantize(stoks).to(self.whmodel[0].device)
454
- return self.whmodel[0].decode(embs, decoding_options)
455
-
456
- # %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 33
457
- def make_model(size:str, tunables:Tunables=Tunables(), dataset:torch.utils.data.Dataset=None):
458
- if size == 'base.en-2d-4096c':
459
- model = RQBottleneckTransformer(codebook_dim=32, vq_codes=4096, q_depth=1, n_head=8, depth=1,
460
- downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True,
461
- whisper_model_name=size.split("-")[0], tunables=tunables)
462
- return model
463
- if size == 'base.en-2d-512c':
464
- model = RQBottleneckTransformer(codebook_dim=32, vq_codes=512, q_depth=1, n_head=8, depth=1,
465
- downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True,
466
- whisper_model_name=size.split("-")[0], tunables=tunables)
467
- return model
468
- if size == 'base.en-2d-512c-dim64':
469
- model = RQBottleneckTransformer(codebook_dim=64, vq_codes=512, q_depth=1, n_head=8, depth=1,
470
- downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True,
471
- whisper_model_name=size.split("-")[0], tunables=tunables)
472
- return model
473
- if size == 'base-2d-512c-dim64':
474
- model = RQBottleneckTransformer(codebook_dim=64, vq_codes=512, q_depth=1, n_head=8, depth=1,
475
- downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True,
476
- whisper_model_name=size.split("-")[0], tunables=tunables)
477
- return model
478
- if size == 'base-2d-1024c-dim64':
479
- model = RQBottleneckTransformer(codebook_dim=64, vq_codes=1024, q_depth=1, n_head=8, depth=1,
480
- downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True,
481
- whisper_model_name=size.split("-")[0], tunables=tunables)
482
- return model
483
- if size == 'medium-2d-512c-dim64':
484
- model = RQBottleneckTransformer(codebook_dim=64, vq_codes=512, q_depth=1, n_head=16, depth=1,
485
- downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True,
486
- whisper_model_name=size.split("-")[0], tunables=tunables)
487
- return model
488
- if size == 'medium-2d-1024c-dim64':
489
- model = RQBottleneckTransformer(codebook_dim=64, vq_codes=1024, q_depth=1, n_head=16, depth=1,
490
- downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True,
491
- whisper_model_name=size.split("-")[0], tunables=tunables)
492
- return model
493
- raise ArgumentError(f"invalid model size: {size}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
whisperspeech/wer_metrics.py DELETED
@@ -1,77 +0,0 @@
1
- # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/C. Word error rate metrics.ipynb.
2
-
3
- # %% auto 0
4
- __all__ = ['librispeech_data', 'DfBuilder', 'WERStats']
5
-
6
- # %% ../nbs/C. Word error rate metrics.ipynb 2
7
- import jiwer
8
- from whisper_normalizer.english import EnglishTextNormalizer
9
-
10
- import torchaudio
11
- from pathlib import Path
12
- import pandas as pd
13
-
14
- # %% ../nbs/C. Word error rate metrics.ipynb 3
15
- engnorm = EnglishTextNormalizer()
16
- def whisper_normalize(x):
17
- if type(x) == list:
18
- return [engnorm(y) for y in x]
19
- else:
20
- return engnorm(x)
21
-
22
- default_transform = jiwer.transforms.Compose([
23
- jiwer.transforms.ToLowerCase(),
24
- jiwer.transforms.ExpandCommonEnglishContractions(),
25
- whisper_normalize,
26
- jiwer.transforms.RemoveMultipleSpaces(),
27
- jiwer.transforms.Strip(),
28
- jiwer.transforms.RemovePunctuation(),
29
- jiwer.transforms.ReduceToListOfListOfWords(),
30
- ])
31
-
32
- # %% ../nbs/C. Word error rate metrics.ipynb 5
33
- def librispeech_data(datadir, sample_rate=16000):
34
- for file in Path(datadir).rglob('*.txt'):
35
- for line in file.read_text().split('\n'):
36
- if not line: continue
37
- idx, text = line.split(" ", 1)
38
- x, sr = torchaudio.load((file.parent/idx).with_suffix('.flac'))
39
- if sr != sample_rate:
40
- x = torchaudio.transforms.Resample(sr, self.sample_rate)(x)
41
- yield x, text
42
-
43
- # %% ../nbs/C. Word error rate metrics.ipynb 6
44
- class DfBuilder:
45
- def __init__(self):
46
- self.data = {}
47
-
48
- def push(self, **kwargs):
49
- for k,v in kwargs.items():
50
- if k not in self.data:
51
- self.data[k] = [v]
52
- else:
53
- self.data[k].append(v)
54
-
55
- def df(self):
56
- return pd.DataFrame(self.data)
57
-
58
- # %% ../nbs/C. Word error rate metrics.ipynb 7
59
- class WERStats(DfBuilder):
60
- def __init__(self, transform=default_transform):
61
- super().__init__()
62
- self.reference_transform = transform
63
- self.hypothesis_transform = transform
64
-
65
- def push_sample(self, snd, gt_text, text, idx=None):
66
- if snd is not None: self.push(secs = snd.shape[-1]/16000)
67
- diff = jiwer.process_words(gt_text, text, reference_transform=self.reference_transform, hypothesis_transform=self.hypothesis_transform)
68
- self.push(
69
- idx = idx,
70
- gt_text = gt_text,
71
- text = text,
72
- wer = diff.wer,
73
- mer = diff.mer,
74
- wil = diff.wil,
75
- wip = diff.wip,
76
- )
77
- return diff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
whisperspeech/wh_transcribe.py DELETED
@@ -1,146 +0,0 @@
1
- # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/2A. Whisper quantization dataset preparation.ipynb.
2
-
3
- # %% auto 0
4
- __all__ = []
5
-
6
- # %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 3
7
- import os
8
- import io
9
- import time
10
- import torch
11
- import torchaudio
12
-
13
- # %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 4
14
- from pathlib import Path
15
- import json
16
- from fastprogress import progress_bar, master_bar
17
- import numpy as np
18
- import random
19
-
20
- import whisper
21
-
22
- from torch import nn
23
- import torch.nn.functional as F
24
- from torch.utils.data.dataloader import DataLoader
25
-
26
- from fastcore.script import *
27
-
28
- from . import vad
29
- import webdataset as wds
30
-
31
- # %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 9
32
- # let's make it a bit more conservative
33
- # with full 30 second chunks it sometimes misses a small part of the transcript
34
- def random_cutter(dur):
35
- if random.random() < 0.5:
36
- return dur > 28 * (random.random()*0.95+0.05)
37
- else:
38
- return dur > 28
39
-
40
- def chunk_merger(segments, should_cut=lambda x: x > 28):
41
- if len(segments) == 0: return segments
42
- curr_start = segments[0][0]
43
- curr_end = 0
44
- merged = []
45
-
46
- for ts,te in segments:
47
- if should_cut(te - curr_start) and curr_end - curr_start > 0:
48
- merged.append((curr_start, curr_end))
49
- curr_start = ts
50
- curr_end = te
51
- merged.append((curr_start, curr_end))
52
- return merged
53
-
54
- # %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 18
55
- def merge_in(*datasets):
56
- """Merge multiple datasets into the current one returning samples with the union of keys.
57
-
58
- It requires (and validates) all datasets to have the same ordering of keys so you have
59
- to use it before any sample shuffling. Shard shuffling is ok.
60
- """
61
- def merge_loop(main_samples):
62
- for samples in zip(*[main_samples]+[iter(x) for x in datasets]):
63
- key = samples[0]['__key__']
64
- news = {}
65
- for s in samples:
66
- assert s['__key__'] == key
67
- news.update(s)
68
- yield news
69
- return merge_loop
70
-
71
- # %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 19
72
- import copy
73
-
74
- # %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 20
75
- # a workaround for https://github.com/webdataset/webdataset/issues/297
76
- # should be possible to use ds.compose here
77
- def wds_compose(ds, *args):
78
- ds = copy.copy(ds)
79
- ds.pipeline = copy.copy(ds.pipeline)
80
- for f in args:
81
- ds.append(f)
82
- return ds
83
-
84
- # %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 24
85
- def split_to_chunks(stream, pad_to_seconds=30, random_shift=False):
86
- for s in stream:
87
- audio, sr = s.get('flac', s.get('wav', (None, None)))
88
- if audio is None:
89
- print(f"warning: '{s['__key__']}' does not contain an audio file")
90
- continue
91
- imax = len(s['vad.npy']) - 1
92
- for i,(ts,te) in enumerate(s['vad.npy']):
93
- samples = audio[0,int(ts*sr):int(te*sr)]
94
- if pad_to_seconds is not None:
95
- padding = pad_to_seconds*sr-samples.shape[-1]
96
- lpad = random.randint(0, padding) if random_shift else 0
97
- samples = F.pad(samples, (lpad, padding-lpad))
98
- yield {"__key__": s['__key__'] + f"_{i:03d}",
99
- "__url__": s['__url__'],
100
- "i": i, "imax": imax,
101
- "tstart": ts, "tend": te, "total_seconds": audio.shape[-1]/sr,
102
- "lpad": lpad, "rpad": padding-lpad,
103
- "lpad_s": lpad/sr, "rpad_s": (padding-lpad)/sr,
104
- "samples": samples, "sample_rate": sr}
105
-
106
- # %% ../nbs/2A. Whisper quantization dataset preparation.ipynb 38
107
- def flac_to_txt_name(input, model_size):
108
- return input.rsplit("/", 1)[1].replace('flac', f'{model_size}-txt') + ".gz"
109
-
110
- @call_parse
111
- def process_shard(
112
- input:str, # input shard URL/path
113
- output:str=None, # output shard URL/path
114
- bs:int=None, # batch size (16 uses around 11GB of VRAM)
115
- n_samples:int=None, # limit the number of samples (useful for quick benchmarking)
116
- whisper_model:str="base.en" # Whisper model size
117
- ):
118
- if output is None: output = flac_to_txt_name(input, whisper_model)
119
- if bs is None: bs = 16
120
- if n_samples is None: n_samples = 'noinfer'
121
- else: n_samples = n_samples // bs
122
-
123
- ds = wds_compose(vad.load_dataset(input),
124
- merge_in(wds.WebDataset(vad.flac_to_vad_name(input)).decode()),
125
- wds.map_dict(**{"vad.npy":chunk_merger}),
126
- split_to_chunks,
127
- wds.to_tuple('__key__', 'samples'),
128
- wds.batched(bs),
129
- )
130
- dl = DataLoader(ds, num_workers=2, batch_size=None)
131
-
132
- whmodel = whisper.load_model(whisper_model)
133
- decoding_options = whisper.DecodingOptions(language='en')
134
-
135
- tmp = output+".tmp"
136
- with wds.TarWriter(tmp) as sink:
137
- for keys, samples in progress_bar(dl, total=n_samples):
138
- with torch.no_grad():
139
- embs = whmodel.encoder(whisper.log_mel_spectrogram(samples).cuda())
140
- decs = whmodel.decode(embs, decoding_options)
141
- for key, dec in zip(keys, decs):
142
- sink.write({
143
- "__key__": key,
144
- "txt": dec.text,
145
- })
146
- os.rename(tmp, output)