AutoLink / models /encoder.py
xingzhehe's picture
try fitst commit
91fc62a
import torch
import torch.nn.functional as F
from torch import nn
import pytorch_lightning as pl
def gen_grid2d(grid_size: int, left_end: float=-1, right_end: float=1) -> torch.Tensor:
"""
Generate a grid of size (grid_size, grid_size, 2) with coordinate values in the range [left_end, right_end]
"""
x = torch.linspace(left_end, right_end, grid_size)
x, y = torch.meshgrid([x, x], indexing='ij')
grid = torch.cat((x.reshape(-1, 1), y.reshape(-1, 1)), dim=1).reshape(grid_size, grid_size, 2)
return grid
class ResBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int) -> None:
super().__init__()
self.conv_res = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False),
nn.BatchNorm2d(out_channels)
)
self.net = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2, True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=1),
nn.BatchNorm2d(out_channels)
)
self.relu = nn.LeakyReLU(0.2, True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
res = self.conv_res(x)
x = self.net(x)
return self.relu(x + res)
class TransposedBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int) -> None:
super().__init__()
self.net = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2, True),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.net(x)
return x
class Detector(nn.Module):
def __init__(self, hyper_paras: pl.utilities.parsing.AttributeDict) -> None:
super().__init__()
self.n_parts = hyper_paras.n_parts
self.output_size = 32
self.conv = nn.Sequential(
ResBlock(3, 64), # 64
ResBlock(64, 128), # 32
ResBlock(128, 256), # 16
ResBlock(256, 512), # 8
TransposedBlock(512, 256), # 16
TransposedBlock(256, 128), # 32
nn.Conv2d(128, self.n_parts, kernel_size=3, padding=1),
)
grid = gen_grid2d(self.output_size).reshape(1, 1, self.output_size ** 2, 2)
self.coord = nn.Parameter(grid, requires_grad=False)
def forward(self, input_dict: dict) -> dict:
img = F.interpolate(input_dict['img'], size=(128, 128), mode='bilinear', align_corners=False)
prob_map = self.conv(img).reshape(img.shape[0], self.n_parts, -1, 1)
prob_map = F.softmax(prob_map, dim=2)
keypoints = self.coord * prob_map
keypoints = keypoints.sum(dim=2)
prob_map = prob_map.reshape(keypoints.shape[0], self.n_parts, self.output_size, self.output_size)
return {'keypoints': keypoints, 'prob_map': prob_map}
class Encoder(nn.Module):
def __init__(self, hyper_paras: pl.utilities.parsing.AttributeDict) -> None:
super().__init__()
self.detector = Detector(hyper_paras)
self.missing = 0.8 # hyper_paras.missing
self.block = 16 # hyper_paras.block
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, a=0.2)
if m.bias is not None:
m.bias.data.zero_()
def forward(self, input_dict: dict, need_masked_img: bool=False) -> dict:
mask_batch = self.detector(input_dict)
if need_masked_img:
damage_mask = torch.zeros(input_dict['img'].shape[0], 1, self.block, self.block, device=input_dict['img'].device).uniform_() > self.missing
damage_mask = F.interpolate(damage_mask.to(input_dict['img']), size=input_dict['img'].shape[-1], mode='nearest')
mask_batch['damaged_img'] = input_dict['img'] * damage_mask
return mask_batch