Update app.py
Browse files
app.py
CHANGED
@@ -72,45 +72,20 @@ def label_to_color_image(label):
|
|
72 |
raise ValueError("label value too large.")
|
73 |
return colormap[label]
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
for line in fp:
|
91 |
-
labels.append(line[:-1])
|
92 |
-
return labels
|
93 |
-
|
94 |
-
# 레이블 목록 가져오기
|
95 |
-
labels = get_labels()
|
96 |
-
|
97 |
-
# 버튼 생성
|
98 |
-
buttons = []
|
99 |
-
for label in labels:
|
100 |
-
button = gr.Button(label)
|
101 |
-
def display_seg(label):
|
102 |
-
seg = np.where(seg == labels.index(label), 1, 0)
|
103 |
-
pred_img = np.array(input_img) * 0.5 + seg * 0.5
|
104 |
-
pred_img = pred_img.astype(np.uint8)
|
105 |
-
fig = draw_plot(pred_img, seg)
|
106 |
-
return fig
|
107 |
-
|
108 |
-
button.click(display_seg, label)
|
109 |
-
buttons.append(button)
|
110 |
-
|
111 |
-
demo.components.append(buttons)
|
112 |
-
demo.launch()
|
113 |
-
|
114 |
-
|
115 |
|
116 |
|
|
|
72 |
raise ValueError("label value too large.")
|
73 |
return colormap[label]
|
74 |
|
75 |
+
# Create a Gradio interface
|
76 |
+
def image_segmentation(input_img):
|
77 |
+
# Perform image segmentation
|
78 |
+
inputs = feature_extractor(images=input_img, return_tensors="pt")
|
79 |
+
outputs = model(**inputs)
|
80 |
+
pred_label = np.argmax(outputs.logits[0].numpy(), axis=0)
|
81 |
+
pred_img = label_to_color_image(pred_label)
|
82 |
+
|
83 |
+
return pred_img
|
84 |
+
|
85 |
+
gr.Interface(
|
86 |
+
fn=image_segmentation,
|
87 |
+
inputs="image",
|
88 |
+
outputs="image"
|
89 |
+
).launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
|
91 |
|