shunk031 commited on
Commit
695da21
1 Parent(s): 34c0c8a

Upload model

Browse files
Files changed (5) hide show
  1. README.md +199 -0
  2. config.json +15 -0
  3. configuration_basnet.py +18 -0
  4. model.safetensors +3 -0
  5. modeling_basnet.py +481 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BASNetModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_basnet.BASNetConfig",
7
+ "AutoModel": "modeling_basnet.BASNetModel"
8
+ },
9
+ "kernel_size": 3,
10
+ "model_type": "basnet",
11
+ "n_channels": 3,
12
+ "resnet_model": "microsoft/resnet-34",
13
+ "torch_dtype": "float32",
14
+ "transformers_version": "4.42.4"
15
+ }
configuration_basnet.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.configuration_utils import PretrainedConfig
2
+
3
+
4
+ class BASNetConfig(PretrainedConfig):
5
+ model_type = "basnet"
6
+
7
+ def __init__(
8
+ self,
9
+ resnet_model: str = "microsoft/resnet-34",
10
+ n_channels: int = 3,
11
+ kernel_size: int = 3,
12
+ **kwargs,
13
+ ) -> None:
14
+ super().__init__(**kwargs)
15
+ self.resnet_model = resnet_model
16
+ self.n_channels = n_channels
17
+
18
+ self.kernel_size = 3
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d56514d048060e04c6df305ebc898d538e7a9ae1a8211c86a9613eef8c16452
3
+ size 348466168
modeling_basnet.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torchvision
7
+ from transformers.modeling_utils import PreTrainedModel
8
+
9
+ from .configuration_basnet import BASNetConfig
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class RefUnet(nn.Module):
15
+ def __init__(self, in_ch: int, inc_ch: int) -> None:
16
+ super().__init__()
17
+
18
+ self.conv0 = nn.Conv2d(in_ch, inc_ch, kernel_size=3, padding=1)
19
+
20
+ self.conv1 = nn.Conv2d(inc_ch, 64, kernel_size=3, padding=1)
21
+ self.bn1 = nn.BatchNorm2d(64)
22
+ self.relu1 = nn.ReLU(inplace=True)
23
+
24
+ self.pool1 = nn.MaxPool2d(2, 2, ceil_mode=True)
25
+
26
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
27
+ self.bn2 = nn.BatchNorm2d(64)
28
+ self.relu2 = nn.ReLU(inplace=True)
29
+
30
+ self.pool2 = nn.MaxPool2d(2, 2, ceil_mode=True)
31
+
32
+ self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
33
+ self.bn3 = nn.BatchNorm2d(64)
34
+ self.relu3 = nn.ReLU(inplace=True)
35
+
36
+ self.pool3 = nn.MaxPool2d(2, 2, ceil_mode=True)
37
+
38
+ self.conv4 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
39
+ self.bn4 = nn.BatchNorm2d(64)
40
+ self.relu4 = nn.ReLU(inplace=True)
41
+
42
+ self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True)
43
+
44
+ #####
45
+
46
+ self.conv5 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
47
+ self.bn5 = nn.BatchNorm2d(64)
48
+ self.relu5 = nn.ReLU(inplace=True)
49
+
50
+ #####
51
+
52
+ self.conv_d4 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
53
+ self.bn_d4 = nn.BatchNorm2d(64)
54
+ self.relu_d4 = nn.ReLU(inplace=True)
55
+
56
+ self.conv_d3 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
57
+ self.bn_d3 = nn.BatchNorm2d(64)
58
+ self.relu_d3 = nn.ReLU(inplace=True)
59
+
60
+ self.conv_d2 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
61
+ self.bn_d2 = nn.BatchNorm2d(64)
62
+ self.relu_d2 = nn.ReLU(inplace=True)
63
+
64
+ self.conv_d1 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
65
+ self.bn_d1 = nn.BatchNorm2d(64)
66
+ self.relu_d1 = nn.ReLU(inplace=True)
67
+
68
+ self.conv_d0 = nn.Conv2d(64, 1, kernel_size=3, padding=1)
69
+
70
+ self.upscore2 = nn.Upsample(
71
+ scale_factor=2, mode="bilinear", align_corners=False
72
+ )
73
+ # self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear')
74
+
75
+ def forward(self, x):
76
+ hx = x
77
+ hx = self.conv0(hx)
78
+
79
+ hx1 = self.relu1(self.bn1(self.conv1(hx)))
80
+ hx = self.pool1(hx1)
81
+
82
+ hx2 = self.relu2(self.bn2(self.conv2(hx)))
83
+ hx = self.pool2(hx2)
84
+
85
+ hx3 = self.relu3(self.bn3(self.conv3(hx)))
86
+ hx = self.pool3(hx3)
87
+
88
+ hx4 = self.relu4(self.bn4(self.conv4(hx)))
89
+ hx = self.pool4(hx4)
90
+
91
+ hx5 = self.relu5(self.bn5(self.conv5(hx)))
92
+
93
+ hx = self.upscore2(hx5)
94
+
95
+ d4 = self.relu_d4(self.bn_d4(self.conv_d4(torch.cat((hx, hx4), 1))))
96
+ hx = self.upscore2(d4)
97
+
98
+ d3 = self.relu_d3(self.bn_d3(self.conv_d3(torch.cat((hx, hx3), 1))))
99
+ hx = self.upscore2(d3)
100
+
101
+ d2 = self.relu_d2(self.bn_d2(self.conv_d2(torch.cat((hx, hx2), 1))))
102
+ hx = self.upscore2(d2)
103
+
104
+ d1 = self.relu_d1(self.bn_d1(self.conv_d1(torch.cat((hx, hx1), 1))))
105
+
106
+ residual = self.conv_d0(d1)
107
+
108
+ return x + residual
109
+
110
+
111
+ def conv3x3(in_planes, out_planes, stride=1) -> nn.Conv2d:
112
+ "3x3 convolution with padding"
113
+ return nn.Conv2d(
114
+ in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
115
+ )
116
+
117
+
118
+ class BasicBlock(nn.Module):
119
+ expansion: int = 1
120
+
121
+ def __init__(self, inplanes: int, planes: int, stride: int = 1, downsample=None):
122
+ super(BasicBlock, self).__init__()
123
+ self.conv1 = conv3x3(inplanes, planes, stride)
124
+ self.bn1 = nn.BatchNorm2d(planes)
125
+ self.relu = nn.ReLU(inplace=True)
126
+ self.conv2 = conv3x3(planes, planes)
127
+ self.bn2 = nn.BatchNorm2d(planes)
128
+ self.downsample = downsample
129
+ self.stride = stride
130
+
131
+ def forward(self, x):
132
+ residual = x
133
+
134
+ out = self.conv1(x)
135
+ out = self.bn1(out)
136
+ out = self.relu(out)
137
+
138
+ out = self.conv2(out)
139
+ out = self.bn2(out)
140
+
141
+ if self.downsample is not None:
142
+ residual = self.downsample(x)
143
+
144
+ out += residual
145
+ out = self.relu(out)
146
+
147
+ return out
148
+
149
+
150
+ class BASNetModel(PreTrainedModel):
151
+ def __init__(self, config: BASNetConfig) -> None:
152
+ super().__init__(config)
153
+
154
+ resnet = torchvision.models.resnet34(
155
+ weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1
156
+ )
157
+
158
+ ## -------------Encoder--------------
159
+
160
+ self.inconv = nn.Conv2d(
161
+ config.n_channels, 64, kernel_size=config.kernel_size, padding=1
162
+ )
163
+ self.inbn = nn.BatchNorm2d(64)
164
+ self.inrelu = nn.ReLU(inplace=True)
165
+
166
+ # stage 1
167
+ self.encoder1 = resnet.layer1 # 256
168
+ # stage 2
169
+ self.encoder2 = resnet.layer2 # 128
170
+ # stage 3
171
+ self.encoder3 = resnet.layer3 # 64
172
+ # stage 4
173
+ self.encoder4 = resnet.layer4 # 32
174
+
175
+ self.pool4 = nn.MaxPool2d(2, 2, ceil_mode=True)
176
+
177
+ # stage 5
178
+ self.resb5_1 = BasicBlock(512, 512)
179
+ self.resb5_2 = BasicBlock(512, 512)
180
+ self.resb5_3 = BasicBlock(512, 512) # 16
181
+
182
+ self.pool5 = nn.MaxPool2d(2, 2, ceil_mode=True)
183
+
184
+ # stage 6
185
+ self.resb6_1 = BasicBlock(512, 512)
186
+ self.resb6_2 = BasicBlock(512, 512)
187
+ self.resb6_3 = BasicBlock(512, 512) # 8
188
+
189
+ ## -------------Bridge--------------
190
+
191
+ # stage Bridge
192
+ self.convbg_1 = nn.Conv2d(
193
+ 512, 512, kernel_size=config.kernel_size, dilation=2, padding=2
194
+ ) # 8
195
+ self.bnbg_1 = nn.BatchNorm2d(512)
196
+ self.relubg_1 = nn.ReLU(inplace=True)
197
+ self.convbg_m = nn.Conv2d(
198
+ 512, 512, kernel_size=config.kernel_size, dilation=2, padding=2
199
+ )
200
+ self.bnbg_m = nn.BatchNorm2d(512)
201
+ self.relubg_m = nn.ReLU(inplace=True)
202
+ self.convbg_2 = nn.Conv2d(
203
+ 512, 512, kernel_size=config.kernel_size, dilation=2, padding=2
204
+ )
205
+ self.bnbg_2 = nn.BatchNorm2d(512)
206
+ self.relubg_2 = nn.ReLU(inplace=True)
207
+
208
+ ## -------------Decoder--------------
209
+
210
+ # stage 6d
211
+ self.conv6d_1 = nn.Conv2d(
212
+ 1024, 512, kernel_size=config.kernel_size, padding=1
213
+ ) # 16
214
+ self.bn6d_1 = nn.BatchNorm2d(512)
215
+ self.relu6d_1 = nn.ReLU(inplace=True)
216
+
217
+ self.conv6d_m = nn.Conv2d(
218
+ 512, 512, kernel_size=config.kernel_size, dilation=2, padding=2
219
+ ) ###
220
+ self.bn6d_m = nn.BatchNorm2d(512)
221
+ self.relu6d_m = nn.ReLU(inplace=True)
222
+
223
+ self.conv6d_2 = nn.Conv2d(
224
+ 512, 512, kernel_size=config.kernel_size, dilation=2, padding=2
225
+ )
226
+ self.bn6d_2 = nn.BatchNorm2d(512)
227
+ self.relu6d_2 = nn.ReLU(inplace=True)
228
+
229
+ # stage 5d
230
+ self.conv5d_1 = nn.Conv2d(
231
+ 1024, 512, kernel_size=config.kernel_size, padding=1
232
+ ) # 16
233
+ self.bn5d_1 = nn.BatchNorm2d(512)
234
+ self.relu5d_1 = nn.ReLU(inplace=True)
235
+
236
+ self.conv5d_m = nn.Conv2d(
237
+ 512, 512, kernel_size=config.kernel_size, padding=1
238
+ ) ###
239
+ self.bn5d_m = nn.BatchNorm2d(512)
240
+ self.relu5d_m = nn.ReLU(inplace=True)
241
+
242
+ self.conv5d_2 = nn.Conv2d(512, 512, kernel_size=config.kernel_size, padding=1)
243
+ self.bn5d_2 = nn.BatchNorm2d(512)
244
+ self.relu5d_2 = nn.ReLU(inplace=True)
245
+
246
+ # stage 4d
247
+ self.conv4d_1 = nn.Conv2d(
248
+ 1024, 512, kernel_size=config.kernel_size, padding=1
249
+ ) # 32
250
+ self.bn4d_1 = nn.BatchNorm2d(512)
251
+ self.relu4d_1 = nn.ReLU(inplace=True)
252
+
253
+ self.conv4d_m = nn.Conv2d(
254
+ 512, 512, kernel_size=config.kernel_size, padding=1
255
+ ) ###
256
+ self.bn4d_m = nn.BatchNorm2d(512)
257
+ self.relu4d_m = nn.ReLU(inplace=True)
258
+
259
+ self.conv4d_2 = nn.Conv2d(512, 256, kernel_size=config.kernel_size, padding=1)
260
+ self.bn4d_2 = nn.BatchNorm2d(256)
261
+ self.relu4d_2 = nn.ReLU(inplace=True)
262
+
263
+ # stage 3d
264
+ self.conv3d_1 = nn.Conv2d(
265
+ 512, 256, kernel_size=config.kernel_size, padding=1
266
+ ) # 64
267
+ self.bn3d_1 = nn.BatchNorm2d(256)
268
+ self.relu3d_1 = nn.ReLU(inplace=True)
269
+
270
+ self.conv3d_m = nn.Conv2d(
271
+ 256, 256, kernel_size=config.kernel_size, padding=1
272
+ ) ###
273
+ self.bn3d_m = nn.BatchNorm2d(256)
274
+ self.relu3d_m = nn.ReLU(inplace=True)
275
+
276
+ self.conv3d_2 = nn.Conv2d(256, 128, kernel_size=config.kernel_size, padding=1)
277
+ self.bn3d_2 = nn.BatchNorm2d(128)
278
+ self.relu3d_2 = nn.ReLU(inplace=True)
279
+
280
+ # stage 2d
281
+
282
+ self.conv2d_1 = nn.Conv2d(
283
+ 256, 128, kernel_size=config.kernel_size, padding=1
284
+ ) # 128
285
+ self.bn2d_1 = nn.BatchNorm2d(128)
286
+ self.relu2d_1 = nn.ReLU(inplace=True)
287
+
288
+ self.conv2d_m = nn.Conv2d(
289
+ 128, 128, kernel_size=config.kernel_size, padding=1
290
+ ) ###
291
+ self.bn2d_m = nn.BatchNorm2d(128)
292
+ self.relu2d_m = nn.ReLU(inplace=True)
293
+
294
+ self.conv2d_2 = nn.Conv2d(128, 64, kernel_size=config.kernel_size, padding=1)
295
+ self.bn2d_2 = nn.BatchNorm2d(64)
296
+ self.relu2d_2 = nn.ReLU(inplace=True)
297
+
298
+ # stage 1d
299
+ self.conv1d_1 = nn.Conv2d(
300
+ 128, 64, kernel_size=config.kernel_size, padding=1
301
+ ) # 256
302
+ self.bn1d_1 = nn.BatchNorm2d(64)
303
+ self.relu1d_1 = nn.ReLU(inplace=True)
304
+
305
+ self.conv1d_m = nn.Conv2d(
306
+ 64, 64, kernel_size=config.kernel_size, padding=1
307
+ ) ###
308
+ self.bn1d_m = nn.BatchNorm2d(64)
309
+ self.relu1d_m = nn.ReLU(inplace=True)
310
+
311
+ self.conv1d_2 = nn.Conv2d(64, 64, kernel_size=config.kernel_size, padding=1)
312
+ self.bn1d_2 = nn.BatchNorm2d(64)
313
+ self.relu1d_2 = nn.ReLU(inplace=True)
314
+
315
+ ## -------------Bilinear Upsampling--------------
316
+ self.upscore6 = nn.Upsample(
317
+ scale_factor=32, mode="bilinear", align_corners=False
318
+ ) ###
319
+ self.upscore5 = nn.Upsample(
320
+ scale_factor=16, mode="bilinear", align_corners=False
321
+ )
322
+ self.upscore4 = nn.Upsample(
323
+ scale_factor=8, mode="bilinear", align_corners=False
324
+ )
325
+ self.upscore3 = nn.Upsample(
326
+ scale_factor=4, mode="bilinear", align_corners=False
327
+ )
328
+ self.upscore2 = nn.Upsample(
329
+ scale_factor=2, mode="bilinear", align_corners=False
330
+ )
331
+
332
+ # self.upscore6 = nn.Upsample(scale_factor=32, mode='bilinear') ###
333
+ # self.upscore5 = nn.Upsample(scale_factor=16, mode='bilinear')
334
+ # self.upscore4 = nn.Upsample(scale_factor=8, mode='bilinear')
335
+ # self.upscore3 = nn.Upsample(scale_factor=4, mode='bilinear')
336
+ # self.upscore2 = nn.Upsample(scale_factor=2, mode='bilinear')
337
+
338
+ ## -------------Side Output--------------
339
+ self.outconvb = nn.Conv2d(512, 1, kernel_size=3, padding=1)
340
+ self.outconv6 = nn.Conv2d(512, 1, kernel_size=3, padding=1)
341
+ self.outconv5 = nn.Conv2d(512, 1, kernel_size=3, padding=1)
342
+ self.outconv4 = nn.Conv2d(256, 1, kernel_size=3, padding=1)
343
+ self.outconv3 = nn.Conv2d(128, 1, kernel_size=3, padding=1)
344
+ self.outconv2 = nn.Conv2d(64, 1, kernel_size=3, padding=1)
345
+ self.outconv1 = nn.Conv2d(64, 1, kernel_size=3, padding=1)
346
+
347
+ ## -------------Refine Module-------------
348
+ self.refunet = RefUnet(1, 64)
349
+
350
+ self.post_init()
351
+
352
+ def forward(
353
+ self, pixel_values: torch.Tensor
354
+ ) -> Tuple[
355
+ torch.Tensor,
356
+ torch.Tensor,
357
+ torch.Tensor,
358
+ torch.Tensor,
359
+ torch.Tensor,
360
+ torch.Tensor,
361
+ torch.Tensor,
362
+ torch.Tensor,
363
+ ]:
364
+ hx = pixel_values
365
+
366
+ ## -------------Encoder-------------
367
+ hx = self.inconv(hx)
368
+ hx = self.inbn(hx)
369
+ hx = self.inrelu(hx)
370
+
371
+ h1 = self.encoder1(hx) # 256
372
+ h2 = self.encoder2(h1) # 128
373
+ h3 = self.encoder3(h2) # 64
374
+ h4 = self.encoder4(h3) # 32
375
+
376
+ hx = self.pool4(h4) # 16
377
+
378
+ hx = self.resb5_1(hx)
379
+ hx = self.resb5_2(hx)
380
+ h5 = self.resb5_3(hx)
381
+
382
+ hx = self.pool5(h5) # 8
383
+
384
+ hx = self.resb6_1(hx)
385
+ hx = self.resb6_2(hx)
386
+ h6 = self.resb6_3(hx)
387
+
388
+ ## -------------Bridge-------------
389
+ hx = self.relubg_1(self.bnbg_1(self.convbg_1(h6))) # 8
390
+ hx = self.relubg_m(self.bnbg_m(self.convbg_m(hx)))
391
+ hbg = self.relubg_2(self.bnbg_2(self.convbg_2(hx)))
392
+
393
+ ## -------------Decoder-------------
394
+
395
+ hx = self.relu6d_1(self.bn6d_1(self.conv6d_1(torch.cat((hbg, h6), 1))))
396
+ hx = self.relu6d_m(self.bn6d_m(self.conv6d_m(hx)))
397
+ hd6 = self.relu6d_2(self.bn5d_2(self.conv6d_2(hx)))
398
+
399
+ hx = self.upscore2(hd6) # 8 -> 16
400
+
401
+ hx = self.relu5d_1(self.bn5d_1(self.conv5d_1(torch.cat((hx, h5), 1))))
402
+ hx = self.relu5d_m(self.bn5d_m(self.conv5d_m(hx)))
403
+ hd5 = self.relu5d_2(self.bn5d_2(self.conv5d_2(hx)))
404
+
405
+ hx = self.upscore2(hd5) # 16 -> 32
406
+
407
+ hx = self.relu4d_1(self.bn4d_1(self.conv4d_1(torch.cat((hx, h4), 1))))
408
+ hx = self.relu4d_m(self.bn4d_m(self.conv4d_m(hx)))
409
+ hd4 = self.relu4d_2(self.bn4d_2(self.conv4d_2(hx)))
410
+
411
+ hx = self.upscore2(hd4) # 32 -> 64
412
+
413
+ hx = self.relu3d_1(self.bn3d_1(self.conv3d_1(torch.cat((hx, h3), 1))))
414
+ hx = self.relu3d_m(self.bn3d_m(self.conv3d_m(hx)))
415
+ hd3 = self.relu3d_2(self.bn3d_2(self.conv3d_2(hx)))
416
+
417
+ hx = self.upscore2(hd3) # 64 -> 128
418
+
419
+ hx = self.relu2d_1(self.bn2d_1(self.conv2d_1(torch.cat((hx, h2), 1))))
420
+ hx = self.relu2d_m(self.bn2d_m(self.conv2d_m(hx)))
421
+ hd2 = self.relu2d_2(self.bn2d_2(self.conv2d_2(hx)))
422
+
423
+ hx = self.upscore2(hd2) # 128 -> 256
424
+
425
+ hx = self.relu1d_1(self.bn1d_1(self.conv1d_1(torch.cat((hx, h1), 1))))
426
+ hx = self.relu1d_m(self.bn1d_m(self.conv1d_m(hx)))
427
+ hd1 = self.relu1d_2(self.bn1d_2(self.conv1d_2(hx)))
428
+
429
+ ## -------------Side Output-------------
430
+ db = self.outconvb(hbg)
431
+ db = self.upscore6(db) # 8->256
432
+
433
+ d6 = self.outconv6(hd6)
434
+ d6 = self.upscore6(d6) # 8->256
435
+
436
+ d5 = self.outconv5(hd5)
437
+ d5 = self.upscore5(d5) # 16->256
438
+
439
+ d4 = self.outconv4(hd4)
440
+ d4 = self.upscore4(d4) # 32->256
441
+
442
+ d3 = self.outconv3(hd3)
443
+ d3 = self.upscore3(d3) # 64->256
444
+
445
+ d2 = self.outconv2(hd2)
446
+ d2 = self.upscore2(d2) # 128->256
447
+
448
+ d1 = self.outconv1(hd1) # 256
449
+
450
+ ## -------------Refine Module-------------
451
+ dout = self.refunet(d1) # 256
452
+
453
+ return (
454
+ torch.sigmoid(dout),
455
+ torch.sigmoid(d1),
456
+ torch.sigmoid(d2),
457
+ torch.sigmoid(d3),
458
+ torch.sigmoid(d4),
459
+ torch.sigmoid(d5),
460
+ torch.sigmoid(d6),
461
+ torch.sigmoid(db),
462
+ )
463
+
464
+
465
+ def convert_from_checkpoint(
466
+ repo_id: str, filename: str, config: Optional[BASNetConfig] = None
467
+ ) -> BASNetModel:
468
+ from huggingface_hub import hf_hub_download
469
+
470
+ checkpoint_path = hf_hub_download(repo_id=repo_id, filename=filename)
471
+
472
+ config = config or BASNetConfig()
473
+ model = BASNetModel(config)
474
+
475
+ logger.info(f"Loading checkpoint from {checkpoint_path}")
476
+ state_dict = torch.load(checkpoint_path)
477
+
478
+ model.load_state_dict(state_dict, strict=True)
479
+ model.eval()
480
+
481
+ return model