import torch import spaces from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler from PIL import Image from io import BytesIO from utils import load_unet_model @spaces.GPU class TextToImage: """ Class to handle Text-to-Image generation using Stable Diffusion XL. """ def __init__(self, device="cpu"): # Model and repository details self.base = "stabilityai/stable-diffusion-xl-base-1.0" self.repo = "ByteDance/SDXL-Lightning" self.ckpt = "sdxl_lightning_4step_unet.safetensors" self.device = device # Load the UNet model print("Loading Text-to-Image model...") self.unet = load_unet_model(self.base, self.repo, self.ckpt, device=self.device) # Initialize the pipeline self.pipe = StableDiffusionXLPipeline.from_pretrained( self.base, unet=self.unet, torch_dtype=torch.float32, ).to(self.device) # Set the scheduler self.pipe.scheduler = EulerDiscreteScheduler.from_config( self.pipe.scheduler.config, timestep_spacing="trailing" ) print("Text-to-Image model loaded successfully.") async def generate_image(self, prompt): """ Generate an image from a text prompt. Args: prompt (str): The text prompt to generate the image. Returns: PIL.Image: The generated image. """ with torch.no_grad(): image = self.pipe( prompt, num_inference_steps=4, guidance_scale=0 ).images[0] return image