from typing import List, Union import torch from diffusers import StableDiffusionInpaintPipeline from internals.pipelines.commons import AbstractPipeline from internals.util.commons import disable_safety_checker, download_image from internals.util.config import get_hf_cache_dir class InPainter(AbstractPipeline): def load(self): self.pipe = StableDiffusionInpaintPipeline.from_pretrained( "jayparmr/icbinp_v8_inpaint_v2", torch_dtype=torch.float16, cache_dir=get_hf_cache_dir(), ).to("cuda") disable_safety_checker(self.pipe) def create(self, pipeline: AbstractPipeline): self.pipe = StableDiffusionInpaintPipeline(**pipeline.pipe.components).to( "cuda" ) disable_safety_checker(self.pipe) @torch.inference_mode() def process( self, image_url: str, mask_image_url: str, width: int, height: int, seed: int, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]], steps: int = 50, ): torch.manual_seed(seed) input_img = download_image(image_url).resize((width, height)) mask_img = download_image(mask_image_url).resize((width, height)) return self.pipe.__call__( prompt=prompt, image=input_img, mask_image=mask_img, height=height, width=width, negative_prompt=negative_prompt, num_inference_steps=steps, ).images