transiteration commited on
Commit
0479abb
1 Parent(s): aeeafe2

Upload 8 files

Browse files
evaluate.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from typing import Dict
3
+
4
+ import nemo.collections.asr as nemo_asr
5
+ import torch
6
+ from omegaconf import open_dict
7
+
8
+
9
+ def evaluate_model(model_path: str, test_manifest: str, batch_size: int = 1) -> Dict:
10
+
11
+ # Determine the device (CPU or GPU)
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ # Restore the ASR model from the provided path
15
+ model = nemo_asr.models.ASRModel.restore_from(restore_path=model_path)
16
+ model.to(device)
17
+ model.eval()
18
+
19
+ # Update the model configuration for evaluation
20
+ with open_dict(model.cfg):
21
+ model.cfg.validation_ds.manifest_filepath = test_manifest
22
+ model.cfg.validation_ds.batch_size = batch_size
23
+
24
+ # Set up the test data using the updated configuration
25
+ model.setup_test_data(model.cfg.validation_ds)
26
+
27
+ wer_nums = []
28
+ wer_denoms = []
29
+
30
+ # Iterate through the test data
31
+ for test_batch in model.test_dataloader():
32
+ # Extract elements from the test batch
33
+ test_batch = [x for x in test_batch]
34
+ targets = test_batch[2].to(device)
35
+ targets_lengths = test_batch[3].to(device)
36
+ # Forward pass through the model
37
+ log_probs, encoded_len, greedy_predictions = model(input_signal=test_batch[0].to(device), input_signal_length=test_batch[1].to(device))
38
+ # Compute Word Error Rate (WER) and store results
39
+ model._wer.update(greedy_predictions, targets, targets_lengths)
40
+ _, wer_num, wer_denom = model._wer.compute()
41
+ model._wer.reset()
42
+ wer_nums.append(wer_num.detach().cpu().numpy())
43
+ wer_denoms.append(wer_denom.detach().cpu().numpy())
44
+ # Free up memory by deleting variables
45
+ del test_batch, log_probs, targets, targets_lengths, encoded_len, greedy_predictions
46
+
47
+ # Compute the WER score
48
+ wer_score = sum(wer_nums) / sum(wer_denoms)
49
+ print({"WER_score": wer_score})
50
+
51
+
52
+ if __name__ == "__main__":
53
+ # Parse command line arguments
54
+ parser = argparse.ArgumentParser()
55
+ parser.add_argument("--model_path", default=None, help="Path to a model to evaluate.")
56
+ parser.add_argument("--test_manifest", help="Path for train manifest JSON file.")
57
+ parser.add_argument("--batch_size", type=int, default=1, help="Batch size of the dataset to train.")
58
+ args = parser.parse_args()
59
+
60
+ evaluate_model(model_path=args.model_path, test_manifest=args.test_manifest, batch_size=args.batch_size)
ksc/.ksc-train.json.swp ADDED
Binary file (12.3 kB). View file
 
ksc/test_manifest.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {"audio_filepath": "ksc/test/crowdsourced/5f60e4ecdb745.wav", "text": "бірақ бізде онымен айналысатын қожалықтар саны саусақпен санарлық", "duration": 7.936}
2
+ {"audio_filepath": "ksc/test/crowdsourced/5f609b2ad370e.wav", "text": "солардың бірі маңғыстаулық шопан есет өтесов", "duration": 5.4613125}
3
+ {"audio_filepath": "ksc/test/crowdsourced/5f5682c1c4739.wav", "text": "иесіз жануарларды ату керек пе әлде асырау керек пе", "duration": 6.7413125}
ksc/train_manifest.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {"audio_filepath": "ksc/train/crowdsourced/5f5861e2baed1.wav", "text": "төрт бұрышты ошақ деп атайды", "duration": 2.816}
2
+ {"audio_filepath": "ksc/train/crowdsourced/5f2b0a559b15f.wav", "text": "оны адамдардың жақтырмаушылық пен көре алмаушылықтары бағдат қаласынан кетуге мәжбүр етті", "duration": 10.0693125}
3
+ {"audio_filepath": "ksc/train/crowdsourced/5f57300c9e48b.wav", "text": "кіші сордук ресейдегі өзен", "duration": 4.352}
ksc/val_manifest.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ {"audio_filepath": "ksc/val/crowdsourced/5f5a35cfcd7b2.wav", "text": "егер мынандай жағдайда әуе компаниясы көмек бере алмаса мемлекет қол ұшын созады", "duration": 8.192}
2
+ {"audio_filepath": "ksc/val/crowdsourced/5f60ff61277a3.wav", "text": "өйткені түсіру жұмыстары басталғанға дейін актерлер екі ай бойы арнайы каскадерлік жаттығулардан өткен", "duration": 7.936}
3
+ {"audio_filepath": "ksc/val/crowdsourced/5f5a5a5a199d7.wav", "text": "беру аяқталды дауыс санау басталды", "duration": 3.072}
requirements.txt ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ aiohttp==3.8.6
3
+ aiosignal==1.3.1
4
+ alabaster==0.7.13
5
+ antlr4-python3-runtime==4.9.3
6
+ appdirs==1.4.4
7
+ async-timeout==4.0.3
8
+ asynctest==0.13.0
9
+ attrdict==2.0.1
10
+ attrs==23.2.0
11
+ audioread==3.0.1
12
+ Babel==2.14.0
13
+ backcall==0.2.0
14
+ beautifulsoup4==4.12.3
15
+ black==19.10b0
16
+ boto3==1.33.13
17
+ botocore==1.33.13
18
+ braceexpand==0.1.7
19
+ cachetools==5.3.2
20
+ certifi==2023.11.17
21
+ cffi==1.15.1
22
+ charset-normalizer==3.3.2
23
+ click==8.0.2
24
+ colorama==0.4.6
25
+ comm==0.1.4
26
+ cycler==0.11.0
27
+ Cython==3.0.8
28
+ decorator==5.1.1
29
+ Distance==0.1.3
30
+ docker-pycreds==0.4.0
31
+ docopt==0.6.2
32
+ docutils==0.17.1
33
+ editdistance==0.6.2
34
+ exceptiongroup==1.2.0
35
+ fasttext==0.9.2
36
+ filelock==3.12.2
37
+ flake8==5.0.4
38
+ Flake8-pyproject==1.2.3
39
+ fonttools==4.38.0
40
+ frozendict==2.4.0
41
+ frozenlist==1.3.3
42
+ fsspec==2023.1.0
43
+ ftfy==6.1.1
44
+ future==0.18.3
45
+ g2p-en==2.1.0
46
+ gdown==4.7.3
47
+ gitdb==4.0.11
48
+ GitPython==3.1.41
49
+ google-auth==2.26.2
50
+ google-auth-oauthlib==0.4.6
51
+ grpcio==1.60.0
52
+ h5py==3.8.0
53
+ huggingface-hub==0.16.4
54
+ Hydra==2.5
55
+ hydra-core==1.3.2
56
+ idna==3.6
57
+ imagesize==1.4.1
58
+ importlib-metadata==4.2.0
59
+ importlib-resources==5.12.0
60
+ inflect==6.0.5
61
+ iniconfig==2.0.0
62
+ ipadic==1.0.0
63
+ ipython==7.34.0
64
+ ipywidgets==8.1.1
65
+ isort==4.3.21
66
+ jedi==0.19.1
67
+ jieba==0.42.1
68
+ Jinja2==3.1.3
69
+ jmespath==1.0.1
70
+ joblib==1.3.2
71
+ jupyterlab-widgets==3.0.9
72
+ kaldi-io==0.9.8
73
+ kaldi-python-io==1.2.2
74
+ kaldiio==2.18.0
75
+ kiwisolver==1.4.5
76
+ latexcodec==2.0.1
77
+ lazy_loader==0.3
78
+ librosa==0.10.1
79
+ llvmlite==0.39.1
80
+ lxml==5.1.0
81
+ Markdown==3.4.4
82
+ markdown-it-py==2.2.0
83
+ MarkupSafe==2.1.3
84
+ marshmallow==3.19.0
85
+ matplotlib==3.5.3
86
+ matplotlib-inline==0.1.6
87
+ mccabe==0.7.0
88
+ mdurl==0.1.2
89
+ mecab-python3==1.0.6
90
+ mpmath==1.3.0
91
+ msgpack==1.0.5
92
+ multidict==6.0.4
93
+ nemo-toolkit==1.7.0
94
+ nltk==3.8.1
95
+ numba==0.56.4
96
+ numpy==1.21.6
97
+ oauthlib==3.2.2
98
+ omegaconf==2.3.0
99
+ onnx==1.14.1
100
+ OpenCC==1.1.6
101
+ packaging==23.2
102
+ pandas==1.3.5
103
+ pangu==4.0.6.1
104
+ parameterized==0.9.0
105
+ parso==0.8.3
106
+ pathspec==0.11.2
107
+ pesq==0.0.4
108
+ pexpect==4.9.0
109
+ pickleshare==0.7.5
110
+ Pillow==9.5.0
111
+ pip-api==0.0.30
112
+ pipreqs==0.4.13
113
+ platformdirs==4.0.0
114
+ pluggy==1.2.0
115
+ pooch==1.8.0
116
+ portalocker==2.7.0
117
+ prompt-toolkit==3.0.43
118
+ protobuf==3.20.3
119
+ psutil==5.9.7
120
+ ptyprocess==0.7.0
121
+ pyannote.core==5.0.0
122
+ pyannote.database==5.0.1
123
+ pyannote.metrics==3.2.1
124
+ pyasn1==0.5.1
125
+ pyasn1-modules==0.3.0
126
+ pybind11==2.11.1
127
+ pybtex==0.24.0
128
+ pybtex-docutils==1.0.3
129
+ pycodestyle==2.9.1
130
+ pycparser==2.21
131
+ pydantic==1.10.13
132
+ pyDeprecate==0.3.1
133
+ pydub==0.25.1
134
+ pyflakes==2.5.0
135
+ Pygments==2.17.2
136
+ pyparsing==3.1.1
137
+ pypinyin==0.50.0
138
+ PySocks==1.7.1
139
+ pystoi==0.4.1
140
+ pytest==7.4.4
141
+ pytest-runner==6.0.1
142
+ python-dateutil==2.8.2
143
+ pytorch-lightning==1.5.10
144
+ pytz==2023.3.post1
145
+ pyupgrade==3.3.2
146
+ PyYAML==5.4.1
147
+ rapidfuzz==3.4.0
148
+ regex==2023.12.25
149
+ requests==2.31.0
150
+ requests-oauthlib==1.3.1
151
+ rich==13.7.0
152
+ rsa==4.9
153
+ ruamel.yaml==0.18.5
154
+ ruamel.yaml.clib==0.2.8
155
+ s3transfer==0.8.2
156
+ sacrebleu==2.4.0
157
+ sacremoses==0.0.53
158
+ safetensors==0.4.1
159
+ scikit-learn==1.0.2
160
+ scipy==1.7.3
161
+ sentencepiece==0.1.99
162
+ sentry-sdk==1.39.2
163
+ setproctitle==1.3.3
164
+ shellingham==1.5.4
165
+ six==1.16.0
166
+ smmap==5.0.1
167
+ snowballstemmer==2.2.0
168
+ sortedcontainers==2.4.0
169
+ soundfile==0.12.1
170
+ soupsieve==2.4.1
171
+ sox==1.4.1
172
+ soxr==0.3.7
173
+ Sphinx==5.3.0
174
+ sphinxcontrib-applehelp==1.0.2
175
+ sphinxcontrib-bibtex==2.6.2
176
+ sphinxcontrib-devhelp==1.0.2
177
+ sphinxcontrib-htmlhelp==2.0.0
178
+ sphinxcontrib-jsmath==1.0.1
179
+ sphinxcontrib-qthelp==1.0.3
180
+ sphinxcontrib-serializinghtml==1.1.5
181
+ sympy==1.10.1
182
+ tabulate==0.9.0
183
+ tensorboard==2.11.2
184
+ tensorboard-data-server==0.6.1
185
+ tensorboard-plugin-wit==1.8.1
186
+ threadpoolctl==3.1.0
187
+ tokenize-rt==5.0.0
188
+ tokenizers==0.13.3
189
+ toml==0.10.2
190
+ tomli==2.0.1
191
+ torch==1.12.1+cu116
192
+ torch-stft==0.1.4
193
+ torchaudio==0.12.1+cu116
194
+ torchmetrics==0.11.4
195
+ torchvision==0.13.1+cu116
196
+ tqdm==4.66.1
197
+ traitlets==5.9.0
198
+ transformers==4.30.2
199
+ typed-ast==1.5.5
200
+ typer==0.9.0
201
+ typing_extensions==4.7.1
202
+ Unidecode==1.3.8
203
+ urllib3==1.26.18
204
+ wandb==0.16.2
205
+ wcwidth==0.2.13
206
+ webdataset==0.1.62
207
+ Werkzeug==2.2.3
208
+ wget==3.2
209
+ widgetsnbextension==4.0.9
210
+ wrapt==1.16.0
211
+ yarg==0.1.9
212
+ yarl==1.9.4
213
+ youtokentome==1.0.6
214
+ zipp==3.15.0
train.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+
4
+ import nemo.collections.asr as nemo_asr
5
+ import pytorch_lightning as ptl
6
+ from nemo.utils import exp_manager, logging
7
+ from omegaconf import OmegaConf, open_dict
8
+
9
+
10
+ def train_model(train_manifest: str, val_manifest: str, accelerator: str, batch_size: int, num_epochs: int, model_save_path: str = None,) -> None:
11
+
12
+ # Loading a STT Quartznet 15x5 model
13
+ model = nemo_asr.models.ASRModel.from_pretrained("stt_en_quartznet15x5")
14
+ # New vocabulary for a model
15
+ new_vocabulary = [
16
+ " ",
17
+ "а",
18
+ "б",
19
+ "в",
20
+ "г",
21
+ "д",
22
+ "е",
23
+ "ж",
24
+ "з",
25
+ "и",
26
+ "й",
27
+ "к",
28
+ "л",
29
+ "м",
30
+ "н",
31
+ "о",
32
+ "п",
33
+ "р",
34
+ "с",
35
+ "т",
36
+ "у",
37
+ "ф",
38
+ "х",
39
+ "ц",
40
+ "ч",
41
+ "ш",
42
+ "щ",
43
+ "ъ",
44
+ "ы",
45
+ "ь",
46
+ "э",
47
+ "ю",
48
+ "я",
49
+ "і",
50
+ "ғ",
51
+ "қ",
52
+ "ң",
53
+ "ү",
54
+ "ұ",
55
+ "һ",
56
+ "ә",
57
+ "ө",
58
+ ]
59
+
60
+ with open_dict(model.cfg):
61
+ # Setting up the labels and sample rate
62
+ model.cfg.labels = new_vocabulary
63
+ model.cfg.sample_rate = 16000
64
+
65
+ # Train dataset
66
+ model.cfg.train_ds.manifest_filepath = train_manifest
67
+ model.cfg.train_ds.labels = new_vocabulary
68
+ model.cfg.train_ds.normalize_transcripts = False
69
+ model.cfg.train_ds.batch_size = batch_size
70
+ model.cfg.train_ds.num_workers = 10
71
+ model.cfg.train_ds.pin_memory = True
72
+ model.cfg.train_ds.trim_silence = True
73
+
74
+ # Validation dataset
75
+ model.cfg.validation_ds.manifest_filepath = val_manifest
76
+ model.cfg.validation_ds.labels = new_vocabulary
77
+ model.cfg.validation_ds.normalize_transcripts = False
78
+ model.cfg.validation_ds.batch_size = batch_size
79
+ model.cfg.validation_ds.num_workers = 10
80
+ model.cfg.validation_ds.pin_memory = True
81
+ model.cfg.validation_ds.trim_silence = True
82
+
83
+ # Setting up an optimizer and scheduler
84
+ model.cfg.optim.lr = 0.001
85
+ model.cfg.optim.betas = [0.8, 0.5]
86
+ model.cfg.optim.weight_decay = 0.001
87
+ model.cfg.optim.sched.warmup_steps = 500
88
+ model.cfg.optim.sched.min_lr = 1e-6
89
+
90
+ model.change_vocabulary(new_vocabulary=new_vocabulary)
91
+ model.setup_training_data(model.cfg.train_ds)
92
+ model.setup_validation_data(model.cfg.validation_ds)
93
+
94
+ # Unfreezing encoders to update the parameters
95
+ model.encoder.unfreeze()
96
+ logging.info("Model encoder has been un-frozen")
97
+
98
+ # Setting up data augmentation
99
+ model.spec_augmentation = model.from_config_dict(model.cfg.spec_augment)
100
+
101
+ # Setting up the metrics
102
+ model._wer.use_cer = True
103
+ model._wer.log_prediction = True
104
+
105
+ # Trainer
106
+ trainer = ptl.Trainer(
107
+ accelerator="gpu",
108
+ max_epochs=num_epochs,
109
+ accumulate_grad_batches=1,
110
+ enable_checkpointing=False,
111
+ logger=False,
112
+ log_every_n_steps=100,
113
+ check_val_every_n_epoch=1,
114
+ precision=16,
115
+ )
116
+
117
+ # Setting up model with the trainer
118
+ model.set_trainer(trainer)
119
+
120
+ # Experiment tracking
121
+ LANGUAGE = "kz"
122
+ config = exp_manager.ExpManagerConfig(
123
+ exp_dir=f"experiments/lang-{LANGUAGE}/",
124
+ name=f"ASR-Model-Language-{LANGUAGE}",
125
+ checkpoint_callback_params=exp_manager.CallbackParams(monitor="val_wer", mode="min", always_save_nemo=True, save_best_model=True,),
126
+ )
127
+ config = OmegaConf.structured(config)
128
+ exp_manager.exp_manager(trainer, config)
129
+
130
+ # Final Configuration
131
+ print("-----------------------------------------------------------")
132
+ print("Updated STT Model Configuration:")
133
+ print(OmegaConf.to_yaml(model.cfg))
134
+ print("-----------------------------------------------------------")
135
+
136
+ # # Fitting the model
137
+ trainer.fit(model)
138
+
139
+ # # Saving the model
140
+ if model_save_path:
141
+ model.save_to(f"{model_save_path}")
142
+ print(f"Model saved at path : {os.getcwd() + os.path.sep + model_save_path}")
143
+
144
+
145
+ if __name__ == "__main__":
146
+ # Parse command line arguments
147
+ parser = argparse.ArgumentParser()
148
+ parser.add_argument("--train_manifest", help="Path for train manifest JSON file.")
149
+ parser.add_argument("--val_manifest", help="Path for validation manifest JSON file.")
150
+ parser.add_argument("--accelerator", help="What accelerator type to use (cpu, gpu, tpu, etc.).")
151
+ parser.add_argument("--batch_size", type=int, default=1, help="Batch size of the dataset to train.")
152
+ parser.add_argument("--num_epochs", type=int, default=1, help="Number of epochs to train for.")
153
+ parser.add_argument("--model_save_path", default=None, help="Path for saving a trained model.")
154
+ args = parser.parse_args()
155
+
156
+ train_model(
157
+ train_manifest=args.train_manifest,
158
+ val_manifest=args.val_manifest,
159
+ accelerator=args.accelerator,
160
+ batch_size=args.batch_size,
161
+ num_epochs=args.num_epochs,
162
+ model_save_path=args.model_save_path,
163
+ )
transcribe.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from typing import Dict
3
+
4
+ import nemo.collections.asr as nemo_asr
5
+
6
+
7
+ def predict_model(model_path: str, audio_file_path: str) -> Dict:
8
+ # Restore the ASR model from the provided path
9
+ model = nemo_asr.models.ASRModel.restore_from(restore_path=model_path)
10
+ # Transcribe the given audio file
11
+ text = model.transcribe([audio_file_path])
12
+ print({"result": text[0]})
13
+
14
+
15
+ if __name__ == "__main__":
16
+ # Parse command line arguments
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument("--model_path", default=None, help="Path to a model to evaluate.")
19
+ parser.add_argument("--audio_file_path", help="Path for train manifest JSON file.")
20
+ args = parser.parse_args()
21
+
22
+ predict_model(model_path=args.model_path, audio_file_path=args.audio_file_path)