schirrmacher commited on
Commit
08aed96
·
verified ·
1 Parent(s): d7b2280

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ example.png filter=lfs diff=lfs merge=lfs -text
37
+ no-background.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -17,7 +17,7 @@ This model is similar to [RMBG-1.4](https://huggingface.co/briaai/RMBG-1.4), but
17
  ## Inference
18
 
19
  ```
20
- test
21
  ```
22
 
23
  ## Training
 
17
  ## Inference
18
 
19
  ```
20
+ python utils/inference.py
21
  ```
22
 
23
  ## Training
example.png ADDED

Git LFS Details

  • SHA256: 42c8627c1ada7b69ef8561fcb5611cd8aa08af5eed211379a2619960524639c5
  • Pointer size: 132 Bytes
  • Size of remote file: 4.83 MB
no-background.png ADDED

Git LFS Details

  • SHA256: 3a5ab25f39d86fb209e8cf20c5b56fa74b4aafcac8bacd95dc39a5efa75ffc00
  • Pointer size: 132 Bytes
  • Size of remote file: 1.12 MB
utils/__pycache__/inference.cpython-312.pyc ADDED
Binary file (2.1 kB). View file
 
utils/__pycache__/ormbg.cpython-312.pyc ADDED
Binary file (20.7 kB). View file
 
utils/inference.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ from skimage import io
5
+ from ormbg import ORMBG
6
+ import torch.nn.functional as F
7
+
8
+
9
+ def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
10
+ if len(im.shape) < 3:
11
+ im = im[:, :, np.newaxis]
12
+ im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
13
+ im_tensor = F.interpolate(
14
+ torch.unsqueeze(im_tensor, 0), size=model_input_size, mode="bilinear"
15
+ ).type(torch.uint8)
16
+ image = torch.divide(im_tensor, 255.0)
17
+ return image
18
+
19
+
20
+ def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
21
+ result = torch.squeeze(F.interpolate(result, size=im_size, mode="bilinear"), 0)
22
+ ma = torch.max(result)
23
+ mi = torch.min(result)
24
+ result = (result - mi) / (ma - mi)
25
+ im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
26
+ im_array = np.squeeze(im_array)
27
+ return im_array
28
+
29
+
30
+ def example_inference():
31
+
32
+ image_path = "example.png"
33
+ result_name = "no-background.png"
34
+
35
+ net = ORMBG()
36
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+
38
+ if torch.cuda.is_available():
39
+ net.load_state_dict(torch.load("models/ormbg.pth"))
40
+ net = net.cuda()
41
+ else:
42
+ net.load_state_dict(torch.load("models/ormbg.pth", map_location="cpu"))
43
+ net.eval()
44
+
45
+ model_input_size = [1024, 1024]
46
+ orig_im = io.imread(image_path)
47
+ orig_im_size = orig_im.shape[0:2]
48
+ image = preprocess_image(orig_im, model_input_size).to(device)
49
+
50
+ result = net(image)
51
+
52
+ # post process
53
+ result_image = postprocess_image(result[0][0], orig_im_size)
54
+
55
+ # save result
56
+ pil_im = Image.fromarray(result_image)
57
+ no_bg_image = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
58
+ orig_image = Image.open(image_path)
59
+ no_bg_image.paste(orig_image, mask=pil_im)
60
+ no_bg_image.save(result_name)
61
+
62
+
63
+ if __name__ == "__main__":
64
+ example_inference()
utils/ormbg.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ # https://github.com/xuebinqin/DIS/blob/main/IS-Net/models/isnet.py
6
+
7
+
8
+ class REBNCONV(nn.Module):
9
+ def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
10
+ super(REBNCONV, self).__init__()
11
+
12
+ self.conv_s1 = nn.Conv2d(
13
+ in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
14
+ )
15
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
16
+ self.relu_s1 = nn.ReLU(inplace=True)
17
+
18
+ def forward(self, x):
19
+
20
+ hx = x
21
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
22
+
23
+ return xout
24
+
25
+
26
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
27
+ def _upsample_like(src, tar):
28
+
29
+ src = F.interpolate(src, size=tar.shape[2:], mode="bilinear")
30
+
31
+ return src
32
+
33
+
34
+ ### RSU-7 ###
35
+ class RSU7(nn.Module):
36
+
37
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
38
+ super(RSU7, self).__init__()
39
+
40
+ self.in_ch = in_ch
41
+ self.mid_ch = mid_ch
42
+ self.out_ch = out_ch
43
+
44
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
45
+
46
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
47
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
48
+
49
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
50
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
51
+
52
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
53
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
54
+
55
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
56
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
57
+
58
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
59
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
60
+
61
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
62
+
63
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
64
+
65
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
66
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
67
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
68
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
69
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
70
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
71
+
72
+ def forward(self, x):
73
+ b, c, h, w = x.shape
74
+
75
+ hx = x
76
+ hxin = self.rebnconvin(hx)
77
+
78
+ hx1 = self.rebnconv1(hxin)
79
+ hx = self.pool1(hx1)
80
+
81
+ hx2 = self.rebnconv2(hx)
82
+ hx = self.pool2(hx2)
83
+
84
+ hx3 = self.rebnconv3(hx)
85
+ hx = self.pool3(hx3)
86
+
87
+ hx4 = self.rebnconv4(hx)
88
+ hx = self.pool4(hx4)
89
+
90
+ hx5 = self.rebnconv5(hx)
91
+ hx = self.pool5(hx5)
92
+
93
+ hx6 = self.rebnconv6(hx)
94
+
95
+ hx7 = self.rebnconv7(hx6)
96
+
97
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
98
+ hx6dup = _upsample_like(hx6d, hx5)
99
+
100
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
101
+ hx5dup = _upsample_like(hx5d, hx4)
102
+
103
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
104
+ hx4dup = _upsample_like(hx4d, hx3)
105
+
106
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
107
+ hx3dup = _upsample_like(hx3d, hx2)
108
+
109
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
110
+ hx2dup = _upsample_like(hx2d, hx1)
111
+
112
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
113
+
114
+ return hx1d + hxin
115
+
116
+
117
+ ### RSU-6 ###
118
+ class RSU6(nn.Module):
119
+
120
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
121
+ super(RSU6, self).__init__()
122
+
123
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
124
+
125
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
126
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
127
+
128
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
129
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
130
+
131
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
132
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
133
+
134
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
135
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
136
+
137
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
138
+
139
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
140
+
141
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
142
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
143
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
144
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
145
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
146
+
147
+ def forward(self, x):
148
+
149
+ hx = x
150
+
151
+ hxin = self.rebnconvin(hx)
152
+
153
+ hx1 = self.rebnconv1(hxin)
154
+ hx = self.pool1(hx1)
155
+
156
+ hx2 = self.rebnconv2(hx)
157
+ hx = self.pool2(hx2)
158
+
159
+ hx3 = self.rebnconv3(hx)
160
+ hx = self.pool3(hx3)
161
+
162
+ hx4 = self.rebnconv4(hx)
163
+ hx = self.pool4(hx4)
164
+
165
+ hx5 = self.rebnconv5(hx)
166
+
167
+ hx6 = self.rebnconv6(hx5)
168
+
169
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
170
+ hx5dup = _upsample_like(hx5d, hx4)
171
+
172
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
173
+ hx4dup = _upsample_like(hx4d, hx3)
174
+
175
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
176
+ hx3dup = _upsample_like(hx3d, hx2)
177
+
178
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
179
+ hx2dup = _upsample_like(hx2d, hx1)
180
+
181
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
182
+
183
+ return hx1d + hxin
184
+
185
+
186
+ ### RSU-5 ###
187
+ class RSU5(nn.Module):
188
+
189
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
190
+ super(RSU5, self).__init__()
191
+
192
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
193
+
194
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
195
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
196
+
197
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
198
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
199
+
200
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
201
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
202
+
203
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
204
+
205
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
206
+
207
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
208
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
209
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
210
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
211
+
212
+ def forward(self, x):
213
+
214
+ hx = x
215
+
216
+ hxin = self.rebnconvin(hx)
217
+
218
+ hx1 = self.rebnconv1(hxin)
219
+ hx = self.pool1(hx1)
220
+
221
+ hx2 = self.rebnconv2(hx)
222
+ hx = self.pool2(hx2)
223
+
224
+ hx3 = self.rebnconv3(hx)
225
+ hx = self.pool3(hx3)
226
+
227
+ hx4 = self.rebnconv4(hx)
228
+
229
+ hx5 = self.rebnconv5(hx4)
230
+
231
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
232
+ hx4dup = _upsample_like(hx4d, hx3)
233
+
234
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
235
+ hx3dup = _upsample_like(hx3d, hx2)
236
+
237
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
238
+ hx2dup = _upsample_like(hx2d, hx1)
239
+
240
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
241
+
242
+ return hx1d + hxin
243
+
244
+
245
+ ### RSU-4 ###
246
+ class RSU4(nn.Module):
247
+
248
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
249
+ super(RSU4, self).__init__()
250
+
251
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
252
+
253
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
254
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
255
+
256
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
257
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
258
+
259
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
260
+
261
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
262
+
263
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
264
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
265
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
266
+
267
+ def forward(self, x):
268
+
269
+ hx = x
270
+
271
+ hxin = self.rebnconvin(hx)
272
+
273
+ hx1 = self.rebnconv1(hxin)
274
+ hx = self.pool1(hx1)
275
+
276
+ hx2 = self.rebnconv2(hx)
277
+ hx = self.pool2(hx2)
278
+
279
+ hx3 = self.rebnconv3(hx)
280
+
281
+ hx4 = self.rebnconv4(hx3)
282
+
283
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
284
+ hx3dup = _upsample_like(hx3d, hx2)
285
+
286
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
287
+ hx2dup = _upsample_like(hx2d, hx1)
288
+
289
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
290
+
291
+ return hx1d + hxin
292
+
293
+
294
+ ### RSU-4F ###
295
+ class RSU4F(nn.Module):
296
+
297
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
298
+ super(RSU4F, self).__init__()
299
+
300
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
301
+
302
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
303
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
304
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
305
+
306
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
307
+
308
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
309
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
310
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
311
+
312
+ def forward(self, x):
313
+
314
+ hx = x
315
+
316
+ hxin = self.rebnconvin(hx)
317
+
318
+ hx1 = self.rebnconv1(hxin)
319
+ hx2 = self.rebnconv2(hx1)
320
+ hx3 = self.rebnconv3(hx2)
321
+
322
+ hx4 = self.rebnconv4(hx3)
323
+
324
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
325
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
326
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
327
+
328
+ return hx1d + hxin
329
+
330
+
331
+ class myrebnconv(nn.Module):
332
+ def __init__(
333
+ self,
334
+ in_ch=3,
335
+ out_ch=1,
336
+ kernel_size=3,
337
+ stride=1,
338
+ padding=1,
339
+ dilation=1,
340
+ groups=1,
341
+ ):
342
+ super(myrebnconv, self).__init__()
343
+
344
+ self.conv = nn.Conv2d(
345
+ in_ch,
346
+ out_ch,
347
+ kernel_size=kernel_size,
348
+ stride=stride,
349
+ padding=padding,
350
+ dilation=dilation,
351
+ groups=groups,
352
+ )
353
+ self.bn = nn.BatchNorm2d(out_ch)
354
+ self.rl = nn.ReLU(inplace=True)
355
+
356
+ def forward(self, x):
357
+ return self.rl(self.bn(self.conv(x)))
358
+
359
+
360
+ class ORMBG(nn.Module):
361
+
362
+ def __init__(self, in_ch=3, out_ch=1):
363
+ super(ORMBG, self).__init__()
364
+
365
+ self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
366
+ self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
367
+
368
+ self.stage1 = RSU7(64, 32, 64)
369
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
370
+
371
+ self.stage2 = RSU6(64, 32, 128)
372
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
373
+
374
+ self.stage3 = RSU5(128, 64, 256)
375
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
376
+
377
+ self.stage4 = RSU4(256, 128, 512)
378
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
379
+
380
+ self.stage5 = RSU4F(512, 256, 512)
381
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
382
+
383
+ self.stage6 = RSU4F(512, 256, 512)
384
+
385
+ # decoder
386
+ self.stage5d = RSU4F(1024, 256, 512)
387
+ self.stage4d = RSU4(1024, 128, 256)
388
+ self.stage3d = RSU5(512, 64, 128)
389
+ self.stage2d = RSU6(256, 32, 64)
390
+ self.stage1d = RSU7(128, 16, 64)
391
+
392
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
393
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
394
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
395
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
396
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
397
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
398
+
399
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
400
+
401
+ def forward(self, x):
402
+
403
+ hx = x
404
+
405
+ hxin = self.conv_in(hx)
406
+ # hx = self.pool_in(hxin)
407
+
408
+ # stage 1
409
+ hx1 = self.stage1(hxin)
410
+ hx = self.pool12(hx1)
411
+
412
+ # stage 2
413
+ hx2 = self.stage2(hx)
414
+ hx = self.pool23(hx2)
415
+
416
+ # stage 3
417
+ hx3 = self.stage3(hx)
418
+ hx = self.pool34(hx3)
419
+
420
+ # stage 4
421
+ hx4 = self.stage4(hx)
422
+ hx = self.pool45(hx4)
423
+
424
+ # stage 5
425
+ hx5 = self.stage5(hx)
426
+ hx = self.pool56(hx5)
427
+
428
+ # stage 6
429
+ hx6 = self.stage6(hx)
430
+ hx6up = _upsample_like(hx6, hx5)
431
+
432
+ # -------------------- decoder --------------------
433
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
434
+ hx5dup = _upsample_like(hx5d, hx4)
435
+
436
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
437
+ hx4dup = _upsample_like(hx4d, hx3)
438
+
439
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
440
+ hx3dup = _upsample_like(hx3d, hx2)
441
+
442
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
443
+ hx2dup = _upsample_like(hx2d, hx1)
444
+
445
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
446
+
447
+ # side output
448
+ d1 = self.side1(hx1d)
449
+ d1 = _upsample_like(d1, x)
450
+
451
+ d2 = self.side2(hx2d)
452
+ d2 = _upsample_like(d2, x)
453
+
454
+ d3 = self.side3(hx3d)
455
+ d3 = _upsample_like(d3, x)
456
+
457
+ d4 = self.side4(hx4d)
458
+ d4 = _upsample_like(d4, x)
459
+
460
+ d5 = self.side5(hx5d)
461
+ d5 = _upsample_like(d5, x)
462
+
463
+ d6 = self.side6(hx6)
464
+ d6 = _upsample_like(d6, x)
465
+
466
+ return [
467
+ F.sigmoid(d1),
468
+ F.sigmoid(d2),
469
+ F.sigmoid(d3),
470
+ F.sigmoid(d4),
471
+ F.sigmoid(d5),
472
+ F.sigmoid(d6),
473
+ ], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
utils/pth_to_onnx.py CHANGED
@@ -1,11 +1,11 @@
1
  import torch
2
  import argparse
3
- from isnet import ISNetDIS
4
 
5
 
6
  def export_to_onnx(model_path, onnx_path):
7
 
8
- net = ISNetDIS()
9
 
10
  if torch.cuda.is_available():
11
  net.load_state_dict(torch.load(model_path))
 
1
  import torch
2
  import argparse
3
+ from ormbg import ORMBG
4
 
5
 
6
  def export_to_onnx(model_path, onnx_path):
7
 
8
+ net = ORMBG()
9
 
10
  if torch.cuda.is_available():
11
  net.load_state_dict(torch.load(model_path))
utils/requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ numpy
2
+ pillow
3
+ scikit-image
4
+ torch
5
+ torchvision
6
+ typing