import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.init as weight_init def conv3x3(in_planes, out_planes, stride=1, padding=1, bias=False): """3x3 convolution with padding.""" return nn.Conv2d(in_planes, out_planes, kernel_size=(3, 3), stride=(stride, stride), padding=padding, bias=bias) def conv1x1(in_planes, out_planes, stride=1, padding=0, bias=False): """1x1 convolution with padding.""" return nn.Conv2d(in_planes, out_planes, kernel_size=(1, 1), stride=(stride, stride), padding=padding, bias=bias) class ConvBlock(nn.Module): def __init__(self, in_planes, out_planes, lightweight = False): super(ConvBlock, self).__init__() if lightweight: self.conv1 = conv1x1(in_planes, int(out_planes / 2)) self.conv2 = conv1x1(int(out_planes / 2), int(out_planes / 4)) self.conv3 = conv1x1(int(out_planes / 4), int(out_planes / 4)) else: self.conv1 = conv3x3(in_planes, int(out_planes / 2)) self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4)) self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4)) self.bn1 = nn.BatchNorm2d(int(out_planes / 2)) self.bn2 = nn.BatchNorm2d(int(out_planes / 4)) self.bn3 = nn.BatchNorm2d(int(out_planes / 4)) if in_planes != out_planes: self.downsample = nn.Sequential( nn.Conv2d(in_planes, out_planes, kernel_size=(1, 1), stride=(1, 1), bias=False), nn.BatchNorm2d(out_planes), nn.ReLU6(True), ) else: self.downsample = None def forward(self, x): residual = x out1 = self.conv1(x) out1 = self.bn1(out1) out1 = F.relu6(out1, True) out2 = self.conv2(out1) out2 = self.bn2(out2) out2 = F.relu6(out2, True) out3 = self.conv3(out2) out3 = self.bn3(out3) out3 = F.relu6(out3, True) out3 = torch.cat((out1, out2, out3), 1) if self.downsample is not None: residual = self.downsample(residual) out3 += residual return out3 class HourGlass(nn.Module): def __init__(self, num_modules, depth, num_features, lightweight = False): super(HourGlass, self).__init__() self.num_modules = num_modules self.depth = depth self.features = num_features self.lightweight = lightweight self._generate_network(self.depth) def _generate_network(self, level): self.add_module('b1_' + str(level), ConvBlock(self.features, self.features, lightweight=self.lightweight)) self.add_module('b2_' + str(level), ConvBlock(self.features, self.features)) if level > 1: self._generate_network(level - 1) else: self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features)) self.add_module('b3_' + str(level), ConvBlock(self.features, self.features)) def _forward(self, level, inp): # Upper branch up1 = inp up1 = self._modules['b1_' + str(level)](up1) # Lower branch low1 = F.max_pool2d(inp, 2, stride=2) low1 = self._modules['b2_' + str(level)](low1) if level > 1: low2 = self._forward(level - 1, low1) else: low2 = low1 low2 = self._modules['b2_plus_' + str(level)](low2) low3 = low2 low3 = self._modules['b3_' + str(level)](low3) up2 = F.interpolate(low3, scale_factor=2, mode='nearest') return up1 + up2 def forward(self, x): return self._forward(self.depth, x) class QFAN(nn.Module): def __init__(self, num_modules=1, num_in=3, num_features = 128, num_out=68, return_features=False): super(QFAN, self).__init__() self.num_modules = num_modules self.num_in = num_in self.num_features = num_features self.num_out = num_out self.return_features = return_features # Base part self.conv1 = nn.Conv2d(self.num_in, int(self.num_features / 2), kernel_size=(7, 7), stride=(2, 2), padding=3) self.bn1 = nn.BatchNorm2d(int(self.num_features / 2)) self.conv2 = ConvBlock(int(self.num_features / 2), int(self.num_features / 2)) self.conv3 = ConvBlock(int(self.num_features / 2), self.num_features) self.conv4 = ConvBlock(self.num_features, self.num_features) # Stacking part for hg_module in range(self.num_modules): self.add_module('m' + str(hg_module), HourGlass(1, 4, self.num_features)) self.add_module('top_m_' + str(hg_module), ConvBlock(self.num_features, self.num_features)) self.add_module('conv_last' + str(hg_module), nn.Conv2d(self.num_features, self.num_features, kernel_size=(1, 1), stride=(1, 1), padding=0)) self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(self.num_features)) self.add_module('l' + str(hg_module), nn.Conv2d(self.num_features, self.num_out, kernel_size=(1, 1), stride=(1, 1), padding=0)) if hg_module < self.num_modules - 1: self.add_module( 'bl' + str(hg_module), nn.Conv2d(self.num_features, self.num_features, kernel_size=(1, 1), stride=(1, 1), padding=0)) self.add_module('al' + str(hg_module), nn.Conv2d(self.num_out, self.num_features, kernel_size=(1, 1), stride=(1, 1), padding=0)) def forward(self, x): features = [] x = F.relu(self.bn1(self.conv1(x)), True) x = F.max_pool2d(self.conv2(x), 2, stride=2) x = self.conv3(x) x = self.conv4(x) if self.return_features: features.append(x) previous = x outputs = [] for i in range(self.num_modules): hg = self._modules['m' + str(i)](previous) ll = hg ll = self._modules['top_m_' + str(i)](ll) ll = F.relu(self._modules['bn_end' + str(i)] (self._modules['conv_last' + str(i)](ll)), True) # Predict heatmaps tmp_out = self._modules['l' + str(i)](ll) outputs.append(tmp_out) if i < self.num_modules - 1: ll = self._modules['bl' + str(i)](ll) tmp_out_ = self._modules['al' + str(i)](tmp_out) previous = previous + ll + tmp_out_ if self.return_features: return outputs, features else: return outputs def init_weights(net, init_type='normal', gain=0.02): def init_func(m): classname = m.__class__.__name__ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): if init_type == 'normal': weight_init.normal_(m.weight.data, 0.0, gain) elif init_type == 'xavier': weight_init.xavier_normal_(m.weight.data, gain=gain) elif init_type == 'kaiming': weight_init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') elif init_type == 'orthogonal': weight_init.orthogonal_(m.weight.data, gain=gain) else: raise NotImplementedError('initialization method [%s] is not implemented' % init_type) if hasattr(m, 'bias') and m.bias is not None: weight_init.constant_(m.bias.data, 0.0) elif classname.find('BatchNorm2d') != -1: weight_init.normal_(m.weight.data, 1.0, gain) weight_init.constant_(m.bias.data, 0.0) net.apply(init_func) class FANAU(nn.Module): def __init__(self, num_modules=1, num_features = 128, n_points=66, block=ConvBlock): super(FANAU, self).__init__() self.num_modules = 1 self.num_features = num_features self.fan = QFAN(num_modules = self.num_modules, return_features=True) block = eval(block) if isinstance(block,str) else block # input features self.conv1 = nn.Sequential(nn.Conv2d(68, self.num_features, 1, 1), nn.BatchNorm2d(self.num_features), nn.ReLU6()) self.conv2 = nn.Sequential(nn.Conv2d(self.num_features, self.num_features, 1, 1), nn.BatchNorm2d(self.num_features), nn.ReLU6()) self.net = HourGlass(1,4, self.num_features, lightweight=True) self.conv_last = nn.Sequential(nn.Conv2d(self.num_features, self.num_features, 1, 1), nn.BatchNorm2d(self.num_features), nn.ReLU6()) self.l = nn.Conv2d(self.num_features, n_points, 1, 1) init_weights(self) def forward(self, x): self.fan.eval() # with torch.no_grad(): output, features = self.fan(x) # print(len(output), len(features)) # print(output[0].shape, features[0].shape) out = output[-1] x = self.conv1(out) + self.conv2(features[0]) x = self.net(x) x = self.conv_last(x) x = self.l(x) # print(x.shape) # quit() return x