|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import contextlib |
|
import glob |
|
import json |
|
import os |
|
from dataclasses import dataclass |
|
from typing import Optional |
|
|
|
import pytorch_lightning as pl |
|
import torch |
|
from omegaconf import OmegaConf |
|
|
|
from nemo.collections.asr.metrics.rnnt_wer import RNNTDecodingConfig |
|
from nemo.collections.asr.metrics.wer import word_error_rate |
|
from nemo.collections.asr.models import ASRModel |
|
from nemo.core.config import hydra_runner |
|
from nemo.utils import logging, model_utils |
|
|
|
|
|
""" |
|
# Transcribe audio |
|
# Arguments |
|
# model_path: path to .nemo ASR checkpoint |
|
# pretrained_name: name of pretrained ASR model (from NGC registry) |
|
# audio_dir: path to directory with audio files |
|
# dataset_manifest: path to dataset JSON manifest file (in NeMo format) |
|
# |
|
# ASR model can be specified by either "model_path" or "pretrained_name". |
|
# Data for transcription can be defined with either "audio_dir" or "dataset_manifest". |
|
# Results are returned in a JSON manifest file. |
|
|
|
python transcribe_speech.py \ |
|
model_path=null \ |
|
pretrained_name=null \ |
|
audio_dir="" \ |
|
dataset_manifest="" \ |
|
output_filename="" |
|
""" |
|
|
|
|
|
@dataclass |
|
class TranscriptionConfig: |
|
|
|
model_path: Optional[str] = None |
|
pretrained_name: Optional[str] = None |
|
audio_dir: Optional[str] = None |
|
dataset_manifest: Optional[str] = None |
|
|
|
|
|
output_filename: Optional[str] = None |
|
batch_size: int = 32 |
|
cuda: Optional[bool] = None |
|
amp: bool = False |
|
audio_type: str = "wav" |
|
|
|
|
|
rnnt_decoding: RNNTDecodingConfig = RNNTDecodingConfig() |
|
|
|
|
|
@hydra_runner(config_name="TranscriptionConfig", schema=TranscriptionConfig) |
|
def main(cfg: TranscriptionConfig): |
|
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') |
|
|
|
if cfg.model_path is None and cfg.pretrained_name is None: |
|
raise ValueError("Both cfg.model_path and cfg.pretrained_name cannot be None!") |
|
if cfg.audio_dir is None and cfg.dataset_manifest is None: |
|
raise ValueError("Both cfg.audio_dir and cfg.dataset_manifest cannot be None!") |
|
|
|
|
|
if cfg.cuda is None: |
|
cfg.cuda = torch.cuda.is_available() |
|
|
|
if type(cfg.cuda) == int: |
|
device_id = int(cfg.cuda) |
|
else: |
|
device_id = 0 |
|
|
|
device = torch.device(f'cuda:{device_id}' if cfg.cuda else 'cpu') |
|
|
|
|
|
if cfg.model_path is not None: |
|
|
|
model_cfg = ASRModel.restore_from(restore_path=cfg.model_path, return_config=True) |
|
classpath = model_cfg.target |
|
imported_class = model_utils.import_class_by_path(classpath) |
|
logging.info(f"Restoring model : {imported_class.__name__}") |
|
asr_model = imported_class.restore_from(restore_path=cfg.model_path, map_location=device) |
|
model_name = os.path.splitext(os.path.basename(cfg.model_path))[0] |
|
else: |
|
|
|
asr_model = ASRModel.from_pretrained(model_name=cfg.pretrained_name, map_location=device) |
|
model_name = cfg.pretrained_name |
|
|
|
trainer = pl.Trainer(gpus=int(cfg.cuda)) |
|
asr_model.set_trainer(trainer) |
|
asr_model = asr_model.eval() |
|
|
|
|
|
if hasattr(asr_model, 'change_decoding_strategy'): |
|
asr_model.change_decoding_strategy(cfg.rnnt_decoding) |
|
|
|
|
|
if cfg.audio_dir is not None: |
|
filepaths = list(glob.glob(os.path.join(cfg.audio_dir, f"*.{cfg.audio_type}"))) |
|
else: |
|
|
|
filepaths = [] |
|
references = [] |
|
with open(cfg.dataset_manifest, 'r', encoding='utf-8') as f: |
|
for line in f: |
|
item = json.loads(line) |
|
filepaths.append(item['audio_filepath']) |
|
references.append(item['text']) |
|
logging.info(f"\nTranscribing {len(filepaths)} files...\n") |
|
|
|
|
|
if cfg.amp and torch.cuda.is_available() and hasattr(torch.cuda, 'amp') and hasattr(torch.cuda.amp, 'autocast'): |
|
logging.info("AMP enabled!\n") |
|
autocast = torch.cuda.amp.autocast |
|
else: |
|
|
|
@contextlib.contextmanager |
|
def autocast(): |
|
yield |
|
|
|
|
|
with autocast(): |
|
with torch.no_grad(): |
|
transcriptions = asr_model.transcribe(filepaths, batch_size=cfg.batch_size) |
|
logging.info(f"Finished transcribing {len(filepaths)} files !") |
|
|
|
wer_value = word_error_rate(hypotheses=transcriptions, references=references, use_cer=False) |
|
logging.info(f'Got WER of {wer_value}. Tolerance was 1.0') |
|
|
|
if cfg.output_filename is None: |
|
|
|
if cfg.audio_dir is not None: |
|
cfg.output_filename = os.path.dirname(os.path.join(cfg.audio_dir, '.')) + '.json' |
|
else: |
|
cfg.output_filename = cfg.dataset_manifest.replace('.json', f'_{model_name}.json') |
|
|
|
logging.info(f"Writing transcriptions into file: {cfg.output_filename}") |
|
|
|
with open(cfg.output_filename, 'w', encoding='utf-8') as f: |
|
if cfg.audio_dir is not None: |
|
for idx, text in enumerate(transcriptions): |
|
item = {'audio_filepath': filepaths[idx], 'pred_text': text} |
|
f.write(json.dumps(item) + "\n") |
|
else: |
|
with open(cfg.dataset_manifest, 'r', encoding='utf-8') as fr: |
|
for idx, line in enumerate(fr): |
|
item = json.loads(line) |
|
item['pred_text'] = transcriptions[idx] |
|
f.write(json.dumps(item, ensure_ascii=False) + "\n") |
|
|
|
logging.info("Finished writing predictions !") |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|