|
import torch |
|
from diffusers import AutoPipelineForInpainting |
|
from torchvision.transforms.functional import to_pil_image |
|
from PIL import Image |
|
|
|
class InpaintingModel: |
|
def __init__(self) -> None: |
|
pass |
|
def generate(self, image: torch.Tensor, mask_image: torch.Tensor, prompt: str) -> Image.Image: |
|
pass |
|
|
|
class KandingskyInpaintingModel(InpaintingModel): |
|
def __init__( |
|
self, |
|
device = torch.device("cpu"), |
|
) -> None: |
|
super().__init__() |
|
self.device = device |
|
self.model = AutoPipelineForInpainting.from_pretrained("kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16) |
|
self.model.enable_model_cpu_offload() |
|
self.negative_prompt = "deformed, ugly, disfigured" |
|
|
|
def generate(self, image: Image.Image, mask_image: Image.Image, prompt: str) -> Image.Image: |
|
output = self.model(prompt=prompt, negative_prompt=self.negative_prompt, image=image, mask_image=mask_image) |
|
return output.images[0] |
|
|
|
|