import torch from torchvision import transforms from torchvision.transforms.functional import to_pil_image from models import segmentation, inpainting from PIL import Image class ProductBackgroundModifier: def __init__( self, segmentation_model: segmentation.SegmentationModel, inpainting_model: inpainting.InpaintingModel, device = torch.device("cpu"), ) -> None: self.segmentation_model = segmentation_model self.inpainting_model = inpainting_model self.device = device self.transform = transforms.Compose([ transforms.ToTensor(), transforms.Resize(1024), transforms.CenterCrop((1024, 1024)) ]) def generate(self, image: Image.Image, prompt: str) -> Image.Image: image_tensor = self.transform(image).to(self.device) mask_image = self.segmentation_model.generate(image_tensor) generated_image = self.inpainting_model.generate(image=image, mask_image=mask_image, prompt=prompt) return generated_image