from typing import List, Union import cv2 import numpy as np import torch from controlnet_aux import HEDdetector, LineartDetector, OpenposeDetector from diffusers import ( ControlNetModel, DiffusionPipeline, StableDiffusionControlNetPipeline, UniPCMultistepScheduler, ) from PIL import Image from torch.nn import Linear from tqdm import gui from internals.data.result import Result from internals.pipelines.commons import AbstractPipeline from internals.pipelines.tileUpscalePipeline import ( StableDiffusionControlNetImg2ImgPipeline, ) from internals.util.cache import clear_cuda_and_gc from internals.util.commons import download_image class ControlNet(AbstractPipeline): __current_task_name = "" def load(self, model_dir: str): # we will load canny by default self.load_scribble() # controlnet pipeline for tile upscaler pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( model_dir, controlnet=self.controlnet, torch_dtype=torch.float16, ).to("cuda") # pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) pipe.enable_model_cpu_offload() pipe.enable_xformers_memory_efficient_attention() self.pipe = pipe # controlnet pipeline for canny and pose pipe2 = StableDiffusionControlNetPipeline(**pipe.components).to("cuda") pipe2.scheduler = UniPCMultistepScheduler.from_config(pipe2.scheduler.config) pipe2.enable_xformers_memory_efficient_attention() self.pipe2 = pipe2 def load_canny(self): if self.__current_task_name == "canny": return canny = ControlNetModel.from_pretrained( "lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16 ).to("cuda") self.__current_task_name = "canny" self.controlnet = canny if hasattr(self, "pipe"): self.pipe.controlnet = canny if hasattr(self, "pipe2"): self.pipe2.controlnet = canny clear_cuda_and_gc() def load_pose(self): if self.__current_task_name == "pose": return pose = ControlNetModel.from_pretrained( "lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16 ).to("cuda") self.__current_task_name = "pose" self.controlnet = pose if hasattr(self, "pipe"): self.pipe.controlnet = pose if hasattr(self, "pipe2"): self.pipe2.controlnet = pose clear_cuda_and_gc() def load_tile_upscaler(self): if self.__current_task_name == "tile_upscaler": return tile_upscaler = ControlNetModel.from_pretrained( "lllyasviel/control_v11f1e_sd15_tile", torch_dtype=torch.float16 ).to("cuda") self.__current_task_name = "tile_upscaler" self.controlnet = tile_upscaler if hasattr(self, "pipe"): self.pipe.controlnet = tile_upscaler if hasattr(self, "pipe2"): self.pipe2.controlnet = tile_upscaler clear_cuda_and_gc() def load_scribble(self): if self.__current_task_name == "scribble": return scribble = ControlNetModel.from_pretrained( "lllyasviel/control_v11p_sd15_scribble", torch_dtype=torch.float16 ).to("cuda") self.__current_task_name = "scribble" self.controlnet = scribble if hasattr(self, "pipe"): self.pipe.controlnet = scribble if hasattr(self, "pipe2"): self.pipe2.controlnet = scribble clear_cuda_and_gc() def load_linearart(self): if self.__current_task_name == "linearart": return linearart = ControlNetModel.from_pretrained( "ControlNet-1-1-preview/control_v11p_sd15_lineart", torch_dtype=torch.float16, ).to("cuda") self.__current_task_name = "linearart" self.controlnet = linearart if hasattr(self, "pipe"): self.pipe.controlnet = linearart if hasattr(self, "pipe2"): self.pipe2.controlnet = linearart clear_cuda_and_gc() def cleanup(self): self.pipe.controlnet = None self.pipe2.controlnet = None self.controlnet = None self.__current_task_name = "" clear_cuda_and_gc() @torch.inference_mode() def process_canny( self, prompt: List[str], imageUrl: str, seed: int, steps: int, negative_prompt: List[str], guidance_scale: float, height: int, width: int, ): if self.__current_task_name != "canny": raise Exception("ControlNet is not loaded with canny model") torch.manual_seed(seed) init_image = download_image(imageUrl).resize((width, height)) init_image = self.__canny_detect_edge(init_image) result = self.pipe2.__call__( prompt=prompt, image=init_image, guidance_scale=guidance_scale, num_images_per_prompt=1, negative_prompt=negative_prompt, num_inference_steps=steps, height=height, width=width, ) return Result.from_result(result) @torch.inference_mode() def process_pose( self, prompt: List[str], image: List[Image.Image], seed: int, steps: int, guidance_scale: float, negative_prompt: List[str], height: int, width: int, ): if self.__current_task_name != "pose": raise Exception("ControlNet is not loaded with pose model") torch.manual_seed(seed) result = self.pipe2.__call__( prompt=prompt, image=image, num_images_per_prompt=1, num_inference_steps=steps, negative_prompt=negative_prompt, guidance_scale=guidance_scale, height=height, width=width, ) return Result.from_result(result) @torch.inference_mode() def process_tile_upscaler( self, imageUrl: str, prompt: str, negative_prompt: str, steps: int, seed: int, height: int, width: int, resize_dimension: int, guidance_scale: float, ): if self.__current_task_name != "tile_upscaler": raise Exception("ControlNet is not loaded with tile_upscaler model") torch.manual_seed(seed) init_image = download_image(imageUrl).resize((width, height)) condition_image = self.__resize_for_condition_image( init_image, resize_dimension ) result = self.pipe.__call__( image=condition_image, prompt=prompt, controlnet_conditioning_image=condition_image, num_inference_steps=steps, negative_prompt=negative_prompt, height=condition_image.size[1], width=condition_image.size[0], guidance_scale=guidance_scale, ) return Result.from_result(result) @torch.inference_mode() def process_scribble( self, imageUrl: Union[str, Image.Image], prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]], steps: int, seed: int, height: int, width: int, guidance_scale: float = 7.5, ): if self.__current_task_name != "scribble": raise Exception("ControlNet is not loaded with scribble model") torch.manual_seed(seed) if isinstance(imageUrl, Image.Image): init_image = imageUrl.resize((width, height)) else: init_image = download_image(imageUrl).resize((width, height)) condition_image = self.__scribble_condition_image(init_image) result = self.pipe2.__call__( image=condition_image, prompt=prompt, num_inference_steps=steps, negative_prompt=negative_prompt, height=height, width=width, guidance_scale=guidance_scale, ) return Result.from_result(result) @torch.inference_mode() def process_linearart( self, imageUrl: str, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]], steps: int, seed: int, height: int, width: int, guidance_scale: float = 7.5, ): if self.__current_task_name != "linearart": raise Exception("ControlNet is not loaded with linearart model") torch.manual_seed(seed) init_image = download_image(imageUrl).resize((width, height)) condition_image = ControlNet.linearart_condition_image(init_image) result = self.pipe2.__call__( image=condition_image, prompt=prompt, num_inference_steps=steps, negative_prompt=negative_prompt, height=height, width=width, guidance_scale=guidance_scale, ) return Result.from_result(result) def detect_pose(self, imageUrl: str) -> Image.Image: detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") image = download_image(imageUrl) image = detector.__call__(image, hand_and_face=True) return image def __scribble_condition_image(self, image: Image.Image) -> Image.Image: processor = HEDdetector.from_pretrained("lllyasviel/Annotators") image = processor.__call__(input_image=image, scribble=True) return image @staticmethod def linearart_condition_image(image: Image.Image) -> Image.Image: processor = LineartDetector.from_pretrained("lllyasviel/Annotators") image = processor.__call__(input_image=image) return image def __canny_detect_edge(self, image: Image.Image) -> Image.Image: image_array = np.array(image) low_threshold = 100 high_threshold = 200 image_array = cv2.Canny(image_array, low_threshold, high_threshold) image_array = image_array[:, :, None] image_array = np.concatenate([image_array, image_array, image_array], axis=2) canny_image = Image.fromarray(image_array) return canny_image def __resize_for_condition_image(self, image: Image.Image, resolution: int): input_image = image.convert("RGB") W, H = input_image.size k = float(resolution) / min(W, H) H *= k W *= k H = int(round(H / 64.0)) * 64 W = int(round(W / 64.0)) * 64 img = input_image.resize((W, H), resample=Image.LANCZOS) return img