shunk031 commited on
Commit
ee3d6e8
1 Parent(s): 2637cc0

Upload model

Browse files
Files changed (3) hide show
  1. config.json +0 -1
  2. configuration_basnet.py +1 -4
  3. modeling_basnet.py +47 -21
config.json CHANGED
@@ -9,7 +9,6 @@
9
  "kernel_size": 3,
10
  "model_type": "basnet",
11
  "n_channels": 3,
12
- "resnet_model": "microsoft/resnet-34",
13
  "torch_dtype": "float32",
14
  "transformers_version": "4.42.4"
15
  }
 
9
  "kernel_size": 3,
10
  "model_type": "basnet",
11
  "n_channels": 3,
 
12
  "torch_dtype": "float32",
13
  "transformers_version": "4.42.4"
14
  }
configuration_basnet.py CHANGED
@@ -6,13 +6,10 @@ class BASNetConfig(PretrainedConfig):
6
 
7
  def __init__(
8
  self,
9
- resnet_model: str = "microsoft/resnet-34",
10
  n_channels: int = 3,
11
  kernel_size: int = 3,
12
  **kwargs,
13
  ) -> None:
14
  super().__init__(**kwargs)
15
- self.resnet_model = resnet_model
16
  self.n_channels = n_channels
17
-
18
- self.kernel_size = 3
 
6
 
7
  def __init__(
8
  self,
 
9
  n_channels: int = 3,
10
  kernel_size: int = 3,
11
  **kwargs,
12
  ) -> None:
13
  super().__init__(**kwargs)
 
14
  self.n_channels = n_channels
15
+ self.kernel_size = kernel_size
 
modeling_basnet.py CHANGED
@@ -1,16 +1,30 @@
1
  import logging
2
- from typing import Optional, Tuple
 
3
 
4
  import torch
5
  import torch.nn as nn
6
  import torchvision
7
  from transformers.modeling_utils import PreTrainedModel
 
8
 
9
  from .configuration_basnet import BASNetConfig
10
 
11
  logger = logging.getLogger(__name__)
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  class RefUnet(nn.Module):
15
  def __init__(self, in_ch: int, inc_ch: int) -> None:
16
  super().__init__()
@@ -352,17 +366,8 @@ class BASNetModel(PreTrainedModel):
352
  self.post_init()
353
 
354
  def forward(
355
- self, pixel_values: torch.Tensor
356
- ) -> Tuple[
357
- torch.Tensor,
358
- torch.Tensor,
359
- torch.Tensor,
360
- torch.Tensor,
361
- torch.Tensor,
362
- torch.Tensor,
363
- torch.Tensor,
364
- torch.Tensor,
365
- ]:
366
  hx = pixel_values
367
 
368
  ## -------------Encoder-------------
@@ -452,15 +457,36 @@ class BASNetModel(PreTrainedModel):
452
  ## -------------Refine Module-------------
453
  dout = self.refunet(d1) # 256
454
 
455
- return (
456
- torch.sigmoid(dout),
457
- torch.sigmoid(d1),
458
- torch.sigmoid(d2),
459
- torch.sigmoid(d3),
460
- torch.sigmoid(d4),
461
- torch.sigmoid(d5),
462
- torch.sigmoid(d6),
463
- torch.sigmoid(db),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
  )
465
 
466
 
 
1
  import logging
2
+ from dataclasses import dataclass
3
+ from typing import Optional, Tuple, Union
4
 
5
  import torch
6
  import torch.nn as nn
7
  import torchvision
8
  from transformers.modeling_utils import PreTrainedModel
9
+ from transformers.utils import ModelOutput
10
 
11
  from .configuration_basnet import BASNetConfig
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
 
16
+ @dataclass
17
+ class BASNetModelOutput(ModelOutput):
18
+ dout: torch.Tensor
19
+ d1: Optional[torch.Tensor] = None
20
+ d2: Optional[torch.Tensor] = None
21
+ d3: Optional[torch.Tensor] = None
22
+ d4: Optional[torch.Tensor] = None
23
+ d5: Optional[torch.Tensor] = None
24
+ d6: Optional[torch.Tensor] = None
25
+ db: Optional[torch.Tensor] = None
26
+
27
+
28
  class RefUnet(nn.Module):
29
  def __init__(self, in_ch: int, inc_ch: int) -> None:
30
  super().__init__()
 
366
  self.post_init()
367
 
368
  def forward(
369
+ self, pixel_values: torch.Tensor, return_dict: Optional[bool] = None
370
+ ) -> Union[Tuple, BASNetModelOutput]:
 
 
 
 
 
 
 
 
 
371
  hx = pixel_values
372
 
373
  ## -------------Encoder-------------
 
457
  ## -------------Refine Module-------------
458
  dout = self.refunet(d1) # 256
459
 
460
+ dout_act = torch.sigmoid(dout)
461
+ d1_act = torch.sigmoid(d1)
462
+ d2_act = torch.sigmoid(d2)
463
+ d3_act = torch.sigmoid(d3)
464
+ d4_act = torch.sigmoid(d4)
465
+ d5_act = torch.sigmoid(d5)
466
+ d6_act = torch.sigmoid(d6)
467
+ db_act = torch.sigmoid(db)
468
+
469
+ if not return_dict:
470
+ return (
471
+ dout_act,
472
+ d1_act,
473
+ d2_act,
474
+ d3_act,
475
+ d4_act,
476
+ d5_act,
477
+ d6_act,
478
+ db_act,
479
+ )
480
+
481
+ return BASNetModelOutput(
482
+ dout=dout_act,
483
+ d1=d1_act,
484
+ d2=d2_act,
485
+ d3=d3_act,
486
+ d4=d4_act,
487
+ d5=d5_act,
488
+ d6=d6_act,
489
+ db=db_act,
490
  )
491
 
492