|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from .backbone import CNNEncoder |
|
from .transformer import FeatureTransformer |
|
from .matching import (global_correlation_softmax, local_correlation_softmax, local_correlation_with_flow, |
|
global_correlation_softmax_stereo, local_correlation_softmax_stereo, |
|
correlation_softmax_depth) |
|
from .attention import SelfAttnPropagation |
|
from .geometry import flow_warp, compute_flow_with_depth_pose |
|
from .reg_refine import BasicUpdateBlock |
|
from .utils import normalize_img, feature_add_position, upsample_flow_with_mask |
|
|
|
|
|
class UniMatch(nn.Module): |
|
def __init__(self, |
|
num_scales=1, |
|
feature_channels=128, |
|
upsample_factor=8, |
|
num_head=1, |
|
ffn_dim_expansion=4, |
|
num_transformer_layers=6, |
|
reg_refine=False, |
|
task='flow', |
|
): |
|
super(UniMatch, self).__init__() |
|
|
|
self.feature_channels = feature_channels |
|
self.num_scales = num_scales |
|
self.upsample_factor = upsample_factor |
|
self.reg_refine = reg_refine |
|
|
|
|
|
self.backbone = CNNEncoder(output_dim=feature_channels, num_output_scales=num_scales) |
|
|
|
|
|
self.transformer = FeatureTransformer(num_layers=num_transformer_layers, |
|
d_model=feature_channels, |
|
nhead=num_head, |
|
ffn_dim_expansion=ffn_dim_expansion, |
|
) |
|
|
|
|
|
self.feature_flow_attn = SelfAttnPropagation(in_channels=feature_channels) |
|
|
|
if not self.reg_refine or task == 'depth': |
|
|
|
|
|
self.upsampler = nn.Sequential(nn.Conv2d(2 + feature_channels, 256, 3, 1, 1), |
|
nn.ReLU(inplace=True), |
|
nn.Conv2d(256, upsample_factor ** 2 * 9, 1, 1, 0)) |
|
|
|
|
|
if reg_refine: |
|
|
|
self.refine_proj = nn.Conv2d(128, 256, 1) |
|
self.refine = BasicUpdateBlock(corr_channels=(2 * 4 + 1) ** 2, |
|
downsample_factor=upsample_factor, |
|
flow_dim=2 if task == 'flow' else 1, |
|
bilinear_up=task == 'depth', |
|
) |
|
|
|
def extract_feature(self, img0, img1): |
|
concat = torch.cat((img0, img1), dim=0) |
|
features = self.backbone(concat) |
|
|
|
|
|
features = features[::-1] |
|
|
|
feature0, feature1 = [], [] |
|
|
|
for i in range(len(features)): |
|
feature = features[i] |
|
chunks = torch.chunk(feature, 2, 0) |
|
feature0.append(chunks[0]) |
|
feature1.append(chunks[1]) |
|
|
|
return feature0, feature1 |
|
|
|
def upsample_flow(self, flow, feature, bilinear=False, upsample_factor=8, |
|
is_depth=False): |
|
if bilinear: |
|
multiplier = 1 if is_depth else upsample_factor |
|
up_flow = F.interpolate(flow, scale_factor=upsample_factor, |
|
mode='bilinear', align_corners=True) * multiplier |
|
else: |
|
concat = torch.cat((flow, feature), dim=1) |
|
mask = self.upsampler(concat) |
|
up_flow = upsample_flow_with_mask(flow, mask, upsample_factor=self.upsample_factor, |
|
is_depth=is_depth) |
|
|
|
return up_flow |
|
|
|
def forward(self, img0, img1, |
|
attn_type=None, |
|
attn_splits_list=None, |
|
corr_radius_list=None, |
|
prop_radius_list=None, |
|
num_reg_refine=1, |
|
pred_bidir_flow=False, |
|
task='flow', |
|
intrinsics=None, |
|
pose=None, |
|
min_depth=1. / 0.5, |
|
max_depth=1. / 10, |
|
num_depth_candidates=64, |
|
depth_from_argmax=False, |
|
pred_bidir_depth=False, |
|
first_scaling=None, |
|
**kwargs, |
|
): |
|
|
|
if 0.0 <= img0.max() <= 1.0: |
|
img0 = img0*255 |
|
img1 = img1*255 |
|
|
|
if first_scaling is not None: |
|
img0 = F.interpolate(img0, scale_factor=1/first_scaling, mode='bilinear') |
|
img1 = F.interpolate(img1, scale_factor=1/first_scaling, mode='bilinear') |
|
|
|
|
|
if pred_bidir_flow: |
|
assert task == 'flow' |
|
|
|
if task == 'depth': |
|
assert self.num_scales == 1 |
|
|
|
results_dict = {} |
|
flow_preds = [] |
|
|
|
if task == 'flow': |
|
|
|
img0, img1 = normalize_img(img0, img1) |
|
|
|
|
|
feature0_list, feature1_list = self.extract_feature(img0, img1) |
|
|
|
flow = None |
|
|
|
if task != 'depth': |
|
assert len(attn_splits_list) == len(corr_radius_list) == len(prop_radius_list) == self.num_scales |
|
else: |
|
assert len(attn_splits_list) == len(prop_radius_list) == self.num_scales == 1 |
|
|
|
for scale_idx in range(self.num_scales): |
|
feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx] |
|
|
|
if pred_bidir_flow and scale_idx > 0: |
|
|
|
feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat((feature1, feature0), dim=0) |
|
|
|
feature0_ori, feature1_ori = feature0, feature1 |
|
|
|
upsample_factor = self.upsample_factor * (2 ** (self.num_scales - 1 - scale_idx)) |
|
|
|
if task == 'depth': |
|
|
|
intrinsics_curr = intrinsics.clone() |
|
intrinsics_curr[:, :2] = intrinsics_curr[:, :2] / upsample_factor |
|
|
|
if scale_idx > 0: |
|
assert task != 'depth' |
|
flow = F.interpolate(flow, scale_factor=2, mode='bilinear', align_corners=True) * 2 |
|
|
|
if flow is not None: |
|
assert task != 'depth' |
|
flow = flow.detach() |
|
|
|
if task == 'stereo': |
|
|
|
|
|
zeros = torch.zeros_like(flow) |
|
|
|
displace = torch.cat((-flow, zeros), dim=1) |
|
feature1 = flow_warp(feature1, displace) |
|
elif task == 'flow': |
|
feature1 = flow_warp(feature1, flow) |
|
else: |
|
raise NotImplementedError |
|
|
|
attn_splits = attn_splits_list[scale_idx] |
|
if task != 'depth': |
|
corr_radius = corr_radius_list[scale_idx] |
|
prop_radius = prop_radius_list[scale_idx] |
|
|
|
|
|
feature0, feature1 = feature_add_position(feature0, feature1, attn_splits, self.feature_channels) |
|
|
|
|
|
feature0, feature1 = self.transformer(feature0, feature1, |
|
attn_type=attn_type, |
|
attn_num_splits=attn_splits, |
|
) |
|
|
|
|
|
if task == 'depth': |
|
|
|
b, _, h, w = feature0.size() |
|
depth_candidates = torch.linspace(min_depth, max_depth, num_depth_candidates).type_as(feature0) |
|
depth_candidates = depth_candidates.view(1, num_depth_candidates, 1, 1).repeat(b, 1, h, |
|
w) |
|
|
|
flow_pred = correlation_softmax_depth(feature0, feature1, |
|
intrinsics_curr, |
|
pose, |
|
depth_candidates=depth_candidates, |
|
depth_from_argmax=depth_from_argmax, |
|
pred_bidir_depth=pred_bidir_depth, |
|
)[0] |
|
|
|
else: |
|
if corr_radius == -1: |
|
if task == 'flow': |
|
flow_pred = global_correlation_softmax(feature0, feature1, pred_bidir_flow)[0] |
|
elif task == 'stereo': |
|
flow_pred = global_correlation_softmax_stereo(feature0, feature1)[0] |
|
else: |
|
raise NotImplementedError |
|
else: |
|
if task == 'flow': |
|
flow_pred = local_correlation_softmax(feature0, feature1, corr_radius)[0] |
|
elif task == 'stereo': |
|
flow_pred = local_correlation_softmax_stereo(feature0, feature1, corr_radius)[0] |
|
else: |
|
raise NotImplementedError |
|
|
|
|
|
flow = flow + flow_pred if flow is not None else flow_pred |
|
|
|
if task == 'stereo': |
|
flow = flow.clamp(min=0) |
|
|
|
|
|
if self.training: |
|
flow_bilinear = self.upsample_flow(flow, None, bilinear=True, upsample_factor=upsample_factor, |
|
is_depth=task == 'depth') |
|
flow_preds.append(flow_bilinear) |
|
|
|
|
|
if (pred_bidir_flow or pred_bidir_depth) and scale_idx == 0: |
|
feature0 = torch.cat((feature0, feature1), dim=0) |
|
|
|
flow = self.feature_flow_attn(feature0, flow.detach(), |
|
local_window_attn=prop_radius > 0, |
|
local_window_radius=prop_radius, |
|
) |
|
|
|
|
|
if self.training and scale_idx < self.num_scales - 1: |
|
flow_up = self.upsample_flow(flow, feature0, bilinear=True, |
|
upsample_factor=upsample_factor, |
|
is_depth=task == 'depth') |
|
flow_preds.append(flow_up) |
|
|
|
if scale_idx == self.num_scales - 1: |
|
if not self.reg_refine: |
|
|
|
|
|
if task == 'stereo': |
|
flow_pad = torch.cat((-flow, torch.zeros_like(flow)), dim=1) |
|
flow_up_pad = self.upsample_flow(flow_pad, feature0) |
|
flow_up = -flow_up_pad[:, :1] |
|
elif task == 'depth': |
|
depth_pad = torch.cat((flow, torch.zeros_like(flow)), dim=1) |
|
depth_up_pad = self.upsample_flow(depth_pad, feature0, |
|
is_depth=True).clamp(min=min_depth, max=max_depth) |
|
flow_up = depth_up_pad[:, :1] |
|
else: |
|
flow_up = self.upsample_flow(flow, feature0) |
|
|
|
flow_preds.append(flow_up) |
|
else: |
|
|
|
|
|
if self.training: |
|
flow_up = self.upsample_flow(flow, feature0, bilinear=True, |
|
upsample_factor=upsample_factor, |
|
is_depth=task == 'depth') |
|
flow_preds.append(flow_up) |
|
|
|
assert num_reg_refine > 0 |
|
for refine_iter_idx in range(num_reg_refine): |
|
flow = flow.detach() |
|
|
|
if task == 'stereo': |
|
zeros = torch.zeros_like(flow) |
|
|
|
displace = torch.cat((-flow, zeros), dim=1) |
|
correlation = local_correlation_with_flow( |
|
feature0_ori, |
|
feature1_ori, |
|
flow=displace, |
|
local_radius=4, |
|
) |
|
elif task == 'depth': |
|
if pred_bidir_depth and refine_iter_idx == 0: |
|
intrinsics_curr = intrinsics_curr.repeat(2, 1, 1) |
|
pose = torch.cat((pose, torch.inverse(pose)), dim=0) |
|
|
|
feature0_ori, feature1_ori = torch.cat((feature0_ori, feature1_ori), |
|
dim=0), torch.cat((feature1_ori, |
|
feature0_ori), dim=0) |
|
|
|
flow_from_depth = compute_flow_with_depth_pose(1. / flow.squeeze(1), |
|
intrinsics_curr, |
|
extrinsics_rel=pose, |
|
) |
|
|
|
correlation = local_correlation_with_flow( |
|
feature0_ori, |
|
feature1_ori, |
|
flow=flow_from_depth, |
|
local_radius=4, |
|
) |
|
|
|
else: |
|
correlation = local_correlation_with_flow( |
|
feature0_ori, |
|
feature1_ori, |
|
flow=flow, |
|
local_radius=4, |
|
) |
|
|
|
proj = self.refine_proj(feature0) |
|
|
|
net, inp = torch.chunk(proj, chunks=2, dim=1) |
|
|
|
net = torch.tanh(net) |
|
inp = torch.relu(inp) |
|
|
|
net, up_mask, residual_flow = self.refine(net, inp, correlation, flow.clone(), |
|
) |
|
|
|
if task == 'depth': |
|
flow = (flow - residual_flow).clamp(min=min_depth, max=max_depth) |
|
else: |
|
flow = flow + residual_flow |
|
|
|
if task == 'stereo': |
|
flow = flow.clamp(min=0) |
|
|
|
if self.training or refine_iter_idx == num_reg_refine - 1: |
|
if task == 'depth': |
|
if refine_iter_idx < num_reg_refine - 1: |
|
|
|
flow_up = self.upsample_flow(flow, feature0, bilinear=True, |
|
upsample_factor=upsample_factor, |
|
is_depth=True) |
|
else: |
|
|
|
|
|
|
|
depth_pad = torch.cat((flow, torch.zeros_like(flow)), dim=1) |
|
depth_up_pad = self.upsample_flow(depth_pad, feature0, |
|
is_depth=True).clamp(min=min_depth, |
|
max=max_depth) |
|
flow_up = depth_up_pad[:, :1] |
|
|
|
else: |
|
flow_up = upsample_flow_with_mask(flow, up_mask, upsample_factor=self.upsample_factor, |
|
is_depth=task == 'depth') |
|
|
|
flow_preds.append(flow_up) |
|
|
|
if first_scaling is not None: |
|
for i in range(len(flow_preds)): |
|
flow_preds[i] = F.interpolate(flow_preds[i], scale_factor=first_scaling, mode='bilinear') |
|
|
|
if task == 'stereo': |
|
for i in range(len(flow_preds)): |
|
flow_preds[i] = flow_preds[i].squeeze(1) |
|
|
|
|
|
if task == 'depth': |
|
for i in range(len(flow_preds)): |
|
flow_preds[i] = 1. / flow_preds[i].squeeze(1) |
|
|
|
results_dict.update({'flow_preds': flow_preds}) |
|
|
|
return results_dict |
|
|