Spaces:
Runtime error
Runtime error
from transformers import AutoProcessor, AutoModelForCTC | |
from transformers import pipeline | |
import soundfile as sf | |
import gradio as gr | |
import librosa | |
import torch | |
import sox | |
import os | |
from image_generator import generate_image | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
api_token = os.getenv("API_TOKEN") | |
asr_processor = AutoProcessor.from_pretrained("imvladikon/wav2vec2-xls-r-300m-hebrew") | |
asr_model = AutoModelForCTC.from_pretrained("imvladikon/wav2vec2-xls-r-300m-hebrew") | |
he_en_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-tc-big-he-en") | |
def process_audio_file(file): | |
data, sr = librosa.load(file) | |
if sr != 16000: | |
data = librosa.resample(data, sr, 16000) | |
input_values = processor(data, sampling_rate=16_000, return_tensors="pt").input_values #.to(device) | |
return input_values | |
def transcribe(file_mic, file_upload): | |
warn_output = "" | |
if (file_mic is not None) and (file_upload is not None): | |
warn_output = "WARNING: You've uploaded an audio file and used the microphone. The recorded file from the microphone will be used and the uploaded audio will be discarded.\n" | |
file = file_mic | |
elif (file_mic is None) and (file_upload is None): | |
return "ERROR: You have to either use the microphone or upload an audio file" | |
elif file_mic is not None: | |
file = file_mic | |
else: | |
file = file_upload | |
input_values = process_audio_file(file) | |
logits = model(input_values).logits | |
predicted_ids = torch.argmax(logits, dim=-1) | |
transcription = processor.decode(predicted_ids[0], skip_special_tokens=True) | |
return warn_output + transcription | |
def convert(inputfile, outfile): | |
sox_tfm = sox.Transformer() | |
sox_tfm.set_output_format( | |
file_type="wav", channels=1, encoding="signed-integer", rate=16000, bits=16 | |
) | |
sox_tfm.build(inputfile, outfile) | |
def generate_image(text): | |
pass | |
def parse_transcription(wav_file): | |
filename = wav_file.name.split('.')[0] | |
convert(wav_file.name, filename + "16k.wav") | |
speech, _ = sf.read(filename + "16k.wav") | |
print(speech.shape) | |
input_values = asr_processor(speech, sampling_rate=16_000, return_tensors="pt").input_values | |
logits = asr_model(input_values).logits | |
predicted_ids = torch.argmax(logits, dim=-1) | |
transcription = asr_processor.decode(predicted_ids[0], skip_special_tokens=True) | |
translated = he_en_translator(transcription)[0]['translation_text'] | |
image = generate_image(translated) | |
return image | |
output = gr.outputs.Image(label='') | |
input_mic = gr.inputs.Audio(source="microphone", type="file", optional=True) | |
input_upload = gr.inputs.Audio(source="upload", type="file", optional=True) | |
gr.Interface(parse_transcription, inputs=[input_mic], outputs=output, | |
analytics_enabled=False, | |
show_tips=False, | |
theme='huggingface', | |
layout='horizontal', | |
title="Draw Me A Sheep in Hebrew", | |
enable_queue=True).launch(inline=False) |