Amir Zait
bugfix
2d31a01
raw
history blame
3.51 kB
import soundfile as sf
import gradio as gr
import jax
import numpy as np
import os
from PIL import Image
import random
import sox
import torch
from transformers import AutoProcessor, AutoModelForCTC
from transformers import pipeline
from dalle_mini import DalleBart, DalleBartProcessor
from vqgan_jax.modeling_flax_vqgan import VQModel
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
api_token = os.getenv("API_TOKEN")
asr_processor = AutoProcessor.from_pretrained("imvladikon/wav2vec2-xls-r-300m-hebrew")
asr_model = AutoModelForCTC.from_pretrained("imvladikon/wav2vec2-xls-r-300m-hebrew")
he_en_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-tc-big-he-en")
# Model references
# dalle-mini, mega too large
# DALLE_MODEL = "dalle-mini/dalle-mini/mega-1-fp16:latest" # can be wandb artifact or πŸ€— Hub or local folder or google bucket
DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:v0"
DALLE_COMMIT_ID = None
# VQGAN model
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
model = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)
vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)
def generate_image(text):
tokenized_prompt = processor([text])
gen_top_k = None
gen_top_p = None
temperature = 0.85
cond_scale = 3.0
encoded_images = model.generate(
**tokenized_prompt,
prng_key=jax.random.PRNGKey(random.randint(0, 1e7)),
params=model.params,
top_k=gen_top_k,
top_p=gen_top_p,
temperature=temperature,
condition_scale=cond_scale,
)
encoded_images = encoded_images.sequences[..., 1:]
decoded_images = vqgan.decode_code(encoded_images, vqgan.params)
decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
img = decoded_images[0]
return Image.fromarray(np.asarray(img * 255, dtype=np.uint8))
def convert(inputfile, outfile):
sox_tfm = sox.Transformer()
sox_tfm.set_output_format(
file_type="wav", channels=1, encoding="signed-integer", rate=16000, bits=16
)
sox_tfm.build(inputfile, outfile)
def parse_transcription(wav_file):
# Get the wav file from the microphone
filename = wav_file.name.split('.')[0]
convert(wav_file.name, filename + "16k.wav")
speech, _ = sf.read(filename + "16k.wav")
# transcribe to hebrew
input_values = asr_processor(speech, sampling_rate=16_000, return_tensors="pt").input_values
logits = asr_model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = asr_processor.decode(predicted_ids[0], skip_special_tokens=True)
print(transcription)
# translate to english
translated = he_en_translator(transcription)[0]['translation_text']
print(translated)
# generate image
image = generate_image(translated)
return transcription, image
outputs = [gr.outputs.Textbox(label="transcript"), gr.outputs.Image(label='')]
input_mic = gr.inputs.Audio(source="microphone", type="file", optional=True)
gr.Interface(parse_transcription, inputs=[input_mic], outputs=outputs,
analytics_enabled=False,
show_tips=False,
theme='huggingface',
layout='horizontal',
title="Draw Me A Sheep in Hebrew",
enable_queue=True).launch(inline=False)