Diception-Demo / app.py
Canyu's picture
commit
a51380e
raw
history blame
9.53 kB
import os
import gradio as gr
from gradio_client import Client, handle_file
from pathlib import Path
from gradio.utils import get_cache_folder
import torch
import torchvision.transforms as transforms
from PIL import Image
import cv2
import numpy as np
class Examples(gr.helpers.Examples):
def __init__(self, *args, cached_folder=None, **kwargs):
super().__init__(*args, **kwargs, _initiated_directly=False)
if cached_folder is not None:
self.cached_folder = cached_folder
# self.cached_file = Path(self.cached_folder) / "log.csv"
self.create()
# user click the image to get points, and show the points on the image
def get_point(img, sel_pix, evt: gr.SelectData):
if len(sel_pix) < 5:
sel_pix.append((evt.index, 1)) # default foreground_point
img = cv2.imread(img)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# draw points
for point, label in sel_pix:
cv2.drawMarker(img, point, colors[label], markerType=markers[label], markerSize=20, thickness=5)
# if img[..., 0][0, 0] == img[..., 2][0, 0]: # BGR to RGB
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
print(sel_pix)
return img, sel_pix
# undo the selected point
def undo_points(orig_img, sel_pix):
if isinstance(orig_img, int): # if orig_img is int, the image if select from examples
temp = cv2.imread(image_examples[orig_img][0])
temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)
else:
temp = cv2.imread(orig_img)
temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)
# draw points
if len(sel_pix) != 0:
sel_pix.pop()
for point, label in sel_pix:
cv2.drawMarker(temp, point, colors[label], markerType=markers[label], markerSize=20, thickness=5)
if temp[..., 0][0, 0] == temp[..., 2][0, 0]: # BGR to RGB
temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)
return temp, sel_pix
# HF_TOKEN = os.environ.get('HF_KEY')
# client = Client("Canyu/Diception",
# max_workers=3,
# hf_token=HF_TOKEN)
colors = [(255, 0, 0), (0, 255, 0)]
markers = [1, 5]
map_prompt = {
'depth': '[[image2depth]]',
'normal': '[[image2normal]]',
'human pose': '[[image2pose]]',
'entity segmentation': '[[image2panoptic coarse]]',
'point segmentation': '[[image2segmentation]]',
'semantic segmentation': '[[image2semantic]]',
}
def download_additional_params(model_name, filename="add_params.bin"):
# 下载文件并返回文件路径
file_path = hf_hub_download(repo_id=model_name, filename=filename, use_auth_token=HF_TOKEN)
return file_path
# 加载 additional_params.bin 文件
def load_additional_params(model_name):
# 下载 additional_params.bin
params_path = download_additional_params(model_name)
# 使用 torch.load() 加载文件内容
additional_params = torch.load(params_path, map_location='cpu')
# 返回加载的参数内容
return additional_params
def process_image_check(path_input, prompt, sel_points, semantic):
print('=========== PROCESS IMAGE CHECK ===========')
print(f"Image Path: {path_input}")
print(f"Prompt: {prompt}")
print(f"Selected Points (before processing): {sel_points}")
print(f"Semantic Input: {semantic}")
print('===========================================')
if path_input is None:
raise gr.Error(
"Missing image in the left pane: please upload an image first."
)
if len(prompt) == 0:
raise gr.Error(
"At least 1 prediction type is needed."
)
if 'point segmentation' in prompt and len(sel_points) == 0:
raise gr.Error(
"At least 1 point is needed."
)
if 'point segmentation' not in prompt and len(sel_points) != 0:
raise gr.Error(
"You must select 'point segmentation' when performing point segmentation."
)
if 'semantic segmentation' in prompt and semantic == None:
raise gr.Error(
"Target category is needed."
)
if 'semantic segmentation' not in prompt and semantic != None:
raise gr.Error(
"You must select 'semantic segmentation' when performing semantic segmentation."
)
def process_image_4(image_path, prompt):
inputs = []
for p in prompt:
cur_p = map_prompt[p]
coor_point = []
point_labels = []
cur_input = {
# 'original_size': [[w,h]],
# 'target_size': [[768, 768]],
'prompt': [cur_p],
'coor_point': coor_point,
'point_labels': point_labels,
}
inputs.append(cur_input)
return inputs
def inf(image_path, prompt, sel_points, semantic):
inputs = process_image_4(image_path, prompt, sel_points, semantic)
# return None
return client.predict(
image=handle_file(image_path),
data=inputs,
api_name="/inf"
)
def clear_cache():
return None, None
def run_demo_server():
options = ['depth', 'normal', 'entity segmentation', 'human pose', 'point segmentation', 'semantic segmentation']
gradio_theme = gr.themes.Default()
with gr.Blocks(
theme=gradio_theme,
title="Matting",
) as demo:
selected_points = gr.State([]) # store points
original_image = gr.State(value=None) # store original image without points, default None
with gr.Row():
gr.Markdown("# Diception Demo")
with gr.Row():
gr.Markdown("### All results are generated using the same single model. To facilitate input processing, we separate point-prompted segmentation and semantic segmentation, as they require input points and segmentation targets.")
with gr.Row():
checkbox_group = gr.CheckboxGroup(choices=options, label="Select options:")
with gr.Row():
semantic_input = gr.Textbox(label="Category Name (for semantic segmentation only, in COCO)", placeholder="e.g. person/cat/dog/elephant......")
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Input Image",
type="filepath",
)
with gr.Column():
with gr.Row():
gr.Markdown('You can click on the image to select points prompt. At most 5 point.')
undo_button = gr.Button('Undo point')
with gr.Row():
matting_image_submit_btn = gr.Button(
value="Estimate Matting", variant="primary"
)
matting_image_reset_btn = gr.Button(value="Reset")
with gr.Row():
img_clear_button = gr.Button("Clear Cache")
with gr.Column():
# matting_image_output = gr.Image(label='Output')
matting_image_output = gr.Image(label='Matting Output')
# label="Matting Output",
# type="filepath",
# show_download_button=True,
# show_share_button=True,
# interactive=False,
# elem_classes="slider",
# position=0.25,
# )
img_clear_button.click(clear_cache, outputs=[input_image, matting_image_output])
matting_image_submit_btn.click(
fn=process_image_check,
inputs=[input_image, checkbox_group, selected_points, semantic_input],
outputs=None,
preprocess=False,
queue=False,
).success(
# fn=process_pipe_matting,
fn=inf,
inputs=[input_image, checkbox_group, selected_points, semantic_input],
outputs=[matting_image_output],
concurrency_limit=1,
)
matting_image_reset_btn.click(
fn=lambda: (
None,
None,
),
inputs=[],
outputs=[
input_image,
matting_image_output,
],
queue=False,
)
# once user upload an image, the original image is stored in `original_image`
def store_img(img):
return img, [] # when new image is uploaded, `selected_points` should be empty
input_image.upload(
store_img,
[input_image],
[original_image, selected_points]
)
input_image.select(
get_point,
[input_image, selected_points],
[input_image, selected_points],
)
undo_button.click(
undo_points,
[original_image, selected_points],
[input_image, selected_points]
)
# gr.Examples(
# fn=inf,
# examples=[
# ["assets/person.jpg", ['depth', 'normal', 'entity segmentation', 'pose']]
# ],
# inputs=[input_image, checkbox_group],
# outputs=[matting_image_output],
# cache_examples=True,
# # cache_examples=False,
# # cached_folder="cache_dir",
# )
demo.queue(
api_open=False,
).launch()
if __name__ == '__main__':
run_demo_server()