phiyodr commited on
Commit
1d75c5c
1 Parent(s): bbee78d

Update: transparent image

Browse files
Files changed (1) hide show
  1. app.py +104 -17
app.py CHANGED
@@ -8,13 +8,14 @@ from torch import nn
8
  from transformers import SegformerForSemanticSegmentation
9
  import sys
10
  import io
11
-
12
 
13
  ###################
14
  # Setup label names
15
  target_list = ['Crack', 'ACrack', 'Wetspot', 'Efflorescence', 'Rust', 'Rockpocket', 'Hollowareas', 'Cavity',
16
  'Spalling', 'Graffiti', 'Weathering', 'Restformwork', 'ExposedRebars',
17
  'Bearing', 'EJoint', 'Drainage', 'PEquipment', 'JTape', 'WConccor']
 
18
  classes, nclasses = target_list, len(target_list)
19
  label2id = dict(zip(classes, range(nclasses)))
20
  id2label = dict(zip(range(nclasses), classes))
@@ -48,7 +49,9 @@ model.eval()
48
  ##################
49
 
50
  to_tensor = transforms.ToTensor()
 
51
  resize = transforms.Resize((512, 512))
 
52
  normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
53
  std=[0.229, 0.224, 0.225])
54
 
@@ -58,11 +61,50 @@ def process_pil(img):
58
  img = normalize(img)
59
  return img
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  ###########
62
  # Inference
63
 
64
- def inference(img, name):
 
 
 
65
  img = process_pil(img)
 
66
  mask = model(img.unsqueeze(0)) # we need a batch, hence we introduce an extra dimenation at position 0 (unsqueeze)
67
  mask = mask[0]
68
 
@@ -85,21 +127,39 @@ def inference(img, name):
85
  labs = ["ALL"] + target_list
86
 
87
  fig, axes = plt.subplots(5, 4, figsize = (10,10))
88
-
 
 
 
89
  for i, ax in enumerate(axes.flat):
90
  label = labs[i]
 
 
 
91
  ax.imshow(mask_preds[i])
92
  ax.set_title(label)
93
 
94
  plt.tight_layout()
95
 
96
-
97
  # plt to PIL
98
  img_buf = io.BytesIO()
99
  fig.savefig(img_buf, format='png')
100
  im = Image.open(img_buf)
101
- return im
102
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
 
105
  title = "dacl-challenge @ WACV2024"
@@ -141,15 +201,42 @@ description = """
141
  """
142
 
143
  article = "<p style='text-align: center'><a href='https://github.com/phiyodr/dacl10k-toolkit' target='_blank'>Github Repo</a></p>"
144
- examples=[['assets/dacl10k_v2_validation_0037.jpg', 'dacl10k_v2_validation_0037.jpg'],['assets/dacl10k_v2_validation_0068.jpg','dacl10k_v2_validation_0068.jpg'], ['assets/dacl10k_v2_validation_0053.jpg', 'dacl10k_v2_validation_0053.jpg']]
145
-
146
- demo = gr.Interface(
147
- fn=inference,
148
- inputs=gr.inputs.Image(type="pil"),
149
- outputs=gr.outputs.Image(type="pil"),
150
- title=title,
151
- description=description,
152
- article=article,
153
- examples=examples)
154
-
155
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from transformers import SegformerForSemanticSegmentation
9
  import sys
10
  import io
11
+ import pdb
12
 
13
  ###################
14
  # Setup label names
15
  target_list = ['Crack', 'ACrack', 'Wetspot', 'Efflorescence', 'Rust', 'Rockpocket', 'Hollowareas', 'Cavity',
16
  'Spalling', 'Graffiti', 'Weathering', 'Restformwork', 'ExposedRebars',
17
  'Bearing', 'EJoint', 'Drainage', 'PEquipment', 'JTape', 'WConccor']
18
+ target_list_all = ["All"] + target_list
19
  classes, nclasses = target_list, len(target_list)
20
  label2id = dict(zip(classes, range(nclasses)))
21
  id2label = dict(zip(range(nclasses), classes))
 
49
  ##################
50
 
51
  to_tensor = transforms.ToTensor()
52
+ to_array = transforms.ToPILImage()
53
  resize = transforms.Resize((512, 512))
54
+ resize_small = transforms.Resize((369,369))
55
  normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
56
  std=[0.229, 0.224, 0.225])
57
 
 
61
  img = normalize(img)
62
  return img
63
 
64
+ # the background of the image
65
+ def resize_pil(img):
66
+ img = to_tensor(img)
67
+ img = resize_small(img)
68
+ img = to_array(img)
69
+ return img
70
+
71
+ # combine the foreground (mask_all) and background (original image) to create one image
72
+ def transparent(fg, bg, alpha_factor):
73
+
74
+ foreground = np.array(fg)
75
+ background = np.array(bg)
76
+
77
+ background = Image.fromarray(bg)
78
+ foreground = Image.fromarray(fg)
79
+ new_alpha_factor = int(255*alpha_factor)
80
+ foreground.putalpha(new_alpha_factor)
81
+ background.paste(foreground, (0, 0), foreground)
82
+
83
+ return background
84
+
85
+ def show_img(all_imgs, dropdown, bg, alpha_factor):
86
+ idx = target_list_all.index(dropdown)
87
+ fg= all_imgs[idx]["name"]
88
+
89
+ foreground = Image.open(fg)
90
+ background = np.array(bg)
91
+
92
+ background = Image.fromarray(bg)
93
+ new_alpha_factor = int(255*alpha_factor)
94
+ foreground.putalpha(new_alpha_factor)
95
+ background.paste(foreground, (0, 0), foreground)
96
+
97
+ return background
98
+
99
  ###########
