endyaris commited on
Commit
db8548c
·
verified ·
1 Parent(s): c2a9e07

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -14
app.py CHANGED
@@ -2,27 +2,50 @@ import gradio as gr
2
  from diffusers import StableDiffusionPipeline
3
  import torch
4
 
5
- # Load the Stable Diffusion pipeline
6
- pipe = StableDiffusionPipeline.from_pretrained(
7
- "runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16
8
- )
9
- pipe = pipe.to("cuda") # Use GPU for faster generation
 
10
 
11
- def generate_image(prompt):
12
- # Generate an image based on the text prompt
13
- image = pipe(prompt).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  return image
15
 
16
- # Create the Gradio interface
17
  with gr.Blocks() as demo:
18
- gr.Markdown("### Text-to-Image Generator")
19
  with gr.Row():
20
  with gr.Column():
21
- text_input = gr.Textbox(label="Enter a text prompt")
22
- submit_button = gr.Button("Generate Image")
 
 
 
23
  with gr.Column():
24
  output_image = gr.Image(label="Generated Image")
25
-
26
- submit_button.click(generate_image, inputs=text_input, outputs=output_image)
 
 
27
 
28
  demo.launch()
 
2
  from diffusers import StableDiffusionPipeline
3
  import torch
4
 
5
+ # Load different models
6
+ models = {
7
+ "Stable Diffusion v1.5": "runwayml/stable-diffusion-v1-5",
8
+ "Stable Diffusion v2.1": "stabilityai/stable-diffusion-2-1",
9
+ "Anime Diffusion": "hakurei/waifu-diffusion-v1-4",
10
+ }
11
 
12
+ # Function to load the selected model
13
+ def load_model(model_name):
14
+ model_id = models[model_name]
15
+ pipe = StableDiffusionPipeline.from_pretrained(
16
+ model_id, torch_dtype=torch.float16
17
+ )
18
+ pipe = pipe.to("cuda") # Use GPU
19
+ return pipe
20
+
21
+ # Load the default model
22
+ current_pipe = load_model("Stable Diffusion v1.5")
23
+
24
+ # Function to generate image
25
+ def generate_image(prompt, model_name):
26
+ global current_pipe
27
+ # Reload pipeline if the model changes
28
+ if model_name not in current_pipe.config["_name_or_path"]:
29
+ current_pipe = load_model(model_name)
30
+ # Generate the image
31
+ image = current_pipe(prompt).images[0]
32
  return image
33
 
34
+ # Create Gradio interface
35
  with gr.Blocks() as demo:
36
+ gr.Markdown("### Multi-Model Text-to-Image Generator")
37
  with gr.Row():
38
  with gr.Column():
39
+ text_input = gr.Textbox(label="Enter a text prompt", placeholder="Describe the image you want...")
40
+ model_selector = gr.Dropdown(
41
+ label="Select Model", choices=list(models.keys()), value="Stable Diffusion v1.5"
42
+ )
43
+ generate_button = gr.Button("Generate Image")
44
  with gr.Column():
45
  output_image = gr.Image(label="Generated Image")
46
+
47
+ generate_button.click(
48
+ generate_image, inputs=[text_input, model_selector], outputs=output_image
49
+ )
50
 
51
  demo.launch()