ihsanvp's picture
fix: mask image showing
76f85f6
import torch
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image
from models import segmentation, inpainting
from PIL import Image
class ProductBackgroundModifier:
def __init__(
self,
segmentation_model: segmentation.SegmentationModel,
inpainting_model: inpainting.InpaintingModel,
device = torch.device("cpu"),
) -> None:
self.segmentation_model = segmentation_model
self.inpainting_model = inpainting_model
self.device = device
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize(1024),
transforms.CenterCrop((1024, 1024))
])
def generate(self, image: Image.Image, prompt: str) -> Image.Image:
image_tensor = self.transform(image).to(self.device)
mask_image = self.segmentation_model.generate(image_tensor)
generated_image = self.inpainting_model.generate(image=image, mask_image=mask_image, prompt=prompt)
return generated_image