import soundfile as sf import gradio as gr import jax import numpy as np import os from PIL import Image import random import sox import torch from transformers import AutoProcessor, AutoModelForCTC from transformers import pipeline from dalle_mini import DalleBart, DalleBartProcessor from vqgan_jax.modeling_flax_vqgan import VQModel device = torch.device("cuda" if torch.cuda.is_available() else "cpu") api_token = os.getenv("API_TOKEN") asr_processor = AutoProcessor.from_pretrained("imvladikon/wav2vec2-xls-r-300m-hebrew") asr_model = AutoModelForCTC.from_pretrained("imvladikon/wav2vec2-xls-r-300m-hebrew") he_en_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-tc-big-he-en") # Model references # dalle-mini, mega too large # DALLE_MODEL = "dalle-mini/dalle-mini/mega-1-fp16:latest" # can be wandb artifact or 🤗 Hub or local folder or google bucket DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:v0" DALLE_COMMIT_ID = None # VQGAN model VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384" VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9" model = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID) vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID) processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID) def generate_image(text): tokenized_prompt = processor([text]) gen_top_k = None gen_top_p = None temperature = 0.85 cond_scale = 3.0 encoded_images = model.generate( **tokenized_prompt, prng_key=jax.random.PRNGKey(random.randint(0, 1e7)), params=model.params, top_k=gen_top_k, top_p=gen_top_p, temperature=temperature, condition_scale=cond_scale, ) encoded_images = encoded_images.sequences[..., 1:] decoded_images = vqgan.decode_code(encoded_images, vqgan.params) decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3)) img = decoded_images[0] return Image.fromarray(np.asarray(img * 255, dtype=np.uint8)) def convert(inputfile, outfile): sox_tfm = sox.Transformer() sox_tfm.set_output_format( file_type="wav", channels=1, encoding="signed-integer", rate=16000, bits=16 ) sox_tfm.build(inputfile, outfile) def parse_transcription(wav_file): # Get the wav file from the microphone filename = wav_file.name.split('.')[0] convert(wav_file.name, filename + "16k.wav") speech, _ = sf.read(filename + "16k.wav") # transcribe to hebrew input_values = asr_processor(speech, sampling_rate=16_000, return_tensors="pt").input_values logits = asr_model(input_values).logits predicted_ids = torch.argmax(logits, dim=-1) transcription = asr_processor.decode(predicted_ids[0], skip_special_tokens=True) print(transcription) # translate to english translated = he_en_translator(transcription)[0]['translation_text'] print(translated) # generate image image = generate_image(translated) return transcription, image outputs = [gr.outputs.Textbox(label="transcript"), gr.outputs.Image(label='')] input_mic = gr.inputs.Audio(source="microphone", type="file", optional=True) gr.Interface(parse_transcription, inputs=[input_mic], outputs=outputs, analytics_enabled=False, show_tips=False, theme='huggingface', layout='horizontal', title="Draw Me A Sheep in Hebrew", enable_queue=True).launch(inline=False)