File size: 870 Bytes
bc05b03 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
import gradio as gr
from PIL import Image
from models.segmentation import SamSegmentationModel
from models.inpainting import KandingskyInpaintingModel
from models.product import ProductBackgroundModifier
import torch
def generate(image: Image.Image, prompt: str):
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = ProductBackgroundModifier(
segmentation_model=SamSegmentationModel(
model_type="vit_h",
checkpoint_path="model_checkpoints/sam_vit.pth",
device=device,
),
inpainting_model=KandingskyInpaintingModel(),
device=device
)
generated = model.generate(image=image, prompt=prompt)
return generated
gr.Interface(
fn=generate,
inputs=[
gr.Image(type="pil"),
gr.Text()
],
outputs=gr.Image(type="pil"),
).launch() |