Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from .skeleton_DME import SkeletonConv, SkeletonPool, SkeletonUnpool | |
def calc_node_depth(topology): | |
def dfs(node, topology): | |
if topology[node] < 0: | |
return 0 | |
return 1 + dfs(topology[node], topology) | |
depth = [] | |
for i in range(len(topology)): | |
depth.append(dfs(i, topology)) | |
return depth | |
def residual_ratio(k): | |
return 1 / (k + 1) | |
class Affine(nn.Module): | |
def __init__(self, num_parameters, scale=True, bias=True, scale_init=1.0): | |
super(Affine, self).__init__() | |
if scale: | |
self.scale = nn.Parameter(torch.ones(num_parameters) * scale_init) | |
else: | |
self.register_parameter("scale", None) | |
if bias: | |
self.bias = nn.Parameter(torch.zeros(num_parameters)) | |
else: | |
self.register_parameter("bias", None) | |
def forward(self, input): | |
output = input | |
if self.scale is not None: | |
scale = self.scale.unsqueeze(0) | |
while scale.dim() < input.dim(): | |
scale = scale.unsqueeze(2) | |
output = output.mul(scale) | |
if self.bias is not None: | |
bias = self.bias.unsqueeze(0) | |
while bias.dim() < input.dim(): | |
bias = bias.unsqueeze(2) | |
output += bias | |
return output | |
class BatchStatistics(nn.Module): | |
def __init__(self, affine=-1): | |
super(BatchStatistics, self).__init__() | |
self.affine = nn.Sequential() if affine == -1 else Affine(affine) | |
self.loss = 0 | |
def clear_loss(self): | |
self.loss = 0 | |
def compute_loss(self, input): | |
input_flat = input.view(input.size(1), input.numel() // input.size(1)) | |
mu = input_flat.mean(1) | |
logvar = (input_flat.pow(2).mean(1) - mu.pow(2)).sqrt().log() | |
self.loss = mu.pow(2).mean() + logvar.pow(2).mean() | |
def forward(self, input): | |
self.compute_loss(input) | |
return self.affine(input) | |
class ResidualBlock(nn.Module): | |
def __init__( | |
self, in_channels, out_channels, kernel_size, stride, padding, residual_ratio, activation, batch_statistics=False, last_layer=False | |
): | |
super(ResidualBlock, self).__init__() | |
self.residual_ratio = residual_ratio | |
self.shortcut_ratio = 1 - residual_ratio | |
residual = [] | |
residual.append(nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding)) | |
if batch_statistics: | |
residual.append(BatchStatistics(out_channels)) | |
if not last_layer: | |
residual.append(nn.PReLU() if activation == "relu" else nn.Tanh()) | |
self.residual = nn.Sequential(*residual) | |
self.shortcut = nn.Sequential( | |
nn.AvgPool1d(kernel_size=2) if stride == 2 else nn.Sequential(), | |
nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0), | |
BatchStatistics(out_channels) if (in_channels != out_channels and batch_statistics is True) else nn.Sequential(), | |
) | |
def forward(self, input): | |
return self.residual(input).mul(self.residual_ratio) + self.shortcut(input).mul(self.shortcut_ratio) | |
class ResidualBlockTranspose(nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, residual_ratio, activation): | |
super(ResidualBlockTranspose, self).__init__() | |
self.residual_ratio = residual_ratio | |
self.shortcut_ratio = 1 - residual_ratio | |
self.residual = nn.Sequential( | |
nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, padding), nn.PReLU() if activation == "relu" else nn.Tanh() | |
) | |
self.shortcut = nn.Sequential( | |
nn.Upsample(scale_factor=2, mode="linear", align_corners=False) if stride == 2 else nn.Sequential(), | |
nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=1, padding=0), | |
) | |
def forward(self, input): | |
return self.residual(input).mul(self.residual_ratio) + self.shortcut(input).mul(self.shortcut_ratio) | |
class SkeletonResidual(nn.Module): | |
def __init__( | |
self, | |
topology, | |
neighbour_list, | |
joint_num, | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride, | |
padding, | |
padding_mode, | |
bias, | |
extra_conv, | |
pooling_mode, | |
activation, | |
last_pool, | |
): | |
super(SkeletonResidual, self).__init__() | |
kernel_even = False if kernel_size % 2 else True | |
seq = [] | |
for _ in range(extra_conv): | |
# (T, J, D) => (T, J, D) | |
seq.append( | |
SkeletonConv( | |
neighbour_list, | |
in_channels=in_channels, | |
out_channels=in_channels, | |
joint_num=joint_num, | |
kernel_size=kernel_size - 1 if kernel_even else kernel_size, | |
stride=1, | |
padding=padding, | |
padding_mode=padding_mode, | |
bias=bias, | |
) | |
) | |
seq.append(nn.PReLU() if activation == "relu" else nn.Tanh()) | |
# (T, J, D) => (T/2, J, 2D) | |
seq.append( | |
SkeletonConv( | |
neighbour_list, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
joint_num=joint_num, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
padding_mode=padding_mode, | |
bias=bias, | |
add_offset=False, | |
) | |
) | |
seq.append(nn.GroupNorm(10, out_channels)) # FIXME: REMEMBER TO CHANGE BACK !!! | |
self.residual = nn.Sequential(*seq) | |
# (T, J, D) => (T/2, J, 2D) | |
self.shortcut = SkeletonConv( | |
neighbour_list, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
joint_num=joint_num, | |
kernel_size=1, | |
stride=stride, | |
padding=0, | |
bias=True, | |
add_offset=False, | |
) | |
seq = [] | |
# (T/2, J, 2D) => (T/2, J', 2D) | |
pool = SkeletonPool( | |
edges=topology, pooling_mode=pooling_mode, channels_per_edge=out_channels // len(neighbour_list), last_pool=last_pool | |
) | |
if len(pool.pooling_list) != pool.edge_num: | |
seq.append(pool) | |
seq.append(nn.PReLU() if activation == "relu" else nn.Tanh()) | |
self.common = nn.Sequential(*seq) | |
def forward(self, input): | |
output = self.residual(input) + self.shortcut(input) | |
return self.common(output) | |
class SkeletonResidualTranspose(nn.Module): | |
def __init__( | |
self, | |
neighbour_list, | |
joint_num, | |
in_channels, | |
out_channels, | |
kernel_size, | |
padding, | |
padding_mode, | |
bias, | |
extra_conv, | |
pooling_list, | |
upsampling, | |
activation, | |
last_layer, | |
): | |
super(SkeletonResidualTranspose, self).__init__() | |
kernel_even = False if kernel_size % 2 else True | |
seq = [] | |
# (T, J, D) => (2T, J, D) | |
if upsampling is not None: | |
seq.append(nn.Upsample(scale_factor=2, mode=upsampling, align_corners=False)) | |
# (2T, J, D) => (2T, J', D) | |
unpool = SkeletonUnpool(pooling_list, in_channels // len(neighbour_list)) | |
if unpool.input_edge_num != unpool.output_edge_num: | |
seq.append(unpool) | |
self.common = nn.Sequential(*seq) | |
seq = [] | |
for _ in range(extra_conv): | |
# (2T, J', D) => (2T, J', D) | |
seq.append( | |
SkeletonConv( | |
neighbour_list, | |
in_channels=in_channels, | |
out_channels=in_channels, | |
joint_num=joint_num, | |
kernel_size=kernel_size - 1 if kernel_even else kernel_size, | |
stride=1, | |
padding=padding, | |
padding_mode=padding_mode, | |
bias=bias, | |
) | |
) | |
seq.append(nn.PReLU() if activation == "relu" else nn.Tanh()) | |
# (2T, J', D) => (2T, J', D/2) | |
seq.append( | |
SkeletonConv( | |
neighbour_list, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
joint_num=joint_num, | |
kernel_size=kernel_size - 1 if kernel_even else kernel_size, | |
stride=1, | |
padding=padding, | |
padding_mode=padding_mode, | |
bias=bias, | |
add_offset=False, | |
) | |
) | |
self.residual = nn.Sequential(*seq) | |
# (2T, J', D) => (2T, J', D/2) | |
self.shortcut = SkeletonConv( | |
neighbour_list, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
joint_num=joint_num, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=True, | |
add_offset=False, | |
) | |
if activation == "relu": | |
self.activation = nn.PReLU() if not last_layer else None | |
else: | |
self.activation = nn.Tanh() if not last_layer else None | |
def forward(self, input): | |
output = self.common(input) | |
output = self.residual(output) + self.shortcut(output) | |
if self.activation is not None: | |
return self.activation(output) | |
else: | |
return output | |