Mihaj's picture
Update app.py
0f9c96a verified
import gradio as gr
from transformers import pipeline, Wav2Vec2ProcessorWithLM, Wav2Vec2ForCTC
from transformers import T5ForConditionalGeneration, T5Tokenizer
import os
import soundfile as sf
from pyannote.audio import Pipeline
import torch
from pydub import AudioSegment
from pydub.playback import play
from datetime import datetime, timedelta
import time
from silero_vad import load_silero_vad, read_audio, get_speech_timestamps
#from googletrans import Translator
HF_TOKEN = os.environ.get("HF_TOKEN")
sr = 16000
channels = 1
model_name = "Mihaj/wav2vec2-large-xls-r-300m-ruOH-alphav"
bond005_model = "bond005/wav2vec2-large-ru-golos-with-lm"
processor = Wav2Vec2ProcessorWithLM.from_pretrained(bond005_model)
model = Wav2Vec2ForCTC.from_pretrained(bond005_model)
pipe = pipeline("automatic-speech-recognition", model=model, tokenizer=processor, feature_extractor=processor.feature_extractor, decoder=processor.decoder)
model = load_silero_vad()
pipeline_dia = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1",
use_auth_token=HF_TOKEN)
device = 'cpu' #or 'cuda' for translate on cpu
model_name_tr = 'utrobinmv/t5_translate_en_ru_zh_small_1024'
model_tr = T5ForConditionalGeneration.from_pretrained(model_name_tr)
model_tr.to(device)
tokenizer_tr = T5Tokenizer.from_pretrained(model_name_tr)
prefix = 'translate to en: '
# translator = Translator()
temp_path = "temp.wav"
def preprocess(audio_path):
print("PREPROCESSING STARTED")
sound = AudioSegment.from_file(audio_path, format="mp3")
sound = sound.set_frame_rate(sr)
sound = sound.set_channels(channels)
sound.export(temp_path, format="wav")
print("PREPROCESSING ENDED")
return temp_path
def fast_transcribe(diarise, how_diarise, translate, audio):
audio = preprocess(audio)
y, sr = sf.read(audio)
if diarise:
if how_diarise=="Accurate":
print("DIARISING")
dia = pipeline_dia(audio)
print("DIARISING ENDED")
lines = []
for i, line in enumerate(dia.to_lab().split('\n')):
if line.strip() != "":
res = line.split(" ")
start = int(float(res[0]) * sr)
start_time = str(datetime.fromtimestamp(start / sr) - timedelta(hours=1, minutes=0)).split()[1]
start_time_prts = start_time.split(":")
start_time_srt = f"{start_time_prts[0]}:{start_time_prts[1]}:{float(start_time_prts[2]):.3f}".replace('.', ',')
end = int(float(res[1]) * sr)
end_time = str(datetime.fromtimestamp(end / sr) - timedelta(hours=1, minutes=0)).split()[1]
end_time_prts = end_time.split(":")
end_time_srt = f"{end_time_prts[0]}:{end_time_prts[1]}:{float(end_time_prts[2]):.3f}".replace('.', ',')
label = res[2]
print(f"RECOGNISING LINE_{i} T_START {start_time_srt} T_END {end_time_srt} SPEAKER_{label}")
trans = pipe(y[start:end], chunk_length_s=10, stride_length_s=(4, 2))["text"]
if not translate:
lines.append(f"{i+1}\n{start_time_srt} --> {end_time_srt}\n[{label}] {trans}\n")
else:
print("TRANSLATION STARTED")
src_text = prefix + trans
# translate Russian to Eng
input_ids = tokenizer_tr(src_text, return_tensors="pt")
generated_tokens = model_tr.generate(**input_ids.to(device))
trans_eng = tokenizer_tr.batch_decode(generated_tokens, skip_special_tokens=True)[0]
#trans_eng = translator.translate(trans, src='ru', dest="en").text
print(f"TRANSLATION ENDED RESULT {trans_eng}")
lines.append(f"{i+1}\n{start_time_srt} --> {end_time_srt}\n[{label}] {trans}\n[{label}] {trans_eng}\n")
print("RECOGNISING ENDED")
print(f"LINE RESULT {trans}")
else:
print("DIARISING")
wav = read_audio(audio) # backend (sox, soundfile, or ffmpeg) required!
speech_timestamps = get_speech_timestamps(wav, model, speech_pad_ms=80, min_silence_duration_ms=150, window_size_samples=256)
print("DIARISING ENDED")
lines = []
for i, line in enumerate(speech_timestamps):
start = line['start']
start_time = str(datetime.fromtimestamp(start / sr) - timedelta(hours=1, minutes=0)).split()[1]
start_time_prts = start_time.split(":")
start_time_srt = f"{start_time_prts[0]}:{start_time_prts[1]}:{float(start_time_prts[2]):.3f}".replace('.', ',')
end = line['end']
end_time = str(datetime.fromtimestamp(end / sr) - timedelta(hours=1, minutes=0)).split()[1]
end_time_prts = end_time.split(":")
end_time_srt = f"{end_time_prts[0]}:{end_time_prts[1]}:{float(end_time_prts[2]):.3f}".replace('.', ',')
print(f"RECOGNISING LINE_{i} T_START {start_time_srt} T_END {end_time_srt}")
trans = pipe(y[start:end], chunk_length_s=10, stride_length_s=(4, 2))["text"]
print("RECOGNISING ENDED")
if not translate:
lines.append(f"{i+1}\n{start_time_srt} --> {end_time_srt}\n[{trans}\n")
else:
print("TRANSLATION STARTED")
src_text = prefix + trans
# translate Russian to Eng
input_ids = tokenizer_tr(src_text, return_tensors="pt")
generated_tokens = model_tr.generate(**input_ids.to(device))
trans_eng = tokenizer_tr.batch_decode(generated_tokens, skip_special_tokens=True)[0]
#trans_eng = translator.translate(trans, src='ru', dest="en").text
print(f"TRANSLATION ENDED RESULT {trans_eng}")
lines.append(f"{i+1}\n{start_time_srt} --> {end_time_srt}\n{trans}\n{trans_eng}\n")
print(f"LINE RESULT {trans}")
text = "\n".join(lines)
else:
print("RECOGNISING FULL AUDIO")
res = pipe(y, chunk_length_s=10, stride_length_s=(4, 2))
print("RECOGNISING FULL AUDIO ENDED")
text = res["text"]
return text
with gr.Blocks() as demo:
gr.Markdown("""
#Wav2Vec2 RuOH
Realtime demo for Russian Oral History recognition using several diarizations method (Silero VAD, Pyannote) and a Wav2Vec large model from bond005. https://huggingface.co/bond005/wav2vec2-large-ru-golos-with-lm"
""")
with gr.Tab("Fast Translation"):
with gr.Row():
with gr.Column():
fast_diarize_input = gr.Checkbox(label="Subtitles", info="Do you want subtitles?")
fast_diarize_radio_input = gr.Radio(["Fast", "Accurate", "-"], label="Separation on subtitles option", info="You can choose separating audio on smaller pieces by faster yet low quality variant (Silero VAD), or slower yet high quality variant (Pyannote.Diarization, this option will detect different speakers)")
fast_translate_input = gr.Checkbox(label="Translate", info="Do you want translation to English?")
fast_audio_input = gr.Audio(type="filepath")
fast_output = gr.Textbox()
fast_inputs = [fast_diarize_input, fast_diarize_radio_input, fast_translate_input, fast_audio_input]
fast_recognize_button = gr.Button("Run")
fast_recognize_button.click(fast_transcribe, inputs=fast_inputs, outputs=fast_output)
if __name__ == "__main__":
demo.launch()