Canyu commited on
Commit
a51380e
·
1 Parent(s): 9381693
Files changed (2) hide show
  1. app.py +117 -30
  2. requirements.txt +2 -1
app.py CHANGED
@@ -8,6 +8,9 @@ import torch
8
  import torchvision.transforms as transforms
9
  from PIL import Image
10
 
 
 
 
11
  class Examples(gr.helpers.Examples):
12
  def __init__(self, *args, cached_folder=None, **kwargs):
13
  super().__init__(*args, **kwargs, _initiated_directly=False)
@@ -17,17 +20,53 @@ class Examples(gr.helpers.Examples):
17
  self.create()
18
 
19
 
20
- HF_TOKEN = os.environ.get('HF_KEY')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- client = Client("Canyu/Diception",
23
- max_workers=3,
24
- hf_token=HF_TOKEN)
25
 
 
 
 
 
 
 
 
 
26
 
27
  map_prompt = {
28
  'depth': '[[image2depth]]',
29
  'normal': '[[image2normal]]',
30
- 'pose': '[[image2pose]]',
31
  'entity segmentation': '[[image2panoptic coarse]]',
32
  'point segmentation': '[[image2segmentation]]',
33
  'semantic segmentation': '[[image2semantic]]',
@@ -49,7 +88,13 @@ def load_additional_params(model_name):
49
  # 返回加载的参数内容
50
  return additional_params
51
 
52
- def process_image_check(path_input, prompt):
 
 
 
 
 
 
53
  if path_input is None:
54
  raise gr.Error(
55
  "Missing image in the left pane: please upload an image first."
@@ -58,6 +103,23 @@ def process_image_check(path_input, prompt):
58
  raise gr.Error(
59
  "At least 1 prediction type is needed."
60
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
 
63
 
@@ -83,10 +145,8 @@ def process_image_4(image_path, prompt):
83
  return inputs
84
 
85
 
86
- def inf(image_path, prompt):
87
- print(image_path)
88
- print(prompt)
89
- inputs = process_image_4(image_path, prompt)
90
  # return None
91
  return client.predict(
92
  image=handle_file(image_path),
@@ -98,26 +158,34 @@ def clear_cache():
98
  return None, None
99
 
100
  def run_demo_server():
101
- options = ['depth', 'normal', 'entity', 'pose']
102
  gradio_theme = gr.themes.Default()
103
  with gr.Blocks(
104
  theme=gradio_theme,
105
  title="Matting",
106
  ) as demo:
 
 
107
  with gr.Row():
108
  gr.Markdown("# Diception Demo")
109
  with gr.Row():
110
  gr.Markdown("### All results are generated using the same single model. To facilitate input processing, we separate point-prompted segmentation and semantic segmentation, as they require input points and segmentation targets.")
111
  with gr.Row():
112
  checkbox_group = gr.CheckboxGroup(choices=options, label="Select options:")
113
-
 
114
  with gr.Row():
115
  with gr.Column():
116
- matting_image_input = gr.Image(
117
  label="Input Image",
118
  type="filepath",
119
  )
120
 
 
 
 
 
 
121
  with gr.Row():
122
  matting_image_submit_btn = gr.Button(
123
  value="Estimate Matting", variant="primary"
@@ -142,21 +210,18 @@ def run_demo_server():
142
 
143
 
144
 
145
- img_clear_button.click(clear_cache, outputs=[matting_image_input, matting_image_output])
146
 
147
  matting_image_submit_btn.click(
148
  fn=process_image_check,
149
- inputs=[matting_image_input, checkbox_group],
150
  outputs=None,
151
  preprocess=False,
152
  queue=False,
153
  ).success(
154
  # fn=process_pipe_matting,
155
  fn=inf,
156
- inputs=[
157
- matting_image_input,
158
- checkbox_group
159
- ],
160
  outputs=[matting_image_output],
161
  concurrency_limit=1,
162
  )
@@ -168,23 +233,45 @@ def run_demo_server():
168
  ),
169
  inputs=[],
170
  outputs=[
171
- matting_image_input,
172
  matting_image_output,
173
  ],
174
  queue=False,
175
  )
176
 
177
- gr.Examples(
178
- fn=inf,
179
- examples=[
180
- ["assets/person.jpg", ['depth', 'normal', 'entity', 'pose']]
181
- ],
182
- inputs=[matting_image_input, checkbox_group],
183
- outputs=[matting_image_output],
184
- cache_examples=True,
185
- # cache_examples=False,
186
- # cached_folder="cache_dir",
 
 
 
 
187
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
  demo.queue(
190
  api_open=False,
 
8
  import torchvision.transforms as transforms
9
  from PIL import Image
10
 
11
+ import cv2
12
+ import numpy as np
13
+
14
  class Examples(gr.helpers.Examples):
15
  def __init__(self, *args, cached_folder=None, **kwargs):
16
  super().__init__(*args, **kwargs, _initiated_directly=False)
 
20
  self.create()
21
 
22
 
23
+ # user click the image to get points, and show the points on the image
24
+ def get_point(img, sel_pix, evt: gr.SelectData):
25
+ if len(sel_pix) < 5:
26
+ sel_pix.append((evt.index, 1)) # default foreground_point
27
+ img = cv2.imread(img)
28
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
29
+ # draw points
30
+
31
+ for point, label in sel_pix:
32
+ cv2.drawMarker(img, point, colors[label], markerType=markers[label], markerSize=20, thickness=5)
33
+ # if img[..., 0][0, 0] == img[..., 2][0, 0]: # BGR to RGB
34
+ # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
35
+ print(sel_pix)
36
+ return img, sel_pix
37
+
38
+
39
+ # undo the selected point
40
+ def undo_points(orig_img, sel_pix):
41
+ if isinstance(orig_img, int): # if orig_img is int, the image if select from examples
42
+ temp = cv2.imread(image_examples[orig_img][0])
43
+ temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)
44
+ else:
45
+ temp = cv2.imread(orig_img)
46
+ temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)
47
+ # draw points
48
+ if len(sel_pix) != 0:
49
+ sel_pix.pop()
50
+ for point, label in sel_pix:
51
+ cv2.drawMarker(temp, point, colors[label], markerType=markers[label], markerSize=20, thickness=5)
52
+ if temp[..., 0][0, 0] == temp[..., 2][0, 0]: # BGR to RGB
53
+ temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)
54
+ return temp, sel_pix
55
 
 
 
 
56
 
57
+ # HF_TOKEN = os.environ.get('HF_KEY')
58
+
59
+ # client = Client("Canyu/Diception",
60
+ # max_workers=3,
61
+ # hf_token=HF_TOKEN)
62
+
63
+ colors = [(255, 0, 0), (0, 255, 0)]
64
+ markers = [1, 5]
65
 
66
  map_prompt = {
67
  'depth': '[[image2depth]]',
68
  'normal': '[[image2normal]]',
69
+ 'human pose': '[[image2pose]]',
70
  'entity segmentation': '[[image2panoptic coarse]]',
71
  'point segmentation': '[[image2segmentation]]',
72
  'semantic segmentation': '[[image2semantic]]',
 
88
  # 返回加载的参数内容
89
  return additional_params
90
 
91
+ def process_image_check(path_input, prompt, sel_points, semantic):
92
+ print('=========== PROCESS IMAGE CHECK ===========')
93
+ print(f"Image Path: {path_input}")
94
+ print(f"Prompt: {prompt}")
95
+ print(f"Selected Points (before processing): {sel_points}")
96
+ print(f"Semantic Input: {semantic}")
97
+ print('===========================================')
98
  if path_input is None:
99
  raise gr.Error(
100
  "Missing image in the left pane: please upload an image first."
 
103
  raise gr.Error(
104
  "At least 1 prediction type is needed."
105
  )
106
+ if 'point segmentation' in prompt and len(sel_points) == 0:
107
+ raise gr.Error(
108
+ "At least 1 point is needed."
109
+ )
110
+ if 'point segmentation' not in prompt and len(sel_points) != 0:
111
+ raise gr.Error(
112
+ "You must select 'point segmentation' when performing point segmentation."
113
+ )
114
+
115
+ if 'semantic segmentation' in prompt and semantic == None:
116
+ raise gr.Error(
117
+ "Target category is needed."
118
+ )
119
+ if 'semantic segmentation' not in prompt and semantic != None:
120
+ raise gr.Error(
121
+ "You must select 'semantic segmentation' when performing semantic segmentation."
122
+ )
123
 
124
 
125
 
 
145
  return inputs
146
 
147
 
148
+ def inf(image_path, prompt, sel_points, semantic):
149
+ inputs = process_image_4(image_path, prompt, sel_points, semantic)
 
 
150
  # return None
151
  return client.predict(
152
  image=handle_file(image_path),
 
158
  return None, None
159
 
160
  def run_demo_server():
161
+ options = ['depth', 'normal', 'entity segmentation', 'human pose', 'point segmentation', 'semantic segmentation']
162
  gradio_theme = gr.themes.Default()
163
  with gr.Blocks(
164
  theme=gradio_theme,
165
  title="Matting",
166
  ) as demo:
167
+ selected_points = gr.State([]) # store points
168
+ original_image = gr.State(value=None) # store original image without points, default None
169
  with gr.Row():
170
  gr.Markdown("# Diception Demo")
171
  with gr.Row():
172
  gr.Markdown("### All results are generated using the same single model. To facilitate input processing, we separate point-prompted segmentation and semantic segmentation, as they require input points and segmentation targets.")
173
  with gr.Row():
174
  checkbox_group = gr.CheckboxGroup(choices=options, label="Select options:")
175
+ with gr.Row():
176
+ semantic_input = gr.Textbox(label="Category Name (for semantic segmentation only, in COCO)", placeholder="e.g. person/cat/dog/elephant......")
177
  with gr.Row():
178
  with gr.Column():
179
+ input_image = gr.Image(
180
  label="Input Image",
181
  type="filepath",
182
  )
183
 
184
+ with gr.Column():
185
+ with gr.Row():
186
+ gr.Markdown('You can click on the image to select points prompt. At most 5 point.')
187
+ undo_button = gr.Button('Undo point')
188
+
189
  with gr.Row():
190
  matting_image_submit_btn = gr.Button(
191
  value="Estimate Matting", variant="primary"
 
210
 
211
 
212
 
213
+ img_clear_button.click(clear_cache, outputs=[input_image, matting_image_output])
214
 
215
  matting_image_submit_btn.click(
216
  fn=process_image_check,
217
+ inputs=[input_image, checkbox_group, selected_points, semantic_input],
218
  outputs=None,
219
  preprocess=False,
220
  queue=False,
221
  ).success(
222
  # fn=process_pipe_matting,
223
  fn=inf,
224
+ inputs=[input_image, checkbox_group, selected_points, semantic_input],
 
 
 
225
  outputs=[matting_image_output],
226
  concurrency_limit=1,
227
  )
 
233
  ),
234
  inputs=[],
235
  outputs=[
236
+ input_image,
237
  matting_image_output,
238
  ],
239
  queue=False,
240
  )
241
 
242
+
243
+ # once user upload an image, the original image is stored in `original_image`
244
+ def store_img(img):
245
+ return img, [] # when new image is uploaded, `selected_points` should be empty
246
+ input_image.upload(
247
+ store_img,
248
+ [input_image],
249
+ [original_image, selected_points]
250
+ )
251
+
252
+ input_image.select(
253
+ get_point,
254
+ [input_image, selected_points],
255
+ [input_image, selected_points],
256
  )
257
+
258
+ undo_button.click(
259
+ undo_points,
260
+ [original_image, selected_points],
261
+ [input_image, selected_points]
262
+ )
263
+
264
+ # gr.Examples(
265
+ # fn=inf,
266
+ # examples=[
267
+ # ["assets/person.jpg", ['depth', 'normal', 'entity segmentation', 'pose']]
268
+ # ],
269
+ # inputs=[input_image, checkbox_group],
270
+ # outputs=[matting_image_output],
271
+ # cache_examples=True,
272
+ # # cache_examples=False,
273
+ # # cached_folder="cache_dir",
274
+ # )
275
 
276
  demo.queue(
277
  api_open=False,
requirements.txt CHANGED
@@ -5,4 +5,5 @@ torch
5
  transformers
6
  xformers
7
  sentencepiece
8
- torchvision
 
 
5
  transformers
6
  xformers
7
  sentencepiece
8
+ torchvision
9
+ opencv-python