Spaces:
Runtime error
Runtime error
from transformers import AutoProcessor, AutoModelForCTC | |
from transformers import pipeline | |
import soundfile as sf | |
import gradio as gr | |
import torch | |
import sox | |
import os | |
from image_generator import generate_image | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
api_token = os.getenv("API_TOKEN") | |
asr_processor = AutoProcessor.from_pretrained("imvladikon/wav2vec2-xls-r-300m-hebrew") | |
asr_model = AutoModelForCTC.from_pretrained("imvladikon/wav2vec2-xls-r-300m-hebrew") | |
he_en_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-tc-big-he-en") | |
def convert(inputfile, outfile): | |
sox_tfm = sox.Transformer() | |
sox_tfm.set_output_format( | |
file_type="wav", channels=1, encoding="signed-integer", rate=16000, bits=16 | |
) | |
sox_tfm.build(inputfile, outfile) | |
def parse_transcription(wav_file): | |
# Get the wav file from the microphone | |
filename = wav_file.name.split('.')[0] | |
convert(wav_file.name, filename + "16k.wav") | |
speech, _ = sf.read(filename + "16k.wav") | |
# transcribe to hebrew | |
input_values = asr_processor(speech, sampling_rate=16_000, return_tensors="pt").input_values | |
logits = asr_model(input_values).logits | |
predicted_ids = torch.argmax(logits, dim=-1) | |
transcription = asr_processor.decode(predicted_ids[0], skip_special_tokens=True) | |
# translate to english | |
translated = he_en_translator(transcription)[0]['translation_text'] | |
# generate image | |
image = generate_image(translated) | |
return image | |
output = gr.outputs.Image(label='') | |
input_mic = gr.inputs.Audio(source="microphone", type="file", optional=True) | |
gr.Interface(parse_transcription, inputs=[input_mic], outputs=output, | |
analytics_enabled=False, | |
show_tips=False, | |
theme='huggingface', | |
layout='horizontal', | |
title="Draw Me A Sheep in Hebrew", | |
enable_queue=True).launch(inline=False) |