Porjaz's picture
Update app.py
f873ad8 verified
raw
history blame
12.5 kB
import spaces
import os
# os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import gc
from functools import partial
import gradio as gr
import torch
from speechbrain.inference.interfaces import Pretrained, foreign_class
from transformers import T5Tokenizer, T5ForConditionalGeneration
import librosa
import whisper_timestamped as whisper
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline, Wav2Vec2ForCTC, AutoProcessor
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cuda.matmul.allow_tf32 = True
def clean_up_memory():
gc.collect()
torch.cuda.empty_cache()
@spaces.GPU(duration=15)
def recap_sentence(string):
# Restore capitalization and punctuation using the model
inputs = recap_tokenizer(["restore capitalization and punctuation: " + string], return_tensors="pt", padding=True).to(device)
outputs = recap_model.generate(**inputs, max_length=768, num_beams=5, early_stopping=True).squeeze(0)
recap_result = recap_tokenizer.decode(outputs, skip_special_tokens=True)
return recap_result
@spaces.GPU(duration=30)
def return_prediction_w2v2(mic=None, file=None, device=device):
if mic is not None:
waveform, sr = librosa.load(mic, sr=16000)
waveform = waveform[:60*sr]
w2v2_result = w2v2_classifier.classify_file_w2v2(waveform, device)
elif file is not None:
waveform, sr = librosa.load(file, sr=16000)
waveform = waveform[:60*sr]
w2v2_result = w2v2_classifier.classify_file_w2v2(waveform, device)
else:
return "You must either provide a mic recording or a file"
recap_result = recap_sentence(w2v2_result[0])
# If the letter after punct is small, recap it
for i, letter in enumerate(recap_result):
if i > 1 and recap_result[i-2] in [".", "!", "?"] and letter.islower():
recap_result = recap_result[:i] + letter.upper() + recap_result[i+1:]
clean_up_memory()
return recap_result
@spaces.GPU(duration=30)
def return_prediction_whisper_mic(mic=None, device=device):
if mic is not None:
waveform, sr = librosa.load(mic, sr=16000)
waveform = waveform[:30*sr]
whisper_result = whisper_classifier.classify_file_whisper_mkd(waveform, device)
else:
return "You must provide a mic recording"
recap_result = recap_sentence(whisper_result[0])
# If the letter after punct is small, recap it
for i, letter in enumerate(recap_result):
if i > 1 and recap_result[i-2] in [".", "!", "?"] and letter.islower():
recap_result = recap_result[:i] + letter.upper() + recap_result[i+1:]
clean_up_memory()
return recap_result
@spaces.GPU(duration=60)
def return_prediction_whisper_file(file=None, device=device):
whisper_result = []
if file is not None:
waveform, sr = librosa.load(file, sr=16000)
waveform = waveform[:3600*sr]
whisper_result = whisper_classifier.classify_file_whisper_mkd_streaming(waveform, device)
else:
yield "You must provide a mic recording"
recap_result = ""
prev_segment = ""
prev_segment_len = 0
segment_counter = 0
for segment in whisper_result:
segment_counter += 1
if prev_segment == "":
recap_segment= recap_sentence(segment[0])
else:
prev_segment_len = len(prev_segment.split())
recap_segment = recap_sentence(prev_segment + " " + segment[0])
# remove prev_segment from the beginning of the recap_result
recap_segment = recap_segment.split()
recap_segment = recap_segment[prev_segment_len:]
recap_segment = " ".join(recap_segment)
prev_segment = segment[0]
recap_result += recap_segment + " "
# If the letter after punct is small, recap it
for i, letter in enumerate(recap_result):
if i > 1 and recap_result[i-2] in [".", "!", "?"] and letter.islower():
recap_result = recap_result[:i] + letter.upper() + recap_result[i+1:]
yield recap_result
def return_prediction_compare(mic=None, file=None, device=device):
# pipe_whisper.model.to(device)
# mms_model.to(device)
if mic is not None:
waveform, sr = librosa.load(mic, sr=16000)
waveform = waveform[:60*sr]
whisper_mkd_result = whisper_classifier.classify_file_whisper_mkd(waveform, device)
# result_generator_w2v2 = w2v2_classifier.classify_file_w2v2(mic, device)
whisper_result = whisper_classifier.classify_file_whisper(waveform, pipe_whisper, device)
mms_result_generator = whisper_classifier.classify_file_mms(waveform, processor_mms, mms_model, device)
elif file is not None:
waveform, sr = librosa.load(file, sr=16000)
waveform = waveform[:30*sr]
whisper_mkd_result = whisper_classifier.classify_file_whisper_mkd(waveform, device)
# result_generator_w2v2 = w2v2_classifier.classify_file_w2v2(file, device)
whisper_result = whisper_classifier.classify_file_whisper(waveform, pipe_whisper, device)
mms_result_generator = whisper_classifier.classify_file_mms(waveform, processor_mms, mms_model, device)
else:
return "You must either provide a mic recording or a file"
# pipe_whisper.model.to("cpu")
# mms_model.to("cpu")
segment_results_whisper = ""
prev_segment_whisper = ""
# segment_results_w2v2 = ""
# prev_segment_w2v2 = ""
segment_results_mms = ""
prev_segment_mms = ""
recap_result_whisper_mkd = recap_sentence(whisper_mkd_result[0])
recap_result_whisper = recap_sentence(whisper_result[0])
recap_result_mms = recap_sentence(mms_result_generator[0])
# If the letter after punct is small, recap it
for i, letter in enumerate(recap_result_whisper_mkd):
if i > 1 and recap_result_whisper_mkd[i-2] in [".", "!", "?"] and letter.islower():
recap_result_whisper_mkd = recap_result_whisper_mkd[:i] + letter.upper() + recap_result_whisper_mkd[i+1:]
for i, letter in enumerate(recap_result_whisper):
if i > 1 and recap_result_whisper[i-2] in [".", "!", "?"] and letter.islower():
recap_result_whisper = recap_result_whisper[:i] + letter.upper() + recap_result_whisper[i+1:]
for i, letter in enumerate(recap_result_mms):
if i > 1 and recap_result_mms[i-2] in [".", "!", "?"] and letter.islower():
recap_result_mms = recap_result_mms[:i] + letter.upper() + recap_result_mms[i+1:]
clean_up_memory()
return "Буки-Whisper:\n" + recap_result_whisper_mkd + "\n\n" + "MMS:\n" + recap_result_mms + "\n\n" + "OpenAI Whisper:\n" + recap_result_whisper
# yield "Our W2v2: \n" + segment_results_w2v2 + "\n\n" + "MMS transcript:\n" + segment_results_mms
# Create a partial function with the device pre-applied
return_prediction_whisper_mic_with_device = partial(return_prediction_whisper_mic, device=device)
return_prediction_whisper_file_with_device = partial(return_prediction_whisper_file, device=device)
return_prediction_w2v2_with_device = partial(return_prediction_w2v2, device=device)
# Load the ASR models
whisper_classifier = foreign_class(source="Macedonian-ASR/whisper-large-v3-macedonian-asr", pymodule_file="custom_interface_app.py", classname="ASR")
whisper_classifier = whisper_classifier.to(device)
whisper_classifier.eval()
w2v2_classifier = foreign_class(source="Macedonian-ASR/wav2vec2-aed-macedonian-asr", pymodule_file="custom_interface_app.py", classname="ASR")
w2v2_classifier = w2v2_classifier.to(device)
w2v2_classifier.eval()
# Load the T5 tokenizer and model for restoring capitalization
recap_model_name = "Macedonian-ASR/mt5-restore-capitalization-macedonian"
recap_tokenizer = T5Tokenizer.from_pretrained(recap_model_name)
recap_model = T5ForConditionalGeneration.from_pretrained(recap_model_name, torch_dtype=torch.float16)
recap_model.to(device)
recap_model.eval()
mic_transcribe_whisper = gr.Interface(
fn=return_prediction_whisper_mic_with_device,
inputs=gr.Audio(sources="microphone", type="filepath"),
outputs=gr.Textbox(),
allow_flagging="never",
live=False,
)
file_transcribe_whisper = gr.Interface(
fn=return_prediction_whisper_file_with_device,
inputs=gr.Audio(sources="upload", type="filepath"),
outputs=gr.Textbox(),
allow_flagging="never",
live=True
)
mic_transcribe_w2v2 = gr.Interface(
fn=return_prediction_w2v2_with_device,
inputs=gr.Audio(sources="microphone", type="filepath"),
outputs=gr.Textbox(),
allow_flagging="never",
live=False,
)
project_description = '''
<img src="https://i.ibb.co/hYhkkhg/Buki-logo-1.jpg"
alt="Bookie logo"
style="float: right; width: 150px; height: 150px; margin-left: 10px;" />
## Автори:
1. **Дејан Порјазовски**
2. **Илина Јакимовска**
3. **Ордан Чукалиев**
4. **Никола Стиков**
Оваа колаборација е дел од активностите на **Центарот за напредни интердисциплинарни истражувања ([ЦеНИИс](https://ukim.edu.mk/en/centri/centar-za-napredni-interdisciplinarni-istrazhuvanja-ceniis))** при УКИМ.
## Во тренирањето на овој модел се употребени податоци од:
1. Дигитален архив за етнолошки и антрополошки ресурси ([ДАЕАР](https://iea.pmf.ukim.edu.mk/tabs/view/61f236ed7d95176b747c20566ddbda1a)) при Институтот за етнологија и антропологија, Природно-математички факултет при УКИМ.
2. Аудио верзија на меѓународното списание [„ЕтноАнтропоЗум“](https://etno.pmf.ukim.mk/index.php/eaz/issue/archive) на Институтот за етнологија и антропологија, Природно-математички факултет при УКИМ.
3. Аудио подкастот [„Обични луѓе“](https://obicniluge.mk/episodes/) на Илина Јакимовска
4. Научните видеа од серијалот [„Наука за деца“](http://naukazadeca.mk), фондација [КАНТАРОТ](https://qantarot.substack.com/)
5. Македонска верзија на [Mozilla Common Voice](https://commonvoice.mozilla.org/en/datasets) (верзија 18.0)
## Како да придонесете за подобрување на македонските модели за препознавање на говор?
На следниот [линк](https://drive.google.com/file/d/1YdZJz9o1X8AMc6J4MNPnVZjASyIXnvoZ/view?usp=sharing) ќе најдете инструкции за тоа како да донирате македонски говор преку платформата Mozilla Common Voice.
'''
# Custom CSS
css = """
.gradio-container {
background-color: #f0f0f0; /* Set your desired background color */
}
.custom-markdown p, .custom-markdown li, .custom-markdown h2, .custom-markdown a, .custom-markdown strong {
font-size: 15px !important;
font-family: Arial, sans-serif !important;
color: black !important;
}
.gradio-container {
background-color: #f3f3f3 !important;
}
"""
transcriber_app = gr.Blocks(css=css, delete_cache=(60, 120))
with transcriber_app:
state = gr.State()
gr.Markdown(project_description, elem_classes="custom-markdown")
# gr.TabbedInterface(
# [mic_transcribe_whisper, mic_transcribe_compare],
# ["Буки-Whisper транскрипција", "Споредба на модели"],
# )
# state = gr.State(value=[], delete_callback=lambda v: print("STATE DELETED"))
gr.TabbedInterface(
[mic_transcribe_whisper, mic_transcribe_w2v2],
["Буки-Whisper транскрипција", "Буки-Wav2vec2 транскрипција"],
)
state = gr.State(value=[], delete_callback=lambda v: print("STATE DELETED"))
transcriber_app.unload(return_prediction_whisper_mic)
transcriber_app.unload(return_prediction_whisper_file)
transcriber_app.unload(return_prediction_compare)
transcriber_app.unload(return_prediction_w2v2)
# transcriber_app.launch(debug=True, share=True, ssl_verify=False)
if __name__ == "__main__":
transcriber_app.queue()
transcriber_app.launch(share=True)