|
from typing import List, Literal, Union |
|
|
|
import cv2 |
|
import numpy as np |
|
import torch |
|
from controlnet_aux import HEDdetector, LineartDetector, OpenposeDetector |
|
from diffusers import ( |
|
ControlNetModel, |
|
DiffusionPipeline, |
|
StableDiffusionControlNetPipeline, |
|
StableDiffusionXLControlNetPipeline, |
|
UniPCMultistepScheduler, |
|
) |
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import ( |
|
MultiControlNetModel, |
|
) |
|
from PIL import Image |
|
from pydash import has |
|
from torch.nn import Linear |
|
from tqdm import gui |
|
from transformers import pipeline |
|
|
|
import internals.util.image as ImageUtil |
|
from external.midas import apply_midas |
|
from internals.data.result import Result |
|
from internals.pipelines.commons import AbstractPipeline |
|
from internals.pipelines.tileUpscalePipeline import ( |
|
StableDiffusionControlNetImg2ImgPipeline, |
|
) |
|
from internals.util.cache import clear_cuda_and_gc |
|
from internals.util.commons import download_image |
|
from internals.util.config import ( |
|
get_hf_cache_dir, |
|
get_hf_token, |
|
get_is_sdxl, |
|
get_model_dir, |
|
) |
|
|
|
CONTROLNET_TYPES = Literal["pose", "canny", "scribble", "linearart", "tile_upscaler"] |
|
|
|
|
|
class ControlNet(AbstractPipeline): |
|
__current_task_name = "" |
|
__loaded = False |
|
|
|
__pipeline: AbstractPipeline |
|
|
|
def init(self, pipeline: AbstractPipeline): |
|
self.__pipeline = pipeline |
|
|
|
def load_model(self, task_name: CONTROLNET_TYPES): |
|
config = self.__model_sdxl if get_is_sdxl() else self.__model_normal |
|
if self.__current_task_name == task_name: |
|
return |
|
model = config[task_name] |
|
if not model: |
|
raise Exception(f"ControlNet is not supported for {task_name}") |
|
while model in list(config.keys()): |
|
task_name = model |
|
model = config[task_name] |
|
|
|
|
|
if "," in model: |
|
model_names = [m.strip() for m in model.split(",")] |
|
controlnets = [] |
|
for name in model_names: |
|
cn = ControlNetModel.from_pretrained( |
|
name, |
|
torch_dtype=torch.float16, |
|
cache_dir=get_hf_cache_dir(), |
|
).to("cuda") |
|
controlnets.append(cn) |
|
controlnet = MultiControlNetModel(controlnets).to("cuda") |
|
|
|
else: |
|
controlnet = ControlNetModel.from_pretrained( |
|
model, |
|
torch_dtype=torch.float16, |
|
cache_dir=get_hf_cache_dir(), |
|
).to("cuda") |
|
self.__current_task_name = task_name |
|
self.controlnet = controlnet |
|
|
|
self.__load() |
|
|
|
if hasattr(self, "pipe"): |
|
self.pipe.controlnet = controlnet |
|
if hasattr(self, "pipe2"): |
|
self.pipe2.controlnet = controlnet |
|
clear_cuda_and_gc() |
|
|
|
def __load(self): |
|
"Should not be called externally" |
|
if self.__loaded: |
|
return |
|
|
|
if not hasattr(self, "controlnet"): |
|
self.load_model("pose") |
|
|
|
|
|
if get_is_sdxl(): |
|
print("Warning: Tile upscale is not supported on SDXL") |
|
|
|
if self.__pipeline: |
|
pipe = StableDiffusionXLControlNetPipeline( |
|
controlnet=self.controlnet, **self.__pipeline.pipe.components |
|
).to("cuda") |
|
else: |
|
pipe = StableDiffusionXLControlNetPipeline.from_pretrained( |
|
get_model_dir(), |
|
controlnet=self.controlnet, |
|
torch_dtype=torch.float16, |
|
use_auth_token=get_hf_token(), |
|
cache_dir=get_hf_cache_dir(), |
|
use_safetensors=True, |
|
).to("cuda") |
|
pipe.enable_vae_tiling() |
|
pipe.enable_vae_slicing() |
|
pipe.enable_xformers_memory_efficient_attention() |
|
self.pipe2 = pipe |
|
else: |
|
if hasattr(self, "__pipeline"): |
|
pipe = StableDiffusionControlNetImg2ImgPipeline( |
|
controlnet=self.controlnet, **self.__pipeline.pipe.components |
|
).to("cuda") |
|
else: |
|
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( |
|
get_model_dir(), |
|
controlnet=self.controlnet, |
|
torch_dtype=torch.float16, |
|
use_auth_token=get_hf_token(), |
|
cache_dir=get_hf_cache_dir(), |
|
).to("cuda") |
|
|
|
pipe.enable_model_cpu_offload() |
|
pipe.enable_xformers_memory_efficient_attention() |
|
self.pipe = pipe |
|
|
|
|
|
pipe2 = StableDiffusionControlNetPipeline(**pipe.components).to("cuda") |
|
pipe2.scheduler = UniPCMultistepScheduler.from_config( |
|
pipe2.scheduler.config |
|
) |
|
pipe2.enable_xformers_memory_efficient_attention() |
|
self.pipe2 = pipe2 |
|
|
|
self.__loaded = True |
|
|
|
def process(self, **kwargs): |
|
if self.__current_task_name == "pose": |
|
return self.process_pose(**kwargs) |
|
if self.__current_task_name == "canny": |
|
return self.process_canny(**kwargs) |
|
if self.__current_task_name == "scribble": |
|
return self.process_scribble(**kwargs) |
|
if self.__current_task_name == "linearart": |
|
return self.process_linearart(**kwargs) |
|
if self.__current_task_name == "tile_upscaler": |
|
return self.process_tile_upscaler(**kwargs) |
|
raise Exception("ControlNet is not loaded with any model") |
|
|
|
@torch.inference_mode() |
|
def process_canny( |
|
self, |
|
prompt: List[str], |
|
imageUrl: str, |
|
seed: int, |
|
num_inference_steps: int, |
|
negative_prompt: List[str], |
|
height: int, |
|
width: int, |
|
guidance_scale: float = 9, |
|
**kwargs, |
|
): |
|
if self.__current_task_name != "canny": |
|
raise Exception("ControlNet is not loaded with canny model") |
|
|
|
torch.manual_seed(seed) |
|
|
|
init_image = download_image(imageUrl).resize((width, height)) |
|
init_image = ControlNet.canny_detect_edge(init_image) |
|
|
|
kwargs = { |
|
"prompt": prompt, |
|
"image": init_image, |
|
"guidance_scale": guidance_scale, |
|
"num_images_per_prompt": 1, |
|
"negative_prompt": negative_prompt, |
|
"num_inference_steps": num_inference_steps, |
|
"height": height, |
|
"width": width, |
|
**kwargs, |
|
} |
|
|
|
result = self.pipe2.__call__(**kwargs) |
|
return Result.from_result(result) |
|
|
|
@torch.inference_mode() |
|
def process_pose( |
|
self, |
|
prompt: List[str], |
|
image: List[Image.Image], |
|
seed: int, |
|
num_inference_steps: int, |
|
negative_prompt: List[str], |
|
height: int, |
|
width: int, |
|
guidance_scale: float = 7.5, |
|
**kwargs, |
|
): |
|
if self.__current_task_name != "pose": |
|
raise Exception("ControlNet is not loaded with pose model") |
|
|
|
torch.manual_seed(seed) |
|
|
|
kwargs = { |
|
"prompt": prompt[0], |
|
"image": image, |
|
"num_images_per_prompt": 4, |
|
"num_inference_steps": num_inference_steps, |
|
"negative_prompt": negative_prompt[0], |
|
"guidance_scale": guidance_scale, |
|
"control_guidance_end": [0.5, 1.0], |
|
"height": height, |
|
"width": width, |
|
**kwargs, |
|
} |
|
print(kwargs) |
|
result = self.pipe2.__call__(**kwargs) |
|
return Result.from_result(result) |
|
|
|
@torch.inference_mode() |
|
def process_tile_upscaler( |
|
self, |
|
imageUrl: str, |
|
prompt: str, |
|
negative_prompt: str, |
|
num_inference_steps: int, |
|
seed: int, |
|
height: int, |
|
width: int, |
|
resize_dimension: int, |
|
guidance_scale: float = 7.5, |
|
**kwargs, |
|
): |
|
if self.__current_task_name != "tile_upscaler": |
|
raise Exception("ControlNet is not loaded with tile_upscaler model") |
|
|
|
torch.manual_seed(seed) |
|
|
|
init_image = download_image(imageUrl).resize((width, height)) |
|
condition_image = self.__resize_for_condition_image( |
|
init_image, resize_dimension |
|
) |
|
|
|
kwargs = { |
|
"image": condition_image, |
|
"prompt": prompt, |
|
"controlnet_conditioning_image": condition_image, |
|
"num_inference_steps": num_inference_steps, |
|
"negative_prompt": negative_prompt, |
|
"height": condition_image.size[1], |
|
"width": condition_image.size[0], |
|
"guidance_scale": guidance_scale, |
|
**kwargs, |
|
} |
|
result = self.pipe.__call__(**kwargs) |
|
return Result.from_result(result) |
|
|
|
@torch.inference_mode() |
|
def process_scribble( |
|
self, |
|
imageUrl: Union[str, Image.Image], |
|
prompt: Union[str, List[str]], |
|
negative_prompt: Union[str, List[str]], |
|
num_inference_steps: int, |
|
seed: int, |
|
height: int, |
|
width: int, |
|
guidance_scale: float = 7.5, |
|
**kwargs, |
|
): |
|
if self.__current_task_name != "scribble": |
|
raise Exception("ControlNet is not loaded with scribble model") |
|
|
|
torch.manual_seed(seed) |
|
|
|
if isinstance(imageUrl, Image.Image): |
|
init_image = imageUrl.resize((width, height)) |
|
else: |
|
init_image = download_image(imageUrl).resize((width, height)) |
|
|
|
condition_image = self.__scribble_condition_image(init_image) |
|
|
|
kwargs = { |
|
"image": condition_image, |
|
"prompt": prompt, |
|
"num_inference_steps": num_inference_steps, |
|
"negative_prompt": negative_prompt, |
|
"height": height, |
|
"width": width, |
|
"guidance_scale": guidance_scale, |
|
**kwargs, |
|
} |
|
result = self.pipe2.__call__(**kwargs) |
|
return Result.from_result(result) |
|
|
|
@torch.inference_mode() |
|
def process_linearart( |
|
self, |
|
imageUrl: str, |
|
prompt: Union[str, List[str]], |
|
negative_prompt: Union[str, List[str]], |
|
num_inference_steps: int, |
|
seed: int, |
|
height: int, |
|
width: int, |
|
guidance_scale: float = 7.5, |
|
**kwargs, |
|
): |
|
if self.__current_task_name != "linearart": |
|
raise Exception("ControlNet is not loaded with linearart model") |
|
|
|
torch.manual_seed(seed) |
|
|
|
init_image = download_image(imageUrl).resize((width, height)) |
|
condition_image = ControlNet.linearart_condition_image(init_image) |
|
|
|
kwargs = { |
|
"image": condition_image, |
|
"prompt": prompt, |
|
"num_inference_steps": num_inference_steps, |
|
"negative_prompt": negative_prompt, |
|
"height": height, |
|
"width": width, |
|
"guidance_scale": guidance_scale, |
|
**kwargs, |
|
} |
|
result = self.pipe2.__call__(**kwargs) |
|
return Result.from_result(result) |
|
|
|
def cleanup(self): |
|
if hasattr(self, "pipe") and hasattr(self.pipe, "controlnet"): |
|
del self.pipe.controlnet |
|
if hasattr(self, "pipe2") and hasattr(self.pipe2, "controlnet"): |
|
del self.pipe2.controlnet |
|
if hasattr(self, "controlnet"): |
|
del self.controlnet |
|
self.__current_task_name = "" |
|
|
|
clear_cuda_and_gc() |
|
|
|
def detect_pose(self, imageUrl: str) -> Image.Image: |
|
detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet") |
|
image = download_image(imageUrl) |
|
image = detector.__call__(image) |
|
return image |
|
|
|
def __scribble_condition_image(self, image: Image.Image) -> Image.Image: |
|
processor = HEDdetector.from_pretrained("lllyasviel/Annotators") |
|
image = processor.__call__(input_image=image, scribble=True) |
|
return image |
|
|
|
@staticmethod |
|
def linearart_condition_image(image: Image.Image) -> Image.Image: |
|
processor = LineartDetector.from_pretrained("lllyasviel/Annotators") |
|
image = processor.__call__(input_image=image) |
|
return image |
|
|
|
@staticmethod |
|
def depth_image(image: Image.Image) -> Image.Image: |
|
depth = np.array(image) |
|
depth = ImageUtil.HWC3(depth) |
|
depth, _ = apply_midas(depth) |
|
depth = ImageUtil.HWC3(depth) |
|
depth = Image.fromarray(depth) |
|
return depth |
|
|
|
@staticmethod |
|
def canny_detect_edge(image: Image.Image) -> Image.Image: |
|
image_array = np.array(image) |
|
|
|
low_threshold = 100 |
|
high_threshold = 200 |
|
|
|
image_array = cv2.Canny(image_array, low_threshold, high_threshold) |
|
image_array = image_array[:, :, None] |
|
image_array = np.concatenate([image_array, image_array, image_array], axis=2) |
|
canny_image = Image.fromarray(image_array) |
|
return canny_image |
|
|
|
def __resize_for_condition_image(self, image: Image.Image, resolution: int): |
|
input_image = image.convert("RGB") |
|
W, H = input_image.size |
|
k = float(resolution) / max(W, H) |
|
H *= k |
|
W *= k |
|
H = int(round(H / 64.0)) * 64 |
|
W = int(round(W / 64.0)) * 64 |
|
img = input_image.resize((W, H), resample=Image.LANCZOS) |
|
return img |
|
|
|
__model_normal = { |
|
"pose": "lllyasviel/control_v11f1p_sd15_depth, lllyasviel/control_v11p_sd15_openpose", |
|
"canny": "lllyasviel/control_v11p_sd15_canny", |
|
"linearart": "lllyasviel/control_v11p_sd15_lineart", |
|
"scribble": "lllyasviel/control_v11p_sd15_scribble", |
|
"tile_upscaler": "lllyasviel/control_v11f1e_sd15_tile", |
|
} |
|
__model_sdxl = { |
|
"pose": "thibaud/controlnet-openpose-sdxl-1.0", |
|
"canny": "diffusers/controlnet-canny-sdxl-1.0", |
|
"linearart": "canny", |
|
"scribble": "canny", |
|
"tile_upscaler": None, |
|
} |
|
|