AutoLink / models /decoder.py
xingzhehe's picture
try fitst commit
91fc62a
import torch
import torch.nn.functional as F
from torch import nn
from typing import Union
import pytorch_lightning as pl
def gen_grid2d(grid_size: int, left_end: float=-1, right_end: float=1) -> torch.Tensor:
"""
Generate a grid of size (grid_size, grid_size, 2) with coordinate values in the range [left_end, right_end]
"""
x = torch.linspace(left_end, right_end, grid_size)
x, y = torch.meshgrid([x, x], indexing='ij')
grid = torch.cat((x.reshape(-1, 1), y.reshape(-1, 1)), dim=1).reshape(grid_size, grid_size, 2)
return grid
def draw_lines(paired_joints: torch.Tensor, heatmap_size: int=16, thick: Union[float, torch.Tensor]=1e-2) -> torch.Tensor:
"""
Draw lines on a grid.
:param paired_joints: (batch_size, n_points, 2, 2)
:return: (batch_size, n_points, grid_size, grid_size)
dist[i,j] = ||x[b,i,:]-y[b,j,:]||^2
"""
bs, n_points, _, _ = paired_joints.shape
start = paired_joints[:, :, 0, :] # (batch_size, n_points, 2)
end = paired_joints[:, :, 1, :] # (batch_size, n_points, 2)
paired_diff = end - start # (batch_size, n_points, 2)
grid = gen_grid2d(heatmap_size).to(paired_joints.device).reshape(1, 1, -1, 2)
diff_to_start = grid - start.unsqueeze(-2) # (batch_size, n_points, heatmap_size**2, 2)
# (batch_size, n_points, heatmap_size**2)
t = (diff_to_start @ paired_diff.unsqueeze(-1)).squeeze(-1) / (1e-8+paired_diff.square().sum(dim=-1, keepdim=True))
diff_to_end = grid - end.unsqueeze(-2) # (batch_size, n_points, heatmap_size**2, 2)
before_start = (t <= 0).float() * diff_to_start.square().sum(dim=-1)
after_end = (t >= 1).float() * diff_to_end.square().sum(dim=-1)
between_start_end = (0 < t).float() * (t < 1).float() * (grid - (start.unsqueeze(-2) + t.unsqueeze(-1) * paired_diff.unsqueeze(-2))).square().sum(dim=-1)
squared_dist = (before_start + after_end + between_start_end).reshape(bs, n_points, heatmap_size, heatmap_size)
heatmaps = torch.exp(- squared_dist / thick)
return heatmaps
class DownBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(out_channels),
nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False),
nn.LeakyReLU(0.2, True),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.net(x)
return x
class UpBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2, True),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.net(x)
return x
class Decoder(nn.Module):
def __init__(self, hyper_paras: pl.LightningModule.hparams) -> None:
super().__init__()
self.n_parts = hyper_paras['n_parts']
self.thick = hyper_paras['thick']
self.sklr = hyper_paras['sklr']
self.skeleton_idx = torch.triu_indices(self.n_parts, self.n_parts, offset=1)
self.n_skeleton = len(self.skeleton_idx[0])
self.alpha = nn.Parameter(torch.tensor(1.0), requires_grad=True)
skeleton_scalar = (torch.randn(self.n_parts, self.n_parts) / 10 - 4) / self.sklr
self.skeleton_scalar = nn.Parameter(skeleton_scalar, requires_grad=True)
self.down0 = nn.Sequential(
nn.Conv2d(3 + 1, 64, kernel_size=(3, 3), padding=1),
nn.LeakyReLU(0.2, True),
)
self.down1 = DownBlock(64, 128) # 64
self.down2 = DownBlock(128, 256) # 32
self.down3 = DownBlock(256, 512) # 16
self.down4 = DownBlock(512, 512) # 8
self.up1 = UpBlock(512, 512) # 16
self.up2 = UpBlock(512 + 512, 256) # 32
self.up3 = UpBlock(256 + 256, 128) # 64
self.up4 = UpBlock(128 + 128, 64) # 64
self.conv = nn.Conv2d(64+64, 3, kernel_size=(3, 3), padding=1)
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, a=0.2)
if m.bias is not None:
m.bias.data.zero_()
def skeleton_scalar_matrix(self) -> torch.Tensor:
"""
Give the skeleton scalar matrix
:return: (n_parts, n_parts)
"""
skeleton_scalar = F.softplus(self.skeleton_scalar * self.sklr)
skeleton_scalar = torch.triu(skeleton_scalar, diagonal=1)
skeleton_scalar = skeleton_scalar + skeleton_scalar.transpose(1, 0)
return skeleton_scalar
def rasterize(self, keypoints: torch.Tensor, output_size: int=128) -> torch.Tensor:
"""
Generate edge heatmap from keypoints, where edges are weighted by the learned scalars.
:param keypoints: (batch_size, n_points, 2)
:return: (batch_size, 1, heatmap_size, heatmap_size)
"""
paired_joints = torch.stack([keypoints[:, self.skeleton_idx[0], :2], keypoints[:, self.skeleton_idx[1], :2]], dim=2)
skeleton_scalar = F.softplus(self.skeleton_scalar * self.sklr)
skeleton_scalar = torch.triu(skeleton_scalar, diagonal=1)
skeleton_scalar = skeleton_scalar[self.skeleton_idx[0], self.skeleton_idx[1]].reshape(1, self.n_skeleton, 1, 1)
skeleton_heatmap_sep = draw_lines(paired_joints, heatmap_size=output_size, thick=self.thick)
skeleton_heatmap_sep = skeleton_heatmap_sep * skeleton_scalar.reshape(1, self.n_skeleton, 1, 1)
skeleton_heatmap = skeleton_heatmap_sep.max(dim=1, keepdim=True)[0]
return skeleton_heatmap
def forward(self, input_dict: dict) -> dict:
skeleton_heatmap = self.rasterize(input_dict['keypoints'])
x = torch.cat([input_dict['damaged_img'] * self.alpha, skeleton_heatmap], dim=1)
down_128 = self.down0(x)
down_64 = self.down1(down_128)
down_32 = self.down2(down_64)
down_16 = self.down3(down_32)
down_8 = self.down4(down_16)
up_8 = down_8
up_16 = torch.cat([self.up1(up_8), down_16], dim=1)
up_32 = torch.cat([self.up2(up_16), down_32], dim=1)
up_64 = torch.cat([self.up3(up_32), down_64], dim=1)
up_128 = torch.cat([self.up4(up_64), down_128], dim=1)
img = self.conv(up_128)
input_dict['heatmap'] = skeleton_heatmap
input_dict['img'] = img
return input_dict
if __name__ == '__main__':
model = Decoder({'z_dim': 256, 'n_parts': 10, 'n_embedding': 128, 'tau': 0.01})
print(sum(p.numel() for p in model.parameters() if p.requires_grad))