CM2000112 / internals /pipelines /replace_background.py
jayparmr's picture
Upload folder using huggingface_hub
1bc457e
raw
history blame
4.08 kB
from io import BytesIO
from typing import List, Union
import torch
from diffusers import (
ControlNetModel,
StableDiffusionControlNetInpaintPipeline,
StableDiffusionInpaintPipeline,
UniPCMultistepScheduler,
)
from PIL import Image, ImageFilter, ImageOps
import internals.util.image as ImageUtil
from internals.data.result import Result
from internals.pipelines.commons import AbstractPipeline
from internals.pipelines.controlnets import ControlNet
from internals.pipelines.remove_background import RemoveBackgroundV2
from internals.pipelines.upscaler import Upscaler
from internals.util.commons import download_image
from internals.util.config import get_hf_cache_dir
class ReplaceBackground(AbstractPipeline):
def load(self, upscaler: Upscaler, remove_background: RemoveBackgroundV2):
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_lineart",
torch_dtype=torch.float16,
cache_dir=get_hf_cache_dir(),
).to("cuda")
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
"runwayml/stable-diffusion-inpainting",
controlnet=controlnet,
torch_dtype=torch.float16,
cache_dir=get_hf_cache_dir(),
)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.to("cuda")
upscaler.load()
self.pipe = pipe
self.upscaler = upscaler
self.remove_background = remove_background
def replace(
self,
image: Union[str, Image.Image],
width: int,
height: int,
product_scale_width: float,
prompt: Union[str, List[str]],
negative_prompt: Union[str, List[str]],
resize_dimension: int,
conditioning_scale: float,
seed: int,
steps: int,
):
if type(image) is str:
image = download_image(image)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
image = image.convert("RGB")
if max(image.size) > 1536:
image = ImageUtil.resize_image(image, dimension=1536)
image = self.remove_background.remove(image)
width = int(width)
height = int(height)
n_width = int(width * product_scale_width)
n_height = int(n_width * height // width)
print(width, height, n_width, n_height)
image = ImageUtil.padd_image(image, n_width, n_height)
f_image = Image.new("RGBA", (width, height), (0, 0, 0, 0))
f_image.paste(image, ((width - n_width) // 2, (height - n_height) // 2))
image = f_image
mask = image.copy()
pixdata = mask.load()
w, h = mask.size
for y in range(h):
for x in range(w):
item = pixdata[x, y]
if item[3] == 0:
pixdata[x, y] = (255, 255, 255, 255)
else:
pixdata[x, y] = (0, 0, 0, 255)
mask = mask.convert("RGB")
condition_image = ControlNet.linearart_condition_image(image)
result = self.pipe.__call__(
prompt=prompt,
negative_prompt=negative_prompt,
image=image,
mask_image=mask,
control_image=condition_image,
controlnet_conditioning_scale=conditioning_scale,
guidance_scale=9,
strength=1,
height=height,
width=width,
)
result = Result.from_result(result)
images, has_nsfw = result
if not has_nsfw:
for i in range(len(images)):
images[i].paste(image, (0, 0), image)
w, h = images[i].size
out_bytes = self.upscaler.upscale(
image=images[i],
width=w,
height=h,
face_enhance=False,
resize_dimension=resize_dimension,
)
images[i] = Image.open(BytesIO(out_bytes)).convert("RGB")
return (images, has_nsfw)