|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn import SyncBatchNorm as BatchNorm2d |
|
import re |
|
import os, sys |
|
from six import moves |
|
|
|
class ModuleParallel(nn.Module): |
|
def __init__(self, module): |
|
super(ModuleParallel, self).__init__() |
|
self.module = module |
|
|
|
def forward(self, x_parallel): |
|
return [self.module(x) for x in x_parallel] |
|
|
|
|
|
class BatchNorm2dParallel(nn.Module): |
|
def __init__(self, num_features, num_parallel): |
|
super(BatchNorm2dParallel, self).__init__() |
|
for i in range(num_parallel): |
|
setattr(self, 'bn_' + str(i), BatchNorm2d(num_features)) |
|
|
|
def forward(self, x_parallel): |
|
return [getattr(self, 'bn_' + str(i))(x) for i, x in enumerate(x_parallel)] |
|
|
|
class MyRefineNet(nn.Module): |
|
def __init__(self, num_layers, num_classes): |
|
super(MyRefineNet, self).__init__() |
|
self.model = refinenet(num_layers, num_classes, 1, None) |
|
self.model = model_init(self.model, num_layers, 1, imagenet=True) |
|
def forward(self, data, get_sup_loss = False, gt = None, criterion = None): |
|
b, c, h, w = data[0].shape |
|
pred = self.model(data) |
|
pred = F.interpolate(pred[0], size=(h, w), mode='bilinear', align_corners=True) |
|
if not self.training: |
|
return pred |
|
else: |
|
if get_sup_loss: |
|
return pred, self.get_sup_loss(pred, gt, criterion) |
|
else: |
|
return pred |
|
|
|
def get_params(self): |
|
enc_params, dec_params= [], [] |
|
for name, param in self.model.named_parameters(): |
|
if bool(re.match('.*conv1.*|.*bn1.*|.*layer.*', name)): |
|
enc_params.append(param) |
|
else: |
|
dec_params.append(param) |
|
return enc_params, dec_params |
|
|
|
def get_sup_loss(self, pred, gt, criterion): |
|
pred = pred[:gt.shape[0]] |
|
return criterion(pred, gt) |
|
|
|
"""RefineNet-LightWeight |
|
|
|
RefineNet-LigthWeight PyTorch for non-commercial purposes |
|
|
|
Copyright (c) 2018, Vladimir Nekrasov ([email protected]) |
|
All rights reserved. |
|
|
|
Redistribution and use in source and binary forms, with or without |
|
modification, are permitted provided that the following conditions are met: |
|
|
|
* Redistributions of source code must retain the above copyright notice, this |
|
list of conditions and the following disclaimer. |
|
|
|
* Redistributions in binary form must reproduce the above copyright notice, |
|
this list of conditions and the following disclaimer in the documentation |
|
and/or other materials provided with the distribution. |
|
|
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" |
|
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE |
|
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE |
|
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL |
|
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR |
|
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER |
|
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, |
|
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE |
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. |
|
""" |
|
|
|
|
|
models_urls = { |
|
'101_voc' : 'https://cloudstor.aarnet.edu.au/plus/s/Owmttk9bdPROwc6/download', |
|
|
|
'18_imagenet' : 'https://download.pytorch.org/models/resnet18-f37072fd.pth', |
|
'50_imagenet' : 'https://download.pytorch.org/models/resnet50-19c8e357.pth', |
|
'101_imagenet': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', |
|
'152_imagenet': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', |
|
} |
|
|
|
bottleneck_idx = 0 |
|
save_idx = 0 |
|
|
|
|
|
def conv3x3(in_planes, out_planes, stride=1, bias=False): |
|
"3x3 convolution with padding" |
|
return ModuleParallel(nn.Conv2d(in_planes, out_planes, kernel_size=3, |
|
stride=stride, padding=1, bias=bias)) |
|
|
|
|
|
def conv1x1(in_planes, out_planes, stride=1, bias=False): |
|
"1x1 convolution" |
|
return ModuleParallel(nn.Conv2d(in_planes, out_planes, kernel_size=1, |
|
stride=stride, padding=0, bias=bias)) |
|
|
|
|
|
class CRPBlock(nn.Module): |
|
def __init__(self, in_planes, out_planes, num_stages, num_parallel): |
|
super(CRPBlock, self).__init__() |
|
for i in range(num_stages): |
|
setattr(self, '{}_{}'.format(i + 1, 'outvar_dimred'), |
|
conv3x3(in_planes if (i == 0) else out_planes, out_planes)) |
|
self.stride = 1 |
|
self.num_stages = num_stages |
|
self.num_parallel = num_parallel |
|
self.maxpool = ModuleParallel(nn.MaxPool2d(kernel_size=5, stride=1, padding=2)) |
|
|
|
def forward(self, x): |
|
top = x |
|
for i in range(self.num_stages): |
|
top = self.maxpool(top) |
|
top = getattr(self, '{}_{}'.format(i + 1, 'outvar_dimred'))(top) |
|
x = [x[l] + top[l] for l in range(self.num_parallel)] |
|
return x |
|
|
|
|
|
stages_suffixes = {0 : '_conv', 1 : '_conv_relu_varout_dimred'} |
|
|
|
class RCUBlock(nn.Module): |
|
def __init__(self, in_planes, out_planes, num_blocks, num_stages, num_parallel): |
|
super(RCUBlock, self).__init__() |
|
for i in range(num_blocks): |
|
for j in range(num_stages): |
|
setattr(self, '{}{}'.format(i + 1, stages_suffixes[j]), |
|
conv3x3(in_planes if (i == 0) and (j == 0) else out_planes, |
|
out_planes, bias=(j == 0))) |
|
self.stride = 1 |
|
self.num_blocks = num_blocks |
|
self.num_stages = num_stages |
|
self.num_parallel = num_parallel |
|
self.relu = ModuleParallel(nn.ReLU(inplace=True)) |
|
|
|
def forward(self, x): |
|
for i in range(self.num_blocks): |
|
residual = x |
|
for j in range(self.num_stages): |
|
x = self.relu(x) |
|
x = getattr(self, '{}{}'.format(i + 1, stages_suffixes[j]))(x) |
|
x = [x[l] + residual[l] for l in range(self.num_parallel)] |
|
return x |
|
|
|
|
|
class BasicBlock(nn.Module): |
|
expansion = 1 |
|
def __init__(self, inplanes, planes, num_parallel, bn_threshold, stride=1, downsample=None): |
|
super(BasicBlock, self).__init__() |
|
self.conv1 = conv3x3(inplanes, planes, stride) |
|
self.bn1 = BatchNorm2dParallel(planes, num_parallel) |
|
self.relu = ModuleParallel(nn.ReLU(inplace=True)) |
|
self.conv2 = conv3x3(planes, planes) |
|
self.bn2 = BatchNorm2dParallel(planes, num_parallel) |
|
self.num_parallel = num_parallel |
|
self.downsample = downsample |
|
self.stride = stride |
|
|
|
self.bn_threshold = bn_threshold |
|
self.bn2_list = [] |
|
for module in self.bn2.modules(): |
|
if isinstance(module, BatchNorm2d): |
|
self.bn2_list.append(module) |
|
|
|
def forward(self, x): |
|
residual = x |
|
|
|
out = self.conv1(x) |
|
out = self.bn1(out) |
|
out = self.relu(out) |
|
|
|
out = self.conv2(out) |
|
out = self.bn2(out) |
|
|
|
if self.downsample is not None: |
|
residual = self.downsample(x) |
|
|
|
out = [out[l] + residual[l] for l in range(self.num_parallel)] |
|
out = self.relu(out) |
|
|
|
return out |
|
|
|
|
|
class Bottleneck(nn.Module): |
|
expansion = 4 |
|
def __init__(self, inplanes, planes, num_parallel, bn_threshold, stride=1, downsample=None): |
|
super(Bottleneck, self).__init__() |
|
self.conv1 = conv1x1(inplanes, planes) |
|
self.bn1 = BatchNorm2dParallel(planes, num_parallel) |
|
self.conv2 = conv3x3(planes, planes, stride=stride) |
|
self.bn2 = BatchNorm2dParallel(planes, num_parallel) |
|
self.conv3 = conv1x1(planes, planes * 4) |
|
self.bn3 = BatchNorm2dParallel(planes * 4, num_parallel) |
|
self.relu = ModuleParallel(nn.ReLU(inplace=True)) |
|
self.num_parallel = num_parallel |
|
self.downsample = downsample |
|
self.stride = stride |
|
|
|
self.bn_threshold = bn_threshold |
|
self.bn2_list = [] |
|
for module in self.bn2.modules(): |
|
if isinstance(module, BatchNorm2d): |
|
self.bn2_list.append(module) |
|
|
|
def forward(self, x): |
|
residual = x |
|
out = x |
|
|
|
out = self.conv1(out) |
|
out = self.bn1(out) |
|
out = self.relu(out) |
|
|
|
out = self.conv2(out) |
|
out = self.bn2(out) |
|
out = self.relu(out) |
|
|
|
out = self.conv3(out) |
|
out = self.bn3(out) |
|
|
|
if self.downsample is not None: |
|
residual = self.downsample(x) |
|
|
|
out = [out[l] + residual[l] for l in range(self.num_parallel)] |
|
out = self.relu(out) |
|
|
|
return out |
|
|
|
|
|
class RefineNet(nn.Module): |
|
def __init__(self, block, layers, num_parallel, num_classes=21, bn_threshold=2e-2): |
|
self.inplanes = 64 |
|
self.num_parallel = num_parallel |
|
super(RefineNet, self).__init__() |
|
self.dropout = ModuleParallel(nn.Dropout(p=0.5)) |
|
self.conv1 = ModuleParallel(nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)) |
|
self.bn1 = BatchNorm2dParallel(64, num_parallel) |
|
self.relu = ModuleParallel(nn.ReLU(inplace=True)) |
|
self.maxpool = ModuleParallel(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) |
|
self.layer1 = self._make_layer(block, 64, layers[0], bn_threshold) |
|
self.layer2 = self._make_layer(block, 128, layers[1], bn_threshold, stride=2) |
|
self.layer3 = self._make_layer(block, 256, layers[2], bn_threshold, stride=2) |
|
self.layer4 = self._make_layer(block, 512, layers[3], bn_threshold, stride=2) |
|
|
|
self.p_ims1d2_outl1_dimred = conv3x3(2048, 512) |
|
self.adapt_stage1_b = self._make_rcu(512, 512, 2, 2) |
|
self.mflow_conv_g1_pool = self._make_crp(512, 512, 4) |
|
self.mflow_conv_g1_b = self._make_rcu(512, 512, 3, 2) |
|
self.mflow_conv_g1_b3_joint_varout_dimred = conv3x3(512, 256) |
|
|
|
self.p_ims1d2_outl2_dimred = conv3x3(1024, 256) |
|
self.adapt_stage2_b = self._make_rcu(256, 256, 2, 2) |
|
self.adapt_stage2_b2_joint_varout_dimred = conv3x3(256, 256) |
|
self.mflow_conv_g2_pool = self._make_crp(256, 256, 4) |
|
self.mflow_conv_g2_b = self._make_rcu(256, 256, 3, 2) |
|
self.mflow_conv_g2_b3_joint_varout_dimred = conv3x3(256, 256) |
|
|
|
self.p_ims1d2_outl3_dimred = conv3x3(512, 256) |
|
self.adapt_stage3_b = self._make_rcu(256, 256, 2, 2) |
|
self.adapt_stage3_b2_joint_varout_dimred = conv3x3(256, 256) |
|
self.mflow_conv_g3_pool = self._make_crp(256, 256, 4) |
|
self.mflow_conv_g3_b = self._make_rcu(256, 256, 3, 2) |
|
self.mflow_conv_g3_b3_joint_varout_dimred = conv3x3(256, 256) |
|
|
|
self.p_ims1d2_outl4_dimred = conv3x3(256, 256) |
|
self.adapt_stage4_b = self._make_rcu(256, 256, 2, 2) |
|
self.adapt_stage4_b2_joint_varout_dimred = conv3x3(256, 256) |
|
self.mflow_conv_g4_pool = self._make_crp(256, 256, 4) |
|
self.mflow_conv_g4_b = self._make_rcu(256, 256, 3, 2) |
|
|
|
self.clf_conv = conv3x3(256, num_classes, bias=True) |
|
|
|
def _make_crp(self, in_planes, out_planes, num_stages): |
|
layers = [CRPBlock(in_planes, out_planes, num_stages, self.num_parallel)] |
|
return nn.Sequential(*layers) |
|
|
|
def _make_rcu(self, in_planes, out_planes, num_blocks, num_stages): |
|
layers = [RCUBlock(in_planes, out_planes, num_blocks, num_stages, self.num_parallel)] |
|
return nn.Sequential(*layers) |
|
|
|
def _make_layer(self, block, planes, num_blocks, bn_threshold, stride=1): |
|
downsample = None |
|
if stride != 1 or self.inplanes != planes * block.expansion: |
|
downsample = nn.Sequential( |
|
conv1x1(self.inplanes, planes * block.expansion, stride=stride), |
|
BatchNorm2dParallel(planes * block.expansion, self.num_parallel) |
|
) |
|
|
|
layers = [] |
|
layers.append(block(self.inplanes, planes, self.num_parallel, bn_threshold, stride, downsample)) |
|
self.inplanes = planes * block.expansion |
|
for i in range(1, num_blocks): |
|
layers.append(block(self.inplanes, planes, self.num_parallel, bn_threshold)) |
|
|
|
return nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
x = self.conv1(x) |
|
x = self.bn1(x) |
|
x = self.relu(x) |
|
x = self.maxpool(x) |
|
|
|
l1 = self.layer1(x) |
|
l2 = self.layer2(l1) |
|
l3 = self.layer3(l2) |
|
l4 = self.layer4(l3) |
|
|
|
l4 = self.dropout(l4) |
|
l3 = self.dropout(l3) |
|
|
|
x4 = self.p_ims1d2_outl1_dimred(l4) |
|
x4 = self.adapt_stage1_b(x4) |
|
x4 = self.relu(x4) |
|
x4 = self.mflow_conv_g1_pool(x4) |
|
x4 = self.mflow_conv_g1_b(x4) |
|
x4 = self.mflow_conv_g1_b3_joint_varout_dimred(x4) |
|
x4 = [nn.Upsample(size=l3[0].size()[2:], mode='bilinear', align_corners=True)(x4_) for x4_ in x4] |
|
|
|
x3 = self.p_ims1d2_outl2_dimred(l3) |
|
x3 = self.adapt_stage2_b(x3) |
|
x3 = self.adapt_stage2_b2_joint_varout_dimred(x3) |
|
x3 = [x3[l] + x4[l] for l in range(self.num_parallel)] |
|
x3 = self.relu(x3) |
|
x3 = self.mflow_conv_g2_pool(x3) |
|
x3 = self.mflow_conv_g2_b(x3) |
|
x3 = self.mflow_conv_g2_b3_joint_varout_dimred(x3) |
|
x3 = [nn.Upsample(size=l2[0].size()[2:], mode='bilinear', align_corners=True)(x3_) for x3_ in x3] |
|
|
|
x2 = self.p_ims1d2_outl3_dimred(l2) |
|
x2 = self.adapt_stage3_b(x2) |
|
x2 = self.adapt_stage3_b2_joint_varout_dimred(x2) |
|
x2 = [x2[l] + x3[l] for l in range(self.num_parallel)] |
|
x2 = self.relu(x2) |
|
x2 = self.mflow_conv_g3_pool(x2) |
|
x2 = self.mflow_conv_g3_b(x2) |
|
x2 = self.mflow_conv_g3_b3_joint_varout_dimred(x2) |
|
x2 = [nn.Upsample(size=l1[0].size()[2:], mode='bilinear', align_corners=True)(x2_) for x2_ in x2] |
|
|
|
x1 = self.p_ims1d2_outl4_dimred(l1) |
|
x1 = self.adapt_stage4_b(x1) |
|
x1 = self.adapt_stage4_b2_joint_varout_dimred(x1) |
|
x1 = [x1[l] + x2[l] for l in range(self.num_parallel)] |
|
x1 = self.relu(x1) |
|
x1 = self.mflow_conv_g4_pool(x1) |
|
x1 = self.mflow_conv_g4_b(x1) |
|
x1 = self.dropout(x1) |
|
|
|
out = self.clf_conv(x1) |
|
return out |
|
|
|
|
|
class RefineNet_Resnet18(nn.Module): |
|
def __init__(self, block, layers, num_parallel, num_classes=21, bn_threshold=2e-2): |
|
self.inplanes = 64 |
|
self.num_parallel = num_parallel |
|
super(RefineNet_Resnet18, self).__init__() |
|
self.dropout = ModuleParallel(nn.Dropout(p=0.5)) |
|
self.conv1 = ModuleParallel(nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)) |
|
self.bn1 = BatchNorm2dParallel(64, num_parallel) |
|
self.relu = ModuleParallel(nn.ReLU(inplace=True)) |
|
self.maxpool = ModuleParallel(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) |
|
self.layer1 = self._make_layer(block, 64, layers[0], bn_threshold) |
|
self.layer2 = self._make_layer(block, 128, layers[1], bn_threshold, stride=2) |
|
self.layer3 = self._make_layer(block, 256, layers[2], bn_threshold, stride=2) |
|
self.layer4 = self._make_layer(block, 512, layers[3], bn_threshold, stride=2) |
|
|
|
self.p_ims1d2_outl1_dimred = conv3x3(512, 256) |
|
self.adapt_stage1_b = self._make_rcu(256, 256, 2, 2) |
|
self.mflow_conv_g1_pool = self._make_crp(256, 256, 4) |
|
self.mflow_conv_g1_b = self._make_rcu(256, 256, 3, 2) |
|
self.mflow_conv_g1_b3_joint_varout_dimred = conv3x3(256, 64) |
|
|
|
self.p_ims1d2_outl2_dimred = conv3x3(256, 64) |
|
self.adapt_stage2_b = self._make_rcu(64, 64, 2, 2) |
|
self.adapt_stage2_b2_joint_varout_dimred = conv3x3(64, 64) |
|
self.mflow_conv_g2_pool = self._make_crp(64, 64, 4) |
|
self.mflow_conv_g2_b = self._make_rcu(64, 64, 3, 2) |
|
self.mflow_conv_g2_b3_joint_varout_dimred = conv3x3(64, 64) |
|
|
|
self.p_ims1d2_outl3_dimred = conv3x3(128, 64) |
|
self.adapt_stage3_b = self._make_rcu(64, 64, 2, 2) |
|
self.adapt_stage3_b2_joint_varout_dimred = conv3x3(64, 64) |
|
self.mflow_conv_g3_pool = self._make_crp(64, 64, 4) |
|
self.mflow_conv_g3_b = self._make_rcu(64, 64, 3, 2) |
|
self.mflow_conv_g3_b3_joint_varout_dimred = conv3x3(64, 64) |
|
|
|
self.p_ims1d2_outl4_dimred = conv3x3(64, 64) |
|
self.adapt_stage4_b = self._make_rcu(64, 64, 2, 2) |
|
self.adapt_stage4_b2_joint_varout_dimred = conv3x3(64, 64) |
|
self.mflow_conv_g4_pool = self._make_crp(64, 64, 4) |
|
self.mflow_conv_g4_b = self._make_rcu(64, 64, 3, 2) |
|
|
|
self.clf_conv = conv3x3(64, num_classes, bias=True) |
|
|
|
def _make_crp(self, in_planes, out_planes, num_stages): |
|
layers = [CRPBlock(in_planes, out_planes, num_stages, self.num_parallel)] |
|
return nn.Sequential(*layers) |
|
|
|
def _make_rcu(self, in_planes, out_planes, num_blocks, num_stages): |
|
layers = [RCUBlock(in_planes, out_planes, num_blocks, num_stages, self.num_parallel)] |
|
return nn.Sequential(*layers) |
|
|
|
def _make_layer(self, block, planes, num_blocks, bn_threshold, stride=1): |
|
downsample = None |
|
if stride != 1 or self.inplanes != planes * block.expansion: |
|
downsample = nn.Sequential( |
|
conv1x1(self.inplanes, planes * block.expansion, stride=stride), |
|
BatchNorm2dParallel(planes * block.expansion, self.num_parallel) |
|
) |
|
|
|
layers = [] |
|
layers.append(block(self.inplanes, planes, self.num_parallel, bn_threshold, stride, downsample)) |
|
self.inplanes = planes * block.expansion |
|
for i in range(1, num_blocks): |
|
layers.append(block(self.inplanes, planes, self.num_parallel, bn_threshold)) |
|
|
|
return nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
x = self.conv1(x) |
|
x = self.bn1(x) |
|
x = self.relu(x) |
|
x = self.maxpool(x) |
|
|
|
l1 = self.layer1(x) |
|
l2 = self.layer2(l1) |
|
l3 = self.layer3(l2) |
|
l4 = self.layer4(l3) |
|
|
|
l4 = self.dropout(l4) |
|
l3 = self.dropout(l3) |
|
|
|
x4 = self.p_ims1d2_outl1_dimred(l4) |
|
x4 = self.adapt_stage1_b(x4) |
|
x4 = self.relu(x4) |
|
x4 = self.mflow_conv_g1_pool(x4) |
|
x4 = self.mflow_conv_g1_b(x4) |
|
x4 = self.mflow_conv_g1_b3_joint_varout_dimred(x4) |
|
x4 = [nn.Upsample(size=l3[0].size()[2:], mode='bilinear', align_corners=True)(x4_) for x4_ in x4] |
|
|
|
x3 = self.p_ims1d2_outl2_dimred(l3) |
|
x3 = self.adapt_stage2_b(x3) |
|
x3 = self.adapt_stage2_b2_joint_varout_dimred(x3) |
|
x3 = [x3[l] + x4[l] for l in range(self.num_parallel)] |
|
x3 = self.relu(x3) |
|
x3 = self.mflow_conv_g2_pool(x3) |
|
x3 = self.mflow_conv_g2_b(x3) |
|
x3 = self.mflow_conv_g2_b3_joint_varout_dimred(x3) |
|
x3 = [nn.Upsample(size=l2[0].size()[2:], mode='bilinear', align_corners=True)(x3_) for x3_ in x3] |
|
|
|
x2 = self.p_ims1d2_outl3_dimred(l2) |
|
x2 = self.adapt_stage3_b(x2) |
|
x2 = self.adapt_stage3_b2_joint_varout_dimred(x2) |
|
x2 = [x2[l] + x3[l] for l in range(self.num_parallel)] |
|
x2 = self.relu(x2) |
|
x2 = self.mflow_conv_g3_pool(x2) |
|
x2 = self.mflow_conv_g3_b(x2) |
|
x2 = self.mflow_conv_g3_b3_joint_varout_dimred(x2) |
|
x2 = [nn.Upsample(size=l1[0].size()[2:], mode='bilinear', align_corners=True)(x2_) for x2_ in x2] |
|
|
|
x1 = self.p_ims1d2_outl4_dimred(l1) |
|
x1 = self.adapt_stage4_b(x1) |
|
x1 = self.adapt_stage4_b2_joint_varout_dimred(x1) |
|
x1 = [x1[l] + x2[l] for l in range(self.num_parallel)] |
|
x1 = self.relu(x1) |
|
x1 = self.mflow_conv_g4_pool(x1) |
|
x1 = self.mflow_conv_g4_b(x1) |
|
x1 = self.dropout(x1) |
|
|
|
out = self.clf_conv(x1) |
|
|
|
return out |
|
|
|
|
|
def refinenet(num_layers, num_classes, num_parallel, bn_threshold): |
|
refinnetClass = RefineNet |
|
if int(num_layers) == 18: |
|
layers = [2, 2, 2, 2] |
|
block = BasicBlock |
|
refinnetClass = RefineNet_Resnet18 |
|
elif int(num_layers) == 50: |
|
layers = [3, 4, 6, 3] |
|
block = Bottleneck |
|
elif int(num_layers) == 101: |
|
layers = [3, 4, 23, 3] |
|
block = Bottleneck |
|
elif int(num_layers) == 152: |
|
layers = [3, 8, 36, 3] |
|
block = Bottleneck |
|
else: |
|
print('invalid num_layers') |
|
|
|
model = refinnetClass(block, layers, num_parallel, num_classes, bn_threshold) |
|
return model |
|
|
|
def maybe_download(model_name, model_url, model_dir=None, map_location=None): |
|
if model_dir is None: |
|
torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch')) |
|
model_dir = os.getenv('TORCH_MODEL_ZOO', os.path.join(torch_home, 'models')) |
|
if not os.path.exists(model_dir): |
|
os.makedirs(model_dir) |
|
filename = '{}.pth.tar'.format(model_name) |
|
cached_file = os.path.join(model_dir, filename) |
|
if not os.path.exists(cached_file): |
|
url = model_url |
|
sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) |
|
moves.urllib.request.urlretrieve(url, cached_file) |
|
return torch.load(cached_file, map_location=map_location) |
|
|
|
def model_init(model, num_layers, num_parallel, imagenet=False): |
|
if imagenet: |
|
key = str(num_layers) + '_imagenet' |
|
url = models_urls[key] |
|
state_dict = maybe_download(key, url) |
|
model_dict = expand_model_dict(model.state_dict(), state_dict, num_parallel) |
|
model.load_state_dict(model_dict, strict=True) |
|
return model |
|
|
|
def expand_model_dict(model_dict, state_dict, num_parallel): |
|
model_dict_keys = model_dict.keys() |
|
state_dict_keys = state_dict.keys() |
|
for model_dict_key in model_dict_keys: |
|
model_dict_key_re = model_dict_key.replace('module.', '') |
|
if model_dict_key_re in state_dict_keys: |
|
model_dict[model_dict_key] = state_dict[model_dict_key_re] |
|
for i in range(num_parallel): |
|
bn = '.bn_%d' % i |
|
replace = True if bn in model_dict_key_re else False |
|
model_dict_key_re = model_dict_key_re.replace(bn, '') |
|
if replace and model_dict_key_re in state_dict_keys: |
|
model_dict[model_dict_key] = state_dict[model_dict_key_re] |
|
return model_dict |
|
|
|
def get_params(model): |
|
enc_params, dec_params, slim_params = [], [], [] |
|
for name, param in model.named_parameters(): |
|
if bool(re.match('.*conv1.*|.*bn1.*|.*layer.*', name)): |
|
enc_params.append(param) |
|
|
|
|
|
else: |
|
dec_params.append(param) |
|
|
|
|
|
if param.requires_grad and name.endswith('weight') and 'bn2' in name: |
|
if len(slim_params) % 2 == 0: |
|
slim_params.append(param[:len(param) // 2]) |
|
else: |
|
slim_params.append(param[len(param) // 2:]) |
|
return enc_params, dec_params, slim_params |
|
|
|
def get_sup_loss_from_output(criterion, outputs, target): |
|
target = target.squeeze(3) |
|
loss = 0 |
|
for output in outputs: |
|
output = nn.functional.interpolate(output, size=target.size()[1:], |
|
mode='bilinear', align_corners=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
loss += criterion(output, target) |
|
return loss/len(outputs) |
|
|
|
def L1_penalty(var): |
|
return torch.abs(var).sum() |