kevinssy commited on
Commit
a63184f
1 Parent(s): 5379278

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -55
app.py CHANGED
@@ -23,6 +23,7 @@ torch.manual_seed(0)
23
 
24
  CFG_PATH = "configs/demo/pokemon.yaml"
25
 
 
26
  def generate_distinct_colors(n):
27
  colors = []
28
  # generate a random number from 0 to 1
@@ -79,7 +80,6 @@ def visualize_segmentation(image,
79
  # Create a figure and axis
80
  fig, ax = plt.subplots(1, figsize=(12, 9))
81
  # Display the image
82
- # ax.imshow(image)
83
  # Generate distinct colors for each mask
84
  final_mask = np.zeros(
85
  (masks.shape[1], masks.shape[2], 3), dtype=np.float32)
@@ -115,18 +115,11 @@ def get_sam_masks(cfg,
115
  image_path=None,
116
  img_sam=None,
117
  pipeline=None):
118
- # image_id = image_path.split('/')[-1].split('.')[0]
119
- # sam_mask_path = os.path.join(cfg.test.sam_mask_root, f'{image_id}.npz')
120
- # if os.path.exists(sam_mask_path):
121
- # sam_mask_masks = np.load(sam_mask_path, allow_pickle=True)
122
- # mask_tensor = torch.from_numpy(sam_mask_masks['mask_tensor'])
123
- # mask_list = sam_mask_path['mask_list']
124
- # else:
125
  print("generating sam masks online")
126
  if img_sam is None and image_path is not None:
127
  raise ValueError(
128
  'Please provide either the image path or the image numpy array.')
129
-
130
  mask_tensor, mask_list = generate_masks_from_sam(
131
  image_path,
132
  save_path='./',
@@ -168,6 +161,7 @@ def load_sam(cfg, device):
168
  )
169
  return pipeline
170
 
 
171
  def generate(img,
172
  class_names,
173
  clip_thresh,
@@ -192,7 +186,6 @@ def generate(img,
192
  seg_mode='semantic',
193
  device=device)
194
 
195
-
196
  # resize image by dividing 2 if the size is larger than 1000
197
  if img.size[0] > 1000:
198
  img = img.resize((img.size[0] // 2, img.size[1] // 2))
@@ -203,8 +196,7 @@ def generate(img,
203
 
204
  # class_names = ['the women chatting', 'the women chatting', 'table', 'fridge', 'cooking pot']
205
 
206
- pseudo_masks, _, _ = car_model(
207
- img, sentences, 1)
208
 
209
  if post_process == 'SAM':
210
  pipeline = load_sam(cfg, device)
@@ -230,7 +222,6 @@ def generate(img,
230
  return demo_img
231
 
232
 
233
-
234
  if __name__ == "__main__":
235
  parser = argparse.ArgumentParser('car')
236
  parser.add_argument("--cfg-path",
@@ -238,48 +229,48 @@ if __name__ == "__main__":
238
  help="path to configuration file.")
239
  args = parser.parse_args()
240
 
241
- demo = gr.Interface(generate,
242
- inputs=[gr.Image(label="upload an image", type="pil"),
243
- "text",
244
- gr.Slider(label="clip thresh",
245
- minimum=0,
246
- maximum=1,
247
- value=0.4,
248
- step=0.1,
249
- info="the threshold for clip-es adversarial heatmap clipping"),
250
- gr.Slider(label="mask thresh",
251
- minimum=0,
252
- maximum=1,
253
- value=0.6,
254
- step=0.1,
255
- info="the binariation threshold for the mask to generate visual prompt"),
256
- gr.Slider(label="confidence thresh",
257
- minimum=0,
258
- maximum=1,
259
- value=0,
260
- step=0.1,
261
- info="the threshold for filtering the proposed classes"),
262
- gr.Radio(["CRF", "SAM"], label="post process", value="CRF", info="choose the post process method"),
263
- gr.Slider(label="stability score thresh for SAM mask proposal \n(only when SAM is chosen for post process)",
264
- minimum=0,
265
- maximum=1,
266
- value=0.95,
267
- step=0.1),
268
- gr.Slider(label="box nms thresh for SAM mask proposal \n(only when SAM is chosen for post process)", minimum=0, maximum=1, value=0.7, step=0.1),
269
- gr.Slider(label="intersection over mask threshold for SAM mask proposal \n(only when SAM is chosen for post process)", minimum=0, maximum=1, value=0.5, step=0.1),
270
- gr.Slider(label="minimum prediction threshold for SAM mask proposal \n(only when SAM is chosen for post process)", minimum=0, maximum=1, value=0.03, step=0.01)],
271
- outputs="image",
272
- title="CLIP as RNN: Segment Countless Visual Concepts without Training Endeavor",
273
- description="This is the official demo for CLIP as RNN. Please upload an image and type in the class names (connected by ',' e.g. cat,dog,human) you want to segment. The model will generate the segmentation masks for the input image. You can also adjust the clip thresh, mask thresh and confidence thresh to get better results.",
274
- examples=[["demo/pokemon1.jpg", "Charmander,Bulbasaur,Squirtle", 0.6, 0.6, 0, "SAM", 0.95, 0.7, 0.6, 0.01],
275
- ["demo/batman.jpg", "Batman,Joker,Cat Woman", 0.6, 0.6, 0, "SAM", 0.95, 0.7, 0.6, 0.01],
276
- ["demo/avengers1.jpg", "Thor,Captain America,Hulk,Iron Man", 0.6, 0.6, 0, "SAM", 0.89, 0.65, 0.5, 0.03],
277
-
278
- ])
 
 
 
 
279
  demo.launch(share=True)
280
 
281
-
282
  # device = "cuda" if torch.cuda.is_available() else "cpu"
283
-
284
-
285
- stop = 0
 
23
 
24
  CFG_PATH = "configs/demo/pokemon.yaml"
25
 
26
+
27
  def generate_distinct_colors(n):
28
  colors = []
29
  # generate a random number from 0 to 1
 
80
  # Create a figure and axis
81
  fig, ax = plt.subplots(1, figsize=(12, 9))
82
  # Display the image
 
83
  # Generate distinct colors for each mask
84
  final_mask = np.zeros(
85
  (masks.shape[1], masks.shape[2], 3), dtype=np.float32)
 
115
  image_path=None,
116
  img_sam=None,
117
  pipeline=None):
 
 
 
 
 
 
 
118
  print("generating sam masks online")
119
  if img_sam is None and image_path is not None:
120
  raise ValueError(
121
  'Please provide either the image path or the image numpy array.')
122
+
123
  mask_tensor, mask_list = generate_masks_from_sam(
124
  image_path,
125
  save_path='./',
 
161
  )
162
  return pipeline
163
 
164
+
165
  def generate(img,
166
  class_names,
167
  clip_thresh,
 
186
  seg_mode='semantic',
187
  device=device)
188
 
 
189
  # resize image by dividing 2 if the size is larger than 1000
190
  if img.size[0] > 1000:
191
  img = img.resize((img.size[0] // 2, img.size[1] // 2))
 
196
 
197
  # class_names = ['the women chatting', 'the women chatting', 'table', 'fridge', 'cooking pot']
198
 
199
+ pseudo_masks, _ = car_model(img, sentences)
 
200
 
201
  if post_process == 'SAM':
202
  pipeline = load_sam(cfg, device)
 
222
  return demo_img
223
 
224
 
 
225
  if __name__ == "__main__":
226
  parser = argparse.ArgumentParser('car')
227
  parser.add_argument("--cfg-path",
 
229
  help="path to configuration file.")
230
  args = parser.parse_args()
231
 
232
+ demo = gr.Interface(generate,
233
+ inputs=[gr.Image(label="upload an image", type="pil"),
234
+ "text",
235
+ gr.Slider(label="clip thresh",
236
+ minimum=0,
237
+ maximum=1,
238
+ value=0.4,
239
+ step=0.1,
240
+ info="the threshold for clip-es adversarial heatmap clipping"),
241
+ gr.Slider(label="mask thresh",
242
+ minimum=0,
243
+ maximum=1,
244
+ value=0.6,
245
+ step=0.1,
246
+ info="the binariation threshold for the mask to generate visual prompt"),
247
+ gr.Slider(label="confidence thresh",
248
+ minimum=0,
249
+ maximum=1,
250
+ value=0,
251
+ step=0.1,
252
+ info="the threshold for filtering the proposed classes"),
253
+ gr.Radio(["CRF", "SAM"], label="post process",
254
+ value="CRF", info="choose the post process method"),
255
+ gr.Slider(label="stability score thresh for SAM mask proposal \n(only when SAM is chosen for post process)",
256
+ minimum=0,
257
+ maximum=1,
258
+ value=0.95,
259
+ step=0.1),
260
+ gr.Slider(label="box nms thresh for SAM mask proposal \n(only when SAM is chosen for post process)",
261
+ minimum=0, maximum=1, value=0.7, step=0.1),
262
+ gr.Slider(label="intersection over mask threshold for SAM mask proposal \n(only when SAM is chosen for post process)",
263
+ minimum=0, maximum=1, value=0.5, step=0.1),
264
+ gr.Slider(label="minimum prediction threshold for SAM mask proposal \n(only when SAM is chosen for post process)", minimum=0, maximum=1, value=0.03, step=0.01)],
265
+ outputs="image",
266
+ title="CLIP as RNN: Segment Countless Visual Concepts without Training Endeavor",
267
+ description="This is the official demo for CLIP as RNN. Please upload an image and type in the class names (connected by ',' e.g. cat,dog,human) you want to segment. The model will generate the segmentation masks for the input image. You can also adjust the clip thresh, mask thresh and confidence thresh to get better results.",
268
+ examples=[["demo/pokemon.jpg", "Pikachu,Eevee", 0.6, 0.6, 0, "SAM", 0.95, 0.7, 0.6, 0.01],
269
+ ["demo/Eiffel_tower.jpg", "Eiffel Tower",
270
+ 0.6, 0.6, 0, "SAM", 0.95, 0.7, 0.6, 0.01],
271
+ ["demo/superhero.jpeg", "Batman,Superman,Wonder Woman,Flash,Cyborg",
272
+ 0.6, 0.6, 0, "SAM", 0.89, 0.65, 0.5, 0.03],
273
+ ])
274
  demo.launch(share=True)
275
 
 
276
  # device = "cuda" if torch.cuda.is_available() else "cpu"