100
  # Inference
101
 
102
+
103
+ def inference(img, alpha_factor):
104
+ background = resize_pil(img)
105
+
106
  img = process_pil(img)
107
+
108
  mask = model(img.unsqueeze(0)) # we need a batch, hence we introduce an extra dimenation at position 0 (unsqueeze)
109
  mask = mask[0]
110
 
 
127
  labs = ["ALL"] + target_list
128
 
129
  fig, axes = plt.subplots(5, 4, figsize = (10,10))
130
+
131
+ # save all mask_preds in all_mask
132
+ all_masks = []
133
+
134
  for i, ax in enumerate(axes.flat):
135
  label = labs[i]
136
+
137
+ all_masks.append(mask_preds[i])
138
+
139
  ax.imshow(mask_preds[i])
140
  ax.set_title(label)
141
 
142
  plt.tight_layout()
143
 
 
144
  # plt to PIL
145
  img_buf = io.BytesIO()
146
  fig.savefig(img_buf, format='png')
147
  im = Image.open(img_buf)
 
148
 
149
+ # Saved all masks combined with unvisible xaxis und yaxis and without a white
150
+ # background.
151
+ all_images = []
152
+ for i in range(len(all_masks)):
153
+ plt.figure()
154
+ fig = plt.imshow(all_masks[i])
155
+ plt.axis('off')
156
+ fig.axes.get_xaxis().set_visible(False)
157
+ fig.axes.get_yaxis().set_visible(False)
158
+ img_buf = io.BytesIO()
159
+ plt.savefig(img_buf, bbox_inches='tight', pad_inches = 0, format='png')
160
+ all_images.append(Image.open(img_buf))
161
+
162
+ return im, all_images, background
163
 
164
 
165
  title = "dacl-challenge @ WACV2024"
 
201
  """
202
 
203
  article = "<p style='text-align: center'><a href='https://github.com/phiyodr/dacl10k-toolkit' target='_blank'>Github Repo</a></p>"
204
+ examples=[
205
+ ["assets/dacl10k_v2_validation_0026.jpg", "dacl10k_v2_validation_0026.jpg"],
206
+ ["assets/dacl10k_v2_validation_0037.jpg", "dacl10k_v2_validation_0037.jpg"],
207
+ ["assets/dacl10k_v2_validation_0053.jpg", "dacl10k_v2_validation_0053.jpg"],
208
+ ["assets/dacl10k_v2_validation_0068.jpg", "dacl10k_v2_validation_0068.jpg"],
209
+ ["assets/dacl10k_v2_validation_0125.jpg", "dacl10k_v2_validation_0125.jpg"],
210
+ ["assets/dacl10k_v2_validation_0153.jpg", "dacl10k_v2_validation_0153.jpg"],
211
+ ["assets/dacl10k_v2_validation_0263.jpg", "dacl10k_v2_validation_0263.jpg"],
212
+ ["assets/dacl10k_v2_validation_0336.jpg", "dacl10k_v2_validation_0336.jpg"],
213
+ ["assets/dacl10k_v2_validation_0429.jpg", "dacl10k_v2_validation_0429.jpg"],
214
+ ["assets/dacl10k_v2_validation_0500.jpg", "dacl10k_v2_validation_0500.jpg"],
215
+ ["assets/dacl10k_v2_validation_0549.jpg", "dacl10k_v2_validation_0549.jpg"],
216
+ ["assets/dacl10k_v2_validation_0609.jpg", "dacl10k_v2_validation_0609.jpg"]
217
+ ]
218
+
219
+
220
+
221
+ with gr.Blocks() as app:
222
+ with gr.Row():
223
+ input_img = gr.inputs.Image(type="pil", label="Original Image")
224
+ gr.Examples(examples=examples, inputs=[input_img])
225
+ with gr.Row():
226
+ img = gr.outputs.Image(type="pil", label="All Masks")
227
+ transparent_img = gr.outputs.Image(type="pil", label="Transparent Image")
228
+ with gr.Row():
229
+ slider = gr.Slider(minimum=0, maximum=1, value=0.5, label="Alpha Factor")
230
+ dropdown = gr.Dropdown(choices=target_list_all, label="Pick image", value="All")
231
+
232
+ all_masks = gr.Gallery(visible=False)
233
+ background = gr.Image(visible=False)
234
+
235
+ generate_mask_slider = gr.Button("Generate Masks")
236
+ generate_mask_slider.click(inference, inputs=[input_img], outputs=[img, all_masks, background])
237
+
238
+ submit_transparent_img = gr.Button("Generate Transparent Mask (with Alpha Factor)")
239
+ submit_transparent_img.click(show_img, inputs=[all_masks, dropdown, background, slider], outputs=[transparent_img])
240
+
241
+
242
+ app.launch()