Spaces:
Runtime error
Runtime error
import soundfile as sf | |
import gradio as gr | |
import jax | |
import numpy as np | |
import os | |
from PIL import Image | |
import random | |
import sox | |
import torch | |
from transformers import AutoProcessor, AutoModelForCTC | |
from transformers import pipeline | |
from dalle_mini import DalleBart, DalleBartProcessor | |
from vqgan_jax.modeling_flax_vqgan import VQModel | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
api_token = os.getenv("API_TOKEN") | |
asr_processor = AutoProcessor.from_pretrained("imvladikon/wav2vec2-xls-r-300m-hebrew") | |
asr_model = AutoModelForCTC.from_pretrained("imvladikon/wav2vec2-xls-r-300m-hebrew") | |
he_en_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-tc-big-he-en") | |
# Model references | |
# dalle-mini, mega too large | |
# DALLE_MODEL = "dalle-mini/dalle-mini/mega-1-fp16:latest" # can be wandb artifact or π€ Hub or local folder or google bucket | |
DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:v0" | |
DALLE_COMMIT_ID = None | |
# VQGAN model | |
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384" | |
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9" | |
model = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID) | |
vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID) | |
processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID) | |
def generate_image(text): | |
tokenized_prompt = processor([text]) | |
gen_top_k = None | |
gen_top_p = None | |
temperature = 0.85 | |
cond_scale = 3.0 | |
encoded_images = model.generate( | |
**tokenized_prompt, | |
prng_key=jax.random.PRNGKey(random.randint(0, 1e7)), | |
params=model.params, | |
top_k=gen_top_k, | |
top_p=gen_top_p, | |
temperature=temperature, | |
condition_scale=cond_scale, | |
) | |
encoded_images = encoded_images.sequences[..., 1:] | |
decoded_images = vqgan.decode_code(encoded_images, vqgan.params) | |
decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3)) | |
img = decoded_images[0] | |
return Image.fromarray(np.asarray(img * 255, dtype=np.uint8)) | |
def convert(inputfile, outfile): | |
sox_tfm = sox.Transformer() | |
sox_tfm.set_output_format( | |
file_type="wav", channels=1, encoding="signed-integer", rate=16000, bits=16 | |
) | |
sox_tfm.build(inputfile, outfile) | |
def parse_transcription(wav_file): | |
# Get the wav file from the microphone | |
filename = wav_file.name.split('.')[0] | |
convert(wav_file.name, filename + "16k.wav") | |
speech, _ = sf.read(filename + "16k.wav") | |
# transcribe to hebrew | |
input_values = asr_processor(speech, sampling_rate=16_000, return_tensors="pt").input_values | |
logits = asr_model(input_values).logits | |
predicted_ids = torch.argmax(logits, dim=-1) | |
transcription = asr_processor.decode(predicted_ids[0], skip_special_tokens=True) | |
print(transcription) | |
# translate to english | |
translated = he_en_translator(transcription)[0]['translation_text'] | |
print(translated) | |
# generate image | |
image = generate_image(translated) | |
return transcription, image | |
outputs = [gr.outputs.Textbox(label="transcript"), gr.outputs.Image(label='')] | |
input_mic = gr.inputs.Audio(source="microphone", type="file", optional=True) | |
gr.Interface(parse_transcription, inputs=[input_mic], outputs=outputs, | |
analytics_enabled=False, | |
show_tips=False, | |
theme='huggingface', | |
layout='horizontal', | |
title="Draw Me A Sheep in Hebrew", | |
enable_queue=True).launch(inline=False) |