import torch import math import numpy import torch.nn.functional as F import torch.nn as nn import torchvision.transforms.v2.functional as TF import modules.components.upr_net_mod2.correlation as correlation import modules.components.upr_net_mod2.softsplat as softsplat from modules.components.upr_net_mod2.m2m import * from modules.components.upr_net_mod2.backwarp import backwarp from .costvol import costvol_func from ..components import register from utils.padder import InputPadder from utils.vos.model.network import STCN from utils.vos.model.inference_core import InferenceCore # **************************************************************************************************# # => Feature Pyramid # **************************************************************************************************# def photometric_consistency(img0, img1, flow01): return (img0 - backwarp(img1, flow01)).abs().sum(dim=1, keepdims=True) def flow_consistency(flow01, flow10): return (flow01 + backwarp(flow10, flow01)).abs().sum(dim=1, keepdims=True) def gaussian(x): gaussian_kernel = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]]) / 16 gaussian_kernel = gaussian_kernel.repeat(2, 1, 1, 1) gaussian_kernel = gaussian_kernel.to("cpu")#torch.cuda.current_device()) x = torch.nn.functional.pad(x, (1, 1, 1, 1), mode='reflect') out = torch.nn.functional.conv2d(x, gaussian_kernel, groups=x.shape[1]) # out = TF.gaussian_blur(x, [3, 3], sigma=[2, 2]) return out def variance_flow(flow): flow = flow * torch.tensor(data=[2.0 / (flow.shape[3] - 1.0), 2.0 / (flow.shape[2] - 1.0)], dtype=flow.dtype, device=flow.device).view(1, 2, 1, 1) return (gaussian(flow ** 2) - gaussian(flow) ** 2 + 1e-4).sqrt().abs().sum(dim=1, keepdim=True) class FeatPyramid(nn.Module): """A 3-level feature pyramid, which by default is shared by the motion estimator and synthesis network. """ def __init__(self): super(FeatPyramid, self).__init__() self.conv_stage0 = nn.Sequential( nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1, padding=1), nn.InstanceNorm2d(num_features=32), nn.LeakyReLU(inplace=False, negative_slope=0.1), nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), nn.InstanceNorm2d(num_features=32), nn.LeakyReLU(inplace=False, negative_slope=0.1), nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), nn.InstanceNorm2d(num_features=32), nn.LeakyReLU(inplace=False, negative_slope=0.1), nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1)) self.conv_stage1 = nn.Sequential( nn.InstanceNorm2d(num_features=32), nn.LeakyReLU(inplace=False, negative_slope=0.1), nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1), nn.InstanceNorm2d(num_features=64), nn.LeakyReLU(inplace=False, negative_slope=0.1), nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), nn.InstanceNorm2d(num_features=64), nn.LeakyReLU(inplace=False, negative_slope=0.1), nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), nn.InstanceNorm2d(num_features=64), nn.LeakyReLU(inplace=False, negative_slope=0.1), nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), ) self.conv_stage2 = nn.Sequential( nn.InstanceNorm2d(num_features=64), nn.LeakyReLU(inplace=False, negative_slope=0.1), nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1), nn.InstanceNorm2d(num_features=128), nn.LeakyReLU(inplace=False, negative_slope=0.1), nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), nn.InstanceNorm2d(num_features=128), nn.LeakyReLU(inplace=False, negative_slope=0.1), nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), nn.InstanceNorm2d(num_features=128), nn.LeakyReLU(inplace=False, negative_slope=0.1), nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), ) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): if m.weight is not None: nn.init.constant_(m.weight, 1) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, img): C0 = self.conv_stage0(img) C1 = self.conv_stage1(C0) C2 = self.conv_stage2(C1) return [C0, C1, C2] # **************************************************************************************************# # => Motion Estimation # **************************************************************************************************# class MotionEstimator(nn.Module): """Bi-directional optical flow estimator 1) construct partial cost volume with the CNN features from the stage 2 of the feature pyramid; 2) estimate bi-directional flows, by feeding cost volume, CNN features for both warped images, CNN feature and estimated flow from previous iteration. """ def __init__(self): super(MotionEstimator, self).__init__() # 64 + 256 + 128 * 2 + 128 = 704 self.conv_flow = nn.Sequential( nn.Conv2d(4, 128, 7, padding=3), nn.LeakyReLU(inplace=False, negative_slope=0.1), nn.Conv2d(128, 64, 3, padding=1), nn.LeakyReLU(inplace=False, negative_slope=0.1) ) self.conv_corr = nn.Sequential( nn.Conv2d(81, 64, 1, padding=0), nn.LeakyReLU(inplace=False, negative_slope=0.1), nn.Conv2d(64, 128, 3, padding=1), nn.LeakyReLU(inplace=False, negative_slope=0.1), ) self.conv_layer1 = nn.Sequential( nn.Conv2d(in_channels=704, out_channels=320, kernel_size=1, stride=1, padding=0), nn.LeakyReLU(inplace=False, negative_slope=0.1)) self.conv_layer2 = nn.Sequential( nn.Conv2d(in_channels=320, out_channels=256, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(inplace=False, negative_slope=0.1)) self.conv_layer3 = nn.Sequential( nn.Conv2d(in_channels=256, out_channels=224, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(inplace=False, negative_slope=0.1)) self.conv_layer4 = nn.Sequential( nn.Conv2d(in_channels=224, out_channels=192, kernel_size=3, stride=1, padding=1), nn.LeakyReLU(inplace=False, negative_slope=0.1)) self.conv_layer5 = nn.Sequential( nn.Conv2d(in_channels=192, out_channels=128, kernel_size=3, stride=1, padding=1)) self.conv_layer6 = nn.Sequential( nn.LeakyReLU(inplace=False, negative_slope=0.1), nn.Conv2d(in_channels=128, out_channels=4, kernel_size=3, stride=1, padding=1, bias=False)) self.upsampler = nn.Sequential( nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(256, 16 * 9, 1, padding=0) ) # for m in self.modules(): # if isinstance(m, nn.Conv2d): # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') # elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): # if m.weight is not None: # nn.init.constant_(m.weight, 1) # if m.bias is not None: # nn.init.constant_(m.bias, 0) def upsample(self, flow, mask): """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ N, _, H, W = flow.shape mask = mask.view(N, 1, 9, 4, 4, H, W) mask = torch.softmax(mask, dim=2) up_flow = F.unfold(4 * flow, [3, 3], padding=1) up_flow = up_flow.view(N, 4, 9, 1, 1, H, W) up_flow = torch.sum(mask * up_flow, dim=2) up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) return up_flow.reshape(N, 4, 4 * H, 4 * W) def forward(self, feat0, feat1, last_feat, last_flow): corr_fn = correlation.FunctionCorrelation feat0_warp = backwarp(feat0, last_flow[:, :2]) feat1_warp = backwarp(feat1, last_flow[:, 2:]) volume0 = F.leaky_relu( input=costvol_func.apply(feat0_warp, feat1_warp), negative_slope=0.1, inplace=False) volume1 = F.leaky_relu( input=costvol_func.apply(feat1_warp, feat0_warp), negative_slope=0.1, inplace=False) corr0 = self.conv_corr(volume0) corr1 = self.conv_corr(volume1) flo = self.conv_flow(last_flow) input_feat = torch.cat([corr0, corr1, feat0_warp, feat1_warp, last_feat, flo], 1) feat = self.conv_layer1(input_feat) feat = self.conv_layer2(feat) feat = self.conv_layer3(feat) feat = self.conv_layer4(feat) feat = self.conv_layer5(feat) flow_res = self.conv_layer6(feat) flow = last_flow + flow_res mask = self.upsampler(feat) * .25 flow = self.upsample(flow, mask) return flow, feat # **************************************************************************************************# # => Frame Synthesis # **************************************************************************************************# class SynthesisNetwork(nn.Module): def __init__(self, splat_mode='average'): super(SynthesisNetwork, self).__init__() input_channels = 9 + 4 + 6 self.encoder_conv = nn.Sequential( nn.Conv2d(in_channels=input_channels, out_channels=64, kernel_size=3, stride=1, padding=1), nn.PReLU(num_parameters=64), nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), nn.PReLU(num_parameters=64)) self.encoder_down1 = nn.Sequential( nn.Conv2d(in_channels=64 + 32 + 32, out_channels=128, kernel_size=3, stride=2, padding=1), nn.PReLU(num_parameters=128), nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), nn.PReLU(num_parameters=128), nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), nn.PReLU(num_parameters=128)) self.encoder_down2 = nn.Sequential( nn.Conv2d(in_channels=128 + 64 + 64, out_channels=256, kernel_size=3, stride=2, padding=1), nn.PReLU(num_parameters=256), nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), nn.PReLU(num_parameters=256), nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), nn.PReLU(num_parameters=256)) self.decoder_up1 = nn.Sequential( torch.nn.ConvTranspose2d(in_channels=256 + 128 + 128, out_channels=128, kernel_size=4, stride=2, padding=1, bias=True), nn.PReLU(num_parameters=128), nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), nn.PReLU(num_parameters=128)) self.decoder_up2 = nn.Sequential( torch.nn.ConvTranspose2d(in_channels=128 + 128, out_channels=64, kernel_size=4, stride=2, padding=1, bias=True), nn.PReLU(num_parameters=64), nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), nn.PReLU(num_parameters=64)) self.decoder_conv = nn.Sequential( nn.Conv2d(in_channels=64 + 64, out_channels=64, kernel_size=3, stride=1, padding=1), nn.PReLU(num_parameters=64), nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), nn.PReLU(num_parameters=64)) self.pred = nn.Conv2d(in_channels=64, out_channels=4, kernel_size=3, stride=1, padding=1) self.splat_mode = splat_mode if self.splat_mode == 'softmax': # New params for splatting mask generation self.alpha = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) self.alpha_splat_photo_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) self.alpha_splat_flow_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) self.alpha_splat_variation_flow = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) def get_splat_weight(self, img0, img1, flow01, flow10): if self.splat_mode == 'softmax': M_splat = 1 / ( 1 + self.alpha_splat_photo_consistency * photometric_consistency(img0, img1, flow01).detach()) + \ 1 / (1 + self.alpha_splat_flow_consistency * flow_consistency(flow01, flow10).detach()) + \ 1 / (1 + self.alpha_splat_variation_flow * variance_flow(flow01).detach()) return M_splat * self.alpha else: return None def get_warped_representations(self, bi_flow, c0, c1, m_splat_0, m_splat_1, i0=None, i1=None, time_period=0.5): flow_t0 = bi_flow[:, :2] * time_period * 2 flow_t1 = bi_flow[:, 2:4] * (1 - time_period) * 2 warped_c0 = backwarp(c0, flow_t0) warped_c1 = backwarp(c1, flow_t1) if (i0 is None) and (i1 is None): return warped_c0, warped_c1 else: warped_img0 = backwarp(i0, flow_t0) warped_img1 = backwarp(i1, flow_t1) scaler = torch.Tensor([i0.shape[3], i0.shape[2]]).view(1, 2, 1, 1)#.cuda() flow_t0_t1 = torch.cat((flow_t0 / scaler, flow_t1 / scaler), 1) return warped_img0, warped_img1, warped_c0, warped_c1, flow_t0_t1 def forward(self, last_i, i0, i1, c0_pyr, c1_pyr, bi_flow_pyr, time_period=0.5, multi_flow=False): m_splat_0_0 = self.get_splat_weight(i0, i1, bi_flow_pyr[0][:, :2], bi_flow_pyr[0][:, 2:4]) m_splat_1_0 = self.get_splat_weight(i1, i0, bi_flow_pyr[0][:, 2:4], bi_flow_pyr[0][:, :2]) warped_img0, warped_img1, warped_c0, warped_c1, flow_0t_1t = \ self.get_warped_representations( bi_flow_pyr[0], c0_pyr[0], c1_pyr[0], m_splat_0_0, m_splat_1_0, i0, i1, time_period=time_period) input_feat = torch.cat( (last_i, warped_img0, warped_img1, i0, i1, flow_0t_1t), 1) s0 = self.encoder_conv(input_feat) # [B, 64,h,w] s1 = self.encoder_down1(torch.cat((s0, warped_c0, warped_c1), 1)) # [B, 128,h/2,w/2] warped_c0, warped_c1 = self.get_warped_representations( bi_flow_pyr[1], c0_pyr[1], c1_pyr[1], None, None, time_period=time_period) s2 = self.encoder_down2(torch.cat((s1, warped_c0, warped_c1), 1)) # [B, 256,h/4,w/4] warped_c0, warped_c1 = self.get_warped_representations( bi_flow_pyr[2], c0_pyr[2], c1_pyr[2], None, None, time_period=time_period) x = self.decoder_up1(torch.cat((s2, warped_c0, warped_c1), 1)) x = self.decoder_up2(torch.cat((x, s1), 1)) x = self.decoder_conv(torch.cat((x, s0), 1)) # prediction refine = self.pred(x) refine_res = torch.sigmoid(refine[:, :3]) * 2 - 1 refine_mask = torch.sigmoid(refine[:, 3:]) merged_img = (warped_img0 * refine_mask + warped_img1 * (1 - refine_mask)) interp_img = merged_img + refine_res # interp_img = torch.clamp(interp_img, 0, 1) extra_dict = {} extra_dict["refine_res"] = refine_res extra_dict["refine_mask"] = refine_mask extra_dict["warped_img0"] = warped_img0 extra_dict["warped_img1"] = warped_img1 extra_dict["merged_img"] = merged_img extra_dict["c0_pyr"] = c0_pyr extra_dict["c1_pyr"] = c1_pyr extra_dict["syn_pyr"] = [s0,s1,s2] return interp_img, extra_dict # **************************************************************************************************# # => Unified model # **************************************************************************************************# @register('upr_net_mod2') class Model(nn.Module): def __init__(self, pyr_level=3, nr_lvl_skipped=0, splat_mode='average'): super(Model, self).__init__() print('@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@UPR-back exp45@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@') self.pyr_level = pyr_level self.feat_pyramid = FeatPyramid() self.nr_lvl_skipped = nr_lvl_skipped self.motion_estimator = MotionEstimator() self.synthesis_network = SynthesisNetwork(splat_mode) self.splat_mode = splat_mode def forward_one_lvl(self, img0, img1, last_feat, last_flow, last_interp=None, time_period=0.5, skip_me=False): # context feature extraction feat0_pyr = self.feat_pyramid(img0) feat1_pyr = self.feat_pyramid(img1) # bi-directional flow estimation if not skip_me: last_flow = F.interpolate( input=last_flow, scale_factor=0.25, mode="bilinear") * 0.25 flow, feat = self.motion_estimator( feat0_pyr[-1], feat1_pyr[-1], last_feat, last_flow) else: flow = last_flow feat = last_feat # frame synthesis ## optical flow is estimated at 1/4 resolution ori_resolution_flow = flow ## consturct 3-level flow pyramid for synthesis network bi_flow_pyr = [] tmp_flow = ori_resolution_flow bi_flow_pyr.append(tmp_flow) for i in range(2): tmp_flow = F.interpolate( input=tmp_flow, scale_factor=0.5, mode="bilinear") * 0.5 bi_flow_pyr.append(tmp_flow) ## merge warped frames as initial interpolation for frame synthesis if last_interp is None: flow_t0 = ori_resolution_flow[:, :2] * time_period * 2 flow_t1 = ori_resolution_flow[:, 2:4] * (1 - time_period) * 2 warped_img0 = backwarp(img0, flow_t0) warped_img1 = backwarp(img1, flow_t1) last_interp = warped_img0 * (1 - time_period) + warped_img1 * time_period ## do synthesis interp_img, extra_dict = self.synthesis_network( last_interp, img0, img1, feat0_pyr, feat1_pyr, bi_flow_pyr, time_period=time_period) return flow, feat, interp_img, extra_dict def forward(self, img0, img1, time_step, seg0=None, segt=None, seg1=None, pyr_level=None, nr_lvl_skipped=None, imgt=None, **kwargs): if pyr_level is None: pyr_level = self.pyr_level if nr_lvl_skipped is None: nr_lvl_skipped = self.nr_lvl_skipped N, _, H, W = img0.shape flow0_pred = [] flow1_pred = [] interp_imgs = [] skipped_levels = [] if nr_lvl_skipped == 0 else \ list(range(pyr_level))[::-1][-nr_lvl_skipped:] with torch.set_grad_enabled(False): tenStats = [img0, img1] tenMean_ = sum([tenIn.mean([1, 2, 3], True) for tenIn in tenStats]) / len(tenStats) tenStd_ = (sum([tenIn.std([1, 2, 3], False, True).square() + ( tenMean_ - tenIn.mean([1, 2, 3], True)).square() for tenIn in tenStats]) / len(tenStats)).sqrt() img0 = (img0 - tenMean_) / (tenStd_ + 0.0000001) img1 = (img1 - tenMean_) / (tenStd_ + 0.0000001) padder = InputPadder(img0.shape, divisor=int(4 * 2 ** pyr_level)) img0, img1 = padder.pad(img0, img1) N, _, H, W = img0.shape # The original input resolution corresponds to level 0. for level in list(range(pyr_level))[::-1]: if level != 0: scale_factor = 1 / 2 ** level img0_this_lvl = F.interpolate( input=img0, scale_factor=scale_factor, mode="bilinear", align_corners=False) img1_this_lvl = F.interpolate( input=img1, scale_factor=scale_factor, mode="bilinear", align_corners=False) else: img0_this_lvl = img0 img1_this_lvl = img1 # skip motion estimation, directly use up-sampled optical flow skip_me = False # the lowest-resolution pyramid level if level == pyr_level - 1: last_flow = torch.zeros( (N, 4, H // (2 ** (level)), W // (2 ** (level))) ).to(img0.device) last_feat = torch.zeros( (N, 128, H // (2 ** (level + 2)), W // (2 ** (level + 2))) ).to(img0.device) last_interp = None # skip some levels for both motion estimation and frame synthesis elif level in skipped_levels[:-1]: continue # last level (original input resolution), only skip motion estimation elif (level == 0) and len(skipped_levels) > 0: if len(skipped_levels) == pyr_level: last_flow = torch.zeros( (N, 4, H, W)).to(img0.device) last_interp = None else: resize_factor = 2 ** len(skipped_levels) last_flow = F.interpolate( input=flow, scale_factor=resize_factor, mode="bilinear", align_corners=False) * resize_factor last_interp = F.interpolate( input=interp_img, scale_factor=resize_factor, mode="bilinear", align_corners=False) skip_me = True # last level (original input resolution), motion estimation + frame # synthesis else: last_flow = F.interpolate(input=flow, scale_factor=2.0, mode="bilinear", align_corners=False) * 2 last_feat = F.interpolate(input=feat, scale_factor=2.0, mode="bilinear", align_corners=False) last_interp = F.interpolate( input=interp_img, scale_factor=2.0, mode="bilinear", align_corners=False) flow, feat, interp_img, extra_dict = self.forward_one_lvl( img0_this_lvl, img1_this_lvl, last_feat, last_flow, last_interp, time_step, skip_me=skip_me) flow0_pred.append( padder.unpad(flow[:, :2])) flow1_pred.append( padder.unpad(flow[:, 2:])) interp_imgs.append(padder.unpad(F.interpolate(interp_img, scale_factor=2 ** level)) * tenStd_ + tenMean_) # directly up-sample estimated flow to full resolution with bi-linear # interpolation refine_res = padder.unpad(extra_dict["refine_res"]) refine_mask = padder.unpad(extra_dict["refine_mask"]) c0_pyr = [padder.unpad(cc) for cc in extra_dict["c0_pyr"]] c1_pyr = [padder.unpad(cc) for cc in extra_dict["c1_pyr"]] syn_pyr = [padder.unpad(cc) for cc in extra_dict["syn_pyr"]] warped_img0 = padder.unpad(extra_dict["warped_img0"]) * tenStd_ + tenMean_ warped_img1 = padder.unpad(extra_dict["warped_img1"]) * tenStd_ + tenMean_ merged_img = padder.unpad(extra_dict["merged_img"]) * tenStd_ + tenMean_ result_dict = { "imgt_preds": interp_imgs, "flow0_pred": flow0_pred[::-1], "flow1_pred": flow1_pred[::-1], 'imgt_pred': interp_imgs[-1].contiguous(), "flowfwd": flow0_pred[-1], "flowbwd": flow1_pred[-1], 'refine_res': refine_res, 'refine_mask': refine_mask, 'warped_img0': warped_img0, 'warped_img1': warped_img1, 'merged_img': merged_img, 'c0_pyr': c0_pyr, 'c1_pyr': c1_pyr, 'syn_pyr': syn_pyr } return result_dict if __name__ == "__main__": pass