from typing import Optional import torch from diffusers import ControlNetModel, StableDiffusionControlNetImg2ImgPipeline from PIL import Image import internals.util.image as ImageUtil from internals.pipelines.commons import AbstractPipeline from internals.pipelines.controlnets import ControlNet from internals.util.config import get_hf_cache_dir class RealtimeDraw(AbstractPipeline): def load(self, pipeline: AbstractPipeline): if hasattr(self, "pipe"): return self.__controlnet_scribble = ControlNetModel.from_pretrained( "lllyasviel/control_v11p_sd15_scribble", torch_dtype=torch.float16, cache_dir=get_hf_cache_dir(), ) self.__controlnet_seg = ControlNetModel.from_pretrained( "lllyasviel/control_v11p_sd15_seg", torch_dtype=torch.float16, cache_dir=get_hf_cache_dir(), ) kwargs = {**pipeline.pipe.components} # pyright: ignore kwargs.pop("image_encoder", None) self.pipe = StableDiffusionControlNetImg2ImgPipeline( **kwargs, controlnet=self.__controlnet_seg ).to("cuda") self.pipe.safety_checker = None self.pipe2 = StableDiffusionControlNetImg2ImgPipeline( **kwargs, controlnet=[self.__controlnet_scribble, self.__controlnet_seg] ).to("cuda") self.pipe2.safety_checker = None def process_seg( self, image: Image.Image, prompt: str, negative_prompt: str, seed: int, ): torch.manual_seed(seed) image = ImageUtil.resize_image(image, 512) img = self.pipe.__call__( image=image, control_image=image, prompt=prompt, num_inference_steps=15, negative_prompt=negative_prompt, guidance_scale=10, strength=0.8, ).images[0] return img def process_img( self, prompt: str, negative_prompt: str, seed: int, image: Optional[Image.Image] = None, image2: Optional[Image.Image] = None, ): torch.manual_seed(seed) if not image: size = (512, 512) if image2: size = image2.size image = Image.new("RGB", size, color=0) if not image2: size = (512, 512) if image: size = image.size image2 = Image.new("RGB", size, color=0) image = ImageUtil.resize_image(image, 512) scribble = ControlNet.scribble_image(image) image2 = ImageUtil.resize_image(image2, 512) img = self.pipe2.__call__( image=image, control_image=[scribble, image2], prompt=prompt, num_inference_steps=15, negative_prompt=negative_prompt, guidance_scale=10, strength=0.9, width=image.size[0], height=image.size[1], controlnet_conditioning_scale=[1.0, 0.8], ).images[0] return img