EUNSEO56 commited on
Commit
8004aa4
·
1 Parent(s): ce400e3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -40
app.py CHANGED
@@ -72,45 +72,20 @@ def label_to_color_image(label):
72
  raise ValueError("label value too large.")
73
  return colormap[label]
74
 
75
- def draw_plot(pred_img, seg):
76
- fig = plt.figure(figsize=(20, 15))
77
-
78
- grid_spec = gridspec.GridSpec(1, 2, width_ratios=[6, 1])
79
-
80
- plt.subplot(grid_spec[0])
81
- plt.imshow(pred_img)
82
- plt.axis('off')
83
- LABEL_NAMES = np.asarray(labels_list)
84
- FULL_LABEL_MAP = np.arange(len(LABEL_NAMES)).reshape(len(LABEL_NAMES), 1)
85
- FULL_COLOR
86
-
87
- def get_labels():
88
- with open(r'labels.txt', 'r') as fp:
89
- labels = []
90
- for line in fp:
91
- labels.append(line[:-1])
92
- return labels
93
-
94
- # 레이블 목록 가져오기
95
- labels = get_labels()
96
-
97
- # 버튼 생성
98
- buttons = []
99
- for label in labels:
100
- button = gr.Button(label)
101
- def display_seg(label):
102
- seg = np.where(seg == labels.index(label), 1, 0)
103
- pred_img = np.array(input_img) * 0.5 + seg * 0.5
104
- pred_img = pred_img.astype(np.uint8)
105
- fig = draw_plot(pred_img, seg)
106
- return fig
107
-
108
- button.click(display_seg, label)
109
- buttons.append(button)
110
-
111
- demo.components.append(buttons)
112
- demo.launch()
113
-
114
-
115
 
116
 
 
72
  raise ValueError("label value too large.")
73
  return colormap[label]
74
 
75
+ # Create a Gradio interface
76
+ def image_segmentation(input_img):
77
+ # Perform image segmentation
78
+ inputs = feature_extractor(images=input_img, return_tensors="pt")
79
+ outputs = model(**inputs)
80
+ pred_label = np.argmax(outputs.logits[0].numpy(), axis=0)
81
+ pred_img = label_to_color_image(pred_label)
82
+
83
+ return pred_img
84
+
85
+ gr.Interface(
86
+ fn=image_segmentation,
87
+ inputs="image",
88
+ outputs="image"
89
+ ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91