|
|
|
import math |
|
import torch |
|
import typing |
|
|
|
from ..components import register |
|
from .backwarp import * |
|
from .pwcnet import * |
|
from .softsplat import * |
|
from .unimatch.unimatch import UniMatch |
|
|
|
|
|
def photometric_consistency(img0, img1, flow01): |
|
return (img0 - backwarp(img1, flow01)).abs().sum(dim=1, keepdims=True) |
|
|
|
|
|
def flow_consistency(flow01, flow10): |
|
return (flow01 + backwarp(flow10, flow01)).abs().sum(dim=1, keepdims=True) |
|
|
|
|
|
|
|
|
|
def gaussian(x): |
|
gaussian_kernel = torch.tensor([[1, 2, 1], |
|
[2, 4, 2], |
|
[1, 2, 1]]) / 16 |
|
gaussian_kernel = gaussian_kernel.repeat(2, 1, 1, 1) |
|
gaussian_kernel = gaussian_kernel.to("cpu") |
|
x = torch.nn.functional.pad(x, (1, 1, 1, 1), mode='reflect') |
|
out = torch.nn.functional.conv2d(x, gaussian_kernel, groups=x.shape[1]) |
|
|
|
return out |
|
|
|
|
|
def variance_flow(flow): |
|
flow = flow * torch.tensor(data=[2.0 / (flow.shape[3] - 1.0), 2.0 / (flow.shape[2] - 1.0)], dtype=flow.dtype, |
|
device=flow.device).view(1, 2, 1, 1) |
|
return (gaussian(flow ** 2) - gaussian(flow) ** 2 + 1e-4).sqrt().abs().sum(dim=1, keepdim=True) |
|
|
|
|
|
|
|
def forwarp_mframe_mask(tenIn1, tenFlow1, t1, tenIn2, tenFlow2, t2, tenMetric1=None, tenMetric2=None): |
|
def one_fdir(tenIn, tenFlow, td, tenMetric): |
|
tenIn = torch.cat([tenIn * td * (tenMetric).clip(-20.0, 20.0).exp(), td * (tenMetric).clip(-20.0, 20.0).exp()], |
|
1) |
|
|
|
tenOut = softsplat_func.apply(tenIn, tenFlow) |
|
|
|
return tenOut[:, :-1, :, :], tenOut[:, -1:, :, :] + 0.0000001 |
|
|
|
flow_num = tenFlow1.shape[0] |
|
tenOut = 0 |
|
tenNormalize = 0 |
|
for idx in range(flow_num): |
|
tenOutF, tenNormalizeF = one_fdir(tenIn1[idx], tenFlow1[idx], t1[idx], tenMetric1[idx]) |
|
tenOutB, tenNormalizeB = one_fdir(tenIn2[idx], tenFlow2[idx], t2[idx], tenMetric2[idx]) |
|
|
|
tenOut += tenOutF + tenOutB |
|
tenNormalize += tenNormalizeF + tenNormalizeB |
|
|
|
return tenOut / tenNormalize, tenNormalize < 0.00001 |
|
|
|
|
|
|
|
|
|
c = 16 |
|
|
|
|
|
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): |
|
return torch.nn.Sequential( |
|
torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, |
|
padding=padding, dilation=dilation, bias=True), |
|
torch.nn.PReLU(out_planes) |
|
) |
|
|
|
|
|
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): |
|
return torch.nn.Sequential( |
|
torch.torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, |
|
kernel_size=kernel_size, stride=stride, padding=padding, bias=True), |
|
torch.nn.PReLU(out_planes) |
|
) |
|
|
|
|
|
class Conv2(torch.nn.Module): |
|
def __init__(self, in_planes, out_planes, stride=2): |
|
super(Conv2, self).__init__() |
|
self.conv1 = conv(in_planes, out_planes, 3, stride, 1) |
|
self.conv2 = conv(out_planes, out_planes, 3, 1, 1) |
|
|
|
def forward(self, x): |
|
x = self.conv1(x) |
|
x = self.conv2(x) |
|
return x |
|
|
|
|
|
class Conv2n(torch.nn.Module): |
|
def __init__(self, in_planes, out_planes, stride=2): |
|
super(Conv2n, self).__init__() |
|
self.conv1 = conv(in_planes, in_planes, 3, stride, 1) |
|
self.conv2 = conv(in_planes, in_planes, 3, 1, 1) |
|
self.conv3 = conv(in_planes, in_planes, 1, 1, 0) |
|
self.conv4 = conv(in_planes, out_planes, 1, 1, 0) |
|
|
|
def forward(self, x): |
|
x = self.conv1(x) |
|
x = self.conv2(x) |
|
x = self.conv3(x) |
|
x = self.conv4(x) |
|
return x |
|
|
|
|
|
|
|
|
|
class ImgPyramid(torch.nn.Module): |
|
def __init__(self): |
|
super(ImgPyramid, self).__init__() |
|
self.conv1 = Conv2(3, c) |
|
self.conv2 = Conv2(c, 2 * c) |
|
self.conv3 = Conv2(2 * c, 4 * c) |
|
self.conv4 = Conv2(4 * c, 8 * c) |
|
|
|
def forward(self, x): |
|
x1 = self.conv1(x) |
|
x2 = self.conv2(x1) |
|
x3 = self.conv3(x2) |
|
x4 = self.conv4(x3) |
|
return [x1, x2, x3, x4] |
|
|
|
|
|
class EncDec(torch.nn.Module): |
|
def __init__(self, branch): |
|
super(EncDec, self).__init__() |
|
self.branch = branch |
|
|
|
self.down0 = Conv2(8, 2 * c) |
|
self.down1 = Conv2(6 * c, 4 * c) |
|
self.down2 = Conv2(12 * c, 8 * c) |
|
self.down3 = Conv2(24 * c, 16 * c) |
|
|
|
self.up0 = deconv(48 * c, 8 * c) |
|
self.up1 = deconv(16 * c, 4 * c) |
|
self.up2 = deconv(8 * c, 2 * c) |
|
self.up3 = deconv(4 * c, c) |
|
self.conv = torch.nn.Conv2d(c, 2 * self.branch, 3, 1, 1) |
|
|
|
self.conv_m = torch.nn.Conv2d(c, 1, 3, 1, 1) |
|
|
|
|
|
self.conv_C = torch.nn.Sequential( |
|
torch.nn.AdaptiveAvgPool2d(1), |
|
torch.nn.Conv2d(16 * c, 16 * 16 * c, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True), |
|
torch.nn.Sigmoid() |
|
) |
|
|
|
|
|
self.conv_H = torch.nn.Sequential( |
|
torch.nn.AdaptiveAvgPool2d((None, 1)), |
|
torch.nn.Conv2d(16 * c, 16, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True), |
|
torch.nn.Sigmoid() |
|
) |
|
|
|
|
|
self.conv_W = torch.nn.Sequential( |
|
torch.nn.AdaptiveAvgPool2d((1, None)), |
|
torch.nn.Conv2d(16 * c, 16, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True), |
|
torch.nn.Sigmoid() |
|
) |
|
|
|
self.sigmoid = torch.nn.Sigmoid() |
|
|
|
def forward(self, flow0, flow1, im0, im1, c0, c1): |
|
N_, C_, H_, W_ = im0.shape |
|
|
|
wim1 = backwarp(im1, flow0) |
|
wim0 = backwarp(im0, flow1) |
|
s0_0 = self.down0(torch.cat((flow0, im0, wim1), 1)) |
|
s1_0 = self.down0(torch.cat((flow1, im1, wim0), 1)) |
|
|
|
|
|
flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 |
|
flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 |
|
|
|
wf0 = backwarp(torch.cat((s0_0, c0[0]), 1), flow1) |
|
wf1 = backwarp(torch.cat((s1_0, c1[0]), 1), flow0) |
|
|
|
s0_1 = self.down1(torch.cat((s0_0, c0[0], wf1), 1)) |
|
s1_1 = self.down1(torch.cat((s1_0, c1[0], wf0), 1)) |
|
|
|
|
|
flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 |
|
flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 |
|
|
|
wf0 = backwarp(torch.cat((s0_1, c0[1]), 1), flow1) |
|
wf1 = backwarp(torch.cat((s1_1, c1[1]), 1), flow0) |
|
|
|
s0_2 = self.down2(torch.cat((s0_1, c0[1], wf1), 1)) |
|
s1_2 = self.down2(torch.cat((s1_1, c1[1], wf0), 1)) |
|
|
|
|
|
flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 |
|
flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 |
|
|
|
wf0 = backwarp(torch.cat((s0_2, c0[2]), 1), flow1) |
|
wf1 = backwarp(torch.cat((s1_2, c1[2]), 1), flow0) |
|
|
|
s0_3 = self.down3(torch.cat((s0_2, c0[2], wf1), 1)) |
|
s1_3 = self.down3(torch.cat((s1_2, c1[2], wf0), 1)) |
|
|
|
|
|
|
|
s0_3_c = self.conv_C(s0_3) |
|
s0_3_c = s0_3_c.view(N_, 16, -1, 1, 1) |
|
|
|
s0_3_h = self.conv_H(s0_3) |
|
s0_3_h = s0_3_h.view(N_, 16, 1, -1, 1) |
|
|
|
s0_3_w = self.conv_W(s0_3) |
|
s0_3_w = s0_3_w.view(N_, 16, 1, 1, -1) |
|
|
|
cube0 = (s0_3_c * s0_3_h * s0_3_w).mean(1) |
|
|
|
s0_3 = s0_3 * cube0 |
|
|
|
s1_3_c = self.conv_C(s1_3) |
|
s1_3_c = s1_3_c.view(N_, 16, -1, 1, 1) |
|
|
|
s1_3_h = self.conv_H(s1_3) |
|
s1_3_h = s1_3_h.view(N_, 16, 1, -1, 1) |
|
|
|
s1_3_w = self.conv_W(s1_3) |
|
s1_3_w = s1_3_w.view(N_, 16, 1, 1, -1) |
|
|
|
cube1 = (s1_3_c * s1_3_h * s1_3_w).mean(1) |
|
|
|
s1_3 = s1_3 * cube1 |
|
|
|
|
|
flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 |
|
flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 |
|
|
|
wf0 = backwarp(torch.cat((s0_3, c0[3]), 1), flow1) |
|
wf1 = backwarp(torch.cat((s1_3, c1[3]), 1), flow0) |
|
|
|
x0 = self.up0(torch.cat((s0_3, c0[3], wf1), 1)) |
|
x1 = self.up0(torch.cat((s1_3, c1[3], wf0), 1)) |
|
|
|
x0 = self.up1(torch.cat((s0_2, x0), 1)) |
|
x1 = self.up1(torch.cat((s1_2, x1), 1)) |
|
|
|
x0 = self.up2(torch.cat((s0_1, x0), 1)) |
|
x1 = self.up2(torch.cat((s1_1, x1), 1)) |
|
|
|
x0 = self.up3(torch.cat((s0_0, x0), 1)) |
|
x1 = self.up3(torch.cat((s1_0, x1), 1)) |
|
|
|
m0 = self.sigmoid(self.conv_m(x0)) * 0.8 + 0.1 |
|
m1 = self.sigmoid(self.conv_m(x1)) * 0.8 + 0.1 |
|
|
|
x0 = self.conv(x0) |
|
x1 = self.conv(x1) |
|
|
|
return x0, x1, m0.repeat(1, self.branch, 1, 1), m1.repeat(1, self.branch, 1, 1) |
|
|
|
|
|
@register('m2m_unimatch') |
|
class M2M_PWC(torch.nn.Module): |
|
def __init__(self, ratio=4): |
|
super(M2M_PWC, self).__init__() |
|
self.branch = 4 |
|
self.ratio = ratio |
|
|
|
self.netFlow = UniMatch(num_scales=2, feature_channels=128, upsample_factor=4, |
|
num_head=1, ffn_dim_expansion=4, num_transformer_layers=6, |
|
reg_refine=True, task='flow') |
|
for p in self.netFlow.parameters(): |
|
p.requires_grad = False |
|
|
|
|
|
|
|
class MotionRefineNet(torch.nn.Module): |
|
def __init__(self, branch): |
|
super(MotionRefineNet, self).__init__() |
|
self.branch = branch |
|
self.img_pyramid = ImgPyramid() |
|
self.motion_encdec = EncDec(branch) |
|
|
|
def forward(self, flow0, flow1, im0, im1, ratio): |
|
flow0 = ratio * torch.nn.functional.interpolate(input=flow0, scale_factor=ratio, mode='bilinear', |
|
align_corners=False) |
|
flow1 = ratio * torch.nn.functional.interpolate(input=flow1, scale_factor=ratio, mode='bilinear', |
|
align_corners=False) |
|
|
|
c0 = self.img_pyramid(im0) |
|
c1 = self.img_pyramid(im1) |
|
|
|
flow_res = self.motion_encdec(flow0, flow1, im0, im1, c0, c1) |
|
|
|
flow0 = flow0.repeat(1, self.branch, 1, 1) + flow_res[0] |
|
flow1 = flow1.repeat(1, self.branch, 1, 1) + flow_res[1] |
|
|
|
return flow0, flow1, flow_res[2], flow_res[3] |
|
|
|
self.MRN = MotionRefineNet(self.branch) |
|
|
|
self.alpha = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) |
|
self.alpha_splat_photo_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) |
|
self.alpha_splat_flow_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) |
|
self.alpha_splat_variation_flow = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) |
|
|
|
def get_splat_weight(self, img0, img1, flow01, flow10): |
|
M_splat = 1 / (1 + self.alpha_splat_photo_consistency * photometric_consistency(img0, img1, flow01).detach()) + \ |
|
1 / (1 + self.alpha_splat_flow_consistency * flow_consistency(flow01, flow10).detach()) + \ |
|
1 / (1 + self.alpha_splat_variation_flow * variance_flow(flow01).detach()) |
|
return M_splat * self.alpha |
|
|
|
def forward(self, img0, img1, time_step=[0.5], ratio=None, **kwargs): |
|
if ratio is None: |
|
ratio = self.ratio |
|
|
|
intWidth = img0.shape[3] and img1.shape[3] |
|
intHeight = img0.shape[2] and img1.shape[2] |
|
|
|
intPadr = ((ratio * 16) - (intWidth % (ratio * 16))) % (ratio * 16) |
|
intPadb = ((ratio * 16) - (intHeight % (ratio * 16))) % (ratio * 16) |
|
|
|
img0 = torch.nn.functional.pad(input=img0, pad=[0, intPadr, 0, intPadb], mode='replicate') |
|
img1 = torch.nn.functional.pad(input=img1, pad=[0, intPadr, 0, intPadb], mode='replicate') |
|
|
|
N_, C_, H_, W_ = img0.shape |
|
|
|
outputs = [] |
|
result_dict = {} |
|
|
|
im0_ = torch.nn.functional.interpolate(input=img0, scale_factor=1.0 / ratio, mode='bilinear', |
|
align_corners=False) |
|
im1_ = torch.nn.functional.interpolate(input=img1, scale_factor=1.0 / ratio, mode='bilinear', |
|
align_corners=False) |
|
|
|
flow_preds = self.netFlow(im0_, im1_, 'swin', [2, 8], [-1, 4], [-1, 1], 6, True) |
|
tenFwds, tenBwds = [], [] |
|
for flow_pred in flow_preds: |
|
tenFwd, tenBwd = torch.chunk(flow_pred, 2, dim=0) |
|
tenFwds.append(tenFwd) |
|
tenBwds.append(tenBwd) |
|
|
|
with torch.set_grad_enabled(False): |
|
tenStats = [img0, img1] |
|
tenMean_ = sum([tenIn.mean([1, 2, 3], True) for tenIn in tenStats]) / len(tenStats) |
|
tenStd_ = (sum([tenIn.std([1, 2, 3], False, True).square() + ( |
|
tenMean_ - tenIn.mean([1, 2, 3], True)).square() for tenIn in tenStats]) / len(tenStats)).sqrt() |
|
|
|
im0_o = (img0 - tenMean_) / (tenStd_ + 0.0000001) |
|
im1_o = (img1 - tenMean_) / (tenStd_ + 0.0000001) |
|
|
|
img0 = (img0 - tenMean_) / (tenStd_ + 0.0000001) |
|
img1 = (img1 - tenMean_) / (tenStd_ + 0.0000001) |
|
|
|
result_dict['flowfwd'] = torch.nn.functional.interpolate(tenFwd, scale_factor=ratio, mode='bilinear', align_corners=False)[:, :, |
|
:intHeight, :intWidth].clone().detach() * ratio |
|
result_dict['flowbwd'] = torch.nn.functional.interpolate(tenBwd, scale_factor=ratio, mode='bilinear', align_corners=False)[:, :, |
|
:intHeight, :intWidth].clone().detach() * ratio |
|
|
|
for i in range(len(tenFwds)): |
|
tenFwd, tenBwd, WeiMF, WeiMB = self.MRN(tenFwds[i], tenBwds[i], img0, img1, ratio) |
|
|
|
img0_ = im0_o.repeat(1, self.branch, 1, 1) |
|
img1_ = im1_o.repeat(1, self.branch, 1, 1) |
|
tenStd = tenStd_.repeat(1, self.branch, 1, 1) |
|
tenMean = tenMean_.repeat(1, self.branch, 1, 1) |
|
fltTime = time_step.repeat(1, self.branch, 1, 1) |
|
|
|
tenFwd = tenFwd.reshape(N_, self.branch, 2, H_, W_).view(N_ * self.branch, 2, H_, W_) |
|
tenBwd = tenBwd.reshape(N_, self.branch, 2, H_, W_).view(N_ * self.branch, 2, H_, W_) |
|
|
|
WeiMF = WeiMF.reshape(N_, self.branch, 1, H_, W_).view(N_ * self.branch, 1, H_, W_) |
|
WeiMB = WeiMB.reshape(N_, self.branch, 1, H_, W_).view(N_ * self.branch, 1, H_, W_) |
|
|
|
img0_ = img0_.reshape(N_, self.branch, 3, H_, W_).view(N_ * self.branch, 3, H_, W_) |
|
img1_ = img1_.reshape(N_, self.branch, 3, H_, W_).view(N_ * self.branch, 3, H_, W_) |
|
|
|
tenStd = tenStd.reshape(N_, self.branch, 1, 1, 1).view(N_ * self.branch, 1, 1, 1) |
|
tenMean = tenMean.reshape(N_, self.branch, 1, 1, 1).view(N_ * self.branch, 1, 1, 1) |
|
fltTime = fltTime.reshape(N_, self.branch, 1, 1, 1).view(N_ * self.branch, 1, 1, 1) |
|
|
|
tenPhotoone = self.get_splat_weight(img0_, img1_, tenFwd, tenBwd) * WeiMF |
|
tenPhototwo = self.get_splat_weight(img1_, img0_, tenBwd, tenFwd) * WeiMB |
|
|
|
t0 = fltTime |
|
flow0 = tenFwd * t0 |
|
metric0 = tenPhotoone |
|
|
|
t1 = 1.0 - fltTime |
|
flow1 = tenBwd * t1 |
|
metric1 = tenPhototwo |
|
|
|
flow0 = flow0.reshape(N_, self.branch, 2, H_, W_).permute(1, 0, 2, 3, 4) |
|
flow1 = flow1.reshape(N_, self.branch, 2, H_, W_).permute(1, 0, 2, 3, 4) |
|
|
|
metric0 = metric0.reshape(N_, self.branch, 1, H_, W_).permute(1, 0, 2, 3, 4) |
|
metric1 = metric1.reshape(N_, self.branch, 1, H_, W_).permute(1, 0, 2, 3, 4) |
|
|
|
img0_ = img0_.reshape(N_, self.branch, 3, H_, W_).permute(1, 0, 2, 3, 4) |
|
img1_ = img1_.reshape(N_, self.branch, 3, H_, W_).permute(1, 0, 2, 3, 4) |
|
|
|
t0 = t0.reshape(N_, self.branch, 1, 1, 1).permute(1, 0, 2, 3, 4) |
|
t1 = t1.reshape(N_, self.branch, 1, 1, 1).permute(1, 0, 2, 3, 4) |
|
|
|
tenOutput, mask = forwarp_mframe_mask(img0_, flow0, t1, img1_, flow1, t0, metric0, metric1) |
|
|
|
tenOutput = tenOutput + mask * (t1.mean(0) * im0_o + t0.mean(0) * im1_o) |
|
|
|
output = (tenOutput * (tenStd_ + 0.0000001)) + tenMean_ |
|
outputs.append(output[:, :, :intHeight, :intWidth]) |
|
result_dict['imgt_preds'] = outputs |
|
result_dict['imgt_pred'] = outputs[-1] |
|
tenFwds.append(tenFwd.reshape(N_, self.branch, 2, H_, W_)) |
|
tenBwds.append(tenBwd.reshape(N_, self.branch, 2, H_, W_)) |
|
result_dict['flow0_pred'] = tenFwds[::-1] |
|
result_dict['flow1_pred'] = tenBwds[::-1] |
|
|
|
return result_dict |
|
|
|
|