import json import logging import os import shutil import nltk import wget from omegaconf import OmegaConf from whisperx.alignment import DEFAULT_ALIGN_MODELS_HF, DEFAULT_ALIGN_MODELS_TORCH from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE punct_model_langs = [ "en", "fr", "de", "es", "it", "nl", "pt", "bg", "pl", "cs", "sk", "sl", ] wav2vec2_langs = list(DEFAULT_ALIGN_MODELS_TORCH.keys()) + list( DEFAULT_ALIGN_MODELS_HF.keys() ) whisper_langs = sorted(LANGUAGES.keys()) + sorted( [k.title() for k in TO_LANGUAGE_CODE.keys()] ) langs_to_iso = { "aa": "aar", "ab": "abk", "ae": "ave", "af": "afr", "ak": "aka", "am": "amh", "an": "arg", "ar": "ara", "as": "asm", "av": "ava", "ay": "aym", "az": "aze", "ba": "bak", "be": "bel", "bg": "bul", "bh": "bih", "bi": "bis", "bm": "bam", "bn": "ben", "bo": "tib", "br": "bre", "bs": "bos", "ca": "cat", "ce": "che", "ch": "cha", "co": "cos", "cr": "cre", "cs": "cze", "cu": "chu", "cv": "chv", "cy": "wel", "da": "dan", "de": "ger", "dv": "div", "dz": "dzo", "ee": "ewe", "el": "gre", "en": "eng", "eo": "epo", "es": "spa", "et": "est", "eu": "baq", "fa": "per", "ff": "ful", "fi": "fin", "fj": "fij", "fo": "fao", "fr": "fre", "fy": "fry", "ga": "gle", "gd": "gla", "gl": "glg", "gn": "grn", "gu": "guj", "gv": "glv", "ha": "hau", "he": "heb", "hi": "hin", "ho": "hmo", "hr": "hrv", "ht": "hat", "hu": "hun", "hy": "arm", "hz": "her", "ia": "ina", "id": "ind", "ie": "ile", "ig": "ibo", "ii": "iii", "ik": "ipk", "io": "ido", "is": "ice", "it": "ita", "iu": "iku", "ja": "jpn", "jv": "jav", "ka": "geo", "kg": "kon", "ki": "kik", "kj": "kua", "kk": "kaz", "kl": "kal", "km": "khm", "kn": "kan", "ko": "kor", "kr": "kau", "ks": "kas", "ku": "kur", "kv": "kom", "kw": "cor", "ky": "kir", "la": "lat", "lb": "ltz", "lg": "lug", "li": "lim", "ln": "lin", "lo": "lao", "lt": "lit", "lu": "lub", "lv": "lav", "mg": "mlg", "mh": "mah", "mi": "mao", "mk": "mac", "ml": "mal", "mn": "mon", "mr": "mar", "ms": "may", "mt": "mlt", "my": "bur", "na": "nau", "nb": "nob", "nd": "nde", "ne": "nep", "ng": "ndo", "nl": "dut", "nn": "nno", "no": "nor", "nr": "nbl", "nv": "nav", "ny": "nya", "oc": "oci", "oj": "oji", "om": "orm", "or": "ori", "os": "oss", "pa": "pan", "pi": "pli", "pl": "pol", "ps": "pus", "pt": "por", "qu": "que", "rm": "roh", "rn": "run", "ro": "rum", "ru": "rus", "rw": "kin", "sa": "san", "sc": "srd", "sd": "snd", "se": "sme", "sg": "sag", "si": "sin", "sk": "slo", "sl": "slv", "sm": "smo", "sn": "sna", "so": "som", "sq": "alb", "sr": "srp", "ss": "ssw", "st": "sot", "su": "sun", "sv": "swe", "sw": "swa", "ta": "tam", "te": "tel", "tg": "tgk", "th": "tha", "ti": "tir", "tk": "tuk", "tl": "tgl", "tn": "tsn", "to": "ton", "tr": "tur", "ts": "tso", "tt": "tat", "tw": "twi", "ty": "tah", "ug": "uig", "uk": "ukr", "ur": "urd", "uz": "uzb", "ve": "ven", "vi": "vie", "vo": "vol", "wa": "wln", "wo": "wol", "xh": "xho", "yi": "yid", "yo": "yor", "za": "zha", "zh": "chi", "zu": "zul", } def create_config(output_dir): DOMAIN_TYPE = "telephonic" # Can be meeting, telephonic, or general based on domain type of the audio file CONFIG_LOCAL_DIRECTORY = "nemo_msdd_configs" CONFIG_FILE_NAME = f"diar_infer_{DOMAIN_TYPE}.yaml" MODEL_CONFIG_PATH = os.path.join(CONFIG_LOCAL_DIRECTORY, CONFIG_FILE_NAME) if not os.path.exists(MODEL_CONFIG_PATH): os.makedirs(CONFIG_LOCAL_DIRECTORY, exist_ok=True) CONFIG_URL = f"https://raw.githubusercontent.com/NVIDIA/NeMo/main/examples/speaker_tasks/diarization/conf/inference/{CONFIG_FILE_NAME}" MODEL_CONFIG_PATH = wget.download(CONFIG_URL, MODEL_CONFIG_PATH) config = OmegaConf.load(MODEL_CONFIG_PATH) data_dir = os.path.join(output_dir, "data") os.makedirs(data_dir, exist_ok=True) meta = { "audio_filepath": os.path.join(output_dir, "mono_file.wav"), "offset": 0, "duration": None, "label": "infer", "text": "-", "rttm_filepath": None, "uem_filepath": None, } with open(os.path.join(data_dir, "input_manifest.json"), "w") as fp: json.dump(meta, fp) fp.write("\n") pretrained_vad = "vad_multilingual_marblenet" pretrained_speaker_model = "titanet_large" config.num_workers = 0 config.diarizer.manifest_filepath = os.path.join(data_dir, "input_manifest.json") config.diarizer.out_dir = ( output_dir # Directory to store intermediate files and prediction outputs ) config.diarizer.speaker_embeddings.model_path = pretrained_speaker_model config.diarizer.oracle_vad = ( False # compute VAD provided with model_path to vad config ) config.diarizer.clustering.parameters.oracle_num_speakers = False # Here, we use our in-house pretrained NeMo VAD model config.diarizer.vad.model_path = pretrained_vad config.diarizer.vad.parameters.onset = 0.8 config.diarizer.vad.parameters.offset = 0.6 config.diarizer.vad.parameters.pad_offset = -0.05 config.diarizer.msdd_model.model_path = ( "diar_msdd_telephonic" # Telephonic speaker diarization model ) return config def get_word_ts_anchor(s, e, option="start"): if option == "end": return e elif option == "mid": return (s + e) / 2 return s def get_words_speaker_mapping(wrd_ts, spk_ts, word_anchor_option="start"): s, e, sp = spk_ts[0] wrd_pos, turn_idx = 0, 0 wrd_spk_mapping = [] for wrd_dict in wrd_ts: ws, we, wrd = ( int(wrd_dict["start"] * 1000), int(wrd_dict["end"] * 1000), wrd_dict["text"], ) wrd_pos = get_word_ts_anchor(ws, we, word_anchor_option) while wrd_pos > float(e): turn_idx += 1 turn_idx = min(turn_idx, len(spk_ts) - 1) s, e, sp = spk_ts[turn_idx] if turn_idx == len(spk_ts) - 1: e = get_word_ts_anchor(ws, we, option="end") wrd_spk_mapping.append( {"word": wrd, "start_time": ws, "end_time": we, "speaker": sp} ) return wrd_spk_mapping sentence_ending_punctuations = ".?!" def get_first_word_idx_of_sentence(word_idx, word_list, speaker_list, max_words): is_word_sentence_end = ( lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations ) left_idx = word_idx while ( left_idx > 0 and word_idx - left_idx < max_words and speaker_list[left_idx - 1] == speaker_list[left_idx] and not is_word_sentence_end(left_idx - 1) ): left_idx -= 1 return left_idx if left_idx == 0 or is_word_sentence_end(left_idx - 1) else -1 def get_last_word_idx_of_sentence(word_idx, word_list, max_words): is_word_sentence_end = ( lambda x: x >= 0 and word_list[x][-1] in sentence_ending_punctuations ) right_idx = word_idx while ( right_idx < len(word_list) - 1 and right_idx - word_idx < max_words and not is_word_sentence_end(right_idx) ): right_idx += 1 return ( right_idx if right_idx == len(word_list) - 1 or is_word_sentence_end(right_idx) else -1 ) def get_realigned_ws_mapping_with_punctuation( word_speaker_mapping, max_words_in_sentence=50 ): is_word_sentence_end = ( lambda x: x >= 0 and word_speaker_mapping[x]["word"][-1] in sentence_ending_punctuations ) wsp_len = len(word_speaker_mapping) words_list, speaker_list = [], [] for k, line_dict in enumerate(word_speaker_mapping): word, speaker = line_dict["word"], line_dict["speaker"] words_list.append(word) speaker_list.append(speaker) k = 0 while k < len(word_speaker_mapping): line_dict = word_speaker_mapping[k] if ( k < wsp_len - 1 and speaker_list[k] != speaker_list[k + 1] and not is_word_sentence_end(k) ): left_idx = get_first_word_idx_of_sentence( k, words_list, speaker_list, max_words_in_sentence ) right_idx = ( get_last_word_idx_of_sentence( k, words_list, max_words_in_sentence - k + left_idx - 1 ) if left_idx > -1 else -1 ) if min(left_idx, right_idx) == -1: k += 1 continue spk_labels = speaker_list[left_idx : right_idx + 1] mod_speaker = max(set(spk_labels), key=spk_labels.count) if spk_labels.count(mod_speaker) < len(spk_labels) // 2: k += 1 continue speaker_list[left_idx : right_idx + 1] = [mod_speaker] * ( right_idx - left_idx + 1 ) k = right_idx k += 1 k, realigned_list = 0, [] while k < len(word_speaker_mapping): line_dict = word_speaker_mapping[k].copy() line_dict["speaker"] = speaker_list[k] realigned_list.append(line_dict) k += 1 return realigned_list def get_sentences_speaker_mapping(word_speaker_mapping, spk_ts): sentence_checker = nltk.tokenize.PunktSentenceTokenizer().text_contains_sentbreak s, e, spk = spk_ts[0] prev_spk = spk snts = [] snt = {"speaker": f"Speaker {spk}", "start_time": s, "end_time": e, "text": ""} for wrd_dict in word_speaker_mapping: wrd, spk = wrd_dict["word"], wrd_dict["speaker"] s, e = wrd_dict["start_time"], wrd_dict["end_time"] if spk != prev_spk or sentence_checker(snt["text"] + " " + wrd): snts.append(snt) snt = { "speaker": f"Speaker {spk}", "start_time": s, "end_time": e, "text": "", } else: snt["end_time"] = e snt["text"] += wrd + " " prev_spk = spk snts.append(snt) return snts def get_speaker_aware_transcript(sentences_speaker_mapping, f): previous_speaker = sentences_speaker_mapping[0]["speaker"] f.write(f"{previous_speaker}: ") for sentence_dict in sentences_speaker_mapping: speaker = sentence_dict["speaker"] sentence = sentence_dict["text"] # If this speaker doesn't match the previous one, start a new paragraph if speaker != previous_speaker: f.write(f"\n\n{speaker}: ") previous_speaker = speaker # No matter what, write the current sentence f.write(sentence + " ") def format_timestamp( milliseconds: float, always_include_hours: bool = False, decimal_marker: str = "." ): assert milliseconds >= 0, "non-negative timestamp expected" hours = milliseconds // 3_600_000 milliseconds -= hours * 3_600_000 minutes = milliseconds // 60_000 milliseconds -= minutes * 60_000 seconds = milliseconds // 1_000 milliseconds -= seconds * 1_000 hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" return ( f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" ) def write_srt(transcript, file): """ Write a transcript to a file in SRT format. """ for i, segment in enumerate(transcript, start=1): # write srt lines print( f"{i}\n" f"{format_timestamp(segment['start_time'], always_include_hours=True, decimal_marker=',')} --> " f"{format_timestamp(segment['end_time'], always_include_hours=True, decimal_marker=',')}\n" f"{segment['speaker']}: {segment['text'].strip().replace('-->', '->')}\n", file=file, flush=True, ) def find_numeral_symbol_tokens(tokenizer): numeral_symbol_tokens = [ -1, ] for token, token_id in tokenizer.get_vocab().items(): has_numeral_symbol = any(c in "0123456789%$£" for c in token) if has_numeral_symbol: numeral_symbol_tokens.append(token_id) return numeral_symbol_tokens def _get_next_start_timestamp(word_timestamps, current_word_index, final_timestamp): # if current word is the last word if current_word_index == len(word_timestamps) - 1: return word_timestamps[current_word_index]["start"] next_word_index = current_word_index + 1 while current_word_index < len(word_timestamps) - 1: if word_timestamps[next_word_index].get("start") is None: # if next word doesn't have a start timestamp # merge it with the current word and delete it word_timestamps[current_word_index]["word"] += ( " " + word_timestamps[next_word_index]["word"] ) word_timestamps[next_word_index]["word"] = None next_word_index += 1 if next_word_index == len(word_timestamps): return final_timestamp else: return word_timestamps[next_word_index]["start"] def filter_missing_timestamps( word_timestamps, initial_timestamp=0, final_timestamp=None ): # handle the first and last word if word_timestamps[0].get("start") is None: word_timestamps[0]["start"] = ( initial_timestamp if initial_timestamp is not None else 0 ) word_timestamps[0]["end"] = _get_next_start_timestamp( word_timestamps, 0, final_timestamp ) result = [ word_timestamps[0], ] for i, ws in enumerate(word_timestamps[1:], start=1): # if ws doesn't have a start and end # use the previous end as start and next start as end if ws.get("start") is None and ws.get("word") is not None: ws["start"] = word_timestamps[i - 1]["end"] ws["end"] = _get_next_start_timestamp(word_timestamps, i, final_timestamp) if ws["word"] is not None: result.append(ws) return result def cleanup(path: str): """path could either be relative or absolute.""" # check if file or directory exists if os.path.isfile(path) or os.path.islink(path): # remove file os.remove(path) elif os.path.isdir(path): # remove directory and all its content shutil.rmtree(path) else: raise ValueError("Path {} is not a file or dir.".format(path)) def process_language_arg(language: str, model_name: str): """ Process the language argument to make sure it's valid and convert language names to language codes. """ if language is not None: language = language.lower() if language not in LANGUAGES: if language in TO_LANGUAGE_CODE: language = TO_LANGUAGE_CODE[language] else: raise ValueError(f"Unsupported language: {language}") if model_name.endswith(".en") and language != "en": if language is not None: logging.warning( f"{model_name} is an English-only model but received '{language}'; using English instead." ) language = "en" return language