import torch import torch.nn as nn from .skeleton_DME import SkeletonConv, SkeletonPool, SkeletonUnpool def calc_node_depth(topology): def dfs(node, topology): if topology[node] < 0: return 0 return 1 + dfs(topology[node], topology) depth = [] for i in range(len(topology)): depth.append(dfs(i, topology)) return depth def residual_ratio(k): return 1 / (k + 1) class Affine(nn.Module): def __init__(self, num_parameters, scale=True, bias=True, scale_init=1.0): super(Affine, self).__init__() if scale: self.scale = nn.Parameter(torch.ones(num_parameters) * scale_init) else: self.register_parameter("scale", None) if bias: self.bias = nn.Parameter(torch.zeros(num_parameters)) else: self.register_parameter("bias", None) def forward(self, input): output = input if self.scale is not None: scale = self.scale.unsqueeze(0) while scale.dim() < input.dim(): scale = scale.unsqueeze(2) output = output.mul(scale) if self.bias is not None: bias = self.bias.unsqueeze(0) while bias.dim() < input.dim(): bias = bias.unsqueeze(2) output += bias return output class BatchStatistics(nn.Module): def __init__(self, affine=-1): super(BatchStatistics, self).__init__() self.affine = nn.Sequential() if affine == -1 else Affine(affine) self.loss = 0 def clear_loss(self): self.loss = 0 def compute_loss(self, input): input_flat = input.view(input.size(1), input.numel() // input.size(1)) mu = input_flat.mean(1) logvar = (input_flat.pow(2).mean(1) - mu.pow(2)).sqrt().log() self.loss = mu.pow(2).mean() + logvar.pow(2).mean() def forward(self, input): self.compute_loss(input) return self.affine(input) class ResidualBlock(nn.Module): def __init__( self, in_channels, out_channels, kernel_size, stride, padding, residual_ratio, activation, batch_statistics=False, last_layer=False ): super(ResidualBlock, self).__init__() self.residual_ratio = residual_ratio self.shortcut_ratio = 1 - residual_ratio residual = [] residual.append(nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)) if batch_statistics: residual.append(BatchStatistics(out_channels)) if not last_layer: residual.append(nn.PReLU() if activation == "relu" else nn.Tanh()) self.residual = nn.Sequential(*residual) self.shortcut = nn.Sequential( nn.AvgPool1d(kernel_size=2) if stride == 2 else nn.Sequential(), nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0), BatchStatistics(out_channels) if (in_channels != out_channels and batch_statistics is True) else nn.Sequential(), ) def forward(self, input): return self.residual(input).mul(self.residual_ratio) + self.shortcut(input).mul(self.shortcut_ratio) class ResidualBlockTranspose(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, stride, padding, residual_ratio, activation): super(ResidualBlockTranspose, self).__init__() self.residual_ratio = residual_ratio self.shortcut_ratio = 1 - residual_ratio self.residual = nn.Sequential( nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, padding), nn.PReLU() if activation == "relu" else nn.Tanh() ) self.shortcut = nn.Sequential( nn.Upsample(scale_factor=2, mode="linear", align_corners=False) if stride == 2 else nn.Sequential(), nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0), ) def forward(self, input): return self.residual(input).mul(self.residual_ratio) + self.shortcut(input).mul(self.shortcut_ratio) class SkeletonResidual(nn.Module): def __init__( self, topology, neighbour_list, joint_num, in_channels, out_channels, kernel_size, stride, padding, padding_mode, bias, extra_conv, pooling_mode, activation, last_pool, ): super(SkeletonResidual, self).__init__() kernel_even = False if kernel_size % 2 else True seq = [] for _ in range(extra_conv): # (T, J, D) => (T, J, D) seq.append( SkeletonConv( neighbour_list, in_channels=in_channels, out_channels=in_channels, joint_num=joint_num, kernel_size=kernel_size - 1 if kernel_even else kernel_size, stride=1, padding=padding, padding_mode=padding_mode, bias=bias, ) ) seq.append(nn.PReLU() if activation == "relu" else nn.Tanh()) # (T, J, D) => (T/2, J, 2D) seq.append( SkeletonConv( neighbour_list, in_channels=in_channels, out_channels=out_channels, joint_num=joint_num, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=padding_mode, bias=bias, add_offset=False, ) ) seq.append(nn.GroupNorm(10, out_channels)) # FIXME: REMEMBER TO CHANGE BACK !!! self.residual = nn.Sequential(*seq) # (T, J, D) => (T/2, J, 2D) self.shortcut = SkeletonConv( neighbour_list, in_channels=in_channels, out_channels=out_channels, joint_num=joint_num, kernel_size=1, stride=stride, padding=0, bias=True, add_offset=False, ) seq = [] # (T/2, J, 2D) => (T/2, J', 2D) pool = SkeletonPool( edges=topology, pooling_mode=pooling_mode, channels_per_edge=out_channels // len(neighbour_list), last_pool=last_pool ) if len(pool.pooling_list) != pool.edge_num: seq.append(pool) seq.append(nn.PReLU() if activation == "relu" else nn.Tanh()) self.common = nn.Sequential(*seq) def forward(self, input): output = self.residual(input) + self.shortcut(input) return self.common(output) class SkeletonResidualTranspose(nn.Module): def __init__( self, neighbour_list, joint_num, in_channels, out_channels, kernel_size, padding, padding_mode, bias, extra_conv, pooling_list, upsampling, activation, last_layer, ): super(SkeletonResidualTranspose, self).__init__() kernel_even = False if kernel_size % 2 else True seq = [] # (T, J, D) => (2T, J, D) if upsampling is not None: seq.append(nn.Upsample(scale_factor=2, mode=upsampling, align_corners=False)) # (2T, J, D) => (2T, J', D) unpool = SkeletonUnpool(pooling_list, in_channels // len(neighbour_list)) if unpool.input_edge_num != unpool.output_edge_num: seq.append(unpool) self.common = nn.Sequential(*seq) seq = [] for _ in range(extra_conv): # (2T, J', D) => (2T, J', D) seq.append( SkeletonConv( neighbour_list, in_channels=in_channels, out_channels=in_channels, joint_num=joint_num, kernel_size=kernel_size - 1 if kernel_even else kernel_size, stride=1, padding=padding, padding_mode=padding_mode, bias=bias, ) ) seq.append(nn.PReLU() if activation == "relu" else nn.Tanh()) # (2T, J', D) => (2T, J', D/2) seq.append( SkeletonConv( neighbour_list, in_channels=in_channels, out_channels=out_channels, joint_num=joint_num, kernel_size=kernel_size - 1 if kernel_even else kernel_size, stride=1, padding=padding, padding_mode=padding_mode, bias=bias, add_offset=False, ) ) self.residual = nn.Sequential(*seq) # (2T, J', D) => (2T, J', D/2) self.shortcut = SkeletonConv( neighbour_list, in_channels=in_channels, out_channels=out_channels, joint_num=joint_num, kernel_size=1, stride=1, padding=0, bias=True, add_offset=False, ) if activation == "relu": self.activation = nn.PReLU() if not last_layer else None else: self.activation = nn.Tanh() if not last_layer else None def forward(self, input): output = self.common(input) output = self.residual(output) + self.shortcut(output) if self.activation is not None: return self.activation(output) else: return output