import os
import gradio as gr
from gradio_client import Client, handle_file
from pathlib import Path
from gradio.utils import get_cache_folder
import torch
import torchvision.transforms as transforms
from PIL import Image
import cv2
import numpy as np
import ast
# from zerogpu import init_zerogpu
# init_zerogpu()
class Examples(gr.helpers.Examples):
def __init__(self, *args, cached_folder=None, **kwargs):
super().__init__(*args, **kwargs, _initiated_directly=False)
if cached_folder is not None:
self.cached_folder = cached_folder
# self.cached_file = Path(self.cached_folder) / "log.csv"
self.create()
def postprocess(output, prompt):
result = []
image = Image.open(output)
w, h = image.size
n = len(prompt)
slice_width = w // n
for i in range(n):
left = i * slice_width
right = (i + 1) * slice_width if i < n - 1 else w
cropped_img = image.crop((left, 0, right, h))
caption = prompt[i]
result.append((cropped_img, caption))
return result
# user click the image to get points, and show the points on the image
def get_point(img, sel_pix, evt: gr.SelectData):
# print(img, sel_pix)
if len(sel_pix) < 5:
sel_pix.append((evt.index, 1)) # default foreground_point
img = cv2.imread(img)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# draw points
for point, label in sel_pix:
cv2.drawMarker(img, point, colors[label], markerType=markers[label], markerSize=20, thickness=5)
# if img[..., 0][0, 0] == img[..., 2][0, 0]: # BGR to RGB
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
print(sel_pix)
return img, sel_pix
def set_point(img, checkbox_group, sel_pix, semantic_input):
ori_img = img
# print(img, checkbox_group, sel_pix, semantic_input)
sel_pix = ast.literal_eval(sel_pix)
img = cv2.imread(img)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if len(sel_pix) <= 5 and len(sel_pix) > 0:
for point, label in sel_pix:
cv2.drawMarker(img, point, colors[label], markerType=markers[label], markerSize=20, thickness=5)
return ori_img, img, sel_pix
# undo the selected point
def undo_points(orig_img, sel_pix):
if isinstance(orig_img, int): # if orig_img is int, the image if select from examples
temp = cv2.imread(image_examples[orig_img][0])
temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)
else:
temp = cv2.imread(orig_img)
temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)
# draw points
if len(sel_pix) != 0:
sel_pix.pop()
for point, label in sel_pix:
cv2.drawMarker(temp, point, colors[label], markerType=markers[label], markerSize=20, thickness=5)
if temp[..., 0][0, 0] == temp[..., 2][0, 0]: # BGR to RGB
temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)
return temp, sel_pix
HF_TOKEN = os.environ.get('HF_KEY')
client = Client("Canyu/Diception",
max_workers=3,
hf_token=HF_TOKEN)
colors = [(255, 0, 0), (0, 255, 0)]
markers = [1, 5]
def process_image_check(path_input, prompt, sel_points, semantic):
if path_input is None:
raise gr.Error(
"Missing image in the left pane: please upload an image first."
)
if len(prompt) == 0:
raise gr.Error(
"At least 1 prediction type is needed."
)
def inf(image_path, prompt, sel_points, semantic):
if isinstance(sel_points, str):
sel_points = ast.literal_eval(selected_points)
print('=========== PROCESS IMAGE CHECK ===========')
print(f"Image Path: {image_path}")
print(f"Prompt: {prompt}")
print(f"Selected Points (before processing): {sel_points}")
print(f"Semantic Input: {semantic}")
print('===========================================')
if 'point segmentation' in prompt and len(sel_points) == 0:
raise gr.Error(
"At least 1 point is needed."
)
return
if 'point segmentation' not in prompt and len(sel_points) != 0:
raise gr.Error(
"You must select 'point segmentation' when performing point segmentation."
)
return
if 'semantic segmentation' in prompt and semantic == '':
raise gr.Error(
"Target category is needed."
)
return
if 'semantic segmentation' not in prompt and semantic != '':
raise gr.Error(
"You must select 'semantic segmentation' when performing semantic segmentation."
)
return
# return None
# inputs = process_image_4(image_path, prompt, sel_points, semantic)
prompt_str = str(sel_points)
result = client.predict(
input_image=handle_file(image_path),
checkbox_group=prompt,
selected_points=prompt_str,
semantic_input=semantic,
api_name="/inf"
)
result = postprocess(result, prompt)
return result
def clear_cache():
return None, None
def dummy():
pass
def run_demo_server():
options = ['depth', 'normal', 'entity segmentation', 'human pose', 'point segmentation', 'semantic segmentation']
gradio_theme = gr.themes.Default()
with gr.Blocks(
theme=gradio_theme,
title="Diception",
css="""
#download {
height: 118px;
}
.slider .inner {
width: 5px;
background: #FFF;
}
.viewport {
aspect-ratio: 4/3;
}
.tabs button.selected {
font-size: 20px !important;
color: crimson !important;
}
h1 {
text-align: center;
display: block;
}
h2 {
text-align: center;
display: block;
}
h3 {
text-align: center;
display: block;
}
.md_feedback li {
margin-bottom: 0px !important;
}
.hideme {
display: none;
}
""",
head="""
""",
) as demo:
selected_points = gr.State([]) # store points
original_image = gr.State(value=None) # store original image without points, default None
gr.HTML(
"""
DICEPTION: A Generalist Diffusion Model for Vision Perception
One single model solves multiple perception tasks, producing impressive results!
Due to the GPU quota limit, if an error occurs, please wait for 5 minutes before retrying.
"""
)
selected_points_tmp = gr.Textbox(label="Points", elem_classes="hideme")
with gr.Row():
checkbox_group = gr.CheckboxGroup(choices=options, label="Task")
with gr.Row():
semantic_input = gr.Textbox(label="Category Name", placeholder="e.g. person/cat/dog/elephant...... (for semantic segmentation only, in COCO)")
with gr.Row():
gr.Markdown('For non-human image inputs, the pose results may have issues. Same when perform semantic segmentation with categories that are not in COCO.')
with gr.Row():
gr.Markdown('The results of semantic segmentation may be unstable because:')
with gr.Row():
gr.Markdown('- We only trained on COCO, whose quality and quantity are insufficient to meet the requirements.')
with gr.Row():
gr.Markdown('- Semantic segmentation is more complex than other tasks, as it requires accurately learning the relationship between semantics and objects.')
with gr.Row():
gr.Markdown('However, we are still able to produce some high-quality semantic segmentation results, strongly demonstrating the potential of our approach.')
with gr.Row():
with gr.Column():
input_image = gr.Image(
label="Input Image",
type="filepath",
)
with gr.Column():
with gr.Row():
gr.Markdown('You can click on the image to select points prompt. At most 5 point.')
matting_image_submit_btn = gr.Button(
value="Run", variant="primary"
)
with gr.Row():
undo_button = gr.Button('Undo point')
matting_image_reset_btn = gr.Button(value="Reset")
with gr.Column():
matting_image_output = gr.Gallery(label="Results")
# img_clear_button.click(clear_cache, outputs=[input_image, matting_image_output])
matting_image_submit_btn.click(
fn=process_image_check,
inputs=[input_image, checkbox_group, selected_points, semantic_input],
outputs=None,
preprocess=False,
queue=False,
).success(
fn=inf,
inputs=[original_image, checkbox_group, selected_points, semantic_input],
outputs=[matting_image_output],
concurrency_limit=1,
)
matting_image_reset_btn.click(
fn=lambda: (
None,
None,
[]
),
inputs=[],
outputs=[
input_image,
matting_image_output,
selected_points
],
queue=False,
)
# once user upload an image, the original image is stored in `original_image`
def store_img(img):
return img, [] # when new image is uploaded, `selected_points` should be empty
input_image.upload(
store_img,
[input_image],
[original_image, selected_points]
)
input_image.select(
get_point,
[original_image, selected_points],
[input_image, selected_points],
)
undo_button.click(
undo_points,
[original_image, selected_points],
[input_image, selected_points]
)
examples = gr.Examples(
fn=set_point,
run_on_click=True,
examples=[
["assets/woman.jpg", ['point segmentation', 'depth', 'normal', 'entity segmentation', 'human pose', 'semantic segmentation'], '[([2744, 975], 1), ([3440, 1954], 1), ([2123, 2405], 1), ([838, 1678], 1), ([4688, 1922], 1)]', 'person'],
["assets/woman2.jpg", ['point segmentation', 'depth', 'entity segmentation', 'semantic segmentation', 'human pose'], '[([687, 1416], 1), ([1021, 707], 1), ([1138, 1138], 1), ([1182, 1583], 1), ([1188, 2172], 1)]', 'person'],
["assets/board.jpg", ['point segmentation', 'depth', 'entity segmentation', 'normal'], '[([1003, 2163], 1)]', ''],
["assets/lion.jpg", ['point segmentation', 'depth', 'semantic segmentation'], '[([1287, 671], 1)]', 'lion'],
["assets/apple.jpg", ['point segmentation', 'depth', 'semantic segmentation', 'normal', 'entity segmentation'], '[([3367, 1950], 1)]','apple'],
["assets/room.jpg", ['point segmentation', 'depth', 'semantic segmentation', 'normal', 'entity segmentation'], '[([1308, 2215], 1)]', 'chair'],
["assets/car.jpg", ['point segmentation', 'depth', 'semantic segmentation', 'normal', 'entity segmentation'], '[([1276, 1369], 1)]', 'car'],
["assets/person.jpg", ['point segmentation', 'depth', 'semantic segmentation', 'normal', 'entity segmentation', 'human pose'], '[([3253, 1459], 1)]', 'tie'],
["assets/woman3.jpg", ['point segmentation', 'depth', 'entity segmentation'], '[([420, 692], 1)]', ''],
["assets/cat.jpg", ['point segmentation', 'depth', 'entity segmentation', 'semantic segmentation'], '[([756, 661], 1)]', 'cat'],
["assets/room2.jpg", ['point segmentation', 'depth', 'entity segmentation', 'semantic segmentation', 'normal'], '[([3946, 224], 1)]', 'laptop'],
["assets/cartoon_cat.png", ['point segmentation', 'depth', 'entity segmentation', 'semantic segmentation', 'normal'], '[([1478, 3048], 1)]', 'cat'],
["assets/sheep.jpg", ['point segmentation', 'depth', 'entity segmentation', 'semantic segmentation'], '[([1789, 1791], 1), ([1869, 1333], 1)]', 'sheep'],
["assets/cartoon_girl.jpeg", ['point segmentation', 'depth', 'entity segmentation', 'normal', 'human pose', 'semantic segmentation'], '[([1208, 2089], 1), ([635, 2731], 1), ([1070, 2888], 1), ([1493, 2350], 1)]', 'person'],
],
inputs=[input_image, checkbox_group, selected_points_tmp, semantic_input],
outputs=[original_image, input_image, selected_points],
cache_examples=False,
)
demo.queue(
api_open=False,
).launch()
if __name__ == '__main__':
run_demo_server()