Spaces:
Build error
Build error
import gradio as gr | |
import whisper | |
import torch | |
import os | |
from diffusers import StableDiffusionPipeline | |
from typing import BinaryIO, Literal | |
def get_device() -> Literal['cuda', 'cpu']: | |
return "cuda" if torch.cuda.is_available() else "cpu" | |
def get_token() -> str: | |
return os.environ.get("HUGGING_FACE_TOKEN") | |
def generate_images(prompt: str, scale: str, iterations: str, seed: str, num_images: str) -> list: | |
AUTH_TOKEN = get_token() | |
device = get_device() | |
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", | |
use_auth_token=AUTH_TOKEN) | |
pipe.to(device) | |
generator = torch.Generator(device).manual_seed(seed) | |
prompt = [prompt] * num_images | |
images = pipe(prompt, num_inference_steps = iterations, guidance_scale = scale, generator=generator).images | |
output_files_names = [] | |
for id, image in enumerate(images): | |
filename = f"output{id}.png" | |
image.save(filename) | |
output_files_names.append(filename) | |
return output_files_names | |
def transcribe_audio(model_selected :str, audio_input: BinaryIO) -> tuple: | |
model = whisper.load_model(model_selected) | |
audio_input = whisper.load_audio(audio_input) | |
audio_input = whisper.pad_or_trim(audio_input) | |
translation_output = "" | |
prompt_for_sd = "" | |
mel = whisper.log_mel_spectrogram(audio_input).to(model.device) | |
transcript_options = whisper.DecodingOptions(task="transcribe", fp16 = False) | |
transcription = whisper.decode(model, mel, transcript_options) | |
prompt_for_sd = transcription.text | |
if transcription.language != "en": | |
translation_options = whisper.DecodingOptions(task="translate", fp16 = False) | |
translation = whisper.decode(model, mel, translation_options) | |
translation_output = translation.text | |
prompt_for_sd = translation_output | |
return transcription.text, translation_output, str(transcription.language).upper(), prompt_for_sd | |
with gr.Blocks() as demo: | |
gr.HTML( | |
""" | |
<div style="text-align: center; max-width: 90%; margin: 0 auto;"> | |
<div> | |
<h1>Whisper App</h1> | |
</div> | |
<p style="margin-bottom: 10px; font-size: 100%"> | |
Try Open AI Whisper with a recorded audio to generate images with Stable Diffusion! | |
</p> | |
</div> | |
""" | |
) | |
with gr.Row(): | |
with gr.Accordion(label="Whisper model selection"): | |
with gr.Row(): | |
model_selection_radio = gr.Radio(['base','small', 'medium', 'large'], value='medium', interactive=True, label="Model") | |
with gr.Tab("Record Prompt"): | |
with gr.Row(): | |
recorded_audio_input = gr.Audio(source="microphone", type="filepath", label="Record your prompt to feed to Stable Diffusion!") | |
audio_transcribe_btn = gr.Button("Launch Whisper") | |
with gr.Row(): | |
transcribed_output_box = gr.TextArea(interactive=False, label="Transcription", placeholder="Transcription will appear here") | |
translated_output_box = gr.TextArea(interactive=True, label="Translated prompt") | |
detected_language_box = gr.Textbox(interactive=False, label="Detected Language") | |
with gr.Tab("Stable Diffusion"): | |
with gr.Row(): | |
prompt_box = gr.TextArea(interactive=False, label="Prompt") | |
with gr.Row(): | |
guidance_slider = gr.Slider(2, 15, value = 7, label = 'Guidance Scale', interactive=True) | |
iterations_slider = gr.Slider(10, 100, value = 25, step = 1, label = 'Number of Iterations', interactive=True) | |
seed_slider = gr.Slider( | |
label = "Seed", | |
minimum = 0, | |
maximum = 2147483647, | |
step = 1, | |
randomize = True, | |
interactive=True) | |
num_images_slider = gr.Slider(2, 8, value= 2, label = "Number of Images Asked", interactive=True) | |
with gr.Row(): | |
images_gallery = gr.Gallery(label="Generated Images").style(grid=[2]) | |
with gr.Row(): | |
generate_image_btn = gr.Button("Generate Images") | |
##################################################### | |
audio_transcribe_btn.click(transcribe_audio, | |
inputs=[ | |
model_selection_radio, | |
recorded_audio_input | |
], | |
outputs=[transcribed_output_box, | |
translated_output_box, | |
detected_language_box, | |
prompt_box | |
] | |
) | |
generate_image_btn.click(generate_images, | |
inputs=[ | |
prompt_box, | |
guidance_slider, | |
iterations_slider, | |
seed_slider, | |
num_images_slider | |
], | |
outputs=images_gallery | |
) | |
demo.launch(enable_queue=True, debug=True) |