karimbenharrak's picture
Update handler.py
41ca444 verified
raw
history blame contribute delete
No virus
5.17 kB
from typing import Dict, List, Any
import torch
from diffusers import StableDiffusionXLImg2ImgPipeline, DiffusionPipeline, AutoencoderKL, DPMSolverMultistepScheduler, DDIMScheduler, StableDiffusionInpaintPipeline, AutoPipelineForInpainting, AutoPipelineForImage2Image, StableDiffusionControlNetInpaintPipeline, ControlNetModel
from PIL import Image
import base64
from io import BytesIO
from diffusers.image_processor import VaeImageProcessor
import numpy as np
# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type != 'cuda':
raise ValueError("need to run on GPU")
class EndpointHandler():
def __init__(self, path=""):
self.smooth_pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16
)
self.smooth_pipe.to("cuda")
self.vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0",
subfolder="vae", use_safetensors=True,
).to("cuda")
self.smooth_pipe.enable_model_cpu_offload()
self.smooth_pipe.enable_xformers_memory_efficient_attention()
self.controlnet = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16
)
self.pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, torch_dtype=torch.float16
)
self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
self.pipe.enable_model_cpu_offload()
self.pipe.enable_xformers_memory_efficient_attention()
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
"""
:param data: A dictionary contains `inputs` and optional `image` field.
:return: A dictionary with `image` field contains image in base64.
"""
method = data.pop("method", "rasterize")
if(method == "rasterize"):
encoded_image = data.pop("image", None)
prompt = data.pop("prompt", "")
num_inference_steps = data.pop("num_inference_steps", 50)
if encoded_image is not None:
image = self.decode_base64_image(encoded_image).convert('RGB')
image_processor = VaeImageProcessor();
latents = image_processor.preprocess(image)
latents = latents.to(device="cuda")
with torch.no_grad():
latents_dist = self.vae.encode(latents).latent_dist.sample() * self.vae.config.scaling_factor
self.smooth_pipe.enable_xformers_memory_efficient_attention()
out = self.smooth_pipe(prompt, image=latents_dist, num_inference_steps=num_inference_steps).images
return out
else:
encoded_image = data.pop("image", None)
encoded_mask_image = data.pop("mask_image", None)
prompt = data.pop("prompt", "")
negative_prompt = data.pop("negative_prompt", "")
method = data.pop("method", "slow")
strength = data.pop("strength", 0.2)
guidance_scale = data.pop("guidance_scale", 8.0)
num_inference_steps = data.pop("num_inference_steps", 20)
# process image
if encoded_image is not None and encoded_mask_image is not None:
image = self.decode_base64_image(encoded_image).convert("RGB")
mask_image = self.decode_base64_image(encoded_mask_image).convert("RGB")
else:
image = None
mask_image = None
control_image = self.make_inpaint_condition(image, mask_image)
# generate image
image = self.pipe(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
eta=1.0,
image=image,
mask_image=mask_image,
control_image=control_image,
guidance_scale=guidance_scale,
strength=strength
).images[0]
return image
# helper to decode input image
def decode_base64_image(self, image_string):
base64_image = base64.b64decode(image_string)
buffer = BytesIO(base64_image)
image = Image.open(buffer)
return image
def make_inpaint_condition(self, image, image_mask):
image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
image[image_mask > 0.5] = -1.0 # set as masked pixel
image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return image