|
from typing import List, Union |
|
|
|
import torch |
|
from diffusers import StableDiffusionInpaintPipeline |
|
|
|
from internals.pipelines.commons import AbstractPipeline |
|
from internals.util.commons import disable_safety_checker, download_image |
|
from internals.util.config import get_hf_cache_dir |
|
|
|
|
|
class InPainter(AbstractPipeline): |
|
def load(self): |
|
self.pipe = StableDiffusionInpaintPipeline.from_pretrained( |
|
"jayparmr/icbinp_v8_inpaint_v2", |
|
torch_dtype=torch.float16, |
|
cache_dir=get_hf_cache_dir(), |
|
).to("cuda") |
|
disable_safety_checker(self.pipe) |
|
|
|
def create(self, pipeline: AbstractPipeline): |
|
self.pipe = StableDiffusionInpaintPipeline(**pipeline.pipe.components).to( |
|
"cuda" |
|
) |
|
disable_safety_checker(self.pipe) |
|
|
|
@torch.inference_mode() |
|
def process( |
|
self, |
|
image_url: str, |
|
mask_image_url: str, |
|
width: int, |
|
height: int, |
|
seed: int, |
|
prompt: Union[str, List[str]], |
|
negative_prompt: Union[str, List[str]], |
|
steps: int = 50, |
|
): |
|
torch.manual_seed(seed) |
|
|
|
input_img = download_image(image_url).resize((width, height)) |
|
mask_img = download_image(mask_image_url).resize((width, height)) |
|
|
|
return self.pipe.__call__( |
|
prompt=prompt, |
|
image=input_img, |
|
mask_image=mask_img, |
|
height=height, |
|
width=width, |
|
negative_prompt=negative_prompt, |
|
num_inference_steps=steps, |
|
).images |
|
|