zhangshengdong commited on
Commit
ab4b0b8
·
1 Parent(s): d78d3dc

Upload smp_deeplabv3plus.py

Browse files
Files changed (1) hide show
  1. NETS/smp_deeplabv3plus.py +17 -0
NETS/smp_deeplabv3plus.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import segmentation_models_pytorch as smp
3
+
4
+
5
+ class Net(nn.Module):
6
+ def __init__(self, class_num, in_channels=4, encoder_name="resnet34"):
7
+ super(Net, self).__init__()
8
+ self.net = smp.deeplabv3.DeepLabV3Plus(
9
+ in_channels=in_channels,
10
+ classes=class_num,
11
+ encoder_name=encoder_name,
12
+ )
13
+
14
+ def forward(self, x):
15
+ x = self.net(x)
16
+
17
+ return x