from torch import nn import segmentation_models_pytorch as smp class Net(nn.Module): def __init__(self, class_num, in_channels=4, encoder_name="resnet34"): super(Net, self).__init__() self.net = smp.deeplabv3.DeepLabV3Plus( in_channels=in_channels, classes=class_num, encoder_name=encoder_name, ) def forward(self, x): x = self.net(x) return x