dustin-cheng commited on
Commit
072446f
·
verified ·
1 Parent(s): 33b8475

Upload image_seg.py

Browse files
Files changed (1) hide show
  1. image_seg.py +467 -0
image_seg.py ADDED
@@ -0,0 +1,467 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from io import BytesIO
4
+ from PIL import Image
5
+ import requests
6
+ import hashlib
7
+ from concurrent.futures import ThreadPoolExecutor, as_completed
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from skimage import io
13
+ from torchvision.transforms.functional import normalize
14
+ import base64
15
+ import time
16
+ from scipy.ndimage import maximum_filter
17
+
18
+
19
+ class REBNCONV(nn.Module):
20
+ def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
21
+ super(REBNCONV, self).__init__()
22
+
23
+ self.conv_s1 = nn.Conv2d(in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride)
24
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
25
+ self.relu_s1 = nn.ReLU(inplace=True)
26
+
27
+ def forward(self, x):
28
+
29
+ hx = x
30
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
31
+
32
+ return xout
33
+
34
+
35
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
36
+ def _upsample_like(src, tar):
37
+
38
+ src = F.interpolate(src, size=tar.shape[2:], mode="bilinear")
39
+
40
+ return src
41
+
42
+
43
+ ### RSU-7 ###
44
+ class RSU7(nn.Module):
45
+
46
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
47
+ super(RSU7, self).__init__()
48
+
49
+ self.in_ch = in_ch
50
+ self.mid_ch = mid_ch
51
+ self.out_ch = out_ch
52
+
53
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
54
+
55
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
56
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
57
+
58
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
59
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
60
+
61
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
62
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
63
+
64
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
65
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
66
+
67
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
68
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
69
+
70
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
71
+
72
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
73
+
74
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
75
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
76
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
77
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
78
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
79
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
80
+
81
+ def forward(self, x):
82
+ b, c, h, w = x.shape
83
+
84
+ hx = x
85
+ hxin = self.rebnconvin(hx)
86
+
87
+ hx1 = self.rebnconv1(hxin)
88
+ hx = self.pool1(hx1)
89
+
90
+ hx2 = self.rebnconv2(hx)
91
+ hx = self.pool2(hx2)
92
+
93
+ hx3 = self.rebnconv3(hx)
94
+ hx = self.pool3(hx3)
95
+
96
+ hx4 = self.rebnconv4(hx)
97
+ hx = self.pool4(hx4)
98
+
99
+ hx5 = self.rebnconv5(hx)
100
+ hx = self.pool5(hx5)
101
+
102
+ hx6 = self.rebnconv6(hx)
103
+
104
+ hx7 = self.rebnconv7(hx6)
105
+
106
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
107
+ hx6dup = _upsample_like(hx6d, hx5)
108
+
109
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
110
+ hx5dup = _upsample_like(hx5d, hx4)
111
+
112
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
113
+ hx4dup = _upsample_like(hx4d, hx3)
114
+
115
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
116
+ hx3dup = _upsample_like(hx3d, hx2)
117
+
118
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
119
+ hx2dup = _upsample_like(hx2d, hx1)
120
+
121
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
122
+
123
+ return hx1d + hxin
124
+
125
+
126
+ ### RSU-6 ###
127
+ class RSU6(nn.Module):
128
+
129
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
130
+ super(RSU6, self).__init__()
131
+
132
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
133
+
134
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
135
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
136
+
137
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
138
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
139
+
140
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
141
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
142
+
143
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
144
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
145
+
146
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
147
+
148
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
149
+
150
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
151
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
152
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
153
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
154
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
155
+
156
+ def forward(self, x):
157
+
158
+ hx = x
159
+
160
+ hxin = self.rebnconvin(hx)
161
+
162
+ hx1 = self.rebnconv1(hxin)
163
+ hx = self.pool1(hx1)
164
+
165
+ hx2 = self.rebnconv2(hx)
166
+ hx = self.pool2(hx2)
167
+
168
+ hx3 = self.rebnconv3(hx)
169
+ hx = self.pool3(hx3)
170
+
171
+ hx4 = self.rebnconv4(hx)
172
+ hx = self.pool4(hx4)
173
+
174
+ hx5 = self.rebnconv5(hx)
175
+
176
+ hx6 = self.rebnconv6(hx5)
177
+
178
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
179
+ hx5dup = _upsample_like(hx5d, hx4)
180
+
181
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
182
+ hx4dup = _upsample_like(hx4d, hx3)
183
+
184
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
185
+ hx3dup = _upsample_like(hx3d, hx2)
186
+
187
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
188
+ hx2dup = _upsample_like(hx2d, hx1)
189
+
190
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
191
+
192
+ return hx1d + hxin
193
+
194
+
195
+ ### RSU-5 ###
196
+ class RSU5(nn.Module):
197
+
198
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
199
+ super(RSU5, self).__init__()
200
+
201
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
202
+
203
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
204
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
205
+
206
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
207
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
208
+
209
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
210
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
211
+
212
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
213
+
214
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
215
+
216
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
217
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
218
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
219
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
220
+
221
+ def forward(self, x):
222
+
223
+ hx = x
224
+
225
+ hxin = self.rebnconvin(hx)
226
+
227
+ hx1 = self.rebnconv1(hxin)
228
+ hx = self.pool1(hx1)
229
+
230
+ hx2 = self.rebnconv2(hx)
231
+ hx = self.pool2(hx2)
232
+
233
+ hx3 = self.rebnconv3(hx)
234
+ hx = self.pool3(hx3)
235
+
236
+ hx4 = self.rebnconv4(hx)
237
+
238
+ hx5 = self.rebnconv5(hx4)
239
+
240
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
241
+ hx4dup = _upsample_like(hx4d, hx3)
242
+
243
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
244
+ hx3dup = _upsample_like(hx3d, hx2)
245
+
246
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
247
+ hx2dup = _upsample_like(hx2d, hx1)
248
+
249
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
250
+
251
+ return hx1d + hxin
252
+
253
+
254
+ ### RSU-4 ###
255
+ class RSU4(nn.Module):
256
+
257
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
258
+ super(RSU4, self).__init__()
259
+
260
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
261
+
262
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
263
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
264
+
265
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
266
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
267
+
268
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
269
+
270
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
271
+
272
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
273
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
274
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
275
+
276
+ def forward(self, x):
277
+
278
+ hx = x
279
+
280
+ hxin = self.rebnconvin(hx)
281
+
282
+ hx1 = self.rebnconv1(hxin)
283
+ hx = self.pool1(hx1)
284
+
285
+ hx2 = self.rebnconv2(hx)
286
+ hx = self.pool2(hx2)
287
+
288
+ hx3 = self.rebnconv3(hx)
289
+
290
+ hx4 = self.rebnconv4(hx3)
291
+
292
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
293
+ hx3dup = _upsample_like(hx3d, hx2)
294
+
295
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
296
+ hx2dup = _upsample_like(hx2d, hx1)
297
+
298
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
299
+
300
+ return hx1d + hxin
301
+
302
+
303
+ ### RSU-4F ###
304
+ class RSU4F(nn.Module):
305
+
306
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
307
+ super(RSU4F, self).__init__()
308
+
309
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
310
+
311
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
312
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
313
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
314
+
315
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
316
+
317
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
318
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
319
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
320
+
321
+ def forward(self, x):
322
+
323
+ hx = x
324
+
325
+ hxin = self.rebnconvin(hx)
326
+
327
+ hx1 = self.rebnconv1(hxin)
328
+ hx2 = self.rebnconv2(hx1)
329
+ hx3 = self.rebnconv3(hx2)
330
+
331
+ hx4 = self.rebnconv4(hx3)
332
+
333
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
334
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
335
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
336
+
337
+ return hx1d + hxin
338
+
339
+
340
+ class myrebnconv(nn.Module):
341
+ def __init__(self, in_ch=3, out_ch=1, kernel_size=3, stride=1, padding=1, dilation=1, groups=1):
342
+ super(myrebnconv, self).__init__()
343
+
344
+ self.conv = nn.Conv2d(
345
+ in_ch, out_ch, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups
346
+ )
347
+ self.bn = nn.BatchNorm2d(out_ch)
348
+ self.rl = nn.ReLU(inplace=True)
349
+
350
+ def forward(self, x):
351
+ return self.rl(self.bn(self.conv(x)))
352
+
353
+
354
+ class BriaRMBG(nn.Module):
355
+
356
+ def __init__(self, in_ch=3, out_ch=1):
357
+ super(BriaRMBG, self).__init__()
358
+
359
+ self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
360
+ self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
361
+
362
+ self.stage1 = RSU7(64, 32, 64)
363
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
364
+
365
+ self.stage2 = RSU6(64, 32, 128)
366
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
367
+
368
+ self.stage3 = RSU5(128, 64, 256)
369
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
370
+
371
+ self.stage4 = RSU4(256, 128, 512)
372
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
373
+
374
+ self.stage5 = RSU4F(512, 256, 512)
375
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
376
+
377
+ self.stage6 = RSU4F(512, 256, 512)
378
+
379
+ # decoder
380
+ self.stage5d = RSU4F(1024, 256, 512)
381
+ self.stage4d = RSU4(1024, 128, 256)
382
+ self.stage3d = RSU5(512, 64, 128)
383
+ self.stage2d = RSU6(256, 32, 64)
384
+ self.stage1d = RSU7(128, 16, 64)
385
+
386
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
387
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
388
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
389
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
390
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
391
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
392
+
393
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
394
+
395
+ def forward(self, x):
396
+
397
+ hx = x
398
+
399
+ hxin = self.conv_in(hx)
400
+ # hx = self.pool_in(hxin)
401
+
402
+ # stage 1
403
+ hx1 = self.stage1(hxin)
404
+ hx = self.pool12(hx1)
405
+
406
+ # stage 2
407
+ hx2 = self.stage2(hx)
408
+ hx = self.pool23(hx2)
409
+
410
+ # stage 3
411
+ hx3 = self.stage3(hx)
412
+ hx = self.pool34(hx3)
413
+
414
+ # stage 4
415
+ hx4 = self.stage4(hx)
416
+ hx = self.pool45(hx4)
417
+
418
+ # stage 5
419
+ hx5 = self.stage5(hx)
420
+ hx = self.pool56(hx5)
421
+
422
+ # stage 6
423
+ hx6 = self.stage6(hx)
424
+ hx6up = _upsample_like(hx6, hx5)
425
+
426
+ # -------------------- decoder --------------------
427
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
428
+ hx5dup = _upsample_like(hx5d, hx4)
429
+
430
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
431
+ hx4dup = _upsample_like(hx4d, hx3)
432
+
433
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
434
+ hx3dup = _upsample_like(hx3d, hx2)
435
+
436
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
437
+ hx2dup = _upsample_like(hx2d, hx1)
438
+
439
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
440
+
441
+ # side output
442
+ d1 = self.side1(hx1d)
443
+ d1 = _upsample_like(d1, x)
444
+
445
+ d2 = self.side2(hx2d)
446
+ d2 = _upsample_like(d2, x)
447
+
448
+ d3 = self.side3(hx3d)
449
+ d3 = _upsample_like(d3, x)
450
+
451
+ d4 = self.side4(hx4d)
452
+ d4 = _upsample_like(d4, x)
453
+
454
+ d5 = self.side5(hx5d)
455
+ d5 = _upsample_like(d5, x)
456
+
457
+ d6 = self.side6(hx6)
458
+ d6 = _upsample_like(d6, x)
459
+
460
+ return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)], [
461
+ hx1d,
462
+ hx2d,
463
+ hx3d,
464
+ hx4d,
465
+ hx5d,
466
+ hx6,
467
+ ]