johann22's picture
Upload 13 files
84d670e verified
raw
history blame contribute delete
No virus
6.26 kB
import argparse
import logging
import os
import re
import torch
import torchaudio
from ctc_forced_aligner import (
generate_emissions,
get_alignments,
get_spans,
load_alignment_model,
postprocess_results,
preprocess_text,
)
from deepmultilingualpunctuation import PunctuationModel
from nemo.collections.asr.models.msdd_models import NeuralDiarizer
from helpers import (
cleanup,
create_config,
get_realigned_ws_mapping_with_punctuation,
get_sentences_speaker_mapping,
get_speaker_aware_transcript,
get_words_speaker_mapping,
langs_to_iso,
punct_model_langs,
whisper_langs,
write_srt,
)
from transcription_helpers import transcribe_batched
mtypes = {"cpu": "int8", "cuda": "float16"}
# Initialize parser
parser = argparse.ArgumentParser()
parser.add_argument(
"-a", "--audio", help="name of the target audio file", required=True
)
parser.add_argument(
"--no-stem",
action="store_false",
dest="stemming",
default=True,
help="Disables source separation."
"This helps with long files that don't contain a lot of music.",
)
parser.add_argument(
"--suppress_numerals",
action="store_true",
dest="suppress_numerals",
default=False,
help="Suppresses Numerical Digits."
"This helps the diarization accuracy but converts all digits into written text.",
)
parser.add_argument(
"--whisper-model",
dest="model_name",
default="medium.en",
help="name of the Whisper model to use",
)
parser.add_argument(
"--batch-size",
type=int,
dest="batch_size",
default=8,
help="Batch size for batched inference, reduce if you run out of memory, set to 0 for non-batched inference",
)
parser.add_argument(
"--language",
type=str,
default=None,
choices=whisper_langs,
help="Language spoken in the audio, specify None to perform language detection",
)
parser.add_argument(
"--device",
dest="device",
default="cuda" if torch.cuda.is_available() else "cpu",
help="if you have a GPU use 'cuda', otherwise 'cpu'",
)
args = parser.parse_args()
if args.stemming:
# Isolate vocals from the rest of the audio
return_code = os.system(
f'python3 -m demucs.separate -n htdemucs --two-stems=vocals "{args.audio}" -o "temp_outputs"'
)
if return_code != 0:
logging.warning(
"Source splitting failed, using original audio file. Use --no-stem argument to disable it."
)
vocal_target = args.audio
else:
vocal_target = os.path.join(
"temp_outputs",
"htdemucs",
os.path.splitext(os.path.basename(args.audio))[0],
"vocals.wav",
)
else:
vocal_target = args.audio
# Transcribe the audio file
whisper_results, language, audio_waveform = transcribe_batched(
vocal_target,
args.language,
args.batch_size,
args.model_name,
mtypes[args.device],
args.suppress_numerals,
args.device,
)
# Forced Alignment
alignment_model, alignment_tokenizer, alignment_dictionary = load_alignment_model(
args.device,
dtype=torch.float16 if args.device == "cuda" else torch.float32,
)
audio_waveform = (
torch.from_numpy(audio_waveform)
.to(alignment_model.dtype)
.to(alignment_model.device)
)
emissions, stride = generate_emissions(
alignment_model, audio_waveform, batch_size=args.batch_size
)
del alignment_model
torch.cuda.empty_cache()
full_transcript = "".join(segment["text"] for segment in whisper_results)
tokens_starred, text_starred = preprocess_text(
full_transcript,
romanize=True,
language=langs_to_iso[language],
)
segments, scores, blank_id = get_alignments(
emissions,
tokens_starred,
alignment_dictionary,
)
spans = get_spans(tokens_starred, segments, alignment_tokenizer.decode(blank_id))
word_timestamps = postprocess_results(text_starred, spans, stride, scores)
# convert audio to mono for NeMo combatibility
ROOT = os.getcwd()
temp_path = os.path.join(ROOT, "temp_outputs")
os.makedirs(temp_path, exist_ok=True)
torchaudio.save(
os.path.join(temp_path, "mono_file.wav"),
audio_waveform.cpu().unsqueeze(0).float(),
16000,
channels_first=True,
)
# Initialize NeMo MSDD diarization model
msdd_model = NeuralDiarizer(cfg=create_config(temp_path)).to(args.device)
msdd_model.diarize()
del msdd_model
torch.cuda.empty_cache()
# Reading timestamps <> Speaker Labels mapping
speaker_ts = []
with open(os.path.join(temp_path, "pred_rttms", "mono_file.rttm"), "r") as f:
lines = f.readlines()
for line in lines:
line_list = line.split(" ")
s = int(float(line_list[5]) * 1000)
e = s + int(float(line_list[8]) * 1000)
speaker_ts.append([s, e, int(line_list[11].split("_")[-1])])
wsm = get_words_speaker_mapping(word_timestamps, speaker_ts, "start")
if language in punct_model_langs:
# restoring punctuation in the transcript to help realign the sentences
punct_model = PunctuationModel(model="kredor/punctuate-all")
words_list = list(map(lambda x: x["word"], wsm))
labled_words = punct_model.predict(words_list, chunk_size=230)
ending_puncts = ".?!"
model_puncts = ".,;:!?"
# We don't want to punctuate U.S.A. with a period. Right?
is_acronym = lambda x: re.fullmatch(r"\b(?:[a-zA-Z]\.){2,}", x)
for word_dict, labeled_tuple in zip(wsm, labled_words):
word = word_dict["word"]
if (
word
and labeled_tuple[1] in ending_puncts
and (word[-1] not in model_puncts or is_acronym(word))
):
word += labeled_tuple[1]
if word.endswith(".."):
word = word.rstrip(".")
word_dict["word"] = word
else:
logging.warning(
f"Punctuation restoration is not available for {language} language. Using the original punctuation."
)
wsm = get_realigned_ws_mapping_with_punctuation(wsm)
ssm = get_sentences_speaker_mapping(wsm, speaker_ts)
with open(f"{os.path.splitext(args.audio)[0]}.txt", "w", encoding="utf-8-sig") as f:
get_speaker_aware_transcript(ssm, f)
with open(f"{os.path.splitext(args.audio)[0]}.srt", "w", encoding="utf-8-sig") as srt:
write_srt(ssm, srt)
cleanup(temp_path)