|
from fix_cache import remove_old_cache |
|
import gradio as gr |
|
from diffusers import StableDiffusionPipeline |
|
import torch |
|
|
|
remove_old_cache() |
|
|
|
models = { |
|
"Stable Diffusion v1.5": "runwayml/stable-diffusion-v1-5", |
|
"Stable Diffusion v2.1": "stabilityai/stable-diffusion-2-1", |
|
"Anime Diffusion": "hakurei/waifu-diffusion-v1-4", |
|
} |
|
|
|
|
|
def load_model(model_name): |
|
model_id = models[model_name] |
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
model_id, torch_dtype=torch.float16 |
|
) |
|
pipe = pipe.to("cpu") |
|
return pipe |
|
|
|
|
|
current_pipe = load_model("Stable Diffusion v1.5") |
|
|
|
|
|
def generate_image(prompt, model_name): |
|
global current_pipe |
|
|
|
if model_name not in current_pipe.config["_name_or_path"]: |
|
current_pipe = load_model(model_name) |
|
|
|
image = current_pipe(prompt).images[0] |
|
return image |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("### Multi-Model Text-to-Image Generator") |
|
with gr.Row(): |
|
with gr.Column(): |
|
text_input = gr.Textbox(label="Enter a text prompt", placeholder="Describe the image you want...") |
|
model_selector = gr.Dropdown( |
|
label="Select Model", choices=list(models.keys()), value="Stable Diffusion v1.5" |
|
) |
|
generate_button = gr.Button("Generate Image") |
|
with gr.Column(): |
|
output_image = gr.Image(label="Generated Image") |
|
with gr.Column(): |
|
output_image2 = gr.Image(label= "Generated Image 2") |
|
|
|
generate_button.click( |
|
generate_image, inputs=[text_input, model_selector], outputs=output_image |
|
) |
|
|
|
demo.launch(share=True) |