Spaces:
Running
Running
import os | |
import gradio as gr | |
from gradio_client import Client, handle_file | |
from pathlib import Path | |
from gradio.utils import get_cache_folder | |
import torch | |
import torchvision.transforms as transforms | |
from PIL import Image | |
import cv2 | |
import numpy as np | |
class Examples(gr.helpers.Examples): | |
def __init__(self, *args, cached_folder=None, **kwargs): | |
super().__init__(*args, **kwargs, _initiated_directly=False) | |
if cached_folder is not None: | |
self.cached_folder = cached_folder | |
# self.cached_file = Path(self.cached_folder) / "log.csv" | |
self.create() | |
# user click the image to get points, and show the points on the image | |
def get_point(img, sel_pix, evt: gr.SelectData): | |
if len(sel_pix) < 5: | |
sel_pix.append((evt.index, 1)) # default foreground_point | |
img = cv2.imread(img) | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
# draw points | |
for point, label in sel_pix: | |
cv2.drawMarker(img, point, colors[label], markerType=markers[label], markerSize=20, thickness=5) | |
# if img[..., 0][0, 0] == img[..., 2][0, 0]: # BGR to RGB | |
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
print(sel_pix) | |
return img, sel_pix | |
# undo the selected point | |
def undo_points(orig_img, sel_pix): | |
if isinstance(orig_img, int): # if orig_img is int, the image if select from examples | |
temp = cv2.imread(image_examples[orig_img][0]) | |
temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB) | |
else: | |
temp = cv2.imread(orig_img) | |
temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB) | |
# draw points | |
if len(sel_pix) != 0: | |
sel_pix.pop() | |
for point, label in sel_pix: | |
cv2.drawMarker(temp, point, colors[label], markerType=markers[label], markerSize=20, thickness=5) | |
if temp[..., 0][0, 0] == temp[..., 2][0, 0]: # BGR to RGB | |
temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB) | |
return temp, sel_pix | |
# HF_TOKEN = os.environ.get('HF_KEY') | |
# client = Client("Canyu/Diception", | |
# max_workers=3, | |
# hf_token=HF_TOKEN) | |
colors = [(255, 0, 0), (0, 255, 0)] | |
markers = [1, 5] | |
map_prompt = { | |
'depth': '[[image2depth]]', | |
'normal': '[[image2normal]]', | |
'human pose': '[[image2pose]]', | |
'entity segmentation': '[[image2panoptic coarse]]', | |
'point segmentation': '[[image2segmentation]]', | |
'semantic segmentation': '[[image2semantic]]', | |
} | |
def download_additional_params(model_name, filename="add_params.bin"): | |
# 下载文件并返回文件路径 | |
file_path = hf_hub_download(repo_id=model_name, filename=filename, use_auth_token=HF_TOKEN) | |
return file_path | |
# 加载 additional_params.bin 文件 | |
def load_additional_params(model_name): | |
# 下载 additional_params.bin | |
params_path = download_additional_params(model_name) | |
# 使用 torch.load() 加载文件内容 | |
additional_params = torch.load(params_path, map_location='cpu') | |
# 返回加载的参数内容 | |
return additional_params | |
def process_image_check(path_input, prompt, sel_points, semantic): | |
print('=========== PROCESS IMAGE CHECK ===========') | |
print(f"Image Path: {path_input}") | |
print(f"Prompt: {prompt}") | |
print(f"Selected Points (before processing): {sel_points}") | |
print(f"Semantic Input: {semantic}") | |
print('===========================================') | |
if path_input is None: | |
raise gr.Error( | |
"Missing image in the left pane: please upload an image first." | |
) | |
if len(prompt) == 0: | |
raise gr.Error( | |
"At least 1 prediction type is needed." | |
) | |
if 'point segmentation' in prompt and len(sel_points) == 0: | |
raise gr.Error( | |
"At least 1 point is needed." | |
) | |
if 'point segmentation' not in prompt and len(sel_points) != 0: | |
raise gr.Error( | |
"You must select 'point segmentation' when performing point segmentation." | |
) | |
if 'semantic segmentation' in prompt and semantic == None: | |
raise gr.Error( | |
"Target category is needed." | |
) | |
if 'semantic segmentation' not in prompt and semantic != None: | |
raise gr.Error( | |
"You must select 'semantic segmentation' when performing semantic segmentation." | |
) | |
def process_image_4(image_path, prompt): | |
inputs = [] | |
for p in prompt: | |
cur_p = map_prompt[p] | |
coor_point = [] | |
point_labels = [] | |
cur_input = { | |
# 'original_size': [[w,h]], | |
# 'target_size': [[768, 768]], | |
'prompt': [cur_p], | |
'coor_point': coor_point, | |
'point_labels': point_labels, | |
} | |
inputs.append(cur_input) | |
return inputs | |
def inf(image_path, prompt, sel_points, semantic): | |
inputs = process_image_4(image_path, prompt, sel_points, semantic) | |
# return None | |
return client.predict( | |
image=handle_file(image_path), | |
data=inputs, | |
api_name="/inf" | |
) | |
def clear_cache(): | |
return None, None | |
def run_demo_server(): | |
options = ['depth', 'normal', 'entity segmentation', 'human pose', 'point segmentation', 'semantic segmentation'] | |
gradio_theme = gr.themes.Default() | |
with gr.Blocks( | |
theme=gradio_theme, | |
title="Matting", | |
) as demo: | |
selected_points = gr.State([]) # store points | |
original_image = gr.State(value=None) # store original image without points, default None | |
with gr.Row(): | |
gr.Markdown("# Diception Demo") | |
with gr.Row(): | |
gr.Markdown("### All results are generated using the same single model. To facilitate input processing, we separate point-prompted segmentation and semantic segmentation, as they require input points and segmentation targets.") | |
with gr.Row(): | |
checkbox_group = gr.CheckboxGroup(choices=options, label="Select options:") | |
with gr.Row(): | |
semantic_input = gr.Textbox(label="Category Name (for semantic segmentation only, in COCO)", placeholder="e.g. person/cat/dog/elephant......") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image( | |
label="Input Image", | |
type="filepath", | |
) | |
with gr.Column(): | |
with gr.Row(): | |
gr.Markdown('You can click on the image to select points prompt. At most 5 point.') | |
undo_button = gr.Button('Undo point') | |
with gr.Row(): | |
matting_image_submit_btn = gr.Button( | |
value="Estimate Matting", variant="primary" | |
) | |
matting_image_reset_btn = gr.Button(value="Reset") | |
with gr.Row(): | |
img_clear_button = gr.Button("Clear Cache") | |
with gr.Column(): | |
# matting_image_output = gr.Image(label='Output') | |
matting_image_output = gr.Image(label='Matting Output') | |
# label="Matting Output", | |
# type="filepath", | |
# show_download_button=True, | |
# show_share_button=True, | |
# interactive=False, | |
# elem_classes="slider", | |
# position=0.25, | |
# ) | |
img_clear_button.click(clear_cache, outputs=[input_image, matting_image_output]) | |
matting_image_submit_btn.click( | |
fn=process_image_check, | |
inputs=[input_image, checkbox_group, selected_points, semantic_input], | |
outputs=None, | |
preprocess=False, | |
queue=False, | |
).success( | |
# fn=process_pipe_matting, | |
fn=inf, | |
inputs=[input_image, checkbox_group, selected_points, semantic_input], | |
outputs=[matting_image_output], | |
concurrency_limit=1, | |
) | |
matting_image_reset_btn.click( | |
fn=lambda: ( | |
None, | |
None, | |
), | |
inputs=[], | |
outputs=[ | |
input_image, | |
matting_image_output, | |
], | |
queue=False, | |
) | |
# once user upload an image, the original image is stored in `original_image` | |
def store_img(img): | |
return img, [] # when new image is uploaded, `selected_points` should be empty | |
input_image.upload( | |
store_img, | |
[input_image], | |
[original_image, selected_points] | |
) | |
input_image.select( | |
get_point, | |
[input_image, selected_points], | |
[input_image, selected_points], | |
) | |
undo_button.click( | |
undo_points, | |
[original_image, selected_points], | |
[input_image, selected_points] | |
) | |
# gr.Examples( | |
# fn=inf, | |
# examples=[ | |
# ["assets/person.jpg", ['depth', 'normal', 'entity segmentation', 'pose']] | |
# ], | |
# inputs=[input_image, checkbox_group], | |
# outputs=[matting_image_output], | |
# cache_examples=True, | |
# # cache_examples=False, | |
# # cached_folder="cache_dir", | |
# ) | |
demo.queue( | |
api_open=False, | |
).launch() | |
if __name__ == '__main__': | |
run_demo_server() |