import torch import torch.nn as nn import torch.nn.functional as F from torch.nn import SyncBatchNorm as BatchNorm2d import re import os, sys # from six import moves class Exchange(nn.Module): def __init__(self): super(Exchange, self).__init__() def forward(self, x, bn, bn_threshold): bn1, bn2 = bn[0].weight.abs(), bn[1].weight.abs() x1, x2 = torch.zeros_like(x[0]), torch.zeros_like(x[1]) x1[:, bn1 >= bn_threshold] = x[0][:, bn1 >= bn_threshold] x1[:, bn1 < bn_threshold] = x[1][:, bn1 < bn_threshold] x2[:, bn2 >= bn_threshold] = x[1][:, bn2 >= bn_threshold] x2[:, bn2 < bn_threshold] = x[0][:, bn2 < bn_threshold] return [x1, x2] class ModuleParallel(nn.Module): def __init__(self, module): super(ModuleParallel, self).__init__() self.module = module def forward(self, x_parallel): return [self.module(x) for x in x_parallel] class BatchNorm2dParallel(nn.Module): def __init__(self, num_features, num_parallel): super(BatchNorm2dParallel, self).__init__() for i in range(num_parallel): setattr(self, 'bn_' + str(i), BatchNorm2d(num_features)) def forward(self, x_parallel): return [getattr(self, 'bn_' + str(i))(x) for i, x in enumerate(x_parallel)] class ChannelExchangingNetwork(nn.Module): def __init__(self, num_layers, num_classes, num_parallel, l1_lambda, bn_threshold): super(ChannelExchangingNetwork, self).__init__() self.model = refinenet(num_layers, num_classes, num_parallel, bn_threshold) self.model = model_init(self.model, num_layers, num_parallel, imagenet=True) #Only initializes the encoder self.l1_lambda = l1_lambda def get_slim_params(self): slim_params = [] for name, param in self.model.named_parameters(): if param.requires_grad and name.endswith('weight') and 'bn2' in name: if len(slim_params) % 2 == 0: slim_params.append(param[:len(param) // 2]) else: slim_params.append(param[len(param) // 2:]) return slim_params def forward(self, data, get_sup_loss = False, gt = None, criterion = None): b, c, h, w = data[0].shape #rgb is the 0th element pred = self.model(data) for i in range(len(pred)): pred[i] = F.interpolate(pred[i], size=(h, w), mode='bilinear', align_corners=True) if not self.training: return pred else: # training if get_sup_loss: l1_loss = self.l1_lambda * self.get_l1_loss(data[0].get_device()) sup_loss = self.get_sup_loss(pred, gt, criterion) return pred, sup_loss + l1_loss else: return pred def get_sup_loss(self, pred, gt, criterion): sup_loss = 0 for p in pred: p = p[:gt.shape[0]] #Getting loss for only those examples in batch where gt exists. Won't get sup loss for unlabeled data. sup_loss += criterion(p, gt) return sup_loss / len(pred) def get_params(self): self.slim_params = self.get_slim_params() #Doing it here and not in __init__ because first the model should be put in appropriate device before accumulating slim_params # enc_params, dec_params = [], [] # for name, param in self.model.named_parameters(): # if bool(re.match('.*conv1.*|.*bn1.*|.*layer.*', name)): # enc_params.append(param) # else: # dec_params.append(param) # return enc_params, dec_params param_groups = [[], [], []] for name, param in self.model.named_parameters(): if "norm" in name: param_groups[1].append(param) elif bool(re.match('.*conv1.*|.*bn1.*|.*layer.*', name)): param_groups[0].append(param) else: param_groups[2].append(param) return param_groups def get_l1_loss(self, device): L1_norm = sum([L1_penalty(m, device) for m in self.slim_params]) if L1_norm > 0: return L1_norm.to(device) else: return torch.tensor(0).to(device) """RefineNet-LightWeight RefineNet-LigthWeight PyTorch for non-commercial purposes Copyright (c) 2018, Vladimir Nekrasov (vladimir.nekrasov@adelaide.edu.au) All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """ models_urls = { '101_voc' : 'https://cloudstor.aarnet.edu.au/plus/s/Owmttk9bdPROwc6/download', '18_imagenet' : 'https://download.pytorch.org/models/resnet18-f37072fd.pth', '50_imagenet' : 'https://download.pytorch.org/models/resnet50-19c8e357.pth', '101_imagenet': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', '152_imagenet': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', } bottleneck_idx = 0 save_idx = 0 def conv3x3(in_planes, out_planes, stride=1, bias=False): "3x3 convolution with padding" return ModuleParallel(nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=bias)) def conv1x1(in_planes, out_planes, stride=1, bias=False): "1x1 convolution" return ModuleParallel(nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, padding=0, bias=bias)) class CRPBlock(nn.Module): def __init__(self, in_planes, out_planes, num_stages, num_parallel): super(CRPBlock, self).__init__() for i in range(num_stages): setattr(self, '{}_{}'.format(i + 1, 'outvar_dimred'), conv3x3(in_planes if (i == 0) else out_planes, out_planes)) self.stride = 1 self.num_stages = num_stages self.num_parallel = num_parallel self.maxpool = ModuleParallel(nn.MaxPool2d(kernel_size=5, stride=1, padding=2)) def forward(self, x): top = x for i in range(self.num_stages): top = self.maxpool(top) top = getattr(self, '{}_{}'.format(i + 1, 'outvar_dimred'))(top) x = [x[l] + top[l] for l in range(self.num_parallel)] return x stages_suffixes = {0 : '_conv', 1 : '_conv_relu_varout_dimred'} class RCUBlock(nn.Module): def __init__(self, in_planes, out_planes, num_blocks, num_stages, num_parallel): super(RCUBlock, self).__init__() for i in range(num_blocks): for j in range(num_stages): setattr(self, '{}{}'.format(i + 1, stages_suffixes[j]), conv3x3(in_planes if (i == 0) and (j == 0) else out_planes, out_planes, bias=(j == 0))) self.stride = 1 self.num_blocks = num_blocks self.num_stages = num_stages self.num_parallel = num_parallel self.relu = ModuleParallel(nn.ReLU(inplace=True)) def forward(self, x): for i in range(self.num_blocks): residual = x for j in range(self.num_stages): x = self.relu(x) x = getattr(self, '{}{}'.format(i + 1, stages_suffixes[j]))(x) x = [x[l] + residual[l] for l in range(self.num_parallel)] return x class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, num_parallel, bn_threshold, stride=1, downsample=None): super(BasicBlock, self).__init__() self.conv1 = conv3x3(inplanes, planes, stride) self.bn1 = BatchNorm2dParallel(planes, num_parallel) self.relu = ModuleParallel(nn.ReLU(inplace=True)) self.conv2 = conv3x3(planes, planes) self.bn2 = BatchNorm2dParallel(planes, num_parallel) self.num_parallel = num_parallel self.downsample = downsample self.stride = stride self.exchange = Exchange() self.bn_threshold = bn_threshold self.bn2_list = [] for module in self.bn2.modules(): if isinstance(module, BatchNorm2d): self.bn2_list.append(module) def forward(self, x): residual = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if len(x) > 1: out = self.exchange(out, self.bn2_list, self.bn_threshold) if self.downsample is not None: residual = self.downsample(x) out = [out[l] + residual[l] for l in range(self.num_parallel)] out = self.relu(out) return out class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, num_parallel, bn_threshold, stride=1, downsample=None): super(Bottleneck, self).__init__() self.conv1 = conv1x1(inplanes, planes) self.bn1 = BatchNorm2dParallel(planes, num_parallel) self.conv2 = conv3x3(planes, planes, stride=stride) self.bn2 = BatchNorm2dParallel(planes, num_parallel) self.conv3 = conv1x1(planes, planes * 4) self.bn3 = BatchNorm2dParallel(planes * 4, num_parallel) self.relu = ModuleParallel(nn.ReLU(inplace=True)) self.num_parallel = num_parallel self.downsample = downsample self.stride = stride self.exchange = Exchange() self.bn_threshold = bn_threshold self.bn2_list = [] for module in self.bn2.modules(): if isinstance(module, BatchNorm2d): self.bn2_list.append(module) def forward(self, x): residual = x out = x out = self.conv1(out) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if len(x) > 1: out = self.exchange(out, self.bn2_list, self.bn_threshold) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) if self.downsample is not None: residual = self.downsample(x) out = [out[l] + residual[l] for l in range(self.num_parallel)] out = self.relu(out) return out class RefineNet(nn.Module): def __init__(self, block, layers, num_parallel, num_classes=21, bn_threshold=2e-2): self.inplanes = 64 self.num_parallel = num_parallel super(RefineNet, self).__init__() self.dropout = ModuleParallel(nn.Dropout(p=0.5)) self.conv1 = ModuleParallel(nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)) self.bn1 = BatchNorm2dParallel(64, num_parallel) self.relu = ModuleParallel(nn.ReLU(inplace=True)) self.maxpool = ModuleParallel(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) self.layer1 = self._make_layer(block, 64, layers[0], bn_threshold) self.layer2 = self._make_layer(block, 128, layers[1], bn_threshold, stride=2) self.layer3 = self._make_layer(block, 256, layers[2], bn_threshold, stride=2) self.layer4 = self._make_layer(block, 512, layers[3], bn_threshold, stride=2) self.p_ims1d2_outl1_dimred = conv3x3(2048, 512) self.adapt_stage1_b = self._make_rcu(512, 512, 2, 2) self.mflow_conv_g1_pool = self._make_crp(512, 512, 4) self.mflow_conv_g1_b = self._make_rcu(512, 512, 3, 2) self.mflow_conv_g1_b3_joint_varout_dimred = conv3x3(512, 256) self.p_ims1d2_outl2_dimred = conv3x3(1024, 256) self.adapt_stage2_b = self._make_rcu(256, 256, 2, 2) self.adapt_stage2_b2_joint_varout_dimred = conv3x3(256, 256) self.mflow_conv_g2_pool = self._make_crp(256, 256, 4) self.mflow_conv_g2_b = self._make_rcu(256, 256, 3, 2) self.mflow_conv_g2_b3_joint_varout_dimred = conv3x3(256, 256) self.p_ims1d2_outl3_dimred = conv3x3(512, 256) self.adapt_stage3_b = self._make_rcu(256, 256, 2, 2) self.adapt_stage3_b2_joint_varout_dimred = conv3x3(256, 256) self.mflow_conv_g3_pool = self._make_crp(256, 256, 4) self.mflow_conv_g3_b = self._make_rcu(256, 256, 3, 2) self.mflow_conv_g3_b3_joint_varout_dimred = conv3x3(256, 256) self.p_ims1d2_outl4_dimred = conv3x3(256, 256) self.adapt_stage4_b = self._make_rcu(256, 256, 2, 2) self.adapt_stage4_b2_joint_varout_dimred = conv3x3(256, 256) self.mflow_conv_g4_pool = self._make_crp(256, 256, 4) self.mflow_conv_g4_b = self._make_rcu(256, 256, 3, 2) self.clf_conv = conv3x3(256, num_classes, bias=True) self.alpha = nn.Parameter(torch.ones(num_parallel, requires_grad=True)) # self.alpha = nn.Parameter(torch.ones([1, num_parallel, 157, 157], requires_grad=True)) self.register_parameter('alpha', self.alpha) def _make_crp(self, in_planes, out_planes, num_stages): layers = [CRPBlock(in_planes, out_planes, num_stages, self.num_parallel)] return nn.Sequential(*layers) def _make_rcu(self, in_planes, out_planes, num_blocks, num_stages): layers = [RCUBlock(in_planes, out_planes, num_blocks, num_stages, self.num_parallel)] return nn.Sequential(*layers) def _make_layer(self, block, planes, num_blocks, bn_threshold, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( conv1x1(self.inplanes, planes * block.expansion, stride=stride), BatchNorm2dParallel(planes * block.expansion, self.num_parallel) ) layers = [] layers.append(block(self.inplanes, planes, self.num_parallel, bn_threshold, stride, downsample)) self.inplanes = planes * block.expansion for i in range(1, num_blocks): layers.append(block(self.inplanes, planes, self.num_parallel, bn_threshold)) return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) l1 = self.layer1(x) l2 = self.layer2(l1) l3 = self.layer3(l2) l4 = self.layer4(l3) l4 = self.dropout(l4) l3 = self.dropout(l3) x4 = self.p_ims1d2_outl1_dimred(l4) x4 = self.adapt_stage1_b(x4) x4 = self.relu(x4) x4 = self.mflow_conv_g1_pool(x4) x4 = self.mflow_conv_g1_b(x4) x4 = self.mflow_conv_g1_b3_joint_varout_dimred(x4) x4 = [nn.Upsample(size=l3[0].size()[2:], mode='bilinear', align_corners=True)(x4_) for x4_ in x4] x3 = self.p_ims1d2_outl2_dimred(l3) x3 = self.adapt_stage2_b(x3) x3 = self.adapt_stage2_b2_joint_varout_dimred(x3) x3 = [x3[l] + x4[l] for l in range(self.num_parallel)] x3 = self.relu(x3) x3 = self.mflow_conv_g2_pool(x3) x3 = self.mflow_conv_g2_b(x3) x3 = self.mflow_conv_g2_b3_joint_varout_dimred(x3) x3 = [nn.Upsample(size=l2[0].size()[2:], mode='bilinear', align_corners=True)(x3_) for x3_ in x3] x2 = self.p_ims1d2_outl3_dimred(l2) x2 = self.adapt_stage3_b(x2) x2 = self.adapt_stage3_b2_joint_varout_dimred(x2) x2 = [x2[l] + x3[l] for l in range(self.num_parallel)] x2 = self.relu(x2) x2 = self.mflow_conv_g3_pool(x2) x2 = self.mflow_conv_g3_b(x2) x2 = self.mflow_conv_g3_b3_joint_varout_dimred(x2) x2 = [nn.Upsample(size=l1[0].size()[2:], mode='bilinear', align_corners=True)(x2_) for x2_ in x2] x1 = self.p_ims1d2_outl4_dimred(l1) x1 = self.adapt_stage4_b(x1) x1 = self.adapt_stage4_b2_joint_varout_dimred(x1) x1 = [x1[l] + x2[l] for l in range(self.num_parallel)] x1 = self.relu(x1) x1 = self.mflow_conv_g4_pool(x1) x1 = self.mflow_conv_g4_b(x1) x1 = self.dropout(x1) out = self.clf_conv(x1) ens = 0 alpha_soft = F.softmax(self.alpha, dim = 0) for l in range(self.num_parallel): ens += alpha_soft[l] * out[l].detach() # alpha_soft = F.softmax(self.alpha, dim=1) # for l in range(self.num_parallel): # print(out[l].shape, l) # ens += alpha_soft[:, l].unsqueeze(1) * out[l].detach() out.append(ens) # return out, alpha_soft return out class RefineNet_Resnet18(nn.Module): def __init__(self, block, layers, num_parallel, num_classes=21, bn_threshold=2e-2): self.inplanes = 64 self.num_parallel = num_parallel super(RefineNet_Resnet18, self).__init__() self.dropout = ModuleParallel(nn.Dropout(p=0.5)) self.conv1 = ModuleParallel(nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)) self.bn1 = BatchNorm2dParallel(64, num_parallel) self.relu = ModuleParallel(nn.ReLU(inplace=True)) self.maxpool = ModuleParallel(nn.MaxPool2d(kernel_size=3, stride=2, padding=1)) self.layer1 = self._make_layer(block, 64, layers[0], bn_threshold) self.layer2 = self._make_layer(block, 128, layers[1], bn_threshold, stride=2) self.layer3 = self._make_layer(block, 256, layers[2], bn_threshold, stride=2) self.layer4 = self._make_layer(block, 512, layers[3], bn_threshold, stride=2) self.p_ims1d2_outl1_dimred = conv3x3(512, 256) self.adapt_stage1_b = self._make_rcu(256, 256, 2, 2) self.mflow_conv_g1_pool = self._make_crp(256, 256, 4) self.mflow_conv_g1_b = self._make_rcu(256, 256, 3, 2) self.mflow_conv_g1_b3_joint_varout_dimred = conv3x3(256, 64) self.p_ims1d2_outl2_dimred = conv3x3(256, 64) self.adapt_stage2_b = self._make_rcu(64, 64, 2, 2) self.adapt_stage2_b2_joint_varout_dimred = conv3x3(64, 64) self.mflow_conv_g2_pool = self._make_crp(64, 64, 4) self.mflow_conv_g2_b = self._make_rcu(64, 64, 3, 2) self.mflow_conv_g2_b3_joint_varout_dimred = conv3x3(64, 64) self.p_ims1d2_outl3_dimred = conv3x3(128, 64) self.adapt_stage3_b = self._make_rcu(64, 64, 2, 2) self.adapt_stage3_b2_joint_varout_dimred = conv3x3(64, 64) self.mflow_conv_g3_pool = self._make_crp(64, 64, 4) self.mflow_conv_g3_b = self._make_rcu(64, 64, 3, 2) self.mflow_conv_g3_b3_joint_varout_dimred = conv3x3(64, 64) self.p_ims1d2_outl4_dimred = conv3x3(64, 64) self.adapt_stage4_b = self._make_rcu(64, 64, 2, 2) self.adapt_stage4_b2_joint_varout_dimred = conv3x3(64, 64) self.mflow_conv_g4_pool = self._make_crp(64, 64, 4) self.mflow_conv_g4_b = self._make_rcu(64, 64, 3, 2) self.clf_conv = conv3x3(64, num_classes, bias=True) self.alpha = nn.Parameter(torch.ones(num_parallel, requires_grad=True)) # self.alpha = nn.Parameter(torch.ones([1, num_parallel, 157, 157], requires_grad=True)) self.register_parameter('alpha', self.alpha) def _make_crp(self, in_planes, out_planes, num_stages): layers = [CRPBlock(in_planes, out_planes, num_stages, self.num_parallel)] return nn.Sequential(*layers) def _make_rcu(self, in_planes, out_planes, num_blocks, num_stages): layers = [RCUBlock(in_planes, out_planes, num_blocks, num_stages, self.num_parallel)] return nn.Sequential(*layers) def _make_layer(self, block, planes, num_blocks, bn_threshold, stride=1): downsample = None if stride != 1 or self.inplanes != planes * block.expansion: downsample = nn.Sequential( conv1x1(self.inplanes, planes * block.expansion, stride=stride), BatchNorm2dParallel(planes * block.expansion, self.num_parallel) ) layers = [] layers.append(block(self.inplanes, planes, self.num_parallel, bn_threshold, stride, downsample)) self.inplanes = planes * block.expansion for i in range(1, num_blocks): layers.append(block(self.inplanes, planes, self.num_parallel, bn_threshold)) return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) l1 = self.layer1(x) l2 = self.layer2(l1) l3 = self.layer3(l2) l4 = self.layer4(l3) l4 = self.dropout(l4) l3 = self.dropout(l3) x4 = self.p_ims1d2_outl1_dimred(l4) x4 = self.adapt_stage1_b(x4) x4 = self.relu(x4) x4 = self.mflow_conv_g1_pool(x4) x4 = self.mflow_conv_g1_b(x4) x4 = self.mflow_conv_g1_b3_joint_varout_dimred(x4) x4 = [nn.Upsample(size=l3[0].size()[2:], mode='bilinear', align_corners=True)(x4_) for x4_ in x4] x3 = self.p_ims1d2_outl2_dimred(l3) x3 = self.adapt_stage2_b(x3) x3 = self.adapt_stage2_b2_joint_varout_dimred(x3) x3 = [x3[l] + x4[l] for l in range(self.num_parallel)] x3 = self.relu(x3) x3 = self.mflow_conv_g2_pool(x3) x3 = self.mflow_conv_g2_b(x3) x3 = self.mflow_conv_g2_b3_joint_varout_dimred(x3) x3 = [nn.Upsample(size=l2[0].size()[2:], mode='bilinear', align_corners=True)(x3_) for x3_ in x3] x2 = self.p_ims1d2_outl3_dimred(l2) x2 = self.adapt_stage3_b(x2) x2 = self.adapt_stage3_b2_joint_varout_dimred(x2) x2 = [x2[l] + x3[l] for l in range(self.num_parallel)] x2 = self.relu(x2) x2 = self.mflow_conv_g3_pool(x2) x2 = self.mflow_conv_g3_b(x2) x2 = self.mflow_conv_g3_b3_joint_varout_dimred(x2) x2 = [nn.Upsample(size=l1[0].size()[2:], mode='bilinear', align_corners=True)(x2_) for x2_ in x2] x1 = self.p_ims1d2_outl4_dimred(l1) x1 = self.adapt_stage4_b(x1) x1 = self.adapt_stage4_b2_joint_varout_dimred(x1) x1 = [x1[l] + x2[l] for l in range(self.num_parallel)] x1 = self.relu(x1) x1 = self.mflow_conv_g4_pool(x1) x1 = self.mflow_conv_g4_b(x1) x1 = self.dropout(x1) out = self.clf_conv(x1) ens = 0 alpha_soft = F.softmax(self.alpha, dim = 0) for l in range(self.num_parallel): ens += alpha_soft[l] * out[l].detach() # alpha_soft = F.softmax(self.alpha, dim=1) # for l in range(self.num_parallel): # print(out[l].shape, l) # ens += alpha_soft[:, l].unsqueeze(1) * out[l].detach() out.append(ens) return out, alpha_soft def refinenet(num_layers, num_classes, num_parallel, bn_threshold): refinnetClass = RefineNet if int(num_layers) == 18: layers = [2, 2, 2, 2] block = BasicBlock refinnetClass = RefineNet_Resnet18 elif int(num_layers) == 50: layers = [3, 4, 6, 3] block = Bottleneck elif int(num_layers) == 101: layers = [3, 4, 23, 3] block = Bottleneck elif int(num_layers) == 152: layers = [3, 8, 36, 3] block = Bottleneck else: print('invalid num_layers') model = refinnetClass(block, layers, num_parallel, num_classes, bn_threshold) return model def maybe_download(model_name, model_url, model_dir=None, map_location=None): if model_dir is None: torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch')) model_dir = os.getenv('TORCH_MODEL_ZOO', os.path.join(torch_home, 'models')) if not os.path.exists(model_dir): os.makedirs(model_dir) filename = '{}.pth.tar'.format(model_name) cached_file = os.path.join(model_dir, filename) if not os.path.exists(cached_file): # url = model_url # sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) # moves.urllib.request.urlretrieve(url, cached_file) raise Exception('cached file not found, maybe_download failed') return torch.load(cached_file, map_location=map_location) def model_init(model, num_layers, num_parallel, imagenet=False): if imagenet: key = str(num_layers) + '_imagenet' url = models_urls[key] state_dict = maybe_download(key, url) model_dict = expand_model_dict(model.state_dict(), state_dict, num_parallel) model.load_state_dict(model_dict, strict=True) return model def expand_model_dict(model_dict, state_dict, num_parallel): model_dict_keys = model_dict.keys() state_dict_keys = state_dict.keys() for model_dict_key in model_dict_keys: model_dict_key_re = model_dict_key.replace('module.', '') if model_dict_key_re in state_dict_keys: model_dict[model_dict_key] = state_dict[model_dict_key_re] for i in range(num_parallel): bn = '.bn_%d' % i replace = True if bn in model_dict_key_re else False model_dict_key_re = model_dict_key_re.replace(bn, '') if replace and model_dict_key_re in state_dict_keys: model_dict[model_dict_key] = state_dict[model_dict_key_re] return model_dict def L1_penalty(var, device): return torch.abs(var).sum().to(device)