|
from io import BytesIO |
|
from typing import List, Optional, Union |
|
|
|
import torch |
|
from cv2 import inpaint |
|
from diffusers import ( |
|
ControlNetModel, |
|
StableDiffusionControlNetInpaintPipeline, |
|
StableDiffusionControlNetPipeline, |
|
StableDiffusionInpaintPipeline, |
|
UniPCMultistepScheduler, |
|
) |
|
from PIL import Image, ImageFilter, ImageOps |
|
|
|
import internals.util.image as ImageUtil |
|
from internals.data.result import Result |
|
from internals.data.task import ModelType |
|
from internals.pipelines.commons import AbstractPipeline |
|
from internals.pipelines.controlnets import ControlNet |
|
from internals.pipelines.high_res import HighRes |
|
from internals.pipelines.inpainter import InPainter |
|
from internals.pipelines.remove_background import RemoveBackgroundV2 |
|
from internals.pipelines.upscaler import Upscaler |
|
from internals.util.cache import clear_cuda_and_gc |
|
from internals.util.commons import download_image |
|
from internals.util.config import ( |
|
get_hf_cache_dir, |
|
get_hf_token, |
|
get_inpaint_model_path, |
|
get_model_dir, |
|
) |
|
|
|
|
|
class ReplaceBackground(AbstractPipeline): |
|
__loaded = False |
|
|
|
def load( |
|
self, |
|
upscaler: Optional[Upscaler] = None, |
|
remove_background: Optional[RemoveBackgroundV2] = None, |
|
base: Optional[AbstractPipeline] = None, |
|
high_res: Optional[HighRes] = None, |
|
): |
|
if self.__loaded: |
|
return |
|
controlnet_model = ControlNetModel.from_pretrained( |
|
"lllyasviel/control_v11p_sd15_canny", |
|
torch_dtype=torch.float16, |
|
cache_dir=get_hf_cache_dir(), |
|
).to("cuda") |
|
if base: |
|
pipe = StableDiffusionControlNetPipeline( |
|
**base.pipe.components, |
|
controlnet=controlnet_model, |
|
) |
|
else: |
|
pipe = StableDiffusionControlNetPipeline.from_pretrained( |
|
get_model_dir(), |
|
controlnet=controlnet_model, |
|
torch_dtype=torch.float16, |
|
cache_dir=get_hf_cache_dir(), |
|
use_auth_token=get_hf_token(), |
|
) |
|
pipe.enable_xformers_memory_efficient_attention() |
|
pipe.enable_vae_slicing() |
|
pipe.to("cuda") |
|
|
|
self.pipe = pipe |
|
|
|
if not high_res: |
|
high_res = HighRes() |
|
high_res.load() |
|
self.high_res = high_res |
|
|
|
if not upscaler: |
|
upscaler = Upscaler() |
|
upscaler.load() |
|
self.upscaler = upscaler |
|
|
|
if not remove_background: |
|
remove_background = RemoveBackgroundV2() |
|
self.remove_background = remove_background |
|
|
|
self.__loaded = True |
|
|
|
def unload(self): |
|
self.__loaded = False |
|
self.pipe = None |
|
self.high_res = None |
|
self.upscaler = None |
|
self.remove_background = None |
|
|
|
clear_cuda_and_gc() |
|
|
|
@torch.inference_mode() |
|
def replace( |
|
self, |
|
image: Union[str, Image.Image], |
|
width: int, |
|
height: int, |
|
prompt: List[str], |
|
negative_prompt: List[str], |
|
conditioning_scale: float, |
|
seed: int, |
|
steps: int, |
|
apply_high_res: bool = False, |
|
model_type: ModelType = ModelType.REAL, |
|
): |
|
if type(image) is str: |
|
image = download_image(image) |
|
|
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
|
|
image = image.convert("RGB") |
|
if max(image.size) > 1024: |
|
image = ImageUtil.resize_image(image, dimension=1024) |
|
image = self.remove_background.remove(image, model_type=model_type) |
|
|
|
width = int(width) |
|
height = int(height) |
|
|
|
resolution = max(width, height) |
|
|
|
image = ImageUtil.resize_image(image, resolution) |
|
image = ImageUtil.padd_image(image, width, height) |
|
|
|
mask = image.copy() |
|
pixdata = mask.load() |
|
|
|
w, h = mask.size |
|
for y in range(h): |
|
for x in range(w): |
|
item = pixdata[x, y] |
|
if item[3] == 0: |
|
pixdata[x, y] = (255, 255, 255, 255) |
|
else: |
|
pixdata[x, y] = (0, 0, 0, 255) |
|
|
|
condition_image = ControlNet.canny_detect_edge(image) |
|
mask = mask.convert("RGB") |
|
|
|
result = self.pipe.__call__( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
image=condition_image, |
|
controlnet_conditioning_scale=conditioning_scale, |
|
guidance_scale=9, |
|
height=height, |
|
num_inference_steps=steps, |
|
width=width, |
|
) |
|
result = Result.from_result(result) |
|
|
|
images, has_nsfw = result |
|
|
|
if not has_nsfw: |
|
for i in range(len(images)): |
|
images[i].paste(image, (0, 0), image) |
|
|
|
return (images, has_nsfw) |
|
|