creativity_hub / image_to_image.py
joyson's picture
Upload 5 files
9d9968c verified
raw
history blame
2.07 kB
import torch
from diffusers import StableDiffusionXLImg2ImgPipeline, EulerDiscreteScheduler
from PIL import Image
from io import BytesIO
from utils import load_unet_model
class ImageToImage:
"""
Class to handle Image-to-Image transformations using Stable Diffusion XL.
"""
def __init__(self, device="cpu"):
# Model and repository details
self.base = "stabilityai/stable-diffusion-xl-base-1.0"
self.repo = "ByteDance/SDXL-Lightning"
self.ckpt = "sdxl_lightning_4step_unet.safetensors"
self.device = device
# Load the UNet model
print("Loading Image-to-Image model...")
self.unet = load_unet_model(self.base, self.repo, self.ckpt, device=self.device)
# Initialize the pipeline
self.pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
self.base,
unet=self.unet,
torch_dtype=torch.float32
).to(self.device)
# Set the scheduler
self.pipe.scheduler = EulerDiscreteScheduler.from_config(
self.pipe.scheduler.config,
timestep_spacing="trailing"
)
print("Image-to-Image model loaded successfully.")
async def transform_image(self, image, prompt):
"""
Transform an uploaded image based on a text prompt.
Args:
image (PIL.Image): The input image to transform.
prompt (str): The text prompt to guide the transformation.
Returns:
PIL.Image: The transformed image.
"""
if not prompt:
raise ValueError("Prompt cannot be empty.")
# Resize the image as required by the model
init_image = image.resize((768, 512))
with torch.no_grad():
transformed_image = self.pipe(
prompt=prompt,
image=init_image,
strength=0.75,
guidance_scale=7.5
).images[0]
return transformed_image