from typing import AbstractSet, List, Literal, Optional, Union import cv2 import numpy as np import torch from controlnet_aux import ( HEDdetector, LineartDetector, OpenposeDetector, PidiNetDetector, ) from diffusers import ( ControlNetModel, DiffusionPipeline, StableDiffusionAdapterPipeline, StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetPipeline, StableDiffusionXLAdapterPipeline, StableDiffusionXLControlNetPipeline, T2IAdapter, UniPCMultistepScheduler, ) from diffusers.pipelines.controlnet import MultiControlNetModel from PIL import Image from pydash import has from torch.nn import Linear from tqdm import gui from transformers import pipeline import internals.util.image as ImageUtil from external.midas import apply_midas from internals.data.result import Result from internals.pipelines.commons import AbstractPipeline from internals.util.cache import clear_cuda_and_gc from internals.util.commons import download_image from internals.util.config import ( get_hf_cache_dir, get_hf_token, get_is_sdxl, get_model_dir, ) CONTROLNET_TYPES = Literal["pose", "canny", "scribble", "linearart", "tile_upscaler"] class StableDiffusionNetworkModelPipelineLoader: """Loads the pipeline for network module, eg: controlnet or t2i. Does not throw error in case of unsupported configurations, instead it returns None. """ def __new__( cls, is_sdxl, is_img2img, network_model, pipeline_type, base_pipe: Optional[AbstractSet] = None, ): if is_sdxl and is_img2img: # Does not matter pipeline type but tile upscale is not supported print("Warning: Tile upscale is not supported on SDXL") return None if base_pipe is None: pretrained = True kwargs = { "pretrained_model_name_or_path": get_model_dir(), "torch_dtype": torch.float16, "use_auth_token": get_hf_token(), "cache_dir": get_hf_cache_dir(), } else: pretrained = False kwargs = { **base_pipe.pipe.components, # pyright: ignore } if is_sdxl and pipeline_type == "controlnet": model = ( StableDiffusionXLControlNetPipeline.from_pretrained if pretrained else StableDiffusionXLControlNetPipeline ) return model(controlnet=network_model, **kwargs).to("cuda") if is_sdxl and pipeline_type == "t2i": model = ( StableDiffusionXLAdapterPipeline.from_pretrained if pretrained else StableDiffusionXLAdapterPipeline ) return model(adapter=network_model, **kwargs).to("cuda") if is_img2img and pipeline_type == "controlnet": model = ( StableDiffusionControlNetImg2ImgPipeline.from_pretrained if pretrained else StableDiffusionControlNetImg2ImgPipeline ) return model(controlnet=network_model, **kwargs).to("cuda") if pipeline_type == "controlnet": model = ( StableDiffusionControlNetPipeline.from_pretrained if pretrained else StableDiffusionControlNetPipeline ) return model(controlnet=network_model, **kwargs).to("cuda") if pipeline_type == "t2i": model = ( StableDiffusionAdapterPipeline.from_pretrained if pretrained else StableDiffusionAdapterPipeline ) return model(adapter=network_model, **kwargs).to("cuda") print( f"Warning: Unsupported configuration {is_sdxl=}, {is_img2img=}, {pipeline_type=}" ) return None class ControlNet(AbstractPipeline): __current_task_name = "" __loaded = False __pipe_type = None def init(self, pipeline: AbstractPipeline): setattr(self, "__pipeline", pipeline) def load_model(self, task_name: CONTROLNET_TYPES): "Appropriately loads the network module, pipelines and cache it for reuse." config = self.__model_sdxl if get_is_sdxl() else self.__model_normal if self.__current_task_name == task_name: return model = config[task_name] if not model: raise Exception(f"ControlNet is not supported for {task_name}") while model in list(config.keys()): task_name = model # pyright: ignore model = config[task_name] pipeline_type = ( self.__model_sdxl_types[task_name] if get_is_sdxl() else self.__model_normal_types[task_name] ) if "," in model: model = [m.strip() for m in model.split(",")] model = self.__load_network_model(model, pipeline_type) self.__load_pipeline(model, pipeline_type) self.__current_task_name = task_name clear_cuda_and_gc() def __load_network_model(self, model_name, pipeline_type): "Loads the network module, eg: ControlNet or T2I Adapters" def load_controlnet(model): return ControlNetModel.from_pretrained( model, torch_dtype=torch.float16, cache_dir=get_hf_cache_dir(), ).to("cuda") def load_t2i(model): return T2IAdapter.from_pretrained( model, torch_dtype=torch.float16, varient="fp16", ).to("cuda") if type(model_name) == str: if pipeline_type == "controlnet": return load_controlnet(model_name) if pipeline_type == "t2i": return load_t2i(model_name) raise Exception("Invalid pipeline type") elif type(model_name) == list: if pipeline_type == "controlnet": cns = [] for model in model_name: cns.append(load_controlnet(model)) return MultiControlNetModel(cns).to("cuda") elif pipeline_type == "t2i": raise Exception("Multi T2I adapters are not supported") raise Exception("Invalid pipeline type") def __load_pipeline(self, network_model, pipeline_type): "Load the base pipeline(s) (if not loaded already) based on pipeline type and attaches the network module to the pipeline" def patch_pipe(pipe): if not pipe: # cases where the loader may return None return None if get_is_sdxl(): pipe.enable_vae_tiling() pipe.enable_vae_slicing() pipe.enable_xformers_memory_efficient_attention() else: pipe.enable_xformers_memory_efficient_attention() return pipe # If the pipeline type is changed we should reload all # the pipelines if not self.__loaded or self.__pipe_type != pipeline_type: # controlnet pipeline for tile upscaler pipe = StableDiffusionNetworkModelPipelineLoader( is_sdxl=get_is_sdxl(), is_img2img=True, network_model=network_model, pipeline_type=pipeline_type, base_pipe=getattr(self, "__pipeline", None), ) pipe = patch_pipe(pipe) if pipe: self.pipe = pipe # controlnet pipeline for canny and pose pipe2 = StableDiffusionNetworkModelPipelineLoader( is_sdxl=get_is_sdxl(), is_img2img=False, network_model=network_model, pipeline_type=pipeline_type, base_pipe=getattr(self, "__pipeline", None), ) pipe2 = patch_pipe(pipe2) if pipe2: self.pipe2 = pipe2 self.__loaded = True self.__pipe_type = pipeline_type # Set the network module in the pipeline if pipeline_type == "controlnet": if hasattr(self, "pipe"): setattr(self.pipe, "controlnet", network_model) if hasattr(self, "pipe2"): setattr(self.pipe2, "controlnet", network_model) elif pipeline_type == "t2i": if hasattr(self, "pipe"): setattr(self.pipe, "adapter", network_model) if hasattr(self, "pipe2"): setattr(self.pipe2, "adapter", network_model) if hasattr(self, "pipe"): self.pipe = self.pipe.to("cuda") if hasattr(self, "pipe2"): self.pipe2 = self.pipe2.to("cuda") clear_cuda_and_gc() def process(self, **kwargs): if self.__current_task_name == "pose": return self.process_pose(**kwargs) if self.__current_task_name == "canny": return self.process_canny(**kwargs) if self.__current_task_name == "scribble": return self.process_scribble(**kwargs) if self.__current_task_name == "linearart": return self.process_linearart(**kwargs) if self.__current_task_name == "tile_upscaler": return self.process_tile_upscaler(**kwargs) raise Exception("ControlNet is not loaded with any model") @torch.inference_mode() def process_canny( self, prompt: List[str], imageUrl: str, seed: int, num_inference_steps: int, negative_prompt: List[str], height: int, width: int, guidance_scale: float = 9, **kwargs, ): if self.__current_task_name != "canny": raise Exception("ControlNet is not loaded with canny model") torch.manual_seed(seed) init_image = download_image(imageUrl).resize((width, height)) init_image = ControlNet.canny_detect_edge(init_image) kwargs = { "prompt": prompt, "image": init_image, "guidance_scale": guidance_scale, "num_images_per_prompt": 1, "negative_prompt": negative_prompt, "num_inference_steps": num_inference_steps, "height": height, "width": width, **kwargs, } result = self.pipe2.__call__(**kwargs) return Result.from_result(result) @torch.inference_mode() def process_pose( self, prompt: List[str], image: List[Image.Image], seed: int, num_inference_steps: int, negative_prompt: List[str], height: int, width: int, guidance_scale: float = 7.5, **kwargs, ): if self.__current_task_name != "pose": raise Exception("ControlNet is not loaded with pose model") torch.manual_seed(seed) kwargs = { "prompt": prompt[0], "image": image, "num_images_per_prompt": 4, "num_inference_steps": num_inference_steps, "negative_prompt": negative_prompt[0], "guidance_scale": guidance_scale, "height": height, "width": width, **kwargs, } print(kwargs) result = self.pipe2.__call__(**kwargs) return Result.from_result(result) @torch.inference_mode() def process_tile_upscaler( self, imageUrl: str, prompt: str, negative_prompt: str, num_inference_steps: int, seed: int, height: int, width: int, resize_dimension: int, guidance_scale: float = 7.5, **kwargs, ): if self.__current_task_name != "tile_upscaler": raise Exception("ControlNet is not loaded with tile_upscaler model") torch.manual_seed(seed) init_image = download_image(imageUrl).resize((width, height)) condition_image = self.__resize_for_condition_image( init_image, resize_dimension ) kwargs = { "image": condition_image, "prompt": prompt, "control_image": condition_image, "num_inference_steps": num_inference_steps, "negative_prompt": negative_prompt, "height": condition_image.size[1], "width": condition_image.size[0], "guidance_scale": guidance_scale, **kwargs, } result = self.pipe.__call__(**kwargs) return Result.from_result(result) @torch.inference_mode() def process_scribble( self, image: List[Image.Image], prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]], num_inference_steps: int, seed: int, height: int, width: int, guidance_scale: float = 7.5, **kwargs, ): if self.__current_task_name != "scribble": raise Exception("ControlNet is not loaded with scribble model") torch.manual_seed(seed) sdxl_args = ( { "guidance_scale": 6, "adapter_conditioning_scale": 0.6, "adapter_conditioning_factor": 1.0, } if get_is_sdxl() else {} ) kwargs = { "image": image, "prompt": prompt, "num_inference_steps": num_inference_steps, "negative_prompt": negative_prompt, "height": height, "width": width, "guidance_scale": guidance_scale, **sdxl_args, **kwargs, } result = self.pipe2.__call__(**kwargs) return Result.from_result(result) @torch.inference_mode() def process_linearart( self, imageUrl: str, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]], num_inference_steps: int, seed: int, height: int, width: int, guidance_scale: float = 7.5, **kwargs, ): if self.__current_task_name != "linearart": raise Exception("ControlNet is not loaded with linearart model") torch.manual_seed(seed) init_image = download_image(imageUrl).resize((width, height)) condition_image = ControlNet.linearart_condition_image(init_image) # we use t2i adapter and the conditioning scale should always be 0.8 sdxl_args = ( { "guidance_scale": 6, "adapter_conditioning_scale": 0.5, "adapter_conditioning_factor": 0.9, } if get_is_sdxl() else {} ) kwargs = { "image": [condition_image] * 4, "prompt": prompt, "num_inference_steps": num_inference_steps, "negative_prompt": negative_prompt, "height": height, "width": width, "guidance_scale": guidance_scale, **sdxl_args, **kwargs, } result = self.pipe2.__call__(**kwargs) return Result.from_result(result) def cleanup(self): """Doesn't do anything considering new diffusers has itself a cleanup mechanism after controlnet generation""" pass def detect_pose(self, imageUrl: str) -> Image.Image: detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") image = download_image(imageUrl) image = detector.__call__(image) return image @staticmethod def scribble_image(image: Image.Image) -> Image.Image: processor = HEDdetector.from_pretrained("lllyasviel/Annotators") image = processor.__call__(input_image=image, scribble=True) return image @staticmethod def linearart_condition_image(image: Image.Image) -> Image.Image: processor = LineartDetector.from_pretrained("lllyasviel/Annotators") image = processor.__call__(input_image=image) return image @staticmethod def depth_image(image: Image.Image) -> Image.Image: global midas, midas_transforms if "midas" not in globals(): midas = torch.hub.load("intel-isl/MiDaS", "MiDaS").to("cuda") midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms") transform = midas_transforms.default_transform cv_image = np.array(image) img = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB) input_batch = transform(img).to("cuda") with torch.no_grad(): prediction = midas(input_batch) prediction = torch.nn.functional.interpolate( prediction.unsqueeze(1), size=img.shape[:2], mode="bicubic", align_corners=False, ).squeeze() output = prediction.cpu().numpy() formatted = (output * 255 / np.max(output)).astype("uint8") img = Image.fromarray(formatted) return img @staticmethod def pidinet_image(image: Image.Image) -> Image.Image: pidinet = PidiNetDetector.from_pretrained("lllyasviel/Annotators").to("cuda") image = pidinet.__call__(input_image=image, apply_filter=True) return image @staticmethod def canny_detect_edge(image: Image.Image) -> Image.Image: image_array = np.array(image) low_threshold = 100 high_threshold = 200 image_array = cv2.Canny(image_array, low_threshold, high_threshold) image_array = image_array[:, :, None] image_array = np.concatenate([image_array, image_array, image_array], axis=2) canny_image = Image.fromarray(image_array) return canny_image def __resize_for_condition_image(self, image: Image.Image, resolution: int): input_image = image.convert("RGB") W, H = input_image.size k = float(resolution) / max(W, H) H *= k W *= k H = int(round(H / 64.0)) * 64 W = int(round(W / 64.0)) * 64 img = input_image.resize((W, H), resample=Image.LANCZOS) return img __model_normal = { "pose": "lllyasviel/control_v11f1p_sd15_depth, lllyasviel/control_v11p_sd15_openpose", "canny": "lllyasviel/control_v11p_sd15_canny", "linearart": "lllyasviel/control_v11p_sd15_lineart", "scribble": "lllyasviel/control_v11p_sd15_scribble", "tile_upscaler": "lllyasviel/control_v11f1e_sd15_tile", } __model_normal_types = { "pose": "controlnet", "canny": "controlnet", "linearart": "controlnet", "scribble": "controlnet", "tile_upscaler": "controlnet", } __model_sdxl = { "pose": "thibaud/controlnet-openpose-sdxl-1.0", "canny": "diffusers/controlnet-canny-sdxl-1.0", "linearart": "TencentARC/t2i-adapter-lineart-sdxl-1.0", "scribble": "TencentARC/t2i-adapter-sketch-sdxl-1.0", "tile_upscaler": None, } __model_sdxl_types = { "pose": "controlnet", "canny": "controlnet", "linearart": "t2i", "scribble": "t2i", "tile_upscaler": None, }