merve's picture
merve HF staff
Update app.py
be35f94 verified
from transformers import pipeline, SegGptImageProcessor, SegGptForImageSegmentation
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import spaces
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
depth_anything = pipeline(task = "depth-estimation", model="nielsr/depth-anything-small", device=device)
checkpoint = "BAAI/seggpt-vit-large"
image_processor = SegGptImageProcessor.from_pretrained(checkpoint)
model = SegGptForImageSegmentation.from_pretrained(checkpoint).to(device)
def infer_seggpt(image_input, image_prompt, mask_prompt):
num_labels = 100
inputs = image_processor(
images=image_input,
prompt_images=image_prompt,
prompt_masks=mask_prompt,
return_tensors="pt",
num_labels=num_labels
).to(device)
with torch.no_grad():
outputs = model(**inputs)
target_sizes = [image_input.shape[:2]]
mask = image_processor.post_process_semantic_segmentation(outputs, target_sizes, num_labels=num_labels)[0]
palette = image_processor.get_palette(num_labels)
fig, ax = plt.subplots()
plt.gca().get_xaxis().get_major_formatter().set_useOffset(False)
mask_rgb = image_processor.mask_to_rgb(mask.cpu().numpy(), palette, data_format="channels_last")
print(mask_rgb.shape, image_input.shape)
ax.imshow(Image.fromarray(image_input))
ax.imshow(mask_rgb, cmap='viridis', alpha=0.6)
ax.axis("off")
ax.margins(0)
plt.show()
plt.savefig("masks.png", bbox_inches='tight', pad_inches=0)
return "masks.png"
@spaces.GPU
def infer(image_input, image_prompt, mask_prompt):
sg_masks = []
mask_prompt = depth_anything(image_prompt)["depth"].convert("RGB")
sg_mask = infer_seggpt(np.asarray(image_input), np.asarray(image_prompt),
np.asarray(mask_prompt))
return sg_mask
import gradio as gr
demo = gr.Interface(
infer,
inputs=[gr.Image(type="pil", label="Image Input"), gr.Image(type="pil", label="Image Prompt")],
outputs=[gr.Image(type="filepath", label="Mask Output")],
#gr.Image(type="numpy", label="Output Mask")],
title="SegGPT 🤝 Depth Anything: Speak to Segmentation in Image",
description="SegGPT is a one-shot image segmentation model where one could ask model what to segment through uploading an example image and an example mask, and ask to segment the same thing in another image. In this demo, we have combined SegGPT and Depth Anything to automatically generate the mask for most outstanding object and segment the same thing in another image for you. You can see how it works by trying the example.",
examples=[
["./cats.png", "./cat.png"],
])
demo.launch(debug=True)