File size: 4,776 Bytes
a3d6c18 42ef134 a3d6c18 fd5252e 10230ea 5c695cd 22df957 10230ea a3d6c18 22df957 b71808f a3d6c18 1377831 fd5252e a3d6c18 22df957 a3d6c18 10230ea a3d6c18 b71808f 42ef134 5c695cd 1377831 42ef134 5c695cd 1bc457e a3d6c18 5c695cd fd5252e 42ef134 5c695cd 42ef134 fd5252e 42ef134 10230ea a3d6c18 1377831 42ef134 a3d6c18 42ef134 a3d6c18 42ef134 22df957 42ef134 a3d6c18 1377831 1bc457e a3d6c18 1377831 10230ea a3d6c18 4ff5093 a3d6c18 5c695cd a3d6c18 fd5252e 049a85c a3d6c18 5c695cd a3d6c18 fd5252e 5c695cd fd5252e a3d6c18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
from io import BytesIO
from typing import List, Optional, Union
import torch
from cv2 import inpaint
from diffusers import (
ControlNetModel,
StableDiffusionControlNetInpaintPipeline,
StableDiffusionControlNetPipeline,
StableDiffusionInpaintPipeline,
UniPCMultistepScheduler,
)
from PIL import Image, ImageFilter, ImageOps
import internals.util.image as ImageUtil
from internals.data.result import Result
from internals.data.task import ModelType
from internals.pipelines.commons import AbstractPipeline
from internals.pipelines.controlnets import ControlNet
from internals.pipelines.high_res import HighRes
from internals.pipelines.inpainter import InPainter
from internals.pipelines.remove_background import RemoveBackgroundV2
from internals.pipelines.upscaler import Upscaler
from internals.util.cache import clear_cuda_and_gc
from internals.util.commons import download_image
from internals.util.config import (
get_hf_cache_dir,
get_hf_token,
get_inpaint_model_path,
get_model_dir,
)
class ReplaceBackground(AbstractPipeline):
__loaded = False
def load(
self,
upscaler: Optional[Upscaler] = None,
remove_background: Optional[RemoveBackgroundV2] = None,
base: Optional[AbstractPipeline] = None,
high_res: Optional[HighRes] = None,
):
if self.__loaded:
return
controlnet_model = ControlNetModel.from_pretrained(
"lllyasviel/control_v11p_sd15_canny",
torch_dtype=torch.float16,
cache_dir=get_hf_cache_dir(),
).to("cuda")
if base:
pipe = StableDiffusionControlNetPipeline(
**base.pipe.components,
controlnet=controlnet_model,
)
else:
pipe = StableDiffusionControlNetPipeline.from_pretrained(
get_model_dir(),
controlnet=controlnet_model,
torch_dtype=torch.float16,
cache_dir=get_hf_cache_dir(),
use_auth_token=get_hf_token(),
)
pipe.enable_xformers_memory_efficient_attention()
pipe.enable_vae_slicing()
pipe.to("cuda")
self.pipe = pipe
if not high_res:
high_res = HighRes()
high_res.load()
self.high_res = high_res
if not upscaler:
upscaler = Upscaler()
upscaler.load()
self.upscaler = upscaler
if not remove_background:
remove_background = RemoveBackgroundV2()
self.remove_background = remove_background
self.__loaded = True
def unload(self):
self.__loaded = False
self.pipe = None
self.high_res = None
self.upscaler = None
self.remove_background = None
clear_cuda_and_gc()
@torch.inference_mode()
def replace(
self,
image: Union[str, Image.Image],
width: int,
height: int,
prompt: List[str],
negative_prompt: List[str],
conditioning_scale: float,
seed: int,
steps: int,
apply_high_res: bool = False,
model_type: ModelType = ModelType.REAL,
):
if type(image) is str:
image = download_image(image)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
image = image.convert("RGB")
if max(image.size) > 1024:
image = ImageUtil.resize_image(image, dimension=1024)
image = self.remove_background.remove(image, model_type=model_type)
width = int(width)
height = int(height)
resolution = max(width, height)
image = ImageUtil.resize_image(image, resolution)
image = ImageUtil.padd_image(image, width, height)
mask = image.copy()
pixdata = mask.load()
w, h = mask.size
for y in range(h):
for x in range(w):
item = pixdata[x, y]
if item[3] == 0:
pixdata[x, y] = (255, 255, 255, 255)
else:
pixdata[x, y] = (0, 0, 0, 255)
condition_image = ControlNet.canny_detect_edge(image)
mask = mask.convert("RGB")
result = self.pipe.__call__(
prompt=prompt,
negative_prompt=negative_prompt,
image=condition_image,
controlnet_conditioning_scale=conditioning_scale,
guidance_scale=9,
height=height,
num_inference_steps=steps,
width=width,
)
result = Result.from_result(result)
images, has_nsfw = result
if not has_nsfw:
for i in range(len(images)):
images[i].paste(image, (0, 0), image)
return (images, has_nsfw)
|