import gc import torch import omegaconf import subprocess import gradio as gr from pathlib import Path from functools import lru_cache from nemo.collections.asr.models import ASRModel from nemo.core.connectors.save_restore_connector import SaveRestoreConnector from dev_utils import DEV_LOGGER def get_model_cache(): @lru_cache(maxsize=3) def get_model(model): # model = 'nvidia/' + model asr_model = ASRModel.from_pretrained(model_name=model, map_location='cpu') # asr_model = ASRModel.restore_from(model, # map_location='cpu') DEV_LOGGER.info(f'USING MODEL: {asr_model}') cfg = asr_model.cfg.decoding with omegaconf.open_dict(cfg): cfg['strategy'] = "greedy_batch" cfg['preserve_alignments'] = True cfg['compute_timestamps'] = True asr_model.change_decoding_strategy(cfg) asr_model.eval() return asr_model return get_model def get_offsets_cache(): @lru_cache(maxsize=1) def get_offsets(model, audio_path): with torch.cuda.amp.autocast(enabled=False, dtype=torch.bfloat16): with torch.inference_mode(): hypotheses = model.transcribe(audio=[audio_path], return_hypotheses=True) if type(hypotheses) == tuple and len(hypotheses) == 2: hypotheses = hypotheses[0] del hypotheses[0].timestamp['timestep'] offsets = hypotheses[0].timestamp.copy() del hypotheses return offsets return get_offsets def adjust_decoding(asr_model, timestamps_type, frame_dur): if timestamps_type == 'segment': supported_punctuation = asr_model.tokenizer.supported_punctuation.intersection({'.', '!', '?'}) if not supported_punctuation: cfg = asr_model.cfg.decoding with omegaconf.open_dict(cfg): cfg['segment_gap_threshold'] = int(1 // frame_dur) asr_model.change_decoding_strategy(cfg) return asr_model def process_audio(input_file): output_file = Path(input_file) output_file = output_file.with_stem(output_file.stem + "_processed").as_posix() command = [ 'sox', input_file, output_file, 'channels', '1', 'rate', '16000' ] try: subprocess.run(command, check=True) return output_file except: raise gr.Error("Failed to convert audio to single channel and sampling rate to 16000") def get_aligned_transcription(model_name, audio_path, timestamps_type, get_model_func, get_offsets_func, device, preloaded_model=None): timestamps_type = 'segment' if timestamps_type == 'Segments' else 'word' if not preloaded_model: model = get_model_func(model_name) else: model = preloaded_model model.to(device) frame_dur = model.cfg.preprocessor.window_stride * model.cfg.encoder.subsampling_factor model = adjust_decoding(model, timestamps_type, frame_dur) try: offsets = get_offsets_func(model, audio_path) offsets = offsets[timestamps_type] timestamps = [] for unit in offsets: start_s = unit['start_offset'] * frame_dur end_s = unit['end_offset'] * frame_dur text = unit[timestamps_type] if timestamps_type == 'segment' and model.tokenizer.supports_capitalization: text = text.capitalize() timestamps.append((text, start_s, end_s)) except torch.cuda.OutOfMemoryError as e: raise gr.Error('CUDA out of memory. Please try a shorter audio.') finally: model.cpu() del model gc.collect() torch.cuda.empty_cache() return timestamps