import torch import torch.nn as nn import torch.nn.functional as F from .warp import warp def resize(x, scale_factor): return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False) def convrelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True): return nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias), nn.PReLU(out_channels) ) class ResBlock(nn.Module): def __init__(self, in_channels, side_channels, bias=True): super(ResBlock, self).__init__() self.side_channels = side_channels self.conv1 = nn.Sequential( nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), nn.PReLU(in_channels) ) self.conv2 = nn.Sequential( nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), nn.PReLU(side_channels) ) self.conv3 = nn.Sequential( nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), nn.PReLU(in_channels) ) self.conv4 = nn.Sequential( nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), nn.PReLU(side_channels) ) self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias) self.prelu = nn.PReLU(in_channels) def forward(self, x): out = self.conv1(x) res_feat = out[:, :-self.side_channels, ...] side_feat = out[:, -self.side_channels:, :, :] side_feat = self.conv2(side_feat) out = self.conv3(torch.cat([res_feat, side_feat], 1)) res_feat = out[:, :-self.side_channels, ...] side_feat = out[:, -self.side_channels:, :, :] side_feat = self.conv4(side_feat) out = self.conv5(torch.cat([res_feat, side_feat], 1)) out = self.prelu(x + out) return out class Encoder(nn.Module): def __init__(self, channels, large=False): super(Encoder, self).__init__() self.channels = channels prev_ch = 3 for idx, ch in enumerate(channels, 1): k = 7 if large and idx == 1 else 3 p = 3 if k == 7 else 1 self.register_module(f'pyramid{idx}', nn.Sequential( convrelu(prev_ch, ch, k, 2, p), convrelu(ch, ch, 3, 1, 1) )) prev_ch = ch def forward(self, in_x): fs = [] for idx in range(len(self.channels)): out_x = getattr(self, f'pyramid{idx + 1}')(in_x) fs.append(out_x) in_x = out_x return fs class InitDecoder(nn.Module): def __init__(self, in_ch, out_ch, skip_ch) -> None: super().__init__() self.convblock = nn.Sequential( convrelu(in_ch * 2 + 1, in_ch * 2), ResBlock(in_ch * 2, skip_ch), nn.ConvTranspose2d(in_ch * 2, out_ch + 4, 4, 2, 1, bias=True) ) def forward(self, f0, f1, embt): h, w = f0.shape[2:] embt = embt.repeat(1, 1, h, w) out = self.convblock(torch.cat([f0, f1, embt], 1)) flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1) ft_ = out[:, 4:, ...] return flow0, flow1, ft_ class IntermediateDecoder(nn.Module): def __init__(self, in_ch, out_ch, skip_ch) -> None: super().__init__() self.convblock = nn.Sequential( convrelu(in_ch * 2 + 2, in_ch * 2), ResBlock(in_ch * 2, skip_ch), nn.ConvTranspose2d(in_ch * 2, out_ch, 4, 2, 1, bias=True) ) self.conv_flow = nn.Conv2d(out_ch, 2, 3, 1, 1) def forward(self, f0, f1, flow_fwd, flow_bwd): f0_warp = warp(f0, flow_bwd) f1_warp = warp(f1, flow_fwd) f0_in = torch.cat([f0, f1_warp, flow_fwd], 1) f1_in = torch.cat([f1, f0_warp, flow_bwd], 1) out0 = self.convblock(f0_in) out1 = self.convblock(f1_in) flow_fwd = 2.0 * resize(flow_fwd, scale_factor=2.0) + self.conv_flow(out0) flow_bwd = 2.0 * resize(flow_bwd, scale_factor=2.0) + self.conv_flow(out1) return flow_fwd, flow_bwd, out0, out1