File size: 2,073 Bytes
9d9968c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
from diffusers import StableDiffusionXLImg2ImgPipeline, EulerDiscreteScheduler
from PIL import Image
from io import BytesIO
from utils import load_unet_model

class ImageToImage:
    """

    Class to handle Image-to-Image transformations using Stable Diffusion XL.

    """
    def __init__(self, device="cpu"):
        # Model and repository details
        self.base = "stabilityai/stable-diffusion-xl-base-1.0"
        self.repo = "ByteDance/SDXL-Lightning"
        self.ckpt = "sdxl_lightning_4step_unet.safetensors"
        self.device = device

        # Load the UNet model
        print("Loading Image-to-Image model...")
        self.unet = load_unet_model(self.base, self.repo, self.ckpt, device=self.device)
        
        # Initialize the pipeline
        self.pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
            self.base, 
            unet=self.unet, 
            torch_dtype=torch.float32
        ).to(self.device)
        
        # Set the scheduler
        self.pipe.scheduler = EulerDiscreteScheduler.from_config(
            self.pipe.scheduler.config, 
            timestep_spacing="trailing"
        )
        print("Image-to-Image model loaded successfully.")
    
    
    async def transform_image(self, image, prompt):
        """

        Transform an uploaded image based on a text prompt.



        Args:

            image (PIL.Image): The input image to transform.

            prompt (str): The text prompt to guide the transformation.



        Returns:

            PIL.Image: The transformed image.

        """
        if not prompt:
            raise ValueError("Prompt cannot be empty.")

        # Resize the image as required by the model
        init_image = image.resize((768, 512))
        with torch.no_grad():
            transformed_image = self.pipe(
                prompt=prompt, 
                image=init_image, 
                strength=0.75, 
                guidance_scale=7.5
            ).images[0]
        return transformed_image