EUNSEO56 commited on
Commit
d6fb118
ยท
1 Parent(s): cebfdfe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -16
app.py CHANGED
@@ -12,7 +12,7 @@ feature_extractor = SegformerFeatureExtractor.from_pretrained(
12
  )
13
  model = TFSegformerForSemanticSegmentation.from_pretrained(
14
  "nickmuchi/segformer-b4-finetuned-segments-sidewalk",
15
- from_pt=True
16
  )
17
 
18
  def ade_palette():
@@ -71,7 +71,7 @@ def label_to_color_image(label):
71
  raise ValueError("label value too large.")
72
  return colormap[label]
73
 
74
- def draw_plot(pred_img, seg, cursor_pos):
75
  fig = plt.figure(figsize=(20, 15))
76
 
77
  grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
@@ -79,25 +79,27 @@ def draw_plot(pred_img, seg, cursor_pos):
79
  plt.subplot(grid_spec[0])
80
  plt.imshow(pred_img)
81
  plt.axis('off')
 
82
  LABEL_NAMES = np.asarray(labels_list)
83
  FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
84
  FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
85
 
86
  unique_labels = np.unique(seg.numpy().astype("uint8"))
87
  ax = plt.subplot(grid_spec[1])
88
- cursor_x, cursor_y = cursor_pos
89
 
90
- mask = seg.numpy() == seg.numpy()[cursor_x, cursor_y]
91
- mask_image = FULL_COLOR_MAP[mask].reshape(pred_img.shape)
 
 
 
 
 
92
 
93
- plt.imshow(mask_image.astype(np.uint8), interpolation="nearest")
94
- ax.yaxis.tick_right()
95
- plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
96
  plt.xticks([], [])
97
  ax.tick_params(width=0.0, labelsize=25)
98
  return fig
99
 
100
- def sepia(input_img, cursor_pos):
101
  input_img = Image.fromarray(input_img)
102
 
103
  inputs = feature_extractor(images=input_img, return_tensors="tf")
@@ -110,14 +112,22 @@ def sepia(input_img, cursor_pos):
110
  )
111
  seg = tf.math.argmax(logits, axis=-1)[0]
112
 
113
- fig = draw_plot(np.array(input_img), seg, cursor_pos)
114
  return fig
115
 
116
- demo = gr.Interface(fn=sepia,
117
- inputs=["image", "canvas"],
118
- outputs="plot",
119
- examples=[["side-1.jpg", [200, 300]], ["side-2.jpg", [150, 250]], ["side-3.jpg", [100, 200]], ["side-4.jpg", [250, 400]]],
120
- live=True,
121
- allow_flagging='never')
 
 
 
 
 
 
 
 
122
 
123
  demo.launch()
 
12
  )
13
  model = TFSegformerForSemanticSegmentation.from_pretrained(
14
  "nickmuchi/segformer-b4-finetuned-segments-sidewalk",
15
+
16
  )
17
 
18
  def ade_palette():
 
71
  raise ValueError("label value too large.")
72
  return colormap[label]
73
 
74
+ def draw_plot_with_label(pred_img, seg, selected_label):
75
  fig = plt.figure(figsize=(20, 15))
76
 
77
  grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
 
79
  plt.subplot(grid_spec[0])
80
  plt.imshow(pred_img)
81
  plt.axis('off')
82
+
83
  LABEL_NAMES = np.asarray(labels_list)
84
  FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
85
  FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
86
 
87
  unique_labels = np.unique(seg.numpy().astype("uint8"))
88
  ax = plt.subplot(grid_spec[1])
 
89
 
90
+ if selected_label in unique_labels:
91
+ mask = seg.numpy() == selected_label
92
+ mask_image = FULL_COLOR_MAP[mask].reshape(pred_img.shape)
93
+ plt.imshow(mask_image.astype(np.uint8), interpolation="nearest")
94
+ plt.yticks([0], [LABEL_NAMES[selected_label]])
95
+ else:
96
+ plt.text(0.5, 0.5, "Label not found", fontsize=20, ha='center')
97
 
 
 
 
98
  plt.xticks([], [])
99
  ax.tick_params(width=0.0, labelsize=25)
100
  return fig
101
 
102
+ def view_segmented_image(input_img, selected_label):
103
  input_img = Image.fromarray(input_img)
104
 
105
  inputs = feature_extractor(images=input_img, return_tensors="tf")
 
112
  )
113
  seg = tf.math.argmax(logits, axis=-1)[0]
114
 
115
+ fig = draw_plot_with_label(np.array(input_img), seg, selected_label)
116
  return fig
117
 
118
+ # ๊ทธ๋ž˜๋””์˜ค ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ •
119
+ demo = gr.Interface(
120
+ fn=view_segmented_image,
121
+ inputs=["image", "dropdown"],
122
+ outputs="plot",
123
+ examples=[
124
+ ["side-1.jpg", "sidewalk"],
125
+ ["side-2.jpg", "person"],
126
+ ["side-3.jpg", "car"],
127
+ ["side-4.jpg", "building"]
128
+ ],
129
+ live=True,
130
+ allow_flagging='never'
131
+ )
132
 
133
  demo.launch()