|
"""Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch""" |
|
|
|
import os |
|
|
|
import pandas as pd |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from scipy.io import loadmat |
|
from torch.nn.modules import BatchNorm2d |
|
|
|
from . import resnet |
|
from . import mobilenet |
|
|
|
|
|
NUM_CLASS = 150 |
|
base_path = os.path.dirname(os.path.abspath(__file__)) |
|
colors_path = os.path.join(base_path, 'color150.mat') |
|
classes_path = os.path.join(base_path, 'object150_info.csv') |
|
|
|
segm_options = dict(colors=loadmat(colors_path)['colors'], |
|
classes=pd.read_csv(classes_path),) |
|
|
|
|
|
class NormalizeTensor: |
|
def __init__(self, mean, std, inplace=False): |
|
"""Normalize a tensor image with mean and standard deviation. |
|
.. note:: |
|
This transform acts out of place by default, i.e., it does not mutates the input tensor. |
|
See :class:`~torchvision.transforms.Normalize` for more details. |
|
Args: |
|
tensor (Tensor): Tensor image of size (C, H, W) to be normalized. |
|
mean (sequence): Sequence of means for each channel. |
|
std (sequence): Sequence of standard deviations for each channel. |
|
inplace(bool,optional): Bool to make this operation inplace. |
|
Returns: |
|
Tensor: Normalized Tensor image. |
|
""" |
|
|
|
self.mean = mean |
|
self.std = std |
|
self.inplace = inplace |
|
|
|
def __call__(self, tensor): |
|
if not self.inplace: |
|
tensor = tensor.clone() |
|
|
|
dtype = tensor.dtype |
|
mean = torch.as_tensor(self.mean, dtype=dtype, device=tensor.device) |
|
std = torch.as_tensor(self.std, dtype=dtype, device=tensor.device) |
|
tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) |
|
return tensor |
|
|
|
|
|
|
|
class ModelBuilder: |
|
|
|
@staticmethod |
|
def weights_init(m): |
|
classname = m.__class__.__name__ |
|
if classname.find('Conv') != -1: |
|
nn.init.kaiming_normal_(m.weight.data) |
|
elif classname.find('BatchNorm') != -1: |
|
m.weight.data.fill_(1.) |
|
m.bias.data.fill_(1e-4) |
|
|
|
@staticmethod |
|
def build_encoder(arch='resnet50dilated', fc_dim=512, weights=''): |
|
pretrained = True if len(weights) == 0 else False |
|
arch = arch.lower() |
|
if arch == 'mobilenetv2dilated': |
|
orig_mobilenet = mobilenet.__dict__['mobilenetv2'](pretrained=pretrained) |
|
net_encoder = MobileNetV2Dilated(orig_mobilenet, dilate_scale=8) |
|
elif arch == 'resnet18': |
|
orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained) |
|
net_encoder = Resnet(orig_resnet) |
|
elif arch == 'resnet18dilated': |
|
orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained) |
|
net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) |
|
elif arch == 'resnet50dilated': |
|
orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained) |
|
net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) |
|
elif arch == 'resnet50': |
|
orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained) |
|
net_encoder = Resnet(orig_resnet) |
|
else: |
|
raise Exception('Architecture undefined!') |
|
|
|
|
|
|
|
if len(weights) > 0: |
|
print('Loading weights for net_encoder') |
|
net_encoder.load_state_dict( |
|
torch.load(weights, map_location=lambda storage, loc: storage), strict=False) |
|
return net_encoder |
|
|
|
@staticmethod |
|
def build_decoder(arch='ppm_deepsup', |
|
fc_dim=512, num_class=NUM_CLASS, |
|
weights='', use_softmax=False, drop_last_conv=False): |
|
arch = arch.lower() |
|
if arch == 'ppm_deepsup': |
|
net_decoder = PPMDeepsup( |
|
num_class=num_class, |
|
fc_dim=fc_dim, |
|
use_softmax=use_softmax, |
|
drop_last_conv=drop_last_conv) |
|
elif arch == 'c1_deepsup': |
|
net_decoder = C1DeepSup( |
|
num_class=num_class, |
|
fc_dim=fc_dim, |
|
use_softmax=use_softmax, |
|
drop_last_conv=drop_last_conv) |
|
else: |
|
raise Exception('Architecture undefined!') |
|
|
|
net_decoder.apply(ModelBuilder.weights_init) |
|
if len(weights) > 0: |
|
print('Loading weights for net_decoder') |
|
net_decoder.load_state_dict( |
|
torch.load(weights, map_location=lambda storage, loc: storage), strict=False) |
|
return net_decoder |
|
|
|
@staticmethod |
|
def get_decoder(weights_path, arch_encoder, arch_decoder, fc_dim, drop_last_conv, *arts, **kwargs): |
|
path = os.path.join(weights_path, 'ade20k', f'ade20k-{arch_encoder}-{arch_decoder}/decoder_epoch_20.pth') |
|
return ModelBuilder.build_decoder(arch=arch_decoder, fc_dim=fc_dim, weights=path, use_softmax=True, drop_last_conv=drop_last_conv) |
|
|
|
@staticmethod |
|
def get_encoder(weights_path, arch_encoder, arch_decoder, fc_dim, segmentation, |
|
*arts, **kwargs): |
|
if segmentation: |
|
path = os.path.join(weights_path, 'ade20k', f'ade20k-{arch_encoder}-{arch_decoder}/encoder_epoch_20.pth') |
|
else: |
|
path = '' |
|
return ModelBuilder.build_encoder(arch=arch_encoder, fc_dim=fc_dim, weights=path) |
|
|
|
|
|
def conv3x3_bn_relu(in_planes, out_planes, stride=1): |
|
return nn.Sequential( |
|
nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False), |
|
BatchNorm2d(out_planes), |
|
nn.ReLU(inplace=True), |
|
) |
|
|
|
|
|
class SegmentationModule(nn.Module): |
|
def __init__(self, |
|
weights_path, |
|
num_classes=150, |
|
arch_encoder="resnet50dilated", |
|
drop_last_conv=False, |
|
net_enc=None, |
|
net_dec=None, |
|
encode=None, |
|
use_default_normalization=False, |
|
return_feature_maps=False, |
|
return_feature_maps_level=3, |
|
return_feature_maps_only=True, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
self.weights_path = weights_path |
|
self.drop_last_conv = drop_last_conv |
|
self.arch_encoder = arch_encoder |
|
if self.arch_encoder == "resnet50dilated": |
|
self.arch_decoder = "ppm_deepsup" |
|
self.fc_dim = 2048 |
|
elif self.arch_encoder == "mobilenetv2dilated": |
|
self.arch_decoder = "c1_deepsup" |
|
self.fc_dim = 320 |
|
else: |
|
raise NotImplementedError(f"No such arch_encoder={self.arch_encoder}") |
|
model_builder_kwargs = dict(arch_encoder=self.arch_encoder, |
|
arch_decoder=self.arch_decoder, |
|
fc_dim=self.fc_dim, |
|
drop_last_conv=drop_last_conv, |
|
weights_path=self.weights_path) |
|
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
self.encoder = ModelBuilder.get_encoder(**model_builder_kwargs) if net_enc is None else net_enc |
|
self.decoder = ModelBuilder.get_decoder(**model_builder_kwargs) if net_dec is None else net_dec |
|
self.use_default_normalization = use_default_normalization |
|
self.default_normalization = NormalizeTensor(mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225]) |
|
|
|
self.encode = encode |
|
|
|
self.return_feature_maps = return_feature_maps |
|
|
|
assert 0 <= return_feature_maps_level <= 3 |
|
self.return_feature_maps_level = return_feature_maps_level |
|
|
|
def normalize_input(self, tensor): |
|
if tensor.min() < 0 or tensor.max() > 1: |
|
raise ValueError("Tensor should be 0..1 before using normalize_input") |
|
return self.default_normalization(tensor) |
|
|
|
@property |
|
def feature_maps_channels(self): |
|
return 256 * 2**(self.return_feature_maps_level) |
|
|
|
def forward(self, img_data, segSize=None): |
|
if segSize is None: |
|
raise NotImplementedError("Please pass segSize param. By default: (300, 300)") |
|
|
|
fmaps = self.encoder(img_data, return_feature_maps=True) |
|
pred = self.decoder(fmaps, segSize=segSize) |
|
|
|
if self.return_feature_maps: |
|
return pred, fmaps |
|
|
|
return pred |
|
|
|
def multi_mask_from_multiclass(self, pred, classes): |
|
def isin(ar1, ar2): |
|
return (ar1[..., None] == ar2).any(-1).float() |
|
return isin(pred, torch.LongTensor(classes).to(self.device)) |
|
|
|
@staticmethod |
|
def multi_mask_from_multiclass_probs(scores, classes): |
|
res = None |
|
for c in classes: |
|
if res is None: |
|
res = scores[:, c] |
|
else: |
|
res += scores[:, c] |
|
return res |
|
|
|
def predict(self, tensor, imgSizes=(-1,), |
|
segSize=None): |
|
"""Entry-point for segmentation. Use this methods instead of forward |
|
Arguments: |
|
tensor {torch.Tensor} -- BCHW |
|
Keyword Arguments: |
|
imgSizes {tuple or list} -- imgSizes for segmentation input. |
|
default: (300, 450) |
|
original implementation: (300, 375, 450, 525, 600) |
|
|
|
""" |
|
if segSize is None: |
|
segSize = tensor.shape[-2:] |
|
segSize = (tensor.shape[2], tensor.shape[3]) |
|
with torch.no_grad(): |
|
if self.use_default_normalization: |
|
tensor = self.normalize_input(tensor) |
|
scores = torch.zeros(1, NUM_CLASS, segSize[0], segSize[1]).to(self.device) |
|
features = torch.zeros(1, self.feature_maps_channels, segSize[0], segSize[1]).to(self.device) |
|
|
|
result = [] |
|
for img_size in imgSizes: |
|
if img_size != -1: |
|
img_data = F.interpolate(tensor.clone(), size=img_size) |
|
else: |
|
img_data = tensor.clone() |
|
|
|
if self.return_feature_maps: |
|
pred_current, fmaps = self.forward(img_data, segSize=segSize) |
|
else: |
|
pred_current = self.forward(img_data, segSize=segSize) |
|
|
|
|
|
result.append(pred_current) |
|
scores = scores + pred_current / len(imgSizes) |
|
|
|
|
|
if self.return_feature_maps: |
|
features = features + F.interpolate(fmaps[self.return_feature_maps_level], size=segSize) / len(imgSizes) |
|
|
|
_, pred = torch.max(scores, dim=1) |
|
|
|
if self.return_feature_maps: |
|
return features |
|
|
|
return pred, result |
|
|
|
def get_edges(self, t): |
|
edge = torch.cuda.ByteTensor(t.size()).zero_() |
|
edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1]) |
|
edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1]) |
|
edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) |
|
edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :]) |
|
|
|
if True: |
|
return edge.half() |
|
return edge.float() |
|
|
|
|
|
|
|
class PPMDeepsup(nn.Module): |
|
def __init__(self, num_class=NUM_CLASS, fc_dim=4096, |
|
use_softmax=False, pool_scales=(1, 2, 3, 6), |
|
drop_last_conv=False): |
|
super().__init__() |
|
self.use_softmax = use_softmax |
|
self.drop_last_conv = drop_last_conv |
|
|
|
self.ppm = [] |
|
for scale in pool_scales: |
|
self.ppm.append(nn.Sequential( |
|
nn.AdaptiveAvgPool2d(scale), |
|
nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), |
|
BatchNorm2d(512), |
|
nn.ReLU(inplace=True) |
|
)) |
|
self.ppm = nn.ModuleList(self.ppm) |
|
self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) |
|
|
|
self.conv_last = nn.Sequential( |
|
nn.Conv2d(fc_dim + len(pool_scales) * 512, 512, |
|
kernel_size=3, padding=1, bias=False), |
|
BatchNorm2d(512), |
|
nn.ReLU(inplace=True), |
|
nn.Dropout2d(0.1), |
|
nn.Conv2d(512, num_class, kernel_size=1) |
|
) |
|
self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) |
|
self.dropout_deepsup = nn.Dropout2d(0.1) |
|
|
|
def forward(self, conv_out, segSize=None): |
|
conv5 = conv_out[-1] |
|
|
|
input_size = conv5.size() |
|
ppm_out = [conv5] |
|
for pool_scale in self.ppm: |
|
ppm_out.append(nn.functional.interpolate( |
|
pool_scale(conv5), |
|
(input_size[2], input_size[3]), |
|
mode='bilinear', align_corners=False)) |
|
ppm_out = torch.cat(ppm_out, 1) |
|
|
|
if self.drop_last_conv: |
|
return ppm_out |
|
else: |
|
x = self.conv_last(ppm_out) |
|
|
|
if self.use_softmax: |
|
x = nn.functional.interpolate( |
|
x, size=segSize, mode='bilinear', align_corners=False) |
|
x = nn.functional.softmax(x, dim=1) |
|
return x |
|
|
|
|
|
conv4 = conv_out[-2] |
|
_ = self.cbr_deepsup(conv4) |
|
_ = self.dropout_deepsup(_) |
|
_ = self.conv_last_deepsup(_) |
|
|
|
x = nn.functional.log_softmax(x, dim=1) |
|
_ = nn.functional.log_softmax(_, dim=1) |
|
|
|
return (x, _) |
|
|
|
|
|
class Resnet(nn.Module): |
|
def __init__(self, orig_resnet): |
|
super(Resnet, self).__init__() |
|
|
|
|
|
self.conv1 = orig_resnet.conv1 |
|
self.bn1 = orig_resnet.bn1 |
|
self.relu1 = orig_resnet.relu1 |
|
self.conv2 = orig_resnet.conv2 |
|
self.bn2 = orig_resnet.bn2 |
|
self.relu2 = orig_resnet.relu2 |
|
self.conv3 = orig_resnet.conv3 |
|
self.bn3 = orig_resnet.bn3 |
|
self.relu3 = orig_resnet.relu3 |
|
self.maxpool = orig_resnet.maxpool |
|
self.layer1 = orig_resnet.layer1 |
|
self.layer2 = orig_resnet.layer2 |
|
self.layer3 = orig_resnet.layer3 |
|
self.layer4 = orig_resnet.layer4 |
|
|
|
def forward(self, x, return_feature_maps=False): |
|
conv_out = [] |
|
|
|
x = self.relu1(self.bn1(self.conv1(x))) |
|
x = self.relu2(self.bn2(self.conv2(x))) |
|
x = self.relu3(self.bn3(self.conv3(x))) |
|
x = self.maxpool(x) |
|
|
|
x = self.layer1(x); conv_out.append(x); |
|
x = self.layer2(x); conv_out.append(x); |
|
x = self.layer3(x); conv_out.append(x); |
|
x = self.layer4(x); conv_out.append(x); |
|
|
|
if return_feature_maps: |
|
return conv_out |
|
return [x] |
|
|
|
|
|
class ResnetDilated(nn.Module): |
|
def __init__(self, orig_resnet, dilate_scale=8): |
|
super().__init__() |
|
from functools import partial |
|
|
|
if dilate_scale == 8: |
|
orig_resnet.layer3.apply( |
|
partial(self._nostride_dilate, dilate=2)) |
|
orig_resnet.layer4.apply( |
|
partial(self._nostride_dilate, dilate=4)) |
|
elif dilate_scale == 16: |
|
orig_resnet.layer4.apply( |
|
partial(self._nostride_dilate, dilate=2)) |
|
|
|
|
|
self.conv1 = orig_resnet.conv1 |
|
self.bn1 = orig_resnet.bn1 |
|
self.relu1 = orig_resnet.relu1 |
|
self.conv2 = orig_resnet.conv2 |
|
self.bn2 = orig_resnet.bn2 |
|
self.relu2 = orig_resnet.relu2 |
|
self.conv3 = orig_resnet.conv3 |
|
self.bn3 = orig_resnet.bn3 |
|
self.relu3 = orig_resnet.relu3 |
|
self.maxpool = orig_resnet.maxpool |
|
self.layer1 = orig_resnet.layer1 |
|
self.layer2 = orig_resnet.layer2 |
|
self.layer3 = orig_resnet.layer3 |
|
self.layer4 = orig_resnet.layer4 |
|
|
|
def _nostride_dilate(self, m, dilate): |
|
classname = m.__class__.__name__ |
|
if classname.find('Conv') != -1: |
|
|
|
if m.stride == (2, 2): |
|
m.stride = (1, 1) |
|
if m.kernel_size == (3, 3): |
|
m.dilation = (dilate // 2, dilate // 2) |
|
m.padding = (dilate // 2, dilate // 2) |
|
|
|
else: |
|
if m.kernel_size == (3, 3): |
|
m.dilation = (dilate, dilate) |
|
m.padding = (dilate, dilate) |
|
|
|
def forward(self, x, return_feature_maps=False): |
|
conv_out = [] |
|
|
|
x = self.relu1(self.bn1(self.conv1(x))) |
|
x = self.relu2(self.bn2(self.conv2(x))) |
|
x = self.relu3(self.bn3(self.conv3(x))) |
|
x = self.maxpool(x) |
|
|
|
x = self.layer1(x) |
|
conv_out.append(x) |
|
x = self.layer2(x) |
|
conv_out.append(x) |
|
x = self.layer3(x) |
|
conv_out.append(x) |
|
x = self.layer4(x) |
|
conv_out.append(x) |
|
|
|
if return_feature_maps: |
|
return conv_out |
|
return [x] |
|
|
|
class MobileNetV2Dilated(nn.Module): |
|
def __init__(self, orig_net, dilate_scale=8): |
|
super(MobileNetV2Dilated, self).__init__() |
|
from functools import partial |
|
|
|
|
|
self.features = orig_net.features[:-1] |
|
|
|
self.total_idx = len(self.features) |
|
self.down_idx = [2, 4, 7, 14] |
|
|
|
if dilate_scale == 8: |
|
for i in range(self.down_idx[-2], self.down_idx[-1]): |
|
self.features[i].apply( |
|
partial(self._nostride_dilate, dilate=2) |
|
) |
|
for i in range(self.down_idx[-1], self.total_idx): |
|
self.features[i].apply( |
|
partial(self._nostride_dilate, dilate=4) |
|
) |
|
elif dilate_scale == 16: |
|
for i in range(self.down_idx[-1], self.total_idx): |
|
self.features[i].apply( |
|
partial(self._nostride_dilate, dilate=2) |
|
) |
|
|
|
def _nostride_dilate(self, m, dilate): |
|
classname = m.__class__.__name__ |
|
if classname.find('Conv') != -1: |
|
|
|
if m.stride == (2, 2): |
|
m.stride = (1, 1) |
|
if m.kernel_size == (3, 3): |
|
m.dilation = (dilate//2, dilate//2) |
|
m.padding = (dilate//2, dilate//2) |
|
|
|
else: |
|
if m.kernel_size == (3, 3): |
|
m.dilation = (dilate, dilate) |
|
m.padding = (dilate, dilate) |
|
|
|
def forward(self, x, return_feature_maps=False): |
|
if return_feature_maps: |
|
conv_out = [] |
|
for i in range(self.total_idx): |
|
x = self.features[i](x) |
|
if i in self.down_idx: |
|
conv_out.append(x) |
|
conv_out.append(x) |
|
return conv_out |
|
|
|
else: |
|
return [self.features(x)] |
|
|
|
|
|
|
|
class C1DeepSup(nn.Module): |
|
def __init__(self, num_class=150, fc_dim=2048, use_softmax=False, drop_last_conv=False): |
|
super(C1DeepSup, self).__init__() |
|
self.use_softmax = use_softmax |
|
self.drop_last_conv = drop_last_conv |
|
|
|
self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) |
|
self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) |
|
|
|
|
|
self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) |
|
self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) |
|
|
|
def forward(self, conv_out, segSize=None): |
|
conv5 = conv_out[-1] |
|
|
|
x = self.cbr(conv5) |
|
|
|
if self.drop_last_conv: |
|
return x |
|
else: |
|
x = self.conv_last(x) |
|
|
|
if self.use_softmax: |
|
x = nn.functional.interpolate( |
|
x, size=segSize, mode='bilinear', align_corners=False) |
|
x = nn.functional.softmax(x, dim=1) |
|
return x |
|
|
|
|
|
conv4 = conv_out[-2] |
|
_ = self.cbr_deepsup(conv4) |
|
_ = self.conv_last_deepsup(_) |
|
|
|
x = nn.functional.log_softmax(x, dim=1) |
|
_ = nn.functional.log_softmax(_, dim=1) |
|
|
|
return (x, _) |
|
|
|
|
|
|
|
class C1(nn.Module): |
|
def __init__(self, num_class=150, fc_dim=2048, use_softmax=False): |
|
super(C1, self).__init__() |
|
self.use_softmax = use_softmax |
|
|
|
self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) |
|
|
|
|
|
self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) |
|
|
|
def forward(self, conv_out, segSize=None): |
|
conv5 = conv_out[-1] |
|
x = self.cbr(conv5) |
|
x = self.conv_last(x) |
|
|
|
if self.use_softmax: |
|
x = nn.functional.interpolate( |
|
x, size=segSize, mode='bilinear', align_corners=False) |
|
x = nn.functional.softmax(x, dim=1) |
|
else: |
|
x = nn.functional.log_softmax(x, dim=1) |
|
|
|
return x |
|
|
|
|
|
|
|
class PPM(nn.Module): |
|
def __init__(self, num_class=150, fc_dim=4096, |
|
use_softmax=False, pool_scales=(1, 2, 3, 6)): |
|
super(PPM, self).__init__() |
|
self.use_softmax = use_softmax |
|
|
|
self.ppm = [] |
|
for scale in pool_scales: |
|
self.ppm.append(nn.Sequential( |
|
nn.AdaptiveAvgPool2d(scale), |
|
nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), |
|
BatchNorm2d(512), |
|
nn.ReLU(inplace=True) |
|
)) |
|
self.ppm = nn.ModuleList(self.ppm) |
|
|
|
self.conv_last = nn.Sequential( |
|
nn.Conv2d(fc_dim+len(pool_scales)*512, 512, |
|
kernel_size=3, padding=1, bias=False), |
|
BatchNorm2d(512), |
|
nn.ReLU(inplace=True), |
|
nn.Dropout2d(0.1), |
|
nn.Conv2d(512, num_class, kernel_size=1) |
|
) |
|
|
|
def forward(self, conv_out, segSize=None): |
|
conv5 = conv_out[-1] |
|
|
|
input_size = conv5.size() |
|
ppm_out = [conv5] |
|
for pool_scale in self.ppm: |
|
ppm_out.append(nn.functional.interpolate( |
|
pool_scale(conv5), |
|
(input_size[2], input_size[3]), |
|
mode='bilinear', align_corners=False)) |
|
ppm_out = torch.cat(ppm_out, 1) |
|
|
|
x = self.conv_last(ppm_out) |
|
|
|
if self.use_softmax: |
|
x = nn.functional.interpolate( |
|
x, size=segSize, mode='bilinear', align_corners=False) |
|
x = nn.functional.softmax(x, dim=1) |
|
else: |
|
x = nn.functional.log_softmax(x, dim=1) |
|
return x |
|
|