transiteration commited on
Commit
aeeafe2
1 Parent(s): 8ea3efa

Delete transcribe_speech.py

Browse files
Files changed (1) hide show
  1. transcribe_speech.py +0 -173
transcribe_speech.py DELETED
@@ -1,173 +0,0 @@
1
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import contextlib
16
- import glob
17
- import json
18
- import os
19
- from dataclasses import dataclass
20
- from typing import Optional
21
-
22
- import pytorch_lightning as pl
23
- import torch
24
- from omegaconf import OmegaConf
25
-
26
- from nemo.collections.asr.metrics.rnnt_wer import RNNTDecodingConfig
27
- from nemo.collections.asr.metrics.wer import word_error_rate
28
- from nemo.collections.asr.models import ASRModel
29
- from nemo.core.config import hydra_runner
30
- from nemo.utils import logging, model_utils
31
-
32
-
33
- """
34
- # Transcribe audio
35
- # Arguments
36
- # model_path: path to .nemo ASR checkpoint
37
- # pretrained_name: name of pretrained ASR model (from NGC registry)
38
- # audio_dir: path to directory with audio files
39
- # dataset_manifest: path to dataset JSON manifest file (in NeMo format)
40
- #
41
- # ASR model can be specified by either "model_path" or "pretrained_name".
42
- # Data for transcription can be defined with either "audio_dir" or "dataset_manifest".
43
- # Results are returned in a JSON manifest file.
44
-
45
- python transcribe_speech.py \
46
- model_path=null \
47
- pretrained_name=null \
48
- audio_dir="" \
49
- dataset_manifest="" \
50
- output_filename=""
51
- """
52
-
53
-
54
- @dataclass
55
- class TranscriptionConfig:
56
- # Required configs
57
- model_path: Optional[str] = None # Path to a .nemo file
58
- pretrained_name: Optional[str] = None # Name of a pretrained model
59
- audio_dir: Optional[str] = None # Path to a directory which contains audio files
60
- dataset_manifest: Optional[str] = None # Path to dataset's JSON manifest
61
-
62
- # General configs
63
- output_filename: Optional[str] = None
64
- batch_size: int = 32
65
- cuda: Optional[bool] = None # will switch to cuda if available, defaults to CPU otherwise
66
- amp: bool = False
67
- audio_type: str = "wav"
68
-
69
- # decoding strategy for RNNT models
70
- rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig()
71
-
72
-
73
- @hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig)
74
- def main(cfg: TranscriptionConfig):
75
- logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')
76
-
77
- if cfg.model_path is None and cfg.pretrained_name is None:
78
- raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!")
79
- if cfg.audio_dir is None and cfg.dataset_manifest is None:
80
- raise ValueError("Both cfg.audio_dir and cfg.dataset_manifest cannot be None!")
81
-
82
- # setup GPU
83
- if cfg.cuda is None:
84
- cfg.cuda = torch.cuda.is_available()
85
-
86
- if type(cfg.cuda) == int:
87
- device_id = int(cfg.cuda)
88
- else:
89
- device_id = 0
90
-
91
- device = torch.device(f'cuda:{device_id}' if cfg.cuda else 'cpu')
92
-
93
- # setup model
94
- if cfg.model_path is not None:
95
- # restore model from .nemo file path
96
- model_cfg = ASRModel.restore_from(restore_path=cfg.model_path, return_config=True)
97
- classpath = model_cfg.target # original class path
98
- imported_class = model_utils.import_class_by_path(classpath) # type: ASRModel
99
- logging.info(f"Restoring model : {imported_class.__name__}")
100
- asr_model = imported_class.restore_from(restore_path=cfg.model_path, map_location=device) # type: ASRModel
101
- model_name = os.path.splitext(os.path.basename(cfg.model_path))[0]
102
- else:
103
- # restore model by name
104
- asr_model = ASRModel.from_pretrained(model_name=cfg.pretrained_name, map_location=device) # type: ASRModel
105
- model_name = cfg.pretrained_name
106
-
107
- trainer = pl.Trainer(gpus=int(cfg.cuda))
108
- asr_model.set_trainer(trainer)
109
- asr_model = asr_model.eval()
110
-
111
- # Setup decoding strategy
112
- if hasattr(asr_model, 'change_decoding_strategy'):
113
- asr_model.change_decoding_strategy(cfg.rnnt_decoding)
114
-
115
- # get audio filenames
116
- if cfg.audio_dir is not None:
117
- filepaths = list(glob.glob(os.path.join(cfg.audio_dir, f"*.{cfg.audio_type}")))
118
- else:
119
- # get filenames from manifest
120
- filepaths = []
121
- references = []
122
- with open(cfg.dataset_manifest, 'r', encoding='utf-8') as f:
123
- for line in f:
124
- item = json.loads(line)
125
- filepaths.append(item['audio_filepath'])
126
- references.append(item['text'])
127
- logging.info(f"\nTranscribing {len(filepaths)} files...\n")
128
-
129
- # setup AMP (optional)
130
- if cfg.amp and torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'):
131
- logging.info("AMP enabled!\n")
132
- autocast = torch.cuda.amp.autocast
133
- else:
134
-
135
- @contextlib.contextmanager
136
- def autocast():
137
- yield
138
-
139
- # transcribe audio
140
- with autocast():
141
- with torch.no_grad():
142
- transcriptions = asr_model.transcribe(filepaths, batch_size=cfg.batch_size)
143
- logging.info(f"Finished transcribing {len(filepaths)} files !")
144
-
145
- wer_value = word_error_rate(hypotheses=transcriptions, references=references, use_cer=False)
146
- logging.info(f'Got WER of {wer_value}. Tolerance was 1.0')
147
-
148
- if cfg.output_filename is None:
149
- # create default output filename
150
- if cfg.audio_dir is not None:
151
- cfg.output_filename = os.path.dirname(os.path.join(cfg.audio_dir, '.')) + '.json'
152
- else:
153
- cfg.output_filename = cfg.dataset_manifest.replace('.json', f'_{model_name}.json')
154
-
155
- logging.info(f"Writing transcriptions into file: {cfg.output_filename}")
156
-
157
- with open(cfg.output_filename, 'w', encoding='utf-8') as f:
158
- if cfg.audio_dir is not None:
159
- for idx, text in enumerate(transcriptions):
160
- item = {'audio_filepath': filepaths[idx], 'pred_text': text}
161
- f.write(json.dumps(item) + "\n")
162
- else:
163
- with open(cfg.dataset_manifest, 'r', encoding='utf-8') as fr:
164
- for idx, line in enumerate(fr):
165
- item = json.loads(line)
166
- item['pred_text'] = transcriptions[idx]
167
- f.write(json.dumps(item, ensure_ascii=False) + "\n")
168
-
169
- logging.info("Finished writing predictions !")
170
-
171
-
172
- if __name__ == '__main__':
173
- main() # noqa pylint: disable=no-value-for-parameter