toto10's picture
c2dbfba56f79c968b0d59e1af1493bfa9e6cfbd19338ec4a352438df2c986640
5bfcfd2
raw
history blame
5.86 kB
import sys, os
import numpy as np
import cv2
from collections import namedtuple
import torch
import argparse
from RAFT.raft import RAFT
from RAFT.utils.utils import InputPadder
import modules.paths as ph
import gc
RAFT_model = None
fgbg = cv2.createBackgroundSubtractorMOG2(history=500, varThreshold=16, detectShadows=True)
def background_subtractor(frame, fgbg):
fgmask = fgbg.apply(frame)
return cv2.bitwise_and(frame, frame, mask=fgmask)
def RAFT_clear_memory():
global RAFT_model
del RAFT_model
gc.collect()
torch.cuda.empty_cache()
RAFT_model = None
def RAFT_estimate_flow(frame1, frame2, device='cuda'):
global RAFT_model
org_size = frame1.shape[1], frame1.shape[0]
size = frame1.shape[1] // 16 * 16, frame1.shape[0] // 16 * 16
frame1 = cv2.resize(frame1, size)
frame2 = cv2.resize(frame2, size)
model_path = ph.models_path + '/RAFT/raft-things.pth'
remote_model_path = 'https://drive.google.com/uc?id=1MqDajR89k-xLV0HIrmJ0k-n8ZpG6_suM'
if not os.path.isfile(model_path):
from basicsr.utils.download_util import load_file_from_url
os.makedirs(os.path.dirname(model_path), exist_ok=True)
load_file_from_url(remote_model_path, file_name=model_path)
if RAFT_model is None:
args = argparse.Namespace(**{
'model': ph.models_path + '/RAFT/raft-things.pth',
'mixed_precision': True,
'small': False,
'alternate_corr': False,
'path': ""
})
RAFT_model = torch.nn.DataParallel(RAFT(args))
RAFT_model.load_state_dict(torch.load(args.model))
RAFT_model = RAFT_model.module
RAFT_model.to(device)
RAFT_model.eval()
with torch.no_grad():
frame1_torch = torch.from_numpy(frame1).permute(2, 0, 1).float()[None].to(device)
frame2_torch = torch.from_numpy(frame2).permute(2, 0, 1).float()[None].to(device)
padder = InputPadder(frame1_torch.shape)
image1, image2 = padder.pad(frame1_torch, frame2_torch)
# estimate optical flow
_, next_flow = RAFT_model(image1, image2, iters=20, test_mode=True)
_, prev_flow = RAFT_model(image2, image1, iters=20, test_mode=True)
next_flow = next_flow[0].permute(1, 2, 0).cpu().numpy()
prev_flow = prev_flow[0].permute(1, 2, 0).cpu().numpy()
fb_flow = next_flow + prev_flow
fb_norm = np.linalg.norm(fb_flow, axis=2)
occlusion_mask = fb_norm[..., None].repeat(3, axis=-1)
next_flow = cv2.resize(next_flow, org_size)
prev_flow = cv2.resize(prev_flow, org_size)
return next_flow, prev_flow, occlusion_mask
def compute_diff_map(next_flow, prev_flow, prev_frame, cur_frame, prev_frame_styled, args_dict):
h, w = cur_frame.shape[:2]
fl_w, fl_h = next_flow.shape[:2]
# normalize flow
next_flow = next_flow / np.array([fl_h,fl_w])
prev_flow = prev_flow / np.array([fl_h,fl_w])
# compute occlusion mask
fb_flow = next_flow + prev_flow
fb_norm = np.linalg.norm(fb_flow , axis=2)
zero_flow_mask = np.clip(1 - np.linalg.norm(prev_flow, axis=-1)[...,None] * 20, 0, 1)
diff_mask_flow = fb_norm[..., None] * zero_flow_mask
# resize flow
next_flow = cv2.resize(next_flow, (w, h))
next_flow = (next_flow * np.array([h,w])).astype(np.float32)
prev_flow = cv2.resize(prev_flow, (w, h))
prev_flow = (prev_flow * np.array([h,w])).astype(np.float32)
# Generate sampling grids
grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))
flow_grid = torch.stack((grid_x, grid_y), dim=0).float()
flow_grid += torch.from_numpy(prev_flow).permute(2, 0, 1)
flow_grid = flow_grid.unsqueeze(0)
flow_grid[:, 0, :, :] = 2 * flow_grid[:, 0, :, :] / (w - 1) - 1
flow_grid[:, 1, :, :] = 2 * flow_grid[:, 1, :, :] / (h - 1) - 1
flow_grid = flow_grid.permute(0, 2, 3, 1)
prev_frame_torch = torch.from_numpy(prev_frame).float().unsqueeze(0).permute(0, 3, 1, 2) #N, C, H, W
prev_frame_styled_torch = torch.from_numpy(prev_frame_styled).float().unsqueeze(0).permute(0, 3, 1, 2) #N, C, H, W
warped_frame = torch.nn.functional.grid_sample(prev_frame_torch, flow_grid, mode="nearest", padding_mode="reflection", align_corners=True).permute(0, 2, 3, 1)[0].numpy()
warped_frame_styled = torch.nn.functional.grid_sample(prev_frame_styled_torch, flow_grid, mode="nearest", padding_mode="reflection", align_corners=True).permute(0, 2, 3, 1)[0].numpy()
#warped_frame = cv2.remap(prev_frame, flow_map, None, cv2.INTER_NEAREST, borderMode = cv2.BORDER_REFLECT)
#warped_frame_styled = cv2.remap(prev_frame_styled, flow_map, None, cv2.INTER_NEAREST, borderMode = cv2.BORDER_REFLECT)
diff_mask_org = np.abs(warped_frame.astype(np.float32) - cur_frame.astype(np.float32)) / 255
diff_mask_org = diff_mask_org.max(axis = -1, keepdims=True)
diff_mask_stl = np.abs(warped_frame_styled.astype(np.float32) - cur_frame.astype(np.float32)) / 255
diff_mask_stl = diff_mask_stl.max(axis = -1, keepdims=True)
alpha_mask = np.maximum.reduce([diff_mask_flow * args_dict['occlusion_mask_flow_multiplier'] * 10, \
diff_mask_org * args_dict['occlusion_mask_difo_multiplier'], \
diff_mask_stl * args_dict['occlusion_mask_difs_multiplier']]) #
alpha_mask = alpha_mask.repeat(3, axis = -1)
#alpha_mask_blured = cv2.dilate(alpha_mask, np.ones((5, 5), np.float32))
if args_dict['occlusion_mask_blur'] > 0:
blur_filter_size = min(w,h) // 15 | 1
alpha_mask = cv2.GaussianBlur(alpha_mask, (blur_filter_size, blur_filter_size) , args_dict['occlusion_mask_blur'], cv2.BORDER_REFLECT)
alpha_mask = np.clip(alpha_mask, 0, 1)
return alpha_mask, warped_frame_styled
def frames_norm(frame): return frame / 127.5 - 1
def flow_norm(flow): return flow / 255
def occl_norm(occl): return occl / 127.5 - 1
def frames_renorm(frame): return (frame + 1) * 127.5
def flow_renorm(flow): return flow * 255
def occl_renorm(occl): return (occl + 1) * 127.5