anonymoussubmitter222 commited on
Commit
e63fe3d
1 Parent(s): 7d9fedb

added app file

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. Untitled.ipynb +242 -0
  3. __pycache__/lm_tunisian.cpython-38.pyc +0 -0
  4. app.py +371 -0
  5. ctc_train.py +339 -0
  6. debugging.csv +11 -0
  7. file.wav +0 -0
  8. lm_decoded_ctc.py +376 -0
  9. lm_tunisian.py +361 -0
  10. partly_frozen_splitted_wavlm/1986/ctc_train.py +339 -0
  11. partly_frozen_splitted_wavlm/1986/env.log +402 -0
  12. partly_frozen_splitted_wavlm/1986/hyperparams.yaml +162 -0
  13. partly_frozen_splitted_wavlm/1986/lm_decoded_ctc.py +377 -0
  14. partly_frozen_splitted_wavlm/1986/lm_tunisian.py +361 -0
  15. partly_frozen_splitted_wavlm/1986/log.txt +0 -0
  16. partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+11-59-05+00/CKPT.yaml +4 -0
  17. partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+11-59-05+00/brain.ckpt +3 -0
  18. partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+11-59-05+00/counter.ckpt +3 -0
  19. partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+11-59-05+00/dataloader-TRAIN.ckpt +3 -0
  20. partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+11-59-05+00/model.ckpt +3 -0
  21. partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+11-59-05+00/modelopt.ckpt +3 -0
  22. partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+11-59-05+00/scheduler_model.ckpt +3 -0
  23. partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+11-59-05+00/scheduler_wav2vec.ckpt +3 -0
  24. partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+11-59-05+00/wav2vec2.ckpt +3 -0
  25. partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+11-59-05+00/wav2vec_opt.ckpt +3 -0
  26. partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+12-24-11+00/CKPT.yaml +4 -0
  27. partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+12-24-11+00/brain.ckpt +3 -0
  28. partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+12-24-11+00/counter.ckpt +3 -0
  29. partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+12-24-11+00/dataloader-TRAIN.ckpt +3 -0
  30. partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+12-24-11+00/model.ckpt +3 -0
  31. partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+12-24-11+00/modelopt.ckpt +3 -0
  32. partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+12-24-11+00/scheduler_model.ckpt +3 -0
  33. partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+12-24-11+00/scheduler_wav2vec.ckpt +3 -0
  34. partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+12-24-11+00/wav2vec2.ckpt +3 -0
  35. partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+12-24-11+00/wav2vec_opt.ckpt +3 -0
  36. partly_frozen_splitted_wavlm/1986/save/label_encoder.txt +46 -0
  37. partly_frozen_splitted_wavlm/1986/train_log.txt +86 -0
  38. partly_frozen_splitted_wavlm/1986/wer_test.txt +0 -0
  39. partly_frozen_splitted_wavlm/1986/wer_test_salah.txt +61 -0
  40. partly_frozen_splitted_wavlm/1986/wer_test_salah_local.txt +21 -0
  41. partly_frozen_splitted_wavlm/ctc_train.py +339 -0
  42. partly_frozen_splitted_wavlm/env.log +379 -0
  43. partly_frozen_splitted_wavlm/hyperparams.yaml +162 -0
  44. partly_frozen_splitted_wavlm/log.txt +1998 -0
  45. partly_frozen_splitted_wavlm/save/label_encoder.txt +37 -0
  46. recording.webm +0 -0
  47. requirements.txt +16 -0
  48. running_tunisian.ipynb +0 -0
  49. samples/Salah1.wav +0 -0
  50. samples/Salah10.wav +0 -0
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ tunisian.arpa filter=lfs diff=lfs merge=lfs -text
Untitled.ipynb ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "71d69be2",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import torchaudio\n",
11
+ "import numpy as np \n",
12
+ "import torch\n",
13
+ "import pandas as pd\n",
14
+ "import os"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": 2,
20
+ "id": "eb5c6da2",
21
+ "metadata": {
22
+ "scrolled": true
23
+ },
24
+ "outputs": [
25
+ {
26
+ "data": {
27
+ "text/plain": [
28
+ "['Salah1.wav',\n",
29
+ " 'Salah2.wav',\n",
30
+ " 'Salah3.wav',\n",
31
+ " 'Salah4.wav',\n",
32
+ " 'Salah5.wav',\n",
33
+ " 'Salah6.wav',\n",
34
+ " 'Salah7.wav',\n",
35
+ " 'Salah8.wav',\n",
36
+ " 'Salah9.wav',\n",
37
+ " 'Salah10.wav']"
38
+ ]
39
+ },
40
+ "execution_count": 2,
41
+ "metadata": {},
42
+ "output_type": "execute_result"
43
+ }
44
+ ],
45
+ "source": [
46
+ "files = os.listdir(\"./\")\n",
47
+ "files = [x for x in files if \".wav\" in x]\n",
48
+ "files = [f\"Salah{i}.wav\" for i in range(1,11)]\n",
49
+ "files"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": 4,
55
+ "id": "b2be1d8e",
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "words = {}\n",
60
+ "words[1] = \"نحب ماكلة بنينة كسكروت نظيف و رخيص\"\n",
61
+ "words[2]= \"باهي وقتاش نمشيو ال تونس\"\n",
62
+ "words[3] = \"اعطيني خمسة الاف و خمسة ميا بلاهي\"\n",
63
+ "words[4] = \"تعبت هاني راكش في الدار\"\n",
64
+ "words[5] = \"نهار السبت ماشي نقرى ان شاء الله\"\n",
65
+ "words[6]= \"زعما نلقى أحمد في الستاد ولا ماهوش هوني\"\n",
66
+ "words[7]= \"نحب نمشي ال بنزرت نرتاح شوية\"\n",
67
+ "words[8] = \"حكيت مع لولاد قالولي كل شي مريقل نهار السبت\"\n",
68
+ "words[9] = \"ناكل كفتاجي و نجم نشري شوية حوت زادة\"\n",
69
+ "words[10] = \"انتي خويا و عشيري صالح نحبك\"\n",
70
+ "words = [words[i] for i in range(1,11)]"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": 6,
76
+ "id": "c46588ba",
77
+ "metadata": {},
78
+ "outputs": [
79
+ {
80
+ "name": "stdout",
81
+ "output_type": "stream",
82
+ "text": [
83
+ "torch.Size([1, 238080])\n",
84
+ "torch.Size([1, 184320])\n",
85
+ "torch.Size([1, 207360])\n",
86
+ "torch.Size([1, 168960])\n",
87
+ "torch.Size([1, 168960])\n",
88
+ "torch.Size([1, 192000])\n",
89
+ "torch.Size([1, 184320])\n",
90
+ "torch.Size([1, 199680])\n",
91
+ "torch.Size([1, 230400])\n",
92
+ "torch.Size([1, 192000])\n"
93
+ ]
94
+ }
95
+ ],
96
+ "source": [
97
+ "durations= []\n",
98
+ "path_jz = \"samples/\"\n",
99
+ "paths = [os.path.join(path_jz,x) for x in files]\n",
100
+ "srs= [48000 for x in paths]\n",
101
+ "IDs=[]\n",
102
+ "for f in files: \n",
103
+ " x,sr = torchaudio.load(f)\n",
104
+ " new_audio = torch.mean(x, dim=0).unsqueeze(0)\n",
105
+ " print(new_audio.shape)\n",
106
+ " torchaudio.save(os.path.join(\"monoaudiotun\", f), new_audio, sr)\n",
107
+ " duration = float(x.shape[1]) / sr\n",
108
+ " durations.append(duration)\n",
109
+ " IDs.append(f.split(\".\")[0])\n",
110
+ " \n"
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "code",
115
+ "execution_count": 7,
116
+ "id": "b71db098",
117
+ "metadata": {},
118
+ "outputs": [],
119
+ "source": [
120
+ "test_salah = pd.DataFrame(\n",
121
+ " {'ID': IDs,\n",
122
+ " 'duration': durations,\n",
123
+ " 'wav': paths,\n",
124
+ " \"sr\": srs,\n",
125
+ " \"wrd\": words\n",
126
+ " })\n"
127
+ ]
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "execution_count": 8,
132
+ "id": "b3fdd365",
133
+ "metadata": {},
134
+ "outputs": [],
135
+ "source": [
136
+ "test_salah.to_csv(\"test_salah_local.csv\", index=False)"
137
+ ]
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "execution_count": 28,
142
+ "id": "f6ac8451",
143
+ "metadata": {},
144
+ "outputs": [
145
+ {
146
+ "name": "stdout",
147
+ "output_type": "stream",
148
+ "text": [
149
+ "%WER 45.59 [ 31 / 68, 3 ins, 7 del, 21 sub ]\n",
150
+ "%SER 90.00 [ 9 / 10 ]\n",
151
+ "Scored 10 sentences, 0 not present in hyp.\n",
152
+ "================================================================================\n",
153
+ "ALIGNMENTS\n",
154
+ "\n",
155
+ "Format:\n",
156
+ "<utterance-id>, WER DETAILS\n",
157
+ "<eps> ; reference ; on ; the ; first ; line\n",
158
+ " I ; S ; = ; = ; S ; D \n",
159
+ " and ; hypothesis ; on ; the ; third ; <eps>\n",
160
+ "================================================================================\n",
161
+ "Salah4, %WER 0.00 [ 0 / 5, 0 ins, 0 del, 0 sub ]\n",
162
+ "تعبت ; هاني ; راكش ; في ; الدار\n",
163
+ " = ; = ; = ; = ; = \n",
164
+ "تعبت ; هاني ; راكش ; في ; الدار\n",
165
+ "================================================================================\n",
166
+ "Salah5, %WER 57.14 [ 4 / 7, 0 ins, 1 del, 3 sub ]\n",
167
+ "نهار ; السبت ; ماشي ; نقرى ; ان ; شاء ; ا��له\n",
168
+ " = ; = ; = ; S ; S ; S ; D \n",
169
+ "نهار ; السبت ; ماشي ; نقرا ; إن ; شاءالله ; <eps>\n",
170
+ "================================================================================\n",
171
+ "Salah2, %WER 60.00 [ 3 / 5, 0 ins, 1 del, 2 sub ]\n",
172
+ "باهي ; وقتاش ; نمشيو ; ال ; تونس\n",
173
+ " = ; = ; S ; S ; D \n",
174
+ "باهي ; وقتاش ; نمشيوا ; لتونس ; <eps>\n",
175
+ "================================================================================\n",
176
+ "Salah7, %WER 33.33 [ 2 / 6, 0 ins, 1 del, 1 sub ]\n",
177
+ "نحب ; نمشي ; ال ; بنزرت ; نرتاح ; شوية\n",
178
+ " = ; = ; S ; D ; = ; = \n",
179
+ "نحب ; نمشي ; لبنزرت ; <eps> ; نرتاح ; شوية\n",
180
+ "================================================================================\n",
181
+ "Salah6, %WER 37.50 [ 3 / 8, 0 ins, 0 del, 3 sub ]\n",
182
+ "زعما ; نلقى ; أحمد ; في ; الستاد ; ولا ; ماهوش ; هوني\n",
183
+ " S ; = ; = ; = ; S ; S ; = ; = \n",
184
+ "زعمة ; نلقى ; أحمد ; في ; السعد ; وإلا ; ماهوش ; هوني\n",
185
+ "================================================================================\n",
186
+ "Salah10, %WER 83.33 [ 5 / 6, 1 ins, 1 del, 3 sub ]\n",
187
+ "انتي ; <eps> ; خويا ; و ; عشيري ; صالح ; نحبك\n",
188
+ " S ; I ; = ; S ; S ; D ; = \n",
189
+ "إنت ; ي ; خويا ; وعشيلي ; صلاح ; <eps> ; نحبك\n",
190
+ "================================================================================\n",
191
+ "Salah8, %WER 44.44 [ 4 / 9, 2 ins, 0 del, 2 sub ]\n",
192
+ "حكيت ; مع ; لولاد ; قالولي ; كل ; شي ; مريقل ; <eps> ; <eps> ; نهار ; السبت\n",
193
+ " = ; = ; S ; = ; = ; = ; S ; I ; I ; = ; = \n",
194
+ "حكيت ; مع ; الأولاد ; قالولي ; كل ; شي ; مر ; ي ; ل ; نهار ; السبت\n",
195
+ "================================================================================\n",
196
+ "Salah3, %WER 85.71 [ 6 / 7, 0 ins, 1 del, 5 sub ]\n",
197
+ "اعطيني ; خمسة ; الاف ; و ; خمسة ; ميا ; بلاهي\n",
198
+ " S ; = ; S ; S ; S ; S ; D \n",
199
+ "أعطيني ; خمسة ; آلاف ; وخمسة ; ملا ; باللاهي ; <eps>\n",
200
+ "================================================================================\n",
201
+ "Salah9, %WER 25.00 [ 2 / 8, 0 ins, 1 del, 1 sub ]\n",
202
+ "ناكل ; كفتاجي ; و ; نجم ; نشري ; شوية ; حوت ; زادة\n",
203
+ " = ; = ; S ; D ; = ; = ; = ; = \n",
204
+ "ناكل ; كفتاجي ; وننجم ; <eps> ; نشري ; شوية ; حوت ; زادة\n",
205
+ "================================================================================\n",
206
+ "Salah1, %WER 28.57 [ 2 / 7, 0 ins, 1 del, 1 sub ]\n",
207
+ "نحب ; ماكلة ; بنينة ; كسكروت ; نظيف ; و ; رخيص\n",
208
+ " = ; = ; = ; = ; = ; S ; D \n",
209
+ "نحب ; ماكلة ; بنينة ; كسكروت ; نظيف ; ورخيص ; <eps>\n"
210
+ ]
211
+ }
212
+ ],
213
+ "source": [
214
+ "filein = \"wer_test_salah.txt\"\n",
215
+ "with open(filein, \"r\") as wer : \n",
216
+ " lines = wer.read().splitlines()\n",
217
+ " print(\"\\n\".join(lines))"
218
+ ]
219
+ }
220
+ ],
221
+ "metadata": {
222
+ "kernelspec": {
223
+ "display_name": "Python 3",
224
+ "language": "python",
225
+ "name": "python3"
226
+ },
227
+ "language_info": {
228
+ "codemirror_mode": {
229
+ "name": "ipython",
230
+ "version": 3
231
+ },
232
+ "file_extension": ".py",
233
+ "mimetype": "text/x-python",
234
+ "name": "python",
235
+ "nbconvert_exporter": "python",
236
+ "pygments_lexer": "ipython3",
237
+ "version": "3.8.5"
238
+ }
239
+ },
240
+ "nbformat": 4,
241
+ "nbformat_minor": 5
242
+ }
__pycache__/lm_tunisian.cpython-38.pyc ADDED
Binary file (9.51 kB). View file
 
app.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import logging
5
+ import speechbrain as sb
6
+ from speechbrain.utils.distributed import run_on_main
7
+ from hyperpyyaml import load_hyperpyyaml
8
+ from pathlib import Path
9
+ import torchaudio.transforms as T
10
+ import torchaudio
11
+ import numpy as np
12
+
13
+ from pyctcdecode import build_ctcdecoder
14
+ hparams_file, run_opts, overrides = sb.parse_arguments(["wavlm_partly_frozen.yaml"])
15
+
16
+ # If distributed_launch=True then
17
+ # create ddp_group with the right communication protocol
18
+ sb.utils.distributed.ddp_init_group(run_opts)
19
+
20
+ with open(hparams_file) as fin:
21
+ hparams = load_hyperpyyaml(fin, overrides)
22
+
23
+ # Create experiment directory
24
+ sb.create_experiment_directory(
25
+ experiment_directory=hparams["output_folder"],
26
+ hyperparams_to_save=hparams_file,
27
+ overrides=overrides,
28
+ )
29
+ def read_labels_file(labels_file):
30
+ with open(labels_file, "r") as lf:
31
+ lines = lf.read().splitlines()
32
+ division = "==="
33
+ numbers = {}
34
+ for line in lines :
35
+ if division in line :
36
+ break
37
+ string, number = line.split("=>")
38
+ number = int(number)
39
+ string = string[1:-2]
40
+ numbers[number] = string
41
+ return [numbers[x] for x in range(len(numbers))]
42
+ labels = read_labels_file(os.path.join(hparams["save_folder"], "label_encoder.txt"))
43
+ print(labels)
44
+ labels = [""] + labels[1:]
45
+ print(len(labels))
46
+
47
+ # Dataset prep (parsing Librispeech)
48
+
49
+ resampler_8000 = T.Resample(8000, 16000, dtype=torch.float)
50
+
51
+ resampler_44100 =T.Resample(44100, 16000, dtype=torch.float)
52
+ resampler_48000 =T.Resample(48000, 16000, dtype=torch.float)
53
+
54
+
55
+ resamplers = {"8000": resampler_8000, "44100":resampler_44100, "48000": resampler_48000}
56
+ def dataio_prepare(hparams):
57
+ """This function prepares the datasets to be used in the brain class.
58
+ It also defines the data processing pipeline through user-defined functions."""
59
+ data_folder = hparams["data_folder"]
60
+
61
+ train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
62
+ csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
63
+ )
64
+
65
+ if hparams["sorting"] == "ascending":
66
+ # we sort training data to speed up training and get better results.
67
+ train_data = train_data.filtered_sorted(sort_key="duration")
68
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
69
+ hparams["train_dataloader_opts"]["shuffle"] = False
70
+
71
+ elif hparams["sorting"] == "descending":
72
+ train_data = train_data.filtered_sorted(
73
+ sort_key="duration", reverse=True
74
+ )
75
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
76
+ hparams["train_dataloader_opts"]["shuffle"] = False
77
+
78
+ elif hparams["sorting"] == "random":
79
+ pass
80
+
81
+ else:
82
+ raise NotImplementedError(
83
+ "sorting must be random, ascending or descending"
84
+ )
85
+
86
+ valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
87
+ csv_path=hparams["valid_csv"], replacements={"data_root": data_folder},
88
+ )
89
+ valid_data = valid_data.filtered_sorted(sort_key="duration")
90
+
91
+ # test is separate
92
+ test_datasets = {}
93
+ for csv_file in hparams["test_csv"]:
94
+ name = Path(csv_file).stem
95
+ test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(
96
+ csv_path=csv_file, replacements={"data_root": data_folder}
97
+ )
98
+ test_datasets[name] = test_datasets[name].filtered_sorted(
99
+ sort_key="duration"
100
+ )
101
+
102
+ datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]
103
+
104
+ # 2. Define audio pipeline:
105
+ @sb.utils.data_pipeline.takes("wav", "sr")
106
+ @sb.utils.data_pipeline.provides("sig")
107
+ def audio_pipeline(wav, sr):
108
+ sig = sb.dataio.dataio.read_audio(wav)
109
+ sig = resamplers[sr](sig)
110
+ return sig
111
+
112
+ sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
113
+ label_encoder = sb.dataio.encoder.CTCTextEncoder()
114
+
115
+ # 3. Define text pipeline:
116
+ @sb.utils.data_pipeline.takes("wrd")
117
+ @sb.utils.data_pipeline.provides(
118
+ "wrd", "char_list", "tokens_list", "tokens_bos", "tokens_eos", "tokens"
119
+ )
120
+ def text_pipeline(wrd):
121
+ yield wrd
122
+ char_list = list(wrd)
123
+ yield char_list
124
+ tokens_list = label_encoder.encode_sequence(char_list)
125
+ yield tokens_list
126
+ tokens_bos = torch.LongTensor([hparams["bos_index"]] + (tokens_list))
127
+ yield tokens_bos
128
+ tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]])
129
+ yield tokens_eos
130
+ tokens = torch.LongTensor(tokens_list)
131
+ yield tokens
132
+
133
+ sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
134
+
135
+ lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
136
+ special_labels = {
137
+ "bos_label": hparams["bos_index"],
138
+ "eos_label": hparams["eos_index"],
139
+ "blank_label": hparams["blank_index"],
140
+ }
141
+ label_encoder.load_or_create(
142
+ path=lab_enc_file,
143
+ from_didatasets=[train_data],
144
+ output_key="char_list",
145
+ special_labels=special_labels,
146
+ sequence_input=True,
147
+ )
148
+
149
+ # 4. Set output:
150
+ sb.dataio.dataset.set_output_keys(
151
+ datasets,
152
+ ["id", "sig", "wrd", "char_list", "tokens_bos", "tokens_eos", "tokens"],
153
+ )
154
+ return train_data, valid_data, test_datasets, label_encoder
155
+
156
+
157
+ class ASR(sb.Brain):
158
+ def compute_forward(self, batch, stage):
159
+ """Forward computations from the waveform batches to the output probabilities."""
160
+ batch = batch.to(self.device)
161
+ wavs, wav_lens = batch.sig
162
+ print(wavs)
163
+ tokens_bos, _ = batch.tokens_bos
164
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
165
+
166
+ # Forward pass
167
+ feats = self.modules.wav2vec2(wavs)
168
+ x = self.modules.enc(feats)
169
+ # Compute outputs
170
+ p_tokens = None
171
+ logits = self.modules.ctc_lin(x)
172
+ p_ctc = self.hparams.log_softmax(logits)
173
+ if stage != sb.Stage.TRAIN:
174
+ p_tokens = sb.decoders.ctc_greedy_decode(
175
+ p_ctc, wav_lens, blank_id=self.hparams.blank_index
176
+ )
177
+ return p_ctc, wav_lens, p_tokens
178
+
179
+ def treat_wav(self,sig):
180
+ feats = self.modules.wav2vec2(sig.to(self.device))
181
+ x = self.modules.enc(feats)
182
+ p_tokens = None
183
+ logits = self.modules.ctc_lin(x)
184
+ p_ctc = self.hparams.log_softmax(logits)
185
+ predicted_words =[]
186
+ for logs in p_ctc:
187
+ text = decoder.decode(logs.detach().cpu().numpy())
188
+ predicted_words.append(text.split(" "))
189
+ return " ".join(predicted_words[0])
190
+
191
+
192
+
193
+
194
+ def compute_objectives(self, predictions, batch, stage):
195
+ """Computes the loss (CTC+NLL) given predictions and targets."""
196
+
197
+ p_ctc, wav_lens, predicted_tokens = predictions
198
+
199
+ ids = batch.id
200
+ tokens_eos, tokens_eos_lens = batch.tokens_eos
201
+ tokens, tokens_lens = batch.tokens
202
+
203
+ if hasattr(self.modules, "env_corrupt") and stage == sb.Stage.TRAIN:
204
+ tokens_eos = torch.cat([tokens_eos, tokens_eos], dim=0)
205
+ tokens_eos_lens = torch.cat(
206
+ [tokens_eos_lens, tokens_eos_lens], dim=0
207
+ )
208
+ tokens = torch.cat([tokens, tokens], dim=0)
209
+ tokens_lens = torch.cat([tokens_lens, tokens_lens], dim=0)
210
+
211
+ loss_ctc = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
212
+ loss = loss_ctc
213
+ if stage != sb.Stage.TRAIN:
214
+ # Decode token terms to words
215
+ predicted_words =[]
216
+ for logs in p_ctc:
217
+ text = decoder.decode(logs.detach().cpu().numpy())
218
+ predicted_words.append(text.split(" "))
219
+
220
+ target_words = [wrd.split(" ") for wrd in batch.wrd]
221
+ self.wer_metric.append(ids, predicted_words, target_words)
222
+ self.cer_metric.append(ids, predicted_words, target_words)
223
+
224
+ return loss
225
+
226
+ def fit_batch(self, batch):
227
+ """Train the parameters given a single batch in input"""
228
+ predictions = self.compute_forward(batch, sb.Stage.TRAIN)
229
+ loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
230
+ loss.backward()
231
+ if self.check_gradients(loss):
232
+ self.wav2vec_optimizer.step()
233
+ self.model_optimizer.step()
234
+
235
+ self.wav2vec_optimizer.zero_grad()
236
+ self.model_optimizer.zero_grad()
237
+
238
+ return loss.detach()
239
+
240
+ def evaluate_batch(self, batch, stage):
241
+ """Computations needed for validation/test batches"""
242
+ predictions = self.compute_forward(batch, stage=stage)
243
+ with torch.no_grad():
244
+ loss = self.compute_objectives(predictions, batch, stage=stage)
245
+ return loss.detach()
246
+
247
+ def on_stage_start(self, stage, epoch):
248
+ """Gets called at the beginning of each epoch"""
249
+ if stage != sb.Stage.TRAIN:
250
+ self.cer_metric = self.hparams.cer_computer()
251
+ self.wer_metric = self.hparams.error_rate_computer()
252
+
253
+ def on_stage_end(self, stage, stage_loss, epoch):
254
+ """Gets called at the end of an epoch."""
255
+ # Compute/store important stats
256
+ stage_stats = {"loss": stage_loss}
257
+ if stage == sb.Stage.TRAIN:
258
+ self.train_stats = stage_stats
259
+ else:
260
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
261
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
262
+
263
+ # Perform end-of-iteration things, like annealing, logging, etc.
264
+ if stage == sb.Stage.VALID:
265
+ old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
266
+ stage_stats["loss"]
267
+ )
268
+ old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(
269
+ stage_stats["loss"]
270
+ )
271
+ sb.nnet.schedulers.update_learning_rate(
272
+ self.model_optimizer, new_lr_model
273
+ )
274
+ sb.nnet.schedulers.update_learning_rate(
275
+ self.wav2vec_optimizer, new_lr_wav2vec
276
+ )
277
+ self.hparams.train_logger.log_stats(
278
+ stats_meta={
279
+ "epoch": epoch,
280
+ "lr_model": old_lr_model,
281
+ "lr_wav2vec": old_lr_wav2vec,
282
+ },
283
+ train_stats=self.train_stats,
284
+ valid_stats=stage_stats,
285
+ )
286
+ self.checkpointer.save_and_keep_only(
287
+ meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
288
+ )
289
+ elif stage == sb.Stage.TEST:
290
+ self.hparams.train_logger.log_stats(
291
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
292
+ test_stats=stage_stats,
293
+ )
294
+ with open(self.hparams.wer_file, "w") as w:
295
+ self.wer_metric.write_stats(w)
296
+
297
+ def init_optimizers(self):
298
+ "Initializes the wav2vec2 optimizer and model optimizer"
299
+ self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
300
+ self.modules.wav2vec2.parameters()
301
+ )
302
+ self.model_optimizer = self.hparams.model_opt_class(
303
+ self.hparams.model.parameters()
304
+ )
305
+
306
+ if self.checkpointer is not None:
307
+ self.checkpointer.add_recoverable(
308
+ "wav2vec_opt", self.wav2vec_optimizer
309
+ )
310
+ self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
311
+
312
+ label_encoder = sb.dataio.encoder.CTCTextEncoder()
313
+
314
+ train_data, valid_data, test_datasets, label_encoder = dataio_prepare(
315
+ hparams
316
+ )
317
+
318
+
319
+ # We dynamicaly add the tokenizer to our brain class.
320
+ # NB: This tokenizer corresponds to the one used for the LM!!
321
+ decoder = build_ctcdecoder(
322
+ labels,
323
+ kenlm_model_path="tunisian.arpa", # either .arpa or .bin file
324
+ alpha=0.5, # tuned on a val set
325
+ beta=1, # tuned on a val set
326
+ )
327
+
328
+ asr_brain = ASR(
329
+ modules=hparams["modules"],
330
+ hparams=hparams,
331
+ run_opts=run_opts,
332
+ checkpointer=hparams["checkpointer"],
333
+ )
334
+ asr_brain.device= "cpu"
335
+ asr_brain.modules.to("cpu")
336
+ asr_brain.tokenizer = label_encoder
337
+
338
+ from enum import Enum, auto
339
+ class Stage(Enum):
340
+ TRAIN = auto()
341
+ VALID = auto()
342
+ TEST = auto()
343
+
344
+ asr_brain.on_evaluate_start()
345
+ asr_brain.modules.eval()
346
+ import gradio as gr
347
+ def treat_wav_file(file_mic, file_upload, resamplers = resamplers,asr=asr_brain, device="cpu") :
348
+
349
+ if (file_mic is not None) and (file_upload is not None):
350
+ warn_output = "WARNING: You've uploaded an audio file and used the microphone. The recorded file from the microphone will be used and the uploaded audio will be discarded.\n"
351
+ wav = file_mic
352
+ elif (file_mic is None) and (file_upload is None):
353
+ return "ERROR: You have to either use the microphone or upload an audio file"
354
+ elif file_mic is not None:
355
+ wav = file_mic
356
+ else:
357
+ wav = file_upload
358
+ sig, sr = torchaudio.load(wav)
359
+ tensor_wav = sig.to(device)
360
+ resampled = resamplers[str(sr)](tensor_wav)
361
+ sentence = asr_brain.treat_wav(resampled)
362
+ return sentence
363
+
364
+ gr.Interface(
365
+ fn=treat_wav_file,
366
+ inputs=[gr.inputs.Audio(source="microphone", type='filepath', optional=True),
367
+ gr.inputs.Audio(source="upload", type='filepath', optional=True)]
368
+ ,outputs="text").launch()
369
+
370
+
371
+
ctc_train.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env/python3
2
+ """Recipe for training a wav2vec-based ctc ASR system with librispeech.
3
+ The system employs wav2vec as its encoder. Decoding is performed with
4
+ ctc greedy decoder.
5
+ To run this recipe, do the following:
6
+ > python train_with_wav2vec.py hparams/train_with_wav2vec.yaml
7
+ The neural network is trained on CTC likelihood target and character units
8
+ are used as basic recognition tokens. Training is performed on the full
9
+ LibriSpeech dataset (960 h).
10
+
11
+ Authors
12
+ * Sung-Lin Yeh 2021
13
+ * Titouan Parcollet 2021
14
+ * Ju-Chieh Chou 2020
15
+ * Mirco Ravanelli 2020
16
+ * Abdel Heba 2020
17
+ * Peter Plantinga 2020
18
+ * Samuele Cornell 2020
19
+ """
20
+
21
+ import os
22
+ import sys
23
+ import torch
24
+ import logging
25
+ import speechbrain as sb
26
+ from speechbrain.utils.distributed import run_on_main
27
+ from hyperpyyaml import load_hyperpyyaml
28
+ from pathlib import Path
29
+ import torchaudio.transforms as T
30
+ logger = logging.getLogger(__name__)
31
+
32
+ # Define training procedure
33
+ class ASR(sb.Brain):
34
+ def compute_forward(self, batch, stage):
35
+ """Forward computations from the waveform batches to the output probabilities."""
36
+ batch = batch.to(self.device)
37
+ wavs, wav_lens = batch.sig
38
+ tokens_bos, _ = batch.tokens_bos
39
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
40
+
41
+ # Forward pass
42
+ feats = self.modules.wav2vec2(wavs)
43
+ x = self.modules.enc(feats)
44
+ # Compute outputs
45
+ p_tokens = None
46
+ logits = self.modules.ctc_lin(x)
47
+ p_ctc = self.hparams.log_softmax(logits)
48
+ if stage != sb.Stage.TRAIN:
49
+ p_tokens = sb.decoders.ctc_greedy_decode(
50
+ p_ctc, wav_lens, blank_id=self.hparams.blank_index
51
+ )
52
+ return p_ctc, wav_lens, p_tokens
53
+
54
+ def compute_objectives(self, predictions, batch, stage):
55
+ """Computes the loss (CTC+NLL) given predictions and targets."""
56
+
57
+ p_ctc, wav_lens, predicted_tokens = predictions
58
+
59
+ ids = batch.id
60
+ tokens_eos, tokens_eos_lens = batch.tokens_eos
61
+ tokens, tokens_lens = batch.tokens
62
+
63
+ if hasattr(self.modules, "env_corrupt") and stage == sb.Stage.TRAIN:
64
+ tokens_eos = torch.cat([tokens_eos, tokens_eos], dim=0)
65
+ tokens_eos_lens = torch.cat(
66
+ [tokens_eos_lens, tokens_eos_lens], dim=0
67
+ )
68
+ tokens = torch.cat([tokens, tokens], dim=0)
69
+ tokens_lens = torch.cat([tokens_lens, tokens_lens], dim=0)
70
+
71
+ loss_ctc = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
72
+ loss = loss_ctc
73
+
74
+ if stage != sb.Stage.TRAIN:
75
+ # Decode token terms to words
76
+ predicted_words = [
77
+ "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
78
+ for utt_seq in predicted_tokens
79
+ ]
80
+ target_words = [wrd.split(" ") for wrd in batch.wrd]
81
+ self.wer_metric.append(ids, predicted_words, target_words)
82
+ self.cer_metric.append(ids, predicted_words, target_words)
83
+
84
+ return loss
85
+
86
+ def fit_batch(self, batch):
87
+ """Train the parameters given a single batch in input"""
88
+ predictions = self.compute_forward(batch, sb.Stage.TRAIN)
89
+ loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
90
+ loss.backward()
91
+ if self.check_gradients(loss):
92
+ self.wav2vec_optimizer.step()
93
+ self.model_optimizer.step()
94
+
95
+ self.wav2vec_optimizer.zero_grad()
96
+ self.model_optimizer.zero_grad()
97
+
98
+ return loss.detach()
99
+
100
+ def evaluate_batch(self, batch, stage):
101
+ """Computations needed for validation/test batches"""
102
+ predictions = self.compute_forward(batch, stage=stage)
103
+ with torch.no_grad():
104
+ loss = self.compute_objectives(predictions, batch, stage=stage)
105
+ return loss.detach()
106
+
107
+ def on_stage_start(self, stage, epoch):
108
+ """Gets called at the beginning of each epoch"""
109
+ if stage != sb.Stage.TRAIN:
110
+ self.cer_metric = self.hparams.cer_computer()
111
+ self.wer_metric = self.hparams.error_rate_computer()
112
+
113
+ def on_stage_end(self, stage, stage_loss, epoch):
114
+ """Gets called at the end of an epoch."""
115
+ # Compute/store important stats
116
+ stage_stats = {"loss": stage_loss}
117
+ if stage == sb.Stage.TRAIN:
118
+ self.train_stats = stage_stats
119
+ else:
120
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
121
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
122
+
123
+ # Perform end-of-iteration things, like annealing, logging, etc.
124
+ if stage == sb.Stage.VALID:
125
+ old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
126
+ stage_stats["loss"]
127
+ )
128
+ old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(
129
+ stage_stats["loss"]
130
+ )
131
+ sb.nnet.schedulers.update_learning_rate(
132
+ self.model_optimizer, new_lr_model
133
+ )
134
+ sb.nnet.schedulers.update_learning_rate(
135
+ self.wav2vec_optimizer, new_lr_wav2vec
136
+ )
137
+ self.hparams.train_logger.log_stats(
138
+ stats_meta={
139
+ "epoch": epoch,
140
+ "lr_model": old_lr_model,
141
+ "lr_wav2vec": old_lr_wav2vec,
142
+ },
143
+ train_stats=self.train_stats,
144
+ valid_stats=stage_stats,
145
+ )
146
+ self.checkpointer.save_and_keep_only(
147
+ meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
148
+ )
149
+ elif stage == sb.Stage.TEST:
150
+ self.hparams.train_logger.log_stats(
151
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
152
+ test_stats=stage_stats,
153
+ )
154
+ with open(self.hparams.wer_file, "w") as w:
155
+ self.wer_metric.write_stats(w)
156
+
157
+ def init_optimizers(self):
158
+ "Initializes the wav2vec2 optimizer and model optimizer"
159
+ self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
160
+ self.modules.wav2vec2.parameters()
161
+ )
162
+ self.model_optimizer = self.hparams.model_opt_class(
163
+ self.hparams.model.parameters()
164
+ )
165
+
166
+ if self.checkpointer is not None:
167
+ self.checkpointer.add_recoverable(
168
+ "wav2vec_opt", self.wav2vec_optimizer
169
+ )
170
+ self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
171
+
172
+
173
+ def dataio_prepare(hparams):
174
+ """This function prepares the datasets to be used in the brain class.
175
+ It also defines the data processing pipeline through user-defined functions."""
176
+ data_folder = hparams["data_folder"]
177
+
178
+ train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
179
+ csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
180
+ )
181
+
182
+ if hparams["sorting"] == "ascending":
183
+ # we sort training data to speed up training and get better results.
184
+ train_data = train_data.filtered_sorted(sort_key="duration")
185
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
186
+ hparams["train_dataloader_opts"]["shuffle"] = False
187
+
188
+ elif hparams["sorting"] == "descending":
189
+ train_data = train_data.filtered_sorted(
190
+ sort_key="duration", reverse=True
191
+ )
192
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
193
+ hparams["train_dataloader_opts"]["shuffle"] = False
194
+
195
+ elif hparams["sorting"] == "random":
196
+ pass
197
+
198
+ else:
199
+ raise NotImplementedError(
200
+ "sorting must be random, ascending or descending"
201
+ )
202
+
203
+ valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
204
+ csv_path=hparams["valid_csv"], replacements={"data_root": data_folder},
205
+ )
206
+ valid_data = valid_data.filtered_sorted(sort_key="duration")
207
+
208
+ # test is separate
209
+ test_datasets = {}
210
+ for csv_file in hparams["test_csv"]:
211
+ name = Path(csv_file).stem
212
+ test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(
213
+ csv_path=csv_file, replacements={"data_root": data_folder}
214
+ )
215
+ test_datasets[name] = test_datasets[name].filtered_sorted(
216
+ sort_key="duration"
217
+ )
218
+
219
+ datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]
220
+
221
+ # 2. Define audio pipeline:
222
+ @sb.utils.data_pipeline.takes("wav", "sr")
223
+ @sb.utils.data_pipeline.provides("sig")
224
+ def audio_pipeline(wav, sr):
225
+ sig = sb.dataio.dataio.read_audio(wav)
226
+ sig = resamplers[sr](sig)
227
+ return sig
228
+
229
+ sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
230
+ label_encoder = sb.dataio.encoder.CTCTextEncoder()
231
+
232
+ # 3. Define text pipeline:
233
+ @sb.utils.data_pipeline.takes("wrd")
234
+ @sb.utils.data_pipeline.provides(
235
+ "wrd", "char_list", "tokens_list", "tokens_bos", "tokens_eos", "tokens"
236
+ )
237
+ def text_pipeline(wrd):
238
+ yield wrd
239
+ char_list = list(wrd)
240
+ yield char_list
241
+ tokens_list = label_encoder.encode_sequence(char_list)
242
+ yield tokens_list
243
+ tokens_bos = torch.LongTensor([hparams["bos_index"]] + (tokens_list))
244
+ yield tokens_bos
245
+ tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]])
246
+ yield tokens_eos
247
+ tokens = torch.LongTensor(tokens_list)
248
+ yield tokens
249
+
250
+ sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
251
+
252
+ lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
253
+ special_labels = {
254
+ "bos_label": hparams["bos_index"],
255
+ "eos_label": hparams["eos_index"],
256
+ "blank_label": hparams["blank_index"],
257
+ }
258
+ label_encoder.load_or_create(
259
+ path=lab_enc_file,
260
+ from_didatasets=[train_data],
261
+ output_key="char_list",
262
+ special_labels=special_labels,
263
+ sequence_input=True,
264
+ )
265
+
266
+ # 4. Set output:
267
+ sb.dataio.dataset.set_output_keys(
268
+ datasets,
269
+ ["id", "sig", "wrd", "char_list", "tokens_bos", "tokens_eos", "tokens"],
270
+ )
271
+ return train_data, valid_data, test_datasets, label_encoder
272
+
273
+
274
+ if __name__ == "__main__":
275
+
276
+ # CLI:
277
+ hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
278
+
279
+ # If distributed_launch=True then
280
+ # create ddp_group with the right communication protocol
281
+ sb.utils.distributed.ddp_init_group(run_opts)
282
+
283
+ with open(hparams_file) as fin:
284
+ hparams = load_hyperpyyaml(fin, overrides)
285
+
286
+ # Create experiment directory
287
+ sb.create_experiment_directory(
288
+ experiment_directory=hparams["output_folder"],
289
+ hyperparams_to_save=hparams_file,
290
+ overrides=overrides,
291
+ )
292
+
293
+ # Dataset prep (parsing Librispeech)
294
+
295
+ resampler_8000 = T.Resample(8000, 16000, dtype=torch.float)
296
+
297
+ resampler_44100 =T.Resample(44100, 16000, dtype=torch.float)
298
+ resampler_32000 =T.Resample(32000, 16000, dtype=torch.float)
299
+ resampler_48000 =T.Resample(48000, 16000, dtype=torch.float)
300
+
301
+
302
+ resamplers = {"48000": resampler_48000,"8000": resampler_8000, "44100":resampler_44100, "32000":resampler_32000}
303
+
304
+ # here we create the datasets objects as well as tokenization and encoding
305
+ train_data, valid_data, test_datasets, label_encoder = dataio_prepare(
306
+ hparams
307
+ )
308
+
309
+ # Trainer initialization
310
+ asr_brain = ASR(
311
+ modules=hparams["modules"],
312
+ hparams=hparams,
313
+ run_opts=run_opts,
314
+ checkpointer=hparams["checkpointer"],
315
+ )
316
+ asr_brain.device= "cpu"
317
+ asr_brain.modules.to("cpu")
318
+
319
+ # We dynamicaly add the tokenizer to our brain class.
320
+ # NB: This tokenizer corresponds to the one used for the LM!!
321
+ asr_brain.tokenizer = label_encoder
322
+
323
+ # Training
324
+ asr_brain.fit(
325
+ asr_brain.hparams.epoch_counter,
326
+ train_data,
327
+ valid_data,
328
+ train_loader_kwargs=hparams["train_dataloader_opts"],
329
+ valid_loader_kwargs=hparams["valid_dataloader_opts"],
330
+ )
331
+
332
+ # Testing
333
+ for k in test_datasets.keys(): # keys are test_clean, test_other etc
334
+ asr_brain.hparams.wer_file = os.path.join(
335
+ hparams["output_folder"], "wer_{}.txt".format(k)
336
+ )
337
+ asr_brain.evaluate(
338
+ test_datasets[k], test_loader_kwargs=hparams["test_dataloader_opts"]
339
+ )
debugging.csv ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ID,duration,wav,sr,wrd
2
+ Salah1,4.96,samples/Salah1.wav,48000,نحب ماكلة بنينة كسكروت نظيف و رخيص
3
+ Salah2,3.84,samples/Salah2.wav,48000,باهي وقتاش نمشيو ال تونس
4
+ Salah3,4.32,samples/Salah3.wav,48000,اعطيني خمسة الاف و خمسة ميا بلاهي
5
+ Salah4,3.52,samples/Salah4.wav,48000,تعبت هاني راكش في الدار
6
+ Salah5,3.52,samples/Salah5.wav,48000,نهار السبت ماشي نقرى ان شاء الله
7
+ Salah6,4.0,samples/Salah6.wav,48000,زعما نلقى أحمد في الستاد ولا ماهوش هوني
8
+ Salah7,3.84,samples/Salah7.wav,48000,نحب نمشي ال بنزرت نرتاح شوية
9
+ Salah8,4.16,samples/Salah8.wav,48000,حكيت مع لولاد قالولي كل شي مريقل نهار السبت
10
+ Salah9,4.8,samples/Salah9.wav,48000,ناكل كفتاجي و نجم نشري شوية حوت زادة
11
+ Salah10,4.0,samples/Salah10.wav,48000,انتي خويا و عشيري صالح نحبك
file.wav ADDED
Binary file (288 kB). View file
 
