Spaces:
Runtime error
Runtime error
import torch | |
from diffusers import StableDiffusionXLImg2ImgPipeline, EulerDiscreteScheduler | |
from PIL import Image | |
from io import BytesIO | |
from utils import load_unet_model | |
class ImageToImage: | |
""" | |
Class to handle Image-to-Image transformations using Stable Diffusion XL. | |
""" | |
def __init__(self, device="cpu"): | |
# Model and repository details | |
self.base = "stabilityai/stable-diffusion-xl-base-1.0" | |
self.repo = "ByteDance/SDXL-Lightning" | |
self.ckpt = "sdxl_lightning_4step_unet.safetensors" | |
self.device = device | |
# Load the UNet model | |
print("Loading Image-to-Image model...") | |
self.unet = load_unet_model(self.base, self.repo, self.ckpt, device=self.device) | |
# Initialize the pipeline | |
self.pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( | |
self.base, | |
unet=self.unet, | |
torch_dtype=torch.float32 | |
).to(self.device) | |
# Set the scheduler | |
self.pipe.scheduler = EulerDiscreteScheduler.from_config( | |
self.pipe.scheduler.config, | |
timestep_spacing="trailing" | |
) | |
print("Image-to-Image model loaded successfully.") | |
async def transform_image(self, image, prompt): | |
""" | |
Transform an uploaded image based on a text prompt. | |
Args: | |
image (PIL.Image): The input image to transform. | |
prompt (str): The text prompt to guide the transformation. | |
Returns: | |
PIL.Image: The transformed image. | |
""" | |
if not prompt: | |
raise ValueError("Prompt cannot be empty.") | |
# Resize the image as required by the model | |
init_image = image.resize((768, 512)) | |
with torch.no_grad(): | |
transformed_image = self.pipe( | |
prompt=prompt, | |
image=init_image, | |
strength=0.75, | |
guidance_scale=7.5 | |
).images[0] | |
return transformed_image | |