Spaces:
Sleeping
Sleeping
from torch import nn | |
import segmentation_models_pytorch as smp | |
class Net(nn.Module): | |
def __init__(self, class_num, in_channels=4, encoder_name="resnet34"): | |
super(Net, self).__init__() | |
self.net = smp.deeplabv3.DeepLabV3Plus( | |
in_channels=in_channels, | |
classes=class_num, | |
encoder_name=encoder_name, | |
) | |
def forward(self, x): | |
x = self.net(x) | |
return x | |