SuyeonJ's picture
Upload folder using huggingface_hub
8d015d4 verified
raw
history blame
18.3 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from .backbone import CNNEncoder
from .transformer import FeatureTransformer
from .matching import (global_correlation_softmax, local_correlation_softmax, local_correlation_with_flow,
global_correlation_softmax_stereo, local_correlation_softmax_stereo,
correlation_softmax_depth)
from .attention import SelfAttnPropagation
from .geometry import flow_warp, compute_flow_with_depth_pose
from .reg_refine import BasicUpdateBlock
from .utils import normalize_img, feature_add_position, upsample_flow_with_mask
class UniMatch(nn.Module):
def __init__(self,
num_scales=1,
feature_channels=128,
upsample_factor=8,
num_head=1,
ffn_dim_expansion=4,
num_transformer_layers=6,
reg_refine=False, # optional local regression refinement
task='flow',
):
super(UniMatch, self).__init__()
self.feature_channels = feature_channels
self.num_scales = num_scales
self.upsample_factor = upsample_factor
self.reg_refine = reg_refine
# CNN
self.backbone = CNNEncoder(output_dim=feature_channels, num_output_scales=num_scales)
# Transformer
self.transformer = FeatureTransformer(num_layers=num_transformer_layers,
d_model=feature_channels,
nhead=num_head,
ffn_dim_expansion=ffn_dim_expansion,
)
# propagation with self-attn
self.feature_flow_attn = SelfAttnPropagation(in_channels=feature_channels)
if not self.reg_refine or task == 'depth':
# convex upsampling simiar to RAFT
# concat feature0 and low res flow as input
self.upsampler = nn.Sequential(nn.Conv2d(2 + feature_channels, 256, 3, 1, 1),
nn.ReLU(inplace=True),
nn.Conv2d(256, upsample_factor ** 2 * 9, 1, 1, 0))
# thus far, all the learnable parameters are task-agnostic
if reg_refine:
# optional task-specific local regression refinement
self.refine_proj = nn.Conv2d(128, 256, 1)
self.refine = BasicUpdateBlock(corr_channels=(2 * 4 + 1) ** 2,
downsample_factor=upsample_factor,
flow_dim=2 if task == 'flow' else 1,
bilinear_up=task == 'depth',
)
def extract_feature(self, img0, img1):
concat = torch.cat((img0, img1), dim=0) # [2B, C, H, W]
features = self.backbone(concat) # list of [2B, C, H, W], resolution from high to low
# reverse: resolution from low to high
features = features[::-1]
feature0, feature1 = [], []
for i in range(len(features)):
feature = features[i]
chunks = torch.chunk(feature, 2, 0) # tuple
feature0.append(chunks[0])
feature1.append(chunks[1])
return feature0, feature1
def upsample_flow(self, flow, feature, bilinear=False, upsample_factor=8,
is_depth=False):
if bilinear:
multiplier = 1 if is_depth else upsample_factor
up_flow = F.interpolate(flow, scale_factor=upsample_factor,
mode='bilinear', align_corners=True) * multiplier
else:
concat = torch.cat((flow, feature), dim=1)
mask = self.upsampler(concat)
up_flow = upsample_flow_with_mask(flow, mask, upsample_factor=self.upsample_factor,
is_depth=is_depth)
return up_flow
def forward(self, img0, img1,
attn_type=None,
attn_splits_list=None,
corr_radius_list=None,
prop_radius_list=None,
num_reg_refine=1,
pred_bidir_flow=False,
task='flow',
intrinsics=None,
pose=None, # relative pose transform
min_depth=1. / 0.5, # inverse depth range
max_depth=1. / 10,
num_depth_candidates=64,
depth_from_argmax=False,
pred_bidir_depth=False,
first_scaling=None,
**kwargs,
):
if 0.0 <= img0.max() <= 1.0:
img0 = img0*255
img1 = img1*255
if first_scaling is not None:
img0 = F.interpolate(img0, scale_factor=1/first_scaling, mode='bilinear')
img1 = F.interpolate(img1, scale_factor=1/first_scaling, mode='bilinear')
if pred_bidir_flow:
assert task == 'flow'
if task == 'depth':
assert self.num_scales == 1 # multi-scale depth model is not supported yet
results_dict = {}
flow_preds = []
if task == 'flow':
# stereo and depth tasks have normalized img in dataloader
img0, img1 = normalize_img(img0, img1) # [B, 3, H, W]
# list of features, resolution low to high ### CNN Features
feature0_list, feature1_list = self.extract_feature(img0, img1) # list of features
flow = None
if task != 'depth':
assert len(attn_splits_list) == len(corr_radius_list) == len(prop_radius_list) == self.num_scales
else:
assert len(attn_splits_list) == len(prop_radius_list) == self.num_scales == 1
for scale_idx in range(self.num_scales):
feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx]
if pred_bidir_flow and scale_idx > 0:
# predicting bidirectional flow with refinement
feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat((feature1, feature0), dim=0)
feature0_ori, feature1_ori = feature0, feature1
upsample_factor = self.upsample_factor * (2 ** (self.num_scales - 1 - scale_idx))
if task == 'depth':
# scale intrinsics
intrinsics_curr = intrinsics.clone()
intrinsics_curr[:, :2] = intrinsics_curr[:, :2] / upsample_factor
if scale_idx > 0:
assert task != 'depth' # not supported for multi-scale depth model
flow = F.interpolate(flow, scale_factor=2, mode='bilinear', align_corners=True) * 2
if flow is not None:
assert task != 'depth'
flow = flow.detach()
if task == 'stereo':
# construct flow vector for disparity
# flow here is actually disparity
zeros = torch.zeros_like(flow) # [B, 1, H, W]
# NOTE: reverse disp, disparity is positive
displace = torch.cat((-flow, zeros), dim=1) # [B, 2, H, W]
feature1 = flow_warp(feature1, displace) # [B, C, H, W]
elif task == 'flow':
feature1 = flow_warp(feature1, flow) # [B, C, H, W]
else:
raise NotImplementedError
attn_splits = attn_splits_list[scale_idx]
if task != 'depth':
corr_radius = corr_radius_list[scale_idx]
prop_radius = prop_radius_list[scale_idx]
# add position to features
feature0, feature1 = feature_add_position(feature0, feature1, attn_splits, self.feature_channels)
# Transformer
feature0, feature1 = self.transformer(feature0, feature1,
attn_type=attn_type,
attn_num_splits=attn_splits,
)
# correlation and softmax
if task == 'depth':
# first generate depth candidates
b, _, h, w = feature0.size()
depth_candidates = torch.linspace(min_depth, max_depth, num_depth_candidates).type_as(feature0)
depth_candidates = depth_candidates.view(1, num_depth_candidates, 1, 1).repeat(b, 1, h,
w) # [B, D, H, W]
flow_pred = correlation_softmax_depth(feature0, feature1,
intrinsics_curr,
pose,
depth_candidates=depth_candidates,
depth_from_argmax=depth_from_argmax,
pred_bidir_depth=pred_bidir_depth,
)[0]
else:
if corr_radius == -1: # global matching
if task == 'flow':
flow_pred = global_correlation_softmax(feature0, feature1, pred_bidir_flow)[0]
elif task == 'stereo':
flow_pred = global_correlation_softmax_stereo(feature0, feature1)[0]
else:
raise NotImplementedError
else: # local matching
if task == 'flow':
flow_pred = local_correlation_softmax(feature0, feature1, corr_radius)[0]
elif task == 'stereo':
flow_pred = local_correlation_softmax_stereo(feature0, feature1, corr_radius)[0]
else:
raise NotImplementedError
# flow or residual flow
flow = flow + flow_pred if flow is not None else flow_pred
if task == 'stereo':
flow = flow.clamp(min=0) # positive disparity
# upsample to the original resolution for supervison at training time only
if self.training:
flow_bilinear = self.upsample_flow(flow, None, bilinear=True, upsample_factor=upsample_factor,
is_depth=task == 'depth')
flow_preds.append(flow_bilinear)
# flow propagation with self-attn
if (pred_bidir_flow or pred_bidir_depth) and scale_idx == 0:
feature0 = torch.cat((feature0, feature1), dim=0) # [2*B, C, H, W] for propagation
flow = self.feature_flow_attn(feature0, flow.detach(),
local_window_attn=prop_radius > 0,
local_window_radius=prop_radius,
)
# bilinear exclude the last one
if self.training and scale_idx < self.num_scales - 1:
flow_up = self.upsample_flow(flow, feature0, bilinear=True,
upsample_factor=upsample_factor,
is_depth=task == 'depth')
flow_preds.append(flow_up)
if scale_idx == self.num_scales - 1:
if not self.reg_refine:
# upsample to the original image resolution
if task == 'stereo':
flow_pad = torch.cat((-flow, torch.zeros_like(flow)), dim=1) # [B, 2, H, W]
flow_up_pad = self.upsample_flow(flow_pad, feature0)
flow_up = -flow_up_pad[:, :1] # [B, 1, H, W]
elif task == 'depth':
depth_pad = torch.cat((flow, torch.zeros_like(flow)), dim=1) # [B, 2, H, W]
depth_up_pad = self.upsample_flow(depth_pad, feature0,
is_depth=True).clamp(min=min_depth, max=max_depth)
flow_up = depth_up_pad[:, :1] # [B, 1, H, W]
else:
flow_up = self.upsample_flow(flow, feature0)
flow_preds.append(flow_up)
else:
# task-specific local regression refinement
# supervise current flow
if self.training:
flow_up = self.upsample_flow(flow, feature0, bilinear=True,
upsample_factor=upsample_factor,
is_depth=task == 'depth')
flow_preds.append(flow_up)
assert num_reg_refine > 0
for refine_iter_idx in range(num_reg_refine):
flow = flow.detach()
if task == 'stereo':
zeros = torch.zeros_like(flow) # [B, 1, H, W]
# NOTE: reverse disp, disparity is positive
displace = torch.cat((-flow, zeros), dim=1) # [B, 2, H, W]
correlation = local_correlation_with_flow(
feature0_ori,
feature1_ori,
flow=displace,
local_radius=4,
) # [B, (2R+1)^2, H, W]
elif task == 'depth':
if pred_bidir_depth and refine_iter_idx == 0:
intrinsics_curr = intrinsics_curr.repeat(2, 1, 1)
pose = torch.cat((pose, torch.inverse(pose)), dim=0)
feature0_ori, feature1_ori = torch.cat((feature0_ori, feature1_ori),
dim=0), torch.cat((feature1_ori,
feature0_ori), dim=0)
flow_from_depth = compute_flow_with_depth_pose(1. / flow.squeeze(1),
intrinsics_curr,
extrinsics_rel=pose,
)
correlation = local_correlation_with_flow(
feature0_ori,
feature1_ori,
flow=flow_from_depth,
local_radius=4,
) # [B, (2R+1)^2, H, W]
else:
correlation = local_correlation_with_flow(
feature0_ori,
feature1_ori,
flow=flow,
local_radius=4,
) # [B, (2R+1)^2, H, W]
proj = self.refine_proj(feature0)
net, inp = torch.chunk(proj, chunks=2, dim=1)
net = torch.tanh(net)
inp = torch.relu(inp)
net, up_mask, residual_flow = self.refine(net, inp, correlation, flow.clone(),
)
if task == 'depth':
flow = (flow - residual_flow).clamp(min=min_depth, max=max_depth)
else:
flow = flow + residual_flow
if task == 'stereo':
flow = flow.clamp(min=0) # positive
if self.training or refine_iter_idx == num_reg_refine - 1:
if task == 'depth':
if refine_iter_idx < num_reg_refine - 1:
# bilinear upsampling
flow_up = self.upsample_flow(flow, feature0, bilinear=True,
upsample_factor=upsample_factor,
is_depth=True)
else:
# last one convex upsampling
# NOTE: clamp depth due to the zero padding in the unfold in the convex upsampling
# pad depth to 2 channels as flow
depth_pad = torch.cat((flow, torch.zeros_like(flow)), dim=1) # [B, 2, H, W]
depth_up_pad = self.upsample_flow(depth_pad, feature0,
is_depth=True).clamp(min=min_depth,
max=max_depth)
flow_up = depth_up_pad[:, :1] # [B, 1, H, W]
else:
flow_up = upsample_flow_with_mask(flow, up_mask, upsample_factor=self.upsample_factor,
is_depth=task == 'depth')
flow_preds.append(flow_up)
if first_scaling is not None:
for i in range(len(flow_preds)):
flow_preds[i] = F.interpolate(flow_preds[i], scale_factor=first_scaling, mode='bilinear')
if task == 'stereo':
for i in range(len(flow_preds)):
flow_preds[i] = flow_preds[i].squeeze(1) # [B, H, W]
# convert inverse depth to depth
if task == 'depth':
for i in range(len(flow_preds)):
flow_preds[i] = 1. / flow_preds[i].squeeze(1) # [B, H, W]
results_dict.update({'flow_preds': flow_preds})
return results_dict