lm_decoded_ctc.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env/python3
2
+ """Recipe for training a wav2vec-based ctc ASR system with librispeech.
3
+ The system employs wav2vec as its encoder. Decoding is performed with
4
+ ctc greedy decoder.
5
+ To run this recipe, do the following:
6
+ > python train_with_wav2vec.py hparams/train_with_wav2vec.yaml
7
+ The neural network is trained on CTC likelihood target and character units
8
+ are used as basic recognition tokens. Training is performed on the full
9
+ LibriSpeech dataset (960 h).
10
+
11
+ Authors
12
+ * Sung-Lin Yeh 2021
13
+ * Titouan Parcollet 2021
14
+ * Ju-Chieh Chou 2020
15
+ * Mirco Ravanelli 2020
16
+ * Abdel Heba 2020
17
+ * Peter Plantinga 2020
18
+ * Samuele Cornell 2020
19
+ """
20
+
21
+ import os
22
+ import sys
23
+ import torch
24
+ import logging
25
+ import speechbrain as sb
26
+ from speechbrain.utils.distributed import run_on_main
27
+ from hyperpyyaml import load_hyperpyyaml
28
+ from pathlib import Path
29
+ from pyctcdecode import build_ctcdecoder
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ # Define training procedure
35
+ class ASR(sb.Brain):
36
+ def compute_forward(self, batch, stage):
37
+ """Forward computations from the waveform batches to the output probabilities."""
38
+ batch = batch.to(self.device)
39
+ wavs, wav_lens = batch.sig
40
+ tokens_bos, _ = batch.tokens_bos
41
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
42
+
43
+ # Forward pass
44
+ feats = self.modules.wav2vec2(wavs)
45
+
46
+ x = self.modules.enc(feats.detach())[0]
47
+ #x = self.modules.enc(feats.detach())
48
+ # Compute outputs
49
+ p_tokens = None
50
+ logits = self.modules.ctc_lin(x)
51
+ p_ctc = self.hparams.log_softmax(logits)
52
+ if stage != sb.Stage.TRAIN:
53
+ p_tokens = sb.decoders.ctc_greedy_decode(
54
+ p_ctc, wav_lens, blank_id=self.hparams.blank_index
55
+ )
56
+ return p_ctc, wav_lens, p_tokens
57
+
58
+ def compute_objectives(self, predictions, batch, stage):
59
+ """Computes the loss (CTC+NLL) given predictions and targets."""
60
+
61
+ p_ctc, wav_lens, predicted_tokens = predictions
62
+
63
+ ids = batch.id
64
+ tokens_eos, tokens_eos_lens = batch.tokens_eos
65
+ tokens, tokens_lens = batch.tokens
66
+
67
+ if hasattr(self.modules, "env_corrupt") and stage == sb.Stage.TRAIN:
68
+ tokens_eos = torch.cat([tokens_eos, tokens_eos], dim=0)
69
+ tokens_eos_lens = torch.cat(
70
+ [tokens_eos_lens, tokens_eos_lens], dim=0
71
+ )
72
+ tokens = torch.cat([tokens, tokens], dim=0)
73
+ tokens_lens = torch.cat([tokens_lens, tokens_lens], dim=0)
74
+
75
+ loss_ctc = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
76
+ loss = loss_ctc
77
+
78
+ if stage != sb.Stage.TRAIN:
79
+ # Decode token terms to words
80
+ predicted_words = [
81
+ "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
82
+ for utt_seq in predicted_tokens
83
+ ]
84
+ predicted_words =[]
85
+ for logs in p_ctc:
86
+ text = decoder.decode(logs.detach().cpu().numpy())
87
+ predicted_words.append(text.split(" "))
88
+
89
+ target_words = [wrd.split(" ") for wrd in batch.wrd]
90
+ self.wer_metric.append(ids, predicted_words, target_words)
91
+ self.cer_metric.append(ids, predicted_words, target_words)
92
+
93
+ return loss
94
+
95
+ def fit_batch(self, batch):
96
+ """Train the parameters given a single batch in input"""
97
+ predictions = self.compute_forward(batch, sb.Stage.TRAIN)
98
+ loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
99
+ loss.backward()
100
+ if self.check_gradients(loss):
101
+ self.wav2vec_optimizer.step()
102
+ self.model_optimizer.step()
103
+
104
+ self.wav2vec_optimizer.zero_grad()
105
+ self.model_optimizer.zero_grad()
106
+
107
+ return loss.detach()
108
+
109
+ def evaluate_batch(self, batch, stage):
110
+ """Computations needed for validation/test batches"""
111
+ predictions = self.compute_forward(batch, stage=stage)
112
+ with torch.no_grad():
113
+ loss = self.compute_objectives(predictions, batch, stage=stage)
114
+ return loss.detach()
115
+
116
+ def on_stage_start(self, stage, epoch):
117
+ """Gets called at the beginning of each epoch"""
118
+ if stage != sb.Stage.TRAIN:
119
+ self.cer_metric = self.hparams.cer_computer()
120
+ self.wer_metric = self.hparams.error_rate_computer()
121
+
122
+ def on_stage_end(self, stage, stage_loss, epoch):
123
+ """Gets called at the end of an epoch."""
124
+ # Compute/store important stats
125
+ stage_stats = {"loss": stage_loss}
126
+ if stage == sb.Stage.TRAIN:
127
+ self.train_stats = stage_stats
128
+ else:
129
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
130
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
131
+
132
+ # Perform end-of-iteration things, like annealing, logging, etc.
133
+ if stage == sb.Stage.VALID:
134
+ old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
135
+ stage_stats["loss"]
136
+ )
137
+ old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(
138
+ stage_stats["loss"]
139
+ )
140
+ sb.nnet.schedulers.update_learning_rate(
141
+ self.model_optimizer, new_lr_model
142
+ )
143
+ sb.nnet.schedulers.update_learning_rate(
144
+ self.wav2vec_optimizer, new_lr_wav2vec
145
+ )
146
+ self.hparams.train_logger.log_stats(
147
+ stats_meta={
148
+ "epoch": epoch,
149
+ "lr_model": old_lr_model,
150
+ "lr_wav2vec": old_lr_wav2vec,
151
+ },
152
+ train_stats=self.train_stats,
153
+ valid_stats=stage_stats,
154
+ )
155
+ self.checkpointer.save_and_keep_only(
156
+ meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
157
+ )
158
+ elif stage == sb.Stage.TEST:
159
+ self.hparams.train_logger.log_stats(
160
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
161
+ test_stats=stage_stats,
162
+ )
163
+ with open(self.hparams.wer_file, "w") as w:
164
+ self.wer_metric.write_stats(w)
165
+
166
+ def init_optimizers(self):
167
+ "Initializes the wav2vec2 optimizer and model optimizer"
168
+ self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
169
+ self.modules.wav2vec2.parameters()
170
+ )
171
+ self.model_optimizer = self.hparams.model_opt_class(
172
+ self.hparams.model.parameters()
173
+ )
174
+
175
+ if self.checkpointer is not None:
176
+ self.checkpointer.add_recoverable(
177
+ "wav2vec_opt", self.wav2vec_optimizer
178
+ )
179
+ self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
180
+
181
+
182
+ def dataio_prepare(hparams):
183
+ """This function prepares the datasets to be used in the brain class.
184
+ It also defines the data processing pipeline through user-defined functions."""
185
+ data_folder = hparams["data_folder"]
186
+
187
+ train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
188
+ csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
189
+ )
190
+
191
+ if hparams["sorting"] == "ascending":
192
+ # we sort training data to speed up training and get better results.
193
+ train_data = train_data.filtered_sorted(sort_key="duration")
194
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
195
+ hparams["train_dataloader_opts"]["shuffle"] = False
196
+
197
+ elif hparams["sorting"] == "descending":
198
+ train_data = train_data.filtered_sorted(
199
+ sort_key="duration", reverse=True
200
+ )
201
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
202
+ hparams["train_dataloader_opts"]["shuffle"] = False
203
+
204
+ elif hparams["sorting"] == "random":
205
+ pass
206
+
207
+ else:
208
+ raise NotImplementedError(
209
+ "sorting must be random, ascending or descending"
210
+ )
211
+
212
+ valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
213
+ csv_path=hparams["valid_csv"], replacements={"data_root": data_folder},
214
+ )
215
+ valid_data = valid_data.filtered_sorted(sort_key="duration")
216
+
217
+ # test is separate
218
+ test_datasets = {}
219
+ for csv_file in hparams["test_csv"]:
220
+ name = Path(csv_file).stem
221
+ test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(
222
+ csv_path=csv_file, replacements={"data_root": data_folder}
223
+ )
224
+ test_datasets[name] = test_datasets[name].filtered_sorted(
225
+ sort_key="duration"
226
+ )
227
+
228
+ datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]
229
+
230
+ # 2. Define audio pipeline:
231
+ @sb.utils.data_pipeline.takes("wav")
232
+ @sb.utils.data_pipeline.provides("sig")
233
+ def audio_pipeline(wav):
234
+ sig = sb.dataio.dataio.read_audio(wav)
235
+ return sig
236
+
237
+ sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
238
+ label_encoder = sb.dataio.encoder.CTCTextEncoder()
239
+
240
+ # 3. Define text pipeline:
241
+ @sb.utils.data_pipeline.takes("wrd")
242
+ @sb.utils.data_pipeline.provides(
243
+ "wrd", "char_list", "tokens_list", "tokens_bos", "tokens_eos", "tokens"
244
+ )
245
+ def text_pipeline(wrd):
246
+ yield wrd
247
+ char_list = list(wrd)
248
+ yield char_list
249
+ tokens_list = label_encoder.encode_sequence(char_list)
250
+ yield tokens_list
251
+ tokens_bos = torch.LongTensor([hparams["bos_index"]] + (tokens_list))
252
+ yield tokens_bos
253
+ tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]])
254
+ yield tokens_eos
255
+ tokens = torch.LongTensor(tokens_list)
256
+ yield tokens
257
+
258
+ sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
259
+
260
+ lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
261
+ special_labels = {
262
+ "bos_label": hparams["bos_index"],
263
+ "eos_label": hparams["eos_index"],
264
+ "blank_label": hparams["blank_index"],
265
+ }
266
+ label_encoder.load_or_create(
267
+ path=lab_enc_file,
268
+ from_didatasets=[train_data],
269
+ output_key="char_list",
270
+ special_labels=special_labels,
271
+ sequence_input=True,
272
+ )
273
+
274
+ # 4. Set output:
275
+ sb.dataio.dataset.set_output_keys(
276
+ datasets,
277
+ ["id", "sig", "wrd", "char_list", "tokens_bos", "tokens_eos", "tokens"],
278
+ )
279
+ return train_data, valid_data, test_datasets, label_encoder
280
+
281
+
282
+ if __name__ == "__main__":
283
+
284
+ # CLI:
285
+ hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
286
+
287
+ # If distributed_launch=True then
288
+ # create ddp_group with the right communication protocol
289
+ sb.utils.distributed.ddp_init_group(run_opts)
290
+
291
+ with open(hparams_file) as fin:
292
+ hparams = load_hyperpyyaml(fin, overrides)
293
+
294
+ # Create experiment directory
295
+ sb.create_experiment_directory(
296
+ experiment_directory=hparams["output_folder"],
297
+ hyperparams_to_save=hparams_file,
298
+ overrides=overrides,
299
+ )
300
+ def read_labels_file(labels_file):
301
+ with open(labels_file, "r") as lf:
302
+ lines = lf.read().splitlines()
303
+ division = "==="
304
+ numbers = {}
305
+ for line in lines :
306
+ if division in line :
307
+ break
308
+ string, number = line.split("=>")
309
+ number = int(number)
310
+ string = string[1:-2]
311
+ numbers[number] = string
312
+ return [numbers[x] for x in range(len(numbers))]
313
+ labels = read_labels_file(os.path.join(hparams["save_folder"], "label_encoder.txt"))
314
+ print(labels)
315
+ labels = [""] + labels[1:]
316
+ print(len(labels))
317
+ decoder = build_ctcdecoder(
318
+ labels,
319
+ kenlm_model_path="/gpfsstore/rech/nou/uzn19yk/4-gram.arpa", # either .arpa or .bin file
320
+ alpha=0.5, # tuned on a val set
321
+ beta=1.0, # tuned on a val set
322
+ )
323
+
324
+ # Dataset prep (parsing Librispeech)
325
+
326
+ # multi-gpu (ddp) save data preparation
327
+ """
328
+ run_on_main(
329
+ prepare_librispeech,
330
+ kwargs={
331
+ "data_folder": hparams["data_folder"],
332
+ "tr_splits": hparams["train_splits"],
333
+ "dev_splits": hparams["dev_splits"],
334
+ "te_splits": hparams["test_splits"],
335
+ "save_folder": hparams["output_folder"],
336
+ "merge_lst": hparams["train_splits"],
337
+ "merge_name": "train.csv",
338
+ "skip_prep": hparams["skip_prep"],
339
+ },
340
+ )
341
+ """
342
+
343
+ # here we create the datasets objects as well as tokenization and encoding
344
+ train_data, valid_data, test_datasets, label_encoder = dataio_prepare(
345
+ hparams
346
+ )
347
+
348
+ # Trainer initialization
349
+ asr_brain = ASR(
350
+ modules=hparams["modules"],
351
+ hparams=hparams,
352
+ run_opts=run_opts,
353
+ checkpointer=hparams["checkpointer"],
354
+ )
355
+
356
+ # We dynamicaly add the tokenizer to our brain class.
357
+ # NB: This tokenizer corresponds to the one used for the LM!!
358
+ asr_brain.tokenizer = label_encoder
359
+
360
+ # Training
361
+ asr_brain.fit(
362
+ asr_brain.hparams.epoch_counter,
363
+ train_data,
364
+ valid_data,
365
+ train_loader_kwargs=hparams["train_dataloader_opts"],
366
+ valid_loader_kwargs=hparams["valid_dataloader_opts"],
367
+ )
368
+
369
+ # Testing
370
+ for k in test_datasets.keys(): # keys are test_clean, test_other etc
371
+ asr_brain.hparams.wer_file = os.path.join(
372
+ hparams["output_folder"], "wer_{}.txt".format(k)
373
+ )
374
+ asr_brain.evaluate(
375
+ test_datasets[k], test_loader_kwargs=hparams["test_dataloader_opts"]
376
+ )
lm_tunisian.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env/python3
2
+ """Recipe for training a wav2vec-based ctc ASR system with librispeech.
3
+ The system employs wav2vec as its encoder. Decoding is performed with
4
+ ctc greedy decoder.
5
+ To run this recipe, do the following:
6
+ > python train_with_wav2vec.py hparams/train_with_wav2vec.yaml
7
+ The neural network is trained on CTC likelihood target and character units
8
+ are used as basic recognition tokens. Training is performed on the full
9
+ LibriSpeech dataset (960 h).
10
+
11
+ Authors
12
+ * Sung-Lin Yeh 2021
13
+ * Titouan Parcollet 2021
14
+ * Ju-Chieh Chou 2020
15
+ * Mirco Ravanelli 2020
16
+ * Abdel Heba 2020
17
+ * Peter Plantinga 2020
18
+ * Samuele Cornell 2020
19
+ """
20
+
21
+ import os
22
+ import sys
23
+ import torch
24
+ import logging
25
+ import speechbrain as sb
26
+ from speechbrain.utils.distributed import run_on_main
27
+ from hyperpyyaml import load_hyperpyyaml
28
+ from pathlib import Path
29
+ import torchaudio.transforms as T
30
+
31
+ from pyctcdecode import build_ctcdecoder
32
+ logger = logging.getLogger(__name__)
33
+
34
+ # Define training procedure
35
+ class ASR(sb.Brain):
36
+ def compute_forward(self, batch, stage):
37
+ """Forward computations from the waveform batches to the output probabilities."""
38
+ batch = batch.to(self.device)
39
+ wavs, wav_lens = batch.sig
40
+ tokens_bos, _ = batch.tokens_bos
41
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
42
+
43
+ # Forward pass
44
+ feats = self.modules.wav2vec2(wavs)
45
+ x = self.modules.enc(feats)
46
+ # Compute outputs
47
+ p_tokens = None
48
+ logits = self.modules.ctc_lin(x)
49
+ p_ctc = self.hparams.log_softmax(logits)
50
+ if stage != sb.Stage.TRAIN:
51
+ p_tokens = sb.decoders.ctc_greedy_decode(
52
+ p_ctc, wav_lens, blank_id=self.hparams.blank_index
53
+ )
54
+ return p_ctc, wav_lens, p_tokens
55
+
56
+ def compute_objectives(self, predictions, batch, stage):
57
+ """Computes the loss (CTC+NLL) given predictions and targets."""
58
+
59
+ p_ctc, wav_lens, predicted_tokens = predictions
60
+
61
+ ids = batch.id
62
+ tokens_eos, tokens_eos_lens = batch.tokens_eos
63
+ tokens, tokens_lens = batch.tokens
64
+
65
+ if hasattr(self.modules, "env_corrupt") and stage == sb.Stage.TRAIN:
66
+ tokens_eos = torch.cat([tokens_eos, tokens_eos], dim=0)
67
+ tokens_eos_lens = torch.cat(
68
+ [tokens_eos_lens, tokens_eos_lens], dim=0
69
+ )
70
+ tokens = torch.cat([tokens, tokens], dim=0)
71
+ tokens_lens = torch.cat([tokens_lens, tokens_lens], dim=0)
72
+
73
+ loss_ctc = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
74
+ loss = loss_ctc
75
+ if stage != sb.Stage.TRAIN:
76
+ # Decode token terms to words
77
+ predicted_words =[]
78
+ for logs in p_ctc:
79
+ text = decoder.decode(logs.detach().cpu().numpy())
80
+ predicted_words.append(text.split(" "))
81
+
82
+
83
+ target_words = [wrd.split(" ") for wrd in batch.wrd]
84
+ self.wer_metric.append(ids, predicted_words, target_words)
85
+ self.cer_metric.append(ids, predicted_words, target_words)
86
+
87
+ return loss
88
+
89
+ def fit_batch(self, batch):
90
+ """Train the parameters given a single batch in input"""
91
+ predictions = self.compute_forward(batch, sb.Stage.TRAIN)
92
+ loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
93
+ loss.backward()
94
+ if self.check_gradients(loss):
95
+ self.wav2vec_optimizer.step()
96
+ self.model_optimizer.step()
97
+
98
+ self.wav2vec_optimizer.zero_grad()
99
+ self.model_optimizer.zero_grad()
100
+
101
+ return loss.detach()
102
+
103
+ def evaluate_batch(self, batch, stage):
104
+ """Computations needed for validation/test batches"""
105
+ predictions = self.compute_forward(batch, stage=stage)
106
+ with torch.no_grad():
107
+ loss = self.compute_objectives(predictions, batch, stage=stage)
108
+ return loss.detach()
109
+
110
+ def on_stage_start(self, stage, epoch):
111
+ """Gets called at the beginning of each epoch"""
112
+ if stage != sb.Stage.TRAIN:
113
+ self.cer_metric = self.hparams.cer_computer()
114
+ self.wer_metric = self.hparams.error_rate_computer()
115
+
116
+ def on_stage_end(self, stage, stage_loss, epoch):
117
+ """Gets called at the end of an epoch."""
118
+ # Compute/store important stats
119
+ stage_stats = {"loss": stage_loss}
120
+ if stage == sb.Stage.TRAIN:
121
+ self.train_stats = stage_stats
122
+ else:
123
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
124
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
125
+
126
+ # Perform end-of-iteration things, like annealing, logging, etc.
127
+ if stage == sb.Stage.VALID:
128
+ old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
129
+ stage_stats["loss"]
130
+ )
131
+ old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(
132
+ stage_stats["loss"]
133
+ )
134
+ sb.nnet.schedulers.update_learning_rate(
135
+ self.model_optimizer, new_lr_model
136
+ )
137
+ sb.nnet.schedulers.update_learning_rate(
138
+ self.wav2vec_optimizer, new_lr_wav2vec
139
+ )
140
+ self.hparams.train_logger.log_stats(
141
+ stats_meta={
142
+ "epoch": epoch,
143
+ "lr_model": old_lr_model,
144
+ "lr_wav2vec": old_lr_wav2vec,
145
+ },
146
+ train_stats=self.train_stats,
147
+ valid_stats=stage_stats,
148
+ )
149
+ self.checkpointer.save_and_keep_only(
150
+ meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
151
+ )
152
+ elif stage == sb.Stage.TEST:
153
+ self.hparams.train_logger.log_stats(
154
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
155
+ test_stats=stage_stats,
156
+ )
157
+ with open(self.hparams.wer_file, "w") as w:
158
+ self.wer_metric.write_stats(w)
159
+
160
+ def init_optimizers(self):
161
+ "Initializes the wav2vec2 optimizer and model optimizer"
162
+ self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
163
+ self.modules.wav2vec2.parameters()
164
+ )
165
+ self.model_optimizer = self.hparams.model_opt_class(
166
+ self.hparams.model.parameters()
167
+ )
168
+
169
+ if self.checkpointer is not None:
170
+ self.checkpointer.add_recoverable(
171
+ "wav2vec_opt", self.wav2vec_optimizer
172
+ )
173
+ self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
174
+
175
+
176
+ def dataio_prepare(hparams):
177
+ """This function prepares the datasets to be used in the brain class.
178
+ It also defines the data processing pipeline through user-defined functions."""
179
+ data_folder = hparams["data_folder"]
180
+
181
+ train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
182
+ csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
183
+ )
184
+
185
+ if hparams["sorting"] == "ascending":
186
+ # we sort training data to speed up training and get better results.
187
+ train_data = train_data.filtered_sorted(sort_key="duration")
188
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
189
+ hparams["train_dataloader_opts"]["shuffle"] = False
190
+
191
+ elif hparams["sorting"] == "descending":
192
+ train_data = train_data.filtered_sorted(
193
+ sort_key="duration", reverse=True
194
+ )
195
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
196
+ hparams["train_dataloader_opts"]["shuffle"] = False
197
+
198
+ elif hparams["sorting"] == "random":
199
+ pass
200
+
201
+ else:
202
+ raise NotImplementedError(
203
+ "sorting must be random, ascending or descending"
204
+ )
205
+
206
+ valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
207
+ csv_path=hparams["valid_csv"], replacements={"data_root": data_folder},
208
+ )
209
+ valid_data = valid_data.filtered_sorted(sort_key="duration")
210
+
211
+ # test is separate
212
+ test_datasets = {}
213
+ for csv_file in hparams["test_csv"]:
214
+ name = Path(csv_file).stem
215
+ test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(
216
+ csv_path=csv_file, replacements={"data_root": data_folder}
217
+ )
218
+ test_datasets[name] = test_datasets[name].filtered_sorted(
219
+ sort_key="duration"
220
+ )
221
+
222
+ datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]
223
+
224
+ # 2. Define audio pipeline:
225
+ @sb.utils.data_pipeline.takes("wav", "sr")
226
+ @sb.utils.data_pipeline.provides("sig")
227
+ def audio_pipeline(wav, sr):
228
+ sig = sb.dataio.dataio.read_audio(wav)
229
+ sig = resamplers[sr](sig)
230
+ return sig
231
+
232
+ sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
233
+ label_encoder = sb.dataio.encoder.CTCTextEncoder()
234
+
235
+ # 3. Define text pipeline:
236
+ @sb.utils.data_pipeline.takes("wrd")
237
+ @sb.utils.data_pipeline.provides(
238
+ "wrd", "char_list", "tokens_list", "tokens_bos", "tokens_eos", "tokens"
239
+ )
240
+ def text_pipeline(wrd):
241
+ yield wrd
242
+ char_list = list(wrd)
243
+ yield char_list
244
+ tokens_list = label_encoder.encode_sequence(char_list)
245
+ yield tokens_list
246
+ tokens_bos = torch.LongTensor([hparams["bos_index"]] + (tokens_list))
247
+ yield tokens_bos
248
+ tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]])
249
+ yield tokens_eos
250
+ tokens = torch.LongTensor(tokens_list)
251
+ yield tokens
252
+
253
+ sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
254
+
255
+ lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
256
+ special_labels = {
257
+ "bos_label": hparams["bos_index"],
258
+ "eos_label": hparams["eos_index"],
259
+ "blank_label": hparams["blank_index"],
260
+ }
261
+ label_encoder.load_or_create(
262
+ path=lab_enc_file,
263
+ from_didatasets=[train_data],
264
+ output_key="char_list",
265
+ special_labels=special_labels,
266
+ sequence_input=True,
267
+ )
268
+
269
+ # 4. Set output:
270
+ sb.dataio.dataset.set_output_keys(
271
+ datasets,
272
+ ["id", "sig", "wrd", "char_list", "tokens_bos", "tokens_eos", "tokens"],
273
+ )
274
+ return train_data, valid_data, test_datasets, label_encoder
275
+
276
+
277
+ if __name__ == "__main__":
278
+
279
+ # CLI:
280
+ hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
281
+
282
+ # If distributed_launch=True then
283
+ # create ddp_group with the right communication protocol
284
+ sb.utils.distributed.ddp_init_group(run_opts)
285
+
286
+ with open(hparams_file) as fin:
287
+ hparams = load_hyperpyyaml(fin, overrides)
288
+
289
+ # Create experiment directory
290
+ sb.create_experiment_directory(
291
+ experiment_directory=hparams["output_folder"],
292
+ hyperparams_to_save=hparams_file,
293
+ overrides=overrides,
294
+ )
295
+ def read_labels_file(labels_file):
296
+ with open(labels_file, "r") as lf:
297
+ lines = lf.read().splitlines()
298
+ division = "==="
299
+ numbers = {}
300
+ for line in lines :
301
+ if division in line :
302
+ break
303
+ string, number = line.split("=>")
304
+ number = int(number)
305
+ string = string[1:-2]
306
+ numbers[number] = string
307
+ return [numbers[x] for x in range(len(numbers))]
308
+ labels = read_labels_file(os.path.join(hparams["save_folder"], "label_encoder.txt"))
309
+ print(labels)
310
+ labels = [""] + labels[1:]
311
+ print(len(labels))
312
+ decoder = build_ctcdecoder(
313
+ labels,
314
+ kenlm_model_path="tunisian.arpa", # either .arpa or .bin file
315
+ alpha=0.5, # tuned on a val set
316
+ beta=1.0, # tuned on a val set
317
+ )
318
+
319
+ # Dataset prep (parsing Librispeech)
320
+
321
+ resampler_8000 = T.Resample(8000, 16000, dtype=torch.float)
322
+
323
+ resampler_44100 =T.Resample(44100, 16000, dtype=torch.float)
324
+ resampler_48000 =T.Resample(48000, 16000, dtype=torch.float)
325
+ resamplers = {"8000": resampler_8000, "44100":resampler_44100, "48000": resampler_48000}
326
+
327
+ # here we create the datasets objects as well as tokenization and encoding
328
+ train_data, valid_data, test_datasets, label_encoder = dataio_prepare(
329
+ hparams
330
+ )
331
+
332
+ # Trainer initialization
333
+ asr_brain = ASR(
334
+ modules=hparams["modules"],
335
+ hparams=hparams,
336
+ run_opts=run_opts,
337
+ checkpointer=hparams["checkpointer"],
338
+ )
339
+ asr_brain.device= "cpu"
340
+ asr_brain.modules.to("cpu")
341
+ # We dynamicaly add the tokenizer to our brain class.
342
+ # NB: This tokenizer corresponds to the one used for the LM!!
343
+ asr_brain.tokenizer = label_encoder
344
+
345
+ # Training
346
+ asr_brain.fit(
347
+ asr_brain.hparams.epoch_counter,
348
+ train_data,
349
+ valid_data,
350
+ train_loader_kwargs=hparams["train_dataloader_opts"],
351
+ valid_loader_kwargs=hparams["valid_dataloader_opts"],
352
+ )
353
+
354
+ # Testing
355
+ for k in test_datasets.keys(): # keys are test_clean, test_other etc
356
+ asr_brain.hparams.wer_file = os.path.join(
357
+ hparams["output_folder"], "wer_{}.txt".format(k)
358
+ )
359
+ asr_brain.evaluate(
360
+ test_datasets[k], test_loader_kwargs=hparams["test_dataloader_opts"]
361
+ )
partly_frozen_splitted_wavlm/1986/ctc_train.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env/python3
2
+ """Recipe for training a wav2vec-based ctc ASR system with librispeech.
3
+ The system employs wav2vec as its encoder. Decoding is performed with
4
+ ctc greedy decoder.
5
+ To run this recipe, do the following:
6
+ > python train_with_wav2vec.py hparams/train_with_wav2vec.yaml
7
+ The neural network is trained on CTC likelihood target and character units
8
+ are used as basic recognition tokens. Training is performed on the full
9
+ LibriSpeech dataset (960 h).
10
+
11
+ Authors
12
+ * Sung-Lin Yeh 2021
13
+ * Titouan Parcollet 2021
14
+ * Ju-Chieh Chou 2020
15
+ * Mirco Ravanelli 2020
16
+ * Abdel Heba 2020
17
+ * Peter Plantinga 2020
18
+ * Samuele Cornell 2020
19
+ """
20
+
21
+ import os
22
+ import sys
23
+ import torch
24
+ import logging
25
+ import speechbrain as sb
26
+ from speechbrain.utils.distributed import run_on_main
27
+ from hyperpyyaml import load_hyperpyyaml
28
+ from pathlib import Path
29
+ import torchaudio.transforms as T
30
+ logger = logging.getLogger(__name__)
31
+
32
+ # Define training procedure
33
+ class ASR(sb.Brain):
34
+ def compute_forward(self, batch, stage):
35
+ """Forward computations from the waveform batches to the output probabilities."""
36
+ batch = batch.to(self.device)
37
+ wavs, wav_lens = batch.sig
38
+ tokens_bos, _ = batch.tokens_bos
39
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
40
+
41
+ # Forward pass
42
+ feats = self.modules.wav2vec2(wavs)
43
+ x = self.modules.enc(feats)
44
+ # Compute outputs
45
+ p_tokens = None
46
+ logits = self.modules.ctc_lin(x)
47
+ p_ctc = self.hparams.log_softmax(logits)
48
+ if stage != sb.Stage.TRAIN:
49
+ p_tokens = sb.decoders.ctc_greedy_decode(
50
+ p_ctc, wav_lens, blank_id=self.hparams.blank_index
51
+ )
52
+ return p_ctc, wav_lens, p_tokens
53
+
54
+ def compute_objectives(self, predictions, batch, stage):
55
+ """Computes the loss (CTC+NLL) given predictions and targets."""
56
+
57
+ p_ctc, wav_lens, predicted_tokens = predictions
58
+
59
+ ids = batch.id
60
+ tokens_eos, tokens_eos_lens = batch.tokens_eos
61
+ tokens, tokens_lens = batch.tokens
62
+
63
+ if hasattr(self.modules, "env_corrupt") and stage == sb.Stage.TRAIN:
64
+ tokens_eos = torch.cat([tokens_eos, tokens_eos], dim=0)
65
+ tokens_eos_lens = torch.cat(
66
+ [tokens_eos_lens, tokens_eos_lens], dim=0
67
+ )
68
+ tokens = torch.cat([tokens, tokens], dim=0)
69
+ tokens_lens = torch.cat([tokens_lens, tokens_lens], dim=0)
70
+
71
+ loss_ctc = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
72
+ loss = loss_ctc
73
+
74
+ if stage != sb.Stage.TRAIN:
75
+ # Decode token terms to words
76
+ predicted_words = [
77
+ "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
78
+ for utt_seq in predicted_tokens
79
+ ]
80
+ target_words = [wrd.split(" ") for wrd in batch.wrd]
81
+ self.wer_metric.append(ids, predicted_words, target_words)
82
+ self.cer_metric.append(ids, predicted_words, target_words)
83
+
84
+ return loss
85
+
86
+ def fit_batch(self, batch):
87
+ """Train the parameters given a single batch in input"""
88
+ predictions = self.compute_forward(batch, sb.Stage.TRAIN)
89
+ loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
90
+ loss.backward()
91
+ if self.check_gradients(loss):
92
+ self.wav2vec_optimizer.step()
93
+ self.model_optimizer.step()
94
+
95
+ self.wav2vec_optimizer.zero_grad()
96
+ self.model_optimizer.zero_grad()
97
+
98
+ return loss.detach()
99
+
100
+ def evaluate_batch(self, batch, stage):
101
+ """Computations needed for validation/test batches"""
102
+ predictions = self.compute_forward(batch, stage=stage)
103
+ with torch.no_grad():
104
+ loss = self.compute_objectives(predictions, batch, stage=stage)
105
+ return loss.detach()
106
+
107
+ def on_stage_start(self, stage, epoch):
108
+ """Gets called at the beginning of each epoch"""
109
+ if stage != sb.Stage.TRAIN:
110
+ self.cer_metric = self.hparams.cer_computer()
111
+ self.wer_metric = self.hparams.error_rate_computer()
112
+
113
+ def on_stage_end(self, stage, stage_loss, epoch):
114
+ """Gets called at the end of an epoch."""
115
+ # Compute/store important stats
116
+ stage_stats = {"loss": stage_loss}
117
+ if stage == sb.Stage.TRAIN:
118
+ self.train_stats = stage_stats
119
+ else:
120
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
121
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
122
+
123
+ # Perform end-of-iteration things, like annealing, logging, etc.
124
+ if stage == sb.Stage.VALID:
125
+ old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
126
+ stage_stats["loss"]
127
+ )
128
+ old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(
129
+ stage_stats["loss"]
130
+ )
131
+ sb.nnet.schedulers.update_learning_rate(
132
+ self.model_optimizer, new_lr_model
133
+ )
134
+ sb.nnet.schedulers.update_learning_rate(
135
+ self.wav2vec_optimizer, new_lr_wav2vec
136
+ )
137
+ self.hparams.train_logger.log_stats(
138
+ stats_meta={
139
+ "epoch": epoch,
140
+ "lr_model": old_lr_model,
141
+ "lr_wav2vec": old_lr_wav2vec,
142
+ },
143
+ train_stats=self.train_stats,
144
+ valid_stats=stage_stats,
145
+ )
146
+ self.checkpointer.save_and_keep_only(
147
+ meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
148
+ )
149
+ elif stage == sb.Stage.TEST:
150
+ self.hparams.train_logger.log_stats(
151
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
152
+ test_stats=stage_stats,
153
+ )
154
+ with open(self.hparams.wer_file, "w") as w:
155
+ self.wer_metric.write_stats(w)
156
+
157
+ def init_optimizers(self):
158
+ "Initializes the wav2vec2 optimizer and model optimizer"
159
+ self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
160
+ self.modules.wav2vec2.parameters()
161
+ )
162
+ self.model_optimizer = self.hparams.model_opt_class(
163
+ self.hparams.model.parameters()
164
+ )
165
+
166
+ if self.checkpointer is not None:
167
+ self.checkpointer.add_recoverable(
168
+ "wav2vec_opt", self.wav2vec_optimizer
169
+ )
170
+ self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
171
+
172
+
173
+ def dataio_prepare(hparams):
174
+ """This function prepares the datasets to be used in the brain class.
175
+ It also defines the data processing pipeline through user-defined functions."""
176
+ data_folder = hparams["data_folder"]
177
+
178
+ train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
179
+ csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
180
+ )
181
+
182
+ if hparams["sorting"] == "ascending":
183
+ # we sort training data to speed up training and get better results.
184
+ train_data = train_data.filtered_sorted(sort_key="duration")
185
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
186
+ hparams["train_dataloader_opts"]["shuffle"] = False
187
+
188
+ elif hparams["sorting"] == "descending":
189
+ train_data = train_data.filtered_sorted(
190
+ sort_key="duration", reverse=True
191
+ )
192
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
193
+ hparams["train_dataloader_opts"]["shuffle"] = False
194
+
195
+ elif hparams["sorting"] == "random":
196
+ pass
197
+
198
+ else:
199
+ raise NotImplementedError(
200
+ "sorting must be random, ascending or descending"
201
+ )
202
+
203
+ valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
204
+ csv_path=hparams["valid_csv"], replacements={"data_root": data_folder},
205
+ )
206
+ valid_data = valid_data.filtered_sorted(sort_key="duration")
207
+
208
+ # test is separate
209
+ test_datasets = {}
210
+ for csv_file in hparams["test_csv"]:
211
+ name = Path(csv_file).stem
212
+ test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(
213
+ csv_path=csv_file, replacements={"data_root": data_folder}
214
+ )
215
+ test_datasets[name] = test_datasets[name].filtered_sorted(
216
+ sort_key="duration"
217
+ )
218
+
219
+ datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]
220
+
221
+ # 2. Define audio pipeline:
222
+ @sb.utils.data_pipeline.takes("wav", "sr")
223
+ @sb.utils.data_pipeline.provides("sig")
224
+ def audio_pipeline(wav, sr):
225
+ sig = sb.dataio.dataio.read_audio(wav)
226
+ sig = resamplers[sr](sig)
227
+ return sig
228
+
229
+ sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
230
+ label_encoder = sb.dataio.encoder.CTCTextEncoder()
231
+
232
+ # 3. Define text pipeline:
233
+ @sb.utils.data_pipeline.takes("wrd")
234
+ @sb.utils.data_pipeline.provides(
235
+ "wrd", "char_list", "tokens_list", "tokens_bos", "tokens_eos", "tokens"
236
+ )
237
+ def text_pipeline(wrd):
238
+ yield wrd
239
+ char_list = list(wrd)
240
+ yield char_list
241
+ tokens_list = label_encoder.encode_sequence(char_list)
242
+ yield tokens_list
243
+ tokens_bos = torch.LongTensor([hparams["bos_index"]] + (tokens_list))
244
+ yield tokens_bos
245
+ tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]])
246
+ yield tokens_eos
247
+ tokens = torch.LongTensor(tokens_list)
248
+ yield tokens
249
+
250
+ sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
251
+
252
+ lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
253
+ special_labels = {
254
+ "bos_label": hparams["bos_index"],
255
+ "eos_label": hparams["eos_index"],
256
+ "blank_label": hparams["blank_index"],
257
+ }
258
+ label_encoder.load_or_create(
259
+ path=lab_enc_file,
260
+ from_didatasets=[train_data],
261
+ output_key="char_list",
262
+ special_labels=special_labels,
263
+ sequence_input=True,
264
+ )
265
+
266
+ # 4. Set output:
267
+ sb.dataio.dataset.set_output_keys(
268
+ datasets,
269
+ ["id", "sig", "wrd", "char_list", "tokens_bos", "tokens_eos", "tokens"],
270
+ )
271
+ return train_data, valid_data, test_datasets, label_encoder
272
+
273
+
274
+ if __name__ == "__main__":
275
+
276
+ # CLI:
277
+ hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
278
+
279
+ # If distributed_launch=True then
280
+ # create ddp_group with the right communication protocol
281
+ sb.utils.distributed.ddp_init_group(run_opts)
282
+
283
+ with open(hparams_file) as fin:
284
+ hparams = load_hyperpyyaml(fin, overrides)
285
+
286
+ # Create experiment directory
287
+ sb.create_experiment_directory(
288
+ experiment_directory=hparams["output_folder"],
289
+ hyperparams_to_save=hparams_file,
290
+ overrides=overrides,
291
+ )
292
+
293
+ # Dataset prep (parsing Librispeech)
294
+
295
+ resampler_8000 = T.Resample(8000, 16000, dtype=torch.float)
296
+
297
+ resampler_44100 =T.Resample(44100, 16000, dtype=torch.float)
298
+ resampler_32000 =T.Resample(32000, 16000, dtype=torch.float)
299
+ resampler_48000 =T.Resample(48000, 16000, dtype=torch.float)
300
+
301
+
302
+ resamplers = {"48000": resampler_48000,"8000": resampler_8000, "44100":resampler_44100, "32000":resampler_32000}
303
+
304
+ # here we create the datasets objects as well as tokenization and encoding
305
+ train_data, valid_data, test_datasets, label_encoder = dataio_prepare(
306
+ hparams
307
+ )
308
+
309
+ # Trainer initialization
310
+ asr_brain = ASR(
311
+ modules=hparams["modules"],
312
+ hparams=hparams,
313
+ run_opts=run_opts,
314
+ checkpointer=hparams["checkpointer"],
315
+ )
316
+ asr_brain.device= "cpu"
317
+ asr_brain.modules.to("cpu")
318
+
319
+ # We dynamicaly add the tokenizer to our brain class.
320
+ # NB: This tokenizer corresponds to the one used for the LM!!
321
+ asr_brain.tokenizer = label_encoder
322
+
323
+ # Training
324
+ asr_brain.fit(
325
+ asr_brain.hparams.epoch_counter,
326
+ train_data,
327
+ valid_data,
328
+ train_loader_kwargs=hparams["train_dataloader_opts"],
329
+ valid_loader_kwargs=hparams["valid_dataloader_opts"],
330
+ )
331
+
332
+ # Testing
333
+ for k in test_datasets.keys(): # keys are test_clean, test_other etc
334
+ asr_brain.hparams.wer_file = os.path.join(
335
+ hparams["output_folder"], "wer_{}.txt".format(k)
336
+ )
337
+ asr_brain.evaluate(
338
+ test_datasets[k], test_loader_kwargs=hparams["test_dataloader_opts"]
339
+ )
partly_frozen_splitted_wavlm/1986/env.log ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SpeechBrain system description
2
+ ==============================
3
+ Python version:
4
+ 3.8.5 (default, Sep 4 2020, 07:30:14)
5
+ [GCC 7.3.0]
6
+ ==============================
7
+ Installed Python packages:
8
+ abkhazia==1.0
9
+ absl-py==0.11.0
10
+ aiohttp==3.8.0
11
+ aiosignal==1.2.0
12
+ alabaster==0.7.12
13
+ alembic==1.7.4
14
+ altair==4.2.0
15
+ altgraph==0.17
16
+ antlr4-python3-runtime==4.8
17
+ anyio==3.6.2
18
+ appdirs==1.4.4
19
+ argcomplete==1.12.2
20
+ argon2-cffi==20.1.0
21
+ asgiref==3.6.0
22
+ astunparse==1.6.3
23
+ async-generator==1.10
24
+ async-timeout==4.0.0
25
+ attrdict==2.0.1
26
+ attrs==20.3.0
27
+ audeer==1.16.0
28
+ audformat==0.11.5
29
+ audinterface==0.7.0
30
+ audiofile==1.0.0
31
+ audiomentations==0.25.0
32
+ audioread==2.1.9
33
+ audobject==0.4.14
34
+ audresample==0.1.6
35
+ -e git+https://github.com/facebookresearch/WavAugment.git@54afcdb00ccc852c2f030f239f8532c9562b550e#egg=augment
36
+ autopage==0.4.0
37
+ Babel==2.9.0
38
+ backcall==0.2.0
39
+ beautifulsoup4==4.10.0
40
+ black==19.10b0
41
+ bleach==3.3.0
42
+ boto3==1.20.2
43
+ botocore==1.23.2
44
+ braceexpand==0.1.7
45
+ cachetools==4.2.0
46
+ certifi @ file:///croot/certifi_1671487769961/work/certifi
47
+ cffi==1.14.3
48
+ cfgv==3.2.0
49
+ chardet==3.0.4
50
+ charset-normalizer==2.0.7
51
+ click==7.1.2
52
+ cliff==3.9.0
53
+ clldutils==3.5.4
54
+ cmaes==0.8.2
55
+ cmake==3.18.4.post1
56
+ cmd2==2.2.0
57
+ colorama==0.4.4
58
+ colorlog==4.6.2
59
+ configparser==5.1.0
60
+ cryptography==38.0.4
61
+ csvw==1.8.1
62
+ cycler==0.10.0
63
+ Cython==0.29.21
64
+ dataclasses==0.6
65
+ datasets==1.5.0
66
+ decorator==4.4.2
67
+ deepspeech==0.9.1
68
+ defusedxml==0.7.1
69
+ denoiser==0.1.5
70
+ dill==0.3.3
71
+ Distance==0.1.3
72
+ distlib==0.3.1
73
+ Django==3.2.16
74
+ django-auditlog==2.2.1
75
+ django-filter==22.1
76
+ django-js-asset==1.2.2
77
+ django-mptt==0.14.0
78
+ djangorestframework==3.14.0
79
+ docker-pycreds==0.4.0
80
+ docopt==0.6.2
81
+ docutils==0.16
82
+ drf-excel==2.2.0
83
+ drf-flex-fields==1.0.0
84
+ drf-renderer-xlsx==0.4.1
85
+ easyocr==1.2.1
86
+ editdistance==0.6.0
87
+ emoji==2.2.0
88
+ entrypoints==0.3
89
+ et-xmlfile==1.1.0
90
+ exceptiongroup==1.1.0
91
+ farasapy==0.0.14
92
+ fastapi==0.89.0
93
+ fasttext==0.9.2
94
+ ffmpeg-python==0.2.0
95
+ ffmpy==0.3.0
96
+ filelock==3.0.12
97
+ flake8==3.7.9
98
+ flatbuffers==1.12
99
+ frozendict==2.0.7
100
+ frozenlist==1.2.0
101
+ fsspec==2021.11.0
102
+ future==0.18.2
103
+ g2p-en==2.1.0
104
+ gast==0.3.3
105
+ gdown==4.2.0
106
+ gensim==4.0.1
107
+ gitdb==4.0.9
108
+ GitPython==3.1.24
109
+ google-auth==1.24.0
110
+ google-auth-oauthlib==0.4.2
111
+ google-pasta==0.2.0
112
+ gradio==3.16.0
113
+ greenlet==1.1.2
114
+ grpcio==1.32.0
115
+ h11==0.14.0
116
+ h5features==1.3.2
117
+ h5py==2.10.0
118
+ htk-io==0.5
119
+ httpcore==0.16.3
120
+ httpx==0.23.3
121
+ huggingface-hub==0.9.1
122
+ hydra-colorlog==0.1.4
123
+ hydra-core==0.11.3
124
+ HyperPyYAML==1.1.0
125
+ hypothesis==6.61.2
126
+ identify==1.5.10
127
+ idna==2.10
128
+ imageio==2.9.0
129
+ imagesize==1.2.0
130
+ importlib-metadata==4.8.1
131
+ importlib-resources==5.2.2
132
+ inflect==5.3.0
133
+ ipadic==1.0.0
134
+ ipykernel==5.3.4
135
+ ipython==7.19.0
136
+ ipython-genutils==0.2.0
137
+ ipywebrtc==0.6.0
138
+ ipywidgets==7.6.3
139
+ iso-639==0.4.5
140
+ isodate==0.6.0
141
+ isort==4.3.21
142
+ jedi==0.17.2
143
+ jieba==0.42.1
144
+ Jinja2==2.11.2
145
+ jiwer==2.2.0
146
+ jmespath==0.10.0
147
+ joblib==0.17.0
148
+ jsonschema==3.2.0
149
+ julius==0.2.7
150
+ jupyter-client==6.1.7
151
+ jupyter-core==4.7.0
152
+ jupyterlab-pygments==0.1.2
153
+ jupyterlab-widgets==1.0.0
154
+ kaitaistruct==0.9
155
+ kaldi-io==0.9.4
156
+ kaldi-python-io==1.2.2
157
+ kaldiio==2.17.2
158
+ kenlm @ https://github.com/kpu/kenlm/archive/master.zip
159
+ Keras-Preprocessing==1.1.2
160
+ kiwisolver==1.3.1
161
+ lang-trans==0.6.0
162
+ latexcodec==2.0.1
163
+ ldap3==2.9.1
164
+ librosa==0.9.0
165
+ linkify-it-py==1.0.3
166
+ llvmlite==0.35.0
167
+ lxml==4.9.0
168
+ Mako==1.1.5
169
+ Markdown==3.3.3
170
+ markdown-it-py==2.1.0
171
+ MarkupSafe==1.1.1
172
+ marshmallow==3.14.0
173
+ matplotlib==3.3.3
174
+ mccabe==0.6.1
175
+ mcd==0.4
176
+ mdit-py-plugins==0.3.3
177
+ mdurl==0.1.2
178
+ mecab-python3==1.0.3
179
+ megatron-lm==2.2.0
180
+ mido==1.2.10
181
+ mistune==0.8.4
182
+ more-itertools==8.6.0
183
+ mpmath==1.2.1
184
+ multidict==5.2.0
185
+ multiprocess==0.70.11.1
186
+ nbclient==0.5.3
187
+ nbconvert==6.0.7
188
+ nbformat==5.1.3
189
+ NEMO==4.3.2
190
+ nemo-toolkit==1.4.0
191
+ nest-asyncio==1.5.1
192
+ networkx==2.5
193
+ nltk==3.5
194
+ nodeenv==1.5.0
195
+ notebook==6.3.0
196
+ numba==0.52.0
197
+ numpy==1.19.4
198
+ nvidia-cublas-cu11==11.10.3.66
199
+ nvidia-cuda-nvrtc-cu11==11.7.99
200
+ nvidia-cuda-runtime-cu11==11.7.99
201
+ nvidia-cudnn-cu11==8.5.0.96
202
+ oauthlib==3.1.0
203
+ omegaconf==1.4.1
204
+ onnx==1.10.2
205
+ OpenCC==1.1.2
206
+ opencv-python==4.4.0.46
207
+ openpyxl==3.0.9
208
+ opensmile==2.2.0
209
+ opt-einsum==3.3.0
210
+ optuna==2.10.0
211
+ orjson==3.8.4
212
+ oyaml==1.0
213
+ packaging==22.0
214
+ pandas==1.2.5
215
+ pandocfilters==1.4.3
216
+ pangu==4.0.6.1
217
+ parameterized==0.8.1
218
+ parso==0.7.1
219
+ pathspec==0.8.1
220
+ pathtools==0.1.2
221
+ pbr==5.6.0
222
+ pefile==2019.4.18
223
+ pescador==2.1.0
224
+ pesq==0.0.3
225
+ pexpect==4.8.0
226
+ phonemizer==2.2.1
227
+ pickleshare==0.7.5
228
+ Pillow==9.3.0
229
+ pip-api==0.0.23
230
+ pipreqs==0.4.11
231
+ pluggy==0.13.1
232
+ pooch==1.3.0
233
+ portalocker==2.3.2
234
+ pre-commit==2.9.0
235
+ pretty-midi==0.2.9
236
+ prettytable==2.2.1
237
+ progressbar2==3.53.1
238
+ prometheus-client==0.10.1
239
+ promise==2.3
240
+ prompt-toolkit==3.0.8
241
+ protobuf==3.14.0
242
+ psutil==5.6.6
243
+ ptyprocess==0.6.0
244
+ py==1.9.0
245
+ py-espeak-ng==0.1.8
246
+ pyannote.audio==1.1.1
247
+ pyannote.core==4.3
248
+ pyannote.database==4.1.1
249
+ pyannote.metrics==3.1
250
+ pyannote.pipeline==1.5.2
251
+ PyArabic==0.6.15
252
+ pyarrow==3.0.0
253
+ pyasn1==0.4.8
254
+ pyasn1-modules==0.2.8
255
+ pybind11==2.8.1
256
+ pybtex==0.24.0
257
+ pybtex-docutils==1.0.1
258
+ pycodestyle==2.5.0
259
+ pycparser==2.20
260
+ pycryptodome==3.16.0
261
+ pyctcdecode==0.4.0
262
+ pydantic==1.10.4
263
+ pyDeprecate==0.3.1
264
+ pydub==0.25.1
265
+ pyflakes==2.1.1
266
+ Pygments==2.7.2
267
+ pygtrie==2.5.0
268
+ pymodbus==2.5.3
269
+ pyparsing==2.4.7
270
+ pyperclip==1.8.2
271
+ pypinyin==0.43.0
272
+ pyrsistent==0.17.3
273
+ pyserial==3.5
274
+ PySocks==1.7.1
275
+ pystoi==0.3.3
276
+ pytest==5.4.1
277
+ pytest-runner==5.3.1
278
+ python-bidi==0.4.2
279
+ python-crfsuite==0.9.7
280
+ python-dateutil==2.8.2
281
+ python-Levenshtein==0.12.2
282
+ python-multipart==0.0.5
283
+ python-utils==2.4.0
284
+ pytorch-lightning==1.4.9
285
+ pytube==11.0.1
286
+ pytz==2022.6
287
+ PyWavelets==1.1.1
288
+ PyYAML==5.3.1
289
+ pyzmq==20.0.0
290
+ rapidfuzz==1.8.2
291
+ regex==2020.11.13
292
+ requests==2.28.1
293
+ requests-oauthlib==1.3.0
294
+ resampy==0.2.2
295
+ rfc3986==1.4.0
296
+ rsa==4.7
297
+ ruamel.yaml==0.17.21
298
+ ruamel.yaml.clib==0.2.7
299
+ s3m==1.1.0
300
+ s3transfer==0.5.0
301
+ sacrebleu==2.0.0
302
+ sacremoses==0.0.44
303
+ scikit-image==0.18.1
304
+ scikit-learn==0.23.2
305
+ scipy==1.5.4
306
+ -e git+https://github.com/sanghack81/SDCIT@00d060dde733fde9345154a494f81e97fb395ca7#egg=SDCIT
307
+ seaborn==0.11.1
308
+ segments==2.1.3
309
+ Send2Trash==1.5.0
310
+ sentencepiece==0.1.94
311
+ sentry-sdk==1.4.3
312
+ shellingham==1.4.0
313
+ shortuuid==1.0.7
314
+ SIDEKIT==1.3.8.5.2
315
+ simplejson==3.17.5
316
+ six==1.15.0
317
+ smart-open==5.0.0
318
+ smmap==5.0.0
319
+ sniffio==1.3.0
320
+ snowballstemmer==2.0.0
321
+ sortedcollections==2.1.0
322
+ sortedcontainers==2.4.0
323
+ sounddevice==0.4.5
324
+ SoundFile==0.10.3.post1
325
+ soupsieve==2.3
326
+ sox==1.4.1
327
+ sparsemax==0.1.9
328
+ speechbrain==0.5.13
329
+ sphfile==1.0.3
330
+ Sphinx==3.3.1
331
+ sphinx-rtd-theme==0.4.3
332
+ sphinxcontrib-applehelp==1.0.2
333
+ sphinxcontrib-bibtex==2.4.1
334
+ sphinxcontrib-devhelp==1.0.2
335
+ sphinxcontrib-htmlhelp==1.0.3
336
+ sphinxcontrib-jsmath==1.0.1
337
+ sphinxcontrib-qthelp==1.0.3
338
+ sphinxcontrib-serializinghtml==1.1.4
339
+ SQLAlchemy==1.4.25
340
+ sqlparse==0.4.2
341
+ stanza==1.4.2
342
+ starlette==0.22.0
343
+ stevedore==3.4.0
344
+ subprocess32==3.5.4
345
+ sympy==1.9
346
+ tabulate==0.8.9
347
+ tensorboard==2.4.0
348
+ tensorboard-plugin-wit==1.7.0
349
+ tensorflow==2.4.0
350
+ tensorflow-estimator==2.4.0
351
+ termcolor==1.1.0
352
+ terminado==0.9.4
353
+ testpath==0.4.4
354
+ threadpoolctl==2.1.0
355
+ tifffile==2020.12.8
356
+ tikzplotlib==0.9.8
357
+ tkseem==0.0.3
358
+ tokenizers==0.10.2
359
+ toml==0.10.2
360
+ toolz==0.12.0
361
+ torch==1.13.1
362
+ torch-stft==0.1.4
363
+ torchaudio==0.13.1
364
+ torchmetrics==0.6.0
365
+ torchvision==0.14.1
366
+ tornado==6.1
367
+ tqdm==4.61.1
368
+ trackrip==1.2.1
369
+ traitlets==5.0.5
370
+ transformers==4.15.0
371
+ typed-ast==1.4.1
372
+ typer==0.4.0
373
+ typing-extensions==4.4.0
374
+ uc-micro-py==1.0.1
375
+ Unidecode==1.3.2
376
+ uritemplate==3.0.1
377
+ urllib3==1.26.2
378
+ uvicorn==0.20.0
379
+ virtualenv==20.2.1
380
+ wandb==0.12.6
381
+ wcwidth==0.2.5
382
+ webdataset==0.1.62
383
+ webencodings==0.5.1
384
+ websockets==10.4
385
+ Werkzeug==1.0.1
386
+ wget==3.2
387
+ widgetsnbextension==3.5.1
388
+ wordninja==2.0.0
389
+ wrapt==1.12.1
390
+ xmltodict==0.13.0
391
+ xxhash==2.0.0
392
+ yamllint==1.23.0
393
+ yarg==0.1.9
394
+ yarl==1.7.2
395
+ yaspin==2.1.0
396
+ youtokentome==1.0.6
397
+ youtube-dl==2021.6.6
398
+ zipp==3.6.0
399
+ ==============================
400
+ Could not get git revision==============================
401
+ CUDA version:
402
+ 11.7
partly_frozen_splitted_wavlm/1986/hyperparams.yaml ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated 2023-01-08 from:
2
+ # /home/salah/kenlm_train/to_copy/wavlm_partly_frozen.yaml
3
+ # yamllint disable
4
+ # ################################
5
+ # Model: wav2vec2 + DNN + CTC
6
+ # Augmentation: SpecAugment
7
+ # Authors: Sung-Lin Yeh 2021
8
+ # ################################
9
+
10
+ # Seed needs to be set at top of yaml, before objects with parameters are made
11
+ seed: 1986
12
+ __set_seed: !apply:torch.manual_seed [1986]
13
+ output_folder: partly_frozen_splitted_wavlm/1986/
14
+ wer_file: partly_frozen_splitted_wavlm/1986//wer.txt
15
+ save_folder: partly_frozen_splitted_wavlm/1986//save
16
+ train_log: partly_frozen_splitted_wavlm/1986//train_log.txt
17
+
18
+ # URL for the biggest Fairseq english wav2vec2 model.
19
+
20
+ # Data files
21
+ data_folder: /gpfsscratch/rech/nou/uzn19yk/Libri/LibriSpeech/ # e,g./path/to/LibriSpeech
22
+ # noise/ris dataset will automatically be downloaded
23
+ data_folder_rirs: /gpfsscratch/rech/nou/uzn19yk/Libri/LibriSpeech/
24
+ train_splits: [train-clean-100]
25
+ dev_splits: [dev-clean]
26
+ test_splits: [test-clean, test-other]
27
+ skip_prep: false
28
+ ckpt_interval_minutes: 25 # save checkpoint every N min
29
+ csv_folder: /gpfsstore/rech/nou/uzn19yk/iwslt/splitted_clean_tunisian_csvs/
30
+ train_csv: test_salah_local.csv
31
+ valid_csv: test_salah_local.csv
32
+ test_csv:
33
+ - test_salah_local.csv
34
+
35
+ # Training parameters
36
+ number_of_epochs: 12
37
+ lr: 1
38
+ lr_wav2vec: 0.0001
39
+ sorting: ascending
40
+ auto_mix_prec: false
41
+ sample_rate: 16000
42
+
43
+ avoid_if_longer_than: 10
44
+ # With data_parallel batch_size is split into N jobs
45
+ # With DDP batch_size is multiplied by N jobs
46
+ # Must be 3 per GPU to fit 32GB of VRAM
47
+ batch_size: 1
48
+ test_batch_size: 1
49
+
50
+ # Dataloader options
51
+ train_dataloader_opts:
52
+ batch_size: 1
53
+
54
+ valid_dataloader_opts:
55
+ batch_size: 1
56
+
57
+ test_dataloader_opts:
58
+ batch_size: 1
59
+
60
+ # Model parameters
61
+ activation: &id001 !name:torch.nn.LeakyReLU
62
+ dnn_layers: 2
63
+ dnn_neurons: 1024
64
+ freeze_wav2vec: false
65
+
66
+ # Outputs
67
+ output_neurons: 41 # BPE size, index(blank/eos/bos) = 0
68
+
69
+ # Decoding parameters
70
+ blank_index: 0
71
+ bos_index: 1
72
+ eos_index: 2
73
+
74
+ #
75
+ # Functions and classes
76
+ #
77
+ epoch_counter: &id008 !new:speechbrain.utils.epoch_loop.EpochCounter
78
+
79
+ limit: 12
80
+
81
+ augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
82
+ sample_rate: 16000
83
+ speeds: [95, 100, 105]
84
+
85
+ enc: &id003 !new:speechbrain.lobes.models.VanillaNN.VanillaNN
86
+ input_shape: [null, null, 1024]
87
+ activation: *id001
88
+ dnn_blocks: 2
89
+ dnn_neurons: 1024
90
+
91
+ wav2vec2: &id002 !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
92
+ source: wavlm-large/
93
+ output_norm: true
94
+ freeze: false
95
+ freeze_feature_extractor: true
96
+ save_path: partly_frozen_splitted_wavlm/1986//save/wav2vec2_hubert_checkpoint
97
+
98
+ #####
99
+ # Uncomment this block if you prefer to use a Fairseq pretrained model instead
100
+ # of a HuggingFace one. Here, we provide an URL that is obtained from the
101
+ # Fairseq github for the multilingual XLSR.
102
+ #
103
+ #wav2vec2_url: https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_960h_pl.pt
104
+ #wav2vec2: !new:speechbrain.lobes.models.fairseq_wav2vec.FairseqWav2Vec2
105
+ # pretrained_path: !ref <wav2vec2_url>
106
+ # output_norm: True
107
+ # freeze: False
108
+ # save_path: !ref <save_folder>/wav2vec2_checkpoint/model.pt
109
+
110
+ ctc_lin: &id004 !new:speechbrain.nnet.linear.Linear
111
+
112
+ input_size: 1024
113
+ n_neurons: 41
114
+
115
+ log_softmax: !new:speechbrain.nnet.activations.Softmax
116
+ apply_log: true
117
+
118
+ ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
119
+ blank_index: 0
120
+
121
+ modules:
122
+ wav2vec2: *id002
123
+ enc: *id003
124
+ ctc_lin: *id004
125
+ model: &id005 !new:torch.nn.ModuleList
126
+ - [*id003, *id004]
127
+ model_opt_class: !name:torch.optim.Adadelta
128
+ lr: 1
129
+ rho: 0.95
130
+ eps: 1.e-8
131
+
132
+ wav2vec_opt_class: !name:torch.optim.Adam
133
+ lr: 0.0001
134
+
135
+ lr_annealing_model: &id006 !new:speechbrain.nnet.schedulers.NewBobScheduler
136
+ initial_value: 1
137
+ improvement_threshold: 0.0025
138
+ annealing_factor: 0.8
139
+ patient: 0
140
+
141
+ lr_annealing_wav2vec: &id007 !new:speechbrain.nnet.schedulers.NewBobScheduler
142
+ initial_value: 0.0001
143
+ improvement_threshold: 0.0025
144
+ annealing_factor: 0.9
145
+ patient: 0
146
+
147
+
148
+ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
149
+ checkpoints_dir: partly_frozen_splitted_wavlm/1986//save
150
+ recoverables:
151
+ wav2vec2: *id002
152
+ model: *id005
153
+ scheduler_model: *id006
154
+ scheduler_wav2vec: *id007
155
+ counter: *id008
156
+ train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
157
+ save_file: partly_frozen_splitted_wavlm/1986//train_log.txt
158
+
159
+ error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
160
+
161
+ cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
162
+ split_tokens: true
partly_frozen_splitted_wavlm/1986/lm_decoded_ctc.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env/python3
2
+ """Recipe for training a wav2vec-based ctc ASR system with librispeech.
3
+ The system employs wav2vec as its encoder. Decoding is performed with
4
+ ctc greedy decoder.
5
+ To run this recipe, do the following:
6
+ > python train_with_wav2vec.py hparams/train_with_wav2vec.yaml
7
+ The neural network is trained on CTC likelihood target and character units
8
+ are used as basic recognition tokens. Training is performed on the full
9
+ LibriSpeech dataset (960 h).
10
+
11
+ Authors
12
+ * Sung-Lin Yeh 2021
13
+ * Titouan Parcollet 2021
14
+ * Ju-Chieh Chou 2020
15
+ * Mirco Ravanelli 2020
16
+ * Abdel Heba 2020
17
+ * Peter Plantinga 2020
18
+ * Samuele Cornell 2020
19
+ """
20
+
21
+ import os
22
+ import sys
23
+ import torch
24
+ import logging
25
+ import speechbrain as sb
26
+ from speechbrain.utils.distributed import run_on_main
27
+ from hyperpyyaml import load_hyperpyyaml
28
+ from pathlib import Path
29
+ from pyctcdecode import build_ctcdecoder
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ # Define training procedure
35
+ class ASR(sb.Brain):
36
+ def compute_forward(self, batch, stage):
37
+ """Forward computations from the waveform batches to the output probabilities."""
38
+ batch = batch.to(self.device)
39
+ wavs, wav_lens = batch.sig
40
+ tokens_bos, _ = batch.tokens_bos
41
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
42
+
43
+ # Forward pass
44
+ feats = self.modules.wav2vec2(wavs)
45
+
46
+ x = self.modules.enc(feats.detach())[0]
47
+ #x = self.modules.enc(feats.detach())
48
+ # Compute outputs
49
+ p_tokens = None
50
+ logits = self.modules.ctc_lin(x)
51
+ p_ctc = self.hparams.log_softmax(logits)
52
+ if stage != sb.Stage.TRAIN:
53
+ p_tokens = sb.decoders.ctc_greedy_decode(
54
+ p_ctc, wav_lens, blank_id=self.hparams.blank_index
55
+ )
56
+ return p_ctc, wav_lens, p_tokens
57
+
58
+ def compute_objectives(self, predictions, batch, stage):
59
+ """Computes the loss (CTC+NLL) given predictions and targets."""
60
+
61
+ p_ctc, wav_lens, predicted_tokens = predictions
62
+
63
+ ids = batch.id
64
+ tokens_eos, tokens_eos_lens = batch.tokens_eos
65
+ tokens, tokens_lens = batch.tokens
66
+
67
+ if hasattr(self.modules, "env_corrupt") and stage == sb.Stage.TRAIN:
68
+ tokens_eos = torch.cat([tokens_eos, tokens_eos], dim=0)
69
+ tokens_eos_lens = torch.cat(
70
+ [tokens_eos_lens, tokens_eos_lens], dim=0
71
+ )
72
+ tokens = torch.cat([tokens, tokens], dim=0)
73
+ tokens_lens = torch.cat([tokens_lens, tokens_lens], dim=0)
74
+
75
+ loss_ctc = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
76
+ loss = loss_ctc
77
+
78
+ if stage != sb.Stage.TRAIN:
79
+ # Decode token terms to words
80
+ predicted_words = [
81
+ "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
82
+ for utt_seq in predicted_tokens
83
+ ]
84
+ predicted_words =[]
85
+ for logs in p_ctc:
86
+ text = decoder.decode(logs.detach().cpu().numpy())
87
+ predicted_words.append(text.split(" "))
88
+
89
+ target_words = [wrd.split(" ") for wrd in batch.wrd]
90
+ self.wer_metric.append(ids, predicted_words, target_words)
91
+ self.cer_metric.append(ids, predicted_words, target_words)
92
+
93
+ return loss
94
+
95
+ def fit_batch(self, batch):
96
+ """Train the parameters given a single batch in input"""
97
+ predictions = self.compute_forward(batch, sb.Stage.TRAIN)
98
+ loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
99
+ loss.backward()
100
+ if self.check_gradients(loss):
101
+ self.wav2vec_optimizer.step()
102
+ self.model_optimizer.step()
103
+
104
+ self.wav2vec_optimizer.zero_grad()
105
+ self.model_optimizer.zero_grad()
106
+
107
+ return loss.detach()
108
+
109
+ def evaluate_batch(self, batch, stage):
110
+ """Computations needed for validation/test batches"""
111
+ predictions = self.compute_forward(batch, stage=stage)
112
+ with torch.no_grad():
113
+ loss = self.compute_objectives(predictions, batch, stage=stage)
114
+ return loss.detach()
115
+
116
+ def on_stage_start(self, stage, epoch):
117
+ """Gets called at the beginning of each epoch"""
118
+ if stage != sb.Stage.TRAIN:
119
+ self.cer_metric = self.hparams.cer_computer()
120
+ self.wer_metric = self.hparams.error_rate_computer()
121
+
122
+ def on_stage_end(self, stage, stage_loss, epoch):
123
+ """Gets called at the end of an epoch."""
124
+ # Compute/store important stats
125
+ stage_stats = {"loss": stage_loss}
126
+ if stage == sb.Stage.TRAIN:
127
+ self.train_stats = stage_stats
128
+ else:
129
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
130
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
131
+
132
+ # Perform end-of-iteration things, like annealing, logging, etc.
133
+ if stage == sb.Stage.VALID:
134
+ old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
135
+ stage_stats["loss"]
136
+ )
137
+ old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(
138
+ stage_stats["loss"]
139
+ )
140
+ sb.nnet.schedulers.update_learning_rate(
141
+ self.model_optimizer, new_lr_model
142
+ )
143
+ sb.nnet.schedulers.update_learning_rate(
144
+ self.wav2vec_optimizer, new_lr_wav2vec
145
+ )
146
+ self.hparams.train_logger.log_stats(
147
+ stats_meta={
148
+ "epoch": epoch,
149
+ "lr_model": old_lr_model,
150
+ "lr_wav2vec": old_lr_wav2vec,
151
+ },
152
+ train_stats=self.train_stats,
153
+ valid_stats=stage_stats,
154
+ )
155
+ self.checkpointer.save_and_keep_only(
156
+ meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
157
+ )
158
+ elif stage == sb.Stage.TEST:
159
+ self.hparams.train_logger.log_stats(
160
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
161
+ test_stats=stage_stats,
162
+ )
163
+ with open(self.hparams.wer_file, "w") as w:
164
+ self.wer_metric.write_stats(w)
165
+
166
+ def init_optimizers(self):
167
+ "Initializes the wav2vec2 optimizer and model optimizer"
168
+ self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
169
+ self.modules.wav2vec2.parameters()
170
+ )
171
+ self.model_optimizer = self.hparams.model_opt_class(
172
+ self.hparams.model.parameters()
173
+ )
174
+
175
+ if self.checkpointer is not None:
176
+ self.checkpointer.add_recoverable(
177
+ "wav2vec_opt", self.wav2vec_optimizer
178
+ )
179
+ self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
180
+
181
+
182
+ def dataio_prepare(hparams):
183
+ """This function prepares the datasets to be used in the brain class.
184
+ It also defines the data processing pipeline through user-defined functions."""
185
+ data_folder = hparams["data_folder"]
186
+
187
+ train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
188
+ csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
189
+ )
190
+
191
+ if hparams["sorting"] == "ascending":
192
+ # we sort training data to speed up training and get better results.
193
+ train_data = train_data.filtered_sorted(sort_key="duration")
194
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
195
+ hparams["train_dataloader_opts"]["shuffle"] = False
196
+
197
+ elif hparams["sorting"] == "descending":
198
+ train_data = train_data.filtered_sorted(
199
+ sort_key="duration", reverse=True
200
+ )
201
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
202
+ hparams["train_dataloader_opts"]["shuffle"] = False
203
+
204
+ elif hparams["sorting"] == "random":
205
+ pass
206
+
207
+ else:
208
+ raise NotImplementedError(
209
+ "sorting must be random, ascending or descending"
210
+ )
211
+
212
+ valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
213
+ csv_path=hparams["valid_csv"], replacements={"data_root": data_folder},
214
+ )
215
+ valid_data = valid_data.filtered_sorted(sort_key="duration")
216
+
217
+ # test is separate
218
+ test_datasets = {}
219
+ for csv_file in hparams["test_csv"]:
220
+ name = Path(csv_file).stem
221
+ test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(
222
+ csv_path=csv_file, replacements={"data_root": data_folder}
223
+ )
224
+ test_datasets[name] = test_datasets[name].filtered_sorted(
225
+ sort_key="duration"
226
+ )
227
+
228
+ datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]
229
+
230
+ # 2. Define audio pipeline:
231
+ @sb.utils.data_pipeline.takes("wav")
232
+ @sb.utils.data_pipeline.provides("sig")
233
+ def audio_pipeline(wav):
234
+ sig = sb.dataio.dataio.read_audio(wav)
235
+ return sig
236
+
237
+ sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
238
+ label_encoder = sb.dataio.encoder.CTCTextEncoder()
239
+
240
+ # 3. Define text pipeline:
241
+ @sb.utils.data_pipeline.takes("wrd")
242
+ @sb.utils.data_pipeline.provides(
243
+ "wrd", "char_list", "tokens_list", "tokens_bos", "tokens_eos", "tokens"
244
+ )
245
+ def text_pipeline(wrd):
246
+ yield wrd
247
+ char_list = list(wrd)
248
+ yield char_list
249
+ tokens_list = label_encoder.encode_sequence(char_list)
250
+ yield tokens_list
251
+ tokens_bos = torch.LongTensor([hparams["bos_index"]] + (tokens_list))
252
+ yield tokens_bos
253
+ tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]])
254
+ yield tokens_eos
255
+ tokens = torch.LongTensor(tokens_list)
256
+ yield tokens
257
+
258
+ sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
259
+
260
+ lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
261
+ special_labels = {
262
+ "bos_label": hparams["bos_index"],
263
+ "eos_label": hparams["eos_index"],
264
+ "blank_label": hparams["blank_index"],
265
+ }
266
+ label_encoder.load_or_create(
267
+ path=lab_enc_file,
268
+ from_didatasets=[train_data],
269
+ output_key="char_list",
270
+ special_labels=special_labels,
271
+ sequence_input=True,
272
+ )
273
+
274
+ # 4. Set output:
275
+ sb.dataio.dataset.set_output_keys(
276
+ datasets,
277
+ ["id", "sig", "wrd", "char_list", "tokens_bos", "tokens_eos", "tokens"],
278
+ )
279
+ return train_data, valid_data, test_datasets, label_encoder
280
+
281
+
282
+ if __name__ == "__main__":
283
+
284
+ # CLI:
285
+ hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
286
+
287
+ # If distributed_launch=True then
288
+ # create ddp_group with the right communication protocol
289
+ sb.utils.distributed.ddp_init_group(run_opts)
290
+
291
+ with open(hparams_file) as fin:
292
+ hparams = load_hyperpyyaml(fin, overrides)
293
+
294
+ # Create experiment directory
295
+ sb.create_experiment_directory(
296
+ experiment_directory=hparams["output_folder"],
297
+ hyperparams_to_save=hparams_file,
298
+ overrides=overrides,
299
+ )
300
+ def read_labels_file(labels_file):
301
+ with open(labels_file, "r") as lf:
302
+ lines = lf.read().splitlines()
303
+ division = "==="
304
+ numbers = {}
305
+ for line in lines :
306
+ if division in line :
307
+ break
308
+ string, number = line.split("=>")
309
+ number = int(number)
310
+ string = string[1:-2]
311
+ numbers[number] = string
312
+ return [numbers[x] for x in range(len(numbers))]
313
+ labels = read_labels_file(os.path.join(hparams["save_folder"], "label_encoder.txt"))
314
+ print(labels)
315
+ labels = [""] + labels[1:]
316
+ print(len(labels))
317
+ decoder = build_ctcdecoder(
318
+ labels,
319
+ kenlm_model_path="/gpfsstore/rech/nou/uzn19yk/4-gram.arpa", # either .arpa or .bin file
320
+ alpha=0.5, # tuned on a val set
321
+ beta=1.0, # tuned on a val set
322
+ )
323
+
324
+ # Dataset prep (parsing Librispeech)
325
+ from librispeech_prepare import prepare_librispeech # noqa
326
+
327
+ # multi-gpu (ddp) save data preparation
328
+ """
329
+ run_on_main(
330
+ prepare_librispeech,
331
+ kwargs={
332
+ "data_folder": hparams["data_folder"],
333
+ "tr_splits": hparams["train_splits"],
334
+ "dev_splits": hparams["dev_splits"],
335
+ "te_splits": hparams["test_splits"],
336
+ "save_folder": hparams["output_folder"],
337
+ "merge_lst": hparams["train_splits"],
338
+ "merge_name": "train.csv",
339
+ "skip_prep": hparams["skip_prep"],
340
+ },
341
+ )
342
+ """
343
+
344
+ # here we create the datasets objects as well as tokenization and encoding
345
+ train_data, valid_data, test_datasets, label_encoder = dataio_prepare(
346
+ hparams
347
+ )
348
+
349
+ # Trainer initialization
350
+ asr_brain = ASR(
351
+ modules=hparams["modules"],
352
+ hparams=hparams,
353
+ run_opts=run_opts,
354
+ checkpointer=hparams["checkpointer"],
355
+ )
356
+
357
+ # We dynamicaly add the tokenizer to our brain class.
358
+ # NB: This tokenizer corresponds to the one used for the LM!!
359
+ asr_brain.tokenizer = label_encoder
360
+
361
+ # Training
362
+ asr_brain.fit(
363
+ asr_brain.hparams.epoch_counter,
364
+ train_data,
365
+ valid_data,
366
+ train_loader_kwargs=hparams["train_dataloader_opts"],
367
+ valid_loader_kwargs=hparams["valid_dataloader_opts"],
368
+ )
369
+
370
+ # Testing
371
+ for k in test_datasets.keys(): # keys are test_clean, test_other etc
372
+ asr_brain.hparams.wer_file = os.path.join(
373
+ hparams["output_folder"], "wer_{}.txt".format(k)
374
+ )
375
+ asr_brain.evaluate(
376
+ test_datasets[k], test_loader_kwargs=hparams["test_dataloader_opts"]
377
+ )
partly_frozen_splitted_wavlm/1986/lm_tunisian.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env/python3
2
+ """Recipe for training a wav2vec-based ctc ASR system with librispeech.
3
+ The system employs wav2vec as its encoder. Decoding is performed with
4
+ ctc greedy decoder.
5
+ To run this recipe, do the following:
6
+ > python train_with_wav2vec.py hparams/train_with_wav2vec.yaml
7
+ The neural network is trained on CTC likelihood target and character units
8
+ are used as basic recognition tokens. Training is performed on the full
9
+ LibriSpeech dataset (960 h).
10
+
11
+ Authors
12
+ * Sung-Lin Yeh 2021
13
+ * Titouan Parcollet 2021
14
+ * Ju-Chieh Chou 2020
15
+ * Mirco Ravanelli 2020
16
+ * Abdel Heba 2020
17
+ * Peter Plantinga 2020
18
+ * Samuele Cornell 2020
19
+ """
20
+
21
+ import os
22
+ import sys
23
+ import torch
24
+ import logging
25
+ import speechbrain as sb
26
+ from speechbrain.utils.distributed import run_on_main
27
+ from hyperpyyaml import load_hyperpyyaml
28
+ from pathlib import Path
29
+ import torchaudio.transforms as T
30
+
31
+ from pyctcdecode import build_ctcdecoder
32
+ logger = logging.getLogger(__name__)
33
+
34
+ # Define training procedure
35
+ class ASR(sb.Brain):
36
+ def compute_forward(self, batch, stage):
37
+ """Forward computations from the waveform batches to the output probabilities."""
38
+ batch = batch.to(self.device)
39
+ wavs, wav_lens = batch.sig
40
+ tokens_bos, _ = batch.tokens_bos
41
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
42
+
43
+ # Forward pass
44
+ feats = self.modules.wav2vec2(wavs)
45
+ x = self.modules.enc(feats)
46
+ # Compute outputs
47
+ p_tokens = None
48
+ logits = self.modules.ctc_lin(x)
49
+ p_ctc = self.hparams.log_softmax(logits)
50
+ if stage != sb.Stage.TRAIN:
51
+ p_tokens = sb.decoders.ctc_greedy_decode(
52
+ p_ctc, wav_lens, blank_id=self.hparams.blank_index
53
+ )
54
+ return p_ctc, wav_lens, p_tokens
55
+
56
+ def compute_objectives(self, predictions, batch, stage):
57
+ """Computes the loss (CTC+NLL) given predictions and targets."""
58
+
59
+ p_ctc, wav_lens, predicted_tokens = predictions
60
+
61
+ ids = batch.id
62
+ tokens_eos, tokens_eos_lens = batch.tokens_eos
63
+ tokens, tokens_lens = batch.tokens
64
+
65
+ if hasattr(self.modules, "env_corrupt") and stage == sb.Stage.TRAIN:
66
+ tokens_eos = torch.cat([tokens_eos, tokens_eos], dim=0)
67
+ tokens_eos_lens = torch.cat(
68
+ [tokens_eos_lens, tokens_eos_lens], dim=0
69
+ )
70
+ tokens = torch.cat([tokens, tokens], dim=0)
71
+ tokens_lens = torch.cat([tokens_lens, tokens_lens], dim=0)
72
+
73
+ loss_ctc = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
74
+ loss = loss_ctc
75
+ if stage != sb.Stage.TRAIN:
76
+ # Decode token terms to words
77
+ predicted_words =[]
78
+ for logs in p_ctc:
79
+ text = decoder.decode(logs.detach().cpu().numpy())
80
+ predicted_words.append(text.split(" "))
81
+
82
+
83
+ target_words = [wrd.split(" ") for wrd in batch.wrd]
84
+ self.wer_metric.append(ids, predicted_words, target_words)
85
+ self.cer_metric.append(ids, predicted_words, target_words)
86
+
87
+ return loss
88
+
89
+ def fit_batch(self, batch):
90
+ """Train the parameters given a single batch in input"""
91
+ predictions = self.compute_forward(batch, sb.Stage.TRAIN)
92
+ loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
93
+ loss.backward()
94
+ if self.check_gradients(loss):
95
+ self.wav2vec_optimizer.step()
96
+ self.model_optimizer.step()
97
+
98
+ self.wav2vec_optimizer.zero_grad()
99
+ self.model_optimizer.zero_grad()
100
+
101
+ return loss.detach()
102
+
103
+ def evaluate_batch(self, batch, stage):
104
+ """Computations needed for validation/test batches"""
105
+ predictions = self.compute_forward(batch, stage=stage)
106
+ with torch.no_grad():
107
+ loss = self.compute_objectives(predictions, batch, stage=stage)
108
+ return loss.detach()
109
+
110
+ def on_stage_start(self, stage, epoch):
111
+ """Gets called at the beginning of each epoch"""
112
+ if stage != sb.Stage.TRAIN:
113
+ self.cer_metric = self.hparams.cer_computer()
114
+ self.wer_metric = self.hparams.error_rate_computer()
115
+
116
+ def on_stage_end(self, stage, stage_loss, epoch):
117
+ """Gets called at the end of an epoch."""
118
+ # Compute/store important stats
119
+ stage_stats = {"loss": stage_loss}
120
+ if stage == sb.Stage.TRAIN:
121
+ self.train_stats = stage_stats
122
+ else:
123
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
124
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
125
+
126
+ # Perform end-of-iteration things, like annealing, logging, etc.
127
+ if stage == sb.Stage.VALID:
128
+ old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
129
+ stage_stats["loss"]
130
+ )
131
+ old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(
132
+ stage_stats["loss"]
133
+ )
134
+ sb.nnet.schedulers.update_learning_rate(
135
+ self.model_optimizer, new_lr_model
136
+ )
137
+ sb.nnet.schedulers.update_learning_rate(
138
+ self.wav2vec_optimizer, new_lr_wav2vec
139
+ )
140
+ self.hparams.train_logger.log_stats(
141
+ stats_meta={
142
+ "epoch": epoch,
143
+ "lr_model": old_lr_model,
144
+ "lr_wav2vec": old_lr_wav2vec,
145
+ },
146
+ train_stats=self.train_stats,
147
+ valid_stats=stage_stats,
148
+ )
149
+ self.checkpointer.save_and_keep_only(
150
+ meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
151
+ )
152
+ elif stage == sb.Stage.TEST:
153
+ self.hparams.train_logger.log_stats(
154
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
155
+ test_stats=stage_stats,
156
+ )
157
+ with open(self.hparams.wer_file, "w") as w:
158
+ self.wer_metric.write_stats(w)
159
+
160
+ def init_optimizers(self):
161
+ "Initializes the wav2vec2 optimizer and model optimizer"
162
+ self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
163
+ self.modules.wav2vec2.parameters()
164
+ )
165
+ self.model_optimizer = self.hparams.model_opt_class(
166
+ self.hparams.model.parameters()
167
+ )
168
+
169
+ if self.checkpointer is not None:
170
+ self.checkpointer.add_recoverable(
171
+ "wav2vec_opt", self.wav2vec_optimizer
172
+ )
173
+ self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
174
+
175
+
176
+ def dataio_prepare(hparams):
177
+ """This function prepares the datasets to be used in the brain class.
178
+ It also defines the data processing pipeline through user-defined functions."""
179
+ data_folder = hparams["data_folder"]
180
+
181
+ train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
182
+ csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
183
+ )
184
+
185
+ if hparams["sorting"] == "ascending":
186
+ # we sort training data to speed up training and get better results.
187
+ train_data = train_data.filtered_sorted(sort_key="duration")
188
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
189
+ hparams["train_dataloader_opts"]["shuffle"] = False
190
+
191
+ elif hparams["sorting"] == "descending":
192
+ train_data = train_data.filtered_sorted(
193
+ sort_key="duration", reverse=True
194
+ )
195
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
196
+ hparams["train_dataloader_opts"]["shuffle"] = False
197
+
198
+ elif hparams["sorting"] == "random":
199
+ pass
200
+
201
+ else:
202
+ raise NotImplementedError(
203
+ "sorting must be random, ascending or descending"
204
+ )
205
+
206
+ valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
207
+ csv_path=hparams["valid_csv"], replacements={"data_root": data_folder},
208
+ )
209
+ valid_data = valid_data.filtered_sorted(sort_key="duration")
210
+
211
+ # test is separate
212
+ test_datasets = {}
213
+ for csv_file in hparams["test_csv"]:
214
+ name = Path(csv_file).stem
215
+ test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(
216
+ csv_path=csv_file, replacements={"data_root": data_folder}
217
+ )
218
+ test_datasets[name] = test_datasets[name].filtered_sorted(
219
+ sort_key="duration"
220
+ )
221
+
222
+ datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]
223
+
224
+ # 2. Define audio pipeline:
225
+ @sb.utils.data_pipeline.takes("wav", "sr")
226
+ @sb.utils.data_pipeline.provides("sig")
227
+ def audio_pipeline(wav, sr):
228
+ sig = sb.dataio.dataio.read_audio(wav)
229
+ sig = resamplers[sr](sig)
230
+ return sig
231
+
232
+ sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
233
+ label_encoder = sb.dataio.encoder.CTCTextEncoder()
234
+
235
+ # 3. Define text pipeline:
236
+ @sb.utils.data_pipeline.takes("wrd")
237
+ @sb.utils.data_pipeline.provides(
238
+ "wrd", "char_list", "tokens_list", "tokens_bos", "tokens_eos", "tokens"
239
+ )
240
+ def text_pipeline(wrd):
241
+ yield wrd
242
+ char_list = list(wrd)
243
+ yield char_list
244
+ tokens_list = label_encoder.encode_sequence(char_list)
245
+ yield tokens_list
246
+ tokens_bos = torch.LongTensor([hparams["bos_index"]] + (tokens_list))
247
+ yield tokens_bos
248
+ tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]])
249
+ yield tokens_eos
250
+ tokens = torch.LongTensor(tokens_list)
251
+ yield tokens
252
+
253
+ sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
254
+
255
+ lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
256
+ special_labels = {
257
+ "bos_label": hparams["bos_index"],
258
+ "eos_label": hparams["eos_index"],
259
+ "blank_label": hparams["blank_index"],
260
+ }
261
+ label_encoder.load_or_create(
262
+ path=lab_enc_file,
263
+ from_didatasets=[train_data],
264
+ output_key="char_list",
265
+ special_labels=special_labels,
266
+ sequence_input=True,
267
+ )
268
+
269
+ # 4. Set output:
270
+ sb.dataio.dataset.set_output_keys(
271
+ datasets,
272
+ ["id", "sig", "wrd", "char_list", "tokens_bos", "tokens_eos", "tokens"],
273
+ )
274
+ return train_data, valid_data, test_datasets, label_encoder
275
+
276
+
277
+ if __name__ == "__main__":
278
+
279
+ # CLI:
280
+ hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
281
+
282
+ # If distributed_launch=True then
283
+ # create ddp_group with the right communication protocol
284
+ sb.utils.distributed.ddp_init_group(run_opts)
285
+
286
+ with open(hparams_file) as fin:
287
+ hparams = load_hyperpyyaml(fin, overrides)
288
+
289
+ # Create experiment directory
290
+ sb.create_experiment_directory(
291
+ experiment_directory=hparams["output_folder"],
292
+ hyperparams_to_save=hparams_file,
293
+ overrides=overrides,
294
+ )
295
+ def read_labels_file(labels_file):
296
+ with open(labels_file, "r") as lf:
297
+ lines = lf.read().splitlines()
298
+ division = "==="
299
+ numbers = {}
300
+ for line in lines :
301
+ if division in line :
302
+ break
303
+ string, number = line.split("=>")
304
+ number = int(number)
305
+ string = string[1:-2]
306
+ numbers[number] = string
307
+ return [numbers[x] for x in range(len(numbers))]
308
+ labels = read_labels_file(os.path.join(hparams["save_folder"], "label_encoder.txt"))
309
+ print(labels)
310
+ labels = [""] + labels[1:]
311
+ print(len(labels))
312
+ decoder = build_ctcdecoder(
313
+ labels,
314
+ kenlm_model_path="tunisian.arpa", # either .arpa or .bin file
315
+ alpha=0.5, # tuned on a val set
316
+ beta=1.0, # tuned on a val set
317
+ )
318
+
319
+ # Dataset prep (parsing Librispeech)
320
+
321
+ resampler_8000 = T.Resample(8000, 16000, dtype=torch.float)
322
+
323
+ resampler_44100 =T.Resample(44100, 16000, dtype=torch.float)
324
+ resampler_48000 =T.Resample(48000, 16000, dtype=torch.float)
325
+ resamplers = {"8000": resampler_8000, "44100":resampler_44100, "48000": resampler_48000}
326
+
327
+ # here we create the datasets objects as well as tokenization and encoding
328
+ train_data, valid_data, test_datasets, label_encoder = dataio_prepare(
329
+ hparams
330
+ )
331
+
332
+ # Trainer initialization
333
+ asr_brain = ASR(
334
+ modules=hparams["modules"],
335
+ hparams=hparams,
336
+ run_opts=run_opts,
337
+ checkpointer=hparams["checkpointer"],
338
+ )
339
+ asr_brain.device= "cpu"
340
+ asr_brain.modules.to("cpu")
341
+ # We dynamicaly add the tokenizer to our brain class.
342
+ # NB: This tokenizer corresponds to the one used for the LM!!
343
+ asr_brain.tokenizer = label_encoder
344
+
345
+ # Training
346
+ asr_brain.fit(
347
+ asr_brain.hparams.epoch_counter,
348
+ train_data,
349
+ valid_data,
350
+ train_loader_kwargs=hparams["train_dataloader_opts"],
351
+ valid_loader_kwargs=hparams["valid_dataloader_opts"],
352
+ )
353
+
354
+ # Testing
355
+ for k in test_datasets.keys(): # keys are test_clean, test_other etc
356
+ asr_brain.hparams.wer_file = os.path.join(
357
+ hparams["output_folder"], "wer_{}.txt".format(k)
358
+ )
359
+ asr_brain.evaluate(
360
+ test_datasets[k], test_loader_kwargs=hparams["test_dataloader_opts"]
361
+ )
partly_frozen_splitted_wavlm/1986/log.txt ADDED
The diff for this file is too large to render. See raw diff
 
partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+11-59-05+00/CKPT.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # yamllint disable
2
+ WER: 47.68855883035507
3
+ end-of-epoch: true
4
+ unixtime: 1672916345.1685827
partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+11-59-05+00/brain.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:33809a026a2c1febce7b03c8aafaee4ddfc851b2c70f180f8c06bf1017f4df5c
3
+ size 46
partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+11-59-05+00/counter.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b51d431df5d7f141cbececcf79edf3dd861c3b4069f0b11661a3eefacbba918
3
+ size 2
partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+11-59-05+00/dataloader-TRAIN.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d53eb07a66864f0a18b5a0c5d029b33bddb11f05025a6e385c92c9fbb618edee
3
+ size 6
partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+11-59-05+00/model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8782ec6abb3dd1a6aa05d54a8f159375707d2ef212f2d934c1209aab5f01a46b
3
+ size 8566935
partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+11-59-05+00/modelopt.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18828222ecebe54247f2e16e2eedf607c5fc34d0b0029f4cc3356f9397a10802
3
+ size 17133057
partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+11-59-05+00/scheduler_model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:524fb9db201619c436c0f51539869e8ffd90c8b60bf30384726052e65076751a
3
+ size 623
partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+11-59-05+00/scheduler_wav2vec.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4663a4d525410c34529e6c80671db216a35c9f285f1cf9a1ca7004d3e14679c7
3
+ size 623
partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+11-59-05+00/wav2vec2.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:96716b4e96d96f3aa00182e1e5bf99219b86f74990e2f05cb69165d6e96c8b4e
3
+ size 1262004913
partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+11-59-05+00/wav2vec_opt.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c39bdda929f079de1ab70cd624900b7fda0a6a6682446ab2a25296442d01862e
3
+ size 2490235001
partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+12-24-11+00/CKPT.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # yamllint disable
2
+ brain_intra_epoch_ckpt: true
3
+ end-of-epoch: false
4
+ unixtime: 1672917851.8581367
partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+12-24-11+00/brain.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3bc72ed3a1d0a5dc95a83b7a139fce806d6929b8db82b1f8dd010d42556698ca
3
+ size 65
partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+12-24-11+00/counter.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3fdba35f04dc8c462986c992bcf875546257113072a909c162f7e470e581e278
3
+ size 2
partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+12-24-11+00/dataloader-TRAIN.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c8228f8d57233cd4d689b31048e7ec6e2c12b409b2501ac69252bbdb65ea575d
3
+ size 5
partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+12-24-11+00/model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:316a78537d2603b0ef0030df65934944a0cc88a0865c7983aa77a081605d1d05
3
+ size 8566935
partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+12-24-11+00/modelopt.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d721ac115b10592cb716b2641f80edf21549fb11c74bcf443f8213249c7fa7e8
3
+ size 17133057
partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+12-24-11+00/scheduler_model.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:524fb9db201619c436c0f51539869e8ffd90c8b60bf30384726052e65076751a
3
+ size 623
partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+12-24-11+00/scheduler_wav2vec.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4663a4d525410c34529e6c80671db216a35c9f285f1cf9a1ca7004d3e14679c7
3
+ size 623
partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+12-24-11+00/wav2vec2.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8616b0fbaeb370bfae2df4d924367ae88d4a17a644b22f7050a607242f036fe5
3
+ size 1262004913
partly_frozen_splitted_wavlm/1986/save/CKPT+2023-01-05+12-24-11+00/wav2vec_opt.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3767ccf04ad421c4a71a167b10e0d31b064bf2f91f769ec42f9510f675dea148
3
+ size 2490235001
partly_frozen_splitted_wavlm/1986/save/label_encoder.txt ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'م' => 38
2
+ 'و' => 39
3
+ ' ' => 40
4
+ 'ق' => 3
5
+ 'ا' => 4
6
+ 'ع' => 5
7
+ 'د' => 6
8
+ 'ة' => 7
9
+ 'ت' => 8
10
+ 'ش' => 9
11
+ 'ي' => 10
12
+ 'ك' => 11
13
+ 'ه' => 12
14
+ 'ل' => 13
15
+ 'ح' => 14
16
+ 'ب' => 15
17
+ 'ن' => 16
18
+ 'ى' => 17
19
+ 'ر' => 18
20
+ 'ف' => 19
21
+ 'إ' => 20
22
+ 'س' => 21
23
+ 'أ' => 22
24
+ 'ض' => 23
25
+ 'ص' => 24
26
+ 'ط' => 25
27
+ 'خ' => 26
28
+ 'ج' => 27
29
+ 'ظ' => 28
30
+ 'ز' => 29
31
+ 'آ' => 30
32
+ 'ذ' => 31
33
+ 'غ' => 32
34
+ 'ث' => 33
35
+ 'ئ' => 34
36
+ 'ء' => 35
37
+ 'ؤ' => 36
38
+ 'ٱ' => 37
39
+ '<blank>' => 0
40
+ '<bos>' => 1
41
+ '<eos>' => 2
42
+ ================
43
+ 'starting_index' => 0
44
+ 'bos_label' => '<bos>'
45
+ 'eos_label' => '<eos>'
46
+ 'blank_label' => '<blank>'
partly_frozen_splitted_wavlm/1986/train_log.txt ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ epoch: 1, lr_model: 1, lr_wav2vec: 1.00e-04 - train loss: 1.24 - valid loss: 9.24e-01, valid CER: 26.54, valid WER: 59.12
2
+ epoch: 2, lr_model: 1, lr_wav2vec: 1.00e-04 - train loss: 9.67e-01 - valid loss: 9.03e-01, valid CER: 25.85, valid WER: 57.02
3
+ epoch: 3, lr_model: 1, lr_wav2vec: 1.00e-04 - train loss: 8.84e-01 - valid loss: 8.81e-01, valid CER: 24.89, valid WER: 55.02
4
+ epoch: 4, lr_model: 1, lr_wav2vec: 1.00e-04 - train loss: 8.19e-01 - valid loss: 8.31e-01, valid CER: 22.92, valid WER: 51.69
5
+ epoch: 5, lr_model: 1, lr_wav2vec: 1.00e-04 - train loss: 7.76e-01 - valid loss: 8.67e-01, valid CER: 23.36, valid WER: 50.67
6
+ epoch: 6, lr_model: 8.00e-01, lr_wav2vec: 9.00e-05 - train loss: 7.20e-01 - valid loss: 8.37e-01, valid CER: 22.87, valid WER: 49.84
7
+ epoch: 7, lr_model: 8.00e-01, lr_wav2vec: 9.00e-05 - train loss: 6.87e-01 - valid loss: 8.78e-01, valid CER: 23.90, valid WER: 51.19
8
+ epoch: 8, lr_model: 6.40e-01, lr_wav2vec: 8.10e-05 - train loss: 6.44e-01 - valid loss: 8.68e-01, valid CER: 23.03, valid WER: 49.83
9
+ epoch: 9, lr_model: 6.40e-01, lr_wav2vec: 8.10e-05 - train loss: 6.20e-01 - valid loss: 8.47e-01, valid CER: 22.77, valid WER: 48.42
10
+ epoch: 10, lr_model: 6.40e-01, lr_wav2vec: 8.10e-05 - train loss: 5.98e-01 - valid loss: 9.07e-01, valid CER: 24.31, valid WER: 49.76
11
+ epoch: 11, lr_model: 5.12e-01, lr_wav2vec: 7.29e-05 - train loss: 5.60e-01 - valid loss: 9.08e-01, valid CER: 23.75, valid WER: 49.33
12
+ epoch: 12, lr_model: 4.10e-01, lr_wav2vec: 6.56e-05 - train loss: 5.22e-01 - valid loss: 9.08e-01, valid CER: 22.61, valid WER: 47.69
13
+ Epoch loaded: 12 - test loss: 1.26e-04, test CER: 9.09, test WER: 42.65
14
+ Epoch loaded: 12 - test loss: 5.95e-02, test CER: 20.52, test WER: 40.71
15
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 10.97, test WER: 54.41
16
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
17
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
18
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
19
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
20
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
21
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
22
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
23
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
24
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
25
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
26
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
27
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
28
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
29
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
30
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
31
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
32
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
33
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
34
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
35
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
36
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
37
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
38
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
39
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
40
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
41
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
42
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
43
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 10.03, test WER: 48.53
44
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 10.03, test WER: 48.53
45
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 10.03, test WER: 48.53
46
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 10.03, test WER: 48.53
47
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.72, test WER: 45.59
48
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.72, test WER: 45.59
49
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.72, test WER: 45.59
50
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.72, test WER: 45.59
51
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.72, test WER: 45.59
52
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.72, test WER: 45.59
53
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.72, test WER: 45.59
54
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.72, test WER: 45.59
55
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.72, test WER: 45.59
56
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.72, test WER: 45.59
57
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.40, test WER: 45.59
58
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.40, test WER: 45.59
59
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.40, test WER: 45.59
60
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
61
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
62
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
63
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
64
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
65
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
66
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
67
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
68
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
69
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
70
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
71
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
72
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
73
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
74
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
75
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
76
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
77
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
78
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 8.78, test WER: 44.12
79
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 10.66, test WER: 54.41
80
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 11.60, test WER: 60.29
81
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
82
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
83
+ Epoch loaded: 12 - test loss: 7.97e-04, test CER: 9.09, test WER: 42.65
84
+ Epoch loaded: 12 - test loss: 6.99e-05, test CER: 8.82, test WER: 42.86
85
+ Epoch loaded: 12 - test loss: 6.99e-05, test CER: 8.82, test WER: 42.86
86
+ Epoch loaded: 12 - test loss: 6.99e-05, test CER: 8.82, test WER: 42.86
partly_frozen_splitted_wavlm/1986/wer_test.txt ADDED
The diff for this file is too large to render. See raw diff
 
