CM2000112 / internals /pipelines /controlnets.py
jayparmr's picture
Upload folder using huggingface_hub
22df957 verified
raw
history blame
20 kB
from typing import AbstractSet, List, Literal, Optional, Union
import cv2
import numpy as np
import torch
from controlnet_aux import (
HEDdetector,
LineartDetector,
OpenposeDetector,
PidiNetDetector,
)
from diffusers import (
ControlNetModel,
DiffusionPipeline,
EulerAncestralDiscreteScheduler,
StableDiffusionAdapterPipeline,
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionXLAdapterPipeline,
StableDiffusionXLControlNetPipeline,
T2IAdapter,
UniPCMultistepScheduler,
)
from diffusers.pipelines.controlnet import MultiControlNetModel
from PIL import Image
from pydash import has
from torch.nn import Linear
from tqdm import gui
from transformers import pipeline
import internals.util.image as ImageUtil
from external.midas import apply_midas
from internals.data.result import Result
from internals.pipelines.commons import AbstractPipeline
from internals.util.cache import clear_cuda_and_gc
from internals.util.commons import download_image
from internals.util.config import (
get_hf_cache_dir,
get_hf_token,
get_is_sdxl,
get_model_dir,
)
CONTROLNET_TYPES = Literal["pose", "canny", "scribble", "linearart", "tile_upscaler"]
class StableDiffusionNetworkModelPipelineLoader:
"""Loads the pipeline for network module, eg: controlnet or t2i.
Does not throw error in case of unsupported configurations, instead it returns None.
"""
def __new__(
cls,
is_sdxl,
is_img2img,
network_model,
pipeline_type,
base_pipe: Optional[AbstractSet] = None,
):
if is_sdxl and is_img2img:
# Does not matter pipeline type but tile upscale is not supported
print("Warning: Tile upscale is not supported on SDXL")
return None
if base_pipe is None:
pretrained = True
kwargs = {
"pretrained_model_name_or_path": get_model_dir(),
"torch_dtype": torch.float16,
"use_auth_token": get_hf_token(),
"cache_dir": get_hf_cache_dir(),
}
else:
pretrained = False
kwargs = {
**base_pipe.pipe.components, # pyright: ignore
}
if is_sdxl and pipeline_type == "controlnet":
model = (
StableDiffusionXLControlNetPipeline.from_pretrained
if pretrained
else StableDiffusionXLControlNetPipeline
)
return model(controlnet=network_model, **kwargs).to("cuda")
if is_sdxl and pipeline_type == "t2i":
model = (
StableDiffusionXLAdapterPipeline.from_pretrained
if pretrained
else StableDiffusionXLAdapterPipeline
)
return model(adapter=network_model, **kwargs).to("cuda")
if is_img2img and pipeline_type == "controlnet":
model = (
StableDiffusionControlNetImg2ImgPipeline.from_pretrained
if pretrained
else StableDiffusionControlNetImg2ImgPipeline
)
return model(controlnet=network_model, **kwargs).to("cuda")
if pipeline_type == "controlnet":
model = (
StableDiffusionControlNetPipeline.from_pretrained
if pretrained
else StableDiffusionControlNetPipeline
)
return model(controlnet=network_model, **kwargs).to("cuda")
if pipeline_type == "t2i":
model = (
StableDiffusionAdapterPipeline.from_pretrained
if pretrained
else StableDiffusionAdapterPipeline
)
return model(adapter=network_model, **kwargs).to("cuda")
print(
f"Warning: Unsupported configuration {is_sdxl=}, {is_img2img=}, {pipeline_type=}"
)
return None
class ControlNet(AbstractPipeline):
__current_task_name = ""
__loaded = False
__pipe_type = None
def init(self, pipeline: AbstractPipeline):
setattr(self, "__pipeline", pipeline)
def unload(self):
"Unloads the network module, pipelines and clears the cache."
if not self.__loaded:
return
self.__loaded = False
self.__pipe_type = None
self.__current_task_name = ""
if hasattr(self, "pipe"):
delattr(self, "pipe")
if hasattr(self, "pipe2"):
delattr(self, "pipe2")
clear_cuda_and_gc()
def load_model(self, task_name: CONTROLNET_TYPES):
"Appropriately loads the network module, pipelines and cache it for reuse."
config = self.__model_sdxl if get_is_sdxl() else self.__model_normal
if self.__current_task_name == task_name:
return
model = config[task_name]
if not model:
raise Exception(f"ControlNet is not supported for {task_name}")
while model in list(config.keys()):
task_name = model # pyright: ignore
model = config[task_name]
pipeline_type = (
self.__model_sdxl_types[task_name]
if get_is_sdxl()
else self.__model_normal_types[task_name]
)
if "," in model:
model = [m.strip() for m in model.split(",")]
model = self.__load_network_model(model, pipeline_type)
self.__load_pipeline(model, pipeline_type)
self.__current_task_name = task_name
clear_cuda_and_gc()
def __load_network_model(self, model_name, pipeline_type):
"Loads the network module, eg: ControlNet or T2I Adapters"
def load_controlnet(model):
return ControlNetModel.from_pretrained(
model,
torch_dtype=torch.float16,
cache_dir=get_hf_cache_dir(),
).to("cuda")
def load_t2i(model):
return T2IAdapter.from_pretrained(
model,
torch_dtype=torch.float16,
varient="fp16",
).to("cuda")
if type(model_name) == str:
if pipeline_type == "controlnet":
return load_controlnet(model_name)
if pipeline_type == "t2i":
return load_t2i(model_name)
raise Exception("Invalid pipeline type")
elif type(model_name) == list:
if pipeline_type == "controlnet":
cns = []
for model in model_name:
cns.append(load_controlnet(model))
return MultiControlNetModel(cns).to("cuda")
elif pipeline_type == "t2i":
raise Exception("Multi T2I adapters are not supported")
raise Exception("Invalid pipeline type")
def __load_pipeline(self, network_model, pipeline_type):
"Load the base pipeline(s) (if not loaded already) based on pipeline type and attaches the network module to the pipeline"
def patch_pipe(pipe):
if not pipe:
# cases where the loader may return None
return None
if get_is_sdxl():
pipe.enable_vae_tiling()
pipe.enable_vae_slicing()
pipe.enable_xformers_memory_efficient_attention()
# this scheduler produces good outputs for t2i adapters
pipe.scheduler = EulerAncestralDiscreteScheduler.from_config(
pipe.scheduler.config
)
else:
pipe.enable_xformers_memory_efficient_attention()
return pipe
# If the pipeline type is changed we should reload all
# the pipelines
if not self.__loaded or self.__pipe_type != pipeline_type:
# controlnet pipeline for tile upscaler
pipe = StableDiffusionNetworkModelPipelineLoader(
is_sdxl=get_is_sdxl(),
is_img2img=True,
network_model=network_model,
pipeline_type=pipeline_type,
base_pipe=getattr(self, "__pipeline", None),
)
pipe = patch_pipe(pipe)
if pipe:
self.pipe = pipe
# controlnet pipeline for canny and pose
pipe2 = StableDiffusionNetworkModelPipelineLoader(
is_sdxl=get_is_sdxl(),
is_img2img=False,
network_model=network_model,
pipeline_type=pipeline_type,
base_pipe=getattr(self, "__pipeline", None),
)
pipe2 = patch_pipe(pipe2)
if pipe2:
self.pipe2 = pipe2
self.__loaded = True
self.__pipe_type = pipeline_type
# Set the network module in the pipeline
if pipeline_type == "controlnet":
if hasattr(self, "pipe"):
setattr(self.pipe, "controlnet", network_model)
if hasattr(self, "pipe2"):
setattr(self.pipe2, "controlnet", network_model)
elif pipeline_type == "t2i":
if hasattr(self, "pipe"):
setattr(self.pipe, "adapter", network_model)
if hasattr(self, "pipe2"):
setattr(self.pipe2, "adapter", network_model)
if hasattr(self, "pipe"):
self.pipe = self.pipe.to("cuda")
if hasattr(self, "pipe2"):
self.pipe2 = self.pipe2.to("cuda")
clear_cuda_and_gc()
def process(self, **kwargs):
if self.__current_task_name == "pose":
return self.process_pose(**kwargs)
if self.__current_task_name == "canny":
return self.process_canny(**kwargs)
if self.__current_task_name == "scribble":
return self.process_scribble(**kwargs)
if self.__current_task_name == "linearart":
return self.process_linearart(**kwargs)
if self.__current_task_name == "tile_upscaler":
return self.process_tile_upscaler(**kwargs)
raise Exception("ControlNet is not loaded with any model")
@torch.inference_mode()
def process_canny(
self,
prompt: List[str],
imageUrl: str,
seed: int,
num_inference_steps: int,
negative_prompt: List[str],
height: int,
width: int,
guidance_scale: float = 9,
**kwargs,
):
if self.__current_task_name != "canny":
raise Exception("ControlNet is not loaded with canny model")
torch.manual_seed(seed)
init_image = download_image(imageUrl).resize((width, height))
init_image = ControlNet.canny_detect_edge(init_image)
kwargs = {
"prompt": prompt,
"image": init_image,
"guidance_scale": guidance_scale,
"num_images_per_prompt": 1,
"negative_prompt": negative_prompt,
"num_inference_steps": num_inference_steps,
"height": height,
"width": width,
**kwargs,
}
result = self.pipe2.__call__(**kwargs)
return Result.from_result(result)
@torch.inference_mode()
def process_pose(
self,
prompt: List[str],
image: List[Image.Image],
seed: int,
num_inference_steps: int,
negative_prompt: List[str],
height: int,
width: int,
guidance_scale: float = 7.5,
**kwargs,
):
if self.__current_task_name != "pose":
raise Exception("ControlNet is not loaded with pose model")
torch.manual_seed(seed)
kwargs = {
"prompt": prompt[0],
"image": image,
"num_images_per_prompt": 4,
"num_inference_steps": num_inference_steps,
"negative_prompt": negative_prompt[0],
"guidance_scale": guidance_scale,
"height": height,
"width": width,
**kwargs,
}
print(kwargs)
result = self.pipe2.__call__(**kwargs)
return Result.from_result(result)
@torch.inference_mode()
def process_tile_upscaler(
self,
imageUrl: str,
prompt: str,
negative_prompt: str,
num_inference_steps: int,
seed: int,
height: int,
width: int,
resize_dimension: int,
guidance_scale: float = 7.5,
**kwargs,
):
if self.__current_task_name != "tile_upscaler":
raise Exception("ControlNet is not loaded with tile_upscaler model")
torch.manual_seed(seed)
init_image = download_image(imageUrl).resize((width, height))
condition_image = self.__resize_for_condition_image(
init_image, resize_dimension
)
kwargs = {
"image": condition_image,
"prompt": prompt,
"control_image": condition_image,
"num_inference_steps": num_inference_steps,
"negative_prompt": negative_prompt,
"height": condition_image.size[1],
"width": condition_image.size[0],
"guidance_scale": guidance_scale,
**kwargs,
}
result = self.pipe.__call__(**kwargs)
return Result.from_result(result)
@torch.inference_mode()
def process_scribble(
self,
image: List[Image.Image],
prompt: Union[str, List[str]],
negative_prompt: Union[str, List[str]],
num_inference_steps: int,
seed: int,
height: int,
width: int,
guidance_scale: float = 7.5,
**kwargs,
):
if self.__current_task_name != "scribble":
raise Exception("ControlNet is not loaded with scribble model")
torch.manual_seed(seed)
sdxl_args = (
{
"guidance_scale": 6,
"adapter_conditioning_scale": 1.0,
"adapter_conditioning_factor": 1.0,
}
if get_is_sdxl()
else {}
)
kwargs = {
"image": image,
"prompt": prompt,
"num_inference_steps": num_inference_steps,
"negative_prompt": negative_prompt,
"height": height,
"width": width,
"guidance_scale": guidance_scale,
**sdxl_args,
**kwargs,
}
result = self.pipe2.__call__(**kwargs)
return Result.from_result(result)
@torch.inference_mode()
def process_linearart(
self,
imageUrl: str,
prompt: Union[str, List[str]],
negative_prompt: Union[str, List[str]],
num_inference_steps: int,
seed: int,
height: int,
width: int,
guidance_scale: float = 7.5,
**kwargs,
):
if self.__current_task_name != "linearart":
raise Exception("ControlNet is not loaded with linearart model")
torch.manual_seed(seed)
init_image = download_image(imageUrl).resize((width, height))
condition_image = ControlNet.linearart_condition_image(init_image)
# we use t2i adapter and the conditioning scale should always be 0.8
sdxl_args = (
{
"guidance_scale": 6,
"adapter_conditioning_scale": 1.0,
"adapter_conditioning_factor": 1.0,
}
if get_is_sdxl()
else {}
)
kwargs = {
"image": [condition_image] * 4,
"prompt": prompt,
"num_inference_steps": num_inference_steps,
"negative_prompt": negative_prompt,
"height": height,
"width": width,
"guidance_scale": guidance_scale,
**sdxl_args,
**kwargs,
}
result = self.pipe2.__call__(**kwargs)
return Result.from_result(result)
def cleanup(self):
"""Doesn't do anything considering new diffusers has itself a cleanup mechanism
after controlnet generation"""
pass
def detect_pose(self, imageUrl: str) -> Image.Image:
detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
image = download_image(imageUrl)
image = detector.__call__(image)
return image
@staticmethod
def scribble_image(image: Image.Image) -> Image.Image:
processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
image = processor.__call__(input_image=image, scribble=True)
return image
@staticmethod
def linearart_condition_image(image: Image.Image, **kwargs) -> Image.Image:
processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
if get_is_sdxl():
kwargs = {"detect_resolution": 384, **kwargs}
image = processor.__call__(input_image=image, **kwargs)
return image
@staticmethod
def depth_image(image: Image.Image) -> Image.Image:
global midas, midas_transforms
if "midas" not in globals():
midas = torch.hub.load("intel-isl/MiDaS", "MiDaS").to("cuda")
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
transform = midas_transforms.default_transform
cv_image = np.array(image)
img = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)
input_batch = transform(img).to("cuda")
with torch.no_grad():
prediction = midas(input_batch)
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=img.shape[:2],
mode="bicubic",
align_corners=False,
).squeeze()
output = prediction.cpu().numpy()
formatted = (output * 255 / np.max(output)).astype("uint8")
img = Image.fromarray(formatted)
return img
@staticmethod
def pidinet_image(image: Image.Image) -> Image.Image:
pidinet = PidiNetDetector.from_pretrained("lllyasviel/Annotators").to("cuda")
image = pidinet.__call__(input_image=image, apply_filter=True)
return image
@staticmethod
def canny_detect_edge(image: Image.Image) -> Image.Image:
image_array = np.array(image)
low_threshold = 100
high_threshold = 200
image_array = cv2.Canny(image_array, low_threshold, high_threshold)
image_array = image_array[:, :, None]
image_array = np.concatenate([image_array, image_array, image_array], axis=2)
canny_image = Image.fromarray(image_array)
return canny_image
def __resize_for_condition_image(self, image: Image.Image, resolution: int):
input_image = image.convert("RGB")
W, H = input_image.size
k = float(resolution) / max(W, H)
H *= k
W *= k
H = int(round(H / 64.0)) * 64
W = int(round(W / 64.0)) * 64
img = input_image.resize((W, H), resample=Image.LANCZOS)
return img
__model_normal = {
"pose": "lllyasviel/control_v11f1p_sd15_depth, lllyasviel/control_v11p_sd15_openpose",
"canny": "lllyasviel/control_v11p_sd15_canny",
"linearart": "lllyasviel/control_v11p_sd15_lineart",
"scribble": "lllyasviel/control_v11p_sd15_scribble",
"tile_upscaler": "lllyasviel/control_v11f1e_sd15_tile",
}
__model_normal_types = {
"pose": "controlnet",
"canny": "controlnet",
"linearart": "controlnet",
"scribble": "controlnet",
"tile_upscaler": "controlnet",
}
__model_sdxl = {
"pose": "thibaud/controlnet-openpose-sdxl-1.0",
"canny": "diffusers/controlnet-canny-sdxl-1.0",
"linearart": "TencentARC/t2i-adapter-lineart-sdxl-1.0",
"scribble": "TencentARC/t2i-adapter-sketch-sdxl-1.0",
"tile_upscaler": None,
}
__model_sdxl_types = {
"pose": "controlnet",
"canny": "controlnet",
"linearart": "t2i",
"scribble": "t2i",
"tile_upscaler": None,
}