Amir Zait
fixed
7c6a43f
raw
history blame
3.41 kB
import soundfile as sf
import gradio as gr
import numpy as np
import os
from PIL import Image
import random
import sox
import torch
from transformers import AutoProcessor, AutoModelForCTC
from transformers import pipeline
from dalle_mini import DalleBart, DalleBartProcessor
from vqgan_jax.modeling_flax_vqgan import VQModel
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
api_token = os.getenv("API_TOKEN")
asr_processor = AutoProcessor.from_pretrained("imvladikon/wav2vec2-xls-r-300m-hebrew")
asr_model = AutoModelForCTC.from_pretrained("imvladikon/wav2vec2-xls-r-300m-hebrew")
he_en_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-tc-big-he-en")
# Model references
# dalle-mini, mega too large
# DALLE_MODEL = "dalle-mini/dalle-mini/mega-1-fp16:latest" # can be wandb artifact or πŸ€— Hub or local folder or google bucket
DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:v0"
DALLE_COMMIT_ID = None
# VQGAN model
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
model = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)
vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)
def generate_image(text):
tokenized_prompt = processor([text])
gen_top_k = None
gen_top_p = None
temperature = 0.85
cond_scale = 3.0
encoded_images = model.generate(
**tokenized_prompt,
prng_key=random.randint(0, 1e7),
params=model.params,
top_k=gen_top_k,
top_p=gen_top_p,
temperature=temperature,
condition_scale=cond_scale,
)
encoded_images = encoded_images.sequences[..., 1:]
decoded_images = model.decode(encoded_images, vqgan.params)
decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
img = decoded_images[0]
return Image.fromarray(np.asarray(img * 255, dtype=np.uint8))
def convert(inputfile, outfile):
sox_tfm = sox.Transformer()
sox_tfm.set_output_format(
file_type="wav", channels=1, encoding="signed-integer", rate=16000, bits=16
)
sox_tfm.build(inputfile, outfile)
def parse_transcription(wav_file):
# Get the wav file from the microphone
filename = wav_file.name.split('.')[0]
convert(wav_file.name, filename + "16k.wav")
speech, _ = sf.read(filename + "16k.wav")
# transcribe to hebrew
input_values = asr_processor(speech, sampling_rate=16_000, return_tensors="pt").input_values
logits = asr_model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = asr_processor.decode(predicted_ids[0], skip_special_tokens=True)
print(transcription)
# translate to english
translated = he_en_translator(transcription)[0]['translation_text']
print(translated)
# generate image
image = generate_image(translated)
return image
output = gr.outputs.Image(label='')
input_mic = gr.inputs.Audio(source="microphone", type="file", optional=True)
gr.Interface(parse_transcription, inputs=[input_mic], outputs=output,
analytics_enabled=False,
show_tips=False,
theme='huggingface',
layout='horizontal',
title="Draw Me A Sheep in Hebrew",
enable_queue=True).launch(inline=False)