joshuaberkowitzus commited on
Commit
36c749c
·
verified ·
1 Parent(s): 938487f
Files changed (1) hide show
  1. app.py +116 -0
app.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from diffusers import DiffusionPipeline
3
+ import torch
4
+ import os
5
+
6
+ # Ensure necessary libraries are installed
7
+ # pip install diffusers --upgrade
8
+ # pip install invisible_watermark transformers accelerate safetensors gradio torch
9
+
10
+ model_id = "stabilityai/stable-diffusion-xl-base-1.0"
11
+
12
+ # Determine device and dtype
13
+ if torch.cuda.is_available():
14
+ device = "cuda"
15
+ dtype = torch.float16
16
+ print("Using CUDA (GPU).")
17
+ # elif torch.backends.mps.is_available(): # Uncomment for MacOS Metal support
18
+ # device = "mps"
19
+ # dtype = torch.float16
20
+ # print("Using MPS (Apple Silicon GPU).")
21
+ else:
22
+ device = "cpu"
23
+ dtype = torch.float32
24
+ print("Using CPU.")
25
+
26
+ # Load the Stable Diffusion XL pipeline
27
+ # Using float16 and safetensors for efficiency if on GPU
28
+ # variant="fp16" loads the fp16 weights
29
+ try:
30
+ pipe = DiffusionPipeline.from_pretrained(
31
+ model_id,
32
+ torch_dtype=dtype,
33
+ use_safetensors=True,
34
+ variant="fp16" if device!= "cpu" else None # Only use fp16 variant if not on CPU
35
+ )
36
+ pipe.to(device)
37
+
38
+ # Optional: Enable CPU offloading if VRAM is limited (only works on CUDA)
39
+ if device == "cuda":
40
+ try:
41
+ # Check VRAM - this is a rough estimate, adjust threshold as needed
42
+ total_vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
43
+ if total_vram_gb < 10: # Example threshold: less than 10GB VRAM
44
+ print(f"Low VRAM ({total_vram_gb:.2f}GB detected). Enabling model CPU offload.")
45
+ pipe.enable_model_cpu_offload()
46
+ except Exception as offload_err:
47
+ print(f"Could not check VRAM or enable offload: {offload_err}")
48
+
49
+
50
+ # Optional: Use torch.compile for speedup (requires torch >= 2.0)
51
+ # if device!= "cpu" and hasattr(torch, "compile"):
52
+ # try:
53
+ # print("Attempting to compile the UNet...")
54
+ # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
55
+ # print("UNet compiled successfully.")
56
+ # except Exception as compile_err:
57
+ # print(f"Torch compile failed: {compile_err}")
58
+
59
+ print(f"SDXL pipeline loaded successfully on {device}.")
60
+
61
+ except Exception as e:
62
+ print(f"Error loading SDXL pipeline: {e}")
63
+ pipe = None
64
+
65
+ def generate_image(prompt):
66
+ """Generates an image based on the text prompt."""
67
+ if pipe is None:
68
+ # Handle case where pipeline failed to load
69
+ # Create a placeholder image or return an error message
70
+ from PIL import Image, ImageDraw, ImageFont
71
+ img = Image.new('RGB', (512, 512), color = (200, 200, 200))
72
+ d = ImageDraw.Draw(img)
73
+ try:
74
+ # Try to load a default font
75
+ fnt = ImageFont.truetype("arial.ttf", 15)
76
+ except IOError:
77
+ fnt = ImageFont.load_default()
78
+ d.text((10,10), "Error: Model pipeline failed to load.", fill=(255,0,0), font=fnt)
79
+ return img
80
+
81
+ if not prompt:
82
+ return None # Return nothing if prompt is empty
83
+
84
+ print(f"Generating image for prompt: '{prompt}'")
85
+ try:
86
+ # Generate the image
87
+ # Using default steps/guidance scale, can be customized
88
+ with torch.inference_mode(): # Use inference mode for efficiency
89
+ image = pipe(prompt=prompt, num_inference_steps=30).images
90
+ print("Image generated successfully.")
91
+ return image
92
+ except Exception as e:
93
+ print(f"Error during image generation: {e}")
94
+ # Return an error image or message
95
+ from PIL import Image, ImageDraw, ImageFont
96
+ img = Image.new('RGB', (512, 512), color = (200, 200, 200))
97
+ d = ImageDraw.Draw(img)
98
+ try: fnt = ImageFont.truetype("arial.ttf", 15)
99
+ except IOError: fnt = ImageFont.load_default()
100
+ d.text((10,10), f"Error generating image:\n{e}", fill=(255,0,0), font=fnt)
101
+ return img
102
+
103
+
104
+ # Create the Gradio interface
105
+ demo = gr.Interface(
106
+ fn=generate_image,
107
+ inputs=gr.Textbox(label="Enter Text Prompt", placeholder="e.g., 'An astronaut riding a green horse'"),
108
+ outputs=gr.Image(label="Generated Image", type="pil"),
109
+ title="Text-to-Image Generation with Stable Diffusion XL",
110
+ description=f"Generate images from text prompts using the {model_id} model. Loading and inference might take a moment, especially on the first run or on CPU.",
111
+ examples=["A high-tech cityscape at sunset, cinematic lighting"]
112
+ )
113
+
114
+ if __name__ == "__main__":
115
+ # Launch the Gradio app
116
+ demo.launch(debug=True)