EUNSEO56 commited on
Commit
9f16c92
ยท
1 Parent(s): d6fb118

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -29
app.py CHANGED
@@ -71,7 +71,7 @@ def label_to_color_image(label):
71
  raise ValueError("label value too large.")
72
  return colormap[label]
73
 
74
- def draw_plot_with_label(pred_img, seg, selected_label):
75
  fig = plt.figure(figsize=(20, 15))
76
 
77
  grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
@@ -79,27 +79,21 @@ def draw_plot_with_label(pred_img, seg, selected_label):
79
  plt.subplot(grid_spec[0])
80
  plt.imshow(pred_img)
81
  plt.axis('off')
82
-
83
  LABEL_NAMES = np.asarray(labels_list)
84
  FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
85
  FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
86
 
87
  unique_labels = np.unique(seg.numpy().astype("uint8"))
88
  ax = plt.subplot(grid_spec[1])
89
-
90
- if selected_label in unique_labels:
91
- mask = seg.numpy() == selected_label
92
- mask_image = FULL_COLOR_MAP[mask].reshape(pred_img.shape)
93
- plt.imshow(mask_image.astype(np.uint8), interpolation="nearest")
94
- plt.yticks([0], [LABEL_NAMES[selected_label]])
95
- else:
96
- plt.text(0.5, 0.5, "Label not found", fontsize=20, ha='center')
97
-
98
  plt.xticks([], [])
99
  ax.tick_params(width=0.0, labelsize=25)
100
  return fig
101
 
102
- def view_segmented_image(input_img, selected_label):
 
103
  input_img = Image.fromarray(input_img)
104
 
105
  inputs = feature_extractor(images=input_img, return_tensors="tf")
@@ -109,25 +103,27 @@ def view_segmented_image(input_img, selected_label):
109
  logits = tf.transpose(logits, [0, 2, 3, 1])
110
  logits = tf.image.resize(
111
  logits, input_img.size[::-1]
112
- )
113
  seg = tf.math.argmax(logits, axis=-1)[0]
114
 
115
- fig = draw_plot_with_label(np.array(input_img), seg, selected_label)
 
 
 
 
 
 
 
 
 
 
116
  return fig
117
 
118
- # ๊ทธ๋ž˜๋””์˜ค ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ •
119
- demo = gr.Interface(
120
- fn=view_segmented_image,
121
- inputs=["image", "dropdown"],
122
- outputs="plot",
123
- examples=[
124
- ["side-1.jpg", "sidewalk"],
125
- ["side-2.jpg", "person"],
126
- ["side-3.jpg", "car"],
127
- ["side-4.jpg", "building"]
128
- ],
129
- live=True,
130
- allow_flagging='never'
131
- )
132
 
133
- demo.launch()
 
71
  raise ValueError("label value too large.")
72
  return colormap[label]
73
 
74
+ def draw_plot(pred_img, seg):
75
  fig = plt.figure(figsize=(20, 15))
76
 
77
  grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
 
79
  plt.subplot(grid_spec[0])
80
  plt.imshow(pred_img)
81
  plt.axis('off')
 
82
  LABEL_NAMES = np.asarray(labels_list)
83
  FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
84
  FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
85
 
86
  unique_labels = np.unique(seg.numpy().astype("uint8"))
87
  ax = plt.subplot(grid_spec[1])
88
+ plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
89
+ ax.yaxis.tick_right()
90
+ plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
 
 
 
 
 
 
91
  plt.xticks([], [])
92
  ax.tick_params(width=0.0, labelsize=25)
93
  return fig
94
 
95
+
96
+ def sepia(input_img):
97
  input_img = Image.fromarray(input_img)
98
 
99
  inputs = feature_extractor(images=input_img, return_tensors="tf")
 
103
  logits = tf.transpose(logits, [0, 2, 3, 1])
104
  logits = tf.image.resize(
105
  logits, input_img.size[::-1]
106
+ ) # We reverse the shape of `image` because `image.size` returns width and height.
107
  seg = tf.math.argmax(logits, axis=-1)[0]
108
 
109
+ color_seg = np.zeros(
110
+ (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
111
+ ) # height, width, 3
112
+ for label, color in enumerate(colormap):
113
+ color_seg[seg.numpy() == label, :] = color
114
+
115
+ # Show image + mask
116
+ pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
117
+ pred_img = pred_img.astype(np.uint8)
118
+
119
+ fig = draw_plot(pred_img, seg)
120
  return fig
121
 
122
+ demo = gr.Interface(fn=sepia,
123
+ inputs=gr.Image(shape=(400, 600)),
124
+ outputs=['plot'],
125
+ examples=["person-1.jpg", "person-2.jpg", "person-3.jpg", "person-4.jpg", "person-5.jpg"],
126
+ allow_flagging='never')
127
+
 
 
 
 
 
 
 
 
128
 
129
+ demo.launch()