partly_frozen_splitted_wavlm/1986/wer_test_salah.txt ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ %WER 42.65 [ 29 / 68, 1 ins, 7 del, 21 sub ]
2
+ %SER 90.00 [ 9 / 10 ]
3
+ Scored 10 sentences, 0 not present in hyp.
4
+ ================================================================================
5
+ ALIGNMENTS
6
+
7
+ Format:
8
+ <utterance-id>, WER DETAILS
9
+ <eps> ; reference ; on ; the ; first ; line
10
+ I ; S ; = ; = ; S ; D
11
+ and ; hypothesis ; on ; the ; third ; <eps>
12
+ ================================================================================
13
+ Salah4, %WER 0.00 [ 0 / 5, 0 ins, 0 del, 0 sub ]
14
+ تعبت ; هاني ; راكش ; في ; الدار
15
+ = ; = ; = ; = ; =
16
+ تعبت ; هاني ; راكش ; في ; الدار
17
+ ================================================================================
18
+ Salah5, %WER 57.14 [ 4 / 7, 0 ins, 1 del, 3 sub ]
19
+ نهار ; السبت ; ماشي ; نقرى ; ان ; شاء ; الله
20
+ = ; = ; = ; S ; S ; S ; D
21
+ نهار ; السبت ; ماشي ; نقرا ; إن ; شاءالله ; <eps>
22
+ ================================================================================
23
+ Salah2, %WER 60.00 [ 3 / 5, 0 ins, 1 del, 2 sub ]
24
+ باهي ; وقتاش ; نمشيو ; ال ; تونس
25
+ = ; = ; S ; S ; D
26
+ باهي ; وقتاش ; نمشيوا ; لتونس ; <eps>
27
+ ================================================================================
28
+ Salah7, %WER 33.33 [ 2 / 6, 0 ins, 1 del, 1 sub ]
29
+ نحب ; نمشي ; ال ; بنزرت ; نرتاح ; شوية
30
+ = ; = ; S ; D ; = ; =
31
+ نحب ; نمشي ; لبنزرت ; <eps> ; نرتاح ; شوية
32
+ ================================================================================
33
+ Salah6, %WER 37.50 [ 3 / 8, 0 ins, 0 del, 3 sub ]
34
+ زعما ; نلقى ; أحمد ; في ; الستاد ; ولا ; ماهوش ; هوني
35
+ S ; = ; = ; = ; = ; S ; = ; S
36
+ زعمة ; نلقى ; أحمد ; في ; الستاد ; وإلا ; ماهوش ; كوني
37
+ ================================================================================
38
+ Salah10, %WER 66.67 [ 4 / 6, 1 ins, 1 del, 2 sub ]
39
+ انتي ; <eps> ; خويا ; و ; عشيري ; صالح ; نحبك
40
+ S ; I ; = ; S ; D ; = ; =
41
+ إنت ; ي ; خويا ; وعشيلي ; <eps> ; صالح ; نحبك
42
+ ================================================================================
43
+ Salah8, %WER 11.11 [ 1 / 9, 0 ins, 0 del, 1 sub ]
44
+ حكيت ; مع ; لولاد ; قالولي ; كل ; شي ; مريقل ; نهار ; السبت
45
+ = ; = ; S ; = ; = ; = ; = ; = ; =
46
+ حكيت ; مع ; الاولاد ; قالولي ; كل ; شي ; مريقل ; نهار ; السبت
47
+ ================================================================================
48
+ Salah3, %WER 85.71 [ 6 / 7, 0 ins, 1 del, 5 sub ]
49
+ اعطيني ; خمسة ; الاف ; و ; خمسة ; ميا ; بلاهي
50
+ S ; = ; S ; S ; S ; S ; D
51
+ أعطيني ; خمسة ; آلاف ; وخمسة ; مية ; باللاهي ; <eps>
52
+ ================================================================================
53
+ Salah9, %WER 37.50 [ 3 / 8, 0 ins, 1 del, 2 sub ]
54
+ ناكل ; كفتاجي ; و ; نجم ; نشري ; شوية ; حوت ; زادة
55
+ = ; S ; S ; D ; = ; = ; = ; =
56
+ ناكل ; الكفتاجي ; وننجم ; <eps> ; نشري ; شوية ; حوت ; زادة
57
+ ================================================================================
58
+ Salah1, %WER 42.86 [ 3 / 7, 0 ins, 1 del, 2 sub ]
59
+ نحب ; ماكلة ; بنينة ; كسكروت ; نظيف ; و ; رخيص
60
+ S ; = ; = ; = ; = ; S ; D
61
+ لحم ; ماكلة ; بنينة ; كسكروت ; نظيف ; ورخيص ; <eps>
partly_frozen_splitted_wavlm/1986/wer_test_salah_local.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ %WER 42.86 [ 6 / 14, 0 ins, 2 del, 4 sub ]
2
+ %SER 100.00 [ 2 / 2 ]
3
+ Scored 2 sentences, 0 not present in hyp.
4
+ ================================================================================
5
+ ALIGNMENTS
6
+
7
+ Format:
8
+ <utterance-id>, WER DETAILS
9
+ <eps> ; reference ; on ; the ; first ; line
10
+ I ; S ; = ; = ; S ; D
11
+ and ; hypothesis ; on ; the ; third ; <eps>
12
+ ================================================================================
13
+ Salah1, %WER 42.86 [ 3 / 7, 0 ins, 1 del, 2 sub ]
14
+ نحب ; ماكلة ; بنينة ; كسكروت ; نظيف ; و ; رخيص
15
+ S ; = ; = ; = ; = ; S ; D
16
+ لحم ; ماكلة ; بنينة ; كسكروت ; نظيف ; ورخيص ; <eps>
17
+ ================================================================================
18
+ Salah2, %WER 42.86 [ 3 / 7, 0 ins, 1 del, 2 sub ]
19
+ نحب ; ماكلة ; بنينة ; كسكروت ; نظيف ; و ; رخيص
20
+ S ; = ; = ; = ; = ; S ; D
21
+ لحم ; ماكلة ; بنينة ; كسكروت ; نظيف ; ورخيص ; <eps>
partly_frozen_splitted_wavlm/ctc_train.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env/python3
2
+ """Recipe for training a wav2vec-based ctc ASR system with librispeech.
3
+ The system employs wav2vec as its encoder. Decoding is performed with
4
+ ctc greedy decoder.
5
+ To run this recipe, do the following:
6
+ > python train_with_wav2vec.py hparams/train_with_wav2vec.yaml
7
+ The neural network is trained on CTC likelihood target and character units
8
+ are used as basic recognition tokens. Training is performed on the full
9
+ LibriSpeech dataset (960 h).
10
+
11
+ Authors
12
+ * Sung-Lin Yeh 2021
13
+ * Titouan Parcollet 2021
14
+ * Ju-Chieh Chou 2020
15
+ * Mirco Ravanelli 2020
16
+ * Abdel Heba 2020
17
+ * Peter Plantinga 2020
18
+ * Samuele Cornell 2020
19
+ """
20
+
21
+ import os
22
+ import sys
23
+ import torch
24
+ import logging
25
+ import speechbrain as sb
26
+ from speechbrain.utils.distributed import run_on_main
27
+ from hyperpyyaml import load_hyperpyyaml
28
+ from pathlib import Path
29
+ import torchaudio.transforms as T
30
+ logger = logging.getLogger(__name__)
31
+
32
+ # Define training procedure
33
+ class ASR(sb.Brain):
34
+ def compute_forward(self, batch, stage):
35
+ """Forward computations from the waveform batches to the output probabilities."""
36
+ batch = batch.to(self.device)
37
+ wavs, wav_lens = batch.sig
38
+ tokens_bos, _ = batch.tokens_bos
39
+ wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
40
+
41
+ # Forward pass
42
+ feats = self.modules.wav2vec2(wavs)
43
+ x = self.modules.enc(feats)
44
+ # Compute outputs
45
+ p_tokens = None
46
+ logits = self.modules.ctc_lin(x)
47
+ p_ctc = self.hparams.log_softmax(logits)
48
+ if stage != sb.Stage.TRAIN:
49
+ p_tokens = sb.decoders.ctc_greedy_decode(
50
+ p_ctc, wav_lens, blank_id=self.hparams.blank_index
51
+ )
52
+ return p_ctc, wav_lens, p_tokens
53
+
54
+ def compute_objectives(self, predictions, batch, stage):
55
+ """Computes the loss (CTC+NLL) given predictions and targets."""
56
+
57
+ p_ctc, wav_lens, predicted_tokens = predictions
58
+
59
+ ids = batch.id
60
+ tokens_eos, tokens_eos_lens = batch.tokens_eos
61
+ tokens, tokens_lens = batch.tokens
62
+
63
+ if hasattr(self.modules, "env_corrupt") and stage == sb.Stage.TRAIN:
64
+ tokens_eos = torch.cat([tokens_eos, tokens_eos], dim=0)
65
+ tokens_eos_lens = torch.cat(
66
+ [tokens_eos_lens, tokens_eos_lens], dim=0
67
+ )
68
+ tokens = torch.cat([tokens, tokens], dim=0)
69
+ tokens_lens = torch.cat([tokens_lens, tokens_lens], dim=0)
70
+
71
+ loss_ctc = self.hparams.ctc_cost(p_ctc, tokens, wav_lens, tokens_lens)
72
+ loss = loss_ctc
73
+
74
+ if stage != sb.Stage.TRAIN:
75
+ # Decode token terms to words
76
+ predicted_words = [
77
+ "".join(self.tokenizer.decode_ndim(utt_seq)).split(" ")
78
+ for utt_seq in predicted_tokens
79
+ ]
80
+ target_words = [wrd.split(" ") for wrd in batch.wrd]
81
+ self.wer_metric.append(ids, predicted_words, target_words)
82
+ self.cer_metric.append(ids, predicted_words, target_words)
83
+
84
+ return loss
85
+
86
+ def fit_batch(self, batch):
87
+ """Train the parameters given a single batch in input"""
88
+ predictions = self.compute_forward(batch, sb.Stage.TRAIN)
89
+ loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
90
+ loss.backward()
91
+ if self.check_gradients(loss):
92
+ self.wav2vec_optimizer.step()
93
+ self.model_optimizer.step()
94
+
95
+ self.wav2vec_optimizer.zero_grad()
96
+ self.model_optimizer.zero_grad()
97
+
98
+ return loss.detach()
99
+
100
+ def evaluate_batch(self, batch, stage):
101
+ """Computations needed for validation/test batches"""
102
+ predictions = self.compute_forward(batch, stage=stage)
103
+ with torch.no_grad():
104
+ loss = self.compute_objectives(predictions, batch, stage=stage)
105
+ return loss.detach()
106
+
107
+ def on_stage_start(self, stage, epoch):
108
+ """Gets called at the beginning of each epoch"""
109
+ if stage != sb.Stage.TRAIN:
110
+ self.cer_metric = self.hparams.cer_computer()
111
+ self.wer_metric = self.hparams.error_rate_computer()
112
+
113
+ def on_stage_end(self, stage, stage_loss, epoch):
114
+ """Gets called at the end of an epoch."""
115
+ # Compute/store important stats
116
+ stage_stats = {"loss": stage_loss}
117
+ if stage == sb.Stage.TRAIN:
118
+ self.train_stats = stage_stats
119
+ else:
120
+ stage_stats["CER"] = self.cer_metric.summarize("error_rate")
121
+ stage_stats["WER"] = self.wer_metric.summarize("error_rate")
122
+
123
+ # Perform end-of-iteration things, like annealing, logging, etc.
124
+ if stage == sb.Stage.VALID:
125
+ old_lr_model, new_lr_model = self.hparams.lr_annealing_model(
126
+ stage_stats["loss"]
127
+ )
128
+ old_lr_wav2vec, new_lr_wav2vec = self.hparams.lr_annealing_wav2vec(
129
+ stage_stats["loss"]
130
+ )
131
+ sb.nnet.schedulers.update_learning_rate(
132
+ self.model_optimizer, new_lr_model
133
+ )
134
+ sb.nnet.schedulers.update_learning_rate(
135
+ self.wav2vec_optimizer, new_lr_wav2vec
136
+ )
137
+ self.hparams.train_logger.log_stats(
138
+ stats_meta={
139
+ "epoch": epoch,
140
+ "lr_model": old_lr_model,
141
+ "lr_wav2vec": old_lr_wav2vec,
142
+ },
143
+ train_stats=self.train_stats,
144
+ valid_stats=stage_stats,
145
+ )
146
+ self.checkpointer.save_and_keep_only(
147
+ meta={"WER": stage_stats["WER"]}, min_keys=["WER"],
148
+ )
149
+ elif stage == sb.Stage.TEST:
150
+ self.hparams.train_logger.log_stats(
151
+ stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
152
+ test_stats=stage_stats,
153
+ )
154
+ with open(self.hparams.wer_file, "w") as w:
155
+ self.wer_metric.write_stats(w)
156
+
157
+ def init_optimizers(self):
158
+ "Initializes the wav2vec2 optimizer and model optimizer"
159
+ self.wav2vec_optimizer = self.hparams.wav2vec_opt_class(
160
+ self.modules.wav2vec2.parameters()
161
+ )
162
+ self.model_optimizer = self.hparams.model_opt_class(
163
+ self.hparams.model.parameters()
164
+ )
165
+
166
+ if self.checkpointer is not None:
167
+ self.checkpointer.add_recoverable(
168
+ "wav2vec_opt", self.wav2vec_optimizer
169
+ )
170
+ self.checkpointer.add_recoverable("modelopt", self.model_optimizer)
171
+
172
+
173
+ def dataio_prepare(hparams):
174
+ """This function prepares the datasets to be used in the brain class.
175
+ It also defines the data processing pipeline through user-defined functions."""
176
+ data_folder = hparams["data_folder"]
177
+
178
+ train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
179
+ csv_path=hparams["train_csv"], replacements={"data_root": data_folder},
180
+ )
181
+
182
+ if hparams["sorting"] == "ascending":
183
+ # we sort training data to speed up training and get better results.
184
+ train_data = train_data.filtered_sorted(sort_key="duration")
185
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
186
+ hparams["train_dataloader_opts"]["shuffle"] = False
187
+
188
+ elif hparams["sorting"] == "descending":
189
+ train_data = train_data.filtered_sorted(
190
+ sort_key="duration", reverse=True
191
+ )
192
+ # when sorting do not shuffle in dataloader ! otherwise is pointless
193
+ hparams["train_dataloader_opts"]["shuffle"] = False
194
+
195
+ elif hparams["sorting"] == "random":
196
+ pass
197
+
198
+ else:
199
+ raise NotImplementedError(
200
+ "sorting must be random, ascending or descending"
201
+ )
202
+
203
+ valid_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
204
+ csv_path=hparams["valid_csv"], replacements={"data_root": data_folder},
205
+ )
206
+ valid_data = valid_data.filtered_sorted(sort_key="duration")
207
+
208
+ # test is separate
209
+ test_datasets = {}
210
+ for csv_file in hparams["test_csv"]:
211
+ name = Path(csv_file).stem
212
+ test_datasets[name] = sb.dataio.dataset.DynamicItemDataset.from_csv(
213
+ csv_path=csv_file, replacements={"data_root": data_folder}
214
+ )
215
+ test_datasets[name] = test_datasets[name].filtered_sorted(
216
+ sort_key="duration"
217
+ )
218
+
219
+ datasets = [train_data, valid_data] + [i for k, i in test_datasets.items()]
220
+
221
+ # 2. Define audio pipeline:
222
+ @sb.utils.data_pipeline.takes("wav", "sr")
223
+ @sb.utils.data_pipeline.provides("sig")
224
+ def audio_pipeline(wav, sr):
225
+ sig = sb.dataio.dataio.read_audio(wav)
226
+ sig = resamplers[sr](sig)
227
+ return sig
228
+
229
+ sb.dataio.dataset.add_dynamic_item(datasets, audio_pipeline)
230
+ label_encoder = sb.dataio.encoder.CTCTextEncoder()
231
+
232
+ # 3. Define text pipeline:
233
+ @sb.utils.data_pipeline.takes("wrd")
234
+ @sb.utils.data_pipeline.provides(
235
+ "wrd", "char_list", "tokens_list", "tokens_bos", "tokens_eos", "tokens"
236
+ )
237
+ def text_pipeline(wrd):
238
+ yield wrd
239
+ char_list = list(wrd)
240
+ yield char_list
241
+ tokens_list = label_encoder.encode_sequence(char_list)
242
+ yield tokens_list
243
+ tokens_bos = torch.LongTensor([hparams["bos_index"]] + (tokens_list))
244
+ yield tokens_bos
245
+ tokens_eos = torch.LongTensor(tokens_list + [hparams["eos_index"]])
246
+ yield tokens_eos
247
+ tokens = torch.LongTensor(tokens_list)
248
+ yield tokens
249
+
250
+ sb.dataio.dataset.add_dynamic_item(datasets, text_pipeline)
251
+
252
+ lab_enc_file = os.path.join(hparams["save_folder"], "label_encoder.txt")
253
+ special_labels = {
254
+ "bos_label": hparams["bos_index"],
255
+ "eos_label": hparams["eos_index"],
256
+ "blank_label": hparams["blank_index"],
257
+ }
258
+ label_encoder.load_or_create(
259
+ path=lab_enc_file,
260
+ from_didatasets=[train_data],
261
+ output_key="char_list",
262
+ special_labels=special_labels,
263
+ sequence_input=True,
264
+ )
265
+
266
+ # 4. Set output:
267
+ sb.dataio.dataset.set_output_keys(
268
+ datasets,
269
+ ["id", "sig", "wrd", "char_list", "tokens_bos", "tokens_eos", "tokens"],
270
+ )
271
+ return train_data, valid_data, test_datasets, label_encoder
272
+
273
+
274
+ if __name__ == "__main__":
275
+
276
+ # CLI:
277
+ hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
278
+
279
+ # If distributed_launch=True then
280
+ # create ddp_group with the right communication protocol
281
+ sb.utils.distributed.ddp_init_group(run_opts)
282
+
283
+ with open(hparams_file) as fin:
284
+ hparams = load_hyperpyyaml(fin, overrides)
285
+
286
+ # Create experiment directory
287
+ sb.create_experiment_directory(
288
+ experiment_directory=hparams["output_folder"],
289
+ hyperparams_to_save=hparams_file,
290
+ overrides=overrides,
291
+ )
292
+
293
+ # Dataset prep (parsing Librispeech)
294
+
295
+ resampler_8000 = T.Resample(8000, 16000, dtype=torch.float)
296
+
297
+ resampler_44100 =T.Resample(44100, 16000, dtype=torch.float)
298
+ resampler_32000 =T.Resample(32000, 16000, dtype=torch.float)
299
+ resampler_48000 =T.Resample(48000, 16000, dtype=torch.float)
300
+
301
+
302
+ resamplers = {"48000": resampler_48000,"8000": resampler_8000, "44100":resampler_44100, "32000":resampler_32000}
303
+
304
+ # here we create the datasets objects as well as tokenization and encoding
305
+ train_data, valid_data, test_datasets, label_encoder = dataio_prepare(
306
+ hparams
307
+ )
308
+
309
+ # Trainer initialization
310
+ asr_brain = ASR(
311
+ modules=hparams["modules"],
312
+ hparams=hparams,
313
+ run_opts=run_opts,
314
+ checkpointer=hparams["checkpointer"],
315
+ )
316
+ asr_brain.device= "cpu"
317
+ asr_brain.modules.to("cpu")
318
+
319
+ # We dynamicaly add the tokenizer to our brain class.
320
+ # NB: This tokenizer corresponds to the one used for the LM!!
321
+ asr_brain.tokenizer = label_encoder
322
+
323
+ # Training
324
+ asr_brain.fit(
325
+ asr_brain.hparams.epoch_counter,
326
+ train_data,
327
+ valid_data,
328
+ train_loader_kwargs=hparams["train_dataloader_opts"],
329
+ valid_loader_kwargs=hparams["valid_dataloader_opts"],
330
+ )
331
+
332
+ # Testing
333
+ for k in test_datasets.keys(): # keys are test_clean, test_other etc
334
+ asr_brain.hparams.wer_file = os.path.join(
335
+ hparams["output_folder"], "wer_{}.txt".format(k)
336
+ )
337
+ asr_brain.evaluate(
338
+ test_datasets[k], test_loader_kwargs=hparams["test_dataloader_opts"]
339
+ )
partly_frozen_splitted_wavlm/env.log ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SpeechBrain system description
2
+ ==============================
3
+ Python version:
4
+ 3.8.5 (default, Sep 4 2020, 07:30:14)
5
+ [GCC 7.3.0]
6
+ ==============================
7
+ Installed Python packages:
8
+ abkhazia==1.0
9
+ absl-py==0.11.0
10
+ aiohttp==3.8.0
11
+ aiosignal==1.2.0
12
+ alabaster==0.7.12
13
+ alembic==1.7.4
14
+ altgraph==0.17
15
+ antlr4-python3-runtime==4.8
16
+ appdirs==1.4.4
17
+ argcomplete==1.12.2
18
+ argon2-cffi==20.1.0
19
+ asgiref==3.6.0
20
+ astunparse==1.6.3
21
+ async-generator==1.10
22
+ async-timeout==4.0.0
23
+ attrdict==2.0.1
24
+ attrs==20.3.0
25
+ audeer==1.16.0
26
+ audformat==0.11.5
27
+ audinterface==0.7.0
28
+ audiofile==1.0.0
29
+ audiomentations==0.25.0
30
+ audioread==2.1.9
31
+ audobject==0.4.14
32
+ audresample==0.1.6
33
+ -e git+https://github.com/facebookresearch/WavAugment.git@54afcdb00ccc852c2f030f239f8532c9562b550e#egg=augment
34
+ autopage==0.4.0
35
+ Babel==2.9.0
36
+ backcall==0.2.0
37
+ beautifulsoup4==4.10.0
38
+ black==19.10b0
39
+ bleach==3.3.0
40
+ boto3==1.20.2
41
+ botocore==1.23.2
42
+ braceexpand==0.1.7
43
+ cachetools==4.2.0
44
+ certifi @ file:///croot/certifi_1671487769961/work/certifi
45
+ cffi==1.14.3
46
+ cfgv==3.2.0
47
+ chardet==3.0.4
48
+ charset-normalizer==2.0.7
49
+ click==7.1.2
50
+ cliff==3.9.0
51
+ clldutils==3.5.4
52
+ cmaes==0.8.2
53
+ cmake==3.18.4.post1
54
+ cmd2==2.2.0
55
+ colorama==0.4.4
56
+ colorlog==4.6.2
57
+ configparser==5.1.0
58
+ cryptography==38.0.4
59
+ csvw==1.8.1
60
+ cycler==0.10.0
61
+ Cython==0.29.21
62
+ dataclasses==0.6
63
+ datasets==1.5.0
64
+ decorator==4.4.2
65
+ deepspeech==0.9.1
66
+ defusedxml==0.7.1
67
+ denoiser==0.1.5
68
+ dill==0.3.3
69
+ Distance==0.1.3
70
+ distlib==0.3.1
71
+ Django==3.2.16
72
+ django-auditlog==2.2.1
73
+ django-filter==22.1
74
+ django-js-asset==1.2.2
75
+ django-mptt==0.14.0
76
+ djangorestframework==3.14.0
77
+ docker-pycreds==0.4.0
78
+ docopt==0.6.2
79
+ docutils==0.16
80
+ drf-excel==2.2.0
81
+ drf-flex-fields==1.0.0
82
+ drf-renderer-xlsx==0.4.1
83
+ easyocr==1.2.1
84
+ editdistance==0.6.0
85
+ emoji==2.2.0
86
+ entrypoints==0.3
87
+ et-xmlfile==1.1.0
88
+ exceptiongroup==1.1.0
89
+ farasapy==0.0.14
90
+ fasttext==0.9.2
91
+ ffmpeg-python==0.2.0
92
+ filelock==3.0.12
93
+ flake8==3.7.9
94
+ flatbuffers==1.12
95
+ frozendict==2.0.7
96
+ frozenlist==1.2.0
97
+ fsspec==2021.11.0
98
+ future==0.18.2
99
+ g2p-en==2.1.0
100
+ gast==0.3.3
101
+ gdown==4.2.0
102
+ gensim==4.0.1
103
+ gitdb==4.0.9
104
+ GitPython==3.1.24
105
+ google-auth==1.24.0
106
+ google-auth-oauthlib==0.4.2
107
+ google-pasta==0.2.0
108
+ greenlet==1.1.2
109
+ grpcio==1.32.0
110
+ h5features==1.3.2
111
+ h5py==2.10.0
112
+ htk-io==0.5
113
+ huggingface-hub==0.9.1
114
+ hydra-colorlog==0.1.4
115
+ hydra-core==0.11.3
116
+ HyperPyYAML==1.1.0
117
+ hypothesis==6.61.2
118
+ identify==1.5.10
119
+ idna==2.10
120
+ imageio==2.9.0
121
+ imagesize==1.2.0
122
+ importlib-metadata==4.8.1
123
+ importlib-resources==5.2.2
124
+ inflect==5.3.0
125
+ ipadic==1.0.0
126
+ ipykernel==5.3.4
127
+ ipython==7.19.0
128
+ ipython-genutils==0.2.0
129
+ ipywidgets==7.6.3
130
+ iso-639==0.4.5
131
+ isodate==0.6.0
132
+ isort==4.3.21
133
+ jedi==0.17.2
134
+ jieba==0.42.1
135
+ Jinja2==2.11.2
136
+ jiwer==2.2.0
137
+ jmespath==0.10.0
138
+ joblib==0.17.0
139
+ jsonschema==3.2.0
140
+ julius==0.2.7
141
+ jupyter-client==6.1.7
142
+ jupyter-core==4.7.0
143
+ jupyterlab-pygments==0.1.2
144
+ jupyterlab-widgets==1.0.0
145
+ kaitaistruct==0.9
146
+ kaldi-io==0.9.4
147
+ kaldi-python-io==1.2.2
148
+ kaldiio==2.17.2
149
+ kenlm @ https://github.com/kpu/kenlm/archive/master.zip
150
+ Keras-Preprocessing==1.1.2
151
+ kiwisolver==1.3.1
152
+ lang-trans==0.6.0
153
+ latexcodec==2.0.1
154
+ ldap3==2.9.1
155
+ librosa==0.9.0
156
+ llvmlite==0.35.0
157
+ lxml==4.9.0
158
+ Mako==1.1.5
159
+ Markdown==3.3.3
160
+ MarkupSafe==1.1.1
161
+ marshmallow==3.14.0
162
+ matplotlib==3.3.3
163
+ mccabe==0.6.1
164
+ mcd==0.4
165
+ mecab-python3==1.0.3
166
+ megatron-lm==2.2.0
167
+ mido==1.2.10
168
+ mistune==0.8.4
169
+ more-itertools==8.6.0
170
+ mpmath==1.2.1
171
+ multidict==5.2.0
172
+ multiprocess==0.70.11.1
173
+ nbclient==0.5.3
174
+ nbconvert==6.0.7
175
+ nbformat==5.1.3
176
+ NEMO==4.3.2
177
+ nemo-toolkit==1.4.0
178
+ nest-asyncio==1.5.1
179
+ networkx==2.5
180
+ nltk==3.5
181
+ nodeenv==1.5.0
182
+ notebook==6.3.0
183
+ numba==0.52.0
184
+ numpy==1.19.4
185
+ nvidia-cublas-cu11==11.10.3.66
186
+ nvidia-cuda-nvrtc-cu11==11.7.99
187
+ nvidia-cuda-runtime-cu11==11.7.99
188
+ nvidia-cudnn-cu11==8.5.0.96
189
+ oauthlib==3.1.0
190
+ omegaconf==1.4.1
191
+ onnx==1.10.2
192
+ OpenCC==1.1.2
193
+ opencv-python==4.4.0.46
194
+ openpyxl==3.0.9
195
+ opensmile==2.2.0
196
+ opt-einsum==3.3.0
197
+ optuna==2.10.0
198
+ oyaml==1.0
199
+ packaging==22.0
200
+ pandas==1.2.5
201
+ pandocfilters==1.4.3
202
+ pangu==4.0.6.1
203
+ parameterized==0.8.1
204
+ parso==0.7.1
205
+ pathspec==0.8.1
206
+ pathtools==0.1.2
207
+ pbr==5.6.0
208
+ pefile==2019.4.18
209
+ pescador==2.1.0
210
+ pesq==0.0.3
211
+ pexpect==4.8.0
212
+ phonemizer==2.2.1
213
+ pickleshare==0.7.5
214
+ Pillow==9.3.0
215
+ pip-api==0.0.23
216
+ pipreqs==0.4.11
217
+ pluggy==0.13.1
218
+ pooch==1.3.0
219
+ portalocker==2.3.2
220
+ pre-commit==2.9.0
221
+ pretty-midi==0.2.9
222
+ prettytable==2.2.1
223
+ progressbar2==3.53.1
224
+ prometheus-client==0.10.1
225
+ promise==2.3
226
+ prompt-toolkit==3.0.8
227
+ protobuf==3.14.0
228
+ psutil==5.6.6
229
+ ptyprocess==0.6.0
230
+ py==1.9.0
231
+ py-espeak-ng==0.1.8
232
+ pyannote.audio==1.1.1
233
+ pyannote.core==4.3
234
+ pyannote.database==4.1.1
235
+ pyannote.metrics==3.1
236
+ pyannote.pipeline==1.5.2
237
+ PyArabic==0.6.15
238
+ pyarrow==3.0.0
239
+ pyasn1==0.4.8
240
+ pyasn1-modules==0.2.8
241
+ pybind11==2.8.1
242
+ pybtex==0.24.0
243
+ pybtex-docutils==1.0.1
244
+ pycodestyle==2.5.0
245
+ pycparser==2.20
246
+ pyctcdecode==0.4.0
247
+ pyDeprecate==0.3.1
248
+ pydub==0.25.1
249
+ pyflakes==2.1.1
250
+ Pygments==2.7.2
251
+ pygtrie==2.5.0
252
+ pymodbus==2.5.3
253
+ pyparsing==2.4.7
254
+ pyperclip==1.8.2
255
+ pypinyin==0.43.0
256
+ pyrsistent==0.17.3
257
+ pyserial==3.5
258
+ PySocks==1.7.1
259
+ pystoi==0.3.3
260
+ pytest==5.4.1
261
+ pytest-runner==5.3.1
262
+ python-bidi==0.4.2
263
+ python-crfsuite==0.9.7
264
+ python-dateutil==2.8.2
265
+ python-Levenshtein==0.12.2
266
+ python-utils==2.4.0
267
+ pytorch-lightning==1.4.9
268
+ pytube==11.0.1
269
+ pytz==2022.6
270
+ PyWavelets==1.1.1
271
+ PyYAML==5.3.1
272
+ pyzmq==20.0.0
273
+ rapidfuzz==1.8.2
274
+ regex==2020.11.13
275
+ requests==2.28.1
276
+ requests-oauthlib==1.3.0
277
+ resampy==0.2.2
278
+ rfc3986==1.4.0
279
+ rsa==4.7
280
+ ruamel.yaml==0.17.21
281
+ ruamel.yaml.clib==0.2.7
282
+ s3m==1.1.0
283
+ s3transfer==0.5.0
284
+ sacrebleu==2.0.0
285
+ sacremoses==0.0.44
286
+ scikit-image==0.18.1
287
+ scikit-learn==0.23.2
288
+ scipy==1.5.4
289
+ -e git+https://github.com/sanghack81/SDCIT@00d060dde733fde9345154a494f81e97fb395ca7#egg=SDCIT
290
+ seaborn==0.11.1
291
+ segments==2.1.3
292
+ Send2Trash==1.5.0
293
+ sentencepiece==0.1.94
294
+ sentry-sdk==1.4.3
295
+ shellingham==1.4.0
296
+ shortuuid==1.0.7
297
+ SIDEKIT==1.3.8.5.2
298
+ simplejson==3.17.5
299
+ six==1.15.0
300
+ smart-open==5.0.0
301
+ smmap==5.0.0
302
+ snowballstemmer==2.0.0
303
+ sortedcollections==2.1.0
304
+ sortedcontainers==2.4.0
305
+ sounddevice==0.4.5
306
+ SoundFile==0.10.3.post1
307
+ soupsieve==2.3
308
+ sox==1.4.1
309
+ sparsemax==0.1.9
310
+ speechbrain==0.5.13
311
+ sphfile==1.0.3
312
+ Sphinx==3.3.1
313
+ sphinx-rtd-theme==0.4.3
314
+ sphinxcontrib-applehelp==1.0.2
315
+ sphinxcontrib-bibtex==2.4.1
316
+ sphinxcontrib-devhelp==1.0.2
317
+ sphinxcontrib-htmlhelp==1.0.3
318
+ sphinxcontrib-jsmath==1.0.1
319
+ sphinxcontrib-qthelp==1.0.3
320
+ sphinxcontrib-serializinghtml==1.1.4
321
+ SQLAlchemy==1.4.25
322
+ sqlparse==0.4.2
323
+ stanza==1.4.2
324
+ stevedore==3.4.0
325
+ subprocess32==3.5.4
326
+ sympy==1.9
327
+ tabulate==0.8.9
328
+ tensorboard==2.4.0
329
+ tensorboard-plugin-wit==1.7.0
330
+ tensorflow==2.4.0
331
+ tensorflow-estimator==2.4.0
332
+ termcolor==1.1.0
333
+ terminado==0.9.4
334
+ testpath==0.4.4
335
+ threadpoolctl==2.1.0
336
+ tifffile==2020.12.8
337
+ tikzplotlib==0.9.8
338
+ tkseem==0.0.3
339
+ tokenizers==0.10.2
340
+ toml==0.10.2
341
+ torch==1.13.1
342
+ torch-stft==0.1.4
343
+ torchaudio==0.13.1
344
+ torchmetrics==0.6.0
345
+ torchvision==0.14.1
346
+ tornado==6.1
347
+ tqdm==4.61.1
348
+ trackrip==1.2.1
349
+ traitlets==5.0.5
350
+ transformers==4.15.0
351
+ typed-ast==1.4.1
352
+ typer==0.4.0
353
+ typing-extensions==3.7.4.3
354
+ Unidecode==1.3.2
355
+ uritemplate==3.0.1
356
+ urllib3==1.26.2
357
+ virtualenv==20.2.1
358
+ wandb==0.12.6
359
+ wcwidth==0.2.5
360
+ webdataset==0.1.62
361
+ webencodings==0.5.1
362
+ Werkzeug==1.0.1
363
+ wget==3.2
364
+ widgetsnbextension==3.5.1
365
+ wordninja==2.0.0
366
+ wrapt==1.12.1
367
+ xmltodict==0.13.0
368
+ xxhash==2.0.0
369
+ yamllint==1.23.0
370
+ yarg==0.1.9
371
+ yarl==1.7.2
372
+ yaspin==2.1.0
373
+ youtokentome==1.0.6
374
+ youtube-dl==2021.6.6
375
+ zipp==3.6.0
376
+ ==============================
377
+ Could not get git revision==============================
378
+ CUDA version:
379
+ 11.7
partly_frozen_splitted_wavlm/hyperparams.yaml ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Generated 2023-01-07 from:
2
+ # /home/salah/kenlm_train/to_copy/wavlm_partly_frozen.yaml
3
+ # yamllint disable
4
+ # ################################
5
+ # Model: wav2vec2 + DNN + CTC
6
+ # Augmentation: SpecAugment
7
+ # Authors: Sung-Lin Yeh 2021
8
+ # ################################
9
+
10
+ # Seed needs to be set at top of yaml, before objects with parameters are made
11
+ seed: 1986
12
+ __set_seed: !apply:torch.manual_seed [1986]
13
+ output_folder: partly_frozen_splitted_wavlm
14
+ wer_file: partly_frozen_splitted_wavlm/wer.txt
15
+ save_folder: partly_frozen_splitted_wavlm/save
16
+ train_log: partly_frozen_splitted_wavlm/train_log.txt
17
+
18
+ # URL for the biggest Fairseq english wav2vec2 model.
19
+
20
+ # Data files
21
+ data_folder: /gpfsscratch/rech/nou/uzn19yk/Libri/LibriSpeech/ # e,g./path/to/LibriSpeech
22
+ # noise/ris dataset will automatically be downloaded
23
+ data_folder_rirs: /gpfsscratch/rech/nou/uzn19yk/Libri/LibriSpeech/
24
+ train_splits: [train-clean-100]
25
+ dev_splits: [dev-clean]
26
+ test_splits: [test-clean, test-other]
27
+ skip_prep: false
28
+ ckpt_interval_minutes: 25 # save checkpoint every N min
29
+ csv_folder: /gpfsstore/rech/nou/uzn19yk/iwslt/splitted_clean_tunisian_csvs/
30
+ train_csv: test_salah_local.csv
31
+ valid_csv: test_salah_local.csv
32
+ test_csv:
33
+ - test_salah_local.csv
34
+
35
+ # Training parameters
36
+ number_of_epochs: 12
37
+ lr: 1
38
+ lr_wav2vec: 0.0001
39
+ sorting: ascending
40
+ auto_mix_prec: false
41
+ sample_rate: 16000
42
+
43
+ avoid_if_longer_than: 10
44
+ # With data_parallel batch_size is split into N jobs
45
+ # With DDP batch_size is multiplied by N jobs
46
+ # Must be 3 per GPU to fit 32GB of VRAM
47
+ batch_size: 1
48
+ test_batch_size: 1
49
+
50
+ # Dataloader options
51
+ train_dataloader_opts:
52
+ batch_size: 1
53
+
54
+ valid_dataloader_opts:
55
+ batch_size: 1
56
+
57
+ test_dataloader_opts:
58
+ batch_size: 1
59
+
60
+ # Model parameters
61
+ activation: &id001 !name:torch.nn.LeakyReLU
62
+ dnn_layers: 2
63
+ dnn_neurons: 1024
64
+ freeze_wav2vec: false
65
+
66
+ # Outputs
67
+ output_neurons: 41 # BPE size, index(blank/eos/bos) = 0
68
+
69
+ # Decoding parameters
70
+ blank_index: 0
71
+ bos_index: 1
72
+ eos_index: 2
73
+
74
+ #
75
+ # Functions and classes
76
+ #
77
+ epoch_counter: &id008 !new:speechbrain.utils.epoch_loop.EpochCounter
78
+
79
+ limit: 12
80
+
81
+ augmentation: !new:speechbrain.lobes.augment.TimeDomainSpecAugment
82
+ sample_rate: 16000
83
+ speeds: [95, 100, 105]
84
+
85
+ enc: &id003 !new:speechbrain.lobes.models.VanillaNN.VanillaNN
86
+ input_shape: [null, null, 1024]
87
+ activation: *id001
88
+ dnn_blocks: 2
89
+ dnn_neurons: 1024
90
+
91
+ wav2vec2: &id002 !new:speechbrain.lobes.models.huggingface_wav2vec.HuggingFaceWav2Vec2
92
+ source: wavlm-large/
93
+ output_norm: true
94
+ freeze: false
95
+ freeze_feature_extractor: true
96
+ save_path: partly_frozen_splitted_wavlm/save/wav2vec2_hubert_checkpoint
97
+
98
+ #####
99
+ # Uncomment this block if you prefer to use a Fairseq pretrained model instead
100
+ # of a HuggingFace one. Here, we provide an URL that is obtained from the
101
+ # Fairseq github for the multilingual XLSR.
102
+ #
103
+ #wav2vec2_url: https://dl.fbaipublicfiles.com/fairseq/wav2vec/wav2vec_vox_960h_pl.pt
104
+ #wav2vec2: !new:speechbrain.lobes.models.fairseq_wav2vec.FairseqWav2Vec2
105
+ # pretrained_path: !ref <wav2vec2_url>
106
+ # output_norm: True
107
+ # freeze: False
108
+ # save_path: !ref <save_folder>/wav2vec2_checkpoint/model.pt
109
+
110
+ ctc_lin: &id004 !new:speechbrain.nnet.linear.Linear
111
+
112
+ input_size: 1024
113
+ n_neurons: 41
114
+
115
+ log_softmax: !new:speechbrain.nnet.activations.Softmax
116
+ apply_log: true
117
+
118
+ ctc_cost: !name:speechbrain.nnet.losses.ctc_loss
119
+ blank_index: 0
120
+
121
+ modules:
122
+ wav2vec2: *id002
123
+ enc: *id003
124
+ ctc_lin: *id004
125
+ model: &id005 !new:torch.nn.ModuleList
126
+ - [*id003, *id004]
127
+ model_opt_class: !name:torch.optim.Adadelta
128
+ lr: 1
129
+ rho: 0.95
130
+ eps: 1.e-8
131
+
132
+ wav2vec_opt_class: !name:torch.optim.Adam
133
+ lr: 0.0001
134
+
135
+ lr_annealing_model: &id006 !new:speechbrain.nnet.schedulers.NewBobScheduler
136
+ initial_value: 1
137
+ improvement_threshold: 0.0025
138
+ annealing_factor: 0.8
139
+ patient: 0
140
+
141
+ lr_annealing_wav2vec: &id007 !new:speechbrain.nnet.schedulers.NewBobScheduler
142
+ initial_value: 0.0001
143
+ improvement_threshold: 0.0025
144
+ annealing_factor: 0.9
145
+ patient: 0
146
+
147
+
148
+ checkpointer: !new:speechbrain.utils.checkpoints.Checkpointer
149
+ checkpoints_dir: partly_frozen_splitted_wavlm/save
150
+ recoverables:
151
+ wav2vec2: *id002
152
+ model: *id005
153
+ scheduler_model: *id006
154
+ scheduler_wav2vec: *id007
155
+ counter: *id008
156
+ train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger
157
+ save_file: partly_frozen_splitted_wavlm/train_log.txt
158
+
159
+ error_rate_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
160
+
161
+ cer_computer: !name:speechbrain.utils.metric_stats.ErrorRateStats
162
+ split_tokens: true
partly_frozen_splitted_wavlm/log.txt ADDED
@@ -0,0 +1,1998 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2023-01-07 15:57:01,210 - speechbrain.core - INFO - Beginning experiment!
2
+ 2023-01-07 15:57:01,210 - speechbrain.core - INFO - Experiment folder: partly_frozen_splitted_wavlm
3
+ 2023-01-07 15:57:01,912 - speechbrain.utils.superpowers - DEBUG - abkhazia==1.0
4
+ absl-py==0.11.0
5
+ aiohttp==3.8.0
6
+ aiosignal==1.2.0
7
+ alabaster==0.7.12
8
+ alembic==1.7.4
9
+ altgraph==0.17
10
+ antlr4-python3-runtime==4.8
11
+ appdirs==1.4.4
12
+ argcomplete==1.12.2
13
+ argon2-cffi==20.1.0
14
+ asgiref==3.6.0
15
+ astunparse==1.6.3
16
+ async-generator==1.10
17
+ async-timeout==4.0.0
18
+ attrdict==2.0.1
19
+ attrs==20.3.0
20
+ audeer==1.16.0
21
+ audformat==0.11.5
22
+ audinterface==0.7.0
23
+ audiofile==1.0.0
24
+ audiomentations==0.25.0
25
+ audioread==2.1.9
26
+ audobject==0.4.14
27
+ audresample==0.1.6
28
+ -e git+https://github.com/facebookresearch/WavAugment.git@54afcdb00ccc852c2f030f239f8532c9562b550e#egg=augment
29
+ autopage==0.4.0
30
+ Babel==2.9.0
31
+ backcall==0.2.0
32
+ beautifulsoup4==4.10.0
33
+ black==19.10b0
34
+ bleach==3.3.0
35
+ boto3==1.20.2
36
+ botocore==1.23.2
37
+ braceexpand==0.1.7
38
+ cachetools==4.2.0
39
+ certifi @ file:///croot/certifi_1671487769961/work/certifi
40
+ cffi==1.14.3
41
+ cfgv==3.2.0
42
+ chardet==3.0.4
43
+ charset-normalizer==2.0.7
44
+ click==7.1.2
45
+ cliff==3.9.0
46
+ clldutils==3.5.4
47
+ cmaes==0.8.2
48
+ cmake==3.18.4.post1
49
+ cmd2==2.2.0
50
+ colorama==0.4.4
51
+ colorlog==4.6.2
52
+ configparser==5.1.0
53
+ cryptography==38.0.4
54
+ csvw==1.8.1
55
+ cycler==0.10.0
56
+ Cython==0.29.21
57
+ dataclasses==0.6
58
+ datasets==1.5.0
59
+ decorator==4.4.2
60
+ deepspeech==0.9.1
61
+ defusedxml==0.7.1
62
+ denoiser==0.1.5
63
+ dill==0.3.3
64
+ Distance==0.1.3
65
+ distlib==0.3.1
66
+ Django==3.2.16
67
+ django-auditlog==2.2.1
68
+ django-filter==22.1
69
+ django-js-asset==1.2.2
70
+ django-mptt==0.14.0
71
+ djangorestframework==3.14.0
72
+ docker-pycreds==0.4.0
73
+ docopt==0.6.2
74
+ docutils==0.16
75
+ drf-excel==2.2.0
76
+ drf-flex-fields==1.0.0
77
+ drf-renderer-xlsx==0.4.1
78
+ easyocr==1.2.1
79
+ editdistance==0.6.0
80
+ emoji==2.2.0
81
+ entrypoints==0.3
82
+ et-xmlfile==1.1.0
83
+ exceptiongroup==1.1.0
84
+ farasapy==0.0.14
85
+ fasttext==0.9.2
86
+ ffmpeg-python==0.2.0
87
+ filelock==3.0.12
88
+ flake8==3.7.9
89
+ flatbuffers==1.12
90
+ frozendict==2.0.7
91
+ frozenlist==1.2.0
92
+ fsspec==2021.11.0
93
+ future==0.18.2
94
+ g2p-en==2.1.0
95
+ gast==0.3.3
96
+ gdown==4.2.0
97
+ gensim==4.0.1
98
+ gitdb==4.0.9
99
+ GitPython==3.1.24
100
+ google-auth==1.24.0
101
+ google-auth-oauthlib==0.4.2
102
+ google-pasta==0.2.0
103
+ greenlet==1.1.2
104
+ grpcio==1.32.0
105
+ h5features==1.3.2
106
+ h5py==2.10.0
107
+ htk-io==0.5
108
+ huggingface-hub==0.9.1
109
+ hydra-colorlog==0.1.4
110
+ hydra-core==0.11.3
111
+ HyperPyYAML==1.1.0
112
+ hypothesis==6.61.2
113
+ identify==1.5.10
114
+ idna==2.10
115
+ imageio==2.9.0
116
+ imagesize==1.2.0
117
+ importlib-metadata==4.8.1
118
+ importlib-resources==5.2.2
119
+ inflect==5.3.0
120
+ ipadic==1.0.0
121
+ ipykernel==5.3.4
122
+ ipython==7.19.0
123
+ ipython-genutils==0.2.0
124
+ ipywidgets==7.6.3
125
+ iso-639==0.4.5
126
+ isodate==0.6.0
127
+ isort==4.3.21
128
+ jedi==0.17.2
129
+ jieba==0.42.1
130
+ Jinja2==2.11.2
131
+ jiwer==2.2.0
132
+ jmespath==0.10.0
133
+ joblib==0.17.0
134
+ jsonschema==3.2.0
135
+ julius==0.2.7
136
+ jupyter-client==6.1.7
137
+ jupyter-core==4.7.0
138
+ jupyterlab-pygments==0.1.2
139
+ jupyterlab-widgets==1.0.0
140
+ kaitaistruct==0.9
141
+ kaldi-io==0.9.4
142
+ kaldi-python-io==1.2.2
143
+ kaldiio==2.17.2
144
+ kenlm @ https://github.com/kpu/kenlm/archive/master.zip
145
+ Keras-Preprocessing==1.1.2
146
+ kiwisolver==1.3.1
147
+ lang-trans==0.6.0
148
+ latexcodec==2.0.1
149
+ ldap3==2.9.1
150
+ librosa==0.9.0
151
+ llvmlite==0.35.0
152
+ lxml==4.9.0
153
+ Mako==1.1.5
154
+ Markdown==3.3.3
155
+ MarkupSafe==1.1.1
156
+ marshmallow==3.14.0
157
+ matplotlib==3.3.3
158
+ mccabe==0.6.1
159
+ mcd==0.4
160
+ mecab-python3==1.0.3
161
+ megatron-lm==2.2.0
162
+ mido==1.2.10
163
+ mistune==0.8.4
164
+ more-itertools==8.6.0
165
+ mpmath==1.2.1
166
+ multidict==5.2.0
167
+ multiprocess==0.70.11.1
168
+ nbclient==0.5.3
169
+ nbconvert==6.0.7
170
+ nbformat==5.1.3
171
+ NEMO==4.3.2
172
+ nemo-toolkit==1.4.0
173
+ nest-asyncio==1.5.1
174
+ networkx==2.5
175
+ nltk==3.5
176
+ nodeenv==1.5.0
177
+ notebook==6.3.0
178
+ numba==0.52.0
179
+ numpy==1.19.4
180
+ nvidia-cublas-cu11==11.10.3.66
181
+ nvidia-cuda-nvrtc-cu11==11.7.99
182
+ nvidia-cuda-runtime-cu11==11.7.99
183
+ nvidia-cudnn-cu11==8.5.0.96
184
+ oauthlib==3.1.0
185
+ omegaconf==1.4.1
186
+ onnx==1.10.2
187
+ OpenCC==1.1.2
188
+ opencv-python==4.4.0.46
189
+ openpyxl==3.0.9
190
+ opensmile==2.2.0
191
+ opt-einsum==3.3.0
192
+ optuna==2.10.0
193
+ oyaml==1.0
194
+ packaging==22.0
195
+ pandas==1.2.5
196
+ pandocfilters==1.4.3
197
+ pangu==4.0.6.1
198
+ parameterized==0.8.1
199
+ parso==0.7.1
200
+ pathspec==0.8.1
201
+ pathtools==0.1.2
202
+ pbr==5.6.0
203
+ pefile==2019.4.18
204
+ pescador==2.1.0
205
+ pesq==0.0.3
206
+ pexpect==4.8.0
207
+ phonemizer==2.2.1
208
+ pickleshare==0.7.5
209
+ Pillow==9.3.0
210
+ pip-api==0.0.23
211
+ pipreqs==0.4.11
212
+ pluggy==0.13.1
213
+ pooch==1.3.0
214
+ portalocker==2.3.2
215
+ pre-commit==2.9.0
216
+ pretty-midi==0.2.9
217
+ prettytable==2.2.1
218
+ progressbar2==3.53.1
219
+ prometheus-client==0.10.1
220
+ promise==2.3
221
+ prompt-toolkit==3.0.8
222
+ protobuf==3.14.0
223
+ psutil==5.6.6
224
+ ptyprocess==0.6.0
225
+ py==1.9.0
226
+ py-espeak-ng==0.1.8
227
+ pyannote.audio==1.1.1
228
+ pyannote.core==4.3
229
+ pyannote.database==4.1.1
230
+ pyannote.metrics==3.1
231
+ pyannote.pipeline==1.5.2
232
+ PyArabic==0.6.15
233
+ pyarrow==3.0.0
234
+ pyasn1==0.4.8
235
+ pyasn1-modules==0.2.8
236
+ pybind11==2.8.1
237
+ pybtex==0.24.0
238
+ pybtex-docutils==1.0.1
239
+ pycodestyle==2.5.0
240
+ pycparser==2.20
241
+ pyctcdecode==0.4.0
242
+ pyDeprecate==0.3.1
243
+ pydub==0.25.1
244
+ pyflakes==2.1.1
245
+ Pygments==2.7.2
246
+ pygtrie==2.5.0
247
+ pymodbus==2.5.3
248
+ pyparsing==2.4.7
249
+ pyperclip==1.8.2
250
+ pypinyin==0.43.0
251
+ pyrsistent==0.17.3
252
+ pyserial==3.5
253
+ PySocks==1.7.1
254
+ pystoi==0.3.3
255
+ pytest==5.4.1
256
+ pytest-runner==5.3.1
257
+ python-bidi==0.4.2
258
+ python-crfsuite==0.9.7
259
+ python-dateutil==2.8.2
260
+ python-Levenshtein==0.12.2
261
+ python-utils==2.4.0
262
+ pytorch-lightning==1.4.9
263
+ pytube==11.0.1
264
+ pytz==2022.6
265
+ PyWavelets==1.1.1
266
+ PyYAML==5.3.1
267
+ pyzmq==20.0.0
268
+ rapidfuzz==1.8.2
269
+ regex==2020.11.13
270
+ requests==2.28.1
271
+ requests-oauthlib==1.3.0
272
+ resampy==0.2.2
273
+ rfc3986==1.4.0
274
+ rsa==4.7
275
+ ruamel.yaml==0.17.21
276
+ ruamel.yaml.clib==0.2.7
277
+ s3m==1.1.0
278
+ s3transfer==0.5.0
279
+ sacrebleu==2.0.0
280
+ sacremoses==0.0.44
281
+ scikit-image==0.18.1
282
+ scikit-learn==0.23.2
283
+ scipy==1.5.4
284
+ -e git+https://github.com/sanghack81/SDCIT@00d060dde733fde9345154a494f81e97fb395ca7#egg=SDCIT
285
+ seaborn==0.11.1
286
+ segments==2.1.3
287
+ Send2Trash==1.5.0
288
+ sentencepiece==0.1.94
289
+ sentry-sdk==1.4.3
290
+ shellingham==1.4.0
291
+ shortuuid==1.0.7
292
+ SIDEKIT==1.3.8.5.2
293
+ simplejson==3.17.5
294
+ six==1.15.0
295
+ smart-open==5.0.0
296
+ smmap==5.0.0
297
+ snowballstemmer==2.0.0
298
+ sortedcollections==2.1.0
299
+ sortedcontainers==2.4.0
300
+ sounddevice==0.4.5
301
+ SoundFile==0.10.3.post1
302
+ soupsieve==2.3
303
+ sox==1.4.1
304
+ sparsemax==0.1.9
305
+ speechbrain==0.5.13
306
+ sphfile==1.0.3
307
+ Sphinx==3.3.1
308
+ sphinx-rtd-theme==0.4.3
309
+ sphinxcontrib-applehelp==1.0.2
310
+ sphinxcontrib-bibtex==2.4.1
311
+ sphinxcontrib-devhelp==1.0.2
312
+ sphinxcontrib-htmlhelp==1.0.3
313
+ sphinxcontrib-jsmath==1.0.1
314
+ sphinxcontrib-qthelp==1.0.3
315
+ sphinxcontrib-serializinghtml==1.1.4
316
+ SQLAlchemy==1.4.25
317
+ sqlparse==0.4.2
318
+ stanza==1.4.2
319
+ stevedore==3.4.0
320
+ subprocess32==3.5.4
321
+ sympy==1.9
322
+ tabulate==0.8.9
323
+ tensorboard==2.4.0
324
+ tensorboard-plugin-wit==1.7.0
325
+ tensorflow==2.4.0
326
+ tensorflow-estimator==2.4.0
327
+ termcolor==1.1.0
328
+ terminado==0.9.4
329
+ testpath==0.4.4
330
+ threadpoolctl==2.1.0
331
+ tifffile==2020.12.8
332
+ tikzplotlib==0.9.8
333
+ tkseem==0.0.3
334
+ tokenizers==0.10.2
335
+ toml==0.10.2
336
+ torch==1.13.1
337
+ torch-stft==0.1.4
338
+ torchaudio==0.13.1
339
+ torchmetrics==0.6.0
340
+ torchvision==0.14.1
341
+ tornado==6.1
342
+ tqdm==4.61.1
343
+ trackrip==1.2.1
344
+ traitlets==5.0.5
345
+ transformers==4.15.0
346
+ typed-ast==1.4.1
347
+ typer==0.4.0
348
+ typing-extensions==3.7.4.3
349
+ Unidecode==1.3.2
350
+ uritemplate==3.0.1
351
+ urllib3==1.26.2
352
+ virtualenv==20.2.1
353
+ wandb==0.12.6
354
+ wcwidth==0.2.5
355
+ webdataset==0.1.62
356
+ webencodings==0.5.1
357
+ Werkzeug==1.0.1
358
+ wget==3.2
359
+ widgetsnbextension==3.5.1
360
+ wordninja==2.0.0
361
+ wrapt==1.12.1
362
+ xmltodict==0.13.0
363
+ xxhash==2.0.0
364
+ yamllint==1.23.0
365
+ yarg==0.1.9
366
+ yarl==1.7.2
367
+ yaspin==2.1.0
368
+ youtokentome==1.0.6
369
+ youtube-dl==2021.6.6
370
+ zipp==3.6.0
371
+
372
+
373
+ 2023-01-07 15:57:02,047 - speechbrain.core - ERROR - Exception:
374
+ Traceback (most recent call last):
375
+ File "ctc_train.py", line 305, in <module>
376
+ train_data, valid_data, test_datasets, label_encoder = dataio_prepare(
377
+ File "ctc_train.py", line 178, in dataio_prepare
378
+ train_data = sb.dataio.dataset.DynamicItemDataset.from_csv(
379
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/speechbrain/dataio/dataset.py", line 365, in from_csv
380
+ data = load_data_csv(csv_path, replacements)
381
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/speechbrain/dataio/dataio.py", line 127, in load_data_csv
382
+ with open(csv_path, newline="") as csvfile:
383
+ FileNotFoundError: [Errno 2] No such file or directory: 'test_salah.csv'
384
+ 2023-01-07 15:57:54,234 - speechbrain.core - INFO - Beginning experiment!
385
+ 2023-01-07 15:57:54,234 - speechbrain.core - INFO - Experiment folder: partly_frozen_splitted_wavlm
386
+ 2023-01-07 15:57:54,891 - speechbrain.utils.superpowers - DEBUG - abkhazia==1.0
387
+ absl-py==0.11.0
388
+ aiohttp==3.8.0
389
+ aiosignal==1.2.0
390
+ alabaster==0.7.12
391
+ alembic==1.7.4
392
+ altgraph==0.17
393
+ antlr4-python3-runtime==4.8
394
+ appdirs==1.4.4
395
+ argcomplete==1.12.2
396
+ argon2-cffi==20.1.0
397
+ asgiref==3.6.0
398
+ astunparse==1.6.3
399
+ async-generator==1.10
400
+ async-timeout==4.0.0
401
+ attrdict==2.0.1
402
+ attrs==20.3.0
403
+ audeer==1.16.0
404
+ audformat==0.11.5
405
+ audinterface==0.7.0
406
+ audiofile==1.0.0
407
+ audiomentations==0.25.0
408
+ audioread==2.1.9
409
+ audobject==0.4.14
410
+ audresample==0.1.6
411
+ -e git+https://github.com/facebookresearch/WavAugment.git@54afcdb00ccc852c2f030f239f8532c9562b550e#egg=augment
412
+ autopage==0.4.0
413
+ Babel==2.9.0
414
+ backcall==0.2.0
415
+ beautifulsoup4==4.10.0
416
+ black==19.10b0
417
+ bleach==3.3.0
418
+ boto3==1.20.2
419
+ botocore==1.23.2
420
+ braceexpand==0.1.7
421
+ cachetools==4.2.0
422
+ certifi @ file:///croot/certifi_1671487769961/work/certifi
423
+ cffi==1.14.3
424
+ cfgv==3.2.0
425
+ chardet==3.0.4
426
+ charset-normalizer==2.0.7
427
+ click==7.1.2
428
+ cliff==3.9.0
429
+ clldutils==3.5.4
430
+ cmaes==0.8.2
431
+ cmake==3.18.4.post1
432
+ cmd2==2.2.0
433
+ colorama==0.4.4
434
+ colorlog==4.6.2
435
+ configparser==5.1.0
436
+ cryptography==38.0.4
437
+ csvw==1.8.1
438
+ cycler==0.10.0
439
+ Cython==0.29.21
440
+ dataclasses==0.6
441
+ datasets==1.5.0
442
+ decorator==4.4.2
443
+ deepspeech==0.9.1
444
+ defusedxml==0.7.1
445
+ denoiser==0.1.5
446
+ dill==0.3.3
447
+ Distance==0.1.3
448
+ distlib==0.3.1
449
+ Django==3.2.16
450
+ django-auditlog==2.2.1
451
+ django-filter==22.1
452
+ django-js-asset==1.2.2
453
+ django-mptt==0.14.0
454
+ djangorestframework==3.14.0
455
+ docker-pycreds==0.4.0
456
+ docopt==0.6.2
457
+ docutils==0.16
458
+ drf-excel==2.2.0
459
+ drf-flex-fields==1.0.0
460
+ drf-renderer-xlsx==0.4.1
461
+ easyocr==1.2.1
462
+ editdistance==0.6.0
463
+ emoji==2.2.0
464
+ entrypoints==0.3
465
+ et-xmlfile==1.1.0
466
+ exceptiongroup==1.1.0
467
+ farasapy==0.0.14
468
+ fasttext==0.9.2
469
+ ffmpeg-python==0.2.0
470
+ filelock==3.0.12
471
+ flake8==3.7.9
472
+ flatbuffers==1.12
473
+ frozendict==2.0.7
474
+ frozenlist==1.2.0
475
+ fsspec==2021.11.0
476
+ future==0.18.2
477
+ g2p-en==2.1.0
478
+ gast==0.3.3
479
+ gdown==4.2.0
480
+ gensim==4.0.1
481
+ gitdb==4.0.9
482
+ GitPython==3.1.24
483
+ google-auth==1.24.0
484
+ google-auth-oauthlib==0.4.2
485
+ google-pasta==0.2.0
486
+ greenlet==1.1.2
487
+ grpcio==1.32.0
488
+ h5features==1.3.2
489
+ h5py==2.10.0
490
+ htk-io==0.5
491
+ huggingface-hub==0.9.1
492
+ hydra-colorlog==0.1.4
493
+ hydra-core==0.11.3
494
+ HyperPyYAML==1.1.0
495
+ hypothesis==6.61.2
496
+ identify==1.5.10
497
+ idna==2.10
498
+ imageio==2.9.0
499
+ imagesize==1.2.0
500
+ importlib-metadata==4.8.1
501
+ importlib-resources==5.2.2
502
+ inflect==5.3.0
503
+ ipadic==1.0.0
504
+ ipykernel==5.3.4
505
+ ipython==7.19.0
506
+ ipython-genutils==0.2.0
507
+ ipywidgets==7.6.3
508
+ iso-639==0.4.5
509
+ isodate==0.6.0
510
+ isort==4.3.21
511
+ jedi==0.17.2
512
+ jieba==0.42.1
513
+ Jinja2==2.11.2
514
+ jiwer==2.2.0
515
+ jmespath==0.10.0
516
+ joblib==0.17.0
517
+ jsonschema==3.2.0
518
+ julius==0.2.7
519
+ jupyter-client==6.1.7
520
+ jupyter-core==4.7.0
521
+ jupyterlab-pygments==0.1.2
522
+ jupyterlab-widgets==1.0.0
523
+ kaitaistruct==0.9
524
+ kaldi-io==0.9.4
525
+ kaldi-python-io==1.2.2
526
+ kaldiio==2.17.2
527
+ kenlm @ https://github.com/kpu/kenlm/archive/master.zip
528
+ Keras-Preprocessing==1.1.2
529
+ kiwisolver==1.3.1
530
+ lang-trans==0.6.0
531
+ latexcodec==2.0.1
532
+ ldap3==2.9.1
533
+ librosa==0.9.0
534
+ llvmlite==0.35.0
535
+ lxml==4.9.0
536
+ Mako==1.1.5
537
+ Markdown==3.3.3
538
+ MarkupSafe==1.1.1
539
+ marshmallow==3.14.0
540
+ matplotlib==3.3.3
541
+ mccabe==0.6.1
542
+ mcd==0.4
543
+ mecab-python3==1.0.3
544
+ megatron-lm==2.2.0
545
+ mido==1.2.10
546
+ mistune==0.8.4
547
+ more-itertools==8.6.0
548
+ mpmath==1.2.1
549
+ multidict==5.2.0
550
+ multiprocess==0.70.11.1
551
+ nbclient==0.5.3
552
+ nbconvert==6.0.7
553
+ nbformat==5.1.3
554
+ NEMO==4.3.2
555
+ nemo-toolkit==1.4.0
556
+ nest-asyncio==1.5.1
557
+ networkx==2.5
558
+ nltk==3.5
559
+ nodeenv==1.5.0
560
+ notebook==6.3.0
561
+ numba==0.52.0
562
+ numpy==1.19.4
563
+ nvidia-cublas-cu11==11.10.3.66
564
+ nvidia-cuda-nvrtc-cu11==11.7.99
565
+ nvidia-cuda-runtime-cu11==11.7.99
566
+ nvidia-cudnn-cu11==8.5.0.96
567
+ oauthlib==3.1.0
568
+ omegaconf==1.4.1
569
+ onnx==1.10.2
570
+ OpenCC==1.1.2
571
+ opencv-python==4.4.0.46
572
+ openpyxl==3.0.9
573
+ opensmile==2.2.0
574
+ opt-einsum==3.3.0
575
+ optuna==2.10.0
576
+ oyaml==1.0
577
+ packaging==22.0
578
+ pandas==1.2.5
579
+ pandocfilters==1.4.3
580
+ pangu==4.0.6.1
581
+ parameterized==0.8.1
582
+ parso==0.7.1
583
+ pathspec==0.8.1
584
+ pathtools==0.1.2
585
+ pbr==5.6.0
586
+ pefile==2019.4.18
587
+ pescador==2.1.0
588
+ pesq==0.0.3
589
+ pexpect==4.8.0
590
+ phonemizer==2.2.1
591
+ pickleshare==0.7.5
592
+ Pillow==9.3.0
593
+ pip-api==0.0.23
594
+ pipreqs==0.4.11
595
+ pluggy==0.13.1
596
+ pooch==1.3.0
597
+ portalocker==2.3.2
598
+ pre-commit==2.9.0
599
+ pretty-midi==0.2.9
600
+ prettytable==2.2.1
601
+ progressbar2==3.53.1
602
+ prometheus-client==0.10.1
603
+ promise==2.3
604
+ prompt-toolkit==3.0.8
605
+ protobuf==3.14.0
606
+ psutil==5.6.6
607
+ ptyprocess==0.6.0
608
+ py==1.9.0
609
+ py-espeak-ng==0.1.8
610
+ pyannote.audio==1.1.1
611
+ pyannote.core==4.3
612
+ pyannote.database==4.1.1
613
+ pyannote.metrics==3.1
614
+ pyannote.pipeline==1.5.2
615
+ PyArabic==0.6.15
616
+ pyarrow==3.0.0
617
+ pyasn1==0.4.8
618
+ pyasn1-modules==0.2.8
619
+ pybind11==2.8.1
620
+ pybtex==0.24.0
621
+ pybtex-docutils==1.0.1
622
+ pycodestyle==2.5.0
623
+ pycparser==2.20
624
+ pyctcdecode==0.4.0
625
+ pyDeprecate==0.3.1
626
+ pydub==0.25.1
627
+ pyflakes==2.1.1
628
+ Pygments==2.7.2
629
+ pygtrie==2.5.0
630
+ pymodbus==2.5.3
631
+ pyparsing==2.4.7
632
+ pyperclip==1.8.2
633
+ pypinyin==0.43.0
634
+ pyrsistent==0.17.3
635
+ pyserial==3.5
636
+ PySocks==1.7.1
637
+ pystoi==0.3.3
638
+ pytest==5.4.1
639
+ pytest-runner==5.3.1
640
+ python-bidi==0.4.2
641
+ python-crfsuite==0.9.7
642
+ python-dateutil==2.8.2
643
+ python-Levenshtein==0.12.2
644
+ python-utils==2.4.0
645
+ pytorch-lightning==1.4.9
646
+ pytube==11.0.1
647
+ pytz==2022.6
648
+ PyWavelets==1.1.1
649
+ PyYAML==5.3.1
650
+ pyzmq==20.0.0
651
+ rapidfuzz==1.8.2
652
+ regex==2020.11.13
653
+ requests==2.28.1
654
+ requests-oauthlib==1.3.0
655
+ resampy==0.2.2
656
+ rfc3986==1.4.0
657
+ rsa==4.7
658
+ ruamel.yaml==0.17.21
659
+ ruamel.yaml.clib==0.2.7
660
+ s3m==1.1.0
661
+ s3transfer==0.5.0
662
+ sacrebleu==2.0.0
663
+ sacremoses==0.0.44
664
+ scikit-image==0.18.1
665
+ scikit-learn==0.23.2
666
+ scipy==1.5.4
667
+ -e git+https://github.com/sanghack81/SDCIT@00d060dde733fde9345154a494f81e97fb395ca7#egg=SDCIT
668
+ seaborn==0.11.1
669
+ segments==2.1.3
670
+ Send2Trash==1.5.0
671
+ sentencepiece==0.1.94
672
+ sentry-sdk==1.4.3
673
+ shellingham==1.4.0
674
+ shortuuid==1.0.7
675
+ SIDEKIT==1.3.8.5.2
676
+ simplejson==3.17.5
677
+ six==1.15.0
678
+ smart-open==5.0.0
679
+ smmap==5.0.0
680
+ snowballstemmer==2.0.0
681
+ sortedcollections==2.1.0
682
+ sortedcontainers==2.4.0
683
+ sounddevice==0.4.5
684
+ SoundFile==0.10.3.post1
685
+ soupsieve==2.3
686
+ sox==1.4.1
687
+ sparsemax==0.1.9
688
+ speechbrain==0.5.13
689
+ sphfile==1.0.3
690
+ Sphinx==3.3.1
691
+ sphinx-rtd-theme==0.4.3
692
+ sphinxcontrib-applehelp==1.0.2
693
+ sphinxcontrib-bibtex==2.4.1
694
+ sphinxcontrib-devhelp==1.0.2
695
+ sphinxcontrib-htmlhelp==1.0.3
696
+ sphinxcontrib-jsmath==1.0.1
697
+ sphinxcontrib-qthelp==1.0.3
698
+ sphinxcontrib-serializinghtml==1.1.4
699
+ SQLAlchemy==1.4.25
700
+ sqlparse==0.4.2
701
+ stanza==1.4.2
702
+ stevedore==3.4.0
703
+ subprocess32==3.5.4
704
+ sympy==1.9
705
+ tabulate==0.8.9
706
+ tensorboard==2.4.0
707
+ tensorboard-plugin-wit==1.7.0
708
+ tensorflow==2.4.0
709
+ tensorflow-estimator==2.4.0
710
+ termcolor==1.1.0
711
+ terminado==0.9.4
712
+ testpath==0.4.4
713
+ threadpoolctl==2.1.0
714
+ tifffile==2020.12.8
715
+ tikzplotlib==0.9.8
716
+ tkseem==0.0.3
717
+ tokenizers==0.10.2
718
+ toml==0.10.2
719
+ torch==1.13.1
720
+ torch-stft==0.1.4
721
+ torchaudio==0.13.1
722
+ torchmetrics==0.6.0
723
+ torchvision==0.14.1
724
+ tornado==6.1
725
+ tqdm==4.61.1
726
+ trackrip==1.2.1
727
+ traitlets==5.0.5
728
+ transformers==4.15.0
729
+ typed-ast==1.4.1
730
+ typer==0.4.0
731
+ typing-extensions==3.7.4.3
732
+ Unidecode==1.3.2
733
+ uritemplate==3.0.1
734
+ urllib3==1.26.2
735
+ virtualenv==20.2.1
736
+ wandb==0.12.6
737
+ wcwidth==0.2.5
738
+ webdataset==0.1.62
739
+ webencodings==0.5.1
740
+ Werkzeug==1.0.1
741
+ wget==3.2
742
+ widgetsnbextension==3.5.1
743
+ wordninja==2.0.0
744
+ wrapt==1.12.1
745
+ xmltodict==0.13.0
746
+ xxhash==2.0.0
747
+ yamllint==1.23.0
748
+ yarg==0.1.9
749
+ yarl==1.7.2
750
+ yaspin==2.1.0
751
+ youtokentome==1.0.6
752
+ youtube-dl==2021.6.6
753
+ zipp==3.6.0
754
+
755
+
756
+ 2023-01-07 15:57:54,987 - speechbrain.dataio.encoder - DEBUG - Would load categorical encoding from partly_frozen_splitted_wavlm/save/label_encoder.txt, but file doesn't exist yet.
757
+ 2023-01-07 15:57:54,988 - speechbrain.dataio.encoder - INFO - Moving label 'ت' from index 0, because '<blank>' was put at its place.
758
+ 2023-01-07 15:57:54,988 - speechbrain.dataio.encoder - INFO - Moving label 'ع' from index 1, because '<bos>' was put at its place.
759
+ 2023-01-07 15:57:54,989 - speechbrain.dataio.encoder - INFO - Moving label 'ب' from index 2, because '<eos>' was put at its place.
760
+ 2023-01-07 15:57:54,989 - speechbrain.dataio.encoder - INFO - Load called, but CTCTextEncoder is not empty. Loaded data will overwrite everything. This is normal if there is e.g. an unk label defined at init.
761
+ 2023-01-07 15:57:54,990 - speechbrain.dataio.encoder - DEBUG - Loaded categorical encoding from partly_frozen_splitted_wavlm/save/label_encoder.txt
762
+ 2023-01-07 15:57:54,990 - speechbrain.core - INFO - Info: auto_mix_prec arg from hparam file is used
763
+ 2023-01-07 15:57:54,990 - speechbrain.core - INFO - Info: ckpt_interval_minutes arg from hparam file is used
764
+ 2023-01-07 15:57:57,073 - speechbrain.core - INFO - 313.4M trainable parameters in ASR
765
+ 2023-01-07 15:57:57,075 - speechbrain.utils.checkpoints - INFO - Would load a checkpoint here, but none found yet.
766
+ 2023-01-07 15:57:57,075 - speechbrain.utils.epoch_loop - INFO - Going into epoch 1
767
+ 2023-01-07 15:57:57,132 - speechbrain.core - ERROR - Exception:
768
+ Traceback (most recent call last):
769
+ File "ctc_train.py", line 322, in <module>
770
+ asr_brain.fit(
771
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/speechbrain/core.py", line 1153, in fit
772
+ self._fit_train(train_set=train_set, epoch=epoch, enable=enable)
773
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/speechbrain/core.py", line 1009, in _fit_train
774
+ loss = self.fit_batch(batch)
775
+ File "ctc_train.py", line 88, in fit_batch
776
+ predictions = self.compute_forward(batch, sb.Stage.TRAIN)
777
+ File "ctc_train.py", line 42, in compute_forward
778
+ feats = self.modules.wav2vec2(wavs)
779
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
780
+ return forward_call(*input, **kwargs)
781
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/speechbrain/lobes/models/huggingface_wav2vec.py", line 266, in forward
782
+ return self.extract_features(wav)
783
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/speechbrain/lobes/models/huggingface_wav2vec.py", line 281, in extract_features
784
+ out = self.model(wav)[0]
785
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
786
+ return forward_call(*input, **kwargs)
787
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/transformers/models/wavlm/modeling_wavlm.py", line 1232, in forward
788
+ extract_features = self.feature_extractor(input_values)
789
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
790
+ return forward_call(*input, **kwargs)
791
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/transformers/models/wavlm/modeling_wavlm.py", line 400, in forward
792
+ hidden_states = conv_layer(hidden_states)
793
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
794
+ return forward_call(*input, **kwargs)
795
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/transformers/models/wavlm/modeling_wavlm.py", line 270, in forward
796
+ hidden_states = self.conv(hidden_states)
797
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
798
+ return forward_call(*input, **kwargs)
799
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 313, in forward
800
+ return self._conv_forward(input, self.weight, self.bias)
801
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 309, in _conv_forward
802
+ return F.conv1d(input, weight, bias, self.stride,
803
+ RuntimeError: Expected 2D (unbatched) or 3D (batched) input to conv1d, but got input of size: [2, 1, 168960, 1]
804
+ 2023-01-07 15:59:22,733 - speechbrain.core - INFO - Beginning experiment!
805
+ 2023-01-07 15:59:22,733 - speechbrain.core - INFO - Experiment folder: partly_frozen_splitted_wavlm
806
+ 2023-01-07 15:59:23,373 - speechbrain.utils.superpowers - DEBUG - abkhazia==1.0
807
+ absl-py==0.11.0
808
+ aiohttp==3.8.0
809
+ aiosignal==1.2.0
810
+ alabaster==0.7.12
811
+ alembic==1.7.4
812
+ altgraph==0.17
813
+ antlr4-python3-runtime==4.8
814
+ appdirs==1.4.4
815
+ argcomplete==1.12.2
816
+ argon2-cffi==20.1.0
817
+ asgiref==3.6.0
818
+ astunparse==1.6.3
819
+ async-generator==1.10
820
+ async-timeout==4.0.0
821
+ attrdict==2.0.1
822
+ attrs==20.3.0
823
+ audeer==1.16.0
824
+ audformat==0.11.5
825
+ audinterface==0.7.0
826
+ audiofile==1.0.0
827
+ audiomentations==0.25.0
828
+ audioread==2.1.9
829
+ audobject==0.4.14
830
+ audresample==0.1.6
831
+ -e git+https://github.com/facebookresearch/WavAugment.git@54afcdb00ccc852c2f030f239f8532c9562b550e#egg=augment
832
+ autopage==0.4.0
833
+ Babel==2.9.0
834
+ backcall==0.2.0
835
+ beautifulsoup4==4.10.0
836
+ black==19.10b0
837
+ bleach==3.3.0
838
+ boto3==1.20.2
839
+ botocore==1.23.2
840
+ braceexpand==0.1.7
841
+ cachetools==4.2.0
842
+ certifi @ file:///croot/certifi_1671487769961/work/certifi
843
+ cffi==1.14.3
844
+ cfgv==3.2.0
845
+ chardet==3.0.4
846
+ charset-normalizer==2.0.7
847
+ click==7.1.2
848
+ cliff==3.9.0
849
+ clldutils==3.5.4
850
+ cmaes==0.8.2
851
+ cmake==3.18.4.post1
852
+ cmd2==2.2.0
853
+ colorama==0.4.4
854
+ colorlog==4.6.2
855
+ configparser==5.1.0
856
+ cryptography==38.0.4
857
+ csvw==1.8.1
858
+ cycler==0.10.0
859
+ Cython==0.29.21
860
+ dataclasses==0.6
861
+ datasets==1.5.0
862
+ decorator==4.4.2
863
+ deepspeech==0.9.1
864
+ defusedxml==0.7.1
865
+ denoiser==0.1.5
866
+ dill==0.3.3
867
+ Distance==0.1.3
868
+ distlib==0.3.1
869
+ Django==3.2.16
870
+ django-auditlog==2.2.1
871
+ django-filter==22.1
872
+ django-js-asset==1.2.2
873
+ django-mptt==0.14.0
874
+ djangorestframework==3.14.0
875
+ docker-pycreds==0.4.0
876
+ docopt==0.6.2
877
+ docutils==0.16
878
+ drf-excel==2.2.0
879
+ drf-flex-fields==1.0.0
880
+ drf-renderer-xlsx==0.4.1
881
+ easyocr==1.2.1
882
+ editdistance==0.6.0
883
+ emoji==2.2.0
884
+ entrypoints==0.3
885
+ et-xmlfile==1.1.0
886
+ exceptiongroup==1.1.0
887
+ farasapy==0.0.14
888
+ fasttext==0.9.2
889
+ ffmpeg-python==0.2.0
890
+ filelock==3.0.12
891
+ flake8==3.7.9
892
+ flatbuffers==1.12
893
+ frozendict==2.0.7
894
+ frozenlist==1.2.0
895
+ fsspec==2021.11.0
896
+ future==0.18.2
897
+ g2p-en==2.1.0
898
+ gast==0.3.3
899
+ gdown==4.2.0
900
+ gensim==4.0.1
901
+ gitdb==4.0.9
902
+ GitPython==3.1.24
903
+ google-auth==1.24.0
904
+ google-auth-oauthlib==0.4.2
905
+ google-pasta==0.2.0
906
+ greenlet==1.1.2
907
+ grpcio==1.32.0
908
+ h5features==1.3.2
909
+ h5py==2.10.0
910
+ htk-io==0.5
911
+ huggingface-hub==0.9.1
912
+ hydra-colorlog==0.1.4
913
+ hydra-core==0.11.3
914
+ HyperPyYAML==1.1.0
915
+ hypothesis==6.61.2
916
+ identify==1.5.10
917
+ idna==2.10
918
+ imageio==2.9.0
919
+ imagesize==1.2.0
920
+ importlib-metadata==4.8.1
921
+ importlib-resources==5.2.2
922
+ inflect==5.3.0
923
+ ipadic==1.0.0
924
+ ipykernel==5.3.4
925
+ ipython==7.19.0
926
+ ipython-genutils==0.2.0
927
+ ipywidgets==7.6.3
928
+ iso-639==0.4.5
929
+ isodate==0.6.0
930
+ isort==4.3.21
931
+ jedi==0.17.2
932
+ jieba==0.42.1
933
+ Jinja2==2.11.2
934
+ jiwer==2.2.0
935
+ jmespath==0.10.0
936
+ joblib==0.17.0
937
+ jsonschema==3.2.0
938
+ julius==0.2.7
939
+ jupyter-client==6.1.7
940
+ jupyter-core==4.7.0
941
+ jupyterlab-pygments==0.1.2
942
+ jupyterlab-widgets==1.0.0
943
+ kaitaistruct==0.9
944
+ kaldi-io==0.9.4
945
+ kaldi-python-io==1.2.2
946
+ kaldiio==2.17.2
947
+ kenlm @ https://github.com/kpu/kenlm/archive/master.zip
948
+ Keras-Preprocessing==1.1.2
949
+ kiwisolver==1.3.1
950
+ lang-trans==0.6.0
951
+ latexcodec==2.0.1
952
+ ldap3==2.9.1
953
+ librosa==0.9.0
954
+ llvmlite==0.35.0
955
+ lxml==4.9.0
956
+ Mako==1.1.5
957
+ Markdown==3.3.3
958
+ MarkupSafe==1.1.1
959
+ marshmallow==3.14.0
960
+ matplotlib==3.3.3
961
+ mccabe==0.6.1
962
+ mcd==0.4
963
+ mecab-python3==1.0.3
964
+ megatron-lm==2.2.0
965
+ mido==1.2.10
966
+ mistune==0.8.4
967
+ more-itertools==8.6.0
968
+ mpmath==1.2.1
969
+ multidict==5.2.0
970
+ multiprocess==0.70.11.1
971
+ nbclient==0.5.3
972
+ nbconvert==6.0.7
973
+ nbformat==5.1.3
974
+ NEMO==4.3.2
975
+ nemo-toolkit==1.4.0
976
+ nest-asyncio==1.5.1
977
+ networkx==2.5
978
+ nltk==3.5
979
+ nodeenv==1.5.0
980
+ notebook==6.3.0
981
+ numba==0.52.0
982
+ numpy==1.19.4
983
+ nvidia-cublas-cu11==11.10.3.66
984
+ nvidia-cuda-nvrtc-cu11==11.7.99
985
+ nvidia-cuda-runtime-cu11==11.7.99
986
+ nvidia-cudnn-cu11==8.5.0.96
987
+ oauthlib==3.1.0
988
+ omegaconf==1.4.1
989
+ onnx==1.10.2
990
+ OpenCC==1.1.2
991
+ opencv-python==4.4.0.46
992
+ openpyxl==3.0.9
993
+ opensmile==2.2.0
994
+ opt-einsum==3.3.0
995
+ optuna==2.10.0
996
+ oyaml==1.0
997
+ packaging==22.0
998
+ pandas==1.2.5
999
+ pandocfilters==1.4.3
1000
+ pangu==4.0.6.1
1001
+ parameterized==0.8.1
1002
+ parso==0.7.1
1003
+ pathspec==0.8.1
1004
+ pathtools==0.1.2
1005
+ pbr==5.6.0
1006
+ pefile==2019.4.18
1007
+ pescador==2.1.0
1008
+ pesq==0.0.3
1009
+ pexpect==4.8.0
1010
+ phonemizer==2.2.1
1011
+ pickleshare==0.7.5
1012
+ Pillow==9.3.0
1013
+ pip-api==0.0.23
1014
+ pipreqs==0.4.11
1015
+ pluggy==0.13.1
1016
+ pooch==1.3.0
1017
+ portalocker==2.3.2
1018
+ pre-commit==2.9.0
1019
+ pretty-midi==0.2.9
1020
+ prettytable==2.2.1
1021
+ progressbar2==3.53.1
1022
+ prometheus-client==0.10.1
1023
+ promise==2.3
1024
+ prompt-toolkit==3.0.8
1025
+ protobuf==3.14.0
1026
+ psutil==5.6.6
1027
+ ptyprocess==0.6.0
1028
+ py==1.9.0
1029
+ py-espeak-ng==0.1.8
1030
+ pyannote.audio==1.1.1
1031
+ pyannote.core==4.3
1032
+ pyannote.database==4.1.1
1033
+ pyannote.metrics==3.1
1034
+ pyannote.pipeline==1.5.2
1035
+ PyArabic==0.6.15
1036
+ pyarrow==3.0.0
1037
+ pyasn1==0.4.8
1038
+ pyasn1-modules==0.2.8
1039
+ pybind11==2.8.1
1040
+ pybtex==0.24.0
1041
+ pybtex-docutils==1.0.1
1042
+ pycodestyle==2.5.0
1043
+ pycparser==2.20
1044
+ pyctcdecode==0.4.0
1045
+ pyDeprecate==0.3.1
1046
+ pydub==0.25.1
1047
+ pyflakes==2.1.1
1048
+ Pygments==2.7.2
1049
+ pygtrie==2.5.0
1050
+ pymodbus==2.5.3
1051
+ pyparsing==2.4.7
1052
+ pyperclip==1.8.2
1053
+ pypinyin==0.43.0
1054
+ pyrsistent==0.17.3
1055
+ pyserial==3.5
1056
+ PySocks==1.7.1
1057
+ pystoi==0.3.3
1058
+ pytest==5.4.1
1059
+ pytest-runner==5.3.1
1060
+ python-bidi==0.4.2
1061
+ python-crfsuite==0.9.7
1062
+ python-dateutil==2.8.2
1063
+ python-Levenshtein==0.12.2
1064
+ python-utils==2.4.0
1065
+ pytorch-lightning==1.4.9
1066
+ pytube==11.0.1
1067
+ pytz==2022.6
1068
+ PyWavelets==1.1.1
1069
+ PyYAML==5.3.1
1070
+ pyzmq==20.0.0
1071
+ rapidfuzz==1.8.2
1072
+ regex==2020.11.13
1073
+ requests==2.28.1
1074
+ requests-oauthlib==1.3.0
1075
+ resampy==0.2.2
1076
+ rfc3986==1.4.0
1077
+ rsa==4.7
1078
+ ruamel.yaml==0.17.21
1079
+ ruamel.yaml.clib==0.2.7
1080
+ s3m==1.1.0
1081
+ s3transfer==0.5.0
1082
+ sacrebleu==2.0.0
1083
+ sacremoses==0.0.44
1084
+ scikit-image==0.18.1
1085
+ scikit-learn==0.23.2
1086
+ scipy==1.5.4
1087
+ -e git+https://github.com/sanghack81/SDCIT@00d060dde733fde9345154a494f81e97fb395ca7#egg=SDCIT
1088
+ seaborn==0.11.1
1089
+ segments==2.1.3
1090
+ Send2Trash==1.5.0
1091
+ sentencepiece==0.1.94
1092
+ sentry-sdk==1.4.3
1093
+ shellingham==1.4.0
1094
+ shortuuid==1.0.7
1095
+ SIDEKIT==1.3.8.5.2
1096
+ simplejson==3.17.5
1097
+ six==1.15.0
1098
+ smart-open==5.0.0
1099
+ smmap==5.0.0
1100
+ snowballstemmer==2.0.0
1101
+ sortedcollections==2.1.0
1102
+ sortedcontainers==2.4.0
1103
+ sounddevice==0.4.5
1104
+ SoundFile==0.10.3.post1
1105
+ soupsieve==2.3
1106
+ sox==1.4.1
1107
+ sparsemax==0.1.9
1108
+ speechbrain==0.5.13
1109
+ sphfile==1.0.3
1110
+ Sphinx==3.3.1
1111
+ sphinx-rtd-theme==0.4.3
1112
+ sphinxcontrib-applehelp==1.0.2
1113
+ sphinxcontrib-bibtex==2.4.1
1114
+ sphinxcontrib-devhelp==1.0.2
1115
+ sphinxcontrib-htmlhelp==1.0.3
1116
+ sphinxcontrib-jsmath==1.0.1
1117
+ sphinxcontrib-qthelp==1.0.3
1118
+ sphinxcontrib-serializinghtml==1.1.4
1119
+ SQLAlchemy==1.4.25
1120
+ sqlparse==0.4.2
1121
+ stanza==1.4.2
1122
+ stevedore==3.4.0
1123
+ subprocess32==3.5.4
1124
+ sympy==1.9
1125
+ tabulate==0.8.9
1126
+ tensorboard==2.4.0
1127
+ tensorboard-plugin-wit==1.7.0
1128
+ tensorflow==2.4.0
1129
+ tensorflow-estimator==2.4.0
1130
+ termcolor==1.1.0
1131
+ terminado==0.9.4
1132
+ testpath==0.4.4
1133
+ threadpoolctl==2.1.0
1134
+ tifffile==2020.12.8
1135
+ tikzplotlib==0.9.8
1136
+ tkseem==0.0.3
1137
+ tokenizers==0.10.2
1138
+ toml==0.10.2
1139
+ torch==1.13.1
1140
+ torch-stft==0.1.4
1141
+ torchaudio==0.13.1
1142
+ torchmetrics==0.6.0
1143
+ torchvision==0.14.1
1144
+ tornado==6.1
1145
+ tqdm==4.61.1
1146
+ trackrip==1.2.1
1147
+ traitlets==5.0.5
1148
+ transformers==4.15.0
1149
+ typed-ast==1.4.1
1150
+ typer==0.4.0
1151
+ typing-extensions==3.7.4.3
1152
+ Unidecode==1.3.2
1153
+ uritemplate==3.0.1
1154
+ urllib3==1.26.2
1155
+ virtualenv==20.2.1
1156
+ wandb==0.12.6
1157
+ wcwidth==0.2.5
1158
+ webdataset==0.1.62
1159
+ webencodings==0.5.1
1160
+ Werkzeug==1.0.1
1161
+ wget==3.2
1162
+ widgetsnbextension==3.5.1
1163
+ wordninja==2.0.0
1164
+ wrapt==1.12.1
1165
+ xmltodict==0.13.0
1166
+ xxhash==2.0.0
1167
+ yamllint==1.23.0
1168
+ yarg==0.1.9
1169
+ yarl==1.7.2
1170
+ yaspin==2.1.0
1171
+ youtokentome==1.0.6
1172
+ youtube-dl==2021.6.6
1173
+ zipp==3.6.0
1174
+
1175
+
1176
+ 2023-01-07 15:59:23,493 - speechbrain.dataio.encoder - DEBUG - Loaded categorical encoding from partly_frozen_splitted_wavlm/save/label_encoder.txt
1177
+ 2023-01-07 15:59:23,493 - speechbrain.dataio.encoder - INFO - Load called, but CTCTextEncoder is not empty. Loaded data will overwrite everything. This is normal if there is e.g. an unk label defined at init.
1178
+ 2023-01-07 15:59:23,494 - speechbrain.dataio.encoder - DEBUG - Loaded categorical encoding from partly_frozen_splitted_wavlm/save/label_encoder.txt
1179
+ 2023-01-07 15:59:23,495 - speechbrain.core - INFO - Info: auto_mix_prec arg from hparam file is used
1180
+ 2023-01-07 15:59:23,495 - speechbrain.core - INFO - Info: ckpt_interval_minutes arg from hparam file is used
1181
+ 2023-01-07 15:59:24,946 - speechbrain.core - INFO - 313.4M trainable parameters in ASR
1182
+ 2023-01-07 15:59:24,949 - speechbrain.utils.checkpoints - INFO - Would load a checkpoint here, but none found yet.
1183
+ 2023-01-07 15:59:24,949 - speechbrain.utils.epoch_loop - INFO - Going into epoch 1
1184
+ 2023-01-07 15:59:27,528 - speechbrain.core - ERROR - Exception:
1185
+ Traceback (most recent call last):
1186
+ File "ctc_train.py", line 322, in <module>
1187
+ asr_brain.fit(
1188
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/speechbrain/core.py", line 1153, in fit
1189
+ self._fit_train(train_set=train_set, epoch=epoch, enable=enable)
1190
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/speechbrain/core.py", line 1009, in _fit_train
1191
+ loss = self.fit_batch(batch)
1192
+ File "ctc_train.py", line 92, in fit_batch
1193
+ self.wav2vec_optimizer.step()
1194
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/torch/optim/optimizer.py", line 140, in wrapper
1195
+ out = func(*args, **kwargs)
1196
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/torch/optim/optimizer.py", line 23, in _use_grad
1197
+ ret = func(self, *args, **kwargs)
1198
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/torch/optim/adam.py", line 220, in step
1199
+ state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
1200
+ torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 16.00 MiB (GPU 0; 5.80 GiB total capacity; 3.82 GiB already allocated; 52.50 MiB free; 3.94 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
1201
+ 2023-01-07 15:59:54,237 - speechbrain.core - INFO - Beginning experiment!
1202
+ 2023-01-07 15:59:54,237 - speechbrain.core - INFO - Experiment folder: partly_frozen_splitted_wavlm
1203
+ 2023-01-07 15:59:54,868 - speechbrain.utils.superpowers - DEBUG - abkhazia==1.0
1204
+ absl-py==0.11.0
1205
+ aiohttp==3.8.0
1206
+ aiosignal==1.2.0
1207
+ alabaster==0.7.12
1208
+ alembic==1.7.4
1209
+ altgraph==0.17
1210
+ antlr4-python3-runtime==4.8
1211
+ appdirs==1.4.4
1212
+ argcomplete==1.12.2
1213
+ argon2-cffi==20.1.0
1214
+ asgiref==3.6.0
1215
+ astunparse==1.6.3
1216
+ async-generator==1.10
1217
+ async-timeout==4.0.0
1218
+ attrdict==2.0.1
1219
+ attrs==20.3.0
1220
+ audeer==1.16.0
1221
+ audformat==0.11.5
1222
+ audinterface==0.7.0
1223
+ audiofile==1.0.0
1224
+ audiomentations==0.25.0
1225
+ audioread==2.1.9
1226
+ audobject==0.4.14
1227
+ audresample==0.1.6
1228
+ -e git+https://github.com/facebookresearch/WavAugment.git@54afcdb00ccc852c2f030f239f8532c9562b550e#egg=augment
1229
+ autopage==0.4.0
1230
+ Babel==2.9.0
1231
+ backcall==0.2.0
1232
+ beautifulsoup4==4.10.0
1233
+ black==19.10b0
1234
+ bleach==3.3.0
1235
+ boto3==1.20.2
1236
+ botocore==1.23.2
1237
+ braceexpand==0.1.7
1238
+ cachetools==4.2.0
1239
+ certifi @ file:///croot/certifi_1671487769961/work/certifi
1240
+ cffi==1.14.3
1241
+ cfgv==3.2.0
1242
+ chardet==3.0.4
1243
+ charset-normalizer==2.0.7
1244
+ click==7.1.2
1245
+ cliff==3.9.0
1246
+ clldutils==3.5.4
1247
+ cmaes==0.8.2
1248
+ cmake==3.18.4.post1
1249
+ cmd2==2.2.0
1250
+ colorama==0.4.4
1251
+ colorlog==4.6.2
1252
+ configparser==5.1.0
1253
+ cryptography==38.0.4
1254
+ csvw==1.8.1
1255
+ cycler==0.10.0
1256
+ Cython==0.29.21
1257
+ dataclasses==0.6
1258
+ datasets==1.5.0
1259
+ decorator==4.4.2
1260
+ deepspeech==0.9.1
1261
+ defusedxml==0.7.1
1262
+ denoiser==0.1.5
1263
+ dill==0.3.3
1264
+ Distance==0.1.3
1265
+ distlib==0.3.1
1266
+ Django==3.2.16
1267
+ django-auditlog==2.2.1
1268
+ django-filter==22.1
1269
+ django-js-asset==1.2.2
1270
+ django-mptt==0.14.0
1271
+ djangorestframework==3.14.0
1272
+ docker-pycreds==0.4.0
1273
+ docopt==0.6.2
1274
+ docutils==0.16
1275
+ drf-excel==2.2.0
1276
+ drf-flex-fields==1.0.0
1277
+ drf-renderer-xlsx==0.4.1
1278
+ easyocr==1.2.1
1279
+ editdistance==0.6.0
1280
+ emoji==2.2.0
1281
+ entrypoints==0.3
1282
+ et-xmlfile==1.1.0
1283
+ exceptiongroup==1.1.0
1284
+ farasapy==0.0.14
1285
+ fasttext==0.9.2
1286
+ ffmpeg-python==0.2.0
1287
+ filelock==3.0.12
1288
+ flake8==3.7.9
1289
+ flatbuffers==1.12
1290
+ frozendict==2.0.7
1291
+ frozenlist==1.2.0
1292
+ fsspec==2021.11.0
1293
+ future==0.18.2
1294
+ g2p-en==2.1.0
1295
+ gast==0.3.3
1296
+ gdown==4.2.0
1297
+ gensim==4.0.1
1298
+ gitdb==4.0.9
1299
+ GitPython==3.1.24
1300
+ google-auth==1.24.0
1301
+ google-auth-oauthlib==0.4.2
1302
+ google-pasta==0.2.0
1303
+ greenlet==1.1.2
1304
+ grpcio==1.32.0
1305
+ h5features==1.3.2
1306
+ h5py==2.10.0
1307
+ htk-io==0.5
1308
+ huggingface-hub==0.9.1
1309
+ hydra-colorlog==0.1.4
1310
+ hydra-core==0.11.3
1311
+ HyperPyYAML==1.1.0
1312
+ hypothesis==6.61.2
1313
+ identify==1.5.10
1314
+ idna==2.10
1315
+ imageio==2.9.0
1316
+ imagesize==1.2.0
1317
+ importlib-metadata==4.8.1
1318
+ importlib-resources==5.2.2
1319
+ inflect==5.3.0
1320
+ ipadic==1.0.0
1321
+ ipykernel==5.3.4
1322
+ ipython==7.19.0
1323
+ ipython-genutils==0.2.0
1324
+ ipywidgets==7.6.3
1325
+ iso-639==0.4.5
1326
+ isodate==0.6.0
1327
+ isort==4.3.21
1328
+ jedi==0.17.2
1329
+ jieba==0.42.1
1330
+ Jinja2==2.11.2
1331
+ jiwer==2.2.0
1332
+ jmespath==0.10.0
1333
+ joblib==0.17.0
1334
+ jsonschema==3.2.0
1335
+ julius==0.2.7
1336
+ jupyter-client==6.1.7
1337
+ jupyter-core==4.7.0
1338
+ jupyterlab-pygments==0.1.2
1339
+ jupyterlab-widgets==1.0.0
1340
+ kaitaistruct==0.9
1341
+ kaldi-io==0.9.4
1342
+ kaldi-python-io==1.2.2
1343
+ kaldiio==2.17.2
1344
+ kenlm @ https://github.com/kpu/kenlm/archive/master.zip
1345
+ Keras-Preprocessing==1.1.2
1346
+ kiwisolver==1.3.1
1347
+ lang-trans==0.6.0
1348
+ latexcodec==2.0.1
1349
+ ldap3==2.9.1
1350
+ librosa==0.9.0
1351
+ llvmlite==0.35.0
1352
+ lxml==4.9.0
1353
+ Mako==1.1.5
1354
+ Markdown==3.3.3
1355
+ MarkupSafe==1.1.1
1356
+ marshmallow==3.14.0
1357
+ matplotlib==3.3.3
1358
+ mccabe==0.6.1
1359
+ mcd==0.4
1360
+ mecab-python3==1.0.3
1361
+ megatron-lm==2.2.0
1362
+ mido==1.2.10
1363
+ mistune==0.8.4
1364
+ more-itertools==8.6.0
1365
+ mpmath==1.2.1
1366
+ multidict==5.2.0
1367
+ multiprocess==0.70.11.1
1368
+ nbclient==0.5.3
1369
+ nbconvert==6.0.7
1370
+ nbformat==5.1.3
1371
+ NEMO==4.3.2
1372
+ nemo-toolkit==1.4.0
1373
+ nest-asyncio==1.5.1
1374
+ networkx==2.5
1375
+ nltk==3.5
1376
+ nodeenv==1.5.0
1377
+ notebook==6.3.0
1378
+ numba==0.52.0
1379
+ numpy==1.19.4
1380
+ nvidia-cublas-cu11==11.10.3.66
1381
+ nvidia-cuda-nvrtc-cu11==11.7.99
1382
+ nvidia-cuda-runtime-cu11==11.7.99
1383
+ nvidia-cudnn-cu11==8.5.0.96
1384
+ oauthlib==3.1.0
1385
+ omegaconf==1.4.1
1386
+ onnx==1.10.2
1387
+ OpenCC==1.1.2
1388
+ opencv-python==4.4.0.46
1389
+ openpyxl==3.0.9
1390
+ opensmile==2.2.0
1391
+ opt-einsum==3.3.0
1392
+ optuna==2.10.0
1393
+ oyaml==1.0
1394
+ packaging==22.0
1395
+ pandas==1.2.5
1396
+ pandocfilters==1.4.3
1397
+ pangu==4.0.6.1
1398
+ parameterized==0.8.1
1399
+ parso==0.7.1
1400
+ pathspec==0.8.1
1401
+ pathtools==0.1.2
1402
+ pbr==5.6.0
1403
+ pefile==2019.4.18
1404
+ pescador==2.1.0
1405
+ pesq==0.0.3
1406
+ pexpect==4.8.0
1407
+ phonemizer==2.2.1
1408
+ pickleshare==0.7.5
1409
+ Pillow==9.3.0
1410
+ pip-api==0.0.23
1411
+ pipreqs==0.4.11
1412
+ pluggy==0.13.1
1413
+ pooch==1.3.0
1414
+ portalocker==2.3.2
1415
+ pre-commit==2.9.0
1416
+ pretty-midi==0.2.9
1417
+ prettytable==2.2.1
1418
+ progressbar2==3.53.1
1419
+ prometheus-client==0.10.1
1420
+ promise==2.3
1421
+ prompt-toolkit==3.0.8
1422
+ protobuf==3.14.0
1423
+ psutil==5.6.6
1424
+ ptyprocess==0.6.0
1425
+ py==1.9.0
1426
+ py-espeak-ng==0.1.8
1427
+ pyannote.audio==1.1.1
1428
+ pyannote.core==4.3
1429
+ pyannote.database==4.1.1
1430
+ pyannote.metrics==3.1
1431
+ pyannote.pipeline==1.5.2
1432
+ PyArabic==0.6.15
1433
+ pyarrow==3.0.0
1434
+ pyasn1==0.4.8
1435
+ pyasn1-modules==0.2.8
1436
+ pybind11==2.8.1
1437
+ pybtex==0.24.0
1438
+ pybtex-docutils==1.0.1
1439
+ pycodestyle==2.5.0
1440
+ pycparser==2.20
1441
+ pyctcdecode==0.4.0
1442
+ pyDeprecate==0.3.1
1443
+ pydub==0.25.1
1444
+ pyflakes==2.1.1
1445
+ Pygments==2.7.2
1446
+ pygtrie==2.5.0
1447
+ pymodbus==2.5.3
1448
+ pyparsing==2.4.7
1449
+ pyperclip==1.8.2
1450
+ pypinyin==0.43.0
1451
+ pyrsistent==0.17.3
1452
+ pyserial==3.5
1453
+ PySocks==1.7.1
1454
+ pystoi==0.3.3
1455
+ pytest==5.4.1
1456
+ pytest-runner==5.3.1
1457
+ python-bidi==0.4.2
1458
+ python-crfsuite==0.9.7
1459
+ python-dateutil==2.8.2
1460
+ python-Levenshtein==0.12.2
1461
+ python-utils==2.4.0
1462
+ pytorch-lightning==1.4.9
1463
+ pytube==11.0.1
1464
+ pytz==2022.6
1465
+ PyWavelets==1.1.1
1466
+ PyYAML==5.3.1
1467
+ pyzmq==20.0.0
1468
+ rapidfuzz==1.8.2
1469
+ regex==2020.11.13
1470
+ requests==2.28.1
1471
+ requests-oauthlib==1.3.0
1472
+ resampy==0.2.2
1473
+ rfc3986==1.4.0
1474
+ rsa==4.7
1475
+ ruamel.yaml==0.17.21
1476
+ ruamel.yaml.clib==0.2.7
1477
+ s3m==1.1.0
1478
+ s3transfer==0.5.0
1479
+ sacrebleu==2.0.0
1480
+ sacremoses==0.0.44
1481
+ scikit-image==0.18.1
1482
+ scikit-learn==0.23.2
1483
+ scipy==1.5.4
1484
+ -e git+https://github.com/sanghack81/SDCIT@00d060dde733fde9345154a494f81e97fb395ca7#egg=SDCIT
1485
+ seaborn==0.11.1
1486
+ segments==2.1.3
1487
+ Send2Trash==1.5.0
1488
+ sentencepiece==0.1.94
1489
+ sentry-sdk==1.4.3
1490
+ shellingham==1.4.0
1491
+ shortuuid==1.0.7
1492
+ SIDEKIT==1.3.8.5.2
1493
+ simplejson==3.17.5
1494
+ six==1.15.0
1495
+ smart-open==5.0.0
1496
+ smmap==5.0.0
1497
+ snowballstemmer==2.0.0
1498
+ sortedcollections==2.1.0
1499
+ sortedcontainers==2.4.0
1500
+ sounddevice==0.4.5
1501
+ SoundFile==0.10.3.post1
1502
+ soupsieve==2.3
1503
+ sox==1.4.1
1504
+ sparsemax==0.1.9
1505
+ speechbrain==0.5.13
1506
+ sphfile==1.0.3
1507
+ Sphinx==3.3.1
1508
+ sphinx-rtd-theme==0.4.3
1509
+ sphinxcontrib-applehelp==1.0.2
1510
+ sphinxcontrib-bibtex==2.4.1
1511
+ sphinxcontrib-devhelp==1.0.2
1512
+ sphinxcontrib-htmlhelp==1.0.3
1513
+ sphinxcontrib-jsmath==1.0.1
1514
+ sphinxcontrib-qthelp==1.0.3
1515
+ sphinxcontrib-serializinghtml==1.1.4
1516
+ SQLAlchemy==1.4.25
1517
+ sqlparse==0.4.2
1518
+ stanza==1.4.2
1519
+ stevedore==3.4.0
1520
+ subprocess32==3.5.4
1521
+ sympy==1.9
1522
+ tabulate==0.8.9
1523
+ tensorboard==2.4.0
1524
+ tensorboard-plugin-wit==1.7.0
1525
+ tensorflow==2.4.0
1526
+ tensorflow-estimator==2.4.0
1527
+ termcolor==1.1.0
1528
+ terminado==0.9.4
1529
+ testpath==0.4.4
1530
+ threadpoolctl==2.1.0
1531
+ tifffile==2020.12.8
1532
+ tikzplotlib==0.9.8
1533
+ tkseem==0.0.3
1534
+ tokenizers==0.10.2
1535
+ toml==0.10.2
1536
+ torch==1.13.1
1537
+ torch-stft==0.1.4
1538
+ torchaudio==0.13.1
1539
+ torchmetrics==0.6.0
1540
+ torchvision==0.14.1
1541
+ tornado==6.1
1542
+ tqdm==4.61.1
1543
+ trackrip==1.2.1
1544
+ traitlets==5.0.5
1545
+ transformers==4.15.0
1546
+ typed-ast==1.4.1
1547
+ typer==0.4.0
1548
+ typing-extensions==3.7.4.3
1549
+ Unidecode==1.3.2
1550
+ uritemplate==3.0.1
1551
+ urllib3==1.26.2
1552
+ virtualenv==20.2.1
1553
+ wandb==0.12.6
1554
+ wcwidth==0.2.5
1555
+ webdataset==0.1.62
1556
+ webencodings==0.5.1
1557
+ Werkzeug==1.0.1
1558
+ wget==3.2
1559
+ widgetsnbextension==3.5.1
1560
+ wordninja==2.0.0
1561
+ wrapt==1.12.1
1562
+ xmltodict==0.13.0
1563
+ xxhash==2.0.0
1564
+ yamllint==1.23.0
1565
+ yarg==0.1.9
1566
+ yarl==1.7.2
1567
+ yaspin==2.1.0
1568
+ youtokentome==1.0.6
1569
+ youtube-dl==2021.6.6
1570
+ zipp==3.6.0
1571
+
1572
+
1573
+ 2023-01-07 15:59:54,960 - speechbrain.dataio.encoder - DEBUG - Loaded categorical encoding from partly_frozen_splitted_wavlm/save/label_encoder.txt
1574
+ 2023-01-07 15:59:54,960 - speechbrain.dataio.encoder - INFO - Load called, but CTCTextEncoder is not empty. Loaded data will overwrite everything. This is normal if there is e.g. an unk label defined at init.
1575
+ 2023-01-07 15:59:54,961 - speechbrain.dataio.encoder - DEBUG - Loaded categorical encoding from partly_frozen_splitted_wavlm/save/label_encoder.txt
1576
+ 2023-01-07 15:59:54,961 - speechbrain.core - INFO - Info: auto_mix_prec arg from hparam file is used
1577
+ 2023-01-07 15:59:54,961 - speechbrain.core - INFO - Info: ckpt_interval_minutes arg from hparam file is used
1578
+ 2023-01-07 15:59:56,396 - speechbrain.core - INFO - 313.4M trainable parameters in ASR
1579
+ 2023-01-07 15:59:56,398 - speechbrain.utils.checkpoints - INFO - Would load a checkpoint here, but none found yet.
1580
+ 2023-01-07 15:59:56,398 - speechbrain.utils.epoch_loop - INFO - Going into epoch 1
1581
+ 2023-01-07 15:59:57,755 - speechbrain.core - ERROR - Exception:
1582
+ Traceback (most recent call last):
1583
+ File "ctc_train.py", line 322, in <module>
1584
+ asr_brain.fit(
1585
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/speechbrain/core.py", line 1153, in fit
1586
+ self._fit_train(train_set=train_set, epoch=epoch, enable=enable)
1587
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/speechbrain/core.py", line 1009, in _fit_train
1588
+ loss = self.fit_batch(batch)
1589
+ File "ctc_train.py", line 92, in fit_batch
1590
+ self.wav2vec_optimizer.step()
1591
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/torch/optim/optimizer.py", line 140, in wrapper
1592
+ out = func(*args, **kwargs)
1593
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/torch/optim/optimizer.py", line 23, in _use_grad
1594
+ ret = func(self, *args, **kwargs)
1595
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/torch/optim/adam.py", line 218, in step
1596
+ state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
1597
+ torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 5.80 GiB total capacity; 3.86 GiB already allocated; 64.50 MiB free; 3.93 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
1598
+ 2023-01-07 16:00:47,395 - speechbrain.core - INFO - Beginning experiment!
1599
+ 2023-01-07 16:00:47,395 - speechbrain.core - INFO - Experiment folder: partly_frozen_splitted_wavlm
1600
+ 2023-01-07 16:00:48,055 - speechbrain.utils.superpowers - DEBUG - abkhazia==1.0
1601
+ absl-py==0.11.0
1602
+ aiohttp==3.8.0
1603
+ aiosignal==1.2.0
1604
+ alabaster==0.7.12
1605
+ alembic==1.7.4
1606
+ altgraph==0.17
1607
+ antlr4-python3-runtime==4.8
1608
+ appdirs==1.4.4
1609
+ argcomplete==1.12.2
1610
+ argon2-cffi==20.1.0
1611
+ asgiref==3.6.0
1612
+ astunparse==1.6.3
1613
+ async-generator==1.10
1614
+ async-timeout==4.0.0
1615
+ attrdict==2.0.1
1616
+ attrs==20.3.0
1617
+ audeer==1.16.0
1618
+ audformat==0.11.5
1619
+ audinterface==0.7.0
1620
+ audiofile==1.0.0
1621
+ audiomentations==0.25.0
1622
+ audioread==2.1.9
1623
+ audobject==0.4.14
1624
+ audresample==0.1.6
1625
+ -e git+https://github.com/facebookresearch/WavAugment.git@54afcdb00ccc852c2f030f239f8532c9562b550e#egg=augment
1626
+ autopage==0.4.0
1627
+ Babel==2.9.0
1628
+ backcall==0.2.0
1629
+ beautifulsoup4==4.10.0
1630
+ black==19.10b0
1631
+ bleach==3.3.0
1632
+ boto3==1.20.2
1633
+ botocore==1.23.2
1634
+ braceexpand==0.1.7
1635
+ cachetools==4.2.0
1636
+ certifi @ file:///croot/certifi_1671487769961/work/certifi
1637
+ cffi==1.14.3
1638
+ cfgv==3.2.0
1639
+ chardet==3.0.4
1640
+ charset-normalizer==2.0.7
1641
+ click==7.1.2
1642
+ cliff==3.9.0
1643
+ clldutils==3.5.4
1644
+ cmaes==0.8.2
1645
+ cmake==3.18.4.post1
1646
+ cmd2==2.2.0
1647
+ colorama==0.4.4
1648
+ colorlog==4.6.2
1649
+ configparser==5.1.0
1650
+ cryptography==38.0.4
1651
+ csvw==1.8.1
1652
+ cycler==0.10.0
1653
+ Cython==0.29.21
1654
+ dataclasses==0.6
1655
+ datasets==1.5.0
1656
+ decorator==4.4.2
1657
+ deepspeech==0.9.1
1658
+ defusedxml==0.7.1
1659
+ denoiser==0.1.5
1660
+ dill==0.3.3
1661
+ Distance==0.1.3
1662
+ distlib==0.3.1
1663
+ Django==3.2.16
1664
+ django-auditlog==2.2.1
1665
+ django-filter==22.1
1666
+ django-js-asset==1.2.2
1667
+ django-mptt==0.14.0
1668
+ djangorestframework==3.14.0
1669
+ docker-pycreds==0.4.0
1670
+ docopt==0.6.2
1671
+ docutils==0.16
1672
+ drf-excel==2.2.0
1673
+ drf-flex-fields==1.0.0
1674
+ drf-renderer-xlsx==0.4.1
1675
+ easyocr==1.2.1
1676
+ editdistance==0.6.0
1677
+ emoji==2.2.0
1678
+ entrypoints==0.3
1679
+ et-xmlfile==1.1.0
1680
+ exceptiongroup==1.1.0
1681
+ farasapy==0.0.14
1682
+ fasttext==0.9.2
1683
+ ffmpeg-python==0.2.0
1684
+ filelock==3.0.12
1685
+ flake8==3.7.9
1686
+ flatbuffers==1.12
1687
+ frozendict==2.0.7
1688
+ frozenlist==1.2.0
1689
+ fsspec==2021.11.0
1690
+ future==0.18.2
1691
+ g2p-en==2.1.0
1692
+ gast==0.3.3
1693
+ gdown==4.2.0
1694
+ gensim==4.0.1
1695
+ gitdb==4.0.9
1696
+ GitPython==3.1.24
1697
+ google-auth==1.24.0
1698
+ google-auth-oauthlib==0.4.2
1699
+ google-pasta==0.2.0
1700
+ greenlet==1.1.2
1701
+ grpcio==1.32.0
1702
+ h5features==1.3.2
1703
+ h5py==2.10.0
1704
+ htk-io==0.5
1705
+ huggingface-hub==0.9.1
1706
+ hydra-colorlog==0.1.4
1707
+ hydra-core==0.11.3
1708
+ HyperPyYAML==1.1.0
1709
+ hypothesis==6.61.2
1710
+ identify==1.5.10
1711
+ idna==2.10
1712
+ imageio==2.9.0
1713
+ imagesize==1.2.0
1714
+ importlib-metadata==4.8.1
1715
+ importlib-resources==5.2.2
1716
+ inflect==5.3.0
1717
+ ipadic==1.0.0
1718
+ ipykernel==5.3.4
1719
+ ipython==7.19.0
1720
+ ipython-genutils==0.2.0
1721
+ ipywidgets==7.6.3
1722
+ iso-639==0.4.5
1723
+ isodate==0.6.0
1724
+ isort==4.3.21
1725
+ jedi==0.17.2
1726
+ jieba==0.42.1
1727
+ Jinja2==2.11.2
1728
+ jiwer==2.2.0
1729
+ jmespath==0.10.0
1730
+ joblib==0.17.0
1731
+ jsonschema==3.2.0
1732
+ julius==0.2.7
1733
+ jupyter-client==6.1.7
1734
+ jupyter-core==4.7.0
1735
+ jupyterlab-pygments==0.1.2
1736
+ jupyterlab-widgets==1.0.0
1737
+ kaitaistruct==0.9
1738
+ kaldi-io==0.9.4
1739
+ kaldi-python-io==1.2.2
1740
+ kaldiio==2.17.2
1741
+ kenlm @ https://github.com/kpu/kenlm/archive/master.zip
1742
+ Keras-Preprocessing==1.1.2
1743
+ kiwisolver==1.3.1
1744
+ lang-trans==0.6.0
1745
+ latexcodec==2.0.1
1746
+ ldap3==2.9.1
1747
+ librosa==0.9.0
1748
+ llvmlite==0.35.0
1749
+ lxml==4.9.0
1750
+ Mako==1.1.5
1751
+ Markdown==3.3.3
1752
+ MarkupSafe==1.1.1
1753
+ marshmallow==3.14.0
1754
+ matplotlib==3.3.3
1755
+ mccabe==0.6.1
1756
+ mcd==0.4
1757
+ mecab-python3==1.0.3
1758
+ megatron-lm==2.2.0
1759
+ mido==1.2.10
1760
+ mistune==0.8.4
1761
+ more-itertools==8.6.0
1762
+ mpmath==1.2.1
1763
+ multidict==5.2.0
1764
+ multiprocess==0.70.11.1
1765
+ nbclient==0.5.3
1766
+ nbconvert==6.0.7
1767
+ nbformat==5.1.3
1768
+ NEMO==4.3.2
1769
+ nemo-toolkit==1.4.0
1770
+ nest-asyncio==1.5.1
1771
+ networkx==2.5
1772
+ nltk==3.5
1773
+ nodeenv==1.5.0
1774
+ notebook==6.3.0
1775
+ numba==0.52.0
1776
+ numpy==1.19.4
1777
+ nvidia-cublas-cu11==11.10.3.66
1778
+ nvidia-cuda-nvrtc-cu11==11.7.99
1779
+ nvidia-cuda-runtime-cu11==11.7.99
1780
+ nvidia-cudnn-cu11==8.5.0.96
1781
+ oauthlib==3.1.0
1782
+ omegaconf==1.4.1
1783
+ onnx==1.10.2
1784
+ OpenCC==1.1.2
1785
+ opencv-python==4.4.0.46
1786
+ openpyxl==3.0.9
1787
+ opensmile==2.2.0
1788
+ opt-einsum==3.3.0
1789
+ optuna==2.10.0
1790
+ oyaml==1.0
1791
+ packaging==22.0
1792
+ pandas==1.2.5
1793
+ pandocfilters==1.4.3
1794
+ pangu==4.0.6.1
1795
+ parameterized==0.8.1
1796
+ parso==0.7.1
1797
+ pathspec==0.8.1
1798
+ pathtools==0.1.2
1799
+ pbr==5.6.0
1800
+ pefile==2019.4.18
1801
+ pescador==2.1.0
1802
+ pesq==0.0.3
1803
+ pexpect==4.8.0
1804
+ phonemizer==2.2.1
1805
+ pickleshare==0.7.5
1806
+ Pillow==9.3.0
1807
+ pip-api==0.0.23
1808
+ pipreqs==0.4.11
1809
+ pluggy==0.13.1
1810
+ pooch==1.3.0
1811
+ portalocker==2.3.2
1812
+ pre-commit==2.9.0
1813
+ pretty-midi==0.2.9
1814
+ prettytable==2.2.1
1815
+ progressbar2==3.53.1
1816
+ prometheus-client==0.10.1
1817
+ promise==2.3
1818
+ prompt-toolkit==3.0.8
1819
+ protobuf==3.14.0
1820
+ psutil==5.6.6
1821
+ ptyprocess==0.6.0
1822
+ py==1.9.0
1823
+ py-espeak-ng==0.1.8
1824
+ pyannote.audio==1.1.1
1825
+ pyannote.core==4.3
1826
+ pyannote.database==4.1.1
1827
+ pyannote.metrics==3.1
1828
+ pyannote.pipeline==1.5.2
1829
+ PyArabic==0.6.15
1830
+ pyarrow==3.0.0
1831
+ pyasn1==0.4.8
1832
+ pyasn1-modules==0.2.8
1833
+ pybind11==2.8.1
1834
+ pybtex==0.24.0
1835
+ pybtex-docutils==1.0.1
1836
+ pycodestyle==2.5.0
1837
+ pycparser==2.20
1838
+ pyctcdecode==0.4.0
1839
+ pyDeprecate==0.3.1
1840
+ pydub==0.25.1
1841
+ pyflakes==2.1.1
1842
+ Pygments==2.7.2
1843
+ pygtrie==2.5.0
1844
+ pymodbus==2.5.3
1845
+ pyparsing==2.4.7
1846
+ pyperclip==1.8.2
1847
+ pypinyin==0.43.0
1848
+ pyrsistent==0.17.3
1849
+ pyserial==3.5
1850
+ PySocks==1.7.1
1851
+ pystoi==0.3.3
1852
+ pytest==5.4.1
1853
+ pytest-runner==5.3.1
1854
+ python-bidi==0.4.2
1855
+ python-crfsuite==0.9.7
1856
+ python-dateutil==2.8.2
1857
+ python-Levenshtein==0.12.2
1858
+ python-utils==2.4.0
1859
+ pytorch-lightning==1.4.9
1860
+ pytube==11.0.1
1861
+ pytz==2022.6
1862
+ PyWavelets==1.1.1
1863
+ PyYAML==5.3.1
1864
+ pyzmq==20.0.0
1865
+ rapidfuzz==1.8.2
1866
+ regex==2020.11.13
1867
+ requests==2.28.1
1868
+ requests-oauthlib==1.3.0
1869
+ resampy==0.2.2
1870
+ rfc3986==1.4.0
1871
+ rsa==4.7
1872
+ ruamel.yaml==0.17.21
1873
+ ruamel.yaml.clib==0.2.7
1874
+ s3m==1.1.0
1875
+ s3transfer==0.5.0
1876
+ sacrebleu==2.0.0
1877
+ sacremoses==0.0.44
1878
+ scikit-image==0.18.1
1879
+ scikit-learn==0.23.2
1880
+ scipy==1.5.4
1881
+ -e git+https://github.com/sanghack81/SDCIT@00d060dde733fde9345154a494f81e97fb395ca7#egg=SDCIT
1882
+ seaborn==0.11.1
1883
+ segments==2.1.3
1884
+ Send2Trash==1.5.0
1885
+ sentencepiece==0.1.94
1886
+ sentry-sdk==1.4.3
1887
+ shellingham==1.4.0
1888
+ shortuuid==1.0.7
1889
+ SIDEKIT==1.3.8.5.2
1890
+ simplejson==3.17.5
1891
+ six==1.15.0
1892
+ smart-open==5.0.0
1893
+ smmap==5.0.0
1894
+ snowballstemmer==2.0.0
1895
+ sortedcollections==2.1.0
1896
+ sortedcontainers==2.4.0
1897
+ sounddevice==0.4.5
1898
+ SoundFile==0.10.3.post1
1899
+ soupsieve==2.3
1900
+ sox==1.4.1
1901
+ sparsemax==0.1.9
1902
+ speechbrain==0.5.13
1903
+ sphfile==1.0.3
1904
+ Sphinx==3.3.1
1905
+ sphinx-rtd-theme==0.4.3
1906
+ sphinxcontrib-applehelp==1.0.2
1907
+ sphinxcontrib-bibtex==2.4.1
1908
+ sphinxcontrib-devhelp==1.0.2
1909
+ sphinxcontrib-htmlhelp==1.0.3
1910
+ sphinxcontrib-jsmath==1.0.1
1911
+ sphinxcontrib-qthelp==1.0.3
1912
+ sphinxcontrib-serializinghtml==1.1.4
1913
+ SQLAlchemy==1.4.25
1914
+ sqlparse==0.4.2
1915
+ stanza==1.4.2
1916
+ stevedore==3.4.0
1917
+ subprocess32==3.5.4
1918
+ sympy==1.9
1919
+ tabulate==0.8.9
1920
+ tensorboard==2.4.0
1921
+ tensorboard-plugin-wit==1.7.0
1922
+ tensorflow==2.4.0
1923
+ tensorflow-estimator==2.4.0
1924
+ termcolor==1.1.0
1925
+ terminado==0.9.4
1926
+ testpath==0.4.4
1927
+ threadpoolctl==2.1.0
1928
+ tifffile==2020.12.8
1929
+ tikzplotlib==0.9.8
1930
+ tkseem==0.0.3
1931
+ tokenizers==0.10.2
1932
+ toml==0.10.2
1933
+ torch==1.13.1
1934
+ torch-stft==0.1.4
1935
+ torchaudio==0.13.1
1936
+ torchmetrics==0.6.0
1937
+ torchvision==0.14.1
1938
+ tornado==6.1
1939
+ tqdm==4.61.1
1940
+ trackrip==1.2.1
1941
+ traitlets==5.0.5
1942
+ transformers==4.15.0
1943
+ typed-ast==1.4.1
1944
+ typer==0.4.0
1945
+ typing-extensions==3.7.4.3
1946
+ Unidecode==1.3.2
1947
+ uritemplate==3.0.1
1948
+ urllib3==1.26.2
1949
+ virtualenv==20.2.1
1950
+ wandb==0.12.6
1951
+ wcwidth==0.2.5
1952
+ webdataset==0.1.62
1953
+ webencodings==0.5.1
1954
+ Werkzeug==1.0.1
1955
+ wget==3.2
1956
+ widgetsnbextension==3.5.1
1957
+ wordninja==2.0.0
1958
+ wrapt==1.12.1
1959
+ xmltodict==0.13.0
1960
+ xxhash==2.0.0
1961
+ yamllint==1.23.0
1962
+ yarg==0.1.9
1963
+ yarl==1.7.2
1964
+ yaspin==2.1.0
1965
+ youtokentome==1.0.6
1966
+ youtube-dl==2021.6.6
1967
+ zipp==3.6.0
1968
+
1969
+
1970
+ 2023-01-07 16:00:48,174 - speechbrain.dataio.encoder - DEBUG - Loaded categorical encoding from partly_frozen_splitted_wavlm/save/label_encoder.txt
1971
+ 2023-01-07 16:00:48,174 - speechbrain.dataio.encoder - INFO - Load called, but CTCTextEncoder is not empty. Loaded data will overwrite everything. This is normal if there is e.g. an unk label defined at init.
1972
+ 2023-01-07 16:00:48,175 - speechbrain.dataio.encoder - DEBUG - Loaded categorical encoding from partly_frozen_splitted_wavlm/save/label_encoder.txt
1973
+ 2023-01-07 16:00:48,176 - speechbrain.core - INFO - Info: auto_mix_prec arg from hparam file is used
1974
+ 2023-01-07 16:00:48,176 - speechbrain.core - INFO - Info: ckpt_interval_minutes arg from hparam file is used
1975
+ 2023-01-07 16:00:49,674 - speechbrain.core - INFO - 313.4M trainable parameters in ASR
1976
+ 2023-01-07 16:00:50,298 - speechbrain.utils.checkpoints - INFO - Would load a checkpoint here, but none found yet.
1977
+ 2023-01-07 16:00:50,298 - speechbrain.utils.epoch_loop - INFO - Going into epoch 1
1978
+ 2023-01-07 16:01:08,641 - speechbrain.core - ERROR - Exception:
1979
+ Traceback (most recent call last):
1980
+ File "ctc_train.py", line 324, in <module>
1981
+ asr_brain.fit(
1982
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/speechbrain/core.py", line 1153, in fit
1983
+ self._fit_train(train_set=train_set, epoch=epoch, enable=enable)
1984
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/speechbrain/core.py", line 1009, in _fit_train
1985
+ loss = self.fit_batch(batch)
1986
+ File "ctc_train.py", line 92, in fit_batch
1987
+ self.wav2vec_optimizer.step()
1988
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/torch/optim/optimizer.py", line 140, in wrapper
1989
+ out = func(*args, **kwargs)
1990
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/torch/optim/optimizer.py", line 23, in _use_grad
1991
+ ret = func(self, *args, **kwargs)
1992
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/torch/optim/adam.py", line 234, in step
1993
+ adam(params_with_grad,
1994
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/torch/optim/adam.py", line 300, in adam
1995
+ func(params,
1996
+ File "/home/salah/anaconda3/envs/salah/lib/python3.8/site-packages/torch/optim/adam.py", line 410, in _single_tensor_adam
1997
+ denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
1998
+ KeyboardInterrupt
partly_frozen_splitted_wavlm/save/label_encoder.txt ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'ت' => 29
2
+ 'ع' => 30
3
+ 'ب' => 31
4
+ ' ' => 3
5
+ 'ه' => 4
6
+ 'ا' => 5
7
+ 'ن' => 6
8
+ 'ي' => 7
9
+ 'ر' => 8
10
+ 'ك' => 9
11
+ 'ش' => 10
12
+ 'ف' => 11
13
+ 'ل' => 12
14
+ 'د' => 13
15
+ 'س' => 14
16
+ 'م' => 15
17
+ 'ق' => 16
18
+ 'ى' => 17
19
+ 'ء' => 18
20
+ 'و' => 19
21
+ 'ح' => 20
22
+ 'ز' => 21
23
+ 'ة' => 22
24
+ 'أ' => 23
25
+ 'خ' => 24
26
+ 'ص' => 25
27
+ 'ط' => 26
28
+ 'ج' => 27
29
+ 'ظ' => 28
30
+ '<blank>' => 0
31
+ '<bos>' => 1
32
+ '<eos>' => 2
33
+ ================
34
+ 'starting_index' => 0
35
+ 'bos_label' => '<bos>'
36
+ 'eos_label' => '<eos>'
37
+ 'blank_label' => '<blank>'
recording.webm ADDED
Binary file (48.8 kB). View file
 
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -r lint-requirements.txt
2
+ huggingface_hub>=0.7.0
3
+ hyperpyyaml>=0.0.1
4
+ joblib>=0.14.1
5
+ numpy>=1.17.0
6
+ packaging
7
+ pre-commit>=2.3.0
8
+ scipy>=1.4.1, <1.9
9
+ sentencepiece>=0.1.91
10
+ SoundFile; sys_platform == 'win32'
11
+ torch>=1.9.0
12
+ torchaudio>=0.9.0
13
+ tqdm>=4.42.0
14
+ transformers==4.15
15
+ speechbrain
16
+ pyctcdecode
running_tunisian.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
samples/Salah1.wav ADDED
Binary file (952 kB). View file
 
samples/Salah10.wav ADDED
Binary file (768 kB). View file