transiteration
commited on
Commit
•
0479abb
1
Parent(s):
aeeafe2
Upload 8 files
Browse files- evaluate.py +60 -0
- ksc/.ksc-train.json.swp +0 -0
- ksc/test_manifest.json +3 -0
- ksc/train_manifest.json +3 -0
- ksc/val_manifest.json +3 -0
- requirements.txt +214 -0
- train.py +163 -0
- transcribe.py +22 -0
evaluate.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from typing import Dict
|
3 |
+
|
4 |
+
import nemo.collections.asr as nemo_asr
|
5 |
+
import torch
|
6 |
+
from omegaconf import open_dict
|
7 |
+
|
8 |
+
|
9 |
+
def evaluate_model(model_path: str, test_manifest: str, batch_size: int = 1) -> Dict:
|
10 |
+
|
11 |
+
# Determine the device (CPU or GPU)
|
12 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
+
|
14 |
+
# Restore the ASR model from the provided path
|
15 |
+
model = nemo_asr.models.ASRModel.restore_from(restore_path=model_path)
|
16 |
+
model.to(device)
|
17 |
+
model.eval()
|
18 |
+
|
19 |
+
# Update the model configuration for evaluation
|
20 |
+
with open_dict(model.cfg):
|
21 |
+
model.cfg.validation_ds.manifest_filepath = test_manifest
|
22 |
+
model.cfg.validation_ds.batch_size = batch_size
|
23 |
+
|
24 |
+
# Set up the test data using the updated configuration
|
25 |
+
model.setup_test_data(model.cfg.validation_ds)
|
26 |
+
|
27 |
+
wer_nums = []
|
28 |
+
wer_denoms = []
|
29 |
+
|
30 |
+
# Iterate through the test data
|
31 |
+
for test_batch in model.test_dataloader():
|
32 |
+
# Extract elements from the test batch
|
33 |
+
test_batch = [x for x in test_batch]
|
34 |
+
targets = test_batch[2].to(device)
|
35 |
+
targets_lengths = test_batch[3].to(device)
|
36 |
+
# Forward pass through the model
|
37 |
+
log_probs, encoded_len, greedy_predictions = model(input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device))
|
38 |
+
# Compute Word Error Rate (WER) and store results
|
39 |
+
model._wer.update(greedy_predictions, targets, targets_lengths)
|
40 |
+
_, wer_num, wer_denom = model._wer.compute()
|
41 |
+
model._wer.reset()
|
42 |
+
wer_nums.append(wer_num.detach().cpu().numpy())
|
43 |
+
wer_denoms.append(wer_denom.detach().cpu().numpy())
|
44 |
+
# Free up memory by deleting variables
|
45 |
+
del test_batch, log_probs, targets, targets_lengths, encoded_len, greedy_predictions
|
46 |
+
|
47 |
+
# Compute the WER score
|
48 |
+
wer_score = sum(wer_nums) / sum(wer_denoms)
|
49 |
+
print({"WER_score": wer_score})
|
50 |
+
|
51 |
+
|
52 |
+
if __name__ == "__main__":
|
53 |
+
# Parse command line arguments
|
54 |
+
parser = argparse.ArgumentParser()
|
55 |
+
parser.add_argument("--model_path", default=None, help="Path to a model to evaluate.")
|
56 |
+
parser.add_argument("--test_manifest", help="Path for train manifest JSON file.")
|
57 |
+
parser.add_argument("--batch_size", type=int, default=1, help="Batch size of the dataset to train.")
|
58 |
+
args = parser.parse_args()
|
59 |
+
|
60 |
+
evaluate_model(model_path=args.model_path, test_manifest=args.test_manifest, batch_size=args.batch_size)
|
ksc/.ksc-train.json.swp
ADDED
Binary file (12.3 kB). View file
|
|
ksc/test_manifest.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{"audio_filepath": "ksc/test/crowdsourced/5f60e4ecdb745.wav", "text": "бірақ бізде онымен айналысатын қожалықтар саны саусақпен санарлық", "duration": 7.936}
|
2 |
+
{"audio_filepath": "ksc/test/crowdsourced/5f609b2ad370e.wav", "text": "солардың бірі маңғыстаулық шопан есет өтесов", "duration": 5.4613125}
|
3 |
+
{"audio_filepath": "ksc/test/crowdsourced/5f5682c1c4739.wav", "text": "иесіз жануарларды ату керек пе әлде асырау керек пе", "duration": 6.7413125}
|
ksc/train_manifest.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{"audio_filepath": "ksc/train/crowdsourced/5f5861e2baed1.wav", "text": "төрт бұрышты ошақ деп атайды", "duration": 2.816}
|
2 |
+
{"audio_filepath": "ksc/train/crowdsourced/5f2b0a559b15f.wav", "text": "оны адамдардың жақтырмаушылық пен көре алмаушылықтары бағдат қаласынан кетуге мәжбүр етті", "duration": 10.0693125}
|
3 |
+
{"audio_filepath": "ksc/train/crowdsourced/5f57300c9e48b.wav", "text": "кіші сордук ресейдегі өзен", "duration": 4.352}
|
ksc/val_manifest.json
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
{"audio_filepath": "ksc/val/crowdsourced/5f5a35cfcd7b2.wav", "text": "егер мынандай жағдайда әуе компаниясы көмек бере алмаса мемлекет қол ұшын созады", "duration": 8.192}
|
2 |
+
{"audio_filepath": "ksc/val/crowdsourced/5f60ff61277a3.wav", "text": "өйткені түсіру жұмыстары басталғанға дейін актерлер екі ай бойы арнайы каскадерлік жаттығулардан өткен", "duration": 7.936}
|
3 |
+
{"audio_filepath": "ksc/val/crowdsourced/5f5a5a5a199d7.wav", "text": "беру аяқталды дауыс санау басталды", "duration": 3.072}
|
requirements.txt
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==2.1.0
|
2 |
+
aiohttp==3.8.6
|
3 |
+
aiosignal==1.3.1
|
4 |
+
alabaster==0.7.13
|
5 |
+
antlr4-python3-runtime==4.9.3
|
6 |
+
appdirs==1.4.4
|
7 |
+
async-timeout==4.0.3
|
8 |
+
asynctest==0.13.0
|
9 |
+
attrdict==2.0.1
|
10 |
+
attrs==23.2.0
|
11 |
+
audioread==3.0.1
|
12 |
+
Babel==2.14.0
|
13 |
+
backcall==0.2.0
|
14 |
+
beautifulsoup4==4.12.3
|
15 |
+
black==19.10b0
|
16 |
+
boto3==1.33.13
|
17 |
+
botocore==1.33.13
|
18 |
+
braceexpand==0.1.7
|
19 |
+
cachetools==5.3.2
|
20 |
+
certifi==2023.11.17
|
21 |
+
cffi==1.15.1
|
22 |
+
charset-normalizer==3.3.2
|
23 |
+
click==8.0.2
|
24 |
+
colorama==0.4.6
|
25 |
+
comm==0.1.4
|
26 |
+
cycler==0.11.0
|
27 |
+
Cython==3.0.8
|
28 |
+
decorator==5.1.1
|
29 |
+
Distance==0.1.3
|
30 |
+
docker-pycreds==0.4.0
|
31 |
+
docopt==0.6.2
|
32 |
+
docutils==0.17.1
|
33 |
+
editdistance==0.6.2
|
34 |
+
exceptiongroup==1.2.0
|
35 |
+
fasttext==0.9.2
|
36 |
+
filelock==3.12.2
|
37 |
+
flake8==5.0.4
|
38 |
+
Flake8-pyproject==1.2.3
|
39 |
+
fonttools==4.38.0
|
40 |
+
frozendict==2.4.0
|
41 |
+
frozenlist==1.3.3
|
42 |
+
fsspec==2023.1.0
|
43 |
+
ftfy==6.1.1
|
44 |
+
future==0.18.3
|
45 |
+
g2p-en==2.1.0
|
46 |
+
gdown==4.7.3
|
47 |
+
gitdb==4.0.11
|
48 |
+
GitPython==3.1.41
|
49 |
+
google-auth==2.26.2
|
50 |
+
google-auth-oauthlib==0.4.6
|
51 |
+
grpcio==1.60.0
|
52 |
+
h5py==3.8.0
|
53 |
+
huggingface-hub==0.16.4
|
54 |
+
Hydra==2.5
|
55 |
+
hydra-core==1.3.2
|
56 |
+
idna==3.6
|
57 |
+
imagesize==1.4.1
|
58 |
+
importlib-metadata==4.2.0
|
59 |
+
importlib-resources==5.12.0
|
60 |
+
inflect==6.0.5
|
61 |
+
iniconfig==2.0.0
|
62 |
+
ipadic==1.0.0
|
63 |
+
ipython==7.34.0
|
64 |
+
ipywidgets==8.1.1
|
65 |
+
isort==4.3.21
|
66 |
+
jedi==0.19.1
|
67 |
+
jieba==0.42.1
|
68 |
+
Jinja2==3.1.3
|
69 |
+
jmespath==1.0.1
|
70 |
+
joblib==1.3.2
|
71 |
+
jupyterlab-widgets==3.0.9
|
72 |
+
kaldi-io==0.9.8
|
73 |
+
kaldi-python-io==1.2.2
|
74 |
+
kaldiio==2.18.0
|
75 |
+
kiwisolver==1.4.5
|
76 |
+
latexcodec==2.0.1
|
77 |
+
lazy_loader==0.3
|
78 |
+
librosa==0.10.1
|
79 |
+
llvmlite==0.39.1
|
80 |
+
lxml==5.1.0
|
81 |
+
Markdown==3.4.4
|
82 |
+
markdown-it-py==2.2.0
|
83 |
+
MarkupSafe==2.1.3
|
84 |
+
marshmallow==3.19.0
|
85 |
+
matplotlib==3.5.3
|
86 |
+
matplotlib-inline==0.1.6
|
87 |
+
mccabe==0.7.0
|
88 |
+
mdurl==0.1.2
|
89 |
+
mecab-python3==1.0.6
|
90 |
+
mpmath==1.3.0
|
91 |
+
msgpack==1.0.5
|
92 |
+
multidict==6.0.4
|
93 |
+
nemo-toolkit==1.7.0
|
94 |
+
nltk==3.8.1
|
95 |
+
numba==0.56.4
|
96 |
+
numpy==1.21.6
|
97 |
+
oauthlib==3.2.2
|
98 |
+
omegaconf==2.3.0
|
99 |
+
onnx==1.14.1
|
100 |
+
OpenCC==1.1.6
|
101 |
+
packaging==23.2
|
102 |
+
pandas==1.3.5
|
103 |
+
pangu==4.0.6.1
|
104 |
+
parameterized==0.9.0
|
105 |
+
parso==0.8.3
|
106 |
+
pathspec==0.11.2
|
107 |
+
pesq==0.0.4
|
108 |
+
pexpect==4.9.0
|
109 |
+
pickleshare==0.7.5
|
110 |
+
Pillow==9.5.0
|
111 |
+
pip-api==0.0.30
|
112 |
+
pipreqs==0.4.13
|
113 |
+
platformdirs==4.0.0
|
114 |
+
pluggy==1.2.0
|
115 |
+
pooch==1.8.0
|
116 |
+
portalocker==2.7.0
|
117 |
+
prompt-toolkit==3.0.43
|
118 |
+
protobuf==3.20.3
|
119 |
+
psutil==5.9.7
|
120 |
+
ptyprocess==0.7.0
|
121 |
+
pyannote.core==5.0.0
|
122 |
+
pyannote.database==5.0.1
|
123 |
+
pyannote.metrics==3.2.1
|
124 |
+
pyasn1==0.5.1
|
125 |
+
pyasn1-modules==0.3.0
|
126 |
+
pybind11==2.11.1
|
127 |
+
pybtex==0.24.0
|
128 |
+
pybtex-docutils==1.0.3
|
129 |
+
pycodestyle==2.9.1
|
130 |
+
pycparser==2.21
|
131 |
+
pydantic==1.10.13
|
132 |
+
pyDeprecate==0.3.1
|
133 |
+
pydub==0.25.1
|
134 |
+
pyflakes==2.5.0
|
135 |
+
Pygments==2.17.2
|
136 |
+
pyparsing==3.1.1
|
137 |
+
pypinyin==0.50.0
|
138 |
+
PySocks==1.7.1
|
139 |
+
pystoi==0.4.1
|
140 |
+
pytest==7.4.4
|
141 |
+
pytest-runner==6.0.1
|
142 |
+
python-dateutil==2.8.2
|
143 |
+
pytorch-lightning==1.5.10
|
144 |
+
pytz==2023.3.post1
|
145 |
+
pyupgrade==3.3.2
|
146 |
+
PyYAML==5.4.1
|
147 |
+
rapidfuzz==3.4.0
|
148 |
+
regex==2023.12.25
|
149 |
+
requests==2.31.0
|
150 |
+
requests-oauthlib==1.3.1
|
151 |
+
rich==13.7.0
|
152 |
+
rsa==4.9
|
153 |
+
ruamel.yaml==0.18.5
|
154 |
+
ruamel.yaml.clib==0.2.8
|
155 |
+
s3transfer==0.8.2
|
156 |
+
sacrebleu==2.4.0
|
157 |
+
sacremoses==0.0.53
|
158 |
+
safetensors==0.4.1
|
159 |
+
scikit-learn==1.0.2
|
160 |
+
scipy==1.7.3
|
161 |
+
sentencepiece==0.1.99
|
162 |
+
sentry-sdk==1.39.2
|
163 |
+
setproctitle==1.3.3
|
164 |
+
shellingham==1.5.4
|
165 |
+
six==1.16.0
|
166 |
+
smmap==5.0.1
|
167 |
+
snowballstemmer==2.2.0
|
168 |
+
sortedcontainers==2.4.0
|
169 |
+
soundfile==0.12.1
|
170 |
+
soupsieve==2.4.1
|
171 |
+
sox==1.4.1
|
172 |
+
soxr==0.3.7
|
173 |
+
Sphinx==5.3.0
|
174 |
+
sphinxcontrib-applehelp==1.0.2
|
175 |
+
sphinxcontrib-bibtex==2.6.2
|
176 |
+
sphinxcontrib-devhelp==1.0.2
|
177 |
+
sphinxcontrib-htmlhelp==2.0.0
|
178 |
+
sphinxcontrib-jsmath==1.0.1
|
179 |
+
sphinxcontrib-qthelp==1.0.3
|
180 |
+
sphinxcontrib-serializinghtml==1.1.5
|
181 |
+
sympy==1.10.1
|
182 |
+
tabulate==0.9.0
|
183 |
+
tensorboard==2.11.2
|
184 |
+
tensorboard-data-server==0.6.1
|
185 |
+
tensorboard-plugin-wit==1.8.1
|
186 |
+
threadpoolctl==3.1.0
|
187 |
+
tokenize-rt==5.0.0
|
188 |
+
tokenizers==0.13.3
|
189 |
+
toml==0.10.2
|
190 |
+
tomli==2.0.1
|
191 |
+
torch==1.12.1+cu116
|
192 |
+
torch-stft==0.1.4
|
193 |
+
torchaudio==0.12.1+cu116
|
194 |
+
torchmetrics==0.11.4
|
195 |
+
torchvision==0.13.1+cu116
|
196 |
+
tqdm==4.66.1
|
197 |
+
traitlets==5.9.0
|
198 |
+
transformers==4.30.2
|
199 |
+
typed-ast==1.5.5
|
200 |
+
typer==0.9.0
|
201 |
+
typing_extensions==4.7.1
|
202 |
+
Unidecode==1.3.8
|
203 |
+
urllib3==1.26.18
|
204 |
+
wandb==0.16.2
|
205 |
+
wcwidth==0.2.13
|
206 |
+
webdataset==0.1.62
|
207 |
+
Werkzeug==2.2.3
|
208 |
+
wget==3.2
|
209 |
+
widgetsnbextension==4.0.9
|
210 |
+
wrapt==1.16.0
|
211 |
+
yarg==0.1.9
|
212 |
+
yarl==1.9.4
|
213 |
+
youtokentome==1.0.6
|
214 |
+
zipp==3.15.0
|
train.py
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
import nemo.collections.asr as nemo_asr
|
5 |
+
import pytorch_lightning as ptl
|
6 |
+
from nemo.utils import exp_manager, logging
|
7 |
+
from omegaconf import OmegaConf, open_dict
|
8 |
+
|
9 |
+
|
10 |
+
def train_model(train_manifest: str, val_manifest: str, accelerator: str, batch_size: int, num_epochs: int, model_save_path: str = None,) -> None:
|
11 |
+
|
12 |
+
# Loading a STT Quartznet 15x5 model
|
13 |
+
model = nemo_asr.models.ASRModel.from_pretrained("stt_en_quartznet15x5")
|
14 |
+
# New vocabulary for a model
|
15 |
+
new_vocabulary = [
|
16 |
+
" ",
|
17 |
+
"а",
|
18 |
+
"б",
|
19 |
+
"в",
|
20 |
+
"г",
|
21 |
+
"д",
|
22 |
+
"е",
|
23 |
+
"ж",
|
24 |
+
"з",
|
25 |
+
"и",
|
26 |
+
"й",
|
27 |
+
"к",
|
28 |
+
"л",
|
29 |
+
"м",
|
30 |
+
"н",
|
31 |
+
"о",
|
32 |
+
"п",
|
33 |
+
"р",
|
34 |
+
"с",
|
35 |
+
"т",
|
36 |
+
"у",
|
37 |
+
"ф",
|
38 |
+
"х",
|
39 |
+
"ц",
|
40 |
+
"ч",
|
41 |
+
"ш",
|
42 |
+
"щ",
|
43 |
+
"ъ",
|
44 |
+
"ы",
|
45 |
+
"ь",
|
46 |
+
"э",
|
47 |
+
"ю",
|
48 |
+
"я",
|
49 |
+
"і",
|
50 |
+
"ғ",
|
51 |
+
"қ",
|
52 |
+
"ң",
|
53 |
+
"ү",
|
54 |
+
"ұ",
|
55 |
+
"һ",
|
56 |
+
"ә",
|
57 |
+
"ө",
|
58 |
+
]
|
59 |
+
|
60 |
+
with open_dict(model.cfg):
|
61 |
+
# Setting up the labels and sample rate
|
62 |
+
model.cfg.labels = new_vocabulary
|
63 |
+
model.cfg.sample_rate = 16000
|
64 |
+
|
65 |
+
# Train dataset
|
66 |
+
model.cfg.train_ds.manifest_filepath = train_manifest
|
67 |
+
model.cfg.train_ds.labels = new_vocabulary
|
68 |
+
model.cfg.train_ds.normalize_transcripts = False
|
69 |
+
model.cfg.train_ds.batch_size = batch_size
|
70 |
+
model.cfg.train_ds.num_workers = 10
|
71 |
+
model.cfg.train_ds.pin_memory = True
|
72 |
+
model.cfg.train_ds.trim_silence = True
|
73 |
+
|
74 |
+
# Validation dataset
|
75 |
+
model.cfg.validation_ds.manifest_filepath = val_manifest
|
76 |
+
model.cfg.validation_ds.labels = new_vocabulary
|
77 |
+
model.cfg.validation_ds.normalize_transcripts = False
|
78 |
+
model.cfg.validation_ds.batch_size = batch_size
|
79 |
+
model.cfg.validation_ds.num_workers = 10
|
80 |
+
model.cfg.validation_ds.pin_memory = True
|
81 |
+
model.cfg.validation_ds.trim_silence = True
|
82 |
+
|
83 |
+
# Setting up an optimizer and scheduler
|
84 |
+
model.cfg.optim.lr = 0.001
|
85 |
+
model.cfg.optim.betas = [0.8, 0.5]
|
86 |
+
model.cfg.optim.weight_decay = 0.001
|
87 |
+
model.cfg.optim.sched.warmup_steps = 500
|
88 |
+
model.cfg.optim.sched.min_lr = 1e-6
|
89 |
+
|
90 |
+
model.change_vocabulary(new_vocabulary=new_vocabulary)
|
91 |
+
model.setup_training_data(model.cfg.train_ds)
|
92 |
+
model.setup_validation_data(model.cfg.validation_ds)
|
93 |
+
|
94 |
+
# Unfreezing encoders to update the parameters
|
95 |
+
model.encoder.unfreeze()
|
96 |
+
logging.info("Model encoder has been un-frozen")
|
97 |
+
|
98 |
+
# Setting up data augmentation
|
99 |
+
model.spec_augmentation = model.from_config_dict(model.cfg.spec_augment)
|
100 |
+
|
101 |
+
# Setting up the metrics
|
102 |
+
model._wer.use_cer = True
|
103 |
+
model._wer.log_prediction = True
|
104 |
+
|
105 |
+
# Trainer
|
106 |
+
trainer = ptl.Trainer(
|
107 |
+
accelerator="gpu",
|
108 |
+
max_epochs=num_epochs,
|
109 |
+
accumulate_grad_batches=1,
|
110 |
+
enable_checkpointing=False,
|
111 |
+
logger=False,
|
112 |
+
log_every_n_steps=100,
|
113 |
+
check_val_every_n_epoch=1,
|
114 |
+
precision=16,
|
115 |
+
)
|
116 |
+
|
117 |
+
# Setting up model with the trainer
|
118 |
+
model.set_trainer(trainer)
|
119 |
+
|
120 |
+
# Experiment tracking
|
121 |
+
LANGUAGE = "kz"
|
122 |
+
config = exp_manager.ExpManagerConfig(
|
123 |
+
exp_dir=f"experiments/lang-{LANGUAGE}/",
|
124 |
+
name=f"ASR-Model-Language-{LANGUAGE}",
|
125 |
+
checkpoint_callback_params=exp_manager.CallbackParams(monitor="val_wer", mode="min", always_save_nemo=True, save_best_model=True,),
|
126 |
+
)
|
127 |
+
config = OmegaConf.structured(config)
|
128 |
+
exp_manager.exp_manager(trainer, config)
|
129 |
+
|
130 |
+
# Final Configuration
|
131 |
+
print("-----------------------------------------------------------")
|
132 |
+
print("Updated STT Model Configuration:")
|
133 |
+
print(OmegaConf.to_yaml(model.cfg))
|
134 |
+
print("-----------------------------------------------------------")
|
135 |
+
|
136 |
+
# # Fitting the model
|
137 |
+
trainer.fit(model)
|
138 |
+
|
139 |
+
# # Saving the model
|
140 |
+
if model_save_path:
|
141 |
+
model.save_to(f"{model_save_path}")
|
142 |
+
print(f"Model saved at path : {os.getcwd() + os.path.sep + model_save_path}")
|
143 |
+
|
144 |
+
|
145 |
+
if __name__ == "__main__":
|
146 |
+
# Parse command line arguments
|
147 |
+
parser = argparse.ArgumentParser()
|
148 |
+
parser.add_argument("--train_manifest", help="Path for train manifest JSON file.")
|
149 |
+
parser.add_argument("--val_manifest", help="Path for validation manifest JSON file.")
|
150 |
+
parser.add_argument("--accelerator", help="What accelerator type to use (cpu, gpu, tpu, etc.).")
|
151 |
+
parser.add_argument("--batch_size", type=int, default=1, help="Batch size of the dataset to train.")
|
152 |
+
parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs to train for.")
|
153 |
+
parser.add_argument("--model_save_path", default=None, help="Path for saving a trained model.")
|
154 |
+
args = parser.parse_args()
|
155 |
+
|
156 |
+
train_model(
|
157 |
+
train_manifest=args.train_manifest,
|
158 |
+
val_manifest=args.val_manifest,
|
159 |
+
accelerator=args.accelerator,
|
160 |
+
batch_size=args.batch_size,
|
161 |
+
num_epochs=args.num_epochs,
|
162 |
+
model_save_path=args.model_save_path,
|
163 |
+
)
|
transcribe.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from typing import Dict
|
3 |
+
|
4 |
+
import nemo.collections.asr as nemo_asr
|
5 |
+
|
6 |
+
|
7 |
+
def predict_model(model_path: str, audio_file_path: str) -> Dict:
|
8 |
+
# Restore the ASR model from the provided path
|
9 |
+
model = nemo_asr.models.ASRModel.restore_from(restore_path=model_path)
|
10 |
+
# Transcribe the given audio file
|
11 |
+
text = model.transcribe([audio_file_path])
|
12 |
+
print({"result": text[0]})
|
13 |
+
|
14 |
+
|
15 |
+
if __name__ == "__main__":
|
16 |
+
# Parse command line arguments
|
17 |
+
parser = argparse.ArgumentParser()
|
18 |
+
parser.add_argument("--model_path", default=None, help="Path to a model to evaluate.")
|
19 |
+
parser.add_argument("--audio_file_path", help="Path for train manifest JSON file.")
|
20 |
+
args = parser.parse_args()
|
21 |
+
|
22 |
+
predict_model(model_path=args.model_path, audio_file_path=args.audio_file_path)
|