CM2000112 / internals /pipelines /controlnets.py
jayparmr's picture
Upload folder using huggingface_hub
86248f3
raw
history blame
10.7 kB
from typing import List, Union
import cv2
import numpy as np
import torch
from controlnet_aux import HEDdetector, LineartDetector, OpenposeDetector
from diffusers import (ControlNetModel, DiffusionPipeline,
StableDiffusionControlNetPipeline,
UniPCMultistepScheduler)
from PIL import Image
from torch.nn import Linear
from tqdm import gui
from internals.data.result import Result
from internals.pipelines.commons import AbstractPipeline
from internals.pipelines.tileUpscalePipeline import \
StableDiffusionControlNetImg2ImgPipeline
from internals.util.cache import clear_cuda_and_gc
from internals.util.commons import download_image
class ControlNet(AbstractPipeline):
__current_task_name = ""
def load(self, model_dir: str):
# we will load canny by default
self.load_canny()
# controlnet pipeline for canny and pose
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
model_dir,
controlnet=self.controlnet,
torch_dtype=torch.float16,
).to("cuda")
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
pipe.enable_xformers_memory_efficient_attention()
self.pipe = pipe
# controlnet pipeline for tile upscaler
pipe2 = StableDiffusionControlNetPipeline(**pipe.components).to("cuda")
pipe2.scheduler = UniPCMultistepScheduler.from_config(pipe2.scheduler.config)
pipe2.enable_xformers_memory_efficient_attention()
self.pipe2 = pipe2
def load_canny(self):
if self.__current_task_name == "canny":
return
canny = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16
).to("cuda")
self.__current_task_name = "canny"
self.controlnet = canny
if hasattr(self, "pipe"):
self.pipe.controlnet = canny
if hasattr(self, "pipe2"):
self.pipe2.controlnet = canny
clear_cuda_and_gc()
def load_pose(self):
if self.__current_task_name == "pose":
return
pose = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-openpose", torch_dtype=torch.float16
).to("cuda")
self.__current_task_name = "pose"
self.controlnet = pose
if hasattr(self, "pipe"):
self.pipe.controlnet = pose
if hasattr(self, "pipe2"):
self.pipe2.controlnet = pose
clear_cuda_and_gc()
def load_tile_upscaler(self):
if self.__current_task_name == "tile_upscaler":
return
tile_upscaler = ControlNetModel.from_pretrained(
"lllyasviel/control_v11f1e_sd15_tile", torch_dtype=torch.float16
).to("cuda")
self.__current_task_name = "tile_upscaler"
self.controlnet = tile_upscaler
if hasattr(self, "pipe"):
self.pipe.controlnet = tile_upscaler
if hasattr(self, "pipe2"):
self.pipe2.controlnet = tile_upscaler
clear_cuda_and_gc()
def load_scribble(self):
if self.__current_task_name == "scribble":
return
scribble = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_scribble", torch_dtype=torch.float16
).to("cuda")
self.__current_task_name = "scribble"
self.controlnet = scribble
if hasattr(self, "pipe"):
self.pipe.controlnet = scribble
if hasattr(self, "pipe2"):
self.pipe2.controlnet = scribble
clear_cuda_and_gc()
def load_linearart(self):
if self.__current_task_name == "linearart":
return
linearart = ControlNetModel.from_pretrained(
"ControlNet-1-1-preview/control_v11p_sd15_lineart",
torch_dtype=torch.float16,
).to("cuda")
self.__current_task_name = "linearart"
self.controlnet = linearart
if hasattr(self, "pipe"):
self.pipe.controlnet = linearart
if hasattr(self, "pipe2"):
self.pipe2.controlnet = linearart
clear_cuda_and_gc()
def cleanup(self):
self.pipe.controlnet = None
self.pipe2.controlnet = None
self.controlnet = None
self.__current_task_name = ""
clear_cuda_and_gc()
@torch.inference_mode()
def process_canny(
self,
prompt: List[str],
imageUrl: str,
seed: int,
steps: int,
negative_prompt: List[str],
guidance_scale: float,
height: int,
width: int,
):
if self.__current_task_name != "canny":
raise Exception("ControlNet is not loaded with canny model")
torch.manual_seed(seed)
init_image = download_image(imageUrl).resize((width, height))
init_image = self.__canny_detect_edge(init_image)
result = self.pipe2.__call__(
prompt=prompt,
image=init_image,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
negative_prompt=negative_prompt,
num_inference_steps=steps,
height=height,
width=width,
)
return Result.from_result(result)
@torch.inference_mode()
def process_pose(
self,
prompt: List[str],
image: List[Image.Image],
seed: int,
steps: int,
guidance_scale: float,
negative_prompt: List[str],
height: int,
width: int,
):
if self.__current_task_name != "pose":
raise Exception("ControlNet is not loaded with pose model")
torch.manual_seed(seed)
result = self.pipe2.__call__(
prompt=prompt,
image=image,
num_images_per_prompt=1,
num_inference_steps=steps,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
height=height,
width=width,
)
return Result.from_result(result)
@torch.inference_mode()
def process_tile_upscaler(
self,
imageUrl: str,
prompt: str,
negative_prompt: str,
steps: int,
seed: int,
height: int,
width: int,
resize_dimension: int,
guidance_scale: float,
):
if self.__current_task_name != "tile_upscaler":
raise Exception("ControlNet is not loaded with tile_upscaler model")
torch.manual_seed(seed)
init_image = download_image(imageUrl).resize((width, height))
condition_image = self.__resize_for_condition_image(
init_image, resize_dimension
)
result = self.pipe.__call__(
image=condition_image,
prompt=prompt,
controlnet_conditioning_image=condition_image,
num_inference_steps=steps,
negative_prompt=negative_prompt,
height=condition_image.size[1],
width=condition_image.size[0],
strength=1.0,
guidance_scale=guidance_scale,
)
return Result.from_result(result)
@torch.inference_mode()
def process_scribble(
self,
imageUrl: str,
prompt: Union[str, List[str]],
negative_prompt: Union[str, List[str]],
steps: int,
seed: int,
height: int,
width: int,
guidance_scale: float = 7.5,
):
if self.__current_task_name != "scribble":
raise Exception("ControlNet is not loaded with scribble model")
torch.manual_seed(seed)
init_image = download_image(imageUrl).resize((width, height))
condition_image = self.__scribble_condition_image(init_image)
result = self.pipe2.__call__(
image=condition_image,
prompt=prompt,
num_inference_steps=steps,
negative_prompt=negative_prompt,
height=height,
width=width,
guidance_scale=guidance_scale,
)
return Result.from_result(result)
@torch.inference_mode()
def process_linearart(
self,
imageUrl: str,
prompt: Union[str, List[str]],
negative_prompt: Union[str, List[str]],
steps: int,
seed: int,
height: int,
width: int,
guidance_scale: float = 7.5,
):
if self.__current_task_name != "linearart":
raise Exception("ControlNet is not loaded with linearart model")
torch.manual_seed(seed)
init_image = download_image(imageUrl).resize((width, height))
condition_image = self.__linearart_condition_image(init_image)
result = self.pipe2.__call__(
image=condition_image,
prompt=prompt,
num_inference_steps=steps,
negative_prompt=negative_prompt,
height=height,
width=width,
guidance_scale=guidance_scale,
)
return Result.from_result(result)
def detect_pose(self, imageUrl: str) -> Image.Image:
detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
image = download_image(imageUrl)
image = detector.__call__(image, hand_and_face=True)
return image
def __scribble_condition_image(self, image: Image.Image) -> Image.Image:
processor = HEDdetector.from_pretrained("lllyasviel/Annotators")
image = processor.__call__(input_image=image, scribble=True)
return image
def __linearart_condition_image(self, image: Image.Image) -> Image.Image:
processor = LineartDetector.from_pretrained("lllyasviel/Annotators")
image = processor.__call__(input_image=image)
return image
def __canny_detect_edge(self, image: Image.Image) -> Image.Image:
image_array = np.array(image)
low_threshold = 100
high_threshold = 200
image_array = cv2.Canny(image_array, low_threshold, high_threshold)
image_array = image_array[:, :, None]
image_array = np.concatenate([image_array, image_array, image_array], axis=2)
canny_image = Image.fromarray(image_array)
return canny_image
def __resize_for_condition_image(self, image: Image.Image, resolution: int):
input_image = image.convert("RGB")
W, H = input_image.size
k = float(resolution) / min(W, H)
H *= k
W *= k
H = int(round(H / 64.0)) * 64
W = int(round(W / 64.0)) * 64
img = input_image.resize((W, H), resample=Image.LANCZOS)
return img