|
from io import BytesIO |
|
from typing import List, Union |
|
|
|
import torch |
|
from diffusers import ( |
|
ControlNetModel, |
|
StableDiffusionControlNetInpaintPipeline, |
|
StableDiffusionInpaintPipeline, |
|
UniPCMultistepScheduler, |
|
) |
|
from PIL import Image, ImageFilter, ImageOps |
|
|
|
import internals.util.image as ImageUtil |
|
from internals.data.result import Result |
|
from internals.pipelines.commons import AbstractPipeline |
|
from internals.pipelines.controlnets import ControlNet |
|
from internals.pipelines.remove_background import RemoveBackgroundV2 |
|
from internals.pipelines.upscaler import Upscaler |
|
from internals.util.commons import download_image |
|
from internals.util.config import get_hf_cache_dir |
|
|
|
|
|
class ReplaceBackground(AbstractPipeline): |
|
def load(self, upscaler: Upscaler, remove_background: RemoveBackgroundV2): |
|
controlnet = ControlNetModel.from_pretrained( |
|
"lllyasviel/control_v11p_sd15_lineart", |
|
torch_dtype=torch.float16, |
|
cache_dir=get_hf_cache_dir(), |
|
).to("cuda") |
|
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( |
|
"runwayml/stable-diffusion-inpainting", |
|
controlnet=controlnet, |
|
torch_dtype=torch.float16, |
|
cache_dir=get_hf_cache_dir(), |
|
) |
|
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) |
|
pipe.to("cuda") |
|
|
|
upscaler.load() |
|
|
|
self.pipe = pipe |
|
self.upscaler = upscaler |
|
self.remove_background = remove_background |
|
|
|
def replace( |
|
self, |
|
image: Union[str, Image.Image], |
|
width: int, |
|
height: int, |
|
product_scale_width: float, |
|
prompt: Union[str, List[str]], |
|
negative_prompt: Union[str, List[str]], |
|
resize_dimension: int, |
|
conditioning_scale: float, |
|
seed: int, |
|
steps: int, |
|
): |
|
if type(image) is str: |
|
image = download_image(image) |
|
|
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
|
|
image = image.convert("RGB") |
|
if max(image.size) > 1536: |
|
image = ImageUtil.resize_image(image, dimension=1536) |
|
image = self.remove_background.remove(image) |
|
|
|
width = int(width) |
|
height = int(height) |
|
|
|
n_width = int(width * product_scale_width) |
|
n_height = int(n_width * height // width) |
|
|
|
print(width, height, n_width, n_height) |
|
|
|
image = ImageUtil.padd_image(image, n_width, n_height) |
|
|
|
f_image = Image.new("RGBA", (width, height), (0, 0, 0, 0)) |
|
f_image.paste(image, ((width - n_width) // 2, (height - n_height) // 2)) |
|
image = f_image |
|
|
|
mask = image.copy() |
|
pixdata = mask.load() |
|
|
|
w, h = mask.size |
|
for y in range(h): |
|
for x in range(w): |
|
item = pixdata[x, y] |
|
if item[3] == 0: |
|
pixdata[x, y] = (255, 255, 255, 255) |
|
else: |
|
pixdata[x, y] = (0, 0, 0, 255) |
|
|
|
mask = mask.convert("RGB") |
|
|
|
condition_image = ControlNet.linearart_condition_image(image) |
|
|
|
result = self.pipe.__call__( |
|
prompt=prompt, |
|
negative_prompt=negative_prompt, |
|
image=image, |
|
mask_image=mask, |
|
control_image=condition_image, |
|
controlnet_conditioning_scale=conditioning_scale, |
|
guidance_scale=9, |
|
strength=1, |
|
height=height, |
|
width=width, |
|
) |
|
result = Result.from_result(result) |
|
|
|
images, has_nsfw = result |
|
|
|
if not has_nsfw: |
|
for i in range(len(images)): |
|
images[i].paste(image, (0, 0), image) |
|
w, h = images[i].size |
|
out_bytes = self.upscaler.upscale( |
|
image=images[i], |
|
width=w, |
|
height=h, |
|
face_enhance=False, |
|
resize_dimension=resize_dimension, |
|
) |
|
images[i] = Image.open(BytesIO(out_bytes)).convert("RGB") |
|
|
|
return (images, has_nsfw) |
|
|