CM2000112 / internals /pipelines /controlnets.py
jayparmr's picture
Upload folder using huggingface_hub
4ff5093
raw
history blame
14.1 kB
from typing import List, Literal, Union
import cv2
import numpy as np
import torch
from controlnet_aux import HEDdetector, LineartDetector, OpenposeDetector
from diffusers import (
ControlNetModel,
DiffusionPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionXLControlNetPipeline,
UniPCMultistepScheduler,
)
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import (
MultiControlNetModel,
)
from PIL import Image
from pydash import has
from torch.nn import Linear
from tqdm import gui
from transformers import pipeline
import internals.util.image as ImageUtil
from external.midas import apply_midas
from internals.data.result import Result
from internals.pipelines.commons import AbstractPipeline
from internals.pipelines.tileUpscalePipeline import (
StableDiffusionControlNetImg2ImgPipeline,
)
from internals.util.cache import clear_cuda_and_gc
from internals.util.commons import download_image
from internals.util.config import (
get_hf_cache_dir,
get_hf_token,
get_is_sdxl,
get_model_dir,
)
CONTROLNET_TYPES = Literal["pose", "canny", "scribble", "linearart", "tile_upscaler"]
class ControlNet(AbstractPipeline):
__current_task_name = ""
__loaded = False
__pipeline: AbstractPipeline
def init(self, pipeline: AbstractPipeline):
self.__pipeline = pipeline
def load_model(self, task_name: CONTROLNET_TYPES):
config = self.__model_sdxl if get_is_sdxl() else self.__model_normal
if self.__current_task_name == task_name:
return
model = config[task_name]
if not model:
raise Exception(f"ControlNet is not supported for {task_name}")
while model in list(config.keys()):
task_name = model # pyright: ignore
model = config[task_name]
# Multi controlnet
if "," in model:
model_names = [m.strip() for m in model.split(",")]
controlnets = []
for name in model_names:
cn = ControlNetModel.from_pretrained(
name,
torch_dtype=torch.float16,
cache_dir=get_hf_cache_dir(),
).to("cuda")
controlnets.append(cn)
controlnet = MultiControlNetModel(controlnets).to("cuda")
# Single controlnet
else:
controlnet = ControlNetModel.from_pretrained(
model,
torch_dtype=torch.float16,
cache_dir=get_hf_cache_dir(),
).to("cuda")
self.__current_task_name = task_name
self.controlnet = controlnet
self.__load()
if hasattr(self, "pipe"):
self.pipe.controlnet = controlnet
if hasattr(self, "pipe2"):
self.pipe2.controlnet = controlnet
clear_cuda_and_gc()
def __load(self):
"Should not be called externally"
if self.__loaded:
return
if not hasattr(self, "controlnet"):
self.load_model("pose")
# controlnet pipeline for tile upscaler
if get_is_sdxl():
print("Warning: Tile upscale is not supported on SDXL")
if self.__pipeline:
pipe = StableDiffusionXLControlNetPipeline(
controlnet=self.controlnet, **self.__pipeline.pipe.components
).to("cuda")
else:
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
get_model_dir(),
controlnet=self.controlnet,
torch_dtype=torch.float16,
use_auth_token=get_hf_token(),
cache_dir=get_hf_cache_dir(),
use_safetensors=True,
).to("cuda")
pipe.enable_vae_tiling()
pipe.enable_vae_slicing()
pipe.enable_xformers_memory_efficient_attention()
self.pipe2 = pipe
else:
if hasattr(self, "__pipeline"):
pipe = StableDiffusionControlNetImg2ImgPipeline(
controlnet=self.controlnet, **self.__pipeline.pipe.components
).to("cuda")
else:
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
get_model_dir(),
controlnet=self.controlnet,
torch_dtype=torch.float16,
use_auth_token=get_hf_token(),
cache_dir=get_hf_cache_dir(),
).to("cuda")
# pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
pipe.enable_xformers_memory_efficient_attention()
self.pipe = pipe
# controlnet pipeline for canny and pose
pipe2 = StableDiffusionControlNetPipeline(**pipe.components).to("cuda")
pipe2.scheduler = UniPCMultistepScheduler.from_config(
pipe2.scheduler.config
)
pipe2.enable_xformers_memory_efficient_attention()
self.pipe2 = pipe2
self.__loaded = True
def process(self, **kwargs):
if self.__current_task_name == "pose":
return self.process_pose(**kwargs)
if self.__current_task_name == "canny":
return self.process_canny(**kwargs)
if self.__current_task_name == "scribble":
return self.process_scribble(**kwargs)
if self.__current_task_name == "linearart":
return self.process_linearart(**kwargs)
if self.__current_task_name == "tile_upscaler":
return self.process_tile_upscaler(**kwargs)
raise Exception("ControlNet is not loaded with any model")
@torch.inference_mode()
def process_canny(
self,
prompt: List[str],
imageUrl: str,
seed: int,
num_inference_steps: int,
negative_prompt: List[str],
height: int,
width: int,
guidance_scale: float = 9,
**kwargs,
):
if self.__current_task_name != "canny":
raise Exception("ControlNet is not loaded with canny model")
torch.manual_seed(seed)
init_image = download_image(imageUrl).resize((width, height))
init_image = ControlNet.canny_detect_edge(init_image)
kwargs = {
"prompt": prompt,
"image": init_image,
"guidance_scale": guidance_scale,
"num_images_per_prompt": 1,
"negative_prompt": negative_prompt,
"num_inference_steps": num_inference_steps,
"height": height,
"width": width,
**kwargs,
}
result = self.pipe2.__call__(**kwargs)
return Result.from_result(result)
@torch.inference_mode()
def process_pose(
self,
prompt: List[str],
image: List[Image.Image],
seed: int,
num_inference_steps: int,
negative_prompt: List[str],
height: int,
width: int,
guidance_scale: float = 7.5,
**kwargs,
):
if self.__current_task_name != "pose":
raise Exception("ControlNet is not loaded with pose model")
torch.manual_seed(seed)
kwargs = {
"prompt": prompt[0],
"image": image,
"num_images_per_prompt": 4,
"num_inference_steps": num_inference_steps,
"negative_prompt": negative_prompt[0],
"guidance_scale": guidance_scale,
"control_guidance_end": [0.5, 1.0],
"height": height,
"width": width,
**kwargs,
}
print(kwargs)
result = self.pipe2.__call__(**kwargs)
return Result.from_result(result)
@torch.inference_mode()
def process_tile_upscaler(
self,
imageUrl: str,
prompt: str,
negative_prompt: str,
num_inference_steps: int,
seed: int,
height: int,
width: int,
resize_dimension: int,
guidance_scale: float = 7.5,
**kwargs,
):
if self.__current_task_name != "tile_upscaler":
raise Exception("ControlNet is not loaded with tile_upscaler model")
torch.manual_seed(seed)
init_image = download_image(imageUrl).resize((width, height))
condition_image = self.__resize_for_condition_image(
init_image, resize_dimension
)
kwargs = {
"image": condition_image,
"prompt": prompt,
"controlnet_conditioning_image": condition_image,
"num_inference_steps": num_inference_steps,
"negative_prompt": negative_prompt,
"height": condition_image.size[1],
"width": condition_image.size[0],
"guidance_scale": guidance_scale,
**kwargs,
}
result = self.pipe.__call__(**kwargs)
return Result.from_result(result)
@torch.inference_mode()
def process_scribble(
self,
imageUrl: Union[str, Image.Image],
prompt: Union[str, List[str]],
negative_prompt: Union[str, List[str]],
num_inference_steps: int,
seed: int,
height: int,
width: int,
guidance_scale: float = 7.5,
**kwargs,
):
if self.__current_task_name != "scribble":
raise Exception("ControlNet is not loaded with scribble model")
torch.manual_seed(seed)
if isinstance(imageUrl, Image.Image):
init_image = imageUrl.resize((width, height))
else:
init_image = download_image(imageUrl).resize((width, height))
condition_image = self.__scribble_condition_image(init_image)
kwargs = {
"image": condition_image,
"prompt": prompt,
"num_inference_steps": num_inference_steps,
"negative_prompt": negative_prompt,
"height": height,
"width": width,
"guidance_scale": guidance_scale,
**kwargs,
}
result = self.pipe2.__call__(**kwargs)
return Result.from_result(result)
@torch.inference_mode()
def process_linearart(
self,
imageUrl: str,
prompt: Union[str, List[str]],
negative_prompt: Union[str, List[str]],
num_inference_steps: int,
seed: int,
height: int,
width: int,
guidance_scale: float = 7.5,
**kwargs,
):
if self.__current_task_name != "linearart":
raise Exception("ControlNet is not loaded with linearart model")
torch.manual_seed(seed)
init_image = download_image(imageUrl).resize((width, height))
condition_image = ControlNet.linearart_condition_image(init_image)
kwargs = {
"image": condition_image,
"prompt": prompt,
"num_inference_steps": num_inference_steps,
"negative_prompt": negative_prompt,
"height": height,
"width": width,
"guidance_scale": guidance_scale,
**kwargs,
}
result = self.pipe2.__call__(**kwargs)
return Result.from_result(result)
def cleanup(self):
if hasattr(self, "pipe") and hasattr(self.pipe, "controlnet"):
del self.pipe.controlnet
if hasattr(self, "pipe2") and hasattr(self.pipe2, "controlnet"):
del self.pipe2.controlnet
if hasattr(self, "controlnet"):
del self.controlnet
self.__current_task_name = ""
clear_cuda_and_gc()
def detect_pose(self, imageUrl: str) -> Image.Image:
detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
image = download_image(imageUrl)
image = detector.__call__(image)
return image
def __scribble_condition_image(self, image: Image.Image) -> Image.Image:
processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
image = processor.__call__(input_image=image, scribble=True)
return image
@staticmethod
def linearart_condition_image(image: Image.Image) -> Image.Image:
processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
image = processor.__call__(input_image=image)
return image
@staticmethod
def depth_image(image: Image.Image) -> Image.Image:
depth = np.array(image)
depth = ImageUtil.HWC3(depth)
depth, _ = apply_midas(depth)
depth = ImageUtil.HWC3(depth)
depth = Image.fromarray(depth)
return depth
@staticmethod
def canny_detect_edge(image: Image.Image) -> Image.Image:
image_array = np.array(image)
low_threshold = 100
high_threshold = 200
image_array = cv2.Canny(image_array, low_threshold, high_threshold)
image_array = image_array[:, :, None]
image_array = np.concatenate([image_array, image_array, image_array], axis=2)
canny_image = Image.fromarray(image_array)
return canny_image
def __resize_for_condition_image(self, image: Image.Image, resolution: int):
input_image = image.convert("RGB")
W, H = input_image.size
k = float(resolution) / max(W, H)
H *= k
W *= k
H = int(round(H / 64.0)) * 64
W = int(round(W / 64.0)) * 64
img = input_image.resize((W, H), resample=Image.LANCZOS)
return img
__model_normal = {
"pose": "lllyasviel/control_v11f1p_sd15_depth, lllyasviel/control_v11p_sd15_openpose",
"canny": "lllyasviel/control_v11p_sd15_canny",
"linearart": "lllyasviel/control_v11p_sd15_lineart",
"scribble": "lllyasviel/control_v11p_sd15_scribble",
"tile_upscaler": "lllyasviel/control_v11f1e_sd15_tile",
}
__model_sdxl = {
"pose": "thibaud/controlnet-openpose-sdxl-1.0",
"canny": "diffusers/controlnet-canny-sdxl-1.0",
"linearart": "canny",
"scribble": "canny",
"tile_upscaler": None,
}