File size: 3,411 Bytes
f7c2e78
 
5a87575
f7c2e78
5a87575
 
 
 
f7c2e78
5a87575
 
 
 
be37091
f7c2e78
 
 
 
 
 
077c45d
f7c2e78
5a87575
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23dd537
7c6a43f
 
 
 
 
 
5a87575
 
 
 
 
 
 
f7c2e78
 
 
 
 
 
 
 
e8b13db
f7c2e78
 
 
e8b13db
 
d8ec8f4
 
f7c2e78
d8ec8f4
be37091
23dd537
 
e8b13db
 
 
23dd537
 
e8b13db
be37091
 
f7c2e78
be37091
f7c2e78
 
 
 
 
 
 
d8ec8f4
f7c2e78
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import soundfile as sf
import gradio as gr
import numpy as np
import os
from PIL import Image
import random
import sox
import torch

from transformers import AutoProcessor, AutoModelForCTC
from transformers import pipeline
from dalle_mini import DalleBart, DalleBartProcessor
from vqgan_jax.modeling_flax_vqgan import VQModel

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

api_token = os.getenv("API_TOKEN")
asr_processor = AutoProcessor.from_pretrained("imvladikon/wav2vec2-xls-r-300m-hebrew")
asr_model = AutoModelForCTC.from_pretrained("imvladikon/wav2vec2-xls-r-300m-hebrew")

he_en_translator = pipeline("translation", model="Helsinki-NLP/opus-mt-tc-big-he-en")

# Model references
# dalle-mini, mega too large
# DALLE_MODEL = "dalle-mini/dalle-mini/mega-1-fp16:latest"  # can be wandb artifact or 🤗 Hub or local folder or google bucket
DALLE_MODEL = "dalle-mini/dalle-mini/mini-1:v0"
DALLE_COMMIT_ID = None

# VQGAN model
VQGAN_REPO = "dalle-mini/vqgan_imagenet_f16_16384"
VQGAN_COMMIT_ID = "e93a26e7707683d349bf5d5c41c5b0ef69b677a9"
 
model = DalleBart.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)
vqgan = VQModel.from_pretrained(VQGAN_REPO, revision=VQGAN_COMMIT_ID)
processor = DalleBartProcessor.from_pretrained(DALLE_MODEL, revision=DALLE_COMMIT_ID)

def generate_image(text):
    tokenized_prompt = processor([text])

    gen_top_k = None
    gen_top_p = None
    temperature = 0.85
    cond_scale = 3.0
 
    encoded_images = model.generate(
        **tokenized_prompt,
        prng_key=random.randint(0, 1e7),
        params=model.params,
        top_k=gen_top_k,
        top_p=gen_top_p,
        temperature=temperature,
        condition_scale=cond_scale,
        )
    encoded_images = encoded_images.sequences[..., 1:]
    decoded_images = model.decode(encoded_images, vqgan.params)
    decoded_images = decoded_images.clip(0.0, 1.0).reshape((-1, 256, 256, 3))
    img = decoded_images[0]
    return Image.fromarray(np.asarray(img * 255, dtype=np.uint8))

def convert(inputfile, outfile):
    sox_tfm = sox.Transformer()
    sox_tfm.set_output_format(
        file_type="wav", channels=1, encoding="signed-integer", rate=16000, bits=16
    )
    sox_tfm.build(inputfile, outfile)

def parse_transcription(wav_file):
    # Get the wav file from the microphone
    filename = wav_file.name.split('.')[0]
    convert(wav_file.name, filename + "16k.wav")
    speech, _ = sf.read(filename + "16k.wav")

    # transcribe to hebrew
    input_values = asr_processor(speech, sampling_rate=16_000, return_tensors="pt").input_values
    logits = asr_model(input_values).logits
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = asr_processor.decode(predicted_ids[0], skip_special_tokens=True)
    
    print(transcription)

    # translate to english
    translated = he_en_translator(transcription)[0]['translation_text']

    print(translated) 

    # generate image
    image = generate_image(translated)
    return image 

output = gr.outputs.Image(label='')
input_mic = gr.inputs.Audio(source="microphone", type="file", optional=True)

gr.Interface(parse_transcription, inputs=[input_mic],  outputs=output,
             analytics_enabled=False,
             show_tips=False,
             theme='huggingface',
             layout='horizontal',
             title="Draw Me A Sheep in Hebrew",
             enable_queue=True).launch(inline=False)