schirrmacher commited on
Commit
80fd191
·
verified ·
1 Parent(s): 78444e2

Upload folder using huggingface_hub

Browse files
Files changed (8) hide show
  1. .DS_Store +0 -0
  2. .gitattributes +1 -0
  3. README.md +6 -5
  4. app.py +92 -0
  5. input.png +3 -0
  6. ormbg.pth +3 -0
  7. ormbg.py +473 -0
  8. requirements.txt +10 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ input.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: Ormbg
3
- emoji: 📊
4
- colorFrom: indigo
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 4.31.5
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: ORMBG
3
+ emoji: 💻
4
+ colorFrom: red
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 4.16.0
8
  app_file: app.py
9
  pinned: false
10
+ license: apache-2.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import gradio as gr
5
+ from ormbg import ORMBG
6
+ from PIL import Image
7
+
8
+
9
+ def preprocess_image(im: np.ndarray, model_input_size: list) -> torch.Tensor:
10
+ if len(im.shape) < 3:
11
+ im = im[:, :, np.newaxis]
12
+ im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
13
+ im_tensor = F.interpolate(
14
+ torch.unsqueeze(im_tensor, 0), size=model_input_size, mode="bilinear"
15
+ ).type(torch.uint8)
16
+ image = torch.divide(im_tensor, 255.0)
17
+ return image
18
+
19
+
20
+ def postprocess_image(result: torch.Tensor, im_size: list) -> np.ndarray:
21
+ result = torch.squeeze(F.interpolate(result, size=im_size, mode="bilinear"), 0)
22
+ ma = torch.max(result)
23
+ mi = torch.min(result)
24
+ result = (result - mi) / (ma - mi)
25
+ im_array = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8)
26
+ im_array = np.squeeze(im_array)
27
+ return im_array
28
+
29
+
30
+ def inference(orig_image):
31
+
32
+ model_path = "ormbg.pth"
33
+
34
+ net = ORMBG()
35
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
36
+
37
+ if torch.cuda.is_available():
38
+ net.load_state_dict(torch.load(model_path))
39
+ net = net.cuda()
40
+ else:
41
+ net.load_state_dict(torch.load(model_path, map_location="cpu"))
42
+ net.eval()
43
+
44
+ model_input_size = [1024, 1024]
45
+ orig_im_size = orig_image.shape[0:2]
46
+ image = preprocess_image(orig_image, model_input_size).to(device)
47
+
48
+ result = net(image)
49
+
50
+ # post process
51
+ result_image = postprocess_image(result[0][0], orig_im_size)
52
+
53
+ # save result
54
+ pil_im = Image.fromarray(result_image)
55
+ no_bg_image = Image.new("RGBA", pil_im.size, (0, 0, 0, 0))
56
+ no_bg_image.paste(orig_image, mask=pil_im)
57
+
58
+ return no_bg_image
59
+
60
+
61
+ gr.Markdown("## Open Remove Background Model (ormbg)")
62
+ gr.HTML(
63
+ """
64
+ <p style="margin-bottom: 10px; font-size: 94%">
65
+ This is a demo for Open Remove Background Model (ormbg) that using
66
+ <a href="https://huggingface.co/schirrmacher/ormbg" target="_blank">Open Remove Background Model (ormbg) model</a> as backbone.
67
+ </p>
68
+ """
69
+ )
70
+ title = "Background Removal"
71
+ description = r"""
72
+ This model is a fully open-source background remover optimized for images with humans.
73
+
74
+ It is based on <a href='https://github.com/xuebinqin/DIS' target='_blank'>Highly Accurate Dichotomous Image Segmentation research</a>.
75
+
76
+ You can find more about the model <a href='https://huggingface.co/schirrmacher/ormbg' target='_blank'>here</a>.
77
+ """
78
+ examples = [
79
+ ["./input.png"],
80
+ ]
81
+
82
+ demo = gr.Interface(
83
+ fn=inference,
84
+ inputs="image",
85
+ outputs="image",
86
+ examples=examples,
87
+ title=title,
88
+ description=description,
89
+ )
90
+
91
+ if __name__ == "__main__":
92
+ demo.launch(share=False)
input.png ADDED

Git LFS Details

  • SHA256: 42c8627c1ada7b69ef8561fcb5611cd8aa08af5eed211379a2619960524639c5
  • Pointer size: 132 Bytes
  • Size of remote file: 4.83 MB
ormbg.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba387a8348526875024f59aa97d23af9cacfff77abf4e9af14332bf477c088fa
3
+ size 176719216
ormbg.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ # https://github.com/xuebinqin/DIS/blob/main/IS-Net/models/isnet.py
6
+
7
+
8
+ class REBNCONV(nn.Module):
9
+ def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
10
+ super(REBNCONV, self).__init__()
11
+
12
+ self.conv_s1 = nn.Conv2d(
13
+ in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
14
+ )
15
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
16
+ self.relu_s1 = nn.ReLU(inplace=True)
17
+
18
+ def forward(self, x):
19
+
20
+ hx = x
21
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
22
+
23
+ return xout
24
+
25
+
26
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
27
+ def _upsample_like(src, tar):
28
+
29
+ src = F.interpolate(src, size=tar.shape[2:], mode="bilinear")
30
+
31
+ return src
32
+
33
+
34
+ ### RSU-7 ###
35
+ class RSU7(nn.Module):
36
+
37
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
38
+ super(RSU7, self).__init__()
39
+
40
+ self.in_ch = in_ch
41
+ self.mid_ch = mid_ch
42
+ self.out_ch = out_ch
43
+
44
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
45
+
46
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
47
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
48
+
49
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
50
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
51
+
52
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
53
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
54
+
55
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
56
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
57
+
58
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
59
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
60
+
61
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
62
+
63
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
64
+
65
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
66
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
67
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
68
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
69
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
70
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
71
+
72
+ def forward(self, x):
73
+ b, c, h, w = x.shape
74
+
75
+ hx = x
76
+ hxin = self.rebnconvin(hx)
77
+
78
+ hx1 = self.rebnconv1(hxin)
79
+ hx = self.pool1(hx1)
80
+
81
+ hx2 = self.rebnconv2(hx)
82
+ hx = self.pool2(hx2)
83
+
84
+ hx3 = self.rebnconv3(hx)
85
+ hx = self.pool3(hx3)
86
+
87
+ hx4 = self.rebnconv4(hx)
88
+ hx = self.pool4(hx4)
89
+
90
+ hx5 = self.rebnconv5(hx)
91
+ hx = self.pool5(hx5)
92
+
93
+ hx6 = self.rebnconv6(hx)
94
+
95
+ hx7 = self.rebnconv7(hx6)
96
+
97
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
98
+ hx6dup = _upsample_like(hx6d, hx5)
99
+
100
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
101
+ hx5dup = _upsample_like(hx5d, hx4)
102
+
103
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
104
+ hx4dup = _upsample_like(hx4d, hx3)
105
+
106
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
107
+ hx3dup = _upsample_like(hx3d, hx2)
108
+
109
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
110
+ hx2dup = _upsample_like(hx2d, hx1)
111
+
112
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
113
+
114
+ return hx1d + hxin
115
+
116
+
117
+ ### RSU-6 ###
118
+ class RSU6(nn.Module):
119
+
120
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
121
+ super(RSU6, self).__init__()
122
+
123
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
124
+
125
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
126
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
127
+
128
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
129
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
130
+
131
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
132
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
133
+
134
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
135
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
136
+
137
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
138
+
139
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
140
+
141
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
142
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
143
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
144
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
145
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
146
+
147
+ def forward(self, x):
148
+
149
+ hx = x
150
+
151
+ hxin = self.rebnconvin(hx)
152
+
153
+ hx1 = self.rebnconv1(hxin)
154
+ hx = self.pool1(hx1)
155
+
156
+ hx2 = self.rebnconv2(hx)
157
+ hx = self.pool2(hx2)
158
+
159
+ hx3 = self.rebnconv3(hx)
160
+ hx = self.pool3(hx3)
161
+
162
+ hx4 = self.rebnconv4(hx)
163
+ hx = self.pool4(hx4)
164
+
165
+ hx5 = self.rebnconv5(hx)
166
+
167
+ hx6 = self.rebnconv6(hx5)
168
+
169
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
170
+ hx5dup = _upsample_like(hx5d, hx4)
171
+
172
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
173
+ hx4dup = _upsample_like(hx4d, hx3)
174
+
175
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
176
+ hx3dup = _upsample_like(hx3d, hx2)
177
+
178
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
179
+ hx2dup = _upsample_like(hx2d, hx1)
180
+
181
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
182
+
183
+ return hx1d + hxin
184
+
185
+
186
+ ### RSU-5 ###
187
+ class RSU5(nn.Module):
188
+
189
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
190
+ super(RSU5, self).__init__()
191
+
192
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
193
+
194
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
195
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
196
+
197
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
198
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
199
+
200
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
201
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
202
+
203
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
204
+
205
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
206
+
207
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
208
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
209
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
210
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
211
+
212
+ def forward(self, x):
213
+
214
+ hx = x
215
+
216
+ hxin = self.rebnconvin(hx)
217
+
218
+ hx1 = self.rebnconv1(hxin)
219
+ hx = self.pool1(hx1)
220
+
221
+ hx2 = self.rebnconv2(hx)
222
+ hx = self.pool2(hx2)
223
+
224
+ hx3 = self.rebnconv3(hx)
225
+ hx = self.pool3(hx3)
226
+
227
+ hx4 = self.rebnconv4(hx)
228
+
229
+ hx5 = self.rebnconv5(hx4)
230
+
231
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
232
+ hx4dup = _upsample_like(hx4d, hx3)
233
+
234
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
235
+ hx3dup = _upsample_like(hx3d, hx2)
236
+
237
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
238
+ hx2dup = _upsample_like(hx2d, hx1)
239
+
240
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
241
+
242
+ return hx1d + hxin
243
+
244
+
245
+ ### RSU-4 ###
246
+ class RSU4(nn.Module):
247
+
248
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
249
+ super(RSU4, self).__init__()
250
+
251
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
252
+
253
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
254
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
255
+
256
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
257
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
258
+
259
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
260
+
261
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
262
+
263
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
264
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
265
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
266
+
267
+ def forward(self, x):
268
+
269
+ hx = x
270
+
271
+ hxin = self.rebnconvin(hx)
272
+
273
+ hx1 = self.rebnconv1(hxin)
274
+ hx = self.pool1(hx1)
275
+
276
+ hx2 = self.rebnconv2(hx)
277
+ hx = self.pool2(hx2)
278
+
279
+ hx3 = self.rebnconv3(hx)
280
+
281
+ hx4 = self.rebnconv4(hx3)
282
+
283
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
284
+ hx3dup = _upsample_like(hx3d, hx2)
285
+
286
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
287
+ hx2dup = _upsample_like(hx2d, hx1)
288
+
289
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
290
+
291
+ return hx1d + hxin
292
+
293
+
294
+ ### RSU-4F ###
295
+ class RSU4F(nn.Module):
296
+
297
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
298
+ super(RSU4F, self).__init__()
299
+
300
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
301
+
302
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
303
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
304
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
305
+
306
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
307
+
308
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
309
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
310
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
311
+
312
+ def forward(self, x):
313
+
314
+ hx = x
315
+
316
+ hxin = self.rebnconvin(hx)
317
+
318
+ hx1 = self.rebnconv1(hxin)
319
+ hx2 = self.rebnconv2(hx1)
320
+ hx3 = self.rebnconv3(hx2)
321
+
322
+ hx4 = self.rebnconv4(hx3)
323
+
324
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
325
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
326
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
327
+
328
+ return hx1d + hxin
329
+
330
+
331
+ class myrebnconv(nn.Module):
332
+ def __init__(
333
+ self,
334
+ in_ch=3,
335
+ out_ch=1,
336
+ kernel_size=3,
337
+ stride=1,
338
+ padding=1,
339
+ dilation=1,
340
+ groups=1,
341
+ ):
342
+ super(myrebnconv, self).__init__()
343
+
344
+ self.conv = nn.Conv2d(
345
+ in_ch,
346
+ out_ch,
347
+ kernel_size=kernel_size,
348
+ stride=stride,
349
+ padding=padding,
350
+ dilation=dilation,
351
+ groups=groups,
352
+ )
353
+ self.bn = nn.BatchNorm2d(out_ch)
354
+ self.rl = nn.ReLU(inplace=True)
355
+
356
+ def forward(self, x):
357
+ return self.rl(self.bn(self.conv(x)))
358
+
359
+
360
+ class ORMBG(nn.Module):
361
+
362
+ def __init__(self, in_ch=3, out_ch=1):
363
+ super(ORMBG, self).__init__()
364
+
365
+ self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
366
+ self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
367
+
368
+ self.stage1 = RSU7(64, 32, 64)
369
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
370
+
371
+ self.stage2 = RSU6(64, 32, 128)
372
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
373
+
374
+ self.stage3 = RSU5(128, 64, 256)
375
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
376
+
377
+ self.stage4 = RSU4(256, 128, 512)
378
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
379
+
380
+ self.stage5 = RSU4F(512, 256, 512)
381
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
382
+
383
+ self.stage6 = RSU4F(512, 256, 512)
384
+
385
+ # decoder
386
+ self.stage5d = RSU4F(1024, 256, 512)
387
+ self.stage4d = RSU4(1024, 128, 256)
388
+ self.stage3d = RSU5(512, 64, 128)
389
+ self.stage2d = RSU6(256, 32, 64)
390
+ self.stage1d = RSU7(128, 16, 64)
391
+
392
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
393
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
394
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
395
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
396
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
397
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
398
+
399
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
400
+
401
+ def forward(self, x):
402
+
403
+ hx = x
404
+
405
+ hxin = self.conv_in(hx)
406
+ # hx = self.pool_in(hxin)
407
+
408
+ # stage 1
409
+ hx1 = self.stage1(hxin)
410
+ hx = self.pool12(hx1)
411
+
412
+ # stage 2
413
+ hx2 = self.stage2(hx)
414
+ hx = self.pool23(hx2)
415
+
416
+ # stage 3
417
+ hx3 = self.stage3(hx)
418
+ hx = self.pool34(hx3)
419
+
420
+ # stage 4
421
+ hx4 = self.stage4(hx)
422
+ hx = self.pool45(hx4)
423
+
424
+ # stage 5
425
+ hx5 = self.stage5(hx)
426
+ hx = self.pool56(hx5)
427
+
428
+ # stage 6
429
+ hx6 = self.stage6(hx)
430
+ hx6up = _upsample_like(hx6, hx5)
431
+
432
+ # -------------------- decoder --------------------
433
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
434
+ hx5dup = _upsample_like(hx5d, hx4)
435
+
436
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
437
+ hx4dup = _upsample_like(hx4d, hx3)
438
+
439
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
440
+ hx3dup = _upsample_like(hx3d, hx2)
441
+
442
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
443
+ hx2dup = _upsample_like(hx2d, hx1)
444
+
445
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
446
+
447
+ # side output
448
+ d1 = self.side1(hx1d)
449
+ d1 = _upsample_like(d1, x)
450
+
451
+ d2 = self.side2(hx2d)
452
+ d2 = _upsample_like(d2, x)
453
+
454
+ d3 = self.side3(hx3d)
455
+ d3 = _upsample_like(d3, x)
456
+
457
+ d4 = self.side4(hx4d)
458
+ d4 = _upsample_like(d4, x)
459
+
460
+ d5 = self.side5(hx5d)
461
+ d5 = _upsample_like(d5, x)
462
+
463
+ d6 = self.side6(hx6)
464
+ d6 = _upsample_like(d6, x)
465
+
466
+ return [
467
+ F.sigmoid(d1),
468
+ F.sigmoid(d2),
469
+ F.sigmoid(d3),
470
+ F.sigmoid(d4),
471
+ F.sigmoid(d5),
472
+ F.sigmoid(d6),
473
+ ], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ gradio_imageslider
3
+ torch
4
+ torchvision
5
+ scikit-image
6
+ pillow
7
+ numpy
8
+ typing
9
+ gitpython
10
+ huggingface_hub