from typing import Dict, List, Any import base64 from PIL import Image from io import BytesIO from diffusers import StableDiffusionControlNetPipeline, ControlNetModel from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker import torch import numpy as np import cv2 import controlnet_hinter # set device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device.type != 'cuda': raise ValueError("need to run on GPU") # set mixed precision dtype dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16 CONTROLNET_MAPPING = { "depth": { "model_id": "lllyasviel/sd-controlnet-depth", "hinter": controlnet_hinter.hint_depth }, } SD_ID_MAPPING = { "default": "Lykon/dreamshaper-8", # "dreamshaper": "stablediffusionapi/dreamshaper-xl", # "juggernaut": "stablediffusionapi/juggernaut-xl-v8", # "realistic":"SG161222/Realistic_Vision_V1.4", # "rev":"s6yx/ReV_Animated" } class EndpointHandler(): def __init__(self, path=""): self.control_type = "depth" self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],torch_dtype=dtype).to(device) self.stable_diffusion_id_0 = "Lykon/dreamshaper-8" self.dreamshaper = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id_0, controlnet=self.controlnet, torch_dtype=dtype, safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=dtype)).to("cuda") # self.stable_diffusion_id_1 = "stablediffusionapi/dreamshaper-xl" # self.juggernaut = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id_1, # controlnet=self.controlnet, # torch_dtype=dtype, # safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=dtype)).to("cuda") # self.stable_diffusion_id_2 = "stablediffusionapi/juggernaut-xl-v8" # self.realistic = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id_2, # controlnet=self.controlnet, # torch_dtype=dtype, # safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=dtype)).to("cuda") # self.stable_diffusion_id_3 = "SG161222/Realistic_Vision_V1.4" # self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id_3, # controlnet=self.controlnet, # torch_dtype=dtype, # safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=dtype)).to("cuda") # self.stable_diffusion_id_4 = "s6yx/ReV_Animated" # self.rev = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id_4, # controlnet=self.controlnet, # torch_dtype=dtype, # safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=dtype)).to("cuda") # Define Generator with seed self.generator = torch.Generator(device=device.type).manual_seed(3) def __call__(self, data: Any) -> List[List[Dict[str, float]]]: """ :param data: A dictionary contains `prompt` and optional `image_depth_map` field. :return: A dictionary with `image` field contains image in base64. """ # hyperparamters sd_model = data.pop("sd_model", "default") prompt = data.pop("inputs", None) negative_prompt = data.pop("negative_prompt", None) image_depth_map = data.pop("image_depth_map", None) steps = data.pop("steps", 25) scale = data.pop("scale", 7) height = data.pop("height", None) width = data.pop("width", None) controlnet_conditioning_scale = data.pop("controlnet_conditioning_scale", 1.0) if sd_model is None or not hasattr(self, sd_model): return {"error": "Modelo SD no especificado o no vĂ¡lido"} if prompt is None: return {"error": "Please provide a prompt"} if(image_depth_map is None): return {"error": "Please provide a image_depth_map"} pipe = getattr(self, sd_model) # process image image = self.decode_base64_image(image_depth_map) # run inference pipeline out = pipe( prompt=prompt, negative_prompt=negative_prompt, image=image, num_inference_steps=steps, guidance_scale=scale, num_images_per_prompt=1, height=height, width=width, controlnet_conditioning_scale=controlnet_conditioning_scale, generator=self.generator ) # return first generate PIL image return out.images[0] # helper to decode input image def decode_base64_image(self, image_string): base64_image = base64.b64decode(image_string) buffer = BytesIO(base64_image) image = Image.open(buffer) return image