import argparse import logging import os import re import torch import torchaudio from ctc_forced_aligner import ( generate_emissions, get_alignments, get_spans, load_alignment_model, postprocess_results, preprocess_text, ) from deepmultilingualpunctuation import PunctuationModel from nemo.collections.asr.models.msdd_models import NeuralDiarizer from helpers import ( cleanup, create_config, get_realigned_ws_mapping_with_punctuation, get_sentences_speaker_mapping, get_speaker_aware_transcript, get_words_speaker_mapping, langs_to_iso, punct_model_langs, whisper_langs, write_srt, ) from transcription_helpers import transcribe_batched mtypes = {"cpu": "int8", "cuda": "float16"} # Initialize parser parser = argparse.ArgumentParser() parser.add_argument( "-a", "--audio", help="name of the target audio file", required=True ) parser.add_argument( "--no-stem", action="store_false", dest="stemming", default=True, help="Disables source separation." "This helps with long files that don't contain a lot of music.", ) parser.add_argument( "--suppress_numerals", action="store_true", dest="suppress_numerals", default=False, help="Suppresses Numerical Digits." "This helps the diarization accuracy but converts all digits into written text.", ) parser.add_argument( "--whisper-model", dest="model_name", default="medium.en", help="name of the Whisper model to use", ) parser.add_argument( "--batch-size", type=int, dest="batch_size", default=8, help="Batch size for batched inference, reduce if you run out of memory, set to 0 for non-batched inference", ) parser.add_argument( "--language", type=str, default=None, choices=whisper_langs, help="Language spoken in the audio, specify None to perform language detection", ) parser.add_argument( "--device", dest="device", default="cuda" if torch.cuda.is_available() else "cpu", help="if you have a GPU use 'cuda', otherwise 'cpu'", ) args = parser.parse_args() if args.stemming: # Isolate vocals from the rest of the audio return_code = os.system( f'python3 -m demucs.separate -n htdemucs --two-stems=vocals "{args.audio}" -o "temp_outputs"' ) if return_code != 0: logging.warning( "Source splitting failed, using original audio file. Use --no-stem argument to disable it." ) vocal_target = args.audio else: vocal_target = os.path.join( "temp_outputs", "htdemucs", os.path.splitext(os.path.basename(args.audio))[0], "vocals.wav", ) else: vocal_target = args.audio # Transcribe the audio file whisper_results, language, audio_waveform = transcribe_batched( vocal_target, args.language, args.batch_size, args.model_name, mtypes[args.device], args.suppress_numerals, args.device, ) # Forced Alignment alignment_model, alignment_tokenizer, alignment_dictionary = load_alignment_model( args.device, dtype=torch.float16 if args.device == "cuda" else torch.float32, ) audio_waveform = ( torch.from_numpy(audio_waveform) .to(alignment_model.dtype) .to(alignment_model.device) ) emissions, stride = generate_emissions( alignment_model, audio_waveform, batch_size=args.batch_size ) del alignment_model torch.cuda.empty_cache() full_transcript = "".join(segment["text"] for segment in whisper_results) tokens_starred, text_starred = preprocess_text( full_transcript, romanize=True, language=langs_to_iso[language], ) segments, scores, blank_id = get_alignments( emissions, tokens_starred, alignment_dictionary, ) spans = get_spans(tokens_starred, segments, alignment_tokenizer.decode(blank_id)) word_timestamps = postprocess_results(text_starred, spans, stride, scores) # convert audio to mono for NeMo combatibility ROOT = os.getcwd() temp_path = os.path.join(ROOT, "temp_outputs") os.makedirs(temp_path, exist_ok=True) torchaudio.save( os.path.join(temp_path, "mono_file.wav"), audio_waveform.cpu().unsqueeze(0).float(), 16000, channels_first=True, ) # Initialize NeMo MSDD diarization model msdd_model = NeuralDiarizer(cfg=create_config(temp_path)).to(args.device) msdd_model.diarize() del msdd_model torch.cuda.empty_cache() # Reading timestamps <> Speaker Labels mapping speaker_ts = [] with open(os.path.join(temp_path, "pred_rttms", "mono_file.rttm"), "r") as f: lines = f.readlines() for line in lines: line_list = line.split(" ") s = int(float(line_list[5]) * 1000) e = s + int(float(line_list[8]) * 1000) speaker_ts.append([s, e, int(line_list[11].split("_")[-1])]) wsm = get_words_speaker_mapping(word_timestamps, speaker_ts, "start") if language in punct_model_langs: # restoring punctuation in the transcript to help realign the sentences punct_model = PunctuationModel(model="kredor/punctuate-all") words_list = list(map(lambda x: x["word"], wsm)) labled_words = punct_model.predict(words_list, chunk_size=230) ending_puncts = ".?!" model_puncts = ".,;:!?" # We don't want to punctuate U.S.A. with a period. Right? is_acronym = lambda x: re.fullmatch(r"\b(?:[a-zA-Z]\.){2,}", x) for word_dict, labeled_tuple in zip(wsm, labled_words): word = word_dict["word"] if ( word and labeled_tuple[1] in ending_puncts and (word[-1] not in model_puncts or is_acronym(word)) ): word += labeled_tuple[1] if word.endswith(".."): word = word.rstrip(".") word_dict["word"] = word else: logging.warning( f"Punctuation restoration is not available for {language} language. Using the original punctuation." ) wsm = get_realigned_ws_mapping_with_punctuation(wsm) ssm = get_sentences_speaker_mapping(wsm, speaker_ts) with open(f"{os.path.splitext(args.audio)[0]}.txt", "w", encoding="utf-8-sig") as f: get_speaker_aware_transcript(ssm, f) with open(f"{os.path.splitext(args.audio)[0]}.srt", "w", encoding="utf-8-sig") as srt: write_srt(ssm, srt) cleanup(temp_path)