|
import gradio as gr |
|
from diffusers import DiffusionPipeline |
|
import torch |
|
import os |
|
|
|
|
|
|
|
|
|
|
|
model_id = "stabilityai/stable-diffusion-xl-base-1.0" |
|
|
|
|
|
if torch.cuda.is_available(): |
|
device = "cuda" |
|
dtype = torch.float16 |
|
print("Using CUDA (GPU).") |
|
|
|
|
|
|
|
|
|
else: |
|
device = "cpu" |
|
dtype = torch.float32 |
|
print("Using CPU.") |
|
|
|
|
|
|
|
|
|
try: |
|
pipe = DiffusionPipeline.from_pretrained( |
|
model_id, |
|
torch_dtype=dtype, |
|
use_safetensors=True, |
|
variant="fp16" if device!= "cpu" else None |
|
) |
|
pipe.to(device) |
|
|
|
|
|
if device == "cuda": |
|
try: |
|
|
|
total_vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3) |
|
if total_vram_gb < 10: |
|
print(f"Low VRAM ({total_vram_gb:.2f}GB detected). Enabling model CPU offload.") |
|
pipe.enable_model_cpu_offload() |
|
except Exception as offload_err: |
|
print(f"Could not check VRAM or enable offload: {offload_err}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"SDXL pipeline loaded successfully on {device}.") |
|
|
|
except Exception as e: |
|
print(f"Error loading SDXL pipeline: {e}") |
|
pipe = None |
|
|
|
def generate_image(prompt): |
|
"""Generates an image based on the text prompt.""" |
|
if pipe is None: |
|
|
|
|
|
from PIL import Image, ImageDraw, ImageFont |
|
img = Image.new('RGB', (512, 512), color = (200, 200, 200)) |
|
d = ImageDraw.Draw(img) |
|
try: |
|
|
|
fnt = ImageFont.truetype("arial.ttf", 15) |
|
except IOError: |
|
fnt = ImageFont.load_default() |
|
d.text((10,10), "Error: Model pipeline failed to load.", fill=(255,0,0), font=fnt) |
|
return img |
|
|
|
if not prompt: |
|
return None |
|
|
|
print(f"Generating image for prompt: '{prompt}'") |
|
try: |
|
|
|
|
|
with torch.inference_mode(): |
|
image = pipe(prompt=prompt, num_inference_steps=30).images |
|
print("Image generated successfully.") |
|
return image |
|
except Exception as e: |
|
print(f"Error during image generation: {e}") |
|
|
|
from PIL import Image, ImageDraw, ImageFont |
|
img = Image.new('RGB', (512, 512), color = (200, 200, 200)) |
|
d = ImageDraw.Draw(img) |
|
try: fnt = ImageFont.truetype("arial.ttf", 15) |
|
except IOError: fnt = ImageFont.load_default() |
|
d.text((10,10), f"Error generating image:\n{e}", fill=(255,0,0), font=fnt) |
|
return img |
|
|
|
|
|
|
|
demo = gr.Interface( |
|
fn=generate_image, |
|
inputs=gr.Textbox(label="Enter Text Prompt", placeholder="e.g., 'An astronaut riding a green horse'"), |
|
outputs=gr.Image(label="Generated Image", type="pil"), |
|
title="Text-to-Image Generation with Stable Diffusion XL", |
|
description=f"Generate images from text prompts using the {model_id} model. Loading and inference might take a moment, especially on the first run or on CPU.", |
|
examples=["A high-tech cityscape at sunset, cinematic lighting"] |
|
) |
|
|
|
if __name__ == "__main__": |
|
|
|
demo.launch(debug=True) |