Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# Set-of-Mark (SoM) Prompting for Visual Grounding in GPT-4V | |
# Copyright (c) 2023 Microsoft | |
# Licensed under The MIT License [see LICENSE for details] | |
# Written by: | |
# Jianwei Yang ([email protected]) | |
# Xueyan Zou ([email protected]) | |
# Hao Zhang ([email protected]) | |
# -------------------------------------------------------- | |
import io | |
import gradio as gr | |
import torch | |
import argparse | |
from PIL import Image | |
# seem | |
from seem.modeling.BaseModel import BaseModel as BaseModel_Seem | |
from seem.utils.distributed import init_distributed as init_distributed_seem | |
from seem.modeling import build_model as build_model_seem | |
from task_adapter.seem.tasks import interactive_seem_m2m_auto, inference_seem_pano, inference_seem_interactive | |
# semantic sam | |
from semantic_sam.BaseModel import BaseModel | |
from semantic_sam import build_model | |
from semantic_sam.utils.dist import init_distributed_mode | |
from semantic_sam.utils.arguments import load_opt_from_config_file | |
from semantic_sam.utils.constants import COCO_PANOPTIC_CLASSES | |
from task_adapter.semantic_sam.tasks import inference_semsam_m2m_auto, prompt_switch | |
# sam | |
from segment_anything import sam_model_registry | |
from task_adapter.sam.tasks.inference_sam_m2m_auto import inference_sam_m2m_auto | |
from task_adapter.sam.tasks.inference_sam_m2m_interactive import inference_sam_m2m_interactive | |
from task_adapter.utils.visualizer import Visualizer | |
from detectron2.data import MetadataCatalog | |
metadata = MetadataCatalog.get('coco_2017_train_panoptic') | |
from scipy.ndimage import label | |
import numpy as np | |
from gpt4v import request_gpt4v | |
from openai import OpenAI | |
from pydub import AudioSegment | |
from pydub.playback import play | |
import matplotlib.colors as mcolors | |
css4_colors = mcolors.CSS4_COLORS | |
color_proposals = [list(mcolors.hex2color(color)) for color in css4_colors.values()] | |
client = OpenAI() | |
''' | |
build args | |
''' | |
semsam_cfg = "configs/semantic_sam_only_sa-1b_swinL.yaml" | |
seem_cfg = "configs/seem_focall_unicl_lang_v1.yaml" | |
semsam_ckpt = "./swinl_only_sam_many2many.pth" | |
sam_ckpt = "./sam_vit_h_4b8939.pth" | |
seem_ckpt = "./seem_focall_v1.pt" | |
opt_semsam = load_opt_from_config_file(semsam_cfg) | |
opt_seem = load_opt_from_config_file(seem_cfg) | |
opt_seem = init_distributed_seem(opt_seem) | |
''' | |
build model | |
''' | |
model_semsam = BaseModel(opt_semsam, build_model(opt_semsam)).from_pretrained(semsam_ckpt).eval().cuda() | |
model_sam = sam_model_registry["vit_h"](checkpoint=sam_ckpt).eval().cuda() | |
model_seem = BaseModel_Seem(opt_seem, build_model_seem(opt_seem)).from_pretrained(seem_ckpt).eval().cuda() | |
with torch.no_grad(): | |
with torch.autocast(device_type='cuda', dtype=torch.float16): | |
model_seem.model.sem_seg_head.predictor.lang_encoder.get_text_embeddings(COCO_PANOPTIC_CLASSES + ["background"], is_eval=True) | |
history_images = [] | |
history_masks = [] | |
history_texts = [] | |
def inference(image, slider, mode, alpha, label_mode, anno_mode, *args, **kwargs): | |
global history_images; history_images = [] | |
global history_masks; history_masks = [] | |
_image = image['background'].convert('RGB') | |
_mask = image['layers'][0].convert('L') if image['layers'] else None | |
if slider < 1.5: | |
model_name = 'seem' | |
elif slider > 2.5: | |
model_name = 'sam' | |
else: | |
if mode == 'Automatic': | |
model_name = 'semantic-sam' | |
if slider < 1.5 + 0.14: | |
level = [1] | |
elif slider < 1.5 + 0.28: | |
level = [2] | |
elif slider < 1.5 + 0.42: | |
level = [3] | |
elif slider < 1.5 + 0.56: | |
level = [4] | |
elif slider < 1.5 + 0.70: | |
level = [5] | |
elif slider < 1.5 + 0.84: | |
level = [6] | |
else: | |
level = [6, 1, 2, 3, 4, 5] | |
else: | |
model_name = 'sam' | |
if label_mode == 'Alphabet': | |
label_mode = 'a' | |
else: | |
label_mode = '1' | |
text_size, hole_scale, island_scale=640,100,100 | |
text, text_part, text_thresh = '','','0.0' | |
with torch.autocast(device_type='cuda', dtype=torch.float16): | |
semantic=False | |
if mode == "Interactive": | |
labeled_array, num_features = label(np.asarray(_mask)) | |
spatial_masks = torch.stack([torch.from_numpy(labeled_array == i+1) for i in range(num_features)]) | |
if model_name == 'semantic-sam': | |
model = model_semsam | |
output, mask = inference_semsam_m2m_auto(model, _image, level, text, text_part, text_thresh, text_size, hole_scale, island_scale, semantic, label_mode=label_mode, alpha=alpha, anno_mode=anno_mode, *args, **kwargs) | |
elif model_name == 'sam': | |
model = model_sam | |
if mode == "Automatic": | |
output, mask = inference_sam_m2m_auto(model, _image, text_size, label_mode, alpha, anno_mode) | |
elif mode == "Interactive": | |
output, mask = inference_sam_m2m_interactive(model, _image, spatial_masks, text_size, label_mode, alpha, anno_mode) | |
elif model_name == 'seem': | |
model = model_seem | |
if mode == "Automatic": | |
output, mask = inference_seem_pano(model, _image, text_size, label_mode, alpha, anno_mode) | |
elif mode == "Interactive": | |
output, mask = inference_seem_interactive(model, _image, spatial_masks, text_size, label_mode, alpha, anno_mode) | |
# convert output to PIL image | |
history_masks.append(mask) | |
history_images.append(Image.fromarray(output)) | |
return (output, []) | |
def gpt4v_response(message, history): | |
global history_images | |
global history_texts; history_texts = [] | |
try: | |
res = request_gpt4v(message, history_images[0]) | |
history_texts.append(res) | |
return res | |
except Exception as e: | |
return None | |
def highlight(mode, alpha, label_mode, anno_mode, *args, **kwargs): | |
res = history_texts[0] | |
# find the seperate numbers in sentence res | |
res = res.split(' ') | |
res = [r.replace('.','').replace(',','').replace(')','').replace('"','') for r in res] | |
# find all numbers in '[]' | |
res = [r for r in res if '[' in r] | |
res = [r.split('[')[1] for r in res] | |
res = [r.split(']')[0] for r in res] | |
res = [r for r in res if r.isdigit()] | |
res = list(set(res)) | |
sections = [] | |
for i, r in enumerate(res): | |
mask_i = history_masks[0][int(r)-1]['segmentation'] | |
sections.append((mask_i, r)) | |
return (history_images[0], sections) | |
''' | |
launch app | |
''' | |
demo = gr.Blocks() | |
image = gr.ImageMask(label="Input", type="pil", sources=["upload"], interactive=True, brush=gr.Brush(colors=["#FFFFFF"])) | |
slider = gr.Slider(1, 3, value=1.8, label="Granularity") # info="Choose in [1, 1.5), [1.5, 2.5), [2.5, 3] for [seem, semantic-sam (multi-level), sam]" | |
mode = gr.Radio(['Automatic', 'Interactive', ], value='Automatic', label="Segmentation Mode") | |
anno_mode = gr.CheckboxGroup(choices=["Mark", "Mask", "Box"], value=['Mark'], label="Annotation Mode") | |
image_out = gr.AnnotatedImage(label="SoM Visual Prompt", height=512) | |
runBtn = gr.Button("Run") | |
highlightBtn = gr.Button("Highlight") | |
bot = gr.Chatbot(label="GPT-4V + SoM", height=256) | |
slider_alpha = gr.Slider(0, 1, value=0.05, label="Mask Alpha") #info="Choose in [0, 1]" | |
label_mode = gr.Radio(['Number', 'Alphabet'], value='Number', label="Mark Mode") | |
title = "Set-of-Mark (SoM) Visual Prompting for Extraordinary Visual Grounding in GPT-4V" | |
description = "This is a demo for SoM Prompting to unleash extraordinary visual grounding in GPT-4V. Please upload an image and them click the 'Run' button to get the image with marks. Then chat with GPT-4V below!" | |
with demo: | |
gr.Markdown("<h1 style='text-align: center'><img src='https://som-gpt4v.github.io/website/img/som_logo.png' style='height:50px;display:inline-block'/> Set-of-Mark (SoM) Prompting Unleashes Extraordinary Visual Grounding in GPT-4V</h1>") | |
# gr.Markdown("<h2 style='text-align: center; margin-bottom: 1rem'>Project: <a href='https://som-gpt4v.github.io/'>link</a> arXiv: <a href='https://arxiv.org/abs/2310.11441'>link</a> Code: <a href='https://github.com/microsoft/SoM'>link</a></h2>") | |
with gr.Row(): | |
with gr.Column(): | |
image.render() | |
slider.render() | |
with gr.Accordion("Detailed prompt settings (e.g., mark type)", open=False): | |
with gr.Row(): | |
mode.render() | |
anno_mode.render() | |
with gr.Row(): | |
slider_alpha.render() | |
label_mode.render() | |
with gr.Column(): | |
image_out.render() | |
runBtn.render() | |
highlightBtn.render() | |
with gr.Row(): | |
gr.ChatInterface(chatbot=bot, fn=gpt4v_response) | |
runBtn.click(inference, inputs=[image, slider, mode, slider_alpha, label_mode, anno_mode], | |
outputs = image_out) | |
highlightBtn.click(highlight, inputs=[image, mode, slider_alpha, label_mode, anno_mode], | |
outputs = image_out) | |
demo.queue().launch(share=True,server_port=6092) | |