sky24h's picture
fix filename
065b69d verified
import os
import cv2
import spaces
import gradio as gr
from PIL import Image
from omegaconf import OmegaConf
# set up environment
from utils.env_utils import set_random_seed, use_lower_vram
from utils.timer_utils import Timer
set_random_seed(1024)
timer = Timer()
timer.start()
# use_lower_vram()
# import functions
from utils.labels_utils import Labels
from utils.ram_utils import ram_inference
from utils.blip2_utils import blip2_caption
from utils.llms_utils import pre_refinement, make_prompt, init_model
from utils.grounded_sam_utils import run_grounded_sam
# hardcode parameters for G-SAM
box_threshold = 0.18
text_threshold = 0.15
iou_threshold = 0.8
global current_config, L, llm, system_prompt
# load Llama-3 here to avoid loading it during the inference.
llm = init_model("Meta-Llama-3-8B-Instruct")
current_config = ""
L = None
system_prompt = None
def load_config(config_type):
config = OmegaConf.load(os.path.join(os.path.dirname(__file__), f"configs/{config_type}.yaml"))
L = Labels(config=config)
# init labels and llm prompt, only Meta-Llama-3-8B-Instruct is supported for online demo, but you can use any model in your local environment using our released code
system_prompt = make_prompt(", ".join(L.LABELS))
return L, system_prompt
@spaces.GPU(duration=120)
def process(image_ori, config_type):
global current_config, L, llm, system_prompt
if current_config != config_type:
L, system_prompt = load_config(config_type)
current_config = config_type
else:
pass
image_ori = cv2.cvtColor(image_ori, cv2.COLOR_BGR2RGB)
image_pil = Image.fromarray(image_ori)
labels_ram = ram_inference(image_pil) + ": " + blip2_caption(image_pil)
converted_labels, llm_output = pre_refinement([labels_ram], system_prompt, llm=llm)
labels_llm = L.check_labels(converted_labels)[0]
print("labels_ram: ", labels_ram)
print("llm_output: ", llm_output)
print("labels_llm: ", labels_llm)
# run sam
label_res, bboxes, output_labels, output_prob_maps, output_points = run_grounded_sam(
input_image = {"image": image_pil, "mask": None},
text_prompt = labels_llm,
box_threshold = box_threshold,
text_threshold = text_threshold,
iou_threshold = iou_threshold,
LABELS = L.LABELS,
IDS = L.IDS,
llm = llm,
timer = timer,
)
# draw mask and save image
ours = L.draw_mask(label_res, image_ori, print_label=True, tag="Ours")
return cv2.cvtColor(ours, cv2.COLOR_BGR2RGB)
if __name__ == "__main__":
# options for different settings
dropdown_options = ["COCO-81", "Cityscapes", "DRAM", "VOC2012"]
default_option = "COCO-81"
with gr.Blocks() as demo:
gr.HTML(
"""
<h1 style="text-align: center; font-size: 32px; font-family: 'Times New Roman', Times, serif;">
Training-Free Zero-Shot Semantic Segmentation with LLM Refinement
</h1>
<p style="text-align: center; font-size: 20px; font-family: 'Times New Roman', Times, serif;">
<a style="text-align: center; display:inline-block"
href="https://sky24h.github.io/websites/bmvc2024_training-free-semseg-with-LLM/">
<img src="https://huggingface.co/datasets/huggingface/badges/raw/main/paper-page-sm.svg#center"
alt="Paper Page">
</a>
<a style="text-align: center; display:inline-block" href="https://huggingface.co/spaces/sky24h/Training-Free_Zero-Shot_Semantic_Segmentation_with_LLM_Refinement?duplicate=true">
<img src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-sm.svg#center" alt="Duplicate Space">
</a>
</p>
"""
)
gr.Interface(
fn=process,
inputs=[gr.Image(type="numpy", height="384"), gr.Dropdown(choices=dropdown_options, label="Refinement Type", value=default_option)],
outputs="image",
description="""<html>
<p style="text-align:center;"> This is an online demo for the paper "Training-Free Zero-Shot Semantic Segmentation with LLM Refinement" (BMVC 2024). </p>
<p style="text-align:center;"> Uasge: Please select or upload an image and choose a dataset setting for semantic segmentation refinement.</p>
</html>""",
allow_flagging='never',
examples=[
["examples/Cityscapes_eg.jpg", "Cityscapes"],
["examples/DRAM_eg.jpg", "DRAM"],
["examples/COCO-81_eg.jpg", "COCO-81"],
["examples/VOC2012_eg.jpg", "VOC2012"],
],
cache_examples=True,
)
demo.queue(max_size=10).launch()