CM2000112 / internals /pipelines /replace_background.py
jayparmr's picture
Upload folder using huggingface_hub
22df957 verified
raw
history blame
4.78 kB
from io import BytesIO
from typing import List, Optional, Union
import torch
from cv2 import inpaint
from diffusers import (
ControlNetModel,
StableDiffusionControlNetInpaintPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionInpaintPipeline,
UniPCMultistepScheduler,
)
from PIL import Image, ImageFilter, ImageOps
import internals.util.image as ImageUtil
from internals.data.result import Result
from internals.data.task import ModelType
from internals.pipelines.commons import AbstractPipeline
from internals.pipelines.controlnets import ControlNet
from internals.pipelines.high_res import HighRes
from internals.pipelines.inpainter import InPainter
from internals.pipelines.remove_background import RemoveBackgroundV2
from internals.pipelines.upscaler import Upscaler
from internals.util.cache import clear_cuda_and_gc
from internals.util.commons import download_image
from internals.util.config import (
get_hf_cache_dir,
get_hf_token,
get_inpaint_model_path,
get_model_dir,
)
class ReplaceBackground(AbstractPipeline):
__loaded = False
def load(
self,
upscaler: Optional[Upscaler] = None,
remove_background: Optional[RemoveBackgroundV2] = None,
base: Optional[AbstractPipeline] = None,
high_res: Optional[HighRes] = None,
):
if self.__loaded:
return
controlnet_model = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_canny",
torch_dtype=torch.float16,
cache_dir=get_hf_cache_dir(),
).to("cuda")
if base:
pipe = StableDiffusionControlNetPipeline(
**base.pipe.components,
controlnet=controlnet_model,
)
else:
pipe = StableDiffusionControlNetPipeline.from_pretrained(
get_model_dir(),
controlnet=controlnet_model,
torch_dtype=torch.float16,
cache_dir=get_hf_cache_dir(),
use_auth_token=get_hf_token(),
)
pipe.enable_xformers_memory_efficient_attention()
pipe.enable_vae_slicing()
pipe.to("cuda")
self.pipe = pipe
if not high_res:
high_res = HighRes()
high_res.load()
self.high_res = high_res
if not upscaler:
upscaler = Upscaler()
upscaler.load()
self.upscaler = upscaler
if not remove_background:
remove_background = RemoveBackgroundV2()
self.remove_background = remove_background
self.__loaded = True
def unload(self):
self.__loaded = False
self.pipe = None
self.high_res = None
self.upscaler = None
self.remove_background = None
clear_cuda_and_gc()
@torch.inference_mode()
def replace(
self,
image: Union[str, Image.Image],
width: int,
height: int,
prompt: List[str],
negative_prompt: List[str],
conditioning_scale: float,
seed: int,
steps: int,
apply_high_res: bool = False,
model_type: ModelType = ModelType.REAL,
):
if type(image) is str:
image = download_image(image)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
image = image.convert("RGB")
if max(image.size) > 1024:
image = ImageUtil.resize_image(image, dimension=1024)
image = self.remove_background.remove(image, model_type=model_type)
width = int(width)
height = int(height)
resolution = max(width, height)
image = ImageUtil.resize_image(image, resolution)
image = ImageUtil.padd_image(image, width, height)
mask = image.copy()
pixdata = mask.load()
w, h = mask.size
for y in range(h):
for x in range(w):
item = pixdata[x, y]
if item[3] == 0:
pixdata[x, y] = (255, 255, 255, 255)
else:
pixdata[x, y] = (0, 0, 0, 255)
condition_image = ControlNet.canny_detect_edge(image)
mask = mask.convert("RGB")
result = self.pipe.__call__(
prompt=prompt,
negative_prompt=negative_prompt,
image=condition_image,
controlnet_conditioning_scale=conditioning_scale,
guidance_scale=9,
height=height,
num_inference_steps=steps,
width=width,
)
result = Result.from_result(result)
images, has_nsfw = result
if not has_nsfw:
for i in range(len(images)):
images[i].paste(image, (0, 0), image)
return (images, has_nsfw)