Amir Zait
fixes
e8b13db
raw
history blame
1.93 kB
from transformers import AutoProcessor, AutoModelForCTC
from transformers import pipeline
import soundfile as sf
import gradio as gr
import torch
import sox
import os
from image_generator import generate_image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
api_token = os.getenv("API_TOKEN")
asr_processor = AutoProcessor.from_pretrained("imvladikon/wav2vec2-xls-r-300m-hebrew")
asr_model = AutoModelForCTC.from_pretrained("imvladikon/wav2vec2-xls-r-300m-hebrew")
he_en_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-tc-big-he-en")
def convert(inputfile, outfile):
sox_tfm = sox.Transformer()
sox_tfm.set_output_format(
file_type="wav", channels=1, encoding="signed-integer", rate=16000, bits=16
)
sox_tfm.build(inputfile, outfile)
def parse_transcription(wav_file):
# Get the wav file from the microphone
filename = wav_file.name.split('.')[0]
convert(wav_file.name, filename + "16k.wav")
speech, _ = sf.read(filename + "16k.wav")
# transcribe to hebrew
input_values = asr_processor(speech, sampling_rate=16_000, return_tensors="pt").input_values
logits = asr_model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = asr_processor.decode(predicted_ids[0], skip_special_tokens=True)
# translate to english
translated = he_en_translator(transcription)[0]['translation_text']
# generate image
image = generate_image(translated)
return image
output = gr.outputs.Image(label='')
input_mic = gr.inputs.Audio(source="microphone", type="file", optional=True)
gr.Interface(parse_transcription, inputs=[input_mic], outputs=output,
analytics_enabled=False,
show_tips=False,
theme='huggingface',
layout='horizontal',
title="Draw Me A Sheep in Hebrew",
enable_queue=True).launch(inline=False)