transiteration commited on
Commit
5072015
1 Parent(s): ca961de

Upload transcribe_speech.py

Browse files
Files changed (1) hide show
  1. transcribe_speech.py +173 -0
transcribe_speech.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import contextlib
16
+ import glob
17
+ import json
18
+ import os
19
+ from dataclasses import dataclass
20
+ from typing import Optional
21
+
22
+ import pytorch_lightning as pl
23
+ import torch
24
+ from omegaconf import OmegaConf
25
+
26
+ from nemo.collections.asr.metrics.rnnt_wer import RNNTDecodingConfig
27
+ from nemo.collections.asr.metrics.wer import word_error_rate
28
+ from nemo.collections.asr.models import ASRModel
29
+ from nemo.core.config import hydra_runner
30
+ from nemo.utils import logging, model_utils
31
+
32
+
33
+ """
34
+ # Transcribe audio
35
+ # Arguments
36
+ # model_path: path to .nemo ASR checkpoint
37
+ # pretrained_name: name of pretrained ASR model (from NGC registry)
38
+ # audio_dir: path to directory with audio files
39
+ # dataset_manifest: path to dataset JSON manifest file (in NeMo format)
40
+ #
41
+ # ASR model can be specified by either "model_path" or "pretrained_name".
42
+ # Data for transcription can be defined with either "audio_dir" or "dataset_manifest".
43
+ # Results are returned in a JSON manifest file.
44
+
45
+ python transcribe_speech.py \
46
+ model_path=null \
47
+ pretrained_name=null \
48
+ audio_dir="" \
49
+ dataset_manifest="" \
50
+ output_filename=""
51
+ """
52
+
53
+
54
+ @dataclass
55
+ class TranscriptionConfig:
56
+ # Required configs
57
+ model_path: Optional[str] = None # Path to a .nemo file
58
+ pretrained_name: Optional[str] = None # Name of a pretrained model
59
+ audio_dir: Optional[str] = None # Path to a directory which contains audio files
60
+ dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest
61
+
62
+ # General configs
63
+ output_filename: Optional[str] = None
64
+ batch_size: int = 32
65
+ cuda: Optional[bool] = None # will switch to cuda if available, defaults to CPU otherwise
66
+ amp: bool = False
67
+ audio_type: str = "wav"
68
+
69
+ # decoding strategy for RNNT models
70
+ rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig()
71
+
72
+
73
+ @hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig)
74
+ def main(cfg: TranscriptionConfig):
75
+ logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
76
+
77
+ if cfg.model_path is None and cfg.pretrained_name is None:
78
+ raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!")
79
+ if cfg.audio_dir is None and cfg.dataset_manifest is None:
80
+ raise ValueError("Both cfg.audio_dir and cfg.dataset_manifest cannot be None!")
81
+
82
+ # setup GPU
83
+ if cfg.cuda is None:
84
+ cfg.cuda = torch.cuda.is_available()
85
+
86
+ if type(cfg.cuda) == int:
87
+ device_id = int(cfg.cuda)
88
+ else:
89
+ device_id = 0
90
+
91
+ device = torch.device(f'cuda:{device_id}' if cfg.cuda else 'cpu')
92
+
93
+ # setup model
94
+ if cfg.model_path is not None:
95
+ # restore model from .nemo file path
96
+ model_cfg = ASRModel.restore_from(restore_path=cfg.model_path, return_config=True)
97
+ classpath = model_cfg.target # original class path
98
+ imported_class = model_utils.import_class_by_path(classpath) # type: ASRModel
99
+ logging.info(f"Restoring model : {imported_class.__name__}")
100
+ asr_model = imported_class.restore_from(restore_path=cfg.model_path, map_location=device) # type: ASRModel
101
+ model_name = os.path.splitext(os.path.basename(cfg.model_path))[0]
102
+ else:
103
+ # restore model by name
104
+ asr_model = ASRModel.from_pretrained(model_name=cfg.pretrained_name, map_location=device) # type: ASRModel
105
+ model_name = cfg.pretrained_name
106
+
107
+ trainer = pl.Trainer(gpus=int(cfg.cuda))
108
+ asr_model.set_trainer(trainer)
109
+ asr_model = asr_model.eval()
110
+
111
+ # Setup decoding strategy
112
+ if hasattr(asr_model, 'change_decoding_strategy'):
113
+ asr_model.change_decoding_strategy(cfg.rnnt_decoding)
114
+
115
+ # get audio filenames
116
+ if cfg.audio_dir is not None:
117
+ filepaths = list(glob.glob(os.path.join(cfg.audio_dir, f"*.{cfg.audio_type}")))
118
+ else:
119
+ # get filenames from manifest
120
+ filepaths = []
121
+ references = []
122
+ with open(cfg.dataset_manifest, 'r', encoding='utf-8') as f:
123
+ for line in f:
124
+ item = json.loads(line)
125
+ filepaths.append(item['audio_filepath'])
126
+ references.append(item['text'])
127
+ logging.info(f"\nTranscribing {len(filepaths)} files...\n")
128
+
129
+ # setup AMP (optional)
130
+ if cfg.amp and torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'):
131
+ logging.info("AMP enabled!\n")
132
+ autocast = torch.cuda.amp.autocast
133
+ else:
134
+
135
+ @contextlib.contextmanager
136
+ def autocast():
137
+ yield
138
+
139
+ # transcribe audio
140
+ with autocast():
141
+ with torch.no_grad():
142
+ transcriptions = asr_model.transcribe(filepaths, batch_size=cfg.batch_size)
143
+ logging.info(f"Finished transcribing {len(filepaths)} files !")
144
+
145
+ wer_value = word_error_rate(hypotheses=transcriptions, references=references, use_cer=False)
146
+ logging.info(f'Got WER of {wer_value}. Tolerance was 1.0')
147
+
148
+ if cfg.output_filename is None:
149
+ # create default output filename
150
+ if cfg.audio_dir is not None:
151
+ cfg.output_filename = os.path.dirname(os.path.join(cfg.audio_dir, '.')) + '.json'
152
+ else:
153
+ cfg.output_filename = cfg.dataset_manifest.replace('.json', f'_{model_name}.json')
154
+
155
+ logging.info(f"Writing transcriptions into file: {cfg.output_filename}")
156
+
157
+ with open(cfg.output_filename, 'w', encoding='utf-8') as f:
158
+ if cfg.audio_dir is not None:
159
+ for idx, text in enumerate(transcriptions):
160
+ item = {'audio_filepath': filepaths[idx], 'pred_text': text}
161
+ f.write(json.dumps(item) + "\n")
162
+ else:
163
+ with open(cfg.dataset_manifest, 'r', encoding='utf-8') as fr:
164
+ for idx, line in enumerate(fr):
165
+ item = json.loads(line)
166
+ item['pred_text'] = transcriptions[idx]
167
+ f.write(json.dumps(item, ensure_ascii=False) + "\n")
168
+
169
+ logging.info("Finished writing predictions !")
170
+
171
+
172
+ if __name__ == '__main__':
173
+ main() # noqa pylint: disable=no-value-for-parameter