EUNSEO56 commited on
Commit
ce400e3
·
1 Parent(s): 15f4206

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -42
app.py CHANGED
@@ -82,49 +82,35 @@ def draw_plot(pred_img, seg):
82
  plt.axis('off')
83
  LABEL_NAMES = np.asarray(labels_list)
84
  FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
85
- FULL_COLOR_MAP = label_to_color_image(FULL_LABEL_MAP)
86
-
87
- unique_labels = np.unique(seg.numpy().astype("uint8"))
88
- ax = plt.subplot(grid_spec[1])
89
- plt.imshow(FULL_COLOR_MAP[unique_labels].astype(np.uint8), interpolation="nearest")
90
- ax.yaxis.tick_right()
91
- plt.yticks(range(len(unique_labels)), LABEL_NAMES[unique_labels])
92
- plt.xticks([], [])
93
- ax.tick_params(width=0.0, labelsize=25)
94
- return fig
95
-
96
-
97
- def sepia(input_img):
98
- input_img = Image.fromarray(input_img)
99
-
100
- inputs = feature_extractor(images=input_img, return_tensors="tf")
101
- outputs = model(**inputs)
102
- logits = outputs.logits
103
-
104
- logits = tf.transpose(logits, [0, 2, 3, 1])
105
- logits = tf.image.resize(
106
- logits, input_img.size[::-1]
107
- ) # We reverse the shape of `image` because `image.size` returns width and height.
108
- seg = tf.math.argmax(logits, axis=-1)[0]
109
-
110
- color_seg = np.zeros(
111
- (seg.shape[0], seg.shape[1], 3), dtype=np.uint8
112
- ) # height, width, 3
113
- for label, color in enumerate(colormap):
114
- color_seg[seg.numpy() == label, :] = color
115
-
116
- # Show image + mask
117
- pred_img = np.array(input_img) * 0.5 + color_seg * 0.5
118
- pred_img = pred_img.astype(np.uint8)
119
 
120
- fig = draw_plot(pred_img, seg)
121
- return fig
122
 
123
- demo = gr.Interface(fn=sepia,
124
- inputs=gr.Image(shape=(400, 600)),
125
- outputs=['plot'],
126
- examples=["side-1.jpg", "side-2.jpg", "side-3.jpg", "side-4.jpg"],
127
- allow_flagging='never')
128
 
129
 
130
- demo.launch()
 
82
  plt.axis('off')
83
  LABEL_NAMES = np.asarray(labels_list)
84
  FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
85
+ FULL_COLOR
86
+
87
+ def get_labels():
88
+ with open(r'labels.txt', 'r') as fp:
89
+ labels = []
90
+ for line in fp:
91
+ labels.append(line[:-1])
92
+ return labels
93
+
94
+ # 레이블 목록 가져오기
95
+ labels = get_labels()
96
+
97
+ # 버튼 생성
98
+ buttons = []
99
+ for label in labels:
100
+ button = gr.Button(label)
101
+ def display_seg(label):
102
+ seg = np.where(seg == labels.index(label), 1, 0)
103
+ pred_img = np.array(input_img) * 0.5 + seg * 0.5
104
+ pred_img = pred_img.astype(np.uint8)
105
+ fig = draw_plot(pred_img, seg)
106
+ return fig
107
+
108
+ button.click(display_seg, label)
109
+ buttons.append(button)
110
+
111
+ demo.components.append(buttons)
112
+ demo.launch()
 
 
 
 
 
 
113
 
 
 
114
 
 
 
 
 
 
115
 
116