|
import gradio as gr |
|
from PIL import Image |
|
from models.segmentation import SamSegmentationModel |
|
from models.inpainting import KandingskyInpaintingModel |
|
from models.product import ProductBackgroundModifier |
|
import torch |
|
|
|
def generate(image: Image.Image, prompt: str): |
|
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
|
model = ProductBackgroundModifier( |
|
segmentation_model=SamSegmentationModel( |
|
model_type="vit_h", |
|
checkpoint_path="model_checkpoints/sam_vit.pth", |
|
device=device, |
|
), |
|
inpainting_model=KandingskyInpaintingModel(), |
|
device=device |
|
) |
|
generated = model.generate(image=image, prompt=prompt) |
|
return generated |
|
|
|
gr.Interface( |
|
fn=generate, |
|
inputs=[ |
|
gr.Image(type="pil"), |
|
gr.Text() |
|
], |
|
outputs=gr.Image(type="pil"), |
|
).launch() |