EUNSEO56 commited on
Commit
3bc890b
1 Parent(s): 5ea6440

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -5
app.py CHANGED
@@ -64,6 +64,7 @@ with open(r'labels.txt', 'r') as fp:
64
 
65
  colormap = np.asarray(ade_palette())
66
 
 
67
  def label_to_color_image(label):
68
  if label.ndim != 2:
69
  raise ValueError("Expect 2-D input label")
@@ -71,8 +72,8 @@ def label_to_color_image(label):
71
  if np.max(label) >= len(colormap):
72
  raise ValueError("label value too large.")
73
  return colormap[label]
74
-
75
- def draw_plot(pred_img, seg, show_seg=False):
76
  fig = plt.figure(figsize=(20, 15))
77
 
78
  grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
@@ -80,8 +81,53 @@ def draw_plot(pred_img, seg, show_seg=False):
80
  plt.subplot(grid_spec[0])
81
  plt.imshow(pred_img)
82
  plt.axis('off')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
- if show_seg:
85
- unique_labels = np.unique(seg.numpy().astype("uint8"))
86
- ax = plt.subplot(grid_spec[1])
87
 
 
64
 
65
  colormap = np.asarray(ade_palette())
66
 
67
+
68
  def label_to_color_image(label):
69
  if label.ndim != 2:
70
  raise ValueError("Expect 2-D input label")
 
72
  if np.max(label) >= len(colormap):
73
  raise ValueError("label value too large.")
74
  return colormap[label]
75
+
76
+ def draw_plot(pred_img, seg):
77
  fig = plt.figure(figsize=(20, 15))
78
 
79
  grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
 
81
  plt.subplot(grid_spec[0])
82
  plt.imshow(pred_img)
83
  plt.axis('off')
84
+ LABEL_NAMES = np.asarray(labels_list)
85
+ FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
86
+ FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
87
+
88
+ unique_labels = np.unique(seg.numpy().astype("uint8"))
89
+ ax = plt.subplot(grid_spec[1])
90
+ plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
91
+ ax.yaxis.tick_right()
92
+ plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
93
+ plt.xticks([], [])
94
+ ax.tick_params(width=0.0, labelsize=25)
95
+ return fig
96
+
97
+
98
+ def sepia(input_img):
99
+ input_img = Image.fromarray(input_img)
100
+
101
+ inputs = feature_extractor(images=input_img, return_tensors="tf")
102
+ outputs = model(**inputs)
103
+ logits = outputs.logits
104
+
105
+ logits = tf.transpose(logits, [0, 2, 3, 1])
106
+ logits = tf.image.resize(
107
+ logits, input_img.size[::-1]
108
+ ) # We reverse the shape of `image` because `image.size` returns width and height.
109
+ seg = tf.math.argmax(logits, axis=-1)[0]
110
+
111
+ color_seg = np.zeros(
112
+ (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
113
+ ) # height, width, 3
114
+ for label, color in enumerate(colormap):
115
+ color_seg[seg.numpy() == label, :] = color
116
+
117
+ # Show image + mask
118
+ pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
119
+ pred_img = pred_img.astype(np.uint8)
120
+
121
+ fig = draw_plot(pred_img, seg)
122
+ return fig
123
+
124
+ demo = gr.Interface(fn=sepia,
125
+ inputs=gr.Image(shape=(400, 600)),
126
+ outputs=['plot'],
127
+ examples=["side-1.jpg", "side-2.jpg", "side-3.jpg", "side-4.jpg"],
128
+ allow_flagging='never')
129
+
130
+
131
+ demo.launch()
132
 
 
 
 
133