root commited on
Commit
872b038
·
1 Parent(s): 889ce86

more than one user.

Browse files
Files changed (2) hide show
  1. app.py +16 -5
  2. segment.py +13 -9
app.py CHANGED
@@ -66,6 +66,18 @@ def load_image_ui(load_edit, input_folder="example_tmp"):
66
  # return None, None, None, None, None, None
67
 
68
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  def transparent_paste_with_mask(backimg, foreimg, mask_np,transparency = 128):
71
  backimg_solid_np = np.array(backimg)
@@ -177,8 +189,6 @@ if os.path.isdir("./example_tmp"):
177
 
178
  from segment import run_segmentation
179
 
180
-
181
-
182
  with gr.Blocks() as demo:
183
  image = gr.State() # store mask
184
  image_loaded = gr.State()
@@ -218,9 +228,7 @@ with gr.Blocks() as demo:
218
  outputs= [canvas, label]
219
  )
220
 
221
- segment_button.click(run_segmentation,
222
- [canvas] ,
223
- [text_button, result_info0] )
224
 
225
 
226
  canvas.upload(image_change, inputs=[], outputs=[text_button])
@@ -361,6 +369,9 @@ with gr.Blocks() as demo:
361
  text_button.click(load_image_ui, [false] ,
362
  [image_loaded, segmentation, mask_np_list, mask_label_list, canvas, slider, slider2] )
363
 
 
 
 
364
 
365
 
366
 
 
66
  # return None, None, None, None, None, None
67
 
68
 
69
+ def run_segmentation_wrapper(image):
70
+ mask_np_list,mask_label_list = run_segmentation(image)
71
+ for img_path in Path("example_tmp").iterdir():
72
+ if img_path.name in ["img_512.png"]:
73
+ image = Image.open(img_path)
74
+ image = image.convert('RGB')
75
+ segmentation = create_segmentation(mask_np_list)
76
+ print("!!", len(mask_np_list))
77
+ max_val = len(mask_np_list)-1
78
+ sliderup = gr.Slider(value = 0, minimum=0, maximum=max_val, step=1, interactive=True)
79
+ return image, segmentation, mask_np_list, mask_label_list, image, sliderup, sliderup , 'Segmentatin finish.'
80
+
81
 
82
  def transparent_paste_with_mask(backimg, foreimg, mask_np,transparency = 128):
83
  backimg_solid_np = np.array(backimg)
 
189
 
190
  from segment import run_segmentation
191
 
 
 
192
  with gr.Blocks() as demo:
193
  image = gr.State() # store mask
194
  image_loaded = gr.State()
 
228
  outputs= [canvas, label]
229
  )
230
 
231
+
 
 
232
 
233
 
234
  canvas.upload(image_change, inputs=[], outputs=[text_button])
 
369
  text_button.click(load_image_ui, [false] ,
370
  [image_loaded, segmentation, mask_np_list, mask_label_list, canvas, slider, slider2] )
371
 
372
+ segment_button.click(run_segmentation_wrapper,
373
+ [canvas] ,
374
+ [image_loaded, segmentation, mask_np_list, mask_label_list, canvas, slider, slider2, result_info0] )
375
 
376
 
377
 
segment.py CHANGED
@@ -45,14 +45,17 @@ def draw_panoptic_segmentation(segmentation, segments_info,save_folder=None, nos
45
  instances_counter = defaultdict(int)
46
  handles = []
47
  label_list = []
 
 
 
48
  if not noseg:
49
  if torch.min(segmentation) == 0:
50
  mask = segmentation==0
51
- mask = mask.cpu().detach().numpy() # [512,512] bool
52
  segment_label = "rest"
53
- np.save( os.path.join(save_folder, "mask{}_{}.npy".format(0,"rest")) , mask)
54
  color = viridis(0)
55
  label = f"{segment_label}-{0}"
 
56
  handles.append(mpatches.Patch(color=color, label=label))
57
  label_list.append(label)
58
 
@@ -61,20 +64,20 @@ def draw_panoptic_segmentation(segmentation, segments_info,save_folder=None, nos
61
  mask = segmentation==segment_id
62
  if torch.min(segmentation) != 0:
63
  segment_id -= 1
64
- mask = mask.cpu().detach().numpy() # [512,512] bool
65
-
66
  segment_label = model.config.id2label[segment['label_id']]
67
  instances_counter[segment['label_id']] += 1
68
- np.save( os.path.join(save_folder, "mask{}_{}.npy".format(segment_id,segment_label)) , mask)
69
  color = viridis(segment_id)
70
 
71
  label = f"{segment_label}-{segment_id}"
72
  handles.append(mpatches.Patch(color=color, label=label))
73
  label_list.append(label)
74
  else:
75
- mask = np.full(segmentation.shape, True)
76
  segment_label = "all"
77
- np.save( os.path.join(save_folder, "mask{}_{}.npy".format(0,"all")) , mask)
78
  color = viridis(0)
79
  label = f"{segment_label}-{0}"
80
  handles.append(mpatches.Patch(color=color, label=label))
@@ -86,6 +89,7 @@ def draw_panoptic_segmentation(segmentation, segments_info,save_folder=None, nos
86
  ax.legend(handles=handles)
87
  plt.savefig(os.path.join(save_folder, 'seg_init.png'), dpi=500 )
88
  print("; ".join(label_list))
 
89
 
90
 
91
 
@@ -114,7 +118,7 @@ def run_segmentation(image, name="example_tmp", size = 512, noseg=False):
114
  panoptic_segmentation = processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
115
  save_folder = os.path.join(base_folder_path, name)
116
  os.makedirs(save_folder, exist_ok=True)
117
- draw_panoptic_segmentation(**panoptic_segmentation, save_folder = save_folder, noseg = noseg, model = model)
118
  print("Finish segment")
119
  #block_flag += 1
120
- return gr.Button("1.2 Load original masks",visible = True), "Segmentation finished."#, gr.Button.update("1.2 Load edited masks",visible = True), gr.Checkbox.update(label = "Show Segmentation",visible = True)
 
45
  instances_counter = defaultdict(int)
46
  handles = []
47
  label_list = []
48
+
49
+ mask_list = []
50
+
51
  if not noseg:
52
  if torch.min(segmentation) == 0:
53
  mask = segmentation==0
54
+ mask = mask.cpu().detach() # [512,512] bool
55
  segment_label = "rest"
 
56
  color = viridis(0)
57
  label = f"{segment_label}-{0}"
58
+ mask_list.append(mask)
59
  handles.append(mpatches.Patch(color=color, label=label))
60
  label_list.append(label)
61
 
 
64
  mask = segmentation==segment_id
65
  if torch.min(segmentation) != 0:
66
  segment_id -= 1
67
+ mask = mask.cpu().detach() # [512,512] bool
68
+ mask_list.append(mask)
69
  segment_label = model.config.id2label[segment['label_id']]
70
  instances_counter[segment['label_id']] += 1
71
+
72
  color = viridis(segment_id)
73
 
74
  label = f"{segment_label}-{segment_id}"
75
  handles.append(mpatches.Patch(color=color, label=label))
76
  label_list.append(label)
77
  else:
78
+ mask = torch.from_numpy(np.full(segmentation.shape, True))
79
  segment_label = "all"
80
+ mask_list.append(mask)
81
  color = viridis(0)
82
  label = f"{segment_label}-{0}"
83
  handles.append(mpatches.Patch(color=color, label=label))
 
89
  ax.legend(handles=handles)
90
  plt.savefig(os.path.join(save_folder, 'seg_init.png'), dpi=500 )
91
  print("; ".join(label_list))
92
+ return mask_list,label_list
93
 
94
 
95
 
 
118
  panoptic_segmentation = processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
119
  save_folder = os.path.join(base_folder_path, name)
120
  os.makedirs(save_folder, exist_ok=True)
121
+ mask_list,label_list = draw_panoptic_segmentation(**panoptic_segmentation, save_folder = save_folder, noseg = noseg, model = model)
122
  print("Finish segment")
123
  #block_flag += 1
124
+ return mask_list,label_list#, gr.Button.update("1.2 Load edited masks",visible = True), gr.Checkbox.update(label = "Show Segmentation",visible = True)