|
import torch |
|
from torchvision import transforms |
|
from torchvision.transforms.functional import to_pil_image |
|
from models import segmentation, inpainting |
|
from PIL import Image |
|
|
|
class ProductBackgroundModifier: |
|
def __init__( |
|
self, |
|
segmentation_model: segmentation.SegmentationModel, |
|
inpainting_model: inpainting.InpaintingModel, |
|
device = torch.device("cpu"), |
|
) -> None: |
|
self.segmentation_model = segmentation_model |
|
self.inpainting_model = inpainting_model |
|
self.device = device |
|
self.transform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Resize(1024), |
|
transforms.CenterCrop((1024, 1024)) |
|
]) |
|
|
|
def generate(self, image: Image.Image, prompt: str) -> Image.Image: |
|
image_tensor = self.transform(image).to(self.device) |
|
mask_image = self.segmentation_model.generate(image_tensor) |
|
generated_image = self.inpainting_model.generate(image=image, mask_image=mask_image, prompt=prompt) |
|
return generated_image |
|
|
|
|