from io import BytesIO from typing import List, Optional, Union import torch from diffusers import ( ControlNetModel, StableDiffusionControlNetInpaintPipeline, StableDiffusionInpaintPipeline, UniPCMultistepScheduler, ) from PIL import Image, ImageFilter, ImageOps import internals.util.image as ImageUtil from internals.data.result import Result from internals.pipelines.commons import AbstractPipeline from internals.pipelines.controlnets import ControlNet from internals.pipelines.high_res import HighRes from internals.pipelines.remove_background import RemoveBackgroundV2 from internals.pipelines.upscaler import Upscaler from internals.util.commons import download_image from internals.util.config import get_hf_cache_dir, get_model_dir class ReplaceBackground(AbstractPipeline): __loaded = False def load( self, upscaler: Optional[Upscaler] = None, remove_background: Optional[RemoveBackgroundV2] = None, controlnet: Optional[ControlNet] = None, high_res: Optional[HighRes] = None, ): if self.__loaded: return controlnet_model = ControlNetModel.from_pretrained( "lllyasviel/control_v11p_sd15_lineart", torch_dtype=torch.float16, cache_dir=get_hf_cache_dir(), ).to("cuda") if controlnet: controlnet.load_linearart() pipe = StableDiffusionControlNetInpaintPipeline( **controlnet.pipe.components ) pipe.controlnet = controlnet_model else: pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( get_model_dir(), controlnet=controlnet_model, torch_dtype=torch.float16, cache_dir=get_hf_cache_dir(), ) pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) pipe.to("cuda") self.pipe = pipe if not high_res: high_res = HighRes() high_res.load() self.high_res = high_res if not upscaler: upscaler = Upscaler() upscaler.load() self.upscaler = upscaler if not remove_background: remove_background = RemoveBackgroundV2() self.remove_background = remove_background self.__loaded = True @torch.inference_mode() def replace( self, image: Union[str, Image.Image], width: int, height: int, product_scale_width: float, prompt: List[str], negative_prompt: List[str], resize_dimension: int, conditioning_scale: float, seed: int, steps: int, apply_high_res: bool = False, ): if type(image) is str: image = download_image(image) torch.manual_seed(seed) torch.cuda.manual_seed(seed) image = image.convert("RGB") if max(image.size) > 1536: image = ImageUtil.resize_image(image, dimension=1536) image = self.remove_background.remove(image) width = int(width) height = int(height) n_width = int(width * product_scale_width) n_height = int(n_width * height // width) print(width, height, n_width, n_height) image = ImageUtil.padd_image(image, n_width, n_height) f_image = Image.new("RGBA", (width, height), (0, 0, 0, 0)) f_image.paste(image, ((width - n_width) // 2, (height - n_height) // 2)) image = f_image mask = image.copy() pixdata = mask.load() w, h = mask.size for y in range(h): for x in range(w): item = pixdata[x, y] if item[3] == 0: pixdata[x, y] = (255, 255, 255, 255) else: pixdata[x, y] = (0, 0, 0, 255) mask = mask.convert("RGB") condition_image = ControlNet.linearart_condition_image(image) if apply_high_res and hasattr(self, "high_res"): (w, h) = self.high_res.get_intermediate_dimension(width, height) images = self.pipe.__call__( prompt=prompt, negative_prompt=negative_prompt, image=image, mask_image=mask, control_image=condition_image, controlnet_conditioning_scale=conditioning_scale, guidance_scale=9, strength=1, num_inference_steps=steps, height=w, width=h, ).images result = self.high_res.apply( prompt=prompt, negative_prompt=negative_prompt, images=images, width=width, height=width, steps=steps, ) else: result = self.pipe.__call__( prompt=prompt, negative_prompt=negative_prompt, image=image, mask_image=mask, control_image=condition_image, controlnet_conditioning_scale=conditioning_scale, guidance_scale=9, strength=1, height=height, num_inference_steps=steps, width=width, ) result = Result.from_result(result) images, has_nsfw = result if not has_nsfw: for i in range(len(images)): images[i].paste(image, (0, 0), image) return (images, has_nsfw)