File size: 7,260 Bytes
91fc62a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import torch
import torch.nn.functional as F
from torch import nn
from typing import Union
import pytorch_lightning as pl
def gen_grid2d(grid_size: int, left_end: float=-1, right_end: float=1) -> torch.Tensor:
"""
Generate a grid of size (grid_size, grid_size, 2) with coordinate values in the range [left_end, right_end]
"""
x = torch.linspace(left_end, right_end, grid_size)
x, y = torch.meshgrid([x, x], indexing='ij')
grid = torch.cat((x.reshape(-1, 1), y.reshape(-1, 1)), dim=1).reshape(grid_size, grid_size, 2)
return grid
def draw_lines(paired_joints: torch.Tensor, heatmap_size: int=16, thick: Union[float, torch.Tensor]=1e-2) -> torch.Tensor:
"""
Draw lines on a grid.
:param paired_joints: (batch_size, n_points, 2, 2)
:return: (batch_size, n_points, grid_size, grid_size)
dist[i,j] = ||x[b,i,:]-y[b,j,:]||^2
"""
bs, n_points, _, _ = paired_joints.shape
start = paired_joints[:, :, 0, :] # (batch_size, n_points, 2)
end = paired_joints[:, :, 1, :] # (batch_size, n_points, 2)
paired_diff = end - start # (batch_size, n_points, 2)
grid = gen_grid2d(heatmap_size).to(paired_joints.device).reshape(1, 1, -1, 2)
diff_to_start = grid - start.unsqueeze(-2) # (batch_size, n_points, heatmap_size**2, 2)
# (batch_size, n_points, heatmap_size**2)
t = (diff_to_start @ paired_diff.unsqueeze(-1)).squeeze(-1) / (1e-8+paired_diff.square().sum(dim=-1, keepdim=True))
diff_to_end = grid - end.unsqueeze(-2) # (batch_size, n_points, heatmap_size**2, 2)
before_start = (t <= 0).float() * diff_to_start.square().sum(dim=-1)
after_end = (t >= 1).float() * diff_to_end.square().sum(dim=-1)
between_start_end = (0 < t).float() * (t < 1).float() * (grid - (start.unsqueeze(-2) + t.unsqueeze(-1) * paired_diff.unsqueeze(-2))).square().sum(dim=-1)
squared_dist = (before_start + after_end + between_start_end).reshape(bs, n_points, heatmap_size, heatmap_size)
heatmaps = torch.exp(- squared_dist / thick)
return heatmaps
class DownBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(out_channels),
nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False),
nn.LeakyReLU(0.2, True),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.net(x)
return x
class UpBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(in_channels, out_channels, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_channels, out_channels, kernel_size=(3, 3), padding=1),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2, True),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.net(x)
return x
class Decoder(nn.Module):
def __init__(self, hyper_paras: pl.LightningModule.hparams) -> None:
super().__init__()
self.n_parts = hyper_paras['n_parts']
self.thick = hyper_paras['thick']
self.sklr = hyper_paras['sklr']
self.skeleton_idx = torch.triu_indices(self.n_parts, self.n_parts, offset=1)
self.n_skeleton = len(self.skeleton_idx[0])
self.alpha = nn.Parameter(torch.tensor(1.0), requires_grad=True)
skeleton_scalar = (torch.randn(self.n_parts, self.n_parts) / 10 - 4) / self.sklr
self.skeleton_scalar = nn.Parameter(skeleton_scalar, requires_grad=True)
self.down0 = nn.Sequential(
nn.Conv2d(3 + 1, 64, kernel_size=(3, 3), padding=1),
nn.LeakyReLU(0.2, True),
)
self.down1 = DownBlock(64, 128) # 64
self.down2 = DownBlock(128, 256) # 32
self.down3 = DownBlock(256, 512) # 16
self.down4 = DownBlock(512, 512) # 8
self.up1 = UpBlock(512, 512) # 16
self.up2 = UpBlock(512 + 512, 256) # 32
self.up3 = UpBlock(256 + 256, 128) # 64
self.up4 = UpBlock(128 + 128, 64) # 64
self.conv = nn.Conv2d(64+64, 3, kernel_size=(3, 3), padding=1)
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, a=0.2)
if m.bias is not None:
m.bias.data.zero_()
def skeleton_scalar_matrix(self) -> torch.Tensor:
"""
Give the skeleton scalar matrix
:return: (n_parts, n_parts)
"""
skeleton_scalar = F.softplus(self.skeleton_scalar * self.sklr)
skeleton_scalar = torch.triu(skeleton_scalar, diagonal=1)
skeleton_scalar = skeleton_scalar + skeleton_scalar.transpose(1, 0)
return skeleton_scalar
def rasterize(self, keypoints: torch.Tensor, output_size: int=128) -> torch.Tensor:
"""
Generate edge heatmap from keypoints, where edges are weighted by the learned scalars.
:param keypoints: (batch_size, n_points, 2)
:return: (batch_size, 1, heatmap_size, heatmap_size)
"""
paired_joints = torch.stack([keypoints[:, self.skeleton_idx[0], :2], keypoints[:, self.skeleton_idx[1], :2]], dim=2)
skeleton_scalar = F.softplus(self.skeleton_scalar * self.sklr)
skeleton_scalar = torch.triu(skeleton_scalar, diagonal=1)
skeleton_scalar = skeleton_scalar[self.skeleton_idx[0], self.skeleton_idx[1]].reshape(1, self.n_skeleton, 1, 1)
skeleton_heatmap_sep = draw_lines(paired_joints, heatmap_size=output_size, thick=self.thick)
skeleton_heatmap_sep = skeleton_heatmap_sep * skeleton_scalar.reshape(1, self.n_skeleton, 1, 1)
skeleton_heatmap = skeleton_heatmap_sep.max(dim=1, keepdim=True)[0]
return skeleton_heatmap
def forward(self, input_dict: dict) -> dict:
skeleton_heatmap = self.rasterize(input_dict['keypoints'])
x = torch.cat([input_dict['damaged_img'] * self.alpha, skeleton_heatmap], dim=1)
down_128 = self.down0(x)
down_64 = self.down1(down_128)
down_32 = self.down2(down_64)
down_16 = self.down3(down_32)
down_8 = self.down4(down_16)
up_8 = down_8
up_16 = torch.cat([self.up1(up_8), down_16], dim=1)
up_32 = torch.cat([self.up2(up_16), down_32], dim=1)
up_64 = torch.cat([self.up3(up_32), down_64], dim=1)
up_128 = torch.cat([self.up4(up_64), down_128], dim=1)
img = self.conv(up_128)
input_dict['heatmap'] = skeleton_heatmap
input_dict['img'] = img
return input_dict
if __name__ == '__main__':
model = Decoder({'z_dim': 256, 'n_parts': 10, 'n_embedding': 128, 'tau': 0.01})
print(sum(p.numel() for p in model.parameters() if p.requires_grad))
|