File size: 1,743 Bytes
9d9968c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
import spaces
from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler
from PIL import Image
from io import BytesIO
from utils import load_unet_model

@spaces.GPU
class TextToImage:
    """

    Class to handle Text-to-Image generation using Stable Diffusion XL.

    """
    def __init__(self, device="cpu"):
        # Model and repository details
        self.base = "stabilityai/stable-diffusion-xl-base-1.0"
        self.repo = "ByteDance/SDXL-Lightning"
        self.ckpt = "sdxl_lightning_4step_unet.safetensors"
        self.device = device

        # Load the UNet model
        print("Loading Text-to-Image model...")
        self.unet = load_unet_model(self.base, self.repo, self.ckpt, device=self.device)
        
        # Initialize the pipeline
        self.pipe = StableDiffusionXLPipeline.from_pretrained(
            self.base, 
            unet=self.unet, 
            torch_dtype=torch.float32,
        ).to(self.device)
        
        # Set the scheduler
        self.pipe.scheduler = EulerDiscreteScheduler.from_config(
            self.pipe.scheduler.config, 
            timestep_spacing="trailing"
        )
        print("Text-to-Image model loaded successfully.")
    
    
    async def generate_image(self, prompt):
        """

        Generate an image from a text prompt.



        Args:

            prompt (str): The text prompt to generate the image.



        Returns:

            PIL.Image: The generated image.

        """
        with torch.no_grad():
            image = self.pipe(
                prompt, 
                num_inference_steps=4, 
                guidance_scale=0
            ).images[0]
        return image