|
import gradio as gr |
|
from transformers import pipeline, Wav2Vec2ProcessorWithLM, Wav2Vec2ForCTC |
|
from transformers import T5ForConditionalGeneration, T5Tokenizer |
|
import os |
|
import soundfile as sf |
|
from pyannote.audio import Pipeline |
|
import torch |
|
from pydub import AudioSegment |
|
from pydub.playback import play |
|
from datetime import datetime, timedelta |
|
import time |
|
from silero_vad import load_silero_vad, read_audio, get_speech_timestamps |
|
|
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
sr = 16000 |
|
channels = 1 |
|
|
|
model_name = "Mihaj/wav2vec2-large-xls-r-300m-ruOH-alphav" |
|
bond005_model = "bond005/wav2vec2-large-ru-golos-with-lm" |
|
processor = Wav2Vec2ProcessorWithLM.from_pretrained(bond005_model) |
|
model = Wav2Vec2ForCTC.from_pretrained(bond005_model) |
|
pipe = pipeline("automatic-speech-recognition", model=model, tokenizer=processor, feature_extractor=processor.feature_extractor, decoder=processor.decoder) |
|
model = load_silero_vad() |
|
|
|
pipeline_dia = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1", |
|
use_auth_token=HF_TOKEN) |
|
|
|
device = 'cpu' |
|
|
|
model_name_tr = 'utrobinmv/t5_translate_en_ru_zh_small_1024' |
|
model_tr = T5ForConditionalGeneration.from_pretrained(model_name_tr) |
|
model_tr.to(device) |
|
tokenizer_tr = T5Tokenizer.from_pretrained(model_name_tr) |
|
|
|
prefix = 'translate to en: ' |
|
|
|
|
|
|
|
temp_path = "temp.wav" |
|
|
|
def preprocess(audio_path): |
|
print("PREPROCESSING STARTED") |
|
sound = AudioSegment.from_file(audio_path, format="mp3") |
|
sound = sound.set_frame_rate(sr) |
|
sound = sound.set_channels(channels) |
|
sound.export(temp_path, format="wav") |
|
print("PREPROCESSING ENDED") |
|
return temp_path |
|
|
|
def fast_transcribe(diarise, how_diarise, translate, audio): |
|
audio = preprocess(audio) |
|
y, sr = sf.read(audio) |
|
if diarise: |
|
if how_diarise=="Accurate": |
|
print("DIARISING") |
|
dia = pipeline_dia(audio) |
|
print("DIARISING ENDED") |
|
lines = [] |
|
for i, line in enumerate(dia.to_lab().split('\n')): |
|
if line.strip() != "": |
|
res = line.split(" ") |
|
start = int(float(res[0]) * sr) |
|
start_time = str(datetime.fromtimestamp(start / sr) - timedelta(hours=1, minutes=0)).split()[1] |
|
start_time_prts = start_time.split(":") |
|
start_time_srt = f"{start_time_prts[0]}:{start_time_prts[1]}:{float(start_time_prts[2]):.3f}".replace('.', ',') |
|
end = int(float(res[1]) * sr) |
|
end_time = str(datetime.fromtimestamp(end / sr) - timedelta(hours=1, minutes=0)).split()[1] |
|
end_time_prts = end_time.split(":") |
|
end_time_srt = f"{end_time_prts[0]}:{end_time_prts[1]}:{float(end_time_prts[2]):.3f}".replace('.', ',') |
|
label = res[2] |
|
print(f"RECOGNISING LINE_{i} T_START {start_time_srt} T_END {end_time_srt} SPEAKER_{label}") |
|
trans = pipe(y[start:end], chunk_length_s=10, stride_length_s=(4, 2))["text"] |
|
if not translate: |
|
lines.append(f"{i+1}\n{start_time_srt} --> {end_time_srt}\n[{label}] {trans}\n") |
|
else: |
|
print("TRANSLATION STARTED") |
|
src_text = prefix + trans |
|
|
|
input_ids = tokenizer_tr(src_text, return_tensors="pt") |
|
generated_tokens = model_tr.generate(**input_ids.to(device)) |
|
|
|
trans_eng = tokenizer_tr.batch_decode(generated_tokens, skip_special_tokens=True)[0] |
|
|
|
print(f"TRANSLATION ENDED RESULT {trans_eng}") |
|
lines.append(f"{i+1}\n{start_time_srt} --> {end_time_srt}\n[{label}] {trans}\n[{label}] {trans_eng}\n") |
|
print("RECOGNISING ENDED") |
|
print(f"LINE RESULT {trans}") |
|
else: |
|
print("DIARISING") |
|
wav = read_audio(audio) |
|
speech_timestamps = get_speech_timestamps(wav, model, speech_pad_ms=80, min_silence_duration_ms=150, window_size_samples=256) |
|
print("DIARISING ENDED") |
|
lines = [] |
|
for i, line in enumerate(speech_timestamps): |
|
start = line['start'] |
|
start_time = str(datetime.fromtimestamp(start / sr) - timedelta(hours=1, minutes=0)).split()[1] |
|
start_time_prts = start_time.split(":") |
|
start_time_srt = f"{start_time_prts[0]}:{start_time_prts[1]}:{float(start_time_prts[2]):.3f}".replace('.', ',') |
|
end = line['end'] |
|
end_time = str(datetime.fromtimestamp(end / sr) - timedelta(hours=1, minutes=0)).split()[1] |
|
end_time_prts = end_time.split(":") |
|
end_time_srt = f"{end_time_prts[0]}:{end_time_prts[1]}:{float(end_time_prts[2]):.3f}".replace('.', ',') |
|
print(f"RECOGNISING LINE_{i} T_START {start_time_srt} T_END {end_time_srt}") |
|
trans = pipe(y[start:end], chunk_length_s=10, stride_length_s=(4, 2))["text"] |
|
print("RECOGNISING ENDED") |
|
if not translate: |
|
lines.append(f"{i+1}\n{start_time_srt} --> {end_time_srt}\n[{trans}\n") |
|
else: |
|
print("TRANSLATION STARTED") |
|
src_text = prefix + trans |
|
|
|
input_ids = tokenizer_tr(src_text, return_tensors="pt") |
|
generated_tokens = model_tr.generate(**input_ids.to(device)) |
|
|
|
trans_eng = tokenizer_tr.batch_decode(generated_tokens, skip_special_tokens=True)[0] |
|
|
|
print(f"TRANSLATION ENDED RESULT {trans_eng}") |
|
lines.append(f"{i+1}\n{start_time_srt} --> {end_time_srt}\n{trans}\n{trans_eng}\n") |
|
|
|
print(f"LINE RESULT {trans}") |
|
text = "\n".join(lines) |
|
else: |
|
print("RECOGNISING FULL AUDIO") |
|
res = pipe(y, chunk_length_s=10, stride_length_s=(4, 2)) |
|
print("RECOGNISING FULL AUDIO ENDED") |
|
text = res["text"] |
|
return text |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(""" |
|
#Wav2Vec2 RuOH |
|
Realtime demo for Russian Oral History recognition using several diarizations method (Silero VAD, Pyannote) and a Wav2Vec large model from bond005. https://huggingface.co/bond005/wav2vec2-large-ru-golos-with-lm" |
|
""") |
|
with gr.Tab("Fast Translation"): |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
fast_diarize_input = gr.Checkbox(label="Subtitles", info="Do you want subtitles?") |
|
fast_diarize_radio_input = gr.Radio(["Fast", "Accurate", "-"], label="Separation on subtitles option", info="You can choose separating audio on smaller pieces by faster yet low quality variant (Silero VAD), or slower yet high quality variant (Pyannote.Diarization, this option will detect different speakers)") |
|
fast_translate_input = gr.Checkbox(label="Translate", info="Do you want translation to English?") |
|
fast_audio_input = gr.Audio(type="filepath") |
|
|
|
fast_output = gr.Textbox() |
|
|
|
fast_inputs = [fast_diarize_input, fast_diarize_radio_input, fast_translate_input, fast_audio_input] |
|
fast_recognize_button = gr.Button("Run") |
|
|
|
|
|
fast_recognize_button.click(fast_transcribe, inputs=fast_inputs, outputs=fast_output) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |