Spaces:
Runtime error
Runtime error
File size: 3,506 Bytes
f7c2e78 6c6d0a0 5a87575 f7c2e78 5a87575 f7c2e78 5a87575 be37091 f7c2e78 077c45d f7c2e78 5a87575 23dd537 6c6d0a0 7c6a43f 5a87575 2d31a01 5a87575 f7c2e78 e8b13db f7c2e78 e8b13db d8ec8f4 f7c2e78 d8ec8f4 be37091 23dd537 e8b13db 23dd537 e8b13db be37091 8ee95b0 f7c2e78 8ee95b0 f7c2e78 8ee95b0 f7c2e78 d8ec8f4 f7c2e78 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
import soundfile as sf
import gradio as gr
import jax
import numpy as np
import os
from PIL import Image
import random
import sox
import torch
from transformers import AutoProcessor, AutoModelForCTC
from transformers import pipeline
from dalle_mini import DalleBart, DalleBartProcessor
from vqgan_jax.modeling_flax_vqgan import VQModel
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
api_token = os.getenv("API_TOKEN")
asr_processor = AutoProcessor.from_pretrained("imvladikon/wav2vec2-xls-r-300m-hebrew")
asr_model = AutoModelForCTC.from_pretrained("imvladikon/wav2vec2-xls-r-300m-hebrew")
he_en_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-tc-big-he-en")
# Model references
# dalle-mini, mega too large
# DALLE_MODEL = "dalle-mini/dalle-mini/mega-1-fp16:latest" # can be wandb artifact or 🤗 Hub or local folder or google bucket
DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:v0"
DALLE_COMMIT_ID = None
# VQGAN model
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
model = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)
vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)
def generate_image(text):
tokenized_prompt = processor([text])
gen_top_k = None
gen_top_p = None
temperature = 0.85
cond_scale = 3.0
encoded_images = model.generate(
**tokenized_prompt,
prng_key=jax.random.PRNGKey(random.randint(0, 1e7)),
params=model.params,
top_k=gen_top_k,
top_p=gen_top_p,
temperature=temperature,
condition_scale=cond_scale,
)
encoded_images = encoded_images.sequences[..., 1:]
decoded_images = vqgan.decode_code(encoded_images, vqgan.params)
decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
img = decoded_images[0]
return Image.fromarray(np.asarray(img * 255, dtype=np.uint8))
def convert(inputfile, outfile):
sox_tfm = sox.Transformer()
sox_tfm.set_output_format(
file_type="wav", channels=1, encoding="signed-integer", rate=16000, bits=16
)
sox_tfm.build(inputfile, outfile)
def parse_transcription(wav_file):
# Get the wav file from the microphone
filename = wav_file.name.split('.')[0]
convert(wav_file.name, filename + "16k.wav")
speech, _ = sf.read(filename + "16k.wav")
# transcribe to hebrew
input_values = asr_processor(speech, sampling_rate=16_000, return_tensors="pt").input_values
logits = asr_model(input_values).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = asr_processor.decode(predicted_ids[0], skip_special_tokens=True)
print(transcription)
# translate to english
translated = he_en_translator(transcription)[0]['translation_text']
print(translated)
# generate image
image = generate_image(translated)
return transcription, image
outputs = [gr.outputs.Textbox(label="transcript"), gr.outputs.Image(label='')]
input_mic = gr.inputs.Audio(source="microphone", type="file", optional=True)
gr.Interface(parse_transcription, inputs=[input_mic], outputs=outputs,
analytics_enabled=False,
show_tips=False,
theme='huggingface',
layout='horizontal',
title="Draw Me A Sheep in Hebrew",
enable_queue=True).launch(inline=False) |