import os import gradio as gr from gradio_client import Client, handle_file from pathlib import Path from gradio.utils import get_cache_folder import torch import torchvision.transforms as transforms from PIL import Image import cv2 import numpy as np import ast # from zerogpu import init_zerogpu # init_zerogpu() class Examples(gr.helpers.Examples): def __init__(self, *args, cached_folder=None, **kwargs): super().__init__(*args, **kwargs, _initiated_directly=False) if cached_folder is not None: self.cached_folder = cached_folder # self.cached_file = Path(self.cached_folder) / "log.csv" self.create() def postprocess(output, prompt): result = [] image = Image.open(output) w, h = image.size n = len(prompt) slice_width = w // n for i in range(n): left = i * slice_width right = (i + 1) * slice_width if i < n - 1 else w cropped_img = image.crop((left, 0, right, h)) caption = prompt[i] result.append((cropped_img, caption)) return result # user click the image to get points, and show the points on the image def get_point(img, sel_pix, evt: gr.SelectData): # print(img, sel_pix) if len(sel_pix) < 5: sel_pix.append((evt.index, 1)) # default foreground_point img = cv2.imread(img) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # draw points for point, label in sel_pix: cv2.drawMarker(img, point, colors[label], markerType=markers[label], markerSize=20, thickness=5) # if img[..., 0][0, 0] == img[..., 2][0, 0]: # BGR to RGB # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) print(sel_pix) return img, sel_pix def set_point(img, checkbox_group, sel_pix, semantic_input): ori_img = img # print(img, checkbox_group, sel_pix, semantic_input) sel_pix = ast.literal_eval(sel_pix) img = cv2.imread(img) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if len(sel_pix) <= 5 and len(sel_pix) > 0: for point, label in sel_pix: cv2.drawMarker(img, point, colors[label], markerType=markers[label], markerSize=20, thickness=5) return ori_img, img, sel_pix # undo the selected point def undo_points(orig_img, sel_pix): if isinstance(orig_img, int): # if orig_img is int, the image if select from examples temp = cv2.imread(image_examples[orig_img][0]) temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB) else: temp = cv2.imread(orig_img) temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB) # draw points if len(sel_pix) != 0: sel_pix.pop() for point, label in sel_pix: cv2.drawMarker(temp, point, colors[label], markerType=markers[label], markerSize=20, thickness=5) if temp[..., 0][0, 0] == temp[..., 2][0, 0]: # BGR to RGB temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB) return temp, sel_pix HF_TOKEN = os.environ.get('HF_KEY') client = Client("Canyu/Diception", max_workers=3, hf_token=HF_TOKEN) colors = [(255, 0, 0), (0, 255, 0)] markers = [1, 5] def process_image_check(path_input, prompt, sel_points, semantic): if path_input is None: raise gr.Error( "Missing image in the left pane: please upload an image first." ) if len(prompt) == 0: raise gr.Error( "At least 1 prediction type is needed." ) def inf(image_path, prompt, sel_points, semantic): if isinstance(sel_points, str): sel_points = ast.literal_eval(selected_points) print('=========== PROCESS IMAGE CHECK ===========') print(f"Image Path: {image_path}") print(f"Prompt: {prompt}") print(f"Selected Points (before processing): {sel_points}") print(f"Semantic Input: {semantic}") print('===========================================') if 'point segmentation' in prompt and len(sel_points) == 0: raise gr.Error( "At least 1 point is needed." ) return if 'point segmentation' not in prompt and len(sel_points) != 0: raise gr.Error( "You must select 'point segmentation' when performing point segmentation." ) return if 'semantic segmentation' in prompt and semantic == '': raise gr.Error( "Target category is needed." ) return if 'semantic segmentation' not in prompt and semantic != '': raise gr.Error( "You must select 'semantic segmentation' when performing semantic segmentation." ) return # return None # inputs = process_image_4(image_path, prompt, sel_points, semantic) prompt_str = str(sel_points) result = client.predict( input_image=handle_file(image_path), checkbox_group=prompt, selected_points=prompt_str, semantic_input=semantic, api_name="/inf" ) result = postprocess(result, prompt) return result def clear_cache(): return None, None def dummy(): pass def run_demo_server(): options = ['depth', 'normal', 'entity segmentation', 'human pose', 'point segmentation', 'semantic segmentation'] gradio_theme = gr.themes.Default() with gr.Blocks( theme=gradio_theme, title="Diception", css=""" #download { height: 118px; } .slider .inner { width: 5px; background: #FFF; } .viewport { aspect-ratio: 4/3; } .tabs button.selected { font-size: 20px !important; color: crimson !important; } h1 { text-align: center; display: block; } h2 { text-align: center; display: block; } h3 { text-align: center; display: block; } .md_feedback li { margin-bottom: 0px !important; } .hideme { display: none; } """, head=""" """, ) as demo: selected_points = gr.State([]) # store points original_image = gr.State(value=None) # store original image without points, default None gr.HTML( """

DICEPTION: A Generalist Diffusion Model for Vision Perception

One single model solves multiple perception tasks, producing impressive results!

Due to the GPU quota limit, if an error occurs, please wait for 5 minutes before retrying.

badge-github-stars

""" ) selected_points_tmp = gr.Textbox(label="Points", elem_classes="hideme") with gr.Row(): checkbox_group = gr.CheckboxGroup(choices=options, label="Task") with gr.Row(): semantic_input = gr.Textbox(label="Category Name", placeholder="e.g. person/cat/dog/elephant...... (for semantic segmentation only, in COCO)") with gr.Row(): gr.Markdown('For non-human image inputs, the pose results may have issues. Same when perform semantic segmentation with categories that are not in COCO.') with gr.Row(): gr.Markdown('The results of semantic segmentation may be unstable because:') with gr.Row(): gr.Markdown('- We only trained on COCO, whose quality and quantity are insufficient to meet the requirements.') with gr.Row(): gr.Markdown('- Semantic segmentation is more complex than other tasks, as it requires accurately learning the relationship between semantics and objects.') with gr.Row(): gr.Markdown('However, we are still able to produce some high-quality semantic segmentation results, strongly demonstrating the potential of our approach.') with gr.Row(): with gr.Column(): input_image = gr.Image( label="Input Image", type="filepath", ) with gr.Column(): with gr.Row(): gr.Markdown('You can click on the image to select points prompt. At most 5 point.') matting_image_submit_btn = gr.Button( value="Run", variant="primary" ) with gr.Row(): undo_button = gr.Button('Undo point') matting_image_reset_btn = gr.Button(value="Reset") with gr.Column(): matting_image_output = gr.Gallery(label="Results") # img_clear_button.click(clear_cache, outputs=[input_image, matting_image_output]) matting_image_submit_btn.click( fn=process_image_check, inputs=[input_image, checkbox_group, selected_points, semantic_input], outputs=None, preprocess=False, queue=False, ).success( fn=inf, inputs=[original_image, checkbox_group, selected_points, semantic_input], outputs=[matting_image_output], concurrency_limit=1, ) matting_image_reset_btn.click( fn=lambda: ( None, None, [] ), inputs=[], outputs=[ input_image, matting_image_output, selected_points ], queue=False, ) # once user upload an image, the original image is stored in `original_image` def store_img(img): return img, [] # when new image is uploaded, `selected_points` should be empty input_image.upload( store_img, [input_image], [original_image, selected_points] ) input_image.select( get_point, [original_image, selected_points], [input_image, selected_points], ) undo_button.click( undo_points, [original_image, selected_points], [input_image, selected_points] ) examples = gr.Examples( fn=set_point, run_on_click=True, examples=[ ["assets/woman.jpg", ['point segmentation', 'depth', 'normal', 'entity segmentation', 'human pose', 'semantic segmentation'], '[([2744, 975], 1), ([3440, 1954], 1), ([2123, 2405], 1), ([838, 1678], 1), ([4688, 1922], 1)]', 'person'], ["assets/woman2.jpg", ['point segmentation', 'depth', 'entity segmentation', 'semantic segmentation', 'human pose'], '[([687, 1416], 1), ([1021, 707], 1), ([1138, 1138], 1), ([1182, 1583], 1), ([1188, 2172], 1)]', 'person'], ["assets/board.jpg", ['point segmentation', 'depth', 'entity segmentation', 'normal'], '[([1003, 2163], 1)]', ''], ["assets/lion.jpg", ['point segmentation', 'depth', 'semantic segmentation'], '[([1287, 671], 1)]', 'lion'], ["assets/apple.jpg", ['point segmentation', 'depth', 'semantic segmentation', 'normal', 'entity segmentation'], '[([3367, 1950], 1)]','apple'], ["assets/room.jpg", ['point segmentation', 'depth', 'semantic segmentation', 'normal', 'entity segmentation'], '[([1308, 2215], 1)]', 'chair'], ["assets/car.jpg", ['point segmentation', 'depth', 'semantic segmentation', 'normal', 'entity segmentation'], '[([1276, 1369], 1)]', 'car'], ["assets/person.jpg", ['point segmentation', 'depth', 'semantic segmentation', 'normal', 'entity segmentation', 'human pose'], '[([3253, 1459], 1)]', 'tie'], ["assets/woman3.jpg", ['point segmentation', 'depth', 'entity segmentation'], '[([420, 692], 1)]', ''], ["assets/cat.jpg", ['point segmentation', 'depth', 'entity segmentation', 'semantic segmentation'], '[([756, 661], 1)]', 'cat'], ["assets/room2.jpg", ['point segmentation', 'depth', 'entity segmentation', 'semantic segmentation', 'normal'], '[([3946, 224], 1)]', 'laptop'], ["assets/cartoon_cat.png", ['point segmentation', 'depth', 'entity segmentation', 'semantic segmentation', 'normal'], '[([1478, 3048], 1)]', 'cat'], ["assets/sheep.jpg", ['point segmentation', 'depth', 'entity segmentation', 'semantic segmentation'], '[([1789, 1791], 1), ([1869, 1333], 1)]', 'sheep'], ["assets/cartoon_girl.jpeg", ['point segmentation', 'depth', 'entity segmentation', 'normal', 'human pose', 'semantic segmentation'], '[([1208, 2089], 1), ([635, 2731], 1), ([1070, 2888], 1), ([1493, 2350], 1)]', 'person'], ], inputs=[input_image, checkbox_group, selected_points_tmp, semantic_input], outputs=[original_image, input_image, selected_points], cache_examples=False, ) demo.queue( api_open=False, ).launch() if __name__ == '__main__': run_demo_server()