|
import torch |
|
import torch.nn as nn |
|
from .blocks.warp import warp |
|
from .blocks.raft import ( |
|
coords_grid, |
|
SmallUpdateBlock, BidirCorrBlock, BasicUpdateBlock |
|
) |
|
from .blocks.feat_enc import ( |
|
SmallEncoder, |
|
BasicEncoder, |
|
LargeEncoder |
|
) |
|
from .blocks.ifrnet import ( |
|
resize, |
|
Encoder, |
|
InitDecoder, |
|
IntermediateDecoder |
|
) |
|
from .blocks.multi_flow import ( |
|
multi_flow_combine, |
|
MultiFlowDecoder |
|
) |
|
|
|
from ..components import register |
|
|
|
from utils.padder import InputPadder |
|
|
|
|
|
def photometric_consistency(img0, img1, flow01): |
|
return (img0 - warp(img1, flow01)).abs().sum(dim=1, keepdims=True) |
|
|
|
|
|
def flow_consistency(flow01, flow10): |
|
return (flow01 + warp(flow10, flow01)).abs().sum(dim=1, keepdims=True) |
|
|
|
|
|
gaussian_kernel = torch.tensor([[1, 2, 1], |
|
[2, 4, 2], |
|
[1, 2, 1]]) / 16 |
|
gaussian_kernel = gaussian_kernel.repeat(2, 1, 1, 1) |
|
gaussian_kernel = gaussian_kernel.to("cpu") |
|
|
|
|
|
def gaussian(x): |
|
x = torch.nn.functional.pad(x, (1, 1, 1, 1), mode='reflect') |
|
out = torch.nn.functional.conv2d(x, gaussian_kernel, groups=x.shape[1]) |
|
|
|
return out |
|
|
|
|
|
def variance_flow(flow): |
|
flow = flow * torch.tensor(data=[2.0 / (flow.shape[3] - 1.0), 2.0 / (flow.shape[2] - 1.0)], dtype=flow.dtype, |
|
device=flow.device).view(1, 2, 1, 1) |
|
return (gaussian(flow ** 2) - gaussian(flow) ** 2 + 1e-4).sqrt().abs().sum(dim=1, keepdim=True) |
|
|
|
@register('amt_splat') |
|
class Model(nn.Module): |
|
def __init__(self, |
|
model_size='S', |
|
corr_radius=3, |
|
corr_lvls=4, |
|
num_flows=3, |
|
channels=[20, 32, 44, 56], |
|
skip_channels=20, |
|
scale_factor=1): |
|
super(Model, self).__init__() |
|
self.model_size = model_size |
|
self.radius = corr_radius |
|
self.corr_levels = corr_lvls |
|
self.num_flows = num_flows |
|
self.channels = channels |
|
self.skip_channels = skip_channels |
|
self.scale_factor = scale_factor |
|
if self.model_size == 'S': |
|
self.feat_encoder = SmallEncoder(output_dim=84, norm_fn='instance', dropout=0.) |
|
elif self.model_size == 'L': |
|
self.feat_encoder = BasicEncoder(output_dim=128, norm_fn='instance', dropout=0.) |
|
elif self.model_size == 'G': |
|
self.feat_encoder = LargeEncoder(output_dim=128, norm_fn='instance', dropout=0.) |
|
self.encoder = Encoder(channels, large=True) |
|
|
|
|
|
self.decoder3 = IntermediateDecoder(channels[2], channels[1], skip_channels) |
|
self.decoder2 = IntermediateDecoder(channels[1] * 2, channels[0], skip_channels) |
|
self.decoder1 = MultiFlowDecoder(channels[0] * 2, skip_channels, num_flows) |
|
|
|
self.update4 = self._get_updateblock(channels[2]) |
|
self.update3_low = self._get_updateblock(channels[1] * 2, 2) |
|
self.update2_low = self._get_updateblock(channels[0] * 2, 4) |
|
|
|
if self.model_size == 'G': |
|
self.update3_high = self._get_updateblock(channels[1] * 2, None) |
|
self.update2_high = self._get_updateblock(channels[0] * 2, None) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_updateblock(self, cdim, scale_factor=None): |
|
return BasicUpdateBlock(cdim=cdim, hidden_dim=192, flow_dim=64, |
|
corr_dim=256, corr_dim2=192, fc_dim=188, |
|
scale_factor=scale_factor, corr_levels=self.corr_levels, |
|
radius=self.radius) |
|
|
|
def _corr_scale_lookup(self, corr_fn, coord, flow_fwd, flow_bwd, embt, downsample=1): |
|
|
|
|
|
t1_scale = 1. / embt |
|
t0_scale = 1. / (1. - embt) |
|
if downsample != 1: |
|
inv = 1 / downsample |
|
flow_fwd = inv * resize(flow_fwd, scale_factor=inv) |
|
flow_bwd = inv * resize(flow_bwd, scale_factor=inv) |
|
|
|
corr_fwd, corr_bwd = corr_fn(coord + flow_fwd, coord + flow_bwd) |
|
return corr_fwd, corr_bwd, flow_fwd, flow_bwd |
|
|
|
def get_splat_weight(self, img0, img1, flow01, flow10): |
|
M_splat = 1 / (1 + self.alpha_splat_photo_consistency * photometric_consistency(img0, img1, flow01).detach()) + \ |
|
1 / (1 + self.alpha_splat_flow_consistency * flow_consistency(flow01, flow10).detach()) + \ |
|
1 / (1 + self.alpha_splat_variation_flow * variance_flow(flow01).detach()) |
|
return M_splat * self.alpha |
|
|
|
|
|
def forward(self, img0, img1, time_step, scale_factor=1.0, eval=False, **kwargs): |
|
scale_factor = self.scale_factor |
|
padder = InputPadder(img0.shape, divisor=int(16 / scale_factor)) |
|
img0, img1 = padder.pad(img0, img1) |
|
mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) |
|
img0 = img0 - mean_ |
|
img1 = img1 - mean_ |
|
img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0 |
|
img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1 |
|
b, _, h, w = img0_.shape |
|
coords = coords_grid(b, h // 8, w // 8, img0.device) |
|
flow_fwd_4, flow_bwd_4 = torch.zeros(b, 2, h // 8, w // 8), torch.zeros(b, 2, h // 8, w // 8) |
|
|
|
fmap0, fmap1 = self.feat_encoder([img0_, img1_]) |
|
corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels) |
|
|
|
|
|
|
|
f0_1, f0_2, f0_3 = self.encoder(img0_) |
|
f1_1, f1_2, f1_3 = self.encoder(img1_) |
|
|
|
|
|
corr_fwd_4, corr_bwd_4, _, _ = self._corr_scale_lookup(corr_fn, coords, flow_fwd_4, flow_bwd_4, time_step) |
|
|
|
|
|
delta_f0_3_, delta_flow_fwd_4 = self.update4(f0_3, flow_fwd_4, corr_fwd_4) |
|
delta_f1_3_, delta_flow_bwd_4 = self.update4(f0_3, flow_bwd_4, corr_bwd_4) |
|
up_f0_3 = f0_3 + delta_f0_3_ |
|
up_f1_3 = f1_3 + delta_f1_3_ |
|
flow_fwd_4 = flow_fwd_4 + delta_flow_fwd_4 |
|
flow_bwd_4 = flow_bwd_4 + delta_flow_bwd_4 |
|
|
|
|
|
flow_fwd_3, flow_bwd_3, f0_2_, f1_2_ = self.decoder3(up_f0_3, up_f1_3, flow_fwd_4, flow_bwd_4) |
|
corr_fwd_3, corr_bwd_3, flow_fwd_3_, flow_bwd_3_ = self._corr_scale_lookup(corr_fn, |
|
coords, flow_fwd_3, flow_bwd_3, |
|
time_step, downsample=2) |
|
|
|
|
|
f0_2 = torch.cat([f0_2, f0_2_], dim=1) |
|
f1_2 = torch.cat([f1_2, f1_2_], dim=1) |
|
delta_f0_2_, delta_flow_fwd_3 = self.update3_low(f0_2, flow_fwd_3_, corr_fwd_3) |
|
delta_f1_2_, delta_flow_bwd_3 = self.update3_low(f1_2, flow_bwd_3_, corr_bwd_3) |
|
f0_2 = f0_2 + delta_f0_2_ |
|
f1_2 = f1_2 + delta_f1_2_ |
|
flow_fwd_3 = flow_fwd_3 + delta_flow_fwd_3 |
|
flow_bwd_3 = flow_bwd_3 + delta_flow_bwd_3 |
|
|
|
if self.model_size == 'G': |
|
|
|
corr_fwd_3 = resize(corr_fwd_3, scale_factor=2.0) |
|
corr_bwd_3 = resize(corr_bwd_3, scale_factor=2.0) |
|
delta_f0_2_, delta_flow_fwd_3 = self.update3_high(f0_2, flow_fwd_3, corr_fwd_3) |
|
delta_f1_2_, delta_flow_bwd_3 = self.update3_high(f1_2, flow_bwd_3, corr_bwd_3) |
|
up_f0_2 = f0_2 + delta_f0_2_ |
|
up_f1_2 = f1_2 + delta_f1_2_ |
|
flow_fwd_3 = flow_fwd_3 + delta_flow_fwd_3 |
|
flow_bwd_3 = flow_bwd_3 + delta_flow_bwd_3 |
|
|
|
|
|
flow_fwd_2, flow_bwd_2, f0_1_, f1_1_ = self.decoder2(up_f0_2, up_f1_2, flow_fwd_3, flow_bwd_3) |
|
corr_fwd_2, corr_bwd_2, flow_fwd_2_, flow_bwd_2_ = self._corr_scale_lookup(corr_fn, |
|
coords, flow_fwd_2, flow_bwd_2, |
|
time_step, downsample=4) |
|
|
|
|
|
f0_1 = torch.cat([f0_1, f0_1_], dim=1) |
|
f1_1 = torch.cat([f1_1, f1_1_], dim=1) |
|
delta_f0_1_, delta_flow_fwd_2 = self.update2_low(f0_1, flow_fwd_2_, corr_fwd_2) |
|
delta_f1_1_, delta_flow_bwd_2 = self.update2_low(f1_1, flow_bwd_2_, corr_bwd_2) |
|
f0_1 = f0_1 + delta_f0_1_ |
|
f1_1 = f1_1 + delta_f1_1_ |
|
flow_fwd_2 = flow_fwd_2 + delta_flow_fwd_2 |
|
flow_bwd_2 = flow_bwd_2 + delta_flow_bwd_2 |
|
if self.model_size == 'G': |
|
|
|
corr_fwd_2 = resize(corr_fwd_2, scale_factor=4.0) |
|
corr_bwd_2 = resize(corr_bwd_2, scale_factor=4.0) |
|
delta_f0_1_, delta_flow_fwd_2 = self.update2_high(f0_1, flow_fwd_2, corr_fwd_2) |
|
delta_f1_1_, delta_flow_bwd_2 = self.update2_high(f1_1, flow_bwd_2, corr_bwd_2) |
|
f0_1 = f0_1 + delta_f0_1_ |
|
f1_1 = f1_1 + delta_f1_1_ |
|
flow_fwd_2 = flow_fwd_2 + delta_flow_fwd_2 |
|
flow_bwd_2 = flow_bwd_2 + delta_flow_bwd_2 |
|
|
|
|
|
flow_fwd_1, flow_bwd_1, mask_fwd, mask_bwd = self.decoder1(f0_1, f1_1, flow_fwd_2, flow_bwd_2) |
|
|
|
if scale_factor != 1.0: |
|
flow_fwd_1 = resize(flow_fwd_1, scale_factor=(1.0 / scale_factor)) * (1.0 / scale_factor) |
|
flow_bwd_1 = resize(flow_bwd_1, scale_factor=(1.0 / scale_factor)) * (1.0 / scale_factor) |
|
mask_fwd = resize(mask_fwd, scale_factor=(1.0 / scale_factor)) |
|
mask_bwd = resize(mask_bwd, scale_factor=(1.0 / scale_factor)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
imgt_pred = multi_flow_combine(img0, img1, flow_fwd_1, flow_bwd_1, |
|
mask_fwd, mask_bwd, time_step, mean_) |
|
imgt_pred = torch.clamp(imgt_pred, 0, 1) |
|
imgt_pred = padder.unpad(imgt_pred) |
|
|
|
if eval: |
|
return {'imgt_pred': imgt_pred, } |
|
else: |
|
flow_fwd_1 = flow_fwd_1.reshape(b, self.num_flows, 2, int(h / scale_factor), int(w / scale_factor)) |
|
flow_bwd_1 = flow_bwd_1.reshape(b, self.num_flows, 2, int(h / scale_factor), int(w / scale_factor)) |
|
return { |
|
'imgt_pred': imgt_pred, |
|
'flow0_pred': [flow_fwd_1 * 0.5, flow_fwd_2 * 0.5, flow_fwd_3 * 0.5, flow_fwd_4 * 0.5], |
|
'flow1_pred': [flow_bwd_1 * 0.5, flow_bwd_2 * 0.5, flow_bwd_3 * 0.5, flow_bwd_4 * 0.5], |
|
'flowfwd': flow_fwd_1[:, 0] * 0.5, |
|
'flowbwd': flow_bwd_1[:, 0] * 0.5 |
|
} |
|
|