XGBooster's picture
Update app.py
365471b
import gradio as gr
import whisper
import torch
import os
from diffusers import StableDiffusionPipeline
from typing import BinaryIO, Literal
def get_device() -> Literal['cuda', 'cpu']:
return "cuda" if torch.cuda.is_available() else "cpu"
def get_token() -> str:
return os.environ.get("HUGGING_FACE_TOKEN")
def generate_images(prompt: str, scale: str, iterations: str, seed: str, num_images: str) -> list:
AUTH_TOKEN = get_token()
device = get_device()
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4",
use_auth_token=AUTH_TOKEN)
pipe.to(device)
generator = torch.Generator(device).manual_seed(seed)
prompt = [prompt] * num_images
images = pipe(prompt, num_inference_steps = iterations, guidance_scale = scale, generator=generator).images
output_files_names = []
for id, image in enumerate(images):
filename = f"output{id}.png"
image.save(filename)
output_files_names.append(filename)
return output_files_names
def transcribe_audio(model_selected :str, audio_input: BinaryIO) -> tuple:
model = whisper.load_model(model_selected)
audio_input = whisper.load_audio(audio_input)
audio_input = whisper.pad_or_trim(audio_input)
translation_output = ""
prompt_for_sd = ""
mel = whisper.log_mel_spectrogram(audio_input).to(model.device)
transcript_options = whisper.DecodingOptions(task="transcribe", fp16 = False)
transcription = whisper.decode(model, mel, transcript_options)
prompt_for_sd = transcription.text
if transcription.language != "en":
translation_options = whisper.DecodingOptions(task="translate", fp16 = False)
translation = whisper.decode(model, mel, translation_options)
translation_output = translation.text
prompt_for_sd = translation_output
return transcription.text, translation_output, str(transcription.language).upper(), prompt_for_sd
with gr.Blocks() as demo:
gr.HTML(
"""
<div style="text-align: center; max-width: 90%; margin: 0 auto;">
<div>
<h1>Whisper App</h1>
</div>
<p style="margin-bottom: 10px; font-size: 100%">
Try Open AI Whisper with a recorded audio to generate images with Stable Diffusion!
</p>
</div>
"""
)
with gr.Row():
with gr.Accordion(label="Whisper model selection"):
with gr.Row():
model_selection_radio = gr.Radio(['base','small', 'medium', 'large'], value='medium', interactive=True, label="Model")
with gr.Tab("Record Prompt"):
with gr.Row():
recorded_audio_input = gr.Audio(source="microphone", type="filepath", label="Record your prompt to feed to Stable Diffusion!")
audio_transcribe_btn = gr.Button("Launch Whisper")
with gr.Row():
transcribed_output_box = gr.TextArea(interactive=False, label="Transcription", placeholder="Transcription will appear here")
translated_output_box = gr.TextArea(interactive=True, label="Translated prompt")
detected_language_box = gr.Textbox(interactive=False, label="Detected Language")
with gr.Tab("Stable Diffusion"):
with gr.Row():
prompt_box = gr.TextArea(interactive=False, label="Prompt")
with gr.Row():
guidance_slider = gr.Slider(2, 15, value = 7, label = 'Guidance Scale', interactive=True)
iterations_slider = gr.Slider(10, 100, value = 25, step = 1, label = 'Number of Iterations', interactive=True)
seed_slider = gr.Slider(
label = "Seed",
minimum = 0,
maximum = 2147483647,
step = 1,
randomize = True,
interactive=True)
num_images_slider = gr.Slider(2, 8, value= 2, label = "Number of Images Asked", interactive=True)
with gr.Row():
images_gallery = gr.Gallery(label="Generated Images").style(grid=[2])
with gr.Row():
generate_image_btn = gr.Button("Generate Images")
#####################################################
audio_transcribe_btn.click(transcribe_audio,
inputs=[
model_selection_radio,
recorded_audio_input
],
outputs=[transcribed_output_box,
translated_output_box,
detected_language_box,
prompt_box
]
)
generate_image_btn.click(generate_images,
inputs=[
prompt_box,
guidance_slider,
iterations_slider,
seed_slider,
num_images_slider
],
outputs=images_gallery
)
demo.launch(enable_queue=True, debug=True)