diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..3677ed68d8ae9b24d633c771bcd3071a2668531d --- /dev/null +++ b/.gitignore @@ -0,0 +1,12 @@ +*.ipynb +*.pkl +*.pth +*.pt +*/*.ipynb +*/*.pkl +*/*.pth +*/*.pt +*/*/*.ipynb +*/*/*.pkl +*/*/*.pth +*/*/*.pt \ No newline at end of file diff --git a/README.md b/README.md index 34510cd11b4d6b1870a9b2d7aaedd0f3a779c61d..2be6790284f2725eed34dad20c0f060eac2b7b13 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,11 @@ --- title: VfiTest -emoji: ๐Ÿข -colorFrom: blue -colorTo: purple +emoji: ๐Ÿ‘ +colorFrom: yellow +colorTo: blue sdk: gradio sdk_version: 4.19.2 app_file: app.py pinned: false ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +license: unknown +--- \ No newline at end of file diff --git a/START_COMMAND.txt b/START_COMMAND.txt new file mode 100644 index 0000000000000000000000000000000000000000..138edcb5bb33b6a4f813aa53a722eddb729911f2 --- /dev/null +++ b/START_COMMAND.txt @@ -0,0 +1,17 @@ +CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node 4 main.py --cfg cfgs/upr_freq_unimatch_exp001.yaml +CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node 4 main.py --cfg cfgs/upr_freq_unimatch_exp002.yaml +CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node 4 main.py --cfg cfgs/upr_freq003.yaml +CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node 4 main.py --cfg cfgs/upr_freq004.yaml +CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node 4 main.py --cfg cfgs/upr_freq005.yaml +CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node 4 main.py --cfg cfgs/upr_freq006.yaml +CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node 4 main.py --cfg cfgs/upr_freq007.yaml +CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node 4 main.py --cfg cfgs/upr_freq008.yaml + +CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node 4 main.py --cfg cfgs/upr_freq_unimatch_exp002_stopmotion.yaml + +**inference ๊ณ ๋ ค์‚ฌํ•ญ** +class Model์—์„œ ํŒจ๋”ฉ ์œ„์น˜: normalization -> padding +utils.padder.py: constant padding mode +pyr_level (Vimeo: 3, UCF: 3, SNU-FILM: 5) +trainer.py: train function์—์„œ epoch ๋“ค์–ด๊ฐ€๊ธฐ ์ „์— return +yaml: [results.txt, ์ด๋ฏธ์ง€๋“ค ์ €์žฅํ•˜๊ณ  ์‹ถ์œผ๋ฉด][~~~every ์ „๋ถ€ 1๋กœ ์„ค์ •] /// [๊ทธ๋ ‡์ง€ ์•Š์œผ๋ฉด][~~~every ์ „๋ถ€ ๋งˆ์ง€๋ง‰ ์—ํญ ์•ˆ ๋‚˜๋ˆ ๋–จ์–ด์ง€๋Š” ์ˆ˜๋กœ ์„ค์ •] \ No newline at end of file diff --git a/UPR_BASIC_INFERENCE_SPEED.txt b/UPR_BASIC_INFERENCE_SPEED.txt new file mode 100644 index 0000000000000000000000000000000000000000..9f511c0f4fbb331685c9235cb2da3a9bb72c5d03 --- /dev/null +++ b/UPR_BASIC_INFERENCE_SPEED.txt @@ -0,0 +1,30 @@ +Test: [ 0/237] eta: 0:30:27 psnr: 25.80 ssim: 0.79 time: 7.7092 data: 0.3797 max mem: 12902 +Test: [ 10/237] eta: 0:04:25 psnr: 25.25 ssim: 0.73 time: 1.1688 data: 0.0348 max mem: 12902 +Test: [ 20/237] eta: 0:02:57 psnr: 24.94 ssim: 0.71 time: 0.4738 data: 0.0002 max mem: 12902 +Test: [ 30/237] eta: 0:02:23 psnr: 25.29 ssim: 0.72 time: 0.4329 data: 0.0002 max mem: 12902 +Test: [ 40/237] eta: 0:02:03 psnr: 25.64 ssim: 0.77 time: 0.4309 data: 0.0002 max mem: 12902 +Test: [ 50/237] eta: 0:01:50 psnr: 27.04 ssim: 0.80 time: 0.4291 data: 0.0002 max mem: 12902 +Test: [ 60/237] eta: 0:01:39 psnr: 24.79 ssim: 0.76 time: 0.4310 data: 0.0002 max mem: 12902 +Test: [ 70/237] eta: 0:01:31 psnr: 25.93 ssim: 0.78 time: 0.4323 data: 0.0002 max mem: 12902 +Test: [ 80/237] eta: 0:01:23 psnr: 25.98 ssim: 0.76 time: 0.4320 data: 0.0002 max mem: 12902 +Test: [ 90/237] eta: 0:01:17 psnr: 24.59 ssim: 0.77 time: 0.4666 data: 0.0002 max mem: 12902 +Test: [100/237] eta: 0:01:11 psnr: 24.06 ssim: 0.72 time: 0.4651 data: 0.0002 max mem: 12902 +Test: [110/237] eta: 0:01:04 psnr: 24.65 ssim: 0.67 time: 0.4289 data: 0.0002 max mem: 12902 +Test: [120/237] eta: 0:00:58 psnr: 23.71 ssim: 0.73 time: 0.4287 data: 0.0002 max mem: 12902 +Test: [130/237] eta: 0:00:53 psnr: 28.94 ssim: 0.87 time: 0.4288 data: 0.0002 max mem: 12902 +Test: [140/237] eta: 0:00:47 psnr: 26.36 ssim: 0.73 time: 0.4301 data: 0.0002 max mem: 12902 +Test: [150/237] eta: 0:00:42 psnr: 29.65 ssim: 0.85 time: 0.4299 data: 0.0002 max mem: 12902 +Test: [160/237] eta: 0:00:37 psnr: 25.18 ssim: 0.74 time: 0.4287 data: 0.0002 max mem: 12902 +Test: [170/237] eta: 0:00:32 psnr: 25.57 ssim: 0.77 time: 0.4618 data: 0.0002 max mem: 12902 +Test: [180/237] eta: 0:00:27 psnr: 27.12 ssim: 0.80 time: 0.4618 data: 0.0002 max mem: 12902 +Test: [190/237] eta: 0:00:22 psnr: 23.36 ssim: 0.59 time: 0.4290 data: 0.0002 max mem: 12902 +Test: [200/237] eta: 0:00:17 psnr: 25.49 ssim: 0.78 time: 0.4296 data: 0.0002 max mem: 12902 +Test: [210/237] eta: 0:00:12 psnr: 27.89 ssim: 0.82 time: 0.4322 data: 0.0002 max mem: 12902 +Test: [220/237] eta: 0:00:08 psnr: 28.20 ssim: 0.81 time: 0.4336 data: 0.0002 max mem: 12902 +Test: [230/237] eta: 0:00:03 psnr: 26.18 ssim: 0.78 time: 0.4322 data: 0.0002 max mem: 12902 +Test: [236/237] eta: 0:00:00 psnr: 25.18 ssim: 0.77 time: 0.5410 data: 0.0002 max mem: 12902 +Test: Total time: 0:01:53 (0.4800 s / it) +[12-01 16:11:10] Averaged validate stats:psnr: 25.48485653640427 ssim: 0.7470527462248057 +[12-01 16:11:10] best performance achieved at epoch -1 with performance of 25.48485653640427 + +MEMORY: 6818MB \ No newline at end of file diff --git a/__pycache__/prepare_extra_training_dataset.cpython-310.pyc b/__pycache__/prepare_extra_training_dataset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f41488c71b66d378c0546d1645dfa8b68363fa8 Binary files /dev/null and b/__pycache__/prepare_extra_training_dataset.cpython-310.pyc differ diff --git a/__pycache__/prepare_extra_training_dataset.cpython-37.pyc b/__pycache__/prepare_extra_training_dataset.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38dec195f82e86f392f844ceb758842ee18fe82a Binary files /dev/null and b/__pycache__/prepare_extra_training_dataset.cpython-37.pyc differ diff --git a/__pycache__/vfi_inference.cpython-310.pyc b/__pycache__/vfi_inference.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a56abbe66953faa8d532f5e9d60ad1ba584d640c Binary files /dev/null and b/__pycache__/vfi_inference.cpython-310.pyc differ diff --git a/__pycache__/vfi_inference_triplet.cpython-310.pyc b/__pycache__/vfi_inference_triplet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0947064471c346b5b504f3a8a247c6066d35163 Binary files /dev/null and b/__pycache__/vfi_inference_triplet.cpython-310.pyc differ diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..665df9ad0690543e034d9ccdb17b2a4ee0978e60 --- /dev/null +++ b/app.py @@ -0,0 +1,408 @@ +import gradio as gr +import numpy as np +import cv2 +import os +import glob +import torch +import shutil +from PIL import Image +from tqdm import tqdm +from torch.nn import functional as F +from torchvision.transforms import functional as TF +from matplotlib import pyplot as plt +from modules.components.upr_net_freq import upr_freq as upr_freq002 +from modules.components.upr_basic import upr as upr_basic +import datetime +import zipfile + +os.system('python -m pip install --upgrade pip') + +#from scipy.interpolate import make_interp_spline + +# python3 -m vfi_inference_triplet --cuda_index 0 \ +# --root ../VFI_Inference/thistriplet_notarget --pretrain_path ./pretrained/upr_freq002.pth \ +# --pyr_level 7 --nr_lvl_skipped 0 --splat_mode average --down_scale 1 + +# ์•„์ด๋””, ๋น„๋ฐ€๋ฒˆํ˜ธ ํŠœํ”Œ, ๋ฆฌ์ŠคํŠธ์— ์ถ”๊ฐ€ํ•˜๋ฉด ์—ฌ๋Ÿฌ ์‚ฌ์šฉ์ž๊ฐ€ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•ฉ๋‹ˆ๋‹ค. +# ๋‹ค๋ฅธ ํŒŒ์ผ๋กœ ๋งŒ๋“ค์–ด ์‚ฌ์šฉํ•˜๊ธฐ๋ฅผ ๊ถŒ์žฅํ•ฉ๋‹ˆ๋‹ค. +KEY = [("test", "test"), + ] + +# ๋กœ๊ทธ์ธ ์‹œ ํ˜ธ์ถœ๋˜๋Š” ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค. +#ํ˜น์‹œ ๋กœ๊ทธ์ธ์— ๋Œ€ํ•œ ์ •๋ณด๋‚˜, ๋‹คip ๋“ฑ์„ ์–ป๊ณ  ์‹ถ์œผ๋ฉด ์ด ๋ถ€๋ถ„ ์ˆ˜์ •๋ฐ”๋ž๋‹ˆ๋‹ค. +def check_valid_login(user_name, password): + #client_ip = request.client.host + #print(client_ip) + flag = (user_name, password) in KEY + return flag + +# ๋น„๋””์˜ค์—์„œ ์ฒ˜์Œ ๋ช‡ ํ”„๋ ˆ์ž„์„ ์ž๋ฅผ์ง€ ๋ณ€์ˆ˜์ž…๋‹ˆ๋‹ค. +MAX_FRAME = 24 + +#VFI inference ์ฝ”๋“œ๋ฅผ ๊ทธ๋Œ€๋กœ ๊ฐ€์ ธ์™”์Šต๋‹ˆ๋‹ค. +DEVICE = 0#"cuda" +torch.cuda.set_device(DEVICE) +#ROOT = args.root +#SAVE_ROOT = f'output' +SCALE = 1 +pyr_level = 7 +nr_lvl_skipped = 0 +splat_mode = "average" +pretrain_path = "./pretrained/upr_freq002.pth" + +model = upr_freq002.Model(pyr_level=pyr_level, + nr_lvl_skipped=nr_lvl_skipped, + splat_mode=splat_mode) +sd = torch.load(pretrain_path, map_location='cpu') +sd = sd['model'] if 'model' in sd.keys() else sd +print(model.load_state_dict(sd)) +model = model.to(DEVICE) + +def get_sorted_img(file_path): + return sorted(glob.glob(os.path.join(file_path, f"*.png")), key = lambda x : float(x.split("_")[-1][:-4])) + +def multiple_pad(image, multiple): + _,_,H,W = image.size() + pad1 = multiple-(H%multiple) if H%multiple!=0 else 0 + pad2 = multiple-(W%multiple) if W%multiple!=0 else 0 + return TF.pad(image, (0,0,pad2,pad1)) + +#์ด๋ฏธ์ง€ 1(path1), 2๋ฅผ VFIํ•˜์—ฌ ๊ฐ€์šด๋ฐ ์ด๋ฏธ์ง€๋ฅผ ์ƒ์„ฑํ•˜๋Š” ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค. +def multiple_VFIx2(path1, path2, output_name): + file_list = [path1, path2] + img_list = [(torch.from_numpy(cv2.imread(file)[:,:,[2,1,0]])/255).permute(2,0,1).unsqueeze(0).to(DEVICE) for file in file_list] + img_list = [multiple_pad(img, SCALE) for k, img in enumerate(img_list)] + img_list = [F.interpolate(img, scale_factor=1/SCALE, mode='bicubic') for k, img in enumerate(img_list)] + img0,img1 = img_list + _,_,Hori,Wori = img0.size() + + with torch.no_grad(): + result_dict, extra_dict = model(img0, img1, pyr_level=pyr_level, nr_lvl_skipped=nr_lvl_skipped, time_step=0.5) + out = F.interpolate(result_dict['imgt_pred'], scale_factor=SCALE, mode='bicubic')[:,:,:Hori,:Wori].clamp(0,1) + cv2.imwrite(output_name, (out[0].cpu().permute(1,2,0)*255).numpy().astype(np.uint8)[:,:,[2,1,0]]) + torch.cuda.empty_cache() + +#1, 2๋ฅผ 3๋ฒˆ VFIํ•˜์—ฌ 3์žฅ์„ ๋งŒ๋“œ๋Š” ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค. +""" +def multiple_VFIx4(path1, path2, name1, name2, name3): + multiple_VFIx2(path1, path2, name2) + multiple_VFIx2(path1, name2, name1) + multiple_VFIx2(name2, path2, name3) +""" +def multiple_VFIx4(path1, path2): + frac = [".25", ".5", ".75"] + name1 , name2, name3 = [f"{path1[:-4]}{f}.png" for f in frac] + multiple_VFIx2(path1, path2, name2) + multiple_VFIx2(path1, name2, name1) + multiple_VFIx2(name2, path2, name3) + +#0, 0.125 , 0.25, 0.5, 0.75, 0.875, 1๋กœ 5์žฅ ์ƒ์„ฑ +def multiple_VFIx6(path1, path2): + frac = [".125", ".25", ".75", ".875"] + name_inf1 , name1, name2, name_inf2 = [f"{path1[:-4]}{f}.png" for f in frac] + multiple_VFIx4(path1, path2) + multiple_VFIx2(path1, name1, name_inf1) + multiple_VFIx2(name2, path2, name_inf2) + +#๋น„๋””์˜ค์—์„œ fix๋ฅผ ํ•˜์—ฌ ์ด๋ฏธ์ง€๋ฅผ ๋Œ€์ฒดํ•˜์—ฌ ์ถœ๋ ฅํ•˜๋Š” ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค. +def fix_img(idx, fixed_list, input_dir = "input", output_dir = "output"): + idx = int(idx) + #์˜ฌ๋ฐ”๋ฅด์ง€ ์•Š๊ฑฐ๋‚˜, ์ด๋ฏธ fix ํ–ˆ๋‹ค๋ฉด ๋ณ€ํ™” x + if idx < 1 or idx > MAX_FRAME - 2 or fixed_list[idx] == 1: + return { + fix_result_gallery : gr.Gallery(), + fix_result_group : gr.Group(), + fixed_frame : gr.Text() + } + now_time = os.path.basename(input_dir) + output_dir = os.path.join(output_dir, f"fix_{now_time}") + os.makedirs(output_dir, exist_ok = True) + output_name = os.path.join(output_dir, f"img_{idx:03d}.png") + + multiple_VFIx2(os.path.join(input_dir, f"img_{idx - 1:03d}.png"), + os.path.join(input_dir, f"img_{idx + 1:03d}.png"), + output_name) + fixed_list[idx] = 1 + + fixed_frame_string = "" + result_list = [] + name_list = [] + #์ˆœ์ฐจ์ ์œผ๋กœ ๊ฒฐ๊ณผ ๊ฐค๋Ÿฌ๋ฆฌ ๊ฐฑ์‹  + for i in range(MAX_FRAME): + if fixed_list[i] == 1: + name_list.append(f"(fixed) frame {i}") + result_list.append(os.path.join(output_dir, f"img_{i:03d}.png")) + fixed_frame_string += f"{i}, " + else: + name_list.append(f"frame {i}") + result_list.append(os.path.join(input_dir, f"img_{i:03d}.png")) + return { + fix_result_gallery : gr.Gallery(value = [(img, name) for img, name in zip(result_list, name_list)], selected_index = idx), + fix_result_group : gr.Group(visible=True), + fixed_frame : gr.Text(visible=True, value = fixed_frame_string[:-2]), + } + +#์ฃผ์–ด์ง„ ease_val ๋ฆฌ์ŠคํŠธ์˜ ๊ฐ’ ๋ฐ”ํƒ•์œผ๋กœ ease๋ฅผ ์‹คํ–‰์‹œํ‚ค๋Š” ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค. +def ease_frames(ease_val, input_dir = "input", output_dir = "output", progress=gr.Progress(track_tqdm=False)): + #now = os.path.basename(input_dir) + now = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') + output_dir = os.path.join(output_dir, f"ease_{now}") + os.makedirs(output_dir, exist_ok = True) + out_frame_list = [os.path.join(output_dir, f"img_{i:03d}.png") for i in range(MAX_FRAME)] + for i, f in enumerate([os.path.join(input_dir, f"img_{i:03d}.png") for i in range(MAX_FRAME)]): + shutil.copyfile(f, out_frame_list[i]) + img_name = [] + for i in progress.tqdm(range(MAX_FRAME - 1), desc = "VFI frames..."): + img_name.append(f"frame {i}") + if ease_val[i] == 1: pass + #x1๋Š” ์•„๋ฌด๊ฒƒ๋„, x2๋Š” ํ•œ ์žฅ, x4๋Š” 3์žฅ + # ์•„๋ž˜ ๊ธ€์ž ์ถ”๊ฐ€ ๋ถ€๋ถ„์€ ์ƒˆ๋กœ์šด ์ด๋ฏธ์ง€์˜ ์ œ๋ชฉ ๋ฐ”๊พธ๋Š” ๋ถ€๋ถ„์ž…๋‹ˆ๋‹ค. + elif ease_val[i] == 2: + multiple_VFIx2(out_frame_list[i], out_frame_list[i + 1] + , os.path.join(output_dir, f"img_{i:03d}.5.png")) + img_name.append(f"(new) frame {i + 0.5}") + elif ease_val[i] == 3: + multiple_VFIx4(out_frame_list[i], out_frame_list[i + 1]) + img_name.append(f"(new) frame {i + 0.25}") + img_name.append(f"(new) frame {i + 0.5}") + img_name.append(f"(new) frame {i + 0.75}") + img_name.append(f"frame {MAX_FRAME - 1}") + files = get_sorted_img(output_dir) + + #๋‹ค์šด๋กœ๋“œ์šฉ zip ํŒŒ์ผ + zip_name = os.path.join(output_dir,"frame_list.zip") + with zipfile.ZipFile(zip_name, 'w', compression=zipfile.ZIP_DEFLATED) as new_zip: + for x in progress.tqdm(files, desc ="compress file..."): + new_zip.write(x, os.path.basename(x)) + + return { + ease_result_gallery : [(file, name) for file, name in zip(files, img_name)], + ease_make_video : gr.Accordion(visible = True), + last_ease_dir : output_dir, + ease_zip : gr.File(value = zip_name) + } +# ์ด๋ฏธ์ง€ ๋‘ ์žฅ์„ ๋ฐ›์•„ VFI๋ฅผ ์ˆ˜ํ–‰ํ•˜๋Š” ํ•จ์ˆ˜์ž…๋‹ˆ๋‹ค. +def VFI_two(l, r, flag ,output_dir = "output"): + now = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') + output_dir = os.path.join(output_dir, f"fix_img_{now}") + os.makedirs(output_dir, exist_ok = True) + l = Image.fromarray(l) + r = Image.fromarray(r) + #๋ฉ”๋ชจ๋ฆฌ ์ดˆ๊ณผ๋ฅผ ๋ง‰๊ธฐ ์œ„ํ•ด ์ ๋‹นํ•œ ํฌ๊ธฐ ํ”ฝ์…€ ์ดํ•˜๊ฐ€ ๋˜๋„๋ก ๊ด€๋ฆฌ + W, H = l.size + #1920 * 1080 * 1.2 * 1.2 ๊ฐ€ ๋Œ€์ถฉ 3e6๋ผ ๊ทธ๊ฑธ ๊ธฐ์ค€์œผ๋กœ ์žก์•˜์Šต๋‹ˆ๋‹ค. + mul = ((3e6) / (W * H)) ** (1/2) + H, W = int(H * mul), int(W * mul) + #์ด๋ฏธ์ง€๊ฐ€ ์ปค์„œ ์ค„์—ฌ์•ผ ํ•œ๋‹ค๋ฉด ๊ฐ์†Œ, ์•„๋‹˜ ๊ทธ๋ƒฅ ์ž…๋ ฅ + if mul < 1: + l = l.resize((W, H)) + r = r.resize((W, H)) + l_name, r_name = f"{output_dir}/img_000.png", f"{output_dir}/img_001.png" + l.save(l_name) + r.save(r_name) + + if flag == "x4": + multiple_VFIx4(l_name, r_name) + elif flag == "x2": + output_name = f"{output_dir}/img_000.5.png" + multiple_VFIx2(l_name, r_name, output_name) + else: + multiple_VFIx6(l_name, r_name) + + return { + frame_gen_result_gallery : gr.Gallery(visible=True, value=get_sorted_img(output_dir)) + } +#๋‹ค๋ฅธ ์ด๋ฏธ์ง€ ์ž…๋ ฅ์„ ์œ„ํ•ด ์ž…๋ ฅ๋œ ์ด๋ฏธ์ง€๋ฅผ ๋‚ ๋ฆฌ๋Š” ํ•ฉ์ˆ˜์ž…๋‹ˆ๋‹ค. +def clear_fix(): + return{ + img_0 : gr.Image(label="start image", sources =["upload"], value = None), + img_1 : gr.Image(label="end image", sources =["upload"], value = None), + frame_gen_result_gallery : gr.Gallery(visible=True, value=None) + } + +with gr.Blocks(theme=gr.themes.Default(), title = "Inshorts Animator V. 0.5") as demo: + def info(request: gr.Request): + #ip๋ฅผ ์–ป๋Š” ๋ถ€๋ถ„์ž…๋‹ˆ๋‹ค. + #์ถ”ํ›„ ํŠน์ • ip ํ—ˆ์šฉ, ์ฐจ๋‹จ ๋“ฑ์ด ํ•„์š”ํ•˜๋ฉด ์ด์ชฝ ์ฐธ๊ณ ํ•ด ์ฃผ์„ธ์š” + headers = request.headers + print(headers["x-forwarded-for"].split(",")) + demo.load(info, None) + gr.Markdown(f"""# Inshorts Animator V. 0.5 WebUI (Permitted User Only)""") + with gr.Tab("Mid Frame Generator"): + with gr.Column(): + + with gr.Row(): + img_0 = gr.Image(label="start image", sources =["upload"]) + img_1 = gr.Image(label="end image", sources =["upload"]) + with gr.Row(): + VFI_flag = gr.Radio(["x2", "x4", "x6(side ease)"], label="VFI ratio", value = "x2", interactive = True) + image_button = gr.Button("Run model") + frame_gen_result_gallery = gr.Gallery(visible=True, + label="result", columns=[5], rows=[1], object_fit="contain", height="auto", preview = True, + interactive = False) + image_button.click(VFI_two, inputs=[img_0, img_1, VFI_flag], + outputs=[frame_gen_result_gallery]) + clear_button = gr.Button("Clear images") + clear_button.click(clear_fix, inputs=[], + outputs=[img_0, img_1, frame_gen_result_gallery]) + with gr.Tab("Video"): + with gr.Group(visible=True) as video_input_group: + gr.Markdown(f"""#### only can handle {MAX_FRAME} frames""") + with gr.Column(): + input_dir = gr.State("") + fps = gr.Number(visible=False) + video_input = gr.Video(label="Input Video", interactive=True, sources=['upload']) + gr.Markdown(f"""If video frame size is big, it will be resized""") + upload_button = gr.Button("upload video") + with gr.Group(visible=False) as image_edit_group: + with gr.Row(): + with gr.Column(): + with gr.Tab("Original Frame (for Monitoring)"): + fixed_list = gr.State([0] * (MAX_FRAME)) + selected = gr.Number(visible=False, label = "selected frame", interactive = False) + image_gallery = gr.Gallery( + label="inputs", columns=[MAX_FRAME], rows=[1], object_fit="contain", height="auto", preview = True, + show_download_button=False) + clear_video = gr.Button("clear video") + with gr.Tab("Frame Fixer"): + with gr.Column(): + #with gr.Row(): + #with gr.Row(): + with gr.Group(visible=True) as fix_result_group: + fix_result_gallery = gr.Gallery( + label="result", columns=[MAX_FRAME], rows=[1], object_fit="contain", height="auto", preview = True, + interactive = False) + fix_button = gr.Button(visible = True) + with gr.Row(): + fixed_frame = gr.Text(visible=False, label = "fixed frame", interactive = False) + fix_button.click(fix_img, inputs=[selected, fixed_list, input_dir], + outputs=[fix_result_gallery, fix_result_group, fixed_frame]) + def update_fix_button_visible(evt: gr.SelectData): + flag = 0 < evt.index < MAX_FRAME - 1 + msg = f"fix frame {evt.index}" if flag else f"can only fix 1 ~ {MAX_FRAME - 2}" + return { + fix_button:gr.Button(msg, visible=True), + selected : evt.index, + fix_result_gallery : gr.Gallery(selected_index = evt.index) + } + image_gallery.select(update_fix_button_visible, None, [fix_button, selected, fix_result_gallery]) + with gr.Tab("Motion easer"): + with gr.Column(): + with gr.Column(): + #with gr.Row(): + with gr.Group(visible=True) as ease_result_group: + last_ease_dir = gr.State("") + ease_result_gallery = gr.Gallery( + label="result", columns=[MAX_FRAME], rows=[4], object_fit="contain", height="auto", preview = True, + interactive = False) + ease_button = gr.Button("ease", visible = True) + plt_data = gr.State([1] * (MAX_FRAME - 1)) + VFI_x = gr.Radio([("x1", 1), ("x2", 2), ("x4", 3)], value = 1, label="Slow ratio", info="adjust Slow ratio", interactive = True) + with gr.Row(): + edit_one_button = gr.Button("edit one scale", visible = True) + edit_all_button = gr.Button("edit all scale", visible = True) + now_frame = gr.Slider(0, MAX_FRAME - 1 - 1, step=1, label="Start frame", info="Choose Start frame to make slow. Interpolation will apply to (frame ~ frame + 1)") + + def plt_edit(data): + fig = plt.figure() + x = np.arange(0, MAX_FRAME - 1) + 0.5 + y = np.array(data) + plt.plot(x , y, color = 'black', marker = "o", linewidth = "2.5") + plt.xticks(np.arange(0, MAX_FRAME)) + plt.yticks([1, 2, 3], ["x1", "x2\nslow", "x4\nslow"]) + plt.gca().invert_yaxis() + plt.grid(True) + plt.rcParams['xtick.bottom'] = plt.rcParams['xtick.labelbottom'] = False + plt.rcParams['xtick.top'] = plt.rcParams['xtick.labeltop'] = True + return fig + ease_plot = gr.Plot(value = plt_edit(plt_data.value), show_label=False) + with gr.Accordion("get result", visible = False) as ease_make_video: + ease_zip = gr.File(label = "Download all image frames in Zip", interactive = False) + make_video_button = gr.Button("make video") + result_video = gr.Video(interactive = False) + def make_video(frame_dir, fps): + t = os.path.basename(frame_dir) + output_name = f"{frame_dir}/{t}.mp4" + if os.path.exists(output_name): + os.remove(output_name) + frame_list = get_sorted_img(frame_dir) + with open(f"{frame_dir}/input.txt", "w") as f: + for line in frame_list: + f.write(f"file '{os.path.basename(line)}'\n") + cmd = f'ffmpeg -r {fps} -f concat -safe 0 -i {frame_dir}/input.txt -c:v libx264 -preset veryslow -crf 10 {output_name}' + os.system(cmd) + return output_name + make_video_button.click(make_video, inputs = [last_ease_dir, fps], outputs = [result_video]) + ease_button.click(ease_frames, inputs=[plt_data, input_dir], outputs=[ease_result_gallery, ease_make_video, last_ease_dir, ease_zip]) + + def edit_one_scale(data, idx, x): + if idx < MAX_FRAME - 1: + data[idx] = x if x else 1 + return plt_edit(data) + edit_one_button.click(edit_one_scale, inputs=[plt_data, now_frame, VFI_x] , outputs=[ease_plot]) + def edit_all_scale(data, x): + for i in range(len(data)): data[i] = x if x else 1 + return plt_edit(data) + edit_all_button.click(edit_all_scale, inputs=[plt_data, VFI_x], outputs=[ease_plot]) + def clear_vd(plt_data, fixed_list): + for i in range(len(plt_data)): plt_data[i] = 1 + for i in range(len(fixed_list)): fixed_list[i] = 0 + return {video_input:gr.Video(label="Input Video", interactive=True, sources=['upload'], value = None), + ease_result_gallery : gr.Gallery( + label="result", columns=[MAX_FRAME], rows=[4], object_fit="contain", height="auto", preview = True, + interactive = False, value = None), + fix_result_gallery : gr.Gallery( + label="result", columns=[MAX_FRAME], rows=[1], object_fit="contain", height="auto", preview = True, + interactive = False, value = None), + fixed_frame : gr.Text(visible=False, label = "fixed frame", interactive = False, value = None), + ease_make_video : gr.Accordion(visible = True), + video_input_group:gr.Group(visible=True), + image_edit_group:gr.Group(visible=False), + ease_plot : gr.Plot(value = plt_edit(plt_data))} + clear_video.click(clear_vd, inputs=[plt_data, fixed_list],outputs=[video_input, ease_result_gallery, fix_result_gallery, fixed_frame, ease_make_video, video_input_group, image_edit_group, ease_plot]) + def update_video_visible(video): + if not video: + return {video_input_group:gr.Group(visible=True), + image_edit_group:gr.Group(visible=False), + image_gallery:[], + input_dir : "", + fps : 0 + } + now = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') + input_now = os.path.join("input", now) + os.makedirs(input_now, exist_ok = True) + cap = cv2.VideoCapture(video) + frame_count = 0 + video_fps = cap.get(cv2.CAP_PROP_FPS) + #print('video fps:', video_fps) + + H = cap.get(cv2.CAP_PROP_FRAME_HEIGHT) + W = cap.get(cv2.CAP_PROP_FRAME_WIDTH) + mul = ((3e6) / (W * H)) ** (1/2) + H, W = int(H * mul), int(W * mul) + + frame_name_list = [] + while cap.isOpened(): + ret, frame = cap.read() + if not ret: + break + img_name = os.path.join(input_now, f"img_{frame_count:03d}.png") + if mul < 1: + frame = cv2.resize(frame, (W, H), interpolation=cv2.INTER_CUBIC) + cv2.imwrite(img_name, frame) + frame_name_list.append((img_name, f"frame {frame_count}")) + frame_count += 1 + if frame_count >= MAX_FRAME: break + cap.release() + return {video_input_group:gr.Group(visible=False), + image_edit_group:gr.Group(visible=True), + image_gallery:frame_name_list, + input_dir : input_now, + fps : video_fps + } + upload_button.click(update_video_visible, + [video_input], + [video_input_group, image_edit_group, image_gallery, input_dir, fps]) + +if __name__ == '__main__': + demo.launch(allowed_paths=["./input", "./output"], auth = check_valid_login, auth_message = "Inshorts Animator V. 0.5 WebUI (Permitted User Only)", share = True) diff --git a/cfgs/upr_backward_unimatch_exp1.yaml b/cfgs/upr_backward_unimatch_exp1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..331f9b69584a9d203a1572e6387a1c28b6c2e958 --- /dev/null +++ b/cfgs/upr_backward_unimatch_exp1.yaml @@ -0,0 +1,75 @@ +exp_name: backwarp_resize + +mode: train + +train_dataset: + name: vimeo + args: + root_path: ../datasets/upr_backward + split: train + patch_size: 256 + flow: none + loader: + batch_size: 32 + num_workers: 8 + +test_dataset: + name: vimeo + args: + root_path: ../datasets/Vimeo90K + split: val + loader: + batch_size: 32 + num_workers: 8 + +demo_dataset: + name: demo + args: + root_path: ../data/animation + split: animation + +model: + name: upr_net_mod + args: + pyr_level: 3 + nr_lvl_skipped: 0 + splat_mode: average + + +optimizer: + name: adamW + args: {lr: 0.00015, weight_decay: 1.e-4} + +lr_scheduler: + name: one_cycle_lr + args: + max_lr: 0.00015 + pct_start: 0.01 + cycle_momentum: False + anneal_strategy: cos + +loss: + - { + name: multiple_flow, + args: { weight: 0.005 } + } + - { + name: charbonnier, + args: { weight: 1 } + } + - { + name: ternary, + args: { weight: 1 } + } + +max_epoch: 838 + +validate_every: 5 +save_every: 10 +vis_every: 40 + +seed: 1234 + +dist_url: env:// +resume: ./save/upr_backward_unimatch_exp1_backwarp_resize/checkpoints/model_200.pth +# pretrained: ./pretrained/upr_backwarp_corr_resize_model.pth diff --git a/cfgs/upr_backward_unimatch_exp2.yaml b/cfgs/upr_backward_unimatch_exp2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1c478a4949d1f3b92c8f8513af09d02fe028688a --- /dev/null +++ b/cfgs/upr_backward_unimatch_exp2.yaml @@ -0,0 +1,76 @@ +exp_name: backwarp_resize + +mode: train + +train_dataset: + name: vimeo + args: + root_path: ../datasets/Vimeo90Kx2 + flow_root: ../datasets/unimatch_flow + split: train + patch_size: 512 + flow: t0 + loader: + batch_size: 16 + num_workers: 16 + +test_dataset: + name: vimeo + args: + root_path: ../datasets/Vimeo90K + split: val + loader: + batch_size: 32 + num_workers: 8 + +demo_dataset: + name: demo + args: + root_path: ../data/animation + split: animation + +model: + name: upr_net_mod + args: + pyr_level: 3 + nr_lvl_skipped: 0 + splat_mode: average + + +optimizer: + name: adamW + args: {lr: 0.00015, weight_decay: 1.e-4} + +lr_scheduler: + name: one_cycle_lr + args: + max_lr: 0.00015 + pct_start: 0.01 + cycle_momentum: False + anneal_strategy: cos + +loss: + - { + name: multiple_flow, + args: { weight: 0.005 } + } + - { + name: charbonnier, + args: { weight: 1 } + } + - { + name: ternary, + args: { weight: 1 } + } + +max_epoch: 250 + +validate_every: 5 +save_every: 10 +vis_every: 40 + +seed: 1234 + +dist_url: env:// +resume: ./save/upr_backward_unimatch_exp2_backwarp_resize/checkpoints/model_30.pth +# pretrained: ./pretrained/upr_backwarp_corr_resize_model.pth diff --git a/cfgs/upr_basic.yaml b/cfgs/upr_basic.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fd7ccdc328c381b89830fa0ecb08776e0497b363 --- /dev/null +++ b/cfgs/upr_basic.yaml @@ -0,0 +1,72 @@ +exp_name: UPR_basic + +mode: train + +train_dataset: + name: vimeo + args: + root_path: ../datasets/Vimeo90K + flow_root: ../datasets/unimatch_flow + split: train + patch_size: 256 + flow: none + use_distance: False + loader: + batch_size: 32 + num_workers: 8 + +test_dataset: + name: vimeo + args: + root_path: ../datasets/Vimeo90K + split: val + use_distance: False + loader: + batch_size: 16 + num_workers: 4 + save_imgs: False + +demo_dataset: + name: demo + args: + root_path: ../data/animation + split: animation + +model: + name: upr_basic + args: + pyr_level: 3 + nr_lvl_skipped: 0 + +optimizer: + name: adamW + args: {lr: 1.5e-4, weight_decay: 1.e-4} + +lr_scheduler: + name: one_cycle_lr + args: + max_lr: 1.5e-4 + pct_start: 0.01 + cycle_momentum: False + anneal_strategy: cos + +loss: + - { + name: charbonnier, + args: { weight: 1 } + } + - { + name: ternary, + args: { weight: 1 } + } + + +max_epoch: 540 + +validate_every: 10 +save_every: 10 +vis_every: 20 + +seed: 1234 + +dist_url: env:// \ No newline at end of file diff --git a/cfgs/upr_custom_4090_exp1.yaml b/cfgs/upr_custom_4090_exp1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a67668fd15b84871109d0a973a1e42952e44f801 --- /dev/null +++ b/cfgs/upr_custom_4090_exp1.yaml @@ -0,0 +1,70 @@ +exp_name: backwarp_resize + +mode: train + +train_dataset: + name: vimeo + args: + root_path: /mnt/nas/512_new + split: train + patch_size: 512 + flow: none + loader: + batch_size: 16 + num_workers: 8 + +test_dataset: + name: vimeo + args: + root_path: /mnt/nas/vimeo_triplet + split: val + loader: + batch_size: 16 + num_workers: 8 + +demo_dataset: + name: demo + args: + root_path: ../data/animation + split: animation + +model: + name: upr_net_mod + args: + pyr_level: 5 + nr_lvl_skipped: 0 + splat_mode: average + + +optimizer: + name: adamW + args: {lr: 0.000075, weight_decay: 1.e-4} + +lr_scheduler: + name: one_cycle_lr + args: + max_lr: 0.000075 + pct_start: 0.01 + cycle_momentum: False + anneal_strategy: cos + +loss: + - { + name: charbonnier, + args: { weight: 1 } + } + - { + name: ternary, + args: { weight: 1 } + } + +max_epoch: 533 + +validate_every: 1 +save_every: 10 +vis_every: 20 + +seed: 1234 + +dist_url: env:// +pretrained: ./pretrained/upr_backwarp_corr_resize_model.pth diff --git a/cfgs/upr_custom_4090_exp1_10epoch_test.yaml b/cfgs/upr_custom_4090_exp1_10epoch_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a67668fd15b84871109d0a973a1e42952e44f801 --- /dev/null +++ b/cfgs/upr_custom_4090_exp1_10epoch_test.yaml @@ -0,0 +1,70 @@ +exp_name: backwarp_resize + +mode: train + +train_dataset: + name: vimeo + args: + root_path: /mnt/nas/512_new + split: train + patch_size: 512 + flow: none + loader: + batch_size: 16 + num_workers: 8 + +test_dataset: + name: vimeo + args: + root_path: /mnt/nas/vimeo_triplet + split: val + loader: + batch_size: 16 + num_workers: 8 + +demo_dataset: + name: demo + args: + root_path: ../data/animation + split: animation + +model: + name: upr_net_mod + args: + pyr_level: 5 + nr_lvl_skipped: 0 + splat_mode: average + + +optimizer: + name: adamW + args: {lr: 0.000075, weight_decay: 1.e-4} + +lr_scheduler: + name: one_cycle_lr + args: + max_lr: 0.000075 + pct_start: 0.01 + cycle_momentum: False + anneal_strategy: cos + +loss: + - { + name: charbonnier, + args: { weight: 1 } + } + - { + name: ternary, + args: { weight: 1 } + } + +max_epoch: 533 + +validate_every: 1 +save_every: 10 +vis_every: 20 + +seed: 1234 + +dist_url: env:// +pretrained: ./pretrained/upr_backwarp_corr_resize_model.pth diff --git a/cfgs/upr_custom_4090_exp1_extract.yaml b/cfgs/upr_custom_4090_exp1_extract.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3b28c27f4486127c8d66b162bfd3eb0c7f44828a --- /dev/null +++ b/cfgs/upr_custom_4090_exp1_extract.yaml @@ -0,0 +1,70 @@ +exp_name: backwarp_resize + +mode: train + +train_dataset: + name: vimeo + args: + root_path: /mnt/nas/512_new + split: train + patch_size: 512 + flow: none + loader: + batch_size: 16 + num_workers: 8 + +test_dataset: + name: vimeo + args: + root_path: ../datasets/Vimeo90K + split: val + loader: + batch_size: 4 + num_workers: 8 + +demo_dataset: + name: demo + args: + root_path: ../data/animation + split: animation + +model: + name: upr_net_mod + args: + pyr_level: 5 + nr_lvl_skipped: 0 + splat_mode: average + + +optimizer: + name: adamW + args: {lr: 0.000075, weight_decay: 1.e-4} + +lr_scheduler: + name: one_cycle_lr + args: + max_lr: 0.000075 + pct_start: 0.01 + cycle_momentum: False + anneal_strategy: cos + +loss: + - { + name: charbonnier, + args: { weight: 1 } + } + - { + name: ternary, + args: { weight: 1 } + } + +max_epoch: 533 + +validate_every: 1 +save_every: 10 +vis_every: 20 + +seed: 1234 + +dist_url: env:// +pretrained: ./pretrained/upr_backwarp_corr_resize_model.pth diff --git a/cfgs/upr_custom_4090_exp2.yaml b/cfgs/upr_custom_4090_exp2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f188f189ad877d4b80afc9b22912c981539e86a1 --- /dev/null +++ b/cfgs/upr_custom_4090_exp2.yaml @@ -0,0 +1,70 @@ +exp_name: backwarp_resize + +mode: train + +train_dataset: + name: vimeo + args: + root_path: ../datasets/upr_backward + split: train + patch_size: 256 + flow: none + loader: + batch_size: 32 + num_workers: 8 + +test_dataset: + name: vimeo + args: + root_path: ../datasets/Vimeo90K + split: val + loader: + batch_size: 32 + num_workers: 8 + +demo_dataset: + name: demo + args: + root_path: ../data/animation + split: animation + +model: + name: upr_net_mod + args: + pyr_level: 5 + nr_lvl_skipped: 0 + splat_mode: average + + +optimizer: + name: adamW + args: {lr: 0.00015, weight_decay: 1.e-4} + +lr_scheduler: + name: one_cycle_lr + args: + max_lr: 0.00015 + pct_start: 0.01 + cycle_momentum: False + anneal_strategy: cos + +loss: + - { + name: charbonnier, + args: { weight: 1 } + } + - { + name: ternary, + args: { weight: 1 } + } + +max_epoch: 533 + +validate_every: 1 +save_every: 10 +vis_every: 20 + +seed: 1234 + +dist_url: env:// +pretrained: ./pretrained/upr_backwarp_corr_resize_model.pth diff --git a/cfgs/upr_freq003.yaml b/cfgs/upr_freq003.yaml new file mode 100644 index 0000000000000000000000000000000000000000..37a44bb86023749e5b7732acd283d84ebbc0db40 --- /dev/null +++ b/cfgs/upr_freq003.yaml @@ -0,0 +1,88 @@ +exp_name: FET-VFI_BiFreqLoss_Distance + +mode: train + +train_dataset: + name: vimeo + args: + root_path: ../datasets/Vimeo90K + flow_root: ../datasets/unimatch_flow + split: train + patch_size: 256 + flow: 't0' + use_distance: True + distance_root: ../datasets/distance_map + loader: + batch_size: 32 + num_workers: 8 + +test_dataset: + name: vimeo + args: + root_path: ../datasets/Vimeo90K + split: val + use_distance: True + loader: + batch_size: 16 + num_workers: 4 + save_imgs: False + +demo_dataset: + name: demo + args: + root_path: ../data/animation + split: animation + +model: + name: upr_net_freq + args: + pyr_level: 3 + nr_lvl_skipped: 0 + splat_mode: average + fftshift: False + + +optimizer: + name: adamW + args: {lr: 1.5e-4, weight_decay: 1.e-4} + +lr_scheduler: + name: one_cycle_lr + args: + max_lr: 1.5e-4 + pct_start: 0.01 + cycle_momentum: False + anneal_strategy: cos + +loss: + - { + name: frequency, + args: { weight: 0.01 } + } + - { + name: bi_frequency, + args: { weight: 0.01 } + } + - { + name: multiple_flow, + args: { weight: 0.005 } + } + - { + name: charbonnier, + args: { weight: 1 } + } + - { + name: ternary, + args: { weight: 1 } + } + + +max_epoch: 540 + +validate_every: 10 +save_every: 10 +vis_every: 20 + +seed: 1234 + +dist_url: env:// \ No newline at end of file diff --git a/cfgs/upr_freq004.yaml b/cfgs/upr_freq004.yaml new file mode 100644 index 0000000000000000000000000000000000000000..37a44bb86023749e5b7732acd283d84ebbc0db40 --- /dev/null +++ b/cfgs/upr_freq004.yaml @@ -0,0 +1,88 @@ +exp_name: FET-VFI_BiFreqLoss_Distance + +mode: train + +train_dataset: + name: vimeo + args: + root_path: ../datasets/Vimeo90K + flow_root: ../datasets/unimatch_flow + split: train + patch_size: 256 + flow: 't0' + use_distance: True + distance_root: ../datasets/distance_map + loader: + batch_size: 32 + num_workers: 8 + +test_dataset: + name: vimeo + args: + root_path: ../datasets/Vimeo90K + split: val + use_distance: True + loader: + batch_size: 16 + num_workers: 4 + save_imgs: False + +demo_dataset: + name: demo + args: + root_path: ../data/animation + split: animation + +model: + name: upr_net_freq + args: + pyr_level: 3 + nr_lvl_skipped: 0 + splat_mode: average + fftshift: False + + +optimizer: + name: adamW + args: {lr: 1.5e-4, weight_decay: 1.e-4} + +lr_scheduler: + name: one_cycle_lr + args: + max_lr: 1.5e-4 + pct_start: 0.01 + cycle_momentum: False + anneal_strategy: cos + +loss: + - { + name: frequency, + args: { weight: 0.01 } + } + - { + name: bi_frequency, + args: { weight: 0.01 } + } + - { + name: multiple_flow, + args: { weight: 0.005 } + } + - { + name: charbonnier, + args: { weight: 1 } + } + - { + name: ternary, + args: { weight: 1 } + } + + +max_epoch: 540 + +validate_every: 10 +save_every: 10 +vis_every: 20 + +seed: 1234 + +dist_url: env:// \ No newline at end of file diff --git a/cfgs/upr_freq005.yaml b/cfgs/upr_freq005.yaml new file mode 100644 index 0000000000000000000000000000000000000000..38321739147182317c09ba78ebdf6f3edb270fad --- /dev/null +++ b/cfgs/upr_freq005.yaml @@ -0,0 +1,87 @@ +exp_name: FET-VFI_EncFreqs_AsymFreqDec + +mode: train + +train_dataset: + name: vimeo + args: + root_path: ../datasets/Vimeo90K + flow_root: ../datasets/unimatch_flow + split: train + patch_size: 256 + flow: 't0' + use_distance: False + loader: + batch_size: 32 + num_workers: 8 + +test_dataset: + name: vimeo + args: + root_path: ../datasets/Vimeo90K + split: val + use_distance: False + loader: + batch_size: 16 + num_workers: 4 + save_imgs: False + +demo_dataset: + name: demo + args: + root_path: ../data/animation + split: animation + +model: + name: upr_net_freq + args: + pyr_level: 3 + nr_lvl_skipped: 0 + splat_mode: average + fftshift: False + + +optimizer: + name: adamW + args: {lr: 1.5e-4, weight_decay: 1.e-4} + +lr_scheduler: + name: one_cycle_lr + args: + max_lr: 1.5e-4 + pct_start: 0.01 + cycle_momentum: False + anneal_strategy: cos + +loss: + - { + name: frequency, + args: { weight: 0.01 } + } + - { + name: bi_frequency, + args: { weight: 0.01 } + } + - { + name: multiple_flow, + args: { weight: 0.005 } + } + - { + name: charbonnier, + args: { weight: 1 } + } + - { + name: ternary, + args: { weight: 1 } + } + + +max_epoch: 540 + +validate_every: 10 +save_every: 10 +vis_every: 20 + +seed: 1234 + +dist_url: env:// \ No newline at end of file diff --git a/cfgs/upr_freq006.yaml b/cfgs/upr_freq006.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e3360c9c23bddfc9c371937b9f1bb9ca916414bf --- /dev/null +++ b/cfgs/upr_freq006.yaml @@ -0,0 +1,83 @@ +exp_name: FET-VFI_Basic_ChangeFreqCNN + +mode: train + +train_dataset: + name: vimeo + args: + root_path: ../datasets/Vimeo90K + flow_root: ../datasets/unimatch_flow + split: train + patch_size: 256 + flow: none + use_distance: False + loader: + batch_size: 32 + num_workers: 8 + +test_dataset: + name: vimeo + args: + root_path: ../datasets/Vimeo90K + split: val + use_distance: False + loader: + batch_size: 16 + num_workers: 4 + save_imgs: False + +demo_dataset: + name: demo + args: + root_path: ../data/animation + split: animation + +model: + name: upr_net_freq2 + args: + pyr_level: 3 + nr_lvl_skipped: 0 + splat_mode: average + fftshift: True + + +optimizer: + name: adamW + args: {lr: 1.5e-4, weight_decay: 1.e-4} + +lr_scheduler: + name: one_cycle_lr + args: + max_lr: 1.5e-4 + pct_start: 0.01 + cycle_momentum: False + anneal_strategy: cos + +loss: + - { + name: frequency, + args: { weight: 0.01 } + } + - { + name: bi_frequency, + args: { weight: 0.01 } + } + - { + name: charbonnier, + args: { weight: 1 } + } + - { + name: ternary, + args: { weight: 1 } + } + + +max_epoch: 540 + +validate_every: 10 +save_every: 10 +vis_every: 20 + +seed: 1234 + +dist_url: env:// \ No newline at end of file diff --git a/cfgs/upr_freq007.yaml b/cfgs/upr_freq007.yaml new file mode 100644 index 0000000000000000000000000000000000000000..aafba8818260373095d5ddd5a1d83e7c67b8a45d --- /dev/null +++ b/cfgs/upr_freq007.yaml @@ -0,0 +1,83 @@ +exp_name: FET-VFI_FETAmpPhaResid + +mode: train + +train_dataset: + name: vimeo + args: + root_path: ../datasets/Vimeo90K + flow_root: ../datasets/unimatch_flow + split: train + patch_size: 256 + flow: none + use_distance: False + loader: + batch_size: 32 + num_workers: 8 + +test_dataset: + name: vimeo + args: + root_path: ../datasets/Vimeo90K + split: val + use_distance: False + loader: + batch_size: 16 + num_workers: 4 + save_imgs: False + +demo_dataset: + name: demo + args: + root_path: ../data/animation + split: animation + +model: + name: upr_net_freq2 + args: + pyr_level: 3 + nr_lvl_skipped: 0 + splat_mode: average + fftshift: True + + +optimizer: + name: adamW + args: {lr: 1.5e-4, weight_decay: 1.e-4} + +lr_scheduler: + name: one_cycle_lr + args: + max_lr: 1.5e-4 + pct_start: 0.01 + cycle_momentum: False + anneal_strategy: cos + +loss: + - { + name: frequency, + args: { weight: 0.01 } + } + - { + name: bi_frequency, + args: { weight: 0.01 } + } + - { + name: charbonnier, + args: { weight: 1 } + } + - { + name: ternary, + args: { weight: 1 } + } + + +max_epoch: 540 + +validate_every: 10 +save_every: 10 +vis_every: 20 + +seed: 1234 + +dist_url: env:// \ No newline at end of file diff --git a/cfgs/upr_freq008.yaml b/cfgs/upr_freq008.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4db5e51689cb1413d6bffd78947578a157991c40 --- /dev/null +++ b/cfgs/upr_freq008.yaml @@ -0,0 +1,83 @@ +exp_name: FET-VFI_FETAmpPhaResid + +mode: train + +train_dataset: + name: vimeo + args: + root_path: ../datasets/Vimeo90K + flow_root: ../datasets/unimatch_flow + split: train + patch_size: 256 + flow: none + use_distance: False + loader: + batch_size: 32 + num_workers: 8 + +test_dataset: + name: vimeo + args: + root_path: ../datasets/Vimeo90K + split: val + use_distance: False + loader: + batch_size: 16 + num_workers: 4 + save_imgs: False + +demo_dataset: + name: demo + args: + root_path: ../data/animation + split: animation + +model: + name: upr_net_freq2 + args: + pyr_level: 3 + nr_lvl_skipped: 0 + splat_mode: average + fftshift: False + + +optimizer: + name: adamW + args: {lr: 1.5e-4, weight_decay: 1.e-4} + +lr_scheduler: + name: one_cycle_lr + args: + max_lr: 1.5e-4 + pct_start: 0.01 + cycle_momentum: False + anneal_strategy: cos + +loss: + - { + name: frequency, + args: { weight: 0.01 } + } + - { + name: bi_frequency, + args: { weight: 0.01 } + } + - { + name: charbonnier, + args: { weight: 1 } + } + - { + name: ternary, + args: { weight: 1 } + } + + +max_epoch: 540 + +validate_every: 10 +save_every: 10 +vis_every: 20 + +seed: 1234 + +dist_url: env:// \ No newline at end of file diff --git a/cfgs/upr_freq_unimatch_exp001.yaml b/cfgs/upr_freq_unimatch_exp001.yaml new file mode 100644 index 0000000000000000000000000000000000000000..85cbe83fc0613dbbf1d570686d846359170b02a6 --- /dev/null +++ b/cfgs/upr_freq_unimatch_exp001.yaml @@ -0,0 +1,83 @@ +exp_name: FreqeuncyEnhancementTransformer + +mode: train + +train_dataset: + name: vimeo + args: + root_path: ../datasets/Vimeo90K + flow_root: ../datasets/unimatch_flow + split: train + patch_size: 256 + flow: 't0' + loader: + batch_size: 32 + num_workers: 8 + +test_dataset: + name: vimeo + args: + root_path: ../datasets/Vimeo90K + split: val + loader: + batch_size: 16 + num_workers: 4 + save_imgs: False + +demo_dataset: + name: demo + args: + root_path: ../data/animation + split: animation + +model: + name: upr_net_freq + args: + pyr_level: 3 + nr_lvl_skipped: 0 + splat_mode: average + fftshift: False + + +optimizer: + name: adamW + args: {lr: 1.5e-4, weight_decay: 1.e-4} + +lr_scheduler: + name: one_cycle_lr + args: + max_lr: 1.5e-4 + pct_start: 0.01 + cycle_momentum: False + anneal_strategy: cos + +loss: + - { + name: frequency, + args: { weight: 0.01 } + } + - { + name: multiple_flow, + args: { weight: 0.005 } + } + - { + name: charbonnier, + args: { weight: 1 } + } + - { + name: ternary, + args: { weight: 1 } + } + + +max_epoch: 540 + +validate_every: 10 +save_every: 10 +vis_every: 20 + +seed: 1234 + +dist_url: env:// + +resume: ./save/upr_freq_unimatch_exp001_FreqeuncyEnhancementTransformer/checkpoints/model_420.pth \ No newline at end of file diff --git a/cfgs/upr_freq_unimatch_exp002.yaml b/cfgs/upr_freq_unimatch_exp002.yaml new file mode 100644 index 0000000000000000000000000000000000000000..297770219d6192624984be6e6eea7786259afc51 --- /dev/null +++ b/cfgs/upr_freq_unimatch_exp002.yaml @@ -0,0 +1,85 @@ +exp_name: FET-VFI_BiFrequencyLoss + +mode: train + +train_dataset: + name: vimeo + args: + root_path: ../datasets/Vimeo90K + flow_root: ../datasets/unimatch_flow + split: train + patch_size: 256 + flow: 't0' + loader: + batch_size: 32 + num_workers: 8 + +test_dataset: + name: vimeo + args: + root_path: ../datasets/Vimeo90K + split: val + loader: + batch_size: 16 + num_workers: 4 + save_imgs: False + +demo_dataset: + name: demo + args: + root_path: ../data/animation + split: animation + +model: + name: upr_net_freq + args: + pyr_level: 3 + nr_lvl_skipped: 0 + splat_mode: average + fftshift: False + + +optimizer: + name: adamW + args: {lr: 1.5e-4, weight_decay: 1.e-4} + +lr_scheduler: + name: one_cycle_lr + args: + max_lr: 1.5e-4 + pct_start: 0.01 + cycle_momentum: False + anneal_strategy: cos + +loss: + - { + name: frequency, + args: { weight: 0.01 } + } + - { + name: bi_frequency, + args: { weight: 0.01 } + } + - { + name: multiple_flow, + args: { weight: 0.005 } + } + - { + name: charbonnier, + args: { weight: 1 } + } + - { + name: ternary, + args: { weight: 1 } + } + + +max_epoch: 540 + +validate_every: 10 +save_every: 10 +vis_every: 20 + +seed: 1234 + +dist_url: env:// \ No newline at end of file diff --git a/cfgs/upr_freq_unimatch_exp002_20240111_extratraining1.yaml b/cfgs/upr_freq_unimatch_exp002_20240111_extratraining1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3d15edd60768039d357edd72808181e671a7093f --- /dev/null +++ b/cfgs/upr_freq_unimatch_exp002_20240111_extratraining1.yaml @@ -0,0 +1,86 @@ +exp_name: FET-VFI_BiFrequencyLoss + +mode: train + +train_dataset: + name: vimeo + args: + root_path: ../datasets/frames1_triplet + flow_root: ../datasets/frames1_unimatch_flow + tri_trainlist: tri_trainlist.txt + split: train + patch_size: 256 + flow: 't0' + loader: + batch_size: 16 + num_workers: 4 + +# test_dataset: +# name: vimeo +# args: +# root_path: ../datasets/Vimeo90K +# split: val +# loader: +# batch_size: 16 +# num_workers: 4 +# save_imgs: False + +# demo_dataset: +# name: demo +# args: +# root_path: ../data/animation +# split: animation + +model: + name: upr_net_freq + args: + pyr_level: 3 + nr_lvl_skipped: 0 + splat_mode: average + fftshift: False + + +optimizer: + name: adamW + args: {lr: 4.e-6, weight_decay: 1.e-4} + +lr_scheduler: + name: one_cycle_lr + args: + max_lr: 4.e-6 + pct_start: 0.01 + cycle_momentum: False + anneal_strategy: cos + +loss: + - { + name: frequency, + args: { weight: 0.01 } + } + - { + name: bi_frequency, + args: { weight: 0.01 } + } + - { + name: multiple_flow, + args: { weight: 0.005 } + } + - { + name: charbonnier, + args: { weight: 1 } + } + - { + name: ternary, + args: { weight: 1 } + } + +pretrained: ./pretrained/upr_freq002.pth +max_epoch: 300 + +validate_every: 10 +save_every: 10 +vis_every: 20 + +seed: 1234 + +dist_url: env:// \ No newline at end of file diff --git a/cfgs/upr_freq_unimatch_exp002_20240111_extratraining1000_2_1.yaml b/cfgs/upr_freq_unimatch_exp002_20240111_extratraining1000_2_1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..68ba756cad6536fca4816d992a4dadb3ee2d67c6 --- /dev/null +++ b/cfgs/upr_freq_unimatch_exp002_20240111_extratraining1000_2_1.yaml @@ -0,0 +1,86 @@ +exp_name: FET-VFI_BiFrequencyLoss + +mode: train + +train_dataset: + name: vimeo + args: + root_path: ../datasets/frames2_triplet + flow_root: ../datasets/frames2_unimatch_flow + tri_trainlist: tri_trainlist.txt + split: train + patch_size: 256 + flow: 't0' + loader: + batch_size: 16 + num_workers: 4 + +# test_dataset: +# name: vimeo +# args: +# root_path: ../datasets/Vimeo90K +# split: val +# loader: +# batch_size: 16 +# num_workers: 4 +# save_imgs: False + +# demo_dataset: +# name: demo +# args: +# root_path: ../data/animation +# split: animation + +model: + name: upr_net_freq + args: + pyr_level: 3 + nr_lvl_skipped: 0 + splat_mode: average + fftshift: False + + +optimizer: + name: adamW + args: {lr: 7.5e-6, weight_decay: 1.e-4} + +lr_scheduler: + name: one_cycle_lr + args: + max_lr: 7.5e-6 + pct_start: 0.01 + cycle_momentum: False + anneal_strategy: cos + +loss: + - { + name: frequency, + args: { weight: 0.01 } + } + - { + name: bi_frequency, + args: { weight: 0.01 } + } + - { + name: multiple_flow, + args: { weight: 0.005 } + } + - { + name: charbonnier, + args: { weight: 1 } + } + - { + name: ternary, + args: { weight: 1 } + } + +pretrained: ./pretrained/upr_freq002.pth +max_epoch: 1000 + +validate_every: 10 +save_every: 10 +vis_every: 20 + +seed: 1234 + +dist_url: env:// \ No newline at end of file diff --git a/cfgs/upr_freq_unimatch_exp002_20240111_extratraining1000_2_2.yaml b/cfgs/upr_freq_unimatch_exp002_20240111_extratraining1000_2_2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f61d6733279d6150e2b39f7d6733d3730982b86c --- /dev/null +++ b/cfgs/upr_freq_unimatch_exp002_20240111_extratraining1000_2_2.yaml @@ -0,0 +1,86 @@ +exp_name: FET-VFI_BiFrequencyLoss + +mode: train + +train_dataset: + name: vimeo + args: + root_path: ../datasets/frames2_triplet + flow_root: ../datasets/frames2_unimatch_flow + tri_trainlist: tri_trainlist.txt + split: train + patch_size: 256 + flow: 't0' + loader: + batch_size: 16 + num_workers: 4 + +# test_dataset: +# name: vimeo +# args: +# root_path: ../datasets/Vimeo90K +# split: val +# loader: +# batch_size: 16 +# num_workers: 4 +# save_imgs: False + +# demo_dataset: +# name: demo +# args: +# root_path: ../data/animation +# split: animation + +model: + name: upr_net_freq + args: + pyr_level: 3 + nr_lvl_skipped: 0 + splat_mode: average + fftshift: False + + +optimizer: + name: adamW + args: {lr: 4.e-6, weight_decay: 1.e-4} + +lr_scheduler: + name: one_cycle_lr + args: + max_lr: 4.e-6 + pct_start: 0.01 + cycle_momentum: False + anneal_strategy: cos + +loss: + - { + name: frequency, + args: { weight: 0.01 } + } + - { + name: bi_frequency, + args: { weight: 0.01 } + } + - { + name: multiple_flow, + args: { weight: 0.005 } + } + - { + name: charbonnier, + args: { weight: 1 } + } + - { + name: ternary, + args: { weight: 1 } + } + +pretrained: ./pretrained/upr_freq002.pth +max_epoch: 1000 + +validate_every: 10 +save_every: 10 +vis_every: 20 + +seed: 1234 + +dist_url: env:// \ No newline at end of file diff --git a/cfgs/upr_freq_unimatch_exp002_20240111_extratraining1000_2_3.yaml b/cfgs/upr_freq_unimatch_exp002_20240111_extratraining1000_2_3.yaml new file mode 100644 index 0000000000000000000000000000000000000000..35e6faf27bfd8df006c77fcc020fc124747fc1de --- /dev/null +++ b/cfgs/upr_freq_unimatch_exp002_20240111_extratraining1000_2_3.yaml @@ -0,0 +1,86 @@ +exp_name: FET-VFI_BiFrequencyLoss + +mode: train + +train_dataset: + name: vimeo + args: + root_path: ../datasets/frames2_triplet + flow_root: ../datasets/frames2_unimatch_flow + tri_trainlist: tri_trainlist.txt + split: train + patch_size: 256 + flow: 't0' + loader: + batch_size: 16 + num_workers: 4 + +# test_dataset: +# name: vimeo +# args: +# root_path: ../datasets/Vimeo90K +# split: val +# loader: +# batch_size: 16 +# num_workers: 4 +# save_imgs: False + +# demo_dataset: +# name: demo +# args: +# root_path: ../data/animation +# split: animation + +model: + name: upr_net_freq + args: + pyr_level: 3 + nr_lvl_skipped: 0 + splat_mode: average + fftshift: False + + +optimizer: + name: adamW + args: {lr: 1.e-6, weight_decay: 1.e-4} + +lr_scheduler: + name: one_cycle_lr + args: + max_lr: 1.e-6 + pct_start: 0.01 + cycle_momentum: False + anneal_strategy: cos + +loss: + - { + name: frequency, + args: { weight: 0.01 } + } + - { + name: bi_frequency, + args: { weight: 0.01 } + } + - { + name: multiple_flow, + args: { weight: 0.005 } + } + - { + name: charbonnier, + args: { weight: 1 } + } + - { + name: ternary, + args: { weight: 1 } + } + +pretrained: ./pretrained/upr_freq002.pth +max_epoch: 10000 + +validate_every: 1000 +save_every: 1000 +vis_every: 2000 + +seed: 1234 + +dist_url: env:// \ No newline at end of file diff --git a/cfgs/upr_freq_unimatch_exp002_20240111_extratraining1000_2_4.yaml b/cfgs/upr_freq_unimatch_exp002_20240111_extratraining1000_2_4.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ffd690f73d32a96b2bbec7fc27c696cb94b131c2 --- /dev/null +++ b/cfgs/upr_freq_unimatch_exp002_20240111_extratraining1000_2_4.yaml @@ -0,0 +1,86 @@ +exp_name: FET-VFI_BiFrequencyLoss + +mode: train + +train_dataset: + name: vimeo + args: + root_path: ../datasets/frames2_triplet + flow_root: ../datasets/frames2_unimatch_flow + tri_trainlist: tri_trainlist.txt + split: train + patch_size: 256 + flow: 't0' + loader: + batch_size: 16 + num_workers: 4 + +# test_dataset: +# name: vimeo +# args: +# root_path: ../datasets/Vimeo90K +# split: val +# loader: +# batch_size: 16 +# num_workers: 4 +# save_imgs: False + +# demo_dataset: +# name: demo +# args: +# root_path: ../data/animation +# split: animation + +model: + name: upr_net_freq + args: + pyr_level: 3 + nr_lvl_skipped: 0 + splat_mode: average + fftshift: False + + +optimizer: + name: adamW + args: {lr: 4.e-6, weight_decay: 1.e-4} + +lr_scheduler: + name: one_cycle_lr + args: + max_lr: 4.e-6 + pct_start: 0.01 + cycle_momentum: False + anneal_strategy: cos + +loss: + - { + name: frequency, + args: { weight: 0.01 } + } + - { + name: bi_frequency, + args: { weight: 0.01 } + } + - { + name: multiple_flow, + args: { weight: 0.005 } + } + - { + name: charbonnier, + args: { weight: 1 } + } + - { + name: ternary, + args: { weight: 1 } + } + +pretrained: ./pretrained/upr_freq002.pth +max_epoch: 10000 + +validate_every: 1000 +save_every: 1000 +vis_every: 2000 + +seed: 1234 + +dist_url: env:// \ No newline at end of file diff --git a/cfgs/upr_freq_unimatch_exp002_20240111_extratraining1000_3_1.yaml b/cfgs/upr_freq_unimatch_exp002_20240111_extratraining1000_3_1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8e8248174758e14e55d5405d47033d27f4cf948c --- /dev/null +++ b/cfgs/upr_freq_unimatch_exp002_20240111_extratraining1000_3_1.yaml @@ -0,0 +1,86 @@ +exp_name: FET-VFI_BiFrequencyLoss + +mode: train + +train_dataset: + name: vimeo + args: + root_path: ../datasets/frames3_triplet + flow_root: ../datasets/frames3_unimatch_flow + tri_trainlist: tri_trainlist.txt + split: train + patch_size: 256 + flow: 't0' + loader: + batch_size: 16 + num_workers: 4 + +# test_dataset: +# name: vimeo +# args: +# root_path: ../datasets/Vimeo90K +# split: val +# loader: +# batch_size: 16 +# num_workers: 4 +# save_imgs: False + +# demo_dataset: +# name: demo +# args: +# root_path: ../data/animation +# split: animation + +model: + name: upr_net_freq + args: + pyr_level: 3 + nr_lvl_skipped: 0 + splat_mode: average + fftshift: False + + +optimizer: + name: adamW + args: {lr: 7.5e-6, weight_decay: 1.e-4} + +lr_scheduler: + name: one_cycle_lr + args: + max_lr: 7.5e-6 + pct_start: 0.01 + cycle_momentum: False + anneal_strategy: cos + +loss: + - { + name: frequency, + args: { weight: 0.01 } + } + - { + name: bi_frequency, + args: { weight: 0.01 } + } + - { + name: multiple_flow, + args: { weight: 0.005 } + } + - { + name: charbonnier, + args: { weight: 1 } + } + - { + name: ternary, + args: { weight: 1 } + } + +pretrained: ./pretrained/upr_freq002.pth +max_epoch: 1000 + +validate_every: 10 +save_every: 10 +vis_every: 20 + +seed: 1234 + +dist_url: env:// \ No newline at end of file diff --git a/cfgs/upr_freq_unimatch_exp002_20240111_extratraining1000_3_2.yaml b/cfgs/upr_freq_unimatch_exp002_20240111_extratraining1000_3_2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..115520234957438d149970d602198392ff1222c3 --- /dev/null +++ b/cfgs/upr_freq_unimatch_exp002_20240111_extratraining1000_3_2.yaml @@ -0,0 +1,86 @@ +exp_name: FET-VFI_BiFrequencyLoss + +mode: train + +train_dataset: + name: vimeo + args: + root_path: ../datasets/frames3_triplet + flow_root: ../datasets/frames3_unimatch_flow + tri_trainlist: tri_trainlist.txt + split: train + patch_size: 256 + flow: 't0' + loader: + batch_size: 16 + num_workers: 4 + +# test_dataset: +# name: vimeo +# args: +# root_path: ../datasets/Vimeo90K +# split: val +# loader: +# batch_size: 16 +# num_workers: 4 +# save_imgs: False + +# demo_dataset: +# name: demo +# args: +# root_path: ../data/animation +# split: animation + +model: + name: upr_net_freq + args: + pyr_level: 3 + nr_lvl_skipped: 0 + splat_mode: average + fftshift: False + + +optimizer: + name: adamW + args: {lr: 4.e-6, weight_decay: 1.e-4} + +lr_scheduler: + name: one_cycle_lr + args: + max_lr: 4.e-6 + pct_start: 0.01 + cycle_momentum: False + anneal_strategy: cos + +loss: + - { + name: frequency, + args: { weight: 0.01 } + } + - { + name: bi_frequency, + args: { weight: 0.01 } + } + - { + name: multiple_flow, + args: { weight: 0.005 } + } + - { + name: charbonnier, + args: { weight: 1 } + } + - { + name: ternary, + args: { weight: 1 } + } + +pretrained: ./pretrained/upr_freq002.pth +max_epoch: 1000 + +validate_every: 10 +save_every: 10 +vis_every: 20 + +seed: 1234 + +dist_url: env:// \ No newline at end of file diff --git a/cfgs/upr_freq_unimatch_exp002_20240111_extratraining2.yaml b/cfgs/upr_freq_unimatch_exp002_20240111_extratraining2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f6d0f529196f83e9b916157a7f1a9d5eed44cd17 --- /dev/null +++ b/cfgs/upr_freq_unimatch_exp002_20240111_extratraining2.yaml @@ -0,0 +1,86 @@ +exp_name: FET-VFI_BiFrequencyLoss + +mode: train + +train_dataset: + name: vimeo + args: + root_path: ../datasets/frames2_triplet + flow_root: ../datasets/frames2_unimatch_flow + tri_trainlist: tri_trainlist.txt + split: train + patch_size: 256 + flow: 't0' + loader: + batch_size: 16 + num_workers: 4 + +# test_dataset: +# name: vimeo +# args: +# root_path: ../datasets/Vimeo90K +# split: val +# loader: +# batch_size: 16 +# num_workers: 4 +# save_imgs: False + +# demo_dataset: +# name: demo +# args: +# root_path: ../data/animation +# split: animation + +model: + name: upr_net_freq + args: + pyr_level: 3 + nr_lvl_skipped: 0 + splat_mode: average + fftshift: False + + +optimizer: + name: adamW + args: {lr: 4.e-6, weight_decay: 1.e-4} + +lr_scheduler: + name: one_cycle_lr + args: + max_lr: 4.e-6 + pct_start: 0.01 + cycle_momentum: False + anneal_strategy: cos + +loss: + - { + name: frequency, + args: { weight: 0.01 } + } + - { + name: bi_frequency, + args: { weight: 0.01 } + } + - { + name: multiple_flow, + args: { weight: 0.005 } + } + - { + name: charbonnier, + args: { weight: 1 } + } + - { + name: ternary, + args: { weight: 1 } + } + +pretrained: ./pretrained/upr_freq002.pth +max_epoch: 300 + +validate_every: 10 +save_every: 10 +vis_every: 20 + +seed: 1234 + +dist_url: env:// \ No newline at end of file diff --git a/cfgs/upr_freq_unimatch_exp002_20240111_extratraining3.yaml b/cfgs/upr_freq_unimatch_exp002_20240111_extratraining3.yaml new file mode 100644 index 0000000000000000000000000000000000000000..04971b0f86a225948a52335acb0acfc2ff5130ec --- /dev/null +++ b/cfgs/upr_freq_unimatch_exp002_20240111_extratraining3.yaml @@ -0,0 +1,86 @@ +exp_name: FET-VFI_BiFrequencyLoss + +mode: train + +train_dataset: + name: vimeo + args: + root_path: ../datasets/frames3_triplet + flow_root: ../datasets/frames3_unimatch_flow + tri_trainlist: tri_trainlist.txt + split: train + patch_size: 256 + flow: 't0' + loader: + batch_size: 16 + num_workers: 4 + +# test_dataset: +# name: vimeo +# args: +# root_path: ../datasets/Vimeo90K +# split: val +# loader: +# batch_size: 16 +# num_workers: 4 +# save_imgs: False + +# demo_dataset: +# name: demo +# args: +# root_path: ../data/animation +# split: animation + +model: + name: upr_net_freq + args: + pyr_level: 3 + nr_lvl_skipped: 0 + splat_mode: average + fftshift: False + + +optimizer: + name: adamW + args: {lr: 4.e-6, weight_decay: 1.e-4} + +lr_scheduler: + name: one_cycle_lr + args: + max_lr: 4.e-6 + pct_start: 0.01 + cycle_momentum: False + anneal_strategy: cos + +loss: + - { + name: frequency, + args: { weight: 0.01 } + } + - { + name: bi_frequency, + args: { weight: 0.01 } + } + - { + name: multiple_flow, + args: { weight: 0.005 } + } + - { + name: charbonnier, + args: { weight: 1 } + } + - { + name: ternary, + args: { weight: 1 } + } + +pretrained: ./pretrained/upr_freq002.pth +max_epoch: 300 + +validate_every: 10 +save_every: 10 +vis_every: 20 + +seed: 1234 + +dist_url: env:// \ No newline at end of file diff --git a/cfgs/upr_freq_unimatch_exp002_20240111_extratraining4.yaml b/cfgs/upr_freq_unimatch_exp002_20240111_extratraining4.yaml new file mode 100644 index 0000000000000000000000000000000000000000..93f25c1e44115270ac32265ba62b6ec5531e1d12 --- /dev/null +++ b/cfgs/upr_freq_unimatch_exp002_20240111_extratraining4.yaml @@ -0,0 +1,86 @@ +exp_name: FET-VFI_BiFrequencyLoss + +mode: train + +train_dataset: + name: vimeo + args: + root_path: ../datasets/frames4_triplet + flow_root: ../datasets/frames4_unimatch_flow + tri_trainlist: tri_trainlist.txt + split: train + patch_size: 256 + flow: 't0' + loader: + batch_size: 4 + num_workers: 4 + +# test_dataset: +# name: vimeo +# args: +# root_path: ../datasets/Vimeo90K +# split: val +# loader: +# batch_size: 16 +# num_workers: 4 +# save_imgs: False + +# demo_dataset: +# name: demo +# args: +# root_path: ../data/animation +# split: animation + +model: + name: upr_net_freq + args: + pyr_level: 3 + nr_lvl_skipped: 0 + splat_mode: average + fftshift: False + + +optimizer: + name: adamW + args: {lr: 1.e-6, weight_decay: 1.e-4} + +lr_scheduler: + name: one_cycle_lr + args: + max_lr: 1.e-6 + pct_start: 0.01 + cycle_momentum: False + anneal_strategy: cos + +loss: + - { + name: frequency, + args: { weight: 0.01 } + } + - { + name: bi_frequency, + args: { weight: 0.01 } + } + - { + name: multiple_flow, + args: { weight: 0.005 } + } + - { + name: charbonnier, + args: { weight: 1 } + } + - { + name: ternary, + args: { weight: 1 } + } + +pretrained: ./pretrained/upr_freq002.pth +max_epoch: 300 + +validate_every: 10 +save_every: 10 +vis_every: 20 + +seed: 1234 + +dist_url: env:// \ No newline at end of file diff --git a/cfgs/upr_freq_unimatch_exp002_stopmotion.yaml b/cfgs/upr_freq_unimatch_exp002_stopmotion.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1fa3e5ba325d33286665d5108a5886ca53e4ab43 --- /dev/null +++ b/cfgs/upr_freq_unimatch_exp002_stopmotion.yaml @@ -0,0 +1,86 @@ +exp_name: FET-VFI_BiFrequencyLoss + +mode: train + +train_dataset: + name: vimeo + args: + root_path: ../datasets/256_new + flow_root: ../datasets/256_new_unimatch_flow + tri_trainlist: tri_trainlist_011.txt + split: train + patch_size: 256 + flow: 't0' + loader: + batch_size: 32 + num_workers: 8 + +test_dataset: + name: vimeo + args: + root_path: ../datasets/Vimeo90K + split: val + loader: + batch_size: 16 + num_workers: 4 + save_imgs: False + +demo_dataset: + name: demo + args: + root_path: ../data/animation + split: animation + +model: + name: upr_net_freq + args: + pyr_level: 3 + nr_lvl_skipped: 0 + splat_mode: average + fftshift: False + + +optimizer: + name: adamW + args: {lr: 4.e-6, weight_decay: 1.e-4} + +lr_scheduler: + name: one_cycle_lr + args: + max_lr: 4.e-6 + pct_start: 0.01 + cycle_momentum: False + anneal_strategy: cos + +loss: + - { + name: frequency, + args: { weight: 0.01 } + } + - { + name: bi_frequency, + args: { weight: 0.01 } + } + - { + name: multiple_flow, + args: { weight: 0.005 } + } + - { + name: charbonnier, + args: { weight: 1 } + } + - { + name: ternary, + args: { weight: 1 } + } + +pretrained: ./pretrained/upr_freq002.pth +max_epoch: 300 + +validate_every: 10 +save_every: 10 +vis_every: 20 + +seed: 1234 + +dist_url: env:// \ No newline at end of file diff --git a/cfgs/upr_vimeo_exp45.yaml b/cfgs/upr_vimeo_exp45.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fb4bb1c922d57203ec5a33b48920ad8adf2966b8 --- /dev/null +++ b/cfgs/upr_vimeo_exp45.yaml @@ -0,0 +1,76 @@ +exp_name: upsample_guided_filter_v2 + +mode: train + +train_dataset: + name: vimeo + args: + root_path: ../data/vimeo_triplet + split: train + patch_size: 256 + flow: 't0' + seg: False + loader: + batch_size: 32 + num_workers: 8 + +test_dataset: + name: vimeo + args: + root_path: ../data/vimeo_triplet + split: val + loader: + batch_size: 16 + num_workers: 8 + save_imgs: False + +demo_dataset: + name: demo + args: + root_path: ../data/animation + split: animation + +model: + name: upr_net_mod2 + args: + pyr_level: 3 + nr_lvl_skipped: 0 + splat_mode: average + + +optimizer: + name: adamW + args: {lr: 1.5e-4, weight_decay: 1.e-4} + +lr_scheduler: + name: one_cycle_lr + args: + max_lr: 1.5e-4 + pct_start: 0.01 + cycle_momentum: False + anneal_strategy: cos + +loss: + - { + name: multiple_flow, + args: { weight: 0.005 } + } + - { + name: charbonnier, + args: { weight: 1 } + } + - { + name: ternary, + args: { weight: 1 } + } + + +max_epoch: 540 + +validate_every: 10 +save_every: 10 +vis_every: 20 + +seed: 1234 + +dist_url: env:// \ No newline at end of file diff --git a/check_images.py b/check_images.py new file mode 100644 index 0000000000000000000000000000000000000000..ed165950915ad43c7527a7cc2e8094225cd7ff5f --- /dev/null +++ b/check_images.py @@ -0,0 +1,39 @@ +import os +import cv2 + +def process_directory(directory, output_file): + for root, dirs, files in os.walk(directory): + if not dirs: + # ๋” ์ด์ƒ ํ•˜์œ„ ๋””๋ ‰ํ† ๋ฆฌ๊ฐ€ ์—†๋Š” ๊ฒฝ์šฐ + print("Directory:", root) + + # ์ด๋ฏธ์ง€ ํŒŒ์ผ ๊ฒฝ๋กœ ๊ฒ€์‚ฌ + image_found = False + for filename in files: + if filename in ["im1.png","im1.jpg"]: + image_found = True + image_path = os.path.join(root, filename) + + # ์ด๋ฏธ์ง€ ์ฝ๊ธฐ + image = cv2.imread(image_path)[:, :, ::-1] + if image is not None: + print("Image:", image_path) + + # ์ด๋ฏธ์ง€๋ฅผ ์ฐพ์ง€ ๋ชปํ•œ ๊ฒฝ์šฐ ๊ฒฝ๋กœ๋ฅผ ํŒŒ์ผ์— ์“ฐ๊ธฐ + if not image_found: + with open(output_file, "a") as f: + f.write(root + "\n") + +if __name__ == "__main__": + # ๋””๋ ‰ํ† ๋ฆฌ ๊ฒฝ๋กœ ์ง€์ • + start_directory = "../../ins4/Triplet_250p_new/sequences" + + # ๊ฒฐ๊ณผ๋ฅผ ์ €์žฅํ•  ํ…์ŠคํŠธ ํŒŒ์ผ ๊ฒฝ๋กœ ์ง€์ • + output_file = "missing_images.txt" + + # ํ…์ŠคํŠธ ํŒŒ์ผ ์ดˆ๊ธฐํ™” + open(output_file, "w").close() + + # ์‹œ์ž‘ ๋””๋ ‰ํ† ๋ฆฌ์—์„œ ์ˆœํšŒ ์‹œ์ž‘ + process_directory(start_directory, output_file) + diff --git a/datasets/__init__.py b/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d1911e4de3849013ceb1d5e5222efd26a9e79a71 --- /dev/null +++ b/datasets/__init__.py @@ -0,0 +1,5 @@ +from .datasets import * +from .vimeo import * +from .snu_film import * +from .xiph import * +from .ucf101 import * \ No newline at end of file diff --git a/datasets/__pycache__/__init__.cpython-310.pyc b/datasets/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a28c589dd71fe5fc1dc51c62cb1adacfa7b42c8 Binary files /dev/null and b/datasets/__pycache__/__init__.cpython-310.pyc differ diff --git a/datasets/__pycache__/__init__.cpython-38.pyc b/datasets/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1a3e27f4a8bedd5b4eb90cd9d15d5e0cb67dd99 Binary files /dev/null and b/datasets/__pycache__/__init__.cpython-38.pyc differ diff --git a/datasets/__pycache__/data_utils.cpython-310.pyc b/datasets/__pycache__/data_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca8d7de63200c62aadd0507d15974478b085b200 Binary files /dev/null and b/datasets/__pycache__/data_utils.cpython-310.pyc differ diff --git a/datasets/__pycache__/data_utils.cpython-38.pyc b/datasets/__pycache__/data_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc53cdcf20ecf7db28bd2a288dd48d10d2e5527d Binary files /dev/null and b/datasets/__pycache__/data_utils.cpython-38.pyc differ diff --git a/datasets/__pycache__/datasets.cpython-310.pyc b/datasets/__pycache__/datasets.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dca8d435e942ab5ae058910382f5467611f15fc4 Binary files /dev/null and b/datasets/__pycache__/datasets.cpython-310.pyc differ diff --git a/datasets/__pycache__/datasets.cpython-38.pyc b/datasets/__pycache__/datasets.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1d523ca6541084579164146082b9ae56be0f97b Binary files /dev/null and b/datasets/__pycache__/datasets.cpython-38.pyc differ diff --git a/datasets/__pycache__/snu_film.cpython-310.pyc b/datasets/__pycache__/snu_film.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..494e7189e7eed700f08cc30f65c714187d519a7f Binary files /dev/null and b/datasets/__pycache__/snu_film.cpython-310.pyc differ diff --git a/datasets/__pycache__/snu_film.cpython-38.pyc b/datasets/__pycache__/snu_film.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c97c9dd1b5b103b8d3290c38805a29a2bae0bb16 Binary files /dev/null and b/datasets/__pycache__/snu_film.cpython-38.pyc differ diff --git a/datasets/__pycache__/ucf101.cpython-310.pyc b/datasets/__pycache__/ucf101.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..994ff08aca90cddee1fe6f3be512694a49a1b8d9 Binary files /dev/null and b/datasets/__pycache__/ucf101.cpython-310.pyc differ diff --git a/datasets/__pycache__/ucf101.cpython-38.pyc b/datasets/__pycache__/ucf101.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a548499ebc9507a912b0e79f34c117a24ace86b5 Binary files /dev/null and b/datasets/__pycache__/ucf101.cpython-38.pyc differ diff --git a/datasets/__pycache__/vimeo.cpython-310.pyc b/datasets/__pycache__/vimeo.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..245d282b8606cf26a017710c7c54ff0b343bc5f1 Binary files /dev/null and b/datasets/__pycache__/vimeo.cpython-310.pyc differ diff --git a/datasets/__pycache__/vimeo.cpython-38.pyc b/datasets/__pycache__/vimeo.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f75e3fd0f8fe47851158e3269a09e1663de2aef Binary files /dev/null and b/datasets/__pycache__/vimeo.cpython-38.pyc differ diff --git a/datasets/__pycache__/xiph.cpython-310.pyc b/datasets/__pycache__/xiph.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4163e832bb18d7dde56cefb8e16f2129e7aec08 Binary files /dev/null and b/datasets/__pycache__/xiph.cpython-310.pyc differ diff --git a/datasets/__pycache__/xiph.cpython-38.pyc b/datasets/__pycache__/xiph.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6489a84d8fb9291f41476f52bf13f943c38f2df3 Binary files /dev/null and b/datasets/__pycache__/xiph.cpython-38.pyc differ diff --git a/datasets/data_utils.py b/datasets/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a85ed38e6741aa65e53aae6858349880017b8502 --- /dev/null +++ b/datasets/data_utils.py @@ -0,0 +1,110 @@ +import cv2 +import numpy as np +import random + +perm = [(0, 1, 2), (0, 2, 1), (1, 0, 2), (1, 2, 0), (2, 0, 1), (2, 1, 0)] +rotate = [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_180, cv2.ROTATE_90_COUNTERCLOCKWISE] + + +def random_crop(img0, imgt, img1, crop_size, flowt0=None, flowt1=None, distance=None): + im_h, im_w = img0.shape[:2] + crop_h, crop_w = crop_size, crop_size + i = random.randint(0, im_h - crop_h) + j = random.randint(0, im_w - crop_w) + img0 = img0[i:i + crop_h, j:j + crop_w] + imgt = imgt[i:i + crop_h, j:j + crop_w] + img1 = img1[i:i + crop_h, j:j + crop_w] + if flowt0 is not None and flowt1 is not None: + flowt0 = flowt0[i:i + crop_h, j:j + crop_w] + flowt1 = flowt1[i:i + crop_h, j:j + crop_w] + if distance is not None: + distance = distance[i:i + crop_h, j:j + crop_w] + return img0, imgt, img1, flowt0, flowt1, distance + + +def random_hor_flip(img0, imgt, img1, flowt0=None, flowt1=None, distance=None): + img0, imgt, img1 = img0[::-1, :, :], imgt[::-1, :, :], img1[::-1, :, :] + if flowt0 is not None and flowt1 is not None: + flowt0, flowt1 = flowt0[::-1, :, :] * np.array([1, -1]).reshape(1, 1, 2), \ + flowt1[::-1, :, :] * np.array([1, -1]).reshape(1, 1, 2) + if distance is not None: + distance = distance[::-1, :, :] + return img0, imgt, img1, flowt0, flowt1, distance + + +def random_ver_flip(img0, imgt, img1, flowt0=None, flowt1=None, distance=None): + img0, imgt, img1 = img0[:, ::-1, :], imgt[:, ::-1, :], img1[:, ::-1, :] + if flowt0 is not None and flowt1 is not None: + flowt0, flowt1 = flowt0[:, ::-1, :] * np.array([-1, 1]).reshape(1, 1, 2), \ + flowt1[:, ::-1, :] * np.array([-1, 1]).reshape(1, 1, 2) + if distance is not None: + distance = distance[:, ::-1, :] + return img0, imgt, img1, flowt0, flowt1, distance + + +def random_color_permutation(img0, imgt, img1): + perm_idx = random.randint(0, 5) + img0, imgt, img1 = img0[:, :, perm[perm_idx]], imgt[:, :, perm[perm_idx]], img1[:, :, perm[perm_idx]] + return img0, imgt, img1 + + +def random_temporal_flip(img0, imgt, img1, time_step, flowt0=None, flowt1=None): + tmp = img1 + img1 = img0 + img0 = tmp + time_step = 1 - time_step + if flowt0 is not None and flowt1 is not None: + tmp = flowt0 + flowt0 = flowt1 + flowt1 = tmp + return img0, imgt, img1, time_step, flowt0, flowt1 + + +def random_rotation(img0, imgt, img1, degree, flowt0=None, flowt1=None, distance=None): + if degree != 3: + img0 = cv2.rotate(img0, rotate[degree]) + imgt = cv2.rotate(imgt, rotate[degree]) + img1 = cv2.rotate(img1, rotate[degree]) + if flowt0 is not None and flowt1 is not None: + flowt0 = cv2.rotate(flowt0, rotate[degree]) + flowt1 = cv2.rotate(flowt1, rotate[degree]) + if degree == 0: + flowt0 = np.concatenate((-flowt0[:, :, 1:2], flowt0[:, :, 0:1]), 2) + flowt1 = np.concatenate((-flowt1[:, :, 1:2], flowt1[:, :, 0:1]), 2) + elif degree == 1: + flowt0 = -flowt0 + flowt1 = -flowt1 + elif degree == 2: + flowt0 = np.concatenate((flowt0[:, :, 1:2], -flowt0[:, :, 0:1]), 2) + flowt1 = np.concatenate((flowt1[:, :, 1:2], -flowt1[:, :, 0:1]), 2) + if distance is not None: + H,W,_ = distance.shape + distance = cv2.rotate(distance, rotate[degree]).reshape(H,W,1) + return img0, imgt, img1, flowt0, flowt1, distance + + +def random_resize(img0, imgt, img1, flowt0=None, flowt1=None, distance=None): + ''' + img0 = cv2.resize(img0, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR) + imgt = cv2.resize(imgt, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR) + img1 = cv2.resize(img1, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR) + if flowt0 is not None and flowt1 is not None: + flowt0 = cv2.resize(flowt0, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR) * 2.0 + flowt1 = cv2.resize(flowt1, dsize=None, fx=2.0, fy=2.0, interpolation=cv2.INTER_LINEAR) * 2.0 + ''' + return img0, imgt, img1, flowt0, flowt1, distance + + +def read_flow(name): + with open(name, "rb") as f: + header = f.read(4) + if header.decode("utf-8") != 'PIEH': + raise Exception('Flow file header does not contain PIEH') + + width = np.fromfile(f, np.int32, 1).squeeze() + height = np.fromfile(f, np.int32, 1).squeeze() + + flow = np.fromfile(f, np.float32, width * height * 2).reshape((height, width, 2)) + + return flow.astype(np.float32) + diff --git a/datasets/datasets.py b/datasets/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..0c6772348651dbc9e19393cd93e58fe173a9f1b4 --- /dev/null +++ b/datasets/datasets.py @@ -0,0 +1,20 @@ +import copy + + +datasets = {} + +def register(name): + def decorator(cls): + datasets[name] = cls + return cls + return decorator + + +def make(dataset_spec, args=None): + if args is not None: + dataset_args = copy.deepcopy(dataset_spec['args']) + dataset_args.update(args) + else: + dataset_args = dataset_spec['args'] + dataset = datasets[dataset_spec['name']](**dataset_args) + return dataset diff --git a/datasets/snu_film.py b/datasets/snu_film.py new file mode 100644 index 0000000000000000000000000000000000000000..dc53145a34bb8e9f6396674dcb4f5cf1918141c6 --- /dev/null +++ b/datasets/snu_film.py @@ -0,0 +1,78 @@ +from pathlib import Path +import os +from PIL import Image +import random +import numpy as np +import cv2 + +from datasets import register + +import torch +import torchvision +import torchvision.transforms as T +import torchvision.transforms.functional as TF +from torch.utils.data import Dataset + + +@register('snu_film') +class SNUFilm(Dataset): + def __init__(self, root_path, split="extreme", use_distance=False, distance_root=None): + self.data_root = root_path + self.data_type = split + assert split in ["easy", "medium", "hard", "extreme"] + self.load_data() + self.use_distance = use_distance + self.distance_root = distance_root + + def __len__(self): + return len(self.meta_data) + + def load_data(self): + if self.data_type == "easy": + easy_file = os.path.join(self.data_root, "eval_modes/test-easy.txt") + with open(easy_file, 'r') as f: + self.meta_data = f.read().splitlines() + if self.data_type == "medium": + medium_file = os.path.join(self.data_root, "eval_modes/test-medium.txt") + with open(medium_file, 'r') as f: + self.meta_data = f.read().splitlines() + if self.data_type == "hard": + hard_file = os.path.join(self.data_root, "eval_modes/test-hard.txt") + with open(hard_file, 'r') as f: + self.meta_data = f.read().splitlines() + if self.data_type == "extreme": + extreme_file = os.path.join(self.data_root, "eval_modes/test-extreme.txt") + with open(extreme_file, 'r') as f: + self.meta_data = f.read().splitlines() + + def get_img(self, index): + imgpath = self.meta_data[index] + imgpaths = imgpath.split() + + # Load images + img0 = cv2.imread(os.path.join(self.data_root, '/'.join(imgpaths[0].split('/')[2:])))[:, :, ::-1] + gt = cv2.imread(os.path.join(self.data_root, '/'.join(imgpaths[1].split('/')[2:])))[:, :, ::-1] + img1 = cv2.imread(os.path.join(self.data_root, '/'.join(imgpaths[2].split('/')[2:])))[:, :, ::-1] + + return img0, gt, img1, '/'.join(imgpaths[1].split('/')[3:]) + + def __getitem__(self, index): + img0, imgt, img1, scene_name = self.get_img(index) + img0 = TF.to_tensor(img0.copy()) + img1 = TF.to_tensor(img1.copy()) + imgt = TF.to_tensor(imgt.copy()) + time_step = torch.Tensor([0.5]).reshape(1, 1, 1) + _,H,W = img0.size() + input_dict = {'img0': img0, 'imgt': imgt, 'img1': img1, 'time_step': time_step, 'scene_name': ''.join(scene_name.split('.')[:-1])} + if self.use_distance: + if self.distance_root is not None: + distance_path = os.path.join(self.data_root.replace('SNU-FILM', 'snux_distance_map'), + '/'.join(self.meta_data[index].split()[1].split('/')[2:]).replace('.png', ''), + 'distance_for.npy') + distance = np.load(distance_path).astype(np.float32).reshape(H,W,1) + distance = TF.to_tensor(distance.copy()) + else: + distance = np.array(0.5).reshape(1,1,1).repeat(H, axis=0).repeat(W, axis=1) + distance = torch.from_numpy(distance).type(torch.float32).permute(2,0,1) + input_dict['distance'] = distance + return input_dict \ No newline at end of file diff --git a/datasets/ucf101.py b/datasets/ucf101.py new file mode 100644 index 0000000000000000000000000000000000000000..05ed76ac6362b9594c4026fd6abe5ec84cbfa9bb --- /dev/null +++ b/datasets/ucf101.py @@ -0,0 +1,46 @@ +import os +import cv2 +from glob import glob + +from datasets import register + +import torch +import torchvision.transforms.functional as TF +from torch.utils.data import Dataset + + +@register('ucf101') +class UCF101(Dataset): + def __init__(self, root_path, **kwargs): + self.data_root = root_path + self.load_data() + + def __len__(self): + return len(self.meta_data) + + def load_data(self): + triplet_dirs = glob(os.path.join(self.data_root, "*")) + self.meta_data = triplet_dirs + + def get_img(self, index): + img_path = self.meta_data[index] + img_paths = [os.path.join(img_path, 'im1.png'), + os.path.join(img_path, 'im2.png'), + os.path.join(img_path, 'im3.png')] + + # Load images + img0 = cv2.imread(img_paths[0])[:,:,::-1] + imgt = cv2.imread(img_paths[1])[:,:,::-1] + img1 = cv2.imread(img_paths[2])[:,:,::-1] + return img0, imgt, img1 + + def __getitem__(self, index): + img0, imgt, img1 = self.get_img(index) + img0 = TF.to_tensor(img0.copy()) + img1 = TF.to_tensor(img1.copy()) + imgt = TF.to_tensor(imgt.copy()) + time_step = torch.Tensor([0.5]).reshape(1, 1, 1) + return { + 'img0': img0, 'imgt': imgt, 'img1': img1, 'time_step': time_step, 'scene_name': self.meta_data[index] + } + diff --git a/datasets/vimeo.py b/datasets/vimeo.py new file mode 100644 index 0000000000000000000000000000000000000000..668b68db3a7aac1c04eb140777df7b5245bab9ad --- /dev/null +++ b/datasets/vimeo.py @@ -0,0 +1,139 @@ +from pathlib import Path +import os +from PIL import Image +import random +import numpy as np +import cv2 + +from datasets import register +from .data_utils import * + +import torch +import torchvision +import torchvision.transforms as T +import torchvision.transforms.functional as TF +from torch.utils.data import Dataset + +perm = [(0, 1, 2), (0, 2, 1), (1, 0, 2), (1, 2, 0), (2, 0, 1), (2, 1, 0)] +rotate = [cv2.ROTATE_90_CLOCKWISE, cv2.ROTATE_180, cv2.ROTATE_90_COUNTERCLOCKWISE] + + +@register('vimeo') +class Vimeo(Dataset): + def __init__(self, root_path, patch_size=(224, 224), split='train', flow="none", flow_root=None, + use_distance=False, distance_root=None, tri_trainlist='tri_trainlist.txt'): + super(Vimeo, self).__init__() + self.data_root = root_path + self.mode = split + self.patch_size = patch_size + train_fn = os.path.join(self.data_root, tri_trainlist) + test_fn = os.path.join(self.data_root, 'tri_testlist.txt') +# self.flow = 't0' + self.flow = flow + self.flow_root = flow_root if flow!='none' else None + self.use_distance = use_distance + self.distance_root = distance_root + + with open(train_fn, "r") as f: + self.trainlist = [line.strip() for line in f.readlines() if len(line.strip())>0] +# self.trainlist = [line.strip() for line in f.readlines() if len(line.strip())>0 and line.strip().endswith(('e', 'n'))] + with open(test_fn, "r") as f: + self.testlist = [line.strip() for line in f.readlines() if len(line.strip())>0] + #cnt = int(len(self.trainlist) * 0.95) + if self.mode == "train": + #self.img_list = self.trainlist[:cnt] + self.img_list = self.trainlist + elif self.mode == "test": + self.img_list = self.testlist + else: + self.img_list = self.testlist + #self.img_list = self.trainlist[cnt:] + + def get_img(self, index): + img_path = os.path.join(self.data_root, "sequences", self.img_list[index]) + if os.path.exists(os.path.join(img_path, "im1.png")): + img0 = cv2.imread(os.path.join(img_path, "im1.png"))[:, :896, ::-1] + imgt = cv2.imread(os.path.join(img_path, "im2.png"))[:, :896, ::-1] + img1 = cv2.imread(os.path.join(img_path, "im3.png"))[:, :896, ::-1] + + elif os.path.exists(os.path.join(img_path, "im1.jpg")): + img0 = cv2.imread(os.path.join(img_path, "im1.jpg"))[:, :, ::-1] + imgt = cv2.imread(os.path.join(img_path, "im2.jpg"))[:, :, ::-1] + img1 = cv2.imread(os.path.join(img_path, "im3.jpg"))[:, :, ::-1] + else: + print(img_path,"ํŒŒ์ผ์ด ์™œ ์—†์ง€?") + +# print(f'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!{self.flow}!!!!!!!!!!!!!!!!!!!!!!!!!!!!!') + if self.flow == 't0': +# if not os.path.exists(os.path.join(self.data_root, "flow_flow_former", self.img_list[index], "flow_t0.flo")): + if not os.path.exists(os.path.join(self.flow_root, 'sequences', self.img_list[index], 'flowt0.npy')): + print(self.img_list[index]) +# flowt0 = read_flow(os.path.join(self.data_root, "flow_flow_former", self.img_list[index], "flow_t0.flo")) +# flowt1 = read_flow(os.path.join(self.data_root, "flow_flow_former", self.img_list[index], "flow_t1.flo")) + flowt0 = np.load(os.path.join(self.flow_root, 'sequences', self.img_list[index], 'flowt0.npy')).astype(np.float32) + flowt1 = np.load(os.path.join(self.flow_root, 'sequences', self.img_list[index], 'flowt1.npy')).astype(np.float32) + elif self.flow == '01': + flowt0 = read_flow(os.path.join(self.data_root, "flow_flow_former", self.img_list[index], "flow_01.flo")) + flowt1 = read_flow(os.path.join(self.data_root, "flow_flow_former", self.img_list[index], "flow_10.flo")) + elif self.flow == '0t': + flowt0 = read_flow(os.path.join(self.data_root, "flow_flow_former", self.img_list[index], "flow_0t.flo")) + flowt1 = read_flow(os.path.join(self.data_root, "flow_flow_former", self.img_list[index], "flow_1t.flo")) + else: + flowt0 = None + flowt1 = None + return img0, imgt, img1, flowt0, flowt1 + + def __getitem__(self, item): + img0, imgt, img1, flowt0, flowt1 = self.get_img(item) + distance = None + H,W,_ = img0.shape + time_step = torch.Tensor([0.5]).reshape(1, 1, 1) + if self.mode == "train": + if random.random() > 0.5: + img0, imgt, img1, time_step, flowt0, flowt1 = random_temporal_flip(img0, imgt, img1, time_step, flowt0, flowt1) + if self.use_distance and self.distance_root is not None: + distance_path = os.path.join(self.distance_root, "sequences", self.img_list[item]) + distance = np.load(os.path.join(distance_path, 'distance_rev.npy')).astype(np.float32).reshape(H,W,1) + asdf = 'distance_rev.npy' + else: + if self.use_distance and self.distance_root is not None: + distance_path = os.path.join(self.distance_root, "sequences", self.img_list[item]) + distance = np.load(os.path.join(distance_path, 'distance_for.npy')).astype(np.float32).reshape(H,W,1) + asdf = 'distance_for.npy' + if random.random() > 0.9: + img0, imgt, img1, flowt0, flowt1, distance = random_resize(img0, imgt, img1, flowt0, flowt1, distance) + img0, imgt, img1, flowt0, flowt1, distance = random_crop(img0, imgt, img1, self.patch_size, flowt0, flowt1, distance) + if random.random() > 0.5: + img0, imgt, img1, flowt0, flowt1, distance = random_hor_flip(img0, imgt, img1, flowt0, flowt1, distance) + if random.random() > 0.5: + img0, imgt, img1, flowt0, flowt1, distance = random_ver_flip(img0, imgt, img1, flowt0, flowt1, distance) + if random.random() > 0.5: + img0, imgt, img1 = random_color_permutation(img0, imgt, img1) + degree = random.randint(0, 3) + img0, imgt, img1, flowt0, flowt1, distance = random_rotation(img0, imgt, img1, degree, flowt0, flowt1, distance) + else: + if self.distance_root is not None: + distance_path = os.path.join(self.distance_root, "sequences", self.img_list[item], 'distance_for.npy') + distance = np.load(distance_path).astype(np.float32).reshape(H,W,1) + img0, imgt, img1 = TF.to_tensor(img0.copy()), TF.to_tensor(imgt.copy()), TF.to_tensor(img1.copy()) + input_dict = { + 'img0': img0, 'imgt': imgt, 'img1': img1, 'time_step': time_step, 'scene_name': self.img_list[item] + } + if flowt0 is not None and flowt1 is not None: + flowt0 = torch.from_numpy(flowt0).type(torch.float32).permute(2, 0, 1) + flowt1 = torch.from_numpy(flowt1).type(torch.float32).permute(2, 0, 1) + input_dict['flowt0'] = flowt0 + input_dict['flowt1'] = flowt1 + if self.use_distance: + if self.distance_root is not None: + distance = TF.to_tensor(distance.copy()) + if torch.any(torch.isnan(distance)): + print(f'@@@@@@@@@@@@@@@@@@@@@@{self.img_list[item]}, {asdf}@@@@@@@@@@@@@@@@@@@@@@') + else: + distance = np.array(0.5).reshape(1,1,1).repeat(H, axis=0).repeat(W, axis=1) + distance = torch.from_numpy(distance).type(torch.float32).permute(2,0,1) + input_dict['distance'] = distance + return input_dict + + def __len__(self): + return len(self.img_list) diff --git a/datasets/x4k1000fps.py b/datasets/x4k1000fps.py new file mode 100644 index 0000000000000000000000000000000000000000..2557f1bf0dad42427089857e73ab5f615fe7273f --- /dev/null +++ b/datasets/x4k1000fps.py @@ -0,0 +1,95 @@ +import numpy as np +import cv2 +from glob import glob +import os + +import torch +from torch.utils.data import Dataset + +from .datasets import register + + +@register('x4k1000fps') +class X_Test(Dataset): + def __init__(self, test_data_path, multiple): + self.test_data_path = test_data_path + self.multiple = multiple + self.testPath = self.make_2_d_dataset_x_test( + self.test_data_path, multiple, t_step_size=32) + + self.nIterations = len(self.testPath) + + # Raise error if no images found in test_data_path. + if len(self.testPath) == 0: + raise (RuntimeError("Found 0 files in subfolders of: " \ + + self.test_data_path + "\n")) + + def make_2_d_dataset_x_test(self, test_data_path, multiple, t_step_size): + """ make [I0,I1,It,t,scene_folder] """ + """ 1D (accumulated) """ + testPath = [] + t = np.linspace( + (1 / multiple), (1 - (1 / multiple)), (multiple - 1) + ) + for type_folder in sorted(glob(os.path.join(test_data_path, '*', ''))): # [type1,type2,type3,...] + for scene_folder in sorted(glob(os.path.join(type_folder, '*', ''))): # [scene1,scene2,..] + frame_folder = sorted(glob(scene_folder + '*.png')) # 32 multiple, ['00000.png',...,'00032.png'] + for idx in range(0, len(frame_folder), t_step_size): # 0,32,64,... + if idx == len(frame_folder) - 1: + break + for mul in range(multiple - 1): + I0I1It_paths = [] + I0I1It_paths.append(frame_folder[idx]) # I0 (fix) + I0I1It_paths.append(frame_folder[idx + t_step_size]) # I1 (fix) + I0I1It_paths.append(frame_folder[idx + int((t_step_size // multiple) * (mul + 1))]) # It + I0I1It_paths.append(t[mul]) + I0I1It_paths.append(scene_folder.split(os.path.join(test_data_path, ''))[-1]) # type1/scene1 + testPath.append(I0I1It_paths) + return testPath + + def frames_loader_test(self, I0I1It_Path): + frames = [] + for path in I0I1It_Path: + frame = cv2.imread(path) + frames.append(frame) + (ih, iw, c) = frame.shape + frames = np.stack(frames, axis=0) # (T, H, W, 3) + + """ np2Tensor [-1,1] normalized """ + frames = X_Test.RGBframes_np2Tensor(frames) + + return frames + + def RGBframes_np2Tensor(self, imgIn, channel=3): + ## input : T, H, W, C + if channel == 1: + # rgb --> Y (gray) + imgIn = np.sum( + imgIn * np.reshape( + [65.481, 128.553, 24.966], [1, 1, 1, 3] + ) / 255.0, + axis=3, + keepdims=True) + 16.0 + + # to Tensor + ts = (3, 0, 1, 2) ############# dimension order should be [C, T, H, W] + imgIn = torch.Tensor(imgIn.transpose(ts).astype(float)).mul_(1.0) + + return imgIn + + def __getitem__(self, idx): + I0, I1, It, t_value, scene_name = self.testPath[idx] + + I0I1It_Path = [I0, I1, It] + frames = self.frames_loader_test(I0I1It_Path) + # including "np2Tensor [-1,1] normalized" + + I0_path = I0.split(os.sep)[-1] + I1_path = I1.split(os.sep)[-1] + It_path = It.split(os.sep)[-1] + + return frames, np.expand_dims(np.array(t_value, dtype=np.float32), 0), \ + scene_name, [It_path, I0_path, I1_path] + + def __len__(self): + return self.nIterations diff --git a/datasets/xiph.py b/datasets/xiph.py new file mode 100644 index 0000000000000000000000000000000000000000..bf366f7a68e8dd93107d282c022564f78468fddd --- /dev/null +++ b/datasets/xiph.py @@ -0,0 +1,67 @@ +from pathlib import Path +import os +from PIL import Image +import random +import numpy as np +import cv2 + +from datasets import register + +import torch +import torchvision +import torchvision.transforms as T +import torchvision.transforms.functional as TF +from torch.utils.data import Dataset + + +@register('xiph') +class Xiph(Dataset): + def __init__(self, root_path, split="resized-2k"): + self.data_root = root_path + self.split = split + assert split in ["resized-2k", "cropped-4k"] + self.load_data() + + def __len__(self): + return len(self.imgt_path_list) + + def load_data(self): + self.img0_path_list = [] + self.imgt_path_list = [] + self.img1_path_list = [] + for flie_name in os.listdir(self.data_root): + for intFrame in range(2, 99, 2): + self.img0_path_list.append(f'{flie_name}/{intFrame - 1:03d}.png') + self.imgt_path_list.append(f'{flie_name}/{intFrame:03d}.png') + self.img1_path_list.append(f'{flie_name}/{intFrame + 1:03d}.png') + + def get_img(self, index): + img0_path = os.path.join(self.data_root, self.img0_path_list[index]) + imgt_path = os.path.join(self.data_root, self.imgt_path_list[index]) + img1_path = os.path.join(self.data_root, self.img1_path_list[index]) + + # Load images + img0 = cv2.imread(img0_path)[:, :, ::-1] + imgt = cv2.imread(imgt_path)[:, :, ::-1] + img1 = cv2.imread(img1_path)[:, :, ::-1] + + return img0, imgt, img1 + + def __getitem__(self, index): + img0, imgt, img1 = self.get_img(index) + if self.split == 'resized-2k': + img0 = cv2.resize(src=img0, dsize=(2048, 1080), fx=0.0, fy=0.0, interpolation=cv2.INTER_AREA) + img1 = cv2.resize(src=img1, dsize=(2048, 1080), fx=0.0, fy=0.0, interpolation=cv2.INTER_AREA) + imgt = cv2.resize(src=imgt, dsize=(2048, 1080), fx=0.0, fy=0.0, interpolation=cv2.INTER_AREA) + + elif self.split == 'cropped-4k': + img0 = img0[540:-540, 1024:-1024, :] + img1 = img1[540:-540, 1024:-1024, :] + imgt = imgt[540:-540, 1024:-1024, :] + img0 = TF.to_tensor(img0.copy()) + img1 = TF.to_tensor(img1.copy()) + imgt = TF.to_tensor(imgt.copy()) + time_step = torch.Tensor([0.5]).reshape(1, 1, 1) + return { + 'img0': img0, 'imgt': imgt, 'img1': img1, 'time_step': time_step, 'scene_name': self.imgt_path_list[index] + } diff --git a/demo.sh b/demo.sh new file mode 100644 index 0000000000000000000000000000000000000000..fe3a6a6f54419652a2d7ee075f3fc5260c0ca411 --- /dev/null +++ b/demo.sh @@ -0,0 +1,4 @@ +#CUDA_VISIBLE_DEVICES=1 python main.py --cfg cfgs/upr_vimeo_exp0_demo.yaml +#CUDA_VISIBLE_DEVICES=1 python main.py --cfg cfgs/upr_vimeo_exp5_demo.yaml +#CUDA_VISIBLE_DEVICES=1 python main.py --cfg cfgs/amt_vimeo_exp5_demo.yaml +CUDA_VISIBLE_DEVICES=1 python main.py --cfg cfgs/m2m_flow_former_vimeo_exp1_demo.yaml diff --git a/figure.py b/figure.py new file mode 100644 index 0000000000000000000000000000000000000000..83797fe5b0d88fcac1c6620a9267d6bd3e8d1f90 --- /dev/null +++ b/figure.py @@ -0,0 +1,500 @@ +import os + +import math +import warnings +# import tkinter as tk +from collections import OrderedDict + +# from tkinter import ttk +from PIL import Image, ImageTk + +dataset_paths = ['vimeo_test', 'snu_film_easy', 'snu_film_medium', 'snu_film_hard', 'snu_film_extreme', + 'xiph_cropped-4k', 'xiph_resized-2k'] + + +class AutoScrollbar(ttk.Scrollbar): + """ A scrollbar that hides itself if it's not needed. Works only for grid geometry manager """ + + def set(self, lo, hi): + if float(lo) <= 0.0 and float(hi) >= 1.0: + self.grid_remove() + else: + self.grid() + ttk.Scrollbar.set(self, lo, hi) + + def pack(self, **kw): + raise tk.TclError('Cannot use pack with the widget ' + self.__class__.__name__) + + def place(self, **kw): + raise tk.TclError('Cannot use place with the widget ' + self.__class__.__name__) + + +class CanvasImage: + """ Display and zoom image """ + + def __init__(self, placeholder, path, row, column, width, height): + """ Initialize the ImageFrame """ + self.imscale = 1.0 # scale for the canvas image zoom, public for outer classes + self.__delta = 1.3 # zoom magnitude + self.__filter = Image.Resampling.BICUBIC # could be: NEAREST, BILINEAR, BICUBIC and ANTIALIAS + self.__previous_state = 0 # previous state of the keyboard + self.path = path # path to the image, should be public for outer classes + # Create ImageFrame in placeholder widget + self.__imframe = placeholder # placeholder of the ImageFrame object + # Vertical and horizontal scrollbars for canvas + hbar = AutoScrollbar(self.__imframe, orient='horizontal') + vbar = AutoScrollbar(self.__imframe, orient='vertical') + hbar.grid(row=row + 1, column=column, columnspan=2, sticky='we') + vbar.grid(row=row, column=column + 2, sticky='ns') + # Create canvas and bind it with scrollbars. Public for outer classes + with warnings.catch_warnings(): # suppress DecompressionBombWarning + warnings.simplefilter('ignore') + self.__image = Image.open(self.path) # open image, but down't load it + self.imwidth, self.imheight = self.__image.size # public for outer classes + self.canvas = tk.Canvas(self.__imframe, highlightthickness=0, + xscrollcommand=hbar.set, yscrollcommand=vbar.set, + width=width, height=height) + self.canvas.grid(row=row, column=column, columnspan=2, sticky='nswe') + self.canvas.update() # wait till canvas is created + hbar.configure(command=self.__scroll_x) # bind scrollbars to the canvas + vbar.configure(command=self.__scroll_y) + # Bind events to the Canvas + self.canvas.bind('', lambda event: self.__show_image()) # canvas is resized + self.canvas.bind('', self.__move_from) # remember canvas position + self.canvas.bind('', self.__default) # remember canvas position + self.canvas.bind('', self.__move_to) # move canvas to the new position + self.canvas.bind('', self.__wheel) # zoom for Windows and MacOS, but not Linux + self.canvas.bind('', self.__wheel) # zoom for Linux, wheel scroll down + self.canvas.bind('', self.__wheel) # zoom for Linux, wheel scroll up + # Handle keystrokes in idle mode, because program slows down on a weak computers, + # when too many key stroke events in the same time + self.canvas.bind('', lambda event: self.canvas.after_idle(self.__keystroke, event)) + # Decide if this image huge or not + self.__huge = False # huge or not + self.__huge_size = 14000 # define size of the huge image + self.__band_width = 1024 # width of the tile band + Image.MAX_IMAGE_PIXELS = 1000000000 # suppress DecompressionBombError for the big image + if self.imwidth * self.imheight > self.__huge_size * self.__huge_size and \ + self.__image.tile[0][0] == 'raw': # only raw images could be tiled + self.__huge = True # image is huge + self.__offset = self.__image.tile[0][2] # initial tile offset + self.__tile = [self.__image.tile[0][0], # it have to be 'raw' + [0, 0, self.imwidth, 0], # tile extent (a rectangle) + self.__offset, + self.__image.tile[0][3]] # list of arguments to the decoder + self.__min_side = min(self.imwidth, self.imheight) # get the smaller image side + # Create image pyramid + self.__pyramid = [self.smaller()] if self.__huge else [Image.open(self.path)] + # Set ratio coefficient for image pyramid + self.__ratio = max(self.imwidth, self.imheight) / self.__huge_size if self.__huge else 1.0 + self.__curr_img = 0 # current image from the pyramid + self.__scale = self.imscale * self.__ratio # image pyramide scale + self.__reduction = 2 # reduction degree of image pyramid + w, h = self.__pyramid[-1].size + while w > 512 and h > 512: # top pyramid image is around 512 pixels in size + w /= self.__reduction # divide on reduction degree + h /= self.__reduction # divide on reduction degree + self.__pyramid.append(self.__pyramid[-1].resize((int(w), int(h)), self.__filter)) + # Put image into container rectangle and use it to set proper coordinates to the image + self.container = self.canvas.create_rectangle((0, 0, self.imwidth, self.imheight), width=0) + self.__default() # show image on the canvas + self.canvas.focus_set() # set focus on the canvas + + def smaller(self): + """ Resize image proportionally and return smaller image """ + w1, h1 = float(self.imwidth), float(self.imheight) + w2, h2 = float(self.__huge_size), float(self.__huge_size) + aspect_ratio1 = w1 / h1 + aspect_ratio2 = w2 / h2 # it equals to 1.0 + if aspect_ratio1 == aspect_ratio2: + image = Image.new('RGB', (int(w2), int(h2))) + k = h2 / h1 # compression ratio + w = int(w2) # band length + elif aspect_ratio1 > aspect_ratio2: + image = Image.new('RGB', (int(w2), int(w2 / aspect_ratio1))) + k = h2 / w1 # compression ratio + w = int(w2) # band length + else: # aspect_ratio1 < aspect_ration2 + image = Image.new('RGB', (int(h2 * aspect_ratio1), int(h2))) + k = h2 / h1 # compression ratio + w = int(h2 * aspect_ratio1) # band length + i, j, n = 0, 1, round(0.5 + self.imheight / self.__band_width) + while i < self.imheight: + print('\rOpening image: {j} from {n}'.format(j=j, n=n), end='') + band = min(self.__band_width, self.imheight - i) # width of the tile band + self.__tile[1][3] = band # set band width + self.__tile[2] = self.__offset + self.imwidth * i * 3 # tile offset (3 bytes per pixel) + self.__image.close() + self.__image = Image.open(self.path) # reopen / reset image + self.__image.size = (self.imwidth, band) # set size of the tile band + self.__image.tile = [self.__tile] # set tile + cropped = self.__image.crop((0, 0, self.imwidth, band)) # crop tile band + image.paste(cropped.resize((w, int(band * k) + 1), self.__filter), (0, int(i * k))) + i += band + j += 1 + print('\r' + 30 * ' ' + '\r', end='') # hide printed string + return image + + def redraw_figures(self): + """ Dummy function to redraw figures in the children classes """ + pass + + def __default(self, *args, **kw): + self.imscale = min(self.canvas.winfo_height() / self.imheight, self.canvas.winfo_width() / self.imwidth) + k = self.imscale * self.__ratio # temporary coefficient + self.__curr_img = min((-1) * int(math.log(k, self.__reduction)), len(self.__pyramid) - 1) + self.__scale = k * math.pow(self.__reduction, max(0, self.__curr_img)) + x, y = 0, (self.canvas.winfo_height() - self.imscale * self.imheight) / 2 + self.canvas.scale('all', x, y, self.imscale, self.imscale) # rescale all objects + # Redraw some figures before showing image on the screen + self.redraw_figures() # method for child classes + self.__show_image() + + def grid(self, **kw): + """ Put CanvasImage widget on the parent widget """ + self.__imframe.grid(**kw) # place CanvasImage widget on the grid + self.__imframe.grid(sticky='nswe') # make frame container sticky + self.__imframe.rowconfigure(0, weight=1) # make canvas expandable + self.__imframe.columnconfigure(0, weight=1) + + def pack(self, **kw): + """ Exception: cannot use pack with this widget """ + raise Exception('Cannot use pack with the widget ' + self.__class__.__name__) + + def place(self, **kw): + """ Exception: cannot use place with this widget """ + raise Exception('Cannot use place with the widget ' + self.__class__.__name__) + + # noinspection PyUnusedLocal + def __scroll_x(self, *args, **kwargs): + """ Scroll canvas horizontally and redraw the image """ + self.canvas.xview(*args) # scroll horizontally + self.__show_image() # redraw the image + + # noinspection PyUnusedLocal + def __scroll_y(self, *args, **kwargs): + """ Scroll canvas vertically and redraw the image """ + self.canvas.yview(*args) # scroll vertically + self.__show_image() # redraw the image + + def __show_image(self): + """ Show image on the Canvas. Implements correct image zoom almost like in Google Maps """ + box_image = self.canvas.coords(self.container) # get image area + box_canvas = (self.canvas.canvasx(0), # get visible area of the canvas + self.canvas.canvasy(0), + self.canvas.canvasx(self.canvas.winfo_width()), + self.canvas.canvasy(self.canvas.winfo_height())) + box_img_int = tuple(map(int, box_image)) # convert to integer or it will not work properly + # Get scroll region box + box_scroll = [min(box_img_int[0], box_canvas[0]), min(box_img_int[1], box_canvas[1]), + max(box_img_int[2], box_canvas[2]), max(box_img_int[3], box_canvas[3])] + # Horizontal part of the image is in the visible area + if box_scroll[0] == box_canvas[0] and box_scroll[2] == box_canvas[2]: + box_scroll[0] = box_img_int[0] + box_scroll[2] = box_img_int[2] + # Vertical part of the image is in the visible area + if box_scroll[1] == box_canvas[1] and box_scroll[3] == box_canvas[3]: + box_scroll[1] = box_img_int[1] + box_scroll[3] = box_img_int[3] + # Convert scroll region to tuple and to integer + self.canvas.configure(scrollregion=tuple(map(int, box_scroll))) # set scroll region + x1 = max(box_canvas[0] - box_image[0], 0) # get coordinates (x1,y1,x2,y2) of the image tile + y1 = max(box_canvas[1] - box_image[1], 0) + x2 = min(box_canvas[2], box_image[2]) - box_image[0] + y2 = min(box_canvas[3], box_image[3]) - box_image[1] + if int(x2 - x1) > 0 and int(y2 - y1) > 0: # show image if it in the visible area + if self.__huge and self.__curr_img < 0: # show huge image + h = int((y2 - y1) / self.imscale) # height of the tile band + self.__tile[1][3] = h # set the tile band height + self.__tile[2] = self.__offset + self.imwidth * int(y1 / self.imscale) * 3 + self.__image.close() + self.__image = Image.open(self.path) # reopen / reset image + self.__image.size = (self.imwidth, h) # set size of the tile band + self.__image.tile = [self.__tile] + image = self.__image.crop((int(x1 / self.imscale), 0, int(x2 / self.imscale), h)) + else: # show normal image + image = self.__pyramid[max(0, self.__curr_img)].crop( # crop current img from pyramid + (int(x1 / self.__scale), int(y1 / self.__scale), + int(x2 / self.__scale), int(y2 / self.__scale))) + # + imagetk = ImageTk.PhotoImage(image.resize((int(x2 - x1), int(y2 - y1)), self.__filter)) + imageid = self.canvas.create_image(max(box_canvas[0], box_img_int[0]), + max(box_canvas[1], box_img_int[1]), + anchor='nw', image=imagetk) + self.canvas.lower(imageid) # set image into background + self.canvas.imagetk = imagetk # keep an extra reference to prevent garbage-collection + + def __move_from(self, event): + """ Remember previous coordinates for scrolling with the mouse """ + self.canvas.scan_mark(event.x, event.y) + + def __move_to(self, event): + """ Drag (move) canvas to the new position """ + self.canvas.scan_dragto(event.x, event.y, gain=1) + self.__show_image() # zoom tile and show it on the canvas + + def outside(self, x, y): + """ Checks if the point (x,y) is outside the image area """ + bbox = self.canvas.coords(self.container) # get image area + if bbox[0] < x < bbox[2] and bbox[1] < y < bbox[3]: + return False # point (x,y) is inside the image area + else: + return True # point (x,y) is outside the image area + + def __wheel(self, event): + """ Zoom with mouse wheel """ + x = self.canvas.canvasx(event.x) # get coordinates of the event on the canvas + y = self.canvas.canvasy(event.y) + if self.outside(x, y): return # zoom only inside image area + scale = 1.0 + # Respond to Linux (event.num) or Windows (event.delta) wheel event + if event.num == 5 or event.delta == -120: # scroll down, smaller + if round(self.__min_side * self.imscale) < 30: return # image is less than 30 pixels + self.imscale /= self.__delta + scale /= self.__delta + if event.num == 4 or event.delta == 120: # scroll up, bigger + i = min(self.canvas.winfo_width(), self.canvas.winfo_height()) >> 1 + if i < self.imscale: return # 1 pixel is bigger than the visible area + self.imscale *= self.__delta + scale *= self.__delta + # Take appropriate image from the pyramid + k = self.imscale * self.__ratio # temporary coefficient + self.__curr_img = min((-1) * int(math.log(k, self.__reduction)), len(self.__pyramid) - 1) + self.__scale = k * math.pow(self.__reduction, max(0, self.__curr_img)) + # + self.canvas.scale('all', x, y, scale, scale) # rescale all objects + # Redraw some figures before showing image on the screen + self.redraw_figures() # method for child classes + self.__show_image() + + def __keystroke(self, event): + """ Scrolling with the keyboard. + Independent from the language of the keyboard, CapsLock, +, etc. """ + if event.state - self.__previous_state == 4: # means that the Control key is pressed + pass # do nothing if Control key is pressed + else: + self.__previous_state = event.state # remember the last keystroke state + # Up, Down, Left, Right keystrokes + if event.keycode in [68, 39, 102]: # scroll right: keys 'D', 'Right' or 'Numpad-6' + self.__scroll_x('scroll', 1, 'unit', event=event) + elif event.keycode in [65, 37, 100]: # scroll left: keys 'A', 'Left' or 'Numpad-4' + self.__scroll_x('scroll', -1, 'unit', event=event) + elif event.keycode in [87, 38, 104]: # scroll up: keys 'W', 'Up' or 'Numpad-8' + self.__scroll_y('scroll', -1, 'unit', event=event) + elif event.keycode in [83, 40, 98]: # scroll down: keys 'S', 'Down' or 'Numpad-2' + self.__scroll_y('scroll', 1, 'unit', event=event) + + def crop(self, bbox): + """ Crop rectangle from the image and return it """ + if self.__huge: # image is huge and not totally in RAM + band = bbox[3] - bbox[1] # width of the tile band + self.__tile[1][3] = band # set the tile height + self.__tile[2] = self.__offset + self.imwidth * bbox[1] * 3 # set offset of the band + self.__image.close() + self.__image = Image.open(self.path) # reopen / reset image + self.__image.size = (self.imwidth, band) # set size of the tile band + self.__image.tile = [self.__tile] + return self.__image.crop((bbox[0], 0, bbox[2], band)) + else: # image is totally in RAM + return self.__pyramid[0].crop(bbox) + + def destroy(self): + """ ImageFrame destructor """ + self.__image.close() + map(lambda i: i.close, self.__pyramid) # close all pyramid images + del self.__pyramid[:] # delete pyramid list + del self.__pyramid # delete pyramid variable + self.canvas.destroy() + + +class FirstWindow(ttk.Frame): + def __init__(self, mainframe): + ttk.Frame.__init__(self, master=mainframe) + + +class SampleApp(tk.Tk): + def __init__(self): + tk.Tk.__init__(self) + self.figure_path = '/hdd/spocklabs/save_backup/' + self._exp_frame = None + self._dataset_frame = None + self._figure_frame = None + self.make_exp_frame() + + def make_exp_frame(self): + self.destroy_all() + self._exp_frame = ExpFrame(self) + self._exp_frame.pack() + + def make_dataset_frame(self): + self.destroy_all() + self._dataset_frame = DatsetFrame(self) + self._dataset_frame.pack() + + def make_figure_frame(self): + self.destroy_all() + self._figure_frame = FigureFrame(self) + self._figure_frame.pack() + + def destroy_all(self): + if self._exp_frame is not None: + self._exp_frame.destroy() + if self._dataset_frame is not None: + self._dataset_frame.destroy() + if self._figure_frame is not None: + self._figure_frame.destroy() + + +class ExpFrame(tk.Frame): + def __init__(self, master): + tk.Frame.__init__(self, master) + self.master.title('Select experiments') + self.master.geometry('800x600') + self.experiments = os.listdir(self.master.figure_path) + self.CheckVarietys = [tk.IntVar() for _ in range(len(self.experiments))] + self.checkbuttons = [tk.Checkbutton(self, text=self.experiments[i], variable=self.CheckVarietys[i]) for i in + range(len(self.experiments))] + for checkbutton in self.checkbuttons: + checkbutton.pack() + + start_figure = tk.Button(self, text='Make Figure', command=self.make_figure) + start_figure.pack() + + def make_figure(self): + self.master.figure_experiments = [] + for i in range(len(self.CheckVarietys)): + if self.CheckVarietys[i].get() == 1: + self.master.figure_experiments.append(self.experiments[i]) + self.master.make_dataset_frame() + + +class DatsetFrame(tk.Frame): + def __init__(self, master): + tk.Frame.__init__(self, master) + self.master.title('Select dataset') + button_go_to_dataset = tk.Button(self, text='Go to select dataset', command=self.master.make_exp_frame, width=20) + button_go_to_dataset.grid(row=0, column=0, sticky='w') + self.listbox = tk.Listbox(self, selectmode='extended') + self.listbox.insert(0, 'vimeo test') + self.listbox.insert(1, 'SNU FILM easy') + self.listbox.insert(2, 'SNU FILM medium') + self.listbox.insert(3, 'SNU FILM hard') + self.listbox.insert(4, 'SNU FILM extreme') + self.listbox.insert(5, 'xiph cropped 4k') + self.listbox.insert(6, 'xiph resized 2k') + self.listbox.grid(row=1, column=0, sticky='nsew') + + select_button = tk.Button(self, text='Select dataset', command=self.select_dataset) + select_button.grid(row=2, column=0) + + def select_dataset(self): + self.master.ind = self.listbox.curselection()[0] + self.master.psnr_lines = [] + self.master.dataset_path = dataset_paths[self.master.ind] + for i in range(len(self.master.figure_experiments)): + with open(os.path.join(self.master.figure_path, self.master.figure_experiments[i], 'output/imgs_test', + self.master.dataset_path, + 'results.txt')) as f: + psnrs = f.readlines() + file_psnr = {psnrs[j].split(":")[0]: float(psnrs[j].split(":")[1]) for j in range(len(psnrs))} + self.master.psnr_lines.append(file_psnr) + if i == 0: + self.master.file_order_list = [psnrs[j].split(":")[0] for j in range(len(psnrs))] + + self.master.make_figure_frame() + + +class FigureFrame(tk.Frame): + def __init__(self, master): + tk.Frame.__init__(self, master) + self.counter = 0 + self.master.title('Figure') + self.width = min(480, 3840 // (len(self.master.figure_experiments) + 2)) + self.height = self.width * 270 // 480 + self.master.geometry(f'{self.width * (len(self.master.figure_experiments) + 2)+100}x{self.height + 100}+100+100') + + self.make_figure() + + def prev(self): + self.counter -= 1 + if self.counter < 0: + self.counter = len(self.master.file_order_list) - 1 + self.destroy_all() + self.make_figure() + + def next(self): + self.counter += 1 + if self.counter >= len(self.master.file_order_list): + self.counter = 0 + self.destroy_all() + self.make_figure() + + def go_to(self, _): + key = self.entry_goto.get() + self.counter = self.master.file_order_list.index(key) + self.destroy_all() + self.make_figure() + + def open_eog(self): + img_path = self.master.file_order_list[self.counter] + overlayed_path = os.path.join(self.master.figure_path, self.master.figure_experiments[0], 'output/imgs_test', + self.master.dataset_path, img_path, + 'overlayedd.png') + os.system(f"eog {overlayed_path}") + for i in range(len(self.master.figure_experiments)): + imgt_pred_path = os.path.join(self.master.figure_path, self.master.figure_experiments[i], 'output/imgs_test', + self.master.dataset_path, img_path, + 'imgt_pred.png') + os.system(f"eog {imgt_pred_path}") + imgt_path = os.path.join(self.master.figure_path, self.master.figure_experiments[0], 'output/imgs_test', + self.master.dataset_path, img_path, + 'imgt.png') + os.system(f"eog {imgt_path}") + def make_figure(self): + button_go_to_dataset = tk.Button(self, text='go to select dataset', command=self.master.make_exp_frame, width=20) + button_go_to_exp = tk.Button(self, text='go to select exp', command=self.master.make_dataset_frame, width=20) + button_go_to_dataset.grid(row=0, column=0, sticky='nsew') + button_go_to_exp.grid(row=0, column=1, sticky='nsew') + img_path = self.master.file_order_list[self.counter] + # photo_list = [] + self.overlayed_label = tk.Label(self, text=f'Overlayed {img_path}') + self.overlayed_label.grid(row=1, column=0, columnspan=2) + self.overlayed = CanvasImage(self, os.path.join(self.master.figure_path, self.master.figure_experiments[0], 'output/imgs_test', + self.master.dataset_path, img_path, + 'overlayedd.png'), 2, 0, self.width, self.height) + self.imgt_pred_labels = [ + tk.Label(self, text=f'{self.master.figure_experiments[i]} {float(self.master.psnr_lines[i][img_path]):2f}') + for i in range(len(self.master.figure_experiments))] + for i, imgt_pred_label in enumerate(self.imgt_pred_labels): + imgt_pred_label.grid(row=1, column=3 * (i + 1), columnspan=2) + self.imgt_preds = [CanvasImage(self, os.path.join(self.master.figure_path, self.master.figure_experiments[i], 'output/imgs_test', + self.master.dataset_path, img_path, + 'imgt_pred.png'), 2, 3 * (i + 1), self.width, self.height) for i in + range(len(self.master.figure_experiments))] + self.imgt_label = tk.Label(self, text=f'GT') + self.imgt_label.grid(row=1, column=3 * (len(self.master.figure_experiments) + 1), columnspan=2) + self.imgt = CanvasImage(self, os.path.join(self.master.figure_path, self.master.figure_experiments[0], 'output/imgs_test', + self.master.dataset_path, img_path, + 'imgt.png'), 2, 3 * (len(self.master.figure_experiments) + 1), self.width, self.height) + button_next = tk.Button(self, text='next', command=self.next, width=10) + button_prev = tk.Button(self, text='prev', command=self.prev, width=10) + self.entry_goto = tk.Entry(self, width=20) + self.entry_goto.bind('', self.go_to) + button_prev.grid(row=4, column=0, sticky='nsew') + button_next.grid(row=4, column=1, sticky='nsew') + self.label = tk.Label(self, text='go to:', width=10) + self.label.grid(row=4, column=3, sticky='nsew') + self.entry_goto.grid(row=4, column=4, sticky='nsew') + self.button_open_eog = tk.Button(self, text='open in eog', command=self.open_eog) + self.button_open_eog.grid(row=5, column=0, sticky='nsew') + + def destroy_all(self): + self.overlayed.destroy() + self.imgt.destroy() + for imgt_pred in self.imgt_preds: + imgt_pred.destroy() + + +if __name__ == "__main__": + app = SampleApp() + app.mainloop() diff --git a/figures/flow_backward_0.png b/figures/flow_backward_0.png new file mode 100644 index 0000000000000000000000000000000000000000..c0e97ba03f7c957255e2865932dd0aae3988735d Binary files /dev/null and b/figures/flow_backward_0.png differ diff --git a/figures/flow_backward_4.png b/figures/flow_backward_4.png new file mode 100644 index 0000000000000000000000000000000000000000..5ea80e22e85037e4fb6b3fb332d9e313e25c3ce8 Binary files /dev/null and b/figures/flow_backward_4.png differ diff --git a/figures/flow_forward_0.png b/figures/flow_forward_0.png new file mode 100644 index 0000000000000000000000000000000000000000..accc5e4d7b98874df4d61954588bc6c1f2940135 Binary files /dev/null and b/figures/flow_forward_0.png differ diff --git a/figures/flow_forward_4.png b/figures/flow_forward_4.png new file mode 100644 index 0000000000000000000000000000000000000000..9d0ff12e409b0cdc44dc8c76f0baa82fc1a43bdd Binary files /dev/null and b/figures/flow_forward_4.png differ diff --git a/figures/img0_0.png b/figures/img0_0.png new file mode 100644 index 0000000000000000000000000000000000000000..720227482134f9f25072d3d9048ca5de10ec3ad0 Binary files /dev/null and b/figures/img0_0.png differ diff --git a/figures/img0_1.png b/figures/img0_1.png new file mode 100644 index 0000000000000000000000000000000000000000..6a7ba05993be5851d395d3820d32ce9d4ea2c417 Binary files /dev/null and b/figures/img0_1.png differ diff --git a/figures/img0_2.png b/figures/img0_2.png new file mode 100644 index 0000000000000000000000000000000000000000..9cfc134b3286ef5bb0301c6cf5dc3686fe402223 Binary files /dev/null and b/figures/img0_2.png differ diff --git a/figures/img1_0.png b/figures/img1_0.png new file mode 100644 index 0000000000000000000000000000000000000000..2969936a5c06e5dd4afdf71219cceb8f1102dc4f Binary files /dev/null and b/figures/img1_0.png differ diff --git a/figures/img1_1.png b/figures/img1_1.png new file mode 100644 index 0000000000000000000000000000000000000000..1622276dbd4a9e0d80819c6304aa820fd65e0e4b Binary files /dev/null and b/figures/img1_1.png differ diff --git a/figures/img1_2.png b/figures/img1_2.png new file mode 100644 index 0000000000000000000000000000000000000000..3a85fec25bd1b77bd067a7e7162d6a52765c3c49 Binary files /dev/null and b/figures/img1_2.png differ diff --git a/figures/imgt_0.png b/figures/imgt_0.png new file mode 100644 index 0000000000000000000000000000000000000000..09303efcec95076fca6b9788d1ee5f4ae777e279 Binary files /dev/null and b/figures/imgt_0.png differ diff --git a/figures/imgt_1.png b/figures/imgt_1.png new file mode 100644 index 0000000000000000000000000000000000000000..c3aedc74cb1084f4878d4fa59c6c890c5b82a346 Binary files /dev/null and b/figures/imgt_1.png differ diff --git a/figures/imgt_2.png b/figures/imgt_2.png new file mode 100644 index 0000000000000000000000000000000000000000..756515c283b5d50663fc9f198b436ed35410a914 Binary files /dev/null and b/figures/imgt_2.png differ diff --git a/figures/imgt_3.png b/figures/imgt_3.png new file mode 100644 index 0000000000000000000000000000000000000000..6800427c2c1092c0616d62fb9d29aef8db542019 Binary files /dev/null and b/figures/imgt_3.png differ diff --git a/inference_video.sh b/inference_video.sh new file mode 100644 index 0000000000000000000000000000000000000000..ff6ccdcfc2ec36ce8d9e925da7175eeb47d2418b --- /dev/null +++ b/inference_video.sh @@ -0,0 +1,11 @@ +IN_PATH="save/upr_vimeo_exp0_upr_original/output/demo/" +for entry in "$IN_PATH"* +do + if [[ ${entry} != *"flow" ]];then + echo $entry + OUT_PATH="${entry}/" + ffmpeg -framerate 120 -pattern_type glob -i "${OUT_PATH}*.png" -c:v libx265 -qp 8 -pix_fmt yuv420p ${entry}_120fps.mp4 + + fi +done + diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..72fc3983fbe2567376435cc84e92ebbae07c5e52 --- /dev/null +++ b/main.py @@ -0,0 +1,83 @@ +import argparse +import os + +import yaml +import torch +import torch.multiprocessing as mp + +import utils +from utils.experiment import * +from trainer import Trainer + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--cfg') + parser.add_argument('--load-root', default='data') + parser.add_argument('--save-root', default='save') + parser.add_argument('--name', '-n', default=None) + parser.add_argument('--tag', default=None) + parser.add_argument('--cudnn', action='store_true') + parser.add_argument('--port-offset', '-p', type=int, default=0) + parser.add_argument('--wandb-upload', '-w', action='store_true') + args = parser.parse_args() + + return args + + +def make_cfg(args): + with open(args.cfg, 'r') as f: + cfg = yaml.load(f, Loader=yaml.FullLoader) + + def translate_cfg_(d): + for k, v in d.items(): + if isinstance(v, dict): + translate_cfg_(v) + elif isinstance(v, str): + d[k] = v.replace('$load_root$', args.load_root) + translate_cfg_(cfg) + + if args.name is None: + exp_name = os.path.basename(args.cfg).split('.')[0].replace('_benchmark', '').replace('_demo', '') + else: + exp_name = args.name + if args.tag is not None: + exp_name += '_' + args.tag + + env = dict() + env['exp_name'] = exp_name + '_' + cfg['exp_name'] + env['save_dir'] = os.path.join(args.save_root, env['exp_name']) + env['tot_gpus'] = torch.cuda.device_count() + env['cudnn'] = args.cudnn + env['port'] = str(29600 + args.port_offset) + env['wandb_upload'] = args.wandb_upload + cfg['env'] = env + + return cfg + + +def main(): + args = parse_args() + + cfgs = make_cfg(args) + + init_experiment(cfgs) + init_distributed_mode(cfgs) + print('here') + init_deterministic(cfgs['seed']) + + trainer = Trainer(cfgs) + + if cfgs['mode'] == 'train': + trainer.train() + elif cfgs['mode'] == 'validate': + trainer.validate() + elif cfgs['mode'] == 'test': + trainer.test() + elif cfgs['mode'] == 'demo': + trainer.demo() + + + +if __name__ == '__main__': + main() diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2558ffdef518fd8b5f64602f544ec06fd5942d95 --- /dev/null +++ b/modules/__init__.py @@ -0,0 +1,3 @@ +from .models import register, make +from .lr_scheduler import * +from .optimizer import * diff --git a/modules/__pycache__/__init__.cpython-310.pyc b/modules/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5273e13261e980e22d4dd2327ac1c2814e35e89 Binary files /dev/null and b/modules/__pycache__/__init__.cpython-310.pyc differ diff --git a/modules/__pycache__/__init__.cpython-38.pyc b/modules/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4555efc2049ee0bb8e63153c513d862885052124 Binary files /dev/null and b/modules/__pycache__/__init__.cpython-38.pyc differ diff --git a/modules/__pycache__/__init__.cpython-39.pyc b/modules/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83a6eb1028569436563fa00d68df167fee475916 Binary files /dev/null and b/modules/__pycache__/__init__.cpython-39.pyc differ diff --git a/modules/__pycache__/loss.cpython-310.pyc b/modules/__pycache__/loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2bf973fa8d1724661fa349128302420360c44eeb Binary files /dev/null and b/modules/__pycache__/loss.cpython-310.pyc differ diff --git a/modules/__pycache__/loss.cpython-38.pyc b/modules/__pycache__/loss.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a8623fa8209a4ba6fecf58eb0ea56245d5d11ca Binary files /dev/null and b/modules/__pycache__/loss.cpython-38.pyc differ diff --git a/modules/__pycache__/loss.cpython-39.pyc b/modules/__pycache__/loss.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff17ad99629c99d5fe115794d64ed10d2cfde237 Binary files /dev/null and b/modules/__pycache__/loss.cpython-39.pyc differ diff --git a/modules/__pycache__/lr_scheduler.cpython-310.pyc b/modules/__pycache__/lr_scheduler.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75d2561a64b9cdf2035ead036beded782ff645fe Binary files /dev/null and b/modules/__pycache__/lr_scheduler.cpython-310.pyc differ diff --git a/modules/__pycache__/lr_scheduler.cpython-38.pyc b/modules/__pycache__/lr_scheduler.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f433a7ca5227ce40f891ceb2868bee645e38afa1 Binary files /dev/null and b/modules/__pycache__/lr_scheduler.cpython-38.pyc differ diff --git a/modules/__pycache__/lr_scheduler.cpython-39.pyc b/modules/__pycache__/lr_scheduler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87d6ba35c72a77d194126dec8b0535020c176c58 Binary files /dev/null and b/modules/__pycache__/lr_scheduler.cpython-39.pyc differ diff --git a/modules/__pycache__/optimizer.cpython-310.pyc b/modules/__pycache__/optimizer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74a396301dd9dc38753425bf885bb603b4ed1adf Binary files /dev/null and b/modules/__pycache__/optimizer.cpython-310.pyc differ diff --git a/modules/__pycache__/optimizer.cpython-38.pyc b/modules/__pycache__/optimizer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61bbeb04f3bc0f80c99263b40bfe5e0c02c4206a Binary files /dev/null and b/modules/__pycache__/optimizer.cpython-38.pyc differ diff --git a/modules/__pycache__/optimizer.cpython-39.pyc b/modules/__pycache__/optimizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e3b9c6dd1a17b9d60e734a050fc204501349cfc Binary files /dev/null and b/modules/__pycache__/optimizer.cpython-39.pyc differ diff --git a/modules/components/__init__.py b/modules/components/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0504b415bd286148436e841a4376e32e8214cb2b --- /dev/null +++ b/modules/components/__init__.py @@ -0,0 +1,14 @@ +from .components import * +from .m2m_pwc import * +from .amt import * +from .upr_basic import * +from .upr_net import * +from .upr_net_mod import * +from .upr_net_mod2 import * +from .upr_net_freq import * +from .upr_net_freq2 import * +# from .m2m_flow_former import * +from .amt_flowformer import * +from .upr_net_multi_flow import * +from .amt_bilateral import * +from .amt_splat import * \ No newline at end of file diff --git a/modules/components/__pycache__/__init__.cpython-310.pyc b/modules/components/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4eca1506a278c940406fb29479768fe54ed7804 Binary files /dev/null and b/modules/components/__pycache__/__init__.cpython-310.pyc differ diff --git a/modules/components/__pycache__/__init__.cpython-38.pyc b/modules/components/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e3655946ff53b46c26f3f2afcf6d50aee3815eb Binary files /dev/null and b/modules/components/__pycache__/__init__.cpython-38.pyc differ diff --git a/modules/components/__pycache__/__init__.cpython-39.pyc b/modules/components/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca4fda6972e679bfb10f2c223648cc2e6c1a9c3d Binary files /dev/null and b/modules/components/__pycache__/__init__.cpython-39.pyc differ diff --git a/modules/components/__pycache__/components.cpython-310.pyc b/modules/components/__pycache__/components.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..168373dc8515e2f2cd09e80ace9ab15a4c6c8699 Binary files /dev/null and b/modules/components/__pycache__/components.cpython-310.pyc differ diff --git a/modules/components/__pycache__/components.cpython-38.pyc b/modules/components/__pycache__/components.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e2fb6c26d8125b69a3e8f360a1db10c2659644a Binary files /dev/null and b/modules/components/__pycache__/components.cpython-38.pyc differ diff --git a/modules/components/__pycache__/components.cpython-39.pyc b/modules/components/__pycache__/components.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06a41b12e2dc378388004c6c5069b412d6839038 Binary files /dev/null and b/modules/components/__pycache__/components.cpython-39.pyc differ diff --git a/modules/components/amt/AMT.py b/modules/components/amt/AMT.py new file mode 100644 index 0000000000000000000000000000000000000000..0cc95ee50fbd16c327a68bc6c75ba98a6ef2ec40 --- /dev/null +++ b/modules/components/amt/AMT.py @@ -0,0 +1,197 @@ +import torch +import torch.nn as nn +from modules.components.amt.blocks.raft import ( + coords_grid, + SmallUpdateBlock, BidirCorrBlock, BasicUpdateBlock +) +from .blocks.feat_enc import ( + SmallEncoder, + BasicEncoder, + LargeEncoder +) +from .blocks.ifrnet import ( + resize, + Encoder, + InitDecoder, + IntermediateDecoder +) +from .blocks.multi_flow import ( + multi_flow_combine, + MultiFlowDecoder +) + +from ..components import register + +from utils.padder import InputPadder + +@register('amt') +class Model(nn.Module): + def __init__(self, + model_size='S', + corr_radius=3, + corr_lvls=4, + num_flows=3, + channels=[20, 32, 44, 56], + skip_channels=20, + scale_factor=1): + super(Model, self).__init__() + self.model_size = model_size + self.radius = corr_radius + self.corr_levels = corr_lvls + self.num_flows = num_flows + self.channels = channels + self.skip_channels = skip_channels + self.scale_factor = scale_factor + if self.model_size == 'S': + self.feat_encoder = SmallEncoder(output_dim=84, norm_fn='instance', dropout=0.) + elif self.model_size == 'L': + self.feat_encoder = BasicEncoder(output_dim=128, norm_fn='instance', dropout=0.) + elif self.model_size == 'G': + self.feat_encoder = LargeEncoder(output_dim=128, norm_fn='instance', dropout=0.) + self.encoder = Encoder(channels, large=True) + + self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels) + self.decoder3 = IntermediateDecoder(channels[2], channels[1], skip_channels) + self.decoder2 = IntermediateDecoder(channels[1], channels[0], skip_channels) + self.decoder1 = MultiFlowDecoder(channels[0], skip_channels, num_flows) + + self.update4 = self._get_updateblock(channels[2]) + self.update3_low = self._get_updateblock(channels[1], 2) + self.update2_low = self._get_updateblock(channels[0], 4) + + if self.model_size == 'G': + self.update3_high = self._get_updateblock(channels[1], None) + self.update2_high = self._get_updateblock(channels[0], None) + + self.comb_block = nn.Sequential( + nn.Conv2d(3 * self.num_flows, 6 * self.num_flows, 7, 1, 3), + nn.PReLU(6 * self.num_flows), + nn.Conv2d(6 * self.num_flows, 3, 7, 1, 3), + ) + + def _get_updateblock(self, cdim, scale_factor=None): + return BasicUpdateBlock(cdim=cdim, hidden_dim=192, flow_dim=64, + corr_dim=256, corr_dim2=192, fc_dim=188, + scale_factor=scale_factor, corr_levels=self.corr_levels, + radius=self.radius) + + def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1): + # convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0 + # based on linear assumption + t1_scale = 1. / embt + t0_scale = 1. / (1. - embt) + if downsample != 1: + inv = 1 / downsample + flow0 = inv * resize(flow0, scale_factor=inv) + flow1 = inv * resize(flow1, scale_factor=inv) + + corr0, corr1 = corr_fn(coord + flow1 * t1_scale, coord + flow0 * t0_scale) + corr = torch.cat([corr0, corr1], dim=1) + flow = torch.cat([flow0, flow1], dim=1) + return corr, flow + + def forward(self, img0, img1, time_step, scale_factor=None, eval=False, **kwargs): + scale_factor = self.scale_factor if scale_factor is None else scale_factor + padder = InputPadder(img0.shape, divisor=int(16 / scale_factor)) + img0, img1 = padder.pad(img0, img1) + mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) + img0 = img0 - mean_ + img1 = img1 - mean_ + img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0 + img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1 + b, _, h, w = img0_.shape + coord = coords_grid(b, h // 8, w // 8, img0.device) + + fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8] + corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels) + + # f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4] + # f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16] + f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_) + f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_) + + ######################################### the 4th decoder ######################################### + up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, time_step) + corr_4, flow_4 = self._corr_scale_lookup(corr_fn, coord, + up_flow0_4, up_flow1_4, + time_step, downsample=1) + + # residue update with lookup corr + delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4) + delta_flow0_4, delta_flow1_4 = torch.chunk(delta_flow_4, 2, 1) + up_flow0_4 = up_flow0_4 + delta_flow0_4 + up_flow1_4 = up_flow1_4 + delta_flow1_4 + ft_3_ = ft_3_ + delta_ft_3_ + + ######################################### the 3rd decoder ######################################### + up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4) + corr_3, flow_3 = self._corr_scale_lookup(corr_fn, + coord, up_flow0_3, up_flow1_3, + time_step, downsample=2) + + # residue update with lookup corr + delta_ft_2_, delta_flow_3 = self.update3_low(ft_2_, flow_3, corr_3) + delta_flow0_3, delta_flow1_3 = torch.chunk(delta_flow_3, 2, 1) + up_flow0_3 = up_flow0_3 + delta_flow0_3 + up_flow1_3 = up_flow1_3 + delta_flow1_3 + ft_2_ = ft_2_ + delta_ft_2_ + + if self.model_size == 'G': + # residue update with lookup corr (hr) + corr_3 = resize(corr_3, scale_factor=2.0) + up_flow_3 = torch.cat([up_flow0_3, up_flow1_3], dim=1) + delta_ft_2_, delta_up_flow_3 = self.update3_high(ft_2_, up_flow_3, corr_3) + ft_2_ += delta_ft_2_ + up_flow0_3 += delta_up_flow_3[:, 0:2] + up_flow1_3 += delta_up_flow_3[:, 2:4] + + ######################################### the 2nd decoder ######################################### + up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3) + corr_2, flow_2 = self._corr_scale_lookup(corr_fn, + coord, up_flow0_2, up_flow1_2, + time_step, downsample=4) + + # residue update with lookup corr + delta_ft_1_, delta_flow_2 = self.update2_low(ft_1_, flow_2, corr_2) + delta_flow0_2, delta_flow1_2 = torch.chunk(delta_flow_2, 2, 1) + up_flow0_2 = up_flow0_2 + delta_flow0_2 + up_flow1_2 = up_flow1_2 + delta_flow1_2 + ft_1_ = ft_1_ + delta_ft_1_ + + if self.model_size == 'G': + # residue update with lookup corr (hr) + corr_2 = resize(corr_2, scale_factor=4.0) + up_flow_2 = torch.cat([up_flow0_2, up_flow1_2], dim=1) + delta_ft_1_, delta_up_flow_2 = self.update2_high(ft_1_, up_flow_2, corr_2) + ft_1_ += delta_ft_1_ + up_flow0_2 += delta_up_flow_2[:, 0:2] + up_flow1_2 += delta_up_flow_2[:, 2:4] + + ######################################### the 1st decoder ######################################### + up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2) + + if scale_factor != 1.0: + up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) + up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) + mask = resize(mask, scale_factor=(1.0/scale_factor)) + img_res = resize(img_res, scale_factor=(1.0/scale_factor)) + + # Merge multiple predictions + imgt_pred = multi_flow_combine(self.comb_block, img0, img1, up_flow0_1, up_flow1_1, + mask, img_res, mean_) + imgt_pred = torch.clamp(imgt_pred, 0, 1) + imgt_pred = padder.unpad(imgt_pred) + + if eval: + return { 'imgt_pred': imgt_pred, } + else: + up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, int(h / scale_factor), int(w / scale_factor)) + up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, int(h / scale_factor), int(w / scale_factor)) + return { + 'imgt_pred': imgt_pred, + 'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4], + 'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4], + 'flowfwd': up_flow0_1[:, 0], + 'flowbwd': up_flow1_1[:, 0], + 'ft_pred': [ft_1_, ft_2_, ft_3_], + } diff --git a/modules/components/amt/__init__.py b/modules/components/amt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..387563589f3cc4a3dc664a5d8d4f5557b0100996 --- /dev/null +++ b/modules/components/amt/__init__.py @@ -0,0 +1 @@ +from .AMT import Model \ No newline at end of file diff --git a/modules/components/amt/__pycache__/AMT.cpython-310.pyc b/modules/components/amt/__pycache__/AMT.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..207575c42bc4befbcaf329775f5f8063db6810bc Binary files /dev/null and b/modules/components/amt/__pycache__/AMT.cpython-310.pyc differ diff --git a/modules/components/amt/__pycache__/AMT.cpython-38.pyc b/modules/components/amt/__pycache__/AMT.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc0f7ef6a364036ded86e272ad2f00076375f4a5 Binary files /dev/null and b/modules/components/amt/__pycache__/AMT.cpython-38.pyc differ diff --git a/modules/components/amt/__pycache__/AMT.cpython-39.pyc b/modules/components/amt/__pycache__/AMT.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59743eeae2f9b947f50179504a7f9cd44942894a Binary files /dev/null and b/modules/components/amt/__pycache__/AMT.cpython-39.pyc differ diff --git a/modules/components/amt/__pycache__/__init__.cpython-310.pyc b/modules/components/amt/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41cb62fa07907c493e81116d63a044d3e6c448cb Binary files /dev/null and b/modules/components/amt/__pycache__/__init__.cpython-310.pyc differ diff --git a/modules/components/amt/__pycache__/__init__.cpython-38.pyc b/modules/components/amt/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e93dfe9147dae70e88bc1afc3bbaf2c7e9b322f Binary files /dev/null and b/modules/components/amt/__pycache__/__init__.cpython-38.pyc differ diff --git a/modules/components/amt/__pycache__/__init__.cpython-39.pyc b/modules/components/amt/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b2f5c4b7a931cf3d1c52504607834edd46697e5 Binary files /dev/null and b/modules/components/amt/__pycache__/__init__.cpython-39.pyc differ diff --git a/modules/components/amt/blocks/__init__.py b/modules/components/amt/blocks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/components/amt/blocks/__pycache__/__init__.cpython-310.pyc b/modules/components/amt/blocks/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76f99571346ad4c2d862b2f2e41e353c1055aa71 Binary files /dev/null and b/modules/components/amt/blocks/__pycache__/__init__.cpython-310.pyc differ diff --git a/modules/components/amt/blocks/__pycache__/__init__.cpython-38.pyc b/modules/components/amt/blocks/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..665817f3c48eb9271cf37eb828977544c60d6b2e Binary files /dev/null and b/modules/components/amt/blocks/__pycache__/__init__.cpython-38.pyc differ diff --git a/modules/components/amt/blocks/__pycache__/__init__.cpython-39.pyc b/modules/components/amt/blocks/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b573a33491939f3aee32bc9c3a082887c8085bb1 Binary files /dev/null and b/modules/components/amt/blocks/__pycache__/__init__.cpython-39.pyc differ diff --git a/modules/components/amt/blocks/__pycache__/feat_enc.cpython-310.pyc b/modules/components/amt/blocks/__pycache__/feat_enc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab6926df9f07786e69593ed545636dc871a557fc Binary files /dev/null and b/modules/components/amt/blocks/__pycache__/feat_enc.cpython-310.pyc differ diff --git a/modules/components/amt/blocks/__pycache__/feat_enc.cpython-38.pyc b/modules/components/amt/blocks/__pycache__/feat_enc.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91837360217aeeda0703054b9867af65ad223a92 Binary files /dev/null and b/modules/components/amt/blocks/__pycache__/feat_enc.cpython-38.pyc differ diff --git a/modules/components/amt/blocks/__pycache__/feat_enc.cpython-39.pyc b/modules/components/amt/blocks/__pycache__/feat_enc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3297acbf026b88a16ab5b3b5680f9e49999176c2 Binary files /dev/null and b/modules/components/amt/blocks/__pycache__/feat_enc.cpython-39.pyc differ diff --git a/modules/components/amt/blocks/__pycache__/ifrnet.cpython-310.pyc b/modules/components/amt/blocks/__pycache__/ifrnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..788e18ccd7797407311049d38beff0797314bb68 Binary files /dev/null and b/modules/components/amt/blocks/__pycache__/ifrnet.cpython-310.pyc differ diff --git a/modules/components/amt/blocks/__pycache__/ifrnet.cpython-38.pyc b/modules/components/amt/blocks/__pycache__/ifrnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be8cf441a6f1db73fe3ffb147952f8da46e7000e Binary files /dev/null and b/modules/components/amt/blocks/__pycache__/ifrnet.cpython-38.pyc differ diff --git a/modules/components/amt/blocks/__pycache__/ifrnet.cpython-39.pyc b/modules/components/amt/blocks/__pycache__/ifrnet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d42fb9633be7aa9c6ccc527af2d6b71a071542c7 Binary files /dev/null and b/modules/components/amt/blocks/__pycache__/ifrnet.cpython-39.pyc differ diff --git a/modules/components/amt/blocks/__pycache__/multi_flow.cpython-310.pyc b/modules/components/amt/blocks/__pycache__/multi_flow.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b24cc8c54d2d633f0c45b9ab4013066a2956888a Binary files /dev/null and b/modules/components/amt/blocks/__pycache__/multi_flow.cpython-310.pyc differ diff --git a/modules/components/amt/blocks/__pycache__/multi_flow.cpython-38.pyc b/modules/components/amt/blocks/__pycache__/multi_flow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a78776e49ac634a093330b45081da785fd6dc744 Binary files /dev/null and b/modules/components/amt/blocks/__pycache__/multi_flow.cpython-38.pyc differ diff --git a/modules/components/amt/blocks/__pycache__/multi_flow.cpython-39.pyc b/modules/components/amt/blocks/__pycache__/multi_flow.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee9c0f891ca103eeb3a0a4a9c81aa36eca6c2a9e Binary files /dev/null and b/modules/components/amt/blocks/__pycache__/multi_flow.cpython-39.pyc differ diff --git a/modules/components/amt/blocks/__pycache__/raft.cpython-310.pyc b/modules/components/amt/blocks/__pycache__/raft.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e94797f7b7305ac7e059aa8bbffbe4911c9dd813 Binary files /dev/null and b/modules/components/amt/blocks/__pycache__/raft.cpython-310.pyc differ diff --git a/modules/components/amt/blocks/__pycache__/raft.cpython-38.pyc b/modules/components/amt/blocks/__pycache__/raft.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f627e3159d3a90d5c6e270854773a39b35a9603 Binary files /dev/null and b/modules/components/amt/blocks/__pycache__/raft.cpython-38.pyc differ diff --git a/modules/components/amt/blocks/__pycache__/raft.cpython-39.pyc b/modules/components/amt/blocks/__pycache__/raft.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40a438eaa02ee0f5516bf9f3db8e8cd3f6ad6fb0 Binary files /dev/null and b/modules/components/amt/blocks/__pycache__/raft.cpython-39.pyc differ diff --git a/modules/components/amt/blocks/__pycache__/warp.cpython-310.pyc b/modules/components/amt/blocks/__pycache__/warp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f98bca953c288bd440ac5380e25f1e4516402af6 Binary files /dev/null and b/modules/components/amt/blocks/__pycache__/warp.cpython-310.pyc differ diff --git a/modules/components/amt/blocks/__pycache__/warp.cpython-38.pyc b/modules/components/amt/blocks/__pycache__/warp.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2926b24fe0ce2a66bd12c24774916040cb01caba Binary files /dev/null and b/modules/components/amt/blocks/__pycache__/warp.cpython-38.pyc differ diff --git a/modules/components/amt/blocks/__pycache__/warp.cpython-39.pyc b/modules/components/amt/blocks/__pycache__/warp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72891a38e799b6a1680d5f2283d104ff4a1daf12 Binary files /dev/null and b/modules/components/amt/blocks/__pycache__/warp.cpython-39.pyc differ diff --git a/modules/components/amt/blocks/feat_enc.py b/modules/components/amt/blocks/feat_enc.py new file mode 100644 index 0000000000000000000000000000000000000000..3805bd315422703c19bf6a4d0962ee75002d92aa --- /dev/null +++ b/modules/components/amt/blocks/feat_enc.py @@ -0,0 +1,343 @@ +import torch +import torch.nn as nn + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) + self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes//4) + self.norm2 = nn.BatchNorm2d(planes//4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes//4) + self.norm2 = nn.InstanceNorm2d(planes//4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + +class SmallEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(SmallEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(32) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(32) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 32 + self.layer1 = self._make_layer(32, stride=1) + self.layer2 = self._make_layer(64, stride=2) + self.layer3 = self._make_layer(96, stride=2) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + +class BasicEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(72, stride=2) + self.layer3 = self._make_layer(128, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + +class LargeEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(LargeEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(112, stride=2) + self.layer3 = self._make_layer(160, stride=2) + self.layer3_2 = self._make_layer(160, stride=1) + + # output convolution + self.conv2 = nn.Conv2d(self.in_planes, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer3_2(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x diff --git a/modules/components/amt/blocks/ifrnet.py b/modules/components/amt/blocks/ifrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..5d6030fdb7bb6c15450ec46a470aa366bdf390be --- /dev/null +++ b/modules/components/amt/blocks/ifrnet.py @@ -0,0 +1,118 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .warp import warp + + +def resize(x, scale_factor): + return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False) + + +def convrelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True): + return nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias), + nn.PReLU(out_channels) + ) + + +class ResBlock(nn.Module): + def __init__(self, in_channels, side_channels, bias=True): + super(ResBlock, self).__init__() + self.side_channels = side_channels + self.conv1 = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(in_channels) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(side_channels) + ) + self.conv3 = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(in_channels) + ) + self.conv4 = nn.Sequential( + nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(side_channels) + ) + self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias) + self.prelu = nn.PReLU(in_channels) + + def forward(self, x): + out = self.conv1(x) + + res_feat = out[:, :-self.side_channels, ...] + side_feat = out[:, -self.side_channels:, :, :] + side_feat = self.conv2(side_feat) + out = self.conv3(torch.cat([res_feat, side_feat], 1)) + + res_feat = out[:, :-self.side_channels, ...] + side_feat = out[:, -self.side_channels:, :, :] + side_feat = self.conv4(side_feat) + out = self.conv5(torch.cat([res_feat, side_feat], 1)) + + out = self.prelu(x + out) + return out + + +class Encoder(nn.Module): + def __init__(self, channels, large=False): + super(Encoder, self).__init__() + self.channels = channels + prev_ch = 3 + for idx, ch in enumerate(channels, 1): + k = 7 if large and idx == 1 else 3 + p = 3 if k == 7 else 1 + self.register_module(f'pyramid{idx}', + nn.Sequential( + convrelu(prev_ch, ch, k, 2, p), + convrelu(ch, ch, 3, 1, 1) + )) + prev_ch = ch + + def forward(self, in_x): + fs = [] + for idx in range(len(self.channels)): + out_x = getattr(self, f'pyramid{idx + 1}')(in_x) + fs.append(out_x) + in_x = out_x + return fs + + +class InitDecoder(nn.Module): + def __init__(self, in_ch, out_ch, skip_ch) -> None: + super().__init__() + self.convblock = nn.Sequential( + convrelu(in_ch * 2 + 1, in_ch * 2), + ResBlock(in_ch * 2, skip_ch), + nn.ConvTranspose2d(in_ch * 2, out_ch + 4, 4, 2, 1, bias=True) + ) + + def forward(self, f0, f1, embt): + h, w = f0.shape[2:] + embt = embt.repeat(1, 1, h, w) + out = self.convblock(torch.cat([f0, f1, embt], 1)) + flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1) + ft_ = out[:, 4:, ...] + return flow0, flow1, ft_ + + +class IntermediateDecoder(nn.Module): + def __init__(self, in_ch, out_ch, skip_ch) -> None: + super().__init__() + self.convblock = nn.Sequential( + convrelu(in_ch * 3 + 4, in_ch * 3), + ResBlock(in_ch * 3, skip_ch), + nn.ConvTranspose2d(in_ch * 3, out_ch + 4, 4, 2, 1, bias=True) + ) + + def forward(self, ft_, f0, f1, flow0_in, flow1_in): + f0_warp = warp(f0, flow0_in) + f1_warp = warp(f1, flow1_in) + f_in = torch.cat([ft_, f0_warp, f1_warp, flow0_in, flow1_in], 1) + out = self.convblock(f_in) + flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1) + ft_ = out[:, 4:, ...] + flow0 = flow0 + 2.0 * resize(flow0_in, scale_factor=2.0) + flow1 = flow1 + 2.0 * resize(flow1_in, scale_factor=2.0) + return flow0, flow1, ft_ diff --git a/modules/components/amt/blocks/multi_flow.py b/modules/components/amt/blocks/multi_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..21097edfa11300d4a11a63e17b3dc793fe0b893d --- /dev/null +++ b/modules/components/amt/blocks/multi_flow.py @@ -0,0 +1,69 @@ +import torch +import torch.nn as nn +from .warp import warp +from .ifrnet import ( + convrelu, resize, + ResBlock, +) + + +def multi_flow_combine(comb_block, img0, img1, flow0, flow1, + mask=None, img_res=None, mean=None): + ''' + A parallel implementation of multiple flow field warping + comb_block: An nn.Seqential object. + img shape: [b, c, h, w] + flow shape: [b, 2*num_flows, h, w] + mask (opt): + If 'mask' is None, the function conduct a simple average. + img_res (opt): + If 'img_res' is None, the function adds zero instead. + mean (opt): + If 'mean' is None, the function adds zero instead. + ''' + b, c, h, w = flow0.shape + num_flows = c // 2 + flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) + flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) + + mask = mask.reshape(b, num_flows, 1, h, w + ).reshape(-1, 1, h, w) if mask is not None else None + img_res = img_res.reshape(b, num_flows, 3, h, w + ).reshape(-1, 3, h, w) if img_res is not None else 0 + img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w) + img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w) + mean = torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1 + ) if mean is not None else 0 + + img0_warp = warp(img0, flow0) + img1_warp = warp(img1, flow1) + img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res + img_warps = img_warps.reshape(b, num_flows, 3, h, w) + imgt_pred = img_warps.mean(1) + comb_block(img_warps.view(b, -1, h, w)) + return imgt_pred + + +class MultiFlowDecoder(nn.Module): + def __init__(self, in_ch, skip_ch, num_flows=3): + super(MultiFlowDecoder, self).__init__() + self.num_flows = num_flows + self.convblock = nn.Sequential( + convrelu(in_ch*3+4, in_ch*3), + ResBlock(in_ch*3, skip_ch), + nn.ConvTranspose2d(in_ch*3, 8*num_flows, 4, 2, 1, bias=True) + ) + + def forward(self, ft_, f0, f1, flow0, flow1): + n = self.num_flows + f0_warp = warp(f0, flow0) + f1_warp = warp(f1, flow1) + out = self.convblock(torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1)) + delta_flow0, delta_flow1, mask, img_res = torch.split(out, [2*n, 2*n, n, 3*n], 1) + mask = torch.sigmoid(mask) + + flow0 = delta_flow0 + 2.0 * resize(flow0, scale_factor=2.0 + ).repeat(1, self.num_flows, 1, 1) + flow1 = delta_flow1 + 2.0 * resize(flow1, scale_factor=2.0 + ).repeat(1, self.num_flows, 1, 1) + + return flow0, flow1, mask, img_res \ No newline at end of file diff --git a/modules/components/amt/blocks/raft.py b/modules/components/amt/blocks/raft.py new file mode 100644 index 0000000000000000000000000000000000000000..0529ddc785b9645fb5fa17afa41efa3d257984b8 --- /dev/null +++ b/modules/components/amt/blocks/raft.py @@ -0,0 +1,207 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def resize(x, scale_factor): + return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False) + + +def bilinear_sampler(img, coords, mask=False): + """ Wrapper for grid_sample, uses pixel coordinates """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1,1], dim=-1) + xgrid = 2*xgrid/(W-1) - 1 + ygrid = 2*ygrid/(H-1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +def coords_grid(batch, ht, wd, device): + coords = torch.meshgrid(torch.arange(ht, device=device), + torch.arange(wd, device=device), + indexing='ij') + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +class SmallUpdateBlock(nn.Module): + def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, fc_dim, + corr_levels=4, radius=3, scale_factor=None): + super(SmallUpdateBlock, self).__init__() + cor_planes = corr_levels * (2 * radius + 1) **2 + self.scale_factor = scale_factor + + self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0) + self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3) + self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1) + self.conv = nn.Conv2d(corr_dim+flow_dim, fc_dim, 3, padding=1) + + self.gru = nn.Sequential( + nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + ) + + self.feat_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, cdim, 3, padding=1), + ) + + self.flow_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, 4, 3, padding=1), + ) + + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, net, flow, corr): + net = resize(net, 1 / self.scale_factor + ) if self.scale_factor is not None else net + cor = self.lrelu(self.convc1(corr)) + flo = self.lrelu(self.convf1(flow)) + flo = self.lrelu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + inp = self.lrelu(self.conv(cor_flo)) + inp = torch.cat([inp, flow, net], dim=1) + + out = self.gru(inp) + delta_net = self.feat_head(out) + delta_flow = self.flow_head(out) + + if self.scale_factor is not None: + delta_net = resize(delta_net, scale_factor=self.scale_factor) + delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor) + + return delta_net, delta_flow + + +class BasicUpdateBlock(nn.Module): + def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, corr_dim2, + fc_dim, corr_levels=4, radius=3, scale_factor=None, out_num=1): + super(BasicUpdateBlock, self).__init__() + cor_planes = corr_levels * (2 * radius + 1) ** 2 + + self.scale_factor = scale_factor + self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0) + self.convc2 = nn.Conv2d(corr_dim, corr_dim2, 3, padding=1) + self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3) + self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1) + self.conv = nn.Conv2d(flow_dim+corr_dim2, fc_dim, 3, padding=1) + + self.gru = nn.Sequential( + nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + ) + + self.feat_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, cdim, 3, padding=1), + ) + + self.flow_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, 4*out_num, 3, padding=1), + ) + + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, net, flow, corr): + net = resize(net, 1 / self.scale_factor + ) if self.scale_factor is not None else net + cor = self.lrelu(self.convc1(corr)) + cor = self.lrelu(self.convc2(cor)) + flo = self.lrelu(self.convf1(flow)) + flo = self.lrelu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + inp = self.lrelu(self.conv(cor_flo)) + inp = torch.cat([inp, flow, net], dim=1) + + out = self.gru(inp) + delta_net = self.feat_head(out) + delta_flow = self.flow_head(out) + + if self.scale_factor is not None: + delta_net = resize(delta_net, scale_factor=self.scale_factor) + delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor) + return delta_net, delta_flow + + +class BidirCorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + self.corr_pyramid = [] + self.corr_pyramid_T = [] + + corr = BidirCorrBlock.corr(fmap1, fmap2) + batch, h1, w1, dim, h2, w2 = corr.shape + corr_T = corr.clone().permute(0, 4, 5, 3, 1, 2) + + corr = corr.reshape(batch*h1*w1, dim, h2, w2) + corr_T = corr_T.reshape(batch*h2*w2, dim, h1, w1) + + self.corr_pyramid.append(corr) + self.corr_pyramid_T.append(corr_T) + + for _ in range(self.num_levels-1): + corr = F.avg_pool2d(corr, 2, stride=2) + corr_T = F.avg_pool2d(corr_T, 2, stride=2) + self.corr_pyramid.append(corr) + self.corr_pyramid_T.append(corr_T) + + def __call__(self, coords0, coords1): + r = self.radius + coords0 = coords0.permute(0, 2, 3, 1) + coords1 = coords1.permute(0, 2, 3, 1) + assert coords0.shape == coords1.shape, f"coords0 shape: [{coords0.shape}] is not equal to [{coords1.shape}]" + batch, h1, w1, _ = coords0.shape + + out_pyramid = [] + out_pyramid_T = [] + for i in range(self.num_levels): + corr = self.corr_pyramid[i] + corr_T = self.corr_pyramid_T[i] + + dx = torch.linspace(-r, r, 2*r+1, device=coords0.device) + dy = torch.linspace(-r, r, 2*r+1, device=coords0.device) + delta = torch.stack(torch.meshgrid(dy, dx, indexing='ij'), axis=-1) + delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) + + centroid_lvl_0 = coords0.reshape(batch*h1*w1, 1, 1, 2) / 2**i + centroid_lvl_1 = coords1.reshape(batch*h1*w1, 1, 1, 2) / 2**i + coords_lvl_0 = centroid_lvl_0 + delta_lvl + coords_lvl_1 = centroid_lvl_1 + delta_lvl + + corr = bilinear_sampler(corr, coords_lvl_0) + corr_T = bilinear_sampler(corr_T, coords_lvl_1) + corr = corr.view(batch, h1, w1, -1) + corr_T = corr_T.view(batch, h1, w1, -1) + out_pyramid.append(corr) + out_pyramid_T.append(corr_T) + + out = torch.cat(out_pyramid, dim=-1) + out_T = torch.cat(out_pyramid_T, dim=-1) + return out.permute(0, 3, 1, 2).contiguous().float(), out_T.permute(0, 3, 1, 2).contiguous().float() + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht*wd) + fmap2 = fmap2.view(batch, dim, ht*wd) + + corr = torch.matmul(fmap1.transpose(1,2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) \ No newline at end of file diff --git a/modules/components/amt/blocks/warp.py b/modules/components/amt/blocks/warp.py new file mode 100644 index 0000000000000000000000000000000000000000..89c63449c52bc12b73cc94b29c1d96a305365270 --- /dev/null +++ b/modules/components/amt/blocks/warp.py @@ -0,0 +1,13 @@ +import torch +import torch.nn.functional as F + + +def warp(img, flow): + B, _, H, W = flow.shape + xx = torch.linspace(-1.0, 1.0, W).view(1, 1, 1, W).expand(B, -1, H, -1) + yy = torch.linspace(-1.0, 1.0, H).view(1, 1, H, 1).expand(B, -1, -1, W) + grid = torch.cat([xx, yy], 1).to(img) + flow_ = torch.cat([flow[:, 0:1, :, :] / ((W - 1.0) / 2.0), flow[:, 1:2, :, :] / ((H - 1.0) / 2.0)], 1) + grid_ = (grid + flow_).permute(0, 2, 3, 1) + output = F.grid_sample(input=img, grid=grid_, mode='bilinear', padding_mode='border', align_corners=True) + return output diff --git a/modules/components/amt_bilateral/AMT.py b/modules/components/amt_bilateral/AMT.py new file mode 100644 index 0000000000000000000000000000000000000000..2fda5c96551d9e6d4fc6eedcfb9a258b25733675 --- /dev/null +++ b/modules/components/amt_bilateral/AMT.py @@ -0,0 +1,201 @@ +import torch +import torch.nn as nn +from modules.components.amt_bilateral.blocks.raft import ( + coords_grid, + SmallUpdateBlock, BidirCorrBlock, BasicUpdateBlock +) +from .blocks.feat_enc import ( + SmallEncoder, + BasicEncoder, + LargeEncoder +) +from .blocks.ifrnet import ( + resize, + Encoder, + InitDecoder, + IntermediateDecoder +) +from .blocks.multi_flow import ( + multi_flow_combine, + MultiFlowDecoder +) + +from .blocks.BilateralCorrelation_NN import bilateralcorrelation_nn + +from ..components import register + +from utils.padder import InputPadder + + +@register('amt_bilateral') +class Model(nn.Module): + def __init__(self, + model_size='S', + corr_radius=3, + corr_lvls=4, + num_flows=3, + channels=[20, 32, 44, 56], + skip_channels=20, + scale_factor=1): + super(Model, self).__init__() + self.model_size = model_size + self.radius = corr_radius + self.corr_levels = corr_lvls + self.num_flows = num_flows + self.channels = channels + self.skip_channels = skip_channels + self.scale_factor = scale_factor + if self.model_size == 'S': + self.feat_encoder = SmallEncoder(output_dim=84, norm_fn='instance', dropout=0.) + elif self.model_size == 'L': + self.feat_encoder = BasicEncoder(output_dim=128, norm_fn='instance', dropout=0.) + elif self.model_size == 'G': + self.feat_encoder = LargeEncoder(output_dim=128, norm_fn='instance', dropout=0.) + self.encoder = Encoder(channels, large=True) + + self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels) + self.decoder3 = IntermediateDecoder(channels[2], channels[1], skip_channels) + self.decoder2 = IntermediateDecoder(channels[1], channels[0], skip_channels) + self.decoder1 = MultiFlowDecoder(channels[0], skip_channels, num_flows) + + self.update4 = self._get_updateblock(channels[2]) + self.update3 = self._get_updateblock(channels[1], 2) + self.update2 = self._get_updateblock(channels[0], 4) + + if self.model_size == 'G': + self.update3_high = self._get_updateblock(channels[1], None) + self.update2_high = self._get_updateblock(channels[0], None) + + self.comb_block = nn.Sequential( + nn.Conv2d(3 * self.num_flows, 6 * self.num_flows, 7, 1, 3), + nn.PReLU(6 * self.num_flows), + nn.Conv2d(6 * self.num_flows, 3, 7, 1, 3), + ) + + def _get_updateblock(self, cdim, scale_factor=None): + return BasicUpdateBlock(cdim=cdim, hidden_dim=192, flow_dim=64, + corr_dim=256, corr_dim2=192, fc_dim=188, + scale_factor=scale_factor, corr_levels=self.corr_levels, + radius=self.radius) + + def _corr_scale_lookup(self, corr_fn, coord, flow0, flow1, embt, downsample=1): + # convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0 + # based on linear assumption + # t1_scale = 1. / (1. - embt) + # t0_scale = 1. / embt + if downsample != 1: + inv = 1 / downsample + flow0 = inv * resize(flow0, scale_factor=inv) + flow1 = inv * resize(flow1, scale_factor=inv) + + corr0, corr1 = corr_fn(flow0, flow1, embt) + corr = torch.cat([corr0, corr1], dim=1) + flow = torch.cat([flow0, flow1], dim=1) + return corr, flow + + def forward(self, img0, img1, time_step, scale_factor=1.0, eval=False, **kwargs): + scale_factor = self.scale_factor + padder = InputPadder(img0.shape, divisor=int(16 / scale_factor)) + img0, img1 = padder.pad(img0, img1) + mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) + img0 = img0 - mean_ + img1 = img1 - mean_ + img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0 + img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1 + b, _, h, w = img0_.shape + coord = coords_grid(b, h // 8, w // 8, img0.device) + + fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8] + corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels) + + # f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4] + # f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16] + f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_) + f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_) + + ######################################### the 4th decoder ######################################### + up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, time_step) + corr_4, flow_4 = self._corr_scale_lookup(corr_fn, coord, + up_flow0_4, up_flow1_4, + time_step, downsample=1) + + # residue update with lookup corr + delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4) + delta_flow0_4, delta_flow1_4 = torch.chunk(delta_flow_4, 2, 1) + up_flow0_4 = up_flow0_4 + delta_flow0_4 + up_flow1_4 = up_flow1_4 + delta_flow1_4 + ft_3_ = ft_3_ + delta_ft_3_ + + ######################################### the 3rd decoder ######################################### + up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4) + corr_3, flow_3 = self._corr_scale_lookup(corr_fn, + coord, up_flow0_3, up_flow1_3, + time_step, downsample=2) + + # residue update with lookup corr + delta_ft_2_, delta_flow_3 = self.update3(ft_2_, flow_3, corr_3) + delta_flow0_3, delta_flow1_3 = torch.chunk(delta_flow_3, 2, 1) + up_flow0_3 = up_flow0_3 + delta_flow0_3 + up_flow1_3 = up_flow1_3 + delta_flow1_3 + ft_2_ = ft_2_ + delta_ft_2_ + + if self.model_size == 'G': + # residue update with lookup corr (hr) + corr_3 = resize(corr_3, scale_factor=2.0) + up_flow_3 = torch.cat([up_flow0_3, up_flow1_3], dim=1) + delta_ft_2_, delta_up_flow_3 = self.update3_high(ft_2_, up_flow_3, corr_3) + ft_2_ += delta_ft_2_ + up_flow0_3 += delta_up_flow_3[:, 0:2] + up_flow1_3 += delta_up_flow_3[:, 2:4] + + ######################################### the 2nd decoder ######################################### + up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3) + corr_2, flow_2 = self._corr_scale_lookup(corr_fn, + coord, up_flow0_2, up_flow1_2, + time_step, downsample=4) + + # residue update with lookup corr + delta_ft_1_, delta_flow_2 = self.update2(ft_1_, flow_2, corr_2) + delta_flow0_2, delta_flow1_2 = torch.chunk(delta_flow_2, 2, 1) + up_flow0_2 = up_flow0_2 + delta_flow0_2 + up_flow1_2 = up_flow1_2 + delta_flow1_2 + ft_1_ = ft_1_ + delta_ft_1_ + + if self.model_size == 'G': + # residue update with lookup corr (hr) + corr_2 = resize(corr_2, scale_factor=4.0) + up_flow_2 = torch.cat([up_flow0_2, up_flow1_2], dim=1) + delta_ft_1_, delta_up_flow_2 = self.update2_high(ft_1_, up_flow_2, corr_2) + ft_1_ += delta_ft_1_ + up_flow0_2 += delta_up_flow_2[:, 0:2] + up_flow1_2 += delta_up_flow_2[:, 2:4] + + ######################################### the 1st decoder ######################################### + up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2) + + if scale_factor != 1.0: + up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) + up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) + mask = resize(mask, scale_factor=(1.0/scale_factor)) + img_res = resize(img_res, scale_factor=(1.0/scale_factor)) + # up_flow0_1, up_flow1_1, mask, img_res = padder.unpad(up_flow0_1, up_flow1_1, mask, img_res) + + # Merge multiple predictions + imgt_pred = multi_flow_combine(self.comb_block, img0, img1, up_flow0_1, up_flow1_1, + mask, img_res, mean_) + imgt_pred = torch.clamp(imgt_pred, 0, 1) + imgt_pred = padder.unpad(imgt_pred) + + if eval: + return { 'imgt_pred': imgt_pred, } + else: + up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, int(h/scale_factor), int(w/scale_factor)) + up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, int(h/scale_factor), int(w/scale_factor)) + return { + 'imgt_pred': imgt_pred, + 'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4], + 'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4], + 'flowfwd': up_flow0_1[:, 0], + 'flowbwd': up_flow1_1[:, 0], + 'ft_pred': [ft_1_, ft_2_, ft_3_], + } diff --git a/modules/components/amt_bilateral/__init__.py b/modules/components/amt_bilateral/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..387563589f3cc4a3dc664a5d8d4f5557b0100996 --- /dev/null +++ b/modules/components/amt_bilateral/__init__.py @@ -0,0 +1 @@ +from .AMT import Model \ No newline at end of file diff --git a/modules/components/amt_bilateral/__pycache__/AMT.cpython-310.pyc b/modules/components/amt_bilateral/__pycache__/AMT.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06e06f1d44fd0e28460002c90671a3ff6c53a673 Binary files /dev/null and b/modules/components/amt_bilateral/__pycache__/AMT.cpython-310.pyc differ diff --git a/modules/components/amt_bilateral/__pycache__/AMT.cpython-38.pyc b/modules/components/amt_bilateral/__pycache__/AMT.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3913688abe2751ca950e017b75eddeeb2b83f2ff Binary files /dev/null and b/modules/components/amt_bilateral/__pycache__/AMT.cpython-38.pyc differ diff --git a/modules/components/amt_bilateral/__pycache__/AMT.cpython-39.pyc b/modules/components/amt_bilateral/__pycache__/AMT.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f19fbc660e6a417ed52b91c724a528e68d932b4f Binary files /dev/null and b/modules/components/amt_bilateral/__pycache__/AMT.cpython-39.pyc differ diff --git a/modules/components/amt_bilateral/__pycache__/__init__.cpython-310.pyc b/modules/components/amt_bilateral/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..675b80fd7ff220a3b6594e684e49a3ed39a960c9 Binary files /dev/null and b/modules/components/amt_bilateral/__pycache__/__init__.cpython-310.pyc differ diff --git a/modules/components/amt_bilateral/__pycache__/__init__.cpython-38.pyc b/modules/components/amt_bilateral/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9744c2a839f3c2546e99d0dea47c04ad668a87f1 Binary files /dev/null and b/modules/components/amt_bilateral/__pycache__/__init__.cpython-38.pyc differ diff --git a/modules/components/amt_bilateral/__pycache__/__init__.cpython-39.pyc b/modules/components/amt_bilateral/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b918d36a547a4f86022060baab79e34a37adb0b Binary files /dev/null and b/modules/components/amt_bilateral/__pycache__/__init__.cpython-39.pyc differ diff --git a/modules/components/amt_bilateral/blocks/BilateralCorrelation_NN.py b/modules/components/amt_bilateral/blocks/BilateralCorrelation_NN.py new file mode 100644 index 0000000000000000000000000000000000000000..ea1a12e2dbe0a09a01c106e6dc1423fa746602af --- /dev/null +++ b/modules/components/amt_bilateral/blocks/BilateralCorrelation_NN.py @@ -0,0 +1,548 @@ +import cupy +import torch +import re +import math + +correlation_forward = ''' + extern "C" __global__ void correlation_forward( + const int n, + const float* feature1, + const float* feature2, + const float* flow, + const float* time, + float* output + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + + float fltOutput = 0.0; + + const int intN = ( intIndex / SIZE_3(output) / SIZE_2(output) / SIZE_1(output) ) % SIZE_0(output); + const int intC = ( intIndex / SIZE_3(output) / SIZE_2(output) ) % SIZE_1(output); + const int intY = ( intIndex / SIZE_3(output) ) % SIZE_2(output); + const int intX = ( intIndex ) % SIZE_3(output); + + + int k = (intC % F_SIZE) - F_SIZE_H; + int l = (intC / F_SIZE) - F_SIZE_H; + + float t = VALUE_2(time,intN, 0); + + float ratio_x = (float) (SIZE_3(feature1)) / SIZE_3(flow); + float ratio_y = (float) (SIZE_2(feature1)) / SIZE_2(flow); + + float fltX1 = ((float) intX + VALUE_4(flow, intN, 0, intY, intX) * -2 * t) * ratio_x - k * -2 * t; + float fltY1 = ((float) intY + VALUE_4(flow, intN, 1, intY, intX) * -2 * t) * ratio_y - l * -2 * t; + float fltX2 = ((float) intX + VALUE_4(flow, intN, 0, intY, intX) * 2 * (1 - t)) * ratio_x - k * 2 * (1 - t); + float fltY2 = ((float) intY + VALUE_4(flow, intN, 1, intY, intX) * 2 * (1 - t)) * ratio_y - l * 2 * (1 - t); + + int intLX1 = (int) (floor(fltX1)); + int intTY1 = (int) (floor(fltY1)); + int intRX1 = intLX1 + 1; + int intBY1 = intTY1 + 1; + + int intRX2 = (int) (ceil(fltX2)); + int intBY2 = (int) (ceil(fltY2)); + int intLX2 = intRX2 - 1; + int intTY2 = intBY2 - 1; + + float fltnw = ((float) intRX1 - fltX1) * ((float) intBY1 - fltY1); + float fltne = (fltX1 - (float) intLX1) * ((float) intBY1 - fltY1); + float fltsw = ((float) intRX1 - fltX1) * (fltY1 - (float) intTY1); + float fltse = (fltX1 - (float) intLX1) * (fltY1 - (float) intTY1); + + if ((intRX1 >= 0) & (intBY1 >= 0) & (intLX1 < SIZE_3(feature1)) & (intTY1 < SIZE_2(feature1))) { + if ((intRX2 >= 0) & (intBY2 >= 0) & (intLX2 < SIZE_3(feature1)) & (intTY2 < SIZE_2(feature1))) { + for (int intChannel = 0; intChannel < SIZE_1(feature1); intChannel += 1) { + float fltF1 = 0.0; + float fltF2 = 0.0; + if ((intLX1 >= 0) & (intTY1 >= 0)) { + fltF1 += VALUE_4(feature1, intN, intChannel, intTY1, intLX1) * fltnw; + } + if ((intRX2 < SIZE_3(feature1)) & (intBY2 < SIZE_2(feature1))) { + fltF2 += VALUE_4(feature2, intN, intChannel, intBY2, intRX2) * fltnw; + } + if ((intLX1 >= 0) & (intBY1 < SIZE_2(feature1))) { + fltF1 += VALUE_4(feature1, intN, intChannel, intBY1, intLX1) * fltsw; + } + if ((intRX2 < SIZE_3(feature2)) & (intTY2 >= 0)) { + fltF2 += VALUE_4(feature2, intN, intChannel, intTY2, intRX2) * fltsw; + } + if ((intRX1 < SIZE_3(feature1)) & (intTY1 >= 0)) { + fltF1 += VALUE_4(feature1, intN, intChannel, intTY1, intRX1) * fltne; + } + if ((intLX2 >= 0) & (intBY2 < SIZE_2(feature1))) { + fltF2 += VALUE_4(feature2, intN, intChannel, intBY2, intLX2) * fltne; + } + if ((intRX1 < SIZE_3(feature1)) & (intBY1 < SIZE_2(feature1))) { + fltF1 += VALUE_4(feature1, intN, intChannel, intBY1, intRX1) * fltse; + } + if ((intLX2 >= 0) & (intTY2 >= 0)) { + fltF2 += VALUE_4(feature2, intN, intChannel, intTY2, intLX2) * fltse; + } + fltOutput += fltF1 * fltF2; + } + } + } + output[intIndex] = fltOutput; + } } +''' + +correlation_coords_forward = ''' + extern "C" __global__ void correlation_coords_forward( + const int n, + const float* feature1, + const float* feature2, + const float* coords_bw, + const float* coords_fw, + const float* time, + float* output + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + + float fltOutput = 0.0; + + const int intN = ( intIndex / SIZE_3(output) / SIZE_2(output) / SIZE_1(output) ) % SIZE_0(output); + const int intC = ( intIndex / SIZE_3(output) / SIZE_2(output) ) % SIZE_1(output); + const int intY = ( intIndex / SIZE_3(output) ) % SIZE_2(output); + const int intX = ( intIndex ) % SIZE_3(output); + + int k = (intC % F_SIZE) - F_SIZE_H; + int l = (intC / F_SIZE) - F_SIZE_H; + + float t = VALUE_2(time,intN, 0); + + float fltX1 = VALUE_4(coords_bw, intN, 0, intY, intX) - (-2 * t * k); + + int intLX1 = (int) (floor(fltX1)); + int intTY1 = (int) (floor(fltY1)); + int intRX1 = intLX1 + 1; + int intBY1 = intTY1 + 1; + + int intRX2 = (int) (ceil(fltX2)); + int intBY2 = (int) (ceil(fltY2)); + int intLX2 = intRX2 - 1; + int intTY2 = intBY2 - 1; + + float fltnw = ((float) intRX1 - fltX1) * ((float) intBY1 - fltY1); + float fltne = (fltX1 - (float) intLX1) * ((float) intBY1 - fltY1); + float fltsw = ((float) intRX1 - fltX1) * (fltY1 - (float) intTY1); + float fltse = (fltX1 - (float) intLX1) * (fltY1 - (float) intTY1); + + if ((intRX1 >= 0) & (intBY1 >= 0) & (intLX1 < SIZE_3(output)) & (intTY1 < SIZE_2(output))) { + if ((intRX2 >= 0) & (intBY2 >= 0) & (intLX2 < SIZE_3(output)) & (intTY2 < SIZE_2(output))) { + for (int intChannel = 0; intChannel < SIZE_1(feature1); intChannel += 1) { + if ((intLX1 >= 0) & (intTY1 >= 0) & (intRX2 < SIZE_3(output)) & (intBY2 < SIZE_2(output))) { + fltOutput += VALUE_4(feature1, intN, intChannel, intTY1, intLX1) * VALUE_4(feature2, intN, intChannel, intBY2, intRX2) * fltnw; + } + if ((intLX1 >= 0) & (intBY1 < SIZE_2(feature1)) & (intRX2 < SIZE_3(feature2)) & (intTY2 >= 0)) { + fltOutput += VALUE_4(feature1, intN, intChannel, intBY1, intLX1) * VALUE_4(feature2, intN, intChannel, intTY2, intRX2) * fltsw; + } + if ((intRX1 < SIZE_3(output)) & (intTY1 >= 0) & (intLX2 >= 0) & (intBY2 < SIZE_2(output))) { + fltOutput += VALUE_4(feature1, intN, intChannel, intTY1, intRX1) * VALUE_4(feature2, intN, intChannel, intBY2, intLX2) * fltne; + } + if ((intRX1 < SIZE_3(output)) & (intBY1 < SIZE_2(output)) & (intLX2 >= 0) & (intTY2 >= 0)) { + fltOutput += VALUE_4(feature1, intN, intChannel, intBY1, intRX1) * VALUE_4(feature2, intN, intChannel, intTY2, intLX2) * fltse; + } + } + } + } + output[intIndex] = fltOutput; + } } +''' + +correlation_backward_feature = ''' + extern "C" __global__ void correlation_backward_feature( + const int n, + const float* gradLoss, + const float* feature1, + const float* feature2, + const float* flow, + const float* time, + float* gradInput1, + float* gradInput2 + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + + const int intN = ( intIndex / SIZE_3(gradLoss) / SIZE_2(gradLoss) / SIZE_1(gradLoss) ) % SIZE_0(gradLoss); + const int intC = ( intIndex / SIZE_3(gradLoss) / SIZE_2(gradLoss) ) % SIZE_1(gradLoss); + const int intY = ( intIndex / SIZE_3(gradLoss) ) % SIZE_2(gradLoss); + const int intX = ( intIndex ) % SIZE_3(gradLoss); + + int k = (intC % F_SIZE) - F_SIZE_H; + int l = (intC / F_SIZE) - F_SIZE_H; + + float t = VALUE_2(time,intN, 0); + + float ratio_x = (float) (SIZE_3(feature1)) / SIZE_3(flow); + float ratio_y = (float) (SIZE_2(feature1)) / SIZE_2(flow); + + float fltX1 = ((float) intX + VALUE_4(flow, intN, 0, intY, intX) * -2 * t) * ratio_x - k * -2 * t; + float fltY1 = ((float) intY + VALUE_4(flow, intN, 1, intY, intX) * -2 * t) * ratio_y - l * -2 * t; + float fltX2 = ((float) intX + VALUE_4(flow, intN, 0, intY, intX) * 2 * (1 - t)) * ratio_x - k * 2 * (1 - t); + float fltY2 = ((float) intY + VALUE_4(flow, intN, 1, intY, intX) * 2 * (1 - t)) * ratio_y - l * 2 * (1 - t); + + int intLX1 = (int) (floor(fltX1)); + int intTY1 = (int) (floor(fltY1)); + int intRX1 = intLX1 + 1; + int intBY1 = intTY1 + 1; + + int intRX2 = (int) (ceil(fltX2)); + int intBY2 = (int) (ceil(fltY2)); + int intLX2 = intRX2 - 1; + int intTY2 = intBY2 - 1; + + float fltnw = ((float) intRX1 - fltX1) * ((float) intBY1 - fltY1); + float fltne = (fltX1 - (float) intLX1) * ((float) intBY1 - fltY1); + float fltsw = ((float) intRX1 - fltX1) * (fltY1 - (float) intTY1); + float fltse = (fltX1 - (float) intLX1) * (fltY1 - (float) intTY1); + + if ((intRX1 >= 0) & (intBY1 >= 0) & (intLX1 < SIZE_3(feature1)) & (intTY1 < SIZE_2(feature1))) { + if ((intRX2 >= 0) & (intBY2 >= 0) & (intLX2 < SIZE_3(feature1)) & (intTY2 < SIZE_2(feature1))) { + float fltLoss = VALUE_4(gradLoss, intN, intC, intY, intX); + for (int intChannel = 0; intChannel < SIZE_1(feature1); intChannel += 1) { + float fltF1 = 0.0; + float fltF2 = 0.0; + if ((intLX1 >= 0) & (intTY1 >= 0)) { + fltF1 += VALUE_4(feature1, intN, intChannel, intTY1, intLX1) * fltnw; + } + if ((intRX2 < SIZE_3(feature1)) & (intBY2 < SIZE_2(feature1))) { + fltF2 += VALUE_4(feature2, intN, intChannel, intBY2, intRX2) * fltnw; + } + if ((intLX1 >= 0) & (intBY1 < SIZE_2(feature1))) { + fltF1 += VALUE_4(feature1, intN, intChannel, intBY1, intLX1) * fltsw; + } + if ((intRX2 < SIZE_3(feature2)) & (intTY2 >= 0)) { + fltF2 += VALUE_4(feature2, intN, intChannel, intTY2, intRX2) * fltsw; + } + if ((intRX1 < SIZE_3(feature1)) & (intTY1 >= 0)) { + fltF1 += VALUE_4(feature1, intN, intChannel, intTY1, intRX1) * fltne; + } + if ((intLX2 >= 0) & (intBY2 < SIZE_2(feature1))) { + fltF2 += VALUE_4(feature2, intN, intChannel, intBY2, intLX2) * fltne; + } + if ((intRX1 < SIZE_3(feature1)) & (intBY1 < SIZE_2(feature1))) { + fltF1 += VALUE_4(feature1, intN, intChannel, intBY1, intRX1) * fltse; + } + if ((intLX2 >= 0) & (intTY2 >= 0)) { + fltF2 += VALUE_4(feature2, intN, intChannel, intTY2, intLX2) * fltse; + } + if ((intLX1 >= 0) & (intTY1 >= 0)) { + atomicAdd(&gradInput1[OFFSET_4(gradInput1, intN, intChannel, intTY1, intLX1)], fltF2 * fltnw * fltLoss); + } + if ((intRX2 < SIZE_3(feature1)) & (intBY2 < SIZE_2(feature1))) { + atomicAdd(&gradInput2[OFFSET_4(gradInput2, intN, intChannel, intBY2, intRX2)], fltF1 * fltnw * fltLoss); + } + if ((intLX1 >= 0) & (intBY1 < SIZE_2(feature1))) { + atomicAdd(&gradInput1[OFFSET_4(gradInput1, intN, intChannel, intBY1, intLX1)], fltF2 * fltsw * fltLoss); + } + if ((intRX2 < SIZE_3(feature2)) & (intTY2 >= 0)) { + atomicAdd(&gradInput2[OFFSET_4(gradInput2, intN, intChannel, intTY2, intRX2)], fltF1 * fltsw * fltLoss); + } + if ((intRX1 < SIZE_3(feature1)) & (intTY1 >= 0)) { + atomicAdd(&gradInput1[OFFSET_4(gradInput1, intN, intChannel, intTY1, intRX1)], fltF2 * fltne * fltLoss); + } + if ((intLX2 >= 0) & (intBY2 < SIZE_2(feature1))) { + atomicAdd(&gradInput2[OFFSET_4(gradInput2, intN, intChannel, intBY2, intLX2)], fltF1 * fltne * fltLoss); + } + if ((intRX1 < SIZE_3(feature1)) & (intBY1 < SIZE_2(feature1))) { + atomicAdd(&gradInput1[OFFSET_4(gradInput1, intN, intChannel, intBY1, intRX1)], fltF2 * fltse * fltLoss); + } + if ((intLX2 >= 0) & (intTY2 >= 0)) { + atomicAdd(&gradInput2[OFFSET_4(gradInput2, intN, intChannel, intTY2, intLX2)], fltF1 * fltse * fltLoss); + } + } + } + } + } } +''' + +correlation_backward_flow = ''' + extern "C" __global__ void correlation_backward_flow( + const int n, + const float* gradLoss, + const float* feature1, + const float* feature2, + const float* flow, + const float* time, + float* gradFlow + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + + const int intN = ( intIndex / SIZE_3(gradLoss) / SIZE_2(gradLoss) / SIZE_1(gradLoss) ) % SIZE_0(gradLoss); + const int intC = ( intIndex / SIZE_3(gradLoss) / SIZE_2(gradLoss) ) % SIZE_1(gradLoss); + const int intY = ( intIndex / SIZE_3(gradLoss) ) % SIZE_2(gradLoss); + const int intX = ( intIndex ) % SIZE_3(gradLoss); + + int k = (intC % F_SIZE) - F_SIZE_H; + int l = (intC / F_SIZE) - F_SIZE_H; + + float t = VALUE_2(time,intN, 0); + + float ratio_x = (float) (SIZE_3(feature1)) / SIZE_3(flow); + float ratio_y = (float) (SIZE_2(feature1)) / SIZE_2(flow); + + float fltX1 = ((float) intX + VALUE_4(flow, intN, 0, intY, intX) * -2 * t) * ratio_x - k * -2 * t; + float fltY1 = ((float) intY + VALUE_4(flow, intN, 1, intY, intX) * -2 * t) * ratio_y - l * -2 * t; + float fltX2 = ((float) intX + VALUE_4(flow, intN, 0, intY, intX) * 2 * (1 - t)) * ratio_x - k * 2 * (1 - t); + float fltY2 = ((float) intY + VALUE_4(flow, intN, 1, intY, intX) * 2 * (1 - t)) * ratio_y - l * 2 * (1 - t); + + int intLX1 = (int) (floor(fltX1)); + int intTY1 = (int) (floor(fltY1)); + int intRX1 = intLX1 + 1; + int intBY1 = intTY1 + 1; + + int intRX2 = (int) (ceil(fltX2)); + int intBY2 = (int) (ceil(fltY2)); + int intLX2 = intRX2 - 1; + int intTY2 = intBY2 - 1; + + float fltnw = ((float) intRX1 - fltX1) * ((float) intBY1 - fltY1); + float fltne = (fltX1 - (float) intLX1) * ((float) intBY1 - fltY1); + float fltsw = ((float) intRX1 - fltX1) * (fltY1 - (float) intTY1); + float fltse = (fltX1 - (float) intLX1) * (fltY1 - (float) intTY1); + + float fltnwx = (-1.0) * ((float) intBY1 - fltY1) * (-2.0) * t * ratio_x; + float fltnex = (+1.0) * ((float) intBY1 - fltY1) * (-2.0) * t * ratio_x; + float fltswx = (-1.0) * (fltY1 - (float) intTY1) * (-2.0) * t * ratio_x; + float fltsex = (+1.0) * (fltY1 - (float) intTY1) * (-2.0) * t * ratio_x; + + float fltnwy = ((float) intRX1 - fltX1) * (-1.0) * (-2.0) * t * ratio_y; + float fltney = (fltX1 - (float) intLX1) * (-1.0) * (-2.0) * t * ratio_y; + float fltswy = ((float) intRX1 - fltX1) * (+1.0) * (-2.0) * t * ratio_y; + float fltsey = (fltX1 - (float) intLX1) * (+1.0) * (-2.0) * t * ratio_y; + + + if ((intRX1 >= 0) & (intBY1 >= 0) & (intLX1 < SIZE_3(feature1)) & (intTY1 < SIZE_2(feature1))) { + if ((intRX2 >= 0) & (intBY2 >= 0) & (intLX2 < SIZE_3(feature1)) & (intTY2 < SIZE_2(feature1))) { + float fltLoss = VALUE_4(gradLoss, intN, intC, intY, intX); + for (int intChannel = 0; intChannel < SIZE_1(feature1); intChannel += 1) { + float fltF1 = 0.0; + float fltF2 = 0.0; + float dxFltF1 = 0.0; + float dyFltF1 = 0.0; + float dxFltF2 = 0.0; + float dyFltF2 = 0.0; + if ((intLX1 >= 0) & (intTY1 >= 0)) { + fltF1 += VALUE_4(feature1, intN, intChannel, intTY1, intLX1) * fltnw; + dxFltF1 += VALUE_4(feature1, intN, intChannel, intTY1, intLX1) * fltnwx; + dyFltF1 += VALUE_4(feature1, intN, intChannel, intTY1, intLX1) * fltnwy; + } + if ((intRX2 < SIZE_3(feature1)) & (intBY2 < SIZE_2(feature1))) { + fltF2 += VALUE_4(feature2, intN, intChannel, intBY2, intRX2) * fltnw; + dxFltF2 += VALUE_4(feature2, intN, intChannel, intBY2, intRX2) * fltnwx; + dyFltF2 += VALUE_4(feature2, intN, intChannel, intBY2, intRX2) * fltnwy; + } + if ((intLX1 >= 0) & (intBY1 < SIZE_2(feature1))) { + fltF1 += VALUE_4(feature1, intN, intChannel, intBY1, intLX1) * fltsw; + dxFltF1 += VALUE_4(feature1, intN, intChannel, intBY1, intLX1) * fltswx; + dyFltF1 += VALUE_4(feature1, intN, intChannel, intBY1, intLX1) * fltswy; + } + if ((intRX2 < SIZE_3(feature2)) & (intTY2 >= 0)) { + fltF2 += VALUE_4(feature2, intN, intChannel, intTY2, intRX2) * fltsw; + dxFltF2 += VALUE_4(feature2, intN, intChannel, intTY2, intRX2) * fltswx; + dyFltF2 += VALUE_4(feature2, intN, intChannel, intTY2, intRX2) * fltswy; + } + if ((intRX1 < SIZE_3(feature1)) & (intTY1 >= 0)) { + fltF1 += VALUE_4(feature1, intN, intChannel, intTY1, intRX1) * fltne; + dxFltF1 += VALUE_4(feature1, intN, intChannel, intTY1, intRX1) * fltnex; + dyFltF1 += VALUE_4(feature1, intN, intChannel, intTY1, intRX1) * fltney; + } + if ((intLX2 >= 0) & (intBY2 < SIZE_2(feature1))) { + fltF2 += VALUE_4(feature2, intN, intChannel, intBY2, intLX2) * fltne; + dxFltF2 += VALUE_4(feature2, intN, intChannel, intBY2, intLX2) * fltnex; + dyFltF2 += VALUE_4(feature2, intN, intChannel, intBY2, intLX2) * fltney; + } + if ((intRX1 < SIZE_3(feature1)) & (intBY1 < SIZE_2(feature1))) { + fltF1 += VALUE_4(feature1, intN, intChannel, intBY1, intRX1) * fltse; + dxFltF1 += VALUE_4(feature1, intN, intChannel, intBY1, intRX1) * fltsex; + dyFltF1 += VALUE_4(feature1, intN, intChannel, intBY1, intRX1) * fltsey; + } + if ((intLX2 >= 0) & (intTY2 >= 0)) { + fltF2 += VALUE_4(feature2, intN, intChannel, intTY2, intLX2) * fltse; + dxFltF2 += VALUE_4(feature2, intN, intChannel, intTY2, intLX2) * fltsex; + dyFltF2 += VALUE_4(feature2, intN, intChannel, intTY2, intLX2) * fltsey; + } + atomicAdd(&gradFlow[OFFSET_4(gradFlow, intN, 0, intY, intX)], (fltF1 * dxFltF2 + fltF2 * dxFltF1) * fltLoss); + atomicAdd(&gradFlow[OFFSET_4(gradFlow, intN, 1, intY, intX)], (fltF1 * dyFltF2 + fltF2 * dyFltF1) * fltLoss); + } + } + } + } } +''' + + +def cupy_kernel(strFunction, intWindowSize, objectVariables): + strKernel = globals()[strFunction] + + strKernel = strKernel.replace('F_SIZE_H', str((intWindowSize - 1) // 2)) + strKernel = strKernel.replace('F_SIZE', str(intWindowSize)) + + while True: + objectMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objectMatch is None: + break + # end + + intArg = int(objectMatch.group(2)) + + strTensor = objectMatch.group(4) + intSizes = objectVariables[strTensor].size() + + strKernel = strKernel.replace(objectMatch.group(), str(intSizes[intArg])) + # end + + while True: + objectMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) + + if objectMatch is None: + break + # end + + intArgs = int(objectMatch.group(2)) + strArgs = objectMatch.group(4).split(',') + + strTensor = strArgs[0] + intStrides = objectVariables[strTensor].stride() + strIndex = ['((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str( + intStrides[intArg]) + ')' for intArg in range(intArgs)] + + strKernel = strKernel.replace(objectMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') + # end + + while True: + objMatch = re.search('(OFFSET_)([0-4])(\()([^\)]+)(\))', strKernel) + + if objMatch is None: + break + # end + + intArgs = int(objMatch.group(2)) + strArgs = objMatch.group(4).split(',') + + strTensor = strArgs[0] + intStrides = objectVariables[strTensor].stride() + strIndex = ['((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str( + intStrides[intArg]) + ')' for intArg in range(intArgs)] + + strKernel = strKernel.replace(objMatch.group(0), '(' + str.join('+', strIndex) + ')') + + return strKernel + + +@cupy.memoize(for_each_device=True) +def cupy_launch(strFunction, strKernel): + return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) + + +class bilateralcorrelation_nn(torch.autograd.Function): + + @staticmethod + def forward(ctx, feature1, feature2, SBM, time, md=2): + ctx.save_for_backward(feature1, feature2, SBM, time) + ctx.md = md + + intInputBatch, _, intInputHeight, intInputWidth = SBM.size() # (+) feature1 -> SBM + intInputChannel = feature1.size(1) # (+) intInputChannel is a channel size of the feature map + intWindowSize = (2 * ctx.md + 1) + + output = feature1.new_zeros(intInputBatch, intWindowSize ** 2, intInputHeight, intInputWidth) + + assert feature1.size() == feature2.size() + assert SBM.size(1) == 2 + assert (feature1.is_contiguous() == True) + assert (feature2.is_contiguous() == True) + assert (SBM.is_contiguous() == True) + assert (time.is_contiguous() == True) + assert feature1.device == feature2.device and feature1.device == SBM.device and feature1.device == time.device + + if feature1.is_cuda == True and feature2.is_cuda == True: + class Stream: + ptr = torch.cuda.current_stream().cuda_stream + + # end + + n = output.nelement() + cupy_launch('correlation_forward', + cupy_kernel('correlation_forward', intWindowSize, { + 'feature1': feature1, + 'feature2': feature2, + 'flow': SBM, + 'time': time, + 'output': output + }))( + grid=tuple([int((n + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[n, feature1.data_ptr(), feature2.data_ptr(), SBM.data_ptr(), time.data_ptr(), output.data_ptr()], + stream=Stream + ) + + # end + return output / torch.sqrt(torch.tensor(intInputChannel).float()) + + @staticmethod + def backward(ctx, gradOutput): + feature1, feature2, SBM, time = ctx.saved_tensors + + # intInputBatch, _, intInputHeight, intInputWidth = SBM.size() # (+) feature1 -> SBM + intInputChannel = feature1.size(1) # (+) intInputChannel is a channel size of the feature map + intWindowSize = (2 * ctx.md + 1) + + gradInput1 = feature1.new_zeros(feature1.size()) if \ + ctx.needs_input_grad[0] == True else None + gradInput2 = feature2.new_zeros(feature2.size()) if \ + ctx.needs_input_grad[1] == True else None + gradFlow = SBM.new_zeros(SBM.size()) if \ + ctx.needs_input_grad[2] == True else None + + gradOutput = gradOutput / torch.sqrt(torch.tensor(intInputChannel).float()) + + if feature1.is_cuda == True and feature2.is_cuda == True: + class Stream: + ptr = torch.cuda.current_stream().cuda_stream + + # end + + # weight grad + n = gradOutput.nelement() + cupy_launch('correlation_backward_feature', + cupy_kernel('correlation_backward_feature', intWindowSize, { + 'gradLoss': gradOutput, + 'feature1': feature1, + 'feature2': feature2, + 'flow': SBM, + 'time': time, + 'gradInput1': gradInput1, + 'gradInput2': gradInput2 + }))( + grid=tuple([int((n + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[n, gradOutput.data_ptr(), feature1.data_ptr(), feature2.data_ptr(), SBM.data_ptr(), + time.data_ptr(), + gradInput1.data_ptr(), gradInput2.data_ptr()], + stream=Stream + ) + + if gradFlow is not None: + class Stream: + ptr = torch.cuda.current_stream().cuda_stream + + n = gradOutput.nelement() + cupy_launch('correlation_backward_flow', + cupy_kernel('correlation_backward_flow', intWindowSize, { + 'gradLoss': gradOutput, + 'feature1': feature1, + 'feature2': feature2, + 'flow': SBM, + 'time': time, + 'gradFlow': gradFlow + }))( + grid=tuple([int((n + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[n, gradOutput.data_ptr(), feature1.data_ptr(), feature2.data_ptr(), SBM.data_ptr(), + time.data_ptr(), + gradFlow.data_ptr()], + stream=Stream + ) + + # end + + return gradInput1, gradInput2, gradFlow, None, None + +# end \ No newline at end of file diff --git a/modules/components/amt_bilateral/blocks/__init__.py b/modules/components/amt_bilateral/blocks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/components/amt_bilateral/blocks/__pycache__/BilateralCorrelation_NN.cpython-310.pyc b/modules/components/amt_bilateral/blocks/__pycache__/BilateralCorrelation_NN.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8acdc8113053e129f39d034c362243e76c43db4 Binary files /dev/null and b/modules/components/amt_bilateral/blocks/__pycache__/BilateralCorrelation_NN.cpython-310.pyc differ diff --git a/modules/components/amt_bilateral/blocks/__pycache__/BilateralCorrelation_NN.cpython-38.pyc b/modules/components/amt_bilateral/blocks/__pycache__/BilateralCorrelation_NN.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e82becbe75a2780912d81a74a627a33735d621a4 Binary files /dev/null and b/modules/components/amt_bilateral/blocks/__pycache__/BilateralCorrelation_NN.cpython-38.pyc differ diff --git a/modules/components/amt_bilateral/blocks/__pycache__/BilateralCorrelation_NN.cpython-39.pyc b/modules/components/amt_bilateral/blocks/__pycache__/BilateralCorrelation_NN.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c988194ab4425712489c027eac2728c919a127b Binary files /dev/null and b/modules/components/amt_bilateral/blocks/__pycache__/BilateralCorrelation_NN.cpython-39.pyc differ diff --git a/modules/components/amt_bilateral/blocks/__pycache__/__init__.cpython-310.pyc b/modules/components/amt_bilateral/blocks/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04c6debda1b284fd2665dfbe9005d1eb502d0c95 Binary files /dev/null and b/modules/components/amt_bilateral/blocks/__pycache__/__init__.cpython-310.pyc differ diff --git a/modules/components/amt_bilateral/blocks/__pycache__/__init__.cpython-38.pyc b/modules/components/amt_bilateral/blocks/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31f36ef90fd250ceb45cb3815d723b09c32cceb2 Binary files /dev/null and b/modules/components/amt_bilateral/blocks/__pycache__/__init__.cpython-38.pyc differ diff --git a/modules/components/amt_bilateral/blocks/__pycache__/__init__.cpython-39.pyc b/modules/components/amt_bilateral/blocks/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b58a273f3be197bed6093ace0c27fe2d88e0bbb Binary files /dev/null and b/modules/components/amt_bilateral/blocks/__pycache__/__init__.cpython-39.pyc differ diff --git a/modules/components/amt_bilateral/blocks/__pycache__/feat_enc.cpython-310.pyc b/modules/components/amt_bilateral/blocks/__pycache__/feat_enc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8559caa2b6e93f232c599aff9e61fd8fbb4c1ce2 Binary files /dev/null and b/modules/components/amt_bilateral/blocks/__pycache__/feat_enc.cpython-310.pyc differ diff --git a/modules/components/amt_bilateral/blocks/__pycache__/feat_enc.cpython-38.pyc b/modules/components/amt_bilateral/blocks/__pycache__/feat_enc.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85b2afbd58c6bd140f5370721b98d7824bf11da1 Binary files /dev/null and b/modules/components/amt_bilateral/blocks/__pycache__/feat_enc.cpython-38.pyc differ diff --git a/modules/components/amt_bilateral/blocks/__pycache__/feat_enc.cpython-39.pyc b/modules/components/amt_bilateral/blocks/__pycache__/feat_enc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..249765eb0bdfb0b61befbc80377589d9c0ca0cf3 Binary files /dev/null and b/modules/components/amt_bilateral/blocks/__pycache__/feat_enc.cpython-39.pyc differ diff --git a/modules/components/amt_bilateral/blocks/__pycache__/ifrnet.cpython-310.pyc b/modules/components/amt_bilateral/blocks/__pycache__/ifrnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f0233a2d3588a7e5769dd9aefec9876eefc7cca Binary files /dev/null and b/modules/components/amt_bilateral/blocks/__pycache__/ifrnet.cpython-310.pyc differ diff --git a/modules/components/amt_bilateral/blocks/__pycache__/ifrnet.cpython-38.pyc b/modules/components/amt_bilateral/blocks/__pycache__/ifrnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba3235e378cd7c8d582b60d9bcd2a4d4a2cd60da Binary files /dev/null and b/modules/components/amt_bilateral/blocks/__pycache__/ifrnet.cpython-38.pyc differ diff --git a/modules/components/amt_bilateral/blocks/__pycache__/ifrnet.cpython-39.pyc b/modules/components/amt_bilateral/blocks/__pycache__/ifrnet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bdda3567e77d7d613e93f735f13e819dbce15599 Binary files /dev/null and b/modules/components/amt_bilateral/blocks/__pycache__/ifrnet.cpython-39.pyc differ diff --git a/modules/components/amt_bilateral/blocks/__pycache__/multi_flow.cpython-310.pyc b/modules/components/amt_bilateral/blocks/__pycache__/multi_flow.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efecb44539a4eccc0bdc3724c92c1b1477a6ea29 Binary files /dev/null and b/modules/components/amt_bilateral/blocks/__pycache__/multi_flow.cpython-310.pyc differ diff --git a/modules/components/amt_bilateral/blocks/__pycache__/multi_flow.cpython-38.pyc b/modules/components/amt_bilateral/blocks/__pycache__/multi_flow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d021d3caa8c5ba50f43765ebd76d872a92f211f1 Binary files /dev/null and b/modules/components/amt_bilateral/blocks/__pycache__/multi_flow.cpython-38.pyc differ diff --git a/modules/components/amt_bilateral/blocks/__pycache__/multi_flow.cpython-39.pyc b/modules/components/amt_bilateral/blocks/__pycache__/multi_flow.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41c252d0591fb29dc81de831db015ecd499bdc68 Binary files /dev/null and b/modules/components/amt_bilateral/blocks/__pycache__/multi_flow.cpython-39.pyc differ diff --git a/modules/components/amt_bilateral/blocks/__pycache__/raft.cpython-310.pyc b/modules/components/amt_bilateral/blocks/__pycache__/raft.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd887779dc35711c4a5cd35052f28f288f77ab99 Binary files /dev/null and b/modules/components/amt_bilateral/blocks/__pycache__/raft.cpython-310.pyc differ diff --git a/modules/components/amt_bilateral/blocks/__pycache__/raft.cpython-38.pyc b/modules/components/amt_bilateral/blocks/__pycache__/raft.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77d43edcfad28aeeeddb44ccc13dc4dac8bf75cf Binary files /dev/null and b/modules/components/amt_bilateral/blocks/__pycache__/raft.cpython-38.pyc differ diff --git a/modules/components/amt_bilateral/blocks/__pycache__/raft.cpython-39.pyc b/modules/components/amt_bilateral/blocks/__pycache__/raft.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2a672c82012bdb7e4e246983b0629b78c6ddca5 Binary files /dev/null and b/modules/components/amt_bilateral/blocks/__pycache__/raft.cpython-39.pyc differ diff --git a/modules/components/amt_bilateral/blocks/__pycache__/warp.cpython-310.pyc b/modules/components/amt_bilateral/blocks/__pycache__/warp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5ae9b8c615a5ce4adc5d8ce79f2499ec0373203 Binary files /dev/null and b/modules/components/amt_bilateral/blocks/__pycache__/warp.cpython-310.pyc differ diff --git a/modules/components/amt_bilateral/blocks/__pycache__/warp.cpython-38.pyc b/modules/components/amt_bilateral/blocks/__pycache__/warp.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f5625ad1617849ffbc85a49eb4b4ca9a9f46f12 Binary files /dev/null and b/modules/components/amt_bilateral/blocks/__pycache__/warp.cpython-38.pyc differ diff --git a/modules/components/amt_bilateral/blocks/__pycache__/warp.cpython-39.pyc b/modules/components/amt_bilateral/blocks/__pycache__/warp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea628f2f6b7d678f0a974333ee2af03ffb82f414 Binary files /dev/null and b/modules/components/amt_bilateral/blocks/__pycache__/warp.cpython-39.pyc differ diff --git a/modules/components/amt_bilateral/blocks/feat_enc.py b/modules/components/amt_bilateral/blocks/feat_enc.py new file mode 100644 index 0000000000000000000000000000000000000000..3805bd315422703c19bf6a4d0962ee75002d92aa --- /dev/null +++ b/modules/components/amt_bilateral/blocks/feat_enc.py @@ -0,0 +1,343 @@ +import torch +import torch.nn as nn + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) + self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes//4) + self.norm2 = nn.BatchNorm2d(planes//4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes//4) + self.norm2 = nn.InstanceNorm2d(planes//4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + +class SmallEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(SmallEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(32) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(32) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 32 + self.layer1 = self._make_layer(32, stride=1) + self.layer2 = self._make_layer(64, stride=2) + self.layer3 = self._make_layer(96, stride=2) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + +class BasicEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(72, stride=2) + self.layer3 = self._make_layer(128, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + +class LargeEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(LargeEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(112, stride=2) + self.layer3 = self._make_layer(160, stride=2) + self.layer3_2 = self._make_layer(160, stride=1) + + # output convolution + self.conv2 = nn.Conv2d(self.in_planes, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer3_2(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x diff --git a/modules/components/amt_bilateral/blocks/ifrnet.py b/modules/components/amt_bilateral/blocks/ifrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..f374fafc182fe725af3b53e5c211050061a7da1d --- /dev/null +++ b/modules/components/amt_bilateral/blocks/ifrnet.py @@ -0,0 +1,122 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .warp import warp + + +def resize(x, scale_factor): + return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False) + + +def convrelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True): + return nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias), + nn.PReLU(out_channels) + ) + + +class ResBlock(nn.Module): + def __init__(self, in_channels, side_channels, bias=True): + super(ResBlock, self).__init__() + self.side_channels = side_channels + self.conv1 = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(in_channels) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(side_channels) + ) + self.conv3 = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(in_channels) + ) + self.conv4 = nn.Sequential( + nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(side_channels) + ) + self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias) + self.prelu = nn.PReLU(in_channels) + + def forward(self, x): + out = self.conv1(x) + + res_feat = out[:, :-self.side_channels, ...] + side_feat = out[:, -self.side_channels:, :, :] + side_feat = self.conv2(side_feat) + out = self.conv3(torch.cat([res_feat, side_feat], 1)) + + res_feat = out[:, :-self.side_channels, ...] + side_feat = out[:, -self.side_channels:, :, :] + side_feat = self.conv4(side_feat) + out = self.conv5(torch.cat([res_feat, side_feat], 1)) + + out = self.prelu(x + out) + return out + + +class Encoder(nn.Module): + def __init__(self, channels, large=False): + super(Encoder, self).__init__() + self.channels = channels + prev_ch = 3 + for idx, ch in enumerate(channels, 1): + k = 7 if large and idx == 1 else 3 + p = 3 if k == 7 else 1 + self.register_module(f'pyramid{idx}', + nn.Sequential( + convrelu(prev_ch, ch, k, 2, p), + convrelu(ch, ch, 3, 1, 1) + )) + prev_ch = ch + + def forward(self, in_x): + fs = [] + for idx in range(len(self.channels)): + out_x = getattr(self, f'pyramid{idx + 1}')(in_x) + fs.append(out_x) + in_x = out_x + return fs + + +class InitDecoder(nn.Module): + def __init__(self, in_ch, out_ch, skip_ch) -> None: + super().__init__() + self.convblock = nn.Sequential( + convrelu(in_ch * 2 + 1, in_ch * 2), + ResBlock(in_ch * 2, skip_ch), + nn.ConvTranspose2d(in_ch * 2, out_ch + 4, 4, 2, 1), + # nn.ConvTranspose2d(in_ch * 2, in_ch * 2, 4, 2, 1), + # nn.Conv2d(in_ch * 2, out_ch + 4, 3, 1, 1, bias=True) + ) + + def forward(self, f0, f1, embt): + h, w = f0.shape[2:] + embt = embt.repeat(1, 1, h, w) + out = self.convblock(torch.cat([f0, f1, embt], 1)) + flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1) + ft_ = out[:, 4:, ...] + return flow0, flow1, ft_ + + +class IntermediateDecoder(nn.Module): + def __init__(self, in_ch, out_ch, skip_ch) -> None: + super().__init__() + self.convblock = nn.Sequential( + convrelu(in_ch * 3 + 4, in_ch * 3), + ResBlock(in_ch * 3, skip_ch), + nn.ConvTranspose2d(in_ch * 3, out_ch + 4, 4, 2, 1), + # nn.ConvTranspose2d(in_ch * 3, in_ch * 3, 4, 2, 1), + # nn.Conv2d(in_ch * 3, out_ch + 4, 3, 1, 1, bias=True) + ) + + def forward(self, ft_, f0, f1, flow0_in, flow1_in): + f0_warp = warp(f0, flow0_in) + f1_warp = warp(f1, flow1_in) + f_in = torch.cat([ft_, f0_warp, f1_warp, flow0_in, flow1_in], 1) + out = self.convblock(f_in) + flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1) + ft_ = out[:, 4:, ...] + flow0 = flow0 + 2.0 * resize(flow0_in, scale_factor=2.0) + flow1 = flow1 + 2.0 * resize(flow1_in, scale_factor=2.0) + return flow0, flow1, ft_ diff --git a/modules/components/amt_bilateral/blocks/multi_flow.py b/modules/components/amt_bilateral/blocks/multi_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..9ca89162697384c3c59cf0b7d0164f7fc10790e0 --- /dev/null +++ b/modules/components/amt_bilateral/blocks/multi_flow.py @@ -0,0 +1,71 @@ +import torch +import torch.nn as nn +from .warp import warp +from .ifrnet import ( + convrelu, resize, + ResBlock, +) + + +def multi_flow_combine(comb_block, img0, img1, flow0, flow1, + mask=None, img_res=None, mean=None): + ''' + A parallel implementation of multiple flow field warping + comb_block: An nn.Seqential object. + img shape: [b, c, h, w] + flow shape: [b, 2*num_flows, h, w] + mask (opt): + If 'mask' is None, the function conduct a simple average. + img_res (opt): + If 'img_res' is None, the function adds zero instead. + mean (opt): + If 'mean' is None, the function adds zero instead. + ''' + b, c, h, w = flow0.shape + num_flows = c // 2 + flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) + flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) + + mask = mask.reshape(b, num_flows, 1, h, w + ).reshape(-1, 1, h, w) if mask is not None else None + img_res = img_res.reshape(b, num_flows, 3, h, w + ).reshape(-1, 3, h, w) if img_res is not None else 0 + img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w) + img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w) + mean = torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1 + ) if mean is not None else 0 + + img0_warp = warp(img0, flow0) + img1_warp = warp(img1, flow1) + img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res + img_warps = img_warps.reshape(b, num_flows, 3, h, w) + imgt_pred = img_warps.mean(1) + comb_block(img_warps.view(b, -1, h, w)) + return imgt_pred + + +class MultiFlowDecoder(nn.Module): + def __init__(self, in_ch, skip_ch, num_flows=3): + super(MultiFlowDecoder, self).__init__() + self.num_flows = num_flows + self.convblock = nn.Sequential( + convrelu(in_ch*3+4, in_ch*3), + ResBlock(in_ch*3, skip_ch), + nn.ConvTranspose2d(in_ch * 3, 8 * num_flows, 4, 2, 1) + # nn.ConvTranspose2d(in_ch * 3, in_ch * 3, 4, 2, 1), + # nn.Conv2d(in_ch*3, 8*num_flows, 3, 1, 1, bias=True) + ) + + def forward(self, ft_, f0, f1, flow0, flow1): + n = self.num_flows + f0_warp = warp(f0, flow0) + f1_warp = warp(f1, flow1) + out = self.convblock(torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1)) + delta_flow0, delta_flow1, mask, img_res = torch.split(out, [2*n, 2*n, n, 3*n], 1) + mask = torch.sigmoid(mask) + + flow0 = delta_flow0 + 2.0 * resize(flow0, scale_factor=2.0 + ).repeat(1, self.num_flows, 1, 1) + flow1 = delta_flow1 + 2.0 * resize(flow1, scale_factor=2.0 + ).repeat(1, self.num_flows, 1, 1) + + return flow0, flow1, mask, img_res \ No newline at end of file diff --git a/modules/components/amt_bilateral/blocks/raft.py b/modules/components/amt_bilateral/blocks/raft.py new file mode 100644 index 0000000000000000000000000000000000000000..3c7d3997613f4aee03106ad9071872bd0643d4d4 --- /dev/null +++ b/modules/components/amt_bilateral/blocks/raft.py @@ -0,0 +1,184 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .BilateralCorrelation_NN import bilateralcorrelation_nn as bicorr_nn + + +def resize(x, scale_factor): + return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False) + + +def bilinear_sampler(img, coords, mask=False): + """ Wrapper for grid_sample, uses pixel coordinates """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1,1], dim=-1) + xgrid = 2*xgrid/(W-1) - 1 + ygrid = 2*ygrid/(H-1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +def coords_grid(batch, ht, wd, device): + coords = torch.meshgrid(torch.arange(ht, device=device), + torch.arange(wd, device=device), + indexing='ij') + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +class SmallUpdateBlock(nn.Module): + def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, fc_dim, + corr_levels=4, radius=3, scale_factor=None): + super(SmallUpdateBlock, self).__init__() + cor_planes = corr_levels * (2 * radius + 1) **2 + self.scale_factor = scale_factor + + self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0) + self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3) + self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1) + self.conv = nn.Conv2d(corr_dim+flow_dim, fc_dim, 3, padding=1) + + self.gru = nn.Sequential( + nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + ) + + self.feat_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, cdim, 3, padding=1), + ) + + self.flow_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, 4, 3, padding=1), + ) + + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, net, flow, corr): + net = resize(net, 1 / self.scale_factor + ) if self.scale_factor is not None else net + cor = self.lrelu(self.convc1(corr)) + flo = self.lrelu(self.convf1(flow)) + flo = self.lrelu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + inp = self.lrelu(self.conv(cor_flo)) + inp = torch.cat([inp, flow, net], dim=1) + + out = self.gru(inp) + delta_net = self.feat_head(out) + delta_flow = self.flow_head(out) + + if self.scale_factor is not None: + delta_net = resize(delta_net, scale_factor=self.scale_factor) + delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor) + + return delta_net, delta_flow + + +class BasicUpdateBlock(nn.Module): + def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, corr_dim2, + fc_dim, corr_levels=4, radius=3, scale_factor=None, out_num=1): + super(BasicUpdateBlock, self).__init__() + cor_planes = (2 * radius + 1) ** 2 * corr_levels + + self.scale_factor = scale_factor + self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0) + self.convc2 = nn.Conv2d(corr_dim, corr_dim2, 3, padding=1) + self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3) + self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1) + self.conv = nn.Conv2d(flow_dim+corr_dim2, fc_dim, 3, padding=1) + + self.gru = nn.Sequential( + nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + ) + + self.feat_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, cdim, 3, padding=1), + ) + + self.flow_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, 4*out_num, 3, padding=1), + ) + + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, net, flow, corr): + net = resize(net, 1 / self.scale_factor + ) if self.scale_factor is not None else net + cor = self.lrelu(self.convc1(corr)) + cor = self.lrelu(self.convc2(cor)) + flo = self.lrelu(self.convf1(flow)) + flo = self.lrelu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + inp = self.lrelu(self.conv(cor_flo)) + inp = torch.cat([inp, flow, net], dim=1) + + out = self.gru(inp) + delta_net = self.feat_head(out) + delta_flow = self.flow_head(out) + + if self.scale_factor is not None: + delta_net = resize(delta_net, scale_factor=self.scale_factor) + delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor) + return delta_net, delta_flow + + +class BidirCorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + + self.fmap1_pyramid = [fmap1] + self.fmap2_pyramid = [fmap2] + + for _ in range(self.num_levels - 1): + fmap1 = F.avg_pool2d(fmap1, 2, stride=2) + fmap2 = F.avg_pool2d(fmap2, 2, stride=2) + self.fmap1_pyramid.append(fmap1) + self.fmap2_pyramid.append(fmap2) + + def __call__(self, flowt0, flowt1, time_step): + r = self.radius + + out_pyramid = [] + out_pyramid_T = [] + flowt0 = flowt0.contiguous() + flowt1 = flowt1.contiguous() + for i in range(self.num_levels): + fmap1 = self.fmap1_pyramid[i] + fmap2 = self.fmap2_pyramid[i] + corr0 = bicorr_nn.apply(fmap2, fmap1, flowt0, time_step, self.radius) + corr1 = bicorr_nn.apply(fmap1, fmap2, flowt1, time_step, self.radius) + out_pyramid.append(corr0) + out_pyramid_T.append(corr1) + + out = torch.cat(out_pyramid, dim=1) + out_T = torch.cat(out_pyramid_T, dim=1) + return out.contiguous().float(), out_T.contiguous().float() + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht*wd) + fmap2 = fmap2.view(batch, dim, ht*wd) + + corr = torch.matmul(fmap1.transpose(1,2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) \ No newline at end of file diff --git a/modules/components/amt_bilateral/blocks/warp.py b/modules/components/amt_bilateral/blocks/warp.py new file mode 100644 index 0000000000000000000000000000000000000000..89c63449c52bc12b73cc94b29c1d96a305365270 --- /dev/null +++ b/modules/components/amt_bilateral/blocks/warp.py @@ -0,0 +1,13 @@ +import torch +import torch.nn.functional as F + + +def warp(img, flow): + B, _, H, W = flow.shape + xx = torch.linspace(-1.0, 1.0, W).view(1, 1, 1, W).expand(B, -1, H, -1) + yy = torch.linspace(-1.0, 1.0, H).view(1, 1, H, 1).expand(B, -1, -1, W) + grid = torch.cat([xx, yy], 1).to(img) + flow_ = torch.cat([flow[:, 0:1, :, :] / ((W - 1.0) / 2.0), flow[:, 1:2, :, :] / ((H - 1.0) / 2.0)], 1) + grid_ = (grid + flow_).permute(0, 2, 3, 1) + output = F.grid_sample(input=img, grid=grid_, mode='bilinear', padding_mode='border', align_corners=True) + return output diff --git a/modules/components/amt_flowformer/AMT.py b/modules/components/amt_flowformer/AMT.py new file mode 100644 index 0000000000000000000000000000000000000000..9d5d44b196502e08e6c2d34024ba41951b4b7207 --- /dev/null +++ b/modules/components/amt_flowformer/AMT.py @@ -0,0 +1,242 @@ +import torch +import torch.nn as nn +from modules.components.amt.blocks.raft import ( + coords_grid, bilinear_sampler, + SmallUpdateBlock, BidirCorrBlock, BasicUpdateBlock +) +from .blocks.feat_enc import ( + SmallEncoder, + BasicEncoder, + LargeEncoder +) +from .blocks.ifrnet import ( + resize, + Encoder, + InitDecoder, + IntermediateDecoder +) +from .blocks.multi_flow import ( + multi_flow_combine, + MultiFlowDecoder +) + +from .blocks.memory_enc import MemoryEncoder +from .blocks.decoder import MemoryDecoderLayer + +from ..components import register + + +@register('amt_flowformer') +class Model(nn.Module): + def __init__(self, + model_size='S', + corr_radius=3, + corr_lvls=4, + num_flows=3, + channels=[20, 32, 44, 56], + skip_channels=20): + super(Model, self).__init__() + self.model_size = model_size + self.radius = corr_radius + self.corr_levels = corr_lvls + self.num_flows = num_flows + self.channels = channels + self.skip_channels = skip_channels + if self.model_size == 'S': + self.feat_encoder = SmallEncoder(output_dim=84, norm_fn='instance', dropout=0.) + elif self.model_size == 'L': + self.feat_encoder = BasicEncoder(output_dim=128, norm_fn='instance', dropout=0.) + elif self.model_size == 'G': + self.feat_encoder = LargeEncoder(output_dim=128, norm_fn='instance', dropout=0.) + self.encoder = Encoder(channels, large=True) + + ################################ Added modules ###################################### + self.memory_encoder = MemoryEncoder(128, 1, False, 8, 3, 8, 128, 64, 'linear', 0) + self.flow_token_encoder = nn.Sequential( + nn.Conv2d(81, 64, 1, 1), + nn.GELU(), + nn.Conv2d(64, 64, 1, 1) + ) + self.decoder_layer = MemoryDecoderLayer(8) + ##################################################################################### + + self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels) + self.decoder3 = IntermediateDecoder(channels[2], channels[1], skip_channels) + self.decoder2 = IntermediateDecoder(channels[1], channels[0], skip_channels) + self.decoder1 = MultiFlowDecoder(channels[0], skip_channels, num_flows) + + self.update4 = self._get_updateblock(channels[2]) + self.update3 = self._get_updateblock(channels[1], 2) + self.update2 = self._get_updateblock(channels[0], 4) + + if self.model_size == 'G': + self.update3_high = self._get_updateblock(channels[1], None) + self.update2_high = self._get_updateblock(channels[0], None) + + self.comb_block = nn.Sequential( + nn.Conv2d(3 * self.num_flows, 6 * self.num_flows, 7, 1, 3), + nn.PReLU(6 * self.num_flows), + nn.Conv2d(6 * self.num_flows, 3, 7, 1, 3), + ) + + def _get_updateblock(self, cdim, scale_factor=None): + return BasicUpdateBlock(cdim=cdim, hidden_dim=192, flow_dim=64, + corr_dim=256, corr_dim2=192, fc_dim=188, + scale_factor=scale_factor, corr_levels=self.corr_levels, + radius=self.radius) + + def _corr_scale_lookup(self, cost_map, cost_map_t, cost_memory, cost_memory_t, size, coord, flow0, flow1, embt, downsample=1): + # convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0 + # based on linear assumption + t1_scale = 1. / (1. - embt) + t0_scale = 1. / embt + if downsample != 1: + inv = 1 / downsample + flow0 = inv * resize(flow0, scale_factor=inv) + flow1 = inv * resize(flow1, scale_factor=inv) + # scaled lookup + corr0 = self.encode_flow_token(cost_map, coord + flow1 * t1_scale) + corr1 = self.encode_flow_token(cost_map_t, coord + flow0 * t0_scale) + + # make query token from lookup-ed feature + query0 = self.flow_token_encoder(corr0) + query1 = self.flow_token_encoder(corr1) + net_size = query1.shape + query0 = query0.permute(0, 2, 3, 1).contiguous().view(-1, 1, 64) + query1 = query1.permute(0, 2, 3, 1).contiguous().view(-1, 1, 64) + + # get decoded cost volume query + cost_global0, _, _ = self.decoder_layer(query0, None, None, cost_memory, coord + flow1 * t1_scale, net_size, size) + cost_global1, _, _ = self.decoder_layer(query1, None, None, cost_memory_t, coord + flow0 * t0_scale, net_size, size) + + # pass both lookup-ed correlation and decoded cost memory + corr = torch.cat([cost_global0, corr0, cost_global1, corr1], dim=1) + flow = torch.cat([flow0, flow1], dim=1) + return corr, flow + + def encode_flow_token(self, cost_maps, coords): + """ + cost_maps - B*H1*W1, cost_heads_num, H2, W2 + coords - B, 2, H1, W1 + """ + coords = coords.permute(0, 2, 3, 1) + batch, h1, w1, _ = coords.shape + + r = 4 + dx = torch.linspace(-r, r, 2*r+1) + dy = torch.linspace(-r, r, 2*r+1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) + + centroid = coords.reshape(batch*h1*w1, 1, 1, 2) + delta = delta.view(1, 2*r+1, 2*r+1, 2) + coords = centroid + delta + cost_maps = cost_maps.permute(0, 4, 5, 1, 2, 3).reshape(-1, 1, h1, w1) + corr = bilinear_sampler(cost_maps, coords) + corr = corr.view(batch, h1, w1, -1).permute(0, 3, 1, 2) + return corr + + def forward(self, img0, img1, time_step, scale_factor=1.0, eval=False, **kwargs): + mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) + img0 = img0 - mean_ + img1 = img1 - mean_ + img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0 + img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1 + b, _, h, w = img0_.shape + coord = coords_grid(b, h // 8, w // 8, img0.device) + + fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8] + ### Fixed code + cost_memory, cost_volume, size = self.memory_encoder(fmap0, fmap1) + cost_memory_t, cost_volume_t, _ = self.memory_encoder(fmap1, fmap0) + # corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels) + ### + + # f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4] + # f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16] + f0_1, f0_2, f0_3, f0_4 = self.encoder(img0_) + f1_1, f1_2, f1_3, f1_4 = self.encoder(img1_) + + ######################################### the 4th decoder ######################################### + up_flow0_4, up_flow1_4, ft_3_ = self.decoder4(f0_4, f1_4, time_step) + corr_4, flow_4 = self._corr_scale_lookup(cost_volume, cost_volume_t, cost_memory, cost_memory_t, size, coord, + up_flow0_4, up_flow1_4, + time_step, downsample=1) + + # residue update with lookup corr + delta_ft_3_, delta_flow_4 = self.update4(ft_3_, flow_4, corr_4) + delta_flow0_4, delta_flow1_4 = torch.chunk(delta_flow_4, 2, 1) + up_flow0_4 = up_flow0_4 + delta_flow0_4 + up_flow1_4 = up_flow1_4 + delta_flow1_4 + ft_3_ = ft_3_ + delta_ft_3_ + + ######################################### the 3rd decoder ######################################### + up_flow0_3, up_flow1_3, ft_2_ = self.decoder3(ft_3_, f0_3, f1_3, up_flow0_4, up_flow1_4) + corr_3, flow_3 = self._corr_scale_lookup(cost_volume, cost_volume_t, cost_memory, cost_memory_t, size, + coord, up_flow0_3, up_flow1_3, + time_step, downsample=2) + + # residue update with lookup corr + delta_ft_2_, delta_flow_3 = self.update3(ft_2_, flow_3, corr_3) + delta_flow0_3, delta_flow1_3 = torch.chunk(delta_flow_3, 2, 1) + up_flow0_3 = up_flow0_3 + delta_flow0_3 + up_flow1_3 = up_flow1_3 + delta_flow1_3 + ft_2_ = ft_2_ + delta_ft_2_ + + if self.model_size == 'G': + # residue update with lookup corr (hr) + corr_3 = resize(corr_3, scale_factor=2.0) + up_flow_3 = torch.cat([up_flow0_3, up_flow1_3], dim=1) + delta_ft_2_, delta_up_flow_3 = self.update3_high(ft_2_, up_flow_3, corr_3) + ft_2_ += delta_ft_2_ + up_flow0_3 += delta_up_flow_3[:, 0:2] + up_flow1_3 += delta_up_flow_3[:, 2:4] + + ######################################### the 2nd decoder ######################################### + up_flow0_2, up_flow1_2, ft_1_ = self.decoder2(ft_2_, f0_2, f1_2, up_flow0_3, up_flow1_3) + corr_2, flow_2 = self._corr_scale_lookup(cost_volume, cost_volume_t, cost_memory, cost_memory_t, size, + coord, up_flow0_2, up_flow1_2, + time_step, downsample=4) + + # residue update with lookup corr + delta_ft_1_, delta_flow_2 = self.update2(ft_1_, flow_2, corr_2) + delta_flow0_2, delta_flow1_2 = torch.chunk(delta_flow_2, 2, 1) + up_flow0_2 = up_flow0_2 + delta_flow0_2 + up_flow1_2 = up_flow1_2 + delta_flow1_2 + ft_1_ = ft_1_ + delta_ft_1_ + + if self.model_size == 'G': + # residue update with lookup corr (hr) + corr_2 = resize(corr_2, scale_factor=4.0) + up_flow_2 = torch.cat([up_flow0_2, up_flow1_2], dim=1) + delta_ft_1_, delta_up_flow_2 = self.update2_high(ft_1_, up_flow_2, corr_2) + ft_1_ += delta_ft_1_ + up_flow0_2 += delta_up_flow_2[:, 0:2] + up_flow1_2 += delta_up_flow_2[:, 2:4] + + ######################################### the 1st decoder ######################################### + up_flow0_1, up_flow1_1, mask, img_res = self.decoder1(ft_1_, f0_1, f1_1, up_flow0_2, up_flow1_2) + + if scale_factor != 1.0: + up_flow0_1 = resize(up_flow0_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) + up_flow1_1 = resize(up_flow1_1, scale_factor=(1.0/scale_factor)) * (1.0/scale_factor) + mask = resize(mask, scale_factor=(1.0/scale_factor)) + img_res = resize(img_res, scale_factor=(1.0/scale_factor)) + + # Merge multiple predictions + imgt_pred = multi_flow_combine(self.comb_block, img0, img1, up_flow0_1, up_flow1_1, + mask, img_res, mean_) + imgt_pred = torch.clamp(imgt_pred, 0, 1) + + if eval: + return { 'imgt_pred': imgt_pred, } + else: + up_flow0_1 = up_flow0_1.reshape(b, self.num_flows, 2, int(h/scale_factor), int(w/scale_factor)) + up_flow1_1 = up_flow1_1.reshape(b, self.num_flows, 2, int(h/scale_factor), int(w/scale_factor)) + return { + 'imgt_pred': imgt_pred, + 'flow0_pred': [up_flow0_1, up_flow0_2, up_flow0_3, up_flow0_4], + 'flow1_pred': [up_flow1_1, up_flow1_2, up_flow1_3, up_flow1_4], + 'flowfwd': up_flow0_1[:, 0], + 'flowbwd': up_flow1_1[:, 0], + 'ft_pred': [ft_1_, ft_2_, ft_3_], + } diff --git a/modules/components/amt_flowformer/__init__.py b/modules/components/amt_flowformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..387563589f3cc4a3dc664a5d8d4f5557b0100996 --- /dev/null +++ b/modules/components/amt_flowformer/__init__.py @@ -0,0 +1 @@ +from .AMT import Model \ No newline at end of file diff --git a/modules/components/amt_flowformer/__pycache__/AMT.cpython-310.pyc b/modules/components/amt_flowformer/__pycache__/AMT.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ef46f5c114ab5575312332db63d3d651de621e5 Binary files /dev/null and b/modules/components/amt_flowformer/__pycache__/AMT.cpython-310.pyc differ diff --git a/modules/components/amt_flowformer/__pycache__/AMT.cpython-38.pyc b/modules/components/amt_flowformer/__pycache__/AMT.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fb6ca910007702a33608c0b516e5a0290ba4bc6 Binary files /dev/null and b/modules/components/amt_flowformer/__pycache__/AMT.cpython-38.pyc differ diff --git a/modules/components/amt_flowformer/__pycache__/AMT.cpython-39.pyc b/modules/components/amt_flowformer/__pycache__/AMT.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c9b2b2ca432e70cc4f997c70452aac7a96e1e78 Binary files /dev/null and b/modules/components/amt_flowformer/__pycache__/AMT.cpython-39.pyc differ diff --git a/modules/components/amt_flowformer/__pycache__/__init__.cpython-310.pyc b/modules/components/amt_flowformer/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b468d7477a633cf5dd7085699084602a9f4448cb Binary files /dev/null and b/modules/components/amt_flowformer/__pycache__/__init__.cpython-310.pyc differ diff --git a/modules/components/amt_flowformer/__pycache__/__init__.cpython-38.pyc b/modules/components/amt_flowformer/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eee28c36a32c381d638ec1fac1fda1d3e995f77d Binary files /dev/null and b/modules/components/amt_flowformer/__pycache__/__init__.cpython-38.pyc differ diff --git a/modules/components/amt_flowformer/__pycache__/__init__.cpython-39.pyc b/modules/components/amt_flowformer/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6464f90af6fed68362550ac54cf9c462f39714f Binary files /dev/null and b/modules/components/amt_flowformer/__pycache__/__init__.cpython-39.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__init__.py b/modules/components/amt_flowformer/blocks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/components/amt_flowformer/blocks/__pycache__/__init__.cpython-310.pyc b/modules/components/amt_flowformer/blocks/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e36eac4537b2daca2ff55baf7101efd1698c9ca6 Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/__init__.cpython-310.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/__init__.cpython-38.pyc b/modules/components/amt_flowformer/blocks/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0970c15a3d49c0df392c27fc160754585fa2086 Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/__init__.cpython-38.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/__init__.cpython-39.pyc b/modules/components/amt_flowformer/blocks/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f211fc20f3af6d39da61fcce83751a4b71dd1f14 Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/__init__.cpython-39.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/attention.cpython-310.pyc b/modules/components/amt_flowformer/blocks/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8dd06d70679caa41388e0cc9c9e5fa0748c8351c Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/attention.cpython-310.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/attention.cpython-38.pyc b/modules/components/amt_flowformer/blocks/__pycache__/attention.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d78f4bff086bec92f0be4ac2bfdae8bb422523c1 Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/attention.cpython-38.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/attention.cpython-39.pyc b/modules/components/amt_flowformer/blocks/__pycache__/attention.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..adc08ec984719cdc1af35c80eb02a8292b525bb2 Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/attention.cpython-39.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/decoder.cpython-310.pyc b/modules/components/amt_flowformer/blocks/__pycache__/decoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..42df7056ef3e092200e894eb1870c0f5fa5e9fdb Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/decoder.cpython-310.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/decoder.cpython-38.pyc b/modules/components/amt_flowformer/blocks/__pycache__/decoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8dbe27c18abab8a29e2db5bf6e4cf2f01228ba40 Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/decoder.cpython-38.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/decoder.cpython-39.pyc b/modules/components/amt_flowformer/blocks/__pycache__/decoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e639f4907768a8eb98f3fda5fb04854971d291e3 Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/decoder.cpython-39.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/feat_enc.cpython-310.pyc b/modules/components/amt_flowformer/blocks/__pycache__/feat_enc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97b9c349211af5ab44d30af6347d461d7467419d Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/feat_enc.cpython-310.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/feat_enc.cpython-38.pyc b/modules/components/amt_flowformer/blocks/__pycache__/feat_enc.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb41f3b5f58f7f9c233695aa2d12b07992e6c1f0 Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/feat_enc.cpython-38.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/feat_enc.cpython-39.pyc b/modules/components/amt_flowformer/blocks/__pycache__/feat_enc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2883b2d74737f06af547bbc36b6ab62725c012e9 Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/feat_enc.cpython-39.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/gma.cpython-310.pyc b/modules/components/amt_flowformer/blocks/__pycache__/gma.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..574c1e14ccfa89a0c6caaccabfb2fe981caac237 Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/gma.cpython-310.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/gma.cpython-38.pyc b/modules/components/amt_flowformer/blocks/__pycache__/gma.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c3fe02407d91d511883e91a67bfa57e6022f7e8 Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/gma.cpython-38.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/gma.cpython-39.pyc b/modules/components/amt_flowformer/blocks/__pycache__/gma.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e75eed43c3a780f23dee6a64f2c94331691e1569 Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/gma.cpython-39.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/gru.cpython-310.pyc b/modules/components/amt_flowformer/blocks/__pycache__/gru.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9453d7dae257c7f7fda227626e8a89f38c7981d8 Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/gru.cpython-310.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/gru.cpython-38.pyc b/modules/components/amt_flowformer/blocks/__pycache__/gru.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc76d4931d833af8026bd1d5dd894319b8bf6569 Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/gru.cpython-38.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/gru.cpython-39.pyc b/modules/components/amt_flowformer/blocks/__pycache__/gru.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b19225f2e6720cea372ef2996148fd29ed1026d Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/gru.cpython-39.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/ifrnet.cpython-310.pyc b/modules/components/amt_flowformer/blocks/__pycache__/ifrnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b16168cb264dd8745636b753596882f21b8db9f Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/ifrnet.cpython-310.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/ifrnet.cpython-38.pyc b/modules/components/amt_flowformer/blocks/__pycache__/ifrnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..876aaa51bba6ffa2b3db11b41fd1c23073acabeb Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/ifrnet.cpython-38.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/ifrnet.cpython-39.pyc b/modules/components/amt_flowformer/blocks/__pycache__/ifrnet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..628f7aa1912f1ecfb2df9a4b21a5f817c794b6b3 Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/ifrnet.cpython-39.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/memory_enc.cpython-310.pyc b/modules/components/amt_flowformer/blocks/__pycache__/memory_enc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6e53b10c306e21a43ae1ec84adf520129a324f1 Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/memory_enc.cpython-310.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/memory_enc.cpython-38.pyc b/modules/components/amt_flowformer/blocks/__pycache__/memory_enc.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2652d232a8f072ec786e6926e9546d493144614 Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/memory_enc.cpython-38.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/memory_enc.cpython-39.pyc b/modules/components/amt_flowformer/blocks/__pycache__/memory_enc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f529b5b124bace535c71f08ddd686856a4e6353 Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/memory_enc.cpython-39.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/multi_flow.cpython-310.pyc b/modules/components/amt_flowformer/blocks/__pycache__/multi_flow.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33be4b8bae71299d769fe863b6b0cc4fc298a8d5 Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/multi_flow.cpython-310.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/multi_flow.cpython-38.pyc b/modules/components/amt_flowformer/blocks/__pycache__/multi_flow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..156bbead85172bcb33158e356c9c9babb69c77f4 Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/multi_flow.cpython-38.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/multi_flow.cpython-39.pyc b/modules/components/amt_flowformer/blocks/__pycache__/multi_flow.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3963330e49a1e2fa6f6e5385a0cccc38bee39942 Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/multi_flow.cpython-39.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/twins.cpython-310.pyc b/modules/components/amt_flowformer/blocks/__pycache__/twins.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..587475600e061b085a55405980e0b6c2f95bcf4c Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/twins.cpython-310.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/twins.cpython-38.pyc b/modules/components/amt_flowformer/blocks/__pycache__/twins.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd1bd0f88050e52c94c64aadb86513cc271ce7a6 Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/twins.cpython-38.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/twins.cpython-39.pyc b/modules/components/amt_flowformer/blocks/__pycache__/twins.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4a00f97051140b008904fd22d3d9878af8c7abf Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/twins.cpython-39.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/utils.cpython-310.pyc b/modules/components/amt_flowformer/blocks/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3654c41974a4f676011239eb47cc325aeb267cd Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/utils.cpython-310.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/utils.cpython-38.pyc b/modules/components/amt_flowformer/blocks/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7852d2cf637929d5268c28176811b8f7f8a85dc2 Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/utils.cpython-38.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/utils.cpython-39.pyc b/modules/components/amt_flowformer/blocks/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b8b9746cf9972c7f3b383b4635c32e7cba1a1a5 Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/utils.cpython-39.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/warp.cpython-310.pyc b/modules/components/amt_flowformer/blocks/__pycache__/warp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa5562651cff5fbef6414d897189546d16907b40 Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/warp.cpython-310.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/warp.cpython-38.pyc b/modules/components/amt_flowformer/blocks/__pycache__/warp.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ba1fb35a7c6471e25b482bb45c96e887a815efb Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/warp.cpython-38.pyc differ diff --git a/modules/components/amt_flowformer/blocks/__pycache__/warp.cpython-39.pyc b/modules/components/amt_flowformer/blocks/__pycache__/warp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cac1bdb1e8bf5423e279f4832e48e4071b5f3a9 Binary files /dev/null and b/modules/components/amt_flowformer/blocks/__pycache__/warp.cpython-39.pyc differ diff --git a/modules/components/amt_flowformer/blocks/attention.py b/modules/components/amt_flowformer/blocks/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..3b3625ea48fd1e8584c387f1a1a22f236dedb6c3 --- /dev/null +++ b/modules/components/amt_flowformer/blocks/attention.py @@ -0,0 +1,160 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum + +from einops.layers.torch import Rearrange +from einops import rearrange + +class BroadMultiHeadAttention(nn.Module): + def __init__(self, dim, heads): + super(BroadMultiHeadAttention, self).__init__() + self.dim = dim + self.heads = heads + self.scale = (dim/heads) ** -0.5 + self.attend = nn.Softmax(dim=-1) + + def attend_with_rpe(self, Q, K): + Q = rearrange(Q.squeeze(), 'i (heads d) -> heads i d', heads=self.heads) + K = rearrange(K, 'b j (heads d) -> b heads j d', heads=self.heads) + + dots = einsum('hid, bhjd -> bhij', Q, K) * self.scale # (b hw) heads 1 pointnum + + return self.attend(dots) + + def forward(self, Q, K, V): + attn = self.attend_with_rpe(Q, K) + B, _, _ = K.shape + _, N, _ = Q.shape + + V = rearrange(V, 'b j (heads d) -> b heads j d', heads=self.heads) + + out = einsum('bhij, bhjd -> bhid', attn, V) + out = rearrange(out, 'b heads n d -> b n (heads d)', b=B, n=N) + + return out + +class MultiHeadAttention(nn.Module): + def __init__(self, dim, heads): + super(MultiHeadAttention, self).__init__() + self.dim = dim + self.heads = heads + self.scale = (dim/heads) ** -0.5 + self.attend = nn.Softmax(dim=-1) + + def attend_with_rpe(self, Q, K): + Q = rearrange(Q, 'b i (heads d) -> b heads i d', heads=self.heads) + K = rearrange(K, 'b j (heads d) -> b heads j d', heads=self.heads) + + dots = einsum('bhid, bhjd -> bhij', Q, K) * self.scale # (b hw) heads 1 pointnum + + return self.attend(dots) + + def forward(self, Q, K, V): + attn = self.attend_with_rpe(Q, K) + B, HW, _ = Q.shape + + V = rearrange(V, 'b j (heads d) -> b heads j d', heads=self.heads) + + out = einsum('bhij, bhjd -> bhid', attn, V) + out = rearrange(out, 'b heads hw d -> b hw (heads d)', b=B, hw=HW) + + return out + +# class MultiHeadAttentionRelative_encoder(nn.Module): +# def __init__(self, dim, heads): +# super(MultiHeadAttentionRelative, self).__init__() +# self.dim = dim +# self.heads = heads +# self.scale = (dim/heads) ** -0.5 +# self.attend = nn.Softmax(dim=-1) + +# def attend_with_rpe(self, Q, K, Q_r, K_r): +# """ +# Q: [BH1W1, H3W3, dim] +# K: [BH1W1, H3W3, dim] +# Q_r: [BH1W1, H3W3, H3W3, dim] +# K_r: [BH1W1, H3W3, H3W3, dim] +# """ + +# Q = rearrange(Q, 'b i (heads d) -> b heads i d', heads=self.heads) # [BH1W1, heads, H3W3, dim] +# K = rearrange(K, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim] +# K_r = rearrange(K_r, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim] +# Q_r = rearrange(Q_r, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim] + +# # context-context similarity +# c_c = einsum('bhid, bhjd -> bhij', Q, K) * self.scale # [(B H1W1) heads H3W3 H3W3] +# # context-position similarity +# c_p = einsum('bhid, bhjd -> bhij', Q, K_r) * self.scale # [(B H1W1) heads 1 H3W3] +# # position-context similarity +# p_c = einsum('bhijd, bhikd -> bhijk', Q_r[:,:,:,None,:], K[:,:,:,None,:]) +# p_c = torch.squeeze(p_c, dim=4) +# p_c = p_c.permute(0, 1, 3, 2) +# dots = c_c + c_p + p_c +# return self.attend(dots) + +# def forward(self, Q, K, V, Q_r, K_r): +# attn = self.attend_with_rpe(Q, K, Q_r, K_r) +# B, HW, _ = Q.shape + +# V = rearrange(V, 'b j (heads d) -> b heads j d', heads=self.heads) + +# out = einsum('bhij, bhjd -> bhid', attn, V) +# out = rearrange(out, 'b heads hw d -> b hw (heads d)', b=B, hw=HW) + +# return out + +class MultiHeadAttentionRelative(nn.Module): + def __init__(self, dim, heads): + super(MultiHeadAttentionRelative, self).__init__() + self.dim = dim + self.heads = heads + self.scale = (dim/heads) ** -0.5 + self.attend = nn.Softmax(dim=-1) + + def attend_with_rpe(self, Q, K, Q_r, K_r): + """ + Q: [BH1W1, 1, dim] + K: [BH1W1, H3W3, dim] + Q_r: [BH1W1, H3W3, dim] + K_r: [BH1W1, H3W3, dim] + """ + + Q = rearrange(Q, 'b i (heads d) -> b heads i d', heads=self.heads) # [BH1W1, heads, 1, dim] + K = rearrange(K, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim] + K_r = rearrange(K_r, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim] + Q_r = rearrange(Q_r, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim] + + # context-context similarity + c_c = einsum('bhid, bhjd -> bhij', Q, K) * self.scale # [(B H1W1) heads 1 H3W3] + # context-position similarity + c_p = einsum('bhid, bhjd -> bhij', Q, K_r) * self.scale # [(B H1W1) heads 1 H3W3] + # position-context similarity + p_c = einsum('bhijd, bhikd -> bhijk', Q_r[:,:,:,None,:], K[:,:,:,None,:]) * self.scale + p_c = torch.squeeze(p_c, dim=4) + p_c = p_c.permute(0, 1, 3, 2) + dots = c_c + c_p + p_c + return self.attend(dots) + + def forward(self, Q, K, V, Q_r, K_r): + attn = self.attend_with_rpe(Q, K, Q_r, K_r) + B, HW, _ = Q.shape + + V = rearrange(V, 'b j (heads d) -> b heads j d', heads=self.heads) + + out = einsum('bhij, bhjd -> bhid', attn, V) + out = rearrange(out, 'b heads hw d -> b hw (heads d)', b=B, hw=HW) + + return out + +def LinearPositionEmbeddingSine(x, dim=128, NORMALIZE_FACOR=1/200): + # 200 should be enough for a 8x downsampled image + # assume x to be [_, _, 2] + freq_bands = torch.linspace(0, dim//4-1, dim//4).to(x.device) + return torch.cat([torch.sin(3.14*x[..., -2:-1]*freq_bands*NORMALIZE_FACOR), torch.cos(3.14*x[..., -2:-1]*freq_bands*NORMALIZE_FACOR), torch.sin(3.14*x[..., -1:]*freq_bands*NORMALIZE_FACOR), torch.cos(3.14*x[..., -1:]*freq_bands*NORMALIZE_FACOR)], dim=-1) + +def ExpPositionEmbeddingSine(x, dim=128, NORMALIZE_FACOR=1/200): + # 200 should be enough for a 8x downsampled image + # assume x to be [_, _, 2] + freq_bands = torch.linspace(0, dim//4-1, dim//4).to(x.device) + return torch.cat([torch.sin(x[..., -2:-1]*(NORMALIZE_FACOR * 2 ** freq_bands)), torch.cos(x[..., -2:-1]*(NORMALIZE_FACOR * 2 ** freq_bands)), torch.sin(x[..., -1:]*(NORMALIZE_FACOR * 2 ** freq_bands)), torch.cos(x[..., -1:]*(NORMALIZE_FACOR * 2 ** freq_bands))], dim=-1) \ No newline at end of file diff --git a/modules/components/amt_flowformer/blocks/cnn.py b/modules/components/amt_flowformer/blocks/cnn.py new file mode 100644 index 0000000000000000000000000000000000000000..47b184570c3cd771580c72c4009107a580612a3b --- /dev/null +++ b/modules/components/amt_flowformer/blocks/cnn.py @@ -0,0 +1,577 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import Mlp, DropPath, to_2tuple, trunc_normal_ +import math +import numpy as np + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) + self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes//4) + self.norm2 = nn.BatchNorm2d(planes//4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes//4) + self.norm2 = nn.InstanceNorm2d(planes//4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + +class BasicEncoder(nn.Module): + def __init__(self, input_dim=3, output_dim=128, norm_fn='batch', dropout=0.0): + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + mul = input_dim // 3 + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64 * mul) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64 * mul) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64 * mul) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(input_dim, 64 * mul, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 * mul + self.layer1 = self._make_layer(64 * mul, stride=1) + self.layer2 = self._make_layer(96 * mul, stride=2) + self.layer3 = self._make_layer(128 * mul, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128 * mul, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def compute_params(self): + num = 0 + for param in self.parameters(): + num += np.prod(param.size()) + + return num + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + +class SmallEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(SmallEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(32) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(32) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 32 + self.layer1 = self._make_layer(32, stride=1) + self.layer2 = self._make_layer(64, stride=2) + self.layer3 = self._make_layer(96, stride=2) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + +class ConvNets(nn.Module): + def __init__(self, in_dim, out_dim, inter_dim, depth, stride=1): + super(ConvNets, self).__init__() + + self.conv_first = nn.Conv2d(in_dim, inter_dim, kernel_size=3, padding=1, stride=stride) + self.conv_last = nn.Conv2d(inter_dim, out_dim, kernel_size=3, padding=1, stride=stride) + self.relu = nn.ReLU(inplace=True) + self.inter_convs = nn.ModuleList( + [ResidualBlock(inter_dim, inter_dim, norm_fn='none', stride=1) for i in range(depth)]) + + def forward(self, x): + x = self.relu(self.conv_first(x)) + for inter_conv in self.inter_convs: + x = inter_conv(x) + x = self.conv_last(x) + return x + +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256): + super(FlowHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) + + h = (1-z) * h + z * q + return h + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + + self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + return h + +class BasicMotionEncoder(nn.Module): + def __init__(self, args): + super(BasicMotionEncoder, self).__init__() + cor_planes = args.motion_feature_dim + self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(2, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + +class BasicFuseMotion(nn.Module): + def __init__(self, args): + super(BasicFuseMotion, self).__init__() + cor_planes = args.motion_feature_dim + out_planes = args.query_latent_dim + + self.normf1 = nn.InstanceNorm2d(128) + self.normf2 = nn.InstanceNorm2d(128) + + self.convf1 = nn.Conv2d(2, 128, 3, padding=1) + self.convf2 = nn.Conv2d(128, 128, 3, padding=1) + self.convf3 = nn.Conv2d(128, 64, 3, padding=1) + + s = 1 + self.normc1 = nn.InstanceNorm2d(256*s) + self.normc2 = nn.InstanceNorm2d(256*s) + self.normc3 = nn.InstanceNorm2d(256*s) + + self.convc1 = nn.Conv2d(cor_planes+128, 256*s, 1, padding=0) + self.convc2 = nn.Conv2d(256*s, 256*s, 3, padding=1) + self.convc3 = nn.Conv2d(256*s, 256*s, 3, padding=1) + self.convc4 = nn.Conv2d(256*s, 256*s, 3, padding=1) + self.conv = nn.Conv2d(256*s + 64, out_planes, 1, padding=0) + + def forward(self, flow, feat, context1=None): + flo = F.relu(self.normf1(self.convf1(flow))) + flo = F.relu(self.normf2(self.convf2(flo))) + flo = self.convf3(flo) + + feat = torch.cat([feat, context1], dim=1) + feat = F.relu(self.normc1(self.convc1(feat))) + feat = F.relu(self.normc2(self.convc2(feat))) + feat = F.relu(self.normc3(self.convc3(feat))) + feat = self.convc4(feat) + + feat = torch.cat([flo, feat], dim=1) + feat = F.relu(self.conv(feat)) + + return feat + +class BasicUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128, input_dim=128): + super(BasicUpdateBlock, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64*9, 1, padding=0)) + + def forward(self, net, inp, corr, flow, upsample=True): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = .25 * self.mask(net) + return net, mask, delta_flow + +class DirectMeanMaskPredictor(nn.Module): + def __init__(self, args): + super(DirectMeanMaskPredictor, self).__init__() + self.flow_head = FlowHead(args.predictor_dim, hidden_dim=256) + self.mask = nn.Sequential( + nn.Conv2d(args.predictor_dim, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64*9, 1, padding=0)) + + def forward(self, motion_features): + delta_flow = self.flow_head(motion_features) + mask = .25 * self.mask(motion_features) + + return mask, delta_flow + +class BaiscMeanPredictor(nn.Module): + def __init__(self, args, hidden_dim=128): + super(BaiscMeanPredictor, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64*9, 1, padding=0)) + + def forward(self, latent, flow): + motion_features = self.encoder(flow, latent) + delta_flow = self.flow_head(motion_features) + mask = .25 * self.mask(motion_features) + + return mask, delta_flow + +class BasicRPEEncoder(nn.Module): + def __init__(self, args): + super(BasicRPEEncoder, self).__init__() + self.args = args + dim = args.query_latent_dim + self.encoder = nn.Sequential( + nn.Linear(2, dim // 2), + nn.ReLU(inplace=True), + nn.Linear(dim // 2, dim), + nn.ReLU(inplace=True), + nn.Linear(dim, dim) + ) + + def forward(self, rpe_tokens): + return self.encoder(rpe_tokens) + +from .twins import Block, CrossBlock + +class TwinsSelfAttentionLayer(nn.Module): + def __init__(self, args): + super(TwinsSelfAttentionLayer, self).__init__() + self.args = args + embed_dim = 256 + num_heads = 8 + mlp_ratio = 4 + ws = 7 + sr_ratio = 4 + dpr = 0. + drop_rate = 0. + attn_drop_rate=0. + + self.local_block = Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr, sr_ratio=sr_ratio, ws=ws, with_rpe=True) + self.global_block = Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr, sr_ratio=sr_ratio, ws=1, with_rpe=True) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + + def forward(self, x, tgt, size): + x = self.local_block(x, size) + x = self.global_block(x, size) + + tgt = self.local_block(tgt, size) + tgt = self.global_block(tgt, size) + return x, tgt + +class TwinsCrossAttentionLayer(nn.Module): + def __init__(self, args): + super(TwinsCrossAttentionLayer, self).__init__() + self.args = args + embed_dim = 256 + num_heads = 8 + mlp_ratio = 4 + ws = 7 + sr_ratio = 4 + dpr = 0. + drop_rate = 0. + attn_drop_rate=0. + + self.local_block = Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr, sr_ratio=sr_ratio, ws=ws, with_rpe=True) + self.global_block = CrossBlock(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr, sr_ratio=sr_ratio, ws=1, with_rpe=True) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + + def forward(self, x, tgt, size): + x = self.local_block(x, size) + tgt = self.local_block(tgt, size) + x, tgt = self.global_block(x, tgt, size) + + return x, tgt diff --git a/modules/components/amt_flowformer/blocks/convnext.py b/modules/components/amt_flowformer/blocks/convnext.py new file mode 100644 index 0000000000000000000000000000000000000000..5114ee8191d9da78a3fac2195787c941c6f560aa --- /dev/null +++ b/modules/components/amt_flowformer/blocks/convnext.py @@ -0,0 +1,87 @@ +from turtle import forward +import torch +from torch import nn +import torch.nn.functional as F +import numpy as np + +class ConvNextLayer(nn.Module): + def __init__(self, dim, depth=4): + super().__init__() + self.net = nn.Sequential( + *[ConvNextBlock(dim=dim) for j in range(depth)] + ) + + def forward(self, x): + return self.net(x) + + def compute_params(self): + num = 0 + for param in self.parameters(): + num += np.prod(param.size()) + + return num + +class ConvNextBlock(nn.Module): + r""" ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + def __init__(self, dim, layer_scale_init_value=1e-6): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), + requires_grad=True) if layer_scale_init_value > 0 else None + # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + # print(f"conv next layer") + + def forward(self, x): + input = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + x + return x + + +class LayerNorm(nn.Module): + r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape, ) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x \ No newline at end of file diff --git a/modules/components/amt_flowformer/blocks/decoder.py b/modules/components/amt_flowformer/blocks/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..11968a7962f52b19bab33e74f4ea4b2a69b8a21c --- /dev/null +++ b/modules/components/amt_flowformer/blocks/decoder.py @@ -0,0 +1,229 @@ +import loguru +import torch +import math +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum + +from einops.layers.torch import Rearrange +from einops import rearrange + +from .utils import coords_grid, bilinear_sampler, upflow8 +from .attention import MultiHeadAttention, LinearPositionEmbeddingSine, ExpPositionEmbeddingSine +from typing import Optional, Tuple + +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + +from .gru import BasicUpdateBlock, GMAUpdateBlock +from .gma import Attention + +def initialize_flow(img): + """ Flow is represented as difference between two means flow = mean1 - mean0""" + N, C, H, W = img.shape + mean = coords_grid(N, H, W).to(img.device) + mean_init = coords_grid(N, H, W).to(img.device) + + # optical flow computed as difference: flow = mean1 - mean0 + return mean, mean_init + +class CrossAttentionLayer(nn.Module): + # def __init__(self, dim, cfg, num_heads=8, attn_drop=0., proj_drop=0., drop_path=0., dropout=0.): + def __init__(self, qk_dim, v_dim, query_token_dim, tgt_token_dim, add_flow_token=True, num_heads=8, attn_drop=0., proj_drop=0., drop_path=0., dropout=0., pe='linear'): + super(CrossAttentionLayer, self).__init__() + + head_dim = qk_dim // num_heads + self.scale = head_dim ** -0.5 + self.query_token_dim = query_token_dim + self.pe = pe + + self.norm1 = nn.LayerNorm(query_token_dim) + self.norm2 = nn.LayerNorm(query_token_dim) + self.multi_head_attn = MultiHeadAttention(qk_dim, num_heads) + self.q, self.k, self.v = nn.Linear(query_token_dim, qk_dim, bias=True), nn.Linear(tgt_token_dim, qk_dim, bias=True), nn.Linear(tgt_token_dim, v_dim, bias=True) + + self.proj = nn.Linear(v_dim*2, query_token_dim) + self.proj_drop = nn.Dropout(proj_drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.ffn = nn.Sequential( + nn.Linear(query_token_dim, query_token_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(query_token_dim, query_token_dim), + nn.Dropout(dropout) + ) + self.add_flow_token = add_flow_token + self.dim = qk_dim + def forward(self, query, key, value, memory, query_coord, patch_size, size_h3w3): + """ + query_coord [B, 2, H1, W1] + """ + B, _, H1, W1 = query_coord.shape + + if key is None and value is None: + key = self.k(memory) + value = self.v(memory) + + # [B, 2, H1, W1] -> [BH1W1, 1, 2] + query_coord = query_coord.contiguous() + query_coord = query_coord.view(B, 2, -1).permute(0, 2, 1)[:,:,None,:].contiguous().view(B*H1*W1, 1, 2) + if self.pe == 'linear': + query_coord_enc = LinearPositionEmbeddingSine(query_coord, dim=self.dim) + elif self.pe == 'exp': + query_coord_enc = ExpPositionEmbeddingSine(query_coord, dim=self.dim) + + short_cut = query + query = self.norm1(query) + + if self.add_flow_token: + q = self.q(query+query_coord_enc) + else: + q = self.q(query_coord_enc) + k, v = key, value + + x = self.multi_head_attn(q, k, v) + + x = self.proj(torch.cat([x, short_cut],dim=2)) + x = short_cut + self.proj_drop(x) + + x = x + self.drop_path(self.ffn(self.norm2(x))) + + return x, k, v + +class MemoryDecoderLayer(nn.Module): + def __init__(self, patch_size, query_latent_dim=64, cost_latent_dim=128): + super(MemoryDecoderLayer, self).__init__() + self.patch_size = patch_size # for converting coords into H2', W2' space + self.query_latent_dim = query_latent_dim + query_token_dim, tgt_token_dim = query_latent_dim, cost_latent_dim + qk_dim, v_dim = query_token_dim, query_token_dim + self.cross_attend = CrossAttentionLayer(qk_dim, v_dim, query_token_dim, tgt_token_dim, add_flow_token=True) + + def forward(self, query, key, value, memory, coords1, size, size_h3w3): + """ + x: [B*H1*W1, 1, C] + memory: [B*H1*W1, H2'*W2', C] + coords1 [B, 2, H2, W2] + size: B, C, H1, W1 + 1. Note that here coords0 and coords1 are in H2, W2 space. + Should first convert it into H2', W2' space. + 2. We assume the upper-left point to be [0, 0], instead of letting center of upper-left patch to be [0, 0] + """ + x_global, k, v = self.cross_attend(query, key, value, memory, coords1, self.patch_size, size_h3w3) + B, C, H1, W1 = size + C = self.query_latent_dim + x_global = x_global.view(B, H1, W1, C).permute(0, 3, 1, 2) + return x_global, k, v + + + +class MemoryDecoder(nn.Module): + def __init__(self, cfg): + super(MemoryDecoder, self).__init__() + dim = self.dim = cfg.query_latent_dim + self.cfg = cfg + + self.flow_token_encoder = nn.Sequential( + nn.Conv2d(81*cfg.cost_heads_num, dim, 1, 1), + nn.GELU(), + nn.Conv2d(dim, dim, 1, 1) + ) + self.proj = nn.Conv2d(256, 256, 1) + self.depth = cfg.decoder_depth + self.decoder_layer = MemoryDecoderLayer(dim, cfg) + + if self.cfg.gma: + self.update_block = GMAUpdateBlock(self.cfg, hidden_dim=128) + self.att = Attention(args=self.cfg, dim=128, heads=1, max_pos_size=160, dim_head=128) + else: + self.update_block = BasicUpdateBlock(self.cfg, hidden_dim=128) + + def upsample_flow(self, flow, mask): + """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 8, 8, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(8 * flow, [3,3], padding=1) + up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 2, 8*H, 8*W) + + def encode_flow_token(self, cost_maps, coords): + """ + cost_maps - B*H1*W1, cost_heads_num, H2, W2 + coords - B, 2, H1, W1 + """ + coords = coords.permute(0, 2, 3, 1) + batch, h1, w1, _ = coords.shape + + r = 4 + dx = torch.linspace(-r, r, 2*r+1) + dy = torch.linspace(-r, r, 2*r+1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) + + centroid = coords.reshape(batch*h1*w1, 1, 1, 2) + delta = delta.view(1, 2*r+1, 2*r+1, 2) + coords = centroid + delta + corr = bilinear_sampler(cost_maps, coords) + corr = corr.view(batch, h1, w1, -1).permute(0, 3, 1, 2) + return corr + + def forward(self, cost_memory, context, data={}, flow_init=None): + """ + memory: [B*H1*W1, H2'*W2', C] + context: [B, D, H1, W1] + """ + cost_maps = data['cost_maps'] + coords0, coords1 = initialize_flow(context) + + if flow_init is not None: + #print("[Using warm start]") + coords1 = coords1 + flow_init + + #flow = coords1 + + flow_predictions = [] + + context = self.proj(context) + net, inp = torch.split(context, [128, 128], dim=1) + net = torch.tanh(net) + inp = torch.relu(inp) + if self.cfg.gma: + attention = self.att(inp) + + size = net.shape + key, value = None, None + + for idx in range(self.depth): + coords1 = coords1.detach() + + cost_forward = self.encode_flow_token(cost_maps, coords1) + #cost_backward = self.reverse_cost_extractor(cost_maps, coords0, coords1) + + query = self.flow_token_encoder(cost_forward) + query = query.permute(0, 2, 3, 1).contiguous().view(size[0]*size[2]*size[3], 1, self.dim) + cost_global, key, value = self.decoder_layer(query, key, value, cost_memory, coords1, size, data['H3W3']) + if self.cfg.only_global: + corr = cost_global + else: + corr = torch.cat([cost_global, cost_forward], dim=1) + + flow = coords1 - coords0 + + if self.cfg.gma: + net, up_mask, delta_flow = self.update_block(net, inp, corr, flow, attention) + else: + net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) + + # flow = delta_flow + coords1 = coords1 + delta_flow + flow_up = self.upsample_flow(coords1 - coords0, up_mask) + flow_predictions.append(flow_up) + + if self.training: + return flow_predictions + else: + return flow_predictions[-1:] diff --git a/modules/components/amt_flowformer/blocks/encoders.py b/modules/components/amt_flowformer/blocks/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..f132aaf1f2bf7e86fe9713c95ec14c2f5991ab4c --- /dev/null +++ b/modules/components/amt_flowformer/blocks/encoders.py @@ -0,0 +1,92 @@ +import torch +import torch.nn as nn +import timm +import numpy as np + +class twins_svt_large(nn.Module): + def __init__(self, pretrained=True): + super().__init__() + self.svt = timm.create_model('twins_svt_large', pretrained=pretrained) + + del self.svt.head + del self.svt.patch_embeds[2] + del self.svt.patch_embeds[2] + del self.svt.blocks[2] + del self.svt.blocks[2] + del self.svt.pos_block[2] + del self.svt.pos_block[2] + + def forward(self, x, data=None, layer=2): + B = x.shape[0] + for i, (embed, drop, blocks, pos_blk) in enumerate( + zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)): + + x, size = embed(x) + x = drop(x) + for j, blk in enumerate(blocks): + x = blk(x, size) + if j==0: + x = pos_blk(x, size) + if i < len(self.svt.depths) - 1: + x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() + + if i == layer-1: + break + + return x + + def compute_params(self, layer=2): + num = 0 + for i, (embed, drop, blocks, pos_blk) in enumerate( + zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)): + + for param in embed.parameters(): + num += np.prod(param.size()) + + for param in drop.parameters(): + num += np.prod(param.size()) + + for param in blocks.parameters(): + num += np.prod(param.size()) + + for param in pos_blk.parameters(): + num += np.prod(param.size()) + + if i == layer-1: + break + + for param in self.svt.head.parameters(): + num += np.prod(param.size()) + + return num + +class twins_svt_large_context(nn.Module): + def __init__(self, pretrained=True): + super().__init__() + self.svt = timm.create_model('twins_svt_large_context', pretrained=pretrained) + + def forward(self, x, data=None, layer=2): + B = x.shape[0] + for i, (embed, drop, blocks, pos_blk) in enumerate( + zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)): + + x, size = embed(x) + x = drop(x) + for j, blk in enumerate(blocks): + x = blk(x, size) + if j==0: + x = pos_blk(x, size) + if i < len(self.svt.depths) - 1: + x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() + + if i == layer-1: + break + + return x + + +if __name__ == "__main__": + m = twins_svt_large() + input = torch.randn(2, 3, 400, 800) + out = m.extract_feature(input) + print(out.shape) diff --git a/modules/components/amt_flowformer/blocks/feat_enc.py b/modules/components/amt_flowformer/blocks/feat_enc.py new file mode 100644 index 0000000000000000000000000000000000000000..3805bd315422703c19bf6a4d0962ee75002d92aa --- /dev/null +++ b/modules/components/amt_flowformer/blocks/feat_enc.py @@ -0,0 +1,343 @@ +import torch +import torch.nn as nn + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) + self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes//4) + self.norm2 = nn.BatchNorm2d(planes//4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes//4) + self.norm2 = nn.InstanceNorm2d(planes//4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + +class SmallEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(SmallEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(32) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(32) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 32 + self.layer1 = self._make_layer(32, stride=1) + self.layer2 = self._make_layer(64, stride=2) + self.layer3 = self._make_layer(96, stride=2) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + +class BasicEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(72, stride=2) + self.layer3 = self._make_layer(128, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + +class LargeEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(LargeEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(112, stride=2) + self.layer3 = self._make_layer(160, stride=2) + self.layer3_2 = self._make_layer(160, stride=1) + + # output convolution + self.conv2 = nn.Conv2d(self.in_planes, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer3_2(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x diff --git a/modules/components/amt_flowformer/blocks/gma.py b/modules/components/amt_flowformer/blocks/gma.py new file mode 100644 index 0000000000000000000000000000000000000000..dd712e8b06c166ac8ca42c99307892212362150a --- /dev/null +++ b/modules/components/amt_flowformer/blocks/gma.py @@ -0,0 +1,123 @@ +import torch +from torch import nn, einsum +from einops import rearrange + + +class RelPosEmb(nn.Module): + def __init__( + self, + max_pos_size, + dim_head + ): + super().__init__() + self.rel_height = nn.Embedding(2 * max_pos_size - 1, dim_head) + self.rel_width = nn.Embedding(2 * max_pos_size - 1, dim_head) + + deltas = torch.arange(max_pos_size).view(1, -1) - torch.arange(max_pos_size).view(-1, 1) + rel_ind = deltas + max_pos_size - 1 + self.register_buffer('rel_ind', rel_ind) + + def forward(self, q): + batch, heads, h, w, c = q.shape + height_emb = self.rel_height(self.rel_ind[:h, :h].reshape(-1)) + width_emb = self.rel_width(self.rel_ind[:w, :w].reshape(-1)) + + height_emb = rearrange(height_emb, '(x u) d -> x u () d', x=h) + width_emb = rearrange(width_emb, '(y v) d -> y () v d', y=w) + + height_score = einsum('b h x y d, x u v d -> b h x y u v', q, height_emb) + width_score = einsum('b h x y d, y u v d -> b h x y u v', q, width_emb) + + return height_score + width_score + + +class Attention(nn.Module): + def __init__( + self, + *, + args, + dim, + max_pos_size = 100, + heads = 4, + dim_head = 128, + ): + super().__init__() + self.args = args + self.heads = heads + self.scale = dim_head ** -0.5 + inner_dim = heads * dim_head + + self.to_qk = nn.Conv2d(dim, inner_dim * 2, 1, bias=False) + + self.pos_emb = RelPosEmb(max_pos_size, dim_head) + + def forward(self, fmap): + heads, b, c, h, w = self.heads, *fmap.shape + + q, k = self.to_qk(fmap).chunk(2, dim=1) + + q, k = map(lambda t: rearrange(t, 'b (h d) x y -> b h x y d', h=heads), (q, k)) + q = self.scale * q + + # if self.args.position_only: + # sim = self.pos_emb(q) + + # elif self.args.position_and_content: + # sim_content = einsum('b h x y d, b h u v d -> b h x y u v', q, k) + # sim_pos = self.pos_emb(q) + # sim = sim_content + sim_pos + + # else: + sim = einsum('b h x y d, b h u v d -> b h x y u v', q, k) + + sim = rearrange(sim, 'b h x y u v -> b h (x y) (u v)') + attn = sim.softmax(dim=-1) + + return attn + + +class Aggregate(nn.Module): + def __init__( + self, + args, + dim, + heads = 4, + dim_head = 128, + ): + super().__init__() + self.args = args + self.heads = heads + self.scale = dim_head ** -0.5 + inner_dim = heads * dim_head + + self.to_v = nn.Conv2d(dim, inner_dim, 1, bias=False) + + self.gamma = nn.Parameter(torch.zeros(1)) + + if dim != inner_dim: + self.project = nn.Conv2d(inner_dim, dim, 1, bias=False) + else: + self.project = None + + def forward(self, attn, fmap): + heads, b, c, h, w = self.heads, *fmap.shape + + v = self.to_v(fmap) + v = rearrange(v, 'b (h d) x y -> b h (x y) d', h=heads) + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w) + + if self.project is not None: + out = self.project(out) + + out = fmap + self.gamma * out + + return out + + +if __name__ == "__main__": + att = Attention(dim=128, heads=1) + fmap = torch.randn(2, 128, 40, 90) + out = att(fmap) + + print(out.shape) \ No newline at end of file diff --git a/modules/components/amt_flowformer/blocks/gru.py b/modules/components/amt_flowformer/blocks/gru.py new file mode 100644 index 0000000000000000000000000000000000000000..92802c76e91471573551164ff60fbd94c26b4424 --- /dev/null +++ b/modules/components/amt_flowformer/blocks/gru.py @@ -0,0 +1,137 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256): + super(FlowHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) + + h = (1-z) * h + z * q + return h + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + + self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + return h + +class BasicMotionEncoder(nn.Module): + def __init__(self, args): + super(BasicMotionEncoder, self).__init__() + if args.only_global: + print("[Decoding with only global cost]") + cor_planes = args.query_latent_dim + else: + cor_planes = 81+args.query_latent_dim + self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(2, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + +class BasicUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128, input_dim=128): + super(BasicUpdateBlock, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64*9, 1, padding=0)) + + def forward(self, net, inp, corr, flow, upsample=True): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = .25 * self.mask(net) + return net, mask, delta_flow + +from .gma import Aggregate +class GMAUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128): + super().__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim+hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64*9, 1, padding=0)) + + self.aggregator = Aggregate(args=self.args, dim=128, dim_head=128, heads=1) + + def forward(self, net, inp, corr, flow, attention): + motion_features = self.encoder(flow, corr) + motion_features_global = self.aggregator(attention, motion_features) + inp_cat = torch.cat([inp, motion_features, motion_features_global], dim=1) + + # Attentional update + net = self.gru(net, inp_cat) + + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = .25 * self.mask(net) + return net, mask, delta_flow \ No newline at end of file diff --git a/modules/components/amt_flowformer/blocks/ifrnet.py b/modules/components/amt_flowformer/blocks/ifrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c00cfba925abe7e51bc558812b4bbea40611cd0d --- /dev/null +++ b/modules/components/amt_flowformer/blocks/ifrnet.py @@ -0,0 +1,111 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .warp import warp + + +def resize(x, scale_factor): + return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False) + +def convrelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True): + return nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias), + nn.PReLU(out_channels) + ) + +class ResBlock(nn.Module): + def __init__(self, in_channels, side_channels, bias=True): + super(ResBlock, self).__init__() + self.side_channels = side_channels + self.conv1 = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(in_channels) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(side_channels) + ) + self.conv3 = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(in_channels) + ) + self.conv4 = nn.Sequential( + nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(side_channels) + ) + self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias) + self.prelu = nn.PReLU(in_channels) + + def forward(self, x): + out = self.conv1(x) + + res_feat = out[:, :-self.side_channels, ...] + side_feat = out[:, -self.side_channels:, :, :] + side_feat = self.conv2(side_feat) + out = self.conv3(torch.cat([res_feat, side_feat], 1)) + + res_feat = out[:, :-self.side_channels, ...] + side_feat = out[:, -self.side_channels:, :, :] + side_feat = self.conv4(side_feat) + out = self.conv5(torch.cat([res_feat, side_feat], 1)) + + out = self.prelu(x + out) + return out + +class Encoder(nn.Module): + def __init__(self, channels, large=False): + super(Encoder, self).__init__() + self.channels = channels + prev_ch = 3 + for idx, ch in enumerate(channels, 1): + k = 7 if large and idx == 1 else 3 + p = 3 if k ==7 else 1 + self.register_module(f'pyramid{idx}', + nn.Sequential( + convrelu(prev_ch, ch, k, 2, p), + convrelu(ch, ch, 3, 1, 1) + )) + prev_ch = ch + + def forward(self, in_x): + fs = [] + for idx in range(len(self.channels)): + out_x = getattr(self, f'pyramid{idx+1}')(in_x) + fs.append(out_x) + in_x = out_x + return fs + +class InitDecoder(nn.Module): + def __init__(self, in_ch, out_ch, skip_ch) -> None: + super().__init__() + self.convblock = nn.Sequential( + convrelu(in_ch*2+1, in_ch*2), + ResBlock(in_ch*2, skip_ch), + nn.ConvTranspose2d(in_ch*2, out_ch+4, 4, 2, 1, bias=True) + ) + def forward(self, f0, f1, embt): + h, w = f0.shape[2:] + embt = embt.repeat(1, 1, h, w) + out = self.convblock(torch.cat([f0, f1, embt], 1)) + flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1) + ft_ = out[:, 4:, ...] + return flow0, flow1, ft_ + +class IntermediateDecoder(nn.Module): + def __init__(self, in_ch, out_ch, skip_ch) -> None: + super().__init__() + self.convblock = nn.Sequential( + convrelu(in_ch*3+4, in_ch*3), + ResBlock(in_ch*3, skip_ch), + nn.ConvTranspose2d(in_ch*3, out_ch+4, 4, 2, 1, bias=True) + ) + def forward(self, ft_, f0, f1, flow0_in, flow1_in): + f0_warp = warp(f0, flow0_in) + f1_warp = warp(f1, flow1_in) + f_in = torch.cat([ft_, f0_warp, f1_warp, flow0_in, flow1_in], 1) + out = self.convblock(f_in) + flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1) + ft_ = out[:, 4:, ...] + flow0 = flow0 + 2.0 * resize(flow0_in, scale_factor=2.0) + flow1 = flow1 + 2.0 * resize(flow1_in, scale_factor=2.0) + return flow0, flow1, ft_ \ No newline at end of file diff --git a/modules/components/amt_flowformer/blocks/memory_enc.py b/modules/components/amt_flowformer/blocks/memory_enc.py new file mode 100644 index 0000000000000000000000000000000000000000..0fd8fabee1c808ceeb8e3c3d2ddca7c7e65ec399 --- /dev/null +++ b/modules/components/amt_flowformer/blocks/memory_enc.py @@ -0,0 +1,291 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum +import numpy as np + +from einops import rearrange + +from .utils import coords_grid, bilinear_sampler, upflow8 +from .attention import BroadMultiHeadAttention, MultiHeadAttention, LinearPositionEmbeddingSine, \ + ExpPositionEmbeddingSine +from typing import Optional, Tuple +from .twins import Size_, PosConv + +from timm.models.layers import Mlp, DropPath, to_2tuple, trunc_normal_ + + +class PatchEmbed(nn.Module): + def __init__(self, patch_size=16, in_chans=1, embed_dim=64, pe='linear'): + super().__init__() + self.patch_size = patch_size + self.dim = embed_dim + self.pe = pe + + # assert patch_size == 8 + if patch_size == 8: + self.proj = nn.Sequential( + nn.Conv2d(in_chans, embed_dim // 4, kernel_size=6, stride=2, padding=2), + nn.ReLU(), + nn.Conv2d(embed_dim // 4, embed_dim // 2, kernel_size=6, stride=2, padding=2), + nn.ReLU(), + nn.Conv2d(embed_dim // 2, embed_dim, kernel_size=6, stride=2, padding=2), + ) + elif patch_size == 4: + self.proj = nn.Sequential( + nn.Conv2d(in_chans, embed_dim // 4, kernel_size=6, stride=2, padding=2), + nn.ReLU(), + nn.Conv2d(embed_dim // 4, embed_dim, kernel_size=6, stride=2, padding=2), + ) + else: + print(f"patch size = {patch_size} is unacceptable.") + + self.ffn_with_coord = nn.Sequential( + nn.Conv2d(embed_dim * 2, embed_dim * 2, kernel_size=1), + nn.ReLU(), + nn.Conv2d(embed_dim * 2, embed_dim * 2, kernel_size=1) + ) + self.norm = nn.LayerNorm(embed_dim * 2) + + def forward(self, x) -> Tuple[torch.Tensor, Size_]: + B, C, H, W = x.shape # C == 1 + + pad_l = pad_t = 0 + pad_r = (self.patch_size - W % self.patch_size) % self.patch_size + pad_b = (self.patch_size - H % self.patch_size) % self.patch_size + x = F.pad(x, (pad_l, pad_r, pad_t, pad_b)) + + x = self.proj(x) + out_size = x.shape[2:] + + patch_coord = coords_grid(B, out_size[0], out_size[1]).to( + x.device) * self.patch_size + self.patch_size / 2 # in feature coordinate space + patch_coord = patch_coord.view(B, 2, -1).permute(0, 2, 1) + if self.pe == 'linear': + patch_coord_enc = LinearPositionEmbeddingSine(patch_coord, dim=self.dim) + elif self.pe == 'exp': + patch_coord_enc = ExpPositionEmbeddingSine(patch_coord, dim=self.dim) + patch_coord_enc = patch_coord_enc.permute(0, 2, 1).view(B, -1, out_size[0], out_size[1]) + + x_pe = torch.cat([x, patch_coord_enc], dim=1) + x = self.ffn_with_coord(x_pe) + x = self.norm(x.flatten(2).transpose(1, 2)) + + return x, out_size + + +from .twins import Block, CrossBlock + + +class VerticalSelfAttentionLayer(nn.Module): + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., drop_path=0., dropout=0.): + super(VerticalSelfAttentionLayer, self).__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + embed_dim = dim + mlp_ratio = 4 + ws = 7 + sr_ratio = 4 + dpr = 0. + drop_rate = dropout + attn_drop_rate = 0. + + self.local_block = Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr, sr_ratio=sr_ratio, ws=ws, with_rpe=True) + self.global_block = Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr, sr_ratio=sr_ratio, ws=1, with_rpe=True) + + def forward(self, x, size, context=None): + x = self.local_block(x, size, context) + x = self.global_block(x, size, context) + + return x + + def compute_params(self): + num = 0 + for param in self.parameters(): + num += np.prod(param.size()) + + return num + + +class SelfAttentionLayer(nn.Module): + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., drop_path=0., dropout=0.): + super(SelfAttentionLayer, self).__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.multi_head_attn = MultiHeadAttention(dim, num_heads) + self.q, self.k, self.v = nn.Linear(dim, dim, bias=True), nn.Linear(dim, dim, bias=True), nn.Linear(dim, dim, + bias=True) + + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.ffn = nn.Sequential( + nn.Linear(dim, dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x): + """ + x: [BH1W1, H3W3, D] + """ + short_cut = x + x = self.norm1(x) + + q, k, v = self.q(x), self.k(x), self.v(x) + + x = self.multi_head_attn(q, k, v) + + x = self.proj(x) + x = short_cut + self.proj_drop(x) + + x = x + self.drop_path(self.ffn(self.norm2(x))) + + return x + + def compute_params(self): + num = 0 + for param in self.parameters(): + num += np.prod(param.size()) + + return num + + +class CrossAttentionLayer(nn.Module): + def __init__(self, qk_dim, v_dim, query_token_dim, tgt_token_dim, num_heads=8, attn_drop=0., proj_drop=0., + drop_path=0., dropout=0.): + super(CrossAttentionLayer, self).__init__() + assert qk_dim % num_heads == 0, f"dim {qk_dim} should be divided by num_heads {num_heads}." + assert v_dim % num_heads == 0, f"dim {v_dim} should be divided by num_heads {num_heads}." + """ + Query Token: [N, C] -> [N, qk_dim] (Q) + Target Token: [M, D] -> [M, qk_dim] (K), [M, v_dim] (V) + """ + self.num_heads = num_heads + head_dim = qk_dim // num_heads + self.scale = head_dim ** -0.5 + + self.norm1 = nn.LayerNorm(query_token_dim) + self.norm2 = nn.LayerNorm(query_token_dim) + self.multi_head_attn = BroadMultiHeadAttention(qk_dim, num_heads) + self.q, self.k, self.v = nn.Linear(query_token_dim, qk_dim, bias=True), nn.Linear(tgt_token_dim, qk_dim, + bias=True), nn.Linear( + tgt_token_dim, v_dim, bias=True) + + self.proj = nn.Linear(v_dim, query_token_dim) + self.proj_drop = nn.Dropout(proj_drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.ffn = nn.Sequential( + nn.Linear(query_token_dim, query_token_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(query_token_dim, query_token_dim), + nn.Dropout(dropout) + ) + + def forward(self, query, tgt_token): + """ + x: [BH1W1, H3W3, D] + """ + short_cut = query + query = self.norm1(query) + + q, k, v = self.q(query), self.k(tgt_token), self.v(tgt_token) + + x = self.multi_head_attn(q, k, v) + + x = short_cut + self.proj_drop(self.proj(x)) + + x = x + self.drop_path(self.ffn(self.norm2(x))) + + return x + + +class CostPerceiverEncoder(nn.Module): + def __init__(self, patch_size, encoder_depth, cost_latent_token_num, cost_latent_dim, cost_latent_input_dim, pe, + dropout): + super(CostPerceiverEncoder, self).__init__() + self.cost_latent_token_num = cost_latent_token_num + self.patch_size = patch_size + self.patch_embed = PatchEmbed(in_chans=1, patch_size=8, + embed_dim=cost_latent_input_dim, pe=pe) + + self.depth = encoder_depth + + self.latent_tokens = nn.Parameter(torch.randn(1, cost_latent_token_num, cost_latent_dim)) + + query_token_dim, tgt_token_dim = cost_latent_dim, cost_latent_input_dim * 2 + qk_dim, v_dim = query_token_dim, query_token_dim + self.input_layer = CrossAttentionLayer(qk_dim, v_dim, query_token_dim, tgt_token_dim, dropout=dropout) + + self.encoder_layers = nn.ModuleList( + [SelfAttentionLayer(cost_latent_dim, dropout=dropout) for _ in range(self.depth)]) + + self.vertical_encoder_layers = nn.ModuleList( + [VerticalSelfAttentionLayer(cost_latent_dim, dropout=dropout) for _ in range(self.depth)]) + + def forward(self, cost_volume, context=None): + B, heads, H1, W1, H2, W2 = cost_volume.shape + cost_maps = cost_volume.permute(0, 2, 3, 1, 4, 5).contiguous().view(B * H1 * W1, 1, H2, W2) + + x, size = self.patch_embed(cost_maps) # B*H1*W1, size[0]*size[1], C + + x = self.input_layer(self.latent_tokens, x) + + short_cut = x + + for idx, layer in enumerate(self.encoder_layers): + x = layer(x) + x = x.view(B, H1 * W1, self.cost_latent_token_num, -1).permute(0, 2, 1, 3).reshape( + B * self.cost_latent_token_num, H1 * W1, -1) + x = self.vertical_encoder_layers[idx](x, (H1, W1), context) + x = x.view(B, self.cost_latent_token_num, H1 * W1, -1).permute(0, 2, 1, 3).reshape(B * H1 * W1, + self.cost_latent_token_num, + -1) + + x = x + short_cut + return x, size + + +class MemoryEncoder(nn.Module): + def __init__(self, encoder_latent_dim, cost_heads_num, feat_cross_attn, patch_size, encoder_depth, + cost_latent_token_num, cost_latent_dim, cost_latent_input_dim, pe, dropout): + super(MemoryEncoder, self).__init__() + self.feat_cross_attn = feat_cross_attn + self.cost_heads_num = cost_heads_num + self.channel_convertor = nn.Conv2d(encoder_latent_dim, encoder_latent_dim, 1, padding=0, bias=False) + self.cost_perceiver_encoder = CostPerceiverEncoder(patch_size, encoder_depth, cost_latent_token_num, + cost_latent_dim, cost_latent_input_dim, pe, dropout) + + def corr(self, fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = rearrange(fmap1, 'b (heads d) h w -> b heads (h w) d', heads=self.cost_heads_num) + fmap2 = rearrange(fmap2, 'b (heads d) h w -> b heads (h w) d', heads=self.cost_heads_num) + corr = einsum('bhid, bhjd -> bhij', fmap1, fmap2) + corr = corr.permute(0, 2, 1, 3).view(batch * ht * wd, self.cost_heads_num, ht, wd) + corr = corr.view(batch, ht * wd, self.cost_heads_num, ht * wd).permute(0, 2, 1, 3) + corr = corr.view(batch, self.cost_heads_num, ht, wd, ht, wd) + + return corr + + def forward(self, feat_s, feat_t, context=None): + cost_volume = self.corr(feat_s, feat_t) + x, size = self.cost_perceiver_encoder(cost_volume, context) + + return x, cost_volume, size diff --git a/modules/components/amt_flowformer/blocks/multi_flow.py b/modules/components/amt_flowformer/blocks/multi_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..21097edfa11300d4a11a63e17b3dc793fe0b893d --- /dev/null +++ b/modules/components/amt_flowformer/blocks/multi_flow.py @@ -0,0 +1,69 @@ +import torch +import torch.nn as nn +from .warp import warp +from .ifrnet import ( + convrelu, resize, + ResBlock, +) + + +def multi_flow_combine(comb_block, img0, img1, flow0, flow1, + mask=None, img_res=None, mean=None): + ''' + A parallel implementation of multiple flow field warping + comb_block: An nn.Seqential object. + img shape: [b, c, h, w] + flow shape: [b, 2*num_flows, h, w] + mask (opt): + If 'mask' is None, the function conduct a simple average. + img_res (opt): + If 'img_res' is None, the function adds zero instead. + mean (opt): + If 'mean' is None, the function adds zero instead. + ''' + b, c, h, w = flow0.shape + num_flows = c // 2 + flow0 = flow0.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) + flow1 = flow1.reshape(b, num_flows, 2, h, w).reshape(-1, 2, h, w) + + mask = mask.reshape(b, num_flows, 1, h, w + ).reshape(-1, 1, h, w) if mask is not None else None + img_res = img_res.reshape(b, num_flows, 3, h, w + ).reshape(-1, 3, h, w) if img_res is not None else 0 + img0 = torch.stack([img0] * num_flows, 1).reshape(-1, 3, h, w) + img1 = torch.stack([img1] * num_flows, 1).reshape(-1, 3, h, w) + mean = torch.stack([mean] * num_flows, 1).reshape(-1, 1, 1, 1 + ) if mean is not None else 0 + + img0_warp = warp(img0, flow0) + img1_warp = warp(img1, flow1) + img_warps = mask * img0_warp + (1 - mask) * img1_warp + mean + img_res + img_warps = img_warps.reshape(b, num_flows, 3, h, w) + imgt_pred = img_warps.mean(1) + comb_block(img_warps.view(b, -1, h, w)) + return imgt_pred + + +class MultiFlowDecoder(nn.Module): + def __init__(self, in_ch, skip_ch, num_flows=3): + super(MultiFlowDecoder, self).__init__() + self.num_flows = num_flows + self.convblock = nn.Sequential( + convrelu(in_ch*3+4, in_ch*3), + ResBlock(in_ch*3, skip_ch), + nn.ConvTranspose2d(in_ch*3, 8*num_flows, 4, 2, 1, bias=True) + ) + + def forward(self, ft_, f0, f1, flow0, flow1): + n = self.num_flows + f0_warp = warp(f0, flow0) + f1_warp = warp(f1, flow1) + out = self.convblock(torch.cat([ft_, f0_warp, f1_warp, flow0, flow1], 1)) + delta_flow0, delta_flow1, mask, img_res = torch.split(out, [2*n, 2*n, n, 3*n], 1) + mask = torch.sigmoid(mask) + + flow0 = delta_flow0 + 2.0 * resize(flow0, scale_factor=2.0 + ).repeat(1, self.num_flows, 1, 1) + flow1 = delta_flow1 + 2.0 * resize(flow1, scale_factor=2.0 + ).repeat(1, self.num_flows, 1, 1) + + return flow0, flow1, mask, img_res \ No newline at end of file diff --git a/modules/components/amt_flowformer/blocks/raft.py b/modules/components/amt_flowformer/blocks/raft.py new file mode 100644 index 0000000000000000000000000000000000000000..9fb85ad6556a28f5b80034c595be539fd700ad48 --- /dev/null +++ b/modules/components/amt_flowformer/blocks/raft.py @@ -0,0 +1,207 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def resize(x, scale_factor): + return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False) + + +def bilinear_sampler(img, coords, mask=False): + """ Wrapper for grid_sample, uses pixel coordinates """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1,1], dim=-1) + xgrid = 2*xgrid/(W-1) - 1 + ygrid = 2*ygrid/(H-1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +def coords_grid(batch, ht, wd, device): + coords = torch.meshgrid(torch.arange(ht, device=device), + torch.arange(wd, device=device), + indexing='ij') + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +class SmallUpdateBlock(nn.Module): + def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, fc_dim, + corr_levels=4, radius=3, scale_factor=None): + super(SmallUpdateBlock, self).__init__() + cor_planes = corr_levels * (2 * radius + 1) **2 + self.scale_factor = scale_factor + + self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0) + self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3) + self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1) + self.conv = nn.Conv2d(corr_dim+flow_dim, fc_dim, 3, padding=1) + + self.gru = nn.Sequential( + nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + ) + + self.feat_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, cdim, 3, padding=1), + ) + + self.flow_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, 4, 3, padding=1), + ) + + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, net, flow, corr): + net = resize(net, 1 / self.scale_factor + ) if self.scale_factor is not None else net + cor = self.lrelu(self.convc1(corr)) + flo = self.lrelu(self.convf1(flow)) + flo = self.lrelu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + inp = self.lrelu(self.conv(cor_flo)) + inp = torch.cat([inp, flow, net], dim=1) + + out = self.gru(inp) + delta_net = self.feat_head(out) + delta_flow = self.flow_head(out) + + if self.scale_factor is not None: + delta_net = resize(delta_net, scale_factor=self.scale_factor) + delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor) + + return delta_net, delta_flow + + +class BasicUpdateBlock(nn.Module): + def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, corr_dim2, + fc_dim, corr_levels=4, radius=3, scale_factor=None, out_num=1): + super(BasicUpdateBlock, self).__init__() + cor_planes = corr_levels * (2 * radius + 1) **2 + + self.scale_factor = scale_factor + self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0) + self.convc2 = nn.Conv2d(corr_dim, corr_dim2, 3, padding=1) + self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3) + self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1) + self.conv = nn.Conv2d(flow_dim+corr_dim2, fc_dim, 3, padding=1) + + self.gru = nn.Sequential( + nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + ) + + self.feat_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, cdim, 3, padding=1), + ) + + self.flow_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, 4*out_num, 3, padding=1), + ) + + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, net, flow, corr): + net = resize(net, 1 / self.scale_factor + ) if self.scale_factor is not None else net + cor = self.lrelu(self.convc1(corr)) + cor = self.lrelu(self.convc2(cor)) + flo = self.lrelu(self.convf1(flow)) + flo = self.lrelu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + inp = self.lrelu(self.conv(cor_flo)) + inp = torch.cat([inp, flow, net], dim=1) + + out = self.gru(inp) + delta_net = self.feat_head(out) + delta_flow = self.flow_head(out) + + if self.scale_factor is not None: + delta_net = resize(delta_net, scale_factor=self.scale_factor) + delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor) + return delta_net, delta_flow + + +class BidirCorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + self.corr_pyramid = [] + self.corr_pyramid_T = [] + + corr = BidirCorrBlock.corr(fmap1, fmap2) + batch, h1, w1, dim, h2, w2 = corr.shape + corr_T = corr.clone().permute(0, 4, 5, 3, 1, 2) + + corr = corr.reshape(batch*h1*w1, dim, h2, w2) + corr_T = corr_T.reshape(batch*h2*w2, dim, h1, w1) + + self.corr_pyramid.append(corr) + self.corr_pyramid_T.append(corr_T) + + for _ in range(self.num_levels-1): + corr = F.avg_pool2d(corr, 2, stride=2) + corr_T = F.avg_pool2d(corr_T, 2, stride=2) + self.corr_pyramid.append(corr) + self.corr_pyramid_T.append(corr_T) + + def __call__(self, coords0, coords1): + r = self.radius + coords0 = coords0.permute(0, 2, 3, 1) + coords1 = coords1.permute(0, 2, 3, 1) + assert coords0.shape == coords1.shape, f"coords0 shape: [{coords0.shape}] is not equal to [{coords1.shape}]" + batch, h1, w1, _ = coords0.shape + + out_pyramid = [] + out_pyramid_T = [] + for i in range(self.num_levels): + corr = self.corr_pyramid[i] + corr_T = self.corr_pyramid_T[i] + + dx = torch.linspace(-r, r, 2*r+1, device=coords0.device) + dy = torch.linspace(-r, r, 2*r+1, device=coords0.device) + delta = torch.stack(torch.meshgrid(dy, dx, indexing='ij'), axis=-1) + delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) + + centroid_lvl_0 = coords0.reshape(batch*h1*w1, 1, 1, 2) / 2**i + centroid_lvl_1 = coords1.reshape(batch*h1*w1, 1, 1, 2) / 2**i + coords_lvl_0 = centroid_lvl_0 + delta_lvl + coords_lvl_1 = centroid_lvl_1 + delta_lvl + + corr = bilinear_sampler(corr, coords_lvl_0) + corr_T = bilinear_sampler(corr_T, coords_lvl_1) + corr = corr.view(batch, h1, w1, -1) + corr_T = corr_T.view(batch, h1, w1, -1) + out_pyramid.append(corr) + out_pyramid_T.append(corr_T) + + out = torch.cat(out_pyramid, dim=-1) + out_T = torch.cat(out_pyramid_T, dim=-1) + return out.permute(0, 3, 1, 2).contiguous().float(), out_T.permute(0, 3, 1, 2).contiguous().float() + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht*wd) + fmap2 = fmap2.view(batch, dim, ht*wd) + + corr = torch.matmul(fmap1.transpose(1,2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) \ No newline at end of file diff --git a/modules/components/amt_flowformer/blocks/twins.py b/modules/components/amt_flowformer/blocks/twins.py new file mode 100644 index 0000000000000000000000000000000000000000..03eefed6c18746ba07168f8a1f6f0b17277195e1 --- /dev/null +++ b/modules/components/amt_flowformer/blocks/twins.py @@ -0,0 +1,1031 @@ +""" Twins +A PyTorch impl of : `Twins: Revisiting the Design of Spatial Attention in Vision Transformers` + - https://arxiv.org/pdf/2104.13840.pdf +Code/weights from https://github.com/Meituan-AutoML/Twins, original copyright/license info below +""" +# -------------------------------------------------------- +# Twins +# Copyright (c) 2021 Meituan +# Licensed under The Apache 2.0 License [see LICENSE for details] +# Written by Xinjie Li, Xiangxiang Chu +# -------------------------------------------------------- +import math +from copy import deepcopy +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.layers import Mlp, DropPath, to_2tuple, trunc_normal_ +from timm.models.registry import register_model +from timm.models.vision_transformer import Attention +from .attention import MultiHeadAttention, LinearPositionEmbeddingSine +from .utils import coords_grid, bilinear_sampler, upflow8 + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embeds.0.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + 'twins_pcpvt_small': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_small-e70e7e7a.pth', + ), + 'twins_pcpvt_base': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_base-e5ecb09b.pth', + ), + 'twins_pcpvt_large': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_large-d273f802.pth', + ), + 'twins_svt_small': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_small-42e5f78c.pth', + ), + 'twins_svt_base': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_base-c2265010.pth', + ), + 'twins_svt_large': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_large-90f6aaa9.pth', + ), +} + +Size_ = Tuple[int, int] + + +class GroupAttnRPEContext(nn.Module): + """ Latent cost tokens attend to different group + """ + + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., ws=1, vert_c_dim=0): + super(GroupAttnRPEContext, self).__init__() + assert ws != 1 + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + assert vert_c_dim > 0, "vert_c_dim should not be 0" + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + self.vert_c_dim = vert_c_dim + + self.context_proj = nn.Linear(256, vert_c_dim) + self.q = nn.Linear(dim + vert_c_dim, dim, bias=True) + self.k = nn.Linear(dim + vert_c_dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.ws = ws + + def forward(self, x, size: Size_, context=None): + B, N, C = x.shape + C_qk = C + self.vert_c_dim + H, W = size + batch_num = B // 5 + + context = context.repeat(B // context.shape[0], 1, 1, 1) + context = context.view(B, -1, H * W).permute(0, 2, 1) + context = self.context_proj(context) + context = context.view(B, H, W, -1) + + x = x.view(B, H, W, C) + x_qk = torch.cat([x, context], dim=-1) + + pad_l = pad_t = 0 + pad_r = (self.ws - W % self.ws) % self.ws + pad_b = (self.ws - H % self.ws) % self.ws + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + x_qk = F.pad(x_qk, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + _h, _w = Hp // self.ws, Wp // self.ws + padded_N = Hp * Wp + + coords = coords_grid(B, Hp, Wp).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C_qk) + coords_enc = coords_enc.reshape(B, Hp, Wp, C_qk) + + q = self.q(x_qk + coords_enc).reshape(B, _h, self.ws, _w, self.ws, self.num_heads, + C // self.num_heads).transpose(2, 3) + q = q.reshape(B, _h * _w, self.ws * self.ws, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4) + + v = self.v(x) + k = self.k(x_qk + coords_enc) + # concate and do shifting operation together + kv = torch.cat([k, v], dim=-1) + kv_up = torch.cat( + [kv[:batch_num, self.ws:Hp, :, :], + kv[:batch_num, Hp - self.ws:Hp, :, :]], dim=1) + kv_down = torch.cat( + [kv[batch_num:batch_num * 2, :self.ws, :, :], + kv[batch_num:batch_num * 2, :Hp - self.ws, :, :]], dim=1) + kv_left = torch.cat( + [kv[batch_num * 2:batch_num * 3, :, self.ws:Wp, :], + kv[batch_num * 2:batch_num * 3, :, Wp - self.ws:Wp, :]], dim=2) + kv_right = torch.cat( + [kv[batch_num * 3:batch_num * 4, :, :self.ws, :], + kv[batch_num * 3:batch_num * 4, :, :Wp - self.ws, :]], dim=2) + kv_center = kv[batch_num * 4:batch_num * 5, :, :, :] + kv_shifted = torch.cat([kv_up, kv_down, kv_left, kv_right, kv_center], dim=0) + k, v = torch.split(kv_shifted, [self.dim, self.dim], dim=-1) + + k = k.reshape(B, _h, self.ws, _w, self.ws, self.num_heads, C // self.num_heads).transpose(2, 3) + k = k.reshape(B, _h * _w, self.ws * self.ws, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4) + + v = v.reshape(B, _h, self.ws, _w, self.ws, self.num_heads, C // self.num_heads).transpose(2, 3) + v = v.reshape(B, _h * _w, self.ws * self.ws, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class GroupAttnRPE(nn.Module): + """ Latent cost tokens attend to different group + """ + + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., ws=1): + super(GroupAttnRPE, self).__init__() + assert ws != 1 + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.ws = ws + + def forward(self, x, size: Size_, context=None): + B, N, C = x.shape + H, W = size + batch_num = B // 5 + x = x.view(B, H, W, C) + pad_l = pad_t = 0 + pad_r = (self.ws - W % self.ws) % self.ws + pad_b = (self.ws - H % self.ws) % self.ws + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + _h, _w = Hp // self.ws, Wp // self.ws + padded_N = Hp * Wp + + coords = coords_grid(B, Hp, Wp).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C) + coords_enc = coords_enc.reshape(B, Hp, Wp, C) + + q = self.q(x + coords_enc).reshape(B, _h, self.ws, _w, self.ws, self.num_heads, C // self.num_heads).transpose( + 2, 3) + q = q.reshape(B, _h * _w, self.ws * self.ws, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4) + + v = self.v(x) + k = self.k(x + coords_enc) + # concate and do shifting operation together + kv = torch.cat([k, v], dim=-1) + kv_up = torch.cat([kv[:batch_num, self.ws:Hp, :, :], kv[:batch_num, Hp - self.ws:Hp, :, :]], dim=1) + kv_down = torch.cat( + [kv[batch_num:batch_num * 2, :self.ws, :, :], kv[batch_num:batch_num * 2, :Hp - self.ws, :, :]], dim=1) + kv_left = torch.cat( + [kv[batch_num * 2:batch_num * 3, :, self.ws:Wp, :], kv[batch_num * 2:batch_num * 3, :, Wp - self.ws:Wp, :]], + dim=2) + kv_right = torch.cat( + [kv[batch_num * 3:batch_num * 4, :, :self.ws, :], kv[batch_num * 3:batch_num * 4, :, :Wp - self.ws, :]], + dim=2) + kv_center = kv[batch_num * 4:batch_num * 5, :, :, :] + kv_shifted = torch.cat([kv_up, kv_down, kv_left, kv_right, kv_center], dim=0) + k, v = torch.split(kv_shifted, [self.dim, self.dim], dim=-1) + + k = k.reshape(B, _h, self.ws, _w, self.ws, self.num_heads, C // self.num_heads).transpose(2, 3) + k = k.reshape(B, _h * _w, self.ws * self.ws, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4) + + v = v.reshape(B, _h, self.ws, _w, self.ws, self.num_heads, C // self.num_heads).transpose(2, 3) + v = v.reshape(B, _h * _w, self.ws * self.ws, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LocallyGroupedAttnRPEContext(nn.Module): + """ LSA: self attention within a group + """ + + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., ws=1, vert_c_dim=0): + assert ws != 1 + super(LocallyGroupedAttnRPEContext, self).__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + self.vert_c_dim = vert_c_dim + + self.context_proj = nn.Linear(256, vert_c_dim) + # context are not added to value + self.q = nn.Linear(dim + vert_c_dim, dim, bias=True) + self.k = nn.Linear(dim + vert_c_dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.ws = ws + + def forward(self, x, size: Size_, context=None): + # There are two implementations for this function, zero padding or mask. We don't observe obvious difference for + # both. You can choose any one, we recommend forward_padding because it's neat. However, + # the masking implementation is more reasonable and accurate. + B, N, C = x.shape + H, W = size + C_qk = C + self.vert_c_dim + + context = context.repeat(B // context.shape[0], 1, 1, 1) + context = context.view(B, -1, H * W).permute(0, 2, 1) + context = self.context_proj(context) + context = context.view(B, H, W, -1) + + x = x.view(B, H, W, C) + x_qk = torch.cat([x, context], dim=-1) + + pad_l = pad_t = 0 + pad_r = (self.ws - W % self.ws) % self.ws + pad_b = (self.ws - H % self.ws) % self.ws + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + x_qk = F.pad(x_qk, (0, 0, pad_l, pad_r, pad_t, pad_b)) + + _, Hp, Wp, _ = x.shape + _h, _w = Hp // self.ws, Wp // self.ws + x = x.reshape(B, _h, self.ws, _w, self.ws, C).transpose(2, 3) + x_qk = x_qk.reshape(B, _h, self.ws, _w, self.ws, C_qk).transpose(2, 3) + + v = self.v(x).reshape( + B, _h * _w, self.ws * self.ws, 1, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)[0] + + coords = coords_grid(B, self.ws, self.ws).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C_qk).view(B, self.ws, self.ws, C_qk) + # coords_enc: B, ws, ws, C + # x: B, _h, _w, self.ws, self.ws, C + x_qk = x_qk + coords_enc[:, None, None, :, :, :] + + q = self.q(x_qk).reshape( + B, _h * _w, self.ws * self.ws, 1, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)[0] + k = self.k(x_qk).reshape( + B, _h * _w, self.ws * self.ws, 1, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)[0] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class GlobalSubSampleAttnRPEContext(nn.Module): + """ GSA: using a key to summarize the information for a group to be efficient. + """ + + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., sr_ratio=1, vert_c_dim=0): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.vert_c_dim = vert_c_dim + self.context_proj = nn.Linear(256, vert_c_dim) + self.q = nn.Linear(dim + vert_c_dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr_key = nn.Conv2d(dim + vert_c_dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.sr_value = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.sr = None + self.norm = None + + def forward(self, x, size: Size_, context=None): + B, N, C = x.shape + C_qk = C + self.vert_c_dim + H, W = size + context = context.repeat(B // context.shape[0], 1, 1, 1) + context = context.view(B, -1, H * W).permute(0, 2, 1) + context = self.context_proj(context) + context = context.view(B, H, W, -1) + x = x.view(B, H, W, C) + x_qk = torch.cat([x, context], dim=-1) + pad_l = pad_t = 0 + pad_r = (self.sr_ratio - W % self.sr_ratio) % self.sr_ratio + pad_b = (self.sr_ratio - H % self.sr_ratio) % self.sr_ratio + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + x_qk = F.pad(x_qk, (0, 0, pad_l, pad_r, pad_t, pad_b)) + + _, Hp, Wp, _ = x.shape + padded_size = (Hp, Wp) + padded_N = Hp * Wp + x = x.view(B, -1, C) + x_qk = x_qk.view(B, -1, C_qk) + + coords = coords_grid(B, *padded_size).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C_qk) + # coords_enc: B, Hp*Wp, C + # x: B, Hp*Wp, C + q = self.q(x_qk + coords_enc).reshape(B, padded_N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + if self.sr_key is not None: + x = x.permute(0, 2, 1).reshape(B, C, *padded_size) + x_qk = x_qk.permute(0, 2, 1).reshape(B, C_qk, *padded_size) + x = self.sr_value(x).reshape(B, C, -1).permute(0, 2, 1) + x_qk = self.sr_key(x_qk).reshape(B, C, -1).permute(0, 2, 1) + x = self.norm(x) + x_qk = self.norm(x_qk) + + coords = coords_grid(B, padded_size[0] // self.sr_ratio, padded_size[1] // self.sr_ratio).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) * self.sr_ratio + # align the coordinate of local and global + coords_enc = LinearPositionEmbeddingSine(coords, dim=C) + k = self.k(x_qk + coords_enc).reshape(B, (padded_size[0] // self.sr_ratio) * (padded_size[1] // self.sr_ratio), + self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + v = self.v(x).reshape(B, (padded_size[0] // self.sr_ratio) * (padded_size[1] // self.sr_ratio), self.num_heads, + C // self.num_heads).permute(0, 2, 1, 3) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, Hp, Wp, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class LocallyGroupedAttnRPE(nn.Module): + """ LSA: self attention within a group + """ + + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., ws=1): + assert ws != 1 + super(LocallyGroupedAttnRPE, self).__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.ws = ws + + def forward(self, x, size: Size_, context=None): + # There are two implementations for this function, zero padding or mask. We don't observe obvious difference for + # both. You can choose any one, we recommend forward_padding because it's neat. However, + # the masking implementation is more reasonable and accurate. + B, N, C = x.shape + H, W = size + x = x.view(B, H, W, C) + pad_l = pad_t = 0 + pad_r = (self.ws - W % self.ws) % self.ws + pad_b = (self.ws - H % self.ws) % self.ws + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + _h, _w = Hp // self.ws, Wp // self.ws + x = x.reshape(B, _h, self.ws, _w, self.ws, C).transpose(2, 3) + v = self.v(x).reshape( + B, _h * _w, self.ws * self.ws, 1, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)[0] + + coords = coords_grid(B, self.ws, self.ws).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C).view(B, self.ws, self.ws, C) + # coords_enc: B, ws, ws, C + # x: B, _h, _w, self.ws, self.ws, C + x = x + coords_enc[:, None, None, :, :, :] + + q = self.q(x).reshape( + B, _h * _w, self.ws * self.ws, 1, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)[0] + k = self.k(x).reshape( + B, _h * _w, self.ws * self.ws, 1, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)[0] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class GlobalSubSampleAttnRPE(nn.Module): + """ GSA: using a key to summarize the information for a group to be efficient. + """ + + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.sr = None + self.norm = None + + def forward(self, x, size: Size_, context=None): + B, N, C = x.shape + H, W = size + x = x.view(B, H, W, C) + pad_l = pad_t = 0 + pad_r = (self.sr_ratio - W % self.sr_ratio) % self.sr_ratio + pad_b = (self.sr_ratio - H % self.sr_ratio) % self.sr_ratio + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + padded_size = (Hp, Wp) + padded_N = Hp * Wp + x = x.view(B, -1, C) + + coords = coords_grid(B, *padded_size).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C) + # coords_enc: B, Hp*Wp, C + # x: B, Hp*Wp, C + q = self.q(x + coords_enc).reshape(B, padded_N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + if self.sr is not None: + x = x.permute(0, 2, 1).reshape(B, C, *padded_size) + x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1) + x = self.norm(x) + + coords = coords_grid(B, padded_size[0] // self.sr_ratio, padded_size[1] // self.sr_ratio).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) * self.sr_ratio + # align the coordinate of local and global + coords_enc = LinearPositionEmbeddingSine(coords, dim=C) + k = self.k(x + coords_enc).reshape(B, (padded_size[0] // self.sr_ratio) * (padded_size[1] // self.sr_ratio), + self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + v = self.v(x).reshape(B, (padded_size[0] // self.sr_ratio) * (padded_size[1] // self.sr_ratio), self.num_heads, + C // self.num_heads).permute(0, 2, 1, 3) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, Hp, Wp, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class CrossGlobalSubSampleAttnRPE(nn.Module): + """ GSA: using a key to summarize the information for a group to be efficient. + """ + + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.sr = None + self.norm = None + + def forward(self, x, tgt, size: Size_): + B, N, C = x.shape + coords = coords_grid(B, *size).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C) + # coords_enc: B, H*W, C + # x: B, H*W, C + q = self.q(x + coords_enc).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + if self.sr is not None: + tgt = tgt.permute(0, 2, 1).reshape(B, C, *size) + tgt = self.sr(tgt).reshape(B, C, -1).permute(0, 2, 1) + tgt = self.norm(tgt) + coords = coords_grid(B, size[0] // self.sr_ratio, size[1] // self.sr_ratio).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) * self.sr_ratio + # align the coordinate of local and global + coords_enc = LinearPositionEmbeddingSine(coords, dim=C) + k = self.k(tgt + coords_enc).reshape(B, (size[0] // self.sr_ratio) * (size[1] // self.sr_ratio), self.num_heads, + C // self.num_heads).permute(0, 2, 1, 3) + v = self.v(tgt).reshape(B, (size[0] // self.sr_ratio) * (size[1] // self.sr_ratio), self.num_heads, + C // self.num_heads).permute(0, 2, 1, 3) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class LocallyGroupedAttn(nn.Module): + """ LSA: self attention within a group + """ + + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., ws=1): + assert ws != 1 + super(LocallyGroupedAttn, self).__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.ws = ws + + def forward(self, x, size: Size_): + # There are two implementations for this function, zero padding or mask. We don't observe obvious difference for + # both. You can choose any one, we recommend forward_padding because it's neat. However, + # the masking implementation is more reasonable and accurate. + B, N, C = x.shape + H, W = size + x = x.view(B, H, W, C) + pad_l = pad_t = 0 + pad_r = (self.ws - W % self.ws) % self.ws + pad_b = (self.ws - H % self.ws) % self.ws + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + _h, _w = Hp // self.ws, Wp // self.ws + x = x.reshape(B, _h, self.ws, _w, self.ws, C).transpose(2, 3) + qkv = self.qkv(x).reshape( + B, _h * _w, self.ws * self.ws, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5) + q, k, v = qkv[0], qkv[1], qkv[2] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class GlobalSubSampleAttn(nn.Module): + """ GSA: using a key to summarize the information for a group to be efficient. + """ + + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=True) + self.kv = nn.Linear(dim, dim * 2, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.sr = None + self.norm = None + + def forward(self, x, size: Size_): + B, N, C = x.shape + q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + if self.sr is not None: + x = x.permute(0, 2, 1).reshape(B, C, *size) + x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1) + x = self.norm(x) + kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class CrossGlobalSubSampleAttn(nn.Module): + """ GSA: using a key to summarize the information for a group to be efficient. + """ + + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=True) + self.kv = nn.Linear(dim, dim * 2, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.sr = None + self.norm = None + + def forward(self, x, tgt, size: Size_): + B, N, C = x.shape + q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + if self.sr is not None: + tgt = tgt.permute(0, 2, 1).reshape(B, C, *size) + tgt = self.sr(tgt).reshape(B, C, -1).permute(0, 2, 1) + tgt = self.norm(tgt) + kv = self.kv(tgt).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class CrossBlock(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, ws=None, with_rpe=True): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = CrossGlobalSubSampleAttnRPE(dim, num_heads, attn_drop, drop, sr_ratio) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, src, tgt, size: Size_): + src_shortcut, tgt_shortcut = src, tgt + + src, tgt = self.norm1(src), self.norm1(tgt) + src = src_shortcut + self.drop_path(self.attn(src, tgt, size)) + tgt = tgt_shortcut + self.drop_path(self.attn(tgt, src, size)) + + src = src + self.drop_path(self.mlp(self.norm2(src))) + tgt = tgt + self.drop_path(self.mlp(self.norm2(tgt))) + return src, tgt + + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, ws=None, with_rpe=False, vert_c_dim=0): + super().__init__() + self.norm1 = norm_layer(dim) + if ws == 1: + if with_rpe: + if vert_c_dim > 0: + self.attn = GlobalSubSampleAttnRPEContext(dim, num_heads, attn_drop, drop, sr_ratio, vert_c_dim) + else: + self.attn = GlobalSubSampleAttnRPE(dim, num_heads, attn_drop, drop, sr_ratio) + else: + self.attn = GlobalSubSampleAttn(dim, num_heads, attn_drop, drop, sr_ratio) + else: + if with_rpe: + if vert_c_dim > 0: + self.attn = LocallyGroupedAttnRPEContext(dim, num_heads, attn_drop, drop, ws, vert_c_dim) + else: + self.attn = LocallyGroupedAttnRPE(dim, num_heads, attn_drop, drop, ws) + else: + self.attn = LocallyGroupedAttn(dim, num_heads, attn_drop, drop, ws) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, size: Size_, context=None): + x = x + self.drop_path(self.attn(self.norm1(x), size, context)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PosConv(nn.Module): + # PEG from https://arxiv.org/abs/2102.10882 + def __init__(self, in_chans, embed_dim=768, stride=1): + super(PosConv, self).__init__() + self.proj = nn.Sequential(nn.Conv2d(in_chans, embed_dim, 3, stride, 1, bias=True, groups=embed_dim), ) + self.stride = stride + + def forward(self, x, size: Size_): + B, N, C = x.shape + cnn_feat_token = x.transpose(1, 2).view(B, C, *size) + x = self.proj(cnn_feat_token) + if self.stride == 1: + x += cnn_feat_token + x = x.flatten(2).transpose(1, 2) + return x + + def no_weight_decay(self): + return ['proj.%d.weight' % i for i in range(4)] + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \ + f"img_size {img_size} should be divided by patch_size {patch_size}." + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = nn.LayerNorm(embed_dim) + + def forward(self, x) -> Tuple[torch.Tensor, Size_]: + B, C, H, W = x.shape + + x = self.proj(x).flatten(2).transpose(1, 2) + x = self.norm(x) + out_size = (H // self.patch_size[0], W // self.patch_size[1]) + + return x, out_size + + +class Twins(nn.Module): + """ Twins Vision Transfomer (Revisiting Spatial Attention) + Adapted from PVT (PyramidVisionTransformer) class at https://github.com/whai362/PVT.git + """ + + def __init__( + self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dims=(64, 128, 256, 512), + num_heads=(1, 2, 4, 8), mlp_ratios=(4, 4, 4, 4), drop_rate=0., attn_drop_rate=0., drop_path_rate=0., + norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=(3, 4, 6, 3), sr_ratios=(8, 4, 2, 1), wss=None, + block_cls=Block, init_weight=True): + super().__init__() + self.num_classes = num_classes + self.depths = depths + self.embed_dims = embed_dims + self.num_features = embed_dims[-1] + + img_size = to_2tuple(img_size) + prev_chs = in_chans + self.patch_embeds = nn.ModuleList() + self.pos_drops = nn.ModuleList() + for i in range(len(depths)): + self.patch_embeds.append(PatchEmbed(img_size, patch_size, prev_chs, embed_dims[i])) + self.pos_drops.append(nn.Dropout(p=drop_rate)) + prev_chs = embed_dims[i] + img_size = tuple(t // patch_size for t in img_size) + patch_size = 2 + + self.blocks = nn.ModuleList() + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + for k in range(len(depths)): + _block = nn.ModuleList([block_cls( + dim=embed_dims[k], num_heads=num_heads[k], mlp_ratio=mlp_ratios[k], drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[k], + ws=1 if wss is None or i % 2 == 1 else wss[k]) for i in range(depths[k])]) + self.blocks.append(_block) + cur += depths[k] + + self.pos_block = nn.ModuleList([PosConv(embed_dim, embed_dim) for embed_dim in embed_dims]) + + self.norm = norm_layer(self.num_features) + + # classification head + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + # init weights + if init_weight: + self.apply(self._init_weights) + + @torch.jit.ignore + def no_weight_decay(self): + return set(['pos_block.' + n for n, p in self.pos_block.named_parameters()]) + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + + def forward_features(self, x): + B = x.shape[0] + for i, (embed, drop, blocks, pos_blk) in enumerate( + zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)): + x, size = embed(x) + x = drop(x) + for j, blk in enumerate(blocks): + x = blk(x, size) + if j == 0: + x = pos_blk(x, size) # PEG here + if i < len(self.depths) - 1: + x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() + x = self.norm(x) + return x.mean(dim=1) # GAP here + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + +# def _create_twins(variant, pretrained=False, **kwargs): +# if kwargs.get('features_only', None): +# raise RuntimeError('features_only not implemented for Vision Transformer models.') + +# model = build_model_with_cfg( +# Twins, variant, pretrained, +# default_cfg=default_cfgs[variant], +# **kwargs) +# return model + + +# @register_model +# def twins_pcpvt_small(pretrained=False, **kwargs): +# model_kwargs = dict( +# patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], +# depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], **kwargs) +# return _create_twins('twins_pcpvt_small', pretrained=pretrained, **model_kwargs) + + +# @register_model +# def twins_pcpvt_base(pretrained=False, **kwargs): +# model_kwargs = dict( +# patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], +# depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], **kwargs) +# return _create_twins('twins_pcpvt_base', pretrained=pretrained, **model_kwargs) + + +# @register_model +# def twins_pcpvt_large(pretrained=False, **kwargs): +# model_kwargs = dict( +# patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], +# depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], **kwargs) +# return _create_twins('twins_pcpvt_large', pretrained=pretrained, **model_kwargs) + + +# @register_model +# def twins_svt_small(pretrained=False, **kwargs): +# model_kwargs = dict( +# patch_size=4, embed_dims=[64, 128, 256, 512], num_heads=[2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4], +# depths=[2, 2, 10, 4], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs) +# return _create_twins('twins_svt_small', pretrained=pretrained, **model_kwargs) + + +# @register_model +# def twins_svt_base(pretrained=False, **kwargs): +# model_kwargs = dict( +# patch_size=4, embed_dims=[96, 192, 384, 768], num_heads=[3, 6, 12, 24], mlp_ratios=[4, 4, 4, 4], +# depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs) +# return _create_twins('twins_svt_base', pretrained=pretrained, **model_kwargs) + + +# @register_model +# def twins_svt_large(pretrained=False, **kwargs): +# model_kwargs = dict( +# patch_size=4, embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], mlp_ratios=[4, 4, 4, 4], +# depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs) +# return _create_twins('twins_svt_large', pretrained=pretrained, **model_kwargs) + +# @register_model +# def twins_svt_large_context(pretrained=False, **kwargs): +# model_kwargs = dict( +# patch_size=4, embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], mlp_ratios=[4, 4, 4, 4], +# depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], in_chans=6, init_weight=False, **kwargs) +# return _create_twins('twins_svt_large', pretrained=pretrained, **model_kwargs) +# # def twins_svt_large_context(pretrained=False, **kwargs): +# # model_kwargs = dict( +# # patch_size=4, embed_dims=[128, 256], num_heads=[4, 8], mlp_ratios=[4, 4], +# # depths=[2, 2], wss=[7, 7], sr_ratios=[8, 4], in_chans=6, init_weight=False, **kwargs) +# # return _create_twins('twins_svt_large', pretrained=pretrained, **model_kwargs) diff --git a/modules/components/amt_flowformer/blocks/utils.py b/modules/components/amt_flowformer/blocks/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2f001b224ae7923fd22bf02269a5f259e0e163b9 --- /dev/null +++ b/modules/components/amt_flowformer/blocks/utils.py @@ -0,0 +1,101 @@ +import torch +import torch.nn.functional as F +import numpy as np +from scipy import interpolate + + +class InputPadder: + """ Pads images such that dimensions are divisible by 8 """ + def __init__(self, dims, mode='sintel'): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 + pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 + if mode == 'sintel': + self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] + elif mode == 'kitti400': + self._pad = [0, 0, 0, 400 - self.ht] + else: + self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] + + def pad(self, *inputs): + return [F.pad(x, self._pad, mode='replicate') for x in inputs] + + def unpad(self,x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] + return x[..., c[0]:c[1], c[2]:c[3]] + +def forward_interpolate(flow): + flow = flow.detach().cpu().numpy() + dx, dy = flow[0], flow[1] + + ht, wd = dx.shape + x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) + + x1 = x0 + dx + y1 = y0 + dy + + x1 = x1.reshape(-1) + y1 = y1.reshape(-1) + dx = dx.reshape(-1) + dy = dy.reshape(-1) + + valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) + x1 = x1[valid] + y1 = y1[valid] + dx = dx[valid] + dy = dy[valid] + + flow_x = interpolate.griddata( + (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) + + flow_y = interpolate.griddata( + (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) + + flow = np.stack([flow_x, flow_y], axis=0) + return torch.from_numpy(flow).float() + +def bilinear_sampler(img, coords, mode='bilinear', mask=False): + """ Wrapper for grid_sample, uses pixel coordinates """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1,1], dim=-1) + xgrid = 2*xgrid/(W-1) - 1 + ygrid = 2*ygrid/(H-1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + +def indexing(img, coords, mask=False): + """ Wrapper for grid_sample, uses pixel coordinates """ + """ + TODO: directly indexing features instead of sampling + """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1,1], dim=-1) + xgrid = 2*xgrid/(W-1) - 1 + ygrid = 2*ygrid/(H-1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True, mode='nearest') + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + +def coords_grid(batch, ht, wd): + coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +def upflow8(flow, mode='bilinear'): + new_size = (8 * flow.shape[2], 8 * flow.shape[3]) + return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) diff --git a/modules/components/amt_flowformer/blocks/warp.py b/modules/components/amt_flowformer/blocks/warp.py new file mode 100644 index 0000000000000000000000000000000000000000..89c63449c52bc12b73cc94b29c1d96a305365270 --- /dev/null +++ b/modules/components/amt_flowformer/blocks/warp.py @@ -0,0 +1,13 @@ +import torch +import torch.nn.functional as F + + +def warp(img, flow): + B, _, H, W = flow.shape + xx = torch.linspace(-1.0, 1.0, W).view(1, 1, 1, W).expand(B, -1, H, -1) + yy = torch.linspace(-1.0, 1.0, H).view(1, 1, H, 1).expand(B, -1, -1, W) + grid = torch.cat([xx, yy], 1).to(img) + flow_ = torch.cat([flow[:, 0:1, :, :] / ((W - 1.0) / 2.0), flow[:, 1:2, :, :] / ((H - 1.0) / 2.0)], 1) + grid_ = (grid + flow_).permute(0, 2, 3, 1) + output = F.grid_sample(input=img, grid=grid_, mode='bilinear', padding_mode='border', align_corners=True) + return output diff --git a/modules/components/amt_splat/AMT.py b/modules/components/amt_splat/AMT.py new file mode 100644 index 0000000000000000000000000000000000000000..e46f41db23047478f207d07eaade477d14dc34a8 --- /dev/null +++ b/modules/components/amt_splat/AMT.py @@ -0,0 +1,246 @@ +import torch +import torch.nn as nn +from .blocks.warp import warp +from .blocks.raft import ( + coords_grid, + SmallUpdateBlock, BidirCorrBlock, BasicUpdateBlock +) +from .blocks.feat_enc import ( + SmallEncoder, + BasicEncoder, + LargeEncoder +) +from .blocks.ifrnet import ( + resize, + Encoder, + InitDecoder, + IntermediateDecoder +) +from .blocks.multi_flow import ( + multi_flow_combine, + MultiFlowDecoder +) + +from ..components import register + +from utils.padder import InputPadder + + +def photometric_consistency(img0, img1, flow01): + return (img0 - warp(img1, flow01)).abs().sum(dim=1, keepdims=True) + + +def flow_consistency(flow01, flow10): + return (flow01 + warp(flow10, flow01)).abs().sum(dim=1, keepdims=True) + + +gaussian_kernel = torch.tensor([[1, 2, 1], + [2, 4, 2], + [1, 2, 1]]) / 16 +gaussian_kernel = gaussian_kernel.repeat(2, 1, 1, 1) +gaussian_kernel = gaussian_kernel.to(torch.cuda.current_device()) + + +def gaussian(x): + x = torch.nn.functional.pad(x, (1, 1, 1, 1), mode='reflect') + out = torch.nn.functional.conv2d(x, gaussian_kernel, groups=x.shape[1]) + # out = TF.gaussian_blur(x, [3, 3], sigma=[2, 2]) + return out + + +def variance_flow(flow): + flow = flow * torch.tensor(data=[2.0 / (flow.shape[3] - 1.0), 2.0 / (flow.shape[2] - 1.0)], dtype=flow.dtype, + device=flow.device).view(1, 2, 1, 1) + return (gaussian(flow ** 2) - gaussian(flow) ** 2 + 1e-4).sqrt().abs().sum(dim=1, keepdim=True) + +@register('amt_splat') +class Model(nn.Module): + def __init__(self, + model_size='S', + corr_radius=3, + corr_lvls=4, + num_flows=3, + channels=[20, 32, 44, 56], + skip_channels=20, + scale_factor=1): + super(Model, self).__init__() + self.model_size = model_size + self.radius = corr_radius + self.corr_levels = corr_lvls + self.num_flows = num_flows + self.channels = channels + self.skip_channels = skip_channels + self.scale_factor = scale_factor + if self.model_size == 'S': + self.feat_encoder = SmallEncoder(output_dim=84, norm_fn='instance', dropout=0.) + elif self.model_size == 'L': + self.feat_encoder = BasicEncoder(output_dim=128, norm_fn='instance', dropout=0.) + elif self.model_size == 'G': + self.feat_encoder = LargeEncoder(output_dim=128, norm_fn='instance', dropout=0.) + self.encoder = Encoder(channels, large=True) + + # self.decoder4 = InitDecoder(channels[3], channels[2], skip_channels) + self.decoder3 = IntermediateDecoder(channels[2], channels[1], skip_channels) + self.decoder2 = IntermediateDecoder(channels[1] * 2, channels[0], skip_channels) + self.decoder1 = MultiFlowDecoder(channels[0] * 2, skip_channels, num_flows) + + self.update4 = self._get_updateblock(channels[2]) + self.update3_low = self._get_updateblock(channels[1] * 2, 2) + self.update2_low = self._get_updateblock(channels[0] * 2, 4) + + if self.model_size == 'G': + self.update3_high = self._get_updateblock(channels[1] * 2, None) + self.update2_high = self._get_updateblock(channels[0] * 2, None) + # self.alpha = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + # self.alpha_splat_photo_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + # self.alpha_splat_flow_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + # self.alpha_splat_variation_flow = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + + # self.comb_block = nn.Sequential( + # nn.Conv2d(3 * self.num_flows, 6 * self.num_flows, 7, 1, 3), + # nn.PReLU(6 * self.num_flows), + # nn.Conv2d(6 * self.num_flows, 3, 7, 1, 3), + # ) + + def _get_updateblock(self, cdim, scale_factor=None): + return BasicUpdateBlock(cdim=cdim, hidden_dim=192, flow_dim=64, + corr_dim=256, corr_dim2=192, fc_dim=188, + scale_factor=scale_factor, corr_levels=self.corr_levels, + radius=self.radius) + + def _corr_scale_lookup(self, corr_fn, coord, flow_fwd, flow_bwd, embt, downsample=1): + # convert t -> 0 to 0 -> 1 | convert t -> 1 to 1 -> 0 + # based on linear assumption + t1_scale = 1. / embt + t0_scale = 1. / (1. - embt) + if downsample != 1: + inv = 1 / downsample + flow_fwd = inv * resize(flow_fwd, scale_factor=inv) + flow_bwd = inv * resize(flow_bwd, scale_factor=inv) + + corr_fwd, corr_bwd = corr_fn(coord + flow_fwd, coord + flow_bwd) + return corr_fwd, corr_bwd, flow_fwd, flow_bwd + + def get_splat_weight(self, img0, img1, flow01, flow10): + M_splat = 1 / (1 + self.alpha_splat_photo_consistency * photometric_consistency(img0, img1, flow01).detach()) + \ + 1 / (1 + self.alpha_splat_flow_consistency * flow_consistency(flow01, flow10).detach()) + \ + 1 / (1 + self.alpha_splat_variation_flow * variance_flow(flow01).detach()) + return M_splat * self.alpha + + + def forward(self, img0, img1, time_step, scale_factor=1.0, eval=False, **kwargs): + scale_factor = self.scale_factor + padder = InputPadder(img0.shape, divisor=int(16 / scale_factor)) + img0, img1 = padder.pad(img0, img1) + mean_ = torch.cat([img0, img1], 2).mean(1, keepdim=True).mean(2, keepdim=True).mean(3, keepdim=True) + img0 = img0 - mean_ + img1 = img1 - mean_ + img0_ = resize(img0, scale_factor) if scale_factor != 1.0 else img0 + img1_ = resize(img1, scale_factor) if scale_factor != 1.0 else img1 + b, _, h, w = img0_.shape + coords = coords_grid(b, h // 8, w // 8, img0.device) + flow_fwd_4, flow_bwd_4 = torch.zeros(b, 2, h // 8, w // 8).cuda(), torch.zeros(b, 2, h // 8, w // 8).cuda() + + fmap0, fmap1 = self.feat_encoder([img0_, img1_]) # [1, 128, H//8, W//8] + corr_fn = BidirCorrBlock(fmap0, fmap1, radius=self.radius, num_levels=self.corr_levels) + + # f0_1: [1, c0, H//2, W//2] | f0_2: [1, c1, H//4, W//4] + # f0_3: [1, c2, H//8, W//8] | f0_4: [1, c3, H//16, W//16] + f0_1, f0_2, f0_3 = self.encoder(img0_) + f1_1, f1_2, f1_3 = self.encoder(img1_) + + ######################################### the 4th decoder ######################################### + corr_fwd_4, corr_bwd_4, _, _ = self._corr_scale_lookup(corr_fn, coords, flow_fwd_4, flow_bwd_4, time_step) + + # residue update with lookup corr + delta_f0_3_, delta_flow_fwd_4 = self.update4(f0_3, flow_fwd_4, corr_fwd_4) + delta_f1_3_, delta_flow_bwd_4 = self.update4(f0_3, flow_bwd_4, corr_bwd_4) + up_f0_3 = f0_3 + delta_f0_3_ + up_f1_3 = f1_3 + delta_f1_3_ + flow_fwd_4 = flow_fwd_4 + delta_flow_fwd_4 + flow_bwd_4 = flow_bwd_4 + delta_flow_bwd_4 + + ######################################### the 3rd decoder ######################################### + flow_fwd_3, flow_bwd_3, f0_2_, f1_2_ = self.decoder3(up_f0_3, up_f1_3, flow_fwd_4, flow_bwd_4) + corr_fwd_3, corr_bwd_3, flow_fwd_3_, flow_bwd_3_ = self._corr_scale_lookup(corr_fn, + coords, flow_fwd_3, flow_bwd_3, + time_step, downsample=2) + + # residue update with lookup corr + f0_2 = torch.cat([f0_2, f0_2_], dim=1) + f1_2 = torch.cat([f1_2, f1_2_], dim=1) + delta_f0_2_, delta_flow_fwd_3 = self.update3_low(f0_2, flow_fwd_3_, corr_fwd_3) + delta_f1_2_, delta_flow_bwd_3 = self.update3_low(f1_2, flow_bwd_3_, corr_bwd_3) + f0_2 = f0_2 + delta_f0_2_ + f1_2 = f1_2 + delta_f1_2_ + flow_fwd_3 = flow_fwd_3 + delta_flow_fwd_3 + flow_bwd_3 = flow_bwd_3 + delta_flow_bwd_3 + + if self.model_size == 'G': + # residue update with lookup corr (hr) + corr_fwd_3 = resize(corr_fwd_3, scale_factor=2.0) + corr_bwd_3 = resize(corr_bwd_3, scale_factor=2.0) + delta_f0_2_, delta_flow_fwd_3 = self.update3_high(f0_2, flow_fwd_3, corr_fwd_3) + delta_f1_2_, delta_flow_bwd_3 = self.update3_high(f1_2, flow_bwd_3, corr_bwd_3) + up_f0_2 = f0_2 + delta_f0_2_ + up_f1_2 = f1_2 + delta_f1_2_ + flow_fwd_3 = flow_fwd_3 + delta_flow_fwd_3 + flow_bwd_3 = flow_bwd_3 + delta_flow_bwd_3 + + ######################################### the 2nd decoder ######################################### + flow_fwd_2, flow_bwd_2, f0_1_, f1_1_ = self.decoder2(up_f0_2, up_f1_2, flow_fwd_3, flow_bwd_3) + corr_fwd_2, corr_bwd_2, flow_fwd_2_, flow_bwd_2_ = self._corr_scale_lookup(corr_fn, + coords, flow_fwd_2, flow_bwd_2, + time_step, downsample=4) + + # residue update with lookup corr + f0_1 = torch.cat([f0_1, f0_1_], dim=1) + f1_1 = torch.cat([f1_1, f1_1_], dim=1) + delta_f0_1_, delta_flow_fwd_2 = self.update2_low(f0_1, flow_fwd_2_, corr_fwd_2) + delta_f1_1_, delta_flow_bwd_2 = self.update2_low(f1_1, flow_bwd_2_, corr_bwd_2) + f0_1 = f0_1 + delta_f0_1_ + f1_1 = f1_1 + delta_f1_1_ + flow_fwd_2 = flow_fwd_2 + delta_flow_fwd_2 + flow_bwd_2 = flow_bwd_2 + delta_flow_bwd_2 + if self.model_size == 'G': + # residue update with lookup corr (hr) + corr_fwd_2 = resize(corr_fwd_2, scale_factor=4.0) + corr_bwd_2 = resize(corr_bwd_2, scale_factor=4.0) + delta_f0_1_, delta_flow_fwd_2 = self.update2_high(f0_1, flow_fwd_2, corr_fwd_2) + delta_f1_1_, delta_flow_bwd_2 = self.update2_high(f1_1, flow_bwd_2, corr_bwd_2) + f0_1 = f0_1 + delta_f0_1_ + f1_1 = f1_1 + delta_f1_1_ + flow_fwd_2 = flow_fwd_2 + delta_flow_fwd_2 + flow_bwd_2 = flow_bwd_2 + delta_flow_bwd_2 + + ######################################### the 1st decoder ######################################### + flow_fwd_1, flow_bwd_1, mask_fwd, mask_bwd = self.decoder1(f0_1, f1_1, flow_fwd_2, flow_bwd_2) + + if scale_factor != 1.0: + flow_fwd_1 = resize(flow_fwd_1, scale_factor=(1.0 / scale_factor)) * (1.0 / scale_factor) + flow_bwd_1 = resize(flow_bwd_1, scale_factor=(1.0 / scale_factor)) * (1.0 / scale_factor) + mask_fwd = resize(mask_fwd, scale_factor=(1.0 / scale_factor)) + mask_bwd = resize(mask_bwd, scale_factor=(1.0 / scale_factor)) + + # Merge multiple predictions + # img0_ = img0.repeat(1, self.num_flows, 1, 1).view(b * self.num_flows, h, w) + # img1_ = img1.repeat(1, self.num_flows, 1, 1).view(b * self.num_flows, h, w) + # metric0 = self.get_splat_weight(img0_, img1_, flow_fwd_1_, flow_bwd_1_) + # metric1 = self.get_splat_weight(img1_, img0_, flow_bwd_1_, flow_fwd_1_) + imgt_pred = multi_flow_combine(img0, img1, flow_fwd_1, flow_bwd_1, + mask_fwd, mask_bwd, time_step, mean_) + imgt_pred = torch.clamp(imgt_pred, 0, 1) + imgt_pred = padder.unpad(imgt_pred) + + if eval: + return {'imgt_pred': imgt_pred, } + else: + flow_fwd_1 = flow_fwd_1.reshape(b, self.num_flows, 2, int(h / scale_factor), int(w / scale_factor)) + flow_bwd_1 = flow_bwd_1.reshape(b, self.num_flows, 2, int(h / scale_factor), int(w / scale_factor)) + return { + 'imgt_pred': imgt_pred, + 'flow0_pred': [flow_fwd_1 * 0.5, flow_fwd_2 * 0.5, flow_fwd_3 * 0.5, flow_fwd_4 * 0.5], + 'flow1_pred': [flow_bwd_1 * 0.5, flow_bwd_2 * 0.5, flow_bwd_3 * 0.5, flow_bwd_4 * 0.5], + 'flowfwd': flow_fwd_1[:, 0] * 0.5, + 'flowbwd': flow_bwd_1[:, 0] * 0.5 + } diff --git a/modules/components/amt_splat/__init__.py b/modules/components/amt_splat/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..387563589f3cc4a3dc664a5d8d4f5557b0100996 --- /dev/null +++ b/modules/components/amt_splat/__init__.py @@ -0,0 +1 @@ +from .AMT import Model \ No newline at end of file diff --git a/modules/components/amt_splat/__pycache__/AMT.cpython-310.pyc b/modules/components/amt_splat/__pycache__/AMT.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aed011f055b7bbc66711bc3adce98e74bca4c022 Binary files /dev/null and b/modules/components/amt_splat/__pycache__/AMT.cpython-310.pyc differ diff --git a/modules/components/amt_splat/__pycache__/AMT.cpython-38.pyc b/modules/components/amt_splat/__pycache__/AMT.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4db5ea49fc6e5a0662092e19c213775734a309a0 Binary files /dev/null and b/modules/components/amt_splat/__pycache__/AMT.cpython-38.pyc differ diff --git a/modules/components/amt_splat/__pycache__/AMT.cpython-39.pyc b/modules/components/amt_splat/__pycache__/AMT.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c368ac2b65705606b127ad085297d1d836e1d12a Binary files /dev/null and b/modules/components/amt_splat/__pycache__/AMT.cpython-39.pyc differ diff --git a/modules/components/amt_splat/__pycache__/__init__.cpython-310.pyc b/modules/components/amt_splat/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8003c9a8cd850c1f19045ce7310348297f2fdc3 Binary files /dev/null and b/modules/components/amt_splat/__pycache__/__init__.cpython-310.pyc differ diff --git a/modules/components/amt_splat/__pycache__/__init__.cpython-38.pyc b/modules/components/amt_splat/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..063fcaa74777fbae037cfaa37c6c3b2dfa958ecd Binary files /dev/null and b/modules/components/amt_splat/__pycache__/__init__.cpython-38.pyc differ diff --git a/modules/components/amt_splat/__pycache__/__init__.cpython-39.pyc b/modules/components/amt_splat/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09f38280d45f4ec2f39b609be0ec69b3ab502ee6 Binary files /dev/null and b/modules/components/amt_splat/__pycache__/__init__.cpython-39.pyc differ diff --git a/modules/components/amt_splat/blocks/__init__.py b/modules/components/amt_splat/blocks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/components/amt_splat/blocks/__pycache__/__init__.cpython-310.pyc b/modules/components/amt_splat/blocks/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..930e07802add7adf04a1fc1fcce451a9b13a6163 Binary files /dev/null and b/modules/components/amt_splat/blocks/__pycache__/__init__.cpython-310.pyc differ diff --git a/modules/components/amt_splat/blocks/__pycache__/__init__.cpython-38.pyc b/modules/components/amt_splat/blocks/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..760407e81bf5cbabc2ce65b0368636e575df1053 Binary files /dev/null and b/modules/components/amt_splat/blocks/__pycache__/__init__.cpython-38.pyc differ diff --git a/modules/components/amt_splat/blocks/__pycache__/__init__.cpython-39.pyc b/modules/components/amt_splat/blocks/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ecc311e6f2e6988fe90fdaec45860fb625691794 Binary files /dev/null and b/modules/components/amt_splat/blocks/__pycache__/__init__.cpython-39.pyc differ diff --git a/modules/components/amt_splat/blocks/__pycache__/feat_enc.cpython-310.pyc b/modules/components/amt_splat/blocks/__pycache__/feat_enc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28d83f751f8741ce9c87304719f8b20c2a8a8d9c Binary files /dev/null and b/modules/components/amt_splat/blocks/__pycache__/feat_enc.cpython-310.pyc differ diff --git a/modules/components/amt_splat/blocks/__pycache__/feat_enc.cpython-38.pyc b/modules/components/amt_splat/blocks/__pycache__/feat_enc.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1c5a005613e7f6c4264168d184d742a1f5903ff Binary files /dev/null and b/modules/components/amt_splat/blocks/__pycache__/feat_enc.cpython-38.pyc differ diff --git a/modules/components/amt_splat/blocks/__pycache__/feat_enc.cpython-39.pyc b/modules/components/amt_splat/blocks/__pycache__/feat_enc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f116a54b27fc561c78ea6d61be6d750b5008b3f2 Binary files /dev/null and b/modules/components/amt_splat/blocks/__pycache__/feat_enc.cpython-39.pyc differ diff --git a/modules/components/amt_splat/blocks/__pycache__/ifrnet.cpython-310.pyc b/modules/components/amt_splat/blocks/__pycache__/ifrnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a36e8debc3c385641ef224f5225c3da350d2c45 Binary files /dev/null and b/modules/components/amt_splat/blocks/__pycache__/ifrnet.cpython-310.pyc differ diff --git a/modules/components/amt_splat/blocks/__pycache__/ifrnet.cpython-38.pyc b/modules/components/amt_splat/blocks/__pycache__/ifrnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5b206b7967e2793f0ed295f99e8a9aa59c6bf4b Binary files /dev/null and b/modules/components/amt_splat/blocks/__pycache__/ifrnet.cpython-38.pyc differ diff --git a/modules/components/amt_splat/blocks/__pycache__/ifrnet.cpython-39.pyc b/modules/components/amt_splat/blocks/__pycache__/ifrnet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b156689bc2ff596c17f7b472f7485a7ca63d695e Binary files /dev/null and b/modules/components/amt_splat/blocks/__pycache__/ifrnet.cpython-39.pyc differ diff --git a/modules/components/amt_splat/blocks/__pycache__/multi_flow.cpython-310.pyc b/modules/components/amt_splat/blocks/__pycache__/multi_flow.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78500093ba2546ba9fccd15cba06bf0697b18fee Binary files /dev/null and b/modules/components/amt_splat/blocks/__pycache__/multi_flow.cpython-310.pyc differ diff --git a/modules/components/amt_splat/blocks/__pycache__/multi_flow.cpython-38.pyc b/modules/components/amt_splat/blocks/__pycache__/multi_flow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0337bd21e48294a0e6a764a627557cf8663f8da9 Binary files /dev/null and b/modules/components/amt_splat/blocks/__pycache__/multi_flow.cpython-38.pyc differ diff --git a/modules/components/amt_splat/blocks/__pycache__/multi_flow.cpython-39.pyc b/modules/components/amt_splat/blocks/__pycache__/multi_flow.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6337f26f200425cb96af30854e7d8ab747e95e6 Binary files /dev/null and b/modules/components/amt_splat/blocks/__pycache__/multi_flow.cpython-39.pyc differ diff --git a/modules/components/amt_splat/blocks/__pycache__/raft.cpython-310.pyc b/modules/components/amt_splat/blocks/__pycache__/raft.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f634029672b3beaa291cb4f5613146007dce8075 Binary files /dev/null and b/modules/components/amt_splat/blocks/__pycache__/raft.cpython-310.pyc differ diff --git a/modules/components/amt_splat/blocks/__pycache__/raft.cpython-38.pyc b/modules/components/amt_splat/blocks/__pycache__/raft.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eea1523475f17b1833be36faf04da8640f6269f1 Binary files /dev/null and b/modules/components/amt_splat/blocks/__pycache__/raft.cpython-38.pyc differ diff --git a/modules/components/amt_splat/blocks/__pycache__/raft.cpython-39.pyc b/modules/components/amt_splat/blocks/__pycache__/raft.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da82df689d69012732fdd0de6a4cb7e16181c746 Binary files /dev/null and b/modules/components/amt_splat/blocks/__pycache__/raft.cpython-39.pyc differ diff --git a/modules/components/amt_splat/blocks/__pycache__/softsplat.cpython-310.pyc b/modules/components/amt_splat/blocks/__pycache__/softsplat.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f0c9e7e33000012604d1fa66fd7082b4847b5c4 Binary files /dev/null and b/modules/components/amt_splat/blocks/__pycache__/softsplat.cpython-310.pyc differ diff --git a/modules/components/amt_splat/blocks/__pycache__/softsplat.cpython-38.pyc b/modules/components/amt_splat/blocks/__pycache__/softsplat.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79b840bd5b569f99422d9b08ad21215000071c3f Binary files /dev/null and b/modules/components/amt_splat/blocks/__pycache__/softsplat.cpython-38.pyc differ diff --git a/modules/components/amt_splat/blocks/__pycache__/softsplat.cpython-39.pyc b/modules/components/amt_splat/blocks/__pycache__/softsplat.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..635f5652b4c58058d3bbe74dd1439bd021949f5f Binary files /dev/null and b/modules/components/amt_splat/blocks/__pycache__/softsplat.cpython-39.pyc differ diff --git a/modules/components/amt_splat/blocks/__pycache__/warp.cpython-310.pyc b/modules/components/amt_splat/blocks/__pycache__/warp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6db15db24271074b357e73de46e389c1ca1facfe Binary files /dev/null and b/modules/components/amt_splat/blocks/__pycache__/warp.cpython-310.pyc differ diff --git a/modules/components/amt_splat/blocks/__pycache__/warp.cpython-38.pyc b/modules/components/amt_splat/blocks/__pycache__/warp.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf55117e986220c14df4d5eddcf0f87cc2b069b8 Binary files /dev/null and b/modules/components/amt_splat/blocks/__pycache__/warp.cpython-38.pyc differ diff --git a/modules/components/amt_splat/blocks/__pycache__/warp.cpython-39.pyc b/modules/components/amt_splat/blocks/__pycache__/warp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28da08621a764d72f294f27844b847debe21a842 Binary files /dev/null and b/modules/components/amt_splat/blocks/__pycache__/warp.cpython-39.pyc differ diff --git a/modules/components/amt_splat/blocks/feat_enc.py b/modules/components/amt_splat/blocks/feat_enc.py new file mode 100644 index 0000000000000000000000000000000000000000..3805bd315422703c19bf6a4d0962ee75002d92aa --- /dev/null +++ b/modules/components/amt_splat/blocks/feat_enc.py @@ -0,0 +1,343 @@ +import torch +import torch.nn as nn + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) + self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes//4) + self.norm2 = nn.BatchNorm2d(planes//4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes//4) + self.norm2 = nn.InstanceNorm2d(planes//4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + +class SmallEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(SmallEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(32) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(32) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 32 + self.layer1 = self._make_layer(32, stride=1) + self.layer2 = self._make_layer(64, stride=2) + self.layer3 = self._make_layer(96, stride=2) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + +class BasicEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(72, stride=2) + self.layer3 = self._make_layer(128, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + +class LargeEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(LargeEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(112, stride=2) + self.layer3 = self._make_layer(160, stride=2) + self.layer3_2 = self._make_layer(160, stride=1) + + # output convolution + self.conv2 = nn.Conv2d(self.in_planes, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer3_2(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x diff --git a/modules/components/amt_splat/blocks/ifrnet.py b/modules/components/amt_splat/blocks/ifrnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c553f29e414fed125cac568d2bcadee1b15a7df4 --- /dev/null +++ b/modules/components/amt_splat/blocks/ifrnet.py @@ -0,0 +1,119 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .warp import warp + + +def resize(x, scale_factor): + return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False) + + +def convrelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True): + return nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias), + nn.PReLU(out_channels) + ) + + +class ResBlock(nn.Module): + def __init__(self, in_channels, side_channels, bias=True): + super(ResBlock, self).__init__() + self.side_channels = side_channels + self.conv1 = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(in_channels) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(side_channels) + ) + self.conv3 = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(in_channels) + ) + self.conv4 = nn.Sequential( + nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(side_channels) + ) + self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias) + self.prelu = nn.PReLU(in_channels) + + def forward(self, x): + out = self.conv1(x) + + res_feat = out[:, :-self.side_channels, ...] + side_feat = out[:, -self.side_channels:, :, :] + side_feat = self.conv2(side_feat) + out = self.conv3(torch.cat([res_feat, side_feat], 1)) + + res_feat = out[:, :-self.side_channels, ...] + side_feat = out[:, -self.side_channels:, :, :] + side_feat = self.conv4(side_feat) + out = self.conv5(torch.cat([res_feat, side_feat], 1)) + + out = self.prelu(x + out) + return out + + +class Encoder(nn.Module): + def __init__(self, channels, large=False): + super(Encoder, self).__init__() + self.channels = channels + prev_ch = 3 + for idx, ch in enumerate(channels, 1): + k = 7 if large and idx == 1 else 3 + p = 3 if k == 7 else 1 + self.register_module(f'pyramid{idx}', + nn.Sequential( + convrelu(prev_ch, ch, k, 2, p), + convrelu(ch, ch, 3, 1, 1) + )) + prev_ch = ch + + def forward(self, in_x): + fs = [] + for idx in range(len(self.channels)): + out_x = getattr(self, f'pyramid{idx + 1}')(in_x) + fs.append(out_x) + in_x = out_x + return fs + + +class InitDecoder(nn.Module): + def __init__(self, in_ch, out_ch, skip_ch) -> None: + super().__init__() + self.convblock = nn.Sequential( + convrelu(in_ch * 2 + 1, in_ch * 2), + ResBlock(in_ch * 2, skip_ch), + nn.ConvTranspose2d(in_ch * 2, out_ch + 4, 4, 2, 1, bias=True) + ) + + def forward(self, f0, f1, embt): + h, w = f0.shape[2:] + embt = embt.repeat(1, 1, h, w) + out = self.convblock(torch.cat([f0, f1, embt], 1)) + flow0, flow1 = torch.chunk(out[:, :4, ...], 2, 1) + ft_ = out[:, 4:, ...] + return flow0, flow1, ft_ + + +class IntermediateDecoder(nn.Module): + def __init__(self, in_ch, out_ch, skip_ch) -> None: + super().__init__() + self.convblock = nn.Sequential( + convrelu(in_ch * 2 + 2, in_ch * 2), + ResBlock(in_ch * 2, skip_ch), + nn.ConvTranspose2d(in_ch * 2, out_ch, 4, 2, 1, bias=True) + ) + self.conv_flow = nn.Conv2d(out_ch, 2, 3, 1, 1) + + def forward(self, f0, f1, flow_fwd, flow_bwd): + f0_warp = warp(f0, flow_bwd) + f1_warp = warp(f1, flow_fwd) + f0_in = torch.cat([f0, f1_warp, flow_fwd], 1) + f1_in = torch.cat([f1, f0_warp, flow_bwd], 1) + out0 = self.convblock(f0_in) + out1 = self.convblock(f1_in) + flow_fwd = 2.0 * resize(flow_fwd, scale_factor=2.0) + self.conv_flow(out0) + flow_bwd = 2.0 * resize(flow_bwd, scale_factor=2.0) + self.conv_flow(out1) + return flow_fwd, flow_bwd, out0, out1 diff --git a/modules/components/amt_splat/blocks/multi_flow.py b/modules/components/amt_splat/blocks/multi_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..a6c6d58632375c32a5ee99b140fbf9da61a7ab35 --- /dev/null +++ b/modules/components/amt_splat/blocks/multi_flow.py @@ -0,0 +1,88 @@ +import torch +import torch.nn as nn +from .warp import warp +from .ifrnet import ( + convrelu, resize, + ResBlock, +) +from .softsplat import _FunctionSoftsplat + + +def forwarp_mframe_mask(tenIn1, tenFlow1, t1, tenIn2, tenFlow2, t2, tenMetric1=None, tenMetric2=None): + def one_fdir(tenIn, tenFlow, td, tenMetric): + tenIn = torch.cat([tenIn * td * (tenMetric).clip(-20.0, 20.0).exp(), td * (tenMetric).clip(-20.0, 20.0).exp()], + 1) + + tenOut = _FunctionSoftsplat.apply(tenIn, tenFlow) + + return tenOut[:, :-1, :, :], tenOut[:, -1:, :, :] + 0.0000001 + + flow_num = tenFlow1.shape[0] + tenOut = 0 + tenNormalize = 0 + for idx in range(flow_num): + tenOutF, tenNormalizeF = one_fdir(tenIn1[idx], tenFlow1[idx], t1[idx], tenMetric1[idx]) + tenOutB, tenNormalizeB = one_fdir(tenIn2[idx], tenFlow2[idx], t2[idx], tenMetric2[idx]) + + tenOut += tenOutF + tenOutB + tenNormalize += tenNormalizeF + tenNormalizeB + + return tenOut / tenNormalize, tenNormalize < 0.00001 + +def multi_flow_combine(img0, img1, flow_fwd, flow_bwd, + mask_fwd, mask_bwd, time_step, mean=None): + ''' + A parallel implementation of multiple flow field warping + comb_block: An nn.Seqential object. + img shape: [b, c, h, w] + flow shape: [b, 2*num_flows, h, w] + mask (opt): + If 'mask' is None, the function conduct a simple average. + img_res (opt): + If 'img_res' is None, the function adds zero instead. + mean (opt): + If 'mean' is None, the function adds zero instead. + ''' + b, c, h, w = flow_fwd.shape + num_flows = c // 2 + fltTime = time_step.repeat(b, num_flows, 1, 1, 1).permute(1, 0, 2, 3, 4) + t0 = fltTime + t1 = 1.0 - fltTime + + flow_fwd = flow_fwd.reshape(b, num_flows, 2, h, w).permute(1, 0, 2, 3, 4).contiguous() + flow_bwd = flow_bwd.reshape(b, num_flows, 2, h, w).permute(1, 0, 2, 3, 4).contiguous() + + mask_fwd = mask_fwd.reshape(b, num_flows, 1, h, w).permute(1, 0, 2, 3, 4) + mask_bwd = mask_bwd.reshape(b, num_flows, 1, h, w).permute(1, 0, 2, 3, 4) + img0_ = torch.stack([img0] * num_flows, 1).reshape(b, num_flows, 3, h, w).permute(1, 0, 2, 3, 4) + img1_ = torch.stack([img1] * num_flows, 1).reshape(b, num_flows, 3, h, w).permute(1, 0, 2, 3, 4) + + imgt_pred, mask = forwarp_mframe_mask(img0_, flow_fwd, t0, img1_, flow_bwd, t1, mask_fwd, mask_bwd) + imgt_pred = imgt_pred + mask * (t1.mean(0) * img0 + t0.mean(0) * img1) + mean + return imgt_pred + + +class MultiFlowDecoder(nn.Module): + def __init__(self, in_ch, skip_ch, num_flows=3): + super(MultiFlowDecoder, self).__init__() + self.num_flows = num_flows + self.convblock = nn.Sequential( + convrelu(in_ch * 2 + 2, in_ch * 2), + ResBlock(in_ch * 2, skip_ch), + nn.UpsamplingBilinear2d(scale_factor=2), + nn.Conv2d(in_ch*2, 3*num_flows, 3, 1, 1, bias=True) + ) + + def forward(self, f0, f1, flow_fwd, flow_bwd): + n = self.num_flows + f0_warp = warp(f0, flow_bwd) + f1_warp = warp(f1, flow_fwd) + out0 = self.convblock(torch.cat([f0, f1_warp, flow_fwd], 1)) + out1 = self.convblock(torch.cat([f1, f0_warp, flow_bwd], 1)) + delta_flow_fwd, mask_fwd = torch.split(out0, [2 * n, n], 1) + delta_flow_bwd, mask_bwd = torch.split(out1, [2 * n, n], 1) + + flow_fwd = delta_flow_fwd + 2.0 * resize(flow_fwd, scale_factor=2.0).repeat(1, self.num_flows, 1, 1) + flow_bwd = delta_flow_bwd + 2.0 * resize(flow_bwd, scale_factor=2.0).repeat(1, self.num_flows, 1, 1) + + return flow_fwd, flow_bwd, mask_fwd, mask_bwd diff --git a/modules/components/amt_splat/blocks/raft.py b/modules/components/amt_splat/blocks/raft.py new file mode 100644 index 0000000000000000000000000000000000000000..14e1c0ea4ebdec57680e6f702e4e63bdd7e21cba --- /dev/null +++ b/modules/components/amt_splat/blocks/raft.py @@ -0,0 +1,207 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def resize(x, scale_factor): + return F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False) + + +def bilinear_sampler(img, coords, mask=False): + """ Wrapper for grid_sample, uses pixel coordinates """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1,1], dim=-1) + xgrid = 2*xgrid/(W-1) - 1 + ygrid = 2*ygrid/(H-1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +def coords_grid(batch, ht, wd, device): + coords = torch.meshgrid(torch.arange(ht, device=device), + torch.arange(wd, device=device), + indexing='ij') + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +class SmallUpdateBlock(nn.Module): + def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, fc_dim, + corr_levels=4, radius=3, scale_factor=None): + super(SmallUpdateBlock, self).__init__() + cor_planes = corr_levels * (2 * radius + 1) **2 + self.scale_factor = scale_factor + + self.convc1 = nn.Conv2d(2 * cor_planes, corr_dim, 1, padding=0) + self.convf1 = nn.Conv2d(4, flow_dim*2, 7, padding=3) + self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1) + self.conv = nn.Conv2d(corr_dim+flow_dim, fc_dim, 3, padding=1) + + self.gru = nn.Sequential( + nn.Conv2d(fc_dim+4+cdim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + ) + + self.feat_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, cdim, 3, padding=1), + ) + + self.flow_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, 4, 3, padding=1), + ) + + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, net, flow, corr): + net = resize(net, 1 / self.scale_factor + ) if self.scale_factor is not None else net + cor = self.lrelu(self.convc1(corr)) + flo = self.lrelu(self.convf1(flow)) + flo = self.lrelu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + inp = self.lrelu(self.conv(cor_flo)) + inp = torch.cat([inp, flow, net], dim=1) + + out = self.gru(inp) + delta_net = self.feat_head(out) + delta_flow = self.flow_head(out) + + if self.scale_factor is not None: + delta_net = resize(delta_net, scale_factor=self.scale_factor) + delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor) + + return delta_net, delta_flow + + +class BasicUpdateBlock(nn.Module): + def __init__(self, cdim, hidden_dim, flow_dim, corr_dim, corr_dim2, + fc_dim, corr_levels=4, radius=3, scale_factor=None, out_num=1): + super(BasicUpdateBlock, self).__init__() + cor_planes = corr_levels * (2 * radius + 1) ** 2 + + self.scale_factor = scale_factor + self.convc1 = nn.Conv2d(cor_planes, corr_dim, 1, padding=0) + self.convc2 = nn.Conv2d(corr_dim, corr_dim2, 3, padding=1) + self.convf1 = nn.Conv2d(2, flow_dim*2, 7, padding=3) + self.convf2 = nn.Conv2d(flow_dim*2, flow_dim, 3, padding=1) + self.conv = nn.Conv2d(flow_dim+corr_dim2, fc_dim, 3, padding=1) + + self.gru = nn.Sequential( + nn.Conv2d(fc_dim+2+cdim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + ) + + self.feat_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, cdim, 3, padding=1), + ) + + self.flow_head = nn.Sequential( + nn.Conv2d(hidden_dim, hidden_dim, 3, padding=1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(hidden_dim, 2*out_num, 3, padding=1), + ) + + self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) + + def forward(self, net, flow, corr): + net = resize(net, 1 / self.scale_factor + ) if self.scale_factor is not None else net + cor = self.lrelu(self.convc1(corr)) + cor = self.lrelu(self.convc2(cor)) + flo = self.lrelu(self.convf1(flow)) + flo = self.lrelu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + inp = self.lrelu(self.conv(cor_flo)) + inp = torch.cat([inp, flow, net], dim=1) + + out = self.gru(inp) + delta_net = self.feat_head(out) + delta_flow = self.flow_head(out) + + if self.scale_factor is not None: + delta_net = resize(delta_net, scale_factor=self.scale_factor) + delta_flow = self.scale_factor * resize(delta_flow, scale_factor=self.scale_factor) + return delta_net, delta_flow + + +class BidirCorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + self.corr_pyramid = [] + self.corr_pyramid_T = [] + + corr = BidirCorrBlock.corr(fmap1, fmap2) + batch, h1, w1, dim, h2, w2 = corr.shape + corr_T = corr.clone().permute(0, 4, 5, 3, 1, 2) + + corr = corr.reshape(batch*h1*w1, dim, h2, w2) + corr_T = corr_T.reshape(batch*h2*w2, dim, h1, w1) + + self.corr_pyramid.append(corr) + self.corr_pyramid_T.append(corr_T) + + for _ in range(self.num_levels-1): + corr = F.avg_pool2d(corr, 2, stride=2) + corr_T = F.avg_pool2d(corr_T, 2, stride=2) + self.corr_pyramid.append(corr) + self.corr_pyramid_T.append(corr_T) + + def __call__(self, coords0, coords1): + r = self.radius + coords0 = coords0.permute(0, 2, 3, 1) + coords1 = coords1.permute(0, 2, 3, 1) + assert coords0.shape == coords1.shape, f"coords0 shape: [{coords0.shape}] is not equal to [{coords1.shape}]" + batch, h1, w1, _ = coords0.shape + + out_pyramid = [] + out_pyramid_T = [] + for i in range(self.num_levels): + corr = self.corr_pyramid[i] + corr_T = self.corr_pyramid_T[i] + + dx = torch.linspace(-r, r, 2*r+1, device=coords0.device) + dy = torch.linspace(-r, r, 2*r+1, device=coords0.device) + delta = torch.stack(torch.meshgrid(dy, dx, indexing='ij'), axis=-1) + delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) + + centroid_lvl_0 = coords0.reshape(batch*h1*w1, 1, 1, 2) / 2**i + centroid_lvl_1 = coords1.reshape(batch*h1*w1, 1, 1, 2) / 2**i + coords_lvl_0 = centroid_lvl_0 + delta_lvl + coords_lvl_1 = centroid_lvl_1 + delta_lvl + + corr = bilinear_sampler(corr, coords_lvl_0) + corr_T = bilinear_sampler(corr_T, coords_lvl_1) + corr = corr.view(batch, h1, w1, -1) + corr_T = corr_T.view(batch, h1, w1, -1) + out_pyramid.append(corr) + out_pyramid_T.append(corr_T) + + out = torch.cat(out_pyramid, dim=-1) + out_T = torch.cat(out_pyramid_T, dim=-1) + return out.permute(0, 3, 1, 2).contiguous().float(), out_T.permute(0, 3, 1, 2).contiguous().float() + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht*wd) + fmap2 = fmap2.view(batch, dim, ht*wd) + + corr = torch.matmul(fmap1.transpose(1,2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) \ No newline at end of file diff --git a/modules/components/amt_splat/blocks/softsplat.py b/modules/components/amt_splat/blocks/softsplat.py new file mode 100644 index 0000000000000000000000000000000000000000..77967f24cd1eeee56417d1de2c88369d13b883c6 --- /dev/null +++ b/modules/components/amt_splat/blocks/softsplat.py @@ -0,0 +1,360 @@ +#!/usr/bin/env python + +import torch + +import cupy +import re + +kernel_Softsplat_updateOutput = ''' + extern "C" __global__ void kernel_Softsplat_updateOutput( + const int n, + const float* input, + const float* flow, + float* output + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(output) / SIZE_2(output) / SIZE_1(output) ) % SIZE_0(output); + const int intC = ( intIndex / SIZE_3(output) / SIZE_2(output) ) % SIZE_1(output); + const int intY = ( intIndex / SIZE_3(output) ) % SIZE_2(output); + const int intX = ( intIndex ) % SIZE_3(output); + + float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); + float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); + + int intNorthwestX = (int) (floor(fltOutputX)); + int intNorthwestY = (int) (floor(fltOutputY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + float fltNorthwest = ((float) (intSoutheastX) - fltOutputX ) * ((float) (intSoutheastY) - fltOutputY ); + float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY ); + float fltSouthwest = ((float) (intNortheastX) - fltOutputX ) * (fltOutputY - (float) (intNortheastY)); + float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY)); + + if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(output)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(output))) { + atomicAdd(&output[OFFSET_4(output, intN, intC, intNorthwestY, intNorthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltNorthwest); + } + + if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(output)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(output))) { + atomicAdd(&output[OFFSET_4(output, intN, intC, intNortheastY, intNortheastX)], VALUE_4(input, intN, intC, intY, intX) * fltNortheast); + } + + if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(output)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(output))) { + atomicAdd(&output[OFFSET_4(output, intN, intC, intSouthwestY, intSouthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltSouthwest); + } + + if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(output)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(output))) { + atomicAdd(&output[OFFSET_4(output, intN, intC, intSoutheastY, intSoutheastX)], VALUE_4(input, intN, intC, intY, intX) * fltSoutheast); + } + } } +''' + +kernel_Softsplat_updateGradInput = ''' + extern "C" __global__ void kernel_Softsplat_updateGradInput( + const int n, + const float* input, + const float* flow, + const float* gradOutput, + float* gradInput, + float* gradFlow + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) / SIZE_1(gradInput) ) % SIZE_0(gradInput); + const int intC = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) ) % SIZE_1(gradInput); + const int intY = ( intIndex / SIZE_3(gradInput) ) % SIZE_2(gradInput); + const int intX = ( intIndex ) % SIZE_3(gradInput); + + float fltGradInput = 0.0; + + float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); + float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); + + int intNorthwestX = (int) (floor(fltOutputX)); + int intNorthwestY = (int) (floor(fltOutputY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + float fltNorthwest = ((float) (intSoutheastX) - fltOutputX ) * ((float) (intSoutheastY) - fltOutputY ); + float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY ); + float fltSouthwest = ((float) (intNortheastX) - fltOutputX ) * (fltOutputY - (float) (intNortheastY)); + float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY)); + + if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) { + fltGradInput += VALUE_4(gradOutput, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest; + } + + if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) { + fltGradInput += VALUE_4(gradOutput, intN, intC, intNortheastY, intNortheastX) * fltNortheast; + } + + if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) { + fltGradInput += VALUE_4(gradOutput, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest; + } + + if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) { + fltGradInput += VALUE_4(gradOutput, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast; + } + + gradInput[intIndex] = fltGradInput; + } } +''' + +kernel_Softsplat_updateGradFlow = ''' + extern "C" __global__ void kernel_Softsplat_updateGradFlow( + const int n, + const float* input, + const float* flow, + const float* gradOutput, + float* gradInput, + float* gradFlow + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + float fltGradFlow = 0.0; + + const int intN = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) / SIZE_1(gradFlow) ) % SIZE_0(gradFlow); + const int intC = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) ) % SIZE_1(gradFlow); + const int intY = ( intIndex / SIZE_3(gradFlow) ) % SIZE_2(gradFlow); + const int intX = ( intIndex ) % SIZE_3(gradFlow); + + float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); + float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); + + int intNorthwestX = (int) (floor(fltOutputX)); + int intNorthwestY = (int) (floor(fltOutputY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + float fltNorthwest = 0.0; + float fltNortheast = 0.0; + float fltSouthwest = 0.0; + float fltSoutheast = 0.0; + + if (intC == 0) { + fltNorthwest = ((float) (-1.0)) * ((float) (intSoutheastY) - fltOutputY ); + fltNortheast = ((float) (+1.0)) * ((float) (intSouthwestY) - fltOutputY ); + fltSouthwest = ((float) (-1.0)) * (fltOutputY - (float) (intNortheastY)); + fltSoutheast = ((float) (+1.0)) * (fltOutputY - (float) (intNorthwestY)); + + } else if (intC == 1) { + fltNorthwest = ((float) (intSoutheastX) - fltOutputX ) * ((float) (-1.0)); + fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (-1.0)); + fltSouthwest = ((float) (intNortheastX) - fltOutputX ) * ((float) (+1.0)); + fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * ((float) (+1.0)); + + } + + for (int intChannel = 0; intChannel < SIZE_1(gradOutput); intChannel += 1) { + float fltInput = VALUE_4(input, intN, intChannel, intY, intX); + + if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) { + fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNorthwestY, intNorthwestX) * fltNorthwest; + } + + if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) { + fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNortheastY, intNortheastX) * fltNortheast; + } + + if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) { + fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSouthwestY, intSouthwestX) * fltSouthwest; + } + + if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) { + fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSoutheastY, intSoutheastX) * fltSoutheast; + } + } + + gradFlow[intIndex] = fltGradFlow; + } } +''' + +def cupy_kernel(strFunction, objVariables): + strKernel = globals()[strFunction] + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg])) + + while True: + objMatch = re.search('(OFFSET_)([0-4])(\()([^\)]+)(\))', strKernel) + + if objMatch is None: + break + + intArgs = int(objMatch.group(2)) + strArgs = objMatch.group(4).split(',') + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')')\ + .strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] + + strKernel = strKernel.replace(objMatch.group(0), '(' + str.join('+', strIndex) + ')') + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) + + if objMatch is None: + break + + intArgs = int(objMatch.group(2)) + strArgs = objMatch.group(4).split(',') + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')')\ + .strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] + + strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') + + return strKernel + + +@cupy.memoize(for_each_device=True) +def cupy_launch(strFunction, strKernel): + return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) + + +class _FunctionSoftsplat(torch.autograd.Function): + @staticmethod + def forward(self, input, flow): + self.save_for_backward(input, flow) + + intSamples = input.shape[0] + intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3] + intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3] + + assert(intFlowDepth == 2) + assert(intInputHeight == intFlowHeight) + assert(intInputWidth == intFlowWidth) + + assert(input.is_contiguous() == True) + assert(flow.is_contiguous() == True) + + output = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ]) + + if input.is_cuda == True: + n = output.nelement() + cupy_launch('kernel_Softsplat_updateOutput', cupy_kernel('kernel_Softsplat_updateOutput', { + 'input': input, + 'flow': flow, + 'output': output + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ n, input.data_ptr(), flow.data_ptr(), output.data_ptr() ] + ) + + elif input.is_cuda == False: + raise NotImplementedError() + + return output + + + @staticmethod + def backward(self, gradOutput): + input, flow = self.saved_tensors + + intSamples = input.shape[0] + intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3] + intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3] + + assert(intFlowDepth == 2) + assert(intInputHeight == intFlowHeight) + assert(intInputWidth == intFlowWidth) + + assert(gradOutput.is_contiguous() == True) + + gradInput = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ])\ + if self.needs_input_grad[0] == True else None + gradFlow = input.new_zeros([ intSamples, intFlowDepth, intFlowHeight, intFlowWidth ])\ + if self.needs_input_grad[1] == True else None + + if input.is_cuda == True: + if gradInput is not None: + n = gradInput.nelement() + cupy_launch('kernel_Softsplat_updateGradInput', cupy_kernel('kernel_Softsplat_updateGradInput', { + 'input': input, + 'flow': flow, + 'gradOutput': gradOutput, + 'gradInput': gradInput, + 'gradFlow': gradFlow + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ n, input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), gradInput.data_ptr(), None ] + ) + + if gradFlow is not None: + n = gradFlow.nelement() + cupy_launch('kernel_Softsplat_updateGradFlow', cupy_kernel('kernel_Softsplat_updateGradFlow', { + 'input': input, + 'flow': flow, + 'gradOutput': gradOutput, + 'gradInput': gradInput, + 'gradFlow': gradFlow + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ n, input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), None, gradFlow.data_ptr() ] + ) + + elif input.is_cuda == False: + raise NotImplementedError() + + + return gradInput, gradFlow + + +def FunctionSoftsplat(tenInput, tenFlow, tenMetric, strType): + assert(tenMetric is None or tenMetric.shape[1] == 1) + assert(strType in ['summation', 'average', 'linear', 'softmax']) + + if strType == 'average': + tenInput = torch.cat([ tenInput, tenInput.new_ones(tenInput.shape[0], 1, tenInput.shape[2], tenInput.shape[3]) ], 1) + + elif strType == 'linear': + tenInput = torch.cat([ tenInput * tenMetric, tenMetric ], 1) + + elif strType == 'softmax': + tenInput = torch.cat([ tenInput * tenMetric.clip(-20, 20).exp(), tenMetric.clip(-20, 20).exp() ], 1) + + + tenOutput = _FunctionSoftsplat.apply(tenInput, tenFlow) + + if strType != 'summation': + tenNormalize = tenOutput[:, -1:, :, :] + + tenNormalize[tenNormalize == 0.0] = 1.0 + + tenOutput = tenOutput[:, :-1, :, :] / tenNormalize + + return tenOutput + + +class ModuleSoftsplat(torch.nn.Module): + def __init__(self, strType): + super(ModuleSoftsplat, self).__init__() + + self.strType = strType + + def forward(self, tenInput, tenFlow, tenMetric): + return FunctionSoftsplat(tenInput, tenFlow, tenMetric, self.strType) + diff --git a/modules/components/amt_splat/blocks/warp.py b/modules/components/amt_splat/blocks/warp.py new file mode 100644 index 0000000000000000000000000000000000000000..89c63449c52bc12b73cc94b29c1d96a305365270 --- /dev/null +++ b/modules/components/amt_splat/blocks/warp.py @@ -0,0 +1,13 @@ +import torch +import torch.nn.functional as F + + +def warp(img, flow): + B, _, H, W = flow.shape + xx = torch.linspace(-1.0, 1.0, W).view(1, 1, 1, W).expand(B, -1, H, -1) + yy = torch.linspace(-1.0, 1.0, H).view(1, 1, H, 1).expand(B, -1, -1, W) + grid = torch.cat([xx, yy], 1).to(img) + flow_ = torch.cat([flow[:, 0:1, :, :] / ((W - 1.0) / 2.0), flow[:, 1:2, :, :] / ((H - 1.0) / 2.0)], 1) + grid_ = (grid + flow_).permute(0, 2, 3, 1) + output = F.grid_sample(input=img, grid=grid_, mode='bilinear', padding_mode='border', align_corners=True) + return output diff --git a/modules/components/components.py b/modules/components/components.py new file mode 100644 index 0000000000000000000000000000000000000000..dd5b8979f109f74deda698854f63bfb590ea71cc --- /dev/null +++ b/modules/components/components.py @@ -0,0 +1,21 @@ +import copy + + +components = {} + + +def register(name): + def decorator(cls): + components[name] = cls + return cls + return decorator + + +def make_components(model_spec, args=None): + if args is not None: + model_args = copy.deepcopy(model_spec['args']) + model_args.update(args) + else: + model_args = model_spec['args'] + model = components[model_spec['name']](**model_args) + return model diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__init__.py b/modules/components/m2m_flow_former/LatentCostFormer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/__init__.cpython-310.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6a1e4881996d19a08a5c7a1d57edc9f43e92832 Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/__init__.cpython-310.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/__init__.cpython-38.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e67b6d413d51f0c3891355940b74f60ad7d0c67a Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/__init__.cpython-38.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/__init__.cpython-39.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7fe55eeebd41d8fed0eb1e29c3b6501e5dd99409 Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/__init__.cpython-39.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/attention.cpython-310.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3d788e39b05e57ce769e11309189f6ce2a037ed Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/attention.cpython-310.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/attention.cpython-38.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/attention.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8fb8987a732409561894cb7a385c0b951e7bfc74 Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/attention.cpython-38.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/attention.cpython-39.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/attention.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36b035bb1ee6c28a817e8bb9a71bbf55416be5d6 Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/attention.cpython-39.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/cnn.cpython-310.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/cnn.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bb2df26aa0ead73b0274a68327854949014ab33 Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/cnn.cpython-310.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/cnn.cpython-38.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/cnn.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1271e407413b01a1256b7a466ad5bd0fdf7ddb00 Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/cnn.cpython-38.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/cnn.cpython-39.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/cnn.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2912509e84c3b03f9db680fa15405f52c32df104 Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/cnn.cpython-39.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/common.cpython-310.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/common.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6811e4853c2175258086082612f84f6fb9240282 Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/common.cpython-310.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/common.cpython-38.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/common.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a805bf086aa1d7d405f43ab1339f70ffc5a61ab Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/common.cpython-38.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/common.cpython-39.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/common.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..861b21bb30b49a0f60b7577211719289f5bc7142 Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/common.cpython-39.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/convnext.cpython-310.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/convnext.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d2e7c2523b21bd4d608eade4766a3ac27504ffa Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/convnext.cpython-310.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/convnext.cpython-38.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/convnext.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6aa88974c4dd6ea35a1147769a526e1613394b26 Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/convnext.cpython-38.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/convnext.cpython-39.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/convnext.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e91677a9d1bd1584daabde415727bbccd4d5adbb Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/convnext.cpython-39.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/decoder.cpython-310.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/decoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b3a759174e96f05723b6bd441368c409d6633bb Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/decoder.cpython-310.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/decoder.cpython-38.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/decoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92dd5cc2440218b4c63169e3abf0f4cdeb11754c Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/decoder.cpython-38.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/decoder.cpython-39.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/decoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d60a9468571787eebf2a9328b1243f75964a514b Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/decoder.cpython-39.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/encoder.cpython-310.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/encoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0870b9df6998a8a7b729612e66eafe1cfebe2223 Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/encoder.cpython-310.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/encoder.cpython-38.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/encoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31697bbddbfc568d6a5dcaead166f8699fc665ef Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/encoder.cpython-38.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/encoder.cpython-39.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/encoder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cabcadbe43a50e390677bbd97ded4fb71fa7d255 Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/encoder.cpython-39.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/encoders.cpython-310.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/encoders.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..deb1d8a48e09947915319fb3f55af50344abe06a Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/encoders.cpython-310.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/encoders.cpython-38.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/encoders.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eee65ff169b776e5facee5d56be69377320c079a Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/encoders.cpython-38.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/encoders.cpython-39.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/encoders.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f89c3beb4d17af4053a12e9b242f911594065e38 Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/encoders.cpython-39.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/gma.cpython-310.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/gma.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05b096a14028d3272ab589b75db5fe6775fdaf34 Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/gma.cpython-310.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/gma.cpython-38.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/gma.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60441cd00308398394eb08b36adde14760dfc138 Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/gma.cpython-38.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/gma.cpython-39.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/gma.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..178f6168a883d5b8348b7231812a6aed0afd08e4 Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/gma.cpython-39.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/gru.cpython-310.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/gru.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11f6c0f95d46c90d7147cc8c4e4f9c597c8d61e3 Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/gru.cpython-310.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/gru.cpython-38.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/gru.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e54635c76e280578f6b4877dcd5168c9fbb98316 Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/gru.cpython-38.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/gru.cpython-39.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/gru.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2556050524cc2600c5fef30e59864e469b904e37 Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/gru.cpython-39.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/mlpmixer.cpython-310.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/mlpmixer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eae411bf93f395535b6748c59fed3f494bcd729f Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/mlpmixer.cpython-310.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/mlpmixer.cpython-38.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/mlpmixer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7142c7d1c06136ac4e74ad4e920bc406b31777d Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/mlpmixer.cpython-38.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/mlpmixer.cpython-39.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/mlpmixer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68c55992c9fc37fa881fa278c30e2673e4c42920 Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/mlpmixer.cpython-39.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/position_encoding.cpython-310.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/position_encoding.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb9717bc8955688632fb9baa8e0530baaeb859fa Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/position_encoding.cpython-310.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/position_encoding.cpython-38.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/position_encoding.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d004898fb850d0ad65482909c24bb4c400157323 Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/position_encoding.cpython-38.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/position_encoding.cpython-39.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/position_encoding.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8486f2016b55e8270cacfd723056afc6c8871498 Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/position_encoding.cpython-39.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/transformer.cpython-310.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1cf03f3ab71ba7af45965a56ac22432ec75ebf25 Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/transformer.cpython-310.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/transformer.cpython-38.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19d746ff05152edf0b88733e51bb07a9c867aa2a Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/transformer.cpython-38.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/transformer.cpython-39.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/transformer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9a898e9978eb0127bf7f11750702b24e786f323 Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/transformer.cpython-39.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/twins.cpython-310.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/twins.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a92de102b0a9e179d50bd7cb440b50d317010a4 Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/twins.cpython-310.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/twins.cpython-38.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/twins.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63c48c2c9eaaac8a22cef4ac8687330ff450e7c7 Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/twins.cpython-38.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/twins.cpython-39.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/twins.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9065b8bf303d94bc11892a8eb33433bf9e4eeb4 Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/twins.cpython-39.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/utils.cpython-310.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..404a3ecafc6cb70bd651313cb3b4409f7f032324 Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/utils.cpython-310.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/utils.cpython-38.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..655fa68f4305fdc8f5087446ddc6748ac6096385 Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/utils.cpython-38.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/utils.cpython-39.pyc b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..26adf0c39d25b4fe3b49d2b05a339016a86ecadc Binary files /dev/null and b/modules/components/m2m_flow_former/LatentCostFormer/__pycache__/utils.cpython-39.pyc differ diff --git a/modules/components/m2m_flow_former/LatentCostFormer/attention.py b/modules/components/m2m_flow_former/LatentCostFormer/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..3b3625ea48fd1e8584c387f1a1a22f236dedb6c3 --- /dev/null +++ b/modules/components/m2m_flow_former/LatentCostFormer/attention.py @@ -0,0 +1,160 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum + +from einops.layers.torch import Rearrange +from einops import rearrange + +class BroadMultiHeadAttention(nn.Module): + def __init__(self, dim, heads): + super(BroadMultiHeadAttention, self).__init__() + self.dim = dim + self.heads = heads + self.scale = (dim/heads) ** -0.5 + self.attend = nn.Softmax(dim=-1) + + def attend_with_rpe(self, Q, K): + Q = rearrange(Q.squeeze(), 'i (heads d) -> heads i d', heads=self.heads) + K = rearrange(K, 'b j (heads d) -> b heads j d', heads=self.heads) + + dots = einsum('hid, bhjd -> bhij', Q, K) * self.scale # (b hw) heads 1 pointnum + + return self.attend(dots) + + def forward(self, Q, K, V): + attn = self.attend_with_rpe(Q, K) + B, _, _ = K.shape + _, N, _ = Q.shape + + V = rearrange(V, 'b j (heads d) -> b heads j d', heads=self.heads) + + out = einsum('bhij, bhjd -> bhid', attn, V) + out = rearrange(out, 'b heads n d -> b n (heads d)', b=B, n=N) + + return out + +class MultiHeadAttention(nn.Module): + def __init__(self, dim, heads): + super(MultiHeadAttention, self).__init__() + self.dim = dim + self.heads = heads + self.scale = (dim/heads) ** -0.5 + self.attend = nn.Softmax(dim=-1) + + def attend_with_rpe(self, Q, K): + Q = rearrange(Q, 'b i (heads d) -> b heads i d', heads=self.heads) + K = rearrange(K, 'b j (heads d) -> b heads j d', heads=self.heads) + + dots = einsum('bhid, bhjd -> bhij', Q, K) * self.scale # (b hw) heads 1 pointnum + + return self.attend(dots) + + def forward(self, Q, K, V): + attn = self.attend_with_rpe(Q, K) + B, HW, _ = Q.shape + + V = rearrange(V, 'b j (heads d) -> b heads j d', heads=self.heads) + + out = einsum('bhij, bhjd -> bhid', attn, V) + out = rearrange(out, 'b heads hw d -> b hw (heads d)', b=B, hw=HW) + + return out + +# class MultiHeadAttentionRelative_encoder(nn.Module): +# def __init__(self, dim, heads): +# super(MultiHeadAttentionRelative, self).__init__() +# self.dim = dim +# self.heads = heads +# self.scale = (dim/heads) ** -0.5 +# self.attend = nn.Softmax(dim=-1) + +# def attend_with_rpe(self, Q, K, Q_r, K_r): +# """ +# Q: [BH1W1, H3W3, dim] +# K: [BH1W1, H3W3, dim] +# Q_r: [BH1W1, H3W3, H3W3, dim] +# K_r: [BH1W1, H3W3, H3W3, dim] +# """ + +# Q = rearrange(Q, 'b i (heads d) -> b heads i d', heads=self.heads) # [BH1W1, heads, H3W3, dim] +# K = rearrange(K, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim] +# K_r = rearrange(K_r, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim] +# Q_r = rearrange(Q_r, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim] + +# # context-context similarity +# c_c = einsum('bhid, bhjd -> bhij', Q, K) * self.scale # [(B H1W1) heads H3W3 H3W3] +# # context-position similarity +# c_p = einsum('bhid, bhjd -> bhij', Q, K_r) * self.scale # [(B H1W1) heads 1 H3W3] +# # position-context similarity +# p_c = einsum('bhijd, bhikd -> bhijk', Q_r[:,:,:,None,:], K[:,:,:,None,:]) +# p_c = torch.squeeze(p_c, dim=4) +# p_c = p_c.permute(0, 1, 3, 2) +# dots = c_c + c_p + p_c +# return self.attend(dots) + +# def forward(self, Q, K, V, Q_r, K_r): +# attn = self.attend_with_rpe(Q, K, Q_r, K_r) +# B, HW, _ = Q.shape + +# V = rearrange(V, 'b j (heads d) -> b heads j d', heads=self.heads) + +# out = einsum('bhij, bhjd -> bhid', attn, V) +# out = rearrange(out, 'b heads hw d -> b hw (heads d)', b=B, hw=HW) + +# return out + +class MultiHeadAttentionRelative(nn.Module): + def __init__(self, dim, heads): + super(MultiHeadAttentionRelative, self).__init__() + self.dim = dim + self.heads = heads + self.scale = (dim/heads) ** -0.5 + self.attend = nn.Softmax(dim=-1) + + def attend_with_rpe(self, Q, K, Q_r, K_r): + """ + Q: [BH1W1, 1, dim] + K: [BH1W1, H3W3, dim] + Q_r: [BH1W1, H3W3, dim] + K_r: [BH1W1, H3W3, dim] + """ + + Q = rearrange(Q, 'b i (heads d) -> b heads i d', heads=self.heads) # [BH1W1, heads, 1, dim] + K = rearrange(K, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim] + K_r = rearrange(K_r, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim] + Q_r = rearrange(Q_r, 'b j (heads d) -> b heads j d', heads=self.heads) # [BH1W1, heads, H3W3, dim] + + # context-context similarity + c_c = einsum('bhid, bhjd -> bhij', Q, K) * self.scale # [(B H1W1) heads 1 H3W3] + # context-position similarity + c_p = einsum('bhid, bhjd -> bhij', Q, K_r) * self.scale # [(B H1W1) heads 1 H3W3] + # position-context similarity + p_c = einsum('bhijd, bhikd -> bhijk', Q_r[:,:,:,None,:], K[:,:,:,None,:]) * self.scale + p_c = torch.squeeze(p_c, dim=4) + p_c = p_c.permute(0, 1, 3, 2) + dots = c_c + c_p + p_c + return self.attend(dots) + + def forward(self, Q, K, V, Q_r, K_r): + attn = self.attend_with_rpe(Q, K, Q_r, K_r) + B, HW, _ = Q.shape + + V = rearrange(V, 'b j (heads d) -> b heads j d', heads=self.heads) + + out = einsum('bhij, bhjd -> bhid', attn, V) + out = rearrange(out, 'b heads hw d -> b hw (heads d)', b=B, hw=HW) + + return out + +def LinearPositionEmbeddingSine(x, dim=128, NORMALIZE_FACOR=1/200): + # 200 should be enough for a 8x downsampled image + # assume x to be [_, _, 2] + freq_bands = torch.linspace(0, dim//4-1, dim//4).to(x.device) + return torch.cat([torch.sin(3.14*x[..., -2:-1]*freq_bands*NORMALIZE_FACOR), torch.cos(3.14*x[..., -2:-1]*freq_bands*NORMALIZE_FACOR), torch.sin(3.14*x[..., -1:]*freq_bands*NORMALIZE_FACOR), torch.cos(3.14*x[..., -1:]*freq_bands*NORMALIZE_FACOR)], dim=-1) + +def ExpPositionEmbeddingSine(x, dim=128, NORMALIZE_FACOR=1/200): + # 200 should be enough for a 8x downsampled image + # assume x to be [_, _, 2] + freq_bands = torch.linspace(0, dim//4-1, dim//4).to(x.device) + return torch.cat([torch.sin(x[..., -2:-1]*(NORMALIZE_FACOR * 2 ** freq_bands)), torch.cos(x[..., -2:-1]*(NORMALIZE_FACOR * 2 ** freq_bands)), torch.sin(x[..., -1:]*(NORMALIZE_FACOR * 2 ** freq_bands)), torch.cos(x[..., -1:]*(NORMALIZE_FACOR * 2 ** freq_bands))], dim=-1) \ No newline at end of file diff --git a/modules/components/m2m_flow_former/LatentCostFormer/cnn.py b/modules/components/m2m_flow_former/LatentCostFormer/cnn.py new file mode 100644 index 0000000000000000000000000000000000000000..47b184570c3cd771580c72c4009107a580612a3b --- /dev/null +++ b/modules/components/m2m_flow_former/LatentCostFormer/cnn.py @@ -0,0 +1,577 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import Mlp, DropPath, to_2tuple, trunc_normal_ +import math +import numpy as np + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) + self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes//4) + self.norm2 = nn.BatchNorm2d(planes//4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes//4) + self.norm2 = nn.InstanceNorm2d(planes//4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + +class BasicEncoder(nn.Module): + def __init__(self, input_dim=3, output_dim=128, norm_fn='batch', dropout=0.0): + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + mul = input_dim // 3 + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64 * mul) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64 * mul) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64 * mul) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(input_dim, 64 * mul, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 * mul + self.layer1 = self._make_layer(64 * mul, stride=1) + self.layer2 = self._make_layer(96 * mul, stride=2) + self.layer3 = self._make_layer(128 * mul, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128 * mul, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def compute_params(self): + num = 0 + for param in self.parameters(): + num += np.prod(param.size()) + + return num + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + +class SmallEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(SmallEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(32) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(32) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 32 + self.layer1 = self._make_layer(32, stride=1) + self.layer2 = self._make_layer(64, stride=2) + self.layer3 = self._make_layer(96, stride=2) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + +class ConvNets(nn.Module): + def __init__(self, in_dim, out_dim, inter_dim, depth, stride=1): + super(ConvNets, self).__init__() + + self.conv_first = nn.Conv2d(in_dim, inter_dim, kernel_size=3, padding=1, stride=stride) + self.conv_last = nn.Conv2d(inter_dim, out_dim, kernel_size=3, padding=1, stride=stride) + self.relu = nn.ReLU(inplace=True) + self.inter_convs = nn.ModuleList( + [ResidualBlock(inter_dim, inter_dim, norm_fn='none', stride=1) for i in range(depth)]) + + def forward(self, x): + x = self.relu(self.conv_first(x)) + for inter_conv in self.inter_convs: + x = inter_conv(x) + x = self.conv_last(x) + return x + +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256): + super(FlowHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) + + h = (1-z) * h + z * q + return h + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + + self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + return h + +class BasicMotionEncoder(nn.Module): + def __init__(self, args): + super(BasicMotionEncoder, self).__init__() + cor_planes = args.motion_feature_dim + self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(2, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + +class BasicFuseMotion(nn.Module): + def __init__(self, args): + super(BasicFuseMotion, self).__init__() + cor_planes = args.motion_feature_dim + out_planes = args.query_latent_dim + + self.normf1 = nn.InstanceNorm2d(128) + self.normf2 = nn.InstanceNorm2d(128) + + self.convf1 = nn.Conv2d(2, 128, 3, padding=1) + self.convf2 = nn.Conv2d(128, 128, 3, padding=1) + self.convf3 = nn.Conv2d(128, 64, 3, padding=1) + + s = 1 + self.normc1 = nn.InstanceNorm2d(256*s) + self.normc2 = nn.InstanceNorm2d(256*s) + self.normc3 = nn.InstanceNorm2d(256*s) + + self.convc1 = nn.Conv2d(cor_planes+128, 256*s, 1, padding=0) + self.convc2 = nn.Conv2d(256*s, 256*s, 3, padding=1) + self.convc3 = nn.Conv2d(256*s, 256*s, 3, padding=1) + self.convc4 = nn.Conv2d(256*s, 256*s, 3, padding=1) + self.conv = nn.Conv2d(256*s + 64, out_planes, 1, padding=0) + + def forward(self, flow, feat, context1=None): + flo = F.relu(self.normf1(self.convf1(flow))) + flo = F.relu(self.normf2(self.convf2(flo))) + flo = self.convf3(flo) + + feat = torch.cat([feat, context1], dim=1) + feat = F.relu(self.normc1(self.convc1(feat))) + feat = F.relu(self.normc2(self.convc2(feat))) + feat = F.relu(self.normc3(self.convc3(feat))) + feat = self.convc4(feat) + + feat = torch.cat([flo, feat], dim=1) + feat = F.relu(self.conv(feat)) + + return feat + +class BasicUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128, input_dim=128): + super(BasicUpdateBlock, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64*9, 1, padding=0)) + + def forward(self, net, inp, corr, flow, upsample=True): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = .25 * self.mask(net) + return net, mask, delta_flow + +class DirectMeanMaskPredictor(nn.Module): + def __init__(self, args): + super(DirectMeanMaskPredictor, self).__init__() + self.flow_head = FlowHead(args.predictor_dim, hidden_dim=256) + self.mask = nn.Sequential( + nn.Conv2d(args.predictor_dim, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64*9, 1, padding=0)) + + def forward(self, motion_features): + delta_flow = self.flow_head(motion_features) + mask = .25 * self.mask(motion_features) + + return mask, delta_flow + +class BaiscMeanPredictor(nn.Module): + def __init__(self, args, hidden_dim=128): + super(BaiscMeanPredictor, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64*9, 1, padding=0)) + + def forward(self, latent, flow): + motion_features = self.encoder(flow, latent) + delta_flow = self.flow_head(motion_features) + mask = .25 * self.mask(motion_features) + + return mask, delta_flow + +class BasicRPEEncoder(nn.Module): + def __init__(self, args): + super(BasicRPEEncoder, self).__init__() + self.args = args + dim = args.query_latent_dim + self.encoder = nn.Sequential( + nn.Linear(2, dim // 2), + nn.ReLU(inplace=True), + nn.Linear(dim // 2, dim), + nn.ReLU(inplace=True), + nn.Linear(dim, dim) + ) + + def forward(self, rpe_tokens): + return self.encoder(rpe_tokens) + +from .twins import Block, CrossBlock + +class TwinsSelfAttentionLayer(nn.Module): + def __init__(self, args): + super(TwinsSelfAttentionLayer, self).__init__() + self.args = args + embed_dim = 256 + num_heads = 8 + mlp_ratio = 4 + ws = 7 + sr_ratio = 4 + dpr = 0. + drop_rate = 0. + attn_drop_rate=0. + + self.local_block = Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr, sr_ratio=sr_ratio, ws=ws, with_rpe=True) + self.global_block = Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr, sr_ratio=sr_ratio, ws=1, with_rpe=True) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + + def forward(self, x, tgt, size): + x = self.local_block(x, size) + x = self.global_block(x, size) + + tgt = self.local_block(tgt, size) + tgt = self.global_block(tgt, size) + return x, tgt + +class TwinsCrossAttentionLayer(nn.Module): + def __init__(self, args): + super(TwinsCrossAttentionLayer, self).__init__() + self.args = args + embed_dim = 256 + num_heads = 8 + mlp_ratio = 4 + ws = 7 + sr_ratio = 4 + dpr = 0. + drop_rate = 0. + attn_drop_rate=0. + + self.local_block = Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr, sr_ratio=sr_ratio, ws=ws, with_rpe=True) + self.global_block = CrossBlock(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr, sr_ratio=sr_ratio, ws=1, with_rpe=True) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + + def forward(self, x, tgt, size): + x = self.local_block(x, size) + tgt = self.local_block(tgt, size) + x, tgt = self.global_block(x, tgt, size) + + return x, tgt diff --git a/modules/components/m2m_flow_former/LatentCostFormer/common.py b/modules/components/m2m_flow_former/LatentCostFormer/common.py new file mode 100644 index 0000000000000000000000000000000000000000..1c4a4d1c513bf15af66512242e0645b496c1ccdb --- /dev/null +++ b/modules/components/m2m_flow_former/LatentCostFormer/common.py @@ -0,0 +1,424 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum + +from einops.layers.torch import Rearrange +from einops import rearrange + +from .utils import coords_grid, bilinear_sampler, indexing +from loguru import logger + +import math + +def nerf_encoding(x, L=6, NORMALIZE_FACOR=1/300): + """ + x is of shape [*, 2]. The last dimension are two coordinates (x and y). + """ + freq_bands = 2.** torch.linspace(0, L, L-1).to(x.device) + return torch.cat([x*NORMALIZE_FACOR, torch.sin(3.14*x[..., -2:-1]*freq_bands*NORMALIZE_FACOR), torch.cos(3.14*x[..., -2:-1]*freq_bands*NORMALIZE_FACOR), torch.sin(3.14*x[..., -1:]*freq_bands*NORMALIZE_FACOR), torch.cos(3.14*x[..., -1:]*freq_bands*NORMALIZE_FACOR)], dim=-1) + +def sampler_gaussian(latent, mean, std, image_size, point_num=25): + # latent [B, H*W, D] + # mean [B, 2, H, W] + # std [B, 1, H, W] + H, W = image_size + B, HW, D = latent.shape + STD_MAX = 20 + latent = rearrange(latent, 'b (h w) c -> b c h w', h=H, w=W) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2) + mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2] + + dx = torch.linspace(-1, 1, int(point_num**0.5)) + dy = torch.linspace(-1, 1, int(point_num**0.5)) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(mean.device) # [B*H*W, point_num**0.5, point_num**0.5, 2] + delta_3sigma = F.sigmoid(std.permute(0, 2, 3, 1).reshape(B*HW, 1, 1, 1)) * STD_MAX * delta * 3 # [B*H*W, point_num**0.5, point_num**0.5, 2] + + centroid = mean.reshape(B*H*W, 1, 1, 2) + coords = centroid + delta_3sigma + + coords = rearrange(coords, '(b h w) r1 r2 c -> b (h w) (r1 r2) c', b=B, h=H, w=W) + sampled_latents = bilinear_sampler(latent, coords) # [B*H*W, dim, point_num**0.5, point_num**0.5] + sampled_latents = sampled_latents.permute(0, 2, 3, 1) + sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) + + return sampled_latents, sampled_weights + +def sampler_gaussian_zy(latent, mean, std, image_size, point_num=25, return_deltaXY=False, beta=1): + # latent [B, H*W, D] + # mean [B, 2, H, W] + # std [B, 1, H, W] + H, W = image_size + B, HW, D = latent.shape + latent = rearrange(latent, 'b (h w) c -> b c h w', h=H, w=W) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2) + mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2] + + dx = torch.linspace(-1, 1, int(point_num**0.5)) + dy = torch.linspace(-1, 1, int(point_num**0.5)) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(mean.device) # [B*H*W, point_num**0.5, point_num**0.5, 2] + delta_3sigma = std.permute(0, 2, 3, 1).reshape(B*HW, 1, 1, 1) * delta * 3 # [B*H*W, point_num**0.5, point_num**0.5, 2] + + centroid = mean.reshape(B*H*W, 1, 1, 2) + coords = centroid + delta_3sigma + + coords = rearrange(coords, '(b h w) r1 r2 c -> b (h w) (r1 r2) c', b=B, h=H, w=W) + sampled_latents = bilinear_sampler(latent, coords) # [B*H*W, dim, point_num**0.5, point_num**0.5] + sampled_latents = sampled_latents.permute(0, 2, 3, 1) + sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / beta + + if return_deltaXY: + return sampled_latents, sampled_weights, delta_3sigma + else: + return sampled_latents, sampled_weights + +def sampler_gaussian(latent, mean, std, image_size, point_num=25, return_deltaXY=False): + # latent [B, H*W, D] + # mean [B, 2, H, W] + # std [B, 1, H, W] + H, W = image_size + B, HW, D = latent.shape + STD_MAX = 20 + latent = rearrange(latent, 'b (h w) c -> b c h w', h=H, w=W) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2) + mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2] + + dx = torch.linspace(-1, 1, int(point_num**0.5)) + dy = torch.linspace(-1, 1, int(point_num**0.5)) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(mean.device) # [B*H*W, point_num**0.5, point_num**0.5, 2] + delta_3sigma = F.sigmoid(std.permute(0, 2, 3, 1).reshape(B*HW, 1, 1, 1)) * STD_MAX * delta * 3 # [B*H*W, point_num**0.5, point_num**0.5, 2] + + centroid = mean.reshape(B*H*W, 1, 1, 2) + coords = centroid + delta_3sigma + + coords = rearrange(coords, '(b h w) r1 r2 c -> b (h w) (r1 r2) c', b=B, h=H, w=W) + sampled_latents = bilinear_sampler(latent, coords) # [B*H*W, dim, point_num**0.5, point_num**0.5] + sampled_latents = sampled_latents.permute(0, 2, 3, 1) + sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) + + if return_deltaXY: + return sampled_latents, sampled_weights, delta_3sigma + else: + return sampled_latents, sampled_weights + +def sampler_gaussian_fix(latent, mean, image_size, point_num=49): + # latent [B, H*W, D] + # mean [B, 2, H, W] + H, W = image_size + B, HW, D = latent.shape + STD_MAX = 20 + latent = rearrange(latent, 'b (h w) c -> b c h w', h=H, w=W) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2) + mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2] + + radius = int((int(point_num**0.5)-1)/2) + + dx = torch.linspace(-radius, radius, 2*radius+1) + dy = torch.linspace(-radius, radius, 2*radius+1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(mean.device) # [B*H*W, point_num**0.5, point_num**0.5, 2] + + centroid = mean.reshape(B*H*W, 1, 1, 2) + coords = centroid + delta + + coords = rearrange(coords, '(b h w) r1 r2 c -> b (h w) (r1 r2) c', b=B, h=H, w=W) + sampled_latents = bilinear_sampler(latent, coords) # [B*H*W, dim, point_num**0.5, point_num**0.5] + sampled_latents = sampled_latents.permute(0, 2, 3, 1) + sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term + + return sampled_latents, sampled_weights + +def sampler_gaussian_fix_pyramid(latent, feat_pyramid, scale_weight, mean, image_size, point_num=25): + # latent [B, H*W, D] + # mean [B, 2, H, W] + # scale weight [B, H*W, layer_num] + + H, W = image_size + B, HW, D = latent.shape + STD_MAX = 20 + latent = rearrange(latent, 'b (h w) c -> b c h w', h=H, w=W) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2) + mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2] + + radius = int((int(point_num**0.5)-1)/2) + + dx = torch.linspace(-radius, radius, 2*radius+1) + dy = torch.linspace(-radius, radius, 2*radius+1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(mean.device) # [B*H*W, point_num**0.5, point_num**0.5, 2] + + sampled_latents = [] + for i in range(len(feat_pyramid)): + centroid = mean.reshape(B*H*W, 1, 1, 2) + coords = (centroid + delta) / 2**i + coords = rearrange(coords, '(b h w) r1 r2 c -> b (h w) (r1 r2) c', b=B, h=H, w=W) + sampled_latents.append(bilinear_sampler(feat_pyramid[i], coords)) + + sampled_latents = torch.stack(sampled_latents, dim=1) # [B, layer_num, dim, H*W, point_num] + sampled_latents = sampled_latents.permute(0, 3, 4, 2, 1) # [B, H*W, point_num, dim, layer_num] + scale_weight = F.softmax(scale_weight, dim=2) # [B, H*W, layer_num] + vis_out = scale_weight + scale_weight = torch.unsqueeze(torch.unsqueeze(scale_weight, dim=2), dim=2) # [B, HW, 1, 1, layer_num] + + weighted_latent = torch.sum(sampled_latents*scale_weight, dim=-1) # [B, H*W, point_num, dim] + + sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term + + return weighted_latent, sampled_weights, vis_out + +def sampler_gaussian_pyramid(latent, feat_pyramid, scale_weight, mean, std, image_size, point_num=25): + # latent [B, H*W, D] + # mean [B, 2, H, W] + # scale weight [B, H*W, layer_num] + + H, W = image_size + B, HW, D = latent.shape + STD_MAX = 20 + latent = rearrange(latent, 'b (h w) c -> b c h w', h=H, w=W) # latent = latent.view(B, H, W, D).permute(0, 3, 1, 2) + mean = mean.permute(0, 2, 3, 1) # [B, H, W, 2] + + radius = int((int(point_num**0.5)-1)/2) + + dx = torch.linspace(-1, 1, int(point_num**0.5)) + dy = torch.linspace(-1, 1, int(point_num**0.5)) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(mean.device) # [B*H*W, point_num**0.5, point_num**0.5, 2] + delta_3sigma = std.permute(0, 2, 3, 1).reshape(B*HW, 1, 1, 1) * delta * 3 # [B*H*W, point_num**0.5, point_num**0.5, 2] + + sampled_latents = [] + for i in range(len(feat_pyramid)): + centroid = mean.reshape(B*H*W, 1, 1, 2) + coords = (centroid + delta_3sigma) / 2**i + coords = rearrange(coords, '(b h w) r1 r2 c -> b (h w) (r1 r2) c', b=B, h=H, w=W) + sampled_latents.append(bilinear_sampler(feat_pyramid[i], coords)) + + sampled_latents = torch.stack(sampled_latents, dim=1) # [B, layer_num, dim, H*W, point_num] + sampled_latents = sampled_latents.permute(0, 3, 4, 2, 1) # [B, H*W, point_num, dim, layer_num] + scale_weight = F.softmax(scale_weight, dim=2) # [B, H*W, layer_num] + vis_out = scale_weight + scale_weight = torch.unsqueeze(torch.unsqueeze(scale_weight, dim=2), dim=2) # [B, HW, 1, 1, layer_num] + + weighted_latent = torch.sum(sampled_latents*scale_weight, dim=-1) # [B, H*W, point_num, dim] + + sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term + + return weighted_latent, sampled_weights, vis_out + +def sampler_gaussian_fix_MH(latent, mean, image_size, point_num=25): + """different heads have different mean""" + # latent [B, H*W, D] + # mean [B, 2, H, W, heands] + + H, W = image_size + B, HW, D = latent.shape + _, _, _, _, HEADS = mean.shape + STD_MAX = 20 + latent = rearrange(latent, 'b (h w) c -> b c h w', h=H, w=W) + mean = mean.permute(0, 2, 3, 4, 1) # [B, H, W, heads, 2] + + radius = int((int(point_num**0.5)-1)/2) + + dx = torch.linspace(-radius, radius, 2*radius+1) + dy = torch.linspace(-radius, radius, 2*radius+1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(mean.device).repeat(HEADS, 1, 1, 1) # [HEADS, point_num**0.5, point_num**0.5, 2] + + centroid = mean.reshape(B*H*W, HEADS, 1, 1, 2) + coords = centroid + delta + coords = rearrange(coords, '(b h w) H r1 r2 c -> b (h w H) (r1 r2) c', b=B, h=H, w=W, H=HEADS) + sampled_latents = bilinear_sampler(latent, coords) # [B, dim, H*W*HEADS, pointnum] + sampled_latents = sampled_latents.permute(0, 2, 3, 1) # [B, H*W*HEADS, pointnum, dim] + sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term + return sampled_latents, sampled_weights + +def sampler_gaussian_fix_pyramid_MH(latent, feat_pyramid, scale_head_weight, mean, image_size, point_num=25): + # latent [B, H*W, D] + # mean [B, 2, H, W, heands] + # scale_head weight [B, H*W, layer_num*heads] + + H, W = image_size + B, HW, D = latent.shape + _, _, _, _, HEADS = mean.shape + + latent = rearrange(latent, 'b (h w) c -> b c h w', h=H, w=W) + mean = mean.permute(0, 2, 3, 4, 1) # [B, H, W, heads, 2] + + radius = int((int(point_num**0.5)-1)/2) + + dx = torch.linspace(-radius, radius, 2*radius+1) + dy = torch.linspace(-radius, radius, 2*radius+1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(mean.device) # [B*H*W, point_num**0.5, point_num**0.5, 2] + + sampled_latents = [] + centroid = mean.reshape(B*H*W, HEADS, 1, 1, 2) + for i in range(len(feat_pyramid)): + coords = (centroid ) / 2**i + delta + coords = rearrange(coords, '(b h w) H r1 r2 c -> b (h w H) (r1 r2) c', b=B, h=H, w=W, H=HEADS) + sampled_latents.append(bilinear_sampler(feat_pyramid[i], coords)) # [B, dim, H*W*HEADS, point_num] + + sampled_latents = torch.stack(sampled_latents, dim=1) # [B, layer_num, dim, H*W*HEADS, point_num] + sampled_latents = sampled_latents.permute(0, 3, 4, 2, 1) # [B, H*W*HEADS, point_num, dim, layer_num] + + scale_head_weight = scale_head_weight.reshape(B, H*W*HEADS, -1) + scale_head_weight = F.softmax(scale_head_weight, dim=2) # [B, H*W*HEADS, layer_num] + scale_head_weight = torch.unsqueeze(torch.unsqueeze(scale_head_weight, dim=2), dim=2) # [B, H*W*HEADS, 1, 1, layer_num] + + weighted_latent = torch.sum(sampled_latents*scale_head_weight, dim=-1) # [B, H*W*HEADS, point_num, dim] + + sampled_weights = -(torch.sum(delta.pow(2), dim=-1)) / point_num # smooth term + + return weighted_latent, sampled_weights + +def sampler(feat, center, window_size): + # feat [B, C, H, W] + # center [B, 2, H, W] + center = center.permute(0, 2, 3, 1) # [B, H, W, 2] + B, H, W, C = center.shape + + radius = window_size // 2 + dx = torch.linspace(-radius, radius, 2*radius+1) + dy = torch.linspace(-radius, radius, 2*radius+1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(center.device) # [B*H*W, window_size, point_num**0.5, 2] + + center = center.reshape(B*H*W, 1, 1, 2) + coords = center + delta + + coords = rearrange(coords, '(b h w) r1 r2 c -> b (h w) (r1 r2) c', b=B, h=H, w=W) + sampled_latents = bilinear_sampler(feat, coords) # [B*H*W, dim, window_size, window_size] + # sampled_latents = sampled_latents.permute(0, 2, 3, 1) + + return sampled_latents + +def retrieve_tokens(feat, center, window_size, sampler): + # feat [B, C, H, W] + # center [B, 2, H, W] + radius = window_size // 2 + dx = torch.linspace(-radius, radius, 2*radius+1) + dy = torch.linspace(-radius, radius, 2*radius+1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(center.device) # [B*H*W, point_num**0.5, point_num**0.5, 2] + + B, H, W, C = center.shape + centroid = center.reshape(B*H*W, 1, 1, 2) + coords = centroid + delta + + coords = rearrange(coords, '(b h w) r1 r2 c -> b (h w) (r1 r2) c', b=B, h=H, w=W) + if sampler == 'nn': + sampled_latents = indexing(feat, coords) + elif sampler == 'bilinear': + sampled_latents = bilinear_sampler(feat, coords) + else: + raise ValueError("invalid sampler") + # [B, dim, H*W, point_num] + + return sampled_latents + +def pyramid_retrieve_tokens(feat_pyramid, center, image_size, window_sizes, sampler='bilinear'): + center = center.permute(0, 2, 3, 1) # [B, H, W, 2] + sampled_latents_pyramid = [] + for idx in range(len(window_sizes)): + sampled_latents_pyramid.append( + retrieve_tokens( + feat_pyramid[idx], + center, + window_sizes[idx], + sampler + )) + center = center / 2 + + return torch.cat(sampled_latents_pyramid, dim=-1) + +class FeedForward(nn.Module): + def __init__(self, dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + x = self.net(x) + return x + +class MLP(nn.Module): + def __init__(self, in_dim=22, out_dim=1, innter_dim=96, depth=5): + super().__init__() + self.FC1 = nn.Linear(in_dim, innter_dim) + self.FC_out = nn.Linear(innter_dim, out_dim) + self.relu = torch.nn.LeakyReLU(0.2) + self.FC_inter = nn.ModuleList( + [nn.Linear(innter_dim, innter_dim) for i in range(depth)]) + + def forward(self, x): + x = self.FC1(x) + x = self.relu(x) + for inter_fc in self.FC_inter: + x = inter_fc(x) + x = self.relu(x) + x = self.FC_out(x) + return x + +class MultiHeadAttention(nn.Module): + def __init__(self, dim, heads, num_kv_tokens, cfg, rpe_bias=None, use_rpe=False): + super(MultiHeadAttention, self).__init__() + self.dim = dim + self.heads = heads + self.num_kv_tokens = num_kv_tokens + self.scale = (dim/heads) ** -0.5 + self.rpe = cfg.rpe + self.attend = nn.Softmax(dim=-1) + self.use_rpe = use_rpe + + if use_rpe: + if rpe_bias is None: + if self.rpe == 'element-wise': + self.rpe_bias = nn.Parameter(torch.zeros(heads, self.num_kv_tokens, dim // heads)) + elif self.rpe == 'head-wise': + self.rpe_bias = nn.Parameter(torch.zeros(1, heads, 1, self.num_kv_tokens)) + elif self.rpe == 'token-wise': + self.rpe_bias = nn.Parameter(torch.zeros(1, 1, 1, self.num_kv_tokens)) # 81 is point_num + elif self.rpe == 'implicit': + pass + # self.implicit_pe_fn = MLP(in_dim=22, out_dim=self.dim, innter_dim=int(self.dim//2.4), depth=2) + # raise ValueError('Implicit Encoding Not Implemented') + elif self.rpe == 'element-wise-value': + self.rpe_bias = nn.Parameter(torch.zeros(heads, self.num_kv_tokens, dim // heads)) + self.rpe_value = nn.Parameter(torch.randn(self.num_kv_tokens, dim)) + else: + raise ValueError('Not Implemented') + else: + self.rpe_bias = rpe_bias + + def attend_with_rpe(self, Q, K, rpe_bias): + Q = rearrange(Q, 'b i (heads d) -> b heads i d', heads=self.heads) + K = rearrange(K, 'b j (heads d) -> b heads j d', heads=self.heads) + + dots = einsum('bhid, bhjd -> bhij', Q, K) * self.scale # (b hw) heads 1 pointnum + if self.use_rpe: + if self.rpe == 'element-wise': + rpe_bias_weight = einsum('bhid, hjd -> bhij', Q, rpe_bias) * self.scale # (b hw) heads 1 pointnum + dots = dots + rpe_bias_weight + elif self.rpe == 'implicit': + pass + rpe_bias_weight = einsum('bhid, bhjd -> bhij', Q, rpe_bias) * self.scale # (b hw) heads 1 pointnum + dots = dots + rpe_bias_weight + elif self.rpe == 'head-wise' or self.rpe == 'token-wise': + dots = dots + rpe_bias + + return self.attend(dots), dots + + def forward(self, Q, K, V, rpe_bias = None): + if self.use_rpe: + if rpe_bias is None or self.rpe =='element-wise': + rpe_bias = self.rpe_bias + else: + rpe_bias = rearrange(rpe_bias, 'b hw pn (heads d) -> (b hw) heads pn d', heads=self.heads) + attn, dots = self.attend_with_rpe(Q, K, rpe_bias) + else: + attn, dots = self.attend_with_rpe(Q, K, None) + B, HW, _ = Q.shape + + if V is not None: + V = rearrange(V, 'b j (heads d) -> b heads j d', heads=self.heads) + + out = einsum('bhij, bhjd -> bhid', attn, V) + out = rearrange(out, 'b heads hw d -> b hw (heads d)', b=B, hw=HW) + else: + out = None + + # dots = torch.squeeze(dots, 2) + # dots = rearrange(dots, '(b hw) heads d -> b hw (heads d)', b=B, hw=HW) + + return out, dots diff --git a/modules/components/m2m_flow_former/LatentCostFormer/convnext.py b/modules/components/m2m_flow_former/LatentCostFormer/convnext.py new file mode 100644 index 0000000000000000000000000000000000000000..5114ee8191d9da78a3fac2195787c941c6f560aa --- /dev/null +++ b/modules/components/m2m_flow_former/LatentCostFormer/convnext.py @@ -0,0 +1,87 @@ +from turtle import forward +import torch +from torch import nn +import torch.nn.functional as F +import numpy as np + +class ConvNextLayer(nn.Module): + def __init__(self, dim, depth=4): + super().__init__() + self.net = nn.Sequential( + *[ConvNextBlock(dim=dim) for j in range(depth)] + ) + + def forward(self, x): + return self.net(x) + + def compute_params(self): + num = 0 + for param in self.parameters(): + num += np.prod(param.size()) + + return num + +class ConvNextBlock(nn.Module): + r""" ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + def __init__(self, dim, layer_scale_init_value=1e-6): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), + requires_grad=True) if layer_scale_init_value > 0 else None + # self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + # print(f"conv next layer") + + def forward(self, x): + input = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + x + return x + + +class LayerNorm(nn.Module): + r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape, ) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x \ No newline at end of file diff --git a/modules/components/m2m_flow_former/LatentCostFormer/decoder.py b/modules/components/m2m_flow_former/LatentCostFormer/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..883ddc1f470467737a549946f780bbced423f4e1 --- /dev/null +++ b/modules/components/m2m_flow_former/LatentCostFormer/decoder.py @@ -0,0 +1,260 @@ +import loguru +import torch +import math +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum + +from einops.layers.torch import Rearrange +from einops import rearrange + +from .utils import coords_grid, bilinear_sampler, upflow8 +from .attention import MultiHeadAttention, LinearPositionEmbeddingSine, ExpPositionEmbeddingSine +from typing import Optional, Tuple + +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + +from .gru import BasicUpdateBlock, GMAUpdateBlock +from .gma import Attention + +def initialize_flow(img): + """ Flow is represented as difference between two means flow = mean1 - mean0""" + N, C, H, W = img.shape + mean = coords_grid(N, H, W).to(img.device) + mean_init = coords_grid(N, H, W).to(img.device) + + # optical flow computed as difference: flow = mean1 - mean0 + return mean, mean_init + +class CrossAttentionLayer(nn.Module): + # def __init__(self, dim, cfg, num_heads=8, attn_drop=0., proj_drop=0., drop_path=0., dropout=0.): + def __init__(self, qk_dim, v_dim, query_token_dim, tgt_token_dim, add_flow_token=True, num_heads=8, attn_drop=0., proj_drop=0., drop_path=0., dropout=0., pe='linear'): + super(CrossAttentionLayer, self).__init__() + + head_dim = qk_dim // num_heads + self.scale = head_dim ** -0.5 + self.query_token_dim = query_token_dim + self.pe = pe + + self.norm1 = nn.LayerNorm(query_token_dim) + self.norm2 = nn.LayerNorm(query_token_dim) + self.multi_head_attn = MultiHeadAttention(qk_dim, num_heads) + self.q, self.k, self.v = nn.Linear(query_token_dim, qk_dim, bias=True), nn.Linear(tgt_token_dim, qk_dim, bias=True), nn.Linear(tgt_token_dim, v_dim, bias=True) + + self.proj = nn.Linear(v_dim*2, query_token_dim) + self.proj_drop = nn.Dropout(proj_drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.ffn = nn.Sequential( + nn.Linear(query_token_dim, query_token_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(query_token_dim, query_token_dim), + nn.Dropout(dropout) + ) + self.add_flow_token = add_flow_token + self.dim = qk_dim + def forward(self, query, key, value, memory, query_coord, patch_size, size_h3w3): + """ + query_coord [B, 2, H1, W1] + """ + B, _, H1, W1 = query_coord.shape + + if key is None and value is None: + key = self.k(memory) + value = self.v(memory) + + # [B, 2, H1, W1] -> [BH1W1, 1, 2] + query_coord = query_coord.contiguous() + query_coord = query_coord.view(B, 2, -1).permute(0, 2, 1)[:,:,None,:].contiguous().view(B*H1*W1, 1, 2) + if self.pe == 'linear': + query_coord_enc = LinearPositionEmbeddingSine(query_coord, dim=self.dim) + elif self.pe == 'exp': + query_coord_enc = ExpPositionEmbeddingSine(query_coord, dim=self.dim) + + short_cut = query + query = self.norm1(query) + + if self.add_flow_token: + q = self.q(query+query_coord_enc) + else: + q = self.q(query_coord_enc) + k, v = key, value + + x = self.multi_head_attn(q, k, v) + + x = self.proj(torch.cat([x, short_cut],dim=2)) + x = short_cut + self.proj_drop(x) + + x = x + self.drop_path(self.ffn(self.norm2(x))) + + return x, k, v + +class MemoryDecoderLayer(nn.Module): + def __init__(self, dim, cfg): + super(MemoryDecoderLayer, self).__init__() + self.cfg = cfg + self.patch_size = cfg.patch_size # for converting coords into H2', W2' space + + query_token_dim, tgt_token_dim = cfg.query_latent_dim, cfg.cost_latent_dim + qk_dim, v_dim = query_token_dim, query_token_dim + self.cross_attend = CrossAttentionLayer(qk_dim, v_dim, query_token_dim, tgt_token_dim, add_flow_token=cfg.add_flow_token, dropout=cfg.dropout) + + def forward(self, query, key, value, memory, coords1, size, size_h3w3): + """ + x: [B*H1*W1, 1, C] + memory: [B*H1*W1, H2'*W2', C] + coords1 [B, 2, H2, W2] + size: B, C, H1, W1 + 1. Note that here coords0 and coords1 are in H2, W2 space. + Should first convert it into H2', W2' space. + 2. We assume the upper-left point to be [0, 0], instead of letting center of upper-left patch to be [0, 0] + """ + x_global, k, v = self.cross_attend(query, key, value, memory, coords1, self.patch_size, size_h3w3) + B, C, H1, W1 = size + C = self.cfg.query_latent_dim + x_global = x_global.view(B, H1, W1, C).permute(0, 3, 1, 2) + return x_global, k, v + +class ReverseCostExtractor(nn.Module): + def __init__(self, cfg): + super(ReverseCostExtractor, self).__init__() + self.cfg = cfg + + def forward(self, cost_maps, coords0, coords1): + """ + cost_maps - B*H1*W1, cost_heads_num, H2, W2 + coords - B, 2, H1, W1 + """ + BH1W1, heads, H2, W2 = cost_maps.shape + B, _, H1, W1 = coords1.shape + + assert (H1 == H2) and (W1 == W2) + assert BH1W1 == B*H1*W1 + + cost_maps = cost_maps.reshape(B, H1* W1*heads, H2, W2) + coords = coords1.permute(0, 2, 3, 1) + corr = bilinear_sampler(cost_maps, coords) # [B, H1*W1*heads, H2, W2] + corr = rearrange(corr, 'b (h1 w1 heads) h2 w2 -> (b h2 w2) heads h1 w1', b=B, heads=heads, h1=H1, w1=W1, h2=H2, w2=W2) + + r = 4 + dx = torch.linspace(-r, r, 2*r+1) + dy = torch.linspace(-r, r, 2*r+1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords0.device) + centroid = coords0.permute(0, 2, 3, 1).reshape(BH1W1, 1, 1, 2) + delta = delta.view(1, 2*r+1, 2*r+1, 2) + coords = centroid + delta + corr = bilinear_sampler(corr, coords) + corr = corr.view(B, H1, W1, -1).permute(0, 3, 1, 2) + return corr + +class MemoryDecoder(nn.Module): + def __init__(self, cfg): + super(MemoryDecoder, self).__init__() + dim = self.dim = cfg.query_latent_dim + self.cfg = cfg + + self.flow_token_encoder = nn.Sequential( + nn.Conv2d(81*cfg.cost_heads_num, dim, 1, 1), + nn.GELU(), + nn.Conv2d(dim, dim, 1, 1) + ) + self.proj = nn.Conv2d(256, 256, 1) + self.depth = cfg.decoder_depth + self.decoder_layer = MemoryDecoderLayer(dim, cfg) + + if self.cfg.gma: + self.update_block = GMAUpdateBlock(self.cfg, hidden_dim=128) + self.att = Attention(args=self.cfg, dim=128, heads=1, max_pos_size=160, dim_head=128) + else: + self.update_block = BasicUpdateBlock(self.cfg, hidden_dim=128) + + def upsample_flow(self, flow, mask): + """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 8, 8, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(8 * flow, [3,3], padding=1) + up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 2, 8*H, 8*W) + + def encode_flow_token(self, cost_maps, coords): + """ + cost_maps - B*H1*W1, cost_heads_num, H2, W2 + coords - B, 2, H1, W1 + """ + coords = coords.permute(0, 2, 3, 1) + batch, h1, w1, _ = coords.shape + + r = 4 + dx = torch.linspace(-r, r, 2*r+1) + dy = torch.linspace(-r, r, 2*r+1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) + + centroid = coords.reshape(batch*h1*w1, 1, 1, 2) + delta = delta.view(1, 2*r+1, 2*r+1, 2) + coords = centroid + delta + corr = bilinear_sampler(cost_maps, coords) + corr = corr.view(batch, h1, w1, -1).permute(0, 3, 1, 2) + return corr + + def forward(self, cost_memory, context, data={}, flow_init=None): + """ + memory: [B*H1*W1, H2'*W2', C] + context: [B, D, H1, W1] + """ + cost_maps = data['cost_maps'] + coords0, coords1 = initialize_flow(context) + + if flow_init is not None: + #print("[Using warm start]") + coords1 = coords1 + flow_init + + #flow = coords1 + + flow_predictions = [] + + context = self.proj(context) + net, inp = torch.split(context, [128, 128], dim=1) + net = torch.tanh(net) + inp = torch.relu(inp) + if self.cfg.gma: + attention = self.att(inp) + + size = net.shape + key, value = None, None + + for idx in range(self.depth): + # coords1 = coords1.detach() + + cost_forward = self.encode_flow_token(cost_maps, coords1) + #cost_backward = self.reverse_cost_extractor(cost_maps, coords0, coords1) + + query = self.flow_token_encoder(cost_forward) + query = query.permute(0, 2, 3, 1).contiguous().view(size[0]*size[2]*size[3], 1, self.dim) + cost_global, key, value = self.decoder_layer(query, key, value, cost_memory, coords1, size, data['H3W3']) + if self.cfg.only_global: + corr = cost_global + else: + corr = torch.cat([cost_global, cost_forward], dim=1) + + flow = coords1 - coords0 + + if self.cfg.gma: + net, up_mask, delta_flow = self.update_block(net, inp, corr, flow, attention) + else: + net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) + + # flow = delta_flow + coords1 = coords1 + delta_flow + flow_up = self.upsample_flow(coords1 - coords0, up_mask) + flow_predictions.append(flow_up) + + if self.training: + return flow_predictions + else: + return flow_predictions[-1:] diff --git a/modules/components/m2m_flow_former/LatentCostFormer/encoder.py b/modules/components/m2m_flow_former/LatentCostFormer/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a3faaec065e14746ead1aa0449150bd54749fc99 --- /dev/null +++ b/modules/components/m2m_flow_former/LatentCostFormer/encoder.py @@ -0,0 +1,368 @@ +import loguru +import torch +import math +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum +import numpy as np + +from einops.layers.torch import Rearrange +from einops import rearrange + +from .utils import coords_grid, bilinear_sampler, upflow8 +from .attention import BroadMultiHeadAttention, MultiHeadAttention, LinearPositionEmbeddingSine, ExpPositionEmbeddingSine +from .encoders import twins_svt_large +from typing import Optional, Tuple +from .twins import Size_, PosConv +from .cnn import TwinsSelfAttentionLayer, TwinsCrossAttentionLayer, BasicEncoder +from .mlpmixer import MLPMixerLayer +from .convnext import ConvNextLayer +import time + +from timm.models.layers import Mlp, DropPath, to_2tuple, trunc_normal_ + +class PatchEmbed(nn.Module): + def __init__(self, patch_size=16, in_chans=1, embed_dim=64, pe='linear'): + super().__init__() + self.patch_size = patch_size + self.dim = embed_dim + self.pe = pe + + # assert patch_size == 8 + if patch_size == 8: + self.proj = nn.Sequential( + nn.Conv2d(in_chans, embed_dim//4, kernel_size=6, stride=2, padding=2), + nn.ReLU(), + nn.Conv2d(embed_dim//4, embed_dim//2, kernel_size=6, stride=2, padding=2), + nn.ReLU(), + nn.Conv2d(embed_dim//2, embed_dim, kernel_size=6, stride=2, padding=2), + ) + elif patch_size == 4: + self.proj = nn.Sequential( + nn.Conv2d(in_chans, embed_dim//4, kernel_size=6, stride=2, padding=2), + nn.ReLU(), + nn.Conv2d(embed_dim//4, embed_dim, kernel_size=6, stride=2, padding=2), + ) + else: + print(f"patch size = {patch_size} is unacceptable.") + + self.ffn_with_coord = nn.Sequential( + nn.Conv2d(embed_dim*2, embed_dim*2, kernel_size=1), + nn.ReLU(), + nn.Conv2d(embed_dim*2, embed_dim*2, kernel_size=1) + ) + self.norm = nn.LayerNorm(embed_dim*2) + + def forward(self, x) -> Tuple[torch.Tensor, Size_]: + B, C, H, W = x.shape # C == 1 + + pad_l = pad_t = 0 + pad_r = (self.patch_size - W % self.patch_size) % self.patch_size + pad_b = (self.patch_size - H % self.patch_size) % self.patch_size + x = F.pad(x, (pad_l, pad_r, pad_t, pad_b)) + + x = self.proj(x) + out_size = x.shape[2:] + + patch_coord = coords_grid(B, out_size[0], out_size[1]).to(x.device) * self.patch_size + self.patch_size/2 # in feature coordinate space + patch_coord = patch_coord.view(B, 2, -1).permute(0, 2, 1) + if self.pe == 'linear': + patch_coord_enc = LinearPositionEmbeddingSine(patch_coord, dim=self.dim) + elif self.pe == 'exp': + patch_coord_enc = ExpPositionEmbeddingSine(patch_coord, dim=self.dim) + patch_coord_enc = patch_coord_enc.permute(0, 2, 1).view(B, -1, out_size[0], out_size[1]) + + x_pe = torch.cat([x, patch_coord_enc], dim=1) + x = self.ffn_with_coord(x_pe) + x = self.norm(x.flatten(2).transpose(1, 2)) + + return x, out_size + +from .twins import Block, CrossBlock + +class GroupVerticalSelfAttentionLayer(nn.Module): + def __init__(self, dim, cfg, num_heads=8, attn_drop=0., proj_drop=0., drop_path=0., dropout=0.): + super(GroupVerticalSelfAttentionLayer, self).__init__() + self.cfg = cfg + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + embed_dim = dim + mlp_ratio = 4 + ws = 7 + sr_ratio = 4 + dpr = 0. + drop_rate = dropout + attn_drop_rate=0. + + self.block = Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr, sr_ratio=sr_ratio, ws=ws, with_rpe=True, vert_c_dim=cfg.vert_c_dim, groupattention=True, cfg=self.cfg) + + def forward(self, x, size, context=None): + x = self.block(x, size, context) + + return x + +class VerticalSelfAttentionLayer(nn.Module): + def __init__(self, dim, cfg, num_heads=8, attn_drop=0., proj_drop=0., drop_path=0., dropout=0.): + super(VerticalSelfAttentionLayer, self).__init__() + self.cfg = cfg + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + embed_dim = dim + mlp_ratio = 4 + ws = 7 + sr_ratio = 4 + dpr = 0. + drop_rate = dropout + attn_drop_rate=0. + + self.local_block = Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr, sr_ratio=sr_ratio, ws=ws, with_rpe=True, vert_c_dim=cfg.vert_c_dim) + self.global_block = Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr, sr_ratio=sr_ratio, ws=1, with_rpe=True, vert_c_dim=cfg.vert_c_dim) + + def forward(self, x, size, context=None): + x = self.local_block(x, size, context) + x = self.global_block(x, size, context) + + return x + + def compute_params(self): + num = 0 + for param in self.parameters(): + num += np.prod(param.size()) + + return num + +class SelfAttentionLayer(nn.Module): + def __init__(self, dim, cfg, num_heads=8, attn_drop=0., proj_drop=0., drop_path=0., dropout=0.): + super(SelfAttentionLayer, self).__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.multi_head_attn = MultiHeadAttention(dim, num_heads) + self.q, self.k, self.v = nn.Linear(dim, dim, bias=True), nn.Linear(dim, dim, bias=True), nn.Linear(dim, dim, bias=True) + + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.ffn = nn.Sequential( + nn.Linear(dim, dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x): + """ + x: [BH1W1, H3W3, D] + """ + short_cut = x + x = self.norm1(x) + + q, k, v = self.q(x), self.k(x), self.v(x) + + x = self.multi_head_attn(q, k, v) + + x = self.proj(x) + x = short_cut + self.proj_drop(x) + + x = x + self.drop_path(self.ffn(self.norm2(x))) + + return x + + def compute_params(self): + num = 0 + for param in self.parameters(): + num += np.prod(param.size()) + + return num + + +class CrossAttentionLayer(nn.Module): + def __init__(self, qk_dim, v_dim, query_token_dim, tgt_token_dim, num_heads=8, attn_drop=0., proj_drop=0., drop_path=0., dropout=0.): + super(CrossAttentionLayer, self).__init__() + assert qk_dim % num_heads == 0, f"dim {qk_dim} should be divided by num_heads {num_heads}." + assert v_dim % num_heads == 0, f"dim {v_dim} should be divided by num_heads {num_heads}." + """ + Query Token: [N, C] -> [N, qk_dim] (Q) + Target Token: [M, D] -> [M, qk_dim] (K), [M, v_dim] (V) + """ + self.num_heads = num_heads + head_dim = qk_dim // num_heads + self.scale = head_dim ** -0.5 + + self.norm1 = nn.LayerNorm(query_token_dim) + self.norm2 = nn.LayerNorm(query_token_dim) + self.multi_head_attn = BroadMultiHeadAttention(qk_dim, num_heads) + self.q, self.k, self.v = nn.Linear(query_token_dim, qk_dim, bias=True), nn.Linear(tgt_token_dim, qk_dim, bias=True), nn.Linear(tgt_token_dim, v_dim, bias=True) + + self.proj = nn.Linear(v_dim, query_token_dim) + self.proj_drop = nn.Dropout(proj_drop) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.ffn = nn.Sequential( + nn.Linear(query_token_dim, query_token_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(query_token_dim, query_token_dim), + nn.Dropout(dropout) + ) + + def forward(self, query, tgt_token): + """ + x: [BH1W1, H3W3, D] + """ + short_cut = query + query = self.norm1(query) + + q, k, v = self.q(query), self.k(tgt_token), self.v(tgt_token) + + x = self.multi_head_attn(q, k, v) + + x = short_cut + self.proj_drop(self.proj(x)) + + x = x + self.drop_path(self.ffn(self.norm2(x))) + + return x + + +class CostPerceiverEncoder(nn.Module): + def __init__(self, cfg): + super(CostPerceiverEncoder, self).__init__() + self.cfg = cfg + self.patch_size = cfg.patch_size + self.patch_embed = PatchEmbed(in_chans=self.cfg.cost_heads_num, patch_size=self.patch_size, embed_dim=cfg.cost_latent_input_dim, pe=cfg.pe) + + self.depth = cfg.encoder_depth + + self.latent_tokens = nn.Parameter(torch.randn(1, cfg.cost_latent_token_num, cfg.cost_latent_dim)) + + query_token_dim, tgt_token_dim = cfg.cost_latent_dim, cfg.cost_latent_input_dim*2 + qk_dim, v_dim = query_token_dim, query_token_dim + self.input_layer = CrossAttentionLayer(qk_dim, v_dim, query_token_dim, tgt_token_dim, dropout=cfg.dropout) + + if cfg.use_mlp: + self.encoder_layers = nn.ModuleList([MLPMixerLayer(cfg.cost_latent_dim, cfg, dropout=cfg.dropout) for idx in range(self.depth)]) + else: + self.encoder_layers = nn.ModuleList([SelfAttentionLayer(cfg.cost_latent_dim, cfg, dropout=cfg.dropout) for idx in range(self.depth)]) + + if self.cfg.vertical_conv: + self.vertical_encoder_layers = nn.ModuleList([ConvNextLayer(cfg.cost_latent_dim) for idx in range(self.depth)]) + else: + self.vertical_encoder_layers = nn.ModuleList([VerticalSelfAttentionLayer(cfg.cost_latent_dim, cfg, dropout=cfg.dropout) for idx in range(self.depth)]) + self.cost_scale_aug = None + if ('cost_scale_aug' in cfg.keys()): + self.cost_scale_aug = cfg.cost_scale_aug + print("[Using cost_scale_aug: {}]".format(self.cost_scale_aug)) + + + + def forward(self, cost_volume, data, context=None): + B, heads, H1, W1, H2, W2 = cost_volume.shape + cost_maps = cost_volume.permute(0, 2, 3, 1, 4, 5).contiguous().view(B*H1*W1, self.cfg.cost_heads_num, H2, W2) + data['cost_maps'] = cost_maps + + if self.cost_scale_aug is not None: + scale_factor = torch.FloatTensor(B*H1*W1, self.cfg.cost_heads_num, H2, W2).uniform_(self.cost_scale_aug[0], self.cost_scale_aug[1]).cuda() + cost_maps = cost_maps * scale_factor + + x, size = self.patch_embed(cost_maps) # B*H1*W1, size[0]*size[1], C + data['H3W3'] = size + H3, W3 = size + + x = self.input_layer(self.latent_tokens, x) + + short_cut = x + + for idx, layer in enumerate(self.encoder_layers): + x = layer(x) + if self.cfg.vertical_conv: + # B, H1*W1, K, D -> B, K, D, H1*W1 -> B*K, D, H1, W1 + x = x.view(B, H1*W1, self.cfg.cost_latent_token_num, -1).permute(0, 3, 1, 2).reshape(B*self.cfg.cost_latent_token_num, -1, H1, W1) + x = self.vertical_encoder_layers[idx](x) + # B*K, D, H1, W1 -> B, K, D, H1*W1 -> B, H1*W1, K, D + x = x.view(B, self.cfg.cost_latent_token_num, -1, H1*W1).permute(0, 2, 3, 1).reshape(B*H1*W1, self.cfg.cost_latent_token_num, -1) + else: + x = x.view(B, H1*W1, self.cfg.cost_latent_token_num, -1).permute(0, 2, 1, 3).reshape(B*self.cfg.cost_latent_token_num, H1*W1, -1) + x = self.vertical_encoder_layers[idx](x, (H1, W1), context) + x = x.view(B, self.cfg.cost_latent_token_num, H1*W1, -1).permute(0, 2, 1, 3).reshape(B*H1*W1, self.cfg.cost_latent_token_num, -1) + + if self.cfg.cost_encoder_res is True: + x = x + short_cut + #print("~~~~") + return x + +class MemoryEncoder(nn.Module): + def __init__(self, cfg): + super(MemoryEncoder, self).__init__() + self.cfg = cfg + + if cfg.fnet == 'twins': + self.feat_encoder = twins_svt_large(pretrained=False) + elif cfg.fnet == 'basicencoder': + self.feat_encoder = BasicEncoder(output_dim=256, norm_fn='instance') + else: + exit() + self.channel_convertor = nn.Conv2d(cfg.encoder_latent_dim, cfg.encoder_latent_dim, 1, padding=0, bias=False) + self.cost_perceiver_encoder = CostPerceiverEncoder(cfg) + + def corr(self, fmap1, fmap2): + + batch, dim, ht, wd = fmap1.shape + fmap1 = rearrange(fmap1, 'b (heads d) h w -> b heads (h w) d', heads=self.cfg.cost_heads_num) + fmap2 = rearrange(fmap2, 'b (heads d) h w -> b heads (h w) d', heads=self.cfg.cost_heads_num) + corr = einsum('bhid, bhjd -> bhij', fmap1, fmap2) + corr = corr.permute(0, 2, 1, 3).view(batch*ht*wd, self.cfg.cost_heads_num, ht, wd) + #corr = self.norm(self.relu(corr)) + corr = corr.view(batch, ht*wd, self.cfg.cost_heads_num, ht*wd).permute(0, 2, 1, 3) + corr = corr.view(batch, self.cfg.cost_heads_num, ht, wd, ht, wd) + + return corr + + def forward(self, img1, img2, data, context=None): + # The original implementation + # feat_s = self.feat_encoder(img1) + # feat_t = self.feat_encoder(img2) + # feat_s = self.channel_convertor(feat_s) + # feat_t = self.channel_convertor(feat_t) + + imgs = torch.cat([img1, img2], dim=0) + feats = self.feat_encoder(imgs) + feats = self.channel_convertor(feats) + B = feats.shape[0] // 2 + + feat_s = feats[:B] + feat_t = feats[B:] + + B, C, H, W = feat_s.shape + size = (H, W) + + if self.cfg.feat_cross_attn: + feat_s = feat_s.flatten(2).transpose(1, 2) + feat_t = feat_t.flatten(2).transpose(1, 2) + + for layer in self.layers: + feat_s, feat_t = layer(feat_s, feat_t, size) + + feat_s = feat_s.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() + feat_t = feat_t.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() + + cost_volume = self.corr(feat_s, feat_t) + x = self.cost_perceiver_encoder(cost_volume, data, context) + + return x \ No newline at end of file diff --git a/modules/components/m2m_flow_former/LatentCostFormer/encoders.py b/modules/components/m2m_flow_former/LatentCostFormer/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..0513a3473d3b5b738ca048c3f19cbe6e2b28c327 --- /dev/null +++ b/modules/components/m2m_flow_former/LatentCostFormer/encoders.py @@ -0,0 +1,94 @@ +import torch +import torch.nn as nn +import timm +import numpy as np + +class twins_svt_large(nn.Module): + def __init__(self, pretrained=True): + super().__init__() + self.svt = timm.create_model('twins_svt_large', pretrained=pretrained) + + del self.svt.head + del self.svt.patch_embeds[2] + del self.svt.patch_embeds[2] + del self.svt.blocks[2] + del self.svt.blocks[2] + del self.svt.pos_block[2] + del self.svt.pos_block[2] + del self.svt.norm.bias + del self.svt.norm.weight + + def forward(self, x, data=None, layer=2): + B = x.shape[0] + for i, (embed, drop, blocks, pos_blk) in enumerate( + zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)): + + x, size = embed(x) + x = drop(x) + for j, blk in enumerate(blocks): + x = blk(x, size) + if j==0: + x = pos_blk(x, size) + if i < len(self.svt.depths) - 1: + x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() + + if i == layer-1: + break + + return x + + def compute_params(self, layer=2): + num = 0 + for i, (embed, drop, blocks, pos_blk) in enumerate( + zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)): + + for param in embed.parameters(): + num += np.prod(param.size()) + + for param in drop.parameters(): + num += np.prod(param.size()) + + for param in blocks.parameters(): + num += np.prod(param.size()) + + for param in pos_blk.parameters(): + num += np.prod(param.size()) + + if i == layer-1: + break + + for param in self.svt.head.parameters(): + num += np.prod(param.size()) + + return num + +class twins_svt_large_context(nn.Module): + def __init__(self, pretrained=True): + super().__init__() + self.svt = timm.create_model('twins_svt_large_context', pretrained=pretrained) + + def forward(self, x, data=None, layer=2): + B = x.shape[0] + for i, (embed, drop, blocks, pos_blk) in enumerate( + zip(self.svt.patch_embeds, self.svt.pos_drops, self.svt.blocks, self.svt.pos_block)): + + x, size = embed(x) + x = drop(x) + for j, blk in enumerate(blocks): + x = blk(x, size) + if j==0: + x = pos_blk(x, size) + if i < len(self.svt.depths) - 1: + x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() + + if i == layer-1: + break + + return x + + +if __name__ == "__main__": + m = twins_svt_large() + input = torch.randn(2, 3, 400, 800) + out = m.extract_feature(input) + print(out.shape) diff --git a/modules/components/m2m_flow_former/LatentCostFormer/gma.py b/modules/components/m2m_flow_former/LatentCostFormer/gma.py new file mode 100644 index 0000000000000000000000000000000000000000..cee57cece03b68d95c30fcfc3e4420e9ce5823d2 --- /dev/null +++ b/modules/components/m2m_flow_former/LatentCostFormer/gma.py @@ -0,0 +1,123 @@ +import torch +from torch import nn, einsum +from einops import rearrange + + +class RelPosEmb(nn.Module): + def __init__( + self, + max_pos_size, + dim_head + ): + super().__init__() + self.rel_height = nn.Embedding(2 * max_pos_size - 1, dim_head) + self.rel_width = nn.Embedding(2 * max_pos_size - 1, dim_head) + + deltas = torch.arange(max_pos_size).view(1, -1) - torch.arange(max_pos_size).view(-1, 1) + rel_ind = deltas + max_pos_size - 1 + self.register_buffer('rel_ind', rel_ind) + + def forward(self, q): + batch, heads, h, w, c = q.shape + height_emb = self.rel_height(self.rel_ind[:h, :h].reshape(-1)) + width_emb = self.rel_width(self.rel_ind[:w, :w].reshape(-1)) + + height_emb = rearrange(height_emb, '(x u) d -> x u () d', x=h) + width_emb = rearrange(width_emb, '(y v) d -> y () v d', y=w) + + height_score = einsum('b h x y d, x u v d -> b h x y u v', q, height_emb) + width_score = einsum('b h x y d, y u v d -> b h x y u v', q, width_emb) + + return height_score + width_score + + +class Attention(nn.Module): + def __init__( + self, + *, + args, + dim, + max_pos_size = 100, + heads = 4, + dim_head = 128, + ): + super().__init__() + self.args = args + self.heads = heads + self.scale = dim_head ** -0.5 + inner_dim = heads * dim_head + + self.to_qk = nn.Conv2d(dim, inner_dim * 2, 1, bias=False) + + # self.pos_emb = RelPosEmb(max_pos_size, dim_head) + + def forward(self, fmap): + heads, b, c, h, w = self.heads, *fmap.shape + + q, k = self.to_qk(fmap).chunk(2, dim=1) + + q, k = map(lambda t: rearrange(t, 'b (h d) x y -> b h x y d', h=heads), (q, k)) + q = self.scale * q + + # if self.args.position_only: + # sim = self.pos_emb(q) + + # elif self.args.position_and_content: + # sim_content = einsum('b h x y d, b h u v d -> b h x y u v', q, k) + # sim_pos = self.pos_emb(q) + # sim = sim_content + sim_pos + + # else: + sim = einsum('b h x y d, b h u v d -> b h x y u v', q, k) + + sim = rearrange(sim, 'b h x y u v -> b h (x y) (u v)') + attn = sim.softmax(dim=-1) + + return attn + + +class Aggregate(nn.Module): + def __init__( + self, + args, + dim, + heads = 4, + dim_head = 128, + ): + super().__init__() + self.args = args + self.heads = heads + self.scale = dim_head ** -0.5 + inner_dim = heads * dim_head + + self.to_v = nn.Conv2d(dim, inner_dim, 1, bias=False) + + self.gamma = nn.Parameter(torch.zeros(1)) + + if dim != inner_dim: + self.project = nn.Conv2d(inner_dim, dim, 1, bias=False) + else: + self.project = None + + def forward(self, attn, fmap): + heads, b, c, h, w = self.heads, *fmap.shape + + v = self.to_v(fmap) + v = rearrange(v, 'b (h d) x y -> b h (x y) d', h=heads) + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w) + + if self.project is not None: + out = self.project(out) + + out = fmap + self.gamma * out + + return out + + +if __name__ == "__main__": + att = Attention(dim=128, heads=1) + fmap = torch.randn(2, 128, 40, 90) + out = att(fmap) + + print(out.shape) \ No newline at end of file diff --git a/modules/components/m2m_flow_former/LatentCostFormer/gru.py b/modules/components/m2m_flow_former/LatentCostFormer/gru.py new file mode 100644 index 0000000000000000000000000000000000000000..92802c76e91471573551164ff60fbd94c26b4424 --- /dev/null +++ b/modules/components/m2m_flow_former/LatentCostFormer/gru.py @@ -0,0 +1,137 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256): + super(FlowHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) + + h = (1-z) * h + z * q + return h + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + + self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + return h + +class BasicMotionEncoder(nn.Module): + def __init__(self, args): + super(BasicMotionEncoder, self).__init__() + if args.only_global: + print("[Decoding with only global cost]") + cor_planes = args.query_latent_dim + else: + cor_planes = 81+args.query_latent_dim + self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(2, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + +class BasicUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128, input_dim=128): + super(BasicUpdateBlock, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64*9, 1, padding=0)) + + def forward(self, net, inp, corr, flow, upsample=True): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = .25 * self.mask(net) + return net, mask, delta_flow + +from .gma import Aggregate +class GMAUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128): + super().__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim+hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64*9, 1, padding=0)) + + self.aggregator = Aggregate(args=self.args, dim=128, dim_head=128, heads=1) + + def forward(self, net, inp, corr, flow, attention): + motion_features = self.encoder(flow, corr) + motion_features_global = self.aggregator(attention, motion_features) + inp_cat = torch.cat([inp, motion_features, motion_features_global], dim=1) + + # Attentional update + net = self.gru(net, inp_cat) + + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = .25 * self.mask(net) + return net, mask, delta_flow \ No newline at end of file diff --git a/modules/components/m2m_flow_former/LatentCostFormer/mlpmixer.py b/modules/components/m2m_flow_former/LatentCostFormer/mlpmixer.py new file mode 100644 index 0000000000000000000000000000000000000000..c31038bb3fffb8ccd93982905f9dd937492fed7c --- /dev/null +++ b/modules/components/m2m_flow_former/LatentCostFormer/mlpmixer.py @@ -0,0 +1,50 @@ +from torch import nn +from einops.layers.torch import Rearrange, Reduce +from functools import partial +import numpy as np + +class PreNormResidual(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.fn = fn + self.norm = nn.LayerNorm(dim) + + def forward(self, x): + return self.fn(self.norm(x)) + x + +def FeedForward(dim, expansion_factor = 4, dropout = 0., dense = nn.Linear): + return nn.Sequential( + dense(dim, dim * expansion_factor), + nn.GELU(), + nn.Dropout(dropout), + dense(dim * expansion_factor, dim), + nn.Dropout(dropout) + ) + +class MLPMixerLayer(nn.Module): + def __init__(self, dim, cfg, drop_path=0., dropout=0.): + super(MLPMixerLayer, self).__init__() + + # print(f"use mlp mixer layer") + K = cfg.cost_latent_token_num + expansion_factor = cfg.mlp_expansion_factor + chan_first, chan_last = partial(nn.Conv1d, kernel_size = 1), nn.Linear + + self.mlpmixer = nn.Sequential( + PreNormResidual(dim, FeedForward(K, expansion_factor, dropout, chan_first)), + PreNormResidual(dim, FeedForward(dim, expansion_factor, dropout, chan_last)), + ) + + def compute_params(self): + num = 0 + for param in self.mlpmixer.parameters(): + num += np.prod(param.size()) + + return num + + def forward(self, x): + """ + x: [BH1W1, K, D] + """ + + return self.mlpmixer(x) diff --git a/modules/components/m2m_flow_former/LatentCostFormer/position_encoding.py b/modules/components/m2m_flow_former/LatentCostFormer/position_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..00220b9fc4d49e3111ca81462bb2b08790496100 --- /dev/null +++ b/modules/components/m2m_flow_former/LatentCostFormer/position_encoding.py @@ -0,0 +1,92 @@ +from loguru import logger +import math +import torch +from torch import nn + + +class PositionEncodingSine(nn.Module): + """ + This is a sinusoidal position encoding that generalized to 2-dimensional images + """ + + def __init__(self, d_model, max_shape=(256, 256)): + """ + Args: + max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels + """ + super().__init__() + + pe = torch.zeros((d_model, *max_shape)) + y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0) + x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0) + div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / d_model//2)) + div_term = div_term[:, None, None] # [C//4, 1, 1] + pe[0::4, :, :] = torch.sin(x_position * div_term) + pe[1::4, :, :] = torch.cos(x_position * div_term) + pe[2::4, :, :] = torch.sin(y_position * div_term) + pe[3::4, :, :] = torch.cos(y_position * div_term) + + self.register_buffer('pe', pe.unsqueeze(0)) # [1, C, H, W] + + def forward(self, x): + """ + Args: + x: [N, C, H, W] + """ + return x + self.pe[:, :, :x.size(2), :x.size(3)] + +class LinearPositionEncoding(nn.Module): + """ + This is a sinusoidal position encoding that generalized to 2-dimensional images + """ + + def __init__(self, d_model, max_shape=(256, 256)): + """ + Args: + max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels + """ + super().__init__() + + pe = torch.zeros((d_model, *max_shape)) + y_position = (torch.ones(max_shape).cumsum(0).float().unsqueeze(0) - 1) / max_shape[0] + x_position = (torch.ones(max_shape).cumsum(1).float().unsqueeze(0) - 1) / max_shape[1] + div_term = torch.arange(0, d_model//2, 2).float() + div_term = div_term[:, None, None] # [C//4, 1, 1] + pe[0::4, :, :] = torch.sin(x_position * div_term * math.pi) + pe[1::4, :, :] = torch.cos(x_position * div_term * math.pi) + pe[2::4, :, :] = torch.sin(y_position * div_term * math.pi) + pe[3::4, :, :] = torch.cos(y_position * div_term * math.pi) + + self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W] + + def forward(self, x): + """ + Args: + x: [N, C, H, W] + """ + # assert x.shape[2] == 80 and x.shape[3] == 80 + + return x + self.pe[:, :, :x.size(2), :x.size(3)] + +class LearnedPositionEncoding(nn.Module): + """ + This is a sinusoidal position encoding that generalized to 2-dimensional images + """ + + def __init__(self, d_model, max_shape=(80, 80)): + """ + Args: + max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels + """ + super().__init__() + + self.pe = nn.Parameter(torch.randn(1, max_shape[0], max_shape[1], d_model)) + + def forward(self, x): + """ + Args: + x: [N, C, H, W] + """ + # assert x.shape[2] == 80 and x.shape[3] == 80 + + return x + self.pe diff --git a/modules/components/m2m_flow_former/LatentCostFormer/transformer.py b/modules/components/m2m_flow_former/LatentCostFormer/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..1533baeca64d15e56b18c8727af9c18b5f828d89 --- /dev/null +++ b/modules/components/m2m_flow_former/LatentCostFormer/transformer.py @@ -0,0 +1,48 @@ +import loguru +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import einsum + +from einops.layers.torch import Rearrange +from einops import rearrange + +from .utils import coords_grid, bilinear_sampler, upflow8 +from .common import FeedForward, pyramid_retrieve_tokens, sampler, sampler_gaussian_fix, retrieve_tokens, MultiHeadAttention, MLP +from .encoders import twins_svt_large_context, twins_svt_large +from .position_encoding import PositionEncodingSine, LinearPositionEncoding +from .twins import PosConv +from .encoder import MemoryEncoder +from .decoder import MemoryDecoder +from .cnn import BasicEncoder + + +class FlowFormer(nn.Module): + def __init__(self, cfg): + super(FlowFormer, self).__init__() + self.cfg = cfg + + self.memory_encoder = MemoryEncoder(cfg) + self.memory_decoder = MemoryDecoder(cfg) + if cfg.cnet == 'twins': + self.context_encoder = twins_svt_large(pretrained=False) + elif cfg.cnet == 'basicencoder': + self.context_encoder = BasicEncoder(output_dim=256, norm_fn='instance') + + def forward(self, image1, image2, output=None, flow_init=None): + # Following https://github.com/princeton-vl/RAFT/ + image1 = 2 * (image1) - 1.0 + image2 = 2 * (image2) - 1.0 + + data = {} + + if self.cfg.context_concat: + context = self.context_encoder(torch.cat([image1, image2], dim=1)) + else: + context = self.context_encoder(image1) + + cost_memory = self.memory_encoder(image1, image2, data, context) + + flow_predictions = self.memory_decoder(cost_memory, context, data, flow_init=flow_init) + + return flow_predictions diff --git a/modules/components/m2m_flow_former/LatentCostFormer/twins.py b/modules/components/m2m_flow_former/LatentCostFormer/twins.py new file mode 100644 index 0000000000000000000000000000000000000000..d8a8f5f0b4b95d44a87879e939bc9ecaf8056617 --- /dev/null +++ b/modules/components/m2m_flow_former/LatentCostFormer/twins.py @@ -0,0 +1,1004 @@ +""" Twins +A PyTorch impl of : `Twins: Revisiting the Design of Spatial Attention in Vision Transformers` + - https://arxiv.org/pdf/2104.13840.pdf +Code/weights from https://github.com/Meituan-AutoML/Twins, original copyright/license info below +""" +# -------------------------------------------------------- +# Twins +# Copyright (c) 2021 Meituan +# Licensed under The Apache 2.0 License [see LICENSE for details] +# Written by Xinjie Li, Xiangxiang Chu +# -------------------------------------------------------- +import math +from copy import deepcopy +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.layers import Mlp, DropPath, to_2tuple, trunc_normal_ +from timm.models.registry import register_model +from timm.models.vision_transformer import Attention +from .attention import MultiHeadAttention, LinearPositionEmbeddingSine +from .utils import coords_grid, bilinear_sampler, upflow8 + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embeds.0.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + 'twins_pcpvt_small': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_small-e70e7e7a.pth', + ), + 'twins_pcpvt_base': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_base-e5ecb09b.pth', + ), + 'twins_pcpvt_large': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_large-d273f802.pth', + ), + 'twins_svt_small': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_small-42e5f78c.pth', + ), + 'twins_svt_base': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_base-c2265010.pth', + ), + 'twins_svt_large': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_large-90f6aaa9.pth', + ), +} + +Size_ = Tuple[int, int] + +class GroupAttnRPEContext(nn.Module): + """ Latent cost tokens attend to different group + """ + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., ws=1, cfg=None, vert_c_dim=0): + super(GroupAttnRPEContext, self).__init__() + assert ws != 1 + assert cfg is not None + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + assert cfg.cost_latent_token_num % 5 == 0, "cost_latent_token_num should be divided by 5." + assert vert_c_dim > 0, "vert_c_dim should not be 0" + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + self.vert_c_dim = vert_c_dim + + self.cfg = cfg + + self.context_proj = nn.Linear(256, vert_c_dim) + self.q = nn.Linear(dim+vert_c_dim, dim, bias=True) + self.k = nn.Linear(dim+vert_c_dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.ws = ws + + def forward(self, x, size: Size_, context=None): + B, N, C = x.shape + C_qk = C+self.vert_c_dim + H, W = size + batch_num = B // 5 + + context = context.repeat(B//context.shape[0], 1, 1, 1) + context = context.view(B, -1, H*W).permute(0, 2, 1) + context = self.context_proj(context) + context = context.view(B, H, W, -1) + + x = x.view(B, H, W, C) + x_qk = torch.cat([x, context], dim=-1) + + pad_l = pad_t = 0 + pad_r = (self.ws - W % self.ws) % self.ws + pad_b = (self.ws - H % self.ws) % self.ws + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + x_qk = F.pad(x_qk, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + _h, _w = Hp // self.ws, Wp // self.ws + padded_N = Hp*Wp + + coords = coords_grid(B, Hp, Wp).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C_qk) + coords_enc = coords_enc.reshape(B, Hp, Wp, C_qk) + + q = self.q(x_qk + coords_enc).reshape(B, _h, self.ws, _w, self.ws, self.num_heads, C // self.num_heads).transpose(2, 3) + q = q.reshape(B, _h * _w, self.ws * self.ws, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4) + + v = self.v(x) + k = self.k(x_qk + coords_enc) + # concate and do shifting operation together + kv = torch.cat([k, v], dim=-1) + kv_up = torch.cat([kv[:batch_num, self.ws:Hp, :, :], kv[:batch_num, Hp-self.ws:Hp, :, :]], dim=1) + kv_down = torch.cat([kv[batch_num:batch_num*2, :self.ws, :, :], kv[batch_num:batch_num*2, :Hp-self.ws, :, :]], dim=1) + kv_left = torch.cat([kv[batch_num*2:batch_num*3, :, self.ws:Wp, :], kv[batch_num*2:batch_num*3, :, Wp-self.ws:Wp, :]], dim=2) + kv_right = torch.cat([kv[batch_num*3:batch_num*4, :, :self.ws, :], kv[batch_num*3:batch_num*4, :, :Wp-self.ws, :]], dim=2) + kv_center = kv[batch_num*4:batch_num*5, :, :, :] + kv_shifted = torch.cat([kv_up, kv_down, kv_left, kv_right, kv_center], dim=0) + k, v = torch.split(kv_shifted, [self.dim, self.dim], dim=-1) + + k = k.reshape(B, _h, self.ws, _w, self.ws, self.num_heads, C // self.num_heads).transpose(2, 3) + k = k.reshape(B, _h * _w, self.ws * self.ws, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4) + + v = v.reshape(B, _h, self.ws, _w, self.ws, self.num_heads, C // self.num_heads).transpose(2, 3) + v = v.reshape(B, _h * _w, self.ws * self.ws, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class GroupAttnRPE(nn.Module): + """ Latent cost tokens attend to different group + """ + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., ws=1, cfg=None): + super(GroupAttnRPE, self).__init__() + assert ws != 1 + assert cfg is not None + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + assert cfg.cost_latent_token_num % 5 == 0, "cost_latent_token_num should be divided by 5." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.cfg = cfg + + self.q = nn.Linear(dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.ws = ws + + def forward(self, x, size: Size_, context=None): + B, N, C = x.shape + H, W = size + batch_num = B // 5 + x = x.view(B, H, W, C) + pad_l = pad_t = 0 + pad_r = (self.ws - W % self.ws) % self.ws + pad_b = (self.ws - H % self.ws) % self.ws + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + _h, _w = Hp // self.ws, Wp // self.ws + padded_N = Hp*Wp + + coords = coords_grid(B, Hp, Wp).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C) + coords_enc = coords_enc.reshape(B, Hp, Wp, C) + + q = self.q(x + coords_enc).reshape(B, _h, self.ws, _w, self.ws, self.num_heads, C // self.num_heads).transpose(2, 3) + q = q.reshape(B, _h * _w, self.ws * self.ws, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4) + + v = self.v(x) + k = self.k(x + coords_enc) + # concate and do shifting operation together + kv = torch.cat([k, v], dim=-1) + kv_up = torch.cat([kv[:batch_num, self.ws:Hp, :, :], kv[:batch_num, Hp-self.ws:Hp, :, :]], dim=1) + kv_down = torch.cat([kv[batch_num:batch_num*2, :self.ws, :, :], kv[batch_num:batch_num*2, :Hp-self.ws, :, :]], dim=1) + kv_left = torch.cat([kv[batch_num*2:batch_num*3, :, self.ws:Wp, :], kv[batch_num*2:batch_num*3, :, Wp-self.ws:Wp, :]], dim=2) + kv_right = torch.cat([kv[batch_num*3:batch_num*4, :, :self.ws, :], kv[batch_num*3:batch_num*4, :, :Wp-self.ws, :]], dim=2) + kv_center = kv[batch_num*4:batch_num*5, :, :, :] + kv_shifted = torch.cat([kv_up, kv_down, kv_left, kv_right, kv_center], dim=0) + k, v = torch.split(kv_shifted, [self.dim, self.dim], dim=-1) + + k = k.reshape(B, _h, self.ws, _w, self.ws, self.num_heads, C // self.num_heads).transpose(2, 3) + k = k.reshape(B, _h * _w, self.ws * self.ws, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4) + + v = v.reshape(B, _h, self.ws, _w, self.ws, self.num_heads, C // self.num_heads).transpose(2, 3) + v = v.reshape(B, _h * _w, self.ws * self.ws, self.num_heads, C // self.num_heads).permute(0, 1, 3, 2, 4) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class LocallyGroupedAttnRPEContext(nn.Module): + """ LSA: self attention within a group + """ + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., ws=1, vert_c_dim=0): + assert ws != 1 + super(LocallyGroupedAttnRPEContext, self).__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + self.vert_c_dim = vert_c_dim + + self.context_proj = nn.Linear(256, vert_c_dim) + # context are not added to value + self.q = nn.Linear(dim+vert_c_dim, dim, bias=True) + self.k = nn.Linear(dim+vert_c_dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.ws = ws + + def forward(self, x, size: Size_, context=None): + # There are two implementations for this function, zero padding or mask. We don't observe obvious difference for + # both. You can choose any one, we recommend forward_padding because it's neat. However, + # the masking implementation is more reasonable and accurate. + B, N, C = x.shape + H, W = size + C_qk = C+self.vert_c_dim + + context = context.repeat(B//context.shape[0], 1, 1, 1) + context = context.view(B, -1, H*W).permute(0, 2, 1) + context = self.context_proj(context) + context = context.view(B, H, W, -1) + + x = x.view(B, H, W, C) + x_qk = torch.cat([x, context], dim=-1) + + pad_l = pad_t = 0 + pad_r = (self.ws - W % self.ws) % self.ws + pad_b = (self.ws - H % self.ws) % self.ws + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + x_qk = F.pad(x_qk, (0, 0, pad_l, pad_r, pad_t, pad_b)) + + _, Hp, Wp, _ = x.shape + _h, _w = Hp // self.ws, Wp // self.ws + x = x.reshape(B, _h, self.ws, _w, self.ws, C).transpose(2, 3) + x_qk = x_qk.reshape(B, _h, self.ws, _w, self.ws, C_qk).transpose(2, 3) + + v = self.v(x).reshape( + B, _h * _w, self.ws * self.ws, 1, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)[0] + + coords = coords_grid(B, self.ws, self.ws).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C_qk).view(B, self.ws, self.ws, C_qk) + # coords_enc: B, ws, ws, C + # x: B, _h, _w, self.ws, self.ws, C + x_qk = x_qk + coords_enc[:, None, None, :, :, :] + + q = self.q(x_qk).reshape( + B, _h * _w, self.ws * self.ws, 1, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)[0] + k = self.k(x_qk).reshape( + B, _h * _w, self.ws * self.ws, 1, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)[0] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class GlobalSubSampleAttnRPEContext(nn.Module): + """ GSA: using a key to summarize the information for a group to be efficient. + """ + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., sr_ratio=1, vert_c_dim=0): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.vert_c_dim = vert_c_dim + self.context_proj = nn.Linear(256, vert_c_dim) + self.q = nn.Linear(dim+vert_c_dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr_key = nn.Conv2d(dim+vert_c_dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.sr_value = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.sr = None + self.norm = None + + def forward(self, x, size: Size_, context=None): + B, N, C = x.shape + C_qk = C + self.vert_c_dim + H, W = size + context = context.repeat(B//context.shape[0], 1, 1, 1) + context = context.view(B, -1, H*W).permute(0, 2, 1) + context = self.context_proj(context) + context = context.view(B, H, W, -1) + x = x.view(B, H, W, C) + x_qk = torch.cat([x, context], dim=-1) + pad_l = pad_t = 0 + pad_r = (self.sr_ratio - W % self.sr_ratio) % self.sr_ratio + pad_b = (self.sr_ratio - H % self.sr_ratio) % self.sr_ratio + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + x_qk = F.pad(x_qk, (0, 0, pad_l, pad_r, pad_t, pad_b)) + + _, Hp, Wp, _ = x.shape + padded_size = (Hp, Wp) + padded_N = Hp*Wp + x = x.view(B, -1, C) + x_qk = x_qk.view(B, -1, C_qk) + + coords = coords_grid(B, *padded_size).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C_qk) + # coords_enc: B, Hp*Wp, C + # x: B, Hp*Wp, C + q = self.q(x_qk + coords_enc).reshape(B, padded_N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + if self.sr_key is not None: + x = x.permute(0, 2, 1).reshape(B, C, *padded_size) + x_qk = x_qk.permute(0, 2, 1).reshape(B, C_qk, *padded_size) + x = self.sr_value(x).reshape(B, C, -1).permute(0, 2, 1) + x_qk = self.sr_key(x_qk).reshape(B, C, -1).permute(0, 2, 1) + x = self.norm(x) + x_qk = self.norm(x_qk) + + coords = coords_grid(B, padded_size[0] // self.sr_ratio, padded_size[1] // self.sr_ratio).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) * self.sr_ratio + # align the coordinate of local and global + coords_enc = LinearPositionEmbeddingSine(coords, dim=C) + k = self.k(x_qk + coords_enc).reshape(B, (padded_size[0] // self.sr_ratio)*(padded_size[1] // self.sr_ratio), self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + v = self.v(x).reshape(B, (padded_size[0] // self.sr_ratio)*(padded_size[1] // self.sr_ratio), self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, Hp, Wp, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + +class LocallyGroupedAttnRPE(nn.Module): + """ LSA: self attention within a group + """ + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., ws=1): + assert ws != 1 + super(LocallyGroupedAttnRPE, self).__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.ws = ws + + def forward(self, x, size: Size_, context=None): + # There are two implementations for this function, zero padding or mask. We don't observe obvious difference for + # both. You can choose any one, we recommend forward_padding because it's neat. However, + # the masking implementation is more reasonable and accurate. + B, N, C = x.shape + H, W = size + x = x.view(B, H, W, C) + pad_l = pad_t = 0 + pad_r = (self.ws - W % self.ws) % self.ws + pad_b = (self.ws - H % self.ws) % self.ws + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + _h, _w = Hp // self.ws, Wp // self.ws + x = x.reshape(B, _h, self.ws, _w, self.ws, C).transpose(2, 3) + v = self.v(x).reshape( + B, _h * _w, self.ws * self.ws, 1, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)[0] + + coords = coords_grid(B, self.ws, self.ws).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C).view(B, self.ws, self.ws, C) + # coords_enc: B, ws, ws, C + # x: B, _h, _w, self.ws, self.ws, C + x = x + coords_enc[:, None, None, :, :, :] + + q = self.q(x).reshape( + B, _h * _w, self.ws * self.ws, 1, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)[0] + k = self.k(x).reshape( + B, _h * _w, self.ws * self.ws, 1, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)[0] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class GlobalSubSampleAttnRPE(nn.Module): + """ GSA: using a key to summarize the information for a group to be efficient. + """ + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.sr = None + self.norm = None + + def forward(self, x, size: Size_, context=None): + B, N, C = x.shape + H, W = size + x = x.view(B, H, W, C) + pad_l = pad_t = 0 + pad_r = (self.sr_ratio - W % self.sr_ratio) % self.sr_ratio + pad_b = (self.sr_ratio - H % self.sr_ratio) % self.sr_ratio + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + padded_size = (Hp, Wp) + padded_N = Hp*Wp + x = x.view(B, -1, C) + + coords = coords_grid(B, *padded_size).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C) + # coords_enc: B, Hp*Wp, C + # x: B, Hp*Wp, C + q = self.q(x + coords_enc).reshape(B, padded_N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + if self.sr is not None: + x = x.permute(0, 2, 1).reshape(B, C, *padded_size) + x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1) + x = self.norm(x) + + coords = coords_grid(B, padded_size[0] // self.sr_ratio, padded_size[1] // self.sr_ratio).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) * self.sr_ratio + # align the coordinate of local and global + coords_enc = LinearPositionEmbeddingSine(coords, dim=C) + k = self.k(x + coords_enc).reshape(B, (padded_size[0] // self.sr_ratio)*(padded_size[1] // self.sr_ratio), self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + v = self.v(x).reshape(B, (padded_size[0] // self.sr_ratio)*(padded_size[1] // self.sr_ratio), self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, Hp, Wp, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + +class CrossGlobalSubSampleAttnRPE(nn.Module): + """ GSA: using a key to summarize the information for a group to be efficient. + """ + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=True) + self.k = nn.Linear(dim, dim, bias=True) + self.v = nn.Linear(dim, dim, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.sr = None + self.norm = None + + def forward(self, x, tgt, size: Size_): + B, N, C = x.shape + coords = coords_grid(B, *size).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) + coords_enc = LinearPositionEmbeddingSine(coords, dim=C) + # coords_enc: B, H*W, C + # x: B, H*W, C + q = self.q(x + coords_enc).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + if self.sr is not None: + tgt = tgt.permute(0, 2, 1).reshape(B, C, *size) + tgt = self.sr(tgt).reshape(B, C, -1).permute(0, 2, 1) + tgt = self.norm(tgt) + coords = coords_grid(B, size[0] // self.sr_ratio, size[1] // self.sr_ratio).to(x.device) + coords = coords.view(B, 2, -1).permute(0, 2, 1) * self.sr_ratio + # align the coordinate of local and global + coords_enc = LinearPositionEmbeddingSine(coords, dim=C) + k = self.k(tgt + coords_enc).reshape(B, (size[0] // self.sr_ratio)*(size[1] // self.sr_ratio), self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + v = self.v(tgt).reshape(B, (size[0] // self.sr_ratio)*(size[1] // self.sr_ratio), self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + +class LocallyGroupedAttn(nn.Module): + """ LSA: self attention within a group + """ + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., ws=1): + assert ws != 1 + super(LocallyGroupedAttn, self).__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.ws = ws + + def forward(self, x, size: Size_): + # There are two implementations for this function, zero padding or mask. We don't observe obvious difference for + # both. You can choose any one, we recommend forward_padding because it's neat. However, + # the masking implementation is more reasonable and accurate. + B, N, C = x.shape + H, W = size + x = x.view(B, H, W, C) + pad_l = pad_t = 0 + pad_r = (self.ws - W % self.ws) % self.ws + pad_b = (self.ws - H % self.ws) % self.ws + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + _h, _w = Hp // self.ws, Wp // self.ws + x = x.reshape(B, _h, self.ws, _w, self.ws, C).transpose(2, 3) + qkv = self.qkv(x).reshape( + B, _h * _w, self.ws * self.ws, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5) + q, k, v = qkv[0], qkv[1], qkv[2] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C) + x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C) + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + x = x.reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + +class GlobalSubSampleAttn(nn.Module): + """ GSA: using a key to summarize the information for a group to be efficient. + """ + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=True) + self.kv = nn.Linear(dim, dim * 2, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.sr = None + self.norm = None + + def forward(self, x, size: Size_): + B, N, C = x.shape + q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + if self.sr is not None: + x = x.permute(0, 2, 1).reshape(B, C, *size) + x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1) + x = self.norm(x) + kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + +class CrossGlobalSubSampleAttn(nn.Module): + """ GSA: using a key to summarize the information for a group to be efficient. + """ + def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., sr_ratio=1): + super().__init__() + assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}." + + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.q = nn.Linear(dim, dim, bias=True) + self.kv = nn.Linear(dim, dim * 2, bias=True) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.sr_ratio = sr_ratio + if sr_ratio > 1: + self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio) + self.norm = nn.LayerNorm(dim) + else: + self.sr = None + self.norm = None + + def forward(self, x, tgt, size: Size_): + B, N, C = x.shape + q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + if self.sr is not None: + tgt = tgt.permute(0, 2, 1).reshape(B, C, *size) + tgt = self.sr(tgt).reshape(B, C, -1).permute(0, 2, 1) + tgt = self.norm(tgt) + kv = self.kv(tgt).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[1] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + + return x + +class CrossBlock(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, ws=None, with_rpe=True): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = CrossGlobalSubSampleAttnRPE(dim, num_heads, attn_drop, drop, sr_ratio) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, src, tgt, size: Size_): + src_shortcut, tgt_shortcut = src, tgt + + src, tgt = self.norm1(src), self.norm1(tgt) + src = src_shortcut + self.drop_path(self.attn(src, tgt, size)) + tgt = tgt_shortcut + self.drop_path(self.attn(tgt, src, size)) + + src = src + self.drop_path(self.mlp(self.norm2(src))) + tgt = tgt + self.drop_path(self.mlp(self.norm2(tgt))) + return src, tgt + +class Block(nn.Module): + + def __init__(self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, ws=None, with_rpe=False, vert_c_dim=0, groupattention=False, cfg=None): + super().__init__() + self.norm1 = norm_layer(dim) + if groupattention: + assert with_rpe, "Not implementing groupattention without rpe" + if vert_c_dim > 0: + self.attn = GroupAttnRPEContext(dim, num_heads, attn_drop, drop, ws, cfg, vert_c_dim) + else: + self.attn = GroupAttnRPE(dim, num_heads, attn_drop, drop, ws, cfg) + elif ws is None: + self.attn = Attention(dim, num_heads, False, None, attn_drop, drop) + elif ws == 1: + if with_rpe: + if vert_c_dim > 0: + self.attn = GlobalSubSampleAttnRPEContext(dim, num_heads, attn_drop, drop, sr_ratio, vert_c_dim) + else: + self.attn = GlobalSubSampleAttnRPE(dim, num_heads, attn_drop, drop, sr_ratio) + else: + self.attn = GlobalSubSampleAttn(dim, num_heads, attn_drop, drop, sr_ratio) + else: + if with_rpe: + if vert_c_dim > 0: + self.attn = LocallyGroupedAttnRPEContext(dim, num_heads, attn_drop, drop, ws, vert_c_dim) + else: + self.attn = LocallyGroupedAttnRPE(dim, num_heads, attn_drop, drop, ws) + else: + self.attn = LocallyGroupedAttn(dim, num_heads, attn_drop, drop, ws) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, size: Size_, context=None): + x = x + self.drop_path(self.attn(self.norm1(x), size, context)) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PosConv(nn.Module): + # PEG from https://arxiv.org/abs/2102.10882 + def __init__(self, in_chans, embed_dim=768, stride=1): + super(PosConv, self).__init__() + self.proj = nn.Sequential(nn.Conv2d(in_chans, embed_dim, 3, stride, 1, bias=True, groups=embed_dim), ) + self.stride = stride + + def forward(self, x, size: Size_): + B, N, C = x.shape + cnn_feat_token = x.transpose(1, 2).view(B, C, *size) + x = self.proj(cnn_feat_token) + if self.stride == 1: + x += cnn_feat_token + x = x.flatten(2).transpose(1, 2) + return x + + def no_weight_decay(self): + return ['proj.%d.weight' % i for i in range(4)] + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + + self.img_size = img_size + self.patch_size = patch_size + assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \ + f"img_size {img_size} should be divided by patch_size {patch_size}." + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + self.norm = nn.LayerNorm(embed_dim) + + def forward(self, x) -> Tuple[torch.Tensor, Size_]: + B, C, H, W = x.shape + + x = self.proj(x).flatten(2).transpose(1, 2) + x = self.norm(x) + out_size = (H // self.patch_size[0], W // self.patch_size[1]) + + return x, out_size + + +class Twins(nn.Module): + """ Twins Vision Transfomer (Revisiting Spatial Attention) + Adapted from PVT (PyramidVisionTransformer) class at https://github.com/whai362/PVT.git + """ + def __init__( + self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dims=(64, 128, 256, 512), + num_heads=(1, 2, 4, 8), mlp_ratios=(4, 4, 4, 4), drop_rate=0., attn_drop_rate=0., drop_path_rate=0., + norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=(3, 4, 6, 3), sr_ratios=(8, 4, 2, 1), wss=None, + block_cls=Block, init_weight=True): + super().__init__() + self.num_classes = num_classes + self.depths = depths + self.embed_dims = embed_dims + self.num_features = embed_dims[-1] + + img_size = to_2tuple(img_size) + prev_chs = in_chans + self.patch_embeds = nn.ModuleList() + self.pos_drops = nn.ModuleList() + for i in range(len(depths)): + self.patch_embeds.append(PatchEmbed(img_size, patch_size, prev_chs, embed_dims[i])) + self.pos_drops.append(nn.Dropout(p=drop_rate)) + prev_chs = embed_dims[i] + img_size = tuple(t // patch_size for t in img_size) + patch_size = 2 + + self.blocks = nn.ModuleList() + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + cur = 0 + for k in range(len(depths)): + _block = nn.ModuleList([block_cls( + dim=embed_dims[k], num_heads=num_heads[k], mlp_ratio=mlp_ratios[k], drop=drop_rate, + attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[k], + ws=1 if wss is None or i % 2 == 1 else wss[k]) for i in range(depths[k])]) + self.blocks.append(_block) + cur += depths[k] + + self.pos_block = nn.ModuleList([PosConv(embed_dim, embed_dim) for embed_dim in embed_dims]) + + self.norm = norm_layer(self.num_features) + + # classification head + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + # init weights + if init_weight: + self.apply(self._init_weights) + + @torch.jit.ignore + def no_weight_decay(self): + return set(['pos_block.' + n for n, p in self.pos_block.named_parameters()]) + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + + def forward_features(self, x): + B = x.shape[0] + for i, (embed, drop, blocks, pos_blk) in enumerate( + zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)): + x, size = embed(x) + x = drop(x) + for j, blk in enumerate(blocks): + x = blk(x, size) + if j == 0: + x = pos_blk(x, size) # PEG here + if i < len(self.depths) - 1: + x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() + x = self.norm(x) + return x.mean(dim=1) # GAP here + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +# def _create_twins(variant, pretrained=False, **kwargs): +# if kwargs.get('features_only', None): +# raise RuntimeError('features_only not implemented for Vision Transformer models.') + +# model = build_model_with_cfg( +# Twins, variant, pretrained, +# default_cfg=default_cfgs[variant], +# **kwargs) +# return model + + +# @register_model +# def twins_pcpvt_small(pretrained=False, **kwargs): +# model_kwargs = dict( +# patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], +# depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], **kwargs) +# return _create_twins('twins_pcpvt_small', pretrained=pretrained, **model_kwargs) + + +# @register_model +# def twins_pcpvt_base(pretrained=False, **kwargs): +# model_kwargs = dict( +# patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], +# depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], **kwargs) +# return _create_twins('twins_pcpvt_base', pretrained=pretrained, **model_kwargs) + + +# @register_model +# def twins_pcpvt_large(pretrained=False, **kwargs): +# model_kwargs = dict( +# patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], +# depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], **kwargs) +# return _create_twins('twins_pcpvt_large', pretrained=pretrained, **model_kwargs) + + +# @register_model +# def twins_svt_small(pretrained=False, **kwargs): +# model_kwargs = dict( +# patch_size=4, embed_dims=[64, 128, 256, 512], num_heads=[2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4], +# depths=[2, 2, 10, 4], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs) +# return _create_twins('twins_svt_small', pretrained=pretrained, **model_kwargs) + + +# @register_model +# def twins_svt_base(pretrained=False, **kwargs): +# model_kwargs = dict( +# patch_size=4, embed_dims=[96, 192, 384, 768], num_heads=[3, 6, 12, 24], mlp_ratios=[4, 4, 4, 4], +# depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs) +# return _create_twins('twins_svt_base', pretrained=pretrained, **model_kwargs) + + +# @register_model +# def twins_svt_large(pretrained=False, **kwargs): +# model_kwargs = dict( +# patch_size=4, embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], mlp_ratios=[4, 4, 4, 4], +# depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs) +# return _create_twins('twins_svt_large', pretrained=pretrained, **model_kwargs) + +# @register_model +# def twins_svt_large_context(pretrained=False, **kwargs): +# model_kwargs = dict( +# patch_size=4, embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], mlp_ratios=[4, 4, 4, 4], +# depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], in_chans=6, init_weight=False, **kwargs) +# return _create_twins('twins_svt_large', pretrained=pretrained, **model_kwargs) +# # def twins_svt_large_context(pretrained=False, **kwargs): +# # model_kwargs = dict( +# # patch_size=4, embed_dims=[128, 256], num_heads=[4, 8], mlp_ratios=[4, 4], +# # depths=[2, 2], wss=[7, 7], sr_ratios=[8, 4], in_chans=6, init_weight=False, **kwargs) +# # return _create_twins('twins_svt_large', pretrained=pretrained, **model_kwargs) diff --git a/modules/components/m2m_flow_former/LatentCostFormer/utils.py b/modules/components/m2m_flow_former/LatentCostFormer/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2f001b224ae7923fd22bf02269a5f259e0e163b9 --- /dev/null +++ b/modules/components/m2m_flow_former/LatentCostFormer/utils.py @@ -0,0 +1,101 @@ +import torch +import torch.nn.functional as F +import numpy as np +from scipy import interpolate + + +class InputPadder: + """ Pads images such that dimensions are divisible by 8 """ + def __init__(self, dims, mode='sintel'): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 + pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 + if mode == 'sintel': + self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] + elif mode == 'kitti400': + self._pad = [0, 0, 0, 400 - self.ht] + else: + self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] + + def pad(self, *inputs): + return [F.pad(x, self._pad, mode='replicate') for x in inputs] + + def unpad(self,x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] + return x[..., c[0]:c[1], c[2]:c[3]] + +def forward_interpolate(flow): + flow = flow.detach().cpu().numpy() + dx, dy = flow[0], flow[1] + + ht, wd = dx.shape + x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) + + x1 = x0 + dx + y1 = y0 + dy + + x1 = x1.reshape(-1) + y1 = y1.reshape(-1) + dx = dx.reshape(-1) + dy = dy.reshape(-1) + + valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) + x1 = x1[valid] + y1 = y1[valid] + dx = dx[valid] + dy = dy[valid] + + flow_x = interpolate.griddata( + (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) + + flow_y = interpolate.griddata( + (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) + + flow = np.stack([flow_x, flow_y], axis=0) + return torch.from_numpy(flow).float() + +def bilinear_sampler(img, coords, mode='bilinear', mask=False): + """ Wrapper for grid_sample, uses pixel coordinates """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1,1], dim=-1) + xgrid = 2*xgrid/(W-1) - 1 + ygrid = 2*ygrid/(H-1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + +def indexing(img, coords, mask=False): + """ Wrapper for grid_sample, uses pixel coordinates """ + """ + TODO: directly indexing features instead of sampling + """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1,1], dim=-1) + xgrid = 2*xgrid/(W-1) - 1 + ygrid = 2*ygrid/(H-1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True, mode='nearest') + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + +def coords_grid(batch, ht, wd): + coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +def upflow8(flow, mode='bilinear'): + new_size = (8 * flow.shape[2], 8 * flow.shape[3]) + return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) diff --git a/modules/components/m2m_flow_former/__init__.py b/modules/components/m2m_flow_former/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a321314fce09c8707b01321be1a44b731fdd0c7f --- /dev/null +++ b/modules/components/m2m_flow_former/__init__.py @@ -0,0 +1 @@ +from .m2m import M2MFlowFormer diff --git a/modules/components/m2m_flow_former/__pycache__/__init__.cpython-310.pyc b/modules/components/m2m_flow_former/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..215f95756ab5705a354b49000db0b32cfe2f31bc Binary files /dev/null and b/modules/components/m2m_flow_former/__pycache__/__init__.cpython-310.pyc differ diff --git a/modules/components/m2m_flow_former/__pycache__/__init__.cpython-38.pyc b/modules/components/m2m_flow_former/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d34b4d783589d32e84abbe9d4b55e169159a736b Binary files /dev/null and b/modules/components/m2m_flow_former/__pycache__/__init__.cpython-38.pyc differ diff --git a/modules/components/m2m_flow_former/__pycache__/__init__.cpython-39.pyc b/modules/components/m2m_flow_former/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88484ee64d11f5584e64496185099587bac6b1fe Binary files /dev/null and b/modules/components/m2m_flow_former/__pycache__/__init__.cpython-39.pyc differ diff --git a/modules/components/m2m_flow_former/__pycache__/backwarp.cpython-310.pyc b/modules/components/m2m_flow_former/__pycache__/backwarp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11699ec43e819443d07af025557b5c2188bbc1cb Binary files /dev/null and b/modules/components/m2m_flow_former/__pycache__/backwarp.cpython-310.pyc differ diff --git a/modules/components/m2m_flow_former/__pycache__/backwarp.cpython-38.pyc b/modules/components/m2m_flow_former/__pycache__/backwarp.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..678963e5c7708a5210584110d485f0b935758900 Binary files /dev/null and b/modules/components/m2m_flow_former/__pycache__/backwarp.cpython-38.pyc differ diff --git a/modules/components/m2m_flow_former/__pycache__/backwarp.cpython-39.pyc b/modules/components/m2m_flow_former/__pycache__/backwarp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0971722af95a700a4dab965ad10b89429e45b931 Binary files /dev/null and b/modules/components/m2m_flow_former/__pycache__/backwarp.cpython-39.pyc differ diff --git a/modules/components/m2m_flow_former/__pycache__/cfg.cpython-310.pyc b/modules/components/m2m_flow_former/__pycache__/cfg.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0a12ce938f21bbbb2a1c8c8e6810d3c8ad3aaac Binary files /dev/null and b/modules/components/m2m_flow_former/__pycache__/cfg.cpython-310.pyc differ diff --git a/modules/components/m2m_flow_former/__pycache__/cfg.cpython-38.pyc b/modules/components/m2m_flow_former/__pycache__/cfg.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74cd66b96f5d1b0982d11c78ba758dbc446e1dc4 Binary files /dev/null and b/modules/components/m2m_flow_former/__pycache__/cfg.cpython-38.pyc differ diff --git a/modules/components/m2m_flow_former/__pycache__/cfg.cpython-39.pyc b/modules/components/m2m_flow_former/__pycache__/cfg.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88980723de083430aba0cb0432bd89c8d09b28a1 Binary files /dev/null and b/modules/components/m2m_flow_former/__pycache__/cfg.cpython-39.pyc differ diff --git a/modules/components/m2m_flow_former/__pycache__/m2m.cpython-310.pyc b/modules/components/m2m_flow_former/__pycache__/m2m.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..075ed2f23cb60da7914ec23b187134332792819e Binary files /dev/null and b/modules/components/m2m_flow_former/__pycache__/m2m.cpython-310.pyc differ diff --git a/modules/components/m2m_flow_former/__pycache__/m2m.cpython-38.pyc b/modules/components/m2m_flow_former/__pycache__/m2m.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..115f42cadfe03d30965e43e0abf2484c25385d16 Binary files /dev/null and b/modules/components/m2m_flow_former/__pycache__/m2m.cpython-38.pyc differ diff --git a/modules/components/m2m_flow_former/__pycache__/m2m.cpython-39.pyc b/modules/components/m2m_flow_former/__pycache__/m2m.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77b0687ff79f4861d141a0c765d4770201c39d5e Binary files /dev/null and b/modules/components/m2m_flow_former/__pycache__/m2m.cpython-39.pyc differ diff --git a/modules/components/m2m_flow_former/__pycache__/softsplat.cpython-310.pyc b/modules/components/m2m_flow_former/__pycache__/softsplat.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4dc73a8996b70c016d7a2fae4c08ddcc9d6bef45 Binary files /dev/null and b/modules/components/m2m_flow_former/__pycache__/softsplat.cpython-310.pyc differ diff --git a/modules/components/m2m_flow_former/__pycache__/softsplat.cpython-38.pyc b/modules/components/m2m_flow_former/__pycache__/softsplat.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57e95635ee7412377475987cdb743db6f8caf5f9 Binary files /dev/null and b/modules/components/m2m_flow_former/__pycache__/softsplat.cpython-38.pyc differ diff --git a/modules/components/m2m_flow_former/__pycache__/softsplat.cpython-39.pyc b/modules/components/m2m_flow_former/__pycache__/softsplat.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5dbb887f3cb71616769b5a2e323b7e30f812b599 Binary files /dev/null and b/modules/components/m2m_flow_former/__pycache__/softsplat.cpython-39.pyc differ diff --git a/modules/components/m2m_flow_former/backwarp.py b/modules/components/m2m_flow_former/backwarp.py new file mode 100644 index 0000000000000000000000000000000000000000..e99a0a5c1b658e81536825451b865b39c45bc9c4 --- /dev/null +++ b/modules/components/m2m_flow_former/backwarp.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python + +import torch + + +########################################################## + + +objBackwarpcache = {} + + +def backwarp(tenIn:torch.Tensor, tenFlow:torch.Tensor): + if 'grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3]) not in objBackwarpcache: + tenHor = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[3], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, 1, -1).repeat(1, 1, tenFlow.shape[2], 1) + tenVer = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[2], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, -1, 1).repeat(1, 1, 1, tenFlow.shape[3]) + + objBackwarpcache['grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3])] = torch.cat([tenHor, tenVer], 1) + # end + + if tenFlow.shape[3] == tenFlow.shape[2]: + tenFlow = tenFlow * (2.0 / ((tenFlow.shape[3] and tenFlow.shape[2]) - 1.0)) + + elif tenFlow.shape[3] != tenFlow.shape[2]: + tenFlow = tenFlow * torch.tensor(data=[2.0 / (tenFlow.shape[3] - 1.0), 2.0 / (tenFlow.shape[2] - 1.0)], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 2, 1, 1) + + # end + + return torch.nn.functional.grid_sample(input=tenIn, grid=(objBackwarpcache['grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3])] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=True) +# end diff --git a/modules/components/m2m_flow_former/cfg.py b/modules/components/m2m_flow_former/cfg.py new file mode 100644 index 0000000000000000000000000000000000000000..5596a98eec708a7e8380d04369097d32aa52ecb4 --- /dev/null +++ b/modules/components/m2m_flow_former/cfg.py @@ -0,0 +1,65 @@ +from yacs.config import CfgNode as CN +_CN = CN() + +_CN.name = 'default' +_CN.suffix ='sintel' +_CN.gamma = 0.85 +_CN.max_flow = 400 +_CN.batch_size = 6 +_CN.sum_freq = 100 +_CN.val_freq = 5000000 +_CN.image_size = [432, 960] +_CN.add_noise = True +_CN.critical_params = [] + +_CN.transformer = 'latentcostformer' +_CN.restore_ckpt = 'checkpoints/things.pth' + +# latentcostformer +_CN.latentcostformer = CN() +_CN.latentcostformer.pe = 'linear' +_CN.latentcostformer.dropout = 0.0 +_CN.latentcostformer.encoder_latent_dim = 256 # in twins, this is 256 +_CN.latentcostformer.query_latent_dim = 64 +_CN.latentcostformer.cost_latent_input_dim = 64 +_CN.latentcostformer.cost_latent_token_num = 8 +_CN.latentcostformer.cost_latent_dim = 128 +_CN.latentcostformer.arc_type = 'transformer' +_CN.latentcostformer.cost_heads_num = 1 +# encoder +_CN.latentcostformer.pretrain = True +_CN.latentcostformer.context_concat = False +_CN.latentcostformer.encoder_depth = 3 +_CN.latentcostformer.feat_cross_attn = False +_CN.latentcostformer.patch_size = 8 +_CN.latentcostformer.patch_embed = 'single' +_CN.latentcostformer.no_pe = False +_CN.latentcostformer.gma = "GMA" +_CN.latentcostformer.kernel_size = 9 +_CN.latentcostformer.rm_res = True +_CN.latentcostformer.vert_c_dim = 64 +_CN.latentcostformer.cost_encoder_res = True +_CN.latentcostformer.cnet = 'twins' +_CN.latentcostformer.fnet = 'twins' +_CN.latentcostformer.no_sc = False +_CN.latentcostformer.only_global = False +_CN.latentcostformer.add_flow_token = True +_CN.latentcostformer.use_mlp = False +_CN.latentcostformer.vertical_conv = False + +# decoder +_CN.latentcostformer.decoder_depth = 12 +_CN.latentcostformer.critical_params = ['cost_heads_num', 'vert_c_dim', 'cnet', 'pretrain' , 'add_flow_token', 'encoder_depth', 'gma', 'cost_encoder_res'] + +### TRAINER +_CN.trainer = CN() +_CN.trainer.scheduler = 'OneCycleLR' +_CN.trainer.optimizer = 'adamw' +_CN.trainer.canonical_lr = 12.5e-5 +_CN.trainer.adamw_decay = 1e-5 +_CN.trainer.clip = 1.0 +_CN.trainer.num_steps = 120000 +_CN.trainer.epsilon = 1e-8 +_CN.trainer.anneal_strategy = 'linear' +def get_cfg(): + return _CN.clone() diff --git a/modules/components/m2m_flow_former/costvol.py b/modules/components/m2m_flow_former/costvol.py new file mode 100644 index 0000000000000000000000000000000000000000..40e1cfb5b95f948321fb4429321dbf3dd48f9288 --- /dev/null +++ b/modules/components/m2m_flow_former/costvol.py @@ -0,0 +1,424 @@ +#!/usr/bin/env python + +import collections +import cupy +import os +import re +import torch +import typing + + +########################################################## + + +objCudacache = {} + + +def cuda_int32(intIn:int): + return cupy.int32(intIn) +# end + + +def cuda_float32(fltIn:float): + return cupy.float32(fltIn) +# end + + +def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict): + if 'device' not in objCudacache: + objCudacache['device'] = torch.cuda.get_device_name() + # end + + strKey = strFunction + + for strVariable in objVariables: + objValue = objVariables[strVariable] + + strKey += strVariable + + if objValue is None: + continue + + elif type(objValue) == int: + strKey += str(objValue) + + elif type(objValue) == float: + strKey += str(objValue) + + elif type(objValue) == bool: + strKey += str(objValue) + + elif type(objValue) == str: + strKey += objValue + + elif type(objValue) == torch.Tensor: + strKey += str(objValue.dtype) + strKey += str(objValue.shape) + strKey += str(objValue.stride()) + + elif True: + print(strVariable, type(objValue)) + assert(False) + + # end + # end + + strKey += objCudacache['device'] + + if strKey not in objCudacache: + for strVariable in objVariables: + objValue = objVariables[strVariable] + + if objValue is None: + continue + + elif type(objValue) == int: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == float: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == bool: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == str: + strKernel = strKernel.replace('{{' + strVariable + '}}', objValue) + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8: + strKernel = strKernel.replace('{{type}}', 'unsigned char') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16: + strKernel = strKernel.replace('{{type}}', 'half') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32: + strKernel = strKernel.replace('{{type}}', 'float') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64: + strKernel = strKernel.replace('{{type}}', 'double') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32: + strKernel = strKernel.replace('{{type}}', 'int') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64: + strKernel = strKernel.replace('{{type}}', 'long') + + elif type(objValue) == torch.Tensor: + print(strVariable, objValue.dtype) + assert(False) + + elif True: + print(strVariable, type(objValue)) + assert(False) + + # end + # end + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + # end + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item())) + # end + + while True: + objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == '(' else 0 + intParentheses -= 1 if strKernel[intStop] == ')' else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(',') + + assert(intArgs == len(strArgs) - 1) + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')') + # end + + strKernel = strKernel.replace('OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', '(' + str.join('+', strIndex) + ')') + # end + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()', strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == '(' else 0 + intParentheses -= 1 if strKernel[intStop] == ')' else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(',') + + assert(intArgs == len(strArgs) - 1) + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')') + # end + + strKernel = strKernel.replace('VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', strTensor + '[' + str.join('+', strIndex) + ']') + # end + + objCudacache[strKey] = { + 'strFunction': strFunction, + 'strKernel': strKernel + } + # end + + return strKey +# end + + +@cupy.memoize(for_each_device=True) +def cuda_launch(strKey:str): + if 'CUDA_HOME' not in os.environ: + os.environ['CUDA_HOME'] = '/usr/local/cuda/' + # end + + return cupy.cuda.compile_with_cache(objCudacache[strKey]['strKernel'], tuple(['-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include'])).get_function(objCudacache[strKey]['strFunction']) +# end + + +########################################################## + + +class costvol_func(torch.autograd.Function): + @staticmethod + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward(self, tenOne, tenTwo): + tenOut = tenOne.new_empty([tenOne.shape[0], 81, tenOne.shape[2], tenOne.shape[3]]) + + cuda_launch(cuda_kernel('costvol_out', ''' + extern "C" __global__ void __launch_bounds__(512) costvol_out( + const int n, + const {{type}}* __restrict__ tenOne, + const {{type}}* __restrict__ tenTwo, + {{type}}* __restrict__ tenOut + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_0(tenOut); + const int intC = -1; + const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut); + const int intX = ( intIndex ) % SIZE_3(tenOut); + + {{type}} fltOne[{{intChans}}]; + + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX); + } + + int intOffset = OFFSET_4(tenOut, intN, 0, intY, intX); + + for (int intOy = intY - 4; intOy <= intY + 4; intOy += 1) { + for (int intOx = intX - 4; intOx <= intX + 4; intOx += 1) { + {{type}} fltValue = 0.0f; + + if ((intOy >= 0) && (intOy < SIZE_2(tenOut)) && (intOx >= 0) && (intOx < SIZE_3(tenOut))) { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltValue += abs(fltOne[intValue] - VALUE_4(tenTwo, intN, intValue, intOy, intOx)); + } + } else { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltValue += abs(fltOne[intValue]); + } + } + + tenOut[intOffset] = fltValue / SIZE_1(tenOne); + intOffset += SIZE_2(tenOut) * SIZE_3(tenOut); + } + } + } } + ''', { + 'intChans': tenOne.shape[1], + 'tenOne': tenOne, + 'tenTwo': tenTwo, + 'tenOut': tenOut + }))( + grid=tuple([int(((tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]) + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), tenOut.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + + self.save_for_backward(tenOne, tenTwo) + + return tenOut + # end + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(self, tenOutgrad): + tenOne, tenTwo = self.saved_tensors + + tenOutgrad = tenOutgrad.contiguous(); assert(tenOutgrad.is_cuda == True) + + tenOnegrad = tenOne.new_zeros([tenOne.shape[0], tenOne.shape[1], tenOne.shape[2], tenOne.shape[3]]) if self.needs_input_grad[0] == True else None + tenTwograd = tenTwo.new_zeros([tenTwo.shape[0], tenTwo.shape[1], tenTwo.shape[2], tenTwo.shape[3]]) if self.needs_input_grad[1] == True else None + + if tenOnegrad is not None: + cuda_launch(cuda_kernel('costvol_onegrad', ''' + extern "C" __global__ void __launch_bounds__(512) costvol_onegrad( + const int n, + const {{type}}* __restrict__ tenOne, + const {{type}}* __restrict__ tenTwo, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenOnegrad, + {{type}}* __restrict__ tenTwograd + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenOnegrad) / SIZE_2(tenOnegrad) ) % SIZE_0(tenOnegrad); + const int intC = -1; + const int intY = ( intIndex / SIZE_3(tenOnegrad) ) % SIZE_2(tenOnegrad); + const int intX = ( intIndex ) % SIZE_3(tenOnegrad); + + {{type}} fltOne[{{intChans}}]; + + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX); + } + + int intOffset = OFFSET_4(tenOutgrad, intN, 0, intY, intX); + + for (int intOy = intY - 4; intOy <= intY + 4; intOy += 1) { + for (int intOx = intX - 4; intOx <= intX + 4; intOx += 1) { + if ((intOy >= 0) && (intOy < SIZE_2(tenOutgrad)) && (intOx >= 0) && (intOx < SIZE_3(tenOutgrad))) { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + if (fltOne[intValue] - VALUE_4(tenTwo, intN, intValue, intOy, intOx) >= 0.0f) { + tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += +tenOutgrad[intOffset] / SIZE_1(tenOne); + } else { + tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += -tenOutgrad[intOffset] / SIZE_1(tenOne); + } + } + } else { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + if (fltOne[intValue] >= 0.0f) { + tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += +tenOutgrad[intOffset] / SIZE_1(tenOne); + } else { + tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += -tenOutgrad[intOffset] / SIZE_1(tenOne); + } + } + } + + intOffset += SIZE_2(tenOutgrad) * SIZE_3(tenOutgrad); + } + } + } } + ''', { + 'intChans': tenOne.shape[1], + 'tenOne': tenOne, + 'tenTwo': tenTwo, + 'tenOutgrad': tenOutgrad, + 'tenOnegrad': tenOnegrad, + 'tenTwograd': tenTwograd + }))( + grid=tuple([int(((tenOnegrad.shape[0] * tenOnegrad.shape[2] * tenOnegrad.shape[3]) + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenOnegrad.shape[0] * tenOnegrad.shape[2] * tenOnegrad.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), tenOutgrad.data_ptr(), tenOnegrad.data_ptr(), tenTwograd.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + # end + + if tenTwograd is not None: + cuda_launch(cuda_kernel('costvol_twograd', ''' + extern "C" __global__ void __launch_bounds__(512) costvol_twograd( + const int n, + const {{type}}* __restrict__ tenOne, + const {{type}}* __restrict__ tenTwo, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenOnegrad, + {{type}}* __restrict__ tenTwograd + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenTwograd) / SIZE_2(tenTwograd) ) % SIZE_0(tenTwograd); + const int intC = -1; + const int intY = ( intIndex / SIZE_3(tenTwograd) ) % SIZE_2(tenTwograd); + const int intX = ( intIndex ) % SIZE_3(tenTwograd); + + {{type}} fltOne[{{intChans}}]; + + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX); + } + + int intOffset = OFFSET_4(tenOutgrad, intN, 0, intY, intX); + + for (int intOy = intY - 4; intOy <= intY + 4; intOy += 1) { + for (int intOx = intX - 4; intOx <= intX + 4; intOx += 1) { + if ((intOy >= 0) && (intOy < SIZE_2(tenOutgrad)) && (intOx >= 0) && (intOx < SIZE_3(tenOutgrad))) { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + if (fltOne[intValue] - VALUE_4(tenTwo, intN, intValue, intOy, intOx) >= 0.0f) { + atomicAdd(&tenTwograd[OFFSET_4(tenTwograd, intN, intValue, intOy, intOx)], -tenOutgrad[intOffset] / SIZE_1(tenOne)); + } else { + atomicAdd(&tenTwograd[OFFSET_4(tenTwograd, intN, intValue, intOy, intOx)], +tenOutgrad[intOffset] / SIZE_1(tenOne)); + } + } + } else { + // ... + } + + intOffset += SIZE_2(tenOutgrad) * SIZE_3(tenOutgrad); + } + } + } } + ''', { + 'intChans': tenOne.shape[1], + 'tenOne': tenOne, + 'tenTwo': tenTwo, + 'tenOutgrad': tenOutgrad, + 'tenOnegrad': tenOnegrad, + 'tenTwograd': tenTwograd + }))( + grid=tuple([int(((tenTwograd.shape[0] * tenTwograd.shape[2] * tenTwograd.shape[3]) + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenTwograd.shape[0] * tenTwograd.shape[2] * tenTwograd.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), tenOutgrad.data_ptr(), tenOnegrad.data_ptr(), tenTwograd.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + # end + + return tenOnegrad, tenTwograd, None, None + # end +# end diff --git a/modules/components/m2m_flow_former/m2m.py b/modules/components/m2m_flow_former/m2m.py new file mode 100644 index 0000000000000000000000000000000000000000..3f7332a1a02950e9d4e5e0e342280c25392ec447 --- /dev/null +++ b/modules/components/m2m_flow_former/m2m.py @@ -0,0 +1,418 @@ + +import math +import torch +import typing + +from ..components import register +from .backwarp import * +from .LatentCostFormer.transformer import * +from .softsplat import * +from .cfg import get_cfg + + + +def photometric_consistency(img0, img1, flow01): + return (img0 - backwarp(img1, flow01)).abs().sum(dim=1, keepdims=True) + + +def flow_consistency(flow01, flow10): + return (flow01 + backwarp(flow10, flow01)).abs().sum(dim=1, keepdims=True) + + + + +def gaussian(x): + gaussian_kernel = torch.tensor([[1, 2, 1], + [2, 4, 2], + [1, 2, 1]]) / 16 + gaussian_kernel = gaussian_kernel.repeat(2, 1, 1, 1) + gaussian_kernel = gaussian_kernel.to(torch.cuda.current_device()) + x = torch.nn.functional.pad(x, (1, 1, 1, 1), mode='reflect') + out = torch.nn.functional.conv2d(x, gaussian_kernel, groups=x.shape[1]) + # out = TF.gaussian_blur(x, [3, 3], sigma=[2, 2]) + return out + + +def variance_flow(flow): + flow = flow * torch.tensor(data=[2.0 / (flow.shape[3] - 1.0), 2.0 / (flow.shape[2] - 1.0)], dtype=flow.dtype, + device=flow.device).view(1, 2, 1, 1) + return (gaussian(flow ** 2) - gaussian(flow) ** 2 + 1e-4).sqrt().abs().sum(dim=1, keepdim=True) + +########################################################## + +def forwarp_mframe_mask(tenIn1, tenFlow1, t1, tenIn2, tenFlow2, t2, tenMetric1=None, tenMetric2=None): + def one_fdir(tenIn, tenFlow, td, tenMetric): + tenIn = torch.cat([tenIn * td * (tenMetric).clip(-20.0, 20.0).exp(), td * (tenMetric).clip(-20.0, 20.0).exp()], + 1) + + tenOut = softsplat_func.apply(tenIn, tenFlow) + + return tenOut[:, :-1, :, :], tenOut[:, -1:, :, :] + 0.0000001 + + flow_num = tenFlow1.shape[0] + tenOut = 0 + tenNormalize = 0 + for idx in range(flow_num): + tenOutF, tenNormalizeF = one_fdir(tenIn1[idx], tenFlow1[idx], t1[idx], tenMetric1[idx]) + tenOutB, tenNormalizeB = one_fdir(tenIn2[idx], tenFlow2[idx], t2[idx], tenMetric2[idx]) + + tenOut += tenOutF + tenOutB + tenNormalize += tenNormalizeF + tenNormalizeB + + return tenOut / tenNormalize, tenNormalize < 0.00001 + + +################################################################### + +c = 16 + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return torch.nn.Sequential( + torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + torch.nn.PReLU(out_planes) + ) + + +def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): + return torch.nn.Sequential( + torch.torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, + kernel_size=kernel_size, stride=stride, padding=padding, bias=True), + torch.nn.PReLU(out_planes) + ) + + +class Conv2(torch.nn.Module): + def __init__(self, in_planes, out_planes, stride=2): + super(Conv2, self).__init__() + self.conv1 = conv(in_planes, out_planes, 3, stride, 1) + self.conv2 = conv(out_planes, out_planes, 3, 1, 1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + +class Conv2n(torch.nn.Module): + def __init__(self, in_planes, out_planes, stride=2): + super(Conv2n, self).__init__() + self.conv1 = conv(in_planes, in_planes, 3, stride, 1) + self.conv2 = conv(in_planes, in_planes, 3, 1, 1) + self.conv3 = conv(in_planes, in_planes, 1, 1, 0) + self.conv4 = conv(in_planes, out_planes, 1, 1, 0) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + return x + + +##################################################### + +class ImgPyramid(torch.nn.Module): + def __init__(self): + super(ImgPyramid, self).__init__() + self.conv1 = Conv2(3, c) + self.conv2 = Conv2(c, 2 * c) + self.conv3 = Conv2(2 * c, 4 * c) + self.conv4 = Conv2(4 * c, 8 * c) + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(x1) + x3 = self.conv3(x2) + x4 = self.conv4(x3) + return [x1, x2, x3, x4] + + +class EncDec(torch.nn.Module): + def __init__(self, branch): + super(EncDec, self).__init__() + self.branch = branch + + self.down0 = Conv2(8, 2 * c) + self.down1 = Conv2(6 * c, 4 * c) + self.down2 = Conv2(12 * c, 8 * c) + self.down3 = Conv2(24 * c, 16 * c) + + self.up0 = deconv(48 * c, 8 * c) + self.up1 = deconv(16 * c, 4 * c) + self.up2 = deconv(8 * c, 2 * c) + self.up3 = deconv(4 * c, c) + self.conv = torch.nn.Conv2d(c, 2 * self.branch, 3, 1, 1) + + self.conv_m = torch.nn.Conv2d(c, 1, 3, 1, 1) + + # For Channel dimennsion + self.conv_C = torch.nn.Sequential( + torch.nn.AdaptiveAvgPool2d(1), + torch.nn.Conv2d(16 * c, 16 * 16 * c, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True), + torch.nn.Sigmoid() + ) + + # For Height dimennsion + self.conv_H = torch.nn.Sequential( + torch.nn.AdaptiveAvgPool2d((None, 1)), + torch.nn.Conv2d(16 * c, 16, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True), + torch.nn.Sigmoid() + ) + + # For Width dimennsion + self.conv_W = torch.nn.Sequential( + torch.nn.AdaptiveAvgPool2d((1, None)), + torch.nn.Conv2d(16 * c, 16, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True), + torch.nn.Sigmoid() + ) + + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, flow0, flow1, im0, im1, c0, c1): + N_, C_, H_, W_ = im0.shape + + wim1 = backwarp(im1, flow0) + wim0 = backwarp(im0, flow1) + s0_0 = self.down0(torch.cat((flow0, im0, wim1), 1)) + s1_0 = self.down0(torch.cat((flow1, im1, wim0), 1)) + + ######################################################################################### + flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + + wf0 = backwarp(torch.cat((s0_0, c0[0]), 1), flow1) + wf1 = backwarp(torch.cat((s1_0, c1[0]), 1), flow0) + + s0_1 = self.down1(torch.cat((s0_0, c0[0], wf1), 1)) + s1_1 = self.down1(torch.cat((s1_0, c1[0], wf0), 1)) + + ######################################################################################### + flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + + wf0 = backwarp(torch.cat((s0_1, c0[1]), 1), flow1) + wf1 = backwarp(torch.cat((s1_1, c1[1]), 1), flow0) + + s0_2 = self.down2(torch.cat((s0_1, c0[1], wf1), 1)) + s1_2 = self.down2(torch.cat((s1_1, c1[1], wf0), 1)) + + ######################################################################################### + flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + + wf0 = backwarp(torch.cat((s0_2, c0[2]), 1), flow1) + wf1 = backwarp(torch.cat((s1_2, c1[2]), 1), flow0) + + s0_3 = self.down3(torch.cat((s0_2, c0[2], wf1), 1)) + s1_3 = self.down3(torch.cat((s1_2, c1[2], wf0), 1)) + + ######################################################################################### + + s0_3_c = self.conv_C(s0_3) + s0_3_c = s0_3_c.view(N_, 16, -1, 1, 1) + + s0_3_h = self.conv_H(s0_3) + s0_3_h = s0_3_h.view(N_, 16, 1, -1, 1) + + s0_3_w = self.conv_W(s0_3) + s0_3_w = s0_3_w.view(N_, 16, 1, 1, -1) + + cube0 = (s0_3_c * s0_3_h * s0_3_w).mean(1) + + s0_3 = s0_3 * cube0 + + s1_3_c = self.conv_C(s1_3) + s1_3_c = s1_3_c.view(N_, 16, -1, 1, 1) + + s1_3_h = self.conv_H(s1_3) + s1_3_h = s1_3_h.view(N_, 16, 1, -1, 1) + + s1_3_w = self.conv_W(s1_3) + s1_3_w = s1_3_w.view(N_, 16, 1, 1, -1) + + cube1 = (s1_3_c * s1_3_h * s1_3_w).mean(1) + + s1_3 = s1_3 * cube1 + + ######################################################################################### + flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + + wf0 = backwarp(torch.cat((s0_3, c0[3]), 1), flow1) + wf1 = backwarp(torch.cat((s1_3, c1[3]), 1), flow0) + + x0 = self.up0(torch.cat((s0_3, c0[3], wf1), 1)) + x1 = self.up0(torch.cat((s1_3, c1[3], wf0), 1)) + + x0 = self.up1(torch.cat((s0_2, x0), 1)) + x1 = self.up1(torch.cat((s1_2, x1), 1)) + + x0 = self.up2(torch.cat((s0_1, x0), 1)) + x1 = self.up2(torch.cat((s1_1, x1), 1)) + + x0 = self.up3(torch.cat((s0_0, x0), 1)) + x1 = self.up3(torch.cat((s1_0, x1), 1)) + + m0 = self.sigmoid(self.conv_m(x0)) * 0.8 + 0.1 + m1 = self.sigmoid(self.conv_m(x1)) * 0.8 + 0.1 + + x0 = self.conv(x0) + x1 = self.conv(x1) + + return x0, x1, m0.repeat(1, self.branch, 1, 1), m1.repeat(1, self.branch, 1, 1) + + +@register('m2m_flowformer') +class M2MFlowFormer(torch.nn.Module): + def __init__(self, ratio=2): + super(M2MFlowFormer, self).__init__() + self.branch = 4 + self.ratio = ratio + cfg = get_cfg().latentcostformer + + self.netFlow = FlowFormer(cfg) + checkpoint = torch.load('./modules/components/m2m_flow_former/flowformer++.pth') + checkpoint_mod = {k.replace('module.', ''): checkpoint[k] for k in checkpoint.keys()} + self.netFlow.load_state_dict(checkpoint_mod, strict=False) + + # self.paramAlpha = torch.nn.Parameter(10.0 * torch.ones(1, 1, 1, 1)) + + class MotionRefineNet(torch.nn.Module): + def __init__(self, branch): + super(MotionRefineNet, self).__init__() + self.branch = branch + self.img_pyramid = ImgPyramid() + self.motion_encdec = EncDec(branch) + + def forward(self, flow0, flow1, im0, im1, ratio): + flow0 = ratio * torch.nn.functional.interpolate(input=flow0, scale_factor=ratio, mode='bilinear', + align_corners=False) + flow1 = ratio * torch.nn.functional.interpolate(input=flow1, scale_factor=ratio, mode='bilinear', + align_corners=False) + + c0 = self.img_pyramid(im0) + c1 = self.img_pyramid(im1) + + flow_res = self.motion_encdec(flow0, flow1, im0, im1, c0, c1) + + flow0 = flow0.repeat(1, self.branch, 1, 1) + flow_res[0] + flow1 = flow1.repeat(1, self.branch, 1, 1) + flow_res[1] + + return flow0, flow1, flow_res[2], flow_res[3] + + self.MRN = MotionRefineNet(self.branch) + + self.alpha = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_photo_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_flow_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_variation_flow = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + + def get_splat_weight(self, img0, img1, flow01, flow10): + M_splat = 1 / (1 + self.alpha_splat_photo_consistency * photometric_consistency(img0, img1, flow01).detach()) + \ + 1 / (1 + self.alpha_splat_flow_consistency * flow_consistency(flow01, flow10).detach()) + \ + 1 / (1 + self.alpha_splat_variation_flow * variance_flow(flow01).detach()) + return M_splat * self.alpha + + def forward(self, img0, img1, time_step=[0.5], ratio=None, **kwargs): + if ratio is None: + ratio = self.ratio + + intWidth = img0.shape[3] and img1.shape[3] + intHeight = img0.shape[2] and img1.shape[2] + + intPadr = int(((ratio * 16) - (intWidth % (ratio * 16))) % (ratio * 16)) + intPadb = int(((ratio * 16) - (intHeight % (ratio * 16))) % (ratio * 16)) + + img0 = torch.nn.functional.pad(input=img0, pad=[0, intPadr, 0, intPadb], mode='replicate') + img1 = torch.nn.functional.pad(input=img1, pad=[0, intPadr, 0, intPadb], mode='replicate') + + N_, C_, H_, W_ = img0.shape + + outputs = [] + result_dict = {} + + im0_ = torch.nn.functional.interpolate(input=img0, scale_factor=1.0 / ratio, mode='bilinear', + align_corners=False) + im1_ = torch.nn.functional.interpolate(input=img1, scale_factor=1.0 / ratio, mode='bilinear', + align_corners=False) + + tenFwds = self.netFlow(im0_, im1_) + tenBwds = self.netFlow(im1_, im0_) + + + with torch.set_grad_enabled(False): + tenStats = [img0, img1] + tenMean_ = sum([tenIn.mean([1, 2, 3], True) for tenIn in tenStats]) / len(tenStats) + tenStd_ = (sum([tenIn.std([1, 2, 3], False, True).square() + ( + tenMean_ - tenIn.mean([1, 2, 3], True)).square() for tenIn in tenStats]) / len(tenStats)).sqrt() + + im0_o = (img0 - tenMean_) / (tenStd_ + 0.0000001) + im1_o = (img1 - tenMean_) / (tenStd_ + 0.0000001) + + img0 = (img0 - tenMean_) / (tenStd_ + 0.0000001) + img1 = (img1 - tenMean_) / (tenStd_ + 0.0000001) + + result_dict['flowfwd'] = torch.nn.functional.interpolate(tenFwds[-1], scale_factor=ratio, mode='bilinear', align_corners=False)[:, :, + :intHeight, :intWidth].clone().detach() * ratio + result_dict['flowbwd'] = torch.nn.functional.interpolate(tenBwds[-1], scale_factor=ratio, mode='bilinear', align_corners=False)[:, :, + :intHeight, :intWidth].clone().detach() * ratio + + outputs = [] + + for i in range(len(tenFwds)): + tenFwd, tenBwd, WeiMF, WeiMB = self.MRN(tenFwds[i], tenBwds[i], img0, img1, ratio) + + img0_ = im0_o.repeat(1, self.branch, 1, 1) + img1_ = im1_o.repeat(1, self.branch, 1, 1) + tenStd = tenStd_.repeat(1, self.branch, 1, 1) + tenMean = tenMean_.repeat(1, self.branch, 1, 1) + fltTime = time_step.repeat(1, self.branch, 1, 1) + + tenFwd = tenFwd.reshape(N_, self.branch, 2, H_, W_).view(N_ * self.branch, 2, H_, W_) + tenBwd = tenBwd.reshape(N_, self.branch, 2, H_, W_).view(N_ * self.branch, 2, H_, W_) + + WeiMF = WeiMF.reshape(N_, self.branch, 1, H_, W_).view(N_ * self.branch, 1, H_, W_) + WeiMB = WeiMB.reshape(N_, self.branch, 1, H_, W_).view(N_ * self.branch, 1, H_, W_) + + img0_ = img0_.reshape(N_, self.branch, 3, H_, W_).view(N_ * self.branch, 3, H_, W_) + img1_ = img1_.reshape(N_, self.branch, 3, H_, W_).view(N_ * self.branch, 3, H_, W_) + + tenStd = tenStd.reshape(N_, self.branch, 1, 1, 1).view(N_ * self.branch, 1, 1, 1) + tenMean = tenMean.reshape(N_, self.branch, 1, 1, 1).view(N_ * self.branch, 1, 1, 1) + fltTime = fltTime.reshape(N_, self.branch, 1, 1, 1).view(N_ * self.branch, 1, 1, 1) + + tenPhotoone = self.get_splat_weight(img0_, img1_, tenFwd, tenBwd) * WeiMF + tenPhototwo = self.get_splat_weight(img1_, img0_, tenBwd, tenFwd) * WeiMB + + t0 = fltTime + flow0 = tenFwd * t0 + metric0 = tenPhotoone + + t1 = 1.0 - fltTime + flow1 = tenBwd * t1 + metric1 = tenPhototwo + + flow0 = flow0.reshape(N_, self.branch, 2, H_, W_).permute(1, 0, 2, 3, 4) + flow1 = flow1.reshape(N_, self.branch, 2, H_, W_).permute(1, 0, 2, 3, 4) + + metric0 = metric0.reshape(N_, self.branch, 1, H_, W_).permute(1, 0, 2, 3, 4) + metric1 = metric1.reshape(N_, self.branch, 1, H_, W_).permute(1, 0, 2, 3, 4) + + img0_ = img0_.reshape(N_, self.branch, 3, H_, W_).permute(1, 0, 2, 3, 4) + img1_ = img1_.reshape(N_, self.branch, 3, H_, W_).permute(1, 0, 2, 3, 4) + + t0 = t0.reshape(N_, self.branch, 1, 1, 1).permute(1, 0, 2, 3, 4) + t1 = t1.reshape(N_, self.branch, 1, 1, 1).permute(1, 0, 2, 3, 4) + + tenOutput, mask = forwarp_mframe_mask(img0_, flow0, t1, img1_, flow1, t0, metric0, metric1) + + tenOutput = tenOutput + mask * (t1.mean(0) * im0_o + t0.mean(0) * im1_o) + + output = (tenOutput * (tenStd_ + 0.0000001)) + tenMean_ + outputs.append(output[:, :, :intHeight, :intWidth]) + result_dict['imgt_preds'] = outputs + result_dict['imgt_pred'] = outputs[-1] + + return result_dict diff --git a/modules/components/m2m_flow_former/softsplat.py b/modules/components/m2m_flow_former/softsplat.py new file mode 100644 index 0000000000000000000000000000000000000000..02b25bd9158c3ac5c21db9835977458b0ae9e6c8 --- /dev/null +++ b/modules/components/m2m_flow_former/softsplat.py @@ -0,0 +1,534 @@ +#!/usr/bin/env python + +######################################### +# This implementation is taken from +# https://github.com/sniklaus/softmax-splatting +######################################### + +import collections +import cupy +import os +import re +import torch +import typing + + +########################################################## + + +objCudacache = {} + + +def cuda_int32(intIn:int): + return cupy.int32(intIn) +# end + + +def cuda_float32(fltIn:float): + return cupy.float32(fltIn) +# end + + +def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict): + if 'device' not in objCudacache: + objCudacache['device'] = torch.cuda.get_device_name() + # end + + strKey = strFunction + + for strVariable in objVariables: + objValue = objVariables[strVariable] + + strKey += strVariable + + if objValue is None: + continue + + elif type(objValue) == int: + strKey += str(objValue) + + elif type(objValue) == float: + strKey += str(objValue) + + elif type(objValue) == bool: + strKey += str(objValue) + + elif type(objValue) == str: + strKey += objValue + + elif type(objValue) == torch.Tensor: + strKey += str(objValue.dtype) + strKey += str(objValue.shape) + strKey += str(objValue.stride()) + + elif True: + print(strVariable, type(objValue)) + assert(False) + + # end + # end + + strKey += objCudacache['device'] + + if strKey not in objCudacache: + for strVariable in objVariables: + objValue = objVariables[strVariable] + + if objValue is None: + continue + + elif type(objValue) == int: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == float: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == bool: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == str: + strKernel = strKernel.replace('{{' + strVariable + '}}', objValue) + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8: + strKernel = strKernel.replace('{{type}}', 'unsigned char') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16: + strKernel = strKernel.replace('{{type}}', 'half') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32: + strKernel = strKernel.replace('{{type}}', 'float') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64: + strKernel = strKernel.replace('{{type}}', 'double') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32: + strKernel = strKernel.replace('{{type}}', 'int') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64: + strKernel = strKernel.replace('{{type}}', 'long') + + elif type(objValue) == torch.Tensor: + print(strVariable, objValue.dtype) + assert(False) + + elif True: + print(strVariable, type(objValue)) + assert(False) + + # end + # end + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + # end + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item())) + # end + + while True: + objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == '(' else 0 + intParentheses -= 1 if strKernel[intStop] == ')' else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(',') + + assert(intArgs == len(strArgs) - 1) + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')') + # end + + strKernel = strKernel.replace('OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', '(' + str.join('+', strIndex) + ')') + # end + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()', strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == '(' else 0 + intParentheses -= 1 if strKernel[intStop] == ')' else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(',') + + assert(intArgs == len(strArgs) - 1) + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')') + # end + + strKernel = strKernel.replace('VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', strTensor + '[' + str.join('+', strIndex) + ']') + # end + + objCudacache[strKey] = { + 'strFunction': strFunction, + 'strKernel': strKernel + } + # end + + return strKey +# end + + +@cupy.memoize(for_each_device=True) +def cuda_launch(strKey:str): + if 'CUDA_HOME' not in os.environ: + os.environ['CUDA_HOME'] = '/usr/local/cuda/' + # end + + return cupy.cuda.compile_with_cache(objCudacache[strKey]['strKernel'], tuple(['-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include'])).get_function(objCudacache[strKey]['strFunction']) +# end + + +########################################################## + + +def softsplat(tenIn:torch.Tensor, tenFlow:torch.Tensor, tenMetric:torch.Tensor, strMode:str): + assert(strMode.split('-')[0] in ['sum', 'avg', 'linear', 'soft']) + + if strMode == 'sum': assert(tenMetric is None) + if strMode == 'avg': assert(tenMetric is None) + if strMode.split('-')[0] == 'linear': assert(tenMetric is not None) + if strMode.split('-')[0] == 'soft': assert(tenMetric is not None) + + if strMode == 'avg': + tenIn = torch.cat([tenIn, tenIn.new_ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]])], 1) + + elif strMode.split('-')[0] == 'linear': + tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1) + + elif strMode.split('-')[0] == 'soft': + tenIn = torch.cat([tenIn * tenMetric.exp(), tenMetric.exp()], 1) + + # end + + tenOut = softsplat_func.apply(tenIn, tenFlow) + + if strMode.split('-')[0] in ['avg', 'linear', 'soft']: + tenNormalize = tenOut[:, -1:, :, :] + + if len(strMode.split('-')) == 1: + tenNormalize = tenNormalize + 0.0000001 + + elif strMode.split('-')[1] == 'addeps': + tenNormalize = tenNormalize + 0.0000001 + + elif strMode.split('-')[1] == 'zeroeps': + tenNormalize[tenNormalize == 0.0] = 1.0 + + elif strMode.split('-')[1] == 'clipeps': + tenNormalize = tenNormalize.clip(0.0000001, None) + + # end + + tenOut = tenOut[:, :-1, :, :] / tenNormalize + # end + + return tenOut +# end + + +class softsplat_func(torch.autograd.Function): + @staticmethod + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward(self, tenIn, tenFlow): + tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) + + if tenIn.is_cuda == True: + cuda_launch(cuda_kernel('softsplat_out', ''' + extern "C" __global__ void __launch_bounds__(512) softsplat_out( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + {{type}}* __restrict__ tenOut + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) / SIZE_1(tenOut) ) % SIZE_0(tenOut); + const int intC = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_1(tenOut); + const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut); + const int intX = ( intIndex ) % SIZE_3(tenOut); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + {{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX); + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); + {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); + {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); + {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn * fltNorthwest); + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn * fltNortheast); + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn * fltSouthwest); + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn * fltSoutheast); + } + } } + ''', { + 'tenIn': tenIn, + 'tenFlow': tenFlow, + 'tenOut': tenOut + }))( + grid=tuple([int((tenOut.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenOut.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOut.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + + elif tenIn.is_cuda != True: + assert(False) + + # end + + self.save_for_backward(tenIn, tenFlow) + + return tenOut + # end + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(self, tenOutgrad): + tenIn, tenFlow = self.saved_tensors + + tenOutgrad = tenOutgrad.contiguous(); assert(tenOutgrad.is_cuda == True) + + tenIngrad = tenIn.new_empty([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) if self.needs_input_grad[0] == True else None + tenFlowgrad = tenFlow.new_empty([tenFlow.shape[0], tenFlow.shape[1], tenFlow.shape[2], tenFlow.shape[3]]) if self.needs_input_grad[1] == True else None + + if tenIngrad is not None: + cuda_launch(cuda_kernel('softsplat_ingrad', ''' + extern "C" __global__ void __launch_bounds__(512) softsplat_ingrad( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenIngrad, + {{type}}* __restrict__ tenFlowgrad + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad); + const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) ) % SIZE_1(tenIngrad); + const int intY = ( intIndex / SIZE_3(tenIngrad) ) % SIZE_2(tenIngrad); + const int intX = ( intIndex ) % SIZE_3(tenIngrad); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltIngrad = 0.0f; + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); + {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); + {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); + {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest; + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNortheastY, intNortheastX) * fltNortheast; + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest; + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast; + } + + tenIngrad[intIndex] = fltIngrad; + } } + ''', { + 'tenIn': tenIn, + 'tenFlow': tenFlow, + 'tenOutgrad': tenOutgrad, + 'tenIngrad': tenIngrad, + 'tenFlowgrad': tenFlowgrad + }))( + grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenIngrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), tenIngrad.data_ptr(), None], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + # end + + if tenFlowgrad is not None: + cuda_launch(cuda_kernel('softsplat_flowgrad', ''' + extern "C" __global__ void __launch_bounds__(512) softsplat_flowgrad( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenIngrad, + {{type}}* __restrict__ tenFlowgrad + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) / SIZE_1(tenFlowgrad) ) % SIZE_0(tenFlowgrad); + const int intC = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) ) % SIZE_1(tenFlowgrad); + const int intY = ( intIndex / SIZE_3(tenFlowgrad) ) % SIZE_2(tenFlowgrad); + const int intX = ( intIndex ) % SIZE_3(tenFlowgrad); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltFlowgrad = 0.0f; + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = 0.0f; + {{type}} fltNortheast = 0.0f; + {{type}} fltSouthwest = 0.0f; + {{type}} fltSoutheast = 0.0f; + + if (intC == 0) { + fltNorthwest = (({{type}}) (-1.0f)) * (({{type}}) (intSoutheastY) - fltY); + fltNortheast = (({{type}}) (+1.0f)) * (({{type}}) (intSouthwestY) - fltY); + fltSouthwest = (({{type}}) (-1.0f)) * (fltY - ({{type}}) (intNortheastY)); + fltSoutheast = (({{type}}) (+1.0f)) * (fltY - ({{type}}) (intNorthwestY)); + + } else if (intC == 1) { + fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (-1.0f)); + fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (-1.0f)); + fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (({{type}}) (+1.0f)); + fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (({{type}}) (+1.0f)); + + } + + for (int intChannel = 0; intChannel < SIZE_1(tenOutgrad); intChannel += 1) { + {{type}} fltIn = VALUE_4(tenIn, intN, intChannel, intY, intX); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNorthwestY, intNorthwestX) * fltIn * fltNorthwest; + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNortheastY, intNortheastX) * fltIn * fltNortheast; + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSouthwestY, intSouthwestX) * fltIn * fltSouthwest; + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSoutheastY, intSoutheastX) * fltIn * fltSoutheast; + } + } + + tenFlowgrad[intIndex] = fltFlowgrad; + } } + ''', { + 'tenIn': tenIn, + 'tenFlow': tenFlow, + 'tenOutgrad': tenOutgrad, + 'tenIngrad': tenIngrad, + 'tenFlowgrad': tenFlowgrad + }))( + grid=tuple([int((tenFlowgrad.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenFlowgrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), None, tenFlowgrad.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + # end + + return tenIngrad, tenFlowgrad + # end +# end diff --git a/modules/components/m2m_pwc/__init__.py b/modules/components/m2m_pwc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..52bf0be7340a5db7e9b8c9e383cda0c1506cdf8f --- /dev/null +++ b/modules/components/m2m_pwc/__init__.py @@ -0,0 +1 @@ +from .m2m import M2M_PWC diff --git a/modules/components/m2m_pwc/__pycache__/__init__.cpython-310.pyc b/modules/components/m2m_pwc/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..508722f7a22d2c2894cbdebe7d0f483e9075cf05 Binary files /dev/null and b/modules/components/m2m_pwc/__pycache__/__init__.cpython-310.pyc differ diff --git a/modules/components/m2m_pwc/__pycache__/__init__.cpython-38.pyc b/modules/components/m2m_pwc/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af5bf3285155b8af099e35a6bf4dd8eb505ce2a1 Binary files /dev/null and b/modules/components/m2m_pwc/__pycache__/__init__.cpython-38.pyc differ diff --git a/modules/components/m2m_pwc/__pycache__/__init__.cpython-39.pyc b/modules/components/m2m_pwc/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4edefdee083ff955414b95e15b61952b33b8689e Binary files /dev/null and b/modules/components/m2m_pwc/__pycache__/__init__.cpython-39.pyc differ diff --git a/modules/components/m2m_pwc/__pycache__/backwarp.cpython-310.pyc b/modules/components/m2m_pwc/__pycache__/backwarp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0cb4117c0581de77c09f0e91cbc5266e5f358de6 Binary files /dev/null and b/modules/components/m2m_pwc/__pycache__/backwarp.cpython-310.pyc differ diff --git a/modules/components/m2m_pwc/__pycache__/backwarp.cpython-38.pyc b/modules/components/m2m_pwc/__pycache__/backwarp.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..743ccd5b733329d2d0ad4037d310e94d201e21d3 Binary files /dev/null and b/modules/components/m2m_pwc/__pycache__/backwarp.cpython-38.pyc differ diff --git a/modules/components/m2m_pwc/__pycache__/backwarp.cpython-39.pyc b/modules/components/m2m_pwc/__pycache__/backwarp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97aac7ff38da7e409a28d376d4810e07998fbc1d Binary files /dev/null and b/modules/components/m2m_pwc/__pycache__/backwarp.cpython-39.pyc differ diff --git a/modules/components/m2m_pwc/__pycache__/costvol.cpython-310.pyc b/modules/components/m2m_pwc/__pycache__/costvol.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..500674f62e03672336b6ddae26a17e0439c233c2 Binary files /dev/null and b/modules/components/m2m_pwc/__pycache__/costvol.cpython-310.pyc differ diff --git a/modules/components/m2m_pwc/__pycache__/costvol.cpython-38.pyc b/modules/components/m2m_pwc/__pycache__/costvol.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..615d38570fc2c5d5b55c474903c0fbed013f3fc4 Binary files /dev/null and b/modules/components/m2m_pwc/__pycache__/costvol.cpython-38.pyc differ diff --git a/modules/components/m2m_pwc/__pycache__/costvol.cpython-39.pyc b/modules/components/m2m_pwc/__pycache__/costvol.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81e9054e51e2af58955017f3fa49c206882bc626 Binary files /dev/null and b/modules/components/m2m_pwc/__pycache__/costvol.cpython-39.pyc differ diff --git a/modules/components/m2m_pwc/__pycache__/m2m.cpython-310.pyc b/modules/components/m2m_pwc/__pycache__/m2m.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eae195e3ca9b96ea7ad90a43424cc36fb0c5397e Binary files /dev/null and b/modules/components/m2m_pwc/__pycache__/m2m.cpython-310.pyc differ diff --git a/modules/components/m2m_pwc/__pycache__/m2m.cpython-38.pyc b/modules/components/m2m_pwc/__pycache__/m2m.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d09c371c424c574cdf463e6ddde495a5ce43d3a1 Binary files /dev/null and b/modules/components/m2m_pwc/__pycache__/m2m.cpython-38.pyc differ diff --git a/modules/components/m2m_pwc/__pycache__/m2m.cpython-39.pyc b/modules/components/m2m_pwc/__pycache__/m2m.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13464ef71659ac151adcd95989a21b794356822e Binary files /dev/null and b/modules/components/m2m_pwc/__pycache__/m2m.cpython-39.pyc differ diff --git a/modules/components/m2m_pwc/__pycache__/pwcnet.cpython-310.pyc b/modules/components/m2m_pwc/__pycache__/pwcnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ca2bca319dc0015574b83b3c570fb3a65d2007d Binary files /dev/null and b/modules/components/m2m_pwc/__pycache__/pwcnet.cpython-310.pyc differ diff --git a/modules/components/m2m_pwc/__pycache__/pwcnet.cpython-38.pyc b/modules/components/m2m_pwc/__pycache__/pwcnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a310ee054d3c875f10311a23fb6fda4f464217f Binary files /dev/null and b/modules/components/m2m_pwc/__pycache__/pwcnet.cpython-38.pyc differ diff --git a/modules/components/m2m_pwc/__pycache__/pwcnet.cpython-39.pyc b/modules/components/m2m_pwc/__pycache__/pwcnet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df4d997467cbe5c5767eb1025def06895f0f853d Binary files /dev/null and b/modules/components/m2m_pwc/__pycache__/pwcnet.cpython-39.pyc differ diff --git a/modules/components/m2m_pwc/__pycache__/softsplat.cpython-310.pyc b/modules/components/m2m_pwc/__pycache__/softsplat.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa64f654976999e18a8f9f417fce1b3351be56df Binary files /dev/null and b/modules/components/m2m_pwc/__pycache__/softsplat.cpython-310.pyc differ diff --git a/modules/components/m2m_pwc/__pycache__/softsplat.cpython-38.pyc b/modules/components/m2m_pwc/__pycache__/softsplat.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d848e5b4f156cc139c2db564fd0ce502783cea99 Binary files /dev/null and b/modules/components/m2m_pwc/__pycache__/softsplat.cpython-38.pyc differ diff --git a/modules/components/m2m_pwc/__pycache__/softsplat.cpython-39.pyc b/modules/components/m2m_pwc/__pycache__/softsplat.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da7f9a79a19595208f990b2bb42b4678fd35a84e Binary files /dev/null and b/modules/components/m2m_pwc/__pycache__/softsplat.cpython-39.pyc differ diff --git a/modules/components/m2m_pwc/backwarp.py b/modules/components/m2m_pwc/backwarp.py new file mode 100644 index 0000000000000000000000000000000000000000..e99a0a5c1b658e81536825451b865b39c45bc9c4 --- /dev/null +++ b/modules/components/m2m_pwc/backwarp.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python + +import torch + + +########################################################## + + +objBackwarpcache = {} + + +def backwarp(tenIn:torch.Tensor, tenFlow:torch.Tensor): + if 'grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3]) not in objBackwarpcache: + tenHor = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[3], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, 1, -1).repeat(1, 1, tenFlow.shape[2], 1) + tenVer = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[2], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, -1, 1).repeat(1, 1, 1, tenFlow.shape[3]) + + objBackwarpcache['grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3])] = torch.cat([tenHor, tenVer], 1) + # end + + if tenFlow.shape[3] == tenFlow.shape[2]: + tenFlow = tenFlow * (2.0 / ((tenFlow.shape[3] and tenFlow.shape[2]) - 1.0)) + + elif tenFlow.shape[3] != tenFlow.shape[2]: + tenFlow = tenFlow * torch.tensor(data=[2.0 / (tenFlow.shape[3] - 1.0), 2.0 / (tenFlow.shape[2] - 1.0)], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 2, 1, 1) + + # end + + return torch.nn.functional.grid_sample(input=tenIn, grid=(objBackwarpcache['grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3])] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=True) +# end diff --git a/modules/components/m2m_pwc/costvol.py b/modules/components/m2m_pwc/costvol.py new file mode 100644 index 0000000000000000000000000000000000000000..40e1cfb5b95f948321fb4429321dbf3dd48f9288 --- /dev/null +++ b/modules/components/m2m_pwc/costvol.py @@ -0,0 +1,424 @@ +#!/usr/bin/env python + +import collections +import cupy +import os +import re +import torch +import typing + + +########################################################## + + +objCudacache = {} + + +def cuda_int32(intIn:int): + return cupy.int32(intIn) +# end + + +def cuda_float32(fltIn:float): + return cupy.float32(fltIn) +# end + + +def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict): + if 'device' not in objCudacache: + objCudacache['device'] = torch.cuda.get_device_name() + # end + + strKey = strFunction + + for strVariable in objVariables: + objValue = objVariables[strVariable] + + strKey += strVariable + + if objValue is None: + continue + + elif type(objValue) == int: + strKey += str(objValue) + + elif type(objValue) == float: + strKey += str(objValue) + + elif type(objValue) == bool: + strKey += str(objValue) + + elif type(objValue) == str: + strKey += objValue + + elif type(objValue) == torch.Tensor: + strKey += str(objValue.dtype) + strKey += str(objValue.shape) + strKey += str(objValue.stride()) + + elif True: + print(strVariable, type(objValue)) + assert(False) + + # end + # end + + strKey += objCudacache['device'] + + if strKey not in objCudacache: + for strVariable in objVariables: + objValue = objVariables[strVariable] + + if objValue is None: + continue + + elif type(objValue) == int: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == float: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == bool: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == str: + strKernel = strKernel.replace('{{' + strVariable + '}}', objValue) + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8: + strKernel = strKernel.replace('{{type}}', 'unsigned char') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16: + strKernel = strKernel.replace('{{type}}', 'half') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32: + strKernel = strKernel.replace('{{type}}', 'float') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64: + strKernel = strKernel.replace('{{type}}', 'double') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32: + strKernel = strKernel.replace('{{type}}', 'int') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64: + strKernel = strKernel.replace('{{type}}', 'long') + + elif type(objValue) == torch.Tensor: + print(strVariable, objValue.dtype) + assert(False) + + elif True: + print(strVariable, type(objValue)) + assert(False) + + # end + # end + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + # end + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item())) + # end + + while True: + objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == '(' else 0 + intParentheses -= 1 if strKernel[intStop] == ')' else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(',') + + assert(intArgs == len(strArgs) - 1) + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')') + # end + + strKernel = strKernel.replace('OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', '(' + str.join('+', strIndex) + ')') + # end + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()', strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == '(' else 0 + intParentheses -= 1 if strKernel[intStop] == ')' else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(',') + + assert(intArgs == len(strArgs) - 1) + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')') + # end + + strKernel = strKernel.replace('VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', strTensor + '[' + str.join('+', strIndex) + ']') + # end + + objCudacache[strKey] = { + 'strFunction': strFunction, + 'strKernel': strKernel + } + # end + + return strKey +# end + + +@cupy.memoize(for_each_device=True) +def cuda_launch(strKey:str): + if 'CUDA_HOME' not in os.environ: + os.environ['CUDA_HOME'] = '/usr/local/cuda/' + # end + + return cupy.cuda.compile_with_cache(objCudacache[strKey]['strKernel'], tuple(['-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include'])).get_function(objCudacache[strKey]['strFunction']) +# end + + +########################################################## + + +class costvol_func(torch.autograd.Function): + @staticmethod + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward(self, tenOne, tenTwo): + tenOut = tenOne.new_empty([tenOne.shape[0], 81, tenOne.shape[2], tenOne.shape[3]]) + + cuda_launch(cuda_kernel('costvol_out', ''' + extern "C" __global__ void __launch_bounds__(512) costvol_out( + const int n, + const {{type}}* __restrict__ tenOne, + const {{type}}* __restrict__ tenTwo, + {{type}}* __restrict__ tenOut + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_0(tenOut); + const int intC = -1; + const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut); + const int intX = ( intIndex ) % SIZE_3(tenOut); + + {{type}} fltOne[{{intChans}}]; + + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX); + } + + int intOffset = OFFSET_4(tenOut, intN, 0, intY, intX); + + for (int intOy = intY - 4; intOy <= intY + 4; intOy += 1) { + for (int intOx = intX - 4; intOx <= intX + 4; intOx += 1) { + {{type}} fltValue = 0.0f; + + if ((intOy >= 0) && (intOy < SIZE_2(tenOut)) && (intOx >= 0) && (intOx < SIZE_3(tenOut))) { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltValue += abs(fltOne[intValue] - VALUE_4(tenTwo, intN, intValue, intOy, intOx)); + } + } else { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltValue += abs(fltOne[intValue]); + } + } + + tenOut[intOffset] = fltValue / SIZE_1(tenOne); + intOffset += SIZE_2(tenOut) * SIZE_3(tenOut); + } + } + } } + ''', { + 'intChans': tenOne.shape[1], + 'tenOne': tenOne, + 'tenTwo': tenTwo, + 'tenOut': tenOut + }))( + grid=tuple([int(((tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]) + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), tenOut.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + + self.save_for_backward(tenOne, tenTwo) + + return tenOut + # end + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(self, tenOutgrad): + tenOne, tenTwo = self.saved_tensors + + tenOutgrad = tenOutgrad.contiguous(); assert(tenOutgrad.is_cuda == True) + + tenOnegrad = tenOne.new_zeros([tenOne.shape[0], tenOne.shape[1], tenOne.shape[2], tenOne.shape[3]]) if self.needs_input_grad[0] == True else None + tenTwograd = tenTwo.new_zeros([tenTwo.shape[0], tenTwo.shape[1], tenTwo.shape[2], tenTwo.shape[3]]) if self.needs_input_grad[1] == True else None + + if tenOnegrad is not None: + cuda_launch(cuda_kernel('costvol_onegrad', ''' + extern "C" __global__ void __launch_bounds__(512) costvol_onegrad( + const int n, + const {{type}}* __restrict__ tenOne, + const {{type}}* __restrict__ tenTwo, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenOnegrad, + {{type}}* __restrict__ tenTwograd + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenOnegrad) / SIZE_2(tenOnegrad) ) % SIZE_0(tenOnegrad); + const int intC = -1; + const int intY = ( intIndex / SIZE_3(tenOnegrad) ) % SIZE_2(tenOnegrad); + const int intX = ( intIndex ) % SIZE_3(tenOnegrad); + + {{type}} fltOne[{{intChans}}]; + + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX); + } + + int intOffset = OFFSET_4(tenOutgrad, intN, 0, intY, intX); + + for (int intOy = intY - 4; intOy <= intY + 4; intOy += 1) { + for (int intOx = intX - 4; intOx <= intX + 4; intOx += 1) { + if ((intOy >= 0) && (intOy < SIZE_2(tenOutgrad)) && (intOx >= 0) && (intOx < SIZE_3(tenOutgrad))) { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + if (fltOne[intValue] - VALUE_4(tenTwo, intN, intValue, intOy, intOx) >= 0.0f) { + tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += +tenOutgrad[intOffset] / SIZE_1(tenOne); + } else { + tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += -tenOutgrad[intOffset] / SIZE_1(tenOne); + } + } + } else { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + if (fltOne[intValue] >= 0.0f) { + tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += +tenOutgrad[intOffset] / SIZE_1(tenOne); + } else { + tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += -tenOutgrad[intOffset] / SIZE_1(tenOne); + } + } + } + + intOffset += SIZE_2(tenOutgrad) * SIZE_3(tenOutgrad); + } + } + } } + ''', { + 'intChans': tenOne.shape[1], + 'tenOne': tenOne, + 'tenTwo': tenTwo, + 'tenOutgrad': tenOutgrad, + 'tenOnegrad': tenOnegrad, + 'tenTwograd': tenTwograd + }))( + grid=tuple([int(((tenOnegrad.shape[0] * tenOnegrad.shape[2] * tenOnegrad.shape[3]) + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenOnegrad.shape[0] * tenOnegrad.shape[2] * tenOnegrad.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), tenOutgrad.data_ptr(), tenOnegrad.data_ptr(), tenTwograd.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + # end + + if tenTwograd is not None: + cuda_launch(cuda_kernel('costvol_twograd', ''' + extern "C" __global__ void __launch_bounds__(512) costvol_twograd( + const int n, + const {{type}}* __restrict__ tenOne, + const {{type}}* __restrict__ tenTwo, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenOnegrad, + {{type}}* __restrict__ tenTwograd + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenTwograd) / SIZE_2(tenTwograd) ) % SIZE_0(tenTwograd); + const int intC = -1; + const int intY = ( intIndex / SIZE_3(tenTwograd) ) % SIZE_2(tenTwograd); + const int intX = ( intIndex ) % SIZE_3(tenTwograd); + + {{type}} fltOne[{{intChans}}]; + + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX); + } + + int intOffset = OFFSET_4(tenOutgrad, intN, 0, intY, intX); + + for (int intOy = intY - 4; intOy <= intY + 4; intOy += 1) { + for (int intOx = intX - 4; intOx <= intX + 4; intOx += 1) { + if ((intOy >= 0) && (intOy < SIZE_2(tenOutgrad)) && (intOx >= 0) && (intOx < SIZE_3(tenOutgrad))) { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + if (fltOne[intValue] - VALUE_4(tenTwo, intN, intValue, intOy, intOx) >= 0.0f) { + atomicAdd(&tenTwograd[OFFSET_4(tenTwograd, intN, intValue, intOy, intOx)], -tenOutgrad[intOffset] / SIZE_1(tenOne)); + } else { + atomicAdd(&tenTwograd[OFFSET_4(tenTwograd, intN, intValue, intOy, intOx)], +tenOutgrad[intOffset] / SIZE_1(tenOne)); + } + } + } else { + // ... + } + + intOffset += SIZE_2(tenOutgrad) * SIZE_3(tenOutgrad); + } + } + } } + ''', { + 'intChans': tenOne.shape[1], + 'tenOne': tenOne, + 'tenTwo': tenTwo, + 'tenOutgrad': tenOutgrad, + 'tenOnegrad': tenOnegrad, + 'tenTwograd': tenTwograd + }))( + grid=tuple([int(((tenTwograd.shape[0] * tenTwograd.shape[2] * tenTwograd.shape[3]) + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenTwograd.shape[0] * tenTwograd.shape[2] * tenTwograd.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), tenOutgrad.data_ptr(), tenOnegrad.data_ptr(), tenTwograd.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + # end + + return tenOnegrad, tenTwograd, None, None + # end +# end diff --git a/modules/components/m2m_pwc/m2m.py b/modules/components/m2m_pwc/m2m.py new file mode 100644 index 0000000000000000000000000000000000000000..a00c9e74c955918f2cc3837619ec9b74730fd36f --- /dev/null +++ b/modules/components/m2m_pwc/m2m.py @@ -0,0 +1,368 @@ + +import math +import torch +import typing + +from ..components import register +from .backwarp import * +from .pwcnet import * +from .softsplat import * + + +########################################################## + +def forwarp_mframe_mask(tenIn1, tenFlow1, t1, tenIn2, tenFlow2, t2, tenMetric1=None, tenMetric2=None): + def one_fdir(tenIn, tenFlow, td, tenMetric): + tenIn = torch.cat([tenIn * td * (tenMetric).clip(-20.0, 20.0).exp(), td * (tenMetric).clip(-20.0, 20.0).exp()], + 1) + + tenOut = softsplat_func.apply(tenIn, tenFlow) + + return tenOut[:, :-1, :, :], tenOut[:, -1:, :, :] + 0.0000001 + + flow_num = tenFlow1.shape[0] + tenOut = 0 + tenNormalize = 0 + for idx in range(flow_num): + tenOutF, tenNormalizeF = one_fdir(tenIn1[idx], tenFlow1[idx], t1[idx], tenMetric1[idx]) + tenOutB, tenNormalizeB = one_fdir(tenIn2[idx], tenFlow2[idx], t2[idx], tenMetric2[idx]) + + tenOut += tenOutF + tenOutB + tenNormalize += tenNormalizeF + tenNormalizeB + + return tenOut / tenNormalize, tenNormalize < 0.00001 + + +################################################################### + +c = 16 + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return torch.nn.Sequential( + torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + torch.nn.PReLU(out_planes) + ) + + +def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): + return torch.nn.Sequential( + torch.torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, + kernel_size=kernel_size, stride=stride, padding=padding, bias=True), + torch.nn.PReLU(out_planes) + ) + + +class Conv2(torch.nn.Module): + def __init__(self, in_planes, out_planes, stride=2): + super(Conv2, self).__init__() + self.conv1 = conv(in_planes, out_planes, 3, stride, 1) + self.conv2 = conv(out_planes, out_planes, 3, 1, 1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + +class Conv2n(torch.nn.Module): + def __init__(self, in_planes, out_planes, stride=2): + super(Conv2n, self).__init__() + self.conv1 = conv(in_planes, in_planes, 3, stride, 1) + self.conv2 = conv(in_planes, in_planes, 3, 1, 1) + self.conv3 = conv(in_planes, in_planes, 1, 1, 0) + self.conv4 = conv(in_planes, out_planes, 1, 1, 0) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + return x + + +##################################################### + +class ImgPyramid(torch.nn.Module): + def __init__(self): + super(ImgPyramid, self).__init__() + self.conv1 = Conv2(3, c) + self.conv2 = Conv2(c, 2 * c) + self.conv3 = Conv2(2 * c, 4 * c) + self.conv4 = Conv2(4 * c, 8 * c) + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(x1) + x3 = self.conv3(x2) + x4 = self.conv4(x3) + return [x1, x2, x3, x4] + + +class EncDec(torch.nn.Module): + def __init__(self, branch): + super(EncDec, self).__init__() + self.branch = branch + + self.down0 = Conv2(8, 2 * c) + self.down1 = Conv2(6 * c, 4 * c) + self.down2 = Conv2(12 * c, 8 * c) + self.down3 = Conv2(24 * c, 16 * c) + + self.up0 = deconv(48 * c, 8 * c) + self.up1 = deconv(16 * c, 4 * c) + self.up2 = deconv(8 * c, 2 * c) + self.up3 = deconv(4 * c, c) + self.conv = torch.nn.Conv2d(c, 2 * self.branch, 3, 1, 1) + + self.conv_m = torch.nn.Conv2d(c, 1, 3, 1, 1) + + # For Channel dimennsion + self.conv_C = torch.nn.Sequential( + torch.nn.AdaptiveAvgPool2d(1), + torch.nn.Conv2d(16 * c, 16 * 16 * c, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True), + torch.nn.Sigmoid() + ) + + # For Height dimennsion + self.conv_H = torch.nn.Sequential( + torch.nn.AdaptiveAvgPool2d((None, 1)), + torch.nn.Conv2d(16 * c, 16, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True), + torch.nn.Sigmoid() + ) + + # For Width dimennsion + self.conv_W = torch.nn.Sequential( + torch.nn.AdaptiveAvgPool2d((1, None)), + torch.nn.Conv2d(16 * c, 16, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True), + torch.nn.Sigmoid() + ) + + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, flow0, flow1, im0, im1, c0, c1): + N_, C_, H_, W_ = im0.shape + + wim1 = backwarp(im1, flow0) + wim0 = backwarp(im0, flow1) + s0_0 = self.down0(torch.cat((flow0, im0, wim1), 1)) + s1_0 = self.down0(torch.cat((flow1, im1, wim0), 1)) + + ######################################################################################### + flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + + wf0 = backwarp(torch.cat((s0_0, c0[0]), 1), flow1) + wf1 = backwarp(torch.cat((s1_0, c1[0]), 1), flow0) + + s0_1 = self.down1(torch.cat((s0_0, c0[0], wf1), 1)) + s1_1 = self.down1(torch.cat((s1_0, c1[0], wf0), 1)) + + ######################################################################################### + flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + + wf0 = backwarp(torch.cat((s0_1, c0[1]), 1), flow1) + wf1 = backwarp(torch.cat((s1_1, c1[1]), 1), flow0) + + s0_2 = self.down2(torch.cat((s0_1, c0[1], wf1), 1)) + s1_2 = self.down2(torch.cat((s1_1, c1[1], wf0), 1)) + + ######################################################################################### + flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + + wf0 = backwarp(torch.cat((s0_2, c0[2]), 1), flow1) + wf1 = backwarp(torch.cat((s1_2, c1[2]), 1), flow0) + + s0_3 = self.down3(torch.cat((s0_2, c0[2], wf1), 1)) + s1_3 = self.down3(torch.cat((s1_2, c1[2], wf0), 1)) + + ######################################################################################### + + s0_3_c = self.conv_C(s0_3) + s0_3_c = s0_3_c.view(N_, 16, -1, 1, 1) + + s0_3_h = self.conv_H(s0_3) + s0_3_h = s0_3_h.view(N_, 16, 1, -1, 1) + + s0_3_w = self.conv_W(s0_3) + s0_3_w = s0_3_w.view(N_, 16, 1, 1, -1) + + cube0 = (s0_3_c * s0_3_h * s0_3_w).mean(1) + + s0_3 = s0_3 * cube0 + + s1_3_c = self.conv_C(s1_3) + s1_3_c = s1_3_c.view(N_, 16, -1, 1, 1) + + s1_3_h = self.conv_H(s1_3) + s1_3_h = s1_3_h.view(N_, 16, 1, -1, 1) + + s1_3_w = self.conv_W(s1_3) + s1_3_w = s1_3_w.view(N_, 16, 1, 1, -1) + + cube1 = (s1_3_c * s1_3_h * s1_3_w).mean(1) + + s1_3 = s1_3 * cube1 + + ######################################################################################### + flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + + wf0 = backwarp(torch.cat((s0_3, c0[3]), 1), flow1) + wf1 = backwarp(torch.cat((s1_3, c1[3]), 1), flow0) + + x0 = self.up0(torch.cat((s0_3, c0[3], wf1), 1)) + x1 = self.up0(torch.cat((s1_3, c1[3], wf0), 1)) + + x0 = self.up1(torch.cat((s0_2, x0), 1)) + x1 = self.up1(torch.cat((s1_2, x1), 1)) + + x0 = self.up2(torch.cat((s0_1, x0), 1)) + x1 = self.up2(torch.cat((s1_1, x1), 1)) + + x0 = self.up3(torch.cat((s0_0, x0), 1)) + x1 = self.up3(torch.cat((s1_0, x1), 1)) + + m0 = self.sigmoid(self.conv_m(x0)) * 0.8 + 0.1 + m1 = self.sigmoid(self.conv_m(x1)) * 0.8 + 0.1 + + x0 = self.conv(x0) + x1 = self.conv(x1) + + return x0, x1, m0.repeat(1, self.branch, 1, 1), m1.repeat(1, self.branch, 1, 1) + + +@register('m2m_pwc') +class M2M_PWC(torch.nn.Module): + def __init__(self, ratio=4): + super(M2M_PWC, self).__init__() + self.branch = 4 + self.ratio = ratio + + self.netFlow = Network() + + self.paramAlpha = torch.nn.Parameter(10.0 * torch.ones(1, 1, 1, 1)) + + class MotionRefineNet(torch.nn.Module): + def __init__(self, branch): + super(MotionRefineNet, self).__init__() + self.branch = branch + self.img_pyramid = ImgPyramid() + self.motion_encdec = EncDec(branch) + + def forward(self, flow0, flow1, im0, im1, ratio): + flow0 = ratio * torch.nn.functional.interpolate(input=flow0, scale_factor=ratio, mode='bilinear', + align_corners=False) + flow1 = ratio * torch.nn.functional.interpolate(input=flow1, scale_factor=ratio, mode='bilinear', + align_corners=False) + + c0 = self.img_pyramid(im0) + c1 = self.img_pyramid(im1) + + flow_res = self.motion_encdec(flow0, flow1, im0, im1, c0, c1) + + flow0 = flow0.repeat(1, self.branch, 1, 1) + flow_res[0] + flow1 = flow1.repeat(1, self.branch, 1, 1) + flow_res[1] + + return flow0, flow1, flow_res[2], flow_res[3] + + self.MRN = MotionRefineNet(self.branch) + + def forward(self, img0, img1, time_step=[0.5], ratio=None, **kwargs): + if ratio is None: + ratio = self.ratio + + intWidth = img0.shape[3] and img1.shape[3] + intHeight = img0.shape[2] and img1.shape[2] + + intPadr = ((ratio * 16) - (intWidth % (ratio * 16))) % (ratio * 16) + intPadb = ((ratio * 16) - (intHeight % (ratio * 16))) % (ratio * 16) + + img0 = torch.nn.functional.pad(input=img0, pad=[0, intPadr, 0, intPadb], mode='replicate') + img1 = torch.nn.functional.pad(input=img1, pad=[0, intPadr, 0, intPadb], mode='replicate') + + N_, C_, H_, W_ = img0.shape + + outputs = [] + result_dict = {} + with torch.set_grad_enabled(False): + tenStats = [img0, img1] + tenMean_ = sum([tenIn.mean([1, 2, 3], True) for tenIn in tenStats]) / len(tenStats) + tenStd_ = (sum([tenIn.std([1, 2, 3], False, True).square() + ( + tenMean_ - tenIn.mean([1, 2, 3], True)).square() for tenIn in tenStats]) / len(tenStats)).sqrt() + + im0_o = (img0 - tenMean_) / (tenStd_ + 0.0000001) + im1_o = (img1 - tenMean_) / (tenStd_ + 0.0000001) + + img0 = (img0 - tenMean_) / (tenStd_ + 0.0000001) + img1 = (img1 - tenMean_) / (tenStd_ + 0.0000001) + + im0_ = torch.nn.functional.interpolate(input=img0, scale_factor=2.0 / ratio, mode='bilinear', + align_corners=False) + im1_ = torch.nn.functional.interpolate(input=img1, scale_factor=2.0 / ratio, mode='bilinear', + align_corners=False) + + tenFwd, tenBwd = self.netFlow.bidir(im0_, im1_) + + result_dict['flowfwd'] = torch.nn.functional.interpolate(tenFwd, scale_factor=ratio, mode='bilinear', align_corners=False)[:, :, + :intHeight, :intWidth].clone().detach() * ratio + result_dict['flowbwd'] = torch.nn.functional.interpolate(tenBwd, scale_factor=ratio, mode='bilinear', align_corners=False)[:, :, + :intHeight, :intWidth].clone().detach() * ratio + + tenFwd, tenBwd, WeiMF, WeiMB = self.MRN(tenFwd, tenBwd, img0, img1, ratio) + + img0 = im0_o.repeat(1, self.branch, 1, 1) + img1 = im1_o.repeat(1, self.branch, 1, 1) + tenStd = tenStd_.repeat(1, self.branch, 1, 1) + tenMean = tenMean_.repeat(1, self.branch, 1, 1) + fltTime = time_step.repeat(1, self.branch, 1, 1) + + tenFwd = tenFwd.reshape(N_, self.branch, 2, H_, W_).view(N_ * self.branch, 2, H_, W_) + tenBwd = tenBwd.reshape(N_, self.branch, 2, H_, W_).view(N_ * self.branch, 2, H_, W_) + + WeiMF = WeiMF.reshape(N_, self.branch, 1, H_, W_).view(N_ * self.branch, 1, H_, W_) + WeiMB = WeiMB.reshape(N_, self.branch, 1, H_, W_).view(N_ * self.branch, 1, H_, W_) + + img0 = img0.reshape(N_, self.branch, 3, H_, W_).view(N_ * self.branch, 3, H_, W_) + img1 = img1.reshape(N_, self.branch, 3, H_, W_).view(N_ * self.branch, 3, H_, W_) + + tenStd = tenStd.reshape(N_, self.branch, 1, 1, 1).view(N_ * self.branch, 1, 1, 1) + tenMean = tenMean.reshape(N_, self.branch, 1, 1, 1).view(N_ * self.branch, 1, 1, 1) + fltTime = fltTime.reshape(N_, self.branch, 1, 1, 1).view(N_ * self.branch, 1, 1, 1) + + tenPhotoone = (1.0 - (WeiMF * (img0 - backwarp(img1, tenFwd).detach()).abs().mean([1], True))).clip( + 0.001, None).square() + tenPhototwo = (1.0 - (WeiMB * (img1 - backwarp(img0, tenBwd).detach()).abs().mean([1], True))).clip( + 0.001, None).square() + + t0 = fltTime + flow0 = tenFwd * t0 + metric0 = self.paramAlpha * tenPhotoone + + t1 = 1.0 - fltTime + flow1 = tenBwd * t1 + metric1 = self.paramAlpha * tenPhototwo + + flow0 = flow0.reshape(N_, self.branch, 2, H_, W_).permute(1, 0, 2, 3, 4) + flow1 = flow1.reshape(N_, self.branch, 2, H_, W_).permute(1, 0, 2, 3, 4) + + metric0 = metric0.reshape(N_, self.branch, 1, H_, W_).permute(1, 0, 2, 3, 4) + metric1 = metric1.reshape(N_, self.branch, 1, H_, W_).permute(1, 0, 2, 3, 4) + + img0 = img0.reshape(N_, self.branch, 3, H_, W_).permute(1, 0, 2, 3, 4) + img1 = img1.reshape(N_, self.branch, 3, H_, W_).permute(1, 0, 2, 3, 4) + + t0 = t0.reshape(N_, self.branch, 1, 1, 1).permute(1, 0, 2, 3, 4) + t1 = t1.reshape(N_, self.branch, 1, 1, 1).permute(1, 0, 2, 3, 4) + + tenOutput, mask = forwarp_mframe_mask(img0, flow0, t1, img1, flow1, t0, metric0, metric1) + + tenOutput = tenOutput + mask * (t1.mean(0) * im0_o + t0.mean(0) * im1_o) + + output = (tenOutput * (tenStd_ + 0.0000001)) + tenMean_ + result_dict['imgt_pred'] = output[:, :, :intHeight, :intWidth] + + return result_dict diff --git a/modules/components/m2m_pwc/pwcnet.py b/modules/components/m2m_pwc/pwcnet.py new file mode 100644 index 0000000000000000000000000000000000000000..3433cae85e172874fcfcb871a71cdbf04c86025b --- /dev/null +++ b/modules/components/m2m_pwc/pwcnet.py @@ -0,0 +1,308 @@ +#!/usr/bin/env python + +import math +import torch +import typing + +from .backwarp import backwarp +from .costvol import costvol_func + + +########################################################## + + +class Basic(torch.nn.Module): + def __init__(self, strType:str, intChans:typing.List[int], objScratch:typing.Optional[typing.Dict]=None): + super().__init__() + + self.strType = strType + self.netEvenize = None + self.netMain = None + self.netShortcut = None + + intIn = intChans[0] + intOut = intChans[-1] + netMain = [] + intChans = intChans.copy() + fltStride = 1.0 + + for intPart, strPart in enumerate(self.strType.split('+')[0].split('-')): + if strPart.startswith('evenize') == True and intPart == 0: + class Evenize(torch.nn.Module): + def __init__(self, strPad): + super().__init__() + + self.strPad = strPad + # end + + def forward(self, tenIn:torch.Tensor) -> torch.Tensor: + intPad = [0, 0, 0, 0] + + if tenIn.shape[3] % 2 != 0: intPad[1] = 1 + if tenIn.shape[2] % 2 != 0: intPad[3] = 1 + + if min(intPad) != 0 or max(intPad) != 0: + tenIn = torch.nn.functional.pad(input=tenIn, pad=intPad, mode=self.strPad if self.strPad != 'zeros' else 'constant', value=0.0) + # end + + return tenIn + # end + # end + + strPad = 'zeros' + + if '(' in strPart: + if 'replpad' in strPart.split('(')[1].split(')')[0].split(','): strPad = 'replicate' + if 'reflpad' in strPart.split('(')[1].split(')')[0].split(','): strPad = 'reflect' + # end + + self.netEvenize = Evenize(strPad) + + elif strPart.startswith('conv') == True: + intKsize = 3 + intPad = 1 + strPad = 'zeros' + + if '(' in strPart: + intKsize = int(strPart.split('(')[1].split(')')[0].split(',')[0]) + intPad = int(math.floor(0.5 * (intKsize - 1))) + + if 'replpad' in strPart.split('(')[1].split(')')[0].split(','): strPad = 'replicate' + if 'reflpad' in strPart.split('(')[1].split(')')[0].split(','): strPad = 'reflect' + # end + + if 'nopad' in self.strType.split('+'): + intPad = 0 + # end + + netMain += [torch.nn.Conv2d(in_channels=intChans[0], out_channels=intChans[1], kernel_size=intKsize, stride=1, padding=intPad, padding_mode=strPad, bias='nobias' not in self.strType.split('+'))] + intChans = intChans[1:] + fltStride *= 1.0 + + elif strPart.startswith('sconv') == True: + intKsize = 3 + intPad = 1 + strPad = 'zeros' + + if '(' in strPart: + intKsize = int(strPart.split('(')[1].split(')')[0].split(',')[0]) + intPad = int(math.floor(0.5 * (intKsize - 1))) + + if 'replpad' in strPart.split('(')[1].split(')')[0].split(','): strPad = 'replicate' + if 'reflpad' in strPart.split('(')[1].split(')')[0].split(','): strPad = 'reflect' + # end + + if 'nopad' in self.strType.split('+'): + intPad = 0 + # end + + netMain += [torch.nn.Conv2d(in_channels=intChans[0], out_channels=intChans[1], kernel_size=intKsize, stride=2, padding=intPad, padding_mode=strPad, bias='nobias' not in self.strType.split('+'))] + intChans = intChans[1:] + fltStride *= 2.0 + + elif strPart.startswith('up') == True: + class Up(torch.nn.Module): + def __init__(self, strType): + super().__init__() + + self.strType = strType + # end + + def forward(self, tenIn:torch.Tensor) -> torch.Tensor: + if self.strType == 'nearest': + return torch.nn.functional.interpolate(input=tenIn, scale_factor=2.0, mode='nearest-exact', align_corners=False) + + elif self.strType == 'bilinear': + return torch.nn.functional.interpolate(input=tenIn, scale_factor=2.0, mode='bilinear', align_corners=False) + + elif self.strType == 'pyramid': + return pyramid(tenIn, None, 'up') + + elif self.strType == 'shuffle': + return torch.nn.functional.pixel_shuffle(tenIn, upscale_factor=2) # https://github.com/pytorch/pytorch/issues/62854 + + # end + + assert(False) # to make torchscript happy + # end + # end + + strType = 'bilinear' + + if '(' in strPart: + if 'nearest' in strPart.split('(')[1].split(')')[0].split(','): strType = 'nearest' + if 'pyramid' in strPart.split('(')[1].split(')')[0].split(','): strType = 'pyramid' + if 'shuffle' in strPart.split('(')[1].split(')')[0].split(','): strType = 'shuffle' + # end + + netMain += [Up(strType)] + fltStride *= 0.5 + + elif strPart.startswith('prelu') == True: + netMain += [torch.nn.PReLU(num_parameters=1, init=float(strPart.split('(')[1].split(')')[0].split(',')[0]))] + fltStride *= 1.0 + + elif True: + assert(False) + + # end + # end + + self.netMain = torch.nn.Sequential(*netMain) + + for strPart in self.strType.split('+')[1:]: + if strPart.startswith('skip') == True: + if intIn == intOut and fltStride == 1.0: + self.netShortcut = torch.nn.Identity() + + elif intIn != intOut and fltStride == 1.0: + self.netShortcut = torch.nn.Conv2d(in_channels=intIn, out_channels=intOut, kernel_size=1, stride=1, padding=0, bias='nobias' not in self.strType.split('+')) + + elif intIn == intOut and fltStride != 1.0: + class Down(torch.nn.Module): + def __init__(self, fltScale): + super().__init__() + + self.fltScale = fltScale + # end + + def forward(self, tenIn:torch.Tensor) -> torch.Tensor: + return torch.nn.functional.interpolate(input=tenIn, scale_factor=self.fltScale, mode='bilinear', align_corners=False) + # end + # end + + self.netShortcut = Down(1.0 / fltStride) + + elif intIn != intOut and fltStride != 1.0: + class Down(torch.nn.Module): + def __init__(self, fltScale): + super().__init__() + + self.fltScale = fltScale + # end + + def forward(self, tenIn:torch.Tensor) -> torch.Tensor: + return torch.nn.functional.interpolate(input=tenIn, scale_factor=self.fltScale, mode='bilinear', align_corners=False) + # end + # end + + self.netShortcut = torch.nn.Sequential(Down(1.0 / fltStride), torch.nn.Conv2d(in_channels=intIn, out_channels=intOut, kernel_size=1, stride=1, padding=0, bias='nobias' not in self.strType.split('+'))) + + # end + + elif strPart.startswith('...') == True: + pass + + # end + # end + + assert(len(intChans) == 1) + # end + + def forward(self, tenIn:torch.Tensor) -> torch.Tensor: + if self.netEvenize is not None: + tenIn = self.netEvenize(tenIn) + # end + + tenOut = self.netMain(tenIn) + + if self.netShortcut is not None: + tenOut = tenOut + self.netShortcut(tenIn) + # end + + return tenOut + # end +# end + + +########################################################## + + +class Network(torch.nn.Module): + def __init__(self): + super().__init__() + + class Extractor(torch.nn.Module): + def __init__(self): + super().__init__() + + self.netOne = Basic('evenize(replpad)-sconv(2)-prelu(0.25)-conv(3,replpad)-prelu(0.25)-conv(3,replpad)-prelu(0.25)', [3, 32, 32, 32], None) + self.netTwo = Basic('evenize(replpad)-sconv(2)-prelu(0.25)-conv(3,replpad)-prelu(0.25)-conv(3,replpad)-prelu(0.25)', [32, 32, 32, 32], None) + self.netThr = Basic('evenize(replpad)-sconv(2)-prelu(0.25)-conv(3,replpad)-prelu(0.25)-conv(3,replpad)-prelu(0.25)', [32, 32, 32, 32], None) + # end + + def forward(self, tenIn): + tenOne = self.netOne(tenIn) + tenTwo = self.netTwo(tenOne) + tenThr = self.netThr(tenTwo) + tenFou = torch.nn.functional.avg_pool2d(input=tenThr, kernel_size=2, stride=2, count_include_pad=False) + tenFiv = torch.nn.functional.avg_pool2d(input=tenFou, kernel_size=2, stride=2, count_include_pad=False) + + return [tenOne, tenTwo, tenThr, tenFou, tenFiv] + # end + # end + + class Decoder(torch.nn.Module): + def __init__(self, intChannels): + super().__init__() + + self.netCostacti = torch.nn.PReLU(num_parameters=1, init=0.25) + self.netMain = Basic('conv(3,replpad)-prelu(0.25)-conv(3,replpad)-prelu(0.25)-conv(3,replpad)-prelu(0.25)-conv(3,replpad)-prelu(0.25)-conv(3,replpad)-prelu(0.25)-conv(3,replpad)', [intChannels, 128, 128, 96, 64, 32, 2], None) + # end + + def forward(self, tenOne, tenTwo, tenFlow): + if tenFlow is not None: + tenFlow = 2.0 * torch.nn.functional.interpolate(input=tenFlow, scale_factor=2.0, mode='bilinear', align_corners=False) + # end + + tenMain = [] + + if tenFlow is None: + tenMain.append(tenOne) + tenMain.append(self.netCostacti(costvol_func.apply(tenOne, tenTwo))) + + elif tenFlow is not None: + tenMain.append(tenOne) + tenMain.append(self.netCostacti(costvol_func.apply(tenOne, backwarp(tenTwo, tenFlow.detach())))) + tenMain.append(tenFlow) + + # end + + return (tenFlow if tenFlow is not None else 0.0) + self.netMain(torch.cat(tenMain, 1)) + # end + # end + + self.netExtractor = Extractor() + + self.netFiv = Decoder(32 + 81 + 0) + self.netFou = Decoder(32 + 81 + 2) + self.netThr = Decoder(32 + 81 + 2) + self.netTwo = Decoder(32 + 81 + 2) + self.netOne = Decoder(32 + 81 + 2) + + self.load_state_dict(torch.load('./modules/components/m2m_pwc/pwc.pth')) + # end + + def bidir(self, tenOne, tenTwo): + intWidth = tenOne.shape[3] and tenTwo.shape[3] + intHeight = tenOne.shape[2] and tenTwo.shape[2] + + tenOne, tenTwo = list(zip(*[torch.split(tenFeat, [tenOne.shape[0], tenTwo.shape[0]], 0) for tenFeat in self.netExtractor(torch.cat([tenOne, tenTwo], 0))])) + + tenFwd = None + tenFwd = self.netFiv(tenOne[-1], tenTwo[-1], tenFwd) + tenFwd = self.netFou(tenOne[-2], tenTwo[-2], tenFwd) + tenFwd = self.netThr(tenOne[-3], tenTwo[-3], tenFwd) + tenFwd = self.netTwo(tenOne[-4], tenTwo[-4], tenFwd) + tenFwd = self.netOne(tenOne[-5], tenTwo[-5], tenFwd) + + tenBwd = None + tenBwd = self.netFiv(tenTwo[-1], tenOne[-1], tenBwd) + tenBwd = self.netFou(tenTwo[-2], tenOne[-2], tenBwd) + tenBwd = self.netThr(tenTwo[-3], tenOne[-3], tenBwd) + tenBwd = self.netTwo(tenTwo[-4], tenOne[-4], tenBwd) + tenBwd = self.netOne(tenTwo[-5], tenOne[-5], tenBwd) + + return tenFwd, tenBwd + # end +# end diff --git a/modules/components/m2m_pwc/softsplat.py b/modules/components/m2m_pwc/softsplat.py new file mode 100644 index 0000000000000000000000000000000000000000..02b25bd9158c3ac5c21db9835977458b0ae9e6c8 --- /dev/null +++ b/modules/components/m2m_pwc/softsplat.py @@ -0,0 +1,534 @@ +#!/usr/bin/env python + +######################################### +# This implementation is taken from +# https://github.com/sniklaus/softmax-splatting +######################################### + +import collections +import cupy +import os +import re +import torch +import typing + + +########################################################## + + +objCudacache = {} + + +def cuda_int32(intIn:int): + return cupy.int32(intIn) +# end + + +def cuda_float32(fltIn:float): + return cupy.float32(fltIn) +# end + + +def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict): + if 'device' not in objCudacache: + objCudacache['device'] = torch.cuda.get_device_name() + # end + + strKey = strFunction + + for strVariable in objVariables: + objValue = objVariables[strVariable] + + strKey += strVariable + + if objValue is None: + continue + + elif type(objValue) == int: + strKey += str(objValue) + + elif type(objValue) == float: + strKey += str(objValue) + + elif type(objValue) == bool: + strKey += str(objValue) + + elif type(objValue) == str: + strKey += objValue + + elif type(objValue) == torch.Tensor: + strKey += str(objValue.dtype) + strKey += str(objValue.shape) + strKey += str(objValue.stride()) + + elif True: + print(strVariable, type(objValue)) + assert(False) + + # end + # end + + strKey += objCudacache['device'] + + if strKey not in objCudacache: + for strVariable in objVariables: + objValue = objVariables[strVariable] + + if objValue is None: + continue + + elif type(objValue) == int: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == float: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == bool: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == str: + strKernel = strKernel.replace('{{' + strVariable + '}}', objValue) + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8: + strKernel = strKernel.replace('{{type}}', 'unsigned char') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16: + strKernel = strKernel.replace('{{type}}', 'half') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32: + strKernel = strKernel.replace('{{type}}', 'float') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64: + strKernel = strKernel.replace('{{type}}', 'double') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32: + strKernel = strKernel.replace('{{type}}', 'int') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64: + strKernel = strKernel.replace('{{type}}', 'long') + + elif type(objValue) == torch.Tensor: + print(strVariable, objValue.dtype) + assert(False) + + elif True: + print(strVariable, type(objValue)) + assert(False) + + # end + # end + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + # end + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item())) + # end + + while True: + objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == '(' else 0 + intParentheses -= 1 if strKernel[intStop] == ')' else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(',') + + assert(intArgs == len(strArgs) - 1) + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')') + # end + + strKernel = strKernel.replace('OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', '(' + str.join('+', strIndex) + ')') + # end + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()', strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == '(' else 0 + intParentheses -= 1 if strKernel[intStop] == ')' else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(',') + + assert(intArgs == len(strArgs) - 1) + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')') + # end + + strKernel = strKernel.replace('VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', strTensor + '[' + str.join('+', strIndex) + ']') + # end + + objCudacache[strKey] = { + 'strFunction': strFunction, + 'strKernel': strKernel + } + # end + + return strKey +# end + + +@cupy.memoize(for_each_device=True) +def cuda_launch(strKey:str): + if 'CUDA_HOME' not in os.environ: + os.environ['CUDA_HOME'] = '/usr/local/cuda/' + # end + + return cupy.cuda.compile_with_cache(objCudacache[strKey]['strKernel'], tuple(['-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include'])).get_function(objCudacache[strKey]['strFunction']) +# end + + +########################################################## + + +def softsplat(tenIn:torch.Tensor, tenFlow:torch.Tensor, tenMetric:torch.Tensor, strMode:str): + assert(strMode.split('-')[0] in ['sum', 'avg', 'linear', 'soft']) + + if strMode == 'sum': assert(tenMetric is None) + if strMode == 'avg': assert(tenMetric is None) + if strMode.split('-')[0] == 'linear': assert(tenMetric is not None) + if strMode.split('-')[0] == 'soft': assert(tenMetric is not None) + + if strMode == 'avg': + tenIn = torch.cat([tenIn, tenIn.new_ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]])], 1) + + elif strMode.split('-')[0] == 'linear': + tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1) + + elif strMode.split('-')[0] == 'soft': + tenIn = torch.cat([tenIn * tenMetric.exp(), tenMetric.exp()], 1) + + # end + + tenOut = softsplat_func.apply(tenIn, tenFlow) + + if strMode.split('-')[0] in ['avg', 'linear', 'soft']: + tenNormalize = tenOut[:, -1:, :, :] + + if len(strMode.split('-')) == 1: + tenNormalize = tenNormalize + 0.0000001 + + elif strMode.split('-')[1] == 'addeps': + tenNormalize = tenNormalize + 0.0000001 + + elif strMode.split('-')[1] == 'zeroeps': + tenNormalize[tenNormalize == 0.0] = 1.0 + + elif strMode.split('-')[1] == 'clipeps': + tenNormalize = tenNormalize.clip(0.0000001, None) + + # end + + tenOut = tenOut[:, :-1, :, :] / tenNormalize + # end + + return tenOut +# end + + +class softsplat_func(torch.autograd.Function): + @staticmethod + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward(self, tenIn, tenFlow): + tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) + + if tenIn.is_cuda == True: + cuda_launch(cuda_kernel('softsplat_out', ''' + extern "C" __global__ void __launch_bounds__(512) softsplat_out( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + {{type}}* __restrict__ tenOut + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) / SIZE_1(tenOut) ) % SIZE_0(tenOut); + const int intC = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_1(tenOut); + const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut); + const int intX = ( intIndex ) % SIZE_3(tenOut); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + {{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX); + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); + {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); + {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); + {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn * fltNorthwest); + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn * fltNortheast); + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn * fltSouthwest); + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn * fltSoutheast); + } + } } + ''', { + 'tenIn': tenIn, + 'tenFlow': tenFlow, + 'tenOut': tenOut + }))( + grid=tuple([int((tenOut.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenOut.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOut.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + + elif tenIn.is_cuda != True: + assert(False) + + # end + + self.save_for_backward(tenIn, tenFlow) + + return tenOut + # end + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(self, tenOutgrad): + tenIn, tenFlow = self.saved_tensors + + tenOutgrad = tenOutgrad.contiguous(); assert(tenOutgrad.is_cuda == True) + + tenIngrad = tenIn.new_empty([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) if self.needs_input_grad[0] == True else None + tenFlowgrad = tenFlow.new_empty([tenFlow.shape[0], tenFlow.shape[1], tenFlow.shape[2], tenFlow.shape[3]]) if self.needs_input_grad[1] == True else None + + if tenIngrad is not None: + cuda_launch(cuda_kernel('softsplat_ingrad', ''' + extern "C" __global__ void __launch_bounds__(512) softsplat_ingrad( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenIngrad, + {{type}}* __restrict__ tenFlowgrad + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad); + const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) ) % SIZE_1(tenIngrad); + const int intY = ( intIndex / SIZE_3(tenIngrad) ) % SIZE_2(tenIngrad); + const int intX = ( intIndex ) % SIZE_3(tenIngrad); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltIngrad = 0.0f; + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); + {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); + {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); + {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest; + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNortheastY, intNortheastX) * fltNortheast; + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest; + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast; + } + + tenIngrad[intIndex] = fltIngrad; + } } + ''', { + 'tenIn': tenIn, + 'tenFlow': tenFlow, + 'tenOutgrad': tenOutgrad, + 'tenIngrad': tenIngrad, + 'tenFlowgrad': tenFlowgrad + }))( + grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenIngrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), tenIngrad.data_ptr(), None], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + # end + + if tenFlowgrad is not None: + cuda_launch(cuda_kernel('softsplat_flowgrad', ''' + extern "C" __global__ void __launch_bounds__(512) softsplat_flowgrad( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenIngrad, + {{type}}* __restrict__ tenFlowgrad + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) / SIZE_1(tenFlowgrad) ) % SIZE_0(tenFlowgrad); + const int intC = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) ) % SIZE_1(tenFlowgrad); + const int intY = ( intIndex / SIZE_3(tenFlowgrad) ) % SIZE_2(tenFlowgrad); + const int intX = ( intIndex ) % SIZE_3(tenFlowgrad); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltFlowgrad = 0.0f; + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = 0.0f; + {{type}} fltNortheast = 0.0f; + {{type}} fltSouthwest = 0.0f; + {{type}} fltSoutheast = 0.0f; + + if (intC == 0) { + fltNorthwest = (({{type}}) (-1.0f)) * (({{type}}) (intSoutheastY) - fltY); + fltNortheast = (({{type}}) (+1.0f)) * (({{type}}) (intSouthwestY) - fltY); + fltSouthwest = (({{type}}) (-1.0f)) * (fltY - ({{type}}) (intNortheastY)); + fltSoutheast = (({{type}}) (+1.0f)) * (fltY - ({{type}}) (intNorthwestY)); + + } else if (intC == 1) { + fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (-1.0f)); + fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (-1.0f)); + fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (({{type}}) (+1.0f)); + fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (({{type}}) (+1.0f)); + + } + + for (int intChannel = 0; intChannel < SIZE_1(tenOutgrad); intChannel += 1) { + {{type}} fltIn = VALUE_4(tenIn, intN, intChannel, intY, intX); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNorthwestY, intNorthwestX) * fltIn * fltNorthwest; + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNortheastY, intNortheastX) * fltIn * fltNortheast; + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSouthwestY, intSouthwestX) * fltIn * fltSouthwest; + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSoutheastY, intSoutheastX) * fltIn * fltSoutheast; + } + } + + tenFlowgrad[intIndex] = fltFlowgrad; + } } + ''', { + 'tenIn': tenIn, + 'tenFlow': tenFlow, + 'tenOutgrad': tenOutgrad, + 'tenIngrad': tenIngrad, + 'tenFlowgrad': tenFlowgrad + }))( + grid=tuple([int((tenFlowgrad.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenFlowgrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), None, tenFlowgrad.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + # end + + return tenIngrad, tenFlowgrad + # end +# end diff --git a/modules/components/m2m_unimatch/__init__.py b/modules/components/m2m_unimatch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..52bf0be7340a5db7e9b8c9e383cda0c1506cdf8f --- /dev/null +++ b/modules/components/m2m_unimatch/__init__.py @@ -0,0 +1 @@ +from .m2m import M2M_PWC diff --git a/modules/components/m2m_unimatch/__pycache__/__init__.cpython-310.pyc b/modules/components/m2m_unimatch/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15284868739b49eff7c4c09d09ba5705bd5434f7 Binary files /dev/null and b/modules/components/m2m_unimatch/__pycache__/__init__.cpython-310.pyc differ diff --git a/modules/components/m2m_unimatch/__pycache__/__init__.cpython-38.pyc b/modules/components/m2m_unimatch/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3cfd64ea12d5be3b7879d55a91e95db953a37374 Binary files /dev/null and b/modules/components/m2m_unimatch/__pycache__/__init__.cpython-38.pyc differ diff --git a/modules/components/m2m_unimatch/__pycache__/__init__.cpython-39.pyc b/modules/components/m2m_unimatch/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d3095d8ab2ae36e0519bb1518f502fa98cd8f64 Binary files /dev/null and b/modules/components/m2m_unimatch/__pycache__/__init__.cpython-39.pyc differ diff --git a/modules/components/m2m_unimatch/__pycache__/backwarp.cpython-310.pyc b/modules/components/m2m_unimatch/__pycache__/backwarp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a653bbab22af426fa8cb18949933805d01c2ac65 Binary files /dev/null and b/modules/components/m2m_unimatch/__pycache__/backwarp.cpython-310.pyc differ diff --git a/modules/components/m2m_unimatch/__pycache__/backwarp.cpython-38.pyc b/modules/components/m2m_unimatch/__pycache__/backwarp.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01b02a04d2850d7a0d9ecf79a059942ea9f34bec Binary files /dev/null and b/modules/components/m2m_unimatch/__pycache__/backwarp.cpython-38.pyc differ diff --git a/modules/components/m2m_unimatch/__pycache__/backwarp.cpython-39.pyc b/modules/components/m2m_unimatch/__pycache__/backwarp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30e655da93632ded29f20fa83c3664d8a5983926 Binary files /dev/null and b/modules/components/m2m_unimatch/__pycache__/backwarp.cpython-39.pyc differ diff --git a/modules/components/m2m_unimatch/__pycache__/costvol.cpython-310.pyc b/modules/components/m2m_unimatch/__pycache__/costvol.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6009bc64583ded53ccc4a9b02940604e64f0078 Binary files /dev/null and b/modules/components/m2m_unimatch/__pycache__/costvol.cpython-310.pyc differ diff --git a/modules/components/m2m_unimatch/__pycache__/costvol.cpython-38.pyc b/modules/components/m2m_unimatch/__pycache__/costvol.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..953070a4bcca514754d61ff7a5634f197a40269a Binary files /dev/null and b/modules/components/m2m_unimatch/__pycache__/costvol.cpython-38.pyc differ diff --git a/modules/components/m2m_unimatch/__pycache__/costvol.cpython-39.pyc b/modules/components/m2m_unimatch/__pycache__/costvol.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ab319b626cbf7d6de61631c9197a7058f5e7217 Binary files /dev/null and b/modules/components/m2m_unimatch/__pycache__/costvol.cpython-39.pyc differ diff --git a/modules/components/m2m_unimatch/__pycache__/m2m.cpython-310.pyc b/modules/components/m2m_unimatch/__pycache__/m2m.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..762e28d99a84221aeec6653dff103e25440a4d3f Binary files /dev/null and b/modules/components/m2m_unimatch/__pycache__/m2m.cpython-310.pyc differ diff --git a/modules/components/m2m_unimatch/__pycache__/m2m.cpython-38.pyc b/modules/components/m2m_unimatch/__pycache__/m2m.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..130a036a045b979fe4900e14cd2a1709c826ec48 Binary files /dev/null and b/modules/components/m2m_unimatch/__pycache__/m2m.cpython-38.pyc differ diff --git a/modules/components/m2m_unimatch/__pycache__/m2m.cpython-39.pyc b/modules/components/m2m_unimatch/__pycache__/m2m.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..063b1ff2a30f811d1a7b254969a6988710907524 Binary files /dev/null and b/modules/components/m2m_unimatch/__pycache__/m2m.cpython-39.pyc differ diff --git a/modules/components/m2m_unimatch/__pycache__/pwcnet.cpython-310.pyc b/modules/components/m2m_unimatch/__pycache__/pwcnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e200a124600d38f29e4a376326cb2b30c025d6a2 Binary files /dev/null and b/modules/components/m2m_unimatch/__pycache__/pwcnet.cpython-310.pyc differ diff --git a/modules/components/m2m_unimatch/__pycache__/pwcnet.cpython-38.pyc b/modules/components/m2m_unimatch/__pycache__/pwcnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5eb8589c9bc30a9a646a9d756110bca25f2f7fe4 Binary files /dev/null and b/modules/components/m2m_unimatch/__pycache__/pwcnet.cpython-38.pyc differ diff --git a/modules/components/m2m_unimatch/__pycache__/pwcnet.cpython-39.pyc b/modules/components/m2m_unimatch/__pycache__/pwcnet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e333120ef4323c2350963d5d8df8f42a8efd93cb Binary files /dev/null and b/modules/components/m2m_unimatch/__pycache__/pwcnet.cpython-39.pyc differ diff --git a/modules/components/m2m_unimatch/__pycache__/softsplat.cpython-310.pyc b/modules/components/m2m_unimatch/__pycache__/softsplat.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc2affb0d4b5af26701897374f0f85a0157974ff Binary files /dev/null and b/modules/components/m2m_unimatch/__pycache__/softsplat.cpython-310.pyc differ diff --git a/modules/components/m2m_unimatch/__pycache__/softsplat.cpython-38.pyc b/modules/components/m2m_unimatch/__pycache__/softsplat.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2392d6073b14c502a52569078f4a9785b1b1034a Binary files /dev/null and b/modules/components/m2m_unimatch/__pycache__/softsplat.cpython-38.pyc differ diff --git a/modules/components/m2m_unimatch/__pycache__/softsplat.cpython-39.pyc b/modules/components/m2m_unimatch/__pycache__/softsplat.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2bd77f955e012299c9349a6710fc633a6f385b70 Binary files /dev/null and b/modules/components/m2m_unimatch/__pycache__/softsplat.cpython-39.pyc differ diff --git a/modules/components/m2m_unimatch/backwarp.py b/modules/components/m2m_unimatch/backwarp.py new file mode 100644 index 0000000000000000000000000000000000000000..e99a0a5c1b658e81536825451b865b39c45bc9c4 --- /dev/null +++ b/modules/components/m2m_unimatch/backwarp.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python + +import torch + + +########################################################## + + +objBackwarpcache = {} + + +def backwarp(tenIn:torch.Tensor, tenFlow:torch.Tensor): + if 'grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3]) not in objBackwarpcache: + tenHor = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[3], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, 1, -1).repeat(1, 1, tenFlow.shape[2], 1) + tenVer = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[2], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, -1, 1).repeat(1, 1, 1, tenFlow.shape[3]) + + objBackwarpcache['grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3])] = torch.cat([tenHor, tenVer], 1) + # end + + if tenFlow.shape[3] == tenFlow.shape[2]: + tenFlow = tenFlow * (2.0 / ((tenFlow.shape[3] and tenFlow.shape[2]) - 1.0)) + + elif tenFlow.shape[3] != tenFlow.shape[2]: + tenFlow = tenFlow * torch.tensor(data=[2.0 / (tenFlow.shape[3] - 1.0), 2.0 / (tenFlow.shape[2] - 1.0)], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 2, 1, 1) + + # end + + return torch.nn.functional.grid_sample(input=tenIn, grid=(objBackwarpcache['grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3])] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=True) +# end diff --git a/modules/components/m2m_unimatch/costvol.py b/modules/components/m2m_unimatch/costvol.py new file mode 100644 index 0000000000000000000000000000000000000000..40e1cfb5b95f948321fb4429321dbf3dd48f9288 --- /dev/null +++ b/modules/components/m2m_unimatch/costvol.py @@ -0,0 +1,424 @@ +#!/usr/bin/env python + +import collections +import cupy +import os +import re +import torch +import typing + + +########################################################## + + +objCudacache = {} + + +def cuda_int32(intIn:int): + return cupy.int32(intIn) +# end + + +def cuda_float32(fltIn:float): + return cupy.float32(fltIn) +# end + + +def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict): + if 'device' not in objCudacache: + objCudacache['device'] = torch.cuda.get_device_name() + # end + + strKey = strFunction + + for strVariable in objVariables: + objValue = objVariables[strVariable] + + strKey += strVariable + + if objValue is None: + continue + + elif type(objValue) == int: + strKey += str(objValue) + + elif type(objValue) == float: + strKey += str(objValue) + + elif type(objValue) == bool: + strKey += str(objValue) + + elif type(objValue) == str: + strKey += objValue + + elif type(objValue) == torch.Tensor: + strKey += str(objValue.dtype) + strKey += str(objValue.shape) + strKey += str(objValue.stride()) + + elif True: + print(strVariable, type(objValue)) + assert(False) + + # end + # end + + strKey += objCudacache['device'] + + if strKey not in objCudacache: + for strVariable in objVariables: + objValue = objVariables[strVariable] + + if objValue is None: + continue + + elif type(objValue) == int: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == float: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == bool: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == str: + strKernel = strKernel.replace('{{' + strVariable + '}}', objValue) + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8: + strKernel = strKernel.replace('{{type}}', 'unsigned char') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16: + strKernel = strKernel.replace('{{type}}', 'half') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32: + strKernel = strKernel.replace('{{type}}', 'float') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64: + strKernel = strKernel.replace('{{type}}', 'double') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32: + strKernel = strKernel.replace('{{type}}', 'int') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64: + strKernel = strKernel.replace('{{type}}', 'long') + + elif type(objValue) == torch.Tensor: + print(strVariable, objValue.dtype) + assert(False) + + elif True: + print(strVariable, type(objValue)) + assert(False) + + # end + # end + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + # end + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item())) + # end + + while True: + objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == '(' else 0 + intParentheses -= 1 if strKernel[intStop] == ')' else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(',') + + assert(intArgs == len(strArgs) - 1) + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')') + # end + + strKernel = strKernel.replace('OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', '(' + str.join('+', strIndex) + ')') + # end + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()', strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == '(' else 0 + intParentheses -= 1 if strKernel[intStop] == ')' else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(',') + + assert(intArgs == len(strArgs) - 1) + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')') + # end + + strKernel = strKernel.replace('VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', strTensor + '[' + str.join('+', strIndex) + ']') + # end + + objCudacache[strKey] = { + 'strFunction': strFunction, + 'strKernel': strKernel + } + # end + + return strKey +# end + + +@cupy.memoize(for_each_device=True) +def cuda_launch(strKey:str): + if 'CUDA_HOME' not in os.environ: + os.environ['CUDA_HOME'] = '/usr/local/cuda/' + # end + + return cupy.cuda.compile_with_cache(objCudacache[strKey]['strKernel'], tuple(['-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include'])).get_function(objCudacache[strKey]['strFunction']) +# end + + +########################################################## + + +class costvol_func(torch.autograd.Function): + @staticmethod + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward(self, tenOne, tenTwo): + tenOut = tenOne.new_empty([tenOne.shape[0], 81, tenOne.shape[2], tenOne.shape[3]]) + + cuda_launch(cuda_kernel('costvol_out', ''' + extern "C" __global__ void __launch_bounds__(512) costvol_out( + const int n, + const {{type}}* __restrict__ tenOne, + const {{type}}* __restrict__ tenTwo, + {{type}}* __restrict__ tenOut + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_0(tenOut); + const int intC = -1; + const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut); + const int intX = ( intIndex ) % SIZE_3(tenOut); + + {{type}} fltOne[{{intChans}}]; + + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX); + } + + int intOffset = OFFSET_4(tenOut, intN, 0, intY, intX); + + for (int intOy = intY - 4; intOy <= intY + 4; intOy += 1) { + for (int intOx = intX - 4; intOx <= intX + 4; intOx += 1) { + {{type}} fltValue = 0.0f; + + if ((intOy >= 0) && (intOy < SIZE_2(tenOut)) && (intOx >= 0) && (intOx < SIZE_3(tenOut))) { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltValue += abs(fltOne[intValue] - VALUE_4(tenTwo, intN, intValue, intOy, intOx)); + } + } else { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltValue += abs(fltOne[intValue]); + } + } + + tenOut[intOffset] = fltValue / SIZE_1(tenOne); + intOffset += SIZE_2(tenOut) * SIZE_3(tenOut); + } + } + } } + ''', { + 'intChans': tenOne.shape[1], + 'tenOne': tenOne, + 'tenTwo': tenTwo, + 'tenOut': tenOut + }))( + grid=tuple([int(((tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]) + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), tenOut.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + + self.save_for_backward(tenOne, tenTwo) + + return tenOut + # end + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(self, tenOutgrad): + tenOne, tenTwo = self.saved_tensors + + tenOutgrad = tenOutgrad.contiguous(); assert(tenOutgrad.is_cuda == True) + + tenOnegrad = tenOne.new_zeros([tenOne.shape[0], tenOne.shape[1], tenOne.shape[2], tenOne.shape[3]]) if self.needs_input_grad[0] == True else None + tenTwograd = tenTwo.new_zeros([tenTwo.shape[0], tenTwo.shape[1], tenTwo.shape[2], tenTwo.shape[3]]) if self.needs_input_grad[1] == True else None + + if tenOnegrad is not None: + cuda_launch(cuda_kernel('costvol_onegrad', ''' + extern "C" __global__ void __launch_bounds__(512) costvol_onegrad( + const int n, + const {{type}}* __restrict__ tenOne, + const {{type}}* __restrict__ tenTwo, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenOnegrad, + {{type}}* __restrict__ tenTwograd + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenOnegrad) / SIZE_2(tenOnegrad) ) % SIZE_0(tenOnegrad); + const int intC = -1; + const int intY = ( intIndex / SIZE_3(tenOnegrad) ) % SIZE_2(tenOnegrad); + const int intX = ( intIndex ) % SIZE_3(tenOnegrad); + + {{type}} fltOne[{{intChans}}]; + + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX); + } + + int intOffset = OFFSET_4(tenOutgrad, intN, 0, intY, intX); + + for (int intOy = intY - 4; intOy <= intY + 4; intOy += 1) { + for (int intOx = intX - 4; intOx <= intX + 4; intOx += 1) { + if ((intOy >= 0) && (intOy < SIZE_2(tenOutgrad)) && (intOx >= 0) && (intOx < SIZE_3(tenOutgrad))) { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + if (fltOne[intValue] - VALUE_4(tenTwo, intN, intValue, intOy, intOx) >= 0.0f) { + tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += +tenOutgrad[intOffset] / SIZE_1(tenOne); + } else { + tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += -tenOutgrad[intOffset] / SIZE_1(tenOne); + } + } + } else { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + if (fltOne[intValue] >= 0.0f) { + tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += +tenOutgrad[intOffset] / SIZE_1(tenOne); + } else { + tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += -tenOutgrad[intOffset] / SIZE_1(tenOne); + } + } + } + + intOffset += SIZE_2(tenOutgrad) * SIZE_3(tenOutgrad); + } + } + } } + ''', { + 'intChans': tenOne.shape[1], + 'tenOne': tenOne, + 'tenTwo': tenTwo, + 'tenOutgrad': tenOutgrad, + 'tenOnegrad': tenOnegrad, + 'tenTwograd': tenTwograd + }))( + grid=tuple([int(((tenOnegrad.shape[0] * tenOnegrad.shape[2] * tenOnegrad.shape[3]) + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenOnegrad.shape[0] * tenOnegrad.shape[2] * tenOnegrad.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), tenOutgrad.data_ptr(), tenOnegrad.data_ptr(), tenTwograd.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + # end + + if tenTwograd is not None: + cuda_launch(cuda_kernel('costvol_twograd', ''' + extern "C" __global__ void __launch_bounds__(512) costvol_twograd( + const int n, + const {{type}}* __restrict__ tenOne, + const {{type}}* __restrict__ tenTwo, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenOnegrad, + {{type}}* __restrict__ tenTwograd + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenTwograd) / SIZE_2(tenTwograd) ) % SIZE_0(tenTwograd); + const int intC = -1; + const int intY = ( intIndex / SIZE_3(tenTwograd) ) % SIZE_2(tenTwograd); + const int intX = ( intIndex ) % SIZE_3(tenTwograd); + + {{type}} fltOne[{{intChans}}]; + + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX); + } + + int intOffset = OFFSET_4(tenOutgrad, intN, 0, intY, intX); + + for (int intOy = intY - 4; intOy <= intY + 4; intOy += 1) { + for (int intOx = intX - 4; intOx <= intX + 4; intOx += 1) { + if ((intOy >= 0) && (intOy < SIZE_2(tenOutgrad)) && (intOx >= 0) && (intOx < SIZE_3(tenOutgrad))) { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + if (fltOne[intValue] - VALUE_4(tenTwo, intN, intValue, intOy, intOx) >= 0.0f) { + atomicAdd(&tenTwograd[OFFSET_4(tenTwograd, intN, intValue, intOy, intOx)], -tenOutgrad[intOffset] / SIZE_1(tenOne)); + } else { + atomicAdd(&tenTwograd[OFFSET_4(tenTwograd, intN, intValue, intOy, intOx)], +tenOutgrad[intOffset] / SIZE_1(tenOne)); + } + } + } else { + // ... + } + + intOffset += SIZE_2(tenOutgrad) * SIZE_3(tenOutgrad); + } + } + } } + ''', { + 'intChans': tenOne.shape[1], + 'tenOne': tenOne, + 'tenTwo': tenTwo, + 'tenOutgrad': tenOutgrad, + 'tenOnegrad': tenOnegrad, + 'tenTwograd': tenTwograd + }))( + grid=tuple([int(((tenTwograd.shape[0] * tenTwograd.shape[2] * tenTwograd.shape[3]) + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenTwograd.shape[0] * tenTwograd.shape[2] * tenTwograd.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), tenOutgrad.data_ptr(), tenOnegrad.data_ptr(), tenTwograd.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + # end + + return tenOnegrad, tenTwograd, None, None + # end +# end diff --git a/modules/components/m2m_unimatch/m2m.py b/modules/components/m2m_unimatch/m2m.py new file mode 100644 index 0000000000000000000000000000000000000000..2a0765ef80e07386eb4a6bda641b1c572062652a --- /dev/null +++ b/modules/components/m2m_unimatch/m2m.py @@ -0,0 +1,423 @@ + +import math +import torch +import typing + +from ..components import register +from .backwarp import * +from .pwcnet import * +from .softsplat import * +from .unimatch.unimatch import UniMatch + + +def photometric_consistency(img0, img1, flow01): + return (img0 - backwarp(img1, flow01)).abs().sum(dim=1, keepdims=True) + + +def flow_consistency(flow01, flow10): + return (flow01 + backwarp(flow10, flow01)).abs().sum(dim=1, keepdims=True) + + + + +def gaussian(x): + gaussian_kernel = torch.tensor([[1, 2, 1], + [2, 4, 2], + [1, 2, 1]]) / 16 + gaussian_kernel = gaussian_kernel.repeat(2, 1, 1, 1) + gaussian_kernel = gaussian_kernel.to(torch.cuda.current_device()) + x = torch.nn.functional.pad(x, (1, 1, 1, 1), mode='reflect') + out = torch.nn.functional.conv2d(x, gaussian_kernel, groups=x.shape[1]) + # out = TF.gaussian_blur(x, [3, 3], sigma=[2, 2]) + return out + + +def variance_flow(flow): + flow = flow * torch.tensor(data=[2.0 / (flow.shape[3] - 1.0), 2.0 / (flow.shape[2] - 1.0)], dtype=flow.dtype, + device=flow.device).view(1, 2, 1, 1) + return (gaussian(flow ** 2) - gaussian(flow) ** 2 + 1e-4).sqrt().abs().sum(dim=1, keepdim=True) + +########################################################## + +def forwarp_mframe_mask(tenIn1, tenFlow1, t1, tenIn2, tenFlow2, t2, tenMetric1=None, tenMetric2=None): + def one_fdir(tenIn, tenFlow, td, tenMetric): + tenIn = torch.cat([tenIn * td * (tenMetric).clip(-20.0, 20.0).exp(), td * (tenMetric).clip(-20.0, 20.0).exp()], + 1) + + tenOut = softsplat_func.apply(tenIn, tenFlow) + + return tenOut[:, :-1, :, :], tenOut[:, -1:, :, :] + 0.0000001 + + flow_num = tenFlow1.shape[0] + tenOut = 0 + tenNormalize = 0 + for idx in range(flow_num): + tenOutF, tenNormalizeF = one_fdir(tenIn1[idx], tenFlow1[idx], t1[idx], tenMetric1[idx]) + tenOutB, tenNormalizeB = one_fdir(tenIn2[idx], tenFlow2[idx], t2[idx], tenMetric2[idx]) + + tenOut += tenOutF + tenOutB + tenNormalize += tenNormalizeF + tenNormalizeB + + return tenOut / tenNormalize, tenNormalize < 0.00001 + + +################################################################### + +c = 16 + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return torch.nn.Sequential( + torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + torch.nn.PReLU(out_planes) + ) + + +def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): + return torch.nn.Sequential( + torch.torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, + kernel_size=kernel_size, stride=stride, padding=padding, bias=True), + torch.nn.PReLU(out_planes) + ) + + +class Conv2(torch.nn.Module): + def __init__(self, in_planes, out_planes, stride=2): + super(Conv2, self).__init__() + self.conv1 = conv(in_planes, out_planes, 3, stride, 1) + self.conv2 = conv(out_planes, out_planes, 3, 1, 1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + +class Conv2n(torch.nn.Module): + def __init__(self, in_planes, out_planes, stride=2): + super(Conv2n, self).__init__() + self.conv1 = conv(in_planes, in_planes, 3, stride, 1) + self.conv2 = conv(in_planes, in_planes, 3, 1, 1) + self.conv3 = conv(in_planes, in_planes, 1, 1, 0) + self.conv4 = conv(in_planes, out_planes, 1, 1, 0) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + return x + + +##################################################### + +class ImgPyramid(torch.nn.Module): + def __init__(self): + super(ImgPyramid, self).__init__() + self.conv1 = Conv2(3, c) + self.conv2 = Conv2(c, 2 * c) + self.conv3 = Conv2(2 * c, 4 * c) + self.conv4 = Conv2(4 * c, 8 * c) + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(x1) + x3 = self.conv3(x2) + x4 = self.conv4(x3) + return [x1, x2, x3, x4] + + +class EncDec(torch.nn.Module): + def __init__(self, branch): + super(EncDec, self).__init__() + self.branch = branch + + self.down0 = Conv2(8, 2 * c) + self.down1 = Conv2(6 * c, 4 * c) + self.down2 = Conv2(12 * c, 8 * c) + self.down3 = Conv2(24 * c, 16 * c) + + self.up0 = deconv(48 * c, 8 * c) + self.up1 = deconv(16 * c, 4 * c) + self.up2 = deconv(8 * c, 2 * c) + self.up3 = deconv(4 * c, c) + self.conv = torch.nn.Conv2d(c, 2 * self.branch, 3, 1, 1) + + self.conv_m = torch.nn.Conv2d(c, 1, 3, 1, 1) + + # For Channel dimennsion + self.conv_C = torch.nn.Sequential( + torch.nn.AdaptiveAvgPool2d(1), + torch.nn.Conv2d(16 * c, 16 * 16 * c, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True), + torch.nn.Sigmoid() + ) + + # For Height dimennsion + self.conv_H = torch.nn.Sequential( + torch.nn.AdaptiveAvgPool2d((None, 1)), + torch.nn.Conv2d(16 * c, 16, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True), + torch.nn.Sigmoid() + ) + + # For Width dimennsion + self.conv_W = torch.nn.Sequential( + torch.nn.AdaptiveAvgPool2d((1, None)), + torch.nn.Conv2d(16 * c, 16, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True), + torch.nn.Sigmoid() + ) + + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, flow0, flow1, im0, im1, c0, c1): + N_, C_, H_, W_ = im0.shape + + wim1 = backwarp(im1, flow0) + wim0 = backwarp(im0, flow1) + s0_0 = self.down0(torch.cat((flow0, im0, wim1), 1)) + s1_0 = self.down0(torch.cat((flow1, im1, wim0), 1)) + + ######################################################################################### + flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + + wf0 = backwarp(torch.cat((s0_0, c0[0]), 1), flow1) + wf1 = backwarp(torch.cat((s1_0, c1[0]), 1), flow0) + + s0_1 = self.down1(torch.cat((s0_0, c0[0], wf1), 1)) + s1_1 = self.down1(torch.cat((s1_0, c1[0], wf0), 1)) + + ######################################################################################### + flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + + wf0 = backwarp(torch.cat((s0_1, c0[1]), 1), flow1) + wf1 = backwarp(torch.cat((s1_1, c1[1]), 1), flow0) + + s0_2 = self.down2(torch.cat((s0_1, c0[1], wf1), 1)) + s1_2 = self.down2(torch.cat((s1_1, c1[1], wf0), 1)) + + ######################################################################################### + flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + + wf0 = backwarp(torch.cat((s0_2, c0[2]), 1), flow1) + wf1 = backwarp(torch.cat((s1_2, c1[2]), 1), flow0) + + s0_3 = self.down3(torch.cat((s0_2, c0[2], wf1), 1)) + s1_3 = self.down3(torch.cat((s1_2, c1[2], wf0), 1)) + + ######################################################################################### + + s0_3_c = self.conv_C(s0_3) + s0_3_c = s0_3_c.view(N_, 16, -1, 1, 1) + + s0_3_h = self.conv_H(s0_3) + s0_3_h = s0_3_h.view(N_, 16, 1, -1, 1) + + s0_3_w = self.conv_W(s0_3) + s0_3_w = s0_3_w.view(N_, 16, 1, 1, -1) + + cube0 = (s0_3_c * s0_3_h * s0_3_w).mean(1) + + s0_3 = s0_3 * cube0 + + s1_3_c = self.conv_C(s1_3) + s1_3_c = s1_3_c.view(N_, 16, -1, 1, 1) + + s1_3_h = self.conv_H(s1_3) + s1_3_h = s1_3_h.view(N_, 16, 1, -1, 1) + + s1_3_w = self.conv_W(s1_3) + s1_3_w = s1_3_w.view(N_, 16, 1, 1, -1) + + cube1 = (s1_3_c * s1_3_h * s1_3_w).mean(1) + + s1_3 = s1_3 * cube1 + + ######################################################################################### + flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + + wf0 = backwarp(torch.cat((s0_3, c0[3]), 1), flow1) + wf1 = backwarp(torch.cat((s1_3, c1[3]), 1), flow0) + + x0 = self.up0(torch.cat((s0_3, c0[3], wf1), 1)) + x1 = self.up0(torch.cat((s1_3, c1[3], wf0), 1)) + + x0 = self.up1(torch.cat((s0_2, x0), 1)) + x1 = self.up1(torch.cat((s1_2, x1), 1)) + + x0 = self.up2(torch.cat((s0_1, x0), 1)) + x1 = self.up2(torch.cat((s1_1, x1), 1)) + + x0 = self.up3(torch.cat((s0_0, x0), 1)) + x1 = self.up3(torch.cat((s1_0, x1), 1)) + + m0 = self.sigmoid(self.conv_m(x0)) * 0.8 + 0.1 + m1 = self.sigmoid(self.conv_m(x1)) * 0.8 + 0.1 + + x0 = self.conv(x0) + x1 = self.conv(x1) + + return x0, x1, m0.repeat(1, self.branch, 1, 1), m1.repeat(1, self.branch, 1, 1) + + +@register('m2m_unimatch') +class M2M_PWC(torch.nn.Module): + def __init__(self, ratio=4): + super(M2M_PWC, self).__init__() + self.branch = 4 + self.ratio = ratio + + self.netFlow = UniMatch(num_scales=2, feature_channels=128, upsample_factor=4, + num_head=1, ffn_dim_expansion=4, num_transformer_layers=6, + reg_refine=True, task='flow') + for p in self.netFlow.parameters(): + p.requires_grad = False + + # self.paramAlpha = torch.nn.Parameter(10.0 * torch.ones(1, 1, 1, 1)) + + class MotionRefineNet(torch.nn.Module): + def __init__(self, branch): + super(MotionRefineNet, self).__init__() + self.branch = branch + self.img_pyramid = ImgPyramid() + self.motion_encdec = EncDec(branch) + + def forward(self, flow0, flow1, im0, im1, ratio): + flow0 = ratio * torch.nn.functional.interpolate(input=flow0, scale_factor=ratio, mode='bilinear', + align_corners=False) + flow1 = ratio * torch.nn.functional.interpolate(input=flow1, scale_factor=ratio, mode='bilinear', + align_corners=False) + + c0 = self.img_pyramid(im0) + c1 = self.img_pyramid(im1) + + flow_res = self.motion_encdec(flow0, flow1, im0, im1, c0, c1) + + flow0 = flow0.repeat(1, self.branch, 1, 1) + flow_res[0] + flow1 = flow1.repeat(1, self.branch, 1, 1) + flow_res[1] + + return flow0, flow1, flow_res[2], flow_res[3] + + self.MRN = MotionRefineNet(self.branch) + + self.alpha = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_photo_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_flow_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_variation_flow = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + + def get_splat_weight(self, img0, img1, flow01, flow10): + M_splat = 1 / (1 + self.alpha_splat_photo_consistency * photometric_consistency(img0, img1, flow01).detach()) + \ + 1 / (1 + self.alpha_splat_flow_consistency * flow_consistency(flow01, flow10).detach()) + \ + 1 / (1 + self.alpha_splat_variation_flow * variance_flow(flow01).detach()) + return M_splat * self.alpha + + def forward(self, img0, img1, time_step=[0.5], ratio=None, **kwargs): + if ratio is None: + ratio = self.ratio + + intWidth = img0.shape[3] and img1.shape[3] + intHeight = img0.shape[2] and img1.shape[2] + + intPadr = ((ratio * 16) - (intWidth % (ratio * 16))) % (ratio * 16) + intPadb = ((ratio * 16) - (intHeight % (ratio * 16))) % (ratio * 16) + + img0 = torch.nn.functional.pad(input=img0, pad=[0, intPadr, 0, intPadb], mode='replicate') + img1 = torch.nn.functional.pad(input=img1, pad=[0, intPadr, 0, intPadb], mode='replicate') + + N_, C_, H_, W_ = img0.shape + + outputs = [] + result_dict = {} + + im0_ = torch.nn.functional.interpolate(input=img0, scale_factor=1.0 / ratio, mode='bilinear', + align_corners=False) + im1_ = torch.nn.functional.interpolate(input=img1, scale_factor=1.0 / ratio, mode='bilinear', + align_corners=False) + + flow_preds = self.netFlow(im0_, im1_, 'swin', [2, 8], [-1, 4], [-1, 1], 6, True) + tenFwds, tenBwds = [], [] + for flow_pred in flow_preds: + tenFwd, tenBwd = torch.chunk(flow_pred, 2, dim=0) + tenFwds.append(tenFwd) + tenBwds.append(tenBwd) + + with torch.set_grad_enabled(False): + tenStats = [img0, img1] + tenMean_ = sum([tenIn.mean([1, 2, 3], True) for tenIn in tenStats]) / len(tenStats) + tenStd_ = (sum([tenIn.std([1, 2, 3], False, True).square() + ( + tenMean_ - tenIn.mean([1, 2, 3], True)).square() for tenIn in tenStats]) / len(tenStats)).sqrt() + + im0_o = (img0 - tenMean_) / (tenStd_ + 0.0000001) + im1_o = (img1 - tenMean_) / (tenStd_ + 0.0000001) + + img0 = (img0 - tenMean_) / (tenStd_ + 0.0000001) + img1 = (img1 - tenMean_) / (tenStd_ + 0.0000001) + + result_dict['flowfwd'] = torch.nn.functional.interpolate(tenFwd, scale_factor=ratio, mode='bilinear', align_corners=False)[:, :, + :intHeight, :intWidth].clone().detach() * ratio + result_dict['flowbwd'] = torch.nn.functional.interpolate(tenBwd, scale_factor=ratio, mode='bilinear', align_corners=False)[:, :, + :intHeight, :intWidth].clone().detach() * ratio + + for i in range(len(tenFwds)): + tenFwd, tenBwd, WeiMF, WeiMB = self.MRN(tenFwds[i], tenBwds[i], img0, img1, ratio) + + img0_ = im0_o.repeat(1, self.branch, 1, 1) + img1_ = im1_o.repeat(1, self.branch, 1, 1) + tenStd = tenStd_.repeat(1, self.branch, 1, 1) + tenMean = tenMean_.repeat(1, self.branch, 1, 1) + fltTime = time_step.repeat(1, self.branch, 1, 1) + + tenFwd = tenFwd.reshape(N_, self.branch, 2, H_, W_).view(N_ * self.branch, 2, H_, W_) + tenBwd = tenBwd.reshape(N_, self.branch, 2, H_, W_).view(N_ * self.branch, 2, H_, W_) + + WeiMF = WeiMF.reshape(N_, self.branch, 1, H_, W_).view(N_ * self.branch, 1, H_, W_) + WeiMB = WeiMB.reshape(N_, self.branch, 1, H_, W_).view(N_ * self.branch, 1, H_, W_) + + img0_ = img0_.reshape(N_, self.branch, 3, H_, W_).view(N_ * self.branch, 3, H_, W_) + img1_ = img1_.reshape(N_, self.branch, 3, H_, W_).view(N_ * self.branch, 3, H_, W_) + + tenStd = tenStd.reshape(N_, self.branch, 1, 1, 1).view(N_ * self.branch, 1, 1, 1) + tenMean = tenMean.reshape(N_, self.branch, 1, 1, 1).view(N_ * self.branch, 1, 1, 1) + fltTime = fltTime.reshape(N_, self.branch, 1, 1, 1).view(N_ * self.branch, 1, 1, 1) + + tenPhotoone = self.get_splat_weight(img0_, img1_, tenFwd, tenBwd) * WeiMF + tenPhototwo = self.get_splat_weight(img1_, img0_, tenBwd, tenFwd) * WeiMB + + t0 = fltTime + flow0 = tenFwd * t0 + metric0 = tenPhotoone + + t1 = 1.0 - fltTime + flow1 = tenBwd * t1 + metric1 = tenPhototwo + + flow0 = flow0.reshape(N_, self.branch, 2, H_, W_).permute(1, 0, 2, 3, 4) + flow1 = flow1.reshape(N_, self.branch, 2, H_, W_).permute(1, 0, 2, 3, 4) + + metric0 = metric0.reshape(N_, self.branch, 1, H_, W_).permute(1, 0, 2, 3, 4) + metric1 = metric1.reshape(N_, self.branch, 1, H_, W_).permute(1, 0, 2, 3, 4) + + img0_ = img0_.reshape(N_, self.branch, 3, H_, W_).permute(1, 0, 2, 3, 4) + img1_ = img1_.reshape(N_, self.branch, 3, H_, W_).permute(1, 0, 2, 3, 4) + + t0 = t0.reshape(N_, self.branch, 1, 1, 1).permute(1, 0, 2, 3, 4) + t1 = t1.reshape(N_, self.branch, 1, 1, 1).permute(1, 0, 2, 3, 4) + + tenOutput, mask = forwarp_mframe_mask(img0_, flow0, t1, img1_, flow1, t0, metric0, metric1) + + tenOutput = tenOutput + mask * (t1.mean(0) * im0_o + t0.mean(0) * im1_o) + + output = (tenOutput * (tenStd_ + 0.0000001)) + tenMean_ + outputs.append(output[:, :, :intHeight, :intWidth]) + result_dict['imgt_preds'] = outputs + result_dict['imgt_pred'] = outputs[-1] + tenFwds.append(tenFwd.reshape(N_, self.branch, 2, H_, W_)) + tenBwds.append(tenBwd.reshape(N_, self.branch, 2, H_, W_)) + result_dict['flow0_pred'] = tenFwds[::-1] + result_dict['flow1_pred'] = tenBwds[::-1] + + return result_dict + diff --git a/modules/components/m2m_unimatch/pwcnet.py b/modules/components/m2m_unimatch/pwcnet.py new file mode 100644 index 0000000000000000000000000000000000000000..77d2c736e4d5262858b51667931454adb8f23a65 --- /dev/null +++ b/modules/components/m2m_unimatch/pwcnet.py @@ -0,0 +1,308 @@ +#!/usr/bin/env python + +import math +import torch +import typing + +from .backwarp import backwarp +from .costvol import costvol_func + + +########################################################## + + +class Basic(torch.nn.Module): + def __init__(self, strType:str, intChans:typing.List[int], objScratch:typing.Optional[typing.Dict]=None): + super().__init__() + + self.strType = strType + self.netEvenize = None + self.netMain = None + self.netShortcut = None + + intIn = intChans[0] + intOut = intChans[-1] + netMain = [] + intChans = intChans.copy() + fltStride = 1.0 + + for intPart, strPart in enumerate(self.strType.split('+')[0].split('-')): + if strPart.startswith('evenize') == True and intPart == 0: + class Evenize(torch.nn.Module): + def __init__(self, strPad): + super().__init__() + + self.strPad = strPad + # end + + def forward(self, tenIn:torch.Tensor) -> torch.Tensor: + intPad = [0, 0, 0, 0] + + if tenIn.shape[3] % 2 != 0: intPad[1] = 1 + if tenIn.shape[2] % 2 != 0: intPad[3] = 1 + + if min(intPad) != 0 or max(intPad) != 0: + tenIn = torch.nn.functional.pad(input=tenIn, pad=intPad, mode=self.strPad if self.strPad != 'zeros' else 'constant', value=0.0) + # end + + return tenIn + # end + # end + + strPad = 'zeros' + + if '(' in strPart: + if 'replpad' in strPart.split('(')[1].split(')')[0].split(','): strPad = 'replicate' + if 'reflpad' in strPart.split('(')[1].split(')')[0].split(','): strPad = 'reflect' + # end + + self.netEvenize = Evenize(strPad) + + elif strPart.startswith('conv') == True: + intKsize = 3 + intPad = 1 + strPad = 'zeros' + + if '(' in strPart: + intKsize = int(strPart.split('(')[1].split(')')[0].split(',')[0]) + intPad = int(math.floor(0.5 * (intKsize - 1))) + + if 'replpad' in strPart.split('(')[1].split(')')[0].split(','): strPad = 'replicate' + if 'reflpad' in strPart.split('(')[1].split(')')[0].split(','): strPad = 'reflect' + # end + + if 'nopad' in self.strType.split('+'): + intPad = 0 + # end + + netMain += [torch.nn.Conv2d(in_channels=intChans[0], out_channels=intChans[1], kernel_size=intKsize, stride=1, padding=intPad, padding_mode=strPad, bias='nobias' not in self.strType.split('+'))] + intChans = intChans[1:] + fltStride *= 1.0 + + elif strPart.startswith('sconv') == True: + intKsize = 3 + intPad = 1 + strPad = 'zeros' + + if '(' in strPart: + intKsize = int(strPart.split('(')[1].split(')')[0].split(',')[0]) + intPad = int(math.floor(0.5 * (intKsize - 1))) + + if 'replpad' in strPart.split('(')[1].split(')')[0].split(','): strPad = 'replicate' + if 'reflpad' in strPart.split('(')[1].split(')')[0].split(','): strPad = 'reflect' + # end + + if 'nopad' in self.strType.split('+'): + intPad = 0 + # end + + netMain += [torch.nn.Conv2d(in_channels=intChans[0], out_channels=intChans[1], kernel_size=intKsize, stride=2, padding=intPad, padding_mode=strPad, bias='nobias' not in self.strType.split('+'))] + intChans = intChans[1:] + fltStride *= 2.0 + + elif strPart.startswith('up') == True: + class Up(torch.nn.Module): + def __init__(self, strType): + super().__init__() + + self.strType = strType + # end + + def forward(self, tenIn:torch.Tensor) -> torch.Tensor: + if self.strType == 'nearest': + return torch.nn.functional.interpolate(input=tenIn, scale_factor=2.0, mode='nearest-exact', align_corners=False) + + elif self.strType == 'bilinear': + return torch.nn.functional.interpolate(input=tenIn, scale_factor=2.0, mode='bilinear', align_corners=False) + + elif self.strType == 'pyramid': + return pyramid(tenIn, None, 'up') + + elif self.strType == 'shuffle': + return torch.nn.functional.pixel_shuffle(tenIn, upscale_factor=2) # https://github.com/pytorch/pytorch/issues/62854 + + # end + + assert(False) # to make torchscript happy + # end + # end + + strType = 'bilinear' + + if '(' in strPart: + if 'nearest' in strPart.split('(')[1].split(')')[0].split(','): strType = 'nearest' + if 'pyramid' in strPart.split('(')[1].split(')')[0].split(','): strType = 'pyramid' + if 'shuffle' in strPart.split('(')[1].split(')')[0].split(','): strType = 'shuffle' + # end + + netMain += [Up(strType)] + fltStride *= 0.5 + + elif strPart.startswith('prelu') == True: + netMain += [torch.nn.PReLU(num_parameters=1, init=float(strPart.split('(')[1].split(')')[0].split(',')[0]))] + fltStride *= 1.0 + + elif True: + assert(False) + + # end + # end + + self.netMain = torch.nn.Sequential(*netMain) + + for strPart in self.strType.split('+')[1:]: + if strPart.startswith('skip') == True: + if intIn == intOut and fltStride == 1.0: + self.netShortcut = torch.nn.Identity() + + elif intIn != intOut and fltStride == 1.0: + self.netShortcut = torch.nn.Conv2d(in_channels=intIn, out_channels=intOut, kernel_size=1, stride=1, padding=0, bias='nobias' not in self.strType.split('+')) + + elif intIn == intOut and fltStride != 1.0: + class Down(torch.nn.Module): + def __init__(self, fltScale): + super().__init__() + + self.fltScale = fltScale + # end + + def forward(self, tenIn:torch.Tensor) -> torch.Tensor: + return torch.nn.functional.interpolate(input=tenIn, scale_factor=self.fltScale, mode='bilinear', align_corners=False) + # end + # end + + self.netShortcut = Down(1.0 / fltStride) + + elif intIn != intOut and fltStride != 1.0: + class Down(torch.nn.Module): + def __init__(self, fltScale): + super().__init__() + + self.fltScale = fltScale + # end + + def forward(self, tenIn:torch.Tensor) -> torch.Tensor: + return torch.nn.functional.interpolate(input=tenIn, scale_factor=self.fltScale, mode='bilinear', align_corners=False) + # end + # end + + self.netShortcut = torch.nn.Sequential(Down(1.0 / fltStride), torch.nn.Conv2d(in_channels=intIn, out_channels=intOut, kernel_size=1, stride=1, padding=0, bias='nobias' not in self.strType.split('+'))) + + # end + + elif strPart.startswith('...') == True: + pass + + # end + # end + + assert(len(intChans) == 1) + # end + + def forward(self, tenIn:torch.Tensor) -> torch.Tensor: + if self.netEvenize is not None: + tenIn = self.netEvenize(tenIn) + # end + + tenOut = self.netMain(tenIn) + + if self.netShortcut is not None: + tenOut = tenOut + self.netShortcut(tenIn) + # end + + return tenOut + # end +# end + + +########################################################## + + +class Network(torch.nn.Module): + def __init__(self): + super().__init__() + + class Extractor(torch.nn.Module): + def __init__(self): + super().__init__() + + self.netOne = Basic('evenize(replpad)-sconv(2)-prelu(0.25)-conv(3,replpad)-prelu(0.25)-conv(3,replpad)-prelu(0.25)', [3, 32, 32, 32], None) + self.netTwo = Basic('evenize(replpad)-sconv(2)-prelu(0.25)-conv(3,replpad)-prelu(0.25)-conv(3,replpad)-prelu(0.25)', [32, 32, 32, 32], None) + self.netThr = Basic('evenize(replpad)-sconv(2)-prelu(0.25)-conv(3,replpad)-prelu(0.25)-conv(3,replpad)-prelu(0.25)', [32, 32, 32, 32], None) + # end + + def forward(self, tenIn): + tenOne = self.netOne(tenIn) + tenTwo = self.netTwo(tenOne) + tenThr = self.netThr(tenTwo) + tenFou = torch.nn.functional.avg_pool2d(input=tenThr, kernel_size=2, stride=2, count_include_pad=False) + tenFiv = torch.nn.functional.avg_pool2d(input=tenFou, kernel_size=2, stride=2, count_include_pad=False) + + return [tenOne, tenTwo, tenThr, tenFou, tenFiv] + # end + # end + + class Decoder(torch.nn.Module): + def __init__(self, intChannels): + super().__init__() + + self.netCostacti = torch.nn.PReLU(num_parameters=1, init=0.25) + self.netMain = Basic('conv(3,replpad)-prelu(0.25)-conv(3,replpad)-prelu(0.25)-conv(3,replpad)-prelu(0.25)-conv(3,replpad)-prelu(0.25)-conv(3,replpad)-prelu(0.25)-conv(3,replpad)', [intChannels, 128, 128, 96, 64, 32, 2], None) + # end + + def forward(self, tenOne, tenTwo, tenFlow): + if tenFlow is not None: + tenFlow = 2.0 * torch.nn.functional.interpolate(input=tenFlow, scale_factor=2.0, mode='bilinear', align_corners=False) + # end + + tenMain = [] + + if tenFlow is None: + tenMain.append(tenOne) + tenMain.append(self.netCostacti(costvol_func.apply(tenOne, tenTwo))) + + elif tenFlow is not None: + tenMain.append(tenOne) + tenMain.append(self.netCostacti(costvol_func.apply(tenOne, backwarp(tenTwo, tenFlow.detach())))) + tenMain.append(tenFlow) + + # end + + return (tenFlow if tenFlow is not None else 0.0) + self.netMain(torch.cat(tenMain, 1)) + # end + # end + + self.netExtractor = Extractor() + + self.netFiv = Decoder(32 + 81 + 0) + self.netFou = Decoder(32 + 81 + 2) + self.netThr = Decoder(32 + 81 + 2) + self.netTwo = Decoder(32 + 81 + 2) + self.netOne = Decoder(32 + 81 + 2) + + self.load_state_dict(torch.load('./modules/components/m2m_pwc/gmflow-scale2-regrefine6-mixdata-train320x576-4e7b215d.pth')) + # end + + def bidir(self, tenOne, tenTwo): + intWidth = tenOne.shape[3] and tenTwo.shape[3] + intHeight = tenOne.shape[2] and tenTwo.shape[2] + + tenOne, tenTwo = list(zip(*[torch.split(tenFeat, [tenOne.shape[0], tenTwo.shape[0]], 0) for tenFeat in self.netExtractor(torch.cat([tenOne, tenTwo], 0))])) + + tenFwd = None + tenFwd = self.netFiv(tenOne[-1], tenTwo[-1], tenFwd) + tenFwd = self.netFou(tenOne[-2], tenTwo[-2], tenFwd) + tenFwd = self.netThr(tenOne[-3], tenTwo[-3], tenFwd) + tenFwd = self.netTwo(tenOne[-4], tenTwo[-4], tenFwd) + tenFwd = self.netOne(tenOne[-5], tenTwo[-5], tenFwd) + + tenBwd = None + tenBwd = self.netFiv(tenTwo[-1], tenOne[-1], tenBwd) + tenBwd = self.netFou(tenTwo[-2], tenOne[-2], tenBwd) + tenBwd = self.netThr(tenTwo[-3], tenOne[-3], tenBwd) + tenBwd = self.netTwo(tenTwo[-4], tenOne[-4], tenBwd) + tenBwd = self.netOne(tenTwo[-5], tenOne[-5], tenBwd) + + return tenFwd, tenBwd + # end +# end diff --git a/modules/components/m2m_unimatch/softsplat.py b/modules/components/m2m_unimatch/softsplat.py new file mode 100644 index 0000000000000000000000000000000000000000..02b25bd9158c3ac5c21db9835977458b0ae9e6c8 --- /dev/null +++ b/modules/components/m2m_unimatch/softsplat.py @@ -0,0 +1,534 @@ +#!/usr/bin/env python + +######################################### +# This implementation is taken from +# https://github.com/sniklaus/softmax-splatting +######################################### + +import collections +import cupy +import os +import re +import torch +import typing + + +########################################################## + + +objCudacache = {} + + +def cuda_int32(intIn:int): + return cupy.int32(intIn) +# end + + +def cuda_float32(fltIn:float): + return cupy.float32(fltIn) +# end + + +def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict): + if 'device' not in objCudacache: + objCudacache['device'] = torch.cuda.get_device_name() + # end + + strKey = strFunction + + for strVariable in objVariables: + objValue = objVariables[strVariable] + + strKey += strVariable + + if objValue is None: + continue + + elif type(objValue) == int: + strKey += str(objValue) + + elif type(objValue) == float: + strKey += str(objValue) + + elif type(objValue) == bool: + strKey += str(objValue) + + elif type(objValue) == str: + strKey += objValue + + elif type(objValue) == torch.Tensor: + strKey += str(objValue.dtype) + strKey += str(objValue.shape) + strKey += str(objValue.stride()) + + elif True: + print(strVariable, type(objValue)) + assert(False) + + # end + # end + + strKey += objCudacache['device'] + + if strKey not in objCudacache: + for strVariable in objVariables: + objValue = objVariables[strVariable] + + if objValue is None: + continue + + elif type(objValue) == int: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == float: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == bool: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == str: + strKernel = strKernel.replace('{{' + strVariable + '}}', objValue) + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8: + strKernel = strKernel.replace('{{type}}', 'unsigned char') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16: + strKernel = strKernel.replace('{{type}}', 'half') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32: + strKernel = strKernel.replace('{{type}}', 'float') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64: + strKernel = strKernel.replace('{{type}}', 'double') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32: + strKernel = strKernel.replace('{{type}}', 'int') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64: + strKernel = strKernel.replace('{{type}}', 'long') + + elif type(objValue) == torch.Tensor: + print(strVariable, objValue.dtype) + assert(False) + + elif True: + print(strVariable, type(objValue)) + assert(False) + + # end + # end + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + # end + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item())) + # end + + while True: + objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == '(' else 0 + intParentheses -= 1 if strKernel[intStop] == ')' else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(',') + + assert(intArgs == len(strArgs) - 1) + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')') + # end + + strKernel = strKernel.replace('OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', '(' + str.join('+', strIndex) + ')') + # end + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()', strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == '(' else 0 + intParentheses -= 1 if strKernel[intStop] == ')' else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(',') + + assert(intArgs == len(strArgs) - 1) + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')') + # end + + strKernel = strKernel.replace('VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', strTensor + '[' + str.join('+', strIndex) + ']') + # end + + objCudacache[strKey] = { + 'strFunction': strFunction, + 'strKernel': strKernel + } + # end + + return strKey +# end + + +@cupy.memoize(for_each_device=True) +def cuda_launch(strKey:str): + if 'CUDA_HOME' not in os.environ: + os.environ['CUDA_HOME'] = '/usr/local/cuda/' + # end + + return cupy.cuda.compile_with_cache(objCudacache[strKey]['strKernel'], tuple(['-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include'])).get_function(objCudacache[strKey]['strFunction']) +# end + + +########################################################## + + +def softsplat(tenIn:torch.Tensor, tenFlow:torch.Tensor, tenMetric:torch.Tensor, strMode:str): + assert(strMode.split('-')[0] in ['sum', 'avg', 'linear', 'soft']) + + if strMode == 'sum': assert(tenMetric is None) + if strMode == 'avg': assert(tenMetric is None) + if strMode.split('-')[0] == 'linear': assert(tenMetric is not None) + if strMode.split('-')[0] == 'soft': assert(tenMetric is not None) + + if strMode == 'avg': + tenIn = torch.cat([tenIn, tenIn.new_ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]])], 1) + + elif strMode.split('-')[0] == 'linear': + tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1) + + elif strMode.split('-')[0] == 'soft': + tenIn = torch.cat([tenIn * tenMetric.exp(), tenMetric.exp()], 1) + + # end + + tenOut = softsplat_func.apply(tenIn, tenFlow) + + if strMode.split('-')[0] in ['avg', 'linear', 'soft']: + tenNormalize = tenOut[:, -1:, :, :] + + if len(strMode.split('-')) == 1: + tenNormalize = tenNormalize + 0.0000001 + + elif strMode.split('-')[1] == 'addeps': + tenNormalize = tenNormalize + 0.0000001 + + elif strMode.split('-')[1] == 'zeroeps': + tenNormalize[tenNormalize == 0.0] = 1.0 + + elif strMode.split('-')[1] == 'clipeps': + tenNormalize = tenNormalize.clip(0.0000001, None) + + # end + + tenOut = tenOut[:, :-1, :, :] / tenNormalize + # end + + return tenOut +# end + + +class softsplat_func(torch.autograd.Function): + @staticmethod + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward(self, tenIn, tenFlow): + tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) + + if tenIn.is_cuda == True: + cuda_launch(cuda_kernel('softsplat_out', ''' + extern "C" __global__ void __launch_bounds__(512) softsplat_out( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + {{type}}* __restrict__ tenOut + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) / SIZE_1(tenOut) ) % SIZE_0(tenOut); + const int intC = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_1(tenOut); + const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut); + const int intX = ( intIndex ) % SIZE_3(tenOut); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + {{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX); + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); + {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); + {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); + {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn * fltNorthwest); + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn * fltNortheast); + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn * fltSouthwest); + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn * fltSoutheast); + } + } } + ''', { + 'tenIn': tenIn, + 'tenFlow': tenFlow, + 'tenOut': tenOut + }))( + grid=tuple([int((tenOut.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenOut.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOut.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + + elif tenIn.is_cuda != True: + assert(False) + + # end + + self.save_for_backward(tenIn, tenFlow) + + return tenOut + # end + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(self, tenOutgrad): + tenIn, tenFlow = self.saved_tensors + + tenOutgrad = tenOutgrad.contiguous(); assert(tenOutgrad.is_cuda == True) + + tenIngrad = tenIn.new_empty([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) if self.needs_input_grad[0] == True else None + tenFlowgrad = tenFlow.new_empty([tenFlow.shape[0], tenFlow.shape[1], tenFlow.shape[2], tenFlow.shape[3]]) if self.needs_input_grad[1] == True else None + + if tenIngrad is not None: + cuda_launch(cuda_kernel('softsplat_ingrad', ''' + extern "C" __global__ void __launch_bounds__(512) softsplat_ingrad( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenIngrad, + {{type}}* __restrict__ tenFlowgrad + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad); + const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) ) % SIZE_1(tenIngrad); + const int intY = ( intIndex / SIZE_3(tenIngrad) ) % SIZE_2(tenIngrad); + const int intX = ( intIndex ) % SIZE_3(tenIngrad); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltIngrad = 0.0f; + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); + {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); + {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); + {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest; + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNortheastY, intNortheastX) * fltNortheast; + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest; + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast; + } + + tenIngrad[intIndex] = fltIngrad; + } } + ''', { + 'tenIn': tenIn, + 'tenFlow': tenFlow, + 'tenOutgrad': tenOutgrad, + 'tenIngrad': tenIngrad, + 'tenFlowgrad': tenFlowgrad + }))( + grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenIngrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), tenIngrad.data_ptr(), None], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + # end + + if tenFlowgrad is not None: + cuda_launch(cuda_kernel('softsplat_flowgrad', ''' + extern "C" __global__ void __launch_bounds__(512) softsplat_flowgrad( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenIngrad, + {{type}}* __restrict__ tenFlowgrad + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) / SIZE_1(tenFlowgrad) ) % SIZE_0(tenFlowgrad); + const int intC = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) ) % SIZE_1(tenFlowgrad); + const int intY = ( intIndex / SIZE_3(tenFlowgrad) ) % SIZE_2(tenFlowgrad); + const int intX = ( intIndex ) % SIZE_3(tenFlowgrad); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltFlowgrad = 0.0f; + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = 0.0f; + {{type}} fltNortheast = 0.0f; + {{type}} fltSouthwest = 0.0f; + {{type}} fltSoutheast = 0.0f; + + if (intC == 0) { + fltNorthwest = (({{type}}) (-1.0f)) * (({{type}}) (intSoutheastY) - fltY); + fltNortheast = (({{type}}) (+1.0f)) * (({{type}}) (intSouthwestY) - fltY); + fltSouthwest = (({{type}}) (-1.0f)) * (fltY - ({{type}}) (intNortheastY)); + fltSoutheast = (({{type}}) (+1.0f)) * (fltY - ({{type}}) (intNorthwestY)); + + } else if (intC == 1) { + fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (-1.0f)); + fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (-1.0f)); + fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (({{type}}) (+1.0f)); + fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (({{type}}) (+1.0f)); + + } + + for (int intChannel = 0; intChannel < SIZE_1(tenOutgrad); intChannel += 1) { + {{type}} fltIn = VALUE_4(tenIn, intN, intChannel, intY, intX); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNorthwestY, intNorthwestX) * fltIn * fltNorthwest; + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNortheastY, intNortheastX) * fltIn * fltNortheast; + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSouthwestY, intSouthwestX) * fltIn * fltSouthwest; + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSoutheastY, intSoutheastX) * fltIn * fltSoutheast; + } + } + + tenFlowgrad[intIndex] = fltFlowgrad; + } } + ''', { + 'tenIn': tenIn, + 'tenFlow': tenFlow, + 'tenOutgrad': tenOutgrad, + 'tenIngrad': tenIngrad, + 'tenFlowgrad': tenFlowgrad + }))( + grid=tuple([int((tenFlowgrad.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenFlowgrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), None, tenFlowgrad.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + # end + + return tenIngrad, tenFlowgrad + # end +# end diff --git a/modules/components/m2m_unimatch/unimatch/__init__.py b/modules/components/m2m_unimatch/unimatch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/__init__.cpython-310.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92b40a7a1b2f79674cfe771258a0e4df71a17b3e Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/__init__.cpython-310.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/__init__.cpython-38.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d66fba82ce5fc739673c7e64fdbb2139c8bdcc93 Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/__init__.cpython-38.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/__init__.cpython-39.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec0c4364b8b9bc01268316f66799f4bfa62cb4cc Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/__init__.cpython-39.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/attention.cpython-310.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45227ed364aa92a618f65b3e4348a932a8c3cdd9 Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/attention.cpython-310.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/attention.cpython-38.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/attention.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77cd56b27c7b25d67833c4fe9f93f421e6b42cf7 Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/attention.cpython-38.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/attention.cpython-39.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/attention.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbf500efe459d762fd6cb3d2bf4b68d66aeb4e40 Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/attention.cpython-39.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/backbone.cpython-310.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/backbone.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6fb32a11deb53a7daf31bef92969a612f438aee6 Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/backbone.cpython-310.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/backbone.cpython-38.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/backbone.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23966a92d87eaa25f73b8d2f4b860294bda2325a Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/backbone.cpython-38.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/backbone.cpython-39.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/backbone.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..147c589790fead624ef34b1dc2afc2d7c3ebc5a2 Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/backbone.cpython-39.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/geometry.cpython-310.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/geometry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..252559bbcc7cabef4765fc0089a94d29225d2161 Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/geometry.cpython-310.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/geometry.cpython-38.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/geometry.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ded90f680178eca52b363a7373a0be436a4a663c Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/geometry.cpython-38.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/geometry.cpython-39.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/geometry.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7060296fd53fb91c906e9df3dc49bc353c56993 Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/geometry.cpython-39.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/matching.cpython-310.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/matching.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24c6288b0da90409870ef65683d99cab7064b5ac Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/matching.cpython-310.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/matching.cpython-38.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/matching.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99a9bc892da7658c0c130bba761a93f455f367c6 Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/matching.cpython-38.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/matching.cpython-39.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/matching.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87e31bc8b0c837318fb41ec74dffe812c42db37c Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/matching.cpython-39.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/position.cpython-310.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/position.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4ddc95d18720b2b4d4d103e4474035d314199e1 Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/position.cpython-310.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/position.cpython-38.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/position.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe70da0fdfdd425bf89b591d6b5bf40bbc25d456 Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/position.cpython-38.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/position.cpython-39.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/position.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7afbed2eef6597c1896e6764ca5184c3deb3d2f9 Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/position.cpython-39.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/reg_refine.cpython-310.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/reg_refine.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7dad449704d9aaf6b29500b3b375ad2157d3dc71 Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/reg_refine.cpython-310.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/reg_refine.cpython-38.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/reg_refine.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74dc09c8c049e1a1f21714d277f20cb5a4e3fd2a Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/reg_refine.cpython-38.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/reg_refine.cpython-39.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/reg_refine.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90ed9cdefe8fac6a63eb292758c16d3fa1400427 Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/reg_refine.cpython-39.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/transformer.cpython-310.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb578c7fdd43e7df2e3af0fffe0659bb37f85635 Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/transformer.cpython-310.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/transformer.cpython-38.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ccb66c551dfeaa83340d2f1a5cae82832acaff97 Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/transformer.cpython-38.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/transformer.cpython-39.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/transformer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10e1b21ec73fc1187bd29ecae1904603a22f0453 Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/transformer.cpython-39.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/trident_conv.cpython-310.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/trident_conv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aec44f17d6a39a17beb9fbddd2861969a0ee04e4 Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/trident_conv.cpython-310.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/trident_conv.cpython-38.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/trident_conv.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e00a2bb9a00b4342c3dd27fcb630d83f5700a1ef Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/trident_conv.cpython-38.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/trident_conv.cpython-39.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/trident_conv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9e67fbf8a632919539373f7af94f956b6770297 Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/trident_conv.cpython-39.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/unimatch.cpython-310.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/unimatch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..810c7c2ce551f315977c17aad5b6e3f8d10ff328 Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/unimatch.cpython-310.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/unimatch.cpython-38.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/unimatch.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a2d5b7f016899e48f75511c3269e4de8ef7ff6c Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/unimatch.cpython-38.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/unimatch.cpython-39.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/unimatch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a20c6312c81e523bac7663a0102cbd0f71199526 Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/unimatch.cpython-39.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/utils.cpython-310.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47f07f62e9ceda5f8159478640c0c835cba1d781 Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/utils.cpython-310.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/utils.cpython-38.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2bb032e821a146d5651ea732ed96e8064c376656 Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/utils.cpython-38.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/__pycache__/utils.cpython-39.pyc b/modules/components/m2m_unimatch/unimatch/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16e561ac120ae951b8bab3e83342710a8e8b80f8 Binary files /dev/null and b/modules/components/m2m_unimatch/unimatch/__pycache__/utils.cpython-39.pyc differ diff --git a/modules/components/m2m_unimatch/unimatch/attention.py b/modules/components/m2m_unimatch/unimatch/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..92a3c878afe541753022ba85c43b5b2e86e4d254 --- /dev/null +++ b/modules/components/m2m_unimatch/unimatch/attention.py @@ -0,0 +1,253 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils import split_feature, merge_splits, split_feature_1d, merge_splits_1d + + +def single_head_full_attention(q, k, v): + # q, k, v: [B, L, C] + assert q.dim() == k.dim() == v.dim() == 3 + + scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** .5) # [B, L, L] + attn = torch.softmax(scores, dim=2) # [B, L, L] + out = torch.matmul(attn, v) # [B, L, C] + + return out + + +def single_head_full_attention_1d(q, k, v, + h=None, + w=None, + ): + # q, k, v: [B, L, C] + + assert h is not None and w is not None + assert q.size(1) == h * w + + b, _, c = q.size() + + q = q.view(b, h, w, c) # [B, H, W, C] + k = k.view(b, h, w, c) + v = v.view(b, h, w, c) + + scale_factor = c ** 0.5 + + scores = torch.matmul(q, k.permute(0, 1, 3, 2)) / scale_factor # [B, H, W, W] + + attn = torch.softmax(scores, dim=-1) + + out = torch.matmul(attn, v).view(b, -1, c) # [B, H*W, C] + + return out + + +def single_head_split_window_attention(q, k, v, + num_splits=1, + with_shift=False, + h=None, + w=None, + attn_mask=None, + ): + # ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + # q, k, v: [B, L, C] + assert q.dim() == k.dim() == v.dim() == 3 + + assert h is not None and w is not None + assert q.size(1) == h * w + + b, _, c = q.size() + + b_new = b * num_splits * num_splits + + window_size_h = h // num_splits + window_size_w = w // num_splits + + q = q.view(b, h, w, c) # [B, H, W, C] + k = k.view(b, h, w, c) + v = v.view(b, h, w, c) + + scale_factor = c ** 0.5 + + if with_shift: + assert attn_mask is not None # compute once + shift_size_h = window_size_h // 2 + shift_size_w = window_size_w // 2 + + q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + + q = split_feature(q, num_splits=num_splits, channel_last=True) # [B*K*K, H/K, W/K, C] + k = split_feature(k, num_splits=num_splits, channel_last=True) + v = split_feature(v, num_splits=num_splits, channel_last=True) + + scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1) + ) / scale_factor # [B*K*K, H/K*W/K, H/K*W/K] + + if with_shift: + scores += attn_mask.repeat(b, 1, 1) + + attn = torch.softmax(scores, dim=-1) + + out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C] + + out = merge_splits(out.view(b_new, h // num_splits, w // num_splits, c), + num_splits=num_splits, channel_last=True) # [B, H, W, C] + + # shift back + if with_shift: + out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2)) + + out = out.view(b, -1, c) + + return out + + +def single_head_split_window_attention_1d(q, k, v, + relative_position_bias=None, + num_splits=1, + with_shift=False, + h=None, + w=None, + attn_mask=None, + ): + # q, k, v: [B, L, C] + + assert h is not None and w is not None + assert q.size(1) == h * w + + b, _, c = q.size() + + b_new = b * num_splits * h + + window_size_w = w // num_splits + + q = q.view(b * h, w, c) # [B*H, W, C] + k = k.view(b * h, w, c) + v = v.view(b * h, w, c) + + scale_factor = c ** 0.5 + + if with_shift: + assert attn_mask is not None # compute once + shift_size_w = window_size_w // 2 + + q = torch.roll(q, shifts=-shift_size_w, dims=1) + k = torch.roll(k, shifts=-shift_size_w, dims=1) + v = torch.roll(v, shifts=-shift_size_w, dims=1) + + q = split_feature_1d(q, num_splits=num_splits) # [B*H*K, W/K, C] + k = split_feature_1d(k, num_splits=num_splits) + v = split_feature_1d(v, num_splits=num_splits) + + scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1) + ) / scale_factor # [B*H*K, W/K, W/K] + + if with_shift: + # attn_mask: [K, W/K, W/K] + scores += attn_mask.repeat(b * h, 1, 1) # [B*H*K, W/K, W/K] + + attn = torch.softmax(scores, dim=-1) + + out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*H*K, W/K, C] + + out = merge_splits_1d(out, h, num_splits=num_splits) # [B, H, W, C] + + # shift back + if with_shift: + out = torch.roll(out, shifts=shift_size_w, dims=2) + + out = out.view(b, -1, c) + + return out + + +class SelfAttnPropagation(nn.Module): + """ + flow propagation with self-attention on feature + query: feature0, key: feature0, value: flow + """ + + def __init__(self, in_channels, + **kwargs, + ): + super(SelfAttnPropagation, self).__init__() + + self.q_proj = nn.Linear(in_channels, in_channels) + self.k_proj = nn.Linear(in_channels, in_channels) + + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feature0, flow, + local_window_attn=False, + local_window_radius=1, + **kwargs, + ): + # q, k: feature [B, C, H, W], v: flow [B, 2, H, W] + if local_window_attn: + return self.forward_local_window_attn(feature0, flow, + local_window_radius=local_window_radius) + + b, c, h, w = feature0.size() + + query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C] + + # a note: the ``correct'' implementation should be: + # ``query = self.q_proj(query), key = self.k_proj(query)'' + # this problem is observed while cleaning up the code + # however, this doesn't affect the performance since the projection is a linear operation, + # thus the two projection matrices for key can be merged + # so I just leave it as is in order to not re-train all models :) + query = self.q_proj(query) # [B, H*W, C] + key = self.k_proj(query) # [B, H*W, C] + + value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2] + + scores = torch.matmul(query, key.permute(0, 2, 1)) / (c ** 0.5) # [B, H*W, H*W] + prob = torch.softmax(scores, dim=-1) + + out = torch.matmul(prob, value) # [B, H*W, 2] + out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W] + + return out + + def forward_local_window_attn(self, feature0, flow, + local_window_radius=1, + ): + assert flow.size(1) == 2 or flow.size(1) == 1 # flow or disparity or depth + assert local_window_radius > 0 + + b, c, h, w = feature0.size() + + value_channel = flow.size(1) + + feature0_reshape = self.q_proj(feature0.view(b, c, -1).permute(0, 2, 1) + ).reshape(b * h * w, 1, c) # [B*H*W, 1, C] + + kernel_size = 2 * local_window_radius + 1 + + feature0_proj = self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1)).permute(0, 2, 1).reshape(b, c, h, w) + + feature0_window = F.unfold(feature0_proj, kernel_size=kernel_size, + padding=local_window_radius) # [B, C*(2R+1)^2), H*W] + + feature0_window = feature0_window.view(b, c, kernel_size ** 2, h, w).permute( + 0, 3, 4, 1, 2).reshape(b * h * w, c, kernel_size ** 2) # [B*H*W, C, (2R+1)^2] + + flow_window = F.unfold(flow, kernel_size=kernel_size, + padding=local_window_radius) # [B, 2*(2R+1)^2), H*W] + + flow_window = flow_window.view(b, value_channel, kernel_size ** 2, h, w).permute( + 0, 3, 4, 2, 1).reshape(b * h * w, kernel_size ** 2, value_channel) # [B*H*W, (2R+1)^2, 2] + + scores = torch.matmul(feature0_reshape, feature0_window) / (c ** 0.5) # [B*H*W, 1, (2R+1)^2] + + prob = torch.softmax(scores, dim=-1) + + out = torch.matmul(prob, flow_window).view(b, h, w, value_channel + ).permute(0, 3, 1, 2).contiguous() # [B, 2, H, W] + + return out diff --git a/modules/components/m2m_unimatch/unimatch/backbone.py b/modules/components/m2m_unimatch/unimatch/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..a30942eca9cad56e75252c3026dca95bf1021df7 --- /dev/null +++ b/modules/components/m2m_unimatch/unimatch/backbone.py @@ -0,0 +1,117 @@ +import torch.nn as nn + +from .trident_conv import MultiScaleTridentConv + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_layer=nn.InstanceNorm2d, stride=1, dilation=1, + ): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, + dilation=dilation, padding=dilation, stride=stride, bias=False) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, + dilation=dilation, padding=dilation, bias=False) + self.relu = nn.ReLU(inplace=True) + + self.norm1 = norm_layer(planes) + self.norm2 = norm_layer(planes) + if not stride == 1 or in_planes != planes: + self.norm3 = norm_layer(planes) + + if stride == 1 and in_planes == planes: + self.downsample = None + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class CNNEncoder(nn.Module): + def __init__(self, output_dim=128, + norm_layer=nn.InstanceNorm2d, + num_output_scales=1, + **kwargs, + ): + super(CNNEncoder, self).__init__() + self.num_branch = num_output_scales + + feature_dims = [64, 96, 128] + + self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False) # 1/2 + self.norm1 = norm_layer(feature_dims[0]) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = feature_dims[0] + self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer) # 1/2 + self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer) # 1/4 + + # highest resolution 1/4 or 1/8 + stride = 2 if num_output_scales == 1 else 1 + self.layer3 = self._make_layer(feature_dims[2], stride=stride, + norm_layer=norm_layer, + ) # 1/4 or 1/8 + + self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0) + + if self.num_branch > 1: + if self.num_branch == 4: + strides = (1, 2, 4, 8) + elif self.num_branch == 3: + strides = (1, 2, 4) + elif self.num_branch == 2: + strides = (1, 2) + else: + raise ValueError + + self.trident_conv = MultiScaleTridentConv(output_dim, output_dim, + kernel_size=3, + strides=strides, + paddings=1, + num_branch=self.num_branch, + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d): + layer1 = ResidualBlock(self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation) + layer2 = ResidualBlock(dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation) + + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) # 1/2 + x = self.layer2(x) # 1/4 + x = self.layer3(x) # 1/8 or 1/4 + + x = self.conv2(x) + + if self.num_branch > 1: + out = self.trident_conv([x] * self.num_branch) # high to low res + else: + out = [x] + + return out diff --git a/modules/components/m2m_unimatch/unimatch/geometry.py b/modules/components/m2m_unimatch/unimatch/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..775a95783aeee66a44e6290525de94909af648df --- /dev/null +++ b/modules/components/m2m_unimatch/unimatch/geometry.py @@ -0,0 +1,195 @@ +import torch +import torch.nn.functional as F + + +def coords_grid(b, h, w, homogeneous=False, device=None): + y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W] + + stacks = [x, y] + + if homogeneous: + ones = torch.ones_like(x) # [H, W] + stacks.append(ones) + + grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] + + grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] + + if device is not None: + grid = grid.to(device) + + return grid + + +def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): + assert device is not None + + x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device), + torch.linspace(h_min, h_max, len_h, device=device)], + ) + grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2] + + return grid + + +def normalize_coords(coords, h, w): + # coords: [B, H, W, 2] + c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device) + return (coords - c) / c # [-1, 1] + + +def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False): + # img: [B, C, H, W] + # sample_coords: [B, 2, H, W] in image scale + if sample_coords.size(1) != 2: # [B, H, W, 2] + sample_coords = sample_coords.permute(0, 3, 1, 2) + + b, _, h, w = sample_coords.shape + + # Normalize to [-1, 1] + x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1 + y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1 + + grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2] + + img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True) + + if return_mask: + mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W] + + return img, mask + + return img + + +def flow_warp(feature, flow, mask=False, padding_mode='zeros'): + b, c, h, w = feature.size() + assert flow.size(1) == 2 + + grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] + + return bilinear_sample(feature, grid, padding_mode=padding_mode, + return_mask=mask) + + +def forward_backward_consistency_check(fwd_flow, bwd_flow, + alpha=0.01, + beta=0.5 + ): + # fwd_flow, bwd_flow: [B, 2, H, W] + # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) + assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 + assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 + flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] + + warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W] + warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W] + + diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W] + diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) + + threshold = alpha * flow_mag + beta + + fwd_occ = (diff_fwd > threshold).float() # [B, H, W] + bwd_occ = (diff_bwd > threshold).float() + + return fwd_occ, bwd_occ + + +def back_project(depth, intrinsics): + # Back project 2D pixel coords to 3D points + # depth: [B, H, W] + # intrinsics: [B, 3, 3] + b, h, w = depth.shape + grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W] + + intrinsics_inv = torch.inverse(intrinsics) # [B, 3, 3] + + points = intrinsics_inv.bmm(grid.view(b, 3, -1)).view(b, 3, h, w) * depth.unsqueeze(1) # [B, 3, H, W] + + return points + + +def camera_transform(points_ref, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None): + # Transform 3D points from reference camera to target camera + # points_ref: [B, 3, H, W] + # extrinsics_ref: [B, 4, 4] + # extrinsics_tgt: [B, 4, 4] + # extrinsics_rel: [B, 4, 4], relative pose transform + b, _, h, w = points_ref.shape + + if extrinsics_rel is None: + extrinsics_rel = torch.bmm(extrinsics_tgt, torch.inverse(extrinsics_ref)) # [B, 4, 4] + + points_tgt = torch.bmm(extrinsics_rel[:, :3, :3], + points_ref.view(b, 3, -1)) + extrinsics_rel[:, :3, -1:] # [B, 3, H*W] + + points_tgt = points_tgt.view(b, 3, h, w) # [B, 3, H, W] + + return points_tgt + + +def reproject(points_tgt, intrinsics, return_mask=False): + # reproject to target view + # points_tgt: [B, 3, H, W] + # intrinsics: [B, 3, 3] + + b, _, h, w = points_tgt.shape + + proj_points = torch.bmm(intrinsics, points_tgt.view(b, 3, -1)).view(b, 3, h, w) # [B, 3, H, W] + + X = proj_points[:, 0] + Y = proj_points[:, 1] + Z = proj_points[:, 2].clamp(min=1e-3) + + pixel_coords = torch.stack([X / Z, Y / Z], dim=1).view(b, 2, h, w) # [B, 2, H, W] in image scale + + if return_mask: + # valid mask in pixel space + mask = (pixel_coords[:, 0] >= 0) & (pixel_coords[:, 0] <= (w - 1)) & ( + pixel_coords[:, 1] >= 0) & (pixel_coords[:, 1] <= (h - 1)) # [B, H, W] + + return pixel_coords, mask + + return pixel_coords + + +def reproject_coords(depth_ref, intrinsics, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None, + return_mask=False): + # Compute reprojection sample coords + points_ref = back_project(depth_ref, intrinsics) # [B, 3, H, W] + points_tgt = camera_transform(points_ref, extrinsics_ref, extrinsics_tgt, extrinsics_rel=extrinsics_rel) + + if return_mask: + reproj_coords, mask = reproject(points_tgt, intrinsics, + return_mask=return_mask) # [B, 2, H, W] in image scale + + return reproj_coords, mask + + reproj_coords = reproject(points_tgt, intrinsics, + return_mask=return_mask) # [B, 2, H, W] in image scale + + return reproj_coords + + +def compute_flow_with_depth_pose(depth_ref, intrinsics, + extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None, + return_mask=False): + b, h, w = depth_ref.shape + coords_init = coords_grid(b, h, w, device=depth_ref.device) # [B, 2, H, W] + + if return_mask: + reproj_coords, mask = reproject_coords(depth_ref, intrinsics, extrinsics_ref, extrinsics_tgt, + extrinsics_rel=extrinsics_rel, + return_mask=return_mask) # [B, 2, H, W] + rigid_flow = reproj_coords - coords_init + + return rigid_flow, mask + + reproj_coords = reproject_coords(depth_ref, intrinsics, extrinsics_ref, extrinsics_tgt, + extrinsics_rel=extrinsics_rel, + return_mask=return_mask) # [B, 2, H, W] + + rigid_flow = reproj_coords - coords_init + + return rigid_flow diff --git a/modules/components/m2m_unimatch/unimatch/matching.py b/modules/components/m2m_unimatch/unimatch/matching.py new file mode 100644 index 0000000000000000000000000000000000000000..595437f2307202ab36d7c2ee3dfa0ab44e4dc830 --- /dev/null +++ b/modules/components/m2m_unimatch/unimatch/matching.py @@ -0,0 +1,279 @@ +import torch +import torch.nn.functional as F + +from .geometry import coords_grid, generate_window_grid, normalize_coords + + +def global_correlation_softmax(feature0, feature1, + pred_bidir_flow=False, + ): + # global correlation + b, c, h, w = feature0.shape + feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C] + feature1 = feature1.view(b, c, -1) # [B, C, H*W] + + correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (c ** 0.5) # [B, H, W, H, W] + + # flow from softmax + init_grid = coords_grid(b, h, w).to(correlation.device) # [B, 2, H, W] + grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] + + correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W] + + if pred_bidir_flow: + correlation = torch.cat((correlation, correlation.permute(0, 2, 1)), dim=0) # [2*B, H*W, H*W] + init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, 2, H, W] + grid = grid.repeat(2, 1, 1) # [2*B, H*W, 2] + b = b * 2 + + prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W] + + correspondence = torch.matmul(prob, grid).view(b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W] + + # when predicting bidirectional flow, flow is the concatenation of forward flow and backward flow + flow = correspondence - init_grid + + return flow, prob + + +def local_correlation_softmax(feature0, feature1, local_radius, + padding_mode='zeros', + ): + b, c, h, w = feature0.size() + coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W] + coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] + + local_h = 2 * local_radius + 1 + local_w = 2 * local_radius + 1 + + window_grid = generate_window_grid(-local_radius, local_radius, + -local_radius, local_radius, + local_h, local_w, device=feature0.device) # [2R+1, 2R+1, 2] + window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2] + sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1)^2, 2] + + sample_coords_softmax = sample_coords + + # exclude coords that are out of image space + valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [B, H*W, (2R+1)^2] + valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [B, H*W, (2R+1)^2] + + valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax + + # normalize coordinates to [-1, 1] + sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1] + window_feature = F.grid_sample(feature1, sample_coords_norm, + padding_mode=padding_mode, align_corners=True + ).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)^2] + feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C] + + corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)^2] + + # mask invalid locations + corr[~valid] = -1e9 + + prob = F.softmax(corr, -1) # [B, H*W, (2R+1)^2] + + correspondence = torch.matmul(prob.unsqueeze(-2), sample_coords_softmax).squeeze(-2).view( + b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W] + + flow = correspondence - coords_init + match_prob = prob + + return flow, match_prob + + +def local_correlation_with_flow(feature0, feature1, + flow, + local_radius, + padding_mode='zeros', + dilation=1, + ): + b, c, h, w = feature0.size() + coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W] + coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] + + local_h = 2 * local_radius + 1 + local_w = 2 * local_radius + 1 + + window_grid = generate_window_grid(-local_radius, local_radius, + -local_radius, local_radius, + local_h, local_w, device=feature0.device) # [2R+1, 2R+1, 2] + window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2] + sample_coords = coords.unsqueeze(-2) + window_grid * dilation # [B, H*W, (2R+1)^2, 2] + + # flow can be zero when using features after transformer + if not isinstance(flow, float): + sample_coords = sample_coords + flow.view( + b, 2, -1).permute(0, 2, 1).unsqueeze(-2) # [B, H*W, (2R+1)^2, 2] + else: + assert flow == 0. + + # normalize coordinates to [-1, 1] + sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1] + window_feature = F.grid_sample(feature1, sample_coords_norm, + padding_mode=padding_mode, align_corners=True + ).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)^2] + feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C] + + corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)^2] + + corr = corr.view(b, h, w, -1).permute(0, 3, 1, 2).contiguous() # [B, (2R+1)^2, H, W] + + return corr + + +def global_correlation_softmax_stereo(feature0, feature1, + ): + # global correlation on horizontal direction + b, c, h, w = feature0.shape + + x_grid = torch.linspace(0, w - 1, w, device=feature0.device) # [W] + + feature0 = feature0.permute(0, 2, 3, 1) # [B, H, W, C] + feature1 = feature1.permute(0, 2, 1, 3) # [B, H, C, W] + + correlation = torch.matmul(feature0, feature1) / (c ** 0.5) # [B, H, W, W] + + # mask subsequent positions to make disparity positive + mask = torch.triu(torch.ones((w, w)), diagonal=1).type_as(feature0) # [W, W] + valid_mask = (mask == 0).unsqueeze(0).unsqueeze(0).repeat(b, h, 1, 1) # [B, H, W, W] + + correlation[~valid_mask] = -1e9 + + prob = F.softmax(correlation, dim=-1) # [B, H, W, W] + + correspondence = (x_grid.view(1, 1, 1, w) * prob).sum(-1) # [B, H, W] + + # NOTE: unlike flow, disparity is typically positive + disparity = x_grid.view(1, 1, w).repeat(b, h, 1) - correspondence # [B, H, W] + + return disparity.unsqueeze(1), prob # feature resolution + + +def local_correlation_softmax_stereo(feature0, feature1, local_radius, + ): + b, c, h, w = feature0.size() + coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W] + coords = coords_init.view(b, 2, -1).permute(0, 2, 1).contiguous() # [B, H*W, 2] + + local_h = 1 + local_w = 2 * local_radius + 1 + + window_grid = generate_window_grid(0, 0, + -local_radius, local_radius, + local_h, local_w, device=feature0.device) # [1, 2R+1, 2] + window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1), 2] + sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1), 2] + + sample_coords_softmax = sample_coords + + # exclude coords that are out of image space + valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [B, H*W, (2R+1)^2] + valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [B, H*W, (2R+1)^2] + + valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax + + # normalize coordinates to [-1, 1] + sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1] + window_feature = F.grid_sample(feature1, sample_coords_norm, + padding_mode='zeros', align_corners=True + ).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)] + feature0_view = feature0.permute(0, 2, 3, 1).contiguous().view(b, h * w, 1, c) # [B, H*W, 1, C] + + corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)] + + # mask invalid locations + corr[~valid] = -1e9 + + prob = F.softmax(corr, -1) # [B, H*W, (2R+1)] + + correspondence = torch.matmul(prob.unsqueeze(-2), + sample_coords_softmax).squeeze(-2).view( + b, h, w, 2).permute(0, 3, 1, 2).contiguous() # [B, 2, H, W] + + flow = correspondence - coords_init # flow at feature resolution + match_prob = prob + + flow_x = -flow[:, :1] # [B, 1, H, W] + + return flow_x, match_prob + + +def correlation_softmax_depth(feature0, feature1, + intrinsics, + pose, + depth_candidates, + depth_from_argmax=False, + pred_bidir_depth=False, + ): + b, c, h, w = feature0.size() + assert depth_candidates.dim() == 4 # [B, D, H, W] + scale_factor = c ** 0.5 + + if pred_bidir_depth: + feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat((feature1, feature0), dim=0) + intrinsics = intrinsics.repeat(2, 1, 1) + pose = torch.cat((pose, torch.inverse(pose)), dim=0) + depth_candidates = depth_candidates.repeat(2, 1, 1, 1) + + # depth candidates are actually inverse depth + warped_feature1 = warp_with_pose_depth_candidates(feature1, intrinsics, pose, + 1. / depth_candidates, + ) # [B, C, D, H, W] + + correlation = (feature0.unsqueeze(2) * warped_feature1).sum(1) / scale_factor # [B, D, H, W] + + match_prob = F.softmax(correlation, dim=1) # [B, D, H, W] + + # for cross-task transfer (flow -> depth), extract depth with argmax at test time + if depth_from_argmax: + index = torch.argmax(match_prob, dim=1, keepdim=True) + depth = torch.gather(depth_candidates, dim=1, index=index) + else: + depth = (match_prob * depth_candidates).sum(dim=1, keepdim=True) # [B, 1, H, W] + + return depth, match_prob + + +def warp_with_pose_depth_candidates(feature1, intrinsics, pose, depth, + clamp_min_depth=1e-3, + ): + """ + feature1: [B, C, H, W] + intrinsics: [B, 3, 3] + pose: [B, 4, 4] + depth: [B, D, H, W] + """ + + assert intrinsics.size(1) == intrinsics.size(2) == 3 + assert pose.size(1) == pose.size(2) == 4 + assert depth.dim() == 4 + + b, d, h, w = depth.size() + c = feature1.size(1) + + with torch.no_grad(): + # pixel coordinates + grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W] + # back project to 3D and transform viewpoint + points = torch.inverse(intrinsics).bmm(grid.view(b, 3, -1)) # [B, 3, H*W] + points = torch.bmm(pose[:, :3, :3], points).unsqueeze(2).repeat( + 1, 1, d, 1) * depth.view(b, 1, d, h * w) # [B, 3, D, H*W] + points = points + pose[:, :3, -1:].unsqueeze(-1) # [B, 3, D, H*W] + # reproject to 2D image plane + points = torch.bmm(intrinsics, points.view(b, 3, -1)).view(b, 3, d, h * w) # [B, 3, D, H*W] + pixel_coords = points[:, :2] / points[:, -1:].clamp(min=clamp_min_depth) # [B, 2, D, H*W] + + # normalize to [-1, 1] + x_grid = 2 * pixel_coords[:, 0] / (w - 1) - 1 + y_grid = 2 * pixel_coords[:, 1] / (h - 1) - 1 + + grid = torch.stack([x_grid, y_grid], dim=-1) # [B, D, H*W, 2] + + # sample features + warped_feature = F.grid_sample(feature1, grid.view(b, d * h, w, 2), mode='bilinear', + padding_mode='zeros', + align_corners=True).view(b, c, d, h, w) # [B, C, D, H, W] + + return warped_feature diff --git a/modules/components/m2m_unimatch/unimatch/position.py b/modules/components/m2m_unimatch/unimatch/position.py new file mode 100644 index 0000000000000000000000000000000000000000..14a6da436c818b7c2784e92dba66f7947d34b7ce --- /dev/null +++ b/modules/components/m2m_unimatch/unimatch/position.py @@ -0,0 +1,46 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py + +import torch +import torch.nn as nn +import math + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x): + # x = tensor_list.tensors # [B, C, H, W] + # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0 + b, c, h, w = x.size() + mask = torch.ones((b, h, w), device=x.device) # [B, H, W] + y_embed = mask.cumsum(1, dtype=torch.float32) + x_embed = mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos diff --git a/modules/components/m2m_unimatch/unimatch/reg_refine.py b/modules/components/m2m_unimatch/unimatch/reg_refine.py new file mode 100644 index 0000000000000000000000000000000000000000..47f83da1c5dcd476069e841d045db04998be3604 --- /dev/null +++ b/modules/components/m2m_unimatch/unimatch/reg_refine.py @@ -0,0 +1,119 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256, + out_dim=2, + ): + super(FlowHead, self).__init__() + + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, out_dim, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + out = self.conv2(self.relu(self.conv1(x))) + + return out + + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192 + 128, + kernel_size=5, + ): + padding = (kernel_size - 1) // 2 + + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding)) + self.convr1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding)) + self.convq1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding)) + + self.convz2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0)) + self.convr2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0)) + self.convq2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0)) + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + return h + + +class BasicMotionEncoder(nn.Module): + def __init__(self, corr_channels=324, + flow_channels=2, + ): + super(BasicMotionEncoder, self).__init__() + + self.convc1 = nn.Conv2d(corr_channels, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(flow_channels, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64 + 192, 128 - flow_channels, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + + +class BasicUpdateBlock(nn.Module): + def __init__(self, corr_channels=324, + hidden_dim=128, + context_dim=128, + downsample_factor=8, + flow_dim=2, + bilinear_up=False, + ): + super(BasicUpdateBlock, self).__init__() + + self.encoder = BasicMotionEncoder(corr_channels=corr_channels, + flow_channels=flow_dim, + ) + + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=context_dim + hidden_dim) + + self.flow_head = FlowHead(hidden_dim, hidden_dim=256, + out_dim=flow_dim, + ) + + if bilinear_up: + self.mask = None + else: + self.mask = nn.Sequential( + nn.Conv2d(hidden_dim, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, downsample_factor ** 2 * 9, 1, padding=0)) + + def forward(self, net, inp, corr, flow): + motion_features = self.encoder(flow, corr) + + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + if self.mask is not None: + mask = self.mask(net) + else: + mask = None + + return net, mask, delta_flow diff --git a/modules/components/m2m_unimatch/unimatch/transformer.py b/modules/components/m2m_unimatch/unimatch/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..4878e23a64f6609b1bf10740b0a794d8da836c31 --- /dev/null +++ b/modules/components/m2m_unimatch/unimatch/transformer.py @@ -0,0 +1,294 @@ +import torch +import torch.nn as nn + +from .attention import (single_head_full_attention, single_head_split_window_attention, + single_head_full_attention_1d, single_head_split_window_attention_1d) +from .utils import generate_shift_window_attn_mask, generate_shift_window_attn_mask_1d + + +class TransformerLayer(nn.Module): + def __init__(self, + d_model=128, + nhead=1, + no_ffn=False, + ffn_dim_expansion=4, + ): + super(TransformerLayer, self).__init__() + + self.dim = d_model + self.nhead = nhead + self.no_ffn = no_ffn + + # multi-head attention + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, d_model, bias=False) + + self.merge = nn.Linear(d_model, d_model, bias=False) + + self.norm1 = nn.LayerNorm(d_model) + + # no ffn after self-attn, with ffn after cross-attn + if not self.no_ffn: + in_channels = d_model * 2 + self.mlp = nn.Sequential( + nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False), + nn.GELU(), + nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False), + ) + + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, source, target, + height=None, + width=None, + shifted_window_attn_mask=None, + shifted_window_attn_mask_1d=None, + attn_type='swin', + with_shift=False, + attn_num_splits=None, + ): + # source, target: [B, L, C] + query, key, value = source, target, target + + # for stereo: 2d attn in self-attn, 1d attn in cross-attn + is_self_attn = (query - key).abs().max() < 1e-6 + + # single-head attention + query = self.q_proj(query) # [B, L, C] + key = self.k_proj(key) # [B, L, C] + value = self.v_proj(value) # [B, L, C] + + if attn_type == 'swin' and attn_num_splits > 1: # self, cross-attn: both swin 2d + if self.nhead > 1: + # we observe that multihead attention slows down the speed and increases the memory consumption + # without bringing obvious performance gains and thus the implementation is removed + raise NotImplementedError + else: + message = single_head_split_window_attention(query, key, value, + num_splits=attn_num_splits, + with_shift=with_shift, + h=height, + w=width, + attn_mask=shifted_window_attn_mask, + ) + + elif attn_type == 'self_swin2d_cross_1d': # self-attn: swin 2d, cross-attn: full 1d + if self.nhead > 1: + raise NotImplementedError + else: + if is_self_attn: + if attn_num_splits > 1: + message = single_head_split_window_attention(query, key, value, + num_splits=attn_num_splits, + with_shift=with_shift, + h=height, + w=width, + attn_mask=shifted_window_attn_mask, + ) + else: + # full 2d attn + message = single_head_full_attention(query, key, value) # [N, L, C] + + else: + # cross attn 1d + message = single_head_full_attention_1d(query, key, value, + h=height, + w=width, + ) + + elif attn_type == 'self_swin2d_cross_swin1d': # self-attn: swin 2d, cross-attn: swin 1d + if self.nhead > 1: + raise NotImplementedError + else: + if is_self_attn: + if attn_num_splits > 1: + # self attn shift window + message = single_head_split_window_attention(query, key, value, + num_splits=attn_num_splits, + with_shift=with_shift, + h=height, + w=width, + attn_mask=shifted_window_attn_mask, + ) + else: + # full 2d attn + message = single_head_full_attention(query, key, value) # [N, L, C] + else: + if attn_num_splits > 1: + assert shifted_window_attn_mask_1d is not None + # cross attn 1d shift + message = single_head_split_window_attention_1d(query, key, value, + num_splits=attn_num_splits, + with_shift=with_shift, + h=height, + w=width, + attn_mask=shifted_window_attn_mask_1d, + ) + else: + message = single_head_full_attention_1d(query, key, value, + h=height, + w=width, + ) + + else: + message = single_head_full_attention(query, key, value) # [B, L, C] + + message = self.merge(message) # [B, L, C] + message = self.norm1(message) + + if not self.no_ffn: + message = self.mlp(torch.cat([source, message], dim=-1)) + message = self.norm2(message) + + return source + message + + +class TransformerBlock(nn.Module): + """self attention + cross attention + FFN""" + + def __init__(self, + d_model=128, + nhead=1, + ffn_dim_expansion=4, + ): + super(TransformerBlock, self).__init__() + + self.self_attn = TransformerLayer(d_model=d_model, + nhead=nhead, + no_ffn=True, + ffn_dim_expansion=ffn_dim_expansion, + ) + + self.cross_attn_ffn = TransformerLayer(d_model=d_model, + nhead=nhead, + ffn_dim_expansion=ffn_dim_expansion, + ) + + def forward(self, source, target, + height=None, + width=None, + shifted_window_attn_mask=None, + shifted_window_attn_mask_1d=None, + attn_type='swin', + with_shift=False, + attn_num_splits=None, + ): + # source, target: [B, L, C] + + # self attention + source = self.self_attn(source, source, + height=height, + width=width, + shifted_window_attn_mask=shifted_window_attn_mask, + attn_type=attn_type, + with_shift=with_shift, + attn_num_splits=attn_num_splits, + ) + + # cross attention and ffn + source = self.cross_attn_ffn(source, target, + height=height, + width=width, + shifted_window_attn_mask=shifted_window_attn_mask, + shifted_window_attn_mask_1d=shifted_window_attn_mask_1d, + attn_type=attn_type, + with_shift=with_shift, + attn_num_splits=attn_num_splits, + ) + + return source + + +class FeatureTransformer(nn.Module): + def __init__(self, + num_layers=6, + d_model=128, + nhead=1, + ffn_dim_expansion=4, + ): + super(FeatureTransformer, self).__init__() + + self.d_model = d_model + self.nhead = nhead + + self.layers = nn.ModuleList([ + TransformerBlock(d_model=d_model, + nhead=nhead, + ffn_dim_expansion=ffn_dim_expansion, + ) + for i in range(num_layers)]) + + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feature0, feature1, + attn_type='swin', + attn_num_splits=None, + **kwargs, + ): + + b, c, h, w = feature0.shape + assert self.d_model == c + + feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C] + feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C] + + # 2d attention + if 'swin' in attn_type and attn_num_splits > 1: + # global and refine use different number of splits + window_size_h = h // attn_num_splits + window_size_w = w // attn_num_splits + + # compute attn mask once + shifted_window_attn_mask = generate_shift_window_attn_mask( + input_resolution=(h, w), + window_size_h=window_size_h, + window_size_w=window_size_w, + shift_size_h=window_size_h // 2, + shift_size_w=window_size_w // 2, + device=feature0.device, + ) # [K*K, H/K*W/K, H/K*W/K] + else: + shifted_window_attn_mask = None + + # 1d attention + if 'swin1d' in attn_type and attn_num_splits > 1: + window_size_w = w // attn_num_splits + + # compute attn mask once + shifted_window_attn_mask_1d = generate_shift_window_attn_mask_1d( + input_w=w, + window_size_w=window_size_w, + shift_size_w=window_size_w // 2, + device=feature0.device, + ) # [K, W/K, W/K] + else: + shifted_window_attn_mask_1d = None + + # concat feature0 and feature1 in batch dimension to compute in parallel + concat0 = torch.cat((feature0, feature1), dim=0) # [2B, H*W, C] + concat1 = torch.cat((feature1, feature0), dim=0) # [2B, H*W, C] + + for i, layer in enumerate(self.layers): + concat0 = layer(concat0, concat1, + height=h, + width=w, + attn_type=attn_type, + with_shift='swin' in attn_type and attn_num_splits > 1 and i % 2 == 1, + attn_num_splits=attn_num_splits, + shifted_window_attn_mask=shifted_window_attn_mask, + shifted_window_attn_mask_1d=shifted_window_attn_mask_1d, + ) + + # update feature1 + concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0) + + feature0, feature1 = concat0.chunk(chunks=2, dim=0) # [B, H*W, C] + + # reshape back + feature0 = feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W] + feature1 = feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W] + + return feature0, feature1 diff --git a/modules/components/m2m_unimatch/unimatch/trident_conv.py b/modules/components/m2m_unimatch/unimatch/trident_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..29a2a73e964a88b68bc095772d9c3cc443e3e0fe --- /dev/null +++ b/modules/components/m2m_unimatch/unimatch/trident_conv.py @@ -0,0 +1,90 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.modules.utils import _pair + + +class MultiScaleTridentConv(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + strides=1, + paddings=0, + dilations=1, + dilation=1, + groups=1, + num_branch=1, + test_branch_idx=-1, + bias=False, + norm=None, + activation=None, + ): + super(MultiScaleTridentConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.num_branch = num_branch + self.stride = _pair(stride) + self.groups = groups + self.with_bias = bias + self.dilation = dilation + if isinstance(paddings, int): + paddings = [paddings] * self.num_branch + if isinstance(dilations, int): + dilations = [dilations] * self.num_branch + if isinstance(strides, int): + strides = [strides] * self.num_branch + self.paddings = [_pair(padding) for padding in paddings] + self.dilations = [_pair(dilation) for dilation in dilations] + self.strides = [_pair(stride) for stride in strides] + self.test_branch_idx = test_branch_idx + self.norm = norm + self.activation = activation + + assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1 + + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) + ) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.bias = None + + nn.init.kaiming_uniform_(self.weight, nonlinearity="relu") + if self.bias is not None: + nn.init.constant_(self.bias, 0) + + def forward(self, inputs): + num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1 + assert len(inputs) == num_branch + + if self.training or self.test_branch_idx == -1: + outputs = [ + F.conv2d(input, self.weight, self.bias, stride, padding, self.dilation, self.groups) + for input, stride, padding in zip(inputs, self.strides, self.paddings) + ] + else: + outputs = [ + F.conv2d( + inputs[0], + self.weight, + self.bias, + self.strides[self.test_branch_idx] if self.test_branch_idx == -1 else self.strides[-1], + self.paddings[self.test_branch_idx] if self.test_branch_idx == -1 else self.paddings[-1], + self.dilation, + self.groups, + ) + ] + + if self.norm is not None: + outputs = [self.norm(x) for x in outputs] + if self.activation is not None: + outputs = [self.activation(x) for x in outputs] + return outputs diff --git a/modules/components/m2m_unimatch/unimatch/unimatch.py b/modules/components/m2m_unimatch/unimatch/unimatch.py new file mode 100644 index 0000000000000000000000000000000000000000..6a0977b87d65e5d7a757b2350a760fc879690394 --- /dev/null +++ b/modules/components/m2m_unimatch/unimatch/unimatch.py @@ -0,0 +1,369 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .backbone import CNNEncoder +from .transformer import FeatureTransformer +from .matching import (global_correlation_softmax, local_correlation_softmax, local_correlation_with_flow, + global_correlation_softmax_stereo, local_correlation_softmax_stereo, + correlation_softmax_depth) +from .attention import SelfAttnPropagation +from .geometry import flow_warp, compute_flow_with_depth_pose +from .reg_refine import BasicUpdateBlock +from .utils import normalize_img, feature_add_position, upsample_flow_with_mask + + +class UniMatch(nn.Module): + def __init__(self, + num_scales=1, + feature_channels=128, + upsample_factor=8, + num_head=1, + ffn_dim_expansion=4, + num_transformer_layers=6, + reg_refine=False, # optional local regression refinement + task='flow', + ): + super(UniMatch, self).__init__() + + self.feature_channels = feature_channels + self.num_scales = num_scales + self.upsample_factor = upsample_factor + self.reg_refine = reg_refine + + # CNN + self.backbone = CNNEncoder(output_dim=feature_channels, num_output_scales=num_scales) + + # Transformer + self.transformer = FeatureTransformer(num_layers=num_transformer_layers, + d_model=feature_channels, + nhead=num_head, + ffn_dim_expansion=ffn_dim_expansion, + ) + + # propagation with self-attn + self.feature_flow_attn = SelfAttnPropagation(in_channels=feature_channels) + + if not self.reg_refine or task == 'depth': + # convex upsampling simiar to RAFT + # concat feature0 and low res flow as input + self.upsampler = nn.Sequential(nn.Conv2d(2 + feature_channels, 256, 3, 1, 1), + nn.ReLU(inplace=True), + nn.Conv2d(256, upsample_factor ** 2 * 9, 1, 1, 0)) + # thus far, all the learnable parameters are task-agnostic + + if reg_refine: + # optional task-specific local regression refinement + self.refine_proj = nn.Conv2d(128, 256, 1) + self.refine = BasicUpdateBlock(corr_channels=(2 * 4 + 1) ** 2, + downsample_factor=upsample_factor, + flow_dim=2 if task == 'flow' else 1, + bilinear_up=task == 'depth', + ) + + self.load_state_dict(torch.load('./modules/components/m2m_unimatch/gmflow-scale2-regrefine6-mixdata-train320x576-4e7b215d.pth')['model']) + + def extract_feature(self, img0, img1): + concat = torch.cat((img0, img1), dim=0) # [2B, C, H, W] + features = self.backbone(concat) # list of [2B, C, H, W], resolution from high to low + + # reverse: resolution from low to high + features = features[::-1] + + feature0, feature1 = [], [] + + for i in range(len(features)): + feature = features[i] + chunks = torch.chunk(feature, 2, 0) # tuple + feature0.append(chunks[0]) + feature1.append(chunks[1]) + + return feature0, feature1 + + def upsample_flow(self, flow, feature, bilinear=False, upsample_factor=8, + is_depth=False): + if bilinear: + multiplier = 1 if is_depth else upsample_factor + up_flow = F.interpolate(flow, scale_factor=upsample_factor, + mode='bilinear', align_corners=True) * multiplier + else: + concat = torch.cat((flow, feature), dim=1) + mask = self.upsampler(concat) + up_flow = upsample_flow_with_mask(flow, mask, upsample_factor=self.upsample_factor, + is_depth=is_depth) + + return up_flow + + def forward(self, img0, img1, + attn_type=None, + attn_splits_list=None, + corr_radius_list=None, + prop_radius_list=None, + num_reg_refine=1, + pred_bidir_flow=False, + task='flow', + intrinsics=None, + pose=None, # relative pose transform + min_depth=1. / 0.5, # inverse depth range + max_depth=1. / 10, + num_depth_candidates=64, + depth_from_argmax=False, + pred_bidir_depth=False, + **kwargs, + ): + + if pred_bidir_flow: + assert task == 'flow' + + if task == 'depth': + assert self.num_scales == 1 # multi-scale depth model is not supported yet + + results_dict = {} + flow_preds = [] + + if task == 'flow': + # stereo and depth tasks have normalized img in dataloader + img0, img1 = normalize_img(img0, img1) # [B, 3, H, W] + + # list of features, resolution low to high + feature0_list, feature1_list = self.extract_feature(img0, img1) # list of features + + flow = None + + if task != 'depth': + assert len(attn_splits_list) == len(corr_radius_list) == len(prop_radius_list) == self.num_scales + else: + assert len(attn_splits_list) == len(prop_radius_list) == self.num_scales == 1 + + for scale_idx in range(self.num_scales): + feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx] + + if pred_bidir_flow and scale_idx > 0: + # predicting bidirectional flow with refinement + feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat((feature1, feature0), dim=0) + + feature0_ori, feature1_ori = feature0, feature1 + + upsample_factor = self.upsample_factor * (2 ** (self.num_scales - 1 - scale_idx)) + + if task == 'depth': + # scale intrinsics + intrinsics_curr = intrinsics.clone() + intrinsics_curr[:, :2] = intrinsics_curr[:, :2] / upsample_factor + + if scale_idx > 0: + assert task != 'depth' # not supported for multi-scale depth model + flow = F.interpolate(flow, scale_factor=2, mode='bilinear', align_corners=True) * 2 + + if flow is not None: + assert task != 'depth' + flow = flow.detach() + + if task == 'stereo': + # construct flow vector for disparity + # flow here is actually disparity + zeros = torch.zeros_like(flow) # [B, 1, H, W] + # NOTE: reverse disp, disparity is positive + displace = torch.cat((-flow, zeros), dim=1) # [B, 2, H, W] + feature1 = flow_warp(feature1, displace) # [B, C, H, W] + elif task == 'flow': + feature1 = flow_warp(feature1, flow) # [B, C, H, W] + else: + raise NotImplementedError + + attn_splits = attn_splits_list[scale_idx] + if task != 'depth': + corr_radius = corr_radius_list[scale_idx] + prop_radius = prop_radius_list[scale_idx] + + # add position to features + feature0, feature1 = feature_add_position(feature0, feature1, attn_splits, self.feature_channels) + + # Transformer + feature0, feature1 = self.transformer(feature0, feature1, + attn_type=attn_type, + attn_num_splits=attn_splits, + ) + + # correlation and softmax + if task == 'depth': + # first generate depth candidates + b, _, h, w = feature0.size() + depth_candidates = torch.linspace(min_depth, max_depth, num_depth_candidates).type_as(feature0) + depth_candidates = depth_candidates.view(1, num_depth_candidates, 1, 1).repeat(b, 1, h, + w) # [B, D, H, W] + + flow_pred = correlation_softmax_depth(feature0, feature1, + intrinsics_curr, + pose, + depth_candidates=depth_candidates, + depth_from_argmax=depth_from_argmax, + pred_bidir_depth=pred_bidir_depth, + )[0] + + else: + if corr_radius == -1: # global matching + if task == 'flow': + flow_pred = global_correlation_softmax(feature0, feature1, pred_bidir_flow)[0] + elif task == 'stereo': + flow_pred = global_correlation_softmax_stereo(feature0, feature1)[0] + else: + raise NotImplementedError + else: # local matching + if task == 'flow': + flow_pred = local_correlation_softmax(feature0, feature1, corr_radius)[0] + elif task == 'stereo': + flow_pred = local_correlation_softmax_stereo(feature0, feature1, corr_radius)[0] + else: + raise NotImplementedError + + # flow or residual flow + flow = flow + flow_pred if flow is not None else flow_pred + + if task == 'stereo': + flow = flow.clamp(min=0) # positive disparity + + # upsample to the original resolution for supervison at training time only + if self.training: + flow_bilinear = self.upsample_flow(flow, None, bilinear=True, upsample_factor=upsample_factor, + is_depth=task == 'depth') + flow_preds.append(flow_bilinear) + + # flow propagation with self-attn + if (pred_bidir_flow or pred_bidir_depth) and scale_idx == 0: + feature0 = torch.cat((feature0, feature1), dim=0) # [2*B, C, H, W] for propagation + + flow = self.feature_flow_attn(feature0, flow.detach(), + local_window_attn=prop_radius > 0, + local_window_radius=prop_radius, + ) + + # bilinear exclude the last one + if self.training and scale_idx < self.num_scales - 1: + flow_up = self.upsample_flow(flow, feature0, bilinear=True, + upsample_factor=upsample_factor, + is_depth=task == 'depth') + flow_preds.append(flow_up) + + if scale_idx == self.num_scales - 1: + if not self.reg_refine: + # upsample to the original image resolution + + if task == 'stereo': + flow_pad = torch.cat((-flow, torch.zeros_like(flow)), dim=1) # [B, 2, H, W] + flow_up_pad = self.upsample_flow(flow_pad, feature0) + flow_up = -flow_up_pad[:, :1] # [B, 1, H, W] + elif task == 'depth': + depth_pad = torch.cat((flow, torch.zeros_like(flow)), dim=1) # [B, 2, H, W] + depth_up_pad = self.upsample_flow(depth_pad, feature0, + is_depth=True).clamp(min=min_depth, max=max_depth) + flow_up = depth_up_pad[:, :1] # [B, 1, H, W] + else: + flow_up = self.upsample_flow(flow, feature0) + + flow_preds.append(flow_up) + else: + # task-specific local regression refinement + # supervise current flow + if self.training: + flow_up = self.upsample_flow(flow, feature0, bilinear=True, + upsample_factor=upsample_factor, + is_depth=task == 'depth') + flow_preds.append(flow_up) + + assert num_reg_refine > 0 + for refine_iter_idx in range(num_reg_refine): + flow = flow.detach() + + if task == 'stereo': + zeros = torch.zeros_like(flow) # [B, 1, H, W] + # NOTE: reverse disp, disparity is positive + displace = torch.cat((-flow, zeros), dim=1) # [B, 2, H, W] + correlation = local_correlation_with_flow( + feature0_ori, + feature1_ori, + flow=displace, + local_radius=4, + ) # [B, (2R+1)^2, H, W] + elif task == 'depth': + if pred_bidir_depth and refine_iter_idx == 0: + intrinsics_curr = intrinsics_curr.repeat(2, 1, 1) + pose = torch.cat((pose, torch.inverse(pose)), dim=0) + + feature0_ori, feature1_ori = torch.cat((feature0_ori, feature1_ori), + dim=0), torch.cat((feature1_ori, + feature0_ori), dim=0) + + flow_from_depth = compute_flow_with_depth_pose(1. / flow.squeeze(1), + intrinsics_curr, + extrinsics_rel=pose, + ) + + correlation = local_correlation_with_flow( + feature0_ori, + feature1_ori, + flow=flow_from_depth, + local_radius=4, + ) # [B, (2R+1)^2, H, W] + + else: + correlation = local_correlation_with_flow( + feature0_ori, + feature1_ori, + flow=flow, + local_radius=4, + ) # [B, (2R+1)^2, H, W] + + proj = self.refine_proj(feature0) + + net, inp = torch.chunk(proj, chunks=2, dim=1) + + net = torch.tanh(net) + inp = torch.relu(inp) + + net, up_mask, residual_flow = self.refine(net, inp, correlation, flow.clone(), + ) + + if task == 'depth': + flow = (flow - residual_flow).clamp(min=min_depth, max=max_depth) + else: + flow = flow + residual_flow + + if task == 'stereo': + flow = flow.clamp(min=0) # positive + + if self.training or refine_iter_idx == num_reg_refine - 1: + if task == 'depth': + if refine_iter_idx < num_reg_refine - 1: + # bilinear upsampling + flow_up = self.upsample_flow(flow, feature0, bilinear=True, + upsample_factor=upsample_factor, + is_depth=True) + else: + # last one convex upsampling + # NOTE: clamp depth due to the zero padding in the unfold in the convex upsampling + # pad depth to 2 channels as flow + depth_pad = torch.cat((flow, torch.zeros_like(flow)), dim=1) # [B, 2, H, W] + depth_up_pad = self.upsample_flow(depth_pad, feature0, + is_depth=True).clamp(min=min_depth, + max=max_depth) + flow_up = depth_up_pad[:, :1] # [B, 1, H, W] + + else: + flow_up = upsample_flow_with_mask(flow, up_mask, upsample_factor=self.upsample_factor, + is_depth=task == 'depth') + + flow_preds.append(flow_up) + + if task == 'stereo': + for i in range(len(flow_preds)): + flow_preds[i] = flow_preds[i].squeeze(1) # [B, H, W] + + # convert inverse depth to depth + if task == 'depth': + for i in range(len(flow_preds)): + flow_preds[i] = 1. / flow_preds[i].squeeze(1) # [B, H, W] + + results_dict.update({'flow_preds': flow_preds}) + + return flow_preds diff --git a/modules/components/m2m_unimatch/unimatch/utils.py b/modules/components/m2m_unimatch/unimatch/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7104da0da173030a721c5cbbf1388bfc4d4fab80 --- /dev/null +++ b/modules/components/m2m_unimatch/unimatch/utils.py @@ -0,0 +1,216 @@ +import torch +import torch.nn.functional as F +from .position import PositionEmbeddingSine + + +def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): + assert device is not None + + x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device), + torch.linspace(h_min, h_max, len_h, device=device)], + ) + grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2] + + return grid + + +def normalize_coords(coords, h, w): + # coords: [B, H, W, 2] + c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device) + return (coords - c) / c # [-1, 1] + + +def normalize_img(img0, img1): + # loaded images are in [0, 255] + # normalize by ImageNet mean and std + mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device) + std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(img1.device) + img0 = (img0 - mean) / std + img1 = (img1 - mean) / std + + return img0, img1 + + +def split_feature(feature, + num_splits=2, + channel_last=False, + ): + if channel_last: # [B, H, W, C] + b, h, w, c = feature.size() + assert h % num_splits == 0 and w % num_splits == 0 + + b_new = b * num_splits * num_splits + h_new = h // num_splits + w_new = w // num_splits + + feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c + ).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c) # [B*K*K, H/K, W/K, C] + else: # [B, C, H, W] + b, c, h, w = feature.size() + assert h % num_splits == 0 and w % num_splits == 0 + + b_new = b * num_splits * num_splits + h_new = h // num_splits + w_new = w // num_splits + + feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits + ).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new) # [B*K*K, C, H/K, W/K] + + return feature + + +def merge_splits(splits, + num_splits=2, + channel_last=False, + ): + if channel_last: # [B*K*K, H/K, W/K, C] + b, h, w, c = splits.size() + new_b = b // num_splits // num_splits + + splits = splits.view(new_b, num_splits, num_splits, h, w, c) + merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view( + new_b, num_splits * h, num_splits * w, c) # [B, H, W, C] + else: # [B*K*K, C, H/K, W/K] + b, c, h, w = splits.size() + new_b = b // num_splits // num_splits + + splits = splits.view(new_b, num_splits, num_splits, c, h, w) + merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view( + new_b, c, num_splits * h, num_splits * w) # [B, C, H, W] + + return merge + + +def generate_shift_window_attn_mask(input_resolution, window_size_h, window_size_w, + shift_size_h, shift_size_w, device=torch.device('cuda')): + # ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + # calculate attention mask for SW-MSA + h, w = input_resolution + img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1 + h_slices = (slice(0, -window_size_h), + slice(-window_size_h, -shift_size_h), + slice(-shift_size_h, None)) + w_slices = (slice(0, -window_size_w), + slice(-window_size_w, -shift_size_w), + slice(-shift_size_w, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = split_feature(img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True) + + mask_windows = mask_windows.view(-1, window_size_h * window_size_w) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + +def feature_add_position(feature0, feature1, attn_splits, feature_channels): + pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2) + + if attn_splits > 1: # add position in splited window + feature0_splits = split_feature(feature0, num_splits=attn_splits) + feature1_splits = split_feature(feature1, num_splits=attn_splits) + + position = pos_enc(feature0_splits) + + feature0_splits = feature0_splits + position + feature1_splits = feature1_splits + position + + feature0 = merge_splits(feature0_splits, num_splits=attn_splits) + feature1 = merge_splits(feature1_splits, num_splits=attn_splits) + else: + position = pos_enc(feature0) + + feature0 = feature0 + position + feature1 = feature1 + position + + return feature0, feature1 + + +def upsample_flow_with_mask(flow, up_mask, upsample_factor, + is_depth=False): + # convex upsampling following raft + + mask = up_mask + b, flow_channel, h, w = flow.shape + mask = mask.view(b, 1, 9, upsample_factor, upsample_factor, h, w) # [B, 1, 9, K, K, H, W] + mask = torch.softmax(mask, dim=2) + + multiplier = 1 if is_depth else upsample_factor + up_flow = F.unfold(multiplier * flow, [3, 3], padding=1) + up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w) # [B, 2, 9, 1, 1, H, W] + + up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W] + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W] + up_flow = up_flow.reshape(b, flow_channel, upsample_factor * h, + upsample_factor * w) # [B, 2, K*H, K*W] + + return up_flow + + +def split_feature_1d(feature, + num_splits=2, + ): + # feature: [B, W, C] + b, w, c = feature.size() + assert w % num_splits == 0 + + b_new = b * num_splits + w_new = w // num_splits + + feature = feature.view(b, num_splits, w // num_splits, c + ).view(b_new, w_new, c) # [B*K, W/K, C] + + return feature + + +def merge_splits_1d(splits, + h, + num_splits=2, + ): + b, w, c = splits.size() + new_b = b // num_splits // h + + splits = splits.view(new_b, h, num_splits, w, c) + merge = splits.view( + new_b, h, num_splits * w, c) # [B, H, W, C] + + return merge + + +def window_partition_1d(x, window_size_w): + """ + Args: + x: (B, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, C) + """ + B, W, C = x.shape + x = x.view(B, W // window_size_w, window_size_w, C).view(-1, window_size_w, C) + return x + + +def generate_shift_window_attn_mask_1d(input_w, window_size_w, + shift_size_w, device=torch.device('cuda')): + # calculate attention mask for SW-MSA + img_mask = torch.zeros((1, input_w, 1)).to(device) # 1 W 1 + w_slices = (slice(0, -window_size_w), + slice(-window_size_w, -shift_size_w), + slice(-shift_size_w, None)) + cnt = 0 + for w in w_slices: + img_mask[:, w, :] = cnt + cnt += 1 + + mask_windows = window_partition_1d(img_mask, window_size_w) # nW, window_size, 1 + mask_windows = mask_windows.view(-1, window_size_w) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # nW, window_size, window_size + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask diff --git a/modules/components/upr_basic/__init__.py b/modules/components/upr_basic/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b880afa46deb6b026b889b6d34b5dd9ebc2fee2b --- /dev/null +++ b/modules/components/upr_basic/__init__.py @@ -0,0 +1 @@ +from .upr import Model \ No newline at end of file diff --git a/modules/components/upr_basic/__pycache__/__init__.cpython-310.pyc b/modules/components/upr_basic/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8c7c8c3a5bd60e11cacd44267ceb967a63b7fe6 Binary files /dev/null and b/modules/components/upr_basic/__pycache__/__init__.cpython-310.pyc differ diff --git a/modules/components/upr_basic/__pycache__/__init__.cpython-38.pyc b/modules/components/upr_basic/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b324a4a47d5359cd83300be1ccdf67149bc6cce5 Binary files /dev/null and b/modules/components/upr_basic/__pycache__/__init__.cpython-38.pyc differ diff --git a/modules/components/upr_basic/__pycache__/__init__.cpython-39.pyc b/modules/components/upr_basic/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1497999ae50deb5aaa53e0824ab306e20f6e9bd1 Binary files /dev/null and b/modules/components/upr_basic/__pycache__/__init__.cpython-39.pyc differ diff --git a/modules/components/upr_basic/__pycache__/correlation.cpython-310.pyc b/modules/components/upr_basic/__pycache__/correlation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c28a469b1f20188a056dac919523b9194ac0f20 Binary files /dev/null and b/modules/components/upr_basic/__pycache__/correlation.cpython-310.pyc differ diff --git a/modules/components/upr_basic/__pycache__/correlation.cpython-38.pyc b/modules/components/upr_basic/__pycache__/correlation.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d9d0ab34f1f9ba36d90b748a44f93b8f5d0f910 Binary files /dev/null and b/modules/components/upr_basic/__pycache__/correlation.cpython-38.pyc differ diff --git a/modules/components/upr_basic/__pycache__/correlation.cpython-39.pyc b/modules/components/upr_basic/__pycache__/correlation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47dac9045ece06a3be1f1e3fcd6d9b6c5a188533 Binary files /dev/null and b/modules/components/upr_basic/__pycache__/correlation.cpython-39.pyc differ diff --git a/modules/components/upr_basic/__pycache__/frequency_enhance.cpython-38.pyc b/modules/components/upr_basic/__pycache__/frequency_enhance.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60afebcefb1bdc265c3f00cee80b5cc6ecb428c4 Binary files /dev/null and b/modules/components/upr_basic/__pycache__/frequency_enhance.cpython-38.pyc differ diff --git a/modules/components/upr_basic/__pycache__/softsplat.cpython-310.pyc b/modules/components/upr_basic/__pycache__/softsplat.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59c911f488a6cd3511342cf271b7b6693d053fd9 Binary files /dev/null and b/modules/components/upr_basic/__pycache__/softsplat.cpython-310.pyc differ diff --git a/modules/components/upr_basic/__pycache__/softsplat.cpython-38.pyc b/modules/components/upr_basic/__pycache__/softsplat.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10497679f887a420149a7687a8213a3a3e0e269e Binary files /dev/null and b/modules/components/upr_basic/__pycache__/softsplat.cpython-38.pyc differ diff --git a/modules/components/upr_basic/__pycache__/softsplat.cpython-39.pyc b/modules/components/upr_basic/__pycache__/softsplat.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6f452f66bd90115ebfa9c6f9e9878806930244c Binary files /dev/null and b/modules/components/upr_basic/__pycache__/softsplat.cpython-39.pyc differ diff --git a/modules/components/upr_basic/__pycache__/upr.cpython-310.pyc b/modules/components/upr_basic/__pycache__/upr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ce7797bb42b998551b98f919748cf31f3bf3d0d Binary files /dev/null and b/modules/components/upr_basic/__pycache__/upr.cpython-310.pyc differ diff --git a/modules/components/upr_basic/__pycache__/upr.cpython-38.pyc b/modules/components/upr_basic/__pycache__/upr.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b1c450ee4c644a01514f07a6e3528fe8a6a073f Binary files /dev/null and b/modules/components/upr_basic/__pycache__/upr.cpython-38.pyc differ diff --git a/modules/components/upr_basic/__pycache__/upr.cpython-39.pyc b/modules/components/upr_basic/__pycache__/upr.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32d597e3ca415f8d1fc9cd92c5dc5267e6174b30 Binary files /dev/null and b/modules/components/upr_basic/__pycache__/upr.cpython-39.pyc differ diff --git a/modules/components/upr_basic/correlation.py b/modules/components/upr_basic/correlation.py new file mode 100644 index 0000000000000000000000000000000000000000..c9c97e3e80f79dd141f01578763090bc96d2a787 --- /dev/null +++ b/modules/components/upr_basic/correlation.py @@ -0,0 +1,397 @@ +#!/usr/bin/env python + +import torch + +import cupy +import re + +kernel_Correlation_rearrange = ''' + extern "C" __global__ void kernel_Correlation_rearrange( + const int n, + const float* input, + float* output + ) { + int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; + + if (intIndex >= n) { + return; + } + + int intSample = blockIdx.z; + int intChannel = blockIdx.y; + + float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex]; + + __syncthreads(); + + int intPaddedY = (intIndex / SIZE_3(input)) + 4; + int intPaddedX = (intIndex % SIZE_3(input)) + 4; + int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX; + + output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue; + } +''' + +kernel_Correlation_updateOutput = ''' + extern "C" __global__ void kernel_Correlation_updateOutput( + const int n, + const float* rbot0, + const float* rbot1, + float* top + ) { + extern __shared__ char patch_data_char[]; + + float *patch_data = (float *)patch_data_char; + + // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1 + int x1 = blockIdx.x + 4; + int y1 = blockIdx.y + 4; + int item = blockIdx.z; + int ch_off = threadIdx.x; + + // Load 3D patch into shared shared memory + for (int j = 0; j < 1; j++) { // HEIGHT + for (int i = 0; i < 1; i++) { // WIDTH + int ji_off = (j + i) * SIZE_3(rbot0); + for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS + int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch; + int idxPatchData = ji_off + ch; + patch_data[idxPatchData] = rbot0[idx1]; + } + } + } + + __syncthreads(); + + __shared__ float sum[32]; + + // Compute correlation + for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) { + sum[ch_off] = 0; + + int s2o = top_channel % 9 - 4; + int s2p = top_channel / 9 - 4; + + for (int j = 0; j < 1; j++) { // HEIGHT + for (int i = 0; i < 1; i++) { // WIDTH + int ji_off = (j + i) * SIZE_3(rbot0); + for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS + int x2 = x1 + s2o; + int y2 = y1 + s2p; + + int idxPatchData = ji_off + ch; + int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch; + + sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2]; + } + } + } + + __syncthreads(); + + if (ch_off == 0) { + float total_sum = 0; + for (int idx = 0; idx < 32; idx++) { + total_sum += sum[idx]; + } + const int sumelems = SIZE_3(rbot0); + const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x; + top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems; + } + } + } +''' + +kernel_Correlation_updateGradFirst = ''' + #define ROUND_OFF 50000 + + extern "C" __global__ void kernel_Correlation_updateGradFirst( + const int n, + const int intSample, + const float* rbot0, + const float* rbot1, + const float* gradOutput, + float* gradFirst, + float* gradSecond + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + int n = intIndex % SIZE_1(gradFirst); // channels + int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 4; // w-pos + int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 4; // h-pos + + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = round_off; + + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) + int ymin = (m - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) + + // Same here: + int xmax = (l - 4 + round_off_s1) - round_off; // floor (l - 4) + int ymax = (m - 4 + round_off_s1) - round_off; // floor (m - 4) + + float sum = 0; + if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { + xmin = max(0,xmin); + xmax = min(SIZE_3(gradOutput)-1,xmax); + + ymin = max(0,ymin); + ymax = min(SIZE_2(gradOutput)-1,ymax); + + for (int p = -4; p <= 4; p++) { + for (int o = -4; o <= 4; o++) { + // Get rbot1 data: + int s2o = o; + int s2p = p; + int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n; + float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n] + + // Index offset for gradOutput in following loops: + int op = (p+4) * 9 + (o+4); // index[o,p] + int idxopoffset = (intSample * SIZE_1(gradOutput) + op); + + for (int y = ymin; y <= ymax; y++) { + for (int x = xmin; x <= xmax; x++) { + int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] + sum += gradOutput[idxgradOutput] * bot1tmp; + } + } + } + } + } + const int sumelems = SIZE_1(gradFirst); + const int bot0index = ((n * SIZE_2(gradFirst)) + (m-4)) * SIZE_3(gradFirst) + (l-4); + gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems; + } } +''' + +kernel_Correlation_updateGradSecond = ''' + #define ROUND_OFF 50000 + + extern "C" __global__ void kernel_Correlation_updateGradSecond( + const int n, + const int intSample, + const float* rbot0, + const float* rbot1, + const float* gradOutput, + float* gradFirst, + float* gradSecond + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + int n = intIndex % SIZE_1(gradSecond); // channels + int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 4; // w-pos + int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 4; // h-pos + + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = round_off; + + float sum = 0; + for (int p = -4; p <= 4; p++) { + for (int o = -4; o <= 4; o++) { + int s2o = o; + int s2p = p; + + //Get X,Y ranges and clamp + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 4 - s2o + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) + int ymin = (m - 4 - s2p + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) + + // Same here: + int xmax = (l - 4 - s2o + round_off_s1) - round_off; // floor (l - 4 - s2o) + int ymax = (m - 4 - s2p + round_off_s1) - round_off; // floor (m - 4 - s2p) + + if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { + xmin = max(0,xmin); + xmax = min(SIZE_3(gradOutput)-1,xmax); + + ymin = max(0,ymin); + ymax = min(SIZE_2(gradOutput)-1,ymax); + + // Get rbot0 data: + int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n; + float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n] + + // Index offset for gradOutput in following loops: + int op = (p+4) * 9 + (o+4); // index[o,p] + int idxopoffset = (intSample * SIZE_1(gradOutput) + op); + + for (int y = ymin; y <= ymax; y++) { + for (int x = xmin; x <= xmax; x++) { + int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] + sum += gradOutput[idxgradOutput] * bot0tmp; + } + } + } + } + } + const int sumelems = SIZE_1(gradSecond); + const int bot1index = ((n * SIZE_2(gradSecond)) + (m-4)) * SIZE_3(gradSecond) + (l-4); + gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems; + } } +''' + +def cupy_kernel(strFunction, objVariables): + strKernel = globals()[strFunction] + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + # end + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg])) + # end + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) + + if objMatch is None: + break + # end + + intArgs = int(objMatch.group(2)) + strArgs = objMatch.group(4).split(',') + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] + + strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') + # end + + return strKernel +# end + +@cupy.memoize(for_each_device=True) +def cupy_launch(strFunction, strKernel): + return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) +# end + +class _FunctionCorrelation(torch.autograd.Function): + @staticmethod + def forward(self, first, second): + rbot0 = first.new_zeros([ first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1] ]) + rbot1 = first.new_zeros([ first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1] ]) + + self.save_for_backward(first, second, rbot0, rbot1) + + assert(first.is_contiguous() == True) + assert(second.is_contiguous() == True) + + output = first.new_zeros([ first.shape[0], 81, first.shape[2], first.shape[3] ]) + + if first.is_cuda == True: + n = first.shape[2] * first.shape[3] + cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { + 'input': first, + 'output': rbot0 + }))( + grid=tuple([ int((n + 16 - 1) / 16), first.shape[1], first.shape[0] ]), + block=tuple([ 16, 1, 1 ]), + args=[ n, first.data_ptr(), rbot0.data_ptr() ] + ) + + n = second.shape[2] * second.shape[3] + cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { + 'input': second, + 'output': rbot1 + }))( + grid=tuple([ int((n + 16 - 1) / 16), second.shape[1], second.shape[0] ]), + block=tuple([ 16, 1, 1 ]), + args=[ n, second.data_ptr(), rbot1.data_ptr() ] + ) + + n = output.shape[1] * output.shape[2] * output.shape[3] + cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'top': output + }))( + grid=tuple([ output.shape[3], output.shape[2], output.shape[0] ]), + block=tuple([ 32, 1, 1 ]), + shared_mem=first.shape[1] * 4, + args=[ n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr() ] + ) + + elif first.is_cuda == False: + raise NotImplementedError() + + # end + + return output + # end + + @staticmethod + def backward(self, gradOutput): + first, second, rbot0, rbot1 = self.saved_tensors + + assert(gradOutput.is_contiguous() == True) + + gradFirst = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[0] == True else None + gradSecond = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[1] == True else None + + if first.is_cuda == True: + if gradFirst is not None: + for intSample in range(first.shape[0]): + n = first.shape[1] * first.shape[2] * first.shape[3] + cupy_launch('kernel_Correlation_updateGradFirst', cupy_kernel('kernel_Correlation_updateGradFirst', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'gradOutput': gradOutput, + 'gradFirst': gradFirst, + 'gradSecond': None + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), gradFirst.data_ptr(), None ] + ) + # end + # end + + if gradSecond is not None: + for intSample in range(first.shape[0]): + n = first.shape[1] * first.shape[2] * first.shape[3] + cupy_launch('kernel_Correlation_updateGradSecond', cupy_kernel('kernel_Correlation_updateGradSecond', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'gradOutput': gradOutput, + 'gradFirst': None, + 'gradSecond': gradSecond + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None, gradSecond.data_ptr() ] + ) + # end + # end + + elif first.is_cuda == False: + raise NotImplementedError() + + # end + + return gradFirst, gradSecond + # end +# end + +def FunctionCorrelation(tenFirst, tenSecond): + return _FunctionCorrelation.apply(tenFirst, tenSecond) +# end + +class ModuleCorrelation(torch.nn.Module): + def __init__(self): + super(ModuleCorrelation, self).__init__() + # end + + def forward(self, tenFirst, tenSecond): + return _FunctionCorrelation.apply(tenFirst, tenSecond) + # end +# end \ No newline at end of file diff --git a/modules/components/upr_basic/softsplat.py b/modules/components/upr_basic/softsplat.py new file mode 100644 index 0000000000000000000000000000000000000000..8967303376941351da0453ecc1ea61163180dcd3 --- /dev/null +++ b/modules/components/upr_basic/softsplat.py @@ -0,0 +1,359 @@ +#!/usr/bin/env python + +import torch + +import cupy +import re + +kernel_Softsplat_updateOutput = ''' + extern "C" __global__ void kernel_Softsplat_updateOutput( + const int n, + const float* input, + const float* flow, + float* output + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(output) / SIZE_2(output) / SIZE_1(output) ) % SIZE_0(output); + const int intC = ( intIndex / SIZE_3(output) / SIZE_2(output) ) % SIZE_1(output); + const int intY = ( intIndex / SIZE_3(output) ) % SIZE_2(output); + const int intX = ( intIndex ) % SIZE_3(output); + + float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); + float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); + + int intNorthwestX = (int) (floor(fltOutputX)); + int intNorthwestY = (int) (floor(fltOutputY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + float fltNorthwest = ((float) (intSoutheastX) - fltOutputX ) * ((float) (intSoutheastY) - fltOutputY ); + float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY ); + float fltSouthwest = ((float) (intNortheastX) - fltOutputX ) * (fltOutputY - (float) (intNortheastY)); + float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY)); + + if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(output)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(output))) { + atomicAdd(&output[OFFSET_4(output, intN, intC, intNorthwestY, intNorthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltNorthwest); + } + + if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(output)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(output))) { + atomicAdd(&output[OFFSET_4(output, intN, intC, intNortheastY, intNortheastX)], VALUE_4(input, intN, intC, intY, intX) * fltNortheast); + } + + if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(output)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(output))) { + atomicAdd(&output[OFFSET_4(output, intN, intC, intSouthwestY, intSouthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltSouthwest); + } + + if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(output)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(output))) { + atomicAdd(&output[OFFSET_4(output, intN, intC, intSoutheastY, intSoutheastX)], VALUE_4(input, intN, intC, intY, intX) * fltSoutheast); + } + } } +''' + +kernel_Softsplat_updateGradInput = ''' + extern "C" __global__ void kernel_Softsplat_updateGradInput( + const int n, + const float* input, + const float* flow, + const float* gradOutput, + float* gradInput, + float* gradFlow + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) / SIZE_1(gradInput) ) % SIZE_0(gradInput); + const int intC = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) ) % SIZE_1(gradInput); + const int intY = ( intIndex / SIZE_3(gradInput) ) % SIZE_2(gradInput); + const int intX = ( intIndex ) % SIZE_3(gradInput); + + float fltGradInput = 0.0; + + float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); + float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); + + int intNorthwestX = (int) (floor(fltOutputX)); + int intNorthwestY = (int) (floor(fltOutputY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + float fltNorthwest = ((float) (intSoutheastX) - fltOutputX ) * ((float) (intSoutheastY) - fltOutputY ); + float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY ); + float fltSouthwest = ((float) (intNortheastX) - fltOutputX ) * (fltOutputY - (float) (intNortheastY)); + float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY)); + + if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) { + fltGradInput += VALUE_4(gradOutput, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest; + } + + if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) { + fltGradInput += VALUE_4(gradOutput, intN, intC, intNortheastY, intNortheastX) * fltNortheast; + } + + if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) { + fltGradInput += VALUE_4(gradOutput, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest; + } + + if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) { + fltGradInput += VALUE_4(gradOutput, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast; + } + + gradInput[intIndex] = fltGradInput; + } } +''' + +kernel_Softsplat_updateGradFlow = ''' + extern "C" __global__ void kernel_Softsplat_updateGradFlow( + const int n, + const float* input, + const float* flow, + const float* gradOutput, + float* gradInput, + float* gradFlow + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + float fltGradFlow = 0.0; + + const int intN = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) / SIZE_1(gradFlow) ) % SIZE_0(gradFlow); + const int intC = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) ) % SIZE_1(gradFlow); + const int intY = ( intIndex / SIZE_3(gradFlow) ) % SIZE_2(gradFlow); + const int intX = ( intIndex ) % SIZE_3(gradFlow); + + float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); + float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); + + int intNorthwestX = (int) (floor(fltOutputX)); + int intNorthwestY = (int) (floor(fltOutputY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + float fltNorthwest = 0.0; + float fltNortheast = 0.0; + float fltSouthwest = 0.0; + float fltSoutheast = 0.0; + + if (intC == 0) { + fltNorthwest = ((float) (-1.0)) * ((float) (intSoutheastY) - fltOutputY ); + fltNortheast = ((float) (+1.0)) * ((float) (intSouthwestY) - fltOutputY ); + fltSouthwest = ((float) (-1.0)) * (fltOutputY - (float) (intNortheastY)); + fltSoutheast = ((float) (+1.0)) * (fltOutputY - (float) (intNorthwestY)); + + } else if (intC == 1) { + fltNorthwest = ((float) (intSoutheastX) - fltOutputX ) * ((float) (-1.0)); + fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (-1.0)); + fltSouthwest = ((float) (intNortheastX) - fltOutputX ) * ((float) (+1.0)); + fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * ((float) (+1.0)); + + } + + for (int intChannel = 0; intChannel < SIZE_1(gradOutput); intChannel += 1) { + float fltInput = VALUE_4(input, intN, intChannel, intY, intX); + + if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) { + fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNorthwestY, intNorthwestX) * fltNorthwest; + } + + if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) { + fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNortheastY, intNortheastX) * fltNortheast; + } + + if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) { + fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSouthwestY, intSouthwestX) * fltSouthwest; + } + + if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) { + fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSoutheastY, intSoutheastX) * fltSoutheast; + } + } + + gradFlow[intIndex] = fltGradFlow; + } } +''' + +def cupy_kernel(strFunction, objVariables): + strKernel = globals()[strFunction] + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg])) + + while True: + objMatch = re.search('(OFFSET_)([0-4])(\()([^\)]+)(\))', strKernel) + + if objMatch is None: + break + + intArgs = int(objMatch.group(2)) + strArgs = objMatch.group(4).split(',') + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')')\ + .strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] + + strKernel = strKernel.replace(objMatch.group(0), '(' + str.join('+', strIndex) + ')') + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) + + if objMatch is None: + break + + intArgs = int(objMatch.group(2)) + strArgs = objMatch.group(4).split(',') + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')')\ + .strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] + + strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') + + return strKernel + + +@cupy.memoize(for_each_device=True) +def cupy_launch(strFunction, strKernel): + return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) + + +class _FunctionSoftsplat(torch.autograd.Function): + @staticmethod + def forward(self, input, flow): + self.save_for_backward(input, flow) + + intSamples = input.shape[0] + intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3] + intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3] + + assert(intFlowDepth == 2) + assert(intInputHeight == intFlowHeight) + assert(intInputWidth == intFlowWidth) + + assert(input.is_contiguous() == True) + assert(flow.is_contiguous() == True) + + output = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ]) + + if input.is_cuda == True: + n = output.nelement() + cupy_launch('kernel_Softsplat_updateOutput', cupy_kernel('kernel_Softsplat_updateOutput', { + 'input': input, + 'flow': flow, + 'output': output + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ n, input.data_ptr(), flow.data_ptr(), output.data_ptr() ] + ) + + elif input.is_cuda == False: + raise NotImplementedError() + + return output + + + @staticmethod + def backward(self, gradOutput): + input, flow = self.saved_tensors + + intSamples = input.shape[0] + intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3] + intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3] + + assert(intFlowDepth == 2) + assert(intInputHeight == intFlowHeight) + assert(intInputWidth == intFlowWidth) + + assert(gradOutput.is_contiguous() == True) + + gradInput = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ])\ + if self.needs_input_grad[0] == True else None + gradFlow = input.new_zeros([ intSamples, intFlowDepth, intFlowHeight, intFlowWidth ])\ + if self.needs_input_grad[1] == True else None + + if input.is_cuda == True: + if gradInput is not None: + n = gradInput.nelement() + cupy_launch('kernel_Softsplat_updateGradInput', cupy_kernel('kernel_Softsplat_updateGradInput', { + 'input': input, + 'flow': flow, + 'gradOutput': gradOutput, + 'gradInput': gradInput, + 'gradFlow': gradFlow + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ n, input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), gradInput.data_ptr(), None ] + ) + + if gradFlow is not None: + n = gradFlow.nelement() + cupy_launch('kernel_Softsplat_updateGradFlow', cupy_kernel('kernel_Softsplat_updateGradFlow', { + 'input': input, + 'flow': flow, + 'gradOutput': gradOutput, + 'gradInput': gradInput, + 'gradFlow': gradFlow + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ n, input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), None, gradFlow.data_ptr() ] + ) + + elif input.is_cuda == False: + raise NotImplementedError() + + + return gradInput, gradFlow + + +def FunctionSoftsplat(tenInput, tenFlow, tenMetric, strType): + assert(tenMetric is None or tenMetric.shape[1] == 1) + assert(strType in ['summation', 'average', 'linear', 'softmax']) + + if strType == 'average': + tenInput = torch.cat([ tenInput, tenInput.new_ones(tenInput.shape[0], 1, tenInput.shape[2], tenInput.shape[3]) ], 1) + + elif strType == 'linear': + tenInput = torch.cat([ tenInput * tenMetric, tenMetric ], 1) + + elif strType == 'softmax': + tenInput = torch.cat([ tenInput * tenMetric.exp(), tenMetric.exp() ], 1) + + + tenOutput = _FunctionSoftsplat.apply(tenInput, tenFlow) + + if strType != 'summation': + tenNormalize = tenOutput[:, -1:, :, :] + + tenNormalize[tenNormalize == 0.0] = 1.0 + + tenOutput = tenOutput[:, :-1, :, :] / tenNormalize + + return tenOutput + + +class ModuleSoftsplat(torch.nn.Module): + def __init__(self, strType): + super(ModuleSoftsplat, self).__init__() + + self.strType = strType + + def forward(self, tenInput, tenFlow, tenMetric): + return FunctionSoftsplat(tenInput, tenFlow, tenMetric, self.strType) diff --git a/modules/components/upr_basic/upr.py b/modules/components/upr_basic/upr.py new file mode 100644 index 0000000000000000000000000000000000000000..74d0580dd4faf219967401d96559a7a0fd082442 --- /dev/null +++ b/modules/components/upr_basic/upr.py @@ -0,0 +1,431 @@ +import torch +import math +import numpy +import torch.nn.functional as F +import torch.nn as nn + +from ..components import register + +import modules.components.upr_basic.softsplat as softsplat +import modules.components.upr_basic.correlation as correlation +from utils.padder import InputPadder + + +#**************************************************************************************************# +# => Feature Pyramid +#**************************************************************************************************# +class FeatPyramid(nn.Module): + """A 3-level feature pyramid, which by default is shared by the motion + estimator and synthesis network. + """ + def __init__(self): + super(FeatPyramid, self).__init__() + self.conv_stage0 = nn.Sequential( + nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_stage1 = nn.Sequential( + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, + stride=2, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_stage2 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, + stride=2, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + + def forward(self, img): + C0 = self.conv_stage0(img) + C1 = self.conv_stage1(C0) + C2 = self.conv_stage2(C1) + return [C0, C1, C2] + + + + +#**************************************************************************************************# +# => Motion Estimation +#**************************************************************************************************# +class MotionEstimator(nn.Module): + """Bi-directional optical flow estimator + 1) construct partial cost volume with the CNN features from the stage 2 of + the feature pyramid; + 2) estimate bi-directional flows, by feeding cost volume, CNN features for + both warped images, CNN feature and estimated flow from previous iteration. + """ + def __init__(self): + super(MotionEstimator, self).__init__() + # (4*2 + 1) ** 2 + 128 * 2 + 128 + 4 = 469 + self.conv_layer1 = nn.Sequential( + nn.Conv2d(in_channels=469, out_channels=320, + kernel_size=1, stride=1, padding=0), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer2 = nn.Sequential( + nn.Conv2d(in_channels=320, out_channels=256, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer3 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=224, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer4 = nn.Sequential( + nn.Conv2d(in_channels=224, out_channels=192, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer5 = nn.Sequential( + nn.Conv2d(in_channels=192, out_channels=128, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer6 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=4, + kernel_size=3, stride=1, padding=1)) + + + def forward(self, feat0, feat1, last_feat, last_flow): + corr_fn=correlation.FunctionCorrelation + feat0 = softsplat.FunctionSoftsplat( + tenInput=feat0, tenFlow=last_flow[:, :2]*0.25*0.5, + tenMetric=None, strType='average') + feat1 = softsplat.FunctionSoftsplat( + tenInput=feat1, tenFlow=last_flow[:, 2:]*0.25*0.5, + tenMetric=None, strType='average') + + volume = F.leaky_relu( + input=corr_fn(tenFirst=feat0, tenSecond=feat1), + negative_slope=0.1, inplace=False) + input_feat = torch.cat([volume, feat0, feat1, last_feat, last_flow], 1) + feat = self.conv_layer1(input_feat) + feat = self.conv_layer2(feat) + feat = self.conv_layer3(feat) + feat = self.conv_layer4(feat) + feat = self.conv_layer5(feat) + flow = self.conv_layer6(feat) + + return flow, feat + + + + +#**************************************************************************************************# +# => Frame Synthesis +#**************************************************************************************************# +class SynthesisNetwork(nn.Module): + def __init__(self): + super(SynthesisNetwork, self).__init__() + input_channels = 9+4+6 + self.encoder_conv = nn.Sequential( + nn.Conv2d(in_channels=input_channels, out_channels=64, + kernel_size=3, stride=1, padding=1), + nn.PReLU(num_parameters=64), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64)) + self.encoder_down1 = nn.Sequential( + nn.Conv2d(in_channels=64 + 32 + 32, out_channels=128, + kernel_size=3, stride=2, padding=1), + nn.PReLU(num_parameters=128), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=128), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=128)) + self.encoder_down2 = nn.Sequential( + nn.Conv2d(in_channels=128 + 64 + 64, out_channels=256, + kernel_size=3, stride=2, padding=1), + nn.PReLU(num_parameters=256), + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=256), + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=256)) + self.decoder_up1 = nn.Sequential( + torch.nn.ConvTranspose2d(in_channels=256 + 128 + 128, + out_channels=128, kernel_size=4, stride=2, + padding=1, bias=True), + nn.PReLU(num_parameters=128), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=128)) + self.decoder_up2 = nn.Sequential( + torch.nn.ConvTranspose2d(in_channels=128 + 128, + out_channels=64, kernel_size=4, stride=2, + padding=1, bias=True), + nn.PReLU(num_parameters=64), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64)) + self.decoder_conv = nn.Sequential( + nn.Conv2d(in_channels=64 + 64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64)) + self.pred = nn.Conv2d(in_channels=64, out_channels=5, kernel_size=3, + stride=1, padding=1) + + + def get_warped_representations(self, bi_flow, c0, c1, + i0=None, i1=None, time_step=0.5): + flow_0t = bi_flow[:, :2] * time_step + flow_1t = bi_flow[:, 2:4] * (1 - time_step) + warped_c0 = softsplat.FunctionSoftsplat( + tenInput=c0, tenFlow=flow_0t, + tenMetric=None, strType='average') + warped_c1 = softsplat.FunctionSoftsplat( + tenInput=c1, tenFlow=flow_1t, + tenMetric=None, strType='average') + if (i0 is None) and (i1 is None): + return warped_c0, warped_c1 + else: + warped_img0 = softsplat.FunctionSoftsplat( + tenInput=i0, tenFlow=flow_0t, + tenMetric=None, strType='average') + warped_img1 = softsplat.FunctionSoftsplat( + tenInput=i1, tenFlow=flow_1t, + tenMetric=None, strType='average') + flow_0t_1t = torch.cat((flow_0t, flow_1t), 1) + return warped_img0, warped_img1, warped_c0, warped_c1, flow_0t_1t + + + def forward(self, last_i, i0, i1, c0_pyr, c1_pyr, bi_flow_pyr, + time_step=0.5): + warped_img0, warped_img1, warped_c0, warped_c1, flow_0t_1t = \ + self.get_warped_representations( + bi_flow_pyr[0], c0_pyr[0], c1_pyr[0], i0, i1, + time_step=time_step) + input_feat = torch.cat( + (last_i, warped_img0, warped_img1, i0, i1, flow_0t_1t), 1) + s0 = self.encoder_conv(input_feat) + s1 = self.encoder_down1(torch.cat((s0, warped_c0, warped_c1), 1)) + warped_c0, warped_c1 = self.get_warped_representations( + bi_flow_pyr[1], c0_pyr[1], c1_pyr[1], + time_step=time_step) + s2 = self.encoder_down2(torch.cat((s1, warped_c0, warped_c1), 1)) + warped_c0, warped_c1 = self.get_warped_representations( + bi_flow_pyr[2], c0_pyr[2], c1_pyr[2], + time_step=time_step) + + x = self.decoder_up1(torch.cat((s2, warped_c0, warped_c1), 1)) + x = self.decoder_up2(torch.cat((x, s1), 1)) + x = self.decoder_conv(torch.cat((x, s0), 1)) + + # prediction + refine = self.pred(x) + refine_res = torch.sigmoid(refine[:, :3]) * 2 - 1 + refine_mask0 = torch.sigmoid(refine[:, 3:4]) + refine_mask1 = torch.sigmoid(refine[:, 4:5]) + merged_img = (warped_img0 * refine_mask0 * (1 - time_step) + \ + warped_img1 * refine_mask1 * time_step) + merged_img = merged_img / (refine_mask0 * (1 - time_step) + \ + refine_mask1 * time_step) + interp_img = merged_img + refine_res + interp_img = torch.clamp(interp_img, 0, 1) + + extra_dict = {} + extra_dict["refine_res"] = refine_res + extra_dict['refine_mask0'] = refine_mask0 + extra_dict['refine_mask1'] = refine_mask1 + extra_dict["warped_img0"] = warped_img0 + extra_dict["warped_img1"] = warped_img1 + extra_dict["merged_img"] = merged_img + extra_dict['c0_pyr'] = c0_pyr + extra_dict['c1_pyr'] = c1_pyr + extra_dict['s0'] = s0 + extra_dict['s1'] = s1 + extra_dict['s2'] = s2 + + return interp_img, extra_dict + + + +#**************************************************************************************************# +# => Unified model +#**************************************************************************************************# +@register('upr_basic') +class Model(nn.Module): + def __init__(self, pyr_level=3, nr_lvl_skipped=0, *args, **kwargs): + print('@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@UPR_basic (REAL)@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@') + super(Model, self).__init__() + self.pyr_level = pyr_level + self.nr_lvl_skipped = nr_lvl_skipped + self.feat_pyramid = FeatPyramid() + self.motion_estimator = MotionEstimator() + self.synthesis_network = SynthesisNetwork() + + def forward_one_lvl(self, + img0, img1, last_feat, last_flow, last_interp=None, + time_step=0.5, skip_me=False): + + # context feature extraction + feat0_pyr = self.feat_pyramid(img0) + feat1_pyr = self.feat_pyramid(img1) + + # bi-directional flow estimation + if not skip_me: + flow, feat = self.motion_estimator( + feat0_pyr[-1], feat1_pyr[-1], + last_feat, last_flow) + else: + flow = last_flow + feat = last_feat + + # frame synthesis + ## optical flow is estimated at 1/4 resolution + ori_resolution_flow = F.interpolate( + input=flow, scale_factor=4.0, + mode="bilinear", align_corners=False) + + ## consturct 3-level flow pyramid for synthesis network + bi_flow_pyr = [] + tmp_flow = ori_resolution_flow + bi_flow_pyr.append(tmp_flow) + for i in range(2): + tmp_flow = F.interpolate( + input=tmp_flow, scale_factor=0.5, + mode="bilinear", align_corners=False) * 0.5 + bi_flow_pyr.append(tmp_flow) + + ## merge warped frames as initial interpolation for frame synthesis + if last_interp is None: + flow_0t = ori_resolution_flow[:, :2] * time_step + flow_1t = ori_resolution_flow[:, 2:4] * (1 - time_step) + warped_img0 = softsplat.FunctionSoftsplat( + tenInput=img0, tenFlow=flow_0t, + tenMetric=None, strType='average') + warped_img1 = softsplat.FunctionSoftsplat( + tenInput=img1, tenFlow=flow_1t, + tenMetric=None, strType='average') + last_interp = warped_img0 * (1 - time_step) \ + + warped_img1 * time_step + + ## do synthesis + interp_img, extra_dict = self.synthesis_network( + last_interp, img0, img1, feat0_pyr, feat1_pyr, bi_flow_pyr, + time_step=time_step) + return flow, feat, interp_img, extra_dict + + def forward(self, img0, img1, time_step, + pyr_level=None, nr_lvl_skipped=None, **kwargs): + + if pyr_level is None: pyr_level = self.pyr_level + if nr_lvl_skipped is None: nr_lvl_skipped = self.nr_lvl_skipped + N, _, H, W = img0.shape + bi_flows = [] + interp_imgs = [] + skipped_levels = [] if nr_lvl_skipped == 0 else\ + list(range(pyr_level))[::-1][-nr_lvl_skipped:] + + padder = InputPadder(img0.shape, divisor=int(4 * 2 ** pyr_level)) + img0, img1 = padder.pad(img0, img1) + N, _, H, W = img0.shape + + # The original input resolution corresponds to level 0. + for level in list(range(pyr_level))[::-1]: + if level != 0: + scale_factor = 1 / 2 ** level + img0_this_lvl = F.interpolate( + input=img0, scale_factor=scale_factor, + mode="bilinear", align_corners=False) + img1_this_lvl = F.interpolate( + input=img1, scale_factor=scale_factor, + mode="bilinear", align_corners=False) + else: + img0_this_lvl = img0 + img1_this_lvl = img1 + + # skip motion estimation, directly use up-sampled optical flow + skip_me = False + + # the lowest-resolution pyramid level + if level == pyr_level - 1: + last_flow = torch.zeros( + (N, 4, H // (2 ** (level+2)), W //(2 ** (level+2))) + ).to(img0.device) + last_feat = torch.zeros( + (N, 128, H // (2 ** (level+2)), W // (2 ** (level+2))) + ).to(img0.device) + last_interp = None + # skip some levels for both motion estimation and frame synthesis + elif level in skipped_levels[:-1]: + continue + # last level (original input resolution), only skip motion estimation + elif (level == 0) and len(skipped_levels) > 0: + if len(skipped_levels) == pyr_level: + last_flow = torch.zeros( + (N, 4, H // 4, W // 4)).to(img0.device) + last_interp = None + else: + resize_factor = 2 ** len(skipped_levels) + last_flow = F.interpolate( + input=flow, scale_factor=resize_factor, + mode="bilinear", align_corners=False) * resize_factor + last_interp = F.interpolate( + input=interp_img, scale_factor=resize_factor, + mode="bilinear", align_corners=False) + skip_me = True + # last level (original input resolution), motion estimation + frame + # synthesis + else: + last_flow = F.interpolate(input=flow, scale_factor=2.0, + mode="bilinear", align_corners=False) * 2 + last_feat = F.interpolate(input=feat, scale_factor=2.0, + mode="bilinear", align_corners=False) * 2 + last_interp = F.interpolate( + input=interp_img, scale_factor=2.0, + mode="bilinear", align_corners=False) + + + flow, feat, interp_img, extra_dict = self.forward_one_lvl( + img0_this_lvl, img1_this_lvl, + last_feat, last_flow, last_interp, + time_step, skip_me=skip_me) + bi_flows.append( + padder.unpad(F.interpolate(input=flow, scale_factor=4.0, + mode="bilinear", align_corners=False))) + interp_imgs.append(padder.unpad(interp_img)) + + # directly up-sample estimated flow to full resolution with bi-linear + # interpolation + bi_flow = F.interpolate( + input=flow, scale_factor=4.0, + mode="bilinear", align_corners=False) + + result_dict = { + "imgt_preds": interp_imgs, 'imgt_pred': interp_imgs[-1].contiguous(),"bi_flows": bi_flows, + "flowfwd": bi_flows[-1][:,:2], "flowbwd": bi_flows[-1][:,2:] + } + return result_dict, extra_dict + + + +if __name__ == "__main__": + pass \ No newline at end of file diff --git a/modules/components/upr_net/__init__.py b/modules/components/upr_net/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1ebb5b7f654e4c78b9fc6b4fd0f1f5c80dea4981 --- /dev/null +++ b/modules/components/upr_net/__init__.py @@ -0,0 +1 @@ +from .upr import Model diff --git a/modules/components/upr_net/__pycache__/__init__.cpython-310.pyc b/modules/components/upr_net/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2fc33b130c83cbf3482065711fa4b64ccd6a27e Binary files /dev/null and b/modules/components/upr_net/__pycache__/__init__.cpython-310.pyc differ diff --git a/modules/components/upr_net/__pycache__/__init__.cpython-38.pyc b/modules/components/upr_net/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5057589d5959f9a826338e691d74b0e46d100cc Binary files /dev/null and b/modules/components/upr_net/__pycache__/__init__.cpython-38.pyc differ diff --git a/modules/components/upr_net/__pycache__/__init__.cpython-39.pyc b/modules/components/upr_net/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62334431b469ca4e52392da2fc5267c2de37648f Binary files /dev/null and b/modules/components/upr_net/__pycache__/__init__.cpython-39.pyc differ diff --git a/modules/components/upr_net/__pycache__/backwarp.cpython-310.pyc b/modules/components/upr_net/__pycache__/backwarp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0c241246f61e7662c364f90cd383bcb7e6cfa57 Binary files /dev/null and b/modules/components/upr_net/__pycache__/backwarp.cpython-310.pyc differ diff --git a/modules/components/upr_net/__pycache__/backwarp.cpython-38.pyc b/modules/components/upr_net/__pycache__/backwarp.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50fc2d369045af0b3c52386dd156a8fed7d339d6 Binary files /dev/null and b/modules/components/upr_net/__pycache__/backwarp.cpython-38.pyc differ diff --git a/modules/components/upr_net/__pycache__/backwarp.cpython-39.pyc b/modules/components/upr_net/__pycache__/backwarp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f070c15b4d65341057f55e5ca4c0fcfa706f6899 Binary files /dev/null and b/modules/components/upr_net/__pycache__/backwarp.cpython-39.pyc differ diff --git a/modules/components/upr_net/__pycache__/correlation.cpython-310.pyc b/modules/components/upr_net/__pycache__/correlation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0999b3f78516700c6f7433a50a7bba7f5108017d Binary files /dev/null and b/modules/components/upr_net/__pycache__/correlation.cpython-310.pyc differ diff --git a/modules/components/upr_net/__pycache__/correlation.cpython-38.pyc b/modules/components/upr_net/__pycache__/correlation.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10d74375da4eb8e5609cac77ffb1b45ef3eac541 Binary files /dev/null and b/modules/components/upr_net/__pycache__/correlation.cpython-38.pyc differ diff --git a/modules/components/upr_net/__pycache__/correlation.cpython-39.pyc b/modules/components/upr_net/__pycache__/correlation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49dbc75b999c5f97e1ae02dcee30114f277637b7 Binary files /dev/null and b/modules/components/upr_net/__pycache__/correlation.cpython-39.pyc differ diff --git a/modules/components/upr_net/__pycache__/m2m.cpython-310.pyc b/modules/components/upr_net/__pycache__/m2m.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1dd89a439433ae850e0fa446beb3ca8f92452bf1 Binary files /dev/null and b/modules/components/upr_net/__pycache__/m2m.cpython-310.pyc differ diff --git a/modules/components/upr_net/__pycache__/m2m.cpython-38.pyc b/modules/components/upr_net/__pycache__/m2m.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc8d10e66472bbf5af3ddfb1646e8fbe7fcbad49 Binary files /dev/null and b/modules/components/upr_net/__pycache__/m2m.cpython-38.pyc differ diff --git a/modules/components/upr_net/__pycache__/m2m.cpython-39.pyc b/modules/components/upr_net/__pycache__/m2m.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0468953e6dbb87641ec89627f729ef296f36c276 Binary files /dev/null and b/modules/components/upr_net/__pycache__/m2m.cpython-39.pyc differ diff --git a/modules/components/upr_net/__pycache__/softsplat.cpython-310.pyc b/modules/components/upr_net/__pycache__/softsplat.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b98ff6ced223a3fdc1a3e34f859a267cb6f64b74 Binary files /dev/null and b/modules/components/upr_net/__pycache__/softsplat.cpython-310.pyc differ diff --git a/modules/components/upr_net/__pycache__/softsplat.cpython-38.pyc b/modules/components/upr_net/__pycache__/softsplat.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8b575f04edee61e35b9b11af16baaf01e778442 Binary files /dev/null and b/modules/components/upr_net/__pycache__/softsplat.cpython-38.pyc differ diff --git a/modules/components/upr_net/__pycache__/softsplat.cpython-39.pyc b/modules/components/upr_net/__pycache__/softsplat.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56a83c501c17fa7ffa3e7260a32a0ffd10268d7f Binary files /dev/null and b/modules/components/upr_net/__pycache__/softsplat.cpython-39.pyc differ diff --git a/modules/components/upr_net/__pycache__/upr.cpython-310.pyc b/modules/components/upr_net/__pycache__/upr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75b99de5d89f1f866a5728499764c75abe8024a5 Binary files /dev/null and b/modules/components/upr_net/__pycache__/upr.cpython-310.pyc differ diff --git a/modules/components/upr_net/__pycache__/upr.cpython-38.pyc b/modules/components/upr_net/__pycache__/upr.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a72f0514c62f716d7fcc26e46a3dd64044b4f8f Binary files /dev/null and b/modules/components/upr_net/__pycache__/upr.cpython-38.pyc differ diff --git a/modules/components/upr_net/__pycache__/upr.cpython-39.pyc b/modules/components/upr_net/__pycache__/upr.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79521372b00a91b99c4d9f931a401b28db4260fe Binary files /dev/null and b/modules/components/upr_net/__pycache__/upr.cpython-39.pyc differ diff --git a/modules/components/upr_net/backwarp.py b/modules/components/upr_net/backwarp.py new file mode 100644 index 0000000000000000000000000000000000000000..e99a0a5c1b658e81536825451b865b39c45bc9c4 --- /dev/null +++ b/modules/components/upr_net/backwarp.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python + +import torch + + +########################################################## + + +objBackwarpcache = {} + + +def backwarp(tenIn:torch.Tensor, tenFlow:torch.Tensor): + if 'grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3]) not in objBackwarpcache: + tenHor = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[3], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, 1, -1).repeat(1, 1, tenFlow.shape[2], 1) + tenVer = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[2], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, -1, 1).repeat(1, 1, 1, tenFlow.shape[3]) + + objBackwarpcache['grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3])] = torch.cat([tenHor, tenVer], 1) + # end + + if tenFlow.shape[3] == tenFlow.shape[2]: + tenFlow = tenFlow * (2.0 / ((tenFlow.shape[3] and tenFlow.shape[2]) - 1.0)) + + elif tenFlow.shape[3] != tenFlow.shape[2]: + tenFlow = tenFlow * torch.tensor(data=[2.0 / (tenFlow.shape[3] - 1.0), 2.0 / (tenFlow.shape[2] - 1.0)], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 2, 1, 1) + + # end + + return torch.nn.functional.grid_sample(input=tenIn, grid=(objBackwarpcache['grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3])] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=True) +# end diff --git a/modules/components/upr_net/correlation.py b/modules/components/upr_net/correlation.py new file mode 100644 index 0000000000000000000000000000000000000000..1d1c92e2ef7dd885f25b30a3b2e4ed25c6a3889e --- /dev/null +++ b/modules/components/upr_net/correlation.py @@ -0,0 +1,415 @@ +#!/usr/bin/env python + +import torch + +import cupy +import re + +kernel_Correlation_rearrange = ''' + extern "C" __global__ void kernel_Correlation_rearrange( + const int n, + const float* input, + float* output + ) { + int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; + + if (intIndex >= n) { + return; + } + + int intSample = blockIdx.z; + int intChannel = blockIdx.y; + + float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex]; + + __syncthreads(); + + int intPaddedY = (intIndex / SIZE_3(input)) + 4; + int intPaddedX = (intIndex % SIZE_3(input)) + 4; + int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX; + + output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue; + } +''' + +kernel_Correlation_updateOutput = ''' + extern "C" __global__ void kernel_Correlation_updateOutput( + const int n, + const float* rbot0, + const float* rbot1, + float* top + ) { + extern __shared__ char patch_data_char[]; + + float *patch_data = (float *)patch_data_char; + + // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1 + int x1 = blockIdx.x + 4; + int y1 = blockIdx.y + 4; + int item = blockIdx.z; + int ch_off = threadIdx.x; + + // Load 3D patch into shared shared memory + for (int j = 0; j < 1; j++) { // HEIGHT + for (int i = 0; i < 1; i++) { // WIDTH + int ji_off = (j + i) * SIZE_3(rbot0); + for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS + int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch; + int idxPatchData = ji_off + ch; + patch_data[idxPatchData] = rbot0[idx1]; + } + } + } + + __syncthreads(); + + __shared__ float sum[32]; + + // Compute correlation + for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) { + sum[ch_off] = 0; + + int s2o = top_channel % 9 - 4; + int s2p = top_channel / 9 - 4; + + for (int j = 0; j < 1; j++) { // HEIGHT + for (int i = 0; i < 1; i++) { // WIDTH + int ji_off = (j + i) * SIZE_3(rbot0); + for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS + int x2 = x1 + s2o; + int y2 = y1 + s2p; + + int idxPatchData = ji_off + ch; + int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch; + + sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2]; + } + } + } + + __syncthreads(); + + if (ch_off == 0) { + float total_sum = 0; + for (int idx = 0; idx < 32; idx++) { + total_sum += sum[idx]; + } + const int sumelems = SIZE_3(rbot0); + const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x; + top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems; + } + } + } +''' + +kernel_Correlation_updateGradFirst = ''' + #define ROUND_OFF 50000 + + extern "C" __global__ void kernel_Correlation_updateGradFirst( + const int n, + const int intSample, + const float* rbot0, + const float* rbot1, + const float* gradOutput, + float* gradFirst, + float* gradSecond + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + int n = intIndex % SIZE_1(gradFirst); // channels + int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 4; // w-pos + int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 4; // h-pos + + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = round_off; + + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) + int ymin = (m - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) + + // Same here: + int xmax = (l - 4 + round_off_s1) - round_off; // floor (l - 4) + int ymax = (m - 4 + round_off_s1) - round_off; // floor (m - 4) + + float sum = 0; + if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { + xmin = max(0,xmin); + xmax = min(SIZE_3(gradOutput)-1,xmax); + + ymin = max(0,ymin); + ymax = min(SIZE_2(gradOutput)-1,ymax); + + for (int p = -4; p <= 4; p++) { + for (int o = -4; o <= 4; o++) { + // Get rbot1 data: + int s2o = o; + int s2p = p; + int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n; + float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n] + + // Index offset for gradOutput in following loops: + int op = (p+4) * 9 + (o+4); // index[o,p] + int idxopoffset = (intSample * SIZE_1(gradOutput) + op); + + for (int y = ymin; y <= ymax; y++) { + for (int x = xmin; x <= xmax; x++) { + int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] + sum += gradOutput[idxgradOutput] * bot1tmp; + } + } + } + } + } + const int sumelems = SIZE_1(gradFirst); + const int bot0index = ((n * SIZE_2(gradFirst)) + (m-4)) * SIZE_3(gradFirst) + (l-4); + gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems; + } } +''' + +kernel_Correlation_updateGradSecond = ''' + #define ROUND_OFF 50000 + + extern "C" __global__ void kernel_Correlation_updateGradSecond( + const int n, + const int intSample, + const float* rbot0, + const float* rbot1, + const float* gradOutput, + float* gradFirst, + float* gradSecond + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + int n = intIndex % SIZE_1(gradSecond); // channels + int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 4; // w-pos + int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 4; // h-pos + + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = round_off; + + float sum = 0; + for (int p = -4; p <= 4; p++) { + for (int o = -4; o <= 4; o++) { + int s2o = o; + int s2p = p; + + //Get X,Y ranges and clamp + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 4 - s2o + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) + int ymin = (m - 4 - s2p + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) + + // Same here: + int xmax = (l - 4 - s2o + round_off_s1) - round_off; // floor (l - 4 - s2o) + int ymax = (m - 4 - s2p + round_off_s1) - round_off; // floor (m - 4 - s2p) + + if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { + xmin = max(0,xmin); + xmax = min(SIZE_3(gradOutput)-1,xmax); + + ymin = max(0,ymin); + ymax = min(SIZE_2(gradOutput)-1,ymax); + + // Get rbot0 data: + int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n; + float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n] + + // Index offset for gradOutput in following loops: + int op = (p+4) * 9 + (o+4); // index[o,p] + int idxopoffset = (intSample * SIZE_1(gradOutput) + op); + + for (int y = ymin; y <= ymax; y++) { + for (int x = xmin; x <= xmax; x++) { + int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] + sum += gradOutput[idxgradOutput] * bot0tmp; + } + } + } + } + } + const int sumelems = SIZE_1(gradSecond); + const int bot1index = ((n * SIZE_2(gradSecond)) + (m-4)) * SIZE_3(gradSecond) + (l-4); + gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems; + } } +''' + + +def cupy_kernel(strFunction, objVariables): + strKernel = globals()[strFunction] + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + # end + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg])) + # end + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) + + if objMatch is None: + break + # end + + intArgs = int(objMatch.group(2)) + strArgs = objMatch.group(4).split(',') + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + strIndex = ['((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str( + intStrides[intArg]) + ')' for intArg in range(intArgs)] + + strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') + # end + + return strKernel + + +# end + +@cupy.memoize(for_each_device=True) +def cupy_launch(strFunction, strKernel): + return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) + + +# end + +class _FunctionCorrelation(torch.autograd.Function): + @staticmethod + def forward(self, first, second): + rbot0 = first.new_zeros([first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1]]) + rbot1 = first.new_zeros([first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1]]) + + self.save_for_backward(first, second, rbot0, rbot1) + + assert (first.is_contiguous() == True) + assert (second.is_contiguous() == True) + + output = first.new_zeros([first.shape[0], 81, first.shape[2], first.shape[3]]) + + if first.is_cuda == True: + n = first.shape[2] * first.shape[3] + cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { + 'input': first, + 'output': rbot0 + }))( + grid=tuple([int((n + 16 - 1) / 16), first.shape[1], first.shape[0]]), + block=tuple([16, 1, 1]), + args=[n, first.data_ptr(), rbot0.data_ptr()] + ) + + n = second.shape[2] * second.shape[3] + cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { + 'input': second, + 'output': rbot1 + }))( + grid=tuple([int((n + 16 - 1) / 16), second.shape[1], second.shape[0]]), + block=tuple([16, 1, 1]), + args=[n, second.data_ptr(), rbot1.data_ptr()] + ) + + n = output.shape[1] * output.shape[2] * output.shape[3] + cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'top': output + }))( + grid=tuple([output.shape[3], output.shape[2], output.shape[0]]), + block=tuple([32, 1, 1]), + shared_mem=first.shape[1] * 4, + args=[n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr()] + ) + + elif first.is_cuda == False: + raise NotImplementedError() + + # end + + return output + + # end + + @staticmethod + def backward(self, gradOutput): + first, second, rbot0, rbot1 = self.saved_tensors + + assert (gradOutput.is_contiguous() == True) + + gradFirst = first.new_zeros([first.shape[0], first.shape[1], first.shape[2], first.shape[3]]) if \ + self.needs_input_grad[0] == True else None + gradSecond = first.new_zeros([first.shape[0], first.shape[1], first.shape[2], first.shape[3]]) if \ + self.needs_input_grad[1] == True else None + + if first.is_cuda == True: + if gradFirst is not None: + for intSample in range(first.shape[0]): + n = first.shape[1] * first.shape[2] * first.shape[3] + cupy_launch('kernel_Correlation_updateGradFirst', + cupy_kernel('kernel_Correlation_updateGradFirst', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'gradOutput': gradOutput, + 'gradFirst': gradFirst, + 'gradSecond': None + }))( + grid=tuple([int((n + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), + gradFirst.data_ptr(), None] + ) + # end + # end + + if gradSecond is not None: + for intSample in range(first.shape[0]): + n = first.shape[1] * first.shape[2] * first.shape[3] + cupy_launch('kernel_Correlation_updateGradSecond', + cupy_kernel('kernel_Correlation_updateGradSecond', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'gradOutput': gradOutput, + 'gradFirst': None, + 'gradSecond': gradSecond + }))( + grid=tuple([int((n + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None, + gradSecond.data_ptr()] + ) + # end + # end + + elif first.is_cuda == False: + raise NotImplementedError() + + # end + + return gradFirst, gradSecond + + +# end +# end + +def FunctionCorrelation(tenFirst, tenSecond): + return _FunctionCorrelation.apply(tenFirst, tenSecond) + + +# end + +class ModuleCorrelation(torch.nn.Module): + def __init__(self): + super(ModuleCorrelation, self).__init__() + + # end + + def forward(self, tenFirst, tenSecond): + return _FunctionCorrelation.apply(tenFirst, tenSecond) +# end +# end \ No newline at end of file diff --git a/modules/components/upr_net/m2m.py b/modules/components/upr_net/m2m.py new file mode 100644 index 0000000000000000000000000000000000000000..f536207982e94a86dc28b8599c557c84b5effb69 --- /dev/null +++ b/modules/components/upr_net/m2m.py @@ -0,0 +1,407 @@ + +import math +import torch +import torch.nn as nn +import typing + +from ..components import register +from .backwarp import * +from .softsplat import _FunctionSoftsplat + + +########################################################## + +def forwarp_mframe_mask(tenIn1, tenFlow1, t1, tenIn2, tenFlow2, t2, tenMetric1=None, tenMetric2=None): + def one_fdir(tenIn, tenFlow, td, tenMetric): + tenIn = torch.cat([tenIn * td * (tenMetric).clip(-20.0, 20.0).exp(), td * (tenMetric).clip(-20.0, 20.0).exp()], + 1) + + tenOut = _FunctionSoftsplat.apply(tenIn, tenFlow) + + return tenOut[:, :-1, :, :], tenOut[:, -1:, :, :] + 0.0000001 + + flow_num = tenFlow1.shape[0] + tenOutF, tenOutB = 0, 0 + tenNormalizeF, tenNormalizeB = 0, 0 + for idx in range(flow_num): + tenOutF_, tenNormalizeF_ = one_fdir(tenIn1[idx], tenFlow1[idx], t1[idx], tenMetric1[idx]) + tenOutB_, tenNormalizeB_ = one_fdir(tenIn2[idx], tenFlow2[idx], t2[idx], tenMetric2[idx]) + + tenOutF += tenOutF_ + tenOutB += tenOutB_ + tenNormalizeF += tenNormalizeF_ + tenNormalizeB += tenNormalizeB_ + + return tenOutF / tenNormalizeF, tenNormalizeF < 0.00001, tenOutB / tenNormalizeB, tenNormalizeB < 0.00001 + + +################################################################### + +c = 16 + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return torch.nn.Sequential( + torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + torch.nn.PReLU(out_planes) + ) + + +def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): + return torch.nn.Sequential( + torch.torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, + kernel_size=kernel_size, stride=stride, padding=padding, bias=True), + torch.nn.PReLU(out_planes) + ) + + +class Conv2(torch.nn.Module): + def __init__(self, in_planes, out_planes, stride=2): + super(Conv2, self).__init__() + self.conv1 = conv(in_planes, out_planes, 3, stride, 1) + self.conv2 = conv(out_planes, out_planes, 3, 1, 1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + +class Conv2n(torch.nn.Module): + def __init__(self, in_planes, out_planes, stride=2): + super(Conv2n, self).__init__() + self.conv1 = conv(in_planes, in_planes, 3, stride, 1) + self.conv2 = conv(in_planes, in_planes, 3, 1, 1) + self.conv3 = conv(in_planes, in_planes, 1, 1, 0) + self.conv4 = conv(in_planes, out_planes, 1, 1, 0) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + return x + + +##################################################### + +class ImgPyramid(torch.nn.Module): + def __init__(self): + super(ImgPyramid, self).__init__() + self.conv1 = Conv2(3, c) + self.conv2 = Conv2(c, 2 * c) + self.conv3 = Conv2(2 * c, 4 * c) + self.conv4 = Conv2(4 * c, 8 * c) + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(x1) + x3 = self.conv3(x2) + x4 = self.conv4(x3) + return [x1, x2, x3, x4] + + +class EncDec(torch.nn.Module): + def __init__(self, branch): + super(EncDec, self).__init__() + self.branch = branch + + self.down0 = Conv2(8, 2 * c) + self.down1 = Conv2(6 * c, 4 * c) + self.down2 = Conv2(12 * c, 8 * c) + self.down3 = Conv2(24 * c, 16 * c) + + self.up0 = deconv(48 * c, 8 * c) + self.up1 = deconv(16 * c, 4 * c) + self.up2 = deconv(8 * c, 2 * c) + self.up3 = deconv(4 * c, c) + self.conv = torch.nn.Conv2d(c, 2 * self.branch, 3, 1, 1) + + self.conv_m = torch.nn.Conv2d(c, self.branch, 3, 1, 1) + + # For Channel dimennsion + self.conv_C = torch.nn.Sequential( + torch.nn.AdaptiveAvgPool2d(1), + torch.nn.Conv2d(16 * c, 16 * 16 * c, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True), + torch.nn.Sigmoid() + ) + + # For Height dimennsion + self.conv_H = torch.nn.Sequential( + torch.nn.AdaptiveAvgPool2d((None, 1)), + torch.nn.Conv2d(16 * c, 16, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True), + torch.nn.Sigmoid() + ) + + # For Width dimennsion + self.conv_W = torch.nn.Sequential( + torch.nn.AdaptiveAvgPool2d((1, None)), + torch.nn.Conv2d(16 * c, 16, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True), + torch.nn.Sigmoid() + ) + + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, flow0, flow1, im0, im1, c0, c1): + N_, C_, H_, W_ = im0.shape + + wim1 = backwarp(im1, flow0) + wim0 = backwarp(im0, flow1) + s0_0 = self.down0(torch.cat((flow0, im0, wim1), 1)) + s1_0 = self.down0(torch.cat((flow1, im1, wim0), 1)) + + ######################################################################################### + flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + + wf0 = backwarp(torch.cat((s0_0, c0[0]), 1), flow1) + wf1 = backwarp(torch.cat((s1_0, c1[0]), 1), flow0) + + s0_1 = self.down1(torch.cat((s0_0, c0[0], wf1), 1)) + s1_1 = self.down1(torch.cat((s1_0, c1[0], wf0), 1)) + + ######################################################################################### + flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + + wf0 = backwarp(torch.cat((s0_1, c0[1]), 1), flow1) + wf1 = backwarp(torch.cat((s1_1, c1[1]), 1), flow0) + + s0_2 = self.down2(torch.cat((s0_1, c0[1], wf1), 1)) + s1_2 = self.down2(torch.cat((s1_1, c1[1], wf0), 1)) + + ######################################################################################### + flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + + wf0 = backwarp(torch.cat((s0_2, c0[2]), 1), flow1) + wf1 = backwarp(torch.cat((s1_2, c1[2]), 1), flow0) + + s0_3 = self.down3(torch.cat((s0_2, c0[2], wf1), 1)) + s1_3 = self.down3(torch.cat((s1_2, c1[2], wf0), 1)) + + ######################################################################################### + + s0_3_c = self.conv_C(s0_3) + s0_3_c = s0_3_c.view(N_, 16, -1, 1, 1) + + s0_3_h = self.conv_H(s0_3) + s0_3_h = s0_3_h.view(N_, 16, 1, -1, 1) + + s0_3_w = self.conv_W(s0_3) + s0_3_w = s0_3_w.view(N_, 16, 1, 1, -1) + + cube0 = (s0_3_c * s0_3_h * s0_3_w).mean(1) + + s0_3 = s0_3 * cube0 + + s1_3_c = self.conv_C(s1_3) + s1_3_c = s1_3_c.view(N_, 16, -1, 1, 1) + + s1_3_h = self.conv_H(s1_3) + s1_3_h = s1_3_h.view(N_, 16, 1, -1, 1) + + s1_3_w = self.conv_W(s1_3) + s1_3_w = s1_3_w.view(N_, 16, 1, 1, -1) + + cube1 = (s1_3_c * s1_3_h * s1_3_w).mean(1) + + s1_3 = s1_3 * cube1 + + ######################################################################################### + flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + + wf0 = backwarp(torch.cat((s0_3, c0[3]), 1), flow1) + wf1 = backwarp(torch.cat((s1_3, c1[3]), 1), flow0) + + x0 = self.up0(torch.cat((s0_3, c0[3], wf1), 1)) + x1 = self.up0(torch.cat((s1_3, c1[3], wf0), 1)) + + x0 = self.up1(torch.cat((s0_2, x0), 1)) + x1 = self.up1(torch.cat((s1_2, x1), 1)) + + x0 = self.up2(torch.cat((s0_1, x0), 1)) + x1 = self.up2(torch.cat((s1_1, x1), 1)) + + x0 = self.up3(torch.cat((s0_0, x0), 1)) + x1 = self.up3(torch.cat((s1_0, x1), 1)) + + m0 = self.sigmoid(self.conv_m(x0)) * 0.8 + 0.1 + m1 = self.sigmoid(self.conv_m(x1)) * 0.8 + 0.1 + + x0 = self.conv(x0) + x1 = self.conv(x1) + + return x0, x1, m0, m1 + + +@register('m2m_pwc') +class M2M_PWC(torch.nn.Module): + def __init__(self, ratio=4): + super(M2M_PWC, self).__init__() + self.branch = 4 + self.ratio = ratio + + self.paramAlpha = torch.nn.Parameter(10.0 * torch.ones(1, 1, 1, 1)) + + class MotionRefineNet(torch.nn.Module): + def __init__(self, branch): + super(MotionRefineNet, self).__init__() + self.branch = branch + self.img_pyramid = ImgPyramid() + self.motion_encdec = EncDec(branch) + + def forward(self, flow0, flow1, im0, im1, ratio): + flow0 = ratio * torch.nn.functional.interpolate(input=flow0, scale_factor=ratio, mode='bilinear', + align_corners=False) + flow1 = ratio * torch.nn.functional.interpolate(input=flow1, scale_factor=ratio, mode='bilinear', + align_corners=False) + + c0 = self.img_pyramid(im0) + c1 = self.img_pyramid(im1) + + flow_res = self.motion_encdec(flow0, flow1, im0, im1, c0, c1) + + flow0 = flow0.repeat(1, self.branch, 1, 1) + flow_res[0] + flow1 = flow1.repeat(1, self.branch, 1, 1) + flow_res[1] + + return flow0, flow1, flow_res[2], flow_res[3] + + self.MRN = MotionRefineNet(self.branch) + + def forward(self, img0, img1, time_step=[0.5], ratio=None, **kwargs): + if ratio is None: + ratio = self.ratio + + intWidth = img0.shape[3] and img1.shape[3] + intHeight = img0.shape[2] and img1.shape[2] + + intPadr = ((ratio * 16) - (intWidth % (ratio * 16))) % (ratio * 16) + intPadb = ((ratio * 16) - (intHeight % (ratio * 16))) % (ratio * 16) + + img0 = torch.nn.functional.pad(input=img0, pad=[0, intPadr, 0, intPadb], mode='replicate') + img1 = torch.nn.functional.pad(input=img1, pad=[0, intPadr, 0, intPadb], mode='replicate') + + N_, C_, H_, W_ = img0.shape + + outputs = [] + result_dict = {} + with torch.set_grad_enabled(False): + tenStats = [img0, img1] + tenMean_ = sum([tenIn.mean([1, 2, 3], True) for tenIn in tenStats]) / len(tenStats) + tenStd_ = (sum([tenIn.std([1, 2, 3], False, True).square() + ( + tenMean_ - tenIn.mean([1, 2, 3], True)).square() for tenIn in tenStats]) / len(tenStats)).sqrt() + + im0_o = (img0 - tenMean_) / (tenStd_ + 0.0000001) + im1_o = (img1 - tenMean_) / (tenStd_ + 0.0000001) + + img0 = (img0 - tenMean_) / (tenStd_ + 0.0000001) + img1 = (img1 - tenMean_) / (tenStd_ + 0.0000001) + + im0_ = torch.nn.functional.interpolate(input=img0, scale_factor=2.0 / ratio, mode='bilinear', + align_corners=False) + im1_ = torch.nn.functional.interpolate(input=img1, scale_factor=2.0 / ratio, mode='bilinear', + align_corners=False) + + tenFwd, tenBwd = self.netFlow.bidir(im0_, im1_) + + result_dict['flowfwd'] = torch.nn.functional.interpolate(tenFwd, scale_factor=ratio, mode='bilinear', align_corners=False)[:, :, + :intHeight, :intWidth].clone().detach() * ratio + result_dict['flowbwd'] = torch.nn.functional.interpolate(tenBwd, scale_factor=ratio, mode='bilinear', align_corners=False)[:, :, + :intHeight, :intWidth].clone().detach() * ratio + + tenFwd, tenBwd, WeiMF, WeiMB = self.MRN(tenFwd, tenBwd, img0, img1, ratio) + + img0 = im0_o.repeat(1, self.branch, 1, 1) + img1 = im1_o.repeat(1, self.branch, 1, 1) + tenStd = tenStd_.repeat(1, self.branch, 1, 1) + tenMean = tenMean_.repeat(1, self.branch, 1, 1) + fltTime = time_step.repeat(1, self.branch, 1, 1) + + tenFwd = tenFwd.reshape(N_, self.branch, 2, H_, W_).view(N_ * self.branch, 2, H_, W_) + tenBwd = tenBwd.reshape(N_, self.branch, 2, H_, W_).view(N_ * self.branch, 2, H_, W_) + + WeiMF = WeiMF.reshape(N_, self.branch, 1, H_, W_).view(N_ * self.branch, 1, H_, W_) + WeiMB = WeiMB.reshape(N_, self.branch, 1, H_, W_).view(N_ * self.branch, 1, H_, W_) + + img0 = img0.reshape(N_, self.branch, 3, H_, W_).view(N_ * self.branch, 3, H_, W_) + img1 = img1.reshape(N_, self.branch, 3, H_, W_).view(N_ * self.branch, 3, H_, W_) + + tenStd = tenStd.reshape(N_, self.branch, 1, 1, 1).view(N_ * self.branch, 1, 1, 1) + tenMean = tenMean.reshape(N_, self.branch, 1, 1, 1).view(N_ * self.branch, 1, 1, 1) + fltTime = fltTime.reshape(N_, self.branch, 1, 1, 1).view(N_ * self.branch, 1, 1, 1) + + tenPhotoone = (1.0 - (WeiMF * (img0 - backwarp(img1, tenFwd).detach()).abs().mean([1], True))).clip( + 0.001, None).square() + tenPhototwo = (1.0 - (WeiMB * (img1 - backwarp(img0, tenBwd).detach()).abs().mean([1], True))).clip( + 0.001, None).square() + + t0 = fltTime + flow0 = tenFwd * t0 + metric0 = self.paramAlpha * tenPhotoone + + t1 = 1.0 - fltTime + flow1 = tenBwd * t1 + metric1 = self.paramAlpha * tenPhototwo + + flow0 = flow0.reshape(N_, self.branch, 2, H_, W_).permute(1, 0, 2, 3, 4) + flow1 = flow1.reshape(N_, self.branch, 2, H_, W_).permute(1, 0, 2, 3, 4) + + metric0 = metric0.reshape(N_, self.branch, 1, H_, W_).permute(1, 0, 2, 3, 4) + metric1 = metric1.reshape(N_, self.branch, 1, H_, W_).permute(1, 0, 2, 3, 4) + + img0 = img0.reshape(N_, self.branch, 3, H_, W_).permute(1, 0, 2, 3, 4) + img1 = img1.reshape(N_, self.branch, 3, H_, W_).permute(1, 0, 2, 3, 4) + + t0 = t0.reshape(N_, self.branch, 1, 1, 1).permute(1, 0, 2, 3, 4) + t1 = t1.reshape(N_, self.branch, 1, 1, 1).permute(1, 0, 2, 3, 4) + + tenOutput, mask = forwarp_mframe_mask(img0, flow0, t1, img1, flow1, t0, metric0, metric1) + + tenOutput = tenOutput + mask * (t1.mean(0) * im0_o + t0.mean(0) * im1_o) + + output = (tenOutput * (tenStd_ + 0.0000001)) + tenMean_ + result_dict['imgt_pred'] = output[:, :, :intHeight, :intWidth] + + return result_dict + +class ResBlock(nn.Module): + def __init__(self, in_channels, side_channels, bias=True): + super(ResBlock, self).__init__() + self.side_channels = side_channels + self.conv1 = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(in_channels) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(side_channels) + ) + self.conv3 = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(in_channels) + ) + self.conv4 = nn.Sequential( + nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(side_channels) + ) + self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias) + self.prelu = nn.PReLU(in_channels) + + def forward(self, x): + out = self.conv1(x) + + res_feat = out[:, :-self.side_channels, ...] + side_feat = out[:, -self.side_channels:, :, :] + side_feat = self.conv2(side_feat) + out = self.conv3(torch.cat([res_feat, side_feat], 1)) + + res_feat = out[:, :-self.side_channels, ...] + side_feat = out[:, -self.side_channels:, :, :] + side_feat = self.conv4(side_feat) + out = self.conv5(torch.cat([res_feat, side_feat], 1)) + + out = self.prelu(x + out) + return out \ No newline at end of file diff --git a/modules/components/upr_net/softsplat.py b/modules/components/upr_net/softsplat.py new file mode 100644 index 0000000000000000000000000000000000000000..77967f24cd1eeee56417d1de2c88369d13b883c6 --- /dev/null +++ b/modules/components/upr_net/softsplat.py @@ -0,0 +1,360 @@ +#!/usr/bin/env python + +import torch + +import cupy +import re + +kernel_Softsplat_updateOutput = ''' + extern "C" __global__ void kernel_Softsplat_updateOutput( + const int n, + const float* input, + const float* flow, + float* output + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(output) / SIZE_2(output) / SIZE_1(output) ) % SIZE_0(output); + const int intC = ( intIndex / SIZE_3(output) / SIZE_2(output) ) % SIZE_1(output); + const int intY = ( intIndex / SIZE_3(output) ) % SIZE_2(output); + const int intX = ( intIndex ) % SIZE_3(output); + + float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); + float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); + + int intNorthwestX = (int) (floor(fltOutputX)); + int intNorthwestY = (int) (floor(fltOutputY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + float fltNorthwest = ((float) (intSoutheastX) - fltOutputX ) * ((float) (intSoutheastY) - fltOutputY ); + float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY ); + float fltSouthwest = ((float) (intNortheastX) - fltOutputX ) * (fltOutputY - (float) (intNortheastY)); + float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY)); + + if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(output)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(output))) { + atomicAdd(&output[OFFSET_4(output, intN, intC, intNorthwestY, intNorthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltNorthwest); + } + + if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(output)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(output))) { + atomicAdd(&output[OFFSET_4(output, intN, intC, intNortheastY, intNortheastX)], VALUE_4(input, intN, intC, intY, intX) * fltNortheast); + } + + if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(output)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(output))) { + atomicAdd(&output[OFFSET_4(output, intN, intC, intSouthwestY, intSouthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltSouthwest); + } + + if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(output)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(output))) { + atomicAdd(&output[OFFSET_4(output, intN, intC, intSoutheastY, intSoutheastX)], VALUE_4(input, intN, intC, intY, intX) * fltSoutheast); + } + } } +''' + +kernel_Softsplat_updateGradInput = ''' + extern "C" __global__ void kernel_Softsplat_updateGradInput( + const int n, + const float* input, + const float* flow, + const float* gradOutput, + float* gradInput, + float* gradFlow + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) / SIZE_1(gradInput) ) % SIZE_0(gradInput); + const int intC = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) ) % SIZE_1(gradInput); + const int intY = ( intIndex / SIZE_3(gradInput) ) % SIZE_2(gradInput); + const int intX = ( intIndex ) % SIZE_3(gradInput); + + float fltGradInput = 0.0; + + float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); + float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); + + int intNorthwestX = (int) (floor(fltOutputX)); + int intNorthwestY = (int) (floor(fltOutputY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + float fltNorthwest = ((float) (intSoutheastX) - fltOutputX ) * ((float) (intSoutheastY) - fltOutputY ); + float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY ); + float fltSouthwest = ((float) (intNortheastX) - fltOutputX ) * (fltOutputY - (float) (intNortheastY)); + float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY)); + + if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) { + fltGradInput += VALUE_4(gradOutput, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest; + } + + if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) { + fltGradInput += VALUE_4(gradOutput, intN, intC, intNortheastY, intNortheastX) * fltNortheast; + } + + if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) { + fltGradInput += VALUE_4(gradOutput, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest; + } + + if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) { + fltGradInput += VALUE_4(gradOutput, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast; + } + + gradInput[intIndex] = fltGradInput; + } } +''' + +kernel_Softsplat_updateGradFlow = ''' + extern "C" __global__ void kernel_Softsplat_updateGradFlow( + const int n, + const float* input, + const float* flow, + const float* gradOutput, + float* gradInput, + float* gradFlow + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + float fltGradFlow = 0.0; + + const int intN = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) / SIZE_1(gradFlow) ) % SIZE_0(gradFlow); + const int intC = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) ) % SIZE_1(gradFlow); + const int intY = ( intIndex / SIZE_3(gradFlow) ) % SIZE_2(gradFlow); + const int intX = ( intIndex ) % SIZE_3(gradFlow); + + float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); + float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); + + int intNorthwestX = (int) (floor(fltOutputX)); + int intNorthwestY = (int) (floor(fltOutputY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + float fltNorthwest = 0.0; + float fltNortheast = 0.0; + float fltSouthwest = 0.0; + float fltSoutheast = 0.0; + + if (intC == 0) { + fltNorthwest = ((float) (-1.0)) * ((float) (intSoutheastY) - fltOutputY ); + fltNortheast = ((float) (+1.0)) * ((float) (intSouthwestY) - fltOutputY ); + fltSouthwest = ((float) (-1.0)) * (fltOutputY - (float) (intNortheastY)); + fltSoutheast = ((float) (+1.0)) * (fltOutputY - (float) (intNorthwestY)); + + } else if (intC == 1) { + fltNorthwest = ((float) (intSoutheastX) - fltOutputX ) * ((float) (-1.0)); + fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (-1.0)); + fltSouthwest = ((float) (intNortheastX) - fltOutputX ) * ((float) (+1.0)); + fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * ((float) (+1.0)); + + } + + for (int intChannel = 0; intChannel < SIZE_1(gradOutput); intChannel += 1) { + float fltInput = VALUE_4(input, intN, intChannel, intY, intX); + + if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) { + fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNorthwestY, intNorthwestX) * fltNorthwest; + } + + if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) { + fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNortheastY, intNortheastX) * fltNortheast; + } + + if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) { + fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSouthwestY, intSouthwestX) * fltSouthwest; + } + + if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) { + fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSoutheastY, intSoutheastX) * fltSoutheast; + } + } + + gradFlow[intIndex] = fltGradFlow; + } } +''' + +def cupy_kernel(strFunction, objVariables): + strKernel = globals()[strFunction] + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg])) + + while True: + objMatch = re.search('(OFFSET_)([0-4])(\()([^\)]+)(\))', strKernel) + + if objMatch is None: + break + + intArgs = int(objMatch.group(2)) + strArgs = objMatch.group(4).split(',') + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')')\ + .strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] + + strKernel = strKernel.replace(objMatch.group(0), '(' + str.join('+', strIndex) + ')') + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) + + if objMatch is None: + break + + intArgs = int(objMatch.group(2)) + strArgs = objMatch.group(4).split(',') + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')')\ + .strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] + + strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') + + return strKernel + + +@cupy.memoize(for_each_device=True) +def cupy_launch(strFunction, strKernel): + return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) + + +class _FunctionSoftsplat(torch.autograd.Function): + @staticmethod + def forward(self, input, flow): + self.save_for_backward(input, flow) + + intSamples = input.shape[0] + intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3] + intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3] + + assert(intFlowDepth == 2) + assert(intInputHeight == intFlowHeight) + assert(intInputWidth == intFlowWidth) + + assert(input.is_contiguous() == True) + assert(flow.is_contiguous() == True) + + output = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ]) + + if input.is_cuda == True: + n = output.nelement() + cupy_launch('kernel_Softsplat_updateOutput', cupy_kernel('kernel_Softsplat_updateOutput', { + 'input': input, + 'flow': flow, + 'output': output + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ n, input.data_ptr(), flow.data_ptr(), output.data_ptr() ] + ) + + elif input.is_cuda == False: + raise NotImplementedError() + + return output + + + @staticmethod + def backward(self, gradOutput): + input, flow = self.saved_tensors + + intSamples = input.shape[0] + intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3] + intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3] + + assert(intFlowDepth == 2) + assert(intInputHeight == intFlowHeight) + assert(intInputWidth == intFlowWidth) + + assert(gradOutput.is_contiguous() == True) + + gradInput = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ])\ + if self.needs_input_grad[0] == True else None + gradFlow = input.new_zeros([ intSamples, intFlowDepth, intFlowHeight, intFlowWidth ])\ + if self.needs_input_grad[1] == True else None + + if input.is_cuda == True: + if gradInput is not None: + n = gradInput.nelement() + cupy_launch('kernel_Softsplat_updateGradInput', cupy_kernel('kernel_Softsplat_updateGradInput', { + 'input': input, + 'flow': flow, + 'gradOutput': gradOutput, + 'gradInput': gradInput, + 'gradFlow': gradFlow + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ n, input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), gradInput.data_ptr(), None ] + ) + + if gradFlow is not None: + n = gradFlow.nelement() + cupy_launch('kernel_Softsplat_updateGradFlow', cupy_kernel('kernel_Softsplat_updateGradFlow', { + 'input': input, + 'flow': flow, + 'gradOutput': gradOutput, + 'gradInput': gradInput, + 'gradFlow': gradFlow + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ n, input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), None, gradFlow.data_ptr() ] + ) + + elif input.is_cuda == False: + raise NotImplementedError() + + + return gradInput, gradFlow + + +def FunctionSoftsplat(tenInput, tenFlow, tenMetric, strType): + assert(tenMetric is None or tenMetric.shape[1] == 1) + assert(strType in ['summation', 'average', 'linear', 'softmax']) + + if strType == 'average': + tenInput = torch.cat([ tenInput, tenInput.new_ones(tenInput.shape[0], 1, tenInput.shape[2], tenInput.shape[3]) ], 1) + + elif strType == 'linear': + tenInput = torch.cat([ tenInput * tenMetric, tenMetric ], 1) + + elif strType == 'softmax': + tenInput = torch.cat([ tenInput * tenMetric.clip(-20, 20).exp(), tenMetric.clip(-20, 20).exp() ], 1) + + + tenOutput = _FunctionSoftsplat.apply(tenInput, tenFlow) + + if strType != 'summation': + tenNormalize = tenOutput[:, -1:, :, :] + + tenNormalize[tenNormalize == 0.0] = 1.0 + + tenOutput = tenOutput[:, :-1, :, :] / tenNormalize + + return tenOutput + + +class ModuleSoftsplat(torch.nn.Module): + def __init__(self, strType): + super(ModuleSoftsplat, self).__init__() + + self.strType = strType + + def forward(self, tenInput, tenFlow, tenMetric): + return FunctionSoftsplat(tenInput, tenFlow, tenMetric, self.strType) + diff --git a/modules/components/upr_net/upr.py b/modules/components/upr_net/upr.py new file mode 100644 index 0000000000000000000000000000000000000000..6536f400b0598278e9f8050a9e90bab4f65e0d40 --- /dev/null +++ b/modules/components/upr_net/upr.py @@ -0,0 +1,583 @@ +import torch +import math +import numpy +import torch.nn.functional as F +import torch.nn as nn + +import modules.components.upr_net.correlation as correlation +import modules.components.upr_net.softsplat as softsplat +from modules.components.upr_net.m2m import * +from modules.components.upr_net.backwarp import backwarp + +from ..components import register + +from utils.padder import InputPadder + + +# **************************************************************************************************# +# => Feature Pyramid +# **************************************************************************************************# + + +def photometric_consistency(img0, img1, flow01): + return (img0 - backwarp(img1, flow01)).abs().sum(dim=1, keepdims=True) + + +def flow_consistency(flow01, flow10): + return (flow01 + backwarp(flow10, flow01)).abs().sum(dim=1, keepdims=True) + + +gaussian_kernel = torch.tensor([[1, 2, 1], + [2, 4, 2], + [1, 2, 1]]) / 16 +gaussian_kernel = gaussian_kernel.repeat(2, 1, 1, 1) +gaussian_kernel = gaussian_kernel.to(torch.cuda.current_device()) + + +def gaussian(x): + x = torch.nn.functional.pad(x, (1, 1, 1, 1), mode='reflect') + out = torch.nn.functional.conv2d(x, gaussian_kernel, groups=x.shape[1]) + # out = TF.gaussian_blur(x, [3, 3], sigma=[2, 2]) + return out + + +def variance_flow(flow): + flow = flow * torch.tensor(data=[2.0 / (flow.shape[3] - 1.0), 2.0 / (flow.shape[2] - 1.0)], dtype=flow.dtype, + device=flow.device).view(1, 2, 1, 1) + return (gaussian(flow ** 2) - gaussian(flow) ** 2 + 1e-4).sqrt().abs().sum(dim=1, keepdim=True) + + +class FeatPyramid(nn.Module): + """A 3-level feature pyramid, which by default is shared by the motion + estimator and synthesis network. + """ + + def __init__(self): + super(FeatPyramid, self).__init__() + self.conv_stage0 = nn.Sequential( + nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_stage1 = nn.Sequential( + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, + stride=2, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_stage2 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, + stride=2, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + + def forward(self, img): + C0 = self.conv_stage0(img) + C1 = self.conv_stage1(C0) + C2 = self.conv_stage2(C1) + return [C0, C1, C2] + + +# **************************************************************************************************# +# => Motion Estimation +# **************************************************************************************************# +class MotionEstimator(nn.Module): + """Bi-directional optical flow estimator + 1) construct partial cost volume with the CNN features from the stage 2 of + the feature pyramid; + 2) estimate bi-directional flows, by feeding cost volume, CNN features for + both warped images, CNN feature and estimated flow from previous iteration. + """ + + def __init__(self): + super(MotionEstimator, self).__init__() + # (4*2 + 1) ** 2 + 128 * 2 + 128 + 4 = 469 + self.conv_layer1 = nn.Sequential( + nn.Conv2d(in_channels=469, out_channels=320, + kernel_size=1, stride=1, padding=0), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer2 = nn.Sequential( + nn.Conv2d(in_channels=320, out_channels=256, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer3 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=224, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer4 = nn.Sequential( + nn.Conv2d(in_channels=224, out_channels=192, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer5 = nn.Sequential( + nn.Conv2d(in_channels=192, out_channels=128, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer6 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=4, + kernel_size=3, stride=1, padding=1)) + + def forward(self, feat0, feat1, last_feat, last_flow): + corr_fn = correlation.FunctionCorrelation + feat0 = softsplat.FunctionSoftsplat( + tenInput=feat0, tenFlow=last_flow[:, :2] * 0.5 * 0.24, + tenMetric=None, strType='average') + feat1 = softsplat.FunctionSoftsplat( + tenInput=feat1, tenFlow=last_flow[:, 2:] * 0.5 * 0.24, + tenMetric=None, strType='average') + + volume = F.leaky_relu( + input=corr_fn(tenFirst=feat0, tenSecond=feat1), + negative_slope=0.1, inplace=False) + input_feat = torch.cat([volume, feat0, feat1, last_feat, last_flow], 1) + feat = self.conv_layer1(input_feat) + feat = self.conv_layer2(feat) + feat = self.conv_layer3(feat) + feat = self.conv_layer4(feat) + feat = self.conv_layer5(feat) + flow = self.conv_layer6(feat) + + return flow, feat + + +# **************************************************************************************************# +# => Frame Synthesis +# **************************************************************************************************# +class SynthesisNetwork(nn.Module): + def __init__(self, splat_mode='average', branch=1): + super(SynthesisNetwork, self).__init__() + input_channels = 9 + 4 + 6 + self.encoder_conv = nn.Sequential( + nn.Conv2d(in_channels=input_channels, out_channels=64, + kernel_size=3, stride=1, padding=1), + nn.PReLU(num_parameters=64), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64)) + self.encoder_down1 = nn.Sequential( + nn.Conv2d(in_channels=64 + 32 + 32, out_channels=128, + kernel_size=3, stride=2, padding=1), + nn.PReLU(num_parameters=128), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=128), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=128)) + self.encoder_down2 = nn.Sequential( + nn.Conv2d(in_channels=128 + 64 + 64, out_channels=256, + kernel_size=3, stride=2, padding=1), + nn.PReLU(num_parameters=256), + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=256), + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=256)) + self.decoder_up1 = nn.Sequential( + torch.nn.ConvTranspose2d(in_channels=256 + 128 + 128, + out_channels=128, kernel_size=4, stride=2, + padding=1, bias=True), + nn.PReLU(num_parameters=128), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=128)) + self.decoder_up2 = nn.Sequential( + torch.nn.ConvTranspose2d(in_channels=128 + 128, + out_channels=64, kernel_size=4, stride=2, + padding=1, bias=True), + nn.PReLU(num_parameters=64), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64)) + self.decoder_conv = nn.Sequential( + nn.Conv2d(in_channels=64 + 64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64)) + self.pred = nn.Conv2d(in_channels=64, out_channels=5, kernel_size=3, + stride=1, padding=1) + self.splat_mode = splat_mode + self.branch = branch + + class MotionRefineNet(torch.nn.Module): + def __init__(self, branch): + super(MotionRefineNet, self).__init__() + self.branch = branch + self.img_pyramid = ImgPyramid() + self.motion_encdec = EncDec(branch) + + def forward(self, flow0, flow1, im0, im1): + c0 = self.img_pyramid(im0) + c1 = self.img_pyramid(im1) + + flow_res = self.motion_encdec(flow0, flow1, im0, im1, c0, c1) + + flow0 = flow0.repeat(1, self.branch, 1, 1) + flow_res[0] + flow1 = flow1.repeat(1, self.branch, 1, 1) + flow_res[1] + + return flow0, flow1, flow_res[2], flow_res[3] + if self.branch > 1: + # self.MRN = MotionRefineNet(self.branch) + self.convblock = nn.Sequential( + nn.Conv2d(32 * 2 + 2, 32, 3, 1, 1), + nn.ReLU(), + ResBlock(32, 16), + nn.Conv2d(32, 3 * self.branch, 3, 1, 1, bias=False) + ) + if self.splat_mode == 'softmax' or branch > 1: + # New params for splatting mask generation + self.alpha = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_photo_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_flow_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_variation_flow = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + + def get_splat_weight(self, img0, img1, flow01, flow10): + if self.splat_mode == 'softmax' or self.branch > 1: + M_splat = 1 / (1 + self.alpha_splat_photo_consistency * photometric_consistency(img0, img1, flow01).detach()) + \ + 1 / (1 + self.alpha_splat_flow_consistency * flow_consistency(flow01, flow10).detach()) + \ + 1 / (1 + self.alpha_splat_variation_flow * variance_flow(flow01).detach()) + return M_splat * self.alpha + else: + return None + + def get_warped_representations(self, bi_flow, c0, c1, m_splat_0, m_splat_1, i0=None, i1=None, time_period=0.5): + flow_0t = bi_flow[:, :2] * time_period + flow_1t = bi_flow[:, 2:4] * (1 - time_period) + warped_c0 = softsplat.FunctionSoftsplat( + tenInput=c0, tenFlow=flow_0t, + tenMetric=None, strType='average') + warped_c1 = softsplat.FunctionSoftsplat( + tenInput=c1, tenFlow=flow_1t, + tenMetric=None, strType='average') + if (i0 is None) and (i1 is None): + return warped_c0, warped_c1 + else: + warped_img0 = softsplat.FunctionSoftsplat( + tenInput=i0, tenFlow=flow_0t, + tenMetric=m_splat_0, strType=self.splat_mode) + warped_img1 = softsplat.FunctionSoftsplat( + tenInput=i1, tenFlow=flow_1t, + tenMetric=m_splat_1, strType=self.splat_mode) + flow_0t_1t = torch.cat((flow_0t, flow_1t), 1) + return warped_img0, warped_img1, warped_c0, warped_c1, flow_0t_1t + + def forward(self, last_i, i0, i1, c0_pyr, c1_pyr, bi_flow_pyr, time_period=0.5, multi_flow=False): + m_splat_0_0 = self.get_splat_weight(i0, i1, bi_flow_pyr[0][:, :2], bi_flow_pyr[0][:, 2:4]) + m_splat_1_0 = self.get_splat_weight(i1, i0, bi_flow_pyr[0][:, 2:4], bi_flow_pyr[0][:, :2]) + if multi_flow: + tenFwd = bi_flow_pyr[0][:, :2] + tenBwd = bi_flow_pyr[0][:, 2:4] + # tenFwd, tenBwd, WeiMF, WeiMB = self.MRN(tenFwd, tenBwd, i0, i1) + c0_warp = backwarp(c0_pyr[0], tenBwd) + c1_warp = backwarp(c1_pyr[0], tenFwd) + out0 = self.convblock(torch.cat([c0_pyr[0], c1_warp, tenFwd], 1)) + out1 = self.convblock(torch.cat([c1_pyr[0], c0_warp, tenBwd], 1)) + delta_flow_fwd, WeiMF = torch.split(out0, [2 * self.branch, self.branch], 1) + delta_flow_bwd, WeiMB = torch.split(out1, [2 * self.branch, self.branch], 1) + + tenFwd = delta_flow_fwd + tenFwd.repeat(1, self.branch, 1, 1) + tenBwd = delta_flow_bwd + tenBwd.repeat(1, self.branch, 1, 1) + N_, _, H_, W_ = i0.shape + + i0_ = i0.repeat(1, self.branch, 1, 1) + i1_ = i1.repeat(1, self.branch, 1, 1) + + fltTime = time_period.repeat(1, self.branch, 1, 1) + + tenFwd = tenFwd.reshape(N_, self.branch, 2, H_, W_).view(N_ * self.branch, 2, H_, W_) + tenBwd = tenBwd.reshape(N_, self.branch, 2, H_, W_).view(N_ * self.branch, 2, H_, W_) + + WeiMF = WeiMF.view(N_, self.branch, 1, H_, W_).reshape(N_ * self.branch, 1, H_, W_) + WeiMB = WeiMB.view(N_, self.branch, 1, H_, W_).reshape(N_ * self.branch, 1, H_, W_) + + i0_ = i0_.reshape(N_, self.branch, 3, H_, W_).view(N_ * self.branch, 3, H_, W_) + i1_ = i1_.reshape(N_, self.branch, 3, H_, W_).view(N_ * self.branch, 3, H_, W_) + + fltTime = fltTime.reshape(N_, self.branch, 1, 1, 1).view(N_ * self.branch, 1, 1, 1) + + tenPhotoone = self.get_splat_weight(i0_, i1_, tenFwd, tenBwd) * WeiMF + tenPhototwo = self.get_splat_weight(i1_, i0_, tenBwd, tenFwd) * WeiMB + + t0 = fltTime + flow0 = tenFwd * t0 + metric0 = tenPhotoone + + t1 = 1.0 - fltTime + flow1 = tenBwd * t1 + metric1 = tenPhototwo + + flow0 = flow0.reshape(N_, self.branch, 2, H_, W_).permute(1, 0, 2, 3, 4) + flow1 = flow1.reshape(N_, self.branch, 2, H_, W_).permute(1, 0, 2, 3, 4) + + metric0 = metric0.reshape(N_, self.branch, 1, H_, W_).permute(1, 0, 2, 3, 4) + metric1 = metric1.reshape(N_, self.branch, 1, H_, W_).permute(1, 0, 2, 3, 4) + + i0_ = i0_.reshape(N_, self.branch, 3, H_, W_).permute(1, 0, 2, 3, 4) + i1_ = i1_.reshape(N_, self.branch, 3, H_, W_).permute(1, 0, 2, 3, 4) + + t0 = t0.reshape(N_, self.branch, 1, 1, 1).permute(1, 0, 2, 3, 4) + t1 = t1.reshape(N_, self.branch, 1, 1, 1).permute(1, 0, 2, 3, 4) + flow0, flow1 = flow0.contiguous(), flow1.contiguous() + + tenOutputF, maskF, tenOutputB, maskB = forwarp_mframe_mask(i0_, flow0, t0, i1_, flow1, t1, metric0, metric1) + + warped_img0 = tenOutputF + maskF * i0 + warped_img1 = tenOutputB + maskB * i1 + warped_c0, warped_c1 = \ + self.get_warped_representations( + bi_flow_pyr[0], c0_pyr[0], c1_pyr[0], m_splat_0_0, m_splat_1_0, + time_period=time_period) + flow_0t = bi_flow_pyr[0][:, :2] * time_period + flow_1t = bi_flow_pyr[0][:, 2:4] * (1 - time_period) + flow_0t_1t = torch.cat((flow_0t, flow_1t), 1) + else: + warped_img0, warped_img1, warped_c0, warped_c1, flow_0t_1t = \ + self.get_warped_representations( + bi_flow_pyr[0], c0_pyr[0], c1_pyr[0], m_splat_0_0, m_splat_1_0, i0, i1, + time_period=time_period) + input_feat = torch.cat( + (last_i, warped_img0, warped_img1, i0, i1, flow_0t_1t), 1) + s0 = self.encoder_conv(input_feat) + s1 = self.encoder_down1(torch.cat((s0, warped_c0, warped_c1), 1)) + warped_c0, warped_c1 = self.get_warped_representations( + bi_flow_pyr[1], c0_pyr[1], c1_pyr[1], None, None, + time_period=time_period) + s2 = self.encoder_down2(torch.cat((s1, warped_c0, warped_c1), 1)) + warped_c0, warped_c1 = self.get_warped_representations( + bi_flow_pyr[2], c0_pyr[2], c1_pyr[2], None, None, + time_period=time_period) + + x = self.decoder_up1(torch.cat((s2, warped_c0, warped_c1), 1)) + x = self.decoder_up2(torch.cat((x, s1), 1)) + x = self.decoder_conv(torch.cat((x, s0), 1)) + + # prediction + refine = self.pred(x) + refine_res = torch.sigmoid(refine[:, :3]) * 2 - 1 + refine_mask0 = torch.sigmoid(refine[:, 3:4]) + refine_mask1 = torch.sigmoid(refine[:, 4:5]) + merged_img = (warped_img0 * refine_mask0 * (1 - time_period) + \ + warped_img1 * refine_mask1 * time_period) + merged_img = merged_img / (refine_mask0 * (1 - time_period) + \ + refine_mask1 * time_period) + interp_img = merged_img + refine_res + interp_img = torch.clamp(interp_img, 0, 1) + + extra_dict = {} + extra_dict["refine_res"] = refine_res + extra_dict["warped_img0"] = warped_img0 + extra_dict["warped_img1"] = warped_img1 + extra_dict["merged_img"] = merged_img + if multi_flow: + extra_dict['tenFwd'] = tenFwd.view(N_, self.branch, 2, H_, W_) + extra_dict['tenBwd'] = tenBwd.view(N_, self.branch, 2, H_, W_) + + return interp_img, extra_dict + + +# **************************************************************************************************# +# => Unified model +# **************************************************************************************************# +@register('upr_net') +class Model(nn.Module): + def __init__(self, pyr_level=3, nr_lvl_skipped=0, splat_mode='average', branch=1): + super(Model, self).__init__() + self.pyr_level = pyr_level + self.feat_pyramid = FeatPyramid() + self.nr_lvl_skipped = nr_lvl_skipped + self.motion_estimator = MotionEstimator() + self.synthesis_network = SynthesisNetwork(splat_mode, branch) + self.splat_mode = splat_mode + self.branch = branch + + def forward_one_lvl(self, + img0, img1, last_feat, last_flow, last_interp=None, + time_period=0.5, skip_me=False, multi_flow=False): + + # context feature extraction + feat0_pyr = self.feat_pyramid(img0) + feat1_pyr = self.feat_pyramid(img1) + + # bi-directional flow estimation + if not skip_me: + flow, feat = self.motion_estimator( + feat0_pyr[-1], feat1_pyr[-1], + last_feat, last_flow) + else: + flow = last_flow + feat = last_feat + + # frame synthesis + ## optical flow is estimated at 1/4 resolution + ori_resolution_flow = F.interpolate( + input=flow, scale_factor=4.0, + mode="bilinear", align_corners=False) + + ## consturct 3-level flow pyramid for synthesis network + bi_flow_pyr = [] + tmp_flow = ori_resolution_flow + bi_flow_pyr.append(tmp_flow) + for i in range(2): + tmp_flow = F.interpolate( + input=tmp_flow, scale_factor=0.5, + mode="bilinear", align_corners=False) * 0.5 + bi_flow_pyr.append(tmp_flow) + + ## merge warped frames as initial interpolation for frame synthesis + if last_interp is None: + flow_0t = ori_resolution_flow[:, :2] * time_period + flow_1t = ori_resolution_flow[:, 2:4] * (1 - time_period) + warped_img0 = softsplat.FunctionSoftsplat( + tenInput=img0, tenFlow=flow_0t, + tenMetric=None, strType='average') + warped_img1 = softsplat.FunctionSoftsplat( + tenInput=img1, tenFlow=flow_1t, + tenMetric=None, strType='average') + last_interp = warped_img0 * (1 - time_period) \ + + warped_img1 * time_period + + ## do synthesis + interp_img, extra_dict = self.synthesis_network( + last_interp, img0, img1, feat0_pyr, feat1_pyr, bi_flow_pyr, + time_period=time_period, multi_flow=multi_flow) + return flow, feat, interp_img, extra_dict + + def forward(self, img0, img1, time_step, + pyr_level=None, nr_lvl_skipped=None, **kwargs): + + if pyr_level is None: pyr_level = self.pyr_level + if nr_lvl_skipped is None: nr_lvl_skipped = self.nr_lvl_skipped + N, _, H, W = img0.shape + flow0_pred = [] + flow1_pred = [] + interp_imgs = [] + skipped_levels = [] if nr_lvl_skipped == 0 else \ + list(range(pyr_level))[::-1][-nr_lvl_skipped:] + + padder = InputPadder(img0.shape, divisor=int(4 * 2**pyr_level)) + img0, img1 = padder.pad(img0, img1) + N, _, H, W = img0.shape + + # with torch.set_grad_enabled(False): + # tenStats = [img0, img1] + # tenMean_ = sum([tenIn.mean([1, 2, 3], True) for tenIn in tenStats]) / len(tenStats) + # tenStd_ = (sum([tenIn.std([1, 2, 3], False, True).square() + ( + # tenMean_ - tenIn.mean([1, 2, 3], True)).square() for tenIn in tenStats]) / len(tenStats)).sqrt() + # + # img0 = (img0 - tenMean_) / (tenStd_ + 0.0000001) + # img1 = (img1 - tenMean_) / (tenStd_ + 0.0000001) + + # The original input resolution corresponds to level 0. + for level in list(range(pyr_level))[::-1]: + if level != 0: + scale_factor = 1 / 2 ** level + img0_this_lvl = F.interpolate( + input=img0, scale_factor=scale_factor, + mode="bilinear", align_corners=False) + img1_this_lvl = F.interpolate( + input=img1, scale_factor=scale_factor, + mode="bilinear", align_corners=False) + else: + img0_this_lvl = img0 + img1_this_lvl = img1 + + # skip motion estimation, directly use up-sampled optical flow + skip_me = False + + # the lowest-resolution pyramid level + if level == pyr_level - 1: + last_flow = torch.zeros( + (N, 4, H // (2 ** (level + 2)), W // (2 ** (level + 2))) + ).to(img0.device) + last_feat = torch.zeros( + (N, 128, H // (2 ** (level + 2)), W // (2 ** (level + 2))) + ).to(img0.device) + last_interp = None + # skip some levels for both motion estimation and frame synthesis + elif level in skipped_levels[:-1]: + continue + # last level (original input resolution), only skip motion estimation + elif (level == 0) and len(skipped_levels) > 0: + if len(skipped_levels) == pyr_level: + last_flow = torch.zeros( + (N, 4, H // 4, W // 4)).to(img0.device) + last_interp = None + else: + resize_factor = 2 ** len(skipped_levels) + last_flow = F.interpolate( + input=flow, scale_factor=resize_factor, + mode="bilinear", align_corners=False) * resize_factor + last_interp = F.interpolate( + input=interp_img, scale_factor=resize_factor, + mode="bilinear", align_corners=False) + skip_me = True + # last level (original input resolution), motion estimation + frame + # synthesis + else: + last_flow = F.interpolate(input=flow, scale_factor=2.0, + mode="bilinear", align_corners=False) * 2 + last_feat = F.interpolate(input=feat, scale_factor=2.0, + mode="bilinear", align_corners=False) * 2 + last_interp = F.interpolate( + input=interp_img, scale_factor=2.0, + mode="bilinear", align_corners=False) + + flow, feat, interp_img, extra_dict = self.forward_one_lvl( + img0_this_lvl, img1_this_lvl, + last_feat, last_flow, last_interp, + time_step, skip_me=skip_me, multi_flow=(self.branch > 1 and level == 0)) + if level == 0 and self.branch > 1: + flow0_pred.append(extra_dict['tenFwd']) + flow1_pred.append(extra_dict['tenBwd']) + elif level == 0 and self.branch == 1: + flow0_pred.append( + F.interpolate(input=flow[:, :2], scale_factor=4.0, + mode="bilinear", align_corners=False).unsqueeze(1) * 4 * 0.5) + flow1_pred.append( + F.interpolate(input=flow[:, 2:], scale_factor=4.0, + mode="bilinear", align_corners=False).unsqueeze(1) * 4 * 0.5) + else: + flow0_pred.append( + padder.unpad(F.interpolate(input=flow[:, :2], scale_factor=4.0, + mode="bilinear", align_corners=False) * 4 * 0.5)) + flow1_pred.append( + padder.unpad(F.interpolate(input=flow[:, 2:], scale_factor=4.0, + mode="bilinear", align_corners=False) * 4 * 0.5)) + interp_imgs.append(padder.unpad(F.interpolate(interp_img, scale_factor=2 ** level))) + + # directly up-sample estimated flow to full resolution with bi-linear + # interpolation + + interp_img = padder.unpad(interp_img) + + return {"imgt_preds": interp_imgs[-2:], "flow0_pred": flow0_pred[::-1], "flow1_pred": flow1_pred[::-1], + 'imgt_pred': interp_img, "flowfwd": flow0_pred[-1][:, 0], "flowbwd": flow1_pred[-1][:, 0]} + + +if __name__ == "__main__": + pass diff --git a/modules/components/upr_net_freq/__init__.py b/modules/components/upr_net_freq/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fa2abb174b2bd27fee1d699b476309babe0ad7d3 --- /dev/null +++ b/modules/components/upr_net_freq/__init__.py @@ -0,0 +1 @@ +from .upr_freq import Model diff --git a/modules/components/upr_net_freq/__pycache__/__init__.cpython-310.pyc b/modules/components/upr_net_freq/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f542722b19aee42911ce7da4aa07d4357c1281c0 Binary files /dev/null and b/modules/components/upr_net_freq/__pycache__/__init__.cpython-310.pyc differ diff --git a/modules/components/upr_net_freq/__pycache__/__init__.cpython-38.pyc b/modules/components/upr_net_freq/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99b1210207344882b8b99cb4ae14f0d15b1cb931 Binary files /dev/null and b/modules/components/upr_net_freq/__pycache__/__init__.cpython-38.pyc differ diff --git a/modules/components/upr_net_freq/__pycache__/__init__.cpython-39.pyc b/modules/components/upr_net_freq/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbd4e47ade480b5d0937924a233109fb1fd91580 Binary files /dev/null and b/modules/components/upr_net_freq/__pycache__/__init__.cpython-39.pyc differ diff --git a/modules/components/upr_net_freq/__pycache__/backwarp.cpython-310.pyc b/modules/components/upr_net_freq/__pycache__/backwarp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..131603438ae8f934fc07b4c50234f00ab3c83ef7 Binary files /dev/null and b/modules/components/upr_net_freq/__pycache__/backwarp.cpython-310.pyc differ diff --git a/modules/components/upr_net_freq/__pycache__/backwarp.cpython-38.pyc b/modules/components/upr_net_freq/__pycache__/backwarp.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e697df8550657361d5406276e4151ec34157d8a Binary files /dev/null and b/modules/components/upr_net_freq/__pycache__/backwarp.cpython-38.pyc differ diff --git a/modules/components/upr_net_freq/__pycache__/backwarp.cpython-39.pyc b/modules/components/upr_net_freq/__pycache__/backwarp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fad15680a3b6b6a969fb2665e2be9195e9e13472 Binary files /dev/null and b/modules/components/upr_net_freq/__pycache__/backwarp.cpython-39.pyc differ diff --git a/modules/components/upr_net_freq/__pycache__/correlation.cpython-310.pyc b/modules/components/upr_net_freq/__pycache__/correlation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f164da92f5f62b2737f4d833104bfaaea2546c58 Binary files /dev/null and b/modules/components/upr_net_freq/__pycache__/correlation.cpython-310.pyc differ diff --git a/modules/components/upr_net_freq/__pycache__/correlation.cpython-38.pyc b/modules/components/upr_net_freq/__pycache__/correlation.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..926c113f70d036fc686dad25ebcd37dcf5165cc5 Binary files /dev/null and b/modules/components/upr_net_freq/__pycache__/correlation.cpython-38.pyc differ diff --git a/modules/components/upr_net_freq/__pycache__/correlation.cpython-39.pyc b/modules/components/upr_net_freq/__pycache__/correlation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b3b8069ee847a28dba6e089bd4a6259beb84b24 Binary files /dev/null and b/modules/components/upr_net_freq/__pycache__/correlation.cpython-39.pyc differ diff --git a/modules/components/upr_net_freq/__pycache__/costvol.cpython-310.pyc b/modules/components/upr_net_freq/__pycache__/costvol.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c2f8a7679fda19a541be451885d615b9bbe1795 Binary files /dev/null and b/modules/components/upr_net_freq/__pycache__/costvol.cpython-310.pyc differ diff --git a/modules/components/upr_net_freq/__pycache__/costvol.cpython-38.pyc b/modules/components/upr_net_freq/__pycache__/costvol.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5db2fea012876b34074dc2f12cc3a293c53e0e9 Binary files /dev/null and b/modules/components/upr_net_freq/__pycache__/costvol.cpython-38.pyc differ diff --git a/modules/components/upr_net_freq/__pycache__/costvol.cpython-39.pyc b/modules/components/upr_net_freq/__pycache__/costvol.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b283f7c20d2c25fcdf8b40551b42548a4ddd82eb Binary files /dev/null and b/modules/components/upr_net_freq/__pycache__/costvol.cpython-39.pyc differ diff --git a/modules/components/upr_net_freq/__pycache__/frequency_enhance.cpython-310.pyc b/modules/components/upr_net_freq/__pycache__/frequency_enhance.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0e50aad1c9da3de4744898032ecc1ce4ff30e02 Binary files /dev/null and b/modules/components/upr_net_freq/__pycache__/frequency_enhance.cpython-310.pyc differ diff --git a/modules/components/upr_net_freq/__pycache__/frequency_enhance.cpython-38.pyc b/modules/components/upr_net_freq/__pycache__/frequency_enhance.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b827944795a8cf50016c08b5cec634f4d4024502 Binary files /dev/null and b/modules/components/upr_net_freq/__pycache__/frequency_enhance.cpython-38.pyc differ diff --git a/modules/components/upr_net_freq/__pycache__/frequency_enhance.cpython-39.pyc b/modules/components/upr_net_freq/__pycache__/frequency_enhance.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..571cc1ce64b6bafd6b157309cbe3ebc5e823ad56 Binary files /dev/null and b/modules/components/upr_net_freq/__pycache__/frequency_enhance.cpython-39.pyc differ diff --git a/modules/components/upr_net_freq/__pycache__/frequency_enhance_001.cpython-38.pyc b/modules/components/upr_net_freq/__pycache__/frequency_enhance_001.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13e19733bb8bfe0e98dc0ad1079a3c713db3f2fd Binary files /dev/null and b/modules/components/upr_net_freq/__pycache__/frequency_enhance_001.cpython-38.pyc differ diff --git a/modules/components/upr_net_freq/__pycache__/frequency_enhance_002.cpython-38.pyc b/modules/components/upr_net_freq/__pycache__/frequency_enhance_002.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f77e55fc356dad25b08b5b3354aed38a71c3422 Binary files /dev/null and b/modules/components/upr_net_freq/__pycache__/frequency_enhance_002.cpython-38.pyc differ diff --git a/modules/components/upr_net_freq/__pycache__/m2m.cpython-310.pyc b/modules/components/upr_net_freq/__pycache__/m2m.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37df48e69a2bcb73a072e574106d0dc84c98ae61 Binary files /dev/null and b/modules/components/upr_net_freq/__pycache__/m2m.cpython-310.pyc differ diff --git a/modules/components/upr_net_freq/__pycache__/m2m.cpython-38.pyc b/modules/components/upr_net_freq/__pycache__/m2m.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d00a692fc70a5b65e3c2433b322db008c5762111 Binary files /dev/null and b/modules/components/upr_net_freq/__pycache__/m2m.cpython-38.pyc differ diff --git a/modules/components/upr_net_freq/__pycache__/m2m.cpython-39.pyc b/modules/components/upr_net_freq/__pycache__/m2m.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..284d385c55a2dc2e7ed86a325db90b6f4ca4f759 Binary files /dev/null and b/modules/components/upr_net_freq/__pycache__/m2m.cpython-39.pyc differ diff --git a/modules/components/upr_net_freq/__pycache__/softsplat.cpython-310.pyc b/modules/components/upr_net_freq/__pycache__/softsplat.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b713284bd5343b44fad220d2704e2c09091f9b65 Binary files /dev/null and b/modules/components/upr_net_freq/__pycache__/softsplat.cpython-310.pyc differ diff --git a/modules/components/upr_net_freq/__pycache__/softsplat.cpython-38.pyc b/modules/components/upr_net_freq/__pycache__/softsplat.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b804b327f5683ec1f2d919100f33dc762a2e6be Binary files /dev/null and b/modules/components/upr_net_freq/__pycache__/softsplat.cpython-38.pyc differ diff --git a/modules/components/upr_net_freq/__pycache__/softsplat.cpython-39.pyc b/modules/components/upr_net_freq/__pycache__/softsplat.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3262aa10c97bea02135e099ec0db3c8e37aba880 Binary files /dev/null and b/modules/components/upr_net_freq/__pycache__/softsplat.cpython-39.pyc differ diff --git a/modules/components/upr_net_freq/__pycache__/upr_freq.cpython-310.pyc b/modules/components/upr_net_freq/__pycache__/upr_freq.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9322356600bbeb3524311f36a0443169d64211f4 Binary files /dev/null and b/modules/components/upr_net_freq/__pycache__/upr_freq.cpython-310.pyc differ diff --git a/modules/components/upr_net_freq/__pycache__/upr_freq.cpython-38.pyc b/modules/components/upr_net_freq/__pycache__/upr_freq.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec9a6b918427fc59ddb70ef827218bdd2fea2c05 Binary files /dev/null and b/modules/components/upr_net_freq/__pycache__/upr_freq.cpython-38.pyc differ diff --git a/modules/components/upr_net_freq/__pycache__/upr_freq.cpython-39.pyc b/modules/components/upr_net_freq/__pycache__/upr_freq.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79fefa610300f010df1083f22f6ff62d2cd124e4 Binary files /dev/null and b/modules/components/upr_net_freq/__pycache__/upr_freq.cpython-39.pyc differ diff --git a/modules/components/upr_net_freq/__pycache__/upr_freq001.cpython-38.pyc b/modules/components/upr_net_freq/__pycache__/upr_freq001.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2caadc6e8c68e2bca06d33324b35740dbfa0c6be Binary files /dev/null and b/modules/components/upr_net_freq/__pycache__/upr_freq001.cpython-38.pyc differ diff --git a/modules/components/upr_net_freq/__pycache__/upr_freq_002.cpython-38.pyc b/modules/components/upr_net_freq/__pycache__/upr_freq_002.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2213f679081f4d0dd434ae16faa2c081c9eb470 Binary files /dev/null and b/modules/components/upr_net_freq/__pycache__/upr_freq_002.cpython-38.pyc differ diff --git a/modules/components/upr_net_freq/__pycache__/upr_freq_005.cpython-38.pyc b/modules/components/upr_net_freq/__pycache__/upr_freq_005.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2bb9b4699ea06c3c946aceaf4aa7fd3cf89fd8a8 Binary files /dev/null and b/modules/components/upr_net_freq/__pycache__/upr_freq_005.cpython-38.pyc differ diff --git a/modules/components/upr_net_freq/__pycache__/upr_freq_temp.cpython-38.pyc b/modules/components/upr_net_freq/__pycache__/upr_freq_temp.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17aab31b13b5a0c4e37ba853a94ad0214853549b Binary files /dev/null and b/modules/components/upr_net_freq/__pycache__/upr_freq_temp.cpython-38.pyc differ diff --git a/modules/components/upr_net_freq/backwarp.py b/modules/components/upr_net_freq/backwarp.py new file mode 100644 index 0000000000000000000000000000000000000000..729a1db8e0117bd49526929e5953cf6e70fd204a --- /dev/null +++ b/modules/components/upr_net_freq/backwarp.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python + +import torch + + +########################################################## + + +objBackwarpcache = {} + + +def backwarp(tenIn:torch.Tensor, tenFlow:torch.Tensor): + if 'grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3]) not in objBackwarpcache: + tenHor = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[3], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, 1, -1).repeat(1, 1, tenFlow.shape[2], 1) + tenVer = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[2], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, -1, 1).repeat(1, 1, 1, tenFlow.shape[3]) + + objBackwarpcache['grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3])] = torch.cat([tenHor, tenVer], 1) + # end + + if tenFlow.shape[3] == tenFlow.shape[2]: + tenFlow = tenFlow * (2.0 / ((tenFlow.shape[3] and tenFlow.shape[2]) - 1.0)) + + elif tenFlow.shape[3] != tenFlow.shape[2]: + tenFlow = tenFlow * torch.tensor(data=[2.0 / (tenFlow.shape[3] - 1.0), 2.0 / (tenFlow.shape[2] - 1.0)], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 2, 1, 1) + + # end + + return torch.nn.functional.grid_sample(input=tenIn, grid=(objBackwarpcache['grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3])] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=True) +# end \ No newline at end of file diff --git a/modules/components/upr_net_freq/correlation.py b/modules/components/upr_net_freq/correlation.py new file mode 100644 index 0000000000000000000000000000000000000000..1d1c92e2ef7dd885f25b30a3b2e4ed25c6a3889e --- /dev/null +++ b/modules/components/upr_net_freq/correlation.py @@ -0,0 +1,415 @@ +#!/usr/bin/env python + +import torch + +import cupy +import re + +kernel_Correlation_rearrange = ''' + extern "C" __global__ void kernel_Correlation_rearrange( + const int n, + const float* input, + float* output + ) { + int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; + + if (intIndex >= n) { + return; + } + + int intSample = blockIdx.z; + int intChannel = blockIdx.y; + + float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex]; + + __syncthreads(); + + int intPaddedY = (intIndex / SIZE_3(input)) + 4; + int intPaddedX = (intIndex % SIZE_3(input)) + 4; + int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX; + + output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue; + } +''' + +kernel_Correlation_updateOutput = ''' + extern "C" __global__ void kernel_Correlation_updateOutput( + const int n, + const float* rbot0, + const float* rbot1, + float* top + ) { + extern __shared__ char patch_data_char[]; + + float *patch_data = (float *)patch_data_char; + + // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1 + int x1 = blockIdx.x + 4; + int y1 = blockIdx.y + 4; + int item = blockIdx.z; + int ch_off = threadIdx.x; + + // Load 3D patch into shared shared memory + for (int j = 0; j < 1; j++) { // HEIGHT + for (int i = 0; i < 1; i++) { // WIDTH + int ji_off = (j + i) * SIZE_3(rbot0); + for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS + int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch; + int idxPatchData = ji_off + ch; + patch_data[idxPatchData] = rbot0[idx1]; + } + } + } + + __syncthreads(); + + __shared__ float sum[32]; + + // Compute correlation + for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) { + sum[ch_off] = 0; + + int s2o = top_channel % 9 - 4; + int s2p = top_channel / 9 - 4; + + for (int j = 0; j < 1; j++) { // HEIGHT + for (int i = 0; i < 1; i++) { // WIDTH + int ji_off = (j + i) * SIZE_3(rbot0); + for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS + int x2 = x1 + s2o; + int y2 = y1 + s2p; + + int idxPatchData = ji_off + ch; + int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch; + + sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2]; + } + } + } + + __syncthreads(); + + if (ch_off == 0) { + float total_sum = 0; + for (int idx = 0; idx < 32; idx++) { + total_sum += sum[idx]; + } + const int sumelems = SIZE_3(rbot0); + const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x; + top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems; + } + } + } +''' + +kernel_Correlation_updateGradFirst = ''' + #define ROUND_OFF 50000 + + extern "C" __global__ void kernel_Correlation_updateGradFirst( + const int n, + const int intSample, + const float* rbot0, + const float* rbot1, + const float* gradOutput, + float* gradFirst, + float* gradSecond + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + int n = intIndex % SIZE_1(gradFirst); // channels + int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 4; // w-pos + int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 4; // h-pos + + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = round_off; + + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) + int ymin = (m - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) + + // Same here: + int xmax = (l - 4 + round_off_s1) - round_off; // floor (l - 4) + int ymax = (m - 4 + round_off_s1) - round_off; // floor (m - 4) + + float sum = 0; + if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { + xmin = max(0,xmin); + xmax = min(SIZE_3(gradOutput)-1,xmax); + + ymin = max(0,ymin); + ymax = min(SIZE_2(gradOutput)-1,ymax); + + for (int p = -4; p <= 4; p++) { + for (int o = -4; o <= 4; o++) { + // Get rbot1 data: + int s2o = o; + int s2p = p; + int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n; + float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n] + + // Index offset for gradOutput in following loops: + int op = (p+4) * 9 + (o+4); // index[o,p] + int idxopoffset = (intSample * SIZE_1(gradOutput) + op); + + for (int y = ymin; y <= ymax; y++) { + for (int x = xmin; x <= xmax; x++) { + int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] + sum += gradOutput[idxgradOutput] * bot1tmp; + } + } + } + } + } + const int sumelems = SIZE_1(gradFirst); + const int bot0index = ((n * SIZE_2(gradFirst)) + (m-4)) * SIZE_3(gradFirst) + (l-4); + gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems; + } } +''' + +kernel_Correlation_updateGradSecond = ''' + #define ROUND_OFF 50000 + + extern "C" __global__ void kernel_Correlation_updateGradSecond( + const int n, + const int intSample, + const float* rbot0, + const float* rbot1, + const float* gradOutput, + float* gradFirst, + float* gradSecond + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + int n = intIndex % SIZE_1(gradSecond); // channels + int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 4; // w-pos + int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 4; // h-pos + + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = round_off; + + float sum = 0; + for (int p = -4; p <= 4; p++) { + for (int o = -4; o <= 4; o++) { + int s2o = o; + int s2p = p; + + //Get X,Y ranges and clamp + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 4 - s2o + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) + int ymin = (m - 4 - s2p + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) + + // Same here: + int xmax = (l - 4 - s2o + round_off_s1) - round_off; // floor (l - 4 - s2o) + int ymax = (m - 4 - s2p + round_off_s1) - round_off; // floor (m - 4 - s2p) + + if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { + xmin = max(0,xmin); + xmax = min(SIZE_3(gradOutput)-1,xmax); + + ymin = max(0,ymin); + ymax = min(SIZE_2(gradOutput)-1,ymax); + + // Get rbot0 data: + int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n; + float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n] + + // Index offset for gradOutput in following loops: + int op = (p+4) * 9 + (o+4); // index[o,p] + int idxopoffset = (intSample * SIZE_1(gradOutput) + op); + + for (int y = ymin; y <= ymax; y++) { + for (int x = xmin; x <= xmax; x++) { + int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] + sum += gradOutput[idxgradOutput] * bot0tmp; + } + } + } + } + } + const int sumelems = SIZE_1(gradSecond); + const int bot1index = ((n * SIZE_2(gradSecond)) + (m-4)) * SIZE_3(gradSecond) + (l-4); + gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems; + } } +''' + + +def cupy_kernel(strFunction, objVariables): + strKernel = globals()[strFunction] + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + # end + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg])) + # end + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) + + if objMatch is None: + break + # end + + intArgs = int(objMatch.group(2)) + strArgs = objMatch.group(4).split(',') + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + strIndex = ['((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str( + intStrides[intArg]) + ')' for intArg in range(intArgs)] + + strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') + # end + + return strKernel + + +# end + +@cupy.memoize(for_each_device=True) +def cupy_launch(strFunction, strKernel): + return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) + + +# end + +class _FunctionCorrelation(torch.autograd.Function): + @staticmethod + def forward(self, first, second): + rbot0 = first.new_zeros([first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1]]) + rbot1 = first.new_zeros([first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1]]) + + self.save_for_backward(first, second, rbot0, rbot1) + + assert (first.is_contiguous() == True) + assert (second.is_contiguous() == True) + + output = first.new_zeros([first.shape[0], 81, first.shape[2], first.shape[3]]) + + if first.is_cuda == True: + n = first.shape[2] * first.shape[3] + cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { + 'input': first, + 'output': rbot0 + }))( + grid=tuple([int((n + 16 - 1) / 16), first.shape[1], first.shape[0]]), + block=tuple([16, 1, 1]), + args=[n, first.data_ptr(), rbot0.data_ptr()] + ) + + n = second.shape[2] * second.shape[3] + cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { + 'input': second, + 'output': rbot1 + }))( + grid=tuple([int((n + 16 - 1) / 16), second.shape[1], second.shape[0]]), + block=tuple([16, 1, 1]), + args=[n, second.data_ptr(), rbot1.data_ptr()] + ) + + n = output.shape[1] * output.shape[2] * output.shape[3] + cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'top': output + }))( + grid=tuple([output.shape[3], output.shape[2], output.shape[0]]), + block=tuple([32, 1, 1]), + shared_mem=first.shape[1] * 4, + args=[n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr()] + ) + + elif first.is_cuda == False: + raise NotImplementedError() + + # end + + return output + + # end + + @staticmethod + def backward(self, gradOutput): + first, second, rbot0, rbot1 = self.saved_tensors + + assert (gradOutput.is_contiguous() == True) + + gradFirst = first.new_zeros([first.shape[0], first.shape[1], first.shape[2], first.shape[3]]) if \ + self.needs_input_grad[0] == True else None + gradSecond = first.new_zeros([first.shape[0], first.shape[1], first.shape[2], first.shape[3]]) if \ + self.needs_input_grad[1] == True else None + + if first.is_cuda == True: + if gradFirst is not None: + for intSample in range(first.shape[0]): + n = first.shape[1] * first.shape[2] * first.shape[3] + cupy_launch('kernel_Correlation_updateGradFirst', + cupy_kernel('kernel_Correlation_updateGradFirst', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'gradOutput': gradOutput, + 'gradFirst': gradFirst, + 'gradSecond': None + }))( + grid=tuple([int((n + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), + gradFirst.data_ptr(), None] + ) + # end + # end + + if gradSecond is not None: + for intSample in range(first.shape[0]): + n = first.shape[1] * first.shape[2] * first.shape[3] + cupy_launch('kernel_Correlation_updateGradSecond', + cupy_kernel('kernel_Correlation_updateGradSecond', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'gradOutput': gradOutput, + 'gradFirst': None, + 'gradSecond': gradSecond + }))( + grid=tuple([int((n + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None, + gradSecond.data_ptr()] + ) + # end + # end + + elif first.is_cuda == False: + raise NotImplementedError() + + # end + + return gradFirst, gradSecond + + +# end +# end + +def FunctionCorrelation(tenFirst, tenSecond): + return _FunctionCorrelation.apply(tenFirst, tenSecond) + + +# end + +class ModuleCorrelation(torch.nn.Module): + def __init__(self): + super(ModuleCorrelation, self).__init__() + + # end + + def forward(self, tenFirst, tenSecond): + return _FunctionCorrelation.apply(tenFirst, tenSecond) +# end +# end \ No newline at end of file diff --git a/modules/components/upr_net_freq/costvol.py b/modules/components/upr_net_freq/costvol.py new file mode 100644 index 0000000000000000000000000000000000000000..6c93e4db22d00bf73c8b1fc06a297a85a16ee352 --- /dev/null +++ b/modules/components/upr_net_freq/costvol.py @@ -0,0 +1,424 @@ +#!/usr/bin/env python + +import collections +import cupy +import os +import re +import torch +import typing + + +########################################################## + + +objCudacache = {} + + +def cuda_int32(intIn:int): + return cupy.int32(intIn) +# end + + +def cuda_float32(fltIn:float): + return cupy.float32(fltIn) +# end + + +def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict): + if 'device' not in objCudacache: + objCudacache['device'] = torch.cuda.get_device_name() + # end + + strKey = strFunction + + for strVariable in objVariables: + objValue = objVariables[strVariable] + + strKey += strVariable + + if objValue is None: + continue + + elif type(objValue) == int: + strKey += str(objValue) + + elif type(objValue) == float: + strKey += str(objValue) + + elif type(objValue) == bool: + strKey += str(objValue) + + elif type(objValue) == str: + strKey += objValue + + elif type(objValue) == torch.Tensor: + strKey += str(objValue.dtype) + strKey += str(objValue.shape) + strKey += str(objValue.stride()) + + elif True: + print(strVariable, type(objValue)) + assert(False) + + # end + # end + + strKey += objCudacache['device'] + + if strKey not in objCudacache: + for strVariable in objVariables: + objValue = objVariables[strVariable] + + if objValue is None: + continue + + elif type(objValue) == int: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == float: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == bool: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == str: + strKernel = strKernel.replace('{{' + strVariable + '}}', objValue) + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8: + strKernel = strKernel.replace('{{type}}', 'unsigned char') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16: + strKernel = strKernel.replace('{{type}}', 'half') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32: + strKernel = strKernel.replace('{{type}}', 'float') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64: + strKernel = strKernel.replace('{{type}}', 'double') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32: + strKernel = strKernel.replace('{{type}}', 'int') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64: + strKernel = strKernel.replace('{{type}}', 'long') + + elif type(objValue) == torch.Tensor: + print(strVariable, objValue.dtype) + assert(False) + + elif True: + print(strVariable, type(objValue)) + assert(False) + + # end + # end + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + # end + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item())) + # end + + while True: + objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == '(' else 0 + intParentheses -= 1 if strKernel[intStop] == ')' else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(',') + + assert(intArgs == len(strArgs) - 1) + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')') + # end + + strKernel = strKernel.replace('OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', '(' + str.join('+', strIndex) + ')') + # end + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()', strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == '(' else 0 + intParentheses -= 1 if strKernel[intStop] == ')' else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(',') + + assert(intArgs == len(strArgs) - 1) + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')') + # end + + strKernel = strKernel.replace('VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', strTensor + '[' + str.join('+', strIndex) + ']') + # end + + objCudacache[strKey] = { + 'strFunction': strFunction, + 'strKernel': strKernel + } + # end + + return strKey +# end + + +@cupy.memoize(for_each_device=True) +def cuda_launch(strKey:str): + if 'CUDA_HOME' not in os.environ: + os.environ['CUDA_HOME'] = '/usr/local/cuda/' + # end + + return cupy.cuda.compile_with_cache(objCudacache[strKey]['strKernel'], tuple(['-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include'])).get_function(objCudacache[strKey]['strFunction']) +# end + + +########################################################## + + +class costvol_func(torch.autograd.Function): + @staticmethod + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward(self, tenOne, tenTwo): + tenOut = tenOne.new_empty([tenOne.shape[0], 81, tenOne.shape[2], tenOne.shape[3]]) + + cuda_launch(cuda_kernel('costvol_out', ''' + extern "C" __global__ void __launch_bounds__(512) costvol_out( + const int n, + const {{type}}* __restrict__ tenOne, + const {{type}}* __restrict__ tenTwo, + {{type}}* __restrict__ tenOut + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_0(tenOut); + const int intC = -1; + const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut); + const int intX = ( intIndex ) % SIZE_3(tenOut); + + {{type}} fltOne[{{intChans}}]; + + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX); + } + + int intOffset = OFFSET_4(tenOut, intN, 0, intY, intX); + + for (int intOy = intY - 4; intOy <= intY + 4; intOy += 1) { + for (int intOx = intX - 4; intOx <= intX + 4; intOx += 1) { + {{type}} fltValue = 0.0f; + + if ((intOy >= 0) && (intOy < SIZE_2(tenOut)) && (intOx >= 0) && (intOx < SIZE_3(tenOut))) { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltValue += abs(fltOne[intValue] - VALUE_4(tenTwo, intN, intValue, intOy, intOx)); + } + } else { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltValue += abs(fltOne[intValue]); + } + } + + tenOut[intOffset] = fltValue / SIZE_1(tenOne); + intOffset += SIZE_2(tenOut) * SIZE_3(tenOut); + } + } + } } + ''', { + 'intChans': tenOne.shape[1], + 'tenOne': tenOne, + 'tenTwo': tenTwo, + 'tenOut': tenOut + }))( + grid=tuple([int(((tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]) + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), tenOut.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + + self.save_for_backward(tenOne, tenTwo) + + return tenOut + # end + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(self, tenOutgrad): + tenOne, tenTwo = self.saved_tensors + + tenOutgrad = tenOutgrad.contiguous(); assert(tenOutgrad.is_cuda == True) + + tenOnegrad = tenOne.new_zeros([tenOne.shape[0], tenOne.shape[1], tenOne.shape[2], tenOne.shape[3]]) if self.needs_input_grad[0] == True else None + tenTwograd = tenTwo.new_zeros([tenTwo.shape[0], tenTwo.shape[1], tenTwo.shape[2], tenTwo.shape[3]]) if self.needs_input_grad[1] == True else None + + if tenOnegrad is not None: + cuda_launch(cuda_kernel('costvol_onegrad', ''' + extern "C" __global__ void __launch_bounds__(512) costvol_onegrad( + const int n, + const {{type}}* __restrict__ tenOne, + const {{type}}* __restrict__ tenTwo, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenOnegrad, + {{type}}* __restrict__ tenTwograd + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenOnegrad) / SIZE_2(tenOnegrad) ) % SIZE_0(tenOnegrad); + const int intC = -1; + const int intY = ( intIndex / SIZE_3(tenOnegrad) ) % SIZE_2(tenOnegrad); + const int intX = ( intIndex ) % SIZE_3(tenOnegrad); + + {{type}} fltOne[{{intChans}}]; + + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX); + } + + int intOffset = OFFSET_4(tenOutgrad, intN, 0, intY, intX); + + for (int intOy = intY - 4; intOy <= intY + 4; intOy += 1) { + for (int intOx = intX - 4; intOx <= intX + 4; intOx += 1) { + if ((intOy >= 0) && (intOy < SIZE_2(tenOutgrad)) && (intOx >= 0) && (intOx < SIZE_3(tenOutgrad))) { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + if (fltOne[intValue] - VALUE_4(tenTwo, intN, intValue, intOy, intOx) >= 0.0f) { + tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += +tenOutgrad[intOffset] / SIZE_1(tenOne); + } else { + tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += -tenOutgrad[intOffset] / SIZE_1(tenOne); + } + } + } else { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + if (fltOne[intValue] >= 0.0f) { + tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += +tenOutgrad[intOffset] / SIZE_1(tenOne); + } else { + tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += -tenOutgrad[intOffset] / SIZE_1(tenOne); + } + } + } + + intOffset += SIZE_2(tenOutgrad) * SIZE_3(tenOutgrad); + } + } + } } + ''', { + 'intChans': tenOne.shape[1], + 'tenOne': tenOne, + 'tenTwo': tenTwo, + 'tenOutgrad': tenOutgrad, + 'tenOnegrad': tenOnegrad, + 'tenTwograd': tenTwograd + }))( + grid=tuple([int(((tenOnegrad.shape[0] * tenOnegrad.shape[2] * tenOnegrad.shape[3]) + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenOnegrad.shape[0] * tenOnegrad.shape[2] * tenOnegrad.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), tenOutgrad.data_ptr(), tenOnegrad.data_ptr(), tenTwograd.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + # end + + if tenTwograd is not None: + cuda_launch(cuda_kernel('costvol_twograd', ''' + extern "C" __global__ void __launch_bounds__(512) costvol_twograd( + const int n, + const {{type}}* __restrict__ tenOne, + const {{type}}* __restrict__ tenTwo, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenOnegrad, + {{type}}* __restrict__ tenTwograd + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenTwograd) / SIZE_2(tenTwograd) ) % SIZE_0(tenTwograd); + const int intC = -1; + const int intY = ( intIndex / SIZE_3(tenTwograd) ) % SIZE_2(tenTwograd); + const int intX = ( intIndex ) % SIZE_3(tenTwograd); + + {{type}} fltOne[{{intChans}}]; + + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX); + } + + int intOffset = OFFSET_4(tenOutgrad, intN, 0, intY, intX); + + for (int intOy = intY - 4; intOy <= intY + 4; intOy += 1) { + for (int intOx = intX - 4; intOx <= intX + 4; intOx += 1) { + if ((intOy >= 0) && (intOy < SIZE_2(tenOutgrad)) && (intOx >= 0) && (intOx < SIZE_3(tenOutgrad))) { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + if (fltOne[intValue] - VALUE_4(tenTwo, intN, intValue, intOy, intOx) >= 0.0f) { + atomicAdd(&tenTwograd[OFFSET_4(tenTwograd, intN, intValue, intOy, intOx)], -tenOutgrad[intOffset] / SIZE_1(tenOne)); + } else { + atomicAdd(&tenTwograd[OFFSET_4(tenTwograd, intN, intValue, intOy, intOx)], +tenOutgrad[intOffset] / SIZE_1(tenOne)); + } + } + } else { + // ... + } + + intOffset += SIZE_2(tenOutgrad) * SIZE_3(tenOutgrad); + } + } + } } + ''', { + 'intChans': tenOne.shape[1], + 'tenOne': tenOne, + 'tenTwo': tenTwo, + 'tenOutgrad': tenOutgrad, + 'tenOnegrad': tenOnegrad, + 'tenTwograd': tenTwograd + }))( + grid=tuple([int(((tenTwograd.shape[0] * tenTwograd.shape[2] * tenTwograd.shape[3]) + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenTwograd.shape[0] * tenTwograd.shape[2] * tenTwograd.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), tenOutgrad.data_ptr(), tenOnegrad.data_ptr(), tenTwograd.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + # end + + return tenOnegrad, tenTwograd, None, None + # end +# end \ No newline at end of file diff --git a/modules/components/upr_net_freq/frequency_enhance.py b/modules/components/upr_net_freq/frequency_enhance.py new file mode 100644 index 0000000000000000000000000000000000000000..6d007987fb327b642cd1c4bc6fb6fdc441405e0a --- /dev/null +++ b/modules/components/upr_net_freq/frequency_enhance.py @@ -0,0 +1,131 @@ +# frequency_enhance_002.py + +import math +import torch +import torch.nn as nn +from einops import rearrange +import torch.nn.functional as F + +class ReshapeLayerNorm(nn.Module): + def __init__(self, dim, norm_layer=nn.LayerNorm): + super(ReshapeLayerNorm, self).__init__() + + self.dim = dim + self.norm = norm_layer(dim) + + def forward(self, x): + B, C, H, W = x.size() + x = rearrange(x, 'b c h w -> b (h w) c') + x = self.norm(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=H) + return x + +class ChannelSelfAttention(nn.Module): + def __init__(self, dim, num_head, attn_drop=0.0, proj_drop=0.0): + super(ChannelSelfAttention, self).__init__() + self.dim = dim + self.num_head = num_head + + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_head, 1, 1))), requires_grad=True) + + self.attn_drop = nn.Dropout(attn_drop) + + self.proj = nn.Conv2d(dim, dim, 1) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, q,k,v, sp=None): + B, C, H, W = q.size() + + q,k,v = map(lambda x: rearrange(x, 'b (l c) h w -> b l c (h w)', l=self.num_head), [q,k,v]) # [B, L, C/L, HW] + + attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(2,-1) # [B, L, C/L, C/L] + logit_scale = torch.clamp(self.logit_scale, max=math.log(1. / 0.01)).exp() + attn = attn * logit_scale + + attn = F.softmax(attn, dim=-1) + attn = self.attn_drop(attn) + + x = attn @ v # [B, L, C/L, HW] + + # head merge + x = rearrange(x, 'b l c (h w) -> b (l c) h w', h=H) # [B, C, H, W] + x = self.proj_drop(self.proj(x)) # [B, C, H, W] + + return x + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_ratio, act_layer=nn.GELU, bias=True, drop=0.0): + super(FeedForward, self).__init__() + + self.dim = dim + self.hidden_ratio = hidden_ratio + + self.hidden = nn.Conv2d(32, int(32*hidden_ratio), 1, bias=bias) + self.drop1 = nn.Dropout(drop) + self.out = nn.Conv2d(int(32*hidden_ratio), dim, 1, bias=bias) + self.drop2 = nn.Dropout(drop) + self.act = act_layer() + + def forward(self, x): + return self.drop2(self.out(self.drop1(self.act(self.hidden(x))))) + +class FrequencyEnhancementTransformer(nn.Module): + def __init__(self, c_dim, feat_dim, num_head, hidden_ratio, fftshift=False, *args, **kwargs): + super(FrequencyEnhancementTransformer, self).__init__() + self.c_dim = c_dim + self.feat_dim = feat_dim + self.num_head = num_head + self.hidden_ratio = hidden_ratio + self.fftshift = fftshift + + self.c_proj = nn.Sequential(nn.Conv2d(in_channels=c_dim*2+4, out_channels=c_dim*2+4, kernel_size=3, stride=1, padding=1, groups=c_dim*2+4), + nn.Conv2d(in_channels=c_dim*2+4, out_channels=32, kernel_size=1, stride=1), + nn.LeakyReLU()) + self.feat_proj = nn.Sequential(nn.Conv2d(in_channels=feat_dim, out_channels=feat_dim, kernel_size=3, stride=1, padding=1, groups=feat_dim), + nn.Conv2d(in_channels=feat_dim, out_channels=32, kernel_size=1, stride=1), + nn.LeakyReLU()) + self.phase_dwconv = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, groups=32), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1), + nn.LeakyReLU()) + + self.q_proj = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1) + self.k_proj = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1) + self.v_proj = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1) + self.attn = ChannelSelfAttention(32, num_head) + self.norm1 = ReshapeLayerNorm(32) + + self.ffn = FeedForward(feat_dim, hidden_ratio) + self.norm2 = ReshapeLayerNorm(feat_dim) + + def dft(self, x): + fft = torch.fft.fft2(x, dim=(2,3), norm='ortho') + fft = torch.fft.fftshift(fft, dim=(2,3)) if self.fftshift else fft + amplitude = torch.abs(fft) + phase = torch.angle(fft) + return amplitude, phase + + def forward(self, c0, c1, feat, flow, *args, **kwargs): + B,D,H,W = feat.size() + amp_, pha_ = self.dft(feat) + + c = self.c_proj(torch.cat([c0,c1,flow], dim=1)) # [B, 32, H, W] + feat = self.feat_proj(feat) # [B, 32, H, W] + + amp_c, pha_c = self.dft(c) # [B, 32, H, W] + amp_f, pha_f = self.dft(feat) # [B, 32, H, W] + + amp_q = self.q_proj(amp_c) # [B, 32, H, W] + amp_k = self.k_proj(amp_f) # [B, 32, H, W] + amp_v = self.v_proj(amp_f) # [B, 32, H, W] + amp_attn = self.norm1(self.attn(amp_q, amp_k, amp_v)) + amp_f # [B, 32, H, W] + amp = self.norm2(self.ffn(amp_attn)) + amp_ # [B, D, H, W] + + pha_local = self.norm1(self.phase_dwconv(pha_c)) + pha_f + pha = self.norm2(self.ffn(pha_local)) + pha_ + + real = amp * torch.cos(pha) + imag = amp * torch.sin(pha) + output = torch.fft.ifft2(torch.complex(real, imag)) + output = torch.abs(output) + + return output, (amp, pha) \ No newline at end of file diff --git a/modules/components/upr_net_freq/frequency_enhance_001.py b/modules/components/upr_net_freq/frequency_enhance_001.py new file mode 100644 index 0000000000000000000000000000000000000000..a6d7940218dd552bb490a8528ee3f5b392b917ee --- /dev/null +++ b/modules/components/upr_net_freq/frequency_enhance_001.py @@ -0,0 +1,128 @@ +# frequency_enhance_001.py + +import math +import torch +import torch.nn as nn +from einops import rearrange +import torch.nn.functional as F + +class ReshapeLayerNorm(nn.Module): + def __init__(self, dim, norm_layer=nn.LayerNorm): + super(ReshapeLayerNorm, self).__init__() + + self.dim = dim + self.norm = norm_layer(dim) + + def forward(self, x): + B, C, H, W = x.size() + x = rearrange(x, 'b c h w -> b (h w) c') + x = self.norm(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=H) + return x + +class ChannelSelfAttention(nn.Module): + def __init__(self, dim, num_head, attn_drop=0.0, proj_drop=0.0): + super(ChannelSelfAttention, self).__init__() + self.dim = dim + self.num_head = num_head + + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_head, 1, 1))), requires_grad=True) + + self.attn_drop = nn.Dropout(attn_drop) + + self.proj = nn.Conv2d(dim, dim, 1) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, q,k,v, sp=None): + B, C, H, W = q.size() + + q,k,v = map(lambda x: rearrange(x, 'b (l c) h w -> b l c (h w)', l=self.num_head), [q,k,v]) # [B, L, C/L, HW] + + attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(2,-1) # [B, L, C/L, C/L] + logit_scale = torch.clamp(self.logit_scale, max=math.log(1. / 0.01)).exp() + attn = attn * logit_scale + + attn = F.softmax(attn, dim=-1) + attn = self.attn_drop(attn) + + x = attn @ v # [B, L, C/L, HW] + + # head merge + x = rearrange(x, 'b l c (h w) -> b (l c) h w', h=H) # [B, C, H, W] + x = self.proj_drop(self.proj(x)) # [B, C, H, W] + + return x + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_ratio, act_layer=nn.GELU, bias=True, drop=0.0): + super(FeedForward, self).__init__() + + self.dim = dim + self.hidden_ratio = hidden_ratio + + self.hidden = nn.Conv2d(32, int(32*hidden_ratio), 1, bias=bias) + self.drop1 = nn.Dropout(drop) + self.out = nn.Conv2d(int(32*hidden_ratio), dim, 1, bias=bias) + self.drop2 = nn.Dropout(drop) + self.act = act_layer() + + def forward(self, x): + return self.drop2(self.out(self.drop1(self.act(self.hidden(x))))) + +class FrequencyEnhancementTransformer(nn.Module): + def __init__(self, c_dim, feat_dim, num_head, hidden_ratio, *args, **kwargs): + super(FrequencyEnhancementTransformer, self).__init__() + self.c_dim = c_dim + self.feat_dim = feat_dim + self.num_head = num_head + self.hidden_ratio = hidden_ratio + + self.c_proj = nn.Sequential(nn.Conv2d(in_channels=c_dim*2, out_channels=c_dim*2, kernel_size=3, stride=1, padding=1, groups=c_dim*2), + nn.Conv2d(in_channels=c_dim*2, out_channels=32, kernel_size=1, stride=1), + nn.LeakyReLU()) + self.feat_proj = nn.Sequential(nn.Conv2d(in_channels=feat_dim, out_channels=feat_dim, kernel_size=3, stride=1, padding=1, groups=feat_dim), + nn.Conv2d(in_channels=feat_dim, out_channels=32, kernel_size=1, stride=1), + nn.LeakyReLU()) + + self.q_proj = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1) + self.k_proj = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1) + self.v_proj = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1) + self.attn = ChannelSelfAttention(32, num_head) + self.norm1 = ReshapeLayerNorm(32) + + self.ffn = FeedForward(feat_dim, hidden_ratio) + self.norm2 = ReshapeLayerNorm(feat_dim) + + def dft(self, x): + fft = torch.fft.fft2(x, dim=(2,3), norm='ortho') +# amplitude = torch.abs(fft) +# phase = torch.angle(fft) + return fft.real, fft.imag + + def forward(self, c0, c1, feat, *args, **kwargs): + B,D,H,W = feat.size() + real_, imag_ = self.dft(feat) + + c = self.c_proj(torch.cat([c0,c1], dim=1)) # [B, 32, H, W] + feat = self.feat_proj(feat) # [B, 32, H, W] + + real_c, imag_c = self.dft(c) # [B, 32, H, W] + real_f, imag_f = self.dft(feat) # [B, 32, H, W] + + real_q = self.q_proj(real_c) # [B, 32, H, W] + real_k = self.k_proj(real_f) # [B, 32, H, W] + real_v = self.v_proj(real_f) # [B, 32, H, W] + real_attn = self.norm1(self.attn(real_q, real_k, real_v)) + real_f # [B, 32, H, W] + real = self.norm2(self.ffn(real_attn)) + real_ # [B, D, H, W] + + imag_q = self.q_proj(imag_c) + imag_k = self.k_proj(imag_f) + imag_v = self.v_proj(imag_f) + imag_attn = self.norm1(self.attn(imag_q, imag_k, imag_v)) + imag_f + imag = self.norm2(self.ffn(imag_attn)) + imag_ + + out = torch.complex(real, imag) + output = torch.fft.ifft2(out) + output = torch.abs(output) + + return output \ No newline at end of file diff --git a/modules/components/upr_net_freq/frequency_enhance_005.py b/modules/components/upr_net_freq/frequency_enhance_005.py new file mode 100644 index 0000000000000000000000000000000000000000..bee1f2e004a0b5c9cf8ecd75ea06e5e449fedfed --- /dev/null +++ b/modules/components/upr_net_freq/frequency_enhance_005.py @@ -0,0 +1,169 @@ +# frequency_enhance_005.py (freq002+synthesis EncConv reduce + Asym.FreqDec) + +import math +import torch +import torch.nn as nn +from einops import rearrange +import torch.nn.functional as F + +class ReshapeLayerNorm(nn.Module): + def __init__(self, dim, norm_layer=nn.LayerNorm): + super(ReshapeLayerNorm, self).__init__() + + self.dim = dim + self.norm = norm_layer(dim) + + def forward(self, x): + B, C, H, W = x.size() + x = rearrange(x, 'b c h w -> b (h w) c') + x = self.norm(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=H) + return x + +class ChannelSelfAttention(nn.Module): + def __init__(self, dim, num_head, attn_drop=0.0, proj_drop=0.0): + super(ChannelSelfAttention, self).__init__() + self.dim = dim + self.num_head = num_head + + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_head, 1, 1))), requires_grad=True) + + self.attn_drop = nn.Dropout(attn_drop) + + self.proj = nn.Conv2d(dim, dim, 1) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, q,k,v, sp=None): + B, C, H, W = q.size() + + q,k,v = map(lambda x: rearrange(x, 'b (l c) h w -> b l c (h w)', l=self.num_head), [q,k,v]) # [B, L, C/L, HW] + + attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(2,-1) # [B, L, C/L, C/L] + logit_scale = torch.clamp(self.logit_scale, max=math.log(1. / 0.01)).exp() + attn = attn * logit_scale + + attn = F.softmax(attn, dim=-1) + attn = self.attn_drop(attn) + + x = attn @ v # [B, L, C/L, HW] + + # head merge + x = rearrange(x, 'b l c (h w) -> b (l c) h w', h=H) # [B, C, H, W] + x = self.proj_drop(self.proj(x)) # [B, C, H, W] + + return x + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_ratio, act_layer=nn.GELU, bias=True, drop=0.0): + super(FeedForward, self).__init__() + + self.dim = dim + self.hidden_ratio = hidden_ratio + + self.hidden = nn.Conv2d(32, int(32*hidden_ratio), 1, bias=bias) + self.drop1 = nn.Dropout(drop) + self.out = nn.Conv2d(int(32*hidden_ratio), dim, 1, bias=bias) + self.drop2 = nn.Dropout(drop) + self.act = act_layer() + + def forward(self, x): + return self.drop2(self.out(self.drop1(self.act(self.hidden(x))))) + +def dft(x, fftshift=False): + fft = torch.fft.fft2(x, dim=(2,3), norm='ortho') + fft = torch.fft.fftshift(fft, dim=(2,3)) if fftshift else fft + amplitude = torch.abs(fft) + phase = torch.angle(fft) + return amplitude, phase + +class FrequencyEnhancementTransformer(nn.Module): + def __init__(self, c_dim, feat_dim, num_head, hidden_ratio, last=False, fftshift=False, *args, **kwargs): + super(FrequencyEnhancementTransformer, self).__init__() + self.c_dim = c_dim + self.feat_dim = feat_dim + self.num_head = num_head + self.hidden_ratio = hidden_ratio + self.last = last + self.fftshift = fftshift + + self.c_proj = nn.Sequential(nn.Conv2d(in_channels=c_dim*2+4, out_channels=32, kernel_size=1, stride=1), + nn.LeakyReLU()) + self.feat_proj = nn.Sequential(nn.Conv2d(in_channels=feat_dim, out_channels=32, kernel_size=1, stride=1), + nn.LeakyReLU()) + if last: + self.phase_dwconv = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1) + + self.q_proj = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1) + self.k_proj = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1) + self.v_proj = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1) + self.attn = ChannelSelfAttention(32, num_head) + self.norm1 = ReshapeLayerNorm(32) + + self.ffn = FeedForward(feat_dim, hidden_ratio) + self.norm2 = ReshapeLayerNorm(feat_dim) + + def forward(self, c0, c1, feat, flow, *args, **kwargs): + B,D,H,W = feat.size() + amp_, pha_ = dft(feat, self.fftshift) + pha = None if self.last else pha_ + + c = self.c_proj(torch.cat([c0,c1,flow], dim=1)) # [B, 32, H, W] + feat = self.feat_proj(feat) # [B, 32, H, W] + + amp_c, pha_c = dft(c, self.fftshift) # [B, 32, H, W] + amp_f, pha_f = dft(feat, self.fftshift) # [B, 32, H, W] + + amp_q = self.q_proj(amp_c) # [B, 32, H, W] + amp_k = self.k_proj(amp_f) # [B, 32, H, W] + amp_v = self.v_proj(amp_f) # [B, 32, H, W] + amp_attn = self.norm1(self.attn(amp_q, amp_k, amp_v)) + amp_f # [B, 32, H, W] + amp = self.norm2(self.ffn(amp_attn)) + amp_ # [B, D, H, W] + + if self.last: + pha_local = self.norm1(self.phase_dwconv(pha_c)) + pha_f + pha = self.norm2(self.ffn(pha_local)) + pha_ + + real = amp * torch.cos(pha) + imag = amp * torch.sin(pha) + output = torch.fft.ifft2(torch.complex(real, imag)) + output = torch.abs(output) + + return output + +class FrequencyEnhancementDecoder(nn.Module): + def __init__(self, concat_dim, dim, fftshift, *args, **kwargs): + super(FrequencyEnhancementDecoder, self).__init__() + self.concat_dim = concat_dim + self.dim = dim + self.fftshift = fftshift + + self.act = nn.LeakyReLU() + + self.in_dwconv = nn.Conv2d(concat_dim, concat_dim, 3, 1, 1, groups=concat_dim) + self.in_pwconv = nn.Conv2d(concat_dim, dim, 1, 1) + + self.amp_conv = nn.Conv2d(dim, dim, 3, 1, 1) + self.pha_conv = nn.Conv2d(dim, dim, 3, 1, 1) + + self.out_conv = nn.Conv2d(dim, dim, 3, 1, 1) + + def forward(self, enc_feats, warped_feats, flow): + _,_,H0,W0 = enc_feats[0].size() + for i, feat in enumerate(enc_feats[1:]): + enc_feats[i+1] = F.pixel_shuffle(feat, H0//feat.size(2)) + for i, feat in enumerate(warped_feats[2:]): + warped_feats[i+2] = F.pixel_shuffle(feat, H0//feat.size(2)) + + x = torch.cat(enc_feats+warped_feats+[flow], dim=1) + x = self.act(self.in_pwconv(self.in_dwconv(x))) + + amp, pha = dft(x, self.fftshift) + amp = self.amp_conv(amp) + amp + pha = self.pha_conv(pha) + pha + + real = amp * torch.cos(pha) + imag = amp * torch.sin(pha) + output = torch.fft.ifft2(torch.complex(real, imag)) + x + output = self.act(self.out_conv(torch.abs(output))) + + return output \ No newline at end of file diff --git a/modules/components/upr_net_freq/frequency_enhance_006.py b/modules/components/upr_net_freq/frequency_enhance_006.py new file mode 100644 index 0000000000000000000000000000000000000000..5491b991be0fd9d5d90e5d0c11499311853495e0 --- /dev/null +++ b/modules/components/upr_net_freq/frequency_enhance_006.py @@ -0,0 +1,193 @@ +# frequency_enhance_006.py (freq005+c_proj,feat_proj dwconv์žฌ์ถ”๊ฐ€+decoder residual in_conv, out_conv 2๊ฐœ์”ฉ ์ถ”๊ฐ€) + +import math +import torch +import torch.nn as nn +from einops import rearrange +import torch.nn.functional as F + +class ReshapeLayerNorm(nn.Module): + def __init__(self, dim, norm_layer=nn.LayerNorm): + super(ReshapeLayerNorm, self).__init__() + + self.dim = dim + self.norm = norm_layer(dim) + + def forward(self, x): + B, C, H, W = x.size() + x = rearrange(x, 'b c h w -> b (h w) c') + x = self.norm(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=H) + return x + +class ChannelSelfAttention(nn.Module): + def __init__(self, dim, num_head, attn_drop=0.0, proj_drop=0.0): + super(ChannelSelfAttention, self).__init__() + self.dim = dim + self.num_head = num_head + + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_head, 1, 1))), requires_grad=True) + + self.attn_drop = nn.Dropout(attn_drop) + + self.proj = nn.Conv2d(dim, dim, 1) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, q,k,v, sp=None): + B, C, H, W = q.size() + + q,k,v = map(lambda x: rearrange(x, 'b (l c) h w -> b l c (h w)', l=self.num_head), [q,k,v]) # [B, L, C/L, HW] + + attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(2,-1) # [B, L, C/L, C/L] + logit_scale = torch.clamp(self.logit_scale, max=math.log(1. / 0.01)).exp() + attn = attn * logit_scale + + attn = F.softmax(attn, dim=-1) + attn = self.attn_drop(attn) + + x = attn @ v # [B, L, C/L, HW] + + # head merge + x = rearrange(x, 'b l c (h w) -> b (l c) h w', h=H) # [B, C, H, W] + x = self.proj_drop(self.proj(x)) # [B, C, H, W] + + return x + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_ratio, act_layer=nn.GELU, bias=True, drop=0.0): + super(FeedForward, self).__init__() + + self.dim = dim + self.hidden_ratio = hidden_ratio + + self.hidden = nn.Conv2d(32, int(32*hidden_ratio), 1, bias=bias) + self.drop1 = nn.Dropout(drop) + self.out = nn.Conv2d(int(32*hidden_ratio), dim, 1, bias=bias) + self.drop2 = nn.Dropout(drop) + self.act = act_layer() + + def forward(self, x): + return self.drop2(self.out(self.drop1(self.act(self.hidden(x))))) + +def dft(x, fftshift=False): + fft = torch.fft.fft2(x, dim=(2,3), norm='ortho') + fft = torch.fft.fftshift(fft, dim=(2,3)) if fftshift else fft + amplitude = torch.abs(fft) + phase = torch.angle(fft) + return amplitude, phase + +class FrequencyEnhancementTransformer(nn.Module): + def __init__(self, c_dim, feat_dim, num_head, hidden_ratio, last=False, fftshift=False, *args, **kwargs): + super(FrequencyEnhancementTransformer, self).__init__() + self.c_dim = c_dim + self.feat_dim = feat_dim + self.num_head = num_head + self.hidden_ratio = hidden_ratio + self.last = last + self.fftshift = fftshift + + self.c_proj = nn.Sequential(nn.Conv2d(in_channels=c_dim*2+4, out_channels=c_dim*2+4, kernel_size=3, stride=1, padding=1, groups=c_dim*2+4), + nn.Conv2d(in_channels=c_dim*2+4, out_channels=32, kernel_size=1, stride=1), + nn.LeakyReLU()) + self.feat_proj = nn.Sequential(nn.Conv2d(in_channels=feat_dim, out_channels=feat_dim, kernel_size=3, stride=1, padding=1, groups=feat_dim), + nn.Conv2d(in_channels=feat_dim, out_channels=32, kernel_size=1, stride=1), + nn.LeakyReLU()) + if last: + self.phase_conv = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1) + + self.q_proj = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1) + self.k_proj = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1) + self.v_proj = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1) + self.attn = ChannelSelfAttention(32, num_head) + self.norm1 = ReshapeLayerNorm(32) + + self.ffn = FeedForward(feat_dim, hidden_ratio) + self.norm2 = ReshapeLayerNorm(feat_dim) + + def forward(self, c0, c1, feat, flow, *args, **kwargs): + B,D,H,W = feat.size() + feat_ = feat + amp_, pha_ = dft(feat, self.fftshift) + pha = None if self.last else pha_ + + c = self.c_proj(torch.cat([c0,c1,flow], dim=1)) # [B, 32, H, W] + feat = self.feat_proj(feat) # [B, 32, H, W] + + amp_c, pha_c = dft(c, self.fftshift) # [B, 32, H, W] + amp_f, pha_f = dft(feat, self.fftshift) # [B, 32, H, W] + + amp_q = self.q_proj(amp_c) # [B, 32, H, W] + amp_k = self.k_proj(amp_f) # [B, 32, H, W] + amp_v = self.v_proj(amp_f) # [B, 32, H, W] + amp_attn = self.norm1(self.attn(amp_q, amp_k, amp_v)) + amp_f # [B, 32, H, W] + amp = self.norm2(self.ffn(amp_attn)) + amp_ # [B, D, H, W] + + if self.last: + pha_local = self.norm1(self.phase_conv(pha_c)) + pha_f + pha = self.norm2(self.ffn(pha_local)) + pha_ + + real = amp * torch.cos(pha) + imag = amp * torch.sin(pha) + output = torch.fft.ifft2(torch.complex(real, imag)) + output = torch.abs(output) + feat_ + + return output + +class FrequencyEnhancementDecoder(nn.Module): + def __init__(self, concat_dim, dim, fftshift, *args, **kwargs): + super(FrequencyEnhancementDecoder, self).__init__() + self.concat_dim = concat_dim + self.dim = dim + self.fftshift = fftshift + + self.act = nn.LeakyReLU() + + self.in_conv1 = nn.Sequential(nn.Conv2d(concat_dim, concat_dim, 3, 1, 1, groups=concat_dim), + nn.Conv2d(concat_dim, dim, 1, 1), + nn.LeakyReLU()) + self.in_conv2 = nn.Sequential(nn.Conv2d(dim, dim, 3, 1, 1, groups=dim), + nn.Conv2d(dim, dim, 1, 1), + nn.LeakyReLU()) + self.in_conv3 = nn.Sequential(nn.Conv2d(dim, dim, 3, 1, 1, groups=dim), + nn.Conv2d(dim, dim, 1, 1), + nn.LeakyReLU()) + + self.amp_conv = nn.Conv2d(dim, dim, 3, 1, 1) + self.pha_conv = nn.Conv2d(dim, dim, 3, 1, 1) + + self.out_conv1 = nn.Sequential(nn.Conv2d(dim, dim, 3, 1, 1, groups=dim), + nn.Conv2d(dim, dim, 1, 1), + nn.LeakyReLU()) + self.out_conv2 = nn.Sequential(nn.Conv2d(dim, dim, 3, 1, 1, groups=dim), + nn.Conv2d(dim, dim, 1, 1), + nn.LeakyReLU()) + self.out_conv3 = nn.Sequential(nn.Conv2d(dim, dim, 3, 1, 1, groups=dim), + nn.Conv2d(dim, dim, 1, 1), + nn.LeakyReLU()) + + def forward(self, enc_feats, warped_feats, flow): + _,_,H0,W0 = enc_feats[0].size() + for i, feat in enumerate(enc_feats[1:]): + enc_feats[i+1] = F.pixel_shuffle(feat, H0//feat.size(2)) + for i, feat in enumerate(warped_feats[2:]): + warped_feats[i+2] = F.pixel_shuffle(feat, H0//feat.size(2)) + + x = torch.cat(enc_feats+warped_feats+[flow], dim=1) + x = self.in_conv1(x) + x = self.in_conv2(x) + x + x = self.in_conv3(x) + x + + amp, pha = dft(x, self.fftshift) + amp = self.amp_conv(amp) + amp + pha = self.pha_conv(pha) + pha + + real = amp * torch.cos(pha) + imag = amp * torch.sin(pha) + output = torch.fft.ifft2(torch.complex(real, imag)) + output = torch.abs(output) + x + + output = self.out_conv1(output) + output + output = self.out_conv2(output) + output + output = self.out_conv3(output) + output + + return output \ No newline at end of file diff --git a/modules/components/upr_net_freq/m2m.py b/modules/components/upr_net_freq/m2m.py new file mode 100644 index 0000000000000000000000000000000000000000..f536207982e94a86dc28b8599c557c84b5effb69 --- /dev/null +++ b/modules/components/upr_net_freq/m2m.py @@ -0,0 +1,407 @@ + +import math +import torch +import torch.nn as nn +import typing + +from ..components import register +from .backwarp import * +from .softsplat import _FunctionSoftsplat + + +########################################################## + +def forwarp_mframe_mask(tenIn1, tenFlow1, t1, tenIn2, tenFlow2, t2, tenMetric1=None, tenMetric2=None): + def one_fdir(tenIn, tenFlow, td, tenMetric): + tenIn = torch.cat([tenIn * td * (tenMetric).clip(-20.0, 20.0).exp(), td * (tenMetric).clip(-20.0, 20.0).exp()], + 1) + + tenOut = _FunctionSoftsplat.apply(tenIn, tenFlow) + + return tenOut[:, :-1, :, :], tenOut[:, -1:, :, :] + 0.0000001 + + flow_num = tenFlow1.shape[0] + tenOutF, tenOutB = 0, 0 + tenNormalizeF, tenNormalizeB = 0, 0 + for idx in range(flow_num): + tenOutF_, tenNormalizeF_ = one_fdir(tenIn1[idx], tenFlow1[idx], t1[idx], tenMetric1[idx]) + tenOutB_, tenNormalizeB_ = one_fdir(tenIn2[idx], tenFlow2[idx], t2[idx], tenMetric2[idx]) + + tenOutF += tenOutF_ + tenOutB += tenOutB_ + tenNormalizeF += tenNormalizeF_ + tenNormalizeB += tenNormalizeB_ + + return tenOutF / tenNormalizeF, tenNormalizeF < 0.00001, tenOutB / tenNormalizeB, tenNormalizeB < 0.00001 + + +################################################################### + +c = 16 + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return torch.nn.Sequential( + torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + torch.nn.PReLU(out_planes) + ) + + +def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): + return torch.nn.Sequential( + torch.torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, + kernel_size=kernel_size, stride=stride, padding=padding, bias=True), + torch.nn.PReLU(out_planes) + ) + + +class Conv2(torch.nn.Module): + def __init__(self, in_planes, out_planes, stride=2): + super(Conv2, self).__init__() + self.conv1 = conv(in_planes, out_planes, 3, stride, 1) + self.conv2 = conv(out_planes, out_planes, 3, 1, 1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + +class Conv2n(torch.nn.Module): + def __init__(self, in_planes, out_planes, stride=2): + super(Conv2n, self).__init__() + self.conv1 = conv(in_planes, in_planes, 3, stride, 1) + self.conv2 = conv(in_planes, in_planes, 3, 1, 1) + self.conv3 = conv(in_planes, in_planes, 1, 1, 0) + self.conv4 = conv(in_planes, out_planes, 1, 1, 0) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + return x + + +##################################################### + +class ImgPyramid(torch.nn.Module): + def __init__(self): + super(ImgPyramid, self).__init__() + self.conv1 = Conv2(3, c) + self.conv2 = Conv2(c, 2 * c) + self.conv3 = Conv2(2 * c, 4 * c) + self.conv4 = Conv2(4 * c, 8 * c) + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(x1) + x3 = self.conv3(x2) + x4 = self.conv4(x3) + return [x1, x2, x3, x4] + + +class EncDec(torch.nn.Module): + def __init__(self, branch): + super(EncDec, self).__init__() + self.branch = branch + + self.down0 = Conv2(8, 2 * c) + self.down1 = Conv2(6 * c, 4 * c) + self.down2 = Conv2(12 * c, 8 * c) + self.down3 = Conv2(24 * c, 16 * c) + + self.up0 = deconv(48 * c, 8 * c) + self.up1 = deconv(16 * c, 4 * c) + self.up2 = deconv(8 * c, 2 * c) + self.up3 = deconv(4 * c, c) + self.conv = torch.nn.Conv2d(c, 2 * self.branch, 3, 1, 1) + + self.conv_m = torch.nn.Conv2d(c, self.branch, 3, 1, 1) + + # For Channel dimennsion + self.conv_C = torch.nn.Sequential( + torch.nn.AdaptiveAvgPool2d(1), + torch.nn.Conv2d(16 * c, 16 * 16 * c, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True), + torch.nn.Sigmoid() + ) + + # For Height dimennsion + self.conv_H = torch.nn.Sequential( + torch.nn.AdaptiveAvgPool2d((None, 1)), + torch.nn.Conv2d(16 * c, 16, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True), + torch.nn.Sigmoid() + ) + + # For Width dimennsion + self.conv_W = torch.nn.Sequential( + torch.nn.AdaptiveAvgPool2d((1, None)), + torch.nn.Conv2d(16 * c, 16, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True), + torch.nn.Sigmoid() + ) + + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, flow0, flow1, im0, im1, c0, c1): + N_, C_, H_, W_ = im0.shape + + wim1 = backwarp(im1, flow0) + wim0 = backwarp(im0, flow1) + s0_0 = self.down0(torch.cat((flow0, im0, wim1), 1)) + s1_0 = self.down0(torch.cat((flow1, im1, wim0), 1)) + + ######################################################################################### + flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + + wf0 = backwarp(torch.cat((s0_0, c0[0]), 1), flow1) + wf1 = backwarp(torch.cat((s1_0, c1[0]), 1), flow0) + + s0_1 = self.down1(torch.cat((s0_0, c0[0], wf1), 1)) + s1_1 = self.down1(torch.cat((s1_0, c1[0], wf0), 1)) + + ######################################################################################### + flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + + wf0 = backwarp(torch.cat((s0_1, c0[1]), 1), flow1) + wf1 = backwarp(torch.cat((s1_1, c1[1]), 1), flow0) + + s0_2 = self.down2(torch.cat((s0_1, c0[1], wf1), 1)) + s1_2 = self.down2(torch.cat((s1_1, c1[1], wf0), 1)) + + ######################################################################################### + flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + + wf0 = backwarp(torch.cat((s0_2, c0[2]), 1), flow1) + wf1 = backwarp(torch.cat((s1_2, c1[2]), 1), flow0) + + s0_3 = self.down3(torch.cat((s0_2, c0[2], wf1), 1)) + s1_3 = self.down3(torch.cat((s1_2, c1[2], wf0), 1)) + + ######################################################################################### + + s0_3_c = self.conv_C(s0_3) + s0_3_c = s0_3_c.view(N_, 16, -1, 1, 1) + + s0_3_h = self.conv_H(s0_3) + s0_3_h = s0_3_h.view(N_, 16, 1, -1, 1) + + s0_3_w = self.conv_W(s0_3) + s0_3_w = s0_3_w.view(N_, 16, 1, 1, -1) + + cube0 = (s0_3_c * s0_3_h * s0_3_w).mean(1) + + s0_3 = s0_3 * cube0 + + s1_3_c = self.conv_C(s1_3) + s1_3_c = s1_3_c.view(N_, 16, -1, 1, 1) + + s1_3_h = self.conv_H(s1_3) + s1_3_h = s1_3_h.view(N_, 16, 1, -1, 1) + + s1_3_w = self.conv_W(s1_3) + s1_3_w = s1_3_w.view(N_, 16, 1, 1, -1) + + cube1 = (s1_3_c * s1_3_h * s1_3_w).mean(1) + + s1_3 = s1_3 * cube1 + + ######################################################################################### + flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + + wf0 = backwarp(torch.cat((s0_3, c0[3]), 1), flow1) + wf1 = backwarp(torch.cat((s1_3, c1[3]), 1), flow0) + + x0 = self.up0(torch.cat((s0_3, c0[3], wf1), 1)) + x1 = self.up0(torch.cat((s1_3, c1[3], wf0), 1)) + + x0 = self.up1(torch.cat((s0_2, x0), 1)) + x1 = self.up1(torch.cat((s1_2, x1), 1)) + + x0 = self.up2(torch.cat((s0_1, x0), 1)) + x1 = self.up2(torch.cat((s1_1, x1), 1)) + + x0 = self.up3(torch.cat((s0_0, x0), 1)) + x1 = self.up3(torch.cat((s1_0, x1), 1)) + + m0 = self.sigmoid(self.conv_m(x0)) * 0.8 + 0.1 + m1 = self.sigmoid(self.conv_m(x1)) * 0.8 + 0.1 + + x0 = self.conv(x0) + x1 = self.conv(x1) + + return x0, x1, m0, m1 + + +@register('m2m_pwc') +class M2M_PWC(torch.nn.Module): + def __init__(self, ratio=4): + super(M2M_PWC, self).__init__() + self.branch = 4 + self.ratio = ratio + + self.paramAlpha = torch.nn.Parameter(10.0 * torch.ones(1, 1, 1, 1)) + + class MotionRefineNet(torch.nn.Module): + def __init__(self, branch): + super(MotionRefineNet, self).__init__() + self.branch = branch + self.img_pyramid = ImgPyramid() + self.motion_encdec = EncDec(branch) + + def forward(self, flow0, flow1, im0, im1, ratio): + flow0 = ratio * torch.nn.functional.interpolate(input=flow0, scale_factor=ratio, mode='bilinear', + align_corners=False) + flow1 = ratio * torch.nn.functional.interpolate(input=flow1, scale_factor=ratio, mode='bilinear', + align_corners=False) + + c0 = self.img_pyramid(im0) + c1 = self.img_pyramid(im1) + + flow_res = self.motion_encdec(flow0, flow1, im0, im1, c0, c1) + + flow0 = flow0.repeat(1, self.branch, 1, 1) + flow_res[0] + flow1 = flow1.repeat(1, self.branch, 1, 1) + flow_res[1] + + return flow0, flow1, flow_res[2], flow_res[3] + + self.MRN = MotionRefineNet(self.branch) + + def forward(self, img0, img1, time_step=[0.5], ratio=None, **kwargs): + if ratio is None: + ratio = self.ratio + + intWidth = img0.shape[3] and img1.shape[3] + intHeight = img0.shape[2] and img1.shape[2] + + intPadr = ((ratio * 16) - (intWidth % (ratio * 16))) % (ratio * 16) + intPadb = ((ratio * 16) - (intHeight % (ratio * 16))) % (ratio * 16) + + img0 = torch.nn.functional.pad(input=img0, pad=[0, intPadr, 0, intPadb], mode='replicate') + img1 = torch.nn.functional.pad(input=img1, pad=[0, intPadr, 0, intPadb], mode='replicate') + + N_, C_, H_, W_ = img0.shape + + outputs = [] + result_dict = {} + with torch.set_grad_enabled(False): + tenStats = [img0, img1] + tenMean_ = sum([tenIn.mean([1, 2, 3], True) for tenIn in tenStats]) / len(tenStats) + tenStd_ = (sum([tenIn.std([1, 2, 3], False, True).square() + ( + tenMean_ - tenIn.mean([1, 2, 3], True)).square() for tenIn in tenStats]) / len(tenStats)).sqrt() + + im0_o = (img0 - tenMean_) / (tenStd_ + 0.0000001) + im1_o = (img1 - tenMean_) / (tenStd_ + 0.0000001) + + img0 = (img0 - tenMean_) / (tenStd_ + 0.0000001) + img1 = (img1 - tenMean_) / (tenStd_ + 0.0000001) + + im0_ = torch.nn.functional.interpolate(input=img0, scale_factor=2.0 / ratio, mode='bilinear', + align_corners=False) + im1_ = torch.nn.functional.interpolate(input=img1, scale_factor=2.0 / ratio, mode='bilinear', + align_corners=False) + + tenFwd, tenBwd = self.netFlow.bidir(im0_, im1_) + + result_dict['flowfwd'] = torch.nn.functional.interpolate(tenFwd, scale_factor=ratio, mode='bilinear', align_corners=False)[:, :, + :intHeight, :intWidth].clone().detach() * ratio + result_dict['flowbwd'] = torch.nn.functional.interpolate(tenBwd, scale_factor=ratio, mode='bilinear', align_corners=False)[:, :, + :intHeight, :intWidth].clone().detach() * ratio + + tenFwd, tenBwd, WeiMF, WeiMB = self.MRN(tenFwd, tenBwd, img0, img1, ratio) + + img0 = im0_o.repeat(1, self.branch, 1, 1) + img1 = im1_o.repeat(1, self.branch, 1, 1) + tenStd = tenStd_.repeat(1, self.branch, 1, 1) + tenMean = tenMean_.repeat(1, self.branch, 1, 1) + fltTime = time_step.repeat(1, self.branch, 1, 1) + + tenFwd = tenFwd.reshape(N_, self.branch, 2, H_, W_).view(N_ * self.branch, 2, H_, W_) + tenBwd = tenBwd.reshape(N_, self.branch, 2, H_, W_).view(N_ * self.branch, 2, H_, W_) + + WeiMF = WeiMF.reshape(N_, self.branch, 1, H_, W_).view(N_ * self.branch, 1, H_, W_) + WeiMB = WeiMB.reshape(N_, self.branch, 1, H_, W_).view(N_ * self.branch, 1, H_, W_) + + img0 = img0.reshape(N_, self.branch, 3, H_, W_).view(N_ * self.branch, 3, H_, W_) + img1 = img1.reshape(N_, self.branch, 3, H_, W_).view(N_ * self.branch, 3, H_, W_) + + tenStd = tenStd.reshape(N_, self.branch, 1, 1, 1).view(N_ * self.branch, 1, 1, 1) + tenMean = tenMean.reshape(N_, self.branch, 1, 1, 1).view(N_ * self.branch, 1, 1, 1) + fltTime = fltTime.reshape(N_, self.branch, 1, 1, 1).view(N_ * self.branch, 1, 1, 1) + + tenPhotoone = (1.0 - (WeiMF * (img0 - backwarp(img1, tenFwd).detach()).abs().mean([1], True))).clip( + 0.001, None).square() + tenPhototwo = (1.0 - (WeiMB * (img1 - backwarp(img0, tenBwd).detach()).abs().mean([1], True))).clip( + 0.001, None).square() + + t0 = fltTime + flow0 = tenFwd * t0 + metric0 = self.paramAlpha * tenPhotoone + + t1 = 1.0 - fltTime + flow1 = tenBwd * t1 + metric1 = self.paramAlpha * tenPhototwo + + flow0 = flow0.reshape(N_, self.branch, 2, H_, W_).permute(1, 0, 2, 3, 4) + flow1 = flow1.reshape(N_, self.branch, 2, H_, W_).permute(1, 0, 2, 3, 4) + + metric0 = metric0.reshape(N_, self.branch, 1, H_, W_).permute(1, 0, 2, 3, 4) + metric1 = metric1.reshape(N_, self.branch, 1, H_, W_).permute(1, 0, 2, 3, 4) + + img0 = img0.reshape(N_, self.branch, 3, H_, W_).permute(1, 0, 2, 3, 4) + img1 = img1.reshape(N_, self.branch, 3, H_, W_).permute(1, 0, 2, 3, 4) + + t0 = t0.reshape(N_, self.branch, 1, 1, 1).permute(1, 0, 2, 3, 4) + t1 = t1.reshape(N_, self.branch, 1, 1, 1).permute(1, 0, 2, 3, 4) + + tenOutput, mask = forwarp_mframe_mask(img0, flow0, t1, img1, flow1, t0, metric0, metric1) + + tenOutput = tenOutput + mask * (t1.mean(0) * im0_o + t0.mean(0) * im1_o) + + output = (tenOutput * (tenStd_ + 0.0000001)) + tenMean_ + result_dict['imgt_pred'] = output[:, :, :intHeight, :intWidth] + + return result_dict + +class ResBlock(nn.Module): + def __init__(self, in_channels, side_channels, bias=True): + super(ResBlock, self).__init__() + self.side_channels = side_channels + self.conv1 = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(in_channels) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(side_channels) + ) + self.conv3 = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(in_channels) + ) + self.conv4 = nn.Sequential( + nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(side_channels) + ) + self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias) + self.prelu = nn.PReLU(in_channels) + + def forward(self, x): + out = self.conv1(x) + + res_feat = out[:, :-self.side_channels, ...] + side_feat = out[:, -self.side_channels:, :, :] + side_feat = self.conv2(side_feat) + out = self.conv3(torch.cat([res_feat, side_feat], 1)) + + res_feat = out[:, :-self.side_channels, ...] + side_feat = out[:, -self.side_channels:, :, :] + side_feat = self.conv4(side_feat) + out = self.conv5(torch.cat([res_feat, side_feat], 1)) + + out = self.prelu(x + out) + return out \ No newline at end of file diff --git a/modules/components/upr_net_freq/softsplat.py b/modules/components/upr_net_freq/softsplat.py new file mode 100644 index 0000000000000000000000000000000000000000..3c4b3fe227283b5ecb256b8ed2aa7b0846a4ccd2 --- /dev/null +++ b/modules/components/upr_net_freq/softsplat.py @@ -0,0 +1,359 @@ +#!/usr/bin/env python + +import torch + +import cupy +import re + +kernel_Softsplat_updateOutput = ''' + extern "C" __global__ void kernel_Softsplat_updateOutput( + const int n, + const float* input, + const float* flow, + float* output + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(output) / SIZE_2(output) / SIZE_1(output) ) % SIZE_0(output); + const int intC = ( intIndex / SIZE_3(output) / SIZE_2(output) ) % SIZE_1(output); + const int intY = ( intIndex / SIZE_3(output) ) % SIZE_2(output); + const int intX = ( intIndex ) % SIZE_3(output); + + float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); + float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); + + int intNorthwestX = (int) (floor(fltOutputX)); + int intNorthwestY = (int) (floor(fltOutputY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + float fltNorthwest = ((float) (intSoutheastX) - fltOutputX ) * ((float) (intSoutheastY) - fltOutputY ); + float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY ); + float fltSouthwest = ((float) (intNortheastX) - fltOutputX ) * (fltOutputY - (float) (intNortheastY)); + float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY)); + + if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(output)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(output))) { + atomicAdd(&output[OFFSET_4(output, intN, intC, intNorthwestY, intNorthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltNorthwest); + } + + if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(output)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(output))) { + atomicAdd(&output[OFFSET_4(output, intN, intC, intNortheastY, intNortheastX)], VALUE_4(input, intN, intC, intY, intX) * fltNortheast); + } + + if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(output)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(output))) { + atomicAdd(&output[OFFSET_4(output, intN, intC, intSouthwestY, intSouthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltSouthwest); + } + + if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(output)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(output))) { + atomicAdd(&output[OFFSET_4(output, intN, intC, intSoutheastY, intSoutheastX)], VALUE_4(input, intN, intC, intY, intX) * fltSoutheast); + } + } } +''' + +kernel_Softsplat_updateGradInput = ''' + extern "C" __global__ void kernel_Softsplat_updateGradInput( + const int n, + const float* input, + const float* flow, + const float* gradOutput, + float* gradInput, + float* gradFlow + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) / SIZE_1(gradInput) ) % SIZE_0(gradInput); + const int intC = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) ) % SIZE_1(gradInput); + const int intY = ( intIndex / SIZE_3(gradInput) ) % SIZE_2(gradInput); + const int intX = ( intIndex ) % SIZE_3(gradInput); + + float fltGradInput = 0.0; + + float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); + float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); + + int intNorthwestX = (int) (floor(fltOutputX)); + int intNorthwestY = (int) (floor(fltOutputY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + float fltNorthwest = ((float) (intSoutheastX) - fltOutputX ) * ((float) (intSoutheastY) - fltOutputY ); + float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY ); + float fltSouthwest = ((float) (intNortheastX) - fltOutputX ) * (fltOutputY - (float) (intNortheastY)); + float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY)); + + if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) { + fltGradInput += VALUE_4(gradOutput, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest; + } + + if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) { + fltGradInput += VALUE_4(gradOutput, intN, intC, intNortheastY, intNortheastX) * fltNortheast; + } + + if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) { + fltGradInput += VALUE_4(gradOutput, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest; + } + + if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) { + fltGradInput += VALUE_4(gradOutput, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast; + } + + gradInput[intIndex] = fltGradInput; + } } +''' + +kernel_Softsplat_updateGradFlow = ''' + extern "C" __global__ void kernel_Softsplat_updateGradFlow( + const int n, + const float* input, + const float* flow, + const float* gradOutput, + float* gradInput, + float* gradFlow + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + float fltGradFlow = 0.0; + + const int intN = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) / SIZE_1(gradFlow) ) % SIZE_0(gradFlow); + const int intC = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) ) % SIZE_1(gradFlow); + const int intY = ( intIndex / SIZE_3(gradFlow) ) % SIZE_2(gradFlow); + const int intX = ( intIndex ) % SIZE_3(gradFlow); + + float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); + float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); + + int intNorthwestX = (int) (floor(fltOutputX)); + int intNorthwestY = (int) (floor(fltOutputY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + float fltNorthwest = 0.0; + float fltNortheast = 0.0; + float fltSouthwest = 0.0; + float fltSoutheast = 0.0; + + if (intC == 0) { + fltNorthwest = ((float) (-1.0)) * ((float) (intSoutheastY) - fltOutputY ); + fltNortheast = ((float) (+1.0)) * ((float) (intSouthwestY) - fltOutputY ); + fltSouthwest = ((float) (-1.0)) * (fltOutputY - (float) (intNortheastY)); + fltSoutheast = ((float) (+1.0)) * (fltOutputY - (float) (intNorthwestY)); + + } else if (intC == 1) { + fltNorthwest = ((float) (intSoutheastX) - fltOutputX ) * ((float) (-1.0)); + fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (-1.0)); + fltSouthwest = ((float) (intNortheastX) - fltOutputX ) * ((float) (+1.0)); + fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * ((float) (+1.0)); + + } + + for (int intChannel = 0; intChannel < SIZE_1(gradOutput); intChannel += 1) { + float fltInput = VALUE_4(input, intN, intChannel, intY, intX); + + if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) { + fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNorthwestY, intNorthwestX) * fltNorthwest; + } + + if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) { + fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNortheastY, intNortheastX) * fltNortheast; + } + + if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) { + fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSouthwestY, intSouthwestX) * fltSouthwest; + } + + if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) { + fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSoutheastY, intSoutheastX) * fltSoutheast; + } + } + + gradFlow[intIndex] = fltGradFlow; + } } +''' + +def cupy_kernel(strFunction, objVariables): + strKernel = globals()[strFunction] + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg])) + + while True: + objMatch = re.search('(OFFSET_)([0-4])(\()([^\)]+)(\))', strKernel) + + if objMatch is None: + break + + intArgs = int(objMatch.group(2)) + strArgs = objMatch.group(4).split(',') + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')')\ + .strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] + + strKernel = strKernel.replace(objMatch.group(0), '(' + str.join('+', strIndex) + ')') + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) + + if objMatch is None: + break + + intArgs = int(objMatch.group(2)) + strArgs = objMatch.group(4).split(',') + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')')\ + .strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] + + strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') + + return strKernel + + +@cupy.memoize(for_each_device=True) +def cupy_launch(strFunction, strKernel): + return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) + + +class _FunctionSoftsplat(torch.autograd.Function): + @staticmethod + def forward(self, input, flow): + self.save_for_backward(input, flow) + + intSamples = input.shape[0] + intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3] + intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3] + + assert(intFlowDepth == 2) + assert(intInputHeight == intFlowHeight) + assert(intInputWidth == intFlowWidth) + + assert(input.is_contiguous() == True) + assert(flow.is_contiguous() == True) + + output = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ]) + + if input.is_cuda == True: + n = output.nelement() + cupy_launch('kernel_Softsplat_updateOutput', cupy_kernel('kernel_Softsplat_updateOutput', { + 'input': input, + 'flow': flow, + 'output': output + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ n, input.data_ptr(), flow.data_ptr(), output.data_ptr() ] + ) + + elif input.is_cuda == False: + raise NotImplementedError() + + return output + + + @staticmethod + def backward(self, gradOutput): + input, flow = self.saved_tensors + + intSamples = input.shape[0] + intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3] + intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3] + + assert(intFlowDepth == 2) + assert(intInputHeight == intFlowHeight) + assert(intInputWidth == intFlowWidth) + + assert(gradOutput.is_contiguous() == True) + + gradInput = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ])\ + if self.needs_input_grad[0] == True else None + gradFlow = input.new_zeros([ intSamples, intFlowDepth, intFlowHeight, intFlowWidth ])\ + if self.needs_input_grad[1] == True else None + + if input.is_cuda == True: + if gradInput is not None: + n = gradInput.nelement() + cupy_launch('kernel_Softsplat_updateGradInput', cupy_kernel('kernel_Softsplat_updateGradInput', { + 'input': input, + 'flow': flow, + 'gradOutput': gradOutput, + 'gradInput': gradInput, + 'gradFlow': gradFlow + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ n, input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), gradInput.data_ptr(), None ] + ) + + if gradFlow is not None: + n = gradFlow.nelement() + cupy_launch('kernel_Softsplat_updateGradFlow', cupy_kernel('kernel_Softsplat_updateGradFlow', { + 'input': input, + 'flow': flow, + 'gradOutput': gradOutput, + 'gradInput': gradInput, + 'gradFlow': gradFlow + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ n, input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), None, gradFlow.data_ptr() ] + ) + + elif input.is_cuda == False: + raise NotImplementedError() + + + return gradInput, gradFlow + + +def FunctionSoftsplat(tenInput, tenFlow, tenMetric, strType): + assert(tenMetric is None or tenMetric.shape[1] == 1) + assert(strType in ['summation', 'average', 'linear', 'softmax']) + + if strType == 'average': + tenInput = torch.cat([ tenInput, tenInput.new_ones(tenInput.shape[0], 1, tenInput.shape[2], tenInput.shape[3]) ], 1) + + elif strType == 'linear': + tenInput = torch.cat([ tenInput * tenMetric, tenMetric ], 1) + + elif strType == 'softmax': + tenInput = torch.cat([ tenInput * tenMetric.clip(-20, 20).exp(), tenMetric.clip(-20, 20).exp() ], 1) + + + tenOutput = _FunctionSoftsplat.apply(tenInput, tenFlow) + + if strType != 'summation': + tenNormalize = tenOutput[:, -1:, :, :] + + tenNormalize[tenNormalize == 0.0] = 1.0 + + tenOutput = tenOutput[:, :-1, :, :] / tenNormalize + + return tenOutput + + +class ModuleSoftsplat(torch.nn.Module): + def __init__(self, strType): + super(ModuleSoftsplat, self).__init__() + + self.strType = strType + + def forward(self, tenInput, tenFlow, tenMetric): + return FunctionSoftsplat(tenInput, tenFlow, tenMetric, self.strType) diff --git a/modules/components/upr_net_freq/upr_freq.py b/modules/components/upr_net_freq/upr_freq.py new file mode 100644 index 0000000000000000000000000000000000000000..41518d89f4ccd19c05c72331943d795f9612eb15 --- /dev/null +++ b/modules/components/upr_net_freq/upr_freq.py @@ -0,0 +1,567 @@ +# upr_freq_002.py (freq001 freq002) +import torch +import math +import numpy +import torch.nn.functional as F +import torch.nn as nn +import torchvision.transforms.v2.functional as TF + +import modules.components.upr_net_freq.correlation as correlation +import modules.components.upr_net_freq.softsplat as softsplat +from modules.components.upr_net_freq.m2m import * +from modules.components.upr_net_freq.backwarp import backwarp +from .costvol import costvol_func +from ..components import register +from modules.components.upr_net_freq.frequency_enhance import FrequencyEnhancementTransformer + +from utils.padder import InputPadder +from utils.vos.model.network import STCN +from utils.vos.model.inference_core import InferenceCore + + +# **************************************************************************************************# +# => Feature Pyramid +# **************************************************************************************************# + + +def photometric_consistency(img0, img1, flow01): + return (img0 - backwarp(img1, flow01)).abs().sum(dim=1, keepdims=True) + + +def flow_consistency(flow01, flow10): + return (flow01 + backwarp(flow10, flow01)).abs().sum(dim=1, keepdims=True) + + +def gaussian(x): + gaussian_kernel = torch.tensor([[1, 2, 1], + [2, 4, 2], + [1, 2, 1]]) / 16 + gaussian_kernel = gaussian_kernel.repeat(2, 1, 1, 1) + gaussian_kernel = gaussian_kernel.to(torch.cuda.current_device()) + x = torch.nn.functional.pad(x, (1, 1, 1, 1), mode='reflect') + out = torch.nn.functional.conv2d(x, gaussian_kernel, groups=x.shape[1]) + # out = TF.gaussian_blur(x, [3, 3], sigma=[2, 2]) + return out + + +def variance_flow(flow): + flow = flow * torch.tensor(data=[2.0 / (flow.shape[3] - 1.0), 2.0 / (flow.shape[2] - 1.0)], dtype=flow.dtype, + device=flow.device).view(1, 2, 1, 1) + return (gaussian(flow ** 2) - gaussian(flow) ** 2 + 1e-4).sqrt().abs().sum(dim=1, keepdim=True) + + +class FeatPyramid(nn.Module): + """A 3-level feature pyramid, which by default is shared by the motion + estimator and synthesis network. + """ + + def __init__(self): + super(FeatPyramid, self).__init__() + self.conv_stage0 = nn.Sequential( + nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1)) + self.conv_stage1 = nn.Sequential( + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, + stride=2, padding=1), + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), ) + self.conv_stage2 = nn.Sequential( + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, + stride=2, padding=1), + nn.InstanceNorm2d(num_features=128), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=128), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=128), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, img): + C0 = self.conv_stage0(img) + C1 = self.conv_stage1(C0) + C2 = self.conv_stage2(C1) + return [C0, C1, C2] + + +# **************************************************************************************************# +# => Motion Estimation +# **************************************************************************************************# +class MotionEstimator(nn.Module): + """Bi-directional optical flow estimator + 1) construct partial cost volume with the CNN features from the stage 2 of + the feature pyramid; + 2) estimate bi-directional flows, by feeding cost volume, CNN features for + both warped images, CNN feature and estimated flow from previous iteration. + """ + + def __init__(self): + super(MotionEstimator, self).__init__() + # 64 + 256 + 128 * 2 + 128 = 704 + self.conv_flow = nn.Sequential( + nn.Conv2d(4, 128, 7, padding=3), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(128, 64, 3, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + self.conv_corr = nn.Sequential( + nn.Conv2d(81, 64, 1, padding=0), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(64, 128, 3, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + self.conv_layer1 = nn.Sequential( + nn.Conv2d(in_channels=704, out_channels=320, + kernel_size=1, stride=1, padding=0), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer2 = nn.Sequential( + nn.Conv2d(in_channels=320, out_channels=256, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer3 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=224, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer4 = nn.Sequential( + nn.Conv2d(in_channels=224, out_channels=192, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer5 = nn.Sequential( + nn.Conv2d(in_channels=192, out_channels=128, + kernel_size=3, stride=1, padding=1)) + self.conv_layer6 = nn.Sequential( + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=4, + kernel_size=3, stride=1, padding=1, bias=False)) + + self.upsampler = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 16 * 9, 1, padding=0) + ) + + # for m in self.modules(): + # if isinstance(m, nn.Conv2d): + # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') + # elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + # if m.weight is not None: + # nn.init.constant_(m.weight, 1) + # if m.bias is not None: + # nn.init.constant_(m.bias, 0) + + def upsample(self, flow, mask): + """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 4, 4, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(4 * flow, [3, 3], padding=1) + up_flow = up_flow.view(N, 4, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 4, 4 * H, 4 * W) + + def forward(self, feat0, feat1, last_feat, last_flow): + corr_fn = correlation.FunctionCorrelation + feat0_warp = backwarp(feat0, last_flow[:, :2]) + feat1_warp = backwarp(feat1, last_flow[:, 2:]) + + volume0 = F.leaky_relu( + input=costvol_func.apply(feat0_warp, feat1_warp), + negative_slope=0.1, inplace=False) + volume1 = F.leaky_relu( + input=costvol_func.apply(feat1_warp, feat0_warp), + negative_slope=0.1, inplace=False) + corr0 = self.conv_corr(volume0) + corr1 = self.conv_corr(volume1) + flo = self.conv_flow(last_flow) + input_feat = torch.cat([corr0, corr1, feat0_warp, feat1_warp, last_feat, flo], 1) + feat = self.conv_layer1(input_feat) + feat = self.conv_layer2(feat) + feat = self.conv_layer3(feat) + feat = self.conv_layer4(feat) + feat = self.conv_layer5(feat) + flow_res = self.conv_layer6(feat) + flow = last_flow + flow_res + mask = self.upsampler(feat) * .25 + flow = self.upsample(flow, mask) + + return flow, feat + + +# **************************************************************************************************# +# => Frame Synthesis +# **************************************************************************************************# +class SynthesisNetwork(nn.Module): + def __init__(self, splat_mode='average', fftshift=False): + super(SynthesisNetwork, self).__init__() + input_channels = 9 + 4 + 6 + self.encoder_conv = nn.Sequential( + nn.Conv2d(in_channels=input_channels, out_channels=64, + kernel_size=3, stride=1, padding=1), + nn.PReLU(num_parameters=64), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64)) + self.freq_enhance0 = FrequencyEnhancementTransformer( + c_dim=32, feat_dim=64, num_head=4, hidden_ratio=4., fftshift=fftshift) + self.encoder_down1 = nn.Sequential( + nn.Conv2d(in_channels=64 + 32 + 32, out_channels=128, + kernel_size=3, stride=2, padding=1), + nn.PReLU(num_parameters=128), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=128), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=128)) + self.freq_enhance1 = FrequencyEnhancementTransformer( + c_dim=64, feat_dim=128, num_head=4, hidden_ratio=4., fftshift=fftshift) + self.encoder_down2 = nn.Sequential( + nn.Conv2d(in_channels=128 + 64 + 64, out_channels=256, + kernel_size=3, stride=2, padding=1), + nn.PReLU(num_parameters=256), + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=256), + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=256)) + self.freq_enhance2 = FrequencyEnhancementTransformer( + c_dim=128, feat_dim=256, num_head=4, hidden_ratio=4., fftshift=fftshift) + self.decoder_up1 = nn.Sequential( + torch.nn.ConvTranspose2d(in_channels=256 + 128 + 128, + out_channels=128, kernel_size=4, stride=2, + padding=1, bias=True), + nn.PReLU(num_parameters=128), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=128)) + self.decoder_up2 = nn.Sequential( + torch.nn.ConvTranspose2d(in_channels=128 + 128, + out_channels=64, kernel_size=4, stride=2, + padding=1, bias=True), + nn.PReLU(num_parameters=64), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64)) + self.decoder_conv = nn.Sequential( + nn.Conv2d(in_channels=64 + 64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64)) + self.pred = nn.Conv2d(in_channels=64, out_channels=4, kernel_size=3, + stride=1, padding=1) + self.splat_mode = splat_mode + + if self.splat_mode == 'softmax': + # New params for splatting mask generation + self.alpha = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_photo_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_flow_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_variation_flow = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + + def get_splat_weight(self, img0, img1, flow01, flow10): + if self.splat_mode == 'softmax': + M_splat = 1 / ( + 1 + self.alpha_splat_photo_consistency * photometric_consistency(img0, img1, flow01).detach()) + \ + 1 / (1 + self.alpha_splat_flow_consistency * flow_consistency(flow01, flow10).detach()) + \ + 1 / (1 + self.alpha_splat_variation_flow * variance_flow(flow01).detach()) + return M_splat * self.alpha + else: + return None + + def get_warped_representations(self, bi_flow, c0, c1, m_splat_0, m_splat_1, i0=None, i1=None, time_period=0.5): + flow_t0 = bi_flow[:, :2] * time_period * 2 + flow_t1 = bi_flow[:, 2:4] * (1 - time_period) * 2 + warped_c0 = backwarp(c0, flow_t0) + warped_c1 = backwarp(c1, flow_t1) + if (i0 is None) and (i1 is None): + return warped_c0, warped_c1 + else: + warped_img0 = backwarp(i0, flow_t0) + warped_img1 = backwarp(i1, flow_t1) + scaler = torch.Tensor([i0.shape[3], i0.shape[2]]).view(1, 2, 1, 1).cuda() + flow_t0_t1 = torch.cat((flow_t0 / scaler, flow_t1 / scaler), 1) + return warped_img0, warped_img1, warped_c0, warped_c1, flow_t0_t1 + + def forward(self, last_i, i0, i1, c0_pyr, c1_pyr, bi_flow_pyr, time_period=0.5, multi_flow=False): + m_splat_0_0 = self.get_splat_weight(i0, i1, bi_flow_pyr[0][:, :2], bi_flow_pyr[0][:, 2:4]) + m_splat_1_0 = self.get_splat_weight(i1, i0, bi_flow_pyr[0][:, 2:4], bi_flow_pyr[0][:, :2]) + warped_img0, warped_img1, warped_c0, warped_c1, flow_0t_1t = \ + self.get_warped_representations( + bi_flow_pyr[0], c0_pyr[0], c1_pyr[0], m_splat_0_0, m_splat_1_0, i0, i1, + time_period=time_period) + input_feat = torch.cat( + (last_i, warped_img0, warped_img1, i0, i1, flow_0t_1t), 1) + s0 = self.encoder_conv(input_feat) # [B, 64,h,w] + ss0, mm0 = self.freq_enhance0(c0_pyr[0], c1_pyr[0], s0, bi_flow_pyr[0]) + s0 = ss0 + s0 + + s1 = self.encoder_down1(torch.cat((s0, warped_c0, warped_c1), 1)) # [B, 128,h/2,w/2] + ss1, mm1 = self.freq_enhance1(c0_pyr[1], c1_pyr[1], s1, bi_flow_pyr[1]) + s1 = ss1 + s1 + warped_c0, warped_c1 = self.get_warped_representations( + bi_flow_pyr[1], c0_pyr[1], c1_pyr[1], None, None, + time_period=time_period) + + s2 = self.encoder_down2(torch.cat((s1, warped_c0, warped_c1), 1)) # [B, 256,h/4,w/4] + ss2, mm2 = self.freq_enhance2(c0_pyr[2], c1_pyr[2], s2, bi_flow_pyr[2]) + s2 = ss2 + s2 + warped_c0, warped_c1 = self.get_warped_representations( + bi_flow_pyr[2], c0_pyr[2], c1_pyr[2], None, None, + time_period=time_period) + + x = self.decoder_up1(torch.cat((s2, warped_c0, warped_c1), 1)) + x = self.decoder_up2(torch.cat((x, s1), 1)) + x = self.decoder_conv(torch.cat((x, s0), 1)) + + # prediction + refine = self.pred(x) + refine_res = torch.sigmoid(refine[:, :3]) * 2 - 1 + refine_mask = torch.sigmoid(refine[:, 3:]) + merged_img = (warped_img0 * refine_mask + + warped_img1 * (1 - refine_mask)) + interp_img = merged_img + refine_res + # interp_img = torch.clamp(interp_img, 0, 1) + + extra_dict = {} + extra_dict["refine_res"] = refine_res + extra_dict["refine_mask"] = refine_mask + extra_dict["warped_img0"] = warped_img0 + extra_dict["warped_img1"] = warped_img1 + extra_dict["merged_img"] = merged_img + extra_dict["c0_pyr"] = c0_pyr + extra_dict["c1_pyr"] = c1_pyr + extra_dict["syn_pyr"] = [s0,s1,s2] + extra_dict['s0'] = s0 + extra_dict['s1'] = s1 + extra_dict['s2'] = s2 + extra_dict['ss0'] = ss0 + extra_dict['ss1'] = ss1 + extra_dict['ss2'] = ss2 + extra_dict['mm0'] = mm0 + extra_dict['mm1'] = mm1 + extra_dict['mm2'] = mm2 + + return interp_img, extra_dict + + +# **************************************************************************************************# +# => Unified model +# **************************************************************************************************# +@register('upr_net_freq') +class Model(nn.Module): + def __init__(self, pyr_level=3, nr_lvl_skipped=0, splat_mode='average', fftshift=False): + super(Model, self).__init__() + print('@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@UPR-back freq@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@') + self.pyr_level = pyr_level + self.feat_pyramid = FeatPyramid() + self.nr_lvl_skipped = nr_lvl_skipped + self.motion_estimator = MotionEstimator() + self.synthesis_network = SynthesisNetwork(splat_mode, fftshift) + self.splat_mode = splat_mode + self.fftshift = fftshift + + def forward_one_lvl(self, + img0, img1, last_feat, last_flow, last_interp=None, + time_period=0.5, skip_me=False): + + # context feature extraction + feat0_pyr = self.feat_pyramid(img0) + feat1_pyr = self.feat_pyramid(img1) + + # bi-directional flow estimation + if not skip_me: + last_flow = F.interpolate( + input=last_flow, scale_factor=0.25, + mode="bilinear") * 0.25 + flow, feat = self.motion_estimator( + feat0_pyr[-1], feat1_pyr[-1], + last_feat, last_flow) + else: + flow = last_flow + feat = last_feat + + # frame synthesis + ## optical flow is estimated at 1/4 resolution + ori_resolution_flow = flow + + ## consturct 3-level flow pyramid for synthesis network + bi_flow_pyr = [] + tmp_flow = ori_resolution_flow + bi_flow_pyr.append(tmp_flow) + for i in range(2): + tmp_flow = F.interpolate( + input=tmp_flow, scale_factor=0.5, + mode="bilinear") * 0.5 + bi_flow_pyr.append(tmp_flow) + + ## merge warped frames as initial interpolation for frame synthesis + if last_interp is None: + flow_t0 = ori_resolution_flow[:, :2] * time_period * 2 + flow_t1 = ori_resolution_flow[:, 2:4] * (1 - time_period) * 2 + warped_img0 = backwarp(img0, flow_t0) + warped_img1 = backwarp(img1, flow_t1) + last_interp = warped_img0 * (1 - time_period) + warped_img1 * time_period + + ## do synthesis + interp_img, extra_dict = self.synthesis_network( + last_interp, img0, img1, feat0_pyr, feat1_pyr, bi_flow_pyr, + time_period=time_period) + return flow, feat, interp_img, extra_dict + + def forward(self, img0, img1, time_step, seg0=None, segt=None, seg1=None, + pyr_level=None, nr_lvl_skipped=None, imgt=None, **kwargs): + + if pyr_level is None: pyr_level = self.pyr_level + if nr_lvl_skipped is None: nr_lvl_skipped = self.nr_lvl_skipped + N, _, H, W = img0.shape + flow0_pred = [] + flow1_pred = [] + interp_imgs = [] + skipped_levels = [] if nr_lvl_skipped == 0 else \ + list(range(pyr_level))[::-1][-nr_lvl_skipped:] + + with torch.set_grad_enabled(False): + tenStats = [img0, img1] + tenMean_ = sum([tenIn.mean([1, 2, 3], True) for tenIn in tenStats]) / len(tenStats) + tenStd_ = (sum([tenIn.std([1, 2, 3], False, True).square() + ( + tenMean_ - tenIn.mean([1, 2, 3], True)).square() for tenIn in tenStats]) / len(tenStats)).sqrt() + + img0 = (img0 - tenMean_) / (tenStd_ + 0.0000001) + img1 = (img1 - tenMean_) / (tenStd_ + 0.0000001) + + padder = InputPadder(img0.shape, divisor=int(4 * 2 ** pyr_level)) + img0, img1 = padder.pad(img0, img1) + N, _, H, W = img0.shape + + # The original input resolution corresponds to level 0. + for level in list(range(pyr_level))[::-1]: + if level != 0: + scale_factor = 1 / 2 ** level + img0_this_lvl = F.interpolate( + input=img0, scale_factor=scale_factor, + mode="bilinear", align_corners=False) + img1_this_lvl = F.interpolate( + input=img1, scale_factor=scale_factor, + mode="bilinear", align_corners=False) + else: + img0_this_lvl = img0 + img1_this_lvl = img1 + + # skip motion estimation, directly use up-sampled optical flow + skip_me = False + + # the lowest-resolution pyramid level + if level == pyr_level - 1: + last_flow = torch.zeros( + (N, 4, H // (2 ** (level)), W // (2 ** (level))) + ).to(img0.device) + last_feat = torch.zeros( + (N, 128, H // (2 ** (level + 2)), W // (2 ** (level + 2))) + ).to(img0.device) + last_interp = None + # skip some levels for both motion estimation and frame synthesis + elif level in skipped_levels[:-1]: + continue + # last level (original input resolution), only skip motion estimation + elif (level == 0) and len(skipped_levels) > 0: + if len(skipped_levels) == pyr_level: + last_flow = torch.zeros( + (N, 4, H, W)).to(img0.device) + last_interp = None + else: + resize_factor = 2 ** len(skipped_levels) + last_flow = F.interpolate( + input=flow, scale_factor=resize_factor, + mode="bilinear", align_corners=False) * resize_factor + last_interp = F.interpolate( + input=interp_img, scale_factor=resize_factor, + mode="bilinear", align_corners=False) + skip_me = True + # last level (original input resolution), motion estimation + frame + # synthesis + else: + last_flow = F.interpolate(input=flow, scale_factor=2.0, + mode="bilinear", align_corners=False) * 2 + last_feat = F.interpolate(input=feat, scale_factor=2.0, + mode="bilinear", align_corners=False) + last_interp = F.interpolate( + input=interp_img, scale_factor=2.0, + mode="bilinear", align_corners=False) + + flow, feat, interp_img, extra_dict = self.forward_one_lvl( + img0_this_lvl, img1_this_lvl, + last_feat, last_flow, last_interp, + time_step, skip_me=skip_me) + flow0_pred.append( + padder.unpad(flow[:, :2])) + flow1_pred.append( + padder.unpad(flow[:, 2:])) + interp_imgs.append(padder.unpad(F.interpolate(interp_img, scale_factor=2 ** level)) * tenStd_ + tenMean_) + + # directly up-sample estimated flow to full resolution with bi-linear + # interpolation + refine_res = padder.unpad(extra_dict["refine_res"]) + refine_mask = padder.unpad(extra_dict["refine_mask"]) + c0_pyr = [padder.unpad(cc) for cc in extra_dict["c0_pyr"]] + c1_pyr = [padder.unpad(cc) for cc in extra_dict["c1_pyr"]] + syn_pyr = [padder.unpad(cc) for cc in extra_dict["syn_pyr"]] + warped_img0 = padder.unpad(extra_dict["warped_img0"]) * tenStd_ + tenMean_ + warped_img1 = padder.unpad(extra_dict["warped_img1"]) * tenStd_ + tenMean_ + merged_img = padder.unpad(extra_dict["merged_img"]) * tenStd_ + tenMean_ + result_dict = { + "imgt_preds": interp_imgs, "flow0_pred": flow0_pred[::-1], "flow1_pred": flow1_pred[::-1], + 'imgt_pred': interp_imgs[-1].contiguous(), "flowfwd": flow0_pred[-1], "flowbwd": flow1_pred[-1], + 'refine_res': refine_res, 'refine_mask': refine_mask, 'warped_img0': warped_img0, + 'warped_img1': warped_img1, 'merged_img': merged_img, 'c0_pyr': c0_pyr, 'c1_pyr': c1_pyr, 'syn_pyr': syn_pyr + } + + return result_dict, extra_dict + + +if __name__ == "__main__": + pass \ No newline at end of file diff --git a/modules/components/upr_net_freq/upr_freq_003.py b/modules/components/upr_net_freq/upr_freq_003.py new file mode 100644 index 0000000000000000000000000000000000000000..8601ae8037f20e68e9dbd4f17a01669ec948839c --- /dev/null +++ b/modules/components/upr_net_freq/upr_freq_003.py @@ -0,0 +1,560 @@ +# upr_freq003.py + +import torch +import math +import numpy +import torch.nn.functional as F +import torch.nn as nn +import torchvision.transforms.v2.functional as TF + +import modules.components.upr_net_freq.correlation as correlation +import modules.components.upr_net_freq.softsplat as softsplat +from modules.components.upr_net_freq.m2m import * +from modules.components.upr_net_freq.backwarp import backwarp +from .costvol import costvol_func +from ..components import register +from modules.components.upr_net_freq.frequency_enhance import FrequencyEnhancementTransformer + +from utils.padder import InputPadder +from utils.vos.model.network import STCN +from utils.vos.model.inference_core import InferenceCore + + +# **************************************************************************************************# +# => Feature Pyramid +# **************************************************************************************************# + + +def photometric_consistency(img0, img1, flow01): + return (img0 - backwarp(img1, flow01)).abs().sum(dim=1, keepdims=True) + + +def flow_consistency(flow01, flow10): + return (flow01 + backwarp(flow10, flow01)).abs().sum(dim=1, keepdims=True) + + +def gaussian(x): + gaussian_kernel = torch.tensor([[1, 2, 1], + [2, 4, 2], + [1, 2, 1]]) / 16 + gaussian_kernel = gaussian_kernel.repeat(2, 1, 1, 1) + gaussian_kernel = gaussian_kernel.to(torch.cuda.current_device()) + x = torch.nn.functional.pad(x, (1, 1, 1, 1), mode='reflect') + out = torch.nn.functional.conv2d(x, gaussian_kernel, groups=x.shape[1]) + # out = TF.gaussian_blur(x, [3, 3], sigma=[2, 2]) + return out + + +def variance_flow(flow): + flow = flow * torch.tensor(data=[2.0 / (flow.shape[3] - 1.0), 2.0 / (flow.shape[2] - 1.0)], dtype=flow.dtype, + device=flow.device).view(1, 2, 1, 1) + return (gaussian(flow ** 2) - gaussian(flow) ** 2 + 1e-4).sqrt().abs().sum(dim=1, keepdim=True) + + +class FeatPyramid(nn.Module): + """A 3-level feature pyramid, which by default is shared by the motion + estimator and synthesis network. + """ + + def __init__(self): + super(FeatPyramid, self).__init__() + self.conv_stage0 = nn.Sequential( + nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1)) + self.conv_stage1 = nn.Sequential( + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, + stride=2, padding=1), + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), ) + self.conv_stage2 = nn.Sequential( + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, + stride=2, padding=1), + nn.InstanceNorm2d(num_features=128), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=128), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=128), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, img): + C0 = self.conv_stage0(img) + C1 = self.conv_stage1(C0) + C2 = self.conv_stage2(C1) + return [C0, C1, C2] + + +# **************************************************************************************************# +# => Motion Estimation +# **************************************************************************************************# +class MotionEstimator(nn.Module): + """Bi-directional optical flow estimator + 1) construct partial cost volume with the CNN features from the stage 2 of + the feature pyramid; + 2) estimate bi-directional flows, by feeding cost volume, CNN features for + both warped images, CNN feature and estimated flow from previous iteration. + """ + + def __init__(self): + super(MotionEstimator, self).__init__() + # 64 + 256 + 128 * 2 + 128 + 1 = 705 + self.conv_flow = nn.Sequential( + nn.Conv2d(4, 128, 7, padding=3), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(128, 64, 3, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + self.conv_corr = nn.Sequential( + nn.Conv2d(81, 64, 1, padding=0), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(64, 128, 3, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + self.conv_layer1 = nn.Sequential( + nn.Conv2d(in_channels=705, out_channels=320, + kernel_size=1, stride=1, padding=0), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer2 = nn.Sequential( + nn.Conv2d(in_channels=320, out_channels=256, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer3 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=224, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer4 = nn.Sequential( + nn.Conv2d(in_channels=224, out_channels=192, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer5 = nn.Sequential( + nn.Conv2d(in_channels=192, out_channels=128, + kernel_size=3, stride=1, padding=1)) + self.conv_layer6 = nn.Sequential( + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=4, + kernel_size=3, stride=1, padding=1, bias=False)) + + self.upsampler = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 16 * 9, 1, padding=0) + ) + + # for m in self.modules(): + # if isinstance(m, nn.Conv2d): + # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') + # elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + # if m.weight is not None: + # nn.init.constant_(m.weight, 1) + # if m.bias is not None: + # nn.init.constant_(m.bias, 0) + + def upsample(self, flow, mask): + """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 4, 4, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(4 * flow, [3, 3], padding=1) + up_flow = up_flow.view(N, 4, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 4, 4 * H, 4 * W) + + def forward(self, feat0, feat1, last_feat, last_flow, distance): + corr_fn = correlation.FunctionCorrelation + feat0_warp = backwarp(feat0, last_flow[:, :2]) + feat1_warp = backwarp(feat1, last_flow[:, 2:]) + + volume0 = F.leaky_relu( + input=costvol_func.apply(feat0_warp, feat1_warp), + negative_slope=0.1, inplace=False) + volume1 = F.leaky_relu( + input=costvol_func.apply(feat1_warp, feat0_warp), + negative_slope=0.1, inplace=False) + corr0 = self.conv_corr(volume0) + corr1 = self.conv_corr(volume1) + flo = self.conv_flow(last_flow) + input_feat = torch.cat([corr0, corr1, feat0_warp, feat1_warp, last_feat, flo, distance], 1) + feat = self.conv_layer1(input_feat) + feat = self.conv_layer2(feat) + feat = self.conv_layer3(feat) + feat = self.conv_layer4(feat) + feat = self.conv_layer5(feat) + flow_res = self.conv_layer6(feat) + flow = last_flow + flow_res + mask = self.upsampler(feat) * .25 + flow = self.upsample(flow, mask) + + return flow, feat + + +# **************************************************************************************************# +# => Frame Synthesis +# **************************************************************************************************# +class SynthesisNetwork(nn.Module): + def __init__(self, splat_mode='average', fftshift=False): + super(SynthesisNetwork, self).__init__() + input_channels = 9 + 4 + 6 + self.encoder_conv = nn.Sequential( + nn.Conv2d(in_channels=input_channels, out_channels=64, + kernel_size=3, stride=1, padding=1), + nn.PReLU(num_parameters=64), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64)) + self.freq_enhance0 = FrequencyEnhancementTransformer( + c_dim=32, feat_dim=64, num_head=4, hidden_ratio=4., fftshift=fftshift) + self.encoder_down1 = nn.Sequential( + nn.Conv2d(in_channels=64 + 32 + 32, out_channels=128, + kernel_size=3, stride=2, padding=1), + nn.PReLU(num_parameters=128), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=128), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=128)) + self.freq_enhance1 = FrequencyEnhancementTransformer( + c_dim=64, feat_dim=128, num_head=4, hidden_ratio=4., fftshift=fftshift) + self.encoder_down2 = nn.Sequential( + nn.Conv2d(in_channels=128 + 64 + 64, out_channels=256, + kernel_size=3, stride=2, padding=1), + nn.PReLU(num_parameters=256), + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=256), + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=256)) + self.freq_enhance2 = FrequencyEnhancementTransformer( + c_dim=128, feat_dim=256, num_head=4, hidden_ratio=4., fftshift=fftshift) + self.decoder_up1 = nn.Sequential( + torch.nn.ConvTranspose2d(in_channels=256 + 128 + 128, + out_channels=128, kernel_size=4, stride=2, + padding=1, bias=True), + nn.PReLU(num_parameters=128), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=128)) + self.decoder_up2 = nn.Sequential( + torch.nn.ConvTranspose2d(in_channels=128 + 128, + out_channels=64, kernel_size=4, stride=2, + padding=1, bias=True), + nn.PReLU(num_parameters=64), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64)) + self.decoder_conv = nn.Sequential( + nn.Conv2d(in_channels=64 + 64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64)) + self.pred = nn.Conv2d(in_channels=64, out_channels=4, kernel_size=3, + stride=1, padding=1) + self.splat_mode = splat_mode + + if self.splat_mode == 'softmax': + # New params for splatting mask generation + self.alpha = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_photo_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_flow_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_variation_flow = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + + def get_splat_weight(self, img0, img1, flow01, flow10): + if self.splat_mode == 'softmax': + M_splat = 1 / ( + 1 + self.alpha_splat_photo_consistency * photometric_consistency(img0, img1, flow01).detach()) + \ + 1 / (1 + self.alpha_splat_flow_consistency * flow_consistency(flow01, flow10).detach()) + \ + 1 / (1 + self.alpha_splat_variation_flow * variance_flow(flow01).detach()) + return M_splat * self.alpha + else: + return None + + def get_warped_representations(self, bi_flow, c0, c1, m_splat_0, m_splat_1, i0=None, i1=None, time_period=0.5): + flow_t0 = bi_flow[:, :2] * time_period * 2 + flow_t1 = bi_flow[:, 2:4] * (1 - time_period) * 2 + warped_c0 = backwarp(c0, flow_t0) + warped_c1 = backwarp(c1, flow_t1) + if (i0 is None) and (i1 is None): + return warped_c0, warped_c1 + else: + warped_img0 = backwarp(i0, flow_t0) + warped_img1 = backwarp(i1, flow_t1) + scaler = torch.Tensor([i0.shape[3], i0.shape[2]]).view(1, 2, 1, 1).cuda() + flow_t0_t1 = torch.cat((flow_t0 / scaler, flow_t1 / scaler), 1) + return warped_img0, warped_img1, warped_c0, warped_c1, flow_t0_t1 + + def forward(self, last_i, i0, i1, c0_pyr, c1_pyr, bi_flow_pyr, time_period=0.5, multi_flow=False): + m_splat_0_0 = self.get_splat_weight(i0, i1, bi_flow_pyr[0][:, :2], bi_flow_pyr[0][:, 2:4]) + m_splat_1_0 = self.get_splat_weight(i1, i0, bi_flow_pyr[0][:, 2:4], bi_flow_pyr[0][:, :2]) + warped_img0, warped_img1, warped_c0, warped_c1, flow_0t_1t = \ + self.get_warped_representations( + bi_flow_pyr[0], c0_pyr[0], c1_pyr[0], m_splat_0_0, m_splat_1_0, i0, i1, + time_period=time_period) + input_feat = torch.cat( + (last_i, warped_img0, warped_img1, i0, i1, flow_0t_1t), 1) + s0 = self.encoder_conv(input_feat) # [B, 64,h,w] + s0 = self.freq_enhance0(c0_pyr[0], c1_pyr[0], s0, bi_flow_pyr[0]) + s0 + s1 = self.encoder_down1(torch.cat((s0, warped_c0, warped_c1), 1)) # [B, 128,h/2,w/2] + s1 = self.freq_enhance1(c0_pyr[1], c1_pyr[1], s1, bi_flow_pyr[1]) + s1 + warped_c0, warped_c1 = self.get_warped_representations( + bi_flow_pyr[1], c0_pyr[1], c1_pyr[1], None, None, + time_period=time_period) + s2 = self.encoder_down2(torch.cat((s1, warped_c0, warped_c1), 1)) # [B, 256,h/4,w/4] + s2 = self.freq_enhance2(c0_pyr[2], c1_pyr[2], s2, bi_flow_pyr[2]) + s2 + warped_c0, warped_c1 = self.get_warped_representations( + bi_flow_pyr[2], c0_pyr[2], c1_pyr[2], None, None, + time_period=time_period) + + x = self.decoder_up1(torch.cat((s2, warped_c0, warped_c1), 1)) + x = self.decoder_up2(torch.cat((x, s1), 1)) + x = self.decoder_conv(torch.cat((x, s0), 1)) + + # prediction + refine = self.pred(x) + refine_res = torch.sigmoid(refine[:, :3]) * 2 - 1 + refine_mask = torch.sigmoid(refine[:, 3:]) + merged_img = (warped_img0 * refine_mask + + warped_img1 * (1 - refine_mask)) + interp_img = merged_img + refine_res + # interp_img = torch.clamp(interp_img, 0, 1) + + extra_dict = {} + extra_dict["refine_res"] = refine_res + extra_dict["refine_mask"] = refine_mask + extra_dict["warped_img0"] = warped_img0 + extra_dict["warped_img1"] = warped_img1 + extra_dict["merged_img"] = merged_img + extra_dict["c0_pyr"] = c0_pyr + extra_dict["c1_pyr"] = c1_pyr + extra_dict["syn_pyr"] = [s0,s1,s2] + + return interp_img, extra_dict + + +# **************************************************************************************************# +# => Unified model +# **************************************************************************************************# +@register('upr_net_freq') +class Model(nn.Module): + def __init__(self, pyr_level=3, nr_lvl_skipped=0, splat_mode='average', fftshift=False): + super(Model, self).__init__() + print('@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@UPR-back freq@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@') + self.pyr_level = pyr_level + self.feat_pyramid = FeatPyramid() + self.nr_lvl_skipped = nr_lvl_skipped + self.motion_estimator = MotionEstimator() + self.synthesis_network = SynthesisNetwork(splat_mode, fftshift) + self.splat_mode = splat_mode + self.fftshift = fftshift + + def forward_one_lvl(self, + img0, img1, last_feat, last_flow, last_interp=None, + time_period=0.5, skip_me=False, distance=None): + + # context feature extraction + feat0_pyr = self.feat_pyramid(img0) + feat1_pyr = self.feat_pyramid(img1) + + # bi-directional flow estimation + if not skip_me: + last_flow = F.interpolate( + input=last_flow, scale_factor=0.25, + mode="bilinear") * 0.25 + distance_ = F.interpolate( + input=distance, scale_factor=0.25, + mode="bilinear") + flow, feat = self.motion_estimator( + feat0_pyr[-1], feat1_pyr[-1], + last_feat, last_flow, distance_) + else: + flow = last_flow + feat = last_feat + + # frame synthesis + ## optical flow is estimated at 1/4 resolution + ori_resolution_flow = flow + + ## consturct 3-level flow pyramid for synthesis network + bi_flow_pyr = [] + tmp_flow = ori_resolution_flow + bi_flow_pyr.append(tmp_flow) + for i in range(2): + tmp_flow = F.interpolate( + input=tmp_flow, scale_factor=0.5, + mode="bilinear") * 0.5 + bi_flow_pyr.append(tmp_flow) + + ## merge warped frames as initial interpolation for frame synthesis + if last_interp is None: + flow_t0 = ori_resolution_flow[:, :2] * time_period * 2 + flow_t1 = ori_resolution_flow[:, 2:4] * (1 - time_period) * 2 + warped_img0 = backwarp(img0, flow_t0) + warped_img1 = backwarp(img1, flow_t1) + last_interp = warped_img0 * (1 - time_period) + warped_img1 * time_period + + ## do synthesis + interp_img, extra_dict = self.synthesis_network( + last_interp, img0, img1, feat0_pyr, feat1_pyr, bi_flow_pyr, + time_period=time_period) + return flow, feat, interp_img, extra_dict + + def forward(self, img0, img1, time_step, distance=None, seg0=None, segt=None, seg1=None, + pyr_level=None, nr_lvl_skipped=None, imgt=None, **kwargs): + + if pyr_level is None: pyr_level = self.pyr_level + if nr_lvl_skipped is None: nr_lvl_skipped = self.nr_lvl_skipped + N, _, H, W = img0.shape + flow0_pred = [] + flow1_pred = [] + interp_imgs = [] + skipped_levels = [] if nr_lvl_skipped == 0 else \ + list(range(pyr_level))[::-1][-nr_lvl_skipped:] + + with torch.set_grad_enabled(False): + tenStats = [img0, img1] + tenMean_ = sum([tenIn.mean([1, 2, 3], True) for tenIn in tenStats]) / len(tenStats) + tenStd_ = (sum([tenIn.std([1, 2, 3], False, True).square() + ( + tenMean_ - tenIn.mean([1, 2, 3], True)).square() for tenIn in tenStats]) / len(tenStats)).sqrt() + + img0 = (img0 - tenMean_) / (tenStd_ + 0.0000001) + img1 = (img1 - tenMean_) / (tenStd_ + 0.0000001) + + padder = InputPadder(img0.shape, divisor=int(4 * 2 ** pyr_level)) + img0, img1 = padder.pad(img0, img1) + N, _, H, W = img0.shape + + # The original input resolution corresponds to level 0. + for level in list(range(pyr_level))[::-1]: + if level != 0: + scale_factor = 1 / 2 ** level + img0_this_lvl = F.interpolate( + input=img0, scale_factor=scale_factor, + mode="bilinear", align_corners=False) + img1_this_lvl = F.interpolate( + input=img1, scale_factor=scale_factor, + mode="bilinear", align_corners=False) + cur_distance = F.interpolate(input=distance, scale_factor=scale_factor, + mode='bilinear', align_corners=False) + else: + img0_this_lvl = img0 + img1_this_lvl = img1 + cur_distance = distance + + # skip motion estimation, directly use up-sampled optical flow + skip_me = False + + # the lowest-resolution pyramid level + if level == pyr_level - 1: + last_flow = torch.zeros( + (N, 4, H // (2 ** (level)), W // (2 ** (level))) + ).to(img0.device) + last_feat = torch.zeros( + (N, 128, H // (2 ** (level + 2)), W // (2 ** (level + 2))) + ).to(img0.device) + last_interp = None + # skip some levels for both motion estimation and frame synthesis + elif level in skipped_levels[:-1]: + continue + # last level (original input resolution), only skip motion estimation + elif (level == 0) and len(skipped_levels) > 0: + if len(skipped_levels) == pyr_level: + last_flow = torch.zeros( + (N, 4, H, W)).to(img0.device) + last_interp = None + else: + resize_factor = 2 ** len(skipped_levels) + last_flow = F.interpolate( + input=flow, scale_factor=resize_factor, + mode="bilinear", align_corners=False) * resize_factor + last_interp = F.interpolate( + input=interp_img, scale_factor=resize_factor, + mode="bilinear", align_corners=False) + skip_me = True + # last level (original input resolution), motion estimation + frame + # synthesis + else: + last_flow = F.interpolate(input=flow, scale_factor=2.0, + mode="bilinear", align_corners=False) * 2 + last_feat = F.interpolate(input=feat, scale_factor=2.0, + mode="bilinear", align_corners=False) + last_interp = F.interpolate( + input=interp_img, scale_factor=2.0, + mode="bilinear", align_corners=False) + + flow, feat, interp_img, extra_dict = self.forward_one_lvl( + img0_this_lvl, img1_this_lvl, + last_feat, last_flow, last_interp, + time_step, skip_me=skip_me, distance=cur_distance) + flow0_pred.append( + padder.unpad(flow[:, :2])) + flow1_pred.append( + padder.unpad(flow[:, 2:])) + interp_imgs.append(padder.unpad(F.interpolate(interp_img, scale_factor=2 ** level)) * tenStd_ + tenMean_) + + # directly up-sample estimated flow to full resolution with bi-linear + # interpolation + refine_res = padder.unpad(extra_dict["refine_res"]) + refine_mask = padder.unpad(extra_dict["refine_mask"]) + c0_pyr = [padder.unpad(cc) for cc in extra_dict["c0_pyr"]] + c1_pyr = [padder.unpad(cc) for cc in extra_dict["c1_pyr"]] + syn_pyr = [padder.unpad(cc) for cc in extra_dict["syn_pyr"]] + warped_img0 = padder.unpad(extra_dict["warped_img0"]) * tenStd_ + tenMean_ + warped_img1 = padder.unpad(extra_dict["warped_img1"]) * tenStd_ + tenMean_ + merged_img = padder.unpad(extra_dict["merged_img"]) * tenStd_ + tenMean_ + result_dict = { + "imgt_preds": interp_imgs, "flow0_pred": flow0_pred[::-1], "flow1_pred": flow1_pred[::-1], + 'imgt_pred': interp_imgs[-1].contiguous(), "flowfwd": flow0_pred[-1], "flowbwd": flow1_pred[-1], + 'refine_res': refine_res, 'refine_mask': refine_mask, 'warped_img0': warped_img0, + 'warped_img1': warped_img1, 'merged_img': merged_img, 'c0_pyr': c0_pyr, 'c1_pyr': c1_pyr, 'syn_pyr': syn_pyr + } + + return result_dict + + +if __name__ == "__main__": + pass \ No newline at end of file diff --git a/modules/components/upr_net_freq/upr_freq_004.py b/modules/components/upr_net_freq/upr_freq_004.py new file mode 100644 index 0000000000000000000000000000000000000000..ff9a22b7d5b46d540abfc4455e4a85a5c82d9f89 --- /dev/null +++ b/modules/components/upr_net_freq/upr_freq_004.py @@ -0,0 +1,561 @@ +# upr_freq004.py + +import torch +import math +import numpy +import torch.nn.functional as F +import torch.nn as nn +import torchvision.transforms.v2.functional as TF + +import modules.components.upr_net_freq.correlation as correlation +import modules.components.upr_net_freq.softsplat as softsplat +from modules.components.upr_net_freq.m2m import * +from modules.components.upr_net_freq.backwarp import backwarp +from .costvol import costvol_func +from ..components import register +from modules.components.upr_net_freq.frequency_enhance import FrequencyEnhancementTransformer + +from utils.padder import InputPadder +from utils.vos.model.network import STCN +from utils.vos.model.inference_core import InferenceCore + + +# **************************************************************************************************# +# => Feature Pyramid +# **************************************************************************************************# + + +def photometric_consistency(img0, img1, flow01): + return (img0 - backwarp(img1, flow01)).abs().sum(dim=1, keepdims=True) + + +def flow_consistency(flow01, flow10): + return (flow01 + backwarp(flow10, flow01)).abs().sum(dim=1, keepdims=True) + + +def gaussian(x): + gaussian_kernel = torch.tensor([[1, 2, 1], + [2, 4, 2], + [1, 2, 1]]) / 16 + gaussian_kernel = gaussian_kernel.repeat(2, 1, 1, 1) + gaussian_kernel = gaussian_kernel.to(torch.cuda.current_device()) + x = torch.nn.functional.pad(x, (1, 1, 1, 1), mode='reflect') + out = torch.nn.functional.conv2d(x, gaussian_kernel, groups=x.shape[1]) + # out = TF.gaussian_blur(x, [3, 3], sigma=[2, 2]) + return out + + +def variance_flow(flow): + flow = flow * torch.tensor(data=[2.0 / (flow.shape[3] - 1.0), 2.0 / (flow.shape[2] - 1.0)], dtype=flow.dtype, + device=flow.device).view(1, 2, 1, 1) + return (gaussian(flow ** 2) - gaussian(flow) ** 2 + 1e-4).sqrt().abs().sum(dim=1, keepdim=True) + + +class FeatPyramid(nn.Module): + """A 3-level feature pyramid, which by default is shared by the motion + estimator and synthesis network. + """ + + def __init__(self): + super(FeatPyramid, self).__init__() + self.conv_stage0 = nn.Sequential( + nn.Conv2d(in_channels=4, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1)) + self.conv_stage1 = nn.Sequential( + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, + stride=2, padding=1), + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), ) + self.conv_stage2 = nn.Sequential( + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, + stride=2, padding=1), + nn.InstanceNorm2d(num_features=128), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=128), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=128), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, img, distance): + distance = distance.clamp(0,1) + C0 = self.conv_stage0(torch.cat([img, distance], dim=1)) + C1 = self.conv_stage1(C0) + C2 = self.conv_stage2(C1) + return [C0, C1, C2] + + +# **************************************************************************************************# +# => Motion Estimation +# **************************************************************************************************# +class MotionEstimator(nn.Module): + """Bi-directional optical flow estimator + 1) construct partial cost volume with the CNN features from the stage 2 of + the feature pyramid; + 2) estimate bi-directional flows, by feeding cost volume, CNN features for + both warped images, CNN feature and estimated flow from previous iteration. + """ + + def __init__(self): + super(MotionEstimator, self).__init__() + # 64 + 256 + 128 * 2 + 128 + 1 = 705 + self.conv_flow = nn.Sequential( + nn.Conv2d(4, 128, 7, padding=3), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(128, 64, 3, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + self.conv_corr = nn.Sequential( + nn.Conv2d(81, 64, 1, padding=0), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(64, 128, 3, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + self.conv_layer1 = nn.Sequential( + nn.Conv2d(in_channels=705, out_channels=320, + kernel_size=1, stride=1, padding=0), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer2 = nn.Sequential( + nn.Conv2d(in_channels=320, out_channels=256, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer3 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=224, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer4 = nn.Sequential( + nn.Conv2d(in_channels=224, out_channels=192, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer5 = nn.Sequential( + nn.Conv2d(in_channels=192, out_channels=128, + kernel_size=3, stride=1, padding=1)) + self.conv_layer6 = nn.Sequential( + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=4, + kernel_size=3, stride=1, padding=1, bias=False)) + + self.upsampler = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 16 * 9, 1, padding=0) + ) + + # for m in self.modules(): + # if isinstance(m, nn.Conv2d): + # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') + # elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + # if m.weight is not None: + # nn.init.constant_(m.weight, 1) + # if m.bias is not None: + # nn.init.constant_(m.bias, 0) + + def upsample(self, flow, mask): + """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 4, 4, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(4 * flow, [3, 3], padding=1) + up_flow = up_flow.view(N, 4, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 4, 4 * H, 4 * W) + + def forward(self, feat0, feat1, last_feat, last_flow, distance): + corr_fn = correlation.FunctionCorrelation + feat0_warp = backwarp(feat0, last_flow[:, :2]) + feat1_warp = backwarp(feat1, last_flow[:, 2:]) + + volume0 = F.leaky_relu( + input=costvol_func.apply(feat0_warp, feat1_warp), + negative_slope=0.1, inplace=False) + volume1 = F.leaky_relu( + input=costvol_func.apply(feat1_warp, feat0_warp), + negative_slope=0.1, inplace=False) + corr0 = self.conv_corr(volume0) + corr1 = self.conv_corr(volume1) + flo = self.conv_flow(last_flow) + input_feat = torch.cat([corr0, corr1, feat0_warp, feat1_warp, last_feat, flo, distance], 1) + feat = self.conv_layer1(input_feat) + feat = self.conv_layer2(feat) + feat = self.conv_layer3(feat) + feat = self.conv_layer4(feat) + feat = self.conv_layer5(feat) + flow_res = self.conv_layer6(feat) + flow = last_flow + flow_res + mask = self.upsampler(feat) * .25 + flow = self.upsample(flow, mask) + + return flow, feat + + +# **************************************************************************************************# +# => Frame Synthesis +# **************************************************************************************************# +class SynthesisNetwork(nn.Module): + def __init__(self, splat_mode='average', fftshift=False): + super(SynthesisNetwork, self).__init__() + input_channels = 9 + 4 + 6 + self.encoder_conv = nn.Sequential( + nn.Conv2d(in_channels=input_channels, out_channels=64, + kernel_size=3, stride=1, padding=1), + nn.PReLU(num_parameters=64), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64)) + self.freq_enhance0 = FrequencyEnhancementTransformer( + c_dim=32, feat_dim=64, num_head=4, hidden_ratio=4., fftshift=fftshift) + self.encoder_down1 = nn.Sequential( + nn.Conv2d(in_channels=64 + 32 + 32, out_channels=128, + kernel_size=3, stride=2, padding=1), + nn.PReLU(num_parameters=128), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=128), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=128)) + self.freq_enhance1 = FrequencyEnhancementTransformer( + c_dim=64, feat_dim=128, num_head=4, hidden_ratio=4., fftshift=fftshift) + self.encoder_down2 = nn.Sequential( + nn.Conv2d(in_channels=128 + 64 + 64, out_channels=256, + kernel_size=3, stride=2, padding=1), + nn.PReLU(num_parameters=256), + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=256), + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=256)) + self.freq_enhance2 = FrequencyEnhancementTransformer( + c_dim=128, feat_dim=256, num_head=4, hidden_ratio=4., fftshift=fftshift) + self.decoder_up1 = nn.Sequential( + torch.nn.ConvTranspose2d(in_channels=256 + 128 + 128, + out_channels=128, kernel_size=4, stride=2, + padding=1, bias=True), + nn.PReLU(num_parameters=128), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=128)) + self.decoder_up2 = nn.Sequential( + torch.nn.ConvTranspose2d(in_channels=128 + 128, + out_channels=64, kernel_size=4, stride=2, + padding=1, bias=True), + nn.PReLU(num_parameters=64), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64)) + self.decoder_conv = nn.Sequential( + nn.Conv2d(in_channels=64 + 64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64)) + self.pred = nn.Conv2d(in_channels=64, out_channels=4, kernel_size=3, + stride=1, padding=1) + self.splat_mode = splat_mode + + if self.splat_mode == 'softmax': + # New params for splatting mask generation + self.alpha = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_photo_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_flow_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_variation_flow = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + + def get_splat_weight(self, img0, img1, flow01, flow10): + if self.splat_mode == 'softmax': + M_splat = 1 / ( + 1 + self.alpha_splat_photo_consistency * photometric_consistency(img0, img1, flow01).detach()) + \ + 1 / (1 + self.alpha_splat_flow_consistency * flow_consistency(flow01, flow10).detach()) + \ + 1 / (1 + self.alpha_splat_variation_flow * variance_flow(flow01).detach()) + return M_splat * self.alpha + else: + return None + + def get_warped_representations(self, bi_flow, c0, c1, m_splat_0, m_splat_1, i0=None, i1=None, time_period=0.5): + flow_t0 = bi_flow[:, :2] * time_period * 2 + flow_t1 = bi_flow[:, 2:4] * (1 - time_period) * 2 + warped_c0 = backwarp(c0, flow_t0) + warped_c1 = backwarp(c1, flow_t1) + if (i0 is None) and (i1 is None): + return warped_c0, warped_c1 + else: + warped_img0 = backwarp(i0, flow_t0) + warped_img1 = backwarp(i1, flow_t1) + scaler = torch.Tensor([i0.shape[3], i0.shape[2]]).view(1, 2, 1, 1).cuda() + flow_t0_t1 = torch.cat((flow_t0 / scaler, flow_t1 / scaler), 1) + return warped_img0, warped_img1, warped_c0, warped_c1, flow_t0_t1 + + def forward(self, last_i, i0, i1, c0_pyr, c1_pyr, bi_flow_pyr, time_period=0.5, multi_flow=False): + m_splat_0_0 = self.get_splat_weight(i0, i1, bi_flow_pyr[0][:, :2], bi_flow_pyr[0][:, 2:4]) + m_splat_1_0 = self.get_splat_weight(i1, i0, bi_flow_pyr[0][:, 2:4], bi_flow_pyr[0][:, :2]) + warped_img0, warped_img1, warped_c0, warped_c1, flow_0t_1t = \ + self.get_warped_representations( + bi_flow_pyr[0], c0_pyr[0], c1_pyr[0], m_splat_0_0, m_splat_1_0, i0, i1, + time_period=time_period) + input_feat = torch.cat( + (last_i, warped_img0, warped_img1, i0, i1, flow_0t_1t), 1) + s0 = self.encoder_conv(input_feat) # [B, 64,h,w] + s0 = self.freq_enhance0(c0_pyr[0], c1_pyr[0], s0, bi_flow_pyr[0]) + s0 + s1 = self.encoder_down1(torch.cat((s0, warped_c0, warped_c1), 1)) # [B, 128,h/2,w/2] + s1 = self.freq_enhance1(c0_pyr[1], c1_pyr[1], s1, bi_flow_pyr[1]) + s1 + warped_c0, warped_c1 = self.get_warped_representations( + bi_flow_pyr[1], c0_pyr[1], c1_pyr[1], None, None, + time_period=time_period) + s2 = self.encoder_down2(torch.cat((s1, warped_c0, warped_c1), 1)) # [B, 256,h/4,w/4] + s2 = self.freq_enhance2(c0_pyr[2], c1_pyr[2], s2, bi_flow_pyr[2]) + s2 + warped_c0, warped_c1 = self.get_warped_representations( + bi_flow_pyr[2], c0_pyr[2], c1_pyr[2], None, None, + time_period=time_period) + + x = self.decoder_up1(torch.cat((s2, warped_c0, warped_c1), 1)) + x = self.decoder_up2(torch.cat((x, s1), 1)) + x = self.decoder_conv(torch.cat((x, s0), 1)) + + # prediction + refine = self.pred(x) + refine_res = torch.sigmoid(refine[:, :3]) * 2 - 1 + refine_mask = torch.sigmoid(refine[:, 3:]) + merged_img = (warped_img0 * refine_mask + + warped_img1 * (1 - refine_mask)) + interp_img = merged_img + refine_res + # interp_img = torch.clamp(interp_img, 0, 1) + + extra_dict = {} + extra_dict["refine_res"] = refine_res + extra_dict["refine_mask"] = refine_mask + extra_dict["warped_img0"] = warped_img0 + extra_dict["warped_img1"] = warped_img1 + extra_dict["merged_img"] = merged_img + extra_dict["c0_pyr"] = c0_pyr + extra_dict["c1_pyr"] = c1_pyr + extra_dict["syn_pyr"] = [s0,s1,s2] + + return interp_img, extra_dict + + +# **************************************************************************************************# +# => Unified model +# **************************************************************************************************# +@register('upr_net_freq') +class Model(nn.Module): + def __init__(self, pyr_level=3, nr_lvl_skipped=0, splat_mode='average', fftshift=False): + super(Model, self).__init__() + print('@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@UPR-back freq004@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@') + self.pyr_level = pyr_level + self.feat_pyramid = FeatPyramid() + self.nr_lvl_skipped = nr_lvl_skipped + self.motion_estimator = MotionEstimator() + self.synthesis_network = SynthesisNetwork(splat_mode, fftshift) + self.splat_mode = splat_mode + self.fftshift = fftshift + + def forward_one_lvl(self, + img0, img1, last_feat, last_flow, last_interp=None, + time_period=0.5, skip_me=False, distance=None): + + # context feature extraction + feat0_pyr = self.feat_pyramid(img0, distance) + feat1_pyr = self.feat_pyramid(img1, distance) + + # bi-directional flow estimation + if not skip_me: + last_flow = F.interpolate( + input=last_flow, scale_factor=0.25, + mode="bilinear") * 0.25 + distance_ = F.interpolate( + input=distance, scale_factor=0.25, + mode="bilinear") + flow, feat = self.motion_estimator( + feat0_pyr[-1], feat1_pyr[-1], + last_feat, last_flow, distance_) + else: + flow = last_flow + feat = last_feat + + # frame synthesis + ## optical flow is estimated at 1/4 resolution + ori_resolution_flow = flow + + ## consturct 3-level flow pyramid for synthesis network + bi_flow_pyr = [] + tmp_flow = ori_resolution_flow + bi_flow_pyr.append(tmp_flow) + for i in range(2): + tmp_flow = F.interpolate( + input=tmp_flow, scale_factor=0.5, + mode="bilinear") * 0.5 + bi_flow_pyr.append(tmp_flow) + + ## merge warped frames as initial interpolation for frame synthesis + if last_interp is None: + flow_t0 = ori_resolution_flow[:, :2] * time_period * 2 + flow_t1 = ori_resolution_flow[:, 2:4] * (1 - time_period) * 2 + warped_img0 = backwarp(img0, flow_t0) + warped_img1 = backwarp(img1, flow_t1) + last_interp = warped_img0 * (1 - time_period) + warped_img1 * time_period + + ## do synthesis + interp_img, extra_dict = self.synthesis_network( + last_interp, img0, img1, feat0_pyr, feat1_pyr, bi_flow_pyr, + time_period=time_period) + return flow, feat, interp_img, extra_dict + + def forward(self, img0, img1, time_step, distance=None, seg0=None, segt=None, seg1=None, + pyr_level=None, nr_lvl_skipped=None, imgt=None, **kwargs): + + if pyr_level is None: pyr_level = self.pyr_level + if nr_lvl_skipped is None: nr_lvl_skipped = self.nr_lvl_skipped + N, _, H, W = img0.shape + flow0_pred = [] + flow1_pred = [] + interp_imgs = [] + skipped_levels = [] if nr_lvl_skipped == 0 else \ + list(range(pyr_level))[::-1][-nr_lvl_skipped:] + + with torch.set_grad_enabled(False): + tenStats = [img0, img1] + tenMean_ = sum([tenIn.mean([1, 2, 3], True) for tenIn in tenStats]) / len(tenStats) + tenStd_ = (sum([tenIn.std([1, 2, 3], False, True).square() + ( + tenMean_ - tenIn.mean([1, 2, 3], True)).square() for tenIn in tenStats]) / len(tenStats)).sqrt() + + img0 = (img0 - tenMean_) / (tenStd_ + 0.0000001) + img1 = (img1 - tenMean_) / (tenStd_ + 0.0000001) + + padder = InputPadder(img0.shape, divisor=int(4 * 2 ** pyr_level)) + img0, img1 = padder.pad(img0, img1) + N, _, H, W = img0.shape + + # The original input resolution corresponds to level 0. + for level in list(range(pyr_level))[::-1]: + if level != 0: + scale_factor = 1 / 2 ** level + img0_this_lvl = F.interpolate( + input=img0, scale_factor=scale_factor, + mode="bilinear", align_corners=False) + img1_this_lvl = F.interpolate( + input=img1, scale_factor=scale_factor, + mode="bilinear", align_corners=False) + cur_distance = F.interpolate(input=distance, scale_factor=scale_factor, + mode='bilinear', align_corners=False) + else: + img0_this_lvl = img0 + img1_this_lvl = img1 + cur_distance = distance + + # skip motion estimation, directly use up-sampled optical flow + skip_me = False + + # the lowest-resolution pyramid level + if level == pyr_level - 1: + last_flow = torch.zeros( + (N, 4, H // (2 ** (level)), W // (2 ** (level))) + ).to(img0.device) + last_feat = torch.zeros( + (N, 128, H // (2 ** (level + 2)), W // (2 ** (level + 2))) + ).to(img0.device) + last_interp = None + # skip some levels for both motion estimation and frame synthesis + elif level in skipped_levels[:-1]: + continue + # last level (original input resolution), only skip motion estimation + elif (level == 0) and len(skipped_levels) > 0: + if len(skipped_levels) == pyr_level: + last_flow = torch.zeros( + (N, 4, H, W)).to(img0.device) + last_interp = None + else: + resize_factor = 2 ** len(skipped_levels) + last_flow = F.interpolate( + input=flow, scale_factor=resize_factor, + mode="bilinear", align_corners=False) * resize_factor + last_interp = F.interpolate( + input=interp_img, scale_factor=resize_factor, + mode="bilinear", align_corners=False) + skip_me = True + # last level (original input resolution), motion estimation + frame + # synthesis + else: + last_flow = F.interpolate(input=flow, scale_factor=2.0, + mode="bilinear", align_corners=False) * 2 + last_feat = F.interpolate(input=feat, scale_factor=2.0, + mode="bilinear", align_corners=False) + last_interp = F.interpolate( + input=interp_img, scale_factor=2.0, + mode="bilinear", align_corners=False) + + flow, feat, interp_img, extra_dict = self.forward_one_lvl( + img0_this_lvl, img1_this_lvl, + last_feat, last_flow, last_interp, + time_step, skip_me=skip_me, distance=cur_distance) + flow0_pred.append( + padder.unpad(flow[:, :2])) + flow1_pred.append( + padder.unpad(flow[:, 2:])) + interp_imgs.append(padder.unpad(F.interpolate(interp_img, scale_factor=2 ** level)) * tenStd_ + tenMean_) + + # directly up-sample estimated flow to full resolution with bi-linear + # interpolation + refine_res = padder.unpad(extra_dict["refine_res"]) + refine_mask = padder.unpad(extra_dict["refine_mask"]) + c0_pyr = [padder.unpad(cc) for cc in extra_dict["c0_pyr"]] + c1_pyr = [padder.unpad(cc) for cc in extra_dict["c1_pyr"]] + syn_pyr = [padder.unpad(cc) for cc in extra_dict["syn_pyr"]] + warped_img0 = padder.unpad(extra_dict["warped_img0"]) * tenStd_ + tenMean_ + warped_img1 = padder.unpad(extra_dict["warped_img1"]) * tenStd_ + tenMean_ + merged_img = padder.unpad(extra_dict["merged_img"]) * tenStd_ + tenMean_ + result_dict = { + "imgt_preds": interp_imgs, "flow0_pred": flow0_pred[::-1], "flow1_pred": flow1_pred[::-1], + 'imgt_pred': interp_imgs[-1].contiguous(), "flowfwd": flow0_pred[-1], "flowbwd": flow1_pred[-1], + 'refine_res': refine_res, 'refine_mask': refine_mask, 'warped_img0': warped_img0, + 'warped_img1': warped_img1, 'merged_img': merged_img, 'c0_pyr': c0_pyr, 'c1_pyr': c1_pyr, 'syn_pyr': syn_pyr + } + + return result_dict + + +if __name__ == "__main__": + pass \ No newline at end of file diff --git a/modules/components/upr_net_freq/upr_freq_005.py b/modules/components/upr_net_freq/upr_freq_005.py new file mode 100644 index 0000000000000000000000000000000000000000..458e9b91d7375686dbd510e43955f2c8fbd69a51 --- /dev/null +++ b/modules/components/upr_net_freq/upr_freq_005.py @@ -0,0 +1,538 @@ +# upr_freq_005.py (freq002+synthesis EncFreqs + Asym.FreqDec) +import torch +import math +import numpy +import torch.nn.functional as F +import torch.nn as nn +import torchvision.transforms.v2.functional as TF + +import modules.components.upr_net_freq.correlation as correlation +import modules.components.upr_net_freq.softsplat as softsplat +from modules.components.upr_net_freq.m2m import * +from modules.components.upr_net_freq.backwarp import backwarp +from .costvol import costvol_func +from ..components import register +from modules.components.upr_net_freq.frequency_enhance import FrequencyEnhancementTransformer, FrequencyEnhancementDecoder + +from utils.padder import InputPadder +from utils.vos.model.network import STCN +from utils.vos.model.inference_core import InferenceCore + + +# **************************************************************************************************# +# => Feature Pyramid +# **************************************************************************************************# + + +def photometric_consistency(img0, img1, flow01): + return (img0 - backwarp(img1, flow01)).abs().sum(dim=1, keepdims=True) + + +def flow_consistency(flow01, flow10): + return (flow01 + backwarp(flow10, flow01)).abs().sum(dim=1, keepdims=True) + + +def gaussian(x): + gaussian_kernel = torch.tensor([[1, 2, 1], + [2, 4, 2], + [1, 2, 1]]) / 16 + gaussian_kernel = gaussian_kernel.repeat(2, 1, 1, 1) + gaussian_kernel = gaussian_kernel.to(torch.cuda.current_device()) + x = torch.nn.functional.pad(x, (1, 1, 1, 1), mode='reflect') + out = torch.nn.functional.conv2d(x, gaussian_kernel, groups=x.shape[1]) + # out = TF.gaussian_blur(x, [3, 3], sigma=[2, 2]) + return out + + +def variance_flow(flow): + flow = flow * torch.tensor(data=[2.0 / (flow.shape[3] - 1.0), 2.0 / (flow.shape[2] - 1.0)], dtype=flow.dtype, + device=flow.device).view(1, 2, 1, 1) + return (gaussian(flow ** 2) - gaussian(flow) ** 2 + 1e-4).sqrt().abs().sum(dim=1, keepdim=True) + + +class FeatPyramid(nn.Module): + """A 3-level feature pyramid, which by default is shared by the motion + estimator and synthesis network. + """ + + def __init__(self): + super(FeatPyramid, self).__init__() + self.conv_stage0 = nn.Sequential( + nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1)) + self.conv_stage1 = nn.Sequential( + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, + stride=2, padding=1), + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), ) + self.conv_stage2 = nn.Sequential( + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, + stride=2, padding=1), + nn.InstanceNorm2d(num_features=128), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=128), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=128), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, img): + C0 = self.conv_stage0(img) + C1 = self.conv_stage1(C0) + C2 = self.conv_stage2(C1) + return [C0, C1, C2] + + +# **************************************************************************************************# +# => Motion Estimation +# **************************************************************************************************# +class MotionEstimator(nn.Module): + """Bi-directional optical flow estimator + 1) construct partial cost volume with the CNN features from the stage 2 of + the feature pyramid; + 2) estimate bi-directional flows, by feeding cost volume, CNN features for + both warped images, CNN feature and estimated flow from previous iteration. + """ + + def __init__(self): + super(MotionEstimator, self).__init__() + # 64 + 256 + 128 * 2 + 128 = 704 + self.conv_flow = nn.Sequential( + nn.Conv2d(4, 128, 7, padding=3), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(128, 64, 3, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + self.conv_corr = nn.Sequential( + nn.Conv2d(81, 64, 1, padding=0), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(64, 128, 3, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + self.conv_layer1 = nn.Sequential( + nn.Conv2d(in_channels=704, out_channels=320, + kernel_size=1, stride=1, padding=0), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer2 = nn.Sequential( + nn.Conv2d(in_channels=320, out_channels=256, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer3 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=224, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer4 = nn.Sequential( + nn.Conv2d(in_channels=224, out_channels=192, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer5 = nn.Sequential( + nn.Conv2d(in_channels=192, out_channels=128, + kernel_size=3, stride=1, padding=1)) + self.conv_layer6 = nn.Sequential( + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=4, + kernel_size=3, stride=1, padding=1, bias=False)) + + self.upsampler = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 16 * 9, 1, padding=0) + ) + + # for m in self.modules(): + # if isinstance(m, nn.Conv2d): + # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') + # elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + # if m.weight is not None: + # nn.init.constant_(m.weight, 1) + # if m.bias is not None: + # nn.init.constant_(m.bias, 0) + + def upsample(self, flow, mask): + """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 4, 4, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(4 * flow, [3, 3], padding=1) + up_flow = up_flow.view(N, 4, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 4, 4 * H, 4 * W) + + def forward(self, feat0, feat1, last_feat, last_flow): + corr_fn = correlation.FunctionCorrelation + feat0_warp = backwarp(feat0, last_flow[:, :2]) + feat1_warp = backwarp(feat1, last_flow[:, 2:]) + + volume0 = F.leaky_relu( + input=costvol_func.apply(feat0_warp, feat1_warp), + negative_slope=0.1, inplace=False) + volume1 = F.leaky_relu( + input=costvol_func.apply(feat1_warp, feat0_warp), + negative_slope=0.1, inplace=False) + corr0 = self.conv_corr(volume0) + corr1 = self.conv_corr(volume1) + flo = self.conv_flow(last_flow) + input_feat = torch.cat([corr0, corr1, feat0_warp, feat1_warp, last_feat, flo], 1) + feat = self.conv_layer1(input_feat) + feat = self.conv_layer2(feat) + feat = self.conv_layer3(feat) + feat = self.conv_layer4(feat) + feat = self.conv_layer5(feat) + flow_res = self.conv_layer6(feat) + flow = last_flow + flow_res + mask = self.upsampler(feat) * .25 + flow = self.upsample(flow, mask) + + return flow, feat + + +# **************************************************************************************************# +# => Frame Synthesis +# **************************************************************************************************# +class SynthesisNetwork(nn.Module): + def __init__(self, splat_mode='average', enc_depths=[1,2,4], fftshift=False): + super(SynthesisNetwork, self).__init__() + input_channels = 9 + 4 + 6 + self.encoder0 = nn.Sequential( + nn.Conv2d(in_channels=input_channels, out_channels=input_channels, + kernel_size=3, stride=1, padding=1, groups=input_channels), + nn.Conv2d(in_channels=input_channels, out_channels=64, + kernel_size=1, stride=1), + nn.PReLU(num_parameters=64)) + self.freq_enhance0 = nn.ModuleList() + for d in range(enc_depths[0]): + self.freq_enhance0.add_module(f'block{d}', + FrequencyEnhancementTransformer( + c_dim=32, feat_dim=64, num_head=4, hidden_ratio=2., + last=False if d!=enc_depths[0]-1 else True, fftshift=fftshift)) + self.encoder1 = nn.Sequential( + nn.Conv2d(in_channels=64 + 32 + 32, out_channels=64, + kernel_size=2, stride=2, padding=0), + nn.PReLU(num_parameters=64)) + self.freq_enhance1 = nn.ModuleList() + for d in range(enc_depths[1]): + self.freq_enhance1.add_module(f'block{d}', + FrequencyEnhancementTransformer( + c_dim=64, feat_dim=64, num_head=4, hidden_ratio=2., + last=False if d!=enc_depths[1]-1 else True, fftshift=fftshift)) + self.encoder2 = nn.Sequential( + nn.Conv2d(in_channels=64 + 64 + 64, out_channels=64, + kernel_size=2, stride=2, padding=0), + nn.PReLU(num_parameters=64)) + self.freq_enhance2 = nn.ModuleList() + for d in range(enc_depths[2]): + self.freq_enhance2.add_module(f'block{d}', + FrequencyEnhancementTransformer( + c_dim=128, feat_dim=64, num_head=4, hidden_ratio=2., + last=False if d!=enc_depths[2]-1 else True, fftshift=fftshift)) + + # s0 + s1` + s2` + warp_c00 + warp_c10 + warp_c10` + warp_c11` + warp_c02` + warp_c12` + flow + # 64 + 16 + 4 + 32 + 32 + 16 + 16 + 8 + 8 + 4 = 200 + self.freq_decoder = FrequencyEnhancementDecoder(concat_dim=200, dim=64, fftshift=fftshift) + self.pred = nn.Conv2d(in_channels=64, out_channels=4, kernel_size=3, + stride=1, padding=1) + self.splat_mode = splat_mode + + if self.splat_mode == 'softmax': + # New params for splatting mask generation + self.alpha = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_photo_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_flow_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_variation_flow = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + + def get_splat_weight(self, img0, img1, flow01, flow10): + if self.splat_mode == 'softmax': + M_splat = 1 / ( + 1 + self.alpha_splat_photo_consistency * photometric_consistency(img0, img1, flow01).detach()) + \ + 1 / (1 + self.alpha_splat_flow_consistency * flow_consistency(flow01, flow10).detach()) + \ + 1 / (1 + self.alpha_splat_variation_flow * variance_flow(flow01).detach()) + return M_splat * self.alpha + else: + return None + + def get_warped_representations(self, bi_flow, c0, c1, m_splat_0, m_splat_1, i0=None, i1=None, time_period=0.5): + flow_t0 = bi_flow[:, :2] * time_period * 2 + flow_t1 = bi_flow[:, 2:4] * (1 - time_period) * 2 + warped_c0 = backwarp(c0, flow_t0) + warped_c1 = backwarp(c1, flow_t1) + if (i0 is None) and (i1 is None): + return warped_c0, warped_c1 + else: + warped_img0 = backwarp(i0, flow_t0) + warped_img1 = backwarp(i1, flow_t1) + scaler = torch.Tensor([i0.shape[3], i0.shape[2]]).view(1, 2, 1, 1).cuda() + flow_t0_t1 = torch.cat((flow_t0 / scaler, flow_t1 / scaler), 1) + return warped_img0, warped_img1, warped_c0, warped_c1, flow_t0_t1 + + def forward(self, last_i, i0, i1, c0_pyr, c1_pyr, bi_flow_pyr, time_period=0.5, multi_flow=False): + m_splat_0_0 = self.get_splat_weight(i0, i1, bi_flow_pyr[0][:, :2], bi_flow_pyr[0][:, 2:4]) + m_splat_1_0 = self.get_splat_weight(i1, i0, bi_flow_pyr[0][:, 2:4], bi_flow_pyr[0][:, :2]) + warped_img0, warped_img1, warped_c00, warped_c10, flow_0t_1t = \ + self.get_warped_representations( + bi_flow_pyr[0], c0_pyr[0], c1_pyr[0], m_splat_0_0, m_splat_1_0, i0, i1, + time_period=time_period) + input_feat = torch.cat( + (last_i, warped_img0, warped_img1, i0, i1, flow_0t_1t), 1) + s0 = self.encoder0(input_feat) # [B, 64,h,w] + for block in self.freq_enhance0: + s0 = block(c0_pyr[0], c1_pyr[0], s0, bi_flow_pyr[0]) + s0 + + s1 = self.encoder1(torch.cat((s0, warped_c00, warped_c10), 1)) # [B, 128,h/2,w/2] + for block in self.freq_enhance1: + s1 = block(c0_pyr[1], c1_pyr[1], s1, bi_flow_pyr[1]) + s1 + warped_c01, warped_c11 = self.get_warped_representations( + bi_flow_pyr[1], c0_pyr[1], c1_pyr[1], None, None, + time_period=time_period) + + s2 = self.encoder2(torch.cat((s1, warped_c01, warped_c11), 1)) # [B, 256,h/4,w/4] + for block in self.freq_enhance2: + s2 = block(c0_pyr[2], c1_pyr[2], s2, bi_flow_pyr[2]) + s2 + warped_c02, warped_c12 = self.get_warped_representations( + bi_flow_pyr[2], c0_pyr[2], c1_pyr[2], None, None, + time_period=time_period) + + x = self.freq_decoder(enc_feats=[s0,s1,s2], + warped_feats=[warped_c00,warped_c10, warped_c01,warped_c11, warped_c02,warped_c12], + flow=bi_flow_pyr[0]) + + # prediction + refine = self.pred(x) + refine_res = torch.sigmoid(refine[:, :3]) * 2 - 1 + refine_mask = torch.sigmoid(refine[:, 3:]) + merged_img = (warped_img0 * refine_mask + + warped_img1 * (1 - refine_mask)) + interp_img = merged_img + refine_res + # interp_img = torch.clamp(interp_img, 0, 1) + + extra_dict = {} + extra_dict["refine_res"] = refine_res + extra_dict["refine_mask"] = refine_mask + extra_dict["warped_img0"] = warped_img0 + extra_dict["warped_img1"] = warped_img1 + extra_dict["merged_img"] = merged_img + extra_dict["c0_pyr"] = c0_pyr + extra_dict["c1_pyr"] = c1_pyr + extra_dict["syn_pyr"] = [s0,s1,s2] + + return interp_img, extra_dict + + +# **************************************************************************************************# +# => Unified model +# **************************************************************************************************# +@register('upr_net_freq') +class Model(nn.Module): + def __init__(self, pyr_level=3, nr_lvl_skipped=0, splat_mode='average', fftshift=False): + super(Model, self).__init__() + print('@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@UPR-back freq005@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@') + self.pyr_level = pyr_level + self.feat_pyramid = FeatPyramid() + self.nr_lvl_skipped = nr_lvl_skipped + self.motion_estimator = MotionEstimator() + self.synthesis_network = SynthesisNetwork(splat_mode, [1,2,2], fftshift) + self.splat_mode = splat_mode + self.fftshift = fftshift + + def forward_one_lvl(self, + img0, img1, last_feat, last_flow, last_interp=None, + time_period=0.5, skip_me=False): + + # context feature extraction + feat0_pyr = self.feat_pyramid(img0) + feat1_pyr = self.feat_pyramid(img1) + + # bi-directional flow estimation + if not skip_me: + last_flow = F.interpolate( + input=last_flow, scale_factor=0.25, + mode="bilinear") * 0.25 + flow, feat = self.motion_estimator( + feat0_pyr[-1], feat1_pyr[-1], + last_feat, last_flow) + else: + flow = last_flow + feat = last_feat + + # frame synthesis + ## optical flow is estimated at 1/4 resolution + ori_resolution_flow = flow + + ## consturct 3-level flow pyramid for synthesis network + bi_flow_pyr = [] + tmp_flow = ori_resolution_flow + bi_flow_pyr.append(tmp_flow) + for i in range(2): + tmp_flow = F.interpolate( + input=tmp_flow, scale_factor=0.5, + mode="bilinear") * 0.5 + bi_flow_pyr.append(tmp_flow) + + ## merge warped frames as initial interpolation for frame synthesis + if last_interp is None: + flow_t0 = ori_resolution_flow[:, :2] * time_period * 2 + flow_t1 = ori_resolution_flow[:, 2:4] * (1 - time_period) * 2 + warped_img0 = backwarp(img0, flow_t0) + warped_img1 = backwarp(img1, flow_t1) + last_interp = warped_img0 * (1 - time_period) + warped_img1 * time_period + + ## do synthesis + interp_img, extra_dict = self.synthesis_network( + last_interp, img0, img1, feat0_pyr, feat1_pyr, bi_flow_pyr, + time_period=time_period) + return flow, feat, interp_img, extra_dict + + def forward(self, img0, img1, time_step, seg0=None, segt=None, seg1=None, + pyr_level=None, nr_lvl_skipped=None, imgt=None, **kwargs): + + if pyr_level is None: pyr_level = self.pyr_level + if nr_lvl_skipped is None: nr_lvl_skipped = self.nr_lvl_skipped + N, _, H, W = img0.shape + flow0_pred = [] + flow1_pred = [] + interp_imgs = [] + skipped_levels = [] if nr_lvl_skipped == 0 else \ + list(range(pyr_level))[::-1][-nr_lvl_skipped:] + + with torch.set_grad_enabled(False): + tenStats = [img0, img1] + tenMean_ = sum([tenIn.mean([1, 2, 3], True) for tenIn in tenStats]) / len(tenStats) + tenStd_ = (sum([tenIn.std([1, 2, 3], False, True).square() + ( + tenMean_ - tenIn.mean([1, 2, 3], True)).square() for tenIn in tenStats]) / len(tenStats)).sqrt() + + img0 = (img0 - tenMean_) / (tenStd_ + 0.0000001) + img1 = (img1 - tenMean_) / (tenStd_ + 0.0000001) + + padder = InputPadder(img0.shape, divisor=int(4 * 2 ** pyr_level)) + img0, img1 = padder.pad(img0, img1) + N, _, H, W = img0.shape + + # The original input resolution corresponds to level 0. + for level in list(range(pyr_level))[::-1]: + if level != 0: + scale_factor = 1 / 2 ** level + img0_this_lvl = F.interpolate( + input=img0, scale_factor=scale_factor, + mode="bilinear", align_corners=False) + img1_this_lvl = F.interpolate( + input=img1, scale_factor=scale_factor, + mode="bilinear", align_corners=False) + else: + img0_this_lvl = img0 + img1_this_lvl = img1 + + # skip motion estimation, directly use up-sampled optical flow + skip_me = False + + # the lowest-resolution pyramid level + if level == pyr_level - 1: + last_flow = torch.zeros( + (N, 4, H // (2 ** (level)), W // (2 ** (level))) + ).to(img0.device) + last_feat = torch.zeros( + (N, 128, H // (2 ** (level + 2)), W // (2 ** (level + 2))) + ).to(img0.device) + last_interp = None + # skip some levels for both motion estimation and frame synthesis + elif level in skipped_levels[:-1]: + continue + # last level (original input resolution), only skip motion estimation + elif (level == 0) and len(skipped_levels) > 0: + if len(skipped_levels) == pyr_level: + last_flow = torch.zeros( + (N, 4, H, W)).to(img0.device) + last_interp = None + else: + resize_factor = 2 ** len(skipped_levels) + last_flow = F.interpolate( + input=flow, scale_factor=resize_factor, + mode="bilinear", align_corners=False) * resize_factor + last_interp = F.interpolate( + input=interp_img, scale_factor=resize_factor, + mode="bilinear", align_corners=False) + skip_me = True + # last level (original input resolution), motion estimation + frame + # synthesis + else: + last_flow = F.interpolate(input=flow, scale_factor=2.0, + mode="bilinear", align_corners=False) * 2 + last_feat = F.interpolate(input=feat, scale_factor=2.0, + mode="bilinear", align_corners=False) + last_interp = F.interpolate( + input=interp_img, scale_factor=2.0, + mode="bilinear", align_corners=False) + + flow, feat, interp_img, extra_dict = self.forward_one_lvl( + img0_this_lvl, img1_this_lvl, + last_feat, last_flow, last_interp, + time_step, skip_me=skip_me) + flow0_pred.append( + padder.unpad(flow[:, :2])) + flow1_pred.append( + padder.unpad(flow[:, 2:])) + interp_imgs.append(padder.unpad(F.interpolate(interp_img, scale_factor=2 ** level)) * tenStd_ + tenMean_) + + # directly up-sample estimated flow to full resolution with bi-linear + # interpolation + refine_res = padder.unpad(extra_dict["refine_res"]) + refine_mask = padder.unpad(extra_dict["refine_mask"]) + c0_pyr = [padder.unpad(cc) for cc in extra_dict["c0_pyr"]] + c1_pyr = [padder.unpad(cc) for cc in extra_dict["c1_pyr"]] + syn_pyr = [padder.unpad(cc) for cc in extra_dict["syn_pyr"]] + warped_img0 = padder.unpad(extra_dict["warped_img0"]) * tenStd_ + tenMean_ + warped_img1 = padder.unpad(extra_dict["warped_img1"]) * tenStd_ + tenMean_ + merged_img = padder.unpad(extra_dict["merged_img"]) * tenStd_ + tenMean_ + result_dict = { + "imgt_preds": interp_imgs, "flow0_pred": flow0_pred[::-1], "flow1_pred": flow1_pred[::-1], + 'imgt_pred': interp_imgs[-1].contiguous(), "flowfwd": flow0_pred[-1], "flowbwd": flow1_pred[-1], + 'refine_res': refine_res, 'refine_mask': refine_mask, 'warped_img0': warped_img0, + 'warped_img1': warped_img1, 'merged_img': merged_img, 'c0_pyr': c0_pyr, 'c1_pyr': c1_pyr, 'syn_pyr': syn_pyr + } + + return result_dict + + +if __name__ == "__main__": + pass \ No newline at end of file diff --git a/modules/components/upr_net_freq/upr_freq_006.py b/modules/components/upr_net_freq/upr_freq_006.py new file mode 100644 index 0000000000000000000000000000000000000000..9aff8960dcc52b5b981eb892c63364d786ddef66 --- /dev/null +++ b/modules/components/upr_net_freq/upr_freq_006.py @@ -0,0 +1,538 @@ +# upr_freq_006.py (freq005+FET c_proj,feat_proj dwconv์žฌ์ถ”๊ฐ€) +import torch +import math +import numpy +import torch.nn.functional as F +import torch.nn as nn +import torchvision.transforms.v2.functional as TF + +import modules.components.upr_net_freq.correlation as correlation +import modules.components.upr_net_freq.softsplat as softsplat +from modules.components.upr_net_freq.m2m import * +from modules.components.upr_net_freq.backwarp import backwarp +from .costvol import costvol_func +from ..components import register +from modules.components.upr_net_freq.frequency_enhance import FrequencyEnhancementTransformer, FrequencyEnhancementDecoder + +from utils.padder import InputPadder +from utils.vos.model.network import STCN +from utils.vos.model.inference_core import InferenceCore + + +# **************************************************************************************************# +# => Feature Pyramid +# **************************************************************************************************# + + +def photometric_consistency(img0, img1, flow01): + return (img0 - backwarp(img1, flow01)).abs().sum(dim=1, keepdims=True) + + +def flow_consistency(flow01, flow10): + return (flow01 + backwarp(flow10, flow01)).abs().sum(dim=1, keepdims=True) + + +def gaussian(x): + gaussian_kernel = torch.tensor([[1, 2, 1], + [2, 4, 2], + [1, 2, 1]]) / 16 + gaussian_kernel = gaussian_kernel.repeat(2, 1, 1, 1) + gaussian_kernel = gaussian_kernel.to(torch.cuda.current_device()) + x = torch.nn.functional.pad(x, (1, 1, 1, 1), mode='reflect') + out = torch.nn.functional.conv2d(x, gaussian_kernel, groups=x.shape[1]) + # out = TF.gaussian_blur(x, [3, 3], sigma=[2, 2]) + return out + + +def variance_flow(flow): + flow = flow * torch.tensor(data=[2.0 / (flow.shape[3] - 1.0), 2.0 / (flow.shape[2] - 1.0)], dtype=flow.dtype, + device=flow.device).view(1, 2, 1, 1) + return (gaussian(flow ** 2) - gaussian(flow) ** 2 + 1e-4).sqrt().abs().sum(dim=1, keepdim=True) + + +class FeatPyramid(nn.Module): + """A 3-level feature pyramid, which by default is shared by the motion + estimator and synthesis network. + """ + + def __init__(self): + super(FeatPyramid, self).__init__() + self.conv_stage0 = nn.Sequential( + nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1)) + self.conv_stage1 = nn.Sequential( + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, + stride=2, padding=1), + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), ) + self.conv_stage2 = nn.Sequential( + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, + stride=2, padding=1), + nn.InstanceNorm2d(num_features=128), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=128), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=128), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, img): + C0 = self.conv_stage0(img) + C1 = self.conv_stage1(C0) + C2 = self.conv_stage2(C1) + return [C0, C1, C2] + + +# **************************************************************************************************# +# => Motion Estimation +# **************************************************************************************************# +class MotionEstimator(nn.Module): + """Bi-directional optical flow estimator + 1) construct partial cost volume with the CNN features from the stage 2 of + the feature pyramid; + 2) estimate bi-directional flows, by feeding cost volume, CNN features for + both warped images, CNN feature and estimated flow from previous iteration. + """ + + def __init__(self): + super(MotionEstimator, self).__init__() + # 64 + 256 + 128 * 2 + 128 = 704 + self.conv_flow = nn.Sequential( + nn.Conv2d(4, 128, 7, padding=3), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(128, 64, 3, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + self.conv_corr = nn.Sequential( + nn.Conv2d(81, 64, 1, padding=0), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(64, 128, 3, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + self.conv_layer1 = nn.Sequential( + nn.Conv2d(in_channels=704, out_channels=320, + kernel_size=1, stride=1, padding=0), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer2 = nn.Sequential( + nn.Conv2d(in_channels=320, out_channels=256, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer3 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=224, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer4 = nn.Sequential( + nn.Conv2d(in_channels=224, out_channels=192, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer5 = nn.Sequential( + nn.Conv2d(in_channels=192, out_channels=128, + kernel_size=3, stride=1, padding=1)) + self.conv_layer6 = nn.Sequential( + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=4, + kernel_size=3, stride=1, padding=1, bias=False)) + + self.upsampler = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 16 * 9, 1, padding=0) + ) + + # for m in self.modules(): + # if isinstance(m, nn.Conv2d): + # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') + # elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + # if m.weight is not None: + # nn.init.constant_(m.weight, 1) + # if m.bias is not None: + # nn.init.constant_(m.bias, 0) + + def upsample(self, flow, mask): + """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 4, 4, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(4 * flow, [3, 3], padding=1) + up_flow = up_flow.view(N, 4, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 4, 4 * H, 4 * W) + + def forward(self, feat0, feat1, last_feat, last_flow): + corr_fn = correlation.FunctionCorrelation + feat0_warp = backwarp(feat0, last_flow[:, :2]) + feat1_warp = backwarp(feat1, last_flow[:, 2:]) + + volume0 = F.leaky_relu( + input=costvol_func.apply(feat0_warp, feat1_warp), + negative_slope=0.1, inplace=False) + volume1 = F.leaky_relu( + input=costvol_func.apply(feat1_warp, feat0_warp), + negative_slope=0.1, inplace=False) + corr0 = self.conv_corr(volume0) + corr1 = self.conv_corr(volume1) + flo = self.conv_flow(last_flow) + input_feat = torch.cat([corr0, corr1, feat0_warp, feat1_warp, last_feat, flo], 1) + feat = self.conv_layer1(input_feat) + feat = self.conv_layer2(feat) + feat = self.conv_layer3(feat) + feat = self.conv_layer4(feat) + feat = self.conv_layer5(feat) + flow_res = self.conv_layer6(feat) + flow = last_flow + flow_res + mask = self.upsampler(feat) * .25 + flow = self.upsample(flow, mask) + + return flow, feat + + +# **************************************************************************************************# +# => Frame Synthesis +# **************************************************************************************************# +class SynthesisNetwork(nn.Module): + def __init__(self, splat_mode='average', enc_depths=[1,2,4], fftshift=False): + super(SynthesisNetwork, self).__init__() + input_channels = 9 + 4 + 6 + self.encoder0 = nn.Sequential( + nn.Conv2d(in_channels=input_channels, out_channels=input_channels, + kernel_size=3, stride=1, padding=1, groups=input_channels), + nn.Conv2d(in_channels=input_channels, out_channels=64, + kernel_size=1, stride=1), + nn.PReLU(num_parameters=64)) + self.freq_enhance0 = nn.ModuleList() + for d in range(enc_depths[0]): + self.freq_enhance0.add_module(f'block{d}', + FrequencyEnhancementTransformer( + c_dim=32, feat_dim=64, num_head=4, hidden_ratio=2., + last=False if d!=enc_depths[0]-1 else True, fftshift=fftshift)) + self.encoder1 = nn.Sequential( + nn.Conv2d(in_channels=64 + 32 + 32, out_channels=64, + kernel_size=2, stride=2, padding=0), + nn.PReLU(num_parameters=64)) + self.freq_enhance1 = nn.ModuleList() + for d in range(enc_depths[1]): + self.freq_enhance1.add_module(f'block{d}', + FrequencyEnhancementTransformer( + c_dim=64, feat_dim=64, num_head=4, hidden_ratio=2., + last=False if d!=enc_depths[1]-1 else True, fftshift=fftshift)) + self.encoder2 = nn.Sequential( + nn.Conv2d(in_channels=64 + 64 + 64, out_channels=64, + kernel_size=2, stride=2, padding=0), + nn.PReLU(num_parameters=64)) + self.freq_enhance2 = nn.ModuleList() + for d in range(enc_depths[2]): + self.freq_enhance2.add_module(f'block{d}', + FrequencyEnhancementTransformer( + c_dim=128, feat_dim=64, num_head=4, hidden_ratio=2., + last=False if d!=enc_depths[2]-1 else True, fftshift=fftshift)) + + # s0 + s1` + s2` + warp_c00 + warp_c10 + warp_c10` + warp_c11` + warp_c02` + warp_c12` + flow + # 64 + 16 + 4 + 32 + 32 + 16 + 16 + 8 + 8 + 4 = 200 + self.freq_decoder = FrequencyEnhancementDecoder(concat_dim=200, dim=64, fftshift=fftshift) + self.pred = nn.Conv2d(in_channels=64, out_channels=4, kernel_size=3, + stride=1, padding=1) + self.splat_mode = splat_mode + + if self.splat_mode == 'softmax': + # New params for splatting mask generation + self.alpha = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_photo_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_flow_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_variation_flow = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + + def get_splat_weight(self, img0, img1, flow01, flow10): + if self.splat_mode == 'softmax': + M_splat = 1 / ( + 1 + self.alpha_splat_photo_consistency * photometric_consistency(img0, img1, flow01).detach()) + \ + 1 / (1 + self.alpha_splat_flow_consistency * flow_consistency(flow01, flow10).detach()) + \ + 1 / (1 + self.alpha_splat_variation_flow * variance_flow(flow01).detach()) + return M_splat * self.alpha + else: + return None + + def get_warped_representations(self, bi_flow, c0, c1, m_splat_0, m_splat_1, i0=None, i1=None, time_period=0.5): + flow_t0 = bi_flow[:, :2] * time_period * 2 + flow_t1 = bi_flow[:, 2:4] * (1 - time_period) * 2 + warped_c0 = backwarp(c0, flow_t0) + warped_c1 = backwarp(c1, flow_t1) + if (i0 is None) and (i1 is None): + return warped_c0, warped_c1 + else: + warped_img0 = backwarp(i0, flow_t0) + warped_img1 = backwarp(i1, flow_t1) + scaler = torch.Tensor([i0.shape[3], i0.shape[2]]).view(1, 2, 1, 1).cuda() + flow_t0_t1 = torch.cat((flow_t0 / scaler, flow_t1 / scaler), 1) + return warped_img0, warped_img1, warped_c0, warped_c1, flow_t0_t1 + + def forward(self, last_i, i0, i1, c0_pyr, c1_pyr, bi_flow_pyr, time_period=0.5, multi_flow=False): + m_splat_0_0 = self.get_splat_weight(i0, i1, bi_flow_pyr[0][:, :2], bi_flow_pyr[0][:, 2:4]) + m_splat_1_0 = self.get_splat_weight(i1, i0, bi_flow_pyr[0][:, 2:4], bi_flow_pyr[0][:, :2]) + warped_img0, warped_img1, warped_c00, warped_c10, flow_0t_1t = \ + self.get_warped_representations( + bi_flow_pyr[0], c0_pyr[0], c1_pyr[0], m_splat_0_0, m_splat_1_0, i0, i1, + time_period=time_period) + input_feat = torch.cat( + (last_i, warped_img0, warped_img1, i0, i1, flow_0t_1t), 1) + s0 = self.encoder0(input_feat) # [B, 64,h,w] + for block in self.freq_enhance0: + s0 = block(c0_pyr[0], c1_pyr[0], s0, bi_flow_pyr[0]) + + s1 = self.encoder1(torch.cat((s0, warped_c00, warped_c10), 1)) # [B, 128,h/2,w/2] + for block in self.freq_enhance1: + s1 = block(c0_pyr[1], c1_pyr[1], s1, bi_flow_pyr[1]) + warped_c01, warped_c11 = self.get_warped_representations( + bi_flow_pyr[1], c0_pyr[1], c1_pyr[1], None, None, + time_period=time_period) + + s2 = self.encoder2(torch.cat((s1, warped_c01, warped_c11), 1)) # [B, 256,h/4,w/4] + for block in self.freq_enhance2: + s2 = block(c0_pyr[2], c1_pyr[2], s2, bi_flow_pyr[2]) + warped_c02, warped_c12 = self.get_warped_representations( + bi_flow_pyr[2], c0_pyr[2], c1_pyr[2], None, None, + time_period=time_period) + + x = self.freq_decoder(enc_feats=[s0,s1,s2], + warped_feats=[warped_c00,warped_c10, warped_c01,warped_c11, warped_c02,warped_c12], + flow=bi_flow_pyr[0]) + + # prediction + refine = self.pred(x) + refine_res = torch.sigmoid(refine[:, :3]) * 2 - 1 + refine_mask = torch.sigmoid(refine[:, 3:]) + merged_img = (warped_img0 * refine_mask + + warped_img1 * (1 - refine_mask)) + interp_img = merged_img + refine_res + # interp_img = torch.clamp(interp_img, 0, 1) + + extra_dict = {} + extra_dict["refine_res"] = refine_res + extra_dict["refine_mask"] = refine_mask + extra_dict["warped_img0"] = warped_img0 + extra_dict["warped_img1"] = warped_img1 + extra_dict["merged_img"] = merged_img + extra_dict["c0_pyr"] = c0_pyr + extra_dict["c1_pyr"] = c1_pyr + extra_dict["syn_pyr"] = [s0,s1,s2] + + return interp_img, extra_dict + + +# **************************************************************************************************# +# => Unified model +# **************************************************************************************************# +@register('upr_net_freq') +class Model(nn.Module): + def __init__(self, pyr_level=3, nr_lvl_skipped=0, splat_mode='average', fftshift=False): + super(Model, self).__init__() + print('@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@UPR-back freq006@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@') + self.pyr_level = pyr_level + self.feat_pyramid = FeatPyramid() + self.nr_lvl_skipped = nr_lvl_skipped + self.motion_estimator = MotionEstimator() + self.synthesis_network = SynthesisNetwork(splat_mode, [1,2,2], fftshift) + self.splat_mode = splat_mode + self.fftshift = fftshift + + def forward_one_lvl(self, + img0, img1, last_feat, last_flow, last_interp=None, + time_period=0.5, skip_me=False): + + # context feature extraction + feat0_pyr = self.feat_pyramid(img0) + feat1_pyr = self.feat_pyramid(img1) + + # bi-directional flow estimation + if not skip_me: + last_flow = F.interpolate( + input=last_flow, scale_factor=0.25, + mode="bilinear") * 0.25 + flow, feat = self.motion_estimator( + feat0_pyr[-1], feat1_pyr[-1], + last_feat, last_flow) + else: + flow = last_flow + feat = last_feat + + # frame synthesis + ## optical flow is estimated at 1/4 resolution + ori_resolution_flow = flow + + ## consturct 3-level flow pyramid for synthesis network + bi_flow_pyr = [] + tmp_flow = ori_resolution_flow + bi_flow_pyr.append(tmp_flow) + for i in range(2): + tmp_flow = F.interpolate( + input=tmp_flow, scale_factor=0.5, + mode="bilinear") * 0.5 + bi_flow_pyr.append(tmp_flow) + + ## merge warped frames as initial interpolation for frame synthesis + if last_interp is None: + flow_t0 = ori_resolution_flow[:, :2] * time_period * 2 + flow_t1 = ori_resolution_flow[:, 2:4] * (1 - time_period) * 2 + warped_img0 = backwarp(img0, flow_t0) + warped_img1 = backwarp(img1, flow_t1) + last_interp = warped_img0 * (1 - time_period) + warped_img1 * time_period + + ## do synthesis + interp_img, extra_dict = self.synthesis_network( + last_interp, img0, img1, feat0_pyr, feat1_pyr, bi_flow_pyr, + time_period=time_period) + return flow, feat, interp_img, extra_dict + + def forward(self, img0, img1, time_step, seg0=None, segt=None, seg1=None, + pyr_level=None, nr_lvl_skipped=None, imgt=None, **kwargs): + + if pyr_level is None: pyr_level = self.pyr_level + if nr_lvl_skipped is None: nr_lvl_skipped = self.nr_lvl_skipped + N, _, H, W = img0.shape + flow0_pred = [] + flow1_pred = [] + interp_imgs = [] + skipped_levels = [] if nr_lvl_skipped == 0 else \ + list(range(pyr_level))[::-1][-nr_lvl_skipped:] + + with torch.set_grad_enabled(False): + tenStats = [img0, img1] + tenMean_ = sum([tenIn.mean([1, 2, 3], True) for tenIn in tenStats]) / len(tenStats) + tenStd_ = (sum([tenIn.std([1, 2, 3], False, True).square() + ( + tenMean_ - tenIn.mean([1, 2, 3], True)).square() for tenIn in tenStats]) / len(tenStats)).sqrt() + + img0 = (img0 - tenMean_) / (tenStd_ + 0.0000001) + img1 = (img1 - tenMean_) / (tenStd_ + 0.0000001) + + padder = InputPadder(img0.shape, divisor=int(4 * 2 ** pyr_level)) + img0, img1 = padder.pad(img0, img1) + N, _, H, W = img0.shape + + # The original input resolution corresponds to level 0. + for level in list(range(pyr_level))[::-1]: + if level != 0: + scale_factor = 1 / 2 ** level + img0_this_lvl = F.interpolate( + input=img0, scale_factor=scale_factor, + mode="bilinear", align_corners=False) + img1_this_lvl = F.interpolate( + input=img1, scale_factor=scale_factor, + mode="bilinear", align_corners=False) + else: + img0_this_lvl = img0 + img1_this_lvl = img1 + + # skip motion estimation, directly use up-sampled optical flow + skip_me = False + + # the lowest-resolution pyramid level + if level == pyr_level - 1: + last_flow = torch.zeros( + (N, 4, H // (2 ** (level)), W // (2 ** (level))) + ).to(img0.device) + last_feat = torch.zeros( + (N, 128, H // (2 ** (level + 2)), W // (2 ** (level + 2))) + ).to(img0.device) + last_interp = None + # skip some levels for both motion estimation and frame synthesis + elif level in skipped_levels[:-1]: + continue + # last level (original input resolution), only skip motion estimation + elif (level == 0) and len(skipped_levels) > 0: + if len(skipped_levels) == pyr_level: + last_flow = torch.zeros( + (N, 4, H, W)).to(img0.device) + last_interp = None + else: + resize_factor = 2 ** len(skipped_levels) + last_flow = F.interpolate( + input=flow, scale_factor=resize_factor, + mode="bilinear", align_corners=False) * resize_factor + last_interp = F.interpolate( + input=interp_img, scale_factor=resize_factor, + mode="bilinear", align_corners=False) + skip_me = True + # last level (original input resolution), motion estimation + frame + # synthesis + else: + last_flow = F.interpolate(input=flow, scale_factor=2.0, + mode="bilinear", align_corners=False) * 2 + last_feat = F.interpolate(input=feat, scale_factor=2.0, + mode="bilinear", align_corners=False) + last_interp = F.interpolate( + input=interp_img, scale_factor=2.0, + mode="bilinear", align_corners=False) + + flow, feat, interp_img, extra_dict = self.forward_one_lvl( + img0_this_lvl, img1_this_lvl, + last_feat, last_flow, last_interp, + time_step, skip_me=skip_me) + flow0_pred.append( + padder.unpad(flow[:, :2])) + flow1_pred.append( + padder.unpad(flow[:, 2:])) + interp_imgs.append(padder.unpad(F.interpolate(interp_img, scale_factor=2 ** level)) * tenStd_ + tenMean_) + + # directly up-sample estimated flow to full resolution with bi-linear + # interpolation + refine_res = padder.unpad(extra_dict["refine_res"]) + refine_mask = padder.unpad(extra_dict["refine_mask"]) + c0_pyr = [padder.unpad(cc) for cc in extra_dict["c0_pyr"]] + c1_pyr = [padder.unpad(cc) for cc in extra_dict["c1_pyr"]] + syn_pyr = [padder.unpad(cc) for cc in extra_dict["syn_pyr"]] + warped_img0 = padder.unpad(extra_dict["warped_img0"]) * tenStd_ + tenMean_ + warped_img1 = padder.unpad(extra_dict["warped_img1"]) * tenStd_ + tenMean_ + merged_img = padder.unpad(extra_dict["merged_img"]) * tenStd_ + tenMean_ + result_dict = { + "imgt_preds": interp_imgs, "flow0_pred": flow0_pred[::-1], "flow1_pred": flow1_pred[::-1], + 'imgt_pred': interp_imgs[-1].contiguous(), "flowfwd": flow0_pred[-1], "flowbwd": flow1_pred[-1], + 'refine_res': refine_res, 'refine_mask': refine_mask, 'warped_img0': warped_img0, + 'warped_img1': warped_img1, 'merged_img': merged_img, 'c0_pyr': c0_pyr, 'c1_pyr': c1_pyr, 'syn_pyr': syn_pyr + } + + return result_dict + + +if __name__ == "__main__": + pass \ No newline at end of file diff --git a/modules/components/upr_net_freq/upr_freq_temp.py b/modules/components/upr_net_freq/upr_freq_temp.py new file mode 100644 index 0000000000000000000000000000000000000000..cb99422b1b668883a3c27e11599ff2a8523d76c8 --- /dev/null +++ b/modules/components/upr_net_freq/upr_freq_temp.py @@ -0,0 +1,552 @@ +import torch +import math +import numpy +import torch.nn.functional as F +import torch.nn as nn +import torchvision.transforms.v2.functional as TF + +import modules.components.upr_net_freq.correlation as correlation +import modules.components.upr_net_freq.softsplat as softsplat +from modules.components.upr_net_freq.m2m import * +from modules.components.upr_net_freq.backwarp import backwarp +from .costvol import costvol_func +from ..components import register +from modules.components.upr_net_freq.frequency_enhance_001 import FrequencyEnhancementTransformer + +from utils.padder import InputPadder +from utils.vos.model.network import STCN +from utils.vos.model.inference_core import InferenceCore + + +# **************************************************************************************************# +# => Feature Pyramid +# **************************************************************************************************# + + +def photometric_consistency(img0, img1, flow01): + return (img0 - backwarp(img1, flow01)).abs().sum(dim=1, keepdims=True) + + +def flow_consistency(flow01, flow10): + return (flow01 + backwarp(flow10, flow01)).abs().sum(dim=1, keepdims=True) + + +def gaussian(x): + gaussian_kernel = torch.tensor([[1, 2, 1], + [2, 4, 2], + [1, 2, 1]]) / 16 + gaussian_kernel = gaussian_kernel.repeat(2, 1, 1, 1) + gaussian_kernel = gaussian_kernel.to(torch.cuda.current_device()) + x = torch.nn.functional.pad(x, (1, 1, 1, 1), mode='reflect') + out = torch.nn.functional.conv2d(x, gaussian_kernel, groups=x.shape[1]) + # out = TF.gaussian_blur(x, [3, 3], sigma=[2, 2]) + return out + + +def variance_flow(flow): + flow = flow * torch.tensor(data=[2.0 / (flow.shape[3] - 1.0), 2.0 / (flow.shape[2] - 1.0)], dtype=flow.dtype, + device=flow.device).view(1, 2, 1, 1) + return (gaussian(flow ** 2) - gaussian(flow) ** 2 + 1e-4).sqrt().abs().sum(dim=1, keepdim=True) + + +class FeatPyramid(nn.Module): + """A 3-level feature pyramid, which by default is shared by the motion + estimator and synthesis network. + """ + + def __init__(self): + super(FeatPyramid, self).__init__() + self.conv_stage0 = nn.Sequential( + nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1)) + self.conv_stage1 = nn.Sequential( + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, + stride=2, padding=1), + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), ) + self.conv_stage2 = nn.Sequential( + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, + stride=2, padding=1), + nn.InstanceNorm2d(num_features=128), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=128), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=128), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, img): + C0 = self.conv_stage0(img) + C1 = self.conv_stage1(C0) + C2 = self.conv_stage2(C1) + return [C0, C1, C2] + + +# **************************************************************************************************# +# => Motion Estimation +# **************************************************************************************************# +class MotionEstimator(nn.Module): + """Bi-directional optical flow estimator + 1) construct partial cost volume with the CNN features from the stage 2 of + the feature pyramid; + 2) estimate bi-directional flows, by feeding cost volume, CNN features for + both warped images, CNN feature and estimated flow from previous iteration. + """ + + def __init__(self): + super(MotionEstimator, self).__init__() + # 64 + 256 + 128 * 2 + 128 = 704 + self.conv_flow = nn.Sequential( + nn.Conv2d(4, 128, 7, padding=3), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(128, 64, 3, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + self.conv_corr = nn.Sequential( + nn.Conv2d(81, 64, 1, padding=0), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(64, 128, 3, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + self.conv_layer1 = nn.Sequential( + nn.Conv2d(in_channels=704, out_channels=320, + kernel_size=1, stride=1, padding=0), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer2 = nn.Sequential( + nn.Conv2d(in_channels=320, out_channels=256, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer3 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=224, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer4 = nn.Sequential( + nn.Conv2d(in_channels=224, out_channels=192, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer5 = nn.Sequential( + nn.Conv2d(in_channels=192, out_channels=128, + kernel_size=3, stride=1, padding=1)) + self.conv_layer6 = nn.Sequential( + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=4, + kernel_size=3, stride=1, padding=1, bias=False)) + + self.upsampler = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 16 * 9, 1, padding=0) + ) + + # for m in self.modules(): + # if isinstance(m, nn.Conv2d): + # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') + # elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + # if m.weight is not None: + # nn.init.constant_(m.weight, 1) + # if m.bias is not None: + # nn.init.constant_(m.bias, 0) + + def upsample(self, flow, mask): + """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 4, 4, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(4 * flow, [3, 3], padding=1) + up_flow = up_flow.view(N, 4, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 4, 4 * H, 4 * W) + + def forward(self, feat0, feat1, last_feat, last_flow): + corr_fn = correlation.FunctionCorrelation + feat0_warp = backwarp(feat0, last_flow[:, :2]) + feat1_warp = backwarp(feat1, last_flow[:, 2:]) + + volume0 = F.leaky_relu( + input=costvol_func.apply(feat0_warp, feat1_warp), + negative_slope=0.1, inplace=False) + volume1 = F.leaky_relu( + input=costvol_func.apply(feat1_warp, feat0_warp), + negative_slope=0.1, inplace=False) + corr0 = self.conv_corr(volume0) + corr1 = self.conv_corr(volume1) + flo = self.conv_flow(last_flow) + input_feat = torch.cat([corr0, corr1, feat0_warp, feat1_warp, last_feat, flo], 1) + feat = self.conv_layer1(input_feat) + feat = self.conv_layer2(feat) + feat = self.conv_layer3(feat) + feat = self.conv_layer4(feat) + feat = self.conv_layer5(feat) + flow_res = self.conv_layer6(feat) + flow = last_flow + flow_res + mask = self.upsampler(feat) * .25 + flow = self.upsample(flow, mask) + + return flow, feat + + +# **************************************************************************************************# +# => Frame Synthesis +# **************************************************************************************************# +class SynthesisNetwork(nn.Module): + def __init__(self, splat_mode='average', fftshift=False): + super(SynthesisNetwork, self).__init__() + input_channels = 9 + 4 + 6 + self.encoder_conv = nn.Sequential( + nn.Conv2d(in_channels=input_channels, out_channels=64, + kernel_size=3, stride=1, padding=1), + nn.PReLU(num_parameters=64), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64)) + self.freq_enhance0 = FrequencyEnhancementTransformer( + c_dim=32, feat_dim=64, num_head=4, hidden_ratio=4., fftshift=fftshift) + self.encoder_down1 = nn.Sequential( + nn.Conv2d(in_channels=64 + 32 + 32, out_channels=128, + kernel_size=3, stride=2, padding=1), + nn.PReLU(num_parameters=128), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=128), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=128)) + self.freq_enhance1 = FrequencyEnhancementTransformer( + c_dim=64, feat_dim=128, num_head=4, hidden_ratio=4., fftshift=fftshift) + self.encoder_down2 = nn.Sequential( + nn.Conv2d(in_channels=128 + 64 + 64, out_channels=256, + kernel_size=3, stride=2, padding=1), + nn.PReLU(num_parameters=256), + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=256), + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=256)) + self.freq_enhance2 = FrequencyEnhancementTransformer( + c_dim=128, feat_dim=256, num_head=4, hidden_ratio=4., fftshift=fftshift) + self.decoder_up1 = nn.Sequential( + torch.nn.ConvTranspose2d(in_channels=256 + 128 + 128, + out_channels=128, kernel_size=4, stride=2, + padding=1, bias=True), + nn.PReLU(num_parameters=128), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=128)) + self.decoder_up2 = nn.Sequential( + torch.nn.ConvTranspose2d(in_channels=128 + 128, + out_channels=64, kernel_size=4, stride=2, + padding=1, bias=True), + nn.PReLU(num_parameters=64), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64)) + self.decoder_conv = nn.Sequential( + nn.Conv2d(in_channels=64 + 64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64)) + self.pred = nn.Conv2d(in_channels=64, out_channels=4, kernel_size=3, + stride=1, padding=1) + self.splat_mode = splat_mode + + if self.splat_mode == 'softmax': + # New params for splatting mask generation + self.alpha = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_photo_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_flow_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_variation_flow = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + + def get_splat_weight(self, img0, img1, flow01, flow10): + if self.splat_mode == 'softmax': + M_splat = 1 / ( + 1 + self.alpha_splat_photo_consistency * photometric_consistency(img0, img1, flow01).detach()) + \ + 1 / (1 + self.alpha_splat_flow_consistency * flow_consistency(flow01, flow10).detach()) + \ + 1 / (1 + self.alpha_splat_variation_flow * variance_flow(flow01).detach()) + return M_splat * self.alpha + else: + return None + + def get_warped_representations(self, bi_flow, c0, c1, m_splat_0, m_splat_1, i0=None, i1=None, time_period=0.5): + flow_t0 = bi_flow[:, :2] * time_period * 2 + flow_t1 = bi_flow[:, 2:4] * (1 - time_period) * 2 + warped_c0 = backwarp(c0, flow_t0) + warped_c1 = backwarp(c1, flow_t1) + if (i0 is None) and (i1 is None): + return warped_c0, warped_c1 + else: + warped_img0 = backwarp(i0, flow_t0) + warped_img1 = backwarp(i1, flow_t1) + scaler = torch.Tensor([i0.shape[3], i0.shape[2]]).view(1, 2, 1, 1).cuda() + flow_t0_t1 = torch.cat((flow_t0 / scaler, flow_t1 / scaler), 1) + return warped_img0, warped_img1, warped_c0, warped_c1, flow_t0_t1 + + def forward(self, last_i, i0, i1, c0_pyr, c1_pyr, bi_flow_pyr, time_period=0.5, multi_flow=False): + m_splat_0_0 = self.get_splat_weight(i0, i1, bi_flow_pyr[0][:, :2], bi_flow_pyr[0][:, 2:4]) + m_splat_1_0 = self.get_splat_weight(i1, i0, bi_flow_pyr[0][:, 2:4], bi_flow_pyr[0][:, :2]) + warped_img0, warped_img1, warped_c0, warped_c1, flow_0t_1t = \ + self.get_warped_representations( + bi_flow_pyr[0], c0_pyr[0], c1_pyr[0], m_splat_0_0, m_splat_1_0, i0, i1, + time_period=time_period) + input_feat = torch.cat( + (last_i, warped_img0, warped_img1, i0, i1, flow_0t_1t), 1) + s0 = self.encoder_conv(input_feat) # [B, 64,h,w] + s0 = self.freq_enhance0(c0_pyr[0], c1_pyr[0], s0, bi_flow_pyr[0]) + s0 + s1 = self.encoder_down1(torch.cat((s0, warped_c0, warped_c1), 1)) # [B, 128,h/2,w/2] + s1 = self.freq_enhance1(c0_pyr[1], c1_pyr[1], s1, bi_flow_pyr[1]) + s1 + warped_c0, warped_c1 = self.get_warped_representations( + bi_flow_pyr[1], c0_pyr[1], c1_pyr[1], None, None, + time_period=time_period) + s2 = self.encoder_down2(torch.cat((s1, warped_c0, warped_c1), 1)) # [B, 256,h/4,w/4] + s2 = self.freq_enhance2(c0_pyr[2], c1_pyr[2], s2, bi_flow_pyr[2]) + s2 + warped_c0, warped_c1 = self.get_warped_representations( + bi_flow_pyr[2], c0_pyr[2], c1_pyr[2], None, None, + time_period=time_period) + + x = self.decoder_up1(torch.cat((s2, warped_c0, warped_c1), 1)) + x = self.decoder_up2(torch.cat((x, s1), 1)) + x = self.decoder_conv(torch.cat((x, s0), 1)) + + # prediction + refine = self.pred(x) + refine_res = torch.sigmoid(refine[:, :3]) * 2 - 1 + refine_mask = torch.sigmoid(refine[:, 3:]) + merged_img = (warped_img0 * refine_mask + + warped_img1 * (1 - refine_mask)) + interp_img = merged_img + refine_res + # interp_img = torch.clamp(interp_img, 0, 1) + + extra_dict = {} + extra_dict["refine_res"] = refine_res + extra_dict["refine_mask"] = refine_mask + extra_dict["warped_img0"] = warped_img0 + extra_dict["warped_img1"] = warped_img1 + extra_dict["merged_img"] = merged_img + extra_dict["c0_pyr"] = c0_pyr + extra_dict["c1_pyr"] = c1_pyr + extra_dict["syn_pyr"] = [s0,s1,s2] + + return interp_img, extra_dict + + +# **************************************************************************************************# +# => Unified model +# **************************************************************************************************# +@register('upr_net_freq') +class Model(nn.Module): + def __init__(self, pyr_level=3, nr_lvl_skipped=0, splat_mode='average', fftshift=False): + super(Model, self).__init__() + print('@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@UPR-back freq@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@') + self.pyr_level = pyr_level + self.feat_pyramid = FeatPyramid() + self.nr_lvl_skipped = nr_lvl_skipped + self.motion_estimator = MotionEstimator() + self.synthesis_network = SynthesisNetwork(splat_mode, fftshift) + self.splat_mode = splat_mode + self.fftshift = fftshift + + def forward_one_lvl(self, + img0, img1, last_feat, last_flow, last_interp=None, + time_period=0.5, skip_me=False): + + # context feature extraction + feat0_pyr = self.feat_pyramid(img0) + feat1_pyr = self.feat_pyramid(img1) + + # bi-directional flow estimation + if not skip_me: + last_flow = F.interpolate( + input=last_flow, scale_factor=0.25, + mode="bilinear") * 0.25 + flow, feat = self.motion_estimator( + feat0_pyr[-1], feat1_pyr[-1], + last_feat, last_flow) + else: + flow = last_flow + feat = last_feat + + # frame synthesis + ## optical flow is estimated at 1/4 resolution + ori_resolution_flow = flow + + ## consturct 3-level flow pyramid for synthesis network + bi_flow_pyr = [] + tmp_flow = ori_resolution_flow + bi_flow_pyr.append(tmp_flow) + for i in range(2): + tmp_flow = F.interpolate( + input=tmp_flow, scale_factor=0.5, + mode="bilinear") * 0.5 + bi_flow_pyr.append(tmp_flow) + + ## merge warped frames as initial interpolation for frame synthesis + if last_interp is None: + flow_t0 = ori_resolution_flow[:, :2] * time_period * 2 + flow_t1 = ori_resolution_flow[:, 2:4] * (1 - time_period) * 2 + warped_img0 = backwarp(img0, flow_t0) + warped_img1 = backwarp(img1, flow_t1) + last_interp = warped_img0 * (1 - time_period) + warped_img1 * time_period + + ## do synthesis + interp_img, extra_dict = self.synthesis_network( + last_interp, img0, img1, feat0_pyr, feat1_pyr, bi_flow_pyr, + time_period=time_period) + return flow, feat, interp_img, extra_dict + + def forward(self, img0, img1, time_step, seg0=None, segt=None, seg1=None, + pyr_level=None, nr_lvl_skipped=None, imgt=None, **kwargs): + + if pyr_level is None: pyr_level = self.pyr_level + if nr_lvl_skipped is None: nr_lvl_skipped = self.nr_lvl_skipped + N, _, H, W = img0.shape + flow0_pred = [] + flow1_pred = [] + interp_imgs = [] + skipped_levels = [] if nr_lvl_skipped == 0 else \ + list(range(pyr_level))[::-1][-nr_lvl_skipped:] + + with torch.set_grad_enabled(False): + tenStats = [img0, img1] + tenMean_ = sum([tenIn.mean([1, 2, 3], True) for tenIn in tenStats]) / len(tenStats) + tenStd_ = (sum([tenIn.std([1, 2, 3], False, True).square() + ( + tenMean_ - tenIn.mean([1, 2, 3], True)).square() for tenIn in tenStats]) / len(tenStats)).sqrt() + + img0 = (img0 - tenMean_) / (tenStd_ + 0.0000001) + img1 = (img1 - tenMean_) / (tenStd_ + 0.0000001) + + padder = InputPadder(img0.shape, divisor=int(4 * 2 ** pyr_level)) + img0, img1 = padder.pad(img0, img1) + N, _, H, W = img0.shape + + # The original input resolution corresponds to level 0. + for level in list(range(pyr_level))[::-1]: + if level != 0: + scale_factor = 1 / 2 ** level + img0_this_lvl = F.interpolate( + input=img0, scale_factor=scale_factor, + mode="bilinear", align_corners=False) + img1_this_lvl = F.interpolate( + input=img1, scale_factor=scale_factor, + mode="bilinear", align_corners=False) + else: + img0_this_lvl = img0 + img1_this_lvl = img1 + + # skip motion estimation, directly use up-sampled optical flow + skip_me = False + + # the lowest-resolution pyramid level + if level == pyr_level - 1: + last_flow = torch.zeros( + (N, 4, H // (2 ** (level)), W // (2 ** (level))) + ).to(img0.device) + last_feat = torch.zeros( + (N, 128, H // (2 ** (level + 2)), W // (2 ** (level + 2))) + ).to(img0.device) + last_interp = None + # skip some levels for both motion estimation and frame synthesis + elif level in skipped_levels[:-1]: + continue + # last level (original input resolution), only skip motion estimation + elif (level == 0) and len(skipped_levels) > 0: + if len(skipped_levels) == pyr_level: + last_flow = torch.zeros( + (N, 4, H, W)).to(img0.device) + last_interp = None + else: + resize_factor = 2 ** len(skipped_levels) + last_flow = F.interpolate( + input=flow, scale_factor=resize_factor, + mode="bilinear", align_corners=False) * resize_factor + last_interp = F.interpolate( + input=interp_img, scale_factor=resize_factor, + mode="bilinear", align_corners=False) + skip_me = True + # last level (original input resolution), motion estimation + frame + # synthesis + else: + last_flow = F.interpolate(input=flow, scale_factor=2.0, + mode="bilinear", align_corners=False) * 2 + last_feat = F.interpolate(input=feat, scale_factor=2.0, + mode="bilinear", align_corners=False) + last_interp = F.interpolate( + input=interp_img, scale_factor=2.0, + mode="bilinear", align_corners=False) + + flow, feat, interp_img, extra_dict = self.forward_one_lvl( + img0_this_lvl, img1_this_lvl, + last_feat, last_flow, last_interp, + time_step, skip_me=skip_me) + flow0_pred.append( + padder.unpad(flow[:, :2])) + flow1_pred.append( + padder.unpad(flow[:, 2:])) + interp_imgs.append(padder.unpad(F.interpolate(interp_img, scale_factor=2 ** level)) * tenStd_ + tenMean_) + + # directly up-sample estimated flow to full resolution with bi-linear + # interpolation + refine_res = padder.unpad(extra_dict["refine_res"]) + refine_mask = padder.unpad(extra_dict["refine_mask"]) + c0_pyr = [padder.unpad(cc) for cc in extra_dict["c0_pyr"]] + c1_pyr = [padder.unpad(cc) for cc in extra_dict["c1_pyr"]] + syn_pyr = [padder.unpad(cc) for cc in extra_dict["syn_pyr"]] + warped_img0 = padder.unpad(extra_dict["warped_img0"]) * tenStd_ + tenMean_ + warped_img1 = padder.unpad(extra_dict["warped_img1"]) * tenStd_ + tenMean_ + merged_img = padder.unpad(extra_dict["merged_img"]) * tenStd_ + tenMean_ + result_dict = { + "imgt_preds": interp_imgs, "flow0_pred": flow0_pred[::-1], "flow1_pred": flow1_pred[::-1], + 'imgt_pred': interp_imgs[-1].contiguous(), "flowfwd": flow0_pred[-1], "flowbwd": flow1_pred[-1], + 'refine_res': refine_res, 'refine_mask': refine_mask, 'warped_img0': warped_img0, + 'warped_img1': warped_img1, 'merged_img': merged_img, 'c0_pyr': c0_pyr, 'c1_pyr': c1_pyr, 'syn_pyr': syn_pyr + } + + return result_dict + + +if __name__ == "__main__": + pass \ No newline at end of file diff --git a/modules/components/upr_net_freq2/__init__.py b/modules/components/upr_net_freq2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c00aa62f5ba5701d7e4fe6bacf9da91fe4da4400 --- /dev/null +++ b/modules/components/upr_net_freq2/__init__.py @@ -0,0 +1 @@ +from .upr_freq import Model \ No newline at end of file diff --git a/modules/components/upr_net_freq2/__pycache__/__init__.cpython-310.pyc b/modules/components/upr_net_freq2/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..abe7db31bc5d3d1ee78416bfe627d2657218ddf1 Binary files /dev/null and b/modules/components/upr_net_freq2/__pycache__/__init__.cpython-310.pyc differ diff --git a/modules/components/upr_net_freq2/__pycache__/__init__.cpython-38.pyc b/modules/components/upr_net_freq2/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..910989add87549a65d67adc5f0f397fb97cf769e Binary files /dev/null and b/modules/components/upr_net_freq2/__pycache__/__init__.cpython-38.pyc differ diff --git a/modules/components/upr_net_freq2/__pycache__/__init__.cpython-39.pyc b/modules/components/upr_net_freq2/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe1ca5e967aa5b0fccef64723d151140be7270ad Binary files /dev/null and b/modules/components/upr_net_freq2/__pycache__/__init__.cpython-39.pyc differ diff --git a/modules/components/upr_net_freq2/__pycache__/correlation.cpython-310.pyc b/modules/components/upr_net_freq2/__pycache__/correlation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a0c4f8d7d4a35c94c1cc3a7e37386939f8c1e9b Binary files /dev/null and b/modules/components/upr_net_freq2/__pycache__/correlation.cpython-310.pyc differ diff --git a/modules/components/upr_net_freq2/__pycache__/correlation.cpython-38.pyc b/modules/components/upr_net_freq2/__pycache__/correlation.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..201ac2f78b0b867c2d88e442972bbde4567e1c19 Binary files /dev/null and b/modules/components/upr_net_freq2/__pycache__/correlation.cpython-38.pyc differ diff --git a/modules/components/upr_net_freq2/__pycache__/correlation.cpython-39.pyc b/modules/components/upr_net_freq2/__pycache__/correlation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d32adfcd26be7ed6586391bb3de164121d0425b2 Binary files /dev/null and b/modules/components/upr_net_freq2/__pycache__/correlation.cpython-39.pyc differ diff --git a/modules/components/upr_net_freq2/__pycache__/frequency_enhance.cpython-310.pyc b/modules/components/upr_net_freq2/__pycache__/frequency_enhance.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..465ec02ddbbc04a24ce6d62d960d4238e5dd4f06 Binary files /dev/null and b/modules/components/upr_net_freq2/__pycache__/frequency_enhance.cpython-310.pyc differ diff --git a/modules/components/upr_net_freq2/__pycache__/frequency_enhance.cpython-38.pyc b/modules/components/upr_net_freq2/__pycache__/frequency_enhance.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2156b26b09b1ef423eaff1ba31f5e3e1870555e7 Binary files /dev/null and b/modules/components/upr_net_freq2/__pycache__/frequency_enhance.cpython-38.pyc differ diff --git a/modules/components/upr_net_freq2/__pycache__/frequency_enhance.cpython-39.pyc b/modules/components/upr_net_freq2/__pycache__/frequency_enhance.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f0f77021c7422ac37a0af34a7c5df8a529cc4a9 Binary files /dev/null and b/modules/components/upr_net_freq2/__pycache__/frequency_enhance.cpython-39.pyc differ diff --git a/modules/components/upr_net_freq2/__pycache__/softsplat.cpython-310.pyc b/modules/components/upr_net_freq2/__pycache__/softsplat.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ae3932f6703cd3bccd8ef55e19cf1ee67e5f53a Binary files /dev/null and b/modules/components/upr_net_freq2/__pycache__/softsplat.cpython-310.pyc differ diff --git a/modules/components/upr_net_freq2/__pycache__/softsplat.cpython-38.pyc b/modules/components/upr_net_freq2/__pycache__/softsplat.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e14d703c8a4ff0e5a6ef78e5574f0b323f88f97 Binary files /dev/null and b/modules/components/upr_net_freq2/__pycache__/softsplat.cpython-38.pyc differ diff --git a/modules/components/upr_net_freq2/__pycache__/softsplat.cpython-39.pyc b/modules/components/upr_net_freq2/__pycache__/softsplat.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e1005d8551367ebc8077d5570e7693bdb41d7b3 Binary files /dev/null and b/modules/components/upr_net_freq2/__pycache__/softsplat.cpython-39.pyc differ diff --git a/modules/components/upr_net_freq2/__pycache__/upr.cpython-310.pyc b/modules/components/upr_net_freq2/__pycache__/upr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e1ec11a9a43054acb1fdec1d789c6b06513e7cc Binary files /dev/null and b/modules/components/upr_net_freq2/__pycache__/upr.cpython-310.pyc differ diff --git a/modules/components/upr_net_freq2/__pycache__/upr.cpython-38.pyc b/modules/components/upr_net_freq2/__pycache__/upr.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96725d0c8ecf1c61092482a28da621da3af2eac0 Binary files /dev/null and b/modules/components/upr_net_freq2/__pycache__/upr.cpython-38.pyc differ diff --git a/modules/components/upr_net_freq2/__pycache__/upr_freq.cpython-310.pyc b/modules/components/upr_net_freq2/__pycache__/upr_freq.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf4d40cd87e021dab4a05bf308cf6a079b4136c4 Binary files /dev/null and b/modules/components/upr_net_freq2/__pycache__/upr_freq.cpython-310.pyc differ diff --git a/modules/components/upr_net_freq2/__pycache__/upr_freq.cpython-38.pyc b/modules/components/upr_net_freq2/__pycache__/upr_freq.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bec8cc68cf3d2da942083b050ed066ee7e7d7437 Binary files /dev/null and b/modules/components/upr_net_freq2/__pycache__/upr_freq.cpython-38.pyc differ diff --git a/modules/components/upr_net_freq2/__pycache__/upr_freq.cpython-39.pyc b/modules/components/upr_net_freq2/__pycache__/upr_freq.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c907e183757b18003e954c4d3e9f2aa5d4ff440 Binary files /dev/null and b/modules/components/upr_net_freq2/__pycache__/upr_freq.cpython-39.pyc differ diff --git a/modules/components/upr_net_freq2/correlation.py b/modules/components/upr_net_freq2/correlation.py new file mode 100644 index 0000000000000000000000000000000000000000..c9c97e3e80f79dd141f01578763090bc96d2a787 --- /dev/null +++ b/modules/components/upr_net_freq2/correlation.py @@ -0,0 +1,397 @@ +#!/usr/bin/env python + +import torch + +import cupy +import re + +kernel_Correlation_rearrange = ''' + extern "C" __global__ void kernel_Correlation_rearrange( + const int n, + const float* input, + float* output + ) { + int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; + + if (intIndex >= n) { + return; + } + + int intSample = blockIdx.z; + int intChannel = blockIdx.y; + + float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex]; + + __syncthreads(); + + int intPaddedY = (intIndex / SIZE_3(input)) + 4; + int intPaddedX = (intIndex % SIZE_3(input)) + 4; + int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX; + + output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue; + } +''' + +kernel_Correlation_updateOutput = ''' + extern "C" __global__ void kernel_Correlation_updateOutput( + const int n, + const float* rbot0, + const float* rbot1, + float* top + ) { + extern __shared__ char patch_data_char[]; + + float *patch_data = (float *)patch_data_char; + + // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1 + int x1 = blockIdx.x + 4; + int y1 = blockIdx.y + 4; + int item = blockIdx.z; + int ch_off = threadIdx.x; + + // Load 3D patch into shared shared memory + for (int j = 0; j < 1; j++) { // HEIGHT + for (int i = 0; i < 1; i++) { // WIDTH + int ji_off = (j + i) * SIZE_3(rbot0); + for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS + int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch; + int idxPatchData = ji_off + ch; + patch_data[idxPatchData] = rbot0[idx1]; + } + } + } + + __syncthreads(); + + __shared__ float sum[32]; + + // Compute correlation + for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) { + sum[ch_off] = 0; + + int s2o = top_channel % 9 - 4; + int s2p = top_channel / 9 - 4; + + for (int j = 0; j < 1; j++) { // HEIGHT + for (int i = 0; i < 1; i++) { // WIDTH + int ji_off = (j + i) * SIZE_3(rbot0); + for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS + int x2 = x1 + s2o; + int y2 = y1 + s2p; + + int idxPatchData = ji_off + ch; + int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch; + + sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2]; + } + } + } + + __syncthreads(); + + if (ch_off == 0) { + float total_sum = 0; + for (int idx = 0; idx < 32; idx++) { + total_sum += sum[idx]; + } + const int sumelems = SIZE_3(rbot0); + const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x; + top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems; + } + } + } +''' + +kernel_Correlation_updateGradFirst = ''' + #define ROUND_OFF 50000 + + extern "C" __global__ void kernel_Correlation_updateGradFirst( + const int n, + const int intSample, + const float* rbot0, + const float* rbot1, + const float* gradOutput, + float* gradFirst, + float* gradSecond + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + int n = intIndex % SIZE_1(gradFirst); // channels + int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 4; // w-pos + int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 4; // h-pos + + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = round_off; + + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) + int ymin = (m - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) + + // Same here: + int xmax = (l - 4 + round_off_s1) - round_off; // floor (l - 4) + int ymax = (m - 4 + round_off_s1) - round_off; // floor (m - 4) + + float sum = 0; + if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { + xmin = max(0,xmin); + xmax = min(SIZE_3(gradOutput)-1,xmax); + + ymin = max(0,ymin); + ymax = min(SIZE_2(gradOutput)-1,ymax); + + for (int p = -4; p <= 4; p++) { + for (int o = -4; o <= 4; o++) { + // Get rbot1 data: + int s2o = o; + int s2p = p; + int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n; + float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n] + + // Index offset for gradOutput in following loops: + int op = (p+4) * 9 + (o+4); // index[o,p] + int idxopoffset = (intSample * SIZE_1(gradOutput) + op); + + for (int y = ymin; y <= ymax; y++) { + for (int x = xmin; x <= xmax; x++) { + int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] + sum += gradOutput[idxgradOutput] * bot1tmp; + } + } + } + } + } + const int sumelems = SIZE_1(gradFirst); + const int bot0index = ((n * SIZE_2(gradFirst)) + (m-4)) * SIZE_3(gradFirst) + (l-4); + gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems; + } } +''' + +kernel_Correlation_updateGradSecond = ''' + #define ROUND_OFF 50000 + + extern "C" __global__ void kernel_Correlation_updateGradSecond( + const int n, + const int intSample, + const float* rbot0, + const float* rbot1, + const float* gradOutput, + float* gradFirst, + float* gradSecond + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + int n = intIndex % SIZE_1(gradSecond); // channels + int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 4; // w-pos + int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 4; // h-pos + + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = round_off; + + float sum = 0; + for (int p = -4; p <= 4; p++) { + for (int o = -4; o <= 4; o++) { + int s2o = o; + int s2p = p; + + //Get X,Y ranges and clamp + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 4 - s2o + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) + int ymin = (m - 4 - s2p + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) + + // Same here: + int xmax = (l - 4 - s2o + round_off_s1) - round_off; // floor (l - 4 - s2o) + int ymax = (m - 4 - s2p + round_off_s1) - round_off; // floor (m - 4 - s2p) + + if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { + xmin = max(0,xmin); + xmax = min(SIZE_3(gradOutput)-1,xmax); + + ymin = max(0,ymin); + ymax = min(SIZE_2(gradOutput)-1,ymax); + + // Get rbot0 data: + int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n; + float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n] + + // Index offset for gradOutput in following loops: + int op = (p+4) * 9 + (o+4); // index[o,p] + int idxopoffset = (intSample * SIZE_1(gradOutput) + op); + + for (int y = ymin; y <= ymax; y++) { + for (int x = xmin; x <= xmax; x++) { + int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] + sum += gradOutput[idxgradOutput] * bot0tmp; + } + } + } + } + } + const int sumelems = SIZE_1(gradSecond); + const int bot1index = ((n * SIZE_2(gradSecond)) + (m-4)) * SIZE_3(gradSecond) + (l-4); + gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems; + } } +''' + +def cupy_kernel(strFunction, objVariables): + strKernel = globals()[strFunction] + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + # end + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg])) + # end + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) + + if objMatch is None: + break + # end + + intArgs = int(objMatch.group(2)) + strArgs = objMatch.group(4).split(',') + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] + + strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') + # end + + return strKernel +# end + +@cupy.memoize(for_each_device=True) +def cupy_launch(strFunction, strKernel): + return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) +# end + +class _FunctionCorrelation(torch.autograd.Function): + @staticmethod + def forward(self, first, second): + rbot0 = first.new_zeros([ first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1] ]) + rbot1 = first.new_zeros([ first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1] ]) + + self.save_for_backward(first, second, rbot0, rbot1) + + assert(first.is_contiguous() == True) + assert(second.is_contiguous() == True) + + output = first.new_zeros([ first.shape[0], 81, first.shape[2], first.shape[3] ]) + + if first.is_cuda == True: + n = first.shape[2] * first.shape[3] + cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { + 'input': first, + 'output': rbot0 + }))( + grid=tuple([ int((n + 16 - 1) / 16), first.shape[1], first.shape[0] ]), + block=tuple([ 16, 1, 1 ]), + args=[ n, first.data_ptr(), rbot0.data_ptr() ] + ) + + n = second.shape[2] * second.shape[3] + cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { + 'input': second, + 'output': rbot1 + }))( + grid=tuple([ int((n + 16 - 1) / 16), second.shape[1], second.shape[0] ]), + block=tuple([ 16, 1, 1 ]), + args=[ n, second.data_ptr(), rbot1.data_ptr() ] + ) + + n = output.shape[1] * output.shape[2] * output.shape[3] + cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'top': output + }))( + grid=tuple([ output.shape[3], output.shape[2], output.shape[0] ]), + block=tuple([ 32, 1, 1 ]), + shared_mem=first.shape[1] * 4, + args=[ n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr() ] + ) + + elif first.is_cuda == False: + raise NotImplementedError() + + # end + + return output + # end + + @staticmethod + def backward(self, gradOutput): + first, second, rbot0, rbot1 = self.saved_tensors + + assert(gradOutput.is_contiguous() == True) + + gradFirst = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[0] == True else None + gradSecond = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[1] == True else None + + if first.is_cuda == True: + if gradFirst is not None: + for intSample in range(first.shape[0]): + n = first.shape[1] * first.shape[2] * first.shape[3] + cupy_launch('kernel_Correlation_updateGradFirst', cupy_kernel('kernel_Correlation_updateGradFirst', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'gradOutput': gradOutput, + 'gradFirst': gradFirst, + 'gradSecond': None + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), gradFirst.data_ptr(), None ] + ) + # end + # end + + if gradSecond is not None: + for intSample in range(first.shape[0]): + n = first.shape[1] * first.shape[2] * first.shape[3] + cupy_launch('kernel_Correlation_updateGradSecond', cupy_kernel('kernel_Correlation_updateGradSecond', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'gradOutput': gradOutput, + 'gradFirst': None, + 'gradSecond': gradSecond + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None, gradSecond.data_ptr() ] + ) + # end + # end + + elif first.is_cuda == False: + raise NotImplementedError() + + # end + + return gradFirst, gradSecond + # end +# end + +def FunctionCorrelation(tenFirst, tenSecond): + return _FunctionCorrelation.apply(tenFirst, tenSecond) +# end + +class ModuleCorrelation(torch.nn.Module): + def __init__(self): + super(ModuleCorrelation, self).__init__() + # end + + def forward(self, tenFirst, tenSecond): + return _FunctionCorrelation.apply(tenFirst, tenSecond) + # end +# end \ No newline at end of file diff --git a/modules/components/upr_net_freq2/frequency_enhance.py b/modules/components/upr_net_freq2/frequency_enhance.py new file mode 100644 index 0000000000000000000000000000000000000000..ff28456974865ca8f980c79dbe3759f37688df0b --- /dev/null +++ b/modules/components/upr_net_freq2/frequency_enhance.py @@ -0,0 +1,187 @@ +# frequency_enhance_008.py (FET output residual connection) + +import math +import torch +import torch.nn as nn +from einops import rearrange +import torch.nn.functional as F +import time + +class ReshapeLayerNorm(nn.Module): + def __init__(self, dim, norm_layer=nn.LayerNorm): + super(ReshapeLayerNorm, self).__init__() + + self.dim = dim + self.norm = norm_layer(dim) + + def forward(self, x): + B, C, H, W = x.size() + x = rearrange(x, 'b c h w -> b (h w) c') + x = self.norm(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=H) + return x + +class ChannelSelfAttention(nn.Module): + def __init__(self, dim, num_head, attn_drop=0.0, proj_drop=0.0): + super(ChannelSelfAttention, self).__init__() + self.dim = dim + self.num_head = num_head + + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_head, 1, 1))), requires_grad=True) + + self.attn_drop = nn.Dropout(attn_drop) + + self.proj = nn.Conv2d(dim, dim, 1) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, q,k,v, sp=None): + B, C, H, W = q.size() + + q,k,v = map(lambda x: rearrange(x, 'b (l c) h w -> b l c (h w)', l=self.num_head), [q,k,v]) # [B, L, C/L, HW] + + attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(2,-1) # [B, L, C/L, C/L] + logit_scale = torch.clamp(self.logit_scale, max=math.log(1. / 0.01)).exp() + attn = attn * logit_scale + + attn = F.softmax(attn, dim=-1) + attn = self.attn_drop(attn) + + x = attn @ v # [B, L, C/L, HW] + + # head merge + x = rearrange(x, 'b l c (h w) -> b (l c) h w', h=H) # [B, C, H, W] + x = self.proj_drop(self.proj(x)) # [B, C, H, W] + + return x + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_ratio, act_layer=nn.GELU, bias=True, drop=0.0): + super(FeedForward, self).__init__() + + self.dim = dim + self.hidden_ratio = hidden_ratio + + self.hidden = nn.Conv2d(dim, int(dim*hidden_ratio), 1, bias=bias) + self.drop1 = nn.Dropout(drop) + self.out = nn.Conv2d(int(dim*hidden_ratio), dim, 1, bias=bias) + self.drop2 = nn.Dropout(drop) + self.act = act_layer() + + def forward(self, x): + return self.drop2(self.out(self.drop1(self.act(self.hidden(x))))) + +def dft(x, fftshift=False): + fft = torch.fft.fft2(x, dim=(2,3), norm='ortho') + fft = torch.fft.fftshift(fft, dim=(2,3)) if fftshift else fft + amplitude = torch.abs(fft) + phase = torch.angle(fft) + return amplitude, phase + +def idft(amplitude, phase): + real = amplitude * torch.cos(phase) + imag = amplitude * torch.sin(phase) + out = torch.fft.ifft2(torch.complex(real, imag), dim=(2,3), norm='ortho') + out = torch.abs(out) + return out + +class FrequencyEnhancementTransformer(nn.Module): + def __init__(self, c_dim, feat_dim, num_head, hidden_ratio, fftshift=False, *args, **kwargs): + super(FrequencyEnhancementTransformer, self).__init__() + self.c_dim = c_dim + self.feat_dim = feat_dim + self.num_head = num_head + self.hidden_ratio = hidden_ratio + self.fftshift = fftshift + + self.c_conv = nn.Sequential(nn.Conv2d(in_channels=c_dim*2+4, out_channels=c_dim*2+4, kernel_size=3, stride=1, padding=1, groups=c_dim*2+4), + nn.Conv2d(in_channels=c_dim*2+4, out_channels=32, kernel_size=1, stride=1), + nn.LeakyReLU()) + self.feat_conv = nn.Sequential(nn.Conv2d(in_channels=feat_dim, out_channels=feat_dim, kernel_size=3, stride=1, padding=1, groups=feat_dim), + nn.Conv2d(in_channels=feat_dim, out_channels=32, kernel_size=1, stride=1), + nn.LeakyReLU()) + + self.q_proj = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1) + self.k_proj = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1) + self.v_proj = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1) + self.attn = ChannelSelfAttention(32, num_head) + self.norm1 = ReshapeLayerNorm(32) + + self.ffn = FeedForward(32, hidden_ratio) + self.norm2 = ReshapeLayerNorm(32) + + self.phase_conv = nn.Sequential(nn.Conv2d(in_channels=32+32, out_channels=32+32, kernel_size=3, stride=1, padding=1, groups=2)) + + self.out_conv = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, groups=32), + nn.Conv2d(in_channels=32, out_channels=feat_dim, kernel_size=1, stride=1), + nn.LeakyReLU()) + + def forward(self, c0, c1, feat, flow, *args, **kwargs): + B,D,H,W = feat.size() + + c = self.c_conv(torch.cat([c0,c1,flow], dim=1)) # [B, 32, H, W] + feat_ = self.feat_conv(feat) # [B, 32, H, W] + + amp_c, pha_c = dft(c, self.fftshift) # [B, 32, H, W] + amp_f, pha_f = dft(feat_, self.fftshift) # [B, 32, H, W] + + amp_q = self.q_proj(amp_c) # [B, 32, H, W] + amp_k = self.k_proj(amp_f) # [B, 32, H, W] + amp_v = self.v_proj(amp_f) # [B, 32, H, W] + amp_attn = self.norm1(self.attn(amp_q, amp_k, amp_v)) # [B, 32, H, W],,, Amplitude is always positive (i.e., >0) + amp = self.ffn(amp_attn+amp_f) # [B, 32, H, W] + + pha = self.phase_conv(torch.cat([pha_c,pha_f], dim=1)) # [B, 32, H, W] + pha = pha[:,:32] + pha[:,32:] + + out = idft(amp, pha) # [B, 32, H, W] + out = self.out_conv(out) # [B, D, H, W] + + return out, (c, feat_) + +class FrequencyEnhancementDecoder(nn.Module): + def __init__(self, concat_dim, dim, fftshift, *args, **kwargs): + super(FrequencyEnhancementDecoder, self).__init__() + self.concat_dim = concat_dim + self.dim = dim + self.fftshift = fftshift + + self.act = nn.LeakyReLU() + + self.in_conv1 = nn.Sequential(nn.Conv2d(concat_dim, concat_dim, 3, 1, 1, groups=concat_dim), + nn.Conv2d(concat_dim, dim, 1, 1), + nn.LeakyReLU()) + self.in_conv2 = nn.Sequential(nn.Conv2d(dim, dim, 3, 1, 1, groups=dim), + nn.Conv2d(dim, dim, 1, 1), + nn.LeakyReLU()) + + self.amp_conv = nn.Conv2d(dim, dim, 3, 1, 1) + self.pha_conv = nn.Conv2d(dim, dim, 3, 1, 1) + + self.out_conv1 = nn.Sequential(nn.Conv2d(dim, dim, 3, 1, 1, groups=dim), + nn.Conv2d(dim, dim, 1, 1), + nn.LeakyReLU()) + self.out_conv2 = nn.Sequential(nn.Conv2d(dim, dim, 3, 1, 1, groups=dim), + nn.Conv2d(dim, dim, 1, 1), + nn.LeakyReLU()) + + def forward(self, enc_feats, warped_feats, flow): + _,_,H0,W0 = enc_feats[0].size() + for i, feat in enumerate(enc_feats[1:]): + enc_feats[i+1] = F.pixel_shuffle(feat, H0//feat.size(2)) + for i, feat in enumerate(warped_feats[2:]): + warped_feats[i+2] = F.pixel_shuffle(feat, H0//feat.size(2)) + + x = torch.cat(enc_feats+warped_feats+[flow], dim=1) + x = self.in_conv1(x) + x = self.in_conv2(x) + + amp, pha = dft(x, self.fftshift) + amp = self.amp_conv(amp) + pha = self.pha_conv(pha) + + out = idft(amp, pha) + + out = self.out_conv1(out) + out = self.out_conv2(out) + + return out \ No newline at end of file diff --git a/modules/components/upr_net_freq2/frequency_enhance_006.py b/modules/components/upr_net_freq2/frequency_enhance_006.py new file mode 100644 index 0000000000000000000000000000000000000000..9bf30b6e6bca8ff43d3cf5a382eed696d606ef6b --- /dev/null +++ b/modules/components/upr_net_freq2/frequency_enhance_006.py @@ -0,0 +1,186 @@ +# frequency_enhance_006.py (์—ฌ๋Ÿฌ๊ฐ€์ง€ ๋ณ€๊ฒฝ) + +import math +import torch +import torch.nn as nn +from einops import rearrange +import torch.nn.functional as F +import time + +class ReshapeLayerNorm(nn.Module): + def __init__(self, dim, norm_layer=nn.LayerNorm): + super(ReshapeLayerNorm, self).__init__() + + self.dim = dim + self.norm = norm_layer(dim) + + def forward(self, x): + B, C, H, W = x.size() + x = rearrange(x, 'b c h w -> b (h w) c') + x = self.norm(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=H) + return x + +class ChannelSelfAttention(nn.Module): + def __init__(self, dim, num_head, attn_drop=0.0, proj_drop=0.0): + super(ChannelSelfAttention, self).__init__() + self.dim = dim + self.num_head = num_head + + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_head, 1, 1))), requires_grad=True) + + self.attn_drop = nn.Dropout(attn_drop) + + self.proj = nn.Conv2d(dim, dim, 1) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, q,k,v, sp=None): + B, C, H, W = q.size() + + q,k,v = map(lambda x: rearrange(x, 'b (l c) h w -> b l c (h w)', l=self.num_head), [q,k,v]) # [B, L, C/L, HW] + + attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(2,-1) # [B, L, C/L, C/L] + logit_scale = torch.clamp(self.logit_scale, max=math.log(1. / 0.01)).exp() + attn = attn * logit_scale + + attn = F.softmax(attn, dim=-1) + attn = self.attn_drop(attn) + + x = attn @ v # [B, L, C/L, HW] + + # head merge + x = rearrange(x, 'b l c (h w) -> b (l c) h w', h=H) # [B, C, H, W] + x = self.proj_drop(self.proj(x)) # [B, C, H, W] + + return x + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_ratio, act_layer=nn.GELU, bias=True, drop=0.0): + super(FeedForward, self).__init__() + + self.dim = dim + self.hidden_ratio = hidden_ratio + + self.hidden = nn.Conv2d(dim, int(dim*hidden_ratio), 1, bias=bias) + self.drop1 = nn.Dropout(drop) + self.out = nn.Conv2d(int(dim*hidden_ratio), dim, 1, bias=bias) + self.drop2 = nn.Dropout(drop) + self.act = act_layer() + + def forward(self, x): + return self.drop2(self.out(self.drop1(self.act(self.hidden(x))))) + +def dft(x, fftshift=False): + fft = torch.fft.fft2(x, dim=(2,3), norm='ortho') + fft = torch.fft.fftshift(fft, dim=(2,3)) if fftshift else fft + amplitude = torch.abs(fft) + phase = torch.angle(fft) + return amplitude, phase + +def idft(amplitude, phase): + real = amplitude * torch.cos(phase) + imag = amplitude * torch.sin(phase) + out = torch.fft.ifft2(torch.complex(real, imag), dim=(2,3), norm='ortho') + out = torch.abs(out) + return out + +class FrequencyEnhancementTransformer(nn.Module): + def __init__(self, c_dim, feat_dim, num_head, hidden_ratio, fftshift=False, *args, **kwargs): + super(FrequencyEnhancementTransformer, self).__init__() + self.c_dim = c_dim + self.feat_dim = feat_dim + self.num_head = num_head + self.hidden_ratio = hidden_ratio + self.fftshift = fftshift + + self.c_conv = nn.Sequential(nn.Conv2d(in_channels=c_dim*2+4, out_channels=c_dim*2+4, kernel_size=3, stride=1, padding=1, groups=c_dim*2+4), + nn.Conv2d(in_channels=c_dim*2+4, out_channels=32, kernel_size=1, stride=1), + nn.LeakyReLU()) + self.feat_conv = nn.Sequential(nn.Conv2d(in_channels=feat_dim, out_channels=feat_dim, kernel_size=3, stride=1, padding=1, groups=feat_dim), + nn.Conv2d(in_channels=feat_dim, out_channels=32, kernel_size=1, stride=1), + nn.LeakyReLU()) + + self.q_proj = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1) + self.k_proj = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1) + self.v_proj = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1) + self.attn = ChannelSelfAttention(32, num_head) + self.norm1 = ReshapeLayerNorm(32) + + self.ffn = FeedForward(32, hidden_ratio) + self.norm2 = ReshapeLayerNorm(32) + + self.phase_conv = nn.Sequential(nn.Conv2d(in_channels=32+32, out_channels=32, kernel_size=3, stride=1, padding=1)) + + self.out_conv = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, groups=32), + nn.Conv2d(in_channels=32, out_channels=feat_dim, kernel_size=1, stride=1), + nn.LeakyReLU()) + + def forward(self, c0, c1, feat, flow, *args, **kwargs): + B,D,H,W = feat.size() + + c = self.c_conv(torch.cat([c0,c1,flow], dim=1)) # [B, 32, H, W] + feat_ = self.feat_conv(feat) # [B, 32, H, W] + + amp_c, pha_c = dft(c, self.fftshift) # [B, 32, H, W] + amp_f, pha_f = dft(feat_, self.fftshift) # [B, 32, H, W] + + amp_q = self.q_proj(amp_c) # [B, 32, H, W] + amp_k = self.k_proj(amp_c) # [B, 32, H, W] + amp_v = self.v_proj(amp_f) # [B, 32, H, W] + amp_attn = self.norm1(self.attn(amp_q, amp_k, amp_v)) # [B, 32, H, W] + amp = self.norm2(self.ffn(amp_attn)) # [B, 32, H, W] + + pha = self.phase_conv(torch.cat([pha_c,pha_f], dim=1)) # [B, 32, H, W] + + out = idft(amp, pha) # [B, 32, H, W] + out = self.out_conv(out) # [B, D, H, W] + + return out + +class FrequencyEnhancementDecoder(nn.Module): + def __init__(self, concat_dim, dim, fftshift, *args, **kwargs): + super(FrequencyEnhancementDecoder, self).__init__() + self.concat_dim = concat_dim + self.dim = dim + self.fftshift = fftshift + + self.act = nn.LeakyReLU() + + self.in_conv1 = nn.Sequential(nn.Conv2d(concat_dim, concat_dim, 3, 1, 1, groups=concat_dim), + nn.Conv2d(concat_dim, dim, 1, 1), + nn.LeakyReLU()) + self.in_conv2 = nn.Sequential(nn.Conv2d(dim, dim, 3, 1, 1, groups=dim), + nn.Conv2d(dim, dim, 1, 1), + nn.LeakyReLU()) + + self.amp_conv = nn.Conv2d(dim, dim, 3, 1, 1) + self.pha_conv = nn.Conv2d(dim, dim, 3, 1, 1) + + self.out_conv1 = nn.Sequential(nn.Conv2d(dim, dim, 3, 1, 1, groups=dim), + nn.Conv2d(dim, dim, 1, 1), + nn.LeakyReLU()) + self.out_conv2 = nn.Sequential(nn.Conv2d(dim, dim, 3, 1, 1, groups=dim), + nn.Conv2d(dim, dim, 1, 1), + nn.LeakyReLU()) + + def forward(self, enc_feats, warped_feats, flow): + _,_,H0,W0 = enc_feats[0].size() + for i, feat in enumerate(enc_feats[1:]): + enc_feats[i+1] = F.pixel_shuffle(feat, H0//feat.size(2)) + for i, feat in enumerate(warped_feats[2:]): + warped_feats[i+2] = F.pixel_shuffle(feat, H0//feat.size(2)) + + x = torch.cat(enc_feats+warped_feats+[flow], dim=1) + x = self.in_conv1(x) + x = self.in_conv2(x) + x + + amp, pha = dft(x, self.fftshift) + amp = self.amp_conv(amp) + amp + pha = self.pha_conv(pha) + pha + + out = idft(amp, pha) + x + + out = self.out_conv1(out) + out + out = self.out_conv2(out) + out + + return out \ No newline at end of file diff --git a/modules/components/upr_net_freq2/frequency_enhance_007.py b/modules/components/upr_net_freq2/frequency_enhance_007.py new file mode 100644 index 0000000000000000000000000000000000000000..1a9797a3a4ed09296baf909c6e56382caaf78f8c --- /dev/null +++ b/modules/components/upr_net_freq2/frequency_enhance_007.py @@ -0,0 +1,186 @@ +# frequency_enhance_007.py (FET value residual connection) + +import math +import torch +import torch.nn as nn +from einops import rearrange +import torch.nn.functional as F +import time + +class ReshapeLayerNorm(nn.Module): + def __init__(self, dim, norm_layer=nn.LayerNorm): + super(ReshapeLayerNorm, self).__init__() + + self.dim = dim + self.norm = norm_layer(dim) + + def forward(self, x): + B, C, H, W = x.size() + x = rearrange(x, 'b c h w -> b (h w) c') + x = self.norm(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=H) + return x + +class ChannelSelfAttention(nn.Module): + def __init__(self, dim, num_head, attn_drop=0.0, proj_drop=0.0): + super(ChannelSelfAttention, self).__init__() + self.dim = dim + self.num_head = num_head + + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_head, 1, 1))), requires_grad=True) + + self.attn_drop = nn.Dropout(attn_drop) + + self.proj = nn.Conv2d(dim, dim, 1) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, q,k,v, sp=None): + B, C, H, W = q.size() + + q,k,v = map(lambda x: rearrange(x, 'b (l c) h w -> b l c (h w)', l=self.num_head), [q,k,v]) # [B, L, C/L, HW] + + attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(2,-1) # [B, L, C/L, C/L] + logit_scale = torch.clamp(self.logit_scale, max=math.log(1. / 0.01)).exp() + attn = attn * logit_scale + + attn = F.softmax(attn, dim=-1) + attn = self.attn_drop(attn) + + x = attn @ v + v # [B, L, C/L, HW] + + # head merge + x = rearrange(x, 'b l c (h w) -> b (l c) h w', h=H) # [B, C, H, W] + x = self.proj_drop(self.proj(x)) # [B, C, H, W] + + return x + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_ratio, act_layer=nn.GELU, bias=True, drop=0.0): + super(FeedForward, self).__init__() + + self.dim = dim + self.hidden_ratio = hidden_ratio + + self.hidden = nn.Conv2d(dim, int(dim*hidden_ratio), 1, bias=bias) + self.drop1 = nn.Dropout(drop) + self.out = nn.Conv2d(int(dim*hidden_ratio), dim, 1, bias=bias) + self.drop2 = nn.Dropout(drop) + self.act = act_layer() + + def forward(self, x): + return self.drop2(self.out(self.drop1(self.act(self.hidden(x))))) + +def dft(x, fftshift=False): + fft = torch.fft.fft2(x, dim=(2,3), norm='ortho') + fft = torch.fft.fftshift(fft, dim=(2,3)) if fftshift else fft + amplitude = torch.abs(fft) + phase = torch.angle(fft) + return amplitude, phase + +def idft(amplitude, phase): + real = amplitude * torch.cos(phase) + imag = amplitude * torch.sin(phase) + out = torch.fft.ifft2(torch.complex(real, imag), dim=(2,3), norm='ortho') + out = torch.abs(out) + return out + +class FrequencyEnhancementTransformer(nn.Module): + def __init__(self, c_dim, feat_dim, num_head, hidden_ratio, fftshift=False, *args, **kwargs): + super(FrequencyEnhancementTransformer, self).__init__() + self.c_dim = c_dim + self.feat_dim = feat_dim + self.num_head = num_head + self.hidden_ratio = hidden_ratio + self.fftshift = fftshift + + self.c_conv = nn.Sequential(nn.Conv2d(in_channels=c_dim*2+4, out_channels=c_dim*2+4, kernel_size=3, stride=1, padding=1, groups=c_dim*2+4), + nn.Conv2d(in_channels=c_dim*2+4, out_channels=32, kernel_size=1, stride=1), + nn.LeakyReLU()) + self.feat_conv = nn.Sequential(nn.Conv2d(in_channels=feat_dim, out_channels=feat_dim, kernel_size=3, stride=1, padding=1, groups=feat_dim), + nn.Conv2d(in_channels=feat_dim, out_channels=32, kernel_size=1, stride=1), + nn.LeakyReLU()) + + self.q_proj = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1) + self.k_proj = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1) + self.v_proj = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=1, stride=1) + self.attn = ChannelSelfAttention(32, num_head) + self.norm1 = ReshapeLayerNorm(32) + + self.ffn = FeedForward(32, hidden_ratio) + self.norm2 = ReshapeLayerNorm(32) + + self.phase_conv = nn.Sequential(nn.Conv2d(in_channels=32+32, out_channels=32, kernel_size=3, stride=1, padding=1)) + + self.out_conv = nn.Sequential(nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, groups=32), + nn.Conv2d(in_channels=32, out_channels=feat_dim, kernel_size=1, stride=1), + nn.LeakyReLU()) + + def forward(self, c0, c1, feat, flow, *args, **kwargs): + B,D,H,W = feat.size() + + c = self.c_conv(torch.cat([c0,c1,flow], dim=1)) # [B, 32, H, W] + feat_ = self.feat_conv(feat) # [B, 32, H, W] + + amp_c, pha_c = dft(c, self.fftshift) # [B, 32, H, W] + amp_f, pha_f = dft(feat_, self.fftshift) # [B, 32, H, W] + + amp_q = self.q_proj(amp_c) # [B, 32, H, W] + amp_k = self.k_proj(amp_c) # [B, 32, H, W] + amp_v = self.v_proj(amp_f) # [B, 32, H, W] + amp_attn = self.norm1(self.attn(amp_q, amp_k, amp_v)) # [B, 32, H, W] + amp = self.norm2(self.ffn(amp_attn)) # [B, 32, H, W] + + pha = self.phase_conv(torch.cat([pha_c,pha_f], dim=1)) # [B, 32, H, W] + + out = idft(amp, pha) # [B, 32, H, W] + out = self.out_conv(out) # [B, D, H, W] + + return out + +class FrequencyEnhancementDecoder(nn.Module): + def __init__(self, concat_dim, dim, fftshift, *args, **kwargs): + super(FrequencyEnhancementDecoder, self).__init__() + self.concat_dim = concat_dim + self.dim = dim + self.fftshift = fftshift + + self.act = nn.LeakyReLU() + + self.in_conv1 = nn.Sequential(nn.Conv2d(concat_dim, concat_dim, 3, 1, 1, groups=concat_dim), + nn.Conv2d(concat_dim, dim, 1, 1), + nn.LeakyReLU()) + self.in_conv2 = nn.Sequential(nn.Conv2d(dim, dim, 3, 1, 1, groups=dim), + nn.Conv2d(dim, dim, 1, 1), + nn.LeakyReLU()) + + self.amp_conv = nn.Conv2d(dim, dim, 3, 1, 1) + self.pha_conv = nn.Conv2d(dim, dim, 3, 1, 1) + + self.out_conv1 = nn.Sequential(nn.Conv2d(dim, dim, 3, 1, 1, groups=dim), + nn.Conv2d(dim, dim, 1, 1), + nn.LeakyReLU()) + self.out_conv2 = nn.Sequential(nn.Conv2d(dim, dim, 3, 1, 1, groups=dim), + nn.Conv2d(dim, dim, 1, 1), + nn.LeakyReLU()) + + def forward(self, enc_feats, warped_feats, flow): + _,_,H0,W0 = enc_feats[0].size() + for i, feat in enumerate(enc_feats[1:]): + enc_feats[i+1] = F.pixel_shuffle(feat, H0//feat.size(2)) + for i, feat in enumerate(warped_feats[2:]): + warped_feats[i+2] = F.pixel_shuffle(feat, H0//feat.size(2)) + + x = torch.cat(enc_feats+warped_feats+[flow], dim=1) + x = self.in_conv1(x) + x = self.in_conv2(x) + x + + amp, pha = dft(x, self.fftshift) + amp = self.amp_conv(amp) + amp + pha = self.pha_conv(pha) + pha + + out = idft(amp, pha) + x + + out = self.out_conv1(out) + out + out = self.out_conv2(out) + out + + return out \ No newline at end of file diff --git a/modules/components/upr_net_freq2/softsplat.py b/modules/components/upr_net_freq2/softsplat.py new file mode 100644 index 0000000000000000000000000000000000000000..8967303376941351da0453ecc1ea61163180dcd3 --- /dev/null +++ b/modules/components/upr_net_freq2/softsplat.py @@ -0,0 +1,359 @@ +#!/usr/bin/env python + +import torch + +import cupy +import re + +kernel_Softsplat_updateOutput = ''' + extern "C" __global__ void kernel_Softsplat_updateOutput( + const int n, + const float* input, + const float* flow, + float* output + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(output) / SIZE_2(output) / SIZE_1(output) ) % SIZE_0(output); + const int intC = ( intIndex / SIZE_3(output) / SIZE_2(output) ) % SIZE_1(output); + const int intY = ( intIndex / SIZE_3(output) ) % SIZE_2(output); + const int intX = ( intIndex ) % SIZE_3(output); + + float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); + float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); + + int intNorthwestX = (int) (floor(fltOutputX)); + int intNorthwestY = (int) (floor(fltOutputY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + float fltNorthwest = ((float) (intSoutheastX) - fltOutputX ) * ((float) (intSoutheastY) - fltOutputY ); + float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY ); + float fltSouthwest = ((float) (intNortheastX) - fltOutputX ) * (fltOutputY - (float) (intNortheastY)); + float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY)); + + if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(output)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(output))) { + atomicAdd(&output[OFFSET_4(output, intN, intC, intNorthwestY, intNorthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltNorthwest); + } + + if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(output)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(output))) { + atomicAdd(&output[OFFSET_4(output, intN, intC, intNortheastY, intNortheastX)], VALUE_4(input, intN, intC, intY, intX) * fltNortheast); + } + + if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(output)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(output))) { + atomicAdd(&output[OFFSET_4(output, intN, intC, intSouthwestY, intSouthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltSouthwest); + } + + if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(output)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(output))) { + atomicAdd(&output[OFFSET_4(output, intN, intC, intSoutheastY, intSoutheastX)], VALUE_4(input, intN, intC, intY, intX) * fltSoutheast); + } + } } +''' + +kernel_Softsplat_updateGradInput = ''' + extern "C" __global__ void kernel_Softsplat_updateGradInput( + const int n, + const float* input, + const float* flow, + const float* gradOutput, + float* gradInput, + float* gradFlow + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) / SIZE_1(gradInput) ) % SIZE_0(gradInput); + const int intC = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) ) % SIZE_1(gradInput); + const int intY = ( intIndex / SIZE_3(gradInput) ) % SIZE_2(gradInput); + const int intX = ( intIndex ) % SIZE_3(gradInput); + + float fltGradInput = 0.0; + + float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); + float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); + + int intNorthwestX = (int) (floor(fltOutputX)); + int intNorthwestY = (int) (floor(fltOutputY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + float fltNorthwest = ((float) (intSoutheastX) - fltOutputX ) * ((float) (intSoutheastY) - fltOutputY ); + float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY ); + float fltSouthwest = ((float) (intNortheastX) - fltOutputX ) * (fltOutputY - (float) (intNortheastY)); + float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY)); + + if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) { + fltGradInput += VALUE_4(gradOutput, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest; + } + + if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) { + fltGradInput += VALUE_4(gradOutput, intN, intC, intNortheastY, intNortheastX) * fltNortheast; + } + + if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) { + fltGradInput += VALUE_4(gradOutput, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest; + } + + if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) { + fltGradInput += VALUE_4(gradOutput, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast; + } + + gradInput[intIndex] = fltGradInput; + } } +''' + +kernel_Softsplat_updateGradFlow = ''' + extern "C" __global__ void kernel_Softsplat_updateGradFlow( + const int n, + const float* input, + const float* flow, + const float* gradOutput, + float* gradInput, + float* gradFlow + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + float fltGradFlow = 0.0; + + const int intN = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) / SIZE_1(gradFlow) ) % SIZE_0(gradFlow); + const int intC = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) ) % SIZE_1(gradFlow); + const int intY = ( intIndex / SIZE_3(gradFlow) ) % SIZE_2(gradFlow); + const int intX = ( intIndex ) % SIZE_3(gradFlow); + + float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); + float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); + + int intNorthwestX = (int) (floor(fltOutputX)); + int intNorthwestY = (int) (floor(fltOutputY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + float fltNorthwest = 0.0; + float fltNortheast = 0.0; + float fltSouthwest = 0.0; + float fltSoutheast = 0.0; + + if (intC == 0) { + fltNorthwest = ((float) (-1.0)) * ((float) (intSoutheastY) - fltOutputY ); + fltNortheast = ((float) (+1.0)) * ((float) (intSouthwestY) - fltOutputY ); + fltSouthwest = ((float) (-1.0)) * (fltOutputY - (float) (intNortheastY)); + fltSoutheast = ((float) (+1.0)) * (fltOutputY - (float) (intNorthwestY)); + + } else if (intC == 1) { + fltNorthwest = ((float) (intSoutheastX) - fltOutputX ) * ((float) (-1.0)); + fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (-1.0)); + fltSouthwest = ((float) (intNortheastX) - fltOutputX ) * ((float) (+1.0)); + fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * ((float) (+1.0)); + + } + + for (int intChannel = 0; intChannel < SIZE_1(gradOutput); intChannel += 1) { + float fltInput = VALUE_4(input, intN, intChannel, intY, intX); + + if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) { + fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNorthwestY, intNorthwestX) * fltNorthwest; + } + + if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) { + fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNortheastY, intNortheastX) * fltNortheast; + } + + if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) { + fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSouthwestY, intSouthwestX) * fltSouthwest; + } + + if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) { + fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSoutheastY, intSoutheastX) * fltSoutheast; + } + } + + gradFlow[intIndex] = fltGradFlow; + } } +''' + +def cupy_kernel(strFunction, objVariables): + strKernel = globals()[strFunction] + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg])) + + while True: + objMatch = re.search('(OFFSET_)([0-4])(\()([^\)]+)(\))', strKernel) + + if objMatch is None: + break + + intArgs = int(objMatch.group(2)) + strArgs = objMatch.group(4).split(',') + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')')\ + .strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] + + strKernel = strKernel.replace(objMatch.group(0), '(' + str.join('+', strIndex) + ')') + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) + + if objMatch is None: + break + + intArgs = int(objMatch.group(2)) + strArgs = objMatch.group(4).split(',') + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')')\ + .strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] + + strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') + + return strKernel + + +@cupy.memoize(for_each_device=True) +def cupy_launch(strFunction, strKernel): + return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) + + +class _FunctionSoftsplat(torch.autograd.Function): + @staticmethod + def forward(self, input, flow): + self.save_for_backward(input, flow) + + intSamples = input.shape[0] + intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3] + intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3] + + assert(intFlowDepth == 2) + assert(intInputHeight == intFlowHeight) + assert(intInputWidth == intFlowWidth) + + assert(input.is_contiguous() == True) + assert(flow.is_contiguous() == True) + + output = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ]) + + if input.is_cuda == True: + n = output.nelement() + cupy_launch('kernel_Softsplat_updateOutput', cupy_kernel('kernel_Softsplat_updateOutput', { + 'input': input, + 'flow': flow, + 'output': output + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ n, input.data_ptr(), flow.data_ptr(), output.data_ptr() ] + ) + + elif input.is_cuda == False: + raise NotImplementedError() + + return output + + + @staticmethod + def backward(self, gradOutput): + input, flow = self.saved_tensors + + intSamples = input.shape[0] + intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3] + intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3] + + assert(intFlowDepth == 2) + assert(intInputHeight == intFlowHeight) + assert(intInputWidth == intFlowWidth) + + assert(gradOutput.is_contiguous() == True) + + gradInput = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ])\ + if self.needs_input_grad[0] == True else None + gradFlow = input.new_zeros([ intSamples, intFlowDepth, intFlowHeight, intFlowWidth ])\ + if self.needs_input_grad[1] == True else None + + if input.is_cuda == True: + if gradInput is not None: + n = gradInput.nelement() + cupy_launch('kernel_Softsplat_updateGradInput', cupy_kernel('kernel_Softsplat_updateGradInput', { + 'input': input, + 'flow': flow, + 'gradOutput': gradOutput, + 'gradInput': gradInput, + 'gradFlow': gradFlow + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ n, input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), gradInput.data_ptr(), None ] + ) + + if gradFlow is not None: + n = gradFlow.nelement() + cupy_launch('kernel_Softsplat_updateGradFlow', cupy_kernel('kernel_Softsplat_updateGradFlow', { + 'input': input, + 'flow': flow, + 'gradOutput': gradOutput, + 'gradInput': gradInput, + 'gradFlow': gradFlow + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ n, input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), None, gradFlow.data_ptr() ] + ) + + elif input.is_cuda == False: + raise NotImplementedError() + + + return gradInput, gradFlow + + +def FunctionSoftsplat(tenInput, tenFlow, tenMetric, strType): + assert(tenMetric is None or tenMetric.shape[1] == 1) + assert(strType in ['summation', 'average', 'linear', 'softmax']) + + if strType == 'average': + tenInput = torch.cat([ tenInput, tenInput.new_ones(tenInput.shape[0], 1, tenInput.shape[2], tenInput.shape[3]) ], 1) + + elif strType == 'linear': + tenInput = torch.cat([ tenInput * tenMetric, tenMetric ], 1) + + elif strType == 'softmax': + tenInput = torch.cat([ tenInput * tenMetric.exp(), tenMetric.exp() ], 1) + + + tenOutput = _FunctionSoftsplat.apply(tenInput, tenFlow) + + if strType != 'summation': + tenNormalize = tenOutput[:, -1:, :, :] + + tenNormalize[tenNormalize == 0.0] = 1.0 + + tenOutput = tenOutput[:, :-1, :, :] / tenNormalize + + return tenOutput + + +class ModuleSoftsplat(torch.nn.Module): + def __init__(self, strType): + super(ModuleSoftsplat, self).__init__() + + self.strType = strType + + def forward(self, tenInput, tenFlow, tenMetric): + return FunctionSoftsplat(tenInput, tenFlow, tenMetric, self.strType) diff --git a/modules/components/upr_net_freq2/upr_freq.py b/modules/components/upr_net_freq2/upr_freq.py new file mode 100644 index 0000000000000000000000000000000000000000..cf5b532816fc1aa4bbbac9d08b3d7919eefe1e54 --- /dev/null +++ b/modules/components/upr_net_freq2/upr_freq.py @@ -0,0 +1,434 @@ +# upr_freq_008.py + +import torch +import math +import numpy +import torch.nn.functional as F +import torch.nn as nn + +from ..components import register + +import modules.components.upr_net_freq2.softsplat as softsplat +import modules.components.upr_net_freq2.correlation as correlation +from utils.padder import InputPadder + +from modules.components.upr_net_freq2.frequency_enhance import FrequencyEnhancementTransformer, FrequencyEnhancementDecoder + + +#**************************************************************************************************# +# => Feature Pyramid +#**************************************************************************************************# +class FeatPyramid(nn.Module): + """A 3-level feature pyramid, which by default is shared by the motion + estimator and synthesis network. + """ + def __init__(self): + super(FeatPyramid, self).__init__() + self.conv_stage0 = nn.Sequential( + nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_stage1 = nn.Sequential( + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, + stride=2, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_stage2 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, + stride=2, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + + def forward(self, img): + C0 = self.conv_stage0(img) + C1 = self.conv_stage1(C0) + C2 = self.conv_stage2(C1) + return [C0, C1, C2] + + + + +#**************************************************************************************************# +# => Motion Estimation +#**************************************************************************************************# +class MotionEstimator(nn.Module): + """Bi-directional optical flow estimator + 1) construct partial cost volume with the CNN features from the stage 2 of + the feature pyramid; + 2) estimate bi-directional flows, by feeding cost volume, CNN features for + both warped images, CNN feature and estimated flow from previous iteration. + """ + def __init__(self): + super(MotionEstimator, self).__init__() + # (4*2 + 1) ** 2 + 128 * 2 + 128 + 4 = 469 + self.conv_layer1 = nn.Sequential( + nn.Conv2d(in_channels=469, out_channels=320, + kernel_size=1, stride=1, padding=0), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer2 = nn.Sequential( + nn.Conv2d(in_channels=320, out_channels=256, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer3 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=224, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer4 = nn.Sequential( + nn.Conv2d(in_channels=224, out_channels=192, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer5 = nn.Sequential( + nn.Conv2d(in_channels=192, out_channels=128, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer6 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=4, + kernel_size=3, stride=1, padding=1)) + + + def forward(self, feat0, feat1, last_feat, last_flow): + corr_fn=correlation.FunctionCorrelation + feat0 = softsplat.FunctionSoftsplat( + tenInput=feat0, tenFlow=last_flow[:, :2]*0.25*0.5, + tenMetric=None, strType='average') + feat1 = softsplat.FunctionSoftsplat( + tenInput=feat1, tenFlow=last_flow[:, 2:]*0.25*0.5, + tenMetric=None, strType='average') + + volume = F.leaky_relu( + input=corr_fn(tenFirst=feat0, tenSecond=feat1), + negative_slope=0.1, inplace=False) + input_feat = torch.cat([volume, feat0, feat1, last_feat, last_flow], 1) + feat = self.conv_layer1(input_feat) + feat = self.conv_layer2(feat) + feat = self.conv_layer3(feat) + feat = self.conv_layer4(feat) + feat = self.conv_layer5(feat) + flow = self.conv_layer6(feat) + + return flow, feat + + + + +#**************************************************************************************************# +# => Frame Synthesis +#**************************************************************************************************# +class SynthesisNetwork(nn.Module): + def __init__(self, enc_depths=[1,1,1], fftshift=False): + super(SynthesisNetwork, self).__init__() + input_channels = 9 + 4 + 6 + self.encoder0 = nn.Sequential( + nn.Conv2d(in_channels=input_channels, out_channels=input_channels, + kernel_size=3, stride=1, padding=1, groups=input_channels), + nn.Conv2d(in_channels=input_channels, out_channels=64, + kernel_size=1, stride=1), + nn.PReLU(num_parameters=64)) + self.freq_enhance0 = nn.ModuleList() + for d in range(enc_depths[0]): + self.freq_enhance0.add_module(f'block{d}', + FrequencyEnhancementTransformer( + c_dim=32, feat_dim=64, num_head=4, hidden_ratio=2., fftshift=fftshift)) + self.encoder1 = nn.Sequential( + nn.Conv2d(in_channels=64 + 32 + 32, out_channels=128, + kernel_size=3, stride=2, padding=1), + nn.PReLU(num_parameters=128)) + self.freq_enhance1 = nn.ModuleList() + for d in range(enc_depths[1]): + self.freq_enhance1.add_module(f'block{d}', + FrequencyEnhancementTransformer( + c_dim=64, feat_dim=128, num_head=4, hidden_ratio=2., fftshift=fftshift)) + self.encoder2 = nn.Sequential( + nn.Conv2d(in_channels=128 + 64 + 64, out_channels=256, + kernel_size=3, stride=2, padding=1), + nn.PReLU(num_parameters=256)) + self.freq_enhance2 = nn.ModuleList() + for d in range(enc_depths[2]): + self.freq_enhance2.add_module(f'block{d}', + FrequencyEnhancementTransformer( + c_dim=128, feat_dim=256, num_head=4, hidden_ratio=2., fftshift=fftshift)) + + # s0 + s1` + s2` + warp_c00 + warp_c10 + warp_c10` + warp_c11` + warp_c02` + warp_c12` + flow + # 64 + 32 + 16 + 32 + 32 + 16 + 16 + 8 + 8 + 4 = 228 + self.freq_decoder = FrequencyEnhancementDecoder(concat_dim=228, dim=64, fftshift=fftshift) + self.pred = nn.Conv2d(in_channels=64, out_channels=5, kernel_size=3, stride=1, padding=1) + + def get_warped_representations(self, bi_flow, c0, c1, + i0=None, i1=None, time_step=0.5): + flow_0t = bi_flow[:, :2] * time_step + flow_1t = bi_flow[:, 2:4] * (1 - time_step) + warped_c0 = softsplat.FunctionSoftsplat( + tenInput=c0, tenFlow=flow_0t, + tenMetric=None, strType='average') + warped_c1 = softsplat.FunctionSoftsplat( + tenInput=c1, tenFlow=flow_1t, + tenMetric=None, strType='average') + if (i0 is None) and (i1 is None): + return warped_c0, warped_c1 + else: + warped_img0 = softsplat.FunctionSoftsplat( + tenInput=i0, tenFlow=flow_0t, + tenMetric=None, strType='average') + warped_img1 = softsplat.FunctionSoftsplat( + tenInput=i1, tenFlow=flow_1t, + tenMetric=None, strType='average') + flow_0t_1t = torch.cat((flow_0t, flow_1t), 1) + return warped_img0, warped_img1, warped_c0, warped_c1, flow_0t_1t + + + def forward(self, last_i, i0, i1, c0_pyr, c1_pyr, bi_flow_pyr, + time_step=0.5): + warped_img0, warped_img1, warped_c00, warped_c10, flow_0t_1t = \ + self.get_warped_representations( + bi_flow_pyr[0], c0_pyr[0], c1_pyr[0], i0, i1, + time_step=time_step) + input_feat = torch.cat( + (last_i, warped_img0, warped_img1, i0, i1, flow_0t_1t), 1) + s0 = ss0 = self.encoder0(input_feat) # [B, 64,h,w] + for block in self.freq_enhance0: + s0, mm0 = block(c0_pyr[0], c1_pyr[0], s0, bi_flow_pyr[0]) + print('s0', s0.size(), s0.min(), s0.max(), s0.mean()) + + s1 = ss1 = self.encoder1(torch.cat((s0, warped_c00, warped_c10), 1)) # [B, 128,h/2,w/2] + warped_c01, warped_c11 = self.get_warped_representations( + bi_flow_pyr[1], c0_pyr[1], c1_pyr[1], + time_step=time_step) + for block in self.freq_enhance1: + s1, mm1 = block(c0_pyr[1], c1_pyr[1], s1, bi_flow_pyr[1]) + print('s1', s1.size(), s1.min(), s1.max(), s1.mean()) + + s2 = ss2 = self.encoder2(torch.cat((s1, warped_c01, warped_c11), 1)) # [B, 256,h/4,w/4] + warped_c02, warped_c12 = self.get_warped_representations( + bi_flow_pyr[2], c0_pyr[2], c1_pyr[2], + time_step=time_step) + for block in self.freq_enhance2: + s2, mm2 = block(c0_pyr[2], c1_pyr[2], s2, bi_flow_pyr[2]) + print('s2', s2.size(), s2.min(), s2.max(), s2.mean()) + + x = self.freq_decoder(enc_feats=[s0,s1,s2], + warped_feats=[warped_c00,warped_c10, warped_c01,warped_c11, warped_c02,warped_c12], + flow=bi_flow_pyr[0]) + + print('dec_out', x.size(), x.min(), x.max(), x.mean()) + print() + + # prediction + refine = self.pred(x) + refine_res = torch.sigmoid(refine[:, :3]) * 2 - 1 + refine_mask0 = torch.sigmoid(refine[:, 3:4]) + refine_mask1 = torch.sigmoid(refine[:, 4:5]) + merged_img = (warped_img0 * refine_mask0 * (1 - time_step) + \ + warped_img1 * refine_mask1 * time_step) + merged_img = merged_img / (refine_mask0 * (1 - time_step) + \ + refine_mask1 * time_step) + interp_img = merged_img + refine_res + interp_img = torch.clamp(interp_img, 0, 1) + + extra_dict = {} + extra_dict["refine_res"] = refine_res + extra_dict["refine_mask0"] = refine_mask0 + extra_dict["refine_mask1"] = refine_mask1 + extra_dict["warped_img0"] = warped_img0 + extra_dict["warped_img1"] = warped_img1 + extra_dict["merged_img"] = merged_img + extra_dict['c0_pyr'] = c0_pyr + extra_dict['c1_pyr'] = c1_pyr + extra_dict["mm0"] = mm0 + extra_dict["mm1"] = mm1 + extra_dict["mm2"] = mm2 + extra_dict["s0"] = s0 + extra_dict["s1"] = s1 + extra_dict["s2"] = s2 + extra_dict["ss0"] = ss0 + extra_dict["ss1"] = ss1 + extra_dict["ss2"] = ss2 + + return interp_img, extra_dict + +#**************************************************************************************************# +# => Unified model +#**************************************************************************************************# +@register('upr_net_freq2') +class Model(nn.Module): + def __init__(self, pyr_level=3, nr_lvl_skipped=0, fftshift=False, *args, **kwargs): + print('@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@UPR + freq2(008)@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@') + super(Model, self).__init__() + self.pyr_level = pyr_level + self.nr_lvl_skipped = nr_lvl_skipped + self.feat_pyramid = FeatPyramid() + self.motion_estimator = MotionEstimator() + self.synthesis_network = SynthesisNetwork([1,1,1], fftshift) + + def forward_one_lvl(self, + img0, img1, last_feat, last_flow, last_interp=None, + time_step=0.5, skip_me=False): + + # context feature extraction + feat0_pyr = self.feat_pyramid(img0) + feat1_pyr = self.feat_pyramid(img1) + + # bi-directional flow estimation + if not skip_me: + flow, feat = self.motion_estimator( + feat0_pyr[-1], feat1_pyr[-1], + last_feat, last_flow) + else: + flow = last_flow + feat = last_feat + + # frame synthesis + ## optical flow is estimated at 1/4 resolution + ori_resolution_flow = F.interpolate( + input=flow, scale_factor=4.0, + mode="bilinear", align_corners=False) + + ## consturct 3-level flow pyramid for synthesis network + bi_flow_pyr = [] + tmp_flow = ori_resolution_flow + bi_flow_pyr.append(tmp_flow) + for i in range(2): + tmp_flow = F.interpolate( + input=tmp_flow, scale_factor=0.5, + mode="bilinear", align_corners=False) * 0.5 + bi_flow_pyr.append(tmp_flow) + + ## merge warped frames as initial interpolation for frame synthesis + if last_interp is None: + flow_0t = ori_resolution_flow[:, :2] * time_step + flow_1t = ori_resolution_flow[:, 2:4] * (1 - time_step) + warped_img0 = softsplat.FunctionSoftsplat( + tenInput=img0, tenFlow=flow_0t, + tenMetric=None, strType='average') + warped_img1 = softsplat.FunctionSoftsplat( + tenInput=img1, tenFlow=flow_1t, + tenMetric=None, strType='average') + last_interp = warped_img0 * (1 - time_step) \ + + warped_img1 * time_step + + ## do synthesis + interp_img, extra_dict = self.synthesis_network( + last_interp, img0, img1, feat0_pyr, feat1_pyr, bi_flow_pyr, + time_step=time_step) + return flow, feat, interp_img, extra_dict + + def forward(self, img0, img1, time_step, + pyr_level=None, nr_lvl_skipped=None, **kwargs): + + if pyr_level is None: pyr_level = self.pyr_level + if nr_lvl_skipped is None: nr_lvl_skipped = self.nr_lvl_skipped + N, _, H, W = img0.shape + bi_flows = [] + interp_imgs = [] + skipped_levels = [] if nr_lvl_skipped == 0 else\ + list(range(pyr_level))[::-1][-nr_lvl_skipped:] + + padder = InputPadder(img0.shape, divisor=int(4 * 2 ** pyr_level)) + img0, img1 = padder.pad(img0, img1) + N, _, H, W = img0.shape + + # The original input resolution corresponds to level 0. + for level in list(range(pyr_level))[::-1]: + if level != 0: + scale_factor = 1 / 2 ** level + img0_this_lvl = F.interpolate( + input=img0, scale_factor=scale_factor, + mode="bilinear", align_corners=False) + img1_this_lvl = F.interpolate( + input=img1, scale_factor=scale_factor, + mode="bilinear", align_corners=False) + else: + img0_this_lvl = img0 + img1_this_lvl = img1 + + # skip motion estimation, directly use up-sampled optical flow + skip_me = False + + # the lowest-resolution pyramid level + if level == pyr_level - 1: + last_flow = torch.zeros( + (N, 4, H // (2 ** (level+2)), W //(2 ** (level+2))) + ).to(img0.device) + last_feat = torch.zeros( + (N, 128, H // (2 ** (level+2)), W // (2 ** (level+2))) + ).to(img0.device) + last_interp = None + # skip some levels for both motion estimation and frame synthesis + elif level in skipped_levels[:-1]: + continue + # last level (original input resolution), only skip motion estimation + elif (level == 0) and len(skipped_levels) > 0: + if len(skipped_levels) == pyr_level: + last_flow = torch.zeros( + (N, 4, H // 4, W // 4)).to(img0.device) + last_interp = None + else: + resize_factor = 2 ** len(skipped_levels) + last_flow = F.interpolate( + input=flow, scale_factor=resize_factor, + mode="bilinear", align_corners=False) * resize_factor + last_interp = F.interpolate( + input=interp_img, scale_factor=resize_factor, + mode="bilinear", align_corners=False) + skip_me = True + # last level (original input resolution), motion estimation + frame + # synthesis + else: + last_flow = F.interpolate(input=flow, scale_factor=2.0, + mode="bilinear", align_corners=False) * 2 + last_feat = F.interpolate(input=feat, scale_factor=2.0, + mode="bilinear", align_corners=False) * 2 + last_interp = F.interpolate( + input=interp_img, scale_factor=2.0, + mode="bilinear", align_corners=False) + + + flow, feat, interp_img, extra_dict = self.forward_one_lvl( + img0_this_lvl, img1_this_lvl, + last_feat, last_flow, last_interp, + time_step, skip_me=skip_me) + bi_flows.append( + padder.unpad(F.interpolate(input=flow, scale_factor=4.0, + mode="bilinear", align_corners=False))) + interp_imgs.append(padder.unpad(interp_img)) + + # directly up-sample estimated flow to full resolution with bi-linear + # interpolation + bi_flow = F.interpolate( + input=flow, scale_factor=4.0, + mode="bilinear", align_corners=False) + + result_dict = { + "imgt_preds": interp_imgs, 'imgt_pred': interp_imgs[-1].contiguous(),"bi_flows": bi_flows, + "flowfwd": bi_flows[-1][:,:2], "flowbwd": bi_flows[-1][:,2:] + } + return result_dict, extra_dict + + + +if __name__ == "__main__": + pass \ No newline at end of file diff --git a/modules/components/upr_net_freq2/upr_freq_006.py b/modules/components/upr_net_freq2/upr_freq_006.py new file mode 100644 index 0000000000000000000000000000000000000000..32a7be62921845ac49e098b93660cd5c0aad4eef --- /dev/null +++ b/modules/components/upr_net_freq2/upr_freq_006.py @@ -0,0 +1,420 @@ +# upr_freq_006.py + +import torch +import math +import numpy +import torch.nn.functional as F +import torch.nn as nn + +from ..components import register + +import modules.components.upr_net_freq2.softsplat as softsplat +import modules.components.upr_net_freq2.correlation as correlation +from utils.padder import InputPadder + +from modules.components.upr_net_freq2.frequency_enhance import FrequencyEnhancementTransformer, FrequencyEnhancementDecoder + + +#**************************************************************************************************# +# => Feature Pyramid +#**************************************************************************************************# +class FeatPyramid(nn.Module): + """A 3-level feature pyramid, which by default is shared by the motion + estimator and synthesis network. + """ + def __init__(self): + super(FeatPyramid, self).__init__() + self.conv_stage0 = nn.Sequential( + nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_stage1 = nn.Sequential( + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, + stride=2, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_stage2 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, + stride=2, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + + def forward(self, img): + C0 = self.conv_stage0(img) + C1 = self.conv_stage1(C0) + C2 = self.conv_stage2(C1) + return [C0, C1, C2] + + + + +#**************************************************************************************************# +# => Motion Estimation +#**************************************************************************************************# +class MotionEstimator(nn.Module): + """Bi-directional optical flow estimator + 1) construct partial cost volume with the CNN features from the stage 2 of + the feature pyramid; + 2) estimate bi-directional flows, by feeding cost volume, CNN features for + both warped images, CNN feature and estimated flow from previous iteration. + """ + def __init__(self): + super(MotionEstimator, self).__init__() + # (4*2 + 1) ** 2 + 128 * 2 + 128 + 4 = 469 + self.conv_layer1 = nn.Sequential( + nn.Conv2d(in_channels=469, out_channels=320, + kernel_size=1, stride=1, padding=0), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer2 = nn.Sequential( + nn.Conv2d(in_channels=320, out_channels=256, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer3 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=224, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer4 = nn.Sequential( + nn.Conv2d(in_channels=224, out_channels=192, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer5 = nn.Sequential( + nn.Conv2d(in_channels=192, out_channels=128, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer6 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=4, + kernel_size=3, stride=1, padding=1)) + + + def forward(self, feat0, feat1, last_feat, last_flow): + corr_fn=correlation.FunctionCorrelation + feat0 = softsplat.FunctionSoftsplat( + tenInput=feat0, tenFlow=last_flow[:, :2]*0.25*0.5, + tenMetric=None, strType='average') + feat1 = softsplat.FunctionSoftsplat( + tenInput=feat1, tenFlow=last_flow[:, 2:]*0.25*0.5, + tenMetric=None, strType='average') + + volume = F.leaky_relu( + input=corr_fn(tenFirst=feat0, tenSecond=feat1), + negative_slope=0.1, inplace=False) + input_feat = torch.cat([volume, feat0, feat1, last_feat, last_flow], 1) + feat = self.conv_layer1(input_feat) + feat = self.conv_layer2(feat) + feat = self.conv_layer3(feat) + feat = self.conv_layer4(feat) + feat = self.conv_layer5(feat) + flow = self.conv_layer6(feat) + + return flow, feat + + + + +#**************************************************************************************************# +# => Frame Synthesis +#**************************************************************************************************# +class SynthesisNetwork(nn.Module): + def __init__(self, enc_depths=[1,1,1], fftshift=False): + super(SynthesisNetwork, self).__init__() + input_channels = 9 + 4 + 6 + self.encoder0 = nn.Sequential( + nn.Conv2d(in_channels=input_channels, out_channels=input_channels, + kernel_size=3, stride=1, padding=1, groups=input_channels), + nn.Conv2d(in_channels=input_channels, out_channels=64, + kernel_size=1, stride=1), + nn.PReLU(num_parameters=64)) + self.freq_enhance0 = nn.ModuleList() + for d in range(enc_depths[0]): + self.freq_enhance0.add_module(f'block{d}', + FrequencyEnhancementTransformer( + c_dim=32, feat_dim=64, num_head=4, hidden_ratio=2., fftshift=fftshift)) + self.encoder1 = nn.Sequential( + nn.Conv2d(in_channels=64 + 32 + 32, out_channels=64, + kernel_size=2, stride=2, padding=0), + nn.PReLU(num_parameters=64)) + self.freq_enhance1 = nn.ModuleList() + for d in range(enc_depths[1]): + self.freq_enhance1.add_module(f'block{d}', + FrequencyEnhancementTransformer( + c_dim=64, feat_dim=64, num_head=4, hidden_ratio=2., fftshift=fftshift)) + self.encoder2 = nn.Sequential( + nn.Conv2d(in_channels=64 + 64 + 64, out_channels=64, + kernel_size=2, stride=2, padding=0), + nn.PReLU(num_parameters=64)) + self.freq_enhance2 = nn.ModuleList() + for d in range(enc_depths[2]): + self.freq_enhance2.add_module(f'block{d}', + FrequencyEnhancementTransformer( + c_dim=128, feat_dim=64, num_head=4, hidden_ratio=2., fftshift=fftshift)) + + # s0 + s1` + s2` + warp_c00 + warp_c10 + warp_c10` + warp_c11` + warp_c02` + warp_c12` + flow + # 64 + 16 + 4 + 32 + 32 + 16 + 16 + 8 + 8 + 4 = 200 + self.freq_decoder = FrequencyEnhancementDecoder(concat_dim=200, dim=64, fftshift=fftshift) + self.pred = nn.Conv2d(in_channels=64, out_channels=4, kernel_size=3, + stride=1, padding=1) + + + def get_warped_representations(self, bi_flow, c0, c1, + i0=None, i1=None, time_step=0.5): + flow_0t = bi_flow[:, :2] * time_step + flow_1t = bi_flow[:, 2:4] * (1 - time_step) + warped_c0 = softsplat.FunctionSoftsplat( + tenInput=c0, tenFlow=flow_0t, + tenMetric=None, strType='average') + warped_c1 = softsplat.FunctionSoftsplat( + tenInput=c1, tenFlow=flow_1t, + tenMetric=None, strType='average') + if (i0 is None) and (i1 is None): + return warped_c0, warped_c1 + else: + warped_img0 = softsplat.FunctionSoftsplat( + tenInput=i0, tenFlow=flow_0t, + tenMetric=None, strType='average') + warped_img1 = softsplat.FunctionSoftsplat( + tenInput=i1, tenFlow=flow_1t, + tenMetric=None, strType='average') + flow_0t_1t = torch.cat((flow_0t, flow_1t), 1) + return warped_img0, warped_img1, warped_c0, warped_c1, flow_0t_1t + + + def forward(self, last_i, i0, i1, c0_pyr, c1_pyr, bi_flow_pyr, + time_step=0.5): + warped_img0, warped_img1, warped_c00, warped_c10, flow_0t_1t = \ + self.get_warped_representations( + bi_flow_pyr[0], c0_pyr[0], c1_pyr[0], i0, i1, + time_step=time_step) + input_feat = torch.cat( + (last_i, warped_img0, warped_img1, i0, i1, flow_0t_1t), 1) + s0 = self.encoder0(input_feat) # [B, 64,h,w] + for block in self.freq_enhance0: + s0 = block(c0_pyr[0], c1_pyr[0], s0, bi_flow_pyr[0]) + + s1 = self.encoder1(torch.cat((s0, warped_c00, warped_c10), 1)) # [B, 128,h/2,w/2] + warped_c01, warped_c11 = self.get_warped_representations( + bi_flow_pyr[1], c0_pyr[1], c1_pyr[1], + time_step=time_step) + for block in self.freq_enhance1: + s1 = block(c0_pyr[1], c1_pyr[1], s1, bi_flow_pyr[1]) + + s2 = self.encoder2(torch.cat((s1, warped_c01, warped_c11), 1)) # [B, 256,h/4,w/4] + warped_c02, warped_c12 = self.get_warped_representations( + bi_flow_pyr[2], c0_pyr[2], c1_pyr[2], + time_step=time_step) + for block in self.freq_enhance2: + s2 = block(c0_pyr[2], c1_pyr[2], s2, bi_flow_pyr[2]) + + x = self.freq_decoder(enc_feats=[s0,s1,s2], + warped_feats=[warped_c00,warped_c10, warped_c01,warped_c11, warped_c02,warped_c12], + flow=bi_flow_pyr[0]) + + # prediction + refine = self.pred(x) + refine_res = torch.sigmoid(refine[:, :3]) * 2 - 1 + refine_mask = torch.sigmoid(refine[:, 3:]) + merged_img = (warped_img0 * refine_mask + + warped_img1 * (1 - refine_mask)) + interp_img = merged_img + refine_res + # interp_img = torch.clamp(interp_img, 0, 1) + + extra_dict = {} + extra_dict["refine_res"] = refine_res + extra_dict["refine_mask"] = refine_mask + extra_dict["warped_img0"] = warped_img0 + extra_dict["warped_img1"] = warped_img1 + extra_dict["merged_img"] = merged_img + extra_dict["c0_pyr"] = c0_pyr + extra_dict["c1_pyr"] = c1_pyr + extra_dict["syn_pyr"] = [s0,s1,s2] + + return interp_img, extra_dict + + + +#**************************************************************************************************# +# => Unified model +#**************************************************************************************************# +@register('upr_net_freq2') +class Model(nn.Module): + def __init__(self, pyr_level=3, nr_lvl_skipped=0, fftshift=False, *args, **kwargs): + print('@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@UPR + freq2(006)@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@') + super(Model, self).__init__() + self.pyr_level = pyr_level + self.nr_lvl_skipped = nr_lvl_skipped + self.feat_pyramid = FeatPyramid() + self.motion_estimator = MotionEstimator() + self.synthesis_network = SynthesisNetwork([1,1,1], fftshift) + + def forward_one_lvl(self, + img0, img1, last_feat, last_flow, last_interp=None, + time_step=0.5, skip_me=False): + + # context feature extraction + feat0_pyr = self.feat_pyramid(img0) + feat1_pyr = self.feat_pyramid(img1) + + # bi-directional flow estimation + if not skip_me: + flow, feat = self.motion_estimator( + feat0_pyr[-1], feat1_pyr[-1], + last_feat, last_flow) + else: + flow = last_flow + feat = last_feat + + # frame synthesis + ## optical flow is estimated at 1/4 resolution + ori_resolution_flow = F.interpolate( + input=flow, scale_factor=4.0, + mode="bilinear", align_corners=False) + + ## consturct 3-level flow pyramid for synthesis network + bi_flow_pyr = [] + tmp_flow = ori_resolution_flow + bi_flow_pyr.append(tmp_flow) + for i in range(2): + tmp_flow = F.interpolate( + input=tmp_flow, scale_factor=0.5, + mode="bilinear", align_corners=False) * 0.5 + bi_flow_pyr.append(tmp_flow) + + ## merge warped frames as initial interpolation for frame synthesis + if last_interp is None: + flow_0t = ori_resolution_flow[:, :2] * time_step + flow_1t = ori_resolution_flow[:, 2:4] * (1 - time_step) + warped_img0 = softsplat.FunctionSoftsplat( + tenInput=img0, tenFlow=flow_0t, + tenMetric=None, strType='average') + warped_img1 = softsplat.FunctionSoftsplat( + tenInput=img1, tenFlow=flow_1t, + tenMetric=None, strType='average') + last_interp = warped_img0 * (1 - time_step) \ + + warped_img1 * time_step + + ## do synthesis + interp_img, extra_dict = self.synthesis_network( + last_interp, img0, img1, feat0_pyr, feat1_pyr, bi_flow_pyr, + time_step=time_step) + return flow, feat, interp_img, extra_dict + + def forward(self, img0, img1, time_step, + pyr_level=None, nr_lvl_skipped=None, **kwargs): + + if pyr_level is None: pyr_level = self.pyr_level + if nr_lvl_skipped is None: nr_lvl_skipped = self.nr_lvl_skipped + N, _, H, W = img0.shape + bi_flows = [] + interp_imgs = [] + skipped_levels = [] if nr_lvl_skipped == 0 else\ + list(range(pyr_level))[::-1][-nr_lvl_skipped:] + + padder = InputPadder(img0.shape, divisor=int(4 * 2 ** pyr_level)) + img0, img1 = padder.pad(img0, img1) + N, _, H, W = img0.shape + + # The original input resolution corresponds to level 0. + for level in list(range(pyr_level))[::-1]: + if level != 0: + scale_factor = 1 / 2 ** level + img0_this_lvl = F.interpolate( + input=img0, scale_factor=scale_factor, + mode="bilinear", align_corners=False) + img1_this_lvl = F.interpolate( + input=img1, scale_factor=scale_factor, + mode="bilinear", align_corners=False) + else: + img0_this_lvl = img0 + img1_this_lvl = img1 + + # skip motion estimation, directly use up-sampled optical flow + skip_me = False + + # the lowest-resolution pyramid level + if level == pyr_level - 1: + last_flow = torch.zeros( + (N, 4, H // (2 ** (level+2)), W //(2 ** (level+2))) + ).to(img0.device) + last_feat = torch.zeros( + (N, 128, H // (2 ** (level+2)), W // (2 ** (level+2))) + ).to(img0.device) + last_interp = None + # skip some levels for both motion estimation and frame synthesis + elif level in skipped_levels[:-1]: + continue + # last level (original input resolution), only skip motion estimation + elif (level == 0) and len(skipped_levels) > 0: + if len(skipped_levels) == pyr_level: + last_flow = torch.zeros( + (N, 4, H // 4, W // 4)).to(img0.device) + last_interp = None + else: + resize_factor = 2 ** len(skipped_levels) + last_flow = F.interpolate( + input=flow, scale_factor=resize_factor, + mode="bilinear", align_corners=False) * resize_factor + last_interp = F.interpolate( + input=interp_img, scale_factor=resize_factor, + mode="bilinear", align_corners=False) + skip_me = True + # last level (original input resolution), motion estimation + frame + # synthesis + else: + last_flow = F.interpolate(input=flow, scale_factor=2.0, + mode="bilinear", align_corners=False) * 2 + last_feat = F.interpolate(input=feat, scale_factor=2.0, + mode="bilinear", align_corners=False) * 2 + last_interp = F.interpolate( + input=interp_img, scale_factor=2.0, + mode="bilinear", align_corners=False) + + + flow, feat, interp_img, _ = self.forward_one_lvl( + img0_this_lvl, img1_this_lvl, + last_feat, last_flow, last_interp, + time_step, skip_me=skip_me) + bi_flows.append( + padder.unpad(F.interpolate(input=flow, scale_factor=4.0, + mode="bilinear", align_corners=False))) + interp_imgs.append(padder.unpad(interp_img)) + + # directly up-sample estimated flow to full resolution with bi-linear + # interpolation + bi_flow = F.interpolate( + input=flow, scale_factor=4.0, + mode="bilinear", align_corners=False) + + result_dict = { + "imgt_preds": interp_imgs, 'imgt_pred': interp_imgs[-1].contiguous(),"bi_flows": bi_flows, + "flowfwd": bi_flows[-1][:,:2], "flowbwd": bi_flows[-1][:,2:] + } + return result_dict + + + +if __name__ == "__main__": + pass \ No newline at end of file diff --git a/modules/components/upr_net_freq2/upr_freq_007.py b/modules/components/upr_net_freq2/upr_freq_007.py new file mode 100644 index 0000000000000000000000000000000000000000..20c211185f5c597b11f57bcd86327e1cfd3a30c1 --- /dev/null +++ b/modules/components/upr_net_freq2/upr_freq_007.py @@ -0,0 +1,420 @@ +# upr_freq_007.py + +import torch +import math +import numpy +import torch.nn.functional as F +import torch.nn as nn + +from ..components import register + +import modules.components.upr_net_freq2.softsplat as softsplat +import modules.components.upr_net_freq2.correlation as correlation +from utils.padder import InputPadder + +from modules.components.upr_net_freq2.frequency_enhance import FrequencyEnhancementTransformer, FrequencyEnhancementDecoder + + +#**************************************************************************************************# +# => Feature Pyramid +#**************************************************************************************************# +class FeatPyramid(nn.Module): + """A 3-level feature pyramid, which by default is shared by the motion + estimator and synthesis network. + """ + def __init__(self): + super(FeatPyramid, self).__init__() + self.conv_stage0 = nn.Sequential( + nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_stage1 = nn.Sequential( + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, + stride=2, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_stage2 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, + stride=2, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + + def forward(self, img): + C0 = self.conv_stage0(img) + C1 = self.conv_stage1(C0) + C2 = self.conv_stage2(C1) + return [C0, C1, C2] + + + + +#**************************************************************************************************# +# => Motion Estimation +#**************************************************************************************************# +class MotionEstimator(nn.Module): + """Bi-directional optical flow estimator + 1) construct partial cost volume with the CNN features from the stage 2 of + the feature pyramid; + 2) estimate bi-directional flows, by feeding cost volume, CNN features for + both warped images, CNN feature and estimated flow from previous iteration. + """ + def __init__(self): + super(MotionEstimator, self).__init__() + # (4*2 + 1) ** 2 + 128 * 2 + 128 + 4 = 469 + self.conv_layer1 = nn.Sequential( + nn.Conv2d(in_channels=469, out_channels=320, + kernel_size=1, stride=1, padding=0), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer2 = nn.Sequential( + nn.Conv2d(in_channels=320, out_channels=256, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer3 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=224, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer4 = nn.Sequential( + nn.Conv2d(in_channels=224, out_channels=192, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer5 = nn.Sequential( + nn.Conv2d(in_channels=192, out_channels=128, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer6 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=4, + kernel_size=3, stride=1, padding=1)) + + + def forward(self, feat0, feat1, last_feat, last_flow): + corr_fn=correlation.FunctionCorrelation + feat0 = softsplat.FunctionSoftsplat( + tenInput=feat0, tenFlow=last_flow[:, :2]*0.25*0.5, + tenMetric=None, strType='average') + feat1 = softsplat.FunctionSoftsplat( + tenInput=feat1, tenFlow=last_flow[:, 2:]*0.25*0.5, + tenMetric=None, strType='average') + + volume = F.leaky_relu( + input=corr_fn(tenFirst=feat0, tenSecond=feat1), + negative_slope=0.1, inplace=False) + input_feat = torch.cat([volume, feat0, feat1, last_feat, last_flow], 1) + feat = self.conv_layer1(input_feat) + feat = self.conv_layer2(feat) + feat = self.conv_layer3(feat) + feat = self.conv_layer4(feat) + feat = self.conv_layer5(feat) + flow = self.conv_layer6(feat) + + return flow, feat + + + + +#**************************************************************************************************# +# => Frame Synthesis +#**************************************************************************************************# +class SynthesisNetwork(nn.Module): + def __init__(self, enc_depths=[1,1,1], fftshift=False): + super(SynthesisNetwork, self).__init__() + input_channels = 9 + 4 + 6 + self.encoder0 = nn.Sequential( + nn.Conv2d(in_channels=input_channels, out_channels=input_channels, + kernel_size=3, stride=1, padding=1, groups=input_channels), + nn.Conv2d(in_channels=input_channels, out_channels=64, + kernel_size=1, stride=1), + nn.PReLU(num_parameters=64)) + self.freq_enhance0 = nn.ModuleList() + for d in range(enc_depths[0]): + self.freq_enhance0.add_module(f'block{d}', + FrequencyEnhancementTransformer( + c_dim=32, feat_dim=64, num_head=4, hidden_ratio=2., fftshift=fftshift)) + self.encoder1 = nn.Sequential( + nn.Conv2d(in_channels=64 + 32 + 32, out_channels=64, + kernel_size=2, stride=2, padding=0), + nn.PReLU(num_parameters=64)) + self.freq_enhance1 = nn.ModuleList() + for d in range(enc_depths[1]): + self.freq_enhance1.add_module(f'block{d}', + FrequencyEnhancementTransformer( + c_dim=64, feat_dim=64, num_head=4, hidden_ratio=2., fftshift=fftshift)) + self.encoder2 = nn.Sequential( + nn.Conv2d(in_channels=64 + 64 + 64, out_channels=64, + kernel_size=2, stride=2, padding=0), + nn.PReLU(num_parameters=64)) + self.freq_enhance2 = nn.ModuleList() + for d in range(enc_depths[2]): + self.freq_enhance2.add_module(f'block{d}', + FrequencyEnhancementTransformer( + c_dim=128, feat_dim=64, num_head=4, hidden_ratio=2., fftshift=fftshift)) + + # s0 + s1` + s2` + warp_c00 + warp_c10 + warp_c10` + warp_c11` + warp_c02` + warp_c12` + flow + # 64 + 16 + 4 + 32 + 32 + 16 + 16 + 8 + 8 + 4 = 200 + self.freq_decoder = FrequencyEnhancementDecoder(concat_dim=200, dim=64, fftshift=fftshift) + self.pred = nn.Conv2d(in_channels=64, out_channels=4, kernel_size=3, + stride=1, padding=1) + + + def get_warped_representations(self, bi_flow, c0, c1, + i0=None, i1=None, time_step=0.5): + flow_0t = bi_flow[:, :2] * time_step + flow_1t = bi_flow[:, 2:4] * (1 - time_step) + warped_c0 = softsplat.FunctionSoftsplat( + tenInput=c0, tenFlow=flow_0t, + tenMetric=None, strType='average') + warped_c1 = softsplat.FunctionSoftsplat( + tenInput=c1, tenFlow=flow_1t, + tenMetric=None, strType='average') + if (i0 is None) and (i1 is None): + return warped_c0, warped_c1 + else: + warped_img0 = softsplat.FunctionSoftsplat( + tenInput=i0, tenFlow=flow_0t, + tenMetric=None, strType='average') + warped_img1 = softsplat.FunctionSoftsplat( + tenInput=i1, tenFlow=flow_1t, + tenMetric=None, strType='average') + flow_0t_1t = torch.cat((flow_0t, flow_1t), 1) + return warped_img0, warped_img1, warped_c0, warped_c1, flow_0t_1t + + + def forward(self, last_i, i0, i1, c0_pyr, c1_pyr, bi_flow_pyr, + time_step=0.5): + warped_img0, warped_img1, warped_c00, warped_c10, flow_0t_1t = \ + self.get_warped_representations( + bi_flow_pyr[0], c0_pyr[0], c1_pyr[0], i0, i1, + time_step=time_step) + input_feat = torch.cat( + (last_i, warped_img0, warped_img1, i0, i1, flow_0t_1t), 1) + s0 = self.encoder0(input_feat) # [B, 64,h,w] + for block in self.freq_enhance0: + s0 = block(c0_pyr[0], c1_pyr[0], s0, bi_flow_pyr[0]) + + s1 = self.encoder1(torch.cat((s0, warped_c00, warped_c10), 1)) # [B, 128,h/2,w/2] + warped_c01, warped_c11 = self.get_warped_representations( + bi_flow_pyr[1], c0_pyr[1], c1_pyr[1], + time_step=time_step) + for block in self.freq_enhance1: + s1 = block(c0_pyr[1], c1_pyr[1], s1, bi_flow_pyr[1]) + + s2 = self.encoder2(torch.cat((s1, warped_c01, warped_c11), 1)) # [B, 256,h/4,w/4] + warped_c02, warped_c12 = self.get_warped_representations( + bi_flow_pyr[2], c0_pyr[2], c1_pyr[2], + time_step=time_step) + for block in self.freq_enhance2: + s2 = block(c0_pyr[2], c1_pyr[2], s2, bi_flow_pyr[2]) + + x = self.freq_decoder(enc_feats=[s0,s1,s2], + warped_feats=[warped_c00,warped_c10, warped_c01,warped_c11, warped_c02,warped_c12], + flow=bi_flow_pyr[0]) + + # prediction + refine = self.pred(x) + refine_res = torch.sigmoid(refine[:, :3]) * 2 - 1 + refine_mask = torch.sigmoid(refine[:, 3:]) + merged_img = (warped_img0 * refine_mask + + warped_img1 * (1 - refine_mask)) + interp_img = merged_img + refine_res + # interp_img = torch.clamp(interp_img, 0, 1) + + extra_dict = {} + extra_dict["refine_res"] = refine_res + extra_dict["refine_mask"] = refine_mask + extra_dict["warped_img0"] = warped_img0 + extra_dict["warped_img1"] = warped_img1 + extra_dict["merged_img"] = merged_img + extra_dict["c0_pyr"] = c0_pyr + extra_dict["c1_pyr"] = c1_pyr + extra_dict["syn_pyr"] = [s0,s1,s2] + + return interp_img, extra_dict + + + +#**************************************************************************************************# +# => Unified model +#**************************************************************************************************# +@register('upr_net_freq2') +class Model(nn.Module): + def __init__(self, pyr_level=3, nr_lvl_skipped=0, fftshift=False, *args, **kwargs): + print('@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@UPR + freq2(007)@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@') + super(Model, self).__init__() + self.pyr_level = pyr_level + self.nr_lvl_skipped = nr_lvl_skipped + self.feat_pyramid = FeatPyramid() + self.motion_estimator = MotionEstimator() + self.synthesis_network = SynthesisNetwork([1,1,1], fftshift) + + def forward_one_lvl(self, + img0, img1, last_feat, last_flow, last_interp=None, + time_step=0.5, skip_me=False): + + # context feature extraction + feat0_pyr = self.feat_pyramid(img0) + feat1_pyr = self.feat_pyramid(img1) + + # bi-directional flow estimation + if not skip_me: + flow, feat = self.motion_estimator( + feat0_pyr[-1], feat1_pyr[-1], + last_feat, last_flow) + else: + flow = last_flow + feat = last_feat + + # frame synthesis + ## optical flow is estimated at 1/4 resolution + ori_resolution_flow = F.interpolate( + input=flow, scale_factor=4.0, + mode="bilinear", align_corners=False) + + ## consturct 3-level flow pyramid for synthesis network + bi_flow_pyr = [] + tmp_flow = ori_resolution_flow + bi_flow_pyr.append(tmp_flow) + for i in range(2): + tmp_flow = F.interpolate( + input=tmp_flow, scale_factor=0.5, + mode="bilinear", align_corners=False) * 0.5 + bi_flow_pyr.append(tmp_flow) + + ## merge warped frames as initial interpolation for frame synthesis + if last_interp is None: + flow_0t = ori_resolution_flow[:, :2] * time_step + flow_1t = ori_resolution_flow[:, 2:4] * (1 - time_step) + warped_img0 = softsplat.FunctionSoftsplat( + tenInput=img0, tenFlow=flow_0t, + tenMetric=None, strType='average') + warped_img1 = softsplat.FunctionSoftsplat( + tenInput=img1, tenFlow=flow_1t, + tenMetric=None, strType='average') + last_interp = warped_img0 * (1 - time_step) \ + + warped_img1 * time_step + + ## do synthesis + interp_img, extra_dict = self.synthesis_network( + last_interp, img0, img1, feat0_pyr, feat1_pyr, bi_flow_pyr, + time_step=time_step) + return flow, feat, interp_img, extra_dict + + def forward(self, img0, img1, time_step, + pyr_level=None, nr_lvl_skipped=None, **kwargs): + + if pyr_level is None: pyr_level = self.pyr_level + if nr_lvl_skipped is None: nr_lvl_skipped = self.nr_lvl_skipped + N, _, H, W = img0.shape + bi_flows = [] + interp_imgs = [] + skipped_levels = [] if nr_lvl_skipped == 0 else\ + list(range(pyr_level))[::-1][-nr_lvl_skipped:] + + padder = InputPadder(img0.shape, divisor=int(4 * 2 ** pyr_level)) + img0, img1 = padder.pad(img0, img1) + N, _, H, W = img0.shape + + # The original input resolution corresponds to level 0. + for level in list(range(pyr_level))[::-1]: + if level != 0: + scale_factor = 1 / 2 ** level + img0_this_lvl = F.interpolate( + input=img0, scale_factor=scale_factor, + mode="bilinear", align_corners=False) + img1_this_lvl = F.interpolate( + input=img1, scale_factor=scale_factor, + mode="bilinear", align_corners=False) + else: + img0_this_lvl = img0 + img1_this_lvl = img1 + + # skip motion estimation, directly use up-sampled optical flow + skip_me = False + + # the lowest-resolution pyramid level + if level == pyr_level - 1: + last_flow = torch.zeros( + (N, 4, H // (2 ** (level+2)), W //(2 ** (level+2))) + ).to(img0.device) + last_feat = torch.zeros( + (N, 128, H // (2 ** (level+2)), W // (2 ** (level+2))) + ).to(img0.device) + last_interp = None + # skip some levels for both motion estimation and frame synthesis + elif level in skipped_levels[:-1]: + continue + # last level (original input resolution), only skip motion estimation + elif (level == 0) and len(skipped_levels) > 0: + if len(skipped_levels) == pyr_level: + last_flow = torch.zeros( + (N, 4, H // 4, W // 4)).to(img0.device) + last_interp = None + else: + resize_factor = 2 ** len(skipped_levels) + last_flow = F.interpolate( + input=flow, scale_factor=resize_factor, + mode="bilinear", align_corners=False) * resize_factor + last_interp = F.interpolate( + input=interp_img, scale_factor=resize_factor, + mode="bilinear", align_corners=False) + skip_me = True + # last level (original input resolution), motion estimation + frame + # synthesis + else: + last_flow = F.interpolate(input=flow, scale_factor=2.0, + mode="bilinear", align_corners=False) * 2 + last_feat = F.interpolate(input=feat, scale_factor=2.0, + mode="bilinear", align_corners=False) * 2 + last_interp = F.interpolate( + input=interp_img, scale_factor=2.0, + mode="bilinear", align_corners=False) + + + flow, feat, interp_img, _ = self.forward_one_lvl( + img0_this_lvl, img1_this_lvl, + last_feat, last_flow, last_interp, + time_step, skip_me=skip_me) + bi_flows.append( + padder.unpad(F.interpolate(input=flow, scale_factor=4.0, + mode="bilinear", align_corners=False))) + interp_imgs.append(padder.unpad(interp_img)) + + # directly up-sample estimated flow to full resolution with bi-linear + # interpolation + bi_flow = F.interpolate( + input=flow, scale_factor=4.0, + mode="bilinear", align_corners=False) + + result_dict = { + "imgt_preds": interp_imgs, 'imgt_pred': interp_imgs[-1].contiguous(),"bi_flows": bi_flows, + "flowfwd": bi_flows[-1][:,:2], "flowbwd": bi_flows[-1][:,2:] + } + return result_dict + + + +if __name__ == "__main__": + pass \ No newline at end of file diff --git a/modules/components/upr_net_gan/__pycache__/__init__.cpython-310.pyc b/modules/components/upr_net_gan/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..043b343941d3ab3e4c4882ccbc1ea62f1289cd2a Binary files /dev/null and b/modules/components/upr_net_gan/__pycache__/__init__.cpython-310.pyc differ diff --git a/modules/components/upr_net_gan/__pycache__/costvol.cpython-310.pyc b/modules/components/upr_net_gan/__pycache__/costvol.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8cf49bd397896cb5eb06273cb19929ce703b6efe Binary files /dev/null and b/modules/components/upr_net_gan/__pycache__/costvol.cpython-310.pyc differ diff --git a/modules/components/upr_net_gan/__pycache__/discriminator.cpython-310.pyc b/modules/components/upr_net_gan/__pycache__/discriminator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcaede1900abcc4eca1ebd6617790b4cb3b323e9 Binary files /dev/null and b/modules/components/upr_net_gan/__pycache__/discriminator.cpython-310.pyc differ diff --git a/modules/components/upr_net_gan/__pycache__/upr.cpython-310.pyc b/modules/components/upr_net_gan/__pycache__/upr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcc40a1492e9209f1f7b41805cd1377d1e16c4bc Binary files /dev/null and b/modules/components/upr_net_gan/__pycache__/upr.cpython-310.pyc differ diff --git a/modules/components/upr_net_mod/__init__.py b/modules/components/upr_net_mod/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1ebb5b7f654e4c78b9fc6b4fd0f1f5c80dea4981 --- /dev/null +++ b/modules/components/upr_net_mod/__init__.py @@ -0,0 +1 @@ +from .upr import Model diff --git a/modules/components/upr_net_mod/__pycache__/__init__.cpython-310.pyc b/modules/components/upr_net_mod/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b17e18bc82c6088916cc27d405becab574ef818 Binary files /dev/null and b/modules/components/upr_net_mod/__pycache__/__init__.cpython-310.pyc differ diff --git a/modules/components/upr_net_mod/__pycache__/__init__.cpython-38.pyc b/modules/components/upr_net_mod/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0eef71ab074be94b1b7ac595c29ae83bcde2fc05 Binary files /dev/null and b/modules/components/upr_net_mod/__pycache__/__init__.cpython-38.pyc differ diff --git a/modules/components/upr_net_mod/__pycache__/__init__.cpython-39.pyc b/modules/components/upr_net_mod/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55da2c2dd6f66e1cc9e84dc34e393b8b7df59d57 Binary files /dev/null and b/modules/components/upr_net_mod/__pycache__/__init__.cpython-39.pyc differ diff --git a/modules/components/upr_net_mod/__pycache__/costvol.cpython-310.pyc b/modules/components/upr_net_mod/__pycache__/costvol.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4e8483be42706c61e5d08ead23274a1e9f45aae Binary files /dev/null and b/modules/components/upr_net_mod/__pycache__/costvol.cpython-310.pyc differ diff --git a/modules/components/upr_net_mod/__pycache__/costvol.cpython-38.pyc b/modules/components/upr_net_mod/__pycache__/costvol.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6472ef2cbc77aa1ac5ba870494123ae959713a75 Binary files /dev/null and b/modules/components/upr_net_mod/__pycache__/costvol.cpython-38.pyc differ diff --git a/modules/components/upr_net_mod/__pycache__/costvol.cpython-39.pyc b/modules/components/upr_net_mod/__pycache__/costvol.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53bbda3d93912c5d928b2d4d22eed1e87199493b Binary files /dev/null and b/modules/components/upr_net_mod/__pycache__/costvol.cpython-39.pyc differ diff --git a/modules/components/upr_net_mod/__pycache__/upr.cpython-310.pyc b/modules/components/upr_net_mod/__pycache__/upr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3daace32e5315bfdaa813b5a1ba8e115de85f595 Binary files /dev/null and b/modules/components/upr_net_mod/__pycache__/upr.cpython-310.pyc differ diff --git a/modules/components/upr_net_mod/__pycache__/upr.cpython-38.pyc b/modules/components/upr_net_mod/__pycache__/upr.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..633646cb67ba651fcd06056126f41973af933636 Binary files /dev/null and b/modules/components/upr_net_mod/__pycache__/upr.cpython-38.pyc differ diff --git a/modules/components/upr_net_mod/__pycache__/upr.cpython-39.pyc b/modules/components/upr_net_mod/__pycache__/upr.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0af428dcb6168822989ddb82fa811650d4cc1b0c Binary files /dev/null and b/modules/components/upr_net_mod/__pycache__/upr.cpython-39.pyc differ diff --git a/modules/components/upr_net_mod/backwarp.py b/modules/components/upr_net_mod/backwarp.py new file mode 100644 index 0000000000000000000000000000000000000000..e99a0a5c1b658e81536825451b865b39c45bc9c4 --- /dev/null +++ b/modules/components/upr_net_mod/backwarp.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python + +import torch + + +########################################################## + + +objBackwarpcache = {} + + +def backwarp(tenIn:torch.Tensor, tenFlow:torch.Tensor): + if 'grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3]) not in objBackwarpcache: + tenHor = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[3], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, 1, -1).repeat(1, 1, tenFlow.shape[2], 1) + tenVer = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[2], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, -1, 1).repeat(1, 1, 1, tenFlow.shape[3]) + + objBackwarpcache['grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3])] = torch.cat([tenHor, tenVer], 1) + # end + + if tenFlow.shape[3] == tenFlow.shape[2]: + tenFlow = tenFlow * (2.0 / ((tenFlow.shape[3] and tenFlow.shape[2]) - 1.0)) + + elif tenFlow.shape[3] != tenFlow.shape[2]: + tenFlow = tenFlow * torch.tensor(data=[2.0 / (tenFlow.shape[3] - 1.0), 2.0 / (tenFlow.shape[2] - 1.0)], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 2, 1, 1) + + # end + + return torch.nn.functional.grid_sample(input=tenIn, grid=(objBackwarpcache['grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3])] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=True) +# end diff --git a/modules/components/upr_net_mod/correlation.py b/modules/components/upr_net_mod/correlation.py new file mode 100644 index 0000000000000000000000000000000000000000..1d1c92e2ef7dd885f25b30a3b2e4ed25c6a3889e --- /dev/null +++ b/modules/components/upr_net_mod/correlation.py @@ -0,0 +1,415 @@ +#!/usr/bin/env python + +import torch + +import cupy +import re + +kernel_Correlation_rearrange = ''' + extern "C" __global__ void kernel_Correlation_rearrange( + const int n, + const float* input, + float* output + ) { + int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; + + if (intIndex >= n) { + return; + } + + int intSample = blockIdx.z; + int intChannel = blockIdx.y; + + float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex]; + + __syncthreads(); + + int intPaddedY = (intIndex / SIZE_3(input)) + 4; + int intPaddedX = (intIndex % SIZE_3(input)) + 4; + int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX; + + output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue; + } +''' + +kernel_Correlation_updateOutput = ''' + extern "C" __global__ void kernel_Correlation_updateOutput( + const int n, + const float* rbot0, + const float* rbot1, + float* top + ) { + extern __shared__ char patch_data_char[]; + + float *patch_data = (float *)patch_data_char; + + // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1 + int x1 = blockIdx.x + 4; + int y1 = blockIdx.y + 4; + int item = blockIdx.z; + int ch_off = threadIdx.x; + + // Load 3D patch into shared shared memory + for (int j = 0; j < 1; j++) { // HEIGHT + for (int i = 0; i < 1; i++) { // WIDTH + int ji_off = (j + i) * SIZE_3(rbot0); + for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS + int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch; + int idxPatchData = ji_off + ch; + patch_data[idxPatchData] = rbot0[idx1]; + } + } + } + + __syncthreads(); + + __shared__ float sum[32]; + + // Compute correlation + for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) { + sum[ch_off] = 0; + + int s2o = top_channel % 9 - 4; + int s2p = top_channel / 9 - 4; + + for (int j = 0; j < 1; j++) { // HEIGHT + for (int i = 0; i < 1; i++) { // WIDTH + int ji_off = (j + i) * SIZE_3(rbot0); + for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS + int x2 = x1 + s2o; + int y2 = y1 + s2p; + + int idxPatchData = ji_off + ch; + int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch; + + sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2]; + } + } + } + + __syncthreads(); + + if (ch_off == 0) { + float total_sum = 0; + for (int idx = 0; idx < 32; idx++) { + total_sum += sum[idx]; + } + const int sumelems = SIZE_3(rbot0); + const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x; + top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems; + } + } + } +''' + +kernel_Correlation_updateGradFirst = ''' + #define ROUND_OFF 50000 + + extern "C" __global__ void kernel_Correlation_updateGradFirst( + const int n, + const int intSample, + const float* rbot0, + const float* rbot1, + const float* gradOutput, + float* gradFirst, + float* gradSecond + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + int n = intIndex % SIZE_1(gradFirst); // channels + int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 4; // w-pos + int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 4; // h-pos + + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = round_off; + + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) + int ymin = (m - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) + + // Same here: + int xmax = (l - 4 + round_off_s1) - round_off; // floor (l - 4) + int ymax = (m - 4 + round_off_s1) - round_off; // floor (m - 4) + + float sum = 0; + if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { + xmin = max(0,xmin); + xmax = min(SIZE_3(gradOutput)-1,xmax); + + ymin = max(0,ymin); + ymax = min(SIZE_2(gradOutput)-1,ymax); + + for (int p = -4; p <= 4; p++) { + for (int o = -4; o <= 4; o++) { + // Get rbot1 data: + int s2o = o; + int s2p = p; + int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n; + float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n] + + // Index offset for gradOutput in following loops: + int op = (p+4) * 9 + (o+4); // index[o,p] + int idxopoffset = (intSample * SIZE_1(gradOutput) + op); + + for (int y = ymin; y <= ymax; y++) { + for (int x = xmin; x <= xmax; x++) { + int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] + sum += gradOutput[idxgradOutput] * bot1tmp; + } + } + } + } + } + const int sumelems = SIZE_1(gradFirst); + const int bot0index = ((n * SIZE_2(gradFirst)) + (m-4)) * SIZE_3(gradFirst) + (l-4); + gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems; + } } +''' + +kernel_Correlation_updateGradSecond = ''' + #define ROUND_OFF 50000 + + extern "C" __global__ void kernel_Correlation_updateGradSecond( + const int n, + const int intSample, + const float* rbot0, + const float* rbot1, + const float* gradOutput, + float* gradFirst, + float* gradSecond + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + int n = intIndex % SIZE_1(gradSecond); // channels + int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 4; // w-pos + int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 4; // h-pos + + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = round_off; + + float sum = 0; + for (int p = -4; p <= 4; p++) { + for (int o = -4; o <= 4; o++) { + int s2o = o; + int s2p = p; + + //Get X,Y ranges and clamp + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 4 - s2o + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) + int ymin = (m - 4 - s2p + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) + + // Same here: + int xmax = (l - 4 - s2o + round_off_s1) - round_off; // floor (l - 4 - s2o) + int ymax = (m - 4 - s2p + round_off_s1) - round_off; // floor (m - 4 - s2p) + + if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { + xmin = max(0,xmin); + xmax = min(SIZE_3(gradOutput)-1,xmax); + + ymin = max(0,ymin); + ymax = min(SIZE_2(gradOutput)-1,ymax); + + // Get rbot0 data: + int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n; + float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n] + + // Index offset for gradOutput in following loops: + int op = (p+4) * 9 + (o+4); // index[o,p] + int idxopoffset = (intSample * SIZE_1(gradOutput) + op); + + for (int y = ymin; y <= ymax; y++) { + for (int x = xmin; x <= xmax; x++) { + int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] + sum += gradOutput[idxgradOutput] * bot0tmp; + } + } + } + } + } + const int sumelems = SIZE_1(gradSecond); + const int bot1index = ((n * SIZE_2(gradSecond)) + (m-4)) * SIZE_3(gradSecond) + (l-4); + gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems; + } } +''' + + +def cupy_kernel(strFunction, objVariables): + strKernel = globals()[strFunction] + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + # end + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg])) + # end + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) + + if objMatch is None: + break + # end + + intArgs = int(objMatch.group(2)) + strArgs = objMatch.group(4).split(',') + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + strIndex = ['((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str( + intStrides[intArg]) + ')' for intArg in range(intArgs)] + + strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') + # end + + return strKernel + + +# end + +@cupy.memoize(for_each_device=True) +def cupy_launch(strFunction, strKernel): + return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) + + +# end + +class _FunctionCorrelation(torch.autograd.Function): + @staticmethod + def forward(self, first, second): + rbot0 = first.new_zeros([first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1]]) + rbot1 = first.new_zeros([first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1]]) + + self.save_for_backward(first, second, rbot0, rbot1) + + assert (first.is_contiguous() == True) + assert (second.is_contiguous() == True) + + output = first.new_zeros([first.shape[0], 81, first.shape[2], first.shape[3]]) + + if first.is_cuda == True: + n = first.shape[2] * first.shape[3] + cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { + 'input': first, + 'output': rbot0 + }))( + grid=tuple([int((n + 16 - 1) / 16), first.shape[1], first.shape[0]]), + block=tuple([16, 1, 1]), + args=[n, first.data_ptr(), rbot0.data_ptr()] + ) + + n = second.shape[2] * second.shape[3] + cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { + 'input': second, + 'output': rbot1 + }))( + grid=tuple([int((n + 16 - 1) / 16), second.shape[1], second.shape[0]]), + block=tuple([16, 1, 1]), + args=[n, second.data_ptr(), rbot1.data_ptr()] + ) + + n = output.shape[1] * output.shape[2] * output.shape[3] + cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'top': output + }))( + grid=tuple([output.shape[3], output.shape[2], output.shape[0]]), + block=tuple([32, 1, 1]), + shared_mem=first.shape[1] * 4, + args=[n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr()] + ) + + elif first.is_cuda == False: + raise NotImplementedError() + + # end + + return output + + # end + + @staticmethod + def backward(self, gradOutput): + first, second, rbot0, rbot1 = self.saved_tensors + + assert (gradOutput.is_contiguous() == True) + + gradFirst = first.new_zeros([first.shape[0], first.shape[1], first.shape[2], first.shape[3]]) if \ + self.needs_input_grad[0] == True else None + gradSecond = first.new_zeros([first.shape[0], first.shape[1], first.shape[2], first.shape[3]]) if \ + self.needs_input_grad[1] == True else None + + if first.is_cuda == True: + if gradFirst is not None: + for intSample in range(first.shape[0]): + n = first.shape[1] * first.shape[2] * first.shape[3] + cupy_launch('kernel_Correlation_updateGradFirst', + cupy_kernel('kernel_Correlation_updateGradFirst', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'gradOutput': gradOutput, + 'gradFirst': gradFirst, + 'gradSecond': None + }))( + grid=tuple([int((n + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), + gradFirst.data_ptr(), None] + ) + # end + # end + + if gradSecond is not None: + for intSample in range(first.shape[0]): + n = first.shape[1] * first.shape[2] * first.shape[3] + cupy_launch('kernel_Correlation_updateGradSecond', + cupy_kernel('kernel_Correlation_updateGradSecond', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'gradOutput': gradOutput, + 'gradFirst': None, + 'gradSecond': gradSecond + }))( + grid=tuple([int((n + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None, + gradSecond.data_ptr()] + ) + # end + # end + + elif first.is_cuda == False: + raise NotImplementedError() + + # end + + return gradFirst, gradSecond + + +# end +# end + +def FunctionCorrelation(tenFirst, tenSecond): + return _FunctionCorrelation.apply(tenFirst, tenSecond) + + +# end + +class ModuleCorrelation(torch.nn.Module): + def __init__(self): + super(ModuleCorrelation, self).__init__() + + # end + + def forward(self, tenFirst, tenSecond): + return _FunctionCorrelation.apply(tenFirst, tenSecond) +# end +# end \ No newline at end of file diff --git a/modules/components/upr_net_mod/costvol.py b/modules/components/upr_net_mod/costvol.py new file mode 100644 index 0000000000000000000000000000000000000000..40e1cfb5b95f948321fb4429321dbf3dd48f9288 --- /dev/null +++ b/modules/components/upr_net_mod/costvol.py @@ -0,0 +1,424 @@ +#!/usr/bin/env python + +import collections +import cupy +import os +import re +import torch +import typing + + +########################################################## + + +objCudacache = {} + + +def cuda_int32(intIn:int): + return cupy.int32(intIn) +# end + + +def cuda_float32(fltIn:float): + return cupy.float32(fltIn) +# end + + +def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict): + if 'device' not in objCudacache: + objCudacache['device'] = torch.cuda.get_device_name() + # end + + strKey = strFunction + + for strVariable in objVariables: + objValue = objVariables[strVariable] + + strKey += strVariable + + if objValue is None: + continue + + elif type(objValue) == int: + strKey += str(objValue) + + elif type(objValue) == float: + strKey += str(objValue) + + elif type(objValue) == bool: + strKey += str(objValue) + + elif type(objValue) == str: + strKey += objValue + + elif type(objValue) == torch.Tensor: + strKey += str(objValue.dtype) + strKey += str(objValue.shape) + strKey += str(objValue.stride()) + + elif True: + print(strVariable, type(objValue)) + assert(False) + + # end + # end + + strKey += objCudacache['device'] + + if strKey not in objCudacache: + for strVariable in objVariables: + objValue = objVariables[strVariable] + + if objValue is None: + continue + + elif type(objValue) == int: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == float: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == bool: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == str: + strKernel = strKernel.replace('{{' + strVariable + '}}', objValue) + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8: + strKernel = strKernel.replace('{{type}}', 'unsigned char') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16: + strKernel = strKernel.replace('{{type}}', 'half') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32: + strKernel = strKernel.replace('{{type}}', 'float') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64: + strKernel = strKernel.replace('{{type}}', 'double') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32: + strKernel = strKernel.replace('{{type}}', 'int') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64: + strKernel = strKernel.replace('{{type}}', 'long') + + elif type(objValue) == torch.Tensor: + print(strVariable, objValue.dtype) + assert(False) + + elif True: + print(strVariable, type(objValue)) + assert(False) + + # end + # end + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + # end + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item())) + # end + + while True: + objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == '(' else 0 + intParentheses -= 1 if strKernel[intStop] == ')' else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(',') + + assert(intArgs == len(strArgs) - 1) + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')') + # end + + strKernel = strKernel.replace('OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', '(' + str.join('+', strIndex) + ')') + # end + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()', strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == '(' else 0 + intParentheses -= 1 if strKernel[intStop] == ')' else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(',') + + assert(intArgs == len(strArgs) - 1) + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')') + # end + + strKernel = strKernel.replace('VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', strTensor + '[' + str.join('+', strIndex) + ']') + # end + + objCudacache[strKey] = { + 'strFunction': strFunction, + 'strKernel': strKernel + } + # end + + return strKey +# end + + +@cupy.memoize(for_each_device=True) +def cuda_launch(strKey:str): + if 'CUDA_HOME' not in os.environ: + os.environ['CUDA_HOME'] = '/usr/local/cuda/' + # end + + return cupy.cuda.compile_with_cache(objCudacache[strKey]['strKernel'], tuple(['-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include'])).get_function(objCudacache[strKey]['strFunction']) +# end + + +########################################################## + + +class costvol_func(torch.autograd.Function): + @staticmethod + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward(self, tenOne, tenTwo): + tenOut = tenOne.new_empty([tenOne.shape[0], 81, tenOne.shape[2], tenOne.shape[3]]) + + cuda_launch(cuda_kernel('costvol_out', ''' + extern "C" __global__ void __launch_bounds__(512) costvol_out( + const int n, + const {{type}}* __restrict__ tenOne, + const {{type}}* __restrict__ tenTwo, + {{type}}* __restrict__ tenOut + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_0(tenOut); + const int intC = -1; + const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut); + const int intX = ( intIndex ) % SIZE_3(tenOut); + + {{type}} fltOne[{{intChans}}]; + + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX); + } + + int intOffset = OFFSET_4(tenOut, intN, 0, intY, intX); + + for (int intOy = intY - 4; intOy <= intY + 4; intOy += 1) { + for (int intOx = intX - 4; intOx <= intX + 4; intOx += 1) { + {{type}} fltValue = 0.0f; + + if ((intOy >= 0) && (intOy < SIZE_2(tenOut)) && (intOx >= 0) && (intOx < SIZE_3(tenOut))) { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltValue += abs(fltOne[intValue] - VALUE_4(tenTwo, intN, intValue, intOy, intOx)); + } + } else { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltValue += abs(fltOne[intValue]); + } + } + + tenOut[intOffset] = fltValue / SIZE_1(tenOne); + intOffset += SIZE_2(tenOut) * SIZE_3(tenOut); + } + } + } } + ''', { + 'intChans': tenOne.shape[1], + 'tenOne': tenOne, + 'tenTwo': tenTwo, + 'tenOut': tenOut + }))( + grid=tuple([int(((tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]) + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), tenOut.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + + self.save_for_backward(tenOne, tenTwo) + + return tenOut + # end + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(self, tenOutgrad): + tenOne, tenTwo = self.saved_tensors + + tenOutgrad = tenOutgrad.contiguous(); assert(tenOutgrad.is_cuda == True) + + tenOnegrad = tenOne.new_zeros([tenOne.shape[0], tenOne.shape[1], tenOne.shape[2], tenOne.shape[3]]) if self.needs_input_grad[0] == True else None + tenTwograd = tenTwo.new_zeros([tenTwo.shape[0], tenTwo.shape[1], tenTwo.shape[2], tenTwo.shape[3]]) if self.needs_input_grad[1] == True else None + + if tenOnegrad is not None: + cuda_launch(cuda_kernel('costvol_onegrad', ''' + extern "C" __global__ void __launch_bounds__(512) costvol_onegrad( + const int n, + const {{type}}* __restrict__ tenOne, + const {{type}}* __restrict__ tenTwo, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenOnegrad, + {{type}}* __restrict__ tenTwograd + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenOnegrad) / SIZE_2(tenOnegrad) ) % SIZE_0(tenOnegrad); + const int intC = -1; + const int intY = ( intIndex / SIZE_3(tenOnegrad) ) % SIZE_2(tenOnegrad); + const int intX = ( intIndex ) % SIZE_3(tenOnegrad); + + {{type}} fltOne[{{intChans}}]; + + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX); + } + + int intOffset = OFFSET_4(tenOutgrad, intN, 0, intY, intX); + + for (int intOy = intY - 4; intOy <= intY + 4; intOy += 1) { + for (int intOx = intX - 4; intOx <= intX + 4; intOx += 1) { + if ((intOy >= 0) && (intOy < SIZE_2(tenOutgrad)) && (intOx >= 0) && (intOx < SIZE_3(tenOutgrad))) { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + if (fltOne[intValue] - VALUE_4(tenTwo, intN, intValue, intOy, intOx) >= 0.0f) { + tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += +tenOutgrad[intOffset] / SIZE_1(tenOne); + } else { + tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += -tenOutgrad[intOffset] / SIZE_1(tenOne); + } + } + } else { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + if (fltOne[intValue] >= 0.0f) { + tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += +tenOutgrad[intOffset] / SIZE_1(tenOne); + } else { + tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += -tenOutgrad[intOffset] / SIZE_1(tenOne); + } + } + } + + intOffset += SIZE_2(tenOutgrad) * SIZE_3(tenOutgrad); + } + } + } } + ''', { + 'intChans': tenOne.shape[1], + 'tenOne': tenOne, + 'tenTwo': tenTwo, + 'tenOutgrad': tenOutgrad, + 'tenOnegrad': tenOnegrad, + 'tenTwograd': tenTwograd + }))( + grid=tuple([int(((tenOnegrad.shape[0] * tenOnegrad.shape[2] * tenOnegrad.shape[3]) + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenOnegrad.shape[0] * tenOnegrad.shape[2] * tenOnegrad.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), tenOutgrad.data_ptr(), tenOnegrad.data_ptr(), tenTwograd.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + # end + + if tenTwograd is not None: + cuda_launch(cuda_kernel('costvol_twograd', ''' + extern "C" __global__ void __launch_bounds__(512) costvol_twograd( + const int n, + const {{type}}* __restrict__ tenOne, + const {{type}}* __restrict__ tenTwo, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenOnegrad, + {{type}}* __restrict__ tenTwograd + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenTwograd) / SIZE_2(tenTwograd) ) % SIZE_0(tenTwograd); + const int intC = -1; + const int intY = ( intIndex / SIZE_3(tenTwograd) ) % SIZE_2(tenTwograd); + const int intX = ( intIndex ) % SIZE_3(tenTwograd); + + {{type}} fltOne[{{intChans}}]; + + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX); + } + + int intOffset = OFFSET_4(tenOutgrad, intN, 0, intY, intX); + + for (int intOy = intY - 4; intOy <= intY + 4; intOy += 1) { + for (int intOx = intX - 4; intOx <= intX + 4; intOx += 1) { + if ((intOy >= 0) && (intOy < SIZE_2(tenOutgrad)) && (intOx >= 0) && (intOx < SIZE_3(tenOutgrad))) { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + if (fltOne[intValue] - VALUE_4(tenTwo, intN, intValue, intOy, intOx) >= 0.0f) { + atomicAdd(&tenTwograd[OFFSET_4(tenTwograd, intN, intValue, intOy, intOx)], -tenOutgrad[intOffset] / SIZE_1(tenOne)); + } else { + atomicAdd(&tenTwograd[OFFSET_4(tenTwograd, intN, intValue, intOy, intOx)], +tenOutgrad[intOffset] / SIZE_1(tenOne)); + } + } + } else { + // ... + } + + intOffset += SIZE_2(tenOutgrad) * SIZE_3(tenOutgrad); + } + } + } } + ''', { + 'intChans': tenOne.shape[1], + 'tenOne': tenOne, + 'tenTwo': tenTwo, + 'tenOutgrad': tenOutgrad, + 'tenOnegrad': tenOnegrad, + 'tenTwograd': tenTwograd + }))( + grid=tuple([int(((tenTwograd.shape[0] * tenTwograd.shape[2] * tenTwograd.shape[3]) + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenTwograd.shape[0] * tenTwograd.shape[2] * tenTwograd.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), tenOutgrad.data_ptr(), tenOnegrad.data_ptr(), tenTwograd.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + # end + + return tenOnegrad, tenTwograd, None, None + # end +# end diff --git a/modules/components/upr_net_mod/m2m.py b/modules/components/upr_net_mod/m2m.py new file mode 100644 index 0000000000000000000000000000000000000000..f536207982e94a86dc28b8599c557c84b5effb69 --- /dev/null +++ b/modules/components/upr_net_mod/m2m.py @@ -0,0 +1,407 @@ + +import math +import torch +import torch.nn as nn +import typing + +from ..components import register +from .backwarp import * +from .softsplat import _FunctionSoftsplat + + +########################################################## + +def forwarp_mframe_mask(tenIn1, tenFlow1, t1, tenIn2, tenFlow2, t2, tenMetric1=None, tenMetric2=None): + def one_fdir(tenIn, tenFlow, td, tenMetric): + tenIn = torch.cat([tenIn * td * (tenMetric).clip(-20.0, 20.0).exp(), td * (tenMetric).clip(-20.0, 20.0).exp()], + 1) + + tenOut = _FunctionSoftsplat.apply(tenIn, tenFlow) + + return tenOut[:, :-1, :, :], tenOut[:, -1:, :, :] + 0.0000001 + + flow_num = tenFlow1.shape[0] + tenOutF, tenOutB = 0, 0 + tenNormalizeF, tenNormalizeB = 0, 0 + for idx in range(flow_num): + tenOutF_, tenNormalizeF_ = one_fdir(tenIn1[idx], tenFlow1[idx], t1[idx], tenMetric1[idx]) + tenOutB_, tenNormalizeB_ = one_fdir(tenIn2[idx], tenFlow2[idx], t2[idx], tenMetric2[idx]) + + tenOutF += tenOutF_ + tenOutB += tenOutB_ + tenNormalizeF += tenNormalizeF_ + tenNormalizeB += tenNormalizeB_ + + return tenOutF / tenNormalizeF, tenNormalizeF < 0.00001, tenOutB / tenNormalizeB, tenNormalizeB < 0.00001 + + +################################################################### + +c = 16 + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return torch.nn.Sequential( + torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + torch.nn.PReLU(out_planes) + ) + + +def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): + return torch.nn.Sequential( + torch.torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, + kernel_size=kernel_size, stride=stride, padding=padding, bias=True), + torch.nn.PReLU(out_planes) + ) + + +class Conv2(torch.nn.Module): + def __init__(self, in_planes, out_planes, stride=2): + super(Conv2, self).__init__() + self.conv1 = conv(in_planes, out_planes, 3, stride, 1) + self.conv2 = conv(out_planes, out_planes, 3, 1, 1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + +class Conv2n(torch.nn.Module): + def __init__(self, in_planes, out_planes, stride=2): + super(Conv2n, self).__init__() + self.conv1 = conv(in_planes, in_planes, 3, stride, 1) + self.conv2 = conv(in_planes, in_planes, 3, 1, 1) + self.conv3 = conv(in_planes, in_planes, 1, 1, 0) + self.conv4 = conv(in_planes, out_planes, 1, 1, 0) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + return x + + +##################################################### + +class ImgPyramid(torch.nn.Module): + def __init__(self): + super(ImgPyramid, self).__init__() + self.conv1 = Conv2(3, c) + self.conv2 = Conv2(c, 2 * c) + self.conv3 = Conv2(2 * c, 4 * c) + self.conv4 = Conv2(4 * c, 8 * c) + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(x1) + x3 = self.conv3(x2) + x4 = self.conv4(x3) + return [x1, x2, x3, x4] + + +class EncDec(torch.nn.Module): + def __init__(self, branch): + super(EncDec, self).__init__() + self.branch = branch + + self.down0 = Conv2(8, 2 * c) + self.down1 = Conv2(6 * c, 4 * c) + self.down2 = Conv2(12 * c, 8 * c) + self.down3 = Conv2(24 * c, 16 * c) + + self.up0 = deconv(48 * c, 8 * c) + self.up1 = deconv(16 * c, 4 * c) + self.up2 = deconv(8 * c, 2 * c) + self.up3 = deconv(4 * c, c) + self.conv = torch.nn.Conv2d(c, 2 * self.branch, 3, 1, 1) + + self.conv_m = torch.nn.Conv2d(c, self.branch, 3, 1, 1) + + # For Channel dimennsion + self.conv_C = torch.nn.Sequential( + torch.nn.AdaptiveAvgPool2d(1), + torch.nn.Conv2d(16 * c, 16 * 16 * c, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True), + torch.nn.Sigmoid() + ) + + # For Height dimennsion + self.conv_H = torch.nn.Sequential( + torch.nn.AdaptiveAvgPool2d((None, 1)), + torch.nn.Conv2d(16 * c, 16, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True), + torch.nn.Sigmoid() + ) + + # For Width dimennsion + self.conv_W = torch.nn.Sequential( + torch.nn.AdaptiveAvgPool2d((1, None)), + torch.nn.Conv2d(16 * c, 16, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True), + torch.nn.Sigmoid() + ) + + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, flow0, flow1, im0, im1, c0, c1): + N_, C_, H_, W_ = im0.shape + + wim1 = backwarp(im1, flow0) + wim0 = backwarp(im0, flow1) + s0_0 = self.down0(torch.cat((flow0, im0, wim1), 1)) + s1_0 = self.down0(torch.cat((flow1, im1, wim0), 1)) + + ######################################################################################### + flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + + wf0 = backwarp(torch.cat((s0_0, c0[0]), 1), flow1) + wf1 = backwarp(torch.cat((s1_0, c1[0]), 1), flow0) + + s0_1 = self.down1(torch.cat((s0_0, c0[0], wf1), 1)) + s1_1 = self.down1(torch.cat((s1_0, c1[0], wf0), 1)) + + ######################################################################################### + flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + + wf0 = backwarp(torch.cat((s0_1, c0[1]), 1), flow1) + wf1 = backwarp(torch.cat((s1_1, c1[1]), 1), flow0) + + s0_2 = self.down2(torch.cat((s0_1, c0[1], wf1), 1)) + s1_2 = self.down2(torch.cat((s1_1, c1[1], wf0), 1)) + + ######################################################################################### + flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + + wf0 = backwarp(torch.cat((s0_2, c0[2]), 1), flow1) + wf1 = backwarp(torch.cat((s1_2, c1[2]), 1), flow0) + + s0_3 = self.down3(torch.cat((s0_2, c0[2], wf1), 1)) + s1_3 = self.down3(torch.cat((s1_2, c1[2], wf0), 1)) + + ######################################################################################### + + s0_3_c = self.conv_C(s0_3) + s0_3_c = s0_3_c.view(N_, 16, -1, 1, 1) + + s0_3_h = self.conv_H(s0_3) + s0_3_h = s0_3_h.view(N_, 16, 1, -1, 1) + + s0_3_w = self.conv_W(s0_3) + s0_3_w = s0_3_w.view(N_, 16, 1, 1, -1) + + cube0 = (s0_3_c * s0_3_h * s0_3_w).mean(1) + + s0_3 = s0_3 * cube0 + + s1_3_c = self.conv_C(s1_3) + s1_3_c = s1_3_c.view(N_, 16, -1, 1, 1) + + s1_3_h = self.conv_H(s1_3) + s1_3_h = s1_3_h.view(N_, 16, 1, -1, 1) + + s1_3_w = self.conv_W(s1_3) + s1_3_w = s1_3_w.view(N_, 16, 1, 1, -1) + + cube1 = (s1_3_c * s1_3_h * s1_3_w).mean(1) + + s1_3 = s1_3 * cube1 + + ######################################################################################### + flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + + wf0 = backwarp(torch.cat((s0_3, c0[3]), 1), flow1) + wf1 = backwarp(torch.cat((s1_3, c1[3]), 1), flow0) + + x0 = self.up0(torch.cat((s0_3, c0[3], wf1), 1)) + x1 = self.up0(torch.cat((s1_3, c1[3], wf0), 1)) + + x0 = self.up1(torch.cat((s0_2, x0), 1)) + x1 = self.up1(torch.cat((s1_2, x1), 1)) + + x0 = self.up2(torch.cat((s0_1, x0), 1)) + x1 = self.up2(torch.cat((s1_1, x1), 1)) + + x0 = self.up3(torch.cat((s0_0, x0), 1)) + x1 = self.up3(torch.cat((s1_0, x1), 1)) + + m0 = self.sigmoid(self.conv_m(x0)) * 0.8 + 0.1 + m1 = self.sigmoid(self.conv_m(x1)) * 0.8 + 0.1 + + x0 = self.conv(x0) + x1 = self.conv(x1) + + return x0, x1, m0, m1 + + +@register('m2m_pwc') +class M2M_PWC(torch.nn.Module): + def __init__(self, ratio=4): + super(M2M_PWC, self).__init__() + self.branch = 4 + self.ratio = ratio + + self.paramAlpha = torch.nn.Parameter(10.0 * torch.ones(1, 1, 1, 1)) + + class MotionRefineNet(torch.nn.Module): + def __init__(self, branch): + super(MotionRefineNet, self).__init__() + self.branch = branch + self.img_pyramid = ImgPyramid() + self.motion_encdec = EncDec(branch) + + def forward(self, flow0, flow1, im0, im1, ratio): + flow0 = ratio * torch.nn.functional.interpolate(input=flow0, scale_factor=ratio, mode='bilinear', + align_corners=False) + flow1 = ratio * torch.nn.functional.interpolate(input=flow1, scale_factor=ratio, mode='bilinear', + align_corners=False) + + c0 = self.img_pyramid(im0) + c1 = self.img_pyramid(im1) + + flow_res = self.motion_encdec(flow0, flow1, im0, im1, c0, c1) + + flow0 = flow0.repeat(1, self.branch, 1, 1) + flow_res[0] + flow1 = flow1.repeat(1, self.branch, 1, 1) + flow_res[1] + + return flow0, flow1, flow_res[2], flow_res[3] + + self.MRN = MotionRefineNet(self.branch) + + def forward(self, img0, img1, time_step=[0.5], ratio=None, **kwargs): + if ratio is None: + ratio = self.ratio + + intWidth = img0.shape[3] and img1.shape[3] + intHeight = img0.shape[2] and img1.shape[2] + + intPadr = ((ratio * 16) - (intWidth % (ratio * 16))) % (ratio * 16) + intPadb = ((ratio * 16) - (intHeight % (ratio * 16))) % (ratio * 16) + + img0 = torch.nn.functional.pad(input=img0, pad=[0, intPadr, 0, intPadb], mode='replicate') + img1 = torch.nn.functional.pad(input=img1, pad=[0, intPadr, 0, intPadb], mode='replicate') + + N_, C_, H_, W_ = img0.shape + + outputs = [] + result_dict = {} + with torch.set_grad_enabled(False): + tenStats = [img0, img1] + tenMean_ = sum([tenIn.mean([1, 2, 3], True) for tenIn in tenStats]) / len(tenStats) + tenStd_ = (sum([tenIn.std([1, 2, 3], False, True).square() + ( + tenMean_ - tenIn.mean([1, 2, 3], True)).square() for tenIn in tenStats]) / len(tenStats)).sqrt() + + im0_o = (img0 - tenMean_) / (tenStd_ + 0.0000001) + im1_o = (img1 - tenMean_) / (tenStd_ + 0.0000001) + + img0 = (img0 - tenMean_) / (tenStd_ + 0.0000001) + img1 = (img1 - tenMean_) / (tenStd_ + 0.0000001) + + im0_ = torch.nn.functional.interpolate(input=img0, scale_factor=2.0 / ratio, mode='bilinear', + align_corners=False) + im1_ = torch.nn.functional.interpolate(input=img1, scale_factor=2.0 / ratio, mode='bilinear', + align_corners=False) + + tenFwd, tenBwd = self.netFlow.bidir(im0_, im1_) + + result_dict['flowfwd'] = torch.nn.functional.interpolate(tenFwd, scale_factor=ratio, mode='bilinear', align_corners=False)[:, :, + :intHeight, :intWidth].clone().detach() * ratio + result_dict['flowbwd'] = torch.nn.functional.interpolate(tenBwd, scale_factor=ratio, mode='bilinear', align_corners=False)[:, :, + :intHeight, :intWidth].clone().detach() * ratio + + tenFwd, tenBwd, WeiMF, WeiMB = self.MRN(tenFwd, tenBwd, img0, img1, ratio) + + img0 = im0_o.repeat(1, self.branch, 1, 1) + img1 = im1_o.repeat(1, self.branch, 1, 1) + tenStd = tenStd_.repeat(1, self.branch, 1, 1) + tenMean = tenMean_.repeat(1, self.branch, 1, 1) + fltTime = time_step.repeat(1, self.branch, 1, 1) + + tenFwd = tenFwd.reshape(N_, self.branch, 2, H_, W_).view(N_ * self.branch, 2, H_, W_) + tenBwd = tenBwd.reshape(N_, self.branch, 2, H_, W_).view(N_ * self.branch, 2, H_, W_) + + WeiMF = WeiMF.reshape(N_, self.branch, 1, H_, W_).view(N_ * self.branch, 1, H_, W_) + WeiMB = WeiMB.reshape(N_, self.branch, 1, H_, W_).view(N_ * self.branch, 1, H_, W_) + + img0 = img0.reshape(N_, self.branch, 3, H_, W_).view(N_ * self.branch, 3, H_, W_) + img1 = img1.reshape(N_, self.branch, 3, H_, W_).view(N_ * self.branch, 3, H_, W_) + + tenStd = tenStd.reshape(N_, self.branch, 1, 1, 1).view(N_ * self.branch, 1, 1, 1) + tenMean = tenMean.reshape(N_, self.branch, 1, 1, 1).view(N_ * self.branch, 1, 1, 1) + fltTime = fltTime.reshape(N_, self.branch, 1, 1, 1).view(N_ * self.branch, 1, 1, 1) + + tenPhotoone = (1.0 - (WeiMF * (img0 - backwarp(img1, tenFwd).detach()).abs().mean([1], True))).clip( + 0.001, None).square() + tenPhototwo = (1.0 - (WeiMB * (img1 - backwarp(img0, tenBwd).detach()).abs().mean([1], True))).clip( + 0.001, None).square() + + t0 = fltTime + flow0 = tenFwd * t0 + metric0 = self.paramAlpha * tenPhotoone + + t1 = 1.0 - fltTime + flow1 = tenBwd * t1 + metric1 = self.paramAlpha * tenPhototwo + + flow0 = flow0.reshape(N_, self.branch, 2, H_, W_).permute(1, 0, 2, 3, 4) + flow1 = flow1.reshape(N_, self.branch, 2, H_, W_).permute(1, 0, 2, 3, 4) + + metric0 = metric0.reshape(N_, self.branch, 1, H_, W_).permute(1, 0, 2, 3, 4) + metric1 = metric1.reshape(N_, self.branch, 1, H_, W_).permute(1, 0, 2, 3, 4) + + img0 = img0.reshape(N_, self.branch, 3, H_, W_).permute(1, 0, 2, 3, 4) + img1 = img1.reshape(N_, self.branch, 3, H_, W_).permute(1, 0, 2, 3, 4) + + t0 = t0.reshape(N_, self.branch, 1, 1, 1).permute(1, 0, 2, 3, 4) + t1 = t1.reshape(N_, self.branch, 1, 1, 1).permute(1, 0, 2, 3, 4) + + tenOutput, mask = forwarp_mframe_mask(img0, flow0, t1, img1, flow1, t0, metric0, metric1) + + tenOutput = tenOutput + mask * (t1.mean(0) * im0_o + t0.mean(0) * im1_o) + + output = (tenOutput * (tenStd_ + 0.0000001)) + tenMean_ + result_dict['imgt_pred'] = output[:, :, :intHeight, :intWidth] + + return result_dict + +class ResBlock(nn.Module): + def __init__(self, in_channels, side_channels, bias=True): + super(ResBlock, self).__init__() + self.side_channels = side_channels + self.conv1 = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(in_channels) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(side_channels) + ) + self.conv3 = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(in_channels) + ) + self.conv4 = nn.Sequential( + nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(side_channels) + ) + self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias) + self.prelu = nn.PReLU(in_channels) + + def forward(self, x): + out = self.conv1(x) + + res_feat = out[:, :-self.side_channels, ...] + side_feat = out[:, -self.side_channels:, :, :] + side_feat = self.conv2(side_feat) + out = self.conv3(torch.cat([res_feat, side_feat], 1)) + + res_feat = out[:, :-self.side_channels, ...] + side_feat = out[:, -self.side_channels:, :, :] + side_feat = self.conv4(side_feat) + out = self.conv5(torch.cat([res_feat, side_feat], 1)) + + out = self.prelu(x + out) + return out \ No newline at end of file diff --git a/modules/components/upr_net_mod/softsplat.py b/modules/components/upr_net_mod/softsplat.py new file mode 100644 index 0000000000000000000000000000000000000000..77967f24cd1eeee56417d1de2c88369d13b883c6 --- /dev/null +++ b/modules/components/upr_net_mod/softsplat.py @@ -0,0 +1,360 @@ +#!/usr/bin/env python + +import torch + +import cupy +import re + +kernel_Softsplat_updateOutput = ''' + extern "C" __global__ void kernel_Softsplat_updateOutput( + const int n, + const float* input, + const float* flow, + float* output + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(output) / SIZE_2(output) / SIZE_1(output) ) % SIZE_0(output); + const int intC = ( intIndex / SIZE_3(output) / SIZE_2(output) ) % SIZE_1(output); + const int intY = ( intIndex / SIZE_3(output) ) % SIZE_2(output); + const int intX = ( intIndex ) % SIZE_3(output); + + float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); + float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); + + int intNorthwestX = (int) (floor(fltOutputX)); + int intNorthwestY = (int) (floor(fltOutputY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + float fltNorthwest = ((float) (intSoutheastX) - fltOutputX ) * ((float) (intSoutheastY) - fltOutputY ); + float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY ); + float fltSouthwest = ((float) (intNortheastX) - fltOutputX ) * (fltOutputY - (float) (intNortheastY)); + float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY)); + + if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(output)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(output))) { + atomicAdd(&output[OFFSET_4(output, intN, intC, intNorthwestY, intNorthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltNorthwest); + } + + if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(output)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(output))) { + atomicAdd(&output[OFFSET_4(output, intN, intC, intNortheastY, intNortheastX)], VALUE_4(input, intN, intC, intY, intX) * fltNortheast); + } + + if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(output)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(output))) { + atomicAdd(&output[OFFSET_4(output, intN, intC, intSouthwestY, intSouthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltSouthwest); + } + + if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(output)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(output))) { + atomicAdd(&output[OFFSET_4(output, intN, intC, intSoutheastY, intSoutheastX)], VALUE_4(input, intN, intC, intY, intX) * fltSoutheast); + } + } } +''' + +kernel_Softsplat_updateGradInput = ''' + extern "C" __global__ void kernel_Softsplat_updateGradInput( + const int n, + const float* input, + const float* flow, + const float* gradOutput, + float* gradInput, + float* gradFlow + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) / SIZE_1(gradInput) ) % SIZE_0(gradInput); + const int intC = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) ) % SIZE_1(gradInput); + const int intY = ( intIndex / SIZE_3(gradInput) ) % SIZE_2(gradInput); + const int intX = ( intIndex ) % SIZE_3(gradInput); + + float fltGradInput = 0.0; + + float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); + float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); + + int intNorthwestX = (int) (floor(fltOutputX)); + int intNorthwestY = (int) (floor(fltOutputY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + float fltNorthwest = ((float) (intSoutheastX) - fltOutputX ) * ((float) (intSoutheastY) - fltOutputY ); + float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY ); + float fltSouthwest = ((float) (intNortheastX) - fltOutputX ) * (fltOutputY - (float) (intNortheastY)); + float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY)); + + if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) { + fltGradInput += VALUE_4(gradOutput, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest; + } + + if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) { + fltGradInput += VALUE_4(gradOutput, intN, intC, intNortheastY, intNortheastX) * fltNortheast; + } + + if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) { + fltGradInput += VALUE_4(gradOutput, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest; + } + + if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) { + fltGradInput += VALUE_4(gradOutput, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast; + } + + gradInput[intIndex] = fltGradInput; + } } +''' + +kernel_Softsplat_updateGradFlow = ''' + extern "C" __global__ void kernel_Softsplat_updateGradFlow( + const int n, + const float* input, + const float* flow, + const float* gradOutput, + float* gradInput, + float* gradFlow + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + float fltGradFlow = 0.0; + + const int intN = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) / SIZE_1(gradFlow) ) % SIZE_0(gradFlow); + const int intC = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) ) % SIZE_1(gradFlow); + const int intY = ( intIndex / SIZE_3(gradFlow) ) % SIZE_2(gradFlow); + const int intX = ( intIndex ) % SIZE_3(gradFlow); + + float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); + float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); + + int intNorthwestX = (int) (floor(fltOutputX)); + int intNorthwestY = (int) (floor(fltOutputY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + float fltNorthwest = 0.0; + float fltNortheast = 0.0; + float fltSouthwest = 0.0; + float fltSoutheast = 0.0; + + if (intC == 0) { + fltNorthwest = ((float) (-1.0)) * ((float) (intSoutheastY) - fltOutputY ); + fltNortheast = ((float) (+1.0)) * ((float) (intSouthwestY) - fltOutputY ); + fltSouthwest = ((float) (-1.0)) * (fltOutputY - (float) (intNortheastY)); + fltSoutheast = ((float) (+1.0)) * (fltOutputY - (float) (intNorthwestY)); + + } else if (intC == 1) { + fltNorthwest = ((float) (intSoutheastX) - fltOutputX ) * ((float) (-1.0)); + fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (-1.0)); + fltSouthwest = ((float) (intNortheastX) - fltOutputX ) * ((float) (+1.0)); + fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * ((float) (+1.0)); + + } + + for (int intChannel = 0; intChannel < SIZE_1(gradOutput); intChannel += 1) { + float fltInput = VALUE_4(input, intN, intChannel, intY, intX); + + if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) { + fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNorthwestY, intNorthwestX) * fltNorthwest; + } + + if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) { + fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNortheastY, intNortheastX) * fltNortheast; + } + + if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) { + fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSouthwestY, intSouthwestX) * fltSouthwest; + } + + if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) { + fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSoutheastY, intSoutheastX) * fltSoutheast; + } + } + + gradFlow[intIndex] = fltGradFlow; + } } +''' + +def cupy_kernel(strFunction, objVariables): + strKernel = globals()[strFunction] + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg])) + + while True: + objMatch = re.search('(OFFSET_)([0-4])(\()([^\)]+)(\))', strKernel) + + if objMatch is None: + break + + intArgs = int(objMatch.group(2)) + strArgs = objMatch.group(4).split(',') + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')')\ + .strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] + + strKernel = strKernel.replace(objMatch.group(0), '(' + str.join('+', strIndex) + ')') + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) + + if objMatch is None: + break + + intArgs = int(objMatch.group(2)) + strArgs = objMatch.group(4).split(',') + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')')\ + .strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] + + strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') + + return strKernel + + +@cupy.memoize(for_each_device=True) +def cupy_launch(strFunction, strKernel): + return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) + + +class _FunctionSoftsplat(torch.autograd.Function): + @staticmethod + def forward(self, input, flow): + self.save_for_backward(input, flow) + + intSamples = input.shape[0] + intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3] + intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3] + + assert(intFlowDepth == 2) + assert(intInputHeight == intFlowHeight) + assert(intInputWidth == intFlowWidth) + + assert(input.is_contiguous() == True) + assert(flow.is_contiguous() == True) + + output = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ]) + + if input.is_cuda == True: + n = output.nelement() + cupy_launch('kernel_Softsplat_updateOutput', cupy_kernel('kernel_Softsplat_updateOutput', { + 'input': input, + 'flow': flow, + 'output': output + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ n, input.data_ptr(), flow.data_ptr(), output.data_ptr() ] + ) + + elif input.is_cuda == False: + raise NotImplementedError() + + return output + + + @staticmethod + def backward(self, gradOutput): + input, flow = self.saved_tensors + + intSamples = input.shape[0] + intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3] + intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3] + + assert(intFlowDepth == 2) + assert(intInputHeight == intFlowHeight) + assert(intInputWidth == intFlowWidth) + + assert(gradOutput.is_contiguous() == True) + + gradInput = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ])\ + if self.needs_input_grad[0] == True else None + gradFlow = input.new_zeros([ intSamples, intFlowDepth, intFlowHeight, intFlowWidth ])\ + if self.needs_input_grad[1] == True else None + + if input.is_cuda == True: + if gradInput is not None: + n = gradInput.nelement() + cupy_launch('kernel_Softsplat_updateGradInput', cupy_kernel('kernel_Softsplat_updateGradInput', { + 'input': input, + 'flow': flow, + 'gradOutput': gradOutput, + 'gradInput': gradInput, + 'gradFlow': gradFlow + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ n, input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), gradInput.data_ptr(), None ] + ) + + if gradFlow is not None: + n = gradFlow.nelement() + cupy_launch('kernel_Softsplat_updateGradFlow', cupy_kernel('kernel_Softsplat_updateGradFlow', { + 'input': input, + 'flow': flow, + 'gradOutput': gradOutput, + 'gradInput': gradInput, + 'gradFlow': gradFlow + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ n, input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), None, gradFlow.data_ptr() ] + ) + + elif input.is_cuda == False: + raise NotImplementedError() + + + return gradInput, gradFlow + + +def FunctionSoftsplat(tenInput, tenFlow, tenMetric, strType): + assert(tenMetric is None or tenMetric.shape[1] == 1) + assert(strType in ['summation', 'average', 'linear', 'softmax']) + + if strType == 'average': + tenInput = torch.cat([ tenInput, tenInput.new_ones(tenInput.shape[0], 1, tenInput.shape[2], tenInput.shape[3]) ], 1) + + elif strType == 'linear': + tenInput = torch.cat([ tenInput * tenMetric, tenMetric ], 1) + + elif strType == 'softmax': + tenInput = torch.cat([ tenInput * tenMetric.clip(-20, 20).exp(), tenMetric.clip(-20, 20).exp() ], 1) + + + tenOutput = _FunctionSoftsplat.apply(tenInput, tenFlow) + + if strType != 'summation': + tenNormalize = tenOutput[:, -1:, :, :] + + tenNormalize[tenNormalize == 0.0] = 1.0 + + tenOutput = tenOutput[:, :-1, :, :] / tenNormalize + + return tenOutput + + +class ModuleSoftsplat(torch.nn.Module): + def __init__(self, strType): + super(ModuleSoftsplat, self).__init__() + + self.strType = strType + + def forward(self, tenInput, tenFlow, tenMetric): + return FunctionSoftsplat(tenInput, tenFlow, tenMetric, self.strType) + diff --git a/modules/components/upr_net_mod/upr.py b/modules/components/upr_net_mod/upr.py new file mode 100644 index 0000000000000000000000000000000000000000..6882b31dac42b35c67d0a46397798ccbf07f96c4 --- /dev/null +++ b/modules/components/upr_net_mod/upr.py @@ -0,0 +1,509 @@ +import torch +import math +import numpy +import torch.nn.functional as F +import torch.nn as nn + +import modules.components.upr_net.correlation as correlation +import modules.components.upr_net.softsplat as softsplat +from modules.components.upr_net.m2m import * +from modules.components.upr_net.backwarp import backwarp +from .costvol import costvol_func +from ..components import register + +from utils.padder import InputPadder + + +# **************************************************************************************************# +# => Feature Pyramid +# **************************************************************************************************# + + +def photometric_consistency(img0, img1, flow01): + return (img0 - backwarp(img1, flow01)).abs().sum(dim=1, keepdims=True) + + +def flow_consistency(flow01, flow10): + return (flow01 + backwarp(flow10, flow01)).abs().sum(dim=1, keepdims=True) + + +def gaussian(x): + gaussian_kernel = torch.tensor([[1, 2, 1], + [2, 4, 2], + [1, 2, 1]]) / 16 + gaussian_kernel = gaussian_kernel.repeat(2, 1, 1, 1) + gaussian_kernel = gaussian_kernel.to(torch.cuda.current_device()) + x = torch.nn.functional.pad(x, (1, 1, 1, 1), mode='reflect') + out = torch.nn.functional.conv2d(x, gaussian_kernel, groups=x.shape[1]) + # out = TF.gaussian_blur(x, [3, 3], sigma=[2, 2]) + return out + + +def variance_flow(flow): + flow = flow * torch.tensor(data=[2.0 / (flow.shape[3] - 1.0), 2.0 / (flow.shape[2] - 1.0)], dtype=flow.dtype, + device=flow.device).view(1, 2, 1, 1) + return (gaussian(flow ** 2) - gaussian(flow) ** 2 + 1e-4).sqrt().abs().sum(dim=1, keepdim=True) + + +class FeatPyramid(nn.Module): + """A 3-level feature pyramid, which by default is shared by the motion + estimator and synthesis network. + """ + + def __init__(self): + super(FeatPyramid, self).__init__() + self.conv_stage0 = nn.Sequential( + nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1)) + self.conv_stage1 = nn.Sequential( + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, + stride=2, padding=1), + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), ) + self.conv_stage2 = nn.Sequential( + nn.InstanceNorm2d(num_features=128), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, + stride=2, padding=1), + nn.InstanceNorm2d(num_features=128), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=128), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=128), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, img): + C0 = self.conv_stage0(img) + C1 = self.conv_stage1(C0) + C2 = self.conv_stage2(C1) + return [C0, C1, C2] + + +# **************************************************************************************************# +# => Motion Estimation +# **************************************************************************************************# +class MotionEstimator(nn.Module): + """Bi-directional optical flow estimator + 1) construct partial cost volume with the CNN features from the stage 2 of + the feature pyramid; + 2) estimate bi-directional flows, by feeding cost volume, CNN features for + both warped images, CNN feature and estimated flow from previous iteration. + """ + + def __init__(self): + super(MotionEstimator, self).__init__() + # 64 + 256 + 128 * 2 + 128 = 704 + self.conv_flow = nn.Sequential( + nn.Conv2d(4, 128, 7, padding=3), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(128, 64, 3, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + self.conv_corr = nn.Sequential( + nn.Conv2d(81, 64, 1, padding=0), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(64, 128, 3, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + self.conv_layer1 = nn.Sequential( + nn.Conv2d(in_channels=704, out_channels=320, + kernel_size=1, stride=1, padding=0), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer2 = nn.Sequential( + nn.Conv2d(in_channels=320, out_channels=256, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer3 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=224, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer4 = nn.Sequential( + nn.Conv2d(in_channels=224, out_channels=192, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer5 = nn.Sequential( + nn.Conv2d(in_channels=192, out_channels=128, + kernel_size=3, stride=1, padding=1)) + self.conv_layer6 = nn.Sequential( + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=4, + kernel_size=3, stride=1, padding=1, bias=False)) + + # for m in self.modules(): + # if isinstance(m, nn.Conv2d): + # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') + # elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + # if m.weight is not None: + # nn.init.constant_(m.weight, 1) + # if m.bias is not None: + # nn.init.constant_(m.bias, 0) + + def forward(self, feat0, feat1, last_feat, last_flow): + corr_fn = correlation.FunctionCorrelation + feat0_warp = backwarp(feat0, last_flow[:, :2]) + feat1_warp = backwarp(feat1, last_flow[:, 2:]) + + volume0 = F.leaky_relu( + input=costvol_func.apply(feat0_warp, feat1_warp), + negative_slope=0.1, inplace=False) + volume1 = F.leaky_relu( + input=costvol_func.apply(feat1_warp, feat0_warp), + negative_slope=0.1, inplace=False) + corr0 = self.conv_corr(volume0) + corr1 = self.conv_corr(volume1) + flo = self.conv_flow(last_flow) + input_feat = torch.cat([corr0, corr1, feat0_warp, feat1_warp, last_feat, flo], 1) + feat = self.conv_layer1(input_feat) + feat = self.conv_layer2(feat) + feat = self.conv_layer3(feat) + feat = self.conv_layer4(feat) + feat = self.conv_layer5(feat) + flow_res = self.conv_layer6(feat) + flow = last_flow + flow_res + + return flow, feat + + +# **************************************************************************************************# +# => Frame Synthesis +# **************************************************************************************************# +class SynthesisNetwork(nn.Module): + def __init__(self, splat_mode='average'): + super(SynthesisNetwork, self).__init__() + input_channels = 9 + 4 + 6 + self.encoder_conv = nn.Sequential( + nn.Conv2d(in_channels=input_channels, out_channels=64, + kernel_size=3, stride=1, padding=1), + nn.PReLU(num_parameters=64), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64)) + self.encoder_down1 = nn.Sequential( + nn.Conv2d(in_channels=64 + 32 + 32, out_channels=128, + kernel_size=3, stride=2, padding=1), + nn.PReLU(num_parameters=128), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=128), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=128)) + self.encoder_down2 = nn.Sequential( + nn.Conv2d(in_channels=128 + 64 + 64, out_channels=256, + kernel_size=3, stride=2, padding=1), + nn.PReLU(num_parameters=256), + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=256), + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=256)) + self.decoder_up1 = nn.Sequential( + torch.nn.ConvTranspose2d(in_channels=256 + 128 + 128, + out_channels=128, kernel_size=4, stride=2, + padding=1, bias=True), + nn.PReLU(num_parameters=128), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=128)) + self.decoder_up2 = nn.Sequential( + torch.nn.ConvTranspose2d(in_channels=128 + 128, + out_channels=64, kernel_size=4, stride=2, + padding=1, bias=True), + nn.PReLU(num_parameters=64), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64)) + self.decoder_conv = nn.Sequential( + nn.Conv2d(in_channels=64 + 64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64)) + self.pred = nn.Conv2d(in_channels=64, out_channels=4, kernel_size=3, + stride=1, padding=1) + self.splat_mode = splat_mode + + if self.splat_mode == 'softmax': + # New params for splatting mask generation + self.alpha = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_photo_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_flow_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_variation_flow = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + + def get_splat_weight(self, img0, img1, flow01, flow10): + if self.splat_mode == 'softmax': + M_splat = 1 / ( + 1 + self.alpha_splat_photo_consistency * photometric_consistency(img0, img1, flow01).detach()) + \ + 1 / (1 + self.alpha_splat_flow_consistency * flow_consistency(flow01, flow10).detach()) + \ + 1 / (1 + self.alpha_splat_variation_flow * variance_flow(flow01).detach()) + return M_splat * self.alpha + else: + return None + + def get_warped_representations(self, bi_flow, c0, c1, m_splat_0, m_splat_1, i0=None, i1=None, time_period=0.5): + flow_t0 = bi_flow[:, :2] * time_period * 2 + flow_t1 = bi_flow[:, 2:4] * (1 - time_period) * 2 + warped_c0 = backwarp(c0, flow_t0) + warped_c1 = backwarp(c1, flow_t1) + if (i0 is None) and (i1 is None): + return warped_c0, warped_c1 + else: + warped_img0 = backwarp(i0, flow_t0) + warped_img1 = backwarp(i1, flow_t1) + scaler = torch.Tensor([i0.shape[3], i0.shape[2]]).view(1, 2, 1, 1).cuda() + flow_t0_t1 = torch.cat((flow_t0 / scaler, flow_t1 / scaler), 1) + return warped_img0, warped_img1, warped_c0, warped_c1, flow_t0_t1 + + def forward(self, last_i, i0, i1, c0_pyr, c1_pyr, bi_flow_pyr, time_period=0.5, multi_flow=False): + m_splat_0_0 = self.get_splat_weight(i0, i1, bi_flow_pyr[0][:, :2], bi_flow_pyr[0][:, 2:4]) + m_splat_1_0 = self.get_splat_weight(i1, i0, bi_flow_pyr[0][:, 2:4], bi_flow_pyr[0][:, :2]) + warped_img0, warped_img1, warped_c0, warped_c1, flow_0t_1t = \ + self.get_warped_representations( + bi_flow_pyr[0], c0_pyr[0], c1_pyr[0], m_splat_0_0, m_splat_1_0, i0, i1, + time_period=time_period) + input_feat = torch.cat( + (last_i, warped_img0, warped_img1, i0, i1, flow_0t_1t), 1) + s0 = self.encoder_conv(input_feat) + s1 = self.encoder_down1(torch.cat((s0, warped_c0, warped_c1), 1)) + warped_c0, warped_c1 = self.get_warped_representations( + bi_flow_pyr[1], c0_pyr[1], c1_pyr[1], None, None, + time_period=time_period) + s2 = self.encoder_down2(torch.cat((s1, warped_c0, warped_c1), 1)) + warped_c0, warped_c1 = self.get_warped_representations( + bi_flow_pyr[2], c0_pyr[2], c1_pyr[2], None, None, + time_period=time_period) + + x = self.decoder_up1(torch.cat((s2, warped_c0, warped_c1), 1)) + x = self.decoder_up2(torch.cat((x, s1), 1)) + x = self.decoder_conv(torch.cat((x, s0), 1)) + + # prediction + refine = self.pred(x) + refine_res = torch.sigmoid(refine[:, :3]) * 2 - 1 + refine_mask = torch.sigmoid(refine[:, 3:]) + merged_img = (warped_img0 * refine_mask + + warped_img1 * (1 - refine_mask)) + interp_img = merged_img + refine_res + # interp_img = torch.clamp(interp_img, 0, 1) + + extra_dict = {} + extra_dict["refine_res"] = refine_res + extra_dict["refine_mask"] = refine_mask + extra_dict["warped_img0"] = warped_img0 + extra_dict["warped_img1"] = warped_img1 + extra_dict["merged_img"] = merged_img + + return interp_img, extra_dict + + +# **************************************************************************************************# +# => Unified model +# **************************************************************************************************# +@register('upr_net_mod') +class Model(nn.Module): + def __init__(self, pyr_level=3, nr_lvl_skipped=0, splat_mode='average'): + super(Model, self).__init__() + self.pyr_level = pyr_level + self.feat_pyramid = FeatPyramid() + self.nr_lvl_skipped = nr_lvl_skipped + self.motion_estimator = MotionEstimator() + self.synthesis_network = SynthesisNetwork(splat_mode) + self.splat_mode = splat_mode + + def forward_one_lvl(self, + img0, img1, last_feat, last_flow, last_interp=None, + time_period=0.5, skip_me=False): + + # context feature extraction + feat0_pyr = self.feat_pyramid(img0) + feat1_pyr = self.feat_pyramid(img1) + + # bi-directional flow estimation + if not skip_me: + flow, feat = self.motion_estimator( + feat0_pyr[-1], feat1_pyr[-1], + last_feat, last_flow) + else: + flow = last_flow + feat = last_feat + + # frame synthesis + ## optical flow is estimated at 1/4 resolution + ori_resolution_flow = F.interpolate( + input=flow, scale_factor=4.0, + mode="bilinear", align_corners=False) * 4 + + ## consturct 3-level flow pyramid for synthesis network + bi_flow_pyr = [] + tmp_flow = ori_resolution_flow + bi_flow_pyr.append(tmp_flow) + for i in range(2): + tmp_flow = F.interpolate( + input=tmp_flow, scale_factor=0.5, + mode="bilinear", align_corners=False) * 0.5 + bi_flow_pyr.append(tmp_flow) + + ## merge warped frames as initial interpolation for frame synthesis + if last_interp is None: + flow_t0 = ori_resolution_flow[:, :2] * time_period * 2 + flow_t1 = ori_resolution_flow[:, 2:4] * (1 - time_period) * 2 + warped_img0 = backwarp(img0, flow_t0) + warped_img1 = backwarp(img1, flow_t1) + last_interp = warped_img0 * (1 - time_period) + warped_img1 * time_period + + ## do synthesis + interp_img, extra_dict = self.synthesis_network( + last_interp, img0, img1, feat0_pyr, feat1_pyr, bi_flow_pyr, + time_period=time_period) + return flow, feat, interp_img, extra_dict + + def forward(self, img0, img1, time_step, + pyr_level=None, nr_lvl_skipped=None, **kwargs): + + if pyr_level is None: pyr_level = self.pyr_level + if nr_lvl_skipped is None: nr_lvl_skipped = self.nr_lvl_skipped + N, _, H, W = img0.shape + flow0_pred = [] + flow1_pred = [] + interp_imgs = [] + skipped_levels = [] if nr_lvl_skipped == 0 else \ + list(range(pyr_level))[::-1][-nr_lvl_skipped:] + + with torch.set_grad_enabled(False): + tenStats = [img0, img1] + tenMean_ = sum([tenIn.mean([1, 2, 3], True) for tenIn in tenStats]) / len(tenStats) + tenStd_ = (sum([tenIn.std([1, 2, 3], False, True).square() + ( + tenMean_ - tenIn.mean([1, 2, 3], True)).square() for tenIn in tenStats]) / len(tenStats)).sqrt() + + img0 = (img0 - tenMean_) / (tenStd_ + 0.0000001) + img1 = (img1 - tenMean_) / (tenStd_ + 0.0000001) + + padder = InputPadder(img0.shape, divisor=int(4 * 2 ** pyr_level)) + img0, img1 = padder.pad(img0, img1) + N, _, H, W = img0.shape + + # The original input resolution corresponds to level 0. + for level in list(range(pyr_level))[::-1]: + if level != 0: + scale_factor = 1 / 2 ** level + img0_this_lvl = F.interpolate( + input=img0, scale_factor=scale_factor, + mode="bilinear", align_corners=False) + img1_this_lvl = F.interpolate( + input=img1, scale_factor=scale_factor, + mode="bilinear", align_corners=False) + else: + img0_this_lvl = img0 + img1_this_lvl = img1 + + # skip motion estimation, directly use up-sampled optical flow + skip_me = False + + # the lowest-resolution pyramid level + if level == pyr_level - 1: + last_flow = torch.zeros( + (N, 4, H // (2 ** (level + 2)), W // (2 ** (level + 2))) + ).to(img0.device) + last_feat = torch.zeros( + (N, 128, H // (2 ** (level + 2)), W // (2 ** (level + 2))) + ).to(img0.device) + last_interp = None + # skip some levels for both motion estimation and frame synthesis + elif level in skipped_levels[:-1]: + continue + # last level (original input resolution), only skip motion estimation + elif (level == 0) and len(skipped_levels) > 0: + if len(skipped_levels) == pyr_level: + last_flow = torch.zeros( + (N, 4, H // 4, W // 4)).to(img0.device) + last_interp = None + else: + resize_factor = 2 ** len(skipped_levels) + last_flow = F.interpolate( + input=flow, scale_factor=resize_factor, + mode="bilinear", align_corners=False) * resize_factor + last_interp = F.interpolate( + input=interp_img, scale_factor=resize_factor, + mode="bilinear", align_corners=False) + skip_me = True + # last level (original input resolution), motion estimation + frame + # synthesis + else: + last_flow = F.interpolate(input=flow, scale_factor=2.0, + mode="bilinear", align_corners=False) * 2 + last_feat = F.interpolate(input=feat, scale_factor=2.0, + mode="bilinear", align_corners=False) * 2 + last_interp = F.interpolate( + input=interp_img, scale_factor=2.0, + mode="bilinear", align_corners=False) + + flow, feat, interp_img, extra_dict = self.forward_one_lvl( + img0_this_lvl, img1_this_lvl, + last_feat, last_flow, last_interp, + time_step, skip_me=skip_me) + flow0_pred.append( + padder.unpad(F.interpolate(input=flow[:, :2], scale_factor=4.0, + mode="bilinear", align_corners=False)) * 4) + flow1_pred.append( + padder.unpad(F.interpolate(input=flow[:, 2:], scale_factor=4.0, + mode="bilinear", align_corners=False)) * 4) + interp_imgs.append(padder.unpad(F.interpolate(interp_img, scale_factor=2 ** level)) * tenStd_ + tenMean_) + + # directly up-sample estimated flow to full resolution with bi-linear + # interpolation + refine_res = padder.unpad(extra_dict["refine_res"]) + refine_mask = padder.unpad(extra_dict["refine_mask"]) + warped_img0 = padder.unpad(extra_dict["warped_img0"]) * tenStd_ + tenMean_ + warped_img1 = padder.unpad(extra_dict["warped_img1"]) * tenStd_ + tenMean_ + merged_img = padder.unpad(extra_dict["merged_img"]) * tenStd_ + tenMean_ + + return {"imgt_preds": interp_imgs, "flow0_pred": flow0_pred[::-1], "flow1_pred": flow1_pred[::-1], + 'imgt_pred': interp_imgs[-1], "flowfwd": flow0_pred[-1], "flowbwd": flow1_pred[-1], + 'refine_res': refine_res, 'refine_mask': refine_mask, 'warped_img0': warped_img0, + 'warped_img1': warped_img1, 'merged_img': merged_img, + } + + +if __name__ == "__main__": + pass diff --git a/modules/components/upr_net_mod2/__init__.py b/modules/components/upr_net_mod2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a025cb0254df7c206a75e55134d826fda008a3e4 --- /dev/null +++ b/modules/components/upr_net_mod2/__init__.py @@ -0,0 +1,3 @@ +from .upr_exp43 import Model +# from .upr_exp45 import Model +from .upr import Model \ No newline at end of file diff --git a/modules/components/upr_net_mod2/__pycache__/__init__.cpython-310.pyc b/modules/components/upr_net_mod2/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e337582bac944b03a080d68c0d55311616c3d15c Binary files /dev/null and b/modules/components/upr_net_mod2/__pycache__/__init__.cpython-310.pyc differ diff --git a/modules/components/upr_net_mod2/__pycache__/__init__.cpython-38.pyc b/modules/components/upr_net_mod2/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..569cf8e414312d65ed6688c66d59d9787a92b036 Binary files /dev/null and b/modules/components/upr_net_mod2/__pycache__/__init__.cpython-38.pyc differ diff --git a/modules/components/upr_net_mod2/__pycache__/__init__.cpython-39.pyc b/modules/components/upr_net_mod2/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac552b7099455a4ad6c5f77b72a6ebfec660afc6 Binary files /dev/null and b/modules/components/upr_net_mod2/__pycache__/__init__.cpython-39.pyc differ diff --git a/modules/components/upr_net_mod2/__pycache__/backwarp.cpython-310.pyc b/modules/components/upr_net_mod2/__pycache__/backwarp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f356de612feed84d01de62147036b1446869f06d Binary files /dev/null and b/modules/components/upr_net_mod2/__pycache__/backwarp.cpython-310.pyc differ diff --git a/modules/components/upr_net_mod2/__pycache__/backwarp.cpython-38.pyc b/modules/components/upr_net_mod2/__pycache__/backwarp.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44e5feabf4caf9b5511c23b32a4208bcc851eef2 Binary files /dev/null and b/modules/components/upr_net_mod2/__pycache__/backwarp.cpython-38.pyc differ diff --git a/modules/components/upr_net_mod2/__pycache__/backwarp.cpython-39.pyc b/modules/components/upr_net_mod2/__pycache__/backwarp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8389ca05f9c243661e794340b10b36a7432f86d7 Binary files /dev/null and b/modules/components/upr_net_mod2/__pycache__/backwarp.cpython-39.pyc differ diff --git a/modules/components/upr_net_mod2/__pycache__/correlation.cpython-310.pyc b/modules/components/upr_net_mod2/__pycache__/correlation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..071de2ef2d86de901c86d8e3d3416a037be2868f Binary files /dev/null and b/modules/components/upr_net_mod2/__pycache__/correlation.cpython-310.pyc differ diff --git a/modules/components/upr_net_mod2/__pycache__/correlation.cpython-38.pyc b/modules/components/upr_net_mod2/__pycache__/correlation.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cabd5a019edd27f3e3e1bf573368b83f811b74c Binary files /dev/null and b/modules/components/upr_net_mod2/__pycache__/correlation.cpython-38.pyc differ diff --git a/modules/components/upr_net_mod2/__pycache__/correlation.cpython-39.pyc b/modules/components/upr_net_mod2/__pycache__/correlation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3df73f3af5dca568d9a1ce997fcc7fa11caa89e5 Binary files /dev/null and b/modules/components/upr_net_mod2/__pycache__/correlation.cpython-39.pyc differ diff --git a/modules/components/upr_net_mod2/__pycache__/costvol.cpython-310.pyc b/modules/components/upr_net_mod2/__pycache__/costvol.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa2e7fd19c9577b10cecf718d2bcb2b7be8ece69 Binary files /dev/null and b/modules/components/upr_net_mod2/__pycache__/costvol.cpython-310.pyc differ diff --git a/modules/components/upr_net_mod2/__pycache__/costvol.cpython-38.pyc b/modules/components/upr_net_mod2/__pycache__/costvol.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3de82133eacd2582cfb0b9ba9b3e33ba3f52ae29 Binary files /dev/null and b/modules/components/upr_net_mod2/__pycache__/costvol.cpython-38.pyc differ diff --git a/modules/components/upr_net_mod2/__pycache__/costvol.cpython-39.pyc b/modules/components/upr_net_mod2/__pycache__/costvol.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cf215ac2777a8fba1dc964acb20bf2cb65a22df Binary files /dev/null and b/modules/components/upr_net_mod2/__pycache__/costvol.cpython-39.pyc differ diff --git a/modules/components/upr_net_mod2/__pycache__/m2m.cpython-310.pyc b/modules/components/upr_net_mod2/__pycache__/m2m.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..760965a177a13ba9adf32c4694b33f27ebaa8ab9 Binary files /dev/null and b/modules/components/upr_net_mod2/__pycache__/m2m.cpython-310.pyc differ diff --git a/modules/components/upr_net_mod2/__pycache__/m2m.cpython-38.pyc b/modules/components/upr_net_mod2/__pycache__/m2m.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b91dccbad4572f84350e05e535c98e3bbab0de6 Binary files /dev/null and b/modules/components/upr_net_mod2/__pycache__/m2m.cpython-38.pyc differ diff --git a/modules/components/upr_net_mod2/__pycache__/m2m.cpython-39.pyc b/modules/components/upr_net_mod2/__pycache__/m2m.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c5dcb26b297757b90923d15fc8d6e7ecc5fdf65 Binary files /dev/null and b/modules/components/upr_net_mod2/__pycache__/m2m.cpython-39.pyc differ diff --git a/modules/components/upr_net_mod2/__pycache__/softsplat.cpython-310.pyc b/modules/components/upr_net_mod2/__pycache__/softsplat.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02bbc369ff098cdc0ddd860e52f5d63f39ae91ea Binary files /dev/null and b/modules/components/upr_net_mod2/__pycache__/softsplat.cpython-310.pyc differ diff --git a/modules/components/upr_net_mod2/__pycache__/softsplat.cpython-38.pyc b/modules/components/upr_net_mod2/__pycache__/softsplat.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c79a248f2fb5fd007c364a63a5bd5401a3cc8949 Binary files /dev/null and b/modules/components/upr_net_mod2/__pycache__/softsplat.cpython-38.pyc differ diff --git a/modules/components/upr_net_mod2/__pycache__/softsplat.cpython-39.pyc b/modules/components/upr_net_mod2/__pycache__/softsplat.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f6230c5b96e7124623aedf1fb785d1bc78de6ca Binary files /dev/null and b/modules/components/upr_net_mod2/__pycache__/softsplat.cpython-39.pyc differ diff --git a/modules/components/upr_net_mod2/__pycache__/upr.cpython-310.pyc b/modules/components/upr_net_mod2/__pycache__/upr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c73f6d0d740fbf7acb4c21c427a620e6435e020 Binary files /dev/null and b/modules/components/upr_net_mod2/__pycache__/upr.cpython-310.pyc differ diff --git a/modules/components/upr_net_mod2/__pycache__/upr.cpython-38.pyc b/modules/components/upr_net_mod2/__pycache__/upr.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f9e2e912664d3b9c64fe1b0fc1b64387ce15c59 Binary files /dev/null and b/modules/components/upr_net_mod2/__pycache__/upr.cpython-38.pyc differ diff --git a/modules/components/upr_net_mod2/__pycache__/upr.cpython-39.pyc b/modules/components/upr_net_mod2/__pycache__/upr.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54248d3fd7da2e50e51070dd32b9f886365c2afa Binary files /dev/null and b/modules/components/upr_net_mod2/__pycache__/upr.cpython-39.pyc differ diff --git a/modules/components/upr_net_mod2/__pycache__/upr_exp43.cpython-310.pyc b/modules/components/upr_net_mod2/__pycache__/upr_exp43.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9a1d5321d7f3d32fff590d22a7c225340eb500f Binary files /dev/null and b/modules/components/upr_net_mod2/__pycache__/upr_exp43.cpython-310.pyc differ diff --git a/modules/components/upr_net_mod2/__pycache__/upr_exp43.cpython-38.pyc b/modules/components/upr_net_mod2/__pycache__/upr_exp43.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d42e067ffcade1d86173d017146f0b12f756418b Binary files /dev/null and b/modules/components/upr_net_mod2/__pycache__/upr_exp43.cpython-38.pyc differ diff --git a/modules/components/upr_net_mod2/__pycache__/upr_exp43.cpython-39.pyc b/modules/components/upr_net_mod2/__pycache__/upr_exp43.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68c0e344557c15b7b8362fd1133329ad6c7169aa Binary files /dev/null and b/modules/components/upr_net_mod2/__pycache__/upr_exp43.cpython-39.pyc differ diff --git a/modules/components/upr_net_mod2/__pycache__/upr_exp45.cpython-38.pyc b/modules/components/upr_net_mod2/__pycache__/upr_exp45.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff55665eca5948a3508406b9995e549064eb8d58 Binary files /dev/null and b/modules/components/upr_net_mod2/__pycache__/upr_exp45.cpython-38.pyc differ diff --git a/modules/components/upr_net_mod2/backwarp.py b/modules/components/upr_net_mod2/backwarp.py new file mode 100644 index 0000000000000000000000000000000000000000..729a1db8e0117bd49526929e5953cf6e70fd204a --- /dev/null +++ b/modules/components/upr_net_mod2/backwarp.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python + +import torch + + +########################################################## + + +objBackwarpcache = {} + + +def backwarp(tenIn:torch.Tensor, tenFlow:torch.Tensor): + if 'grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3]) not in objBackwarpcache: + tenHor = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[3], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, 1, -1).repeat(1, 1, tenFlow.shape[2], 1) + tenVer = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[2], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, -1, 1).repeat(1, 1, 1, tenFlow.shape[3]) + + objBackwarpcache['grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3])] = torch.cat([tenHor, tenVer], 1) + # end + + if tenFlow.shape[3] == tenFlow.shape[2]: + tenFlow = tenFlow * (2.0 / ((tenFlow.shape[3] and tenFlow.shape[2]) - 1.0)) + + elif tenFlow.shape[3] != tenFlow.shape[2]: + tenFlow = tenFlow * torch.tensor(data=[2.0 / (tenFlow.shape[3] - 1.0), 2.0 / (tenFlow.shape[2] - 1.0)], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 2, 1, 1) + + # end + + return torch.nn.functional.grid_sample(input=tenIn, grid=(objBackwarpcache['grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3])] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=True) +# end \ No newline at end of file diff --git a/modules/components/upr_net_mod2/correlation.py b/modules/components/upr_net_mod2/correlation.py new file mode 100644 index 0000000000000000000000000000000000000000..1d1c92e2ef7dd885f25b30a3b2e4ed25c6a3889e --- /dev/null +++ b/modules/components/upr_net_mod2/correlation.py @@ -0,0 +1,415 @@ +#!/usr/bin/env python + +import torch + +import cupy +import re + +kernel_Correlation_rearrange = ''' + extern "C" __global__ void kernel_Correlation_rearrange( + const int n, + const float* input, + float* output + ) { + int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; + + if (intIndex >= n) { + return; + } + + int intSample = blockIdx.z; + int intChannel = blockIdx.y; + + float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex]; + + __syncthreads(); + + int intPaddedY = (intIndex / SIZE_3(input)) + 4; + int intPaddedX = (intIndex % SIZE_3(input)) + 4; + int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX; + + output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue; + } +''' + +kernel_Correlation_updateOutput = ''' + extern "C" __global__ void kernel_Correlation_updateOutput( + const int n, + const float* rbot0, + const float* rbot1, + float* top + ) { + extern __shared__ char patch_data_char[]; + + float *patch_data = (float *)patch_data_char; + + // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1 + int x1 = blockIdx.x + 4; + int y1 = blockIdx.y + 4; + int item = blockIdx.z; + int ch_off = threadIdx.x; + + // Load 3D patch into shared shared memory + for (int j = 0; j < 1; j++) { // HEIGHT + for (int i = 0; i < 1; i++) { // WIDTH + int ji_off = (j + i) * SIZE_3(rbot0); + for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS + int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch; + int idxPatchData = ji_off + ch; + patch_data[idxPatchData] = rbot0[idx1]; + } + } + } + + __syncthreads(); + + __shared__ float sum[32]; + + // Compute correlation + for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) { + sum[ch_off] = 0; + + int s2o = top_channel % 9 - 4; + int s2p = top_channel / 9 - 4; + + for (int j = 0; j < 1; j++) { // HEIGHT + for (int i = 0; i < 1; i++) { // WIDTH + int ji_off = (j + i) * SIZE_3(rbot0); + for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS + int x2 = x1 + s2o; + int y2 = y1 + s2p; + + int idxPatchData = ji_off + ch; + int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch; + + sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2]; + } + } + } + + __syncthreads(); + + if (ch_off == 0) { + float total_sum = 0; + for (int idx = 0; idx < 32; idx++) { + total_sum += sum[idx]; + } + const int sumelems = SIZE_3(rbot0); + const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x; + top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems; + } + } + } +''' + +kernel_Correlation_updateGradFirst = ''' + #define ROUND_OFF 50000 + + extern "C" __global__ void kernel_Correlation_updateGradFirst( + const int n, + const int intSample, + const float* rbot0, + const float* rbot1, + const float* gradOutput, + float* gradFirst, + float* gradSecond + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + int n = intIndex % SIZE_1(gradFirst); // channels + int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 4; // w-pos + int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 4; // h-pos + + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = round_off; + + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) + int ymin = (m - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) + + // Same here: + int xmax = (l - 4 + round_off_s1) - round_off; // floor (l - 4) + int ymax = (m - 4 + round_off_s1) - round_off; // floor (m - 4) + + float sum = 0; + if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { + xmin = max(0,xmin); + xmax = min(SIZE_3(gradOutput)-1,xmax); + + ymin = max(0,ymin); + ymax = min(SIZE_2(gradOutput)-1,ymax); + + for (int p = -4; p <= 4; p++) { + for (int o = -4; o <= 4; o++) { + // Get rbot1 data: + int s2o = o; + int s2p = p; + int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n; + float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n] + + // Index offset for gradOutput in following loops: + int op = (p+4) * 9 + (o+4); // index[o,p] + int idxopoffset = (intSample * SIZE_1(gradOutput) + op); + + for (int y = ymin; y <= ymax; y++) { + for (int x = xmin; x <= xmax; x++) { + int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] + sum += gradOutput[idxgradOutput] * bot1tmp; + } + } + } + } + } + const int sumelems = SIZE_1(gradFirst); + const int bot0index = ((n * SIZE_2(gradFirst)) + (m-4)) * SIZE_3(gradFirst) + (l-4); + gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems; + } } +''' + +kernel_Correlation_updateGradSecond = ''' + #define ROUND_OFF 50000 + + extern "C" __global__ void kernel_Correlation_updateGradSecond( + const int n, + const int intSample, + const float* rbot0, + const float* rbot1, + const float* gradOutput, + float* gradFirst, + float* gradSecond + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + int n = intIndex % SIZE_1(gradSecond); // channels + int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 4; // w-pos + int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 4; // h-pos + + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = round_off; + + float sum = 0; + for (int p = -4; p <= 4; p++) { + for (int o = -4; o <= 4; o++) { + int s2o = o; + int s2p = p; + + //Get X,Y ranges and clamp + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 4 - s2o + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) + int ymin = (m - 4 - s2p + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) + + // Same here: + int xmax = (l - 4 - s2o + round_off_s1) - round_off; // floor (l - 4 - s2o) + int ymax = (m - 4 - s2p + round_off_s1) - round_off; // floor (m - 4 - s2p) + + if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { + xmin = max(0,xmin); + xmax = min(SIZE_3(gradOutput)-1,xmax); + + ymin = max(0,ymin); + ymax = min(SIZE_2(gradOutput)-1,ymax); + + // Get rbot0 data: + int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n; + float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n] + + // Index offset for gradOutput in following loops: + int op = (p+4) * 9 + (o+4); // index[o,p] + int idxopoffset = (intSample * SIZE_1(gradOutput) + op); + + for (int y = ymin; y <= ymax; y++) { + for (int x = xmin; x <= xmax; x++) { + int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] + sum += gradOutput[idxgradOutput] * bot0tmp; + } + } + } + } + } + const int sumelems = SIZE_1(gradSecond); + const int bot1index = ((n * SIZE_2(gradSecond)) + (m-4)) * SIZE_3(gradSecond) + (l-4); + gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems; + } } +''' + + +def cupy_kernel(strFunction, objVariables): + strKernel = globals()[strFunction] + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + # end + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg])) + # end + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) + + if objMatch is None: + break + # end + + intArgs = int(objMatch.group(2)) + strArgs = objMatch.group(4).split(',') + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + strIndex = ['((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str( + intStrides[intArg]) + ')' for intArg in range(intArgs)] + + strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') + # end + + return strKernel + + +# end + +@cupy.memoize(for_each_device=True) +def cupy_launch(strFunction, strKernel): + return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) + + +# end + +class _FunctionCorrelation(torch.autograd.Function): + @staticmethod + def forward(self, first, second): + rbot0 = first.new_zeros([first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1]]) + rbot1 = first.new_zeros([first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1]]) + + self.save_for_backward(first, second, rbot0, rbot1) + + assert (first.is_contiguous() == True) + assert (second.is_contiguous() == True) + + output = first.new_zeros([first.shape[0], 81, first.shape[2], first.shape[3]]) + + if first.is_cuda == True: + n = first.shape[2] * first.shape[3] + cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { + 'input': first, + 'output': rbot0 + }))( + grid=tuple([int((n + 16 - 1) / 16), first.shape[1], first.shape[0]]), + block=tuple([16, 1, 1]), + args=[n, first.data_ptr(), rbot0.data_ptr()] + ) + + n = second.shape[2] * second.shape[3] + cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { + 'input': second, + 'output': rbot1 + }))( + grid=tuple([int((n + 16 - 1) / 16), second.shape[1], second.shape[0]]), + block=tuple([16, 1, 1]), + args=[n, second.data_ptr(), rbot1.data_ptr()] + ) + + n = output.shape[1] * output.shape[2] * output.shape[3] + cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'top': output + }))( + grid=tuple([output.shape[3], output.shape[2], output.shape[0]]), + block=tuple([32, 1, 1]), + shared_mem=first.shape[1] * 4, + args=[n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr()] + ) + + elif first.is_cuda == False: + raise NotImplementedError() + + # end + + return output + + # end + + @staticmethod + def backward(self, gradOutput): + first, second, rbot0, rbot1 = self.saved_tensors + + assert (gradOutput.is_contiguous() == True) + + gradFirst = first.new_zeros([first.shape[0], first.shape[1], first.shape[2], first.shape[3]]) if \ + self.needs_input_grad[0] == True else None + gradSecond = first.new_zeros([first.shape[0], first.shape[1], first.shape[2], first.shape[3]]) if \ + self.needs_input_grad[1] == True else None + + if first.is_cuda == True: + if gradFirst is not None: + for intSample in range(first.shape[0]): + n = first.shape[1] * first.shape[2] * first.shape[3] + cupy_launch('kernel_Correlation_updateGradFirst', + cupy_kernel('kernel_Correlation_updateGradFirst', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'gradOutput': gradOutput, + 'gradFirst': gradFirst, + 'gradSecond': None + }))( + grid=tuple([int((n + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), + gradFirst.data_ptr(), None] + ) + # end + # end + + if gradSecond is not None: + for intSample in range(first.shape[0]): + n = first.shape[1] * first.shape[2] * first.shape[3] + cupy_launch('kernel_Correlation_updateGradSecond', + cupy_kernel('kernel_Correlation_updateGradSecond', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'gradOutput': gradOutput, + 'gradFirst': None, + 'gradSecond': gradSecond + }))( + grid=tuple([int((n + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None, + gradSecond.data_ptr()] + ) + # end + # end + + elif first.is_cuda == False: + raise NotImplementedError() + + # end + + return gradFirst, gradSecond + + +# end +# end + +def FunctionCorrelation(tenFirst, tenSecond): + return _FunctionCorrelation.apply(tenFirst, tenSecond) + + +# end + +class ModuleCorrelation(torch.nn.Module): + def __init__(self): + super(ModuleCorrelation, self).__init__() + + # end + + def forward(self, tenFirst, tenSecond): + return _FunctionCorrelation.apply(tenFirst, tenSecond) +# end +# end \ No newline at end of file diff --git a/modules/components/upr_net_mod2/costvol.py b/modules/components/upr_net_mod2/costvol.py new file mode 100644 index 0000000000000000000000000000000000000000..6c93e4db22d00bf73c8b1fc06a297a85a16ee352 --- /dev/null +++ b/modules/components/upr_net_mod2/costvol.py @@ -0,0 +1,424 @@ +#!/usr/bin/env python + +import collections +import cupy +import os +import re +import torch +import typing + + +########################################################## + + +objCudacache = {} + + +def cuda_int32(intIn:int): + return cupy.int32(intIn) +# end + + +def cuda_float32(fltIn:float): + return cupy.float32(fltIn) +# end + + +def cuda_kernel(strFunction:str, strKernel:str, objVariables:typing.Dict): + if 'device' not in objCudacache: + objCudacache['device'] = torch.cuda.get_device_name() + # end + + strKey = strFunction + + for strVariable in objVariables: + objValue = objVariables[strVariable] + + strKey += strVariable + + if objValue is None: + continue + + elif type(objValue) == int: + strKey += str(objValue) + + elif type(objValue) == float: + strKey += str(objValue) + + elif type(objValue) == bool: + strKey += str(objValue) + + elif type(objValue) == str: + strKey += objValue + + elif type(objValue) == torch.Tensor: + strKey += str(objValue.dtype) + strKey += str(objValue.shape) + strKey += str(objValue.stride()) + + elif True: + print(strVariable, type(objValue)) + assert(False) + + # end + # end + + strKey += objCudacache['device'] + + if strKey not in objCudacache: + for strVariable in objVariables: + objValue = objVariables[strVariable] + + if objValue is None: + continue + + elif type(objValue) == int: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == float: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == bool: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == str: + strKernel = strKernel.replace('{{' + strVariable + '}}', objValue) + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8: + strKernel = strKernel.replace('{{type}}', 'unsigned char') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16: + strKernel = strKernel.replace('{{type}}', 'half') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32: + strKernel = strKernel.replace('{{type}}', 'float') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64: + strKernel = strKernel.replace('{{type}}', 'double') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32: + strKernel = strKernel.replace('{{type}}', 'int') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64: + strKernel = strKernel.replace('{{type}}', 'long') + + elif type(objValue) == torch.Tensor: + print(strVariable, objValue.dtype) + assert(False) + + elif True: + print(strVariable, type(objValue)) + assert(False) + + # end + # end + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + # end + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item())) + # end + + while True: + objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == '(' else 0 + intParentheses -= 1 if strKernel[intStop] == ')' else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(',') + + assert(intArgs == len(strArgs) - 1) + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')') + # end + + strKernel = strKernel.replace('OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', '(' + str.join('+', strIndex) + ')') + # end + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()', strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == '(' else 0 + intParentheses -= 1 if strKernel[intStop] == ')' else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(',') + + assert(intArgs == len(strArgs) - 1) + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')') + # end + + strKernel = strKernel.replace('VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', strTensor + '[' + str.join('+', strIndex) + ']') + # end + + objCudacache[strKey] = { + 'strFunction': strFunction, + 'strKernel': strKernel + } + # end + + return strKey +# end + + +@cupy.memoize(for_each_device=True) +def cuda_launch(strKey:str): + if 'CUDA_HOME' not in os.environ: + os.environ['CUDA_HOME'] = '/usr/local/cuda/' + # end + + return cupy.cuda.compile_with_cache(objCudacache[strKey]['strKernel'], tuple(['-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include'])).get_function(objCudacache[strKey]['strFunction']) +# end + + +########################################################## + + +class costvol_func(torch.autograd.Function): + @staticmethod + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward(self, tenOne, tenTwo): + tenOut = tenOne.new_empty([tenOne.shape[0], 81, tenOne.shape[2], tenOne.shape[3]]) + + cuda_launch(cuda_kernel('costvol_out', ''' + extern "C" __global__ void __launch_bounds__(512) costvol_out( + const int n, + const {{type}}* __restrict__ tenOne, + const {{type}}* __restrict__ tenTwo, + {{type}}* __restrict__ tenOut + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_0(tenOut); + const int intC = -1; + const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut); + const int intX = ( intIndex ) % SIZE_3(tenOut); + + {{type}} fltOne[{{intChans}}]; + + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX); + } + + int intOffset = OFFSET_4(tenOut, intN, 0, intY, intX); + + for (int intOy = intY - 4; intOy <= intY + 4; intOy += 1) { + for (int intOx = intX - 4; intOx <= intX + 4; intOx += 1) { + {{type}} fltValue = 0.0f; + + if ((intOy >= 0) && (intOy < SIZE_2(tenOut)) && (intOx >= 0) && (intOx < SIZE_3(tenOut))) { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltValue += abs(fltOne[intValue] - VALUE_4(tenTwo, intN, intValue, intOy, intOx)); + } + } else { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltValue += abs(fltOne[intValue]); + } + } + + tenOut[intOffset] = fltValue / SIZE_1(tenOne); + intOffset += SIZE_2(tenOut) * SIZE_3(tenOut); + } + } + } } + ''', { + 'intChans': tenOne.shape[1], + 'tenOne': tenOne, + 'tenTwo': tenTwo, + 'tenOut': tenOut + }))( + grid=tuple([int(((tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]) + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenOut.shape[0] * tenOut.shape[2] * tenOut.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), tenOut.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + + self.save_for_backward(tenOne, tenTwo) + + return tenOut + # end + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(self, tenOutgrad): + tenOne, tenTwo = self.saved_tensors + + tenOutgrad = tenOutgrad.contiguous(); assert(tenOutgrad.is_cuda == True) + + tenOnegrad = tenOne.new_zeros([tenOne.shape[0], tenOne.shape[1], tenOne.shape[2], tenOne.shape[3]]) if self.needs_input_grad[0] == True else None + tenTwograd = tenTwo.new_zeros([tenTwo.shape[0], tenTwo.shape[1], tenTwo.shape[2], tenTwo.shape[3]]) if self.needs_input_grad[1] == True else None + + if tenOnegrad is not None: + cuda_launch(cuda_kernel('costvol_onegrad', ''' + extern "C" __global__ void __launch_bounds__(512) costvol_onegrad( + const int n, + const {{type}}* __restrict__ tenOne, + const {{type}}* __restrict__ tenTwo, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenOnegrad, + {{type}}* __restrict__ tenTwograd + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenOnegrad) / SIZE_2(tenOnegrad) ) % SIZE_0(tenOnegrad); + const int intC = -1; + const int intY = ( intIndex / SIZE_3(tenOnegrad) ) % SIZE_2(tenOnegrad); + const int intX = ( intIndex ) % SIZE_3(tenOnegrad); + + {{type}} fltOne[{{intChans}}]; + + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX); + } + + int intOffset = OFFSET_4(tenOutgrad, intN, 0, intY, intX); + + for (int intOy = intY - 4; intOy <= intY + 4; intOy += 1) { + for (int intOx = intX - 4; intOx <= intX + 4; intOx += 1) { + if ((intOy >= 0) && (intOy < SIZE_2(tenOutgrad)) && (intOx >= 0) && (intOx < SIZE_3(tenOutgrad))) { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + if (fltOne[intValue] - VALUE_4(tenTwo, intN, intValue, intOy, intOx) >= 0.0f) { + tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += +tenOutgrad[intOffset] / SIZE_1(tenOne); + } else { + tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += -tenOutgrad[intOffset] / SIZE_1(tenOne); + } + } + } else { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + if (fltOne[intValue] >= 0.0f) { + tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += +tenOutgrad[intOffset] / SIZE_1(tenOne); + } else { + tenOnegrad[OFFSET_4(tenOnegrad, intN, intValue, intY, intX)] += -tenOutgrad[intOffset] / SIZE_1(tenOne); + } + } + } + + intOffset += SIZE_2(tenOutgrad) * SIZE_3(tenOutgrad); + } + } + } } + ''', { + 'intChans': tenOne.shape[1], + 'tenOne': tenOne, + 'tenTwo': tenTwo, + 'tenOutgrad': tenOutgrad, + 'tenOnegrad': tenOnegrad, + 'tenTwograd': tenTwograd + }))( + grid=tuple([int(((tenOnegrad.shape[0] * tenOnegrad.shape[2] * tenOnegrad.shape[3]) + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenOnegrad.shape[0] * tenOnegrad.shape[2] * tenOnegrad.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), tenOutgrad.data_ptr(), tenOnegrad.data_ptr(), tenTwograd.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + # end + + if tenTwograd is not None: + cuda_launch(cuda_kernel('costvol_twograd', ''' + extern "C" __global__ void __launch_bounds__(512) costvol_twograd( + const int n, + const {{type}}* __restrict__ tenOne, + const {{type}}* __restrict__ tenTwo, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenOnegrad, + {{type}}* __restrict__ tenTwograd + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenTwograd) / SIZE_2(tenTwograd) ) % SIZE_0(tenTwograd); + const int intC = -1; + const int intY = ( intIndex / SIZE_3(tenTwograd) ) % SIZE_2(tenTwograd); + const int intX = ( intIndex ) % SIZE_3(tenTwograd); + + {{type}} fltOne[{{intChans}}]; + + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + fltOne[intValue] = VALUE_4(tenOne, intN, intValue, intY, intX); + } + + int intOffset = OFFSET_4(tenOutgrad, intN, 0, intY, intX); + + for (int intOy = intY - 4; intOy <= intY + 4; intOy += 1) { + for (int intOx = intX - 4; intOx <= intX + 4; intOx += 1) { + if ((intOy >= 0) && (intOy < SIZE_2(tenOutgrad)) && (intOx >= 0) && (intOx < SIZE_3(tenOutgrad))) { + for (int intValue = 0; intValue < SIZE_1(tenOne); intValue += 1) { + if (fltOne[intValue] - VALUE_4(tenTwo, intN, intValue, intOy, intOx) >= 0.0f) { + atomicAdd(&tenTwograd[OFFSET_4(tenTwograd, intN, intValue, intOy, intOx)], -tenOutgrad[intOffset] / SIZE_1(tenOne)); + } else { + atomicAdd(&tenTwograd[OFFSET_4(tenTwograd, intN, intValue, intOy, intOx)], +tenOutgrad[intOffset] / SIZE_1(tenOne)); + } + } + } else { + // ... + } + + intOffset += SIZE_2(tenOutgrad) * SIZE_3(tenOutgrad); + } + } + } } + ''', { + 'intChans': tenOne.shape[1], + 'tenOne': tenOne, + 'tenTwo': tenTwo, + 'tenOutgrad': tenOutgrad, + 'tenOnegrad': tenOnegrad, + 'tenTwograd': tenTwograd + }))( + grid=tuple([int(((tenTwograd.shape[0] * tenTwograd.shape[2] * tenTwograd.shape[3]) + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenTwograd.shape[0] * tenTwograd.shape[2] * tenTwograd.shape[3]), tenOne.data_ptr(), tenTwo.data_ptr(), tenOutgrad.data_ptr(), tenOnegrad.data_ptr(), tenTwograd.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + # end + + return tenOnegrad, tenTwograd, None, None + # end +# end \ No newline at end of file diff --git a/modules/components/upr_net_mod2/m2m.py b/modules/components/upr_net_mod2/m2m.py new file mode 100644 index 0000000000000000000000000000000000000000..f536207982e94a86dc28b8599c557c84b5effb69 --- /dev/null +++ b/modules/components/upr_net_mod2/m2m.py @@ -0,0 +1,407 @@ + +import math +import torch +import torch.nn as nn +import typing + +from ..components import register +from .backwarp import * +from .softsplat import _FunctionSoftsplat + + +########################################################## + +def forwarp_mframe_mask(tenIn1, tenFlow1, t1, tenIn2, tenFlow2, t2, tenMetric1=None, tenMetric2=None): + def one_fdir(tenIn, tenFlow, td, tenMetric): + tenIn = torch.cat([tenIn * td * (tenMetric).clip(-20.0, 20.0).exp(), td * (tenMetric).clip(-20.0, 20.0).exp()], + 1) + + tenOut = _FunctionSoftsplat.apply(tenIn, tenFlow) + + return tenOut[:, :-1, :, :], tenOut[:, -1:, :, :] + 0.0000001 + + flow_num = tenFlow1.shape[0] + tenOutF, tenOutB = 0, 0 + tenNormalizeF, tenNormalizeB = 0, 0 + for idx in range(flow_num): + tenOutF_, tenNormalizeF_ = one_fdir(tenIn1[idx], tenFlow1[idx], t1[idx], tenMetric1[idx]) + tenOutB_, tenNormalizeB_ = one_fdir(tenIn2[idx], tenFlow2[idx], t2[idx], tenMetric2[idx]) + + tenOutF += tenOutF_ + tenOutB += tenOutB_ + tenNormalizeF += tenNormalizeF_ + tenNormalizeB += tenNormalizeB_ + + return tenOutF / tenNormalizeF, tenNormalizeF < 0.00001, tenOutB / tenNormalizeB, tenNormalizeB < 0.00001 + + +################################################################### + +c = 16 + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return torch.nn.Sequential( + torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + torch.nn.PReLU(out_planes) + ) + + +def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): + return torch.nn.Sequential( + torch.torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, + kernel_size=kernel_size, stride=stride, padding=padding, bias=True), + torch.nn.PReLU(out_planes) + ) + + +class Conv2(torch.nn.Module): + def __init__(self, in_planes, out_planes, stride=2): + super(Conv2, self).__init__() + self.conv1 = conv(in_planes, out_planes, 3, stride, 1) + self.conv2 = conv(out_planes, out_planes, 3, 1, 1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + +class Conv2n(torch.nn.Module): + def __init__(self, in_planes, out_planes, stride=2): + super(Conv2n, self).__init__() + self.conv1 = conv(in_planes, in_planes, 3, stride, 1) + self.conv2 = conv(in_planes, in_planes, 3, 1, 1) + self.conv3 = conv(in_planes, in_planes, 1, 1, 0) + self.conv4 = conv(in_planes, out_planes, 1, 1, 0) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + return x + + +##################################################### + +class ImgPyramid(torch.nn.Module): + def __init__(self): + super(ImgPyramid, self).__init__() + self.conv1 = Conv2(3, c) + self.conv2 = Conv2(c, 2 * c) + self.conv3 = Conv2(2 * c, 4 * c) + self.conv4 = Conv2(4 * c, 8 * c) + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(x1) + x3 = self.conv3(x2) + x4 = self.conv4(x3) + return [x1, x2, x3, x4] + + +class EncDec(torch.nn.Module): + def __init__(self, branch): + super(EncDec, self).__init__() + self.branch = branch + + self.down0 = Conv2(8, 2 * c) + self.down1 = Conv2(6 * c, 4 * c) + self.down2 = Conv2(12 * c, 8 * c) + self.down3 = Conv2(24 * c, 16 * c) + + self.up0 = deconv(48 * c, 8 * c) + self.up1 = deconv(16 * c, 4 * c) + self.up2 = deconv(8 * c, 2 * c) + self.up3 = deconv(4 * c, c) + self.conv = torch.nn.Conv2d(c, 2 * self.branch, 3, 1, 1) + + self.conv_m = torch.nn.Conv2d(c, self.branch, 3, 1, 1) + + # For Channel dimennsion + self.conv_C = torch.nn.Sequential( + torch.nn.AdaptiveAvgPool2d(1), + torch.nn.Conv2d(16 * c, 16 * 16 * c, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True), + torch.nn.Sigmoid() + ) + + # For Height dimennsion + self.conv_H = torch.nn.Sequential( + torch.nn.AdaptiveAvgPool2d((None, 1)), + torch.nn.Conv2d(16 * c, 16, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True), + torch.nn.Sigmoid() + ) + + # For Width dimennsion + self.conv_W = torch.nn.Sequential( + torch.nn.AdaptiveAvgPool2d((1, None)), + torch.nn.Conv2d(16 * c, 16, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True), + torch.nn.Sigmoid() + ) + + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, flow0, flow1, im0, im1, c0, c1): + N_, C_, H_, W_ = im0.shape + + wim1 = backwarp(im1, flow0) + wim0 = backwarp(im0, flow1) + s0_0 = self.down0(torch.cat((flow0, im0, wim1), 1)) + s1_0 = self.down0(torch.cat((flow1, im1, wim0), 1)) + + ######################################################################################### + flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + + wf0 = backwarp(torch.cat((s0_0, c0[0]), 1), flow1) + wf1 = backwarp(torch.cat((s1_0, c1[0]), 1), flow0) + + s0_1 = self.down1(torch.cat((s0_0, c0[0], wf1), 1)) + s1_1 = self.down1(torch.cat((s1_0, c1[0], wf0), 1)) + + ######################################################################################### + flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + + wf0 = backwarp(torch.cat((s0_1, c0[1]), 1), flow1) + wf1 = backwarp(torch.cat((s1_1, c1[1]), 1), flow0) + + s0_2 = self.down2(torch.cat((s0_1, c0[1], wf1), 1)) + s1_2 = self.down2(torch.cat((s1_1, c1[1], wf0), 1)) + + ######################################################################################### + flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + + wf0 = backwarp(torch.cat((s0_2, c0[2]), 1), flow1) + wf1 = backwarp(torch.cat((s1_2, c1[2]), 1), flow0) + + s0_3 = self.down3(torch.cat((s0_2, c0[2], wf1), 1)) + s1_3 = self.down3(torch.cat((s1_2, c1[2], wf0), 1)) + + ######################################################################################### + + s0_3_c = self.conv_C(s0_3) + s0_3_c = s0_3_c.view(N_, 16, -1, 1, 1) + + s0_3_h = self.conv_H(s0_3) + s0_3_h = s0_3_h.view(N_, 16, 1, -1, 1) + + s0_3_w = self.conv_W(s0_3) + s0_3_w = s0_3_w.view(N_, 16, 1, 1, -1) + + cube0 = (s0_3_c * s0_3_h * s0_3_w).mean(1) + + s0_3 = s0_3 * cube0 + + s1_3_c = self.conv_C(s1_3) + s1_3_c = s1_3_c.view(N_, 16, -1, 1, 1) + + s1_3_h = self.conv_H(s1_3) + s1_3_h = s1_3_h.view(N_, 16, 1, -1, 1) + + s1_3_w = self.conv_W(s1_3) + s1_3_w = s1_3_w.view(N_, 16, 1, 1, -1) + + cube1 = (s1_3_c * s1_3_h * s1_3_w).mean(1) + + s1_3 = s1_3 * cube1 + + ######################################################################################### + flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + + wf0 = backwarp(torch.cat((s0_3, c0[3]), 1), flow1) + wf1 = backwarp(torch.cat((s1_3, c1[3]), 1), flow0) + + x0 = self.up0(torch.cat((s0_3, c0[3], wf1), 1)) + x1 = self.up0(torch.cat((s1_3, c1[3], wf0), 1)) + + x0 = self.up1(torch.cat((s0_2, x0), 1)) + x1 = self.up1(torch.cat((s1_2, x1), 1)) + + x0 = self.up2(torch.cat((s0_1, x0), 1)) + x1 = self.up2(torch.cat((s1_1, x1), 1)) + + x0 = self.up3(torch.cat((s0_0, x0), 1)) + x1 = self.up3(torch.cat((s1_0, x1), 1)) + + m0 = self.sigmoid(self.conv_m(x0)) * 0.8 + 0.1 + m1 = self.sigmoid(self.conv_m(x1)) * 0.8 + 0.1 + + x0 = self.conv(x0) + x1 = self.conv(x1) + + return x0, x1, m0, m1 + + +@register('m2m_pwc') +class M2M_PWC(torch.nn.Module): + def __init__(self, ratio=4): + super(M2M_PWC, self).__init__() + self.branch = 4 + self.ratio = ratio + + self.paramAlpha = torch.nn.Parameter(10.0 * torch.ones(1, 1, 1, 1)) + + class MotionRefineNet(torch.nn.Module): + def __init__(self, branch): + super(MotionRefineNet, self).__init__() + self.branch = branch + self.img_pyramid = ImgPyramid() + self.motion_encdec = EncDec(branch) + + def forward(self, flow0, flow1, im0, im1, ratio): + flow0 = ratio * torch.nn.functional.interpolate(input=flow0, scale_factor=ratio, mode='bilinear', + align_corners=False) + flow1 = ratio * torch.nn.functional.interpolate(input=flow1, scale_factor=ratio, mode='bilinear', + align_corners=False) + + c0 = self.img_pyramid(im0) + c1 = self.img_pyramid(im1) + + flow_res = self.motion_encdec(flow0, flow1, im0, im1, c0, c1) + + flow0 = flow0.repeat(1, self.branch, 1, 1) + flow_res[0] + flow1 = flow1.repeat(1, self.branch, 1, 1) + flow_res[1] + + return flow0, flow1, flow_res[2], flow_res[3] + + self.MRN = MotionRefineNet(self.branch) + + def forward(self, img0, img1, time_step=[0.5], ratio=None, **kwargs): + if ratio is None: + ratio = self.ratio + + intWidth = img0.shape[3] and img1.shape[3] + intHeight = img0.shape[2] and img1.shape[2] + + intPadr = ((ratio * 16) - (intWidth % (ratio * 16))) % (ratio * 16) + intPadb = ((ratio * 16) - (intHeight % (ratio * 16))) % (ratio * 16) + + img0 = torch.nn.functional.pad(input=img0, pad=[0, intPadr, 0, intPadb], mode='replicate') + img1 = torch.nn.functional.pad(input=img1, pad=[0, intPadr, 0, intPadb], mode='replicate') + + N_, C_, H_, W_ = img0.shape + + outputs = [] + result_dict = {} + with torch.set_grad_enabled(False): + tenStats = [img0, img1] + tenMean_ = sum([tenIn.mean([1, 2, 3], True) for tenIn in tenStats]) / len(tenStats) + tenStd_ = (sum([tenIn.std([1, 2, 3], False, True).square() + ( + tenMean_ - tenIn.mean([1, 2, 3], True)).square() for tenIn in tenStats]) / len(tenStats)).sqrt() + + im0_o = (img0 - tenMean_) / (tenStd_ + 0.0000001) + im1_o = (img1 - tenMean_) / (tenStd_ + 0.0000001) + + img0 = (img0 - tenMean_) / (tenStd_ + 0.0000001) + img1 = (img1 - tenMean_) / (tenStd_ + 0.0000001) + + im0_ = torch.nn.functional.interpolate(input=img0, scale_factor=2.0 / ratio, mode='bilinear', + align_corners=False) + im1_ = torch.nn.functional.interpolate(input=img1, scale_factor=2.0 / ratio, mode='bilinear', + align_corners=False) + + tenFwd, tenBwd = self.netFlow.bidir(im0_, im1_) + + result_dict['flowfwd'] = torch.nn.functional.interpolate(tenFwd, scale_factor=ratio, mode='bilinear', align_corners=False)[:, :, + :intHeight, :intWidth].clone().detach() * ratio + result_dict['flowbwd'] = torch.nn.functional.interpolate(tenBwd, scale_factor=ratio, mode='bilinear', align_corners=False)[:, :, + :intHeight, :intWidth].clone().detach() * ratio + + tenFwd, tenBwd, WeiMF, WeiMB = self.MRN(tenFwd, tenBwd, img0, img1, ratio) + + img0 = im0_o.repeat(1, self.branch, 1, 1) + img1 = im1_o.repeat(1, self.branch, 1, 1) + tenStd = tenStd_.repeat(1, self.branch, 1, 1) + tenMean = tenMean_.repeat(1, self.branch, 1, 1) + fltTime = time_step.repeat(1, self.branch, 1, 1) + + tenFwd = tenFwd.reshape(N_, self.branch, 2, H_, W_).view(N_ * self.branch, 2, H_, W_) + tenBwd = tenBwd.reshape(N_, self.branch, 2, H_, W_).view(N_ * self.branch, 2, H_, W_) + + WeiMF = WeiMF.reshape(N_, self.branch, 1, H_, W_).view(N_ * self.branch, 1, H_, W_) + WeiMB = WeiMB.reshape(N_, self.branch, 1, H_, W_).view(N_ * self.branch, 1, H_, W_) + + img0 = img0.reshape(N_, self.branch, 3, H_, W_).view(N_ * self.branch, 3, H_, W_) + img1 = img1.reshape(N_, self.branch, 3, H_, W_).view(N_ * self.branch, 3, H_, W_) + + tenStd = tenStd.reshape(N_, self.branch, 1, 1, 1).view(N_ * self.branch, 1, 1, 1) + tenMean = tenMean.reshape(N_, self.branch, 1, 1, 1).view(N_ * self.branch, 1, 1, 1) + fltTime = fltTime.reshape(N_, self.branch, 1, 1, 1).view(N_ * self.branch, 1, 1, 1) + + tenPhotoone = (1.0 - (WeiMF * (img0 - backwarp(img1, tenFwd).detach()).abs().mean([1], True))).clip( + 0.001, None).square() + tenPhototwo = (1.0 - (WeiMB * (img1 - backwarp(img0, tenBwd).detach()).abs().mean([1], True))).clip( + 0.001, None).square() + + t0 = fltTime + flow0 = tenFwd * t0 + metric0 = self.paramAlpha * tenPhotoone + + t1 = 1.0 - fltTime + flow1 = tenBwd * t1 + metric1 = self.paramAlpha * tenPhototwo + + flow0 = flow0.reshape(N_, self.branch, 2, H_, W_).permute(1, 0, 2, 3, 4) + flow1 = flow1.reshape(N_, self.branch, 2, H_, W_).permute(1, 0, 2, 3, 4) + + metric0 = metric0.reshape(N_, self.branch, 1, H_, W_).permute(1, 0, 2, 3, 4) + metric1 = metric1.reshape(N_, self.branch, 1, H_, W_).permute(1, 0, 2, 3, 4) + + img0 = img0.reshape(N_, self.branch, 3, H_, W_).permute(1, 0, 2, 3, 4) + img1 = img1.reshape(N_, self.branch, 3, H_, W_).permute(1, 0, 2, 3, 4) + + t0 = t0.reshape(N_, self.branch, 1, 1, 1).permute(1, 0, 2, 3, 4) + t1 = t1.reshape(N_, self.branch, 1, 1, 1).permute(1, 0, 2, 3, 4) + + tenOutput, mask = forwarp_mframe_mask(img0, flow0, t1, img1, flow1, t0, metric0, metric1) + + tenOutput = tenOutput + mask * (t1.mean(0) * im0_o + t0.mean(0) * im1_o) + + output = (tenOutput * (tenStd_ + 0.0000001)) + tenMean_ + result_dict['imgt_pred'] = output[:, :, :intHeight, :intWidth] + + return result_dict + +class ResBlock(nn.Module): + def __init__(self, in_channels, side_channels, bias=True): + super(ResBlock, self).__init__() + self.side_channels = side_channels + self.conv1 = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(in_channels) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(side_channels) + ) + self.conv3 = nn.Sequential( + nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(in_channels) + ) + self.conv4 = nn.Sequential( + nn.Conv2d(side_channels, side_channels, kernel_size=3, stride=1, padding=1, bias=bias), + nn.PReLU(side_channels) + ) + self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias) + self.prelu = nn.PReLU(in_channels) + + def forward(self, x): + out = self.conv1(x) + + res_feat = out[:, :-self.side_channels, ...] + side_feat = out[:, -self.side_channels:, :, :] + side_feat = self.conv2(side_feat) + out = self.conv3(torch.cat([res_feat, side_feat], 1)) + + res_feat = out[:, :-self.side_channels, ...] + side_feat = out[:, -self.side_channels:, :, :] + side_feat = self.conv4(side_feat) + out = self.conv5(torch.cat([res_feat, side_feat], 1)) + + out = self.prelu(x + out) + return out \ No newline at end of file diff --git a/modules/components/upr_net_mod2/softsplat.py b/modules/components/upr_net_mod2/softsplat.py new file mode 100644 index 0000000000000000000000000000000000000000..3c4b3fe227283b5ecb256b8ed2aa7b0846a4ccd2 --- /dev/null +++ b/modules/components/upr_net_mod2/softsplat.py @@ -0,0 +1,359 @@ +#!/usr/bin/env python + +import torch + +import cupy +import re + +kernel_Softsplat_updateOutput = ''' + extern "C" __global__ void kernel_Softsplat_updateOutput( + const int n, + const float* input, + const float* flow, + float* output + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(output) / SIZE_2(output) / SIZE_1(output) ) % SIZE_0(output); + const int intC = ( intIndex / SIZE_3(output) / SIZE_2(output) ) % SIZE_1(output); + const int intY = ( intIndex / SIZE_3(output) ) % SIZE_2(output); + const int intX = ( intIndex ) % SIZE_3(output); + + float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); + float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); + + int intNorthwestX = (int) (floor(fltOutputX)); + int intNorthwestY = (int) (floor(fltOutputY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + float fltNorthwest = ((float) (intSoutheastX) - fltOutputX ) * ((float) (intSoutheastY) - fltOutputY ); + float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY ); + float fltSouthwest = ((float) (intNortheastX) - fltOutputX ) * (fltOutputY - (float) (intNortheastY)); + float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY)); + + if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(output)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(output))) { + atomicAdd(&output[OFFSET_4(output, intN, intC, intNorthwestY, intNorthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltNorthwest); + } + + if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(output)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(output))) { + atomicAdd(&output[OFFSET_4(output, intN, intC, intNortheastY, intNortheastX)], VALUE_4(input, intN, intC, intY, intX) * fltNortheast); + } + + if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(output)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(output))) { + atomicAdd(&output[OFFSET_4(output, intN, intC, intSouthwestY, intSouthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltSouthwest); + } + + if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(output)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(output))) { + atomicAdd(&output[OFFSET_4(output, intN, intC, intSoutheastY, intSoutheastX)], VALUE_4(input, intN, intC, intY, intX) * fltSoutheast); + } + } } +''' + +kernel_Softsplat_updateGradInput = ''' + extern "C" __global__ void kernel_Softsplat_updateGradInput( + const int n, + const float* input, + const float* flow, + const float* gradOutput, + float* gradInput, + float* gradFlow + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) / SIZE_1(gradInput) ) % SIZE_0(gradInput); + const int intC = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) ) % SIZE_1(gradInput); + const int intY = ( intIndex / SIZE_3(gradInput) ) % SIZE_2(gradInput); + const int intX = ( intIndex ) % SIZE_3(gradInput); + + float fltGradInput = 0.0; + + float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); + float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); + + int intNorthwestX = (int) (floor(fltOutputX)); + int intNorthwestY = (int) (floor(fltOutputY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + float fltNorthwest = ((float) (intSoutheastX) - fltOutputX ) * ((float) (intSoutheastY) - fltOutputY ); + float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY ); + float fltSouthwest = ((float) (intNortheastX) - fltOutputX ) * (fltOutputY - (float) (intNortheastY)); + float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY)); + + if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) { + fltGradInput += VALUE_4(gradOutput, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest; + } + + if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) { + fltGradInput += VALUE_4(gradOutput, intN, intC, intNortheastY, intNortheastX) * fltNortheast; + } + + if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) { + fltGradInput += VALUE_4(gradOutput, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest; + } + + if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) { + fltGradInput += VALUE_4(gradOutput, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast; + } + + gradInput[intIndex] = fltGradInput; + } } +''' + +kernel_Softsplat_updateGradFlow = ''' + extern "C" __global__ void kernel_Softsplat_updateGradFlow( + const int n, + const float* input, + const float* flow, + const float* gradOutput, + float* gradInput, + float* gradFlow + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + float fltGradFlow = 0.0; + + const int intN = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) / SIZE_1(gradFlow) ) % SIZE_0(gradFlow); + const int intC = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) ) % SIZE_1(gradFlow); + const int intY = ( intIndex / SIZE_3(gradFlow) ) % SIZE_2(gradFlow); + const int intX = ( intIndex ) % SIZE_3(gradFlow); + + float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); + float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); + + int intNorthwestX = (int) (floor(fltOutputX)); + int intNorthwestY = (int) (floor(fltOutputY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + float fltNorthwest = 0.0; + float fltNortheast = 0.0; + float fltSouthwest = 0.0; + float fltSoutheast = 0.0; + + if (intC == 0) { + fltNorthwest = ((float) (-1.0)) * ((float) (intSoutheastY) - fltOutputY ); + fltNortheast = ((float) (+1.0)) * ((float) (intSouthwestY) - fltOutputY ); + fltSouthwest = ((float) (-1.0)) * (fltOutputY - (float) (intNortheastY)); + fltSoutheast = ((float) (+1.0)) * (fltOutputY - (float) (intNorthwestY)); + + } else if (intC == 1) { + fltNorthwest = ((float) (intSoutheastX) - fltOutputX ) * ((float) (-1.0)); + fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (-1.0)); + fltSouthwest = ((float) (intNortheastX) - fltOutputX ) * ((float) (+1.0)); + fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * ((float) (+1.0)); + + } + + for (int intChannel = 0; intChannel < SIZE_1(gradOutput); intChannel += 1) { + float fltInput = VALUE_4(input, intN, intChannel, intY, intX); + + if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) { + fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNorthwestY, intNorthwestX) * fltNorthwest; + } + + if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) { + fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNortheastY, intNortheastX) * fltNortheast; + } + + if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) { + fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSouthwestY, intSouthwestX) * fltSouthwest; + } + + if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) { + fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSoutheastY, intSoutheastX) * fltSoutheast; + } + } + + gradFlow[intIndex] = fltGradFlow; + } } +''' + +def cupy_kernel(strFunction, objVariables): + strKernel = globals()[strFunction] + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg])) + + while True: + objMatch = re.search('(OFFSET_)([0-4])(\()([^\)]+)(\))', strKernel) + + if objMatch is None: + break + + intArgs = int(objMatch.group(2)) + strArgs = objMatch.group(4).split(',') + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')')\ + .strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] + + strKernel = strKernel.replace(objMatch.group(0), '(' + str.join('+', strIndex) + ')') + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) + + if objMatch is None: + break + + intArgs = int(objMatch.group(2)) + strArgs = objMatch.group(4).split(',') + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')')\ + .strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] + + strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') + + return strKernel + + +@cupy.memoize(for_each_device=True) +def cupy_launch(strFunction, strKernel): + return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) + + +class _FunctionSoftsplat(torch.autograd.Function): + @staticmethod + def forward(self, input, flow): + self.save_for_backward(input, flow) + + intSamples = input.shape[0] + intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3] + intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3] + + assert(intFlowDepth == 2) + assert(intInputHeight == intFlowHeight) + assert(intInputWidth == intFlowWidth) + + assert(input.is_contiguous() == True) + assert(flow.is_contiguous() == True) + + output = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ]) + + if input.is_cuda == True: + n = output.nelement() + cupy_launch('kernel_Softsplat_updateOutput', cupy_kernel('kernel_Softsplat_updateOutput', { + 'input': input, + 'flow': flow, + 'output': output + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ n, input.data_ptr(), flow.data_ptr(), output.data_ptr() ] + ) + + elif input.is_cuda == False: + raise NotImplementedError() + + return output + + + @staticmethod + def backward(self, gradOutput): + input, flow = self.saved_tensors + + intSamples = input.shape[0] + intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3] + intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3] + + assert(intFlowDepth == 2) + assert(intInputHeight == intFlowHeight) + assert(intInputWidth == intFlowWidth) + + assert(gradOutput.is_contiguous() == True) + + gradInput = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ])\ + if self.needs_input_grad[0] == True else None + gradFlow = input.new_zeros([ intSamples, intFlowDepth, intFlowHeight, intFlowWidth ])\ + if self.needs_input_grad[1] == True else None + + if input.is_cuda == True: + if gradInput is not None: + n = gradInput.nelement() + cupy_launch('kernel_Softsplat_updateGradInput', cupy_kernel('kernel_Softsplat_updateGradInput', { + 'input': input, + 'flow': flow, + 'gradOutput': gradOutput, + 'gradInput': gradInput, + 'gradFlow': gradFlow + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ n, input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), gradInput.data_ptr(), None ] + ) + + if gradFlow is not None: + n = gradFlow.nelement() + cupy_launch('kernel_Softsplat_updateGradFlow', cupy_kernel('kernel_Softsplat_updateGradFlow', { + 'input': input, + 'flow': flow, + 'gradOutput': gradOutput, + 'gradInput': gradInput, + 'gradFlow': gradFlow + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ n, input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), None, gradFlow.data_ptr() ] + ) + + elif input.is_cuda == False: + raise NotImplementedError() + + + return gradInput, gradFlow + + +def FunctionSoftsplat(tenInput, tenFlow, tenMetric, strType): + assert(tenMetric is None or tenMetric.shape[1] == 1) + assert(strType in ['summation', 'average', 'linear', 'softmax']) + + if strType == 'average': + tenInput = torch.cat([ tenInput, tenInput.new_ones(tenInput.shape[0], 1, tenInput.shape[2], tenInput.shape[3]) ], 1) + + elif strType == 'linear': + tenInput = torch.cat([ tenInput * tenMetric, tenMetric ], 1) + + elif strType == 'softmax': + tenInput = torch.cat([ tenInput * tenMetric.clip(-20, 20).exp(), tenMetric.clip(-20, 20).exp() ], 1) + + + tenOutput = _FunctionSoftsplat.apply(tenInput, tenFlow) + + if strType != 'summation': + tenNormalize = tenOutput[:, -1:, :, :] + + tenNormalize[tenNormalize == 0.0] = 1.0 + + tenOutput = tenOutput[:, :-1, :, :] / tenNormalize + + return tenOutput + + +class ModuleSoftsplat(torch.nn.Module): + def __init__(self, strType): + super(ModuleSoftsplat, self).__init__() + + self.strType = strType + + def forward(self, tenInput, tenFlow, tenMetric): + return FunctionSoftsplat(tenInput, tenFlow, tenMetric, self.strType) diff --git a/modules/components/upr_net_mod2/upr.py b/modules/components/upr_net_mod2/upr.py new file mode 100644 index 0000000000000000000000000000000000000000..c6a0537bd33560e0833ec51e47b5b76ea68e28a4 --- /dev/null +++ b/modules/components/upr_net_mod2/upr.py @@ -0,0 +1,541 @@ +import torch +import math +import numpy +import torch.nn.functional as F +import torch.nn as nn +import torchvision.transforms.v2.functional as TF + +import modules.components.upr_net_mod2.correlation as correlation +import modules.components.upr_net_mod2.softsplat as softsplat +from modules.components.upr_net_mod2.m2m import * +from modules.components.upr_net_mod2.backwarp import backwarp +from .costvol import costvol_func +from ..components import register + +from utils.padder import InputPadder +from utils.vos.model.network import STCN +from utils.vos.model.inference_core import InferenceCore + + +# **************************************************************************************************# +# => Feature Pyramid +# **************************************************************************************************# + + +def photometric_consistency(img0, img1, flow01): + return (img0 - backwarp(img1, flow01)).abs().sum(dim=1, keepdims=True) + + +def flow_consistency(flow01, flow10): + return (flow01 + backwarp(flow10, flow01)).abs().sum(dim=1, keepdims=True) + + +def gaussian(x): + gaussian_kernel = torch.tensor([[1, 2, 1], + [2, 4, 2], + [1, 2, 1]]) / 16 + gaussian_kernel = gaussian_kernel.repeat(2, 1, 1, 1) + gaussian_kernel = gaussian_kernel.to(torch.cuda.current_device()) + x = torch.nn.functional.pad(x, (1, 1, 1, 1), mode='reflect') + out = torch.nn.functional.conv2d(x, gaussian_kernel, groups=x.shape[1]) + # out = TF.gaussian_blur(x, [3, 3], sigma=[2, 2]) + return out + + +def variance_flow(flow): + flow = flow * torch.tensor(data=[2.0 / (flow.shape[3] - 1.0), 2.0 / (flow.shape[2] - 1.0)], dtype=flow.dtype, + device=flow.device).view(1, 2, 1, 1) + return (gaussian(flow ** 2) - gaussian(flow) ** 2 + 1e-4).sqrt().abs().sum(dim=1, keepdim=True) + + +class FeatPyramid(nn.Module): + """A 3-level feature pyramid, which by default is shared by the motion + estimator and synthesis network. + """ + + def __init__(self): + super(FeatPyramid, self).__init__() + self.conv_stage0 = nn.Sequential( + nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1)) + self.conv_stage1 = nn.Sequential( + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, + stride=2, padding=1), + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), ) + self.conv_stage2 = nn.Sequential( + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, + stride=2, padding=1), + nn.InstanceNorm2d(num_features=128), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=128), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=128), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, img): + C0 = self.conv_stage0(img) + C1 = self.conv_stage1(C0) + C2 = self.conv_stage2(C1) + return [C0, C1, C2] + + +# **************************************************************************************************# +# => Motion Estimation +# **************************************************************************************************# +class MotionEstimator(nn.Module): + """Bi-directional optical flow estimator + 1) construct partial cost volume with the CNN features from the stage 2 of + the feature pyramid; + 2) estimate bi-directional flows, by feeding cost volume, CNN features for + both warped images, CNN feature and estimated flow from previous iteration. + """ + + def __init__(self): + super(MotionEstimator, self).__init__() + # 64 + 256 + 128 * 2 + 128 = 704 + self.conv_flow = nn.Sequential( + nn.Conv2d(4, 128, 7, padding=3), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(128, 64, 3, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + self.conv_corr = nn.Sequential( + nn.Conv2d(81, 64, 1, padding=0), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(64, 128, 3, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + self.conv_layer1 = nn.Sequential( + nn.Conv2d(in_channels=704, out_channels=320, + kernel_size=1, stride=1, padding=0), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer2 = nn.Sequential( + nn.Conv2d(in_channels=320, out_channels=256, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer3 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=224, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer4 = nn.Sequential( + nn.Conv2d(in_channels=224, out_channels=192, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer5 = nn.Sequential( + nn.Conv2d(in_channels=192, out_channels=128, + kernel_size=3, stride=1, padding=1)) + self.conv_layer6 = nn.Sequential( + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=4, + kernel_size=3, stride=1, padding=1, bias=False)) + + self.upsampler = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 16 * 9, 1, padding=0) + ) + + # for m in self.modules(): + # if isinstance(m, nn.Conv2d): + # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') + # elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + # if m.weight is not None: + # nn.init.constant_(m.weight, 1) + # if m.bias is not None: + # nn.init.constant_(m.bias, 0) + + def upsample(self, flow, mask): + """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 4, 4, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(4 * flow, [3, 3], padding=1) + up_flow = up_flow.view(N, 4, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 4, 4 * H, 4 * W) + + def forward(self, feat0, feat1, last_feat, last_flow): + corr_fn = correlation.FunctionCorrelation + feat0_warp = backwarp(feat0, last_flow[:, :2]) + feat1_warp = backwarp(feat1, last_flow[:, 2:]) + + volume0 = F.leaky_relu( + input=costvol_func.apply(feat0_warp, feat1_warp), + negative_slope=0.1, inplace=False) + volume1 = F.leaky_relu( + input=costvol_func.apply(feat1_warp, feat0_warp), + negative_slope=0.1, inplace=False) + corr0 = self.conv_corr(volume0) + corr1 = self.conv_corr(volume1) + flo = self.conv_flow(last_flow) + input_feat = torch.cat([corr0, corr1, feat0_warp, feat1_warp, last_feat, flo], 1) + feat = self.conv_layer1(input_feat) + feat = self.conv_layer2(feat) + feat = self.conv_layer3(feat) + feat = self.conv_layer4(feat) + feat = self.conv_layer5(feat) + flow_res = self.conv_layer6(feat) + flow = last_flow + flow_res + mask = self.upsampler(feat) * .25 + flow = self.upsample(flow, mask) + + return flow, feat + + +# **************************************************************************************************# +# => Frame Synthesis +# **************************************************************************************************# +class SynthesisNetwork(nn.Module): + def __init__(self, splat_mode='average'): + super(SynthesisNetwork, self).__init__() + input_channels = 9 + 4 + 6 + self.encoder_conv = nn.Sequential( + nn.Conv2d(in_channels=input_channels, out_channels=64, + kernel_size=3, stride=1, padding=1), + nn.PReLU(num_parameters=64), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64)) + self.encoder_down1 = nn.Sequential( + nn.Conv2d(in_channels=64 + 32 + 32, out_channels=128, + kernel_size=3, stride=2, padding=1), + nn.PReLU(num_parameters=128), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=128), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=128)) + self.encoder_down2 = nn.Sequential( + nn.Conv2d(in_channels=128 + 64 + 64, out_channels=256, + kernel_size=3, stride=2, padding=1), + nn.PReLU(num_parameters=256), + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=256), + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=256)) + self.decoder_up1 = nn.Sequential( + torch.nn.ConvTranspose2d(in_channels=256 + 128 + 128, + out_channels=128, kernel_size=4, stride=2, + padding=1, bias=True), + nn.PReLU(num_parameters=128), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=128)) + self.decoder_up2 = nn.Sequential( + torch.nn.ConvTranspose2d(in_channels=128 + 128, + out_channels=64, kernel_size=4, stride=2, + padding=1, bias=True), + nn.PReLU(num_parameters=64), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64)) + self.decoder_conv = nn.Sequential( + nn.Conv2d(in_channels=64 + 64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64)) + self.pred = nn.Conv2d(in_channels=64, out_channels=4, kernel_size=3, + stride=1, padding=1) + self.splat_mode = splat_mode + + if self.splat_mode == 'softmax': + # New params for splatting mask generation + self.alpha = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_photo_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_flow_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_variation_flow = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + + def get_splat_weight(self, img0, img1, flow01, flow10): + if self.splat_mode == 'softmax': + M_splat = 1 / ( + 1 + self.alpha_splat_photo_consistency * photometric_consistency(img0, img1, flow01).detach()) + \ + 1 / (1 + self.alpha_splat_flow_consistency * flow_consistency(flow01, flow10).detach()) + \ + 1 / (1 + self.alpha_splat_variation_flow * variance_flow(flow01).detach()) + return M_splat * self.alpha + else: + return None + + def get_warped_representations(self, bi_flow, c0, c1, m_splat_0, m_splat_1, i0=None, i1=None, time_period=0.5): + flow_t0 = bi_flow[:, :2] * time_period * 2 + flow_t1 = bi_flow[:, 2:4] * (1 - time_period) * 2 + warped_c0 = backwarp(c0, flow_t0) + warped_c1 = backwarp(c1, flow_t1) + if (i0 is None) and (i1 is None): + return warped_c0, warped_c1 + else: + warped_img0 = backwarp(i0, flow_t0) + warped_img1 = backwarp(i1, flow_t1) + scaler = torch.Tensor([i0.shape[3], i0.shape[2]]).view(1, 2, 1, 1).cuda() + flow_t0_t1 = torch.cat((flow_t0 / scaler, flow_t1 / scaler), 1) + return warped_img0, warped_img1, warped_c0, warped_c1, flow_t0_t1 + + def forward(self, last_i, i0, i1, c0_pyr, c1_pyr, bi_flow_pyr, time_period=0.5, multi_flow=False): + m_splat_0_0 = self.get_splat_weight(i0, i1, bi_flow_pyr[0][:, :2], bi_flow_pyr[0][:, 2:4]) + m_splat_1_0 = self.get_splat_weight(i1, i0, bi_flow_pyr[0][:, 2:4], bi_flow_pyr[0][:, :2]) + warped_img0, warped_img1, warped_c0, warped_c1, flow_0t_1t = \ + self.get_warped_representations( + bi_flow_pyr[0], c0_pyr[0], c1_pyr[0], m_splat_0_0, m_splat_1_0, i0, i1, + time_period=time_period) + input_feat = torch.cat( + (last_i, warped_img0, warped_img1, i0, i1, flow_0t_1t), 1) + s0 = self.encoder_conv(input_feat) # [B, 64,h,w] + s1 = self.encoder_down1(torch.cat((s0, warped_c0, warped_c1), 1)) # [B, 128,h/2,w/2] + warped_c0, warped_c1 = self.get_warped_representations( + bi_flow_pyr[1], c0_pyr[1], c1_pyr[1], None, None, + time_period=time_period) + s2 = self.encoder_down2(torch.cat((s1, warped_c0, warped_c1), 1)) # [B, 256,h/4,w/4] + warped_c0, warped_c1 = self.get_warped_representations( + bi_flow_pyr[2], c0_pyr[2], c1_pyr[2], None, None, + time_period=time_period) + + x = self.decoder_up1(torch.cat((s2, warped_c0, warped_c1), 1)) + x = self.decoder_up2(torch.cat((x, s1), 1)) + x = self.decoder_conv(torch.cat((x, s0), 1)) + + # prediction + refine = self.pred(x) + refine_res = torch.sigmoid(refine[:, :3]) * 2 - 1 + refine_mask = torch.sigmoid(refine[:, 3:]) + merged_img = (warped_img0 * refine_mask + + warped_img1 * (1 - refine_mask)) + interp_img = merged_img + refine_res + # interp_img = torch.clamp(interp_img, 0, 1) + + extra_dict = {} + extra_dict["refine_res"] = refine_res + extra_dict["refine_mask"] = refine_mask + extra_dict["warped_img0"] = warped_img0 + extra_dict["warped_img1"] = warped_img1 + extra_dict["merged_img"] = merged_img + extra_dict["c0_pyr"] = c0_pyr + extra_dict["c1_pyr"] = c1_pyr + extra_dict["syn_pyr"] = [s0,s1,s2] + + return interp_img, extra_dict + + +# **************************************************************************************************# +# => Unified model +# **************************************************************************************************# +@register('upr_net_mod2') +class Model(nn.Module): + def __init__(self, pyr_level=3, nr_lvl_skipped=0, splat_mode='average'): + super(Model, self).__init__() + print('@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@UPR-back exp45@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@') + self.pyr_level = pyr_level + self.feat_pyramid = FeatPyramid() + self.nr_lvl_skipped = nr_lvl_skipped + self.motion_estimator = MotionEstimator() + self.synthesis_network = SynthesisNetwork(splat_mode) + self.splat_mode = splat_mode + + def forward_one_lvl(self, + img0, img1, last_feat, last_flow, last_interp=None, + time_period=0.5, skip_me=False): + + # context feature extraction + feat0_pyr = self.feat_pyramid(img0) + feat1_pyr = self.feat_pyramid(img1) + + # bi-directional flow estimation + if not skip_me: + last_flow = F.interpolate( + input=last_flow, scale_factor=0.25, + mode="bilinear") * 0.25 + flow, feat = self.motion_estimator( + feat0_pyr[-1], feat1_pyr[-1], + last_feat, last_flow) + else: + flow = last_flow + feat = last_feat + + # frame synthesis + ## optical flow is estimated at 1/4 resolution + ori_resolution_flow = flow + + ## consturct 3-level flow pyramid for synthesis network + bi_flow_pyr = [] + tmp_flow = ori_resolution_flow + bi_flow_pyr.append(tmp_flow) + for i in range(2): + tmp_flow = F.interpolate( + input=tmp_flow, scale_factor=0.5, + mode="bilinear") * 0.5 + bi_flow_pyr.append(tmp_flow) + + ## merge warped frames as initial interpolation for frame synthesis + if last_interp is None: + flow_t0 = ori_resolution_flow[:, :2] * time_period * 2 + flow_t1 = ori_resolution_flow[:, 2:4] * (1 - time_period) * 2 + warped_img0 = backwarp(img0, flow_t0) + warped_img1 = backwarp(img1, flow_t1) + last_interp = warped_img0 * (1 - time_period) + warped_img1 * time_period + + ## do synthesis + interp_img, extra_dict = self.synthesis_network( + last_interp, img0, img1, feat0_pyr, feat1_pyr, bi_flow_pyr, + time_period=time_period) + return flow, feat, interp_img, extra_dict + + def forward(self, img0, img1, time_step, seg0=None, segt=None, seg1=None, + pyr_level=None, nr_lvl_skipped=None, imgt=None, **kwargs): + + if pyr_level is None: pyr_level = self.pyr_level + if nr_lvl_skipped is None: nr_lvl_skipped = self.nr_lvl_skipped + N, _, H, W = img0.shape + flow0_pred = [] + flow1_pred = [] + interp_imgs = [] + skipped_levels = [] if nr_lvl_skipped == 0 else \ + list(range(pyr_level))[::-1][-nr_lvl_skipped:] + + with torch.set_grad_enabled(False): + tenStats = [img0, img1] + tenMean_ = sum([tenIn.mean([1, 2, 3], True) for tenIn in tenStats]) / len(tenStats) + tenStd_ = (sum([tenIn.std([1, 2, 3], False, True).square() + ( + tenMean_ - tenIn.mean([1, 2, 3], True)).square() for tenIn in tenStats]) / len(tenStats)).sqrt() + + img0 = (img0 - tenMean_) / (tenStd_ + 0.0000001) + img1 = (img1 - tenMean_) / (tenStd_ + 0.0000001) + + padder = InputPadder(img0.shape, divisor=int(4 * 2 ** pyr_level)) + img0, img1 = padder.pad(img0, img1) + N, _, H, W = img0.shape + + # The original input resolution corresponds to level 0. + for level in list(range(pyr_level))[::-1]: + if level != 0: + scale_factor = 1 / 2 ** level + img0_this_lvl = F.interpolate( + input=img0, scale_factor=scale_factor, + mode="bilinear", align_corners=False) + img1_this_lvl = F.interpolate( + input=img1, scale_factor=scale_factor, + mode="bilinear", align_corners=False) + else: + img0_this_lvl = img0 + img1_this_lvl = img1 + + # skip motion estimation, directly use up-sampled optical flow + skip_me = False + + # the lowest-resolution pyramid level + if level == pyr_level - 1: + last_flow = torch.zeros( + (N, 4, H // (2 ** (level)), W // (2 ** (level))) + ).to(img0.device) + last_feat = torch.zeros( + (N, 128, H // (2 ** (level + 2)), W // (2 ** (level + 2))) + ).to(img0.device) + last_interp = None + # skip some levels for both motion estimation and frame synthesis + elif level in skipped_levels[:-1]: + continue + # last level (original input resolution), only skip motion estimation + elif (level == 0) and len(skipped_levels) > 0: + if len(skipped_levels) == pyr_level: + last_flow = torch.zeros( + (N, 4, H, W)).to(img0.device) + last_interp = None + else: + resize_factor = 2 ** len(skipped_levels) + last_flow = F.interpolate( + input=flow, scale_factor=resize_factor, + mode="bilinear", align_corners=False) * resize_factor + last_interp = F.interpolate( + input=interp_img, scale_factor=resize_factor, + mode="bilinear", align_corners=False) + skip_me = True + # last level (original input resolution), motion estimation + frame + # synthesis + else: + last_flow = F.interpolate(input=flow, scale_factor=2.0, + mode="bilinear", align_corners=False) * 2 + last_feat = F.interpolate(input=feat, scale_factor=2.0, + mode="bilinear", align_corners=False) + last_interp = F.interpolate( + input=interp_img, scale_factor=2.0, + mode="bilinear", align_corners=False) + + flow, feat, interp_img, extra_dict = self.forward_one_lvl( + img0_this_lvl, img1_this_lvl, + last_feat, last_flow, last_interp, + time_step, skip_me=skip_me) + flow0_pred.append( + padder.unpad(flow[:, :2])) + flow1_pred.append( + padder.unpad(flow[:, 2:])) + interp_imgs.append(padder.unpad(F.interpolate(interp_img, scale_factor=2 ** level)) * tenStd_ + tenMean_) + + # directly up-sample estimated flow to full resolution with bi-linear + # interpolation + refine_res = padder.unpad(extra_dict["refine_res"]) + refine_mask = padder.unpad(extra_dict["refine_mask"]) + c0_pyr = [padder.unpad(cc) for cc in extra_dict["c0_pyr"]] + c1_pyr = [padder.unpad(cc) for cc in extra_dict["c1_pyr"]] + syn_pyr = [padder.unpad(cc) for cc in extra_dict["syn_pyr"]] + warped_img0 = padder.unpad(extra_dict["warped_img0"]) * tenStd_ + tenMean_ + warped_img1 = padder.unpad(extra_dict["warped_img1"]) * tenStd_ + tenMean_ + merged_img = padder.unpad(extra_dict["merged_img"]) * tenStd_ + tenMean_ + result_dict = { + "imgt_preds": interp_imgs, "flow0_pred": flow0_pred[::-1], "flow1_pred": flow1_pred[::-1], + 'imgt_pred': interp_imgs[-1].contiguous(), "flowfwd": flow0_pred[-1], "flowbwd": flow1_pred[-1], + 'refine_res': refine_res, 'refine_mask': refine_mask, 'warped_img0': warped_img0, + 'warped_img1': warped_img1, 'merged_img': merged_img, 'c0_pyr': c0_pyr, 'c1_pyr': c1_pyr, 'syn_pyr': syn_pyr + } + + return result_dict + + +if __name__ == "__main__": + pass \ No newline at end of file diff --git a/modules/components/upr_net_mod2/upr_exp43.py b/modules/components/upr_net_mod2/upr_exp43.py new file mode 100644 index 0000000000000000000000000000000000000000..de4c224f142ceeb03776699afe98558af369e9e4 --- /dev/null +++ b/modules/components/upr_net_mod2/upr_exp43.py @@ -0,0 +1,535 @@ +import torch +import math +import numpy +import torch.nn.functional as F +import torch.nn as nn +import torchvision.transforms.v2.functional as TF + +import modules.components.upr_net_mod2.correlation as correlation +import modules.components.upr_net_mod2.softsplat as softsplat +from modules.components.upr_net_mod2.m2m import * +from modules.components.upr_net_mod2.backwarp import backwarp +from .costvol import costvol_func +from ..components import register + +from utils.padder import InputPadder +from utils.vos.model.network import STCN +from utils.vos.model.inference_core import InferenceCore + + +# **************************************************************************************************# +# => Feature Pyramid +# **************************************************************************************************# + + +def photometric_consistency(img0, img1, flow01): + return (img0 - backwarp(img1, flow01)).abs().sum(dim=1, keepdims=True) + + +def flow_consistency(flow01, flow10): + return (flow01 + backwarp(flow10, flow01)).abs().sum(dim=1, keepdims=True) + + +def gaussian(x): + gaussian_kernel = torch.tensor([[1, 2, 1], + [2, 4, 2], + [1, 2, 1]]) / 16 + gaussian_kernel = gaussian_kernel.repeat(2, 1, 1, 1) + gaussian_kernel = gaussian_kernel.to(torch.cuda.current_device()) + x = torch.nn.functional.pad(x, (1, 1, 1, 1), mode='reflect') + out = torch.nn.functional.conv2d(x, gaussian_kernel, groups=x.shape[1]) + # out = TF.gaussian_blur(x, [3, 3], sigma=[2, 2]) + return out + + +def variance_flow(flow): + flow = flow * torch.tensor(data=[2.0 / (flow.shape[3] - 1.0), 2.0 / (flow.shape[2] - 1.0)], dtype=flow.dtype, + device=flow.device).view(1, 2, 1, 1) + return (gaussian(flow ** 2) - gaussian(flow) ** 2 + 1e-4).sqrt().abs().sum(dim=1, keepdim=True) + + +class FeatPyramid(nn.Module): + """A 3-level feature pyramid, which by default is shared by the motion + estimator and synthesis network. + """ + + def __init__(self): + super(FeatPyramid, self).__init__() + self.conv_stage0 = nn.Sequential( + nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1)) + self.conv_stage1 = nn.Sequential( + nn.InstanceNorm2d(num_features=32), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, + stride=2, padding=1), + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), ) + self.conv_stage2 = nn.Sequential( + nn.InstanceNorm2d(num_features=64), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, + stride=2, padding=1), + nn.InstanceNorm2d(num_features=128), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=128), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.InstanceNorm2d(num_features=128), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, img): + C0 = self.conv_stage0(img) + C1 = self.conv_stage1(C0) + C2 = self.conv_stage2(C1) + return [C0, C1, C2] + + +# **************************************************************************************************# +# => Motion Estimation +# **************************************************************************************************# +class MotionEstimator(nn.Module): + """Bi-directional optical flow estimator + 1) construct partial cost volume with the CNN features from the stage 2 of + the feature pyramid; + 2) estimate bi-directional flows, by feeding cost volume, CNN features for + both warped images, CNN feature and estimated flow from previous iteration. + """ + + def __init__(self): + super(MotionEstimator, self).__init__() + # 64 + 256 + 128 * 2 + 128 = 704 + self.conv_flow = nn.Sequential( + nn.Conv2d(4, 128, 7, padding=3), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(128, 64, 3, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + self.conv_corr = nn.Sequential( + nn.Conv2d(81, 64, 1, padding=0), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(64, 128, 3, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + ) + self.conv_layer1 = nn.Sequential( + nn.Conv2d(in_channels=704, out_channels=320, + kernel_size=1, stride=1, padding=0), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer2 = nn.Sequential( + nn.Conv2d(in_channels=320, out_channels=256, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer3 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=224, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer4 = nn.Sequential( + nn.Conv2d(in_channels=224, out_channels=192, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer5 = nn.Sequential( + nn.Conv2d(in_channels=192, out_channels=128, + kernel_size=3, stride=1, padding=1)) + self.conv_layer6 = nn.Sequential( + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=4, + kernel_size=3, stride=1, padding=1, bias=False)) + + self.upsampler = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 16 * 9, 1, padding=0) + ) + + # for m in self.modules(): + # if isinstance(m, nn.Conv2d): + # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu') + # elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + # if m.weight is not None: + # nn.init.constant_(m.weight, 1) + # if m.bias is not None: + # nn.init.constant_(m.bias, 0) + + def upsample(self, flow, mask): + """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 4, 4, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(4 * flow, [3, 3], padding=1) + up_flow = up_flow.view(N, 4, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 4, 4 * H, 4 * W) + + def forward(self, feat0, feat1, last_feat, last_flow): + corr_fn = correlation.FunctionCorrelation + feat0_warp = backwarp(feat0, last_flow[:, :2]) + feat1_warp = backwarp(feat1, last_flow[:, 2:]) + + volume0 = F.leaky_relu( + input=costvol_func.apply(feat0_warp, feat1_warp), + negative_slope=0.1, inplace=False) + volume1 = F.leaky_relu( + input=costvol_func.apply(feat1_warp, feat0_warp), + negative_slope=0.1, inplace=False) + corr0 = self.conv_corr(volume0) + corr1 = self.conv_corr(volume1) + flo = self.conv_flow(last_flow) + input_feat = torch.cat([corr0, corr1, feat0_warp, feat1_warp, last_feat, flo], 1) + feat = self.conv_layer1(input_feat) + feat = self.conv_layer2(feat) + feat = self.conv_layer3(feat) + feat = self.conv_layer4(feat) + feat = self.conv_layer5(feat) + flow_res = self.conv_layer6(feat) + flow = last_flow + flow_res + mask = self.upsampler(feat) * .25 + flow = self.upsample(flow, mask) + + return flow, feat + + +# **************************************************************************************************# +# => Frame Synthesis +# **************************************************************************************************# +class SynthesisNetwork(nn.Module): + def __init__(self, splat_mode='average'): + super(SynthesisNetwork, self).__init__() + input_channels = 9 + 4 + 6 + self.encoder_conv = nn.Sequential( + nn.Conv2d(in_channels=input_channels, out_channels=64, + kernel_size=3, stride=1, padding=1), + nn.PReLU(num_parameters=64), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64)) + self.encoder_down1 = nn.Sequential( + nn.Conv2d(in_channels=64 + 32 + 32, out_channels=128, + kernel_size=3, stride=2, padding=1), + nn.PReLU(num_parameters=128), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=128), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=128)) + self.encoder_down2 = nn.Sequential( + nn.Conv2d(in_channels=128 + 64 + 64, out_channels=256, + kernel_size=3, stride=2, padding=1), + nn.PReLU(num_parameters=256), + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=256), + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=256)) + self.decoder_up1 = nn.Sequential( + torch.nn.ConvTranspose2d(in_channels=256 + 128 + 128, + out_channels=128, kernel_size=4, stride=2, + padding=1, bias=True), + nn.PReLU(num_parameters=128), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=128)) + self.decoder_up2 = nn.Sequential( + torch.nn.ConvTranspose2d(in_channels=128 + 128, + out_channels=64, kernel_size=4, stride=2, + padding=1, bias=True), + nn.PReLU(num_parameters=64), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64)) + self.decoder_conv = nn.Sequential( + nn.Conv2d(in_channels=64 + 64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64)) + self.pred = nn.Conv2d(in_channels=64, out_channels=4, kernel_size=3, + stride=1, padding=1) + self.splat_mode = splat_mode + + if self.splat_mode == 'softmax': + # New params for splatting mask generation + self.alpha = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_photo_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_flow_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_variation_flow = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + + def get_splat_weight(self, img0, img1, flow01, flow10): + if self.splat_mode == 'softmax': + M_splat = 1 / ( + 1 + self.alpha_splat_photo_consistency * photometric_consistency(img0, img1, flow01).detach()) + \ + 1 / (1 + self.alpha_splat_flow_consistency * flow_consistency(flow01, flow10).detach()) + \ + 1 / (1 + self.alpha_splat_variation_flow * variance_flow(flow01).detach()) + return M_splat * self.alpha + else: + return None + + def get_warped_representations(self, bi_flow, c0, c1, m_splat_0, m_splat_1, i0=None, i1=None, time_period=0.5): + flow_t0 = bi_flow[:, :2] * time_period * 2 + flow_t1 = bi_flow[:, 2:4] * (1 - time_period) * 2 + warped_c0 = backwarp(c0, flow_t0) + warped_c1 = backwarp(c1, flow_t1) + if (i0 is None) and (i1 is None): + return warped_c0, warped_c1 + else: + warped_img0 = backwarp(i0, flow_t0) + warped_img1 = backwarp(i1, flow_t1) + scaler = torch.Tensor([i0.shape[3], i0.shape[2]]).view(1, 2, 1, 1).cuda() + flow_t0_t1 = torch.cat((flow_t0 / scaler, flow_t1 / scaler), 1) + return warped_img0, warped_img1, warped_c0, warped_c1, flow_t0_t1 + + def forward(self, last_i, i0, i1, c0_pyr, c1_pyr, bi_flow_pyr, time_period=0.5, multi_flow=False): + m_splat_0_0 = self.get_splat_weight(i0, i1, bi_flow_pyr[0][:, :2], bi_flow_pyr[0][:, 2:4]) + m_splat_1_0 = self.get_splat_weight(i1, i0, bi_flow_pyr[0][:, 2:4], bi_flow_pyr[0][:, :2]) + warped_img0, warped_img1, warped_c0, warped_c1, flow_0t_1t = \ + self.get_warped_representations( + bi_flow_pyr[0], c0_pyr[0], c1_pyr[0], m_splat_0_0, m_splat_1_0, i0, i1, + time_period=time_period) + input_feat = torch.cat( + (last_i, warped_img0, warped_img1, i0, i1, flow_0t_1t), 1) + s0 = self.encoder_conv(input_feat) + s1 = self.encoder_down1(torch.cat((s0, warped_c0, warped_c1), 1)) + warped_c0, warped_c1 = self.get_warped_representations( + bi_flow_pyr[1], c0_pyr[1], c1_pyr[1], None, None, + time_period=time_period) + s2 = self.encoder_down2(torch.cat((s1, warped_c0, warped_c1), 1)) + warped_c0, warped_c1 = self.get_warped_representations( + bi_flow_pyr[2], c0_pyr[2], c1_pyr[2], None, None, + time_period=time_period) + + x = self.decoder_up1(torch.cat((s2, warped_c0, warped_c1), 1)) + x = self.decoder_up2(torch.cat((x, s1), 1)) + x = self.decoder_conv(torch.cat((x, s0), 1)) + + # prediction + refine = self.pred(x) + refine_res = torch.sigmoid(refine[:, :3]) * 2 - 1 + refine_mask = torch.sigmoid(refine[:, 3:]) + merged_img = (warped_img0 * refine_mask + + warped_img1 * (1 - refine_mask)) + interp_img = merged_img + refine_res + # interp_img = torch.clamp(interp_img, 0, 1) + + extra_dict = {} + extra_dict["refine_res"] = refine_res + extra_dict["refine_mask"] = refine_mask + extra_dict["warped_img0"] = warped_img0 + extra_dict["warped_img1"] = warped_img1 + extra_dict["merged_img"] = merged_img + + return interp_img, extra_dict + + +# **************************************************************************************************# +# => Unified model +# **************************************************************************************************# +@register('upr_net_mod2') +class Model(nn.Module): + def __init__(self, pyr_level=3, nr_lvl_skipped=0, splat_mode='average'): + super(Model, self).__init__() + print('@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@UPR-back exp43@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@') + self.pyr_level = pyr_level + self.feat_pyramid = FeatPyramid() + self.nr_lvl_skipped = nr_lvl_skipped + self.motion_estimator = MotionEstimator() + self.synthesis_network = SynthesisNetwork(splat_mode) + self.splat_mode = splat_mode + + def forward_one_lvl(self, + img0, img1, last_feat, last_flow, last_interp=None, + time_period=0.5, skip_me=False): + + # context feature extraction + feat0_pyr = self.feat_pyramid(img0) + feat1_pyr = self.feat_pyramid(img1) + + # bi-directional flow estimation + if not skip_me: + last_flow = F.interpolate( + input=last_flow, scale_factor=0.25, + mode="nearest") * 0.25 + flow, feat = self.motion_estimator( + feat0_pyr[-1], feat1_pyr[-1], + last_feat, last_flow) + else: + flow = last_flow + feat = last_feat + + # frame synthesis + ## optical flow is estimated at 1/4 resolution + ori_resolution_flow = flow + + ## consturct 3-level flow pyramid for synthesis network + bi_flow_pyr = [] + tmp_flow = ori_resolution_flow + bi_flow_pyr.append(tmp_flow) + for i in range(2): + tmp_flow = F.interpolate( + input=tmp_flow, scale_factor=0.5, + mode="nearest") * 0.5 + bi_flow_pyr.append(tmp_flow) + + ## merge warped frames as initial interpolation for frame synthesis + if last_interp is None: + flow_t0 = ori_resolution_flow[:, :2] * time_period * 2 + flow_t1 = ori_resolution_flow[:, 2:4] * (1 - time_period) * 2 + warped_img0 = backwarp(img0, flow_t0) + warped_img1 = backwarp(img1, flow_t1) + last_interp = warped_img0 * (1 - time_period) + warped_img1 * time_period + + ## do synthesis + interp_img, extra_dict = self.synthesis_network( + last_interp, img0, img1, feat0_pyr, feat1_pyr, bi_flow_pyr, + time_period=time_period) + return flow, feat, interp_img, extra_dict + + def forward(self, img0, img1, time_step, seg0=None, segt=None, seg1=None, + pyr_level=None, nr_lvl_skipped=None, imgt=None, **kwargs): + + if pyr_level is None: pyr_level = self.pyr_level + if nr_lvl_skipped is None: nr_lvl_skipped = self.nr_lvl_skipped + N, _, H, W = img0.shape + flow0_pred = [] + flow1_pred = [] + interp_imgs = [] + skipped_levels = [] if nr_lvl_skipped == 0 else \ + list(range(pyr_level))[::-1][-nr_lvl_skipped:] + + with torch.set_grad_enabled(False): + tenStats = [img0, img1] + tenMean_ = sum([tenIn.mean([1, 2, 3], True) for tenIn in tenStats]) / len(tenStats) + tenStd_ = (sum([tenIn.std([1, 2, 3], False, True).square() + ( + tenMean_ - tenIn.mean([1, 2, 3], True)).square() for tenIn in tenStats]) / len(tenStats)).sqrt() + + img0 = (img0 - tenMean_) / (tenStd_ + 0.0000001) + img1 = (img1 - tenMean_) / (tenStd_ + 0.0000001) + + padder = InputPadder(img0.shape, divisor=int(4 * 2 ** pyr_level)) + img0, img1 = padder.pad(img0, img1) + N, _, H, W = img0.shape + + # The original input resolution corresponds to level 0. + for level in list(range(pyr_level))[::-1]: + if level != 0: + scale_factor = 1 / 2 ** level + img0_this_lvl = F.interpolate( + input=img0, scale_factor=scale_factor, + mode="bilinear", align_corners=False) + img1_this_lvl = F.interpolate( + input=img1, scale_factor=scale_factor, + mode="bilinear", align_corners=False) + else: + img0_this_lvl = img0 + img1_this_lvl = img1 + + # skip motion estimation, directly use up-sampled optical flow + skip_me = False + + # the lowest-resolution pyramid level + if level == pyr_level - 1: + last_flow = torch.zeros( + (N, 4, H // (2 ** (level)), W // (2 ** (level))) + ).to(img0.device) + last_feat = torch.zeros( + (N, 128, H // (2 ** (level + 2)), W // (2 ** (level + 2))) + ).to(img0.device) + last_interp = None + # skip some levels for both motion estimation and frame synthesis + elif level in skipped_levels[:-1]: + continue + # last level (original input resolution), only skip motion estimation + elif (level == 0) and len(skipped_levels) > 0: + if len(skipped_levels) == pyr_level: + last_flow = torch.zeros( + (N, 4, H, W)).to(img0.device) + last_interp = None + else: + resize_factor = 2 ** len(skipped_levels) + last_flow = F.interpolate( + input=flow, scale_factor=resize_factor, + mode="bilinear", align_corners=False) * resize_factor + last_interp = F.interpolate( + input=interp_img, scale_factor=resize_factor, + mode="bilinear", align_corners=False) + skip_me = True + # last level (original input resolution), motion estimation + frame + # synthesis + else: + last_flow = F.interpolate(input=flow, scale_factor=2.0, + mode="bilinear", align_corners=False) * 2 + last_feat = F.interpolate(input=feat, scale_factor=2.0, + mode="bilinear", align_corners=False) * 2 + last_interp = F.interpolate( + input=interp_img, scale_factor=2.0, + mode="bilinear", align_corners=False) + + flow, feat, interp_img, extra_dict = self.forward_one_lvl( + img0_this_lvl, img1_this_lvl, + last_feat, last_flow, last_interp, + time_step, skip_me=skip_me) + flow0_pred.append( + padder.unpad(flow[:, :2])) + flow1_pred.append( + padder.unpad(flow[:, 2:])) + interp_imgs.append(padder.unpad(F.interpolate(interp_img, scale_factor=2 ** level)) * tenStd_ + tenMean_) + + # directly up-sample estimated flow to full resolution with bi-linear + # interpolation + refine_res = padder.unpad(extra_dict["refine_res"]) + refine_mask = padder.unpad(extra_dict["refine_mask"]) + warped_img0 = padder.unpad(extra_dict["warped_img0"]) * tenStd_ + tenMean_ + warped_img1 = padder.unpad(extra_dict["warped_img1"]) * tenStd_ + tenMean_ + merged_img = padder.unpad(extra_dict["merged_img"]) * tenStd_ + tenMean_ + result_dict = { + "imgt_preds": interp_imgs, "flow0_pred": flow0_pred[::-1], "flow1_pred": flow1_pred[::-1], + 'imgt_pred': interp_imgs[-1].contiguous(), "flowfwd": flow0_pred[-1], "flowbwd": flow1_pred[-1], + 'refine_res': refine_res, 'refine_mask': refine_mask, 'warped_img0': warped_img0, + 'warped_img1': warped_img1, 'merged_img': merged_img, + } + + return result_dict + + +if __name__ == "__main__": + pass \ No newline at end of file diff --git a/modules/components/upr_net_multi_flow/__init__.py b/modules/components/upr_net_multi_flow/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..563193494ca84698d0212bcdf1e6131124b6dfa5 --- /dev/null +++ b/modules/components/upr_net_multi_flow/__init__.py @@ -0,0 +1 @@ +from .upr import * diff --git a/modules/components/upr_net_multi_flow/__pycache__/__init__.cpython-310.pyc b/modules/components/upr_net_multi_flow/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..011cac6eed5be966c9692ff164aa6fc2dfa6c151 Binary files /dev/null and b/modules/components/upr_net_multi_flow/__pycache__/__init__.cpython-310.pyc differ diff --git a/modules/components/upr_net_multi_flow/__pycache__/__init__.cpython-38.pyc b/modules/components/upr_net_multi_flow/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..300e8dd07de030680a9da47b5903ddd562bf8c5a Binary files /dev/null and b/modules/components/upr_net_multi_flow/__pycache__/__init__.cpython-38.pyc differ diff --git a/modules/components/upr_net_multi_flow/__pycache__/__init__.cpython-39.pyc b/modules/components/upr_net_multi_flow/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..390a7c716933955717197e8f56f558bdadbecbbd Binary files /dev/null and b/modules/components/upr_net_multi_flow/__pycache__/__init__.cpython-39.pyc differ diff --git a/modules/components/upr_net_multi_flow/__pycache__/backwarp.cpython-310.pyc b/modules/components/upr_net_multi_flow/__pycache__/backwarp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f2e66701de55e196bf3687c972979702331bedd Binary files /dev/null and b/modules/components/upr_net_multi_flow/__pycache__/backwarp.cpython-310.pyc differ diff --git a/modules/components/upr_net_multi_flow/__pycache__/backwarp.cpython-38.pyc b/modules/components/upr_net_multi_flow/__pycache__/backwarp.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0899c641b7af279913aa4767b5b9fb5aa44a3ecc Binary files /dev/null and b/modules/components/upr_net_multi_flow/__pycache__/backwarp.cpython-38.pyc differ diff --git a/modules/components/upr_net_multi_flow/__pycache__/backwarp.cpython-39.pyc b/modules/components/upr_net_multi_flow/__pycache__/backwarp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a345c57c20a3bdc887afc1723162fe75c9a645d2 Binary files /dev/null and b/modules/components/upr_net_multi_flow/__pycache__/backwarp.cpython-39.pyc differ diff --git a/modules/components/upr_net_multi_flow/__pycache__/correlation.cpython-310.pyc b/modules/components/upr_net_multi_flow/__pycache__/correlation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a441a2494ec415ec2e1545e4a0796775964319d Binary files /dev/null and b/modules/components/upr_net_multi_flow/__pycache__/correlation.cpython-310.pyc differ diff --git a/modules/components/upr_net_multi_flow/__pycache__/correlation.cpython-38.pyc b/modules/components/upr_net_multi_flow/__pycache__/correlation.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49c72035e340dc8c7c31e44ef47b2b2a1b3d593d Binary files /dev/null and b/modules/components/upr_net_multi_flow/__pycache__/correlation.cpython-38.pyc differ diff --git a/modules/components/upr_net_multi_flow/__pycache__/correlation.cpython-39.pyc b/modules/components/upr_net_multi_flow/__pycache__/correlation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc02d0dfa050c1b2ce8c1ba25199062524840275 Binary files /dev/null and b/modules/components/upr_net_multi_flow/__pycache__/correlation.cpython-39.pyc differ diff --git a/modules/components/upr_net_multi_flow/__pycache__/m2m.cpython-310.pyc b/modules/components/upr_net_multi_flow/__pycache__/m2m.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3529a46b1942bb57b0c0750fd27d8656da56a0e3 Binary files /dev/null and b/modules/components/upr_net_multi_flow/__pycache__/m2m.cpython-310.pyc differ diff --git a/modules/components/upr_net_multi_flow/__pycache__/m2m.cpython-38.pyc b/modules/components/upr_net_multi_flow/__pycache__/m2m.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e68b370eda6b6e59f7b7a2cca4fe7cef3443e3bf Binary files /dev/null and b/modules/components/upr_net_multi_flow/__pycache__/m2m.cpython-38.pyc differ diff --git a/modules/components/upr_net_multi_flow/__pycache__/m2m.cpython-39.pyc b/modules/components/upr_net_multi_flow/__pycache__/m2m.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2eac9e0b7fa28b70c508f5ca950f00ecddad77e Binary files /dev/null and b/modules/components/upr_net_multi_flow/__pycache__/m2m.cpython-39.pyc differ diff --git a/modules/components/upr_net_multi_flow/__pycache__/softsplat.cpython-310.pyc b/modules/components/upr_net_multi_flow/__pycache__/softsplat.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce8fc9b1b8d04026a5c5b42960dc3eea5bae156e Binary files /dev/null and b/modules/components/upr_net_multi_flow/__pycache__/softsplat.cpython-310.pyc differ diff --git a/modules/components/upr_net_multi_flow/__pycache__/softsplat.cpython-38.pyc b/modules/components/upr_net_multi_flow/__pycache__/softsplat.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b120d361e9ba37a6c73ab0503b5c4ebde8218a89 Binary files /dev/null and b/modules/components/upr_net_multi_flow/__pycache__/softsplat.cpython-38.pyc differ diff --git a/modules/components/upr_net_multi_flow/__pycache__/softsplat.cpython-39.pyc b/modules/components/upr_net_multi_flow/__pycache__/softsplat.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ea69abdcf0dc5d5c71f5e46e62df3935af3c080 Binary files /dev/null and b/modules/components/upr_net_multi_flow/__pycache__/softsplat.cpython-39.pyc differ diff --git a/modules/components/upr_net_multi_flow/__pycache__/upr.cpython-310.pyc b/modules/components/upr_net_multi_flow/__pycache__/upr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..471b3db882b081c83635accb6af21b3107cd4567 Binary files /dev/null and b/modules/components/upr_net_multi_flow/__pycache__/upr.cpython-310.pyc differ diff --git a/modules/components/upr_net_multi_flow/__pycache__/upr.cpython-38.pyc b/modules/components/upr_net_multi_flow/__pycache__/upr.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9043e0e449bf259f7b38a92616e27bf74a3a411e Binary files /dev/null and b/modules/components/upr_net_multi_flow/__pycache__/upr.cpython-38.pyc differ diff --git a/modules/components/upr_net_multi_flow/__pycache__/upr.cpython-39.pyc b/modules/components/upr_net_multi_flow/__pycache__/upr.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2c0be63a1b33f7e7a6545707f005d00d3ec22c9 Binary files /dev/null and b/modules/components/upr_net_multi_flow/__pycache__/upr.cpython-39.pyc differ diff --git a/modules/components/upr_net_multi_flow/backwarp.py b/modules/components/upr_net_multi_flow/backwarp.py new file mode 100644 index 0000000000000000000000000000000000000000..e99a0a5c1b658e81536825451b865b39c45bc9c4 --- /dev/null +++ b/modules/components/upr_net_multi_flow/backwarp.py @@ -0,0 +1,29 @@ +#!/usr/bin/env python + +import torch + + +########################################################## + + +objBackwarpcache = {} + + +def backwarp(tenIn:torch.Tensor, tenFlow:torch.Tensor): + if 'grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3]) not in objBackwarpcache: + tenHor = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[3], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, 1, -1).repeat(1, 1, tenFlow.shape[2], 1) + tenVer = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[2], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, -1, 1).repeat(1, 1, 1, tenFlow.shape[3]) + + objBackwarpcache['grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3])] = torch.cat([tenHor, tenVer], 1) + # end + + if tenFlow.shape[3] == tenFlow.shape[2]: + tenFlow = tenFlow * (2.0 / ((tenFlow.shape[3] and tenFlow.shape[2]) - 1.0)) + + elif tenFlow.shape[3] != tenFlow.shape[2]: + tenFlow = tenFlow * torch.tensor(data=[2.0 / (tenFlow.shape[3] - 1.0), 2.0 / (tenFlow.shape[2] - 1.0)], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 2, 1, 1) + + # end + + return torch.nn.functional.grid_sample(input=tenIn, grid=(objBackwarpcache['grid' + str(tenFlow.dtype) + str(tenFlow.device) + str(tenFlow.shape[2]) + str(tenFlow.shape[3])] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=True) +# end diff --git a/modules/components/upr_net_multi_flow/correlation.py b/modules/components/upr_net_multi_flow/correlation.py new file mode 100644 index 0000000000000000000000000000000000000000..1d1c92e2ef7dd885f25b30a3b2e4ed25c6a3889e --- /dev/null +++ b/modules/components/upr_net_multi_flow/correlation.py @@ -0,0 +1,415 @@ +#!/usr/bin/env python + +import torch + +import cupy +import re + +kernel_Correlation_rearrange = ''' + extern "C" __global__ void kernel_Correlation_rearrange( + const int n, + const float* input, + float* output + ) { + int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; + + if (intIndex >= n) { + return; + } + + int intSample = blockIdx.z; + int intChannel = blockIdx.y; + + float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex]; + + __syncthreads(); + + int intPaddedY = (intIndex / SIZE_3(input)) + 4; + int intPaddedX = (intIndex % SIZE_3(input)) + 4; + int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX; + + output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue; + } +''' + +kernel_Correlation_updateOutput = ''' + extern "C" __global__ void kernel_Correlation_updateOutput( + const int n, + const float* rbot0, + const float* rbot1, + float* top + ) { + extern __shared__ char patch_data_char[]; + + float *patch_data = (float *)patch_data_char; + + // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1 + int x1 = blockIdx.x + 4; + int y1 = blockIdx.y + 4; + int item = blockIdx.z; + int ch_off = threadIdx.x; + + // Load 3D patch into shared shared memory + for (int j = 0; j < 1; j++) { // HEIGHT + for (int i = 0; i < 1; i++) { // WIDTH + int ji_off = (j + i) * SIZE_3(rbot0); + for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS + int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch; + int idxPatchData = ji_off + ch; + patch_data[idxPatchData] = rbot0[idx1]; + } + } + } + + __syncthreads(); + + __shared__ float sum[32]; + + // Compute correlation + for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) { + sum[ch_off] = 0; + + int s2o = top_channel % 9 - 4; + int s2p = top_channel / 9 - 4; + + for (int j = 0; j < 1; j++) { // HEIGHT + for (int i = 0; i < 1; i++) { // WIDTH + int ji_off = (j + i) * SIZE_3(rbot0); + for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS + int x2 = x1 + s2o; + int y2 = y1 + s2p; + + int idxPatchData = ji_off + ch; + int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch; + + sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2]; + } + } + } + + __syncthreads(); + + if (ch_off == 0) { + float total_sum = 0; + for (int idx = 0; idx < 32; idx++) { + total_sum += sum[idx]; + } + const int sumelems = SIZE_3(rbot0); + const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x; + top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems; + } + } + } +''' + +kernel_Correlation_updateGradFirst = ''' + #define ROUND_OFF 50000 + + extern "C" __global__ void kernel_Correlation_updateGradFirst( + const int n, + const int intSample, + const float* rbot0, + const float* rbot1, + const float* gradOutput, + float* gradFirst, + float* gradSecond + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + int n = intIndex % SIZE_1(gradFirst); // channels + int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 4; // w-pos + int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 4; // h-pos + + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = round_off; + + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) + int ymin = (m - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4) + + // Same here: + int xmax = (l - 4 + round_off_s1) - round_off; // floor (l - 4) + int ymax = (m - 4 + round_off_s1) - round_off; // floor (m - 4) + + float sum = 0; + if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { + xmin = max(0,xmin); + xmax = min(SIZE_3(gradOutput)-1,xmax); + + ymin = max(0,ymin); + ymax = min(SIZE_2(gradOutput)-1,ymax); + + for (int p = -4; p <= 4; p++) { + for (int o = -4; o <= 4; o++) { + // Get rbot1 data: + int s2o = o; + int s2p = p; + int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n; + float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n] + + // Index offset for gradOutput in following loops: + int op = (p+4) * 9 + (o+4); // index[o,p] + int idxopoffset = (intSample * SIZE_1(gradOutput) + op); + + for (int y = ymin; y <= ymax; y++) { + for (int x = xmin; x <= xmax; x++) { + int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] + sum += gradOutput[idxgradOutput] * bot1tmp; + } + } + } + } + } + const int sumelems = SIZE_1(gradFirst); + const int bot0index = ((n * SIZE_2(gradFirst)) + (m-4)) * SIZE_3(gradFirst) + (l-4); + gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems; + } } +''' + +kernel_Correlation_updateGradSecond = ''' + #define ROUND_OFF 50000 + + extern "C" __global__ void kernel_Correlation_updateGradSecond( + const int n, + const int intSample, + const float* rbot0, + const float* rbot1, + const float* gradOutput, + float* gradFirst, + float* gradSecond + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + int n = intIndex % SIZE_1(gradSecond); // channels + int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 4; // w-pos + int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 4; // h-pos + + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = round_off; + + float sum = 0; + for (int p = -4; p <= 4; p++) { + for (int o = -4; o <= 4; o++) { + int s2o = o; + int s2p = p; + + //Get X,Y ranges and clamp + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 4 - s2o + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) + int ymin = (m - 4 - s2p + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o) + + // Same here: + int xmax = (l - 4 - s2o + round_off_s1) - round_off; // floor (l - 4 - s2o) + int ymax = (m - 4 - s2p + round_off_s1) - round_off; // floor (m - 4 - s2p) + + if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { + xmin = max(0,xmin); + xmax = min(SIZE_3(gradOutput)-1,xmax); + + ymin = max(0,ymin); + ymax = min(SIZE_2(gradOutput)-1,ymax); + + // Get rbot0 data: + int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n; + float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n] + + // Index offset for gradOutput in following loops: + int op = (p+4) * 9 + (o+4); // index[o,p] + int idxopoffset = (intSample * SIZE_1(gradOutput) + op); + + for (int y = ymin; y <= ymax; y++) { + for (int x = xmin; x <= xmax; x++) { + int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] + sum += gradOutput[idxgradOutput] * bot0tmp; + } + } + } + } + } + const int sumelems = SIZE_1(gradSecond); + const int bot1index = ((n * SIZE_2(gradSecond)) + (m-4)) * SIZE_3(gradSecond) + (l-4); + gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems; + } } +''' + + +def cupy_kernel(strFunction, objVariables): + strKernel = globals()[strFunction] + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + # end + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg])) + # end + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) + + if objMatch is None: + break + # end + + intArgs = int(objMatch.group(2)) + strArgs = objMatch.group(4).split(',') + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + strIndex = ['((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str( + intStrides[intArg]) + ')' for intArg in range(intArgs)] + + strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') + # end + + return strKernel + + +# end + +@cupy.memoize(for_each_device=True) +def cupy_launch(strFunction, strKernel): + return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) + + +# end + +class _FunctionCorrelation(torch.autograd.Function): + @staticmethod + def forward(self, first, second): + rbot0 = first.new_zeros([first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1]]) + rbot1 = first.new_zeros([first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1]]) + + self.save_for_backward(first, second, rbot0, rbot1) + + assert (first.is_contiguous() == True) + assert (second.is_contiguous() == True) + + output = first.new_zeros([first.shape[0], 81, first.shape[2], first.shape[3]]) + + if first.is_cuda == True: + n = first.shape[2] * first.shape[3] + cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { + 'input': first, + 'output': rbot0 + }))( + grid=tuple([int((n + 16 - 1) / 16), first.shape[1], first.shape[0]]), + block=tuple([16, 1, 1]), + args=[n, first.data_ptr(), rbot0.data_ptr()] + ) + + n = second.shape[2] * second.shape[3] + cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { + 'input': second, + 'output': rbot1 + }))( + grid=tuple([int((n + 16 - 1) / 16), second.shape[1], second.shape[0]]), + block=tuple([16, 1, 1]), + args=[n, second.data_ptr(), rbot1.data_ptr()] + ) + + n = output.shape[1] * output.shape[2] * output.shape[3] + cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'top': output + }))( + grid=tuple([output.shape[3], output.shape[2], output.shape[0]]), + block=tuple([32, 1, 1]), + shared_mem=first.shape[1] * 4, + args=[n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr()] + ) + + elif first.is_cuda == False: + raise NotImplementedError() + + # end + + return output + + # end + + @staticmethod + def backward(self, gradOutput): + first, second, rbot0, rbot1 = self.saved_tensors + + assert (gradOutput.is_contiguous() == True) + + gradFirst = first.new_zeros([first.shape[0], first.shape[1], first.shape[2], first.shape[3]]) if \ + self.needs_input_grad[0] == True else None + gradSecond = first.new_zeros([first.shape[0], first.shape[1], first.shape[2], first.shape[3]]) if \ + self.needs_input_grad[1] == True else None + + if first.is_cuda == True: + if gradFirst is not None: + for intSample in range(first.shape[0]): + n = first.shape[1] * first.shape[2] * first.shape[3] + cupy_launch('kernel_Correlation_updateGradFirst', + cupy_kernel('kernel_Correlation_updateGradFirst', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'gradOutput': gradOutput, + 'gradFirst': gradFirst, + 'gradSecond': None + }))( + grid=tuple([int((n + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), + gradFirst.data_ptr(), None] + ) + # end + # end + + if gradSecond is not None: + for intSample in range(first.shape[0]): + n = first.shape[1] * first.shape[2] * first.shape[3] + cupy_launch('kernel_Correlation_updateGradSecond', + cupy_kernel('kernel_Correlation_updateGradSecond', { + 'rbot0': rbot0, + 'rbot1': rbot1, + 'gradOutput': gradOutput, + 'gradFirst': None, + 'gradSecond': gradSecond + }))( + grid=tuple([int((n + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None, + gradSecond.data_ptr()] + ) + # end + # end + + elif first.is_cuda == False: + raise NotImplementedError() + + # end + + return gradFirst, gradSecond + + +# end +# end + +def FunctionCorrelation(tenFirst, tenSecond): + return _FunctionCorrelation.apply(tenFirst, tenSecond) + + +# end + +class ModuleCorrelation(torch.nn.Module): + def __init__(self): + super(ModuleCorrelation, self).__init__() + + # end + + def forward(self, tenFirst, tenSecond): + return _FunctionCorrelation.apply(tenFirst, tenSecond) +# end +# end \ No newline at end of file diff --git a/modules/components/upr_net_multi_flow/m2m.py b/modules/components/upr_net_multi_flow/m2m.py new file mode 100644 index 0000000000000000000000000000000000000000..b213a8c64e1abf70d9a75bb25721d424067dfdfb --- /dev/null +++ b/modules/components/upr_net_multi_flow/m2m.py @@ -0,0 +1,367 @@ + +import math +import torch +import typing + +from ..components import register +from .backwarp import * +from .softsplat import * + + +########################################################## + +def forwarp_mframe_mask(tenIn1, tenFlow1, t1, tenIn2, tenFlow2, t2, tenMetric1=None, tenMetric2=None): + def one_fdir(tenIn, tenFlow, td, tenMetric): + tenIn = torch.cat([tenIn * td * (tenMetric).clip(-20.0, 20.0).exp(), td * (tenMetric).clip(-20.0, 20.0).exp()], + 1) + + tenOut = softsplat_func.apply(tenIn, tenFlow) + + return tenOut[:, :-1, :, :], tenOut[:, -1:, :, :] + 0.0000001 + + flow_num = tenFlow1.shape[0] + tenOutF, tenOutB = 0, 0 + tenNormalizeF, tenNormalizeB = 0, 0 + for idx in range(flow_num): + tenOutF_, tenNormalizeF_ = one_fdir(tenIn1[idx], tenFlow1[idx], t1[idx], tenMetric1[idx]) + tenOutB_, tenNormalizeB_ = one_fdir(tenIn2[idx], tenFlow2[idx], t2[idx], tenMetric2[idx]) + + tenOutF += tenOutF_ + tenOutB += tenOutB_ + tenNormalizeF += tenNormalizeF_ + tenNormalizeB += tenNormalizeB_ + + return tenOutF / tenNormalizeF, tenNormalizeF < 0.00001, tenOutB / tenNormalizeB, tenNormalizeB < 0.00001 + + +################################################################### + +c = 16 + + +def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): + return torch.nn.Sequential( + torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=True), + torch.nn.PReLU(out_planes) + ) + + +def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): + return torch.nn.Sequential( + torch.torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, + kernel_size=kernel_size, stride=stride, padding=padding, bias=True), + torch.nn.PReLU(out_planes) + ) + + +class Conv2(torch.nn.Module): + def __init__(self, in_planes, out_planes, stride=2): + super(Conv2, self).__init__() + self.conv1 = conv(in_planes, out_planes, 3, stride, 1) + self.conv2 = conv(out_planes, out_planes, 3, 1, 1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + return x + + +class Conv2n(torch.nn.Module): + def __init__(self, in_planes, out_planes, stride=2): + super(Conv2n, self).__init__() + self.conv1 = conv(in_planes, in_planes, 3, stride, 1) + self.conv2 = conv(in_planes, in_planes, 3, 1, 1) + self.conv3 = conv(in_planes, in_planes, 1, 1, 0) + self.conv4 = conv(in_planes, out_planes, 1, 1, 0) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + return x + + +##################################################### + +class ImgPyramid(torch.nn.Module): + def __init__(self): + super(ImgPyramid, self).__init__() + self.conv1 = Conv2(3, c) + self.conv2 = Conv2(c, 2 * c) + self.conv3 = Conv2(2 * c, 4 * c) + self.conv4 = Conv2(4 * c, 8 * c) + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.conv2(x1) + x3 = self.conv3(x2) + x4 = self.conv4(x3) + return [x1, x2, x3, x4] + + +class EncDec(torch.nn.Module): + def __init__(self, branch): + super(EncDec, self).__init__() + self.branch = branch + + self.down0 = Conv2(8, 2 * c) + self.down1 = Conv2(6 * c, 4 * c) + self.down2 = Conv2(12 * c, 8 * c) + self.down3 = Conv2(24 * c, 16 * c) + + self.up0 = deconv(48 * c, 8 * c) + self.up1 = deconv(16 * c, 4 * c) + self.up2 = deconv(8 * c, 2 * c) + self.up3 = deconv(4 * c, c) + self.conv = torch.nn.Conv2d(c, 2 * self.branch, 3, 1, 1) + + self.conv_m = torch.nn.Conv2d(c, 1, 3, 1, 1) + + # For Channel dimennsion + self.conv_C = torch.nn.Sequential( + torch.nn.AdaptiveAvgPool2d(1), + torch.nn.Conv2d(16 * c, 16 * 16 * c, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True), + torch.nn.Sigmoid() + ) + + # For Height dimennsion + self.conv_H = torch.nn.Sequential( + torch.nn.AdaptiveAvgPool2d((None, 1)), + torch.nn.Conv2d(16 * c, 16, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True), + torch.nn.Sigmoid() + ) + + # For Width dimennsion + self.conv_W = torch.nn.Sequential( + torch.nn.AdaptiveAvgPool2d((1, None)), + torch.nn.Conv2d(16 * c, 16, kernel_size=(1, 1), stride=(1, 1), padding=(0, 0), bias=True), + torch.nn.Sigmoid() + ) + + self.sigmoid = torch.nn.Sigmoid() + + def forward(self, flow0, flow1, im0, im1, c0, c1): + N_, C_, H_, W_ = im0.shape + + wim1 = backwarp(im1, flow0) + wim0 = backwarp(im0, flow1) + s0_0 = self.down0(torch.cat((flow0, im0, wim1), 1)) + s1_0 = self.down0(torch.cat((flow1, im1, wim0), 1)) + + ######################################################################################### + flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + + wf0 = backwarp(torch.cat((s0_0, c0[0]), 1), flow1) + wf1 = backwarp(torch.cat((s1_0, c1[0]), 1), flow0) + + s0_1 = self.down1(torch.cat((s0_0, c0[0], wf1), 1)) + s1_1 = self.down1(torch.cat((s1_0, c1[0], wf0), 1)) + + ######################################################################################### + flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + + wf0 = backwarp(torch.cat((s0_1, c0[1]), 1), flow1) + wf1 = backwarp(torch.cat((s1_1, c1[1]), 1), flow0) + + s0_2 = self.down2(torch.cat((s0_1, c0[1], wf1), 1)) + s1_2 = self.down2(torch.cat((s1_1, c1[1], wf0), 1)) + + ######################################################################################### + flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + + wf0 = backwarp(torch.cat((s0_2, c0[2]), 1), flow1) + wf1 = backwarp(torch.cat((s1_2, c1[2]), 1), flow0) + + s0_3 = self.down3(torch.cat((s0_2, c0[2], wf1), 1)) + s1_3 = self.down3(torch.cat((s1_2, c1[2], wf0), 1)) + + ######################################################################################### + + s0_3_c = self.conv_C(s0_3) + s0_3_c = s0_3_c.view(N_, 16, -1, 1, 1) + + s0_3_h = self.conv_H(s0_3) + s0_3_h = s0_3_h.view(N_, 16, 1, -1, 1) + + s0_3_w = self.conv_W(s0_3) + s0_3_w = s0_3_w.view(N_, 16, 1, 1, -1) + + cube0 = (s0_3_c * s0_3_h * s0_3_w).mean(1) + + s0_3 = s0_3 * cube0 + + s1_3_c = self.conv_C(s1_3) + s1_3_c = s1_3_c.view(N_, 16, -1, 1, 1) + + s1_3_h = self.conv_H(s1_3) + s1_3_h = s1_3_h.view(N_, 16, 1, -1, 1) + + s1_3_w = self.conv_W(s1_3) + s1_3_w = s1_3_w.view(N_, 16, 1, 1, -1) + + cube1 = (s1_3_c * s1_3_h * s1_3_w).mean(1) + + s1_3 = s1_3 * cube1 + + ######################################################################################### + flow0 = torch.nn.functional.interpolate(flow0, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + flow1 = torch.nn.functional.interpolate(flow1, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5 + + wf0 = backwarp(torch.cat((s0_3, c0[3]), 1), flow1) + wf1 = backwarp(torch.cat((s1_3, c1[3]), 1), flow0) + + x0 = self.up0(torch.cat((s0_3, c0[3], wf1), 1)) + x1 = self.up0(torch.cat((s1_3, c1[3], wf0), 1)) + + x0 = self.up1(torch.cat((s0_2, x0), 1)) + x1 = self.up1(torch.cat((s1_2, x1), 1)) + + x0 = self.up2(torch.cat((s0_1, x0), 1)) + x1 = self.up2(torch.cat((s1_1, x1), 1)) + + x0 = self.up3(torch.cat((s0_0, x0), 1)) + x1 = self.up3(torch.cat((s1_0, x1), 1)) + + m0 = self.sigmoid(self.conv_m(x0)) * 0.8 + 0.1 + m1 = self.sigmoid(self.conv_m(x1)) * 0.8 + 0.1 + + x0 = self.conv(x0) + x1 = self.conv(x1) + + return x0, x1, m0.repeat(1, self.branch, 1, 1), m1.repeat(1, self.branch, 1, 1) + + +@register('m2m_pwc') +class M2M_PWC(torch.nn.Module): + def __init__(self, ratio=4): + super(M2M_PWC, self).__init__() + self.branch = 4 + self.ratio = ratio + + self.paramAlpha = torch.nn.Parameter(10.0 * torch.ones(1, 1, 1, 1)) + + class MotionRefineNet(torch.nn.Module): + def __init__(self, branch): + super(MotionRefineNet, self).__init__() + self.branch = branch + self.img_pyramid = ImgPyramid() + self.motion_encdec = EncDec(branch) + + def forward(self, flow0, flow1, im0, im1, ratio): + flow0 = ratio * torch.nn.functional.interpolate(input=flow0, scale_factor=ratio, mode='bilinear', + align_corners=False) + flow1 = ratio * torch.nn.functional.interpolate(input=flow1, scale_factor=ratio, mode='bilinear', + align_corners=False) + + c0 = self.img_pyramid(im0) + c1 = self.img_pyramid(im1) + + flow_res = self.motion_encdec(flow0, flow1, im0, im1, c0, c1) + + flow0 = flow0.repeat(1, self.branch, 1, 1) + flow_res[0] + flow1 = flow1.repeat(1, self.branch, 1, 1) + flow_res[1] + + return flow0, flow1, flow_res[2], flow_res[3] + + self.MRN = MotionRefineNet(self.branch) + + def forward(self, img0, img1, time_step=[0.5], ratio=None, **kwargs): + if ratio is None: + ratio = self.ratio + + intWidth = img0.shape[3] and img1.shape[3] + intHeight = img0.shape[2] and img1.shape[2] + + intPadr = ((ratio * 16) - (intWidth % (ratio * 16))) % (ratio * 16) + intPadb = ((ratio * 16) - (intHeight % (ratio * 16))) % (ratio * 16) + + img0 = torch.nn.functional.pad(input=img0, pad=[0, intPadr, 0, intPadb], mode='replicate') + img1 = torch.nn.functional.pad(input=img1, pad=[0, intPadr, 0, intPadb], mode='replicate') + + N_, C_, H_, W_ = img0.shape + + outputs = [] + result_dict = {} + with torch.set_grad_enabled(False): + tenStats = [img0, img1] + tenMean_ = sum([tenIn.mean([1, 2, 3], True) for tenIn in tenStats]) / len(tenStats) + tenStd_ = (sum([tenIn.std([1, 2, 3], False, True).square() + ( + tenMean_ - tenIn.mean([1, 2, 3], True)).square() for tenIn in tenStats]) / len(tenStats)).sqrt() + + im0_o = (img0 - tenMean_) / (tenStd_ + 0.0000001) + im1_o = (img1 - tenMean_) / (tenStd_ + 0.0000001) + + img0 = (img0 - tenMean_) / (tenStd_ + 0.0000001) + img1 = (img1 - tenMean_) / (tenStd_ + 0.0000001) + + im0_ = torch.nn.functional.interpolate(input=img0, scale_factor=2.0 / ratio, mode='bilinear', + align_corners=False) + im1_ = torch.nn.functional.interpolate(input=img1, scale_factor=2.0 / ratio, mode='bilinear', + align_corners=False) + + tenFwd, tenBwd = self.netFlow.bidir(im0_, im1_) + + result_dict['flowfwd'] = torch.nn.functional.interpolate(tenFwd, scale_factor=ratio, mode='bilinear', align_corners=False)[:, :, + :intHeight, :intWidth].clone().detach() * ratio + result_dict['flowbwd'] = torch.nn.functional.interpolate(tenBwd, scale_factor=ratio, mode='bilinear', align_corners=False)[:, :, + :intHeight, :intWidth].clone().detach() * ratio + + tenFwd, tenBwd, WeiMF, WeiMB = self.MRN(tenFwd, tenBwd, img0, img1, ratio) + + img0 = im0_o.repeat(1, self.branch, 1, 1) + img1 = im1_o.repeat(1, self.branch, 1, 1) + tenStd = tenStd_.repeat(1, self.branch, 1, 1) + tenMean = tenMean_.repeat(1, self.branch, 1, 1) + fltTime = time_step.repeat(1, self.branch, 1, 1) + + tenFwd = tenFwd.reshape(N_, self.branch, 2, H_, W_).view(N_ * self.branch, 2, H_, W_) + tenBwd = tenBwd.reshape(N_, self.branch, 2, H_, W_).view(N_ * self.branch, 2, H_, W_) + + WeiMF = WeiMF.reshape(N_, self.branch, 1, H_, W_).view(N_ * self.branch, 1, H_, W_) + WeiMB = WeiMB.reshape(N_, self.branch, 1, H_, W_).view(N_ * self.branch, 1, H_, W_) + + img0 = img0.reshape(N_, self.branch, 3, H_, W_).view(N_ * self.branch, 3, H_, W_) + img1 = img1.reshape(N_, self.branch, 3, H_, W_).view(N_ * self.branch, 3, H_, W_) + + tenStd = tenStd.reshape(N_, self.branch, 1, 1, 1).view(N_ * self.branch, 1, 1, 1) + tenMean = tenMean.reshape(N_, self.branch, 1, 1, 1).view(N_ * self.branch, 1, 1, 1) + fltTime = fltTime.reshape(N_, self.branch, 1, 1, 1).view(N_ * self.branch, 1, 1, 1) + + tenPhotoone = (1.0 - (WeiMF * (img0 - backwarp(img1, tenFwd).detach()).abs().mean([1], True))).clip( + 0.001, None).square() + tenPhototwo = (1.0 - (WeiMB * (img1 - backwarp(img0, tenBwd).detach()).abs().mean([1], True))).clip( + 0.001, None).square() + + t0 = fltTime + flow0 = tenFwd * t0 + metric0 = self.paramAlpha * tenPhotoone + + t1 = 1.0 - fltTime + flow1 = tenBwd * t1 + metric1 = self.paramAlpha * tenPhototwo + + flow0 = flow0.reshape(N_, self.branch, 2, H_, W_).permute(1, 0, 2, 3, 4) + flow1 = flow1.reshape(N_, self.branch, 2, H_, W_).permute(1, 0, 2, 3, 4) + + metric0 = metric0.reshape(N_, self.branch, 1, H_, W_).permute(1, 0, 2, 3, 4) + metric1 = metric1.reshape(N_, self.branch, 1, H_, W_).permute(1, 0, 2, 3, 4) + + img0 = img0.reshape(N_, self.branch, 3, H_, W_).permute(1, 0, 2, 3, 4) + img1 = img1.reshape(N_, self.branch, 3, H_, W_).permute(1, 0, 2, 3, 4) + + t0 = t0.reshape(N_, self.branch, 1, 1, 1).permute(1, 0, 2, 3, 4) + t1 = t1.reshape(N_, self.branch, 1, 1, 1).permute(1, 0, 2, 3, 4) + + tenOutput, mask = forwarp_mframe_mask(img0, flow0, t1, img1, flow1, t0, metric0, metric1) + + tenOutput = tenOutput + mask * (t1.mean(0) * im0_o + t0.mean(0) * im1_o) + + output = (tenOutput * (tenStd_ + 0.0000001)) + tenMean_ + result_dict['imgt_pred'] = output[:, :, :intHeight, :intWidth] + + return result_dict diff --git a/modules/components/upr_net_multi_flow/softsplat.py b/modules/components/upr_net_multi_flow/softsplat.py new file mode 100644 index 0000000000000000000000000000000000000000..76237ad31f9837842710abf278ab772140fcd5f2 --- /dev/null +++ b/modules/components/upr_net_multi_flow/softsplat.py @@ -0,0 +1,558 @@ +#!/usr/bin/env python + +######################################### +# This implementation is taken from +# https://github.com/sniklaus/softmax-splatting +######################################### + +import collections +import cupy +import os +import re +import torch +import typing + +########################################################## + + +objCudacache = {} + + +def cuda_int32(intIn: int): + return cupy.int32(intIn) + + +# end + + +def cuda_float32(fltIn: float): + return cupy.float32(fltIn) + + +# end + + +def cuda_kernel(strFunction: str, strKernel: str, objVariables: typing.Dict): + if 'device' not in objCudacache: + objCudacache['device'] = torch.cuda.get_device_name() + # end + + strKey = strFunction + + for strVariable in objVariables: + objValue = objVariables[strVariable] + + strKey += strVariable + + if objValue is None: + continue + + elif type(objValue) == int: + strKey += str(objValue) + + elif type(objValue) == float: + strKey += str(objValue) + + elif type(objValue) == bool: + strKey += str(objValue) + + elif type(objValue) == str: + strKey += objValue + + elif type(objValue) == torch.Tensor: + strKey += str(objValue.dtype) + strKey += str(objValue.shape) + strKey += str(objValue.stride()) + + elif True: + print(strVariable, type(objValue)) + assert (False) + + # end + # end + + strKey += objCudacache['device'] + + if strKey not in objCudacache: + for strVariable in objVariables: + objValue = objVariables[strVariable] + + if objValue is None: + continue + + elif type(objValue) == int: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == float: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == bool: + strKernel = strKernel.replace('{{' + strVariable + '}}', str(objValue)) + + elif type(objValue) == str: + strKernel = strKernel.replace('{{' + strVariable + '}}', objValue) + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8: + strKernel = strKernel.replace('{{type}}', 'unsigned char') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16: + strKernel = strKernel.replace('{{type}}', 'half') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32: + strKernel = strKernel.replace('{{type}}', 'float') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64: + strKernel = strKernel.replace('{{type}}', 'double') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32: + strKernel = strKernel.replace('{{type}}', 'int') + + elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64: + strKernel = strKernel.replace('{{type}}', 'long') + + elif type(objValue) == torch.Tensor: + print(strVariable, objValue.dtype) + assert (False) + + elif True: + print(strVariable, type(objValue)) + assert (False) + + # end + # end + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + # end + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str( + intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item())) + # end + + while True: + objMatch = re.search('(OFFSET_)([0-4])(\()', strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == '(' else 0 + intParentheses -= 1 if strKernel[intStop] == ')' else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(',') + + assert (intArgs == len(strArgs) - 1) + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str( + intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[ + intArg].item()) + ')') + # end + + strKernel = strKernel.replace('OFFSET_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', + '(' + str.join('+', strIndex) + ')') + # end + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()', strKernel) + + if objMatch is None: + break + # end + + intStart = objMatch.span()[1] + intStop = objMatch.span()[1] + intParentheses = 1 + + while True: + intParentheses += 1 if strKernel[intStop] == '(' else 0 + intParentheses -= 1 if strKernel[intStop] == ')' else 0 + + if intParentheses == 0: + break + # end + + intStop += 1 + # end + + intArgs = int(objMatch.group(2)) + strArgs = strKernel[intStart:intStop].split(',') + + assert (intArgs == len(strArgs) - 1) + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + + strIndex = [] + + for intArg in range(intArgs): + strIndex.append('((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str( + intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[ + intArg].item()) + ')') + # end + + strKernel = strKernel.replace('VALUE_' + str(intArgs) + '(' + strKernel[intStart:intStop] + ')', + strTensor + '[' + str.join('+', strIndex) + ']') + # end + + objCudacache[strKey] = { + 'strFunction': strFunction, + 'strKernel': strKernel + } + # end + + return strKey + + +# end + + +@cupy.memoize(for_each_device=True) +def cuda_launch(strKey: str): + if 'CUDA_HOME' not in os.environ: + os.environ['CUDA_HOME'] = '/usr/local/cuda/' + # end + + return cupy.cuda.compile_with_cache(objCudacache[strKey]['strKernel'], tuple( + ['-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include'])).get_function( + objCudacache[strKey]['strFunction']) + + +# end + + +########################################################## + + +def softsplat(tenIn: torch.Tensor, tenFlow: torch.Tensor, tenMetric: torch.Tensor, strMode: str): + assert (strMode.split('-')[0] in ['sum', 'avg', 'linear', 'soft']) + + if strMode == 'sum': assert (tenMetric is None) + if strMode == 'avg': assert (tenMetric is None) + if strMode.split('-')[0] == 'linear': assert (tenMetric is not None) + if strMode.split('-')[0] == 'soft': assert (tenMetric is not None) + + if strMode == 'avg': + tenIn = torch.cat([tenIn, tenIn.new_ones([tenIn.shape[0], 1, tenIn.shape[2], tenIn.shape[3]])], 1) + + elif strMode.split('-')[0] == 'linear': + tenIn = torch.cat([tenIn * tenMetric, tenMetric], 1) + + elif strMode.split('-')[0] == 'soft': + tenIn = torch.cat([tenIn * tenMetric.exp(), tenMetric.exp()], 1) + + # end + + tenOut = softsplat_func.apply(tenIn, tenFlow) + + if strMode.split('-')[0] in ['avg', 'linear', 'soft']: + tenNormalize = tenOut[:, -1:, :, :] + + if len(strMode.split('-')) == 1: + tenNormalize = tenNormalize + 0.0000001 + + elif strMode.split('-')[1] == 'addeps': + tenNormalize = tenNormalize + 0.0000001 + + elif strMode.split('-')[1] == 'zeroeps': + tenNormalize[tenNormalize == 0.0] = 1.0 + + elif strMode.split('-')[1] == 'clipeps': + tenNormalize = tenNormalize.clip(0.0000001, None) + + # end + + tenOut = tenOut[:, :-1, :, :] / tenNormalize + # end + + return tenOut + + +# end + + +class softsplat_func(torch.autograd.Function): + @staticmethod + @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32) + def forward(self, tenIn, tenFlow): + tenOut = tenIn.new_zeros([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) + + if tenIn.is_cuda == True: + cuda_launch(cuda_kernel('softsplat_out', ''' + extern "C" __global__ void __launch_bounds__(512) softsplat_out( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + {{type}}* __restrict__ tenOut + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) / SIZE_1(tenOut) ) % SIZE_0(tenOut); + const int intC = ( intIndex / SIZE_3(tenOut) / SIZE_2(tenOut) ) % SIZE_1(tenOut); + const int intY = ( intIndex / SIZE_3(tenOut) ) % SIZE_2(tenOut); + const int intX = ( intIndex ) % SIZE_3(tenOut); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + {{type}} fltIn = VALUE_4(tenIn, intN, intC, intY, intX); + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); + {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); + {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); + {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOut)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNorthwestY, intNorthwestX)], fltIn * fltNorthwest); + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOut)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intNortheastY, intNortheastX)], fltIn * fltNortheast); + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOut)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSouthwestY, intSouthwestX)], fltIn * fltSouthwest); + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOut)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOut))) { + atomicAdd(&tenOut[OFFSET_4(tenOut, intN, intC, intSoutheastY, intSoutheastX)], fltIn * fltSoutheast); + } + } } + ''', { + 'tenIn': tenIn, + 'tenFlow': tenFlow, + 'tenOut': tenOut + }))( + grid=tuple([int((tenOut.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenOut.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOut.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + + elif tenIn.is_cuda != True: + assert (False) + + # end + + self.save_for_backward(tenIn, tenFlow) + + return tenOut + + # end + + @staticmethod + @torch.cuda.amp.custom_bwd + def backward(self, tenOutgrad): + tenIn, tenFlow = self.saved_tensors + + tenOutgrad = tenOutgrad.contiguous(); + assert (tenOutgrad.is_cuda == True) + + tenIngrad = tenIn.new_empty([tenIn.shape[0], tenIn.shape[1], tenIn.shape[2], tenIn.shape[3]]) if \ + self.needs_input_grad[0] == True else None + tenFlowgrad = tenFlow.new_empty([tenFlow.shape[0], tenFlow.shape[1], tenFlow.shape[2], tenFlow.shape[3]]) if \ + self.needs_input_grad[1] == True else None + + if tenIngrad is not None: + cuda_launch(cuda_kernel('softsplat_ingrad', ''' + extern "C" __global__ void __launch_bounds__(512) softsplat_ingrad( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenIngrad, + {{type}}* __restrict__ tenFlowgrad + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) / SIZE_1(tenIngrad) ) % SIZE_0(tenIngrad); + const int intC = ( intIndex / SIZE_3(tenIngrad) / SIZE_2(tenIngrad) ) % SIZE_1(tenIngrad); + const int intY = ( intIndex / SIZE_3(tenIngrad) ) % SIZE_2(tenIngrad); + const int intX = ( intIndex ) % SIZE_3(tenIngrad); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltIngrad = 0.0f; + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (intSoutheastY) - fltY); + {{type}} fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (intSouthwestY) - fltY); + {{type}} fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (fltY - ({{type}}) (intNortheastY)); + {{type}} fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (fltY - ({{type}}) (intNorthwestY)); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest; + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intNortheastY, intNortheastX) * fltNortheast; + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest; + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { + fltIngrad += VALUE_4(tenOutgrad, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast; + } + + tenIngrad[intIndex] = fltIngrad; + } } + ''', { + 'tenIn': tenIn, + 'tenFlow': tenFlow, + 'tenOutgrad': tenOutgrad, + 'tenIngrad': tenIngrad, + 'tenFlowgrad': tenFlowgrad + }))( + grid=tuple([int((tenIngrad.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenIngrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), + tenIngrad.data_ptr(), None], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + # end + + if tenFlowgrad is not None: + cuda_launch(cuda_kernel('softsplat_flowgrad', ''' + extern "C" __global__ void __launch_bounds__(512) softsplat_flowgrad( + const int n, + const {{type}}* __restrict__ tenIn, + const {{type}}* __restrict__ tenFlow, + const {{type}}* __restrict__ tenOutgrad, + {{type}}* __restrict__ tenIngrad, + {{type}}* __restrict__ tenFlowgrad + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + const int intN = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) / SIZE_1(tenFlowgrad) ) % SIZE_0(tenFlowgrad); + const int intC = ( intIndex / SIZE_3(tenFlowgrad) / SIZE_2(tenFlowgrad) ) % SIZE_1(tenFlowgrad); + const int intY = ( intIndex / SIZE_3(tenFlowgrad) ) % SIZE_2(tenFlowgrad); + const int intX = ( intIndex ) % SIZE_3(tenFlowgrad); + + assert(SIZE_1(tenFlow) == 2); + + {{type}} fltFlowgrad = 0.0f; + + {{type}} fltX = ({{type}}) (intX) + VALUE_4(tenFlow, intN, 0, intY, intX); + {{type}} fltY = ({{type}}) (intY) + VALUE_4(tenFlow, intN, 1, intY, intX); + + if (isfinite(fltX) == false) { return; } + if (isfinite(fltY) == false) { return; } + + int intNorthwestX = (int) (floor(fltX)); + int intNorthwestY = (int) (floor(fltY)); + int intNortheastX = intNorthwestX + 1; + int intNortheastY = intNorthwestY; + int intSouthwestX = intNorthwestX; + int intSouthwestY = intNorthwestY + 1; + int intSoutheastX = intNorthwestX + 1; + int intSoutheastY = intNorthwestY + 1; + + {{type}} fltNorthwest = 0.0f; + {{type}} fltNortheast = 0.0f; + {{type}} fltSouthwest = 0.0f; + {{type}} fltSoutheast = 0.0f; + + if (intC == 0) { + fltNorthwest = (({{type}}) (-1.0f)) * (({{type}}) (intSoutheastY) - fltY); + fltNortheast = (({{type}}) (+1.0f)) * (({{type}}) (intSouthwestY) - fltY); + fltSouthwest = (({{type}}) (-1.0f)) * (fltY - ({{type}}) (intNortheastY)); + fltSoutheast = (({{type}}) (+1.0f)) * (fltY - ({{type}}) (intNorthwestY)); + + } else if (intC == 1) { + fltNorthwest = (({{type}}) (intSoutheastX) - fltX) * (({{type}}) (-1.0f)); + fltNortheast = (fltX - ({{type}}) (intSouthwestX)) * (({{type}}) (-1.0f)); + fltSouthwest = (({{type}}) (intNortheastX) - fltX) * (({{type}}) (+1.0f)); + fltSoutheast = (fltX - ({{type}}) (intNorthwestX)) * (({{type}}) (+1.0f)); + + } + + for (int intChannel = 0; intChannel < SIZE_1(tenOutgrad); intChannel += 1) { + {{type}} fltIn = VALUE_4(tenIn, intN, intChannel, intY, intX); + + if ((intNorthwestX >= 0) && (intNorthwestX < SIZE_3(tenOutgrad)) && (intNorthwestY >= 0) && (intNorthwestY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNorthwestY, intNorthwestX) * fltIn * fltNorthwest; + } + + if ((intNortheastX >= 0) && (intNortheastX < SIZE_3(tenOutgrad)) && (intNortheastY >= 0) && (intNortheastY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intNortheastY, intNortheastX) * fltIn * fltNortheast; + } + + if ((intSouthwestX >= 0) && (intSouthwestX < SIZE_3(tenOutgrad)) && (intSouthwestY >= 0) && (intSouthwestY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSouthwestY, intSouthwestX) * fltIn * fltSouthwest; + } + + if ((intSoutheastX >= 0) && (intSoutheastX < SIZE_3(tenOutgrad)) && (intSoutheastY >= 0) && (intSoutheastY < SIZE_2(tenOutgrad))) { + fltFlowgrad += VALUE_4(tenOutgrad, intN, intChannel, intSoutheastY, intSoutheastX) * fltIn * fltSoutheast; + } + } + + tenFlowgrad[intIndex] = fltFlowgrad; + } } + ''', { + 'tenIn': tenIn, + 'tenFlow': tenFlow, + 'tenOutgrad': tenOutgrad, + 'tenIngrad': tenIngrad, + 'tenFlowgrad': tenFlowgrad + }))( + grid=tuple([int((tenFlowgrad.nelement() + 512 - 1) / 512), 1, 1]), + block=tuple([512, 1, 1]), + args=[cuda_int32(tenFlowgrad.nelement()), tenIn.data_ptr(), tenFlow.data_ptr(), tenOutgrad.data_ptr(), + None, tenFlowgrad.data_ptr()], + stream=collections.namedtuple('Stream', 'ptr')(torch.cuda.current_stream().cuda_stream) + ) + # end + + return tenIngrad, tenFlowgrad + # end +# end diff --git a/modules/components/upr_net_multi_flow/upr.py b/modules/components/upr_net_multi_flow/upr.py new file mode 100644 index 0000000000000000000000000000000000000000..32a98dc0b338846324b282f81eedee66b7c9ef5f --- /dev/null +++ b/modules/components/upr_net_multi_flow/upr.py @@ -0,0 +1,540 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn + +import modules.components.upr_net_multi_flow.correlation as correlation +from modules.components.upr_net_multi_flow.softsplat import * +from modules.components.upr_net_multi_flow.backwarp import backwarp +from modules.components.upr_net_multi_flow.m2m import * + +from ..components import register + + +def photometric_consistency(img0, img1, flow01): + return (img0 - backwarp(img1, flow01)).abs().sum(dim=1, keepdims=True) + + +def flow_consistency(flow01, flow10): + return (flow01 + backwarp(flow10, flow01)).abs().sum(dim=1, keepdims=True) + + +gaussian_kernel = torch.tensor([[1, 2, 1], + [2, 4, 2], + [1, 2, 1]]) / 16 +gaussian_kernel = gaussian_kernel.repeat(2, 1, 1, 1) +gaussian_kernel = gaussian_kernel.to(torch.cuda.current_device()) + + +def gaussian(x): + x = torch.nn.functional.pad(x, (1, 1, 1, 1), mode='reflect') + out = torch.nn.functional.conv2d(x, gaussian_kernel, groups=x.shape[1]) + # out = TF.gaussian_blur(x, [3, 3], sigma=[2, 2]) + return out + + +def variance_flow(flow): + flow = flow * torch.tensor(data=[2.0 / (flow.shape[3] - 1.0), 2.0 / (flow.shape[2] - 1.0)], dtype=flow.dtype, + device=flow.device).view(1, 2, 1, 1) + return (gaussian(flow ** 2) - gaussian(flow) ** 2 + 1e-4).sqrt().abs().sum(dim=1, keepdim=True) + + +# **************************************************************************************************# +# => Feature Pyramid +# **************************************************************************************************# +class FeatPyramid(nn.Module): + """A 3-level feature pyramid, which by default is shared by the motion + estimator and synthesis network. + """ + + def __init__(self): + super(FeatPyramid, self).__init__() + self.conv_stage0 = nn.Sequential( + nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_stage1 = nn.Sequential( + nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, + stride=2, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_stage2 = nn.Sequential( + nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, + stride=2, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + + def forward(self, img): + C0 = self.conv_stage0(img) + C1 = self.conv_stage1(C0) + C2 = self.conv_stage2(C1) + return [C0, C1, C2] + + +# **************************************************************************************************# +# => Motion Estimation +# **************************************************************************************************# +class MotionEstimator(nn.Module): + """Bi-directional optical flow estimator + 1) construct partial cost volume with the CNN features from the stage 2 of + the feature pyramid; + 2) estimate bi-directional flows, by feeding cost volume, CNN features for + both warped images, CNN feature and estimated flow from previous iteration. + """ + + def __init__(self): + super(MotionEstimator, self).__init__() + # (4*2 + 1) ** 2 + 128 * 2 + 128 + 4 = 469 + self.conv_layer1 = nn.Sequential( + nn.Conv2d(in_channels=469, out_channels=320, + kernel_size=1, stride=1, padding=0), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer2 = nn.Sequential( + nn.Conv2d(in_channels=320, out_channels=256, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer3 = nn.Sequential( + nn.Conv2d(in_channels=256, out_channels=224, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer4 = nn.Sequential( + nn.Conv2d(in_channels=224, out_channels=192, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer5 = nn.Sequential( + nn.Conv2d(in_channels=192, out_channels=128, + kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(inplace=False, negative_slope=0.1)) + self.conv_layer6 = nn.Sequential( + nn.Conv2d(in_channels=128, out_channels=4, + kernel_size=3, stride=1, padding=1)) + + def forward(self, feat0, feat1, last_feat, last_flow): + corr_fn = correlation.FunctionCorrelation + feat0 = softsplat( + tenIn=feat0, tenFlow=last_flow[:, :2] * 0.25 * 0.5, + tenMetric=None, strMode='avg') + feat1 = softsplat( + tenIn=feat1, tenFlow=last_flow[:, 2:] * 0.25 * 0.5, + tenMetric=None, strMode='avg') + + volume = F.leaky_relu( + input=corr_fn(tenFirst=feat0, tenSecond=feat1), + negative_slope=0.1, inplace=False) + input_feat = torch.cat([volume, feat0, feat1, last_feat, last_flow], 1) + feat = self.conv_layer1(input_feat) + feat = self.conv_layer2(feat) + feat = self.conv_layer3(feat) + feat = self.conv_layer4(feat) + feat = self.conv_layer5(feat) + flow = self.conv_layer6(feat) + + return flow, feat + + +# **************************************************************************************************# +# => Frame Synthesis +# **************************************************************************************************# +class SynthesisNetwork(nn.Module): + def __init__(self, branch): + super(SynthesisNetwork, self).__init__() + self.branch = branch + input_channels = 9 + 4 + 6 + self.encoder_conv = nn.Sequential( + nn.Conv2d(in_channels=input_channels, out_channels=64, + kernel_size=3, stride=1, padding=1), + nn.PReLU(num_parameters=64), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64)) + self.encoder_down1 = nn.Sequential( + nn.Conv2d(in_channels=64 + 32 + 32, out_channels=128, + kernel_size=3, stride=2, padding=1), + nn.PReLU(num_parameters=128), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=128), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=128)) + self.encoder_down2 = nn.Sequential( + nn.Conv2d(in_channels=128 + 64 + 64, out_channels=256, + kernel_size=3, stride=2, padding=1), + nn.PReLU(num_parameters=256), + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=256), + nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=256)) + self.decoder_up1 = nn.Sequential( + torch.nn.ConvTranspose2d(in_channels=256 + 128 + 128, + out_channels=128, kernel_size=4, stride=2, + padding=1, bias=True), + nn.PReLU(num_parameters=128), + nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=128)) + self.decoder_up2 = nn.Sequential( + torch.nn.ConvTranspose2d(in_channels=128 + 128, + out_channels=64, kernel_size=4, stride=2, + padding=1, bias=True), + nn.PReLU(num_parameters=64), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64)) + self.decoder_conv = nn.Sequential( + nn.Conv2d(in_channels=64 + 64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64), + nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, + stride=1, padding=1), + nn.PReLU(num_parameters=64)) + self.pred = nn.Conv2d(in_channels=64, out_channels=5, kernel_size=3, + stride=1, padding=1) + + class MotionRefineNet(torch.nn.Module): + def __init__(self, branch): + super(MotionRefineNet, self).__init__() + self.branch = branch + self.img_pyramid = ImgPyramid() + self.motion_encdec = EncDec(branch) + + def forward(self, flow0, flow1, im0, im1): + c0 = self.img_pyramid(im0) + c1 = self.img_pyramid(im1) + + flow_res = self.motion_encdec(flow0, flow1, im0, im1, c0, c1) + + flow0 = flow0.repeat(1, self.branch, 1, 1) + flow_res[0] + flow1 = flow1.repeat(1, self.branch, 1, 1) + flow_res[1] + + return flow0, flow1, flow_res[2], flow_res[3] + + self.MRN = MotionRefineNet(self.branch) + + self.alpha = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + # New params for splatting mask generation + self.alpha_splat_photo_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_flow_consistency = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + self.alpha_splat_variation_flow = torch.nn.Parameter(torch.ones(1, 1, 1, 1)) + + def get_splat_weight(self, img0, img1, flow01, flow10): + M_splat = 1 / (1 + self.alpha_splat_photo_consistency * photometric_consistency(img0, img1, flow01)) + \ + 1 / (1 + self.alpha_splat_flow_consistency * flow_consistency(flow01, flow10)) + \ + 1 / (1 + self.alpha_splat_variation_flow * variance_flow(flow01)) + return M_splat * self.alpha + + def get_warped_representations(self, bi_flow, c0, c1, m_splat_0, m_splat_1, + i0=None, i1=None, time_period=0.5): + flow_0t = bi_flow[:, :2] * time_period + flow_1t = bi_flow[:, 2:4] * (1 - time_period) + warped_c0 = softsplat( + tenIn=c0, tenFlow=flow_0t, + tenMetric=m_splat_0, strMode='soft') + warped_c1 = softsplat( + tenIn=c1, tenFlow=flow_1t, + tenMetric=m_splat_1, strMode='soft') + if (i0 is None) and (i1 is None): + return warped_c0, warped_c1 + else: + warped_img0 = softsplat( + tenIn=i0, tenFlow=flow_0t, + tenMetric=m_splat_0, strMode='soft') + warped_img1 = softsplat( + tenIn=i1, tenFlow=flow_1t, + tenMetric=m_splat_1, strMode='soft') + flow_0t_1t = torch.cat((flow_0t, flow_1t), 1) + return warped_img0, warped_img1, warped_c0, warped_c1, flow_0t_1t + + def forward(self, last_i, i0, i1, c0_pyr, c1_pyr, bi_flow_pyr, + time_period=0.5, multi_flow=False): + m_splat_0_0 = self.get_splat_weight(i0, i1, bi_flow_pyr[0][:, :2], bi_flow_pyr[0][:, 2:4]) + m_splat_1_0 = self.get_splat_weight(i1, i0, bi_flow_pyr[0][:, 2:4], bi_flow_pyr[0][:, :2]) + if multi_flow: + tenFwd = bi_flow_pyr[0][:, :2] + tenBwd = bi_flow_pyr[0][:, 2:4] + tenFwd, tenBwd, WeiMF, WeiMB = self.MRN(tenFwd, tenBwd, i0, i1) + N_, _, H_, W_ = i0.shape + + i0_ = i0.repeat(1, self.branch, 1, 1) + i1_ = i1.repeat(1, self.branch, 1, 1) + + fltTime = time_period.repeat(1, self.branch, 1, 1) + + tenFwd = tenFwd.reshape(N_, self.branch, 2, H_, W_).view(N_ * self.branch, 2, H_, W_) + tenBwd = tenBwd.reshape(N_, self.branch, 2, H_, W_).view(N_ * self.branch, 2, H_, W_) + + WeiMF = WeiMF.reshape(N_, self.branch, 1, H_, W_).view(N_ * self.branch, 1, H_, W_) + WeiMB = WeiMB.reshape(N_, self.branch, 1, H_, W_).view(N_ * self.branch, 1, H_, W_) + + i0_ = i0_.reshape(N_, self.branch, 3, H_, W_).view(N_ * self.branch, 3, H_, W_) + i1_ = i1_.reshape(N_, self.branch, 3, H_, W_).view(N_ * self.branch, 3, H_, W_) + + fltTime = fltTime.reshape(N_, self.branch, 1, 1, 1).view(N_ * self.branch, 1, 1, 1) + + tenPhotoone = self.get_splat_weight(i0_, i1_, tenFwd, tenBwd) * WeiMF + tenPhototwo = self.get_splat_weight(i1_, i0_, tenBwd, tenFwd) * WeiMB + + t0 = fltTime + flow0 = tenFwd * t0 + metric0 = tenPhotoone + + t1 = 1.0 - fltTime + flow1 = tenBwd * t1 + metric1 = tenPhototwo + + flow0 = flow0.reshape(N_, self.branch, 2, H_, W_).permute(1, 0, 2, 3, 4) + flow1 = flow1.reshape(N_, self.branch, 2, H_, W_).permute(1, 0, 2, 3, 4) + + metric0 = metric0.reshape(N_, self.branch, 1, H_, W_).permute(1, 0, 2, 3, 4) + metric1 = metric1.reshape(N_, self.branch, 1, H_, W_).permute(1, 0, 2, 3, 4) + + i0_ = i0_.reshape(N_, self.branch, 3, H_, W_).permute(1, 0, 2, 3, 4) + i1_ = i1_.reshape(N_, self.branch, 3, H_, W_).permute(1, 0, 2, 3, 4) + + t0 = t0.reshape(N_, self.branch, 1, 1, 1).permute(1, 0, 2, 3, 4) + t1 = t1.reshape(N_, self.branch, 1, 1, 1).permute(1, 0, 2, 3, 4) + + tenOutputF, maskF, tenOutputB, maskB = forwarp_mframe_mask(i0_, flow0, t1, i1_, flow1, t0, metric0, metric1) + + warped_img0 = tenOutputF + maskF * i0 + warped_img1 = tenOutputB + maskB * i1 + warped_c0, warped_c1 = \ + self.get_warped_representations( + bi_flow_pyr[0], c0_pyr[0], c1_pyr[0], m_splat_0_0, m_splat_1_0, + time_period=time_period) + flow_0t = bi_flow_pyr[0][:, :2] * time_period + flow_1t = bi_flow_pyr[0][:, 2:4] * (1 - time_period) + flow_0t_1t = torch.cat((flow_0t, flow_1t), 1) + else: + warped_img0, warped_img1, warped_c0, warped_c1, flow_0t_1t = \ + self.get_warped_representations( + bi_flow_pyr[0], c0_pyr[0], c1_pyr[0], m_splat_0_0, m_splat_1_0, i0, i1, + time_period=time_period) + input_feat = torch.cat( + (last_i, warped_img0, warped_img1, i0, i1, flow_0t_1t), 1) + s0 = self.encoder_conv(input_feat) + s1 = self.encoder_down1(torch.cat((s0, warped_c0, warped_c1), 1)) + m_splat_0_1 = F.interpolate(m_splat_0_0, scale_factor=0.5, mode='bilinear') + m_splat_1_1 = F.interpolate(m_splat_1_0, scale_factor=0.5, mode='bilinear') + warped_c0, warped_c1 = self.get_warped_representations( + bi_flow_pyr[1], c0_pyr[1], c1_pyr[1], m_splat_0_1, m_splat_1_1, + time_period=time_period) + s2 = self.encoder_down2(torch.cat((s1, warped_c0, warped_c1), 1)) + m_splat_0_2 = F.interpolate(m_splat_0_1, scale_factor=0.5, mode='bilinear') + m_splat_1_2 = F.interpolate(m_splat_1_1, scale_factor=0.5, mode='bilinear') + warped_c0, warped_c1 = self.get_warped_representations( + bi_flow_pyr[2], c0_pyr[2], c1_pyr[2], m_splat_0_2, m_splat_1_2, + time_period=time_period) + + x = self.decoder_up1(torch.cat((s2, warped_c0, warped_c1), 1)) + x = self.decoder_up2(torch.cat((x, s1), 1)) + x = self.decoder_conv(torch.cat((x, s0), 1)) + + # prediction + refine = self.pred(x) + refine_res = torch.sigmoid(refine[:, :3]) * 2 - 1 + refine_mask0 = torch.sigmoid(refine[:, 3:4]) + refine_mask1 = torch.sigmoid(refine[:, 4:5]) + merged_img = (warped_img0 * refine_mask0 * (1 - time_period) + + warped_img1 * refine_mask1 * time_period) + merged_img = merged_img / (refine_mask0 * (1 - time_period) + + refine_mask1 * time_period) + interp_img = merged_img + refine_res + interp_img = torch.clamp(interp_img, 0, 1) + + extra_dict = {} + extra_dict["refine_res"] = refine_res + extra_dict["warped_img0"] = warped_img0 + extra_dict["warped_img1"] = warped_img1 + extra_dict["merged_img"] = merged_img + if multi_flow: + extra_dict['tenFwd'] = tenFwd.view(N_, self.branch, 2, H_, W_) + extra_dict['tenBwd'] = tenBwd.view(N_, self.branch, 2, H_, W_) + + return interp_img, extra_dict + + +# **************************************************************************************************# +# => Unified model +# **************************************************************************************************# +@register('upr_net_multi_flow') +class UPRMultiFlow(nn.Module): + def __init__(self, pyr_level=3, nr_lvl_skipped=0, branch=1): + super(UPRMultiFlow, self).__init__() + self.pyr_level = pyr_level + self.feat_pyramid = FeatPyramid() + self.nr_lvl_skipped = nr_lvl_skipped + self.branch = branch + self.motion_estimator = MotionEstimator() + self.synthesis_network = SynthesisNetwork(self.branch) + + def forward_one_lvl(self, + img0, img1, last_feat, last_flow, last_interp=None, + time_period=0.5, skip_me=False, multi_flow=False): + + # context feature extraction + feat0_pyr = self.feat_pyramid(img0) + feat1_pyr = self.feat_pyramid(img1) + + # bi-directional flow estimation + if not skip_me: + flow, feat = self.motion_estimator( + feat0_pyr[-1], feat1_pyr[-1], + last_feat, last_flow) + else: + flow = last_flow + feat = last_feat + + # frame synthesis + ## optical flow is estimated at 1/4 resolution + ori_resolution_flow = F.interpolate( + input=flow, scale_factor=4.0, + mode="bilinear", align_corners=False) + + ## consturct 3-level flow pyramid for synthesis network + bi_flow_pyr = [] + tmp_flow = ori_resolution_flow + bi_flow_pyr.append(tmp_flow) + for i in range(2): + tmp_flow = F.interpolate( + input=tmp_flow, scale_factor=0.5, + mode="bilinear", align_corners=False) * 0.5 + bi_flow_pyr.append(tmp_flow) + + ## merge warped frames as initial interpolation for frame synthesis + if last_interp is None: + flow_0t = ori_resolution_flow[:, :2] * time_period + flow_1t = ori_resolution_flow[:, 2:4] * (1 - time_period) + warped_img0 = softsplat( + tenIn=img0, tenFlow=flow_0t, + tenMetric=None, strMode='avg') + warped_img1 = softsplat( + tenIn=img1, tenFlow=flow_1t, + tenMetric=None, strMode='avg') + last_interp = warped_img0 * (1 - time_period) \ + + warped_img1 * time_period + + ## do synthesis + interp_img, extra_dict = self.synthesis_network( + last_interp, img0, img1, feat0_pyr, feat1_pyr, bi_flow_pyr, + time_period=time_period, multi_flow=multi_flow) + return flow, feat, interp_img, extra_dict + + def forward(self, img0, img1, time_step, + pyr_level=None, nr_lvl_skipped=None, **kwargs): + + if pyr_level is None: pyr_level = self.pyr_level + if nr_lvl_skipped is None: nr_lvl_skipped = self.nr_lvl_skipped + N, _, H, W = img0.shape + flow0_pred = [] + flow1_pred = [] + interp_imgs = [] + skipped_levels = [] if nr_lvl_skipped == 0 else \ + list(range(pyr_level))[::-1][-nr_lvl_skipped:] + + # The original input resolution corresponds to level 0. + for level in list(range(pyr_level))[::-1]: + if level != 0: + scale_factor = 1 / 2 ** level + img0_this_lvl = F.interpolate( + input=img0, scale_factor=scale_factor, + mode="bilinear", align_corners=False) + img1_this_lvl = F.interpolate( + input=img1, scale_factor=scale_factor, + mode="bilinear", align_corners=False) + else: + img0_this_lvl = img0 + img1_this_lvl = img1 + + # skip motion estimation, directly use up-sampled optical flow + skip_me = False + + # the lowest-resolution pyramid level + if level == pyr_level - 1: + last_flow = torch.zeros( + (N, 4, H // (2 ** (level + 2)), W // (2 ** (level + 2))) + ).to(img0.device) + last_feat = torch.zeros( + (N, 128, H // (2 ** (level + 2)), W // (2 ** (level + 2))) + ).to(img0.device) + last_interp = None + # skip some levels for both motion estimation and frame synthesis + elif level in skipped_levels[:-1]: + continue + # last level (original input resolution), only skip motion estimation + elif (level == 0) and len(skipped_levels) > 0: + if len(skipped_levels) == pyr_level: + last_flow = torch.zeros( + (N, 4, H // 4, W // 4)).to(img0.device) + last_interp = None + else: + resize_factor = 2 ** len(skipped_levels) + last_flow = F.interpolate( + input=flow, scale_factor=resize_factor, + mode="bilinear", align_corners=False) * resize_factor + last_interp = F.interpolate( + input=interp_img, scale_factor=resize_factor, + mode="bilinear", align_corners=False) + skip_me = True + # last level (original input resolution), motion estimation + frame + # synthesis + else: + last_flow = F.interpolate(input=flow, scale_factor=2.0, + mode="bilinear", align_corners=False) * 2 + last_feat = F.interpolate(input=feat, scale_factor=2.0, + mode="bilinear", align_corners=False) + last_interp = F.interpolate( + input=interp_img, scale_factor=2.0, + mode="bilinear", align_corners=False) + + flow, feat, interp_img, extra_dict = self.forward_one_lvl( + img0_this_lvl, img1_this_lvl, + last_feat, last_flow, last_interp, + time_step, skip_me=skip_me, multi_flow=(level == 0)) + if level != 0: + flow0_pred.append( + F.interpolate(input=flow[:, :2], scale_factor=4.0 * 2 ** level, + mode="bilinear", align_corners=False)) + flow1_pred.append( + F.interpolate(input=flow[:, 2:], scale_factor=4.0 * 2 ** level, + mode="bilinear", align_corners=False)) + else: + flow0_pred.append(extra_dict['tenFwd']) + flow1_pred.append(extra_dict['tenBwd']) + interp_imgs.append(F.interpolate(interp_img, scale_factor=2 ** level)) + + # directly up-sample estimated flow to full resolution with bi-linear + # interpolation + + return {"imgt_preds": interp_imgs[-2:], "flow0_pred": flow0_pred[::-1], "flow1_pred": flow1_pred[::-1], + 'imgt_pred': interp_img, "flowfwd": flow0_pred[-1][:, 0], "flowbwd": flow1_pred[-1][:, 0]} + + +if __name__ == "__main__": + pass diff --git a/modules/loss.py b/modules/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..7765c3000d7b638b9cb3c9c9da5c91a82393af95 --- /dev/null +++ b/modules/loss.py @@ -0,0 +1,614 @@ +#!/usr/bin/env python + +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +import torchvision.models as models +import cv2 +import numpy + +from modules.components.m2m_unimatch.unimatch.unimatch import UniMatch +from modules.components.m2m_flow_former.LatentCostFormer.transformer import * +from modules.components.m2m_flow_former.cfg import get_cfg + + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +losses = {} + + +def register(name): + def decorator(cls): + losses[name] = cls + return cls + return decorator + + +def make_loss_dict(loss_cfgs): + loss_dict = dict() + + def make_loss(loss_spec): + loss = losses[loss_spec['name']](**loss_spec['args']) + return loss + + for loss_cfg in loss_cfgs: + loss_dict[loss_cfg['name']] = make_loss(loss_cfg) + + return loss_dict + +@register('frequency') +class Frequency(nn.Module): + def __init__(self, weight): + super(Frequency, self).__init__() + self.weight = weight + + def forward(self, imgt, imgt_pred, **kwargs): + fft_pred = torch.fft.fft2(imgt_pred) + amp_pred = torch.abs(fft_pred) + pha_pred = torch.angle(fft_pred) + + fft_gt = torch.fft.fft2(imgt) + amp_gt = torch.abs(fft_gt) + pha_gt = torch.angle(fft_gt) + + amp_loss = F.l1_loss(input=amp_pred, target=amp_gt, reduction='mean') + pha_loss = F.l1_loss(input=pha_pred, target=pha_gt, reduction='mean') + + return (amp_loss + pha_loss) * self.weight + +@register('bi_frequency') +class BidirectionalFrequency(nn.Module): + def __init__(self, weight): + super(BidirectionalFrequency, self).__init__() + self.weight = weight + + def get_amp_pha(self, img): + fft = torch.fft.fft2(img) + amplitude = torch.abs(fft) + phase = torch.angle(fft) + return amplitude, phase + + def forward(self, img0, img1, imgt, imgt_pred, **kwargs): + amp0, pha0 = self.get_amp_pha(img0) + amp1, pha1 = self.get_amp_pha(img1) + ampt, phat = self.get_amp_pha(imgt) + ampt_pred, phat_pred = self.get_amp_pha(imgt_pred) + + amp_loss0 = F.l1_loss(torch.abs(amp0-ampt), torch.abs(amp0-ampt_pred)) + amp_loss1 = F.l1_loss(torch.abs(amp1-ampt), torch.abs(amp1-ampt_pred)) + pha_loss0 = F.l1_loss(torch.abs(pha0-phat), torch.abs(pha0-phat_pred)) + pha_loss1 = F.l1_loss(torch.abs(pha1-phat), torch.abs(pha1-phat_pred)) + + return (amp_loss0 + amp_loss1 + pha_loss0 + pha_loss1) * self.weight + +@register('l1') +class L1(nn.Module): + def __init__(self): + super(L1, self).__init__() + + # end + + def forward(self, img0, img1): + return F.l1_loss(input=img0, target=img1, reduction='mean') + # end + + +# end + +@register('charbonnier') +class Charbonnier(nn.Module): + def __init__(self, weight): + super(Charbonnier, self).__init__() + self.weight = weight + + # end + + def forward(self, imgt, imgt_pred, **kwargs): + return (((imgt - imgt_pred) ** 2 + 1e-6) ** 0.5).mean() * self.weight + # end + + +# end + +@register('multiple_charbonnier') +class MultipleCharbonnier(nn.Module): + def __init__(self, weight, gamma, **kwargs): + super().__init__() + self.weight = weight + self.gamma = gamma + self.charbonnier = Charbonnier(1) + + def forward(self, imgt_preds, imgt, **kwargs): + loss_charbonnier = torch.Tensor([0]).cuda() + for i in range(len(imgt_preds)): + i_weight = self.gamma ** (len(imgt_preds) - i - 1) + loss_charbonnier += self.charbonnier(imgt_preds[i], imgt) * i_weight + return loss_charbonnier * self.weight + + +@register('ternary') +class Ternary(nn.Module): + def __init__(self, weight): + super(Ternary, self).__init__() + patch_size = 7 + out_channels = patch_size * patch_size + self.w = np.eye(out_channels).reshape( + (patch_size, patch_size, 1, out_channels)) + self.w = np.transpose(self.w, (3, 2, 0, 1)) + self.w = torch.tensor(self.w).float().to(device) + self.weight = weight + + # end + + def transform(self, img): + patches = F.conv2d(img, self.w, padding=3, bias=None) + transf = patches - img + transf_norm = transf / torch.sqrt(0.81 + transf ** 2) + return transf_norm + + # end + + def rgb2gray(self, rgb): + r, g, b = rgb[:, 0:1, :, :], rgb[:, 1:2, :, :], rgb[:, 2:3, :, :] + gray = 0.2989 * r + 0.5870 * g + 0.1140 * b + return gray + + # end + + def hamming(self, t1, t2): + dist = (t1 - t2) ** 2 + dist_norm = torch.mean(dist / (0.1 + dist), 1, True) + return dist_norm + + # end + + def valid_mask(self, t, padding): + n, _, h, w = t.size() + inner = torch.ones(n, 1, h - 2 * padding, w - 2 * padding).type_as(t) + mask = F.pad(inner, [padding] * 4) + return mask + + # end + + def forward(self, imgt, imgt_pred, **kwargs): + imgt = self.transform(self.rgb2gray(imgt)) + imgt_pred = self.transform(self.rgb2gray(imgt_pred)) + return (self.hamming(imgt, imgt_pred) * self.valid_mask(imgt, 1)).mean() * self.weight + # end + + +# end + +@register('multiple_ternary') +class MultipleTernary(nn.Module): + def __init__(self, weight, gamma, **kwargs): + super().__init__() + self.weight = weight + self.gamma = gamma + self.ternary = Ternary(1) + + def forward(self, imgt_preds, imgt, **kwargs): + loss_ter = torch.Tensor([0]).cuda() + for i in range(len(imgt_preds)): + i_weight = self.gamma ** (len(imgt_preds) - i - 1) + loss_ter += self.ternary(imgt_preds[i], imgt) * i_weight + return loss_ter * self.weight + + +@register('sobel') +class SOBEL(nn.Module): + def __init__(self): + super(SOBEL, self).__init__() + self.kernelX = torch.tensor([ + [1, 0, -1], + [2, 0, -2], + [1, 0, -1], + ]).float() + self.kernelY = self.kernelX.clone().T + self.kernelX = self.kernelX.unsqueeze(0).unsqueeze(0).to(device) + self.kernelY = self.kernelY.unsqueeze(0).unsqueeze(0).to(device) + + # end + + def forward(self, pred, gt): + N, C, H, W = pred.shape[0], pred.shape[1], pred.shape[2], pred.shape[3] + img_stack = torch.cat( + [pred.reshape(N * C, 1, H, W), gt.reshape(N * C, 1, H, W)], 0) + sobel_stack_x = F.conv2d(img_stack, self.kernelX, padding=1) + sobel_stack_y = F.conv2d(img_stack, self.kernelY, padding=1) + pred_X, gt_X = sobel_stack_x[:N * C], sobel_stack_x[N * C:] + pred_Y, gt_Y = sobel_stack_y[:N * C], sobel_stack_y[N * C:] + + L1X, L1Y = torch.abs(pred_X - gt_X), torch.abs(pred_Y - gt_Y) + loss = (L1X + L1Y) + return loss + # end + + +# end + + +class MeanShift(nn.Conv2d): + def __init__(self, data_mean, data_std, data_range=1, norm=True): + c = len(data_mean) + super(MeanShift, self).__init__(c, c, kernel_size=1) + std = torch.Tensor(data_std) + self.weight.data = torch.eye(c).view(c, c, 1, 1) + if norm: + self.weight.data.div_(std.view(c, 1, 1, 1)) + self.bias.data = -1 * data_range * torch.Tensor(data_mean) + self.bias.data.div_(std) + else: + self.weight.data.mul_(std.view(c, 1, 1, 1)) + self.bias.data = data_range * torch.Tensor(data_mean) + # end + self.requires_grad = False + # end + + +# end + +@register('vgg') +class VGGPerceptualLoss(nn.Module): + def __init__(self, weight=1): + super(VGGPerceptualLoss, self).__init__() + blocks = [] + pretrained = True + self.weight = weight + self.vgg_pretrained_features = models.vgg19(pretrained=pretrained).features + self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda() + for param in self.parameters(): + param.requires_grad = False + # end + + # end + + def forward(self, imgt, imgt_pred, **kwargs): + imgt = self.normalize(imgt) + imgt_pred = self.normalize(imgt_pred) + indices = [2, 7, 12, 21, 30] + weights = [1.0 / 2.6, 1.0 / 4.8, 1.0 / 3.7, 1.0 / 5.6, 10 / 1.5] + k = 0 + loss = 0 + for i in range(indices[-1]): + imgt = self.vgg_pretrained_features[i](imgt) + imgt_pred = self.vgg_pretrained_features[i](imgt_pred) + if (i + 1) in indices: + loss += weights[k] * (imgt - imgt_pred.detach()).abs().mean() * 0.1 + k += 1 + # end + # end + return loss * self.weight + # end +# end + + +@register('ada_charbonnier') +class AdaCharbonnierLoss(nn.Module): + def __init__(self, weight) -> None: + super().__init__() + self.weight = weight + + def forward(self, imgt_pred, imgt, weight, **kwargs): + alpha = weight / 2 + epsilon = 10 ** (-(10 * weight - 1) / 3) + + diff = imgt_pred - imgt + loss = ((diff ** 2 + epsilon ** 2) ** alpha).mean() + return loss + + +@register('multiple_flow') +class MultipleFlowLoss(nn.Module): + def __init__(self, weight, beta=0.3) -> None: + super().__init__() + self.weight = weight + self.beta = beta + self.ada_cb_loss = AdaCharbonnierLoss(1.0) + + def forward(self, flow0_pred, flow1_pred, flowt0, flowt1, **kwargs): + robust_weight0 = self.get_mutli_flow_robust_weight(flow0_pred[0], flowt0) + robust_weight1 = self.get_mutli_flow_robust_weight(flow1_pred[0], flowt1) + loss = 0 + h, w = flowt0.shape[-2:] + for lvl in range(0, len(flow0_pred)): + h_lvl, w_lvl = flow0_pred[lvl].shape[-2:] + scale_factor = h / h_lvl + loss = loss + self.ada_cb_loss(**{ + 'imgt_pred': self.resize(flow0_pred[lvl], scale_factor), + 'imgt': flowt0, + 'weight': robust_weight0 + }) + loss = loss + self.ada_cb_loss(**{ + 'imgt_pred': self.resize(flow1_pred[lvl], scale_factor), + 'imgt': flowt1, + 'weight': robust_weight1 + }) + return loss * self.weight + + def resize(self, x, scale_factor): + return scale_factor * F.interpolate(x, scale_factor=scale_factor, mode="bilinear", align_corners=False) + + def get_mutli_flow_robust_weight(self, flow_pred, flow_gt): + dims = flow_pred.shape + if len(dims) == 5: + b, num_flows, c, h, w = dims + else: + b, c, h, w = dims + num_flows = 1 + flow_pred = flow_pred.view(b, num_flows, c, h, w) + flow_gt = flow_gt.repeat(1, num_flows, 1, 1).view(b, num_flows, c, h, w) + epe = ((flow_pred.detach() - flow_gt) ** 2).sum(dim=2, keepdim=True).max(1)[0] ** 0.5 + # robust_weight = torch.exp(-self.beta * epe) + robust_weight = torch.ones_like(epe) + return robust_weight + + +@register('lap') +class LapLoss(torch.nn.Module): + @staticmethod + def gauss_kernel(size=5, channels=3): + kernel = torch.tensor([[1., 4., 6., 4., 1], + [4., 16., 24., 16., 4.], + [6., 24., 36., 24., 6.], + [4., 16., 24., 16., 4.], + [1., 4., 6., 4., 1.]]) + kernel /= 256. + kernel = kernel.repeat(channels, 1, 1, 1) + kernel = kernel.to(device) + return kernel + + + @staticmethod + def laplacian_pyramid(img, kernel, max_levels=3): + def downsample(x): + return x[:, :, ::2, ::2] + + def upsample(x): + cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3]).to(device)], dim=3) + cc = cc.view(x.shape[0], x.shape[1], x.shape[2]*2, x.shape[3]) + cc = cc.permute(0,1,3,2) + cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2]*2).to(device)], dim=3) + cc = cc.view(x.shape[0], x.shape[1], x.shape[3]*2, x.shape[2]*2) + x_up = cc.permute(0,1,3,2) + return conv_gauss(x_up, 4*LapLoss.gauss_kernel(channels=x.shape[1])) + + def conv_gauss(img, kernel): + img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode='reflect') + out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1]) + return out + + current = img + pyr = [] + for level in range(max_levels): + filtered = conv_gauss(current, kernel) + down = downsample(filtered) + up = upsample(down) + diff = current-up + pyr.append(diff) + current = down + return pyr + + def __init__(self, max_levels=5, channels=3): + super(LapLoss, self).__init__() + self.max_levels = max_levels + self.gauss_kernel = LapLoss.gauss_kernel(channels=channels) + + def forward(self, imgt_pred, imgt): + pyr_pred = LapLoss.laplacian_pyramid( + img=imgt_pred, kernel=self.gauss_kernel, max_levels=self.max_levels) + pyr_target = LapLoss.laplacian_pyramid( + img=imgt, kernel=self.gauss_kernel, max_levels=self.max_levels) + return sum(torch.nn.functional.l1_loss(a, b) for a, b in zip(pyr_pred, pyr_target)) + + +@register('vos') +class VOSLoss(nn.Module): + def __init__(self, weight): + super(VOSLoss, self).__init__() + self.weight = weight + + def forward(self, segt, segt_f_binary, segt_b_binary, **kwargs): + # segt = torch.cat([segt < 0.5, segt > 0.5], dim=1).float() + loss = F.binary_cross_entropy(segt_f_binary, segt) + F.binary_cross_entropy(segt_b_binary, segt) + F.binary_cross_entropy(segt_b_binary, segt_f_binary) + return loss * self.weight + + +@register('texture_consistency') +class TCLoss(nn.Module): + def __init__(self, weight): + super(TCLoss, self).__init__() + self.weight = weight + + def rgb2gray(self, rgb): + r, g, b = rgb[:, 0:1, :, :], rgb[:, 1:2, :, :], rgb[:, 2:3, :, :] + gray = 0.2989 * r + 0.5870 * g + 0.1140 * b + return gray + + def forward(self, imgt_pred, imgt, **kwargs): + b, c, h, w = imgt_pred.shape + imgt_g = self.rgb2gray(imgt) + imgt_pred_g = self.rgb2gray(imgt_pred) + imgt_patched = F.unfold(imgt_g, [3, 3], padding=1).view(b, 9, h, w) + census_imgt = ((imgt_patched - imgt_g) < 0).to(torch.float32) + imgt_pred_patched = F.unfold(imgt_pred_g, [3, 3], padding=1).view(b, 9, h, w) + census_imgt_pred = ((imgt_pred_patched - imgt_pred_g) < 0).to(torch.float32).view(b, 9, 1, h, w) + census_imgt_unfold = F.unfold(census_imgt, [5, 5], padding=2).view(b, 9, 25, h, w) + diff = (census_imgt_unfold - census_imgt_pred).abs().sum(dim=1) + valid_mask = torch.argmax(diff, dim=1, keepdim=True).view(b, 1, 1, h, w) + imgt_patched = F.unfold(imgt, [3, 3], padding=1).view(b, c * 9, h, w) + imgt_masked = torch.take_along_dim( + F.unfold(imgt_patched, kernel_size=[5, 5], padding=2).view(b, c * 9, 25, h, w), valid_mask, 2) + imgt_pred_patched = F.unfold(imgt, [3, 3], padding=1).view(b, c * 9, 1, h, w) + + loss = F.l1_loss(imgt_masked, imgt_pred_patched) + return loss * self.weight + + +@register('flow_consistency') +class FCLoss(nn.Module): + def __init__(self, weight): + super(FCLoss, self).__init__() + self.weight = weight + # self.of_model = UniMatch(2, 128, 4, 1, 4, 6, True) + + cfg = get_cfg().latentcostformer + + self.of_model = FlowFormer(cfg) + checkpoint = torch.load('./modules/components/m2m_flow_former/flowformer++.pth') + checkpoint_mod = {k.replace('module.', ''): checkpoint[k] for k in checkpoint.keys()} + self.of_model.load_state_dict(checkpoint_mod, strict=False) + self.of_model.to(device) + self.of_model.eval() + for p in self.of_model.parameters(True): + p.requires_grad = False + + def forward(self, imgt_pred, img0, img1, flowt0, flowt1, **kwargs): + self.of_model.eval() + # flowt0_pred = self.of_model(imgt_pred, img0, 'swin', [2, 8], [-1, 4], [-1, 1], 6)[-1] + # flowt1_pred = self.of_model(imgt_pred, img1, 'swin', [2, 8], [-1, 4], [-1, 1], 6)[-1] + flowt0_pred = self.of_model(imgt_pred, img0)[-1] + flowt1_pred = self.of_model(imgt_pred, img1)[-1] + return ((flowt0_pred - flowt0).abs().mean() + (flowt1_pred - flowt1).abs().mean()) * self.weight, flowt0_pred + + +def census_transform(img, kernel_size=3): + """ + Calculates the census transform of an image of shape [N x C x H x W] with batch size N, number of channels C, + height H and width W. If C > 1, the census transform is applied independently on each channel. + :param img: input image as torch.Tensor of shape [H x C x H x W] + :return: census transform of img + """ + assert len(img.size()) == 4 + if kernel_size != 3: + raise NotImplementedError + + n, c, h, w = img.size() + + census = torch.zeros((n, c, h - 2, w - 2), dtype=torch.uint8, device=img.device) + + cp = img[:, :, 1:h - 1, 1:w - 1] + offsets = [(u, v) for v in range(3) for u in range(3) if not u == 1 == v] + + # do the pixel comparisons + for u, v in offsets: + census = (census << 1) | (img[:, :, v:v + h - 2, u:u + w - 2] >= cp).byte() + + return torch.nn.functional.pad(census.float() / 255, (1, 1, 1, 1), mode='reflect') + + +class CensusTransform(torch.nn.Module): + """ + Calculates the census transform of an image of shape [N x C x H x W] with batch size N, number of channels C, + height H and width W. If C > 1, the census transform is applied independently on each channel. + :param img: input image as torch.Tensor of shape [H x C x H x W] + :return: census transform of img + """ + def __init__(self, kernel_size=3): + super().__init__() + self._kernel_size = kernel_size + + def forward(self, x): + x = census_transform(x, self._kernel_size) + return x + + +@register('texture_consistency_original') +class PatchMatching(nn.Module): + def __init__(self, weight, kSize=3, nsize=7, scale=4, alpha=1): + super(PatchMatching, self).__init__() + self.scale = scale + self.kSize = kSize + self.nsize = nsize + self.alpha = alpha + self.weight = weight + + self.ct = CensusTransform() + + def _unfold(self, data, with_unfold=False): + + if self.scale != 1: + data = torch.nn.functional.interpolate(data, scale_factor=1.0 / self.scale, mode='bicubic', + align_corners=False) + pad = self.kSize // 2 + + data_pad = torch.nn.functional.pad(data, (pad, pad, pad, pad), mode='reflect') + d1 = torch.nn.functional.unfold(data_pad, kernel_size=self.kSize) # .permute(0,2,1) + if not with_unfold: + return d1.permute(0, 2, 1).unsqueeze(-2) + else: + b, c, h, w = data.size() + # print('d1',d1.shape,data.shape) + d1 = d1.view(b, -1, h, w) + c1 = d1.size()[1] + pad = self.nsize // 2 + d1_pad = torch.nn.functional.pad(d1, (pad, pad, pad, pad), mode='reflect') + d1_pad_unflod = torch.nn.functional.unfold(d1_pad, kernel_size=self.nsize) # .permute(0,2,1) + d1_pad_unflod = d1_pad_unflod.view(b, c1, -1, h * w).permute(0, 3, 2, 1) + # print(d1_pad_unflod.shape) + return d1_pad_unflod + + def _match(self, pred, ref_d0, ref_d1): + # b + b, n, c = pred.size() + print('--', pred.shape) + pred_2 = (pred ** 2).sum(-1).view(b, n, -1) + ref_d0_2 = (ref_d0 ** 2).sum(-1).view(b, -1, n) + ref_d1_2 = (ref_d1 ** 2).sum(-1).view(b, -1, n) + # gt_2 = (gt**2).sum(-1).view(b,-1,n) + + error_d0 = pred_2 + ref_d0_2 - 2.0 * torch.matmul(pred, ref_d0.permute(0, 2, 1)) + error_d1 = pred_2 + ref_d1_2 - 2.0 * torch.matmul(pred, ref_d1.permute(0, 2, 1)) + + score_d0 = torch.exp(self.alpha * error_d0) + score_d1 = torch.exp(self.alpha * error_d1) + # print('score_d0',score_d0.shape,score_d1.shape) + + weight, ind = torch.min(score_d0, dim=2) + index_d0 = ind.unsqueeze(-1).expand([-1, -1, c]) + print(ref_d0.shape, index_d0.shape) + matched_d0 = torch.gather(ref_d0, dim=1, index=index_d0) + + weight, ind = torch.min(score_d1, dim=2) + index_d1 = ind.unsqueeze(-1).expand([-1, -1, c]) + matched_d1 = torch.gather(ref_d1, dim=1, index=index_d1) + # print('matched_d1',matched_d1.shape) + + # error_gt_d0 = gt_2 + ref_d0_2 - 2.0 * torch.matmul(ref_d0,gt.permute(0,2,1)) + # score_gt_d0 = torch.exp(self.alpha * error_gt_d0) + # weight,ind = torch.min(score_gt_d0,dim=2) + # index_d0 = ind.unsqueeze(-1).expand([-1,-1,c]) + # matched_d0 = torch.gather(ref_d0,dim=1,index=index_d0) + + loss = ((pred - matched_d0) ** 2).mean() + ((pred - matched_d1) ** 2).mean() + return loss + + # error_d1 = pred_2 + ref_d0_2 - 2.0 * torch.matmul(pred,ref_d0.permute(0,2,1)) + + def forward(self, imgt_pred, imgt, **kwarps): + + pred_ct = self.ct(imgt_pred) + gt_ct = self.ct(imgt) + + pred_ct = self._unfold(pred_ct) + gt_ct = self._unfold(gt_ct, with_unfold=True) + + + pred_ct = pred_ct.repeat(1, 1, self.nsize ** 2, 1) + + dis_I_ct = ((pred_ct - gt_ct) ** 2).sum(-1) + weight, ind = torch.min(dis_I_ct, dim=2) + index_d = ind.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, self.nsize ** 2 * 2, 3 * self.kSize ** 2) + + imgt_pred = self._unfold(imgt_pred) + imgt = self._unfold(imgt, with_unfold=True) + + imgt_pred = imgt_pred.repeat(1, 1, self.nsize ** 2, 1) + + matched_d = torch.gather(imgt, dim=2, index=index_d) + + # print(pred.shape,matched_d.shape) + + loss = ((imgt_pred[:, :, 0] - matched_d[:, :, 0]) ** 2) * 0.5 + + return loss.mean() * self.weight \ No newline at end of file diff --git a/modules/lr_scheduler.py b/modules/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..07355f48e599b47016479e2283ede433fae42bd0 --- /dev/null +++ b/modules/lr_scheduler.py @@ -0,0 +1,12 @@ +from torch.optim.lr_scheduler import * + + +def make_lr_scheduler(optimizer, lr_scheduler_spec): + lr_scheduler = { + 'step_lr': StepLR, + 'one_cycle_lr': OneCycleLR, + 'cosine_lr': CosineAnnealingLR, + 'constant_lr': ConstantLR, + + }[lr_scheduler_spec['name']](optimizer, **lr_scheduler_spec['args']) + return lr_scheduler diff --git a/modules/models/__init__.py b/modules/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7cd94067fc801115aa24e2a4d47ebffff924061 --- /dev/null +++ b/modules/models/__init__.py @@ -0,0 +1,15 @@ +from .models import * +from .base_model import * +from .m2m_pwc import * +from .amt import * +from .amt_flowformer import * +from .amt_bilateral import * +from .amt_splat import * +from .upr_basic import * +from .upr_net import * +from .upr_net_mod import * +from .upr_net_mod2 import * +from .upr_net_freq import * +from .upr_net_freq2 import * +from .m2m_flowformer import * +from .upr_net_multi_flow import * \ No newline at end of file diff --git a/modules/models/__pycache__/__init__.cpython-310.pyc b/modules/models/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a43d625402a784205b75a47c343aa9661d189dca Binary files /dev/null and b/modules/models/__pycache__/__init__.cpython-310.pyc differ diff --git a/modules/models/__pycache__/__init__.cpython-38.pyc b/modules/models/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..272518a2f57bceb253838df9bcdec10c799e6092 Binary files /dev/null and b/modules/models/__pycache__/__init__.cpython-38.pyc differ diff --git a/modules/models/__pycache__/__init__.cpython-39.pyc b/modules/models/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66267714712e4d23adb893cfb1a92d5fa5918761 Binary files /dev/null and b/modules/models/__pycache__/__init__.cpython-39.pyc differ diff --git a/modules/models/__pycache__/amt.cpython-310.pyc b/modules/models/__pycache__/amt.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e3918adf8ae70487066f7084386b1e1f993f3e5 Binary files /dev/null and b/modules/models/__pycache__/amt.cpython-310.pyc differ diff --git a/modules/models/__pycache__/amt.cpython-38.pyc b/modules/models/__pycache__/amt.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..673b2e729a3f4cfd199fed5ec8f1e9025c2ff338 Binary files /dev/null and b/modules/models/__pycache__/amt.cpython-38.pyc differ diff --git a/modules/models/__pycache__/amt.cpython-39.pyc b/modules/models/__pycache__/amt.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b10524da95e2dbce2f1ed0d23bbfaa333309045a Binary files /dev/null and b/modules/models/__pycache__/amt.cpython-39.pyc differ diff --git a/modules/models/__pycache__/amt_bilateral.cpython-310.pyc b/modules/models/__pycache__/amt_bilateral.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31987143ad06db44216b86c1b8fc999326f2150a Binary files /dev/null and b/modules/models/__pycache__/amt_bilateral.cpython-310.pyc differ diff --git a/modules/models/__pycache__/amt_bilateral.cpython-38.pyc b/modules/models/__pycache__/amt_bilateral.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a24bc9b8a72a41429d977afba60f5c6d3bd8c08 Binary files /dev/null and b/modules/models/__pycache__/amt_bilateral.cpython-38.pyc differ diff --git a/modules/models/__pycache__/amt_bilateral.cpython-39.pyc b/modules/models/__pycache__/amt_bilateral.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd13d37ed8ce98022088de188c5ecc549fdcf0b1 Binary files /dev/null and b/modules/models/__pycache__/amt_bilateral.cpython-39.pyc differ diff --git a/modules/models/__pycache__/amt_flowformer.cpython-310.pyc b/modules/models/__pycache__/amt_flowformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c1caf47803058ce2fd3979f917e662018c5e7c5 Binary files /dev/null and b/modules/models/__pycache__/amt_flowformer.cpython-310.pyc differ diff --git a/modules/models/__pycache__/amt_flowformer.cpython-38.pyc b/modules/models/__pycache__/amt_flowformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1ff881c2ac1d69889d630fd2256ba1083353269 Binary files /dev/null and b/modules/models/__pycache__/amt_flowformer.cpython-38.pyc differ diff --git a/modules/models/__pycache__/amt_flowformer.cpython-39.pyc b/modules/models/__pycache__/amt_flowformer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e6718db3c56b267e3b405482a4bce5ccbdf031e Binary files /dev/null and b/modules/models/__pycache__/amt_flowformer.cpython-39.pyc differ diff --git a/modules/models/__pycache__/amt_splat.cpython-310.pyc b/modules/models/__pycache__/amt_splat.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e1846d7ca33e537af5bfd675c7195f8bdf604bc Binary files /dev/null and b/modules/models/__pycache__/amt_splat.cpython-310.pyc differ diff --git a/modules/models/__pycache__/amt_splat.cpython-38.pyc b/modules/models/__pycache__/amt_splat.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40abfd531d525c991e2978d195764de245951bb6 Binary files /dev/null and b/modules/models/__pycache__/amt_splat.cpython-38.pyc differ diff --git a/modules/models/__pycache__/amt_splat.cpython-39.pyc b/modules/models/__pycache__/amt_splat.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a61eade2f46ee88f323d436d36ff70e9eb3dc03 Binary files /dev/null and b/modules/models/__pycache__/amt_splat.cpython-39.pyc differ diff --git a/modules/models/__pycache__/base_model.cpython-310.pyc b/modules/models/__pycache__/base_model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3658071b52114001631ca713bd7b31064137ae12 Binary files /dev/null and b/modules/models/__pycache__/base_model.cpython-310.pyc differ diff --git a/modules/models/__pycache__/base_model.cpython-38.pyc b/modules/models/__pycache__/base_model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e09679c7969c85c2749824251fb6a93c449e662d Binary files /dev/null and b/modules/models/__pycache__/base_model.cpython-38.pyc differ diff --git a/modules/models/__pycache__/base_model.cpython-39.pyc b/modules/models/__pycache__/base_model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15445545fa814cb38b1f1da73dbe72e4648831ce Binary files /dev/null and b/modules/models/__pycache__/base_model.cpython-39.pyc differ diff --git a/modules/models/__pycache__/inference_video.cpython-310.pyc b/modules/models/__pycache__/inference_video.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f4b5f66e403873476e7f08b104b265e928e6f97 Binary files /dev/null and b/modules/models/__pycache__/inference_video.cpython-310.pyc differ diff --git a/modules/models/__pycache__/inference_video.cpython-38.pyc b/modules/models/__pycache__/inference_video.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..582420ca90c3c6fbd6b80d26188232907b050751 Binary files /dev/null and b/modules/models/__pycache__/inference_video.cpython-38.pyc differ diff --git a/modules/models/__pycache__/inference_video.cpython-39.pyc b/modules/models/__pycache__/inference_video.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e20d39aa0942fc54b2bb73928b910c1daab34eb Binary files /dev/null and b/modules/models/__pycache__/inference_video.cpython-39.pyc differ diff --git a/modules/models/__pycache__/m2m_flowformer.cpython-310.pyc b/modules/models/__pycache__/m2m_flowformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18721fe06319235a159c286a127df6df64bd6062 Binary files /dev/null and b/modules/models/__pycache__/m2m_flowformer.cpython-310.pyc differ diff --git a/modules/models/__pycache__/m2m_flowformer.cpython-38.pyc b/modules/models/__pycache__/m2m_flowformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b14d54f27e4af3e271867043dfb85cc2aed66e1 Binary files /dev/null and b/modules/models/__pycache__/m2m_flowformer.cpython-38.pyc differ diff --git a/modules/models/__pycache__/m2m_flowformer.cpython-39.pyc b/modules/models/__pycache__/m2m_flowformer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4540607671a482825cfd930cb6087e3d5694d611 Binary files /dev/null and b/modules/models/__pycache__/m2m_flowformer.cpython-39.pyc differ diff --git a/modules/models/__pycache__/m2m_pwc.cpython-310.pyc b/modules/models/__pycache__/m2m_pwc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a943302356251539a7a5a5280ddf3b1324695f57 Binary files /dev/null and b/modules/models/__pycache__/m2m_pwc.cpython-310.pyc differ diff --git a/modules/models/__pycache__/m2m_pwc.cpython-38.pyc b/modules/models/__pycache__/m2m_pwc.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9d636dfb8a9069bf68fc7ca5ed5a3225ab0c732 Binary files /dev/null and b/modules/models/__pycache__/m2m_pwc.cpython-38.pyc differ diff --git a/modules/models/__pycache__/m2m_pwc.cpython-39.pyc b/modules/models/__pycache__/m2m_pwc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b109315e17e37f4fb1c5f6e151bc95b479eb9cc3 Binary files /dev/null and b/modules/models/__pycache__/m2m_pwc.cpython-39.pyc differ diff --git a/modules/models/__pycache__/models.cpython-310.pyc b/modules/models/__pycache__/models.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6ff0e39420327a2531f87a2db03b1b13e4afc94 Binary files /dev/null and b/modules/models/__pycache__/models.cpython-310.pyc differ diff --git a/modules/models/__pycache__/models.cpython-38.pyc b/modules/models/__pycache__/models.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..421b88054df748d1c1e357707c174504a04dab76 Binary files /dev/null and b/modules/models/__pycache__/models.cpython-38.pyc differ diff --git a/modules/models/__pycache__/models.cpython-39.pyc b/modules/models/__pycache__/models.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63733b5c0323da197d83bd3f2eb6fd255405abe2 Binary files /dev/null and b/modules/models/__pycache__/models.cpython-39.pyc differ diff --git a/modules/models/__pycache__/upr_basic.cpython-310.pyc b/modules/models/__pycache__/upr_basic.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99df6bf6aeb34dd4aa2306a1070d50da5212f460 Binary files /dev/null and b/modules/models/__pycache__/upr_basic.cpython-310.pyc differ diff --git a/modules/models/__pycache__/upr_basic.cpython-38.pyc b/modules/models/__pycache__/upr_basic.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0fccb0c5b6153c9e337fe2716d930bd62d493f6 Binary files /dev/null and b/modules/models/__pycache__/upr_basic.cpython-38.pyc differ diff --git a/modules/models/__pycache__/upr_basic.cpython-39.pyc b/modules/models/__pycache__/upr_basic.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a2f1f957f4f71af27597516f12aae1e3ee1ee9c Binary files /dev/null and b/modules/models/__pycache__/upr_basic.cpython-39.pyc differ diff --git a/modules/models/__pycache__/upr_net.cpython-310.pyc b/modules/models/__pycache__/upr_net.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..371eaca9bb4f0700792a1d5308b82d9b63600871 Binary files /dev/null and b/modules/models/__pycache__/upr_net.cpython-310.pyc differ diff --git a/modules/models/__pycache__/upr_net.cpython-38.pyc b/modules/models/__pycache__/upr_net.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d435c288c29dd66cf8a6e75cb64be91e5f96dd44 Binary files /dev/null and b/modules/models/__pycache__/upr_net.cpython-38.pyc differ diff --git a/modules/models/__pycache__/upr_net.cpython-39.pyc b/modules/models/__pycache__/upr_net.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b52319884abad2a2ce972d8974f6638315a285b Binary files /dev/null and b/modules/models/__pycache__/upr_net.cpython-39.pyc differ diff --git a/modules/models/__pycache__/upr_net_freq.cpython-310.pyc b/modules/models/__pycache__/upr_net_freq.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a39a5166d37c1aa1e19d89f3746768193c6a5438 Binary files /dev/null and b/modules/models/__pycache__/upr_net_freq.cpython-310.pyc differ diff --git a/modules/models/__pycache__/upr_net_freq.cpython-38.pyc b/modules/models/__pycache__/upr_net_freq.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7bbfd10c277847747887b00b0b081a79edd6b94 Binary files /dev/null and b/modules/models/__pycache__/upr_net_freq.cpython-38.pyc differ diff --git a/modules/models/__pycache__/upr_net_freq.cpython-39.pyc b/modules/models/__pycache__/upr_net_freq.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ebbe2fbe217fb2a068cf3009cd7111b5447e823 Binary files /dev/null and b/modules/models/__pycache__/upr_net_freq.cpython-39.pyc differ diff --git a/modules/models/__pycache__/upr_net_freq2.cpython-310.pyc b/modules/models/__pycache__/upr_net_freq2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e6caedf52d5e901260bc9afea055d20626d0139 Binary files /dev/null and b/modules/models/__pycache__/upr_net_freq2.cpython-310.pyc differ diff --git a/modules/models/__pycache__/upr_net_freq2.cpython-38.pyc b/modules/models/__pycache__/upr_net_freq2.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd933cab586a5106149a26f7712f54abba9554cd Binary files /dev/null and b/modules/models/__pycache__/upr_net_freq2.cpython-38.pyc differ diff --git a/modules/models/__pycache__/upr_net_freq2.cpython-39.pyc b/modules/models/__pycache__/upr_net_freq2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fa56e2b589c75ceb1aafd4235248922b0fe6ba1 Binary files /dev/null and b/modules/models/__pycache__/upr_net_freq2.cpython-39.pyc differ diff --git a/modules/models/__pycache__/upr_net_gan.cpython-310.pyc b/modules/models/__pycache__/upr_net_gan.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..106504364624bb20e6eb6ee09b4cb80b7ca8db1b Binary files /dev/null and b/modules/models/__pycache__/upr_net_gan.cpython-310.pyc differ diff --git a/modules/models/__pycache__/upr_net_mod.cpython-310.pyc b/modules/models/__pycache__/upr_net_mod.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef10dfe33af7ddd1a114be7a1732f8bac1094674 Binary files /dev/null and b/modules/models/__pycache__/upr_net_mod.cpython-310.pyc differ diff --git a/modules/models/__pycache__/upr_net_mod.cpython-38.pyc b/modules/models/__pycache__/upr_net_mod.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d333c8616db8f574e64ef0844bdd5c34ff34d477 Binary files /dev/null and b/modules/models/__pycache__/upr_net_mod.cpython-38.pyc differ diff --git a/modules/models/__pycache__/upr_net_mod.cpython-39.pyc b/modules/models/__pycache__/upr_net_mod.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3b51577a9f8a35d23d8c2b900fb8cbfaba98141 Binary files /dev/null and b/modules/models/__pycache__/upr_net_mod.cpython-39.pyc differ diff --git a/modules/models/__pycache__/upr_net_mod2.cpython-310.pyc b/modules/models/__pycache__/upr_net_mod2.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd81a133dc21bc0b0893fce9f522c40744bb96e1 Binary files /dev/null and b/modules/models/__pycache__/upr_net_mod2.cpython-310.pyc differ diff --git a/modules/models/__pycache__/upr_net_mod2.cpython-38.pyc b/modules/models/__pycache__/upr_net_mod2.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d3d11764c0ec3eb5d0fc228f5649f0a03f11b64 Binary files /dev/null and b/modules/models/__pycache__/upr_net_mod2.cpython-38.pyc differ diff --git a/modules/models/__pycache__/upr_net_mod2.cpython-39.pyc b/modules/models/__pycache__/upr_net_mod2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4078ec8f3ba10a6e562e52c911ac461347a449a6 Binary files /dev/null and b/modules/models/__pycache__/upr_net_mod2.cpython-39.pyc differ diff --git a/modules/models/__pycache__/upr_net_multi_flow.cpython-310.pyc b/modules/models/__pycache__/upr_net_multi_flow.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..112736009dd5b920aa2b03f2259c7c6fb2ae5a60 Binary files /dev/null and b/modules/models/__pycache__/upr_net_multi_flow.cpython-310.pyc differ diff --git a/modules/models/__pycache__/upr_net_multi_flow.cpython-38.pyc b/modules/models/__pycache__/upr_net_multi_flow.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6080ea89f059e87351806ca3dbfe878799716ea Binary files /dev/null and b/modules/models/__pycache__/upr_net_multi_flow.cpython-38.pyc differ diff --git a/modules/models/__pycache__/upr_net_multi_flow.cpython-39.pyc b/modules/models/__pycache__/upr_net_multi_flow.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad8c28c16cb141f2b7e1e01481097f0b1e65d150 Binary files /dev/null and b/modules/models/__pycache__/upr_net_multi_flow.cpython-39.pyc differ diff --git a/modules/models/amt.py b/modules/models/amt.py new file mode 100644 index 0000000000000000000000000000000000000000..b4ab670d3402df47b7bf4e9ba0fb1c511ad236b0 --- /dev/null +++ b/modules/models/amt.py @@ -0,0 +1,8 @@ +from modules.models.base_model import BaseModel +from modules.models import register + + +@register('amt') +class AMT(BaseModel): + def __init__(self, cfg): + super(AMT, self).__init__(cfg) diff --git a/modules/models/amt_bilateral.py b/modules/models/amt_bilateral.py new file mode 100644 index 0000000000000000000000000000000000000000..e3c13def0a49e4d1c7f849965e4891920e788373 --- /dev/null +++ b/modules/models/amt_bilateral.py @@ -0,0 +1,8 @@ +from modules.models.base_model import BaseModel +from modules.models import register + + +@register('amt_bilateral') +class AMT(BaseModel): + def __init__(self, cfg): + super(AMT, self).__init__(cfg) diff --git a/modules/models/amt_flowformer.py b/modules/models/amt_flowformer.py new file mode 100644 index 0000000000000000000000000000000000000000..fa753e92232688fd829f0f72fe30b6f539e6f825 --- /dev/null +++ b/modules/models/amt_flowformer.py @@ -0,0 +1,8 @@ +from modules.models.base_model import BaseModel +from modules.models import register + + +@register('amt_flowformer') +class AMT(BaseModel): + def __init__(self, cfg): + super(AMT, self).__init__(cfg) diff --git a/modules/models/amt_splat.py b/modules/models/amt_splat.py new file mode 100644 index 0000000000000000000000000000000000000000..d40e24de0a36a0226733c55df326b1fe7afc04d6 --- /dev/null +++ b/modules/models/amt_splat.py @@ -0,0 +1,8 @@ +from modules.models.base_model import BaseModel +from modules.models import register + + +@register('amt_splat') +class AMT(BaseModel): + def __init__(self, cfg): + super(AMT, self).__init__(cfg) diff --git a/modules/models/base_model.py b/modules/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..a41ec4f7c7cf8e3c8854621b60997e0ba247295a --- /dev/null +++ b/modules/models/base_model.py @@ -0,0 +1,361 @@ +import logging +import os +import shutil +import math +import sys +import numpy as np +from tensorboardX import SummaryWriter +from tqdm import tqdm +from typing import Iterable +from pathlib import Path +from time import time +import datetime +import wandb +import cv2 + +import torch +from torchvision.utils import save_image +from torchvision.transforms import functional as TF + +import utils.misc +from modules.components import make_components +import utils.misc as misc +from utils.plot import plot_samples_per_epoch, plot_val_samples +from utils.metrics import calculate_batch_psnr, calculate_batch_ssim +from utils.flowvis import flow2img +from utils.padder import InputPadder +from modules.loss import make_loss_dict +from modules.lr_scheduler import make_lr_scheduler +from modules.optimizer import make_optimizer +from modules.models import make, register +from modules.models.inference_video import inference_demo +from modules.models.unimatch.unimatch import UniMatch + + +@register('base_model') +class BaseModel: + def __init__(self, cfgs): + self.cfgs = cfgs + self.device = torch.cuda.current_device() + + self.current_iteration = 0 + self.current_epoch = 0 + self.model = make_components(self.cfgs['model']) + self.loss_dict = make_loss_dict(cfgs['loss']) + + self.logger = logging.getLogger(self.cfgs['model']['name']) + self.move_components_to_device(cfgs['mode']) + self.model_without_ddp = self.model + if cfgs['distributed']: + self.model = torch.nn.parallel.DistributedDataParallel(self.model, device_ids=[cfgs['gpu']]) + self.model_without_ddp = self.model.module + # self.model = torch.compile(self.model) + self.optimizer = make_optimizer(self.model_without_ddp.parameters(), self.cfgs['optimizer']) + self.lr_scheduler = make_lr_scheduler(self.optimizer, cfgs['lr_scheduler']) + # if self.cfgs['enable_wandb']: + # wandb.watch(self.model_without_ddp, log="all", log_freq=100) + print(f'Total params: {self.count_parameters()}') + +# self.flow_extractor = UniMatch(feature_channels=128, +# num_scales=2, +# upsample_factor=8//2, +# num_head=1, +# ffn_dim_expansion=4, +# num_transformer_layers=6, +# reg_refine=True, +# task='flow') +# fe_sd = torch.load('./pretrained/gmflow-scale2-regrefine6-mixdata-train320x576-4e7b215d.pth')['model'] +# print(self.flow_extractor.load_state_dict(fe_sd)) +# for n,p in self.flow_extractor.named_parameters(): +# p.requires_grad = False +# self.flow_extractor = self.flow_extractor.to(self.device) + + def load_checkpoint(self, file_path): + """ + Load checkpoint + """ + checkpoint = torch.load(file_path, map_location="cpu") + + self.current_epoch = checkpoint['epoch'] + self.current_iteration = checkpoint['iteration'] + self.model_without_ddp.load_state_dict(checkpoint['model']) + self.optimizer.load_state_dict(checkpoint['optimizer']) + self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + self.logger.info('Chekpoint loaded successfully from {} at epoch: {} and iteration: {}'.format( + file_path, checkpoint['epoch'], checkpoint['iteration'])) + self.move_components_to_device(self.cfgs['mode']) + return self.current_epoch + + def load_pretrained(self, file_path): + """ + Load checkpoint + """ + checkpoint = torch.load(file_path, map_location="cpu") +# for key in list(checkpoint.keys()): +# checkpoint[key.replace('module.', '')] = checkpoint.pop(key) + for key in list(checkpoint.keys()): + checkpoint['module.'+key] = checkpoint.pop(key) + if 'state_dict' in checkpoint.keys(): + self.model.load_state_dict(checkpoint['state_dict']) + else: + self.model.load_state_dict(checkpoint) + self.logger.info('Pretrained model loaded successfully from {} '.format( + file_path)) + self.move_components_to_device(self.cfgs['mode']) + return self.current_epoch + + def save_checkpoint(self, file_name, is_best=0): + """ + Save checkpoint + """ + state = { + 'epoch': self.current_epoch, # because epoch is used for loading then this must be added + 1 + 'iteration': self.current_iteration, + 'model': self.model_without_ddp.state_dict(), + 'optimizer': self.optimizer.state_dict(), + 'lr_scheduler': self.lr_scheduler.state_dict() + } + + misc.save_on_master(state, os.path.join(self.cfgs['checkpoint_dir'], file_name)) + + if is_best and misc.is_main_process(): + shutil.copyfile(os.path.join(self.cfgs['checkpoint_dir'], file_name), + os.path.join(self.cfgs['checkpoint_dir'], 'model_best.pth')) + + def adjust_learning_rate(self, epoch): + """ + Adjust learning rate every epoch + """ + self.lr_scheduler.step() + + def train_one_epoch(self, train_loader: Iterable, epoch: int, max_norm: float = 0): + """ + Training step for each mini-batch + """ + self.current_epoch = epoch + self._reset_metric() + + self.model.train() + + header = 'Epoch: [{}]'.format(epoch) + print_freq = 100 + for input_dict in self.metric_logger.log_every(train_loader, print_freq, header): + input_dict = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in input_dict.items()} + result_dict, extra_dict = self.model(**input_dict) + imgt_pred = result_dict['imgt_pred'] + loss = torch.Tensor([0]).to(self.device) + losses = dict() + for k, v in self.loss_dict.items(): + losses[k] = v(**result_dict, **input_dict) + loss += losses[k] + + imgt_pred = torch.clamp(imgt_pred, 0, 1) + self.optimizer.zero_grad() + loss.backward() + if 'gradient_clip' in self.cfgs: + torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfgs['gradient_clip']) + self.optimizer.step() + self.lr_scheduler.step() + + self.metric_logger.update(loss=loss, **losses) + self.metric_logger.update(lr=self.optimizer.param_groups[0]["lr"]) + if misc.is_main_process() and self.current_iteration % print_freq == 0: + nsample = 4 + img0_p, img1_p, gt_p, imgt_pred_p = input_dict['img0'][:nsample].detach(), input_dict['img1'][:nsample].detach(), \ + input_dict['imgt'][:nsample].detach(), imgt_pred[:nsample].detach() + overlapped_img = img0_p * 0.5 + img1_p * 0.5 + + flowfwd = flow2img(result_dict['flowfwd'][:nsample].detach()) + if self.cfgs['train_dataset']['args']['flow'] != 'none': + flowfwd_gt = flow2img(input_dict['flowt0'][:nsample]) +# figure = torch.stack([overlapped_img, imgt_pred_p, flowfwd, gt_p]) + figure = torch.stack([overlapped_img, imgt_pred_p, flowfwd]) +# figure = torch.stack( +# [overlapped_img, imgt_pred_p, flowfwd, flowfwd_gt, gt_p]) + else: + figure = torch.stack( + [overlapped_img, imgt_pred_p, flowfwd, gt_p]) + figure = torch.transpose(figure, 0, 1).reshape(-1, 3, self.cfgs['train_dataset']['args']['patch_size'], + self.cfgs['train_dataset']['args']['patch_size']) + image = plot_samples_per_epoch(figure, os.path.join(self.cfgs['output_dir'], "imgs_train"), + self.current_epoch, self.current_iteration, nsample) + self.summary_writer.add_scalar("Train/loss", loss, self.current_iteration) + for k, v in losses.items(): + self.summary_writer.add_scalar(f'Train/loss_{k}', v, self.current_iteration) + self.summary_writer.add_scalar("Train/LR", self.lr_scheduler.get_last_lr(), self.current_iteration) + # self.summary_writer.add_image("Train/image", image, self.current_iteration) + if self.cfgs['enable_wandb']: + wandb.log({"loss": loss}, step=self.current_iteration) + for k, v in losses.items(): + wandb.log({f'loss_{k}': v}, step=self.current_iteration) + wandb.log({"lr": torch.Tensor(self.lr_scheduler.get_last_lr())}, + step=self.current_iteration) + if self.current_iteration % (print_freq * 10) == 0: + wandb.log({"Image": wandb.Image(image)}, step=self.current_iteration) + + self.current_iteration += 1 + + # gather the stats from all processes + self.metric_logger.synchronize_between_processes() + self.current_epoch += 1 + if utils.misc.is_main_process(): + self.logger.info(f"Averaged training stats: {self.metric_logger}") + + @torch.no_grad() + def validate(self, val_loader): + """ + Validation step for each mini-batch + """ + self.model.eval() + + self.metric_logger = misc.MetricLogger(delimiter=" ") + self.metric_logger.add_meter('psnr', misc.SmoothedValue(window_size=1, fmt='{value:.2f}')) + self.metric_logger.add_meter('ssim', misc.SmoothedValue(window_size=1, fmt='{value:.2f}')) + header = 'Test:' + psnr_dict = {} + + print_freq = 10 + + for input_dict in self.metric_logger.log_every(val_loader, print_freq, header): + input_dict = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in input_dict.items()} + img0 = input_dict['img0'] + imgt = input_dict['imgt'] + img1 = input_dict['img1'] + result_dict, extra_dict = self.model(**input_dict) + + scene_names = input_dict['scene_name'] + + imgt_pred = result_dict['imgt_pred'] + +# folder = os.path.join('../datasets/Vimeo90K/asdf/', scene_names[0]) +# os.makedirs(folder, exist_ok=True) +# cv2.imwrite(os.path.join(folder, 'im2_pred.png'), (imgt_pred[0].clamp(0,1).cpu().detach()*255).permute(1,2,0).numpy().astype(np.uint8)[:,:,::-1]) +# torch.save(result_dict['flowfwd'][0].cpu().detach(), os.path.join(folder, 'flow_fwd.flo')) +# torch.save(result_dict['flowbwd'][0].cpu().detach(), os.path.join(folder, 'flow_bwd.flo')) + + psnr, psnr_list = calculate_batch_psnr(imgt, imgt_pred) + ssim, bs = calculate_batch_ssim(imgt, imgt_pred) + self.metric_logger.update(psnr={'value': psnr, 'n': len(psnr_list)}, + ssim={'value': ssim, 'n': len(psnr_list)}) + if (self.current_epoch!=0) and ((self.current_epoch % self.cfgs['vis_every'] == 0) or (self.cfgs['mode'] != 'train' and self.cfgs['test_dataset']['save_imgs'])): + for i in range(len(scene_names)): + psnr_dict[scene_names[i]] = float(psnr_list[i]) + if self.cfgs['mode'] == "test": + scene_path = os.path.join(self.cfgs['output_dir'], "imgs_test", + f"{self.cfgs['test_dataset']['name']}_{self.cfgs['test_dataset']['args']['split']}", + scene_names[i]) + else: + scene_path = os.path.join(self.cfgs['output_dir'], "imgs_val", + f"{self.cfgs['test_dataset']['name']}_{self.cfgs['test_dataset']['args']['split']}", + scene_names[i]) + Path(scene_path).mkdir(exist_ok=True, parents=True) + save_image(img0[i], os.path.join(scene_path, "img0.png")) + save_image(imgt_pred[i], os.path.join(scene_path, "imgt_pred.png")) + save_image(imgt[i], os.path.join(scene_path, "imgt.png")) + save_image(img1[i], os.path.join(scene_path, "img1.png")) + save_image((img1[i] + img0[i]) / 2, os.path.join(scene_path, "overlayedd.png")) + save_image(flow2img(result_dict['flowfwd'])[i], os.path.join(scene_path, "flow_fwd.png")) + save_image(flow2img(result_dict['flowbwd'])[i], os.path.join(scene_path, "flow_bwd.png")) + # save_image(flow2img(result_dict['flow0_pred'][1])[i], os.path.join(scene_path, "flow_fwd_2.png")) + # save_image(flow2img(result_dict['flow1_pred'][1])[i], os.path.join(scene_path, "flow_bwd_2.png")) + + # gather the stats from all processes + # self.metric_logger.synchronize_between_processes() + self.logger.info(f"Averaged validate stats:{self.metric_logger.print_avg()}") + if (self.current_epoch!=0) and ((self.current_epoch % self.cfgs['vis_every'] == 0) or (self.cfgs['mode'] != 'train' and self.cfgs['test_dataset']['save_imgs'])): + psnr_str = [] + psnr_dict = sorted(psnr_dict.items(), key=lambda item: item[1]) + for key, val in psnr_dict: + psnr_str.append("{}: {}".format(key, val)) + psnr_str = "\n".join(psnr_str) + if self.cfgs['mode'] == "test": + outdir = os.path.join(self.cfgs['output_dir'], "imgs_test", + f"{self.cfgs['test_dataset']['name']}_{self.cfgs['test_dataset']['args']['split']}") + else: + outdir = os.path.join(self.cfgs['output_dir'], "imgs_val", + f"{self.cfgs['test_dataset']['name']}_{self.cfgs['test_dataset']['args']['split']}") + with open(os.path.join(outdir, "results.txt"), "w") as f: + f.write(psnr_str) + if misc.is_main_process() and self.cfgs['mode'] == 'train': + self.summary_writer.add_scalar("Val/psnr", self.metric_logger.psnr.global_avg, self.current_epoch) + self.summary_writer.add_scalar("Val/ssim", self.metric_logger.ssim.global_avg, self.current_epoch) + if self.cfgs['enable_wandb']: + wandb.log({'val_psnr': self.metric_logger.psnr.global_avg, 'val_ssim': self.metric_logger.ssim.global_avg}, + step=self.current_iteration) + return self.metric_logger.psnr.global_avg + + @torch.no_grad() + def demo(self, video_dir): + start_time = time() + for video_name in os.listdir(video_dir): + # video_name = "Awesome_Again_Stakes_2019.mkv" + video_path = os.path.join(video_dir, video_name) + out_path = os.path.join(self.cfgs['output_dir'], 'demo', video_name.split(".")[0]) + inference_demo(self.model, 2, video_path, out_path) + total_time_str = str(datetime.timedelta(seconds=int(time() - start_time))) + print("Total time: {}".format(total_time_str)) + + def init_training_logger(self): + """ + Initialize training logger specific for each model + """ + if misc.is_main_process(): + self.summary_writer = SummaryWriter(log_dir=self.cfgs['summary_dir'], comment='m2mpwc') + Path(os.path.join(self.cfgs['output_dir'], 'imgs_train')).mkdir(parents=True, exist_ok=True) + Path(os.path.join(self.cfgs['output_dir'], 'imgs_val')).mkdir(parents=True, exist_ok=True) + self._reset_metric() + + def init_validation_logger(self): + """ + Initialize validation logger specific for each model + """ + if misc.is_main_process(): + self.summary_writer = SummaryWriter(log_dir=self.cfgs['summary_dir'], comment='m2mpwc') + Path(os.path.join(self.cfgs['output_dir'], 'imgs_val')).mkdir(parents=True, exist_ok=True) + self._reset_metric() + + def init_testing_logger(self): + """ + Initialize testing logger specific for each model + """ + if misc.is_main_process(): + self.summary_writer = SummaryWriter(log_dir=self.cfgs['summary_dir'], comment='m2mpwc') + Path(os.path.join(self.cfgs['output_dir'], 'imgs_test')).mkdir(parents=True, exist_ok=True) + self._reset_metric() + + def init_demo_logger(self): + """ + Initialize testing logger specific for each model + """ + if misc.is_main_process(): + self.summary_writer = SummaryWriter(log_dir=self.cfgs['summary_dir'], comment='m2mpwc') + Path(os.path.join(self.cfgs['output_dir'], 'demo')).mkdir(parents=True, exist_ok=True) + self._reset_metric() + + def finalize_training(self): + if misc.is_main_process(): + self.summary_writer.close() + + def move_components_to_device(self, mode): + """ + Move components to device + """ + self.model.to(self.device) + for _, v in self.loss_dict.items(): + v.to(self.device) + self.logger.info('Model: {}'.format(self.model)) + + def _reset_metric(self): + """ + Metric related to average meter + """ + self.metric_logger = misc.MetricLogger(delimiter=" ") + self.metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}')) + self.metric_logger.add_meter('loss', misc.SmoothedValue(window_size=20)) + + def count_parameters(self): + """ + Return the number of parameters for the model + """ + model_number = sum(p.numel() for p in self.model_without_ddp.parameters() if p.requires_grad) + return model_number diff --git a/modules/models/inference_video.py b/modules/models/inference_video.py new file mode 100644 index 0000000000000000000000000000000000000000..cd619245ba6dc843608631d89150ffadd6cf175e --- /dev/null +++ b/modules/models/inference_video.py @@ -0,0 +1,119 @@ +import glob +import numpy +import os +import cv2 +import math +import PIL.Image +import torch +import torch.nn.functional as F +import tqdm +import argparse +from moviepy.editor import VideoFileClip +import sys + +from torchvision.utils import save_image +from utils.flowvis import flow2img +from utils.padder import InputPadder + + +########################################################## + +########################################################## +def inference_demo(model, ratio, video_path, out_path): + videogen = [] + is_video = video_path.endswith(".mkv") or video_path.endswith(".webm") or video_path.endswith( + ".mp4") or video_path.endswith(".avi") + if is_video: + clip = VideoFileClip(video_path) + videogen = clip.iter_frames() + ratio = 2 + fps = clip.fps + # if fps == 23 or fps == 25: + # fps = 24 + # if fps == 29 or fps == 31: + # fps = 30 + # if fps == 59: + # fps = 60 + # ratio = 120 // fps + # if fps == 60: + # ratio = 120 // 24 + else: + for f in os.listdir(video_path): + if 'png' or 'jpg' in f: + videogen.append(f) + videogen.sort(key=lambda x: int(x[:-4])) + + if not os.path.exists(out_path): + os.mkdir(out_path) + if not os.path.exists(out_path + "_flow"): + os.mkdir(out_path + '_flow') + + img0 = None + idx = 0 + name_idx = 0 + time_range = torch.arange(1, ratio).view(ratio - 1, 1, 1, 1).cuda() / ratio + for curfile_name in videogen: + if not is_video: + curframe = os.path.join(video_path, curfile_name) + img4_np = cv2.imread(curframe)[:, :, ::-1] + else: + img4_np = curfile_name + img4 = (torch.tensor(img4_np.transpose(2, 0, 1).copy()).float() / 255.0).unsqueeze(0).cuda() + if img0 is None: + img0 = img4 + cv2.imwrite(out_path + '/{:0>7d}.png'.format(name_idx), + (img0[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:, :, ::-1]) + _, _, h, w = img0.shape + if h >= 2160: + scale_factor = 0.25 + pyr_level = 8 + nr_lvl_skipped = 4 + elif h >= 1080: + scale_factor = 0.5 + pyr_level = 7 + nr_lvl_skipped = 0 + else: + scale_factor = 1 + pyr_level = 5 + nr_lvl_skipped = 0 + idx += 1 + name_idx += 1 + continue + # if is_video: + # if fps == 60: + # if idx % 5 != 0 and idx % 5 != 3: + # idx += 1 + # continue + # img0_ = F.interpolate(img0, scale_factor=pre_down, mode='bilinear') + # img4_ = F.interpolate(img4, scale_factor=pre_down, mode='bilinear') + results_dict = model(img0=img0, img1=img4, time_step=time_range, scale_factor=scale_factor, + ratio=(1 / scale_factor), pyr_level=pyr_level, nr_lvl_skipped=nr_lvl_skipped) + imgt_pred = results_dict['imgt_pred'] + imgt_pred = torch.clip(imgt_pred, 0, 1) + save_image(flow2img(results_dict['flowfwd']), + os.path.join(out_path + '_flow', "{:0>7d}ff.png".format(name_idx - 1))) + save_image(flow2img(results_dict['flowbwd']), + os.path.join(out_path + '_flow', "{:0>7d}bb.png".format(name_idx - 1))) + if "flowfwd_pre" in results_dict.keys(): + save_image(flow2img(results_dict['flowfwd_pre']), + os.path.join(out_path + '_flow', "pre_{:0>7d}ff.png".format(name_idx - 1))) + save_image(results_dict['refine_res'], os.path.join(out_path, "refine_res.png")) + save_image(results_dict['refine_mask'], os.path.join(out_path, "refine_mask.png")) + save_image(results_dict['warped_img0'], os.path.join(out_path, "warped_img0.png")) + save_image(results_dict['warped_img1'], os.path.join(out_path, "warped_img1.png")) + save_image(results_dict['merged_img'], os.path.join(out_path, "merged_img.png")) + + img_pred = imgt_pred + # img_pred = F.interpolate(img_pred, scale_factor=1 // pre_down, mode='bilinear') + cv2.imwrite(out_path + '/{:0>7d}.png'.format(name_idx), + (img_pred[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:, :, ::-1]) + name_idx += 1 + # img4 = F.interpolate(img4, scale_factor=1 // pre_down, mode='bilinear') + cv2.imwrite(out_path + '/{:0>7d}.png'.format(name_idx), + (img4[0] * 255).byte().cpu().numpy().transpose(1, 2, 0)[:, :, ::-1]) + name_idx += 1 + idx += 1 + img0 = img4 + if is_video: + os.system( + f'ffmpeg -framerate {fps * 2} -pattern_type glob -i "{out_path}/*.png" -c:v libx265 -qp 8 -pix_fmt yuv420p {out_path}_{fps * 2}fps.mp4') diff --git a/modules/models/m2m_flowformer.py b/modules/models/m2m_flowformer.py new file mode 100644 index 0000000000000000000000000000000000000000..9e94ef2fe8f0e99929493a9bdfd2b7cb9a5f7274 --- /dev/null +++ b/modules/models/m2m_flowformer.py @@ -0,0 +1,8 @@ +from modules.models.base_model import BaseModel +from modules.models import register + + +@register('m2m_flowformer') +class M2MFlowFormer(BaseModel): + def __init__(self, cfg): + super(M2MFlowFormer, self).__init__(cfg) diff --git a/modules/models/m2m_pwc.py b/modules/models/m2m_pwc.py new file mode 100644 index 0000000000000000000000000000000000000000..1ad24f4d4c55a57e7aa44702a6ca1a3249aa22ff --- /dev/null +++ b/modules/models/m2m_pwc.py @@ -0,0 +1,8 @@ +from modules.models.base_model import BaseModel +from modules.models import register + + +@register('m2m_pwc') +class M2MPWC(BaseModel): + def __init__(self, cfg): + super(M2MPWC, self).__init__(cfg) diff --git a/modules/models/models.py b/modules/models/models.py new file mode 100644 index 0000000000000000000000000000000000000000..d81f3ad209c137a49426593a7b4867f376ba6ccb --- /dev/null +++ b/modules/models/models.py @@ -0,0 +1,16 @@ +import copy + + +models = {} + + +def register(name): + def decorator(cls): + models[name] = cls + return cls + return decorator + + +def make(cfgs): + model = models[cfgs['model']['name']](cfgs) + return model diff --git a/modules/models/unimatch/__init__.py b/modules/models/unimatch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/models/unimatch/__pycache__/__init__.cpython-310.pyc b/modules/models/unimatch/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06bcaca29167e2a9732b0450410c7860a395ad5c Binary files /dev/null and b/modules/models/unimatch/__pycache__/__init__.cpython-310.pyc differ diff --git a/modules/models/unimatch/__pycache__/__init__.cpython-38.pyc b/modules/models/unimatch/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7916a81e33ab71f3b0904136732218a4a867f7fb Binary files /dev/null and b/modules/models/unimatch/__pycache__/__init__.cpython-38.pyc differ diff --git a/modules/models/unimatch/__pycache__/__init__.cpython-39.pyc b/modules/models/unimatch/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51c9b6e671877e2f57133d211d1a4482dfc5774e Binary files /dev/null and b/modules/models/unimatch/__pycache__/__init__.cpython-39.pyc differ diff --git a/modules/models/unimatch/__pycache__/attention.cpython-310.pyc b/modules/models/unimatch/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd1dfb4e67564666510eb3eac53df5e7ffb8e42f Binary files /dev/null and b/modules/models/unimatch/__pycache__/attention.cpython-310.pyc differ diff --git a/modules/models/unimatch/__pycache__/attention.cpython-38.pyc b/modules/models/unimatch/__pycache__/attention.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0cbc38f3993c71d8a72041a8c7634228cdec8882 Binary files /dev/null and b/modules/models/unimatch/__pycache__/attention.cpython-38.pyc differ diff --git a/modules/models/unimatch/__pycache__/attention.cpython-39.pyc b/modules/models/unimatch/__pycache__/attention.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee6862d251f5f6fae06ecc139f279f317c1d51de Binary files /dev/null and b/modules/models/unimatch/__pycache__/attention.cpython-39.pyc differ diff --git a/modules/models/unimatch/__pycache__/backbone.cpython-310.pyc b/modules/models/unimatch/__pycache__/backbone.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fabe0181f655d8a0223fc4784288a61e3a39ea9 Binary files /dev/null and b/modules/models/unimatch/__pycache__/backbone.cpython-310.pyc differ diff --git a/modules/models/unimatch/__pycache__/backbone.cpython-38.pyc b/modules/models/unimatch/__pycache__/backbone.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efa591c50e8a5f6aaafe0ad64b582d53d580e185 Binary files /dev/null and b/modules/models/unimatch/__pycache__/backbone.cpython-38.pyc differ diff --git a/modules/models/unimatch/__pycache__/backbone.cpython-39.pyc b/modules/models/unimatch/__pycache__/backbone.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13670acd08519d305321b14e25c6892dcbb30dee Binary files /dev/null and b/modules/models/unimatch/__pycache__/backbone.cpython-39.pyc differ diff --git a/modules/models/unimatch/__pycache__/geometry.cpython-310.pyc b/modules/models/unimatch/__pycache__/geometry.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c669fbc9a1c0be14457217808fc04fd58606a7a3 Binary files /dev/null and b/modules/models/unimatch/__pycache__/geometry.cpython-310.pyc differ diff --git a/modules/models/unimatch/__pycache__/geometry.cpython-38.pyc b/modules/models/unimatch/__pycache__/geometry.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..800473af90bcd968c2a3fa3ded10c273ec551b2c Binary files /dev/null and b/modules/models/unimatch/__pycache__/geometry.cpython-38.pyc differ diff --git a/modules/models/unimatch/__pycache__/geometry.cpython-39.pyc b/modules/models/unimatch/__pycache__/geometry.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..415bb4fb023d27692943634e32075e3077a1b61e Binary files /dev/null and b/modules/models/unimatch/__pycache__/geometry.cpython-39.pyc differ diff --git a/modules/models/unimatch/__pycache__/matching.cpython-310.pyc b/modules/models/unimatch/__pycache__/matching.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b61217210bea08f4986ed8a3bf4ee66c0c2e7e1e Binary files /dev/null and b/modules/models/unimatch/__pycache__/matching.cpython-310.pyc differ diff --git a/modules/models/unimatch/__pycache__/matching.cpython-38.pyc b/modules/models/unimatch/__pycache__/matching.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9f40605c7d484841a5bbb365ab1b52ee5c8ccb0 Binary files /dev/null and b/modules/models/unimatch/__pycache__/matching.cpython-38.pyc differ diff --git a/modules/models/unimatch/__pycache__/matching.cpython-39.pyc b/modules/models/unimatch/__pycache__/matching.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..948d76bb1310965de20403b609e55e50b55611b6 Binary files /dev/null and b/modules/models/unimatch/__pycache__/matching.cpython-39.pyc differ diff --git a/modules/models/unimatch/__pycache__/position.cpython-310.pyc b/modules/models/unimatch/__pycache__/position.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d6d6d3ee0a4c691f43f695567b68357c035e6c9 Binary files /dev/null and b/modules/models/unimatch/__pycache__/position.cpython-310.pyc differ diff --git a/modules/models/unimatch/__pycache__/position.cpython-38.pyc b/modules/models/unimatch/__pycache__/position.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e15dad9e0117a940a4c21102e9ec84cb2104adf Binary files /dev/null and b/modules/models/unimatch/__pycache__/position.cpython-38.pyc differ diff --git a/modules/models/unimatch/__pycache__/position.cpython-39.pyc b/modules/models/unimatch/__pycache__/position.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c53672d831c2ebb4ccd17277e66cdc1708060548 Binary files /dev/null and b/modules/models/unimatch/__pycache__/position.cpython-39.pyc differ diff --git a/modules/models/unimatch/__pycache__/reg_refine.cpython-310.pyc b/modules/models/unimatch/__pycache__/reg_refine.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c9c2d914bab4ac41e2aea876e5377bc1fd734cf Binary files /dev/null and b/modules/models/unimatch/__pycache__/reg_refine.cpython-310.pyc differ diff --git a/modules/models/unimatch/__pycache__/reg_refine.cpython-38.pyc b/modules/models/unimatch/__pycache__/reg_refine.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6ef87dae5a4acc74cc5866bbbcddf6497002693 Binary files /dev/null and b/modules/models/unimatch/__pycache__/reg_refine.cpython-38.pyc differ diff --git a/modules/models/unimatch/__pycache__/reg_refine.cpython-39.pyc b/modules/models/unimatch/__pycache__/reg_refine.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84a1dddc7739351be0ed33efd2076f5f585fabf9 Binary files /dev/null and b/modules/models/unimatch/__pycache__/reg_refine.cpython-39.pyc differ diff --git a/modules/models/unimatch/__pycache__/transformer.cpython-310.pyc b/modules/models/unimatch/__pycache__/transformer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83a39606d8fed1ab73aec0dd9afefff9de7e77d9 Binary files /dev/null and b/modules/models/unimatch/__pycache__/transformer.cpython-310.pyc differ diff --git a/modules/models/unimatch/__pycache__/transformer.cpython-38.pyc b/modules/models/unimatch/__pycache__/transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2ee319d4766ed5e770781e931282a23628eade8 Binary files /dev/null and b/modules/models/unimatch/__pycache__/transformer.cpython-38.pyc differ diff --git a/modules/models/unimatch/__pycache__/transformer.cpython-39.pyc b/modules/models/unimatch/__pycache__/transformer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..085cec918466db4133ccb2d870a0675e004b4a54 Binary files /dev/null and b/modules/models/unimatch/__pycache__/transformer.cpython-39.pyc differ diff --git a/modules/models/unimatch/__pycache__/trident_conv.cpython-310.pyc b/modules/models/unimatch/__pycache__/trident_conv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fad6c0afc3756288f5b921e17caeafcc70c4da2 Binary files /dev/null and b/modules/models/unimatch/__pycache__/trident_conv.cpython-310.pyc differ diff --git a/modules/models/unimatch/__pycache__/trident_conv.cpython-38.pyc b/modules/models/unimatch/__pycache__/trident_conv.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f273c7ab878481e9c1fb5a198439de45a774d54 Binary files /dev/null and b/modules/models/unimatch/__pycache__/trident_conv.cpython-38.pyc differ diff --git a/modules/models/unimatch/__pycache__/trident_conv.cpython-39.pyc b/modules/models/unimatch/__pycache__/trident_conv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91d9d77f62e5f9a4488f2f1bf5fc787bae3f9ab2 Binary files /dev/null and b/modules/models/unimatch/__pycache__/trident_conv.cpython-39.pyc differ diff --git a/modules/models/unimatch/__pycache__/unimatch.cpython-310.pyc b/modules/models/unimatch/__pycache__/unimatch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..005dd3c26e4c841f3a6298fb5c2dcc99f2ca7a06 Binary files /dev/null and b/modules/models/unimatch/__pycache__/unimatch.cpython-310.pyc differ diff --git a/modules/models/unimatch/__pycache__/unimatch.cpython-38.pyc b/modules/models/unimatch/__pycache__/unimatch.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d299b832b4820b1ab1305ebdb2bb7980e9fc1d41 Binary files /dev/null and b/modules/models/unimatch/__pycache__/unimatch.cpython-38.pyc differ diff --git a/modules/models/unimatch/__pycache__/unimatch.cpython-39.pyc b/modules/models/unimatch/__pycache__/unimatch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28c35ed4f37ec5d4c38d907cdce9cd6a0749d07f Binary files /dev/null and b/modules/models/unimatch/__pycache__/unimatch.cpython-39.pyc differ diff --git a/modules/models/unimatch/__pycache__/utils.cpython-310.pyc b/modules/models/unimatch/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8362ce2c878b116abc9231b6a9ee9a356fa99b6 Binary files /dev/null and b/modules/models/unimatch/__pycache__/utils.cpython-310.pyc differ diff --git a/modules/models/unimatch/__pycache__/utils.cpython-38.pyc b/modules/models/unimatch/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..436edc2057d85420fb998c7d9b4908f84f9386a9 Binary files /dev/null and b/modules/models/unimatch/__pycache__/utils.cpython-38.pyc differ diff --git a/modules/models/unimatch/__pycache__/utils.cpython-39.pyc b/modules/models/unimatch/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..696e7158b55ea2b7acbcb4cff0a40b260002a569 Binary files /dev/null and b/modules/models/unimatch/__pycache__/utils.cpython-39.pyc differ diff --git a/modules/models/unimatch/attention.py b/modules/models/unimatch/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..92a3c878afe541753022ba85c43b5b2e86e4d254 --- /dev/null +++ b/modules/models/unimatch/attention.py @@ -0,0 +1,253 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils import split_feature, merge_splits, split_feature_1d, merge_splits_1d + + +def single_head_full_attention(q, k, v): + # q, k, v: [B, L, C] + assert q.dim() == k.dim() == v.dim() == 3 + + scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** .5) # [B, L, L] + attn = torch.softmax(scores, dim=2) # [B, L, L] + out = torch.matmul(attn, v) # [B, L, C] + + return out + + +def single_head_full_attention_1d(q, k, v, + h=None, + w=None, + ): + # q, k, v: [B, L, C] + + assert h is not None and w is not None + assert q.size(1) == h * w + + b, _, c = q.size() + + q = q.view(b, h, w, c) # [B, H, W, C] + k = k.view(b, h, w, c) + v = v.view(b, h, w, c) + + scale_factor = c ** 0.5 + + scores = torch.matmul(q, k.permute(0, 1, 3, 2)) / scale_factor # [B, H, W, W] + + attn = torch.softmax(scores, dim=-1) + + out = torch.matmul(attn, v).view(b, -1, c) # [B, H*W, C] + + return out + + +def single_head_split_window_attention(q, k, v, + num_splits=1, + with_shift=False, + h=None, + w=None, + attn_mask=None, + ): + # ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + # q, k, v: [B, L, C] + assert q.dim() == k.dim() == v.dim() == 3 + + assert h is not None and w is not None + assert q.size(1) == h * w + + b, _, c = q.size() + + b_new = b * num_splits * num_splits + + window_size_h = h // num_splits + window_size_w = w // num_splits + + q = q.view(b, h, w, c) # [B, H, W, C] + k = k.view(b, h, w, c) + v = v.view(b, h, w, c) + + scale_factor = c ** 0.5 + + if with_shift: + assert attn_mask is not None # compute once + shift_size_h = window_size_h // 2 + shift_size_w = window_size_w // 2 + + q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + + q = split_feature(q, num_splits=num_splits, channel_last=True) # [B*K*K, H/K, W/K, C] + k = split_feature(k, num_splits=num_splits, channel_last=True) + v = split_feature(v, num_splits=num_splits, channel_last=True) + + scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1) + ) / scale_factor # [B*K*K, H/K*W/K, H/K*W/K] + + if with_shift: + scores += attn_mask.repeat(b, 1, 1) + + attn = torch.softmax(scores, dim=-1) + + out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C] + + out = merge_splits(out.view(b_new, h // num_splits, w // num_splits, c), + num_splits=num_splits, channel_last=True) # [B, H, W, C] + + # shift back + if with_shift: + out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2)) + + out = out.view(b, -1, c) + + return out + + +def single_head_split_window_attention_1d(q, k, v, + relative_position_bias=None, + num_splits=1, + with_shift=False, + h=None, + w=None, + attn_mask=None, + ): + # q, k, v: [B, L, C] + + assert h is not None and w is not None + assert q.size(1) == h * w + + b, _, c = q.size() + + b_new = b * num_splits * h + + window_size_w = w // num_splits + + q = q.view(b * h, w, c) # [B*H, W, C] + k = k.view(b * h, w, c) + v = v.view(b * h, w, c) + + scale_factor = c ** 0.5 + + if with_shift: + assert attn_mask is not None # compute once + shift_size_w = window_size_w // 2 + + q = torch.roll(q, shifts=-shift_size_w, dims=1) + k = torch.roll(k, shifts=-shift_size_w, dims=1) + v = torch.roll(v, shifts=-shift_size_w, dims=1) + + q = split_feature_1d(q, num_splits=num_splits) # [B*H*K, W/K, C] + k = split_feature_1d(k, num_splits=num_splits) + v = split_feature_1d(v, num_splits=num_splits) + + scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1) + ) / scale_factor # [B*H*K, W/K, W/K] + + if with_shift: + # attn_mask: [K, W/K, W/K] + scores += attn_mask.repeat(b * h, 1, 1) # [B*H*K, W/K, W/K] + + attn = torch.softmax(scores, dim=-1) + + out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*H*K, W/K, C] + + out = merge_splits_1d(out, h, num_splits=num_splits) # [B, H, W, C] + + # shift back + if with_shift: + out = torch.roll(out, shifts=shift_size_w, dims=2) + + out = out.view(b, -1, c) + + return out + + +class SelfAttnPropagation(nn.Module): + """ + flow propagation with self-attention on feature + query: feature0, key: feature0, value: flow + """ + + def __init__(self, in_channels, + **kwargs, + ): + super(SelfAttnPropagation, self).__init__() + + self.q_proj = nn.Linear(in_channels, in_channels) + self.k_proj = nn.Linear(in_channels, in_channels) + + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feature0, flow, + local_window_attn=False, + local_window_radius=1, + **kwargs, + ): + # q, k: feature [B, C, H, W], v: flow [B, 2, H, W] + if local_window_attn: + return self.forward_local_window_attn(feature0, flow, + local_window_radius=local_window_radius) + + b, c, h, w = feature0.size() + + query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C] + + # a note: the ``correct'' implementation should be: + # ``query = self.q_proj(query), key = self.k_proj(query)'' + # this problem is observed while cleaning up the code + # however, this doesn't affect the performance since the projection is a linear operation, + # thus the two projection matrices for key can be merged + # so I just leave it as is in order to not re-train all models :) + query = self.q_proj(query) # [B, H*W, C] + key = self.k_proj(query) # [B, H*W, C] + + value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2] + + scores = torch.matmul(query, key.permute(0, 2, 1)) / (c ** 0.5) # [B, H*W, H*W] + prob = torch.softmax(scores, dim=-1) + + out = torch.matmul(prob, value) # [B, H*W, 2] + out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W] + + return out + + def forward_local_window_attn(self, feature0, flow, + local_window_radius=1, + ): + assert flow.size(1) == 2 or flow.size(1) == 1 # flow or disparity or depth + assert local_window_radius > 0 + + b, c, h, w = feature0.size() + + value_channel = flow.size(1) + + feature0_reshape = self.q_proj(feature0.view(b, c, -1).permute(0, 2, 1) + ).reshape(b * h * w, 1, c) # [B*H*W, 1, C] + + kernel_size = 2 * local_window_radius + 1 + + feature0_proj = self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1)).permute(0, 2, 1).reshape(b, c, h, w) + + feature0_window = F.unfold(feature0_proj, kernel_size=kernel_size, + padding=local_window_radius) # [B, C*(2R+1)^2), H*W] + + feature0_window = feature0_window.view(b, c, kernel_size ** 2, h, w).permute( + 0, 3, 4, 1, 2).reshape(b * h * w, c, kernel_size ** 2) # [B*H*W, C, (2R+1)^2] + + flow_window = F.unfold(flow, kernel_size=kernel_size, + padding=local_window_radius) # [B, 2*(2R+1)^2), H*W] + + flow_window = flow_window.view(b, value_channel, kernel_size ** 2, h, w).permute( + 0, 3, 4, 2, 1).reshape(b * h * w, kernel_size ** 2, value_channel) # [B*H*W, (2R+1)^2, 2] + + scores = torch.matmul(feature0_reshape, feature0_window) / (c ** 0.5) # [B*H*W, 1, (2R+1)^2] + + prob = torch.softmax(scores, dim=-1) + + out = torch.matmul(prob, flow_window).view(b, h, w, value_channel + ).permute(0, 3, 1, 2).contiguous() # [B, 2, H, W] + + return out diff --git a/modules/models/unimatch/backbone.py b/modules/models/unimatch/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..a30942eca9cad56e75252c3026dca95bf1021df7 --- /dev/null +++ b/modules/models/unimatch/backbone.py @@ -0,0 +1,117 @@ +import torch.nn as nn + +from .trident_conv import MultiScaleTridentConv + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_layer=nn.InstanceNorm2d, stride=1, dilation=1, + ): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, + dilation=dilation, padding=dilation, stride=stride, bias=False) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, + dilation=dilation, padding=dilation, bias=False) + self.relu = nn.ReLU(inplace=True) + + self.norm1 = norm_layer(planes) + self.norm2 = norm_layer(planes) + if not stride == 1 or in_planes != planes: + self.norm3 = norm_layer(planes) + + if stride == 1 and in_planes == planes: + self.downsample = None + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class CNNEncoder(nn.Module): + def __init__(self, output_dim=128, + norm_layer=nn.InstanceNorm2d, + num_output_scales=1, + **kwargs, + ): + super(CNNEncoder, self).__init__() + self.num_branch = num_output_scales + + feature_dims = [64, 96, 128] + + self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False) # 1/2 + self.norm1 = norm_layer(feature_dims[0]) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = feature_dims[0] + self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer) # 1/2 + self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer) # 1/4 + + # highest resolution 1/4 or 1/8 + stride = 2 if num_output_scales == 1 else 1 + self.layer3 = self._make_layer(feature_dims[2], stride=stride, + norm_layer=norm_layer, + ) # 1/4 or 1/8 + + self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0) + + if self.num_branch > 1: + if self.num_branch == 4: + strides = (1, 2, 4, 8) + elif self.num_branch == 3: + strides = (1, 2, 4) + elif self.num_branch == 2: + strides = (1, 2) + else: + raise ValueError + + self.trident_conv = MultiScaleTridentConv(output_dim, output_dim, + kernel_size=3, + strides=strides, + paddings=1, + num_branch=self.num_branch, + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d): + layer1 = ResidualBlock(self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation) + layer2 = ResidualBlock(dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation) + + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) # 1/2 + x = self.layer2(x) # 1/4 + x = self.layer3(x) # 1/8 or 1/4 + + x = self.conv2(x) + + if self.num_branch > 1: + out = self.trident_conv([x] * self.num_branch) # high to low res + else: + out = [x] + + return out diff --git a/modules/models/unimatch/geometry.py b/modules/models/unimatch/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..775a95783aeee66a44e6290525de94909af648df --- /dev/null +++ b/modules/models/unimatch/geometry.py @@ -0,0 +1,195 @@ +import torch +import torch.nn.functional as F + + +def coords_grid(b, h, w, homogeneous=False, device=None): + y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W] + + stacks = [x, y] + + if homogeneous: + ones = torch.ones_like(x) # [H, W] + stacks.append(ones) + + grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] + + grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] + + if device is not None: + grid = grid.to(device) + + return grid + + +def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): + assert device is not None + + x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device), + torch.linspace(h_min, h_max, len_h, device=device)], + ) + grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2] + + return grid + + +def normalize_coords(coords, h, w): + # coords: [B, H, W, 2] + c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device) + return (coords - c) / c # [-1, 1] + + +def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False): + # img: [B, C, H, W] + # sample_coords: [B, 2, H, W] in image scale + if sample_coords.size(1) != 2: # [B, H, W, 2] + sample_coords = sample_coords.permute(0, 3, 1, 2) + + b, _, h, w = sample_coords.shape + + # Normalize to [-1, 1] + x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1 + y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1 + + grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2] + + img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True) + + if return_mask: + mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W] + + return img, mask + + return img + + +def flow_warp(feature, flow, mask=False, padding_mode='zeros'): + b, c, h, w = feature.size() + assert flow.size(1) == 2 + + grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] + + return bilinear_sample(feature, grid, padding_mode=padding_mode, + return_mask=mask) + + +def forward_backward_consistency_check(fwd_flow, bwd_flow, + alpha=0.01, + beta=0.5 + ): + # fwd_flow, bwd_flow: [B, 2, H, W] + # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) + assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 + assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 + flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] + + warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W] + warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W] + + diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W] + diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) + + threshold = alpha * flow_mag + beta + + fwd_occ = (diff_fwd > threshold).float() # [B, H, W] + bwd_occ = (diff_bwd > threshold).float() + + return fwd_occ, bwd_occ + + +def back_project(depth, intrinsics): + # Back project 2D pixel coords to 3D points + # depth: [B, H, W] + # intrinsics: [B, 3, 3] + b, h, w = depth.shape + grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W] + + intrinsics_inv = torch.inverse(intrinsics) # [B, 3, 3] + + points = intrinsics_inv.bmm(grid.view(b, 3, -1)).view(b, 3, h, w) * depth.unsqueeze(1) # [B, 3, H, W] + + return points + + +def camera_transform(points_ref, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None): + # Transform 3D points from reference camera to target camera + # points_ref: [B, 3, H, W] + # extrinsics_ref: [B, 4, 4] + # extrinsics_tgt: [B, 4, 4] + # extrinsics_rel: [B, 4, 4], relative pose transform + b, _, h, w = points_ref.shape + + if extrinsics_rel is None: + extrinsics_rel = torch.bmm(extrinsics_tgt, torch.inverse(extrinsics_ref)) # [B, 4, 4] + + points_tgt = torch.bmm(extrinsics_rel[:, :3, :3], + points_ref.view(b, 3, -1)) + extrinsics_rel[:, :3, -1:] # [B, 3, H*W] + + points_tgt = points_tgt.view(b, 3, h, w) # [B, 3, H, W] + + return points_tgt + + +def reproject(points_tgt, intrinsics, return_mask=False): + # reproject to target view + # points_tgt: [B, 3, H, W] + # intrinsics: [B, 3, 3] + + b, _, h, w = points_tgt.shape + + proj_points = torch.bmm(intrinsics, points_tgt.view(b, 3, -1)).view(b, 3, h, w) # [B, 3, H, W] + + X = proj_points[:, 0] + Y = proj_points[:, 1] + Z = proj_points[:, 2].clamp(min=1e-3) + + pixel_coords = torch.stack([X / Z, Y / Z], dim=1).view(b, 2, h, w) # [B, 2, H, W] in image scale + + if return_mask: + # valid mask in pixel space + mask = (pixel_coords[:, 0] >= 0) & (pixel_coords[:, 0] <= (w - 1)) & ( + pixel_coords[:, 1] >= 0) & (pixel_coords[:, 1] <= (h - 1)) # [B, H, W] + + return pixel_coords, mask + + return pixel_coords + + +def reproject_coords(depth_ref, intrinsics, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None, + return_mask=False): + # Compute reprojection sample coords + points_ref = back_project(depth_ref, intrinsics) # [B, 3, H, W] + points_tgt = camera_transform(points_ref, extrinsics_ref, extrinsics_tgt, extrinsics_rel=extrinsics_rel) + + if return_mask: + reproj_coords, mask = reproject(points_tgt, intrinsics, + return_mask=return_mask) # [B, 2, H, W] in image scale + + return reproj_coords, mask + + reproj_coords = reproject(points_tgt, intrinsics, + return_mask=return_mask) # [B, 2, H, W] in image scale + + return reproj_coords + + +def compute_flow_with_depth_pose(depth_ref, intrinsics, + extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None, + return_mask=False): + b, h, w = depth_ref.shape + coords_init = coords_grid(b, h, w, device=depth_ref.device) # [B, 2, H, W] + + if return_mask: + reproj_coords, mask = reproject_coords(depth_ref, intrinsics, extrinsics_ref, extrinsics_tgt, + extrinsics_rel=extrinsics_rel, + return_mask=return_mask) # [B, 2, H, W] + rigid_flow = reproj_coords - coords_init + + return rigid_flow, mask + + reproj_coords = reproject_coords(depth_ref, intrinsics, extrinsics_ref, extrinsics_tgt, + extrinsics_rel=extrinsics_rel, + return_mask=return_mask) # [B, 2, H, W] + + rigid_flow = reproj_coords - coords_init + + return rigid_flow diff --git a/modules/models/unimatch/matching.py b/modules/models/unimatch/matching.py new file mode 100644 index 0000000000000000000000000000000000000000..595437f2307202ab36d7c2ee3dfa0ab44e4dc830 --- /dev/null +++ b/modules/models/unimatch/matching.py @@ -0,0 +1,279 @@ +import torch +import torch.nn.functional as F + +from .geometry import coords_grid, generate_window_grid, normalize_coords + + +def global_correlation_softmax(feature0, feature1, + pred_bidir_flow=False, + ): + # global correlation + b, c, h, w = feature0.shape + feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C] + feature1 = feature1.view(b, c, -1) # [B, C, H*W] + + correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (c ** 0.5) # [B, H, W, H, W] + + # flow from softmax + init_grid = coords_grid(b, h, w).to(correlation.device) # [B, 2, H, W] + grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] + + correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W] + + if pred_bidir_flow: + correlation = torch.cat((correlation, correlation.permute(0, 2, 1)), dim=0) # [2*B, H*W, H*W] + init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, 2, H, W] + grid = grid.repeat(2, 1, 1) # [2*B, H*W, 2] + b = b * 2 + + prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W] + + correspondence = torch.matmul(prob, grid).view(b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W] + + # when predicting bidirectional flow, flow is the concatenation of forward flow and backward flow + flow = correspondence - init_grid + + return flow, prob + + +def local_correlation_softmax(feature0, feature1, local_radius, + padding_mode='zeros', + ): + b, c, h, w = feature0.size() + coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W] + coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] + + local_h = 2 * local_radius + 1 + local_w = 2 * local_radius + 1 + + window_grid = generate_window_grid(-local_radius, local_radius, + -local_radius, local_radius, + local_h, local_w, device=feature0.device) # [2R+1, 2R+1, 2] + window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2] + sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1)^2, 2] + + sample_coords_softmax = sample_coords + + # exclude coords that are out of image space + valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [B, H*W, (2R+1)^2] + valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [B, H*W, (2R+1)^2] + + valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax + + # normalize coordinates to [-1, 1] + sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1] + window_feature = F.grid_sample(feature1, sample_coords_norm, + padding_mode=padding_mode, align_corners=True + ).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)^2] + feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C] + + corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)^2] + + # mask invalid locations + corr[~valid] = -1e9 + + prob = F.softmax(corr, -1) # [B, H*W, (2R+1)^2] + + correspondence = torch.matmul(prob.unsqueeze(-2), sample_coords_softmax).squeeze(-2).view( + b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W] + + flow = correspondence - coords_init + match_prob = prob + + return flow, match_prob + + +def local_correlation_with_flow(feature0, feature1, + flow, + local_radius, + padding_mode='zeros', + dilation=1, + ): + b, c, h, w = feature0.size() + coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W] + coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] + + local_h = 2 * local_radius + 1 + local_w = 2 * local_radius + 1 + + window_grid = generate_window_grid(-local_radius, local_radius, + -local_radius, local_radius, + local_h, local_w, device=feature0.device) # [2R+1, 2R+1, 2] + window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2] + sample_coords = coords.unsqueeze(-2) + window_grid * dilation # [B, H*W, (2R+1)^2, 2] + + # flow can be zero when using features after transformer + if not isinstance(flow, float): + sample_coords = sample_coords + flow.view( + b, 2, -1).permute(0, 2, 1).unsqueeze(-2) # [B, H*W, (2R+1)^2, 2] + else: + assert flow == 0. + + # normalize coordinates to [-1, 1] + sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1] + window_feature = F.grid_sample(feature1, sample_coords_norm, + padding_mode=padding_mode, align_corners=True + ).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)^2] + feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C] + + corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)^2] + + corr = corr.view(b, h, w, -1).permute(0, 3, 1, 2).contiguous() # [B, (2R+1)^2, H, W] + + return corr + + +def global_correlation_softmax_stereo(feature0, feature1, + ): + # global correlation on horizontal direction + b, c, h, w = feature0.shape + + x_grid = torch.linspace(0, w - 1, w, device=feature0.device) # [W] + + feature0 = feature0.permute(0, 2, 3, 1) # [B, H, W, C] + feature1 = feature1.permute(0, 2, 1, 3) # [B, H, C, W] + + correlation = torch.matmul(feature0, feature1) / (c ** 0.5) # [B, H, W, W] + + # mask subsequent positions to make disparity positive + mask = torch.triu(torch.ones((w, w)), diagonal=1).type_as(feature0) # [W, W] + valid_mask = (mask == 0).unsqueeze(0).unsqueeze(0).repeat(b, h, 1, 1) # [B, H, W, W] + + correlation[~valid_mask] = -1e9 + + prob = F.softmax(correlation, dim=-1) # [B, H, W, W] + + correspondence = (x_grid.view(1, 1, 1, w) * prob).sum(-1) # [B, H, W] + + # NOTE: unlike flow, disparity is typically positive + disparity = x_grid.view(1, 1, w).repeat(b, h, 1) - correspondence # [B, H, W] + + return disparity.unsqueeze(1), prob # feature resolution + + +def local_correlation_softmax_stereo(feature0, feature1, local_radius, + ): + b, c, h, w = feature0.size() + coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W] + coords = coords_init.view(b, 2, -1).permute(0, 2, 1).contiguous() # [B, H*W, 2] + + local_h = 1 + local_w = 2 * local_radius + 1 + + window_grid = generate_window_grid(0, 0, + -local_radius, local_radius, + local_h, local_w, device=feature0.device) # [1, 2R+1, 2] + window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1), 2] + sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1), 2] + + sample_coords_softmax = sample_coords + + # exclude coords that are out of image space + valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [B, H*W, (2R+1)^2] + valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [B, H*W, (2R+1)^2] + + valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax + + # normalize coordinates to [-1, 1] + sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1] + window_feature = F.grid_sample(feature1, sample_coords_norm, + padding_mode='zeros', align_corners=True + ).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)] + feature0_view = feature0.permute(0, 2, 3, 1).contiguous().view(b, h * w, 1, c) # [B, H*W, 1, C] + + corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)] + + # mask invalid locations + corr[~valid] = -1e9 + + prob = F.softmax(corr, -1) # [B, H*W, (2R+1)] + + correspondence = torch.matmul(prob.unsqueeze(-2), + sample_coords_softmax).squeeze(-2).view( + b, h, w, 2).permute(0, 3, 1, 2).contiguous() # [B, 2, H, W] + + flow = correspondence - coords_init # flow at feature resolution + match_prob = prob + + flow_x = -flow[:, :1] # [B, 1, H, W] + + return flow_x, match_prob + + +def correlation_softmax_depth(feature0, feature1, + intrinsics, + pose, + depth_candidates, + depth_from_argmax=False, + pred_bidir_depth=False, + ): + b, c, h, w = feature0.size() + assert depth_candidates.dim() == 4 # [B, D, H, W] + scale_factor = c ** 0.5 + + if pred_bidir_depth: + feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat((feature1, feature0), dim=0) + intrinsics = intrinsics.repeat(2, 1, 1) + pose = torch.cat((pose, torch.inverse(pose)), dim=0) + depth_candidates = depth_candidates.repeat(2, 1, 1, 1) + + # depth candidates are actually inverse depth + warped_feature1 = warp_with_pose_depth_candidates(feature1, intrinsics, pose, + 1. / depth_candidates, + ) # [B, C, D, H, W] + + correlation = (feature0.unsqueeze(2) * warped_feature1).sum(1) / scale_factor # [B, D, H, W] + + match_prob = F.softmax(correlation, dim=1) # [B, D, H, W] + + # for cross-task transfer (flow -> depth), extract depth with argmax at test time + if depth_from_argmax: + index = torch.argmax(match_prob, dim=1, keepdim=True) + depth = torch.gather(depth_candidates, dim=1, index=index) + else: + depth = (match_prob * depth_candidates).sum(dim=1, keepdim=True) # [B, 1, H, W] + + return depth, match_prob + + +def warp_with_pose_depth_candidates(feature1, intrinsics, pose, depth, + clamp_min_depth=1e-3, + ): + """ + feature1: [B, C, H, W] + intrinsics: [B, 3, 3] + pose: [B, 4, 4] + depth: [B, D, H, W] + """ + + assert intrinsics.size(1) == intrinsics.size(2) == 3 + assert pose.size(1) == pose.size(2) == 4 + assert depth.dim() == 4 + + b, d, h, w = depth.size() + c = feature1.size(1) + + with torch.no_grad(): + # pixel coordinates + grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W] + # back project to 3D and transform viewpoint + points = torch.inverse(intrinsics).bmm(grid.view(b, 3, -1)) # [B, 3, H*W] + points = torch.bmm(pose[:, :3, :3], points).unsqueeze(2).repeat( + 1, 1, d, 1) * depth.view(b, 1, d, h * w) # [B, 3, D, H*W] + points = points + pose[:, :3, -1:].unsqueeze(-1) # [B, 3, D, H*W] + # reproject to 2D image plane + points = torch.bmm(intrinsics, points.view(b, 3, -1)).view(b, 3, d, h * w) # [B, 3, D, H*W] + pixel_coords = points[:, :2] / points[:, -1:].clamp(min=clamp_min_depth) # [B, 2, D, H*W] + + # normalize to [-1, 1] + x_grid = 2 * pixel_coords[:, 0] / (w - 1) - 1 + y_grid = 2 * pixel_coords[:, 1] / (h - 1) - 1 + + grid = torch.stack([x_grid, y_grid], dim=-1) # [B, D, H*W, 2] + + # sample features + warped_feature = F.grid_sample(feature1, grid.view(b, d * h, w, 2), mode='bilinear', + padding_mode='zeros', + align_corners=True).view(b, c, d, h, w) # [B, C, D, H, W] + + return warped_feature diff --git a/modules/models/unimatch/position.py b/modules/models/unimatch/position.py new file mode 100644 index 0000000000000000000000000000000000000000..14a6da436c818b7c2784e92dba66f7947d34b7ce --- /dev/null +++ b/modules/models/unimatch/position.py @@ -0,0 +1,46 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py + +import torch +import torch.nn as nn +import math + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x): + # x = tensor_list.tensors # [B, C, H, W] + # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0 + b, c, h, w = x.size() + mask = torch.ones((b, h, w), device=x.device) # [B, H, W] + y_embed = mask.cumsum(1, dtype=torch.float32) + x_embed = mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos diff --git a/modules/models/unimatch/reg_refine.py b/modules/models/unimatch/reg_refine.py new file mode 100644 index 0000000000000000000000000000000000000000..47f83da1c5dcd476069e841d045db04998be3604 --- /dev/null +++ b/modules/models/unimatch/reg_refine.py @@ -0,0 +1,119 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256, + out_dim=2, + ): + super(FlowHead, self).__init__() + + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, out_dim, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + out = self.conv2(self.relu(self.conv1(x))) + + return out + + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192 + 128, + kernel_size=5, + ): + padding = (kernel_size - 1) // 2 + + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding)) + self.convr1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding)) + self.convq1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding)) + + self.convz2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0)) + self.convr2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0)) + self.convq2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0)) + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + return h + + +class BasicMotionEncoder(nn.Module): + def __init__(self, corr_channels=324, + flow_channels=2, + ): + super(BasicMotionEncoder, self).__init__() + + self.convc1 = nn.Conv2d(corr_channels, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(flow_channels, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64 + 192, 128 - flow_channels, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + + +class BasicUpdateBlock(nn.Module): + def __init__(self, corr_channels=324, + hidden_dim=128, + context_dim=128, + downsample_factor=8, + flow_dim=2, + bilinear_up=False, + ): + super(BasicUpdateBlock, self).__init__() + + self.encoder = BasicMotionEncoder(corr_channels=corr_channels, + flow_channels=flow_dim, + ) + + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=context_dim + hidden_dim) + + self.flow_head = FlowHead(hidden_dim, hidden_dim=256, + out_dim=flow_dim, + ) + + if bilinear_up: + self.mask = None + else: + self.mask = nn.Sequential( + nn.Conv2d(hidden_dim, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, downsample_factor ** 2 * 9, 1, padding=0)) + + def forward(self, net, inp, corr, flow): + motion_features = self.encoder(flow, corr) + + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + if self.mask is not None: + mask = self.mask(net) + else: + mask = None + + return net, mask, delta_flow diff --git a/modules/models/unimatch/transformer.py b/modules/models/unimatch/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..4878e23a64f6609b1bf10740b0a794d8da836c31 --- /dev/null +++ b/modules/models/unimatch/transformer.py @@ -0,0 +1,294 @@ +import torch +import torch.nn as nn + +from .attention import (single_head_full_attention, single_head_split_window_attention, + single_head_full_attention_1d, single_head_split_window_attention_1d) +from .utils import generate_shift_window_attn_mask, generate_shift_window_attn_mask_1d + + +class TransformerLayer(nn.Module): + def __init__(self, + d_model=128, + nhead=1, + no_ffn=False, + ffn_dim_expansion=4, + ): + super(TransformerLayer, self).__init__() + + self.dim = d_model + self.nhead = nhead + self.no_ffn = no_ffn + + # multi-head attention + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, d_model, bias=False) + + self.merge = nn.Linear(d_model, d_model, bias=False) + + self.norm1 = nn.LayerNorm(d_model) + + # no ffn after self-attn, with ffn after cross-attn + if not self.no_ffn: + in_channels = d_model * 2 + self.mlp = nn.Sequential( + nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False), + nn.GELU(), + nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False), + ) + + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, source, target, + height=None, + width=None, + shifted_window_attn_mask=None, + shifted_window_attn_mask_1d=None, + attn_type='swin', + with_shift=False, + attn_num_splits=None, + ): + # source, target: [B, L, C] + query, key, value = source, target, target + + # for stereo: 2d attn in self-attn, 1d attn in cross-attn + is_self_attn = (query - key).abs().max() < 1e-6 + + # single-head attention + query = self.q_proj(query) # [B, L, C] + key = self.k_proj(key) # [B, L, C] + value = self.v_proj(value) # [B, L, C] + + if attn_type == 'swin' and attn_num_splits > 1: # self, cross-attn: both swin 2d + if self.nhead > 1: + # we observe that multihead attention slows down the speed and increases the memory consumption + # without bringing obvious performance gains and thus the implementation is removed + raise NotImplementedError + else: + message = single_head_split_window_attention(query, key, value, + num_splits=attn_num_splits, + with_shift=with_shift, + h=height, + w=width, + attn_mask=shifted_window_attn_mask, + ) + + elif attn_type == 'self_swin2d_cross_1d': # self-attn: swin 2d, cross-attn: full 1d + if self.nhead > 1: + raise NotImplementedError + else: + if is_self_attn: + if attn_num_splits > 1: + message = single_head_split_window_attention(query, key, value, + num_splits=attn_num_splits, + with_shift=with_shift, + h=height, + w=width, + attn_mask=shifted_window_attn_mask, + ) + else: + # full 2d attn + message = single_head_full_attention(query, key, value) # [N, L, C] + + else: + # cross attn 1d + message = single_head_full_attention_1d(query, key, value, + h=height, + w=width, + ) + + elif attn_type == 'self_swin2d_cross_swin1d': # self-attn: swin 2d, cross-attn: swin 1d + if self.nhead > 1: + raise NotImplementedError + else: + if is_self_attn: + if attn_num_splits > 1: + # self attn shift window + message = single_head_split_window_attention(query, key, value, + num_splits=attn_num_splits, + with_shift=with_shift, + h=height, + w=width, + attn_mask=shifted_window_attn_mask, + ) + else: + # full 2d attn + message = single_head_full_attention(query, key, value) # [N, L, C] + else: + if attn_num_splits > 1: + assert shifted_window_attn_mask_1d is not None + # cross attn 1d shift + message = single_head_split_window_attention_1d(query, key, value, + num_splits=attn_num_splits, + with_shift=with_shift, + h=height, + w=width, + attn_mask=shifted_window_attn_mask_1d, + ) + else: + message = single_head_full_attention_1d(query, key, value, + h=height, + w=width, + ) + + else: + message = single_head_full_attention(query, key, value) # [B, L, C] + + message = self.merge(message) # [B, L, C] + message = self.norm1(message) + + if not self.no_ffn: + message = self.mlp(torch.cat([source, message], dim=-1)) + message = self.norm2(message) + + return source + message + + +class TransformerBlock(nn.Module): + """self attention + cross attention + FFN""" + + def __init__(self, + d_model=128, + nhead=1, + ffn_dim_expansion=4, + ): + super(TransformerBlock, self).__init__() + + self.self_attn = TransformerLayer(d_model=d_model, + nhead=nhead, + no_ffn=True, + ffn_dim_expansion=ffn_dim_expansion, + ) + + self.cross_attn_ffn = TransformerLayer(d_model=d_model, + nhead=nhead, + ffn_dim_expansion=ffn_dim_expansion, + ) + + def forward(self, source, target, + height=None, + width=None, + shifted_window_attn_mask=None, + shifted_window_attn_mask_1d=None, + attn_type='swin', + with_shift=False, + attn_num_splits=None, + ): + # source, target: [B, L, C] + + # self attention + source = self.self_attn(source, source, + height=height, + width=width, + shifted_window_attn_mask=shifted_window_attn_mask, + attn_type=attn_type, + with_shift=with_shift, + attn_num_splits=attn_num_splits, + ) + + # cross attention and ffn + source = self.cross_attn_ffn(source, target, + height=height, + width=width, + shifted_window_attn_mask=shifted_window_attn_mask, + shifted_window_attn_mask_1d=shifted_window_attn_mask_1d, + attn_type=attn_type, + with_shift=with_shift, + attn_num_splits=attn_num_splits, + ) + + return source + + +class FeatureTransformer(nn.Module): + def __init__(self, + num_layers=6, + d_model=128, + nhead=1, + ffn_dim_expansion=4, + ): + super(FeatureTransformer, self).__init__() + + self.d_model = d_model + self.nhead = nhead + + self.layers = nn.ModuleList([ + TransformerBlock(d_model=d_model, + nhead=nhead, + ffn_dim_expansion=ffn_dim_expansion, + ) + for i in range(num_layers)]) + + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feature0, feature1, + attn_type='swin', + attn_num_splits=None, + **kwargs, + ): + + b, c, h, w = feature0.shape + assert self.d_model == c + + feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C] + feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C] + + # 2d attention + if 'swin' in attn_type and attn_num_splits > 1: + # global and refine use different number of splits + window_size_h = h // attn_num_splits + window_size_w = w // attn_num_splits + + # compute attn mask once + shifted_window_attn_mask = generate_shift_window_attn_mask( + input_resolution=(h, w), + window_size_h=window_size_h, + window_size_w=window_size_w, + shift_size_h=window_size_h // 2, + shift_size_w=window_size_w // 2, + device=feature0.device, + ) # [K*K, H/K*W/K, H/K*W/K] + else: + shifted_window_attn_mask = None + + # 1d attention + if 'swin1d' in attn_type and attn_num_splits > 1: + window_size_w = w // attn_num_splits + + # compute attn mask once + shifted_window_attn_mask_1d = generate_shift_window_attn_mask_1d( + input_w=w, + window_size_w=window_size_w, + shift_size_w=window_size_w // 2, + device=feature0.device, + ) # [K, W/K, W/K] + else: + shifted_window_attn_mask_1d = None + + # concat feature0 and feature1 in batch dimension to compute in parallel + concat0 = torch.cat((feature0, feature1), dim=0) # [2B, H*W, C] + concat1 = torch.cat((feature1, feature0), dim=0) # [2B, H*W, C] + + for i, layer in enumerate(self.layers): + concat0 = layer(concat0, concat1, + height=h, + width=w, + attn_type=attn_type, + with_shift='swin' in attn_type and attn_num_splits > 1 and i % 2 == 1, + attn_num_splits=attn_num_splits, + shifted_window_attn_mask=shifted_window_attn_mask, + shifted_window_attn_mask_1d=shifted_window_attn_mask_1d, + ) + + # update feature1 + concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0) + + feature0, feature1 = concat0.chunk(chunks=2, dim=0) # [B, H*W, C] + + # reshape back + feature0 = feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W] + feature1 = feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W] + + return feature0, feature1 diff --git a/modules/models/unimatch/trident_conv.py b/modules/models/unimatch/trident_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..29a2a73e964a88b68bc095772d9c3cc443e3e0fe --- /dev/null +++ b/modules/models/unimatch/trident_conv.py @@ -0,0 +1,90 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.modules.utils import _pair + + +class MultiScaleTridentConv(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + strides=1, + paddings=0, + dilations=1, + dilation=1, + groups=1, + num_branch=1, + test_branch_idx=-1, + bias=False, + norm=None, + activation=None, + ): + super(MultiScaleTridentConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.num_branch = num_branch + self.stride = _pair(stride) + self.groups = groups + self.with_bias = bias + self.dilation = dilation + if isinstance(paddings, int): + paddings = [paddings] * self.num_branch + if isinstance(dilations, int): + dilations = [dilations] * self.num_branch + if isinstance(strides, int): + strides = [strides] * self.num_branch + self.paddings = [_pair(padding) for padding in paddings] + self.dilations = [_pair(dilation) for dilation in dilations] + self.strides = [_pair(stride) for stride in strides] + self.test_branch_idx = test_branch_idx + self.norm = norm + self.activation = activation + + assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1 + + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) + ) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.bias = None + + nn.init.kaiming_uniform_(self.weight, nonlinearity="relu") + if self.bias is not None: + nn.init.constant_(self.bias, 0) + + def forward(self, inputs): + num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1 + assert len(inputs) == num_branch + + if self.training or self.test_branch_idx == -1: + outputs = [ + F.conv2d(input, self.weight, self.bias, stride, padding, self.dilation, self.groups) + for input, stride, padding in zip(inputs, self.strides, self.paddings) + ] + else: + outputs = [ + F.conv2d( + inputs[0], + self.weight, + self.bias, + self.strides[self.test_branch_idx] if self.test_branch_idx == -1 else self.strides[-1], + self.paddings[self.test_branch_idx] if self.test_branch_idx == -1 else self.paddings[-1], + self.dilation, + self.groups, + ) + ] + + if self.norm is not None: + outputs = [self.norm(x) for x in outputs] + if self.activation is not None: + outputs = [self.activation(x) for x in outputs] + return outputs diff --git a/modules/models/unimatch/unimatch.py b/modules/models/unimatch/unimatch.py new file mode 100644 index 0000000000000000000000000000000000000000..d813103a4bb67bfeba73571915bea50cf1ff64ec --- /dev/null +++ b/modules/models/unimatch/unimatch.py @@ -0,0 +1,381 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .backbone import CNNEncoder +from .transformer import FeatureTransformer +from .matching import (global_correlation_softmax, local_correlation_softmax, local_correlation_with_flow, + global_correlation_softmax_stereo, local_correlation_softmax_stereo, + correlation_softmax_depth) +from .attention import SelfAttnPropagation +from .geometry import flow_warp, compute_flow_with_depth_pose +from .reg_refine import BasicUpdateBlock +from .utils import normalize_img, feature_add_position, upsample_flow_with_mask + + +class UniMatch(nn.Module): + def __init__(self, + num_scales=1, + feature_channels=128, + upsample_factor=8, + num_head=1, + ffn_dim_expansion=4, + num_transformer_layers=6, + reg_refine=False, # optional local regression refinement + task='flow', + ): + super(UniMatch, self).__init__() + + self.feature_channels = feature_channels + self.num_scales = num_scales + self.upsample_factor = upsample_factor + self.reg_refine = reg_refine + + # CNN + self.backbone = CNNEncoder(output_dim=feature_channels, num_output_scales=num_scales) + + # Transformer + self.transformer = FeatureTransformer(num_layers=num_transformer_layers, + d_model=feature_channels, + nhead=num_head, + ffn_dim_expansion=ffn_dim_expansion, + ) + + # propagation with self-attn + self.feature_flow_attn = SelfAttnPropagation(in_channels=feature_channels) + + if not self.reg_refine or task == 'depth': + # convex upsampling simiar to RAFT + # concat feature0 and low res flow as input + self.upsampler = nn.Sequential(nn.Conv2d(2 + feature_channels, 256, 3, 1, 1), + nn.ReLU(inplace=True), + nn.Conv2d(256, upsample_factor ** 2 * 9, 1, 1, 0)) + # thus far, all the learnable parameters are task-agnostic + + if reg_refine: + # optional task-specific local regression refinement + self.refine_proj = nn.Conv2d(128, 256, 1) + self.refine = BasicUpdateBlock(corr_channels=(2 * 4 + 1) ** 2, + downsample_factor=upsample_factor, + flow_dim=2 if task == 'flow' else 1, + bilinear_up=task == 'depth', + ) + + def extract_feature(self, img0, img1): + concat = torch.cat((img0, img1), dim=0) # [2B, C, H, W] + features = self.backbone(concat) # list of [2B, C, H, W], resolution from high to low + + # reverse: resolution from low to high + features = features[::-1] + + feature0, feature1 = [], [] + + for i in range(len(features)): + feature = features[i] + chunks = torch.chunk(feature, 2, 0) # tuple + feature0.append(chunks[0]) + feature1.append(chunks[1]) + + return feature0, feature1 + + def upsample_flow(self, flow, feature, bilinear=False, upsample_factor=8, + is_depth=False): + if bilinear: + multiplier = 1 if is_depth else upsample_factor + up_flow = F.interpolate(flow, scale_factor=upsample_factor, + mode='bilinear', align_corners=True) * multiplier + else: + concat = torch.cat((flow, feature), dim=1) + mask = self.upsampler(concat) + up_flow = upsample_flow_with_mask(flow, mask, upsample_factor=self.upsample_factor, + is_depth=is_depth) + + return up_flow + + def forward(self, img0, img1, + attn_type=None, + attn_splits_list=None, + corr_radius_list=None, + prop_radius_list=None, + num_reg_refine=1, + pred_bidir_flow=False, + task='flow', + intrinsics=None, + pose=None, # relative pose transform + min_depth=1. / 0.5, # inverse depth range + max_depth=1. / 10, + num_depth_candidates=64, + depth_from_argmax=False, + pred_bidir_depth=False, + first_scaling=None, + **kwargs, + ): + + if 0.0 <= img0.max() <= 1.0: + img0 = img0*255 + img1 = img1*255 + + if first_scaling is not None: + img0 = F.interpolate(img0, scale_factor=1/first_scaling, mode='bilinear') + img1 = F.interpolate(img1, scale_factor=1/first_scaling, mode='bilinear') + + + if pred_bidir_flow: + assert task == 'flow' + + if task == 'depth': + assert self.num_scales == 1 # multi-scale depth model is not supported yet + + results_dict = {} + flow_preds = [] + + if task == 'flow': + # stereo and depth tasks have normalized img in dataloader + img0, img1 = normalize_img(img0, img1) # [B, 3, H, W] + + # list of features, resolution low to high ### CNN Features + feature0_list, feature1_list = self.extract_feature(img0, img1) # list of features + + flow = None + + if task != 'depth': + assert len(attn_splits_list) == len(corr_radius_list) == len(prop_radius_list) == self.num_scales + else: + assert len(attn_splits_list) == len(prop_radius_list) == self.num_scales == 1 + + for scale_idx in range(self.num_scales): + feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx] + + if pred_bidir_flow and scale_idx > 0: + # predicting bidirectional flow with refinement + feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat((feature1, feature0), dim=0) + + feature0_ori, feature1_ori = feature0, feature1 + + upsample_factor = self.upsample_factor * (2 ** (self.num_scales - 1 - scale_idx)) + + if task == 'depth': + # scale intrinsics + intrinsics_curr = intrinsics.clone() + intrinsics_curr[:, :2] = intrinsics_curr[:, :2] / upsample_factor + + if scale_idx > 0: + assert task != 'depth' # not supported for multi-scale depth model + flow = F.interpolate(flow, scale_factor=2, mode='bilinear', align_corners=True) * 2 + + if flow is not None: + assert task != 'depth' + flow = flow.detach() + + if task == 'stereo': + # construct flow vector for disparity + # flow here is actually disparity + zeros = torch.zeros_like(flow) # [B, 1, H, W] + # NOTE: reverse disp, disparity is positive + displace = torch.cat((-flow, zeros), dim=1) # [B, 2, H, W] + feature1 = flow_warp(feature1, displace) # [B, C, H, W] + elif task == 'flow': + feature1 = flow_warp(feature1, flow) # [B, C, H, W] + else: + raise NotImplementedError + + attn_splits = attn_splits_list[scale_idx] + if task != 'depth': + corr_radius = corr_radius_list[scale_idx] + prop_radius = prop_radius_list[scale_idx] + + # add position to features + feature0, feature1 = feature_add_position(feature0, feature1, attn_splits, self.feature_channels) + + # Transformer + feature0, feature1 = self.transformer(feature0, feature1, + attn_type=attn_type, + attn_num_splits=attn_splits, + ) + + # correlation and softmax + if task == 'depth': + # first generate depth candidates + b, _, h, w = feature0.size() + depth_candidates = torch.linspace(min_depth, max_depth, num_depth_candidates).type_as(feature0) + depth_candidates = depth_candidates.view(1, num_depth_candidates, 1, 1).repeat(b, 1, h, + w) # [B, D, H, W] + + flow_pred = correlation_softmax_depth(feature0, feature1, + intrinsics_curr, + pose, + depth_candidates=depth_candidates, + depth_from_argmax=depth_from_argmax, + pred_bidir_depth=pred_bidir_depth, + )[0] + + else: + if corr_radius == -1: # global matching + if task == 'flow': + flow_pred = global_correlation_softmax(feature0, feature1, pred_bidir_flow)[0] + elif task == 'stereo': + flow_pred = global_correlation_softmax_stereo(feature0, feature1)[0] + else: + raise NotImplementedError + else: # local matching + if task == 'flow': + flow_pred = local_correlation_softmax(feature0, feature1, corr_radius)[0] + elif task == 'stereo': + flow_pred = local_correlation_softmax_stereo(feature0, feature1, corr_radius)[0] + else: + raise NotImplementedError + + # flow or residual flow + flow = flow + flow_pred if flow is not None else flow_pred + + if task == 'stereo': + flow = flow.clamp(min=0) # positive disparity + + # upsample to the original resolution for supervison at training time only + if self.training: + flow_bilinear = self.upsample_flow(flow, None, bilinear=True, upsample_factor=upsample_factor, + is_depth=task == 'depth') + flow_preds.append(flow_bilinear) + + # flow propagation with self-attn + if (pred_bidir_flow or pred_bidir_depth) and scale_idx == 0: + feature0 = torch.cat((feature0, feature1), dim=0) # [2*B, C, H, W] for propagation + + flow = self.feature_flow_attn(feature0, flow.detach(), + local_window_attn=prop_radius > 0, + local_window_radius=prop_radius, + ) + + # bilinear exclude the last one + if self.training and scale_idx < self.num_scales - 1: + flow_up = self.upsample_flow(flow, feature0, bilinear=True, + upsample_factor=upsample_factor, + is_depth=task == 'depth') + flow_preds.append(flow_up) + + if scale_idx == self.num_scales - 1: + if not self.reg_refine: + # upsample to the original image resolution + + if task == 'stereo': + flow_pad = torch.cat((-flow, torch.zeros_like(flow)), dim=1) # [B, 2, H, W] + flow_up_pad = self.upsample_flow(flow_pad, feature0) + flow_up = -flow_up_pad[:, :1] # [B, 1, H, W] + elif task == 'depth': + depth_pad = torch.cat((flow, torch.zeros_like(flow)), dim=1) # [B, 2, H, W] + depth_up_pad = self.upsample_flow(depth_pad, feature0, + is_depth=True).clamp(min=min_depth, max=max_depth) + flow_up = depth_up_pad[:, :1] # [B, 1, H, W] + else: + flow_up = self.upsample_flow(flow, feature0) + + flow_preds.append(flow_up) + else: + # task-specific local regression refinement + # supervise current flow + if self.training: + flow_up = self.upsample_flow(flow, feature0, bilinear=True, + upsample_factor=upsample_factor, + is_depth=task == 'depth') + flow_preds.append(flow_up) + + assert num_reg_refine > 0 + for refine_iter_idx in range(num_reg_refine): + flow = flow.detach() + + if task == 'stereo': + zeros = torch.zeros_like(flow) # [B, 1, H, W] + # NOTE: reverse disp, disparity is positive + displace = torch.cat((-flow, zeros), dim=1) # [B, 2, H, W] + correlation = local_correlation_with_flow( + feature0_ori, + feature1_ori, + flow=displace, + local_radius=4, + ) # [B, (2R+1)^2, H, W] + elif task == 'depth': + if pred_bidir_depth and refine_iter_idx == 0: + intrinsics_curr = intrinsics_curr.repeat(2, 1, 1) + pose = torch.cat((pose, torch.inverse(pose)), dim=0) + + feature0_ori, feature1_ori = torch.cat((feature0_ori, feature1_ori), + dim=0), torch.cat((feature1_ori, + feature0_ori), dim=0) + + flow_from_depth = compute_flow_with_depth_pose(1. / flow.squeeze(1), + intrinsics_curr, + extrinsics_rel=pose, + ) + + correlation = local_correlation_with_flow( + feature0_ori, + feature1_ori, + flow=flow_from_depth, + local_radius=4, + ) # [B, (2R+1)^2, H, W] + + else: + correlation = local_correlation_with_flow( + feature0_ori, + feature1_ori, + flow=flow, + local_radius=4, + ) # [B, (2R+1)^2, H, W] + + proj = self.refine_proj(feature0) + + net, inp = torch.chunk(proj, chunks=2, dim=1) + + net = torch.tanh(net) + inp = torch.relu(inp) + + net, up_mask, residual_flow = self.refine(net, inp, correlation, flow.clone(), + ) + + if task == 'depth': + flow = (flow - residual_flow).clamp(min=min_depth, max=max_depth) + else: + flow = flow + residual_flow + + if task == 'stereo': + flow = flow.clamp(min=0) # positive + + if self.training or refine_iter_idx == num_reg_refine - 1: + if task == 'depth': + if refine_iter_idx < num_reg_refine - 1: + # bilinear upsampling + flow_up = self.upsample_flow(flow, feature0, bilinear=True, + upsample_factor=upsample_factor, + is_depth=True) + else: + # last one convex upsampling + # NOTE: clamp depth due to the zero padding in the unfold in the convex upsampling + # pad depth to 2 channels as flow + depth_pad = torch.cat((flow, torch.zeros_like(flow)), dim=1) # [B, 2, H, W] + depth_up_pad = self.upsample_flow(depth_pad, feature0, + is_depth=True).clamp(min=min_depth, + max=max_depth) + flow_up = depth_up_pad[:, :1] # [B, 1, H, W] + + else: + flow_up = upsample_flow_with_mask(flow, up_mask, upsample_factor=self.upsample_factor, + is_depth=task == 'depth') + + flow_preds.append(flow_up) + + if first_scaling is not None: + for i in range(len(flow_preds)): + flow_preds[i] = F.interpolate(flow_preds[i], scale_factor=first_scaling, mode='bilinear') + + if task == 'stereo': + for i in range(len(flow_preds)): + flow_preds[i] = flow_preds[i].squeeze(1) # [B, H, W] + + # convert inverse depth to depth + if task == 'depth': + for i in range(len(flow_preds)): + flow_preds[i] = 1. / flow_preds[i].squeeze(1) # [B, H, W] + + results_dict.update({'flow_preds': flow_preds}) + + return results_dict diff --git a/modules/models/unimatch/unimatch_inference.py b/modules/models/unimatch/unimatch_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..f7a2a9c77a4c4ea642202e15be156c7acfbefd8f --- /dev/null +++ b/modules/models/unimatch/unimatch_inference.py @@ -0,0 +1,39 @@ +import torch +from unimatch import UniMatch +from torchvision.transforms import functional as TF + +# define +# img0 +# img1 + +device = 'cuda:0' + +flow_extractor = UniMatch(feature_channels=128, + num_scales=2, + upsample_factor=8//2, + num_head=1, + ffn_dim_expansion=4, + num_transformer_layers=6, + reg_refine=True, + task='flow') +fe_sd = torch.load('./pretrained/gmflow-scale2-regrefine6-mixdata-train320x576-4e7b215d.pth')['model'] +print(flow_extractor.load_state_dict(fe_sd)) +for n,p in flow_extractor.named_parameters(): + p.requires_grad = False +flow_extractor = flow_extractor.to(device) + +unimatch_multiple = 128 +_,_,Huni,Wuni = img0.size() +padw = unimatch_multiple - (Wuni%unimatch_multiple) if Wuni%unimatch_multiple!=0 else 0 +padh = unimatch_multiple - (Huni%unimatch_multiple) if Huni%unimatch_multiple!=0 else 0 +img0_pad = TF.pad(img0, (0,0,padw,padh), padding_mode='symmetric').to(device) +img1_pad = TF.pad(img1, (0,0,padw,padh), padding_mode='symmetric').to(device) +with torch.no_grad(): + flow = flow_extractor(img0_pad, img1_pad, + attn_type='swin', + attn_splits_list=[2,8], + corr_radius_list=[-1,4], + prop_radius_list=[-1,1], + num_reg_refine=6, + first_scaling=4, + task='flow')['flow_preds'][-1][:,:,:Huni,:Wuni] # [B, 2, H, W] \ No newline at end of file diff --git a/modules/models/unimatch/utils.py b/modules/models/unimatch/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7ae824ccbed6701aea4150870210f90a8ad798ba --- /dev/null +++ b/modules/models/unimatch/utils.py @@ -0,0 +1,216 @@ +import torch +import torch.nn.functional as F +from .position import PositionEmbeddingSine + + +def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): + assert device is not None + + x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device), + torch.linspace(h_min, h_max, len_h, device=device)], + ) + grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2] + + return grid + + +def normalize_coords(coords, h, w): + # coords: [B, H, W, 2] + c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device) + return (coords - c) / c # [-1, 1] + + +def normalize_img(img0, img1): + # loaded images are in [0, 255] + # normalize by ImageNet mean and std + mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device) + std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(img1.device) + img0 = (img0 / 255. - mean) / std + img1 = (img1 / 255. - mean) / std + + return img0, img1 + + +def split_feature(feature, + num_splits=2, + channel_last=False, + ): + if channel_last: # [B, H, W, C] + b, h, w, c = feature.size() + assert h % num_splits == 0 and w % num_splits == 0 + + b_new = b * num_splits * num_splits + h_new = h // num_splits + w_new = w // num_splits + + feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c + ).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c) # [B*K*K, H/K, W/K, C] + else: # [B, C, H, W] + b, c, h, w = feature.size() + assert h % num_splits == 0 and w % num_splits == 0, f'!!!!!!{h} {w} {num_splits} {h % num_splits} {w % num_splits}!!!!!!' + + b_new = b * num_splits * num_splits + h_new = h // num_splits + w_new = w // num_splits + + feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits + ).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new) # [B*K*K, C, H/K, W/K] + + return feature + + +def merge_splits(splits, + num_splits=2, + channel_last=False, + ): + if channel_last: # [B*K*K, H/K, W/K, C] + b, h, w, c = splits.size() + new_b = b // num_splits // num_splits + + splits = splits.view(new_b, num_splits, num_splits, h, w, c) + merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view( + new_b, num_splits * h, num_splits * w, c) # [B, H, W, C] + else: # [B*K*K, C, H/K, W/K] + b, c, h, w = splits.size() + new_b = b // num_splits // num_splits + + splits = splits.view(new_b, num_splits, num_splits, c, h, w) + merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view( + new_b, c, num_splits * h, num_splits * w) # [B, C, H, W] + + return merge + + +def generate_shift_window_attn_mask(input_resolution, window_size_h, window_size_w, + shift_size_h, shift_size_w, device=torch.device('cuda')): + # ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + # calculate attention mask for SW-MSA + h, w = input_resolution + img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1 + h_slices = (slice(0, -window_size_h), + slice(-window_size_h, -shift_size_h), + slice(-shift_size_h, None)) + w_slices = (slice(0, -window_size_w), + slice(-window_size_w, -shift_size_w), + slice(-shift_size_w, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = split_feature(img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True) + + mask_windows = mask_windows.view(-1, window_size_h * window_size_w) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + +def feature_add_position(feature0, feature1, attn_splits, feature_channels): + pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2) + + if attn_splits > 1: # add position in splited window + feature0_splits = split_feature(feature0, num_splits=attn_splits) + feature1_splits = split_feature(feature1, num_splits=attn_splits) + + position = pos_enc(feature0_splits) + + feature0_splits = feature0_splits + position + feature1_splits = feature1_splits + position + + feature0 = merge_splits(feature0_splits, num_splits=attn_splits) + feature1 = merge_splits(feature1_splits, num_splits=attn_splits) + else: + position = pos_enc(feature0) + + feature0 = feature0 + position + feature1 = feature1 + position + + return feature0, feature1 + + +def upsample_flow_with_mask(flow, up_mask, upsample_factor, + is_depth=False): + # convex upsampling following raft + + mask = up_mask + b, flow_channel, h, w = flow.shape + mask = mask.view(b, 1, 9, upsample_factor, upsample_factor, h, w) # [B, 1, 9, K, K, H, W] + mask = torch.softmax(mask, dim=2) + + multiplier = 1 if is_depth else upsample_factor + up_flow = F.unfold(multiplier * flow, [3, 3], padding=1) + up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w) # [B, 2, 9, 1, 1, H, W] + + up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W] + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W] + up_flow = up_flow.reshape(b, flow_channel, upsample_factor * h, + upsample_factor * w) # [B, 2, K*H, K*W] + + return up_flow + + +def split_feature_1d(feature, + num_splits=2, + ): + # feature: [B, W, C] + b, w, c = feature.size() + assert w % num_splits == 0 + + b_new = b * num_splits + w_new = w // num_splits + + feature = feature.view(b, num_splits, w // num_splits, c + ).view(b_new, w_new, c) # [B*K, W/K, C] + + return feature + + +def merge_splits_1d(splits, + h, + num_splits=2, + ): + b, w, c = splits.size() + new_b = b // num_splits // h + + splits = splits.view(new_b, h, num_splits, w, c) + merge = splits.view( + new_b, h, num_splits * w, c) # [B, H, W, C] + + return merge + + +def window_partition_1d(x, window_size_w): + """ + Args: + x: (B, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, C) + """ + B, W, C = x.shape + x = x.view(B, W // window_size_w, window_size_w, C).view(-1, window_size_w, C) + return x + + +def generate_shift_window_attn_mask_1d(input_w, window_size_w, + shift_size_w, device=torch.device('cuda')): + # calculate attention mask for SW-MSA + img_mask = torch.zeros((1, input_w, 1)).to(device) # 1 W 1 + w_slices = (slice(0, -window_size_w), + slice(-window_size_w, -shift_size_w), + slice(-shift_size_w, None)) + cnt = 0 + for w in w_slices: + img_mask[:, w, :] = cnt + cnt += 1 + + mask_windows = window_partition_1d(img_mask, window_size_w) # nW, window_size, 1 + mask_windows = mask_windows.view(-1, window_size_w) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # nW, window_size, window_size + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask diff --git a/modules/models/upr_basic.py b/modules/models/upr_basic.py new file mode 100644 index 0000000000000000000000000000000000000000..24f33159afc14482c8fd0d8990b5bc98be71fca8 --- /dev/null +++ b/modules/models/upr_basic.py @@ -0,0 +1,8 @@ +from modules.models.base_model import BaseModel +from modules.models import register + + +@register('upr_basic') +class UPRNet(BaseModel): + def __init__(self, cfg): + super(UPRNet, self).__init__(cfg) diff --git a/modules/models/upr_net.py b/modules/models/upr_net.py new file mode 100644 index 0000000000000000000000000000000000000000..2749a4ee5e6ccb1063eac3506d4eb6805b1c0d4a --- /dev/null +++ b/modules/models/upr_net.py @@ -0,0 +1,8 @@ +from modules.models.base_model import BaseModel +from modules.models import register + + +@register('upr_net') +class UPRNet(BaseModel): + def __init__(self, cfg): + super(UPRNet, self).__init__(cfg) diff --git a/modules/models/upr_net_freq.py b/modules/models/upr_net_freq.py new file mode 100644 index 0000000000000000000000000000000000000000..6621fc5ce9f2e82bd080e1e2ab5a5e970bf766d1 --- /dev/null +++ b/modules/models/upr_net_freq.py @@ -0,0 +1,8 @@ +from modules.models.base_model import BaseModel +from modules.models import register + + +@register('upr_net_freq') +class UPRNet(BaseModel): + def __init__(self, cfg): + super(UPRNet, self).__init__(cfg) diff --git a/modules/models/upr_net_freq2.py b/modules/models/upr_net_freq2.py new file mode 100644 index 0000000000000000000000000000000000000000..b6703bc60d050f17b920b0e876e2e4abfc13d64a --- /dev/null +++ b/modules/models/upr_net_freq2.py @@ -0,0 +1,8 @@ +from modules.models.base_model import BaseModel +from modules.models import register + + +@register('upr_net_freq2') +class UPRNet(BaseModel): + def __init__(self, cfg): + super(UPRNet, self).__init__(cfg) diff --git a/modules/models/upr_net_mod.py b/modules/models/upr_net_mod.py new file mode 100644 index 0000000000000000000000000000000000000000..9773eaebbf9584cbcf47d34a39239e30ae6557dd --- /dev/null +++ b/modules/models/upr_net_mod.py @@ -0,0 +1,8 @@ +from modules.models.base_model import BaseModel +from modules.models import register + + +@register('upr_net_mod') +class UPRNet(BaseModel): + def __init__(self, cfg): + super(UPRNet, self).__init__(cfg) diff --git a/modules/models/upr_net_mod2.py b/modules/models/upr_net_mod2.py new file mode 100644 index 0000000000000000000000000000000000000000..b191571a618a996ebd3d0a72e5167de45c12fbc5 --- /dev/null +++ b/modules/models/upr_net_mod2.py @@ -0,0 +1,8 @@ +from modules.models.base_model import BaseModel +from modules.models import register + + +@register('upr_net_mod2') +class UPRNet(BaseModel): + def __init__(self, cfg): + super(UPRNet, self).__init__(cfg) \ No newline at end of file diff --git a/modules/models/upr_net_multi_flow.py b/modules/models/upr_net_multi_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..b75d2263b83bf45f070bfd6d10af4fcfa637aacd --- /dev/null +++ b/modules/models/upr_net_multi_flow.py @@ -0,0 +1,8 @@ +from modules.models.base_model import BaseModel +from modules.models import register + + +@register('upr_net_multi_flow') +class UPRNetMultiFlow(BaseModel): + def __init__(self, cfg): + super(UPRNetMultiFlow, self).__init__(cfg) diff --git a/modules/optimizer.py b/modules/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..108c29c79e9a1bcc86d9e2136135f44b7ef569b4 --- /dev/null +++ b/modules/optimizer.py @@ -0,0 +1,10 @@ +from torch.optim import * + + +def make_optimizer(params, optimizer_spec): + optimizer = { + 'sgd': SGD, + 'adam': Adam, + 'adamW': AdamW + }[optimizer_spec['name']](params, **optimizer_spec['args']) + return optimizer diff --git a/prepare_extra_training_dataset.py b/prepare_extra_training_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a0da11c3e998d6b64744c1d3b5658404053a5acb --- /dev/null +++ b/prepare_extra_training_dataset.py @@ -0,0 +1,109 @@ +# [START COMMAND] +# python3 -m prepare_extra_training_dataset \ +# --video_path ../datasets/video1.mov \ +# --frame_save_root ../datasets/frames1 \ +# --tri_save_root ../datasets/frames1_triplet/sequences/0001 \ +# --dup_info ../datasets/dup_video1.txt \ +# --flow_save_root ../datasets/frames1_unimatch_flow/sequences \ +# --total_gpus 1 --start_cuda_index 1 --cuda_index 1 --first_scaling 2 + +# python3 -m prepare_extra_training_dataset \ +# --video_path ../datasets/video2.mov \ +# --frame_save_root ../datasets/frames2 \ +# --tri_save_root ../datasets/frames2_triplet/sequences/0001 \ +# --dup_info ../datasets/dup_video2.txt \ +# --flow_save_root ../datasets/frames2_unimatch_flow/sequences \ +# --total_gpus 1 --start_cuda_index 1 --cuda_index 1 --first_scaling 2 + +# python3 -m prepare_extra_training_dataset \ +# --video_path ../datasets/video3.mov \ +# --frame_save_root ../datasets/frames3 \ +# --tri_save_root ../datasets/frames3_triplet/sequences/0001 \ +# --dup_info ../datasets/dup_video3.txt \ +# --flow_save_root ../datasets/frames3_unimatch_flow/sequences \ +# --total_gpus 1 --start_cuda_index 1 --cuda_index 1 --first_scaling 2 + +# python3 -m prepare_extra_training_dataset \ +# --video_path ../datasets/video4.mov \ +# --frame_save_root ../datasets/frames4 \ +# --tri_save_root ../datasets/frames4_triplet/sequences/0001 \ +# --dup_info ../datasets/dup_video4.txt \ +# --flow_save_root ../datasets/frames4_unimatch_flow/sequences \ +# --total_gpus 1 --start_cuda_index 1 --cuda_index 1 --first_scaling 2 + +import argparse + +import os +import cv2 +import glob +from tqdm import tqdm + +def main(): + parser = argparse.ArgumentParser(description="Preparing extra training dataset for UPR-Net-back inference.") + parser.add_argument('--video_path', type=str, default='../datasets/video4.mov', help="video file path") + parser.add_argument('--frame_save_root', type=str, default='../datasets/frames4', help="root to save frames") + parser.add_argument('--tri_save_root', type=str, default='../datasets/frames4_triplet/sequences', help="root to save triplets") + parser.add_argument('--dup_info', type=str, default='../datasets/dup_video4.txt', help="duplicated frames information") + parser.add_argument('--flow_save_root', type=str, default='../datasets/frames4_unimatch_flow/sequences', help="root to save UniMatch optical flows") + + # UniMatch parameters + parser.add_argument('--total_gpus', type=int, default=2, help="number of CUDA GPUs to use") + parser.add_argument("--start_cuda_index", type=int, default=0, help="starting CUDA GPU index") + parser.add_argument("--cuda_index", type=int, default=0, help="CUDA GPU index") + parser.add_argument("--first_scaling", type=int, default=1, help="downsizing ratio before computing flow") + + args = parser.parse_args() + + os.makedirs(args.frame_save_root, exist_ok=True) + video = cv2.VideoCapture(args.video_path) + num_frame = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) + + with open(args.dup_info, 'r') as f: + dup_list = [int(line.strip()) for line in f.readlines() if len(line.strip())>0] + + print('SAVE Frames') + file_num = 0 + for index in tqdm(range(num_frame)): + _, frame = video.read() + if index in dup_list: continue + newfile = os.path.join(args.frame_save_root, str(file_num).zfill(4)+'.png') + cv2.imwrite(newfile, frame) + file_num += 1 + + file_list = sorted(glob.glob(os.path.join(args.frame_save_root, '*.png'))) + + print('SAVE Triplets') + for i, file in enumerate(tqdm(file_list[:-1])): + if i==0: continue + if i==1: + prv = cv2.imread(file_list[i-1]) + cur = cv2.imread(file) + prv_name = file_list[i-1] + cur_name = file_list[i] + nxt = cv2.imread(file_list[i+1]) + nxt_name = file_list[i+1] + + # SAVE + newfolder = os.path.join(args.tri_save_root, str(i).zfill(4)) + os.makedirs(newfolder, exist_ok=True) + cv2.imwrite(newfolder+'/im1.png', prv) + cv2.imwrite(newfolder+'/im2.png', cur) + cv2.imwrite(newfolder+'/im3.png', nxt) + temp = '/'.join(newfolder.split('/')[-2:]) + with open(os.path.join(args.tri_save_root, '..', '..', 'tri_trainlist.txt'), 'w' if i==1 else 'a') as f: + f.writelines(f'{temp}\n') + + prv = cur.copy() + cur = nxt.copy() + prv_name = cur_name + cur_name = nxt_name + + cwd = os.getcwd() + os.chdir('../unimatch_inference') + cmd = f'python3 -m unimatch_inference --total_gpus {args.total_gpus} --start_cuda_index {args.start_cuda_index} --cuda_index {args.cuda_index} \ + --root {os.path.join(args.tri_save_root, "..")} --save_root {args.flow_save_root} --first_scaling {args.first_scaling}' + os.system(cmd) + os.chdir(cwd) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..248cd9130e0fb28285b822911599be59babb54da --- /dev/null +++ b/requirements.txt @@ -0,0 +1,232 @@ +absl-py==2.0.0 +aggregate==0.13.0 +aiofiles==23.2.1 +aiohttp==3.8.6 +aiosignal==1.3.1 +altair==5.2.0 +annotated-types==0.6.0 +anyio==4.2.0 +appdirs==1.4.4 +argon2-cffi==23.1.0 +argon2-cffi-bindings==21.2.0 +arrow==1.3.0 +asttokens==2.4.0 +async-lru==2.0.4 +async-timeout==4.0.3 +attrs==23.1.0 +Babel==2.14.0 +backcall==0.2.0 +beautifulsoup4==4.12.2 +bleach==6.1.0 +cachetools==5.3.1 +certifi==2023.7.22 +cffi==1.16.0 +charset-normalizer==3.3.0 +click==8.1.7 +colorama==0.4.6 +coloredlogs==15.0.1 +comm==0.1.4 +contourpy==1.1.1 +cycler==0.12.1 +debugpy==1.8.0 +decorator==4.4.2 +defusedxml==0.7.1 +docker-pycreds==0.4.0 +easydict==1.11 +einops==0.7.0 +exceptiongroup==1.1.3 +executing==2.0.0 +fastapi==0.110.0 +fastjsonschema==2.19.1 +fastrlock +ffmpy==0.3.2 +filelock==3.12.4 +flatbuffers==23.5.26 +fonttools==4.43.1 +fqdn==1.5.1 +frozenlist==1.4.0 +fsspec==2023.9.2 +gdown==5.1.0 +gitdb==4.0.10 +GitPython==3.1.37 +google-auth==2.23.2 +google-auth-oauthlib==1.0.0 +gradio==4.19.2 +gradio_client==0.10.1 +grpcio==1.59.0 +h11==0.14.0 +httpcore==1.0.4 +httpx==0.27.0 +huggingface-hub==0.21.3 +humanfriendly==10.0 +idna==3.4 +imageio==2.31.5 +imageio-ffmpeg==0.4.9 +importlib-metadata==6.8.0 +importlib-resources==6.1.0 +ipykernel==6.25.2 +ipython==8.16.1 +ipywidgets==8.1.1 +isoduration==20.11.0 +jedi==0.19.1 +Jinja2==3.1.2 +joblib==1.3.2 +json5==0.9.14 +jsonpointer==2.4 +jsonschema==4.20.0 +jsonschema-specifications==2023.12.1 +jupyter==1.0.0 +jupyter-console==6.6.3 +jupyter-events==0.9.0 +jupyter-lsp==2.2.1 +jupyter_client==8.3.1 +jupyter_core==5.3.2 +jupyter_server==2.12.2 +jupyter_server_terminals==0.5.1 +jupyterlab==4.0.10 +jupyterlab-widgets==3.0.9 +jupyterlab_pygments==0.3.0 +jupyterlab_server==2.25.2 +kiwisolver==1.4.5 +lazy_loader==0.3 +lightning-utilities==0.9.0 +loguru==0.7.2 +Markdown==3.5 +markdown-it-py==3.0.0 +MarkupSafe==2.1.3 +matplotlib==3.8.0 +matplotlib-inline==0.1.6 +mdurl==0.1.2 +mistune==3.0.2 +moviepy==1.0.3 +mpmath==1.3.0 +multidict==6.0.4 +nbclient==0.9.0 +nbconvert==7.14.0 +nbformat==5.9.2 +nest-asyncio==1.5.8 +networkx==3.1 +notebook==7.0.6 +notebook_shim==0.2.3 +numpy +nvidia-cublas-cu12==12.1.3.1 +nvidia-cuda-cupti-cu12==12.1.105 +nvidia-cuda-nvrtc-cu12==12.1.105 +nvidia-cuda-runtime-cu12==12.1.105 +nvidia-cudnn-cu12==8.9.2.26 +nvidia-cufft-cu12==11.0.2.54 +nvidia-curand-cu12==10.3.2.106 +nvidia-cusolver-cu12==11.4.5.107 +nvidia-cusparse-cu12==12.1.0.106 +nvidia-nccl-cu12==2.18.1 +nvidia-nvjitlink-cu12==12.2.140 +nvidia-nvtx-cu12==12.1.105 +oauthlib==3.2.2 +onnx==1.15.0 +onnxruntime==1.16.3 +opencv-python==4.8.1.78 +orjson==3.9.15 +overrides==7.4.0 +packaging==23.2 +pandas==2.1.1 +pandocfilters==1.5.0 +parso==0.8.3 +pathtools==0.1.2 +pexpect==4.8.0 +pickleshare==0.7.5 +Pillow==10.0.1 +platformdirs==3.11.0 +proglog==0.1.10 +progressbar2==4.2.0 +prometheus-client==0.19.0 +prompt-toolkit==3.0.39 +protobuf==4.24.4 +psutil==5.9.5 +ptyprocess==0.7.0 +pure-eval==0.2.2 +pyasn1==0.5.0 +pyasn1-modules==0.3.0 +pycocotools==2.0.7 +pycparser==2.21 +pydantic==2.6.3 +pydantic_core==2.16.3 +pydub==0.25.1 +Pygments==2.16.1 +pyparsing==3.1.1 +PySocks==1.7.1 +python-dateutil==2.8.2 +python-json-logger==2.0.7 +python-multipart==0.0.9 +python-utils==3.8.1 +pytorch-lightning==2.0.9.post0 +pytz==2023.3.post1 +PyYAML==6.0.1 +pyzmq==25.1.1 +qtconsole==5.5.1 +QtPy==2.4.1 +referencing==0.32.1 +requests==2.31.0 +requests-oauthlib==1.3.1 +rfc3339-validator==0.1.4 +rfc3986-validator==0.1.1 +rich==13.7.1 +rpds-py==0.16.2 +rsa==4.9 +ruff==0.3.0 +safetensors==0.4.0 +scikit-image==0.22.0 +scikit-learn==1.3.1 +scikit-video==1.1.11 +scipy==1.11.3 +segment-anything +semantic-version==2.10.0 +Send2Trash==1.8.2 +sentry-sdk==1.31.0 +setproctitle==1.3.3 +shellingham==1.5.4 +six==1.16.0 +sly==0.5 +smmap==5.0.1 +sniffio==1.3.0 +soupsieve==2.5 +stack-data==0.6.3 +starlette==0.36.3 +sympy==1.12 +tensorboard==2.14.1 +tensorboard-data-server==0.7.1 +tensorboardX==2.6.2.2 +terminado==0.18.0 +threadpoolctl==3.2.0 +tifffile==2023.9.26 +timm==0.9.7 +tinycss2==1.2.1 +titlecase==2.4.1 +tomli==2.0.1 +tomlkit==0.12.0 +toolz==0.12.1 +torch==2.1.0 +torchmetrics==1.2.0 +torchvision==0.16.0 +tornado==6.3.3 +tqdm==4.66.1 +traitlets==5.11.2 +triton==2.1.0 +typer==0.9.0 +types-python-dateutil==2.8.19.20240106 +typing_extensions==4.8.0 +tzdata==2023.3 +uri-template==1.3.0 +urllib3==2.0.6 +uvicorn==0.27.1 +wandb==0.15.12 +wcwidth==0.2.8 +webcolors==1.13 +webencodings==0.5.1 +websocket-client==1.7.0 +websockets==11.0.3 +Werkzeug==3.0.0 +widgetsnbextension==4.0.9 +yacs==0.1.8 +yarl==1.9.2 +zipp==3.17.0 +cupy-cuda11x \ No newline at end of file diff --git a/trainer/__init__.py b/trainer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5c0a8a4a91724515aee0aecd8217cfe16ee5ec80 --- /dev/null +++ b/trainer/__init__.py @@ -0,0 +1 @@ +from .trainer import * diff --git a/trainer/__pycache__/__init__.cpython-310.pyc b/trainer/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45c813905d903d413daa0c5ed5feee3b9cac6344 Binary files /dev/null and b/trainer/__pycache__/__init__.cpython-310.pyc differ diff --git a/trainer/__pycache__/__init__.cpython-38.pyc b/trainer/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a745b029335c82caac99eed075d664dbe9701fbd Binary files /dev/null and b/trainer/__pycache__/__init__.cpython-38.pyc differ diff --git a/trainer/__pycache__/trainer.cpython-310.pyc b/trainer/__pycache__/trainer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bfaf6bfb02f272c0d58095525fdca822619c606 Binary files /dev/null and b/trainer/__pycache__/trainer.cpython-310.pyc differ diff --git a/trainer/__pycache__/trainer.cpython-38.pyc b/trainer/__pycache__/trainer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..474095d744321800c63c1d4d54164c6fa3d9c8c0 Binary files /dev/null and b/trainer/__pycache__/trainer.cpython-38.pyc differ diff --git a/trainer/trainer.py b/trainer/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..836432bc96426ed3d48ce30ddae0f40ccc88b1c2 --- /dev/null +++ b/trainer/trainer.py @@ -0,0 +1,187 @@ +import torch +import logging +import os +import os.path as osp +import time +import datetime +import random +import wandb +import yaml +import json +import numpy as np + +from torch.nn.parallel import DistributedDataParallel +from torch.utils.data import DistributedSampler, DataLoader + +import utils +from utils.misc import print_cuda_statistics, is_main_process, get_rank, get_world_size +import datasets +import modules.models as models + + +class Trainer(object): + """ + Wrapper for training, more related to engineering than research code + """ + + def __init__(self, cfgs): + self.rank = get_rank() + self.cfgs = cfgs + self.is_master = (self.rank == 0) + self.is_train = False + + env = cfgs['env'] + self.tot_gpus = get_world_size() + self.distributed = (get_world_size() > 1) + + # Setup log, tensorboard, wandb + if self.is_master: + logger = utils.misc.set_save_dir(cfgs['log_dir'], cfgs["run_description"], replace=False) + with open(osp.join(cfgs['cfg_dir'], f'cfg_{cfgs["run_description"]}.yaml'), 'w') as f: + yaml.dump(cfgs, f, sort_keys=False) + + self.log = logger.info + + self.enable_tb = True + + if env['wandb_upload']: + self.enable_wandb = True + self.cfgs['enable_wandb'] = True + with open('wandb.yaml', 'r') as f: + wandb_cfg = yaml.load(f, Loader=yaml.FullLoader) + os.environ['WANDB_DIR'] = env['save_dir'] + os.environ['WANDB_NAME'] = env['exp_name'] + os.environ['WANDB_API_KEY'] = wandb_cfg['api_key'] + if 'resume' in self.cfgs: + with open(os.path.join(env['save_dir'], 'wandb', 'wandb-resume.json')) as f: + run_id = json.load(f)['run_id'] + wandb.init(id=run_id, resume="allow") + else: + wandb.init(project=wandb_cfg['project'], entity=wandb_cfg['entity'], config=cfgs, name=env['exp_name'], resume='allow') + else: + self.enable_wandb = False + self.cfgs['enable_wandb'] = False + else: + self.log = lambda *args, **kwargs: None + self.enable_tb = False + self.enable_wandb = False + + self.make_datasets() + self.model = models.make(cfgs) + self.start_epoch = 0 + self.end_epoch = self.cfgs['max_epoch'] + if 'resume' in self.cfgs: + self.start_epoch = self.model.load_checkpoint(self.cfgs['resume']) + if 'pretrained' in self.cfgs: + self.start_epoch = self.model.load_pretrained(self.cfgs['pretrained']) + + def make_datasets(self): + """ + By default, train dataset performs shuffle and drop_last. + Distributed sampler will extend the dataset with a prefix to make the length divisible by tot_gpus, samplers should be stored in .dist_samplers. + + Cfg example: + + train/test_dataset: + name: + args: + loader: {batch_size: , num_workers: } + """ + cfgs = self.cfgs + self.dist_samplers = [] + + def make_distributed_loader(dataset, batch_size, num_workers, shuffle=False, drop_last=False): + sampler = DistributedSampler(dataset, shuffle=shuffle) if self.distributed else None + loader = DataLoader( + dataset, + batch_size // self.tot_gpus, + drop_last=drop_last, + sampler=sampler, + shuffle=(shuffle and (sampler is None)), + num_workers=num_workers // self.tot_gpus, + pin_memory=True) + return loader, sampler + + if cfgs.get('train_dataset') is not None: + train_dataset = datasets.make(cfgs['train_dataset']) + self.log(f'Train dataset: len={len(train_dataset)}') + l = cfgs['train_dataset']['loader'] + self.train_loader, train_sampler = make_distributed_loader( + train_dataset, l['batch_size'], l['num_workers'], shuffle=True, drop_last=True) + self.dist_samplers.append(train_sampler) + self.cfgs['lr_scheduler']['args']['total_steps'] = len(self.train_loader) * self.cfgs['max_epoch'] + + if cfgs.get('test_dataset') is not None and self.is_master: + test_dataset = datasets.make(cfgs['test_dataset']) + self.log(f'Test dataset: len={len(test_dataset)}') + l = cfgs['test_dataset']['loader'] + self.test_loader = DataLoader(test_dataset, l['batch_size'], drop_last=False, shuffle=False, num_workers=l['num_workers'], pin_memory=True) +# self.test_loader, test_sampler = make_distributed_loader( +# test_dataset, l['batch_size'], l['num_workers'], shuffle=False, drop_last=False) +# self.dist_samplers.append(test_sampler) + if cfgs.get('demo_dataset') is not None: + self.demo_root = self.cfgs['demo_dataset']['args']['root_path'] + + def train(self): + print("Start training") + start_time = time.time() + self.is_train = True + self.model.init_training_logger() + self.best_performance = 0 + # torch.backends.cudnn.benchmark = True +# if is_main_process(): +# epoch = -1 +# performance = self.validate() +# if performance > self.best_performance: +# self.best_performance = performance +# self.model.save_checkpoint('model_{}.pth'.format(epoch + 1), is_best=1) +# self.log( +# "best performance achieved at epoch {} with performance of {}".format(epoch, +# self.best_performance)) + print(f'@@@@@@@@@@@@@@@@@@@@@@@{self.start_epoch} {self.end_epoch}@@@@@@@@@@@@@@@@@@@@@@@@@@@') + for epoch in range(self.start_epoch, self.end_epoch): + if self.cfgs['distributed']: + self.train_loader.batch_sampler.sampler.set_epoch(epoch) + + random.seed(self.cfgs['seed'] + epoch) + np.random.seed(self.cfgs['seed'] + epoch) + torch.random.manual_seed(self.cfgs['seed'] + epoch) + torch.manual_seed(self.cfgs['seed'] + epoch) + torch.cuda.manual_seed_all(self.cfgs['seed'] + epoch) + + self.model.train_one_epoch(self.train_loader, epoch) + +# if ((epoch + 1) % self.cfgs['validate_every']) == 0: +# if is_main_process(): +# performance = self.validate() +# if performance > self.best_performance: +# self.best_performance = performance +# self.model.save_checkpoint('model_{}.pth'.format(epoch + 1), is_best=1) +# self.log( +# "best performance achieved at epoch {} with performance of {}".format(epoch, +# self.best_performance)) + + if ((epoch + 1) % self.cfgs['save_every']) == 0 and is_main_process(): + self.model.save_checkpoint('model_{}.pth'.format(epoch + 1)) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('Training time {}'.format(total_time_str)) + self.finalize_training() + + def validate(self): + # return performance to save the best model, if there is no performance measure e.g. GAN then just return 0 + if not self.is_train: # if mode == validation only + self.model.init_validation_logger() + return self.model.validate(self.test_loader) + + def test(self): + self.model.init_testing_logger() + self.model.validate(self.test_loader) + + def demo(self): + self.model.init_demo_logger() + self.model.demo(self.demo_root) + + def finalize_training(self): + self.model.finalize_training() diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..45d3b66b8857cbbd8e8bce38dc2117ada1bc536c --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1 @@ +from .misc import * \ No newline at end of file diff --git a/utils/__pycache__/__init__.cpython-310.pyc b/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c95109c3e58ccbee4e5a3ac8143f36a446714d21 Binary files /dev/null and b/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/utils/__pycache__/__init__.cpython-38.pyc b/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94c29534547f02a00da17a93899fa34a9260db60 Binary files /dev/null and b/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/utils/__pycache__/__init__.cpython-39.pyc b/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..644ec79e0ae5a201d159763d7a8499f488088217 Binary files /dev/null and b/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/utils/__pycache__/experiment.cpython-310.pyc b/utils/__pycache__/experiment.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cab7926e7aa1582f8a9a696312bea95b87548a0 Binary files /dev/null and b/utils/__pycache__/experiment.cpython-310.pyc differ diff --git a/utils/__pycache__/experiment.cpython-38.pyc b/utils/__pycache__/experiment.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef522d6562de674e25f116e447dc6d35c81a259e Binary files /dev/null and b/utils/__pycache__/experiment.cpython-38.pyc differ diff --git a/utils/__pycache__/flow_visualization.cpython-310.pyc b/utils/__pycache__/flow_visualization.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..484db11b9fe284e97bab51a633b5b332eb2f3f25 Binary files /dev/null and b/utils/__pycache__/flow_visualization.cpython-310.pyc differ diff --git a/utils/__pycache__/flow_visualization.cpython-38.pyc b/utils/__pycache__/flow_visualization.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c0f79f71e6ebc69b829beb116f8755e7d06b0fe Binary files /dev/null and b/utils/__pycache__/flow_visualization.cpython-38.pyc differ diff --git a/utils/__pycache__/flowvis.cpython-310.pyc b/utils/__pycache__/flowvis.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13544877df450dd83a7a6534a71b343a00a25aff Binary files /dev/null and b/utils/__pycache__/flowvis.cpython-310.pyc differ diff --git a/utils/__pycache__/flowvis.cpython-38.pyc b/utils/__pycache__/flowvis.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4daba362966c45b2722d016e218a59baec6a43d6 Binary files /dev/null and b/utils/__pycache__/flowvis.cpython-38.pyc differ diff --git a/utils/__pycache__/flowvis.cpython-39.pyc b/utils/__pycache__/flowvis.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b3ae279c45e45ed7d42df026efcefd336c6a7c5 Binary files /dev/null and b/utils/__pycache__/flowvis.cpython-39.pyc differ diff --git a/utils/__pycache__/metrics.cpython-310.pyc b/utils/__pycache__/metrics.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da5e25e2f21942ae39dc437cdf207b2ab89ee542 Binary files /dev/null and b/utils/__pycache__/metrics.cpython-310.pyc differ diff --git a/utils/__pycache__/metrics.cpython-38.pyc b/utils/__pycache__/metrics.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e03708c03e7d07ea88a95e655c4fa9667675a54d Binary files /dev/null and b/utils/__pycache__/metrics.cpython-38.pyc differ diff --git a/utils/__pycache__/metrics.cpython-39.pyc b/utils/__pycache__/metrics.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dda058a52dcc76753091f1e0beae02e2a2892937 Binary files /dev/null and b/utils/__pycache__/metrics.cpython-39.pyc differ diff --git a/utils/__pycache__/misc.cpython-310.pyc b/utils/__pycache__/misc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e88ed6e1250d87108dbebed01de93233829075d Binary files /dev/null and b/utils/__pycache__/misc.cpython-310.pyc differ diff --git a/utils/__pycache__/misc.cpython-38.pyc b/utils/__pycache__/misc.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5998871a452a29f69b0b8fb0674e6754b70f6af4 Binary files /dev/null and b/utils/__pycache__/misc.cpython-38.pyc differ diff --git a/utils/__pycache__/misc.cpython-39.pyc b/utils/__pycache__/misc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36e007953f0187fa46a8a93144acce6680d36733 Binary files /dev/null and b/utils/__pycache__/misc.cpython-39.pyc differ diff --git a/utils/__pycache__/padder.cpython-310.pyc b/utils/__pycache__/padder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9e5f95ea9c980d05301314309d4969e2293d7fe Binary files /dev/null and b/utils/__pycache__/padder.cpython-310.pyc differ diff --git a/utils/__pycache__/padder.cpython-38.pyc b/utils/__pycache__/padder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3cc78361643bf759e8d00fbf993e43e064b9c3a Binary files /dev/null and b/utils/__pycache__/padder.cpython-38.pyc differ diff --git a/utils/__pycache__/padder.cpython-39.pyc b/utils/__pycache__/padder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8562dd8bc5919368067035e919575c1430dd9517 Binary files /dev/null and b/utils/__pycache__/padder.cpython-39.pyc differ diff --git a/utils/__pycache__/plot.cpython-310.pyc b/utils/__pycache__/plot.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..beb9f60ce746f471ff8e1cf3a4a83a8f9cfd87d0 Binary files /dev/null and b/utils/__pycache__/plot.cpython-310.pyc differ diff --git a/utils/__pycache__/plot.cpython-38.pyc b/utils/__pycache__/plot.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd505e0b19e02a6efad03ff22bec7cd4e33853c6 Binary files /dev/null and b/utils/__pycache__/plot.cpython-38.pyc differ diff --git a/utils/__pycache__/plot.cpython-39.pyc b/utils/__pycache__/plot.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9ec629eb6f8bdfec61e0c7ef7bc867315fd15f2 Binary files /dev/null and b/utils/__pycache__/plot.cpython-39.pyc differ diff --git a/utils/download_xiph.py b/utils/download_xiph.py new file mode 100644 index 0000000000000000000000000000000000000000..cc283e26486d7743e2d33499a726ce0e76ef9be4 --- /dev/null +++ b/utils/download_xiph.py @@ -0,0 +1,32 @@ +import os +import os.path as osp +import glob + + +if __name__ == "__main__": + root = "../data/xiph" + + if not os.path.exists(root): + os.mkdir(root) + ############################################# Prepare Dataset ############################################# + download_links = [ + 'https://media.xiph.org/video/derf/ElFuente/Netflix_BoxingPractice_4096x2160_60fps_10bit_420.y4m', + 'https://media.xiph.org/video/derf/ElFuente/Netflix_Crosswalk_4096x2160_60fps_10bit_420.y4m', + 'https://media.xiph.org/video/derf/Chimera/Netflix_DrivingPOV_4096x2160_60fps_10bit_420.y4m', + 'https://media.xiph.org/video/derf/ElFuente/Netflix_FoodMarket_4096x2160_60fps_10bit_420.y4m', + 'https://media.xiph.org/video/derf/ElFuente/Netflix_FoodMarket2_4096x2160_60fps_10bit_420.y4m', + 'https://media.xiph.org/video/derf/ElFuente/Netflix_RitualDance_4096x2160_60fps_10bit_420.y4m', + 'https://media.xiph.org/video/derf/ElFuente/Netflix_SquareAndTimelapse_4096x2160_60fps_10bit_420.y4m', + 'https://media.xiph.org/video/derf/ElFuente/Netflix_Tango_4096x2160_60fps_10bit_420.y4m', + ] + file_list = ['BoxingPractice', 'Crosswalk', 'DrivingPOV', 'FoodMarket', 'FoodMarket2', 'RitualDance', + 'SquareAndTimelapse', 'Tango'] + + for file_name, link in zip(file_list, download_links): + data_dir = osp.join(root, file_name) + if osp.exists(data_dir) is False: + os.makedirs(data_dir) + if len(glob.glob(f'{data_dir}/*.png')) < 100: + ffmpeg_path = "~/anaconda3/bin/ffmpeg" + os.system(f'{ffmpeg_path} -i {link} -pix_fmt rgb24 -vframes 100 {data_dir}/%03d.png') + # ############################################### Prepare End ############################################### diff --git a/utils/experiment.py b/utils/experiment.py new file mode 100644 index 0000000000000000000000000000000000000000..1279c103fcbd92bf74ae75f9b5536267660d705a --- /dev/null +++ b/utils/experiment.py @@ -0,0 +1,76 @@ +""" +Experiment related stuffs +Act as a bridge between main and utils (logging, init directory, etc) +""" +from pathlib import Path +import os +import random +import numpy as np +import cupyx.distributed + +import torch.distributed as dist +import torch + + +def init_experiment(cfgs): + """ + in: + cfgs: arguments such as hyperparameters and other + out: + --cfgs + procedure to initialize experiment consisting of: + - parse config file as a json dictionary + - initialize logging + - create dictionary to save everything + """ + + assert 'exp_name' in cfgs + + cfgs['summary_dir'] = os.path.join(cfgs['env']['save_dir'], "summaries") + cfgs['checkpoint_dir'] = os.path.join(cfgs['env']['save_dir'], "checkpoints") + cfgs['output_dir'] = os.path.join(cfgs['env']['save_dir'], "output") + cfgs['log_dir'] = os.path.join(cfgs['env']['save_dir'], "logs") + cfgs['cfg_dir'] = os.path.join(cfgs['env']['save_dir'], "cfgs") + mode = cfgs["mode"] + dataset = cfgs[f"{mode}_dataset"]['name'] + split = cfgs[f"{mode}_dataset"]['args']['split'] + cfgs['run_description'] = f'{mode}_{dataset}_{split}' + + Path(cfgs['summary_dir']).mkdir(parents=True, exist_ok=True) + Path(cfgs['checkpoint_dir']).mkdir(parents=True, exist_ok=True) + Path(cfgs['output_dir']).mkdir(parents=True, exist_ok=True) + Path(cfgs['log_dir']).mkdir(parents=True, exist_ok=True) + Path(cfgs['cfg_dir']).mkdir(parents=True, exist_ok=True) + + +def init_deterministic(random_seed=7): + random.seed(random_seed) + np.random.seed(random_seed) + torch.random.manual_seed(random_seed) + torch.manual_seed(random_seed) + torch.cuda.manual_seed_all(random_seed) + torch.backends.cudnn.benchmark = True + + +def init_distributed_mode(cfgs): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + cfgs['rank'] = int(os.environ["RANK"]) + cfgs['world_size'] = int(os.environ['WORLD_SIZE']) + cfgs['gpu'] = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + cfgs['rank'] = int(os.environ['SLURM_PROCID']) + cfgs['gpu'] = cfgs['rank'] % torch.cuda.device_count() + else: + print('Not using distributed mode') + cfgs['distributed'] = False + return + + cfgs['distributed'] = True + torch.cuda.set_device(cfgs['gpu']) + cfgs['dist_backend'] = 'nccl' + print('| distributed init (rank {}): {}'.format( + cfgs['rank'], cfgs['dist_url']), flush=True) + dist.init_process_group(backend=cfgs['dist_backend'], init_method=cfgs['dist_url'], + world_size=cfgs['world_size'], rank=cfgs['rank']) +# cupyx.distributed.NCCLBackend(n_devices=cfgs['world_size'], rank=cfgs['rank']) + dist.barrier() diff --git a/utils/flow_generation/__init__.py b/utils/flow_generation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/flow_generation/gen_flow.py b/utils/flow_generation/gen_flow.py new file mode 100644 index 0000000000000000000000000000000000000000..fed7676d91fe62bce8fa66293cea601a5fa700bb --- /dev/null +++ b/utils/flow_generation/gen_flow.py @@ -0,0 +1,277 @@ +import os +import sys +import torch +import argparse +import numpy as np +import os.path as osp +import re +from imageio import imread, imwrite +import torch.nn.functional as F + +sys.path.append('.') +from utils.flow_generation.liteflownet.run import estimate + + + +def read(file): + if file.endswith('.float3'): return readFloat(file) + elif file.endswith('.flo'): return readFlow(file) + elif file.endswith('.ppm'): return readImage(file) + elif file.endswith('.pgm'): return readImage(file) + elif file.endswith('.png'): return readImage(file) + elif file.endswith('.jpg'): return readImage(file) + elif file.endswith('.pfm'): return readPFM(file)[0] + else: raise Exception('don\'t know how to read %s' % file) + + +def write(file, data): + if file.endswith('.float3'): return writeFloat(file, data) + elif file.endswith('.flo'): return writeFlow(file, data) + elif file.endswith('.ppm'): return writeImage(file, data) + elif file.endswith('.pgm'): return writeImage(file, data) + elif file.endswith('.png'): return writeImage(file, data) + elif file.endswith('.jpg'): return writeImage(file, data) + elif file.endswith('.pfm'): return writePFM(file, data) + else: raise Exception('don\'t know how to write %s' % file) + + +def readPFM(file): + file = open(file, 'rb') + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == 'PF': + color = True + elif header.decode("ascii") == 'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception('Malformed PFM header.') + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: + endian = '<' + scale = -scale + else: + endian = '>' + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data, scale + + +def writePFM(file, image, scale=1): + file = open(file, 'wb') + + color = None + + if image.dtype.name != 'float32': + raise Exception('Image dtype must be float32.') + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: + color = True + elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: + color = False + else: + raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') + + file.write('PF\n' if color else 'Pf\n'.encode()) + file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == '<' or endian == '=' and sys.byteorder == 'little': + scale = -scale + + file.write('%f\n'.encode() % scale) + + image.tofile(file) + + +def readFlow(name): + if name.endswith('.pfm') or name.endswith('.PFM'): + return readPFM(name)[0][:,:,0:2] + + f = open(name, 'rb') + + header = f.read(4) + if header.decode("utf-8") != 'PIEH': + raise Exception('Flow file header does not contain PIEH') + + width = np.fromfile(f, np.int32, 1).squeeze() + height = np.fromfile(f, np.int32, 1).squeeze() + + flow = np.fromfile(f, np.float32, width * height * 2).reshape((height, width, 2)) + + return flow.astype(np.float32) + + +def readImage(name): + if name.endswith('.pfm') or name.endswith('.PFM'): + data = readPFM(name)[0] + if len(data.shape)==3: + return data[:,:,0:3] + else: + return data + return imread(name) + + +def writeImage(name, data): + if name.endswith('.pfm') or name.endswith('.PFM'): + return writePFM(name, data, 1) + return imwrite(name, data) + + +def writeFlow(name, flow): + f = open(name, 'wb') + f.write('PIEH'.encode('utf-8')) + np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) + flow = flow.astype(np.float32) + flow.tofile(f) + + +def readFloat(name): + f = open(name, 'rb') + + if(f.readline().decode("utf-8")) != 'float\n': + raise Exception('float file %s did not contain keyword' % name) + + dim = int(f.readline()) + + dims = [] + count = 1 + for i in range(0, dim): + d = int(f.readline()) + dims.append(d) + count *= d + + dims = list(reversed(dims)) + + data = np.fromfile(f, np.float32, count).reshape(dims) + if dim > 2: + data = np.transpose(data, (2, 1, 0)) + data = np.transpose(data, (1, 0, 2)) + + return data + + +def writeFloat(name, data): + f = open(name, 'wb') + + dim=len(data.shape) + if dim>3: + raise Exception('bad float file dimension: %d' % dim) + + f.write(('float\n').encode('ascii')) + f.write(('%d\n' % dim).encode('ascii')) + + if dim == 1: + f.write(('%d\n' % data.shape[0]).encode('ascii')) + else: + f.write(('%d\n' % data.shape[1]).encode('ascii')) + f.write(('%d\n' % data.shape[0]).encode('ascii')) + for i in range(2, dim): + f.write(('%d\n' % data.shape[i]).encode('ascii')) + + data = data.astype(np.float32) + if dim==2: + data.tofile(f) + + else: + np.transpose(data, (2, 0, 1)).tofile(f) + + +def check_dim_and_resize(tensor_list): + shape_list = [] + for t in tensor_list: + shape_list.append(t.shape[2:]) + + if len(set(shape_list)) > 1: + desired_shape = shape_list[0] + print(f'Inconsistent size of input video frames. All frames will be resized to {desired_shape}') + + resize_tensor_list = [] + for t in tensor_list: + resize_tensor_list.append(torch.nn.functional.interpolate(t, size=tuple(desired_shape), mode='bilinear')) + + tensor_list = resize_tensor_list + + return tensor_list + +parser = argparse.ArgumentParser( + prog = 'AMT', + description = 'Flow generation', + ) +parser.add_argument('-r', '--root', default='../data/vimeo_triplet') +args = parser.parse_args() + +vimeo90k_dir = args.root +vimeo90k_sequences_dir = osp.join(vimeo90k_dir, 'sequences') +vimeo90k_flow_dir = osp.join(vimeo90k_dir, 'flow') + +def pred_flow(img1, img2): + img1 = torch.from_numpy(img1).float().permute(2, 0, 1) / 255.0 + img2 = torch.from_numpy(img2).float().permute(2, 0, 1) / 255.0 + + flow = estimate(img1, img2) + + flow = flow.permute(1, 2, 0).cpu().numpy() + return flow + +print('Built Flow Path') +if not osp.exists(vimeo90k_flow_dir): + os.makedirs(vimeo90k_flow_dir) + +for sequences_path in sorted(os.listdir(vimeo90k_sequences_dir)): + vimeo90k_sequences_path_dir = osp.join(vimeo90k_sequences_dir, sequences_path) + vimeo90k_flow_path_dir = osp.join(vimeo90k_flow_dir, sequences_path) + if not osp.exists(vimeo90k_flow_path_dir): + os.mkdir(vimeo90k_flow_path_dir) + + for sequences_id in sorted(os.listdir(vimeo90k_sequences_path_dir)): + vimeo90k_flow_id_dir = osp.join(vimeo90k_flow_path_dir, sequences_id) + if not osp.exists(vimeo90k_flow_id_dir): + os.mkdir(vimeo90k_flow_id_dir) + +for sequences_path in sorted(os.listdir(vimeo90k_sequences_dir)): + vimeo90k_sequences_path_dir = os.path.join(vimeo90k_sequences_dir, sequences_path) + vimeo90k_flow_path_dir = os.path.join(vimeo90k_flow_dir, sequences_path) + + for sequences_id in sorted(os.listdir(vimeo90k_sequences_path_dir)): + vimeo90k_sequences_id_dir = os.path.join(vimeo90k_sequences_path_dir, sequences_id) + vimeo90k_flow_id_dir = os.path.join(vimeo90k_flow_path_dir, sequences_id) + + img0_path = vimeo90k_sequences_id_dir + '/im1.png' + imgt_path = vimeo90k_sequences_id_dir + '/im2.png' + img1_path = vimeo90k_sequences_id_dir + '/im3.png' + flow_t0_path = vimeo90k_flow_id_dir + '/flow_t0.flo' + flow_t1_path = vimeo90k_flow_id_dir + '/flow_t1.flo' + + img0 = read(img0_path) + imgt = read(imgt_path) + img1 = read(img1_path) + + flow_t0 = pred_flow(imgt, img0) + flow_t1 = pred_flow(imgt, img1) + + write(flow_t0_path, flow_t0) + write(flow_t1_path, flow_t1) + + print('Written Sequences {}'.format(sequences_path)) + + diff --git a/utils/flow_generation/liteflownet/README.md b/utils/flow_generation/liteflownet/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9511ad984f0209048ad912250b611f8e0459668b --- /dev/null +++ b/utils/flow_generation/liteflownet/README.md @@ -0,0 +1,45 @@ +# pytorch-liteflownet +This is a personal reimplementation of LiteFlowNet [1] using PyTorch. Should you be making use of this work, please cite the paper accordingly. Also, make sure to adhere to the licensing terms of the authors. Should you be making use of this particular implementation, please acknowledge it appropriately [2]. + +Paper + +For the original Caffe version of this work, please see: https://github.com/twhui/LiteFlowNet +
+Other optical flow implementations from me: [pytorch-pwc](https://github.com/sniklaus/pytorch-pwc), [pytorch-unflow](https://github.com/sniklaus/pytorch-unflow), [pytorch-spynet](https://github.com/sniklaus/pytorch-spynet) + +## setup +The correlation layer is implemented in CUDA using CuPy, which is why CuPy is a required dependency. It can be installed using `pip install cupy` or alternatively using one of the provided [binary packages](https://docs.cupy.dev/en/stable/install.html#installing-cupy) as outlined in the CuPy repository. If you would like to use Docker, you can take a look at [this](https://github.com/sniklaus/pytorch-liteflownet/pull/43) pull request to get started. + +## usage +To run it on your own pair of images, use the following command. You can choose between three models, please make sure to see their paper / the code for more details. + +``` +python run.py --model default --one ./images/one.png --two ./images/two.png --out ./out.flo +``` + +I am afraid that I cannot guarantee that this reimplementation is correct. However, it produced results pretty much identical to the implementation of the original authors in the examples that I tried. There are some numerical deviations that stem from differences in the `DownsampleLayer` of Caffe and the `torch.nn.functional.interpolate` function of PyTorch. Please feel free to contribute to this repository by submitting issues and pull requests. + +## comparison +

Comparison

+ +## license +As stated in the licensing terms of the authors of the paper, their material is provided for research purposes only. Please make sure to further consult their licensing terms. + +## references +``` +[1] @inproceedings{Hui_CVPR_2018, + author = {Tak-Wai Hui and Xiaoou Tang and Chen Change Loy}, + title = {{LiteFlowNet}: A Lightweight Convolutional Neural Network for Optical Flow Estimation}, + booktitle = {IEEE Conference on Computer Vision and Pattern Recognition}, + year = {2018} + } +``` + +``` +[2] @misc{pytorch-liteflownet, + author = {Simon Niklaus}, + title = {A Reimplementation of {LiteFlowNet} Using {PyTorch}}, + year = {2019}, + howpublished = {\url{https://github.com/sniklaus/pytorch-liteflownet}} + } +``` \ No newline at end of file diff --git a/utils/flow_generation/liteflownet/__init__.py b/utils/flow_generation/liteflownet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/flow_generation/liteflownet/correlation/README.md b/utils/flow_generation/liteflownet/correlation/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e80f923bfa484ff505366c30f66fa88da0bfd566 --- /dev/null +++ b/utils/flow_generation/liteflownet/correlation/README.md @@ -0,0 +1 @@ +This is an adaptation of the FlowNet2 implementation in order to compute cost volumes. Should you be making use of this work, please make sure to adhere to the licensing terms of the original authors. Should you be making use or modify this particular implementation, please acknowledge it appropriately. \ No newline at end of file diff --git a/utils/flow_generation/liteflownet/correlation/correlation.py b/utils/flow_generation/liteflownet/correlation/correlation.py new file mode 100644 index 0000000000000000000000000000000000000000..212af7103a8bffd024cf7e8e43c4a96997157f53 --- /dev/null +++ b/utils/flow_generation/liteflownet/correlation/correlation.py @@ -0,0 +1,396 @@ +#!/usr/bin/env python + +import cupy +import math +import re +import torch + +kernel_Correlation_rearrange = ''' + extern "C" __global__ void kernel_Correlation_rearrange( + const int n, + const float* input, + float* output + ) { + int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; + if (intIndex >= n) { + return; + } + int intSample = blockIdx.z; + int intChannel = blockIdx.y; + float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex]; + __syncthreads(); + int intPaddedY = (intIndex / SIZE_3(input)) + 3*{{intStride}}; + int intPaddedX = (intIndex % SIZE_3(input)) + 3*{{intStride}}; + int intRearrange = ((SIZE_3(input) + 6*{{intStride}}) * intPaddedY) + intPaddedX; + output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue; + } +''' + +kernel_Correlation_updateOutput = ''' + extern "C" __global__ void kernel_Correlation_updateOutput( + const int n, + const float* rbot0, + const float* rbot1, + float* top + ) { + extern __shared__ char patch_data_char[]; + + float *patch_data = (float *)patch_data_char; + + // First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1 + int x1 = (blockIdx.x + 3) * {{intStride}}; + int y1 = (blockIdx.y + 3) * {{intStride}}; + int item = blockIdx.z; + int ch_off = threadIdx.x; + + // Load 3D patch into shared shared memory + for (int j = 0; j < 1; j++) { // HEIGHT + for (int i = 0; i < 1; i++) { // WIDTH + int ji_off = (j + i) * SIZE_3(rbot0); + for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS + int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch; + int idxPatchData = ji_off + ch; + patch_data[idxPatchData] = rbot0[idx1]; + } + } + } + + __syncthreads(); + + __shared__ float sum[32]; + + // Compute correlation + for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) { + sum[ch_off] = 0; + + int s2o = (top_channel % 7 - 3) * {{intStride}}; + int s2p = (top_channel / 7 - 3) * {{intStride}}; + + for (int j = 0; j < 1; j++) { // HEIGHT + for (int i = 0; i < 1; i++) { // WIDTH + int ji_off = (j + i) * SIZE_3(rbot0); + for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS + int x2 = x1 + s2o; + int y2 = y1 + s2p; + + int idxPatchData = ji_off + ch; + int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch; + + sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2]; + } + } + } + + __syncthreads(); + + if (ch_off == 0) { + float total_sum = 0; + for (int idx = 0; idx < 32; idx++) { + total_sum += sum[idx]; + } + const int sumelems = SIZE_3(rbot0); + const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x; + top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems; + } + } + } +''' + +kernel_Correlation_updateGradOne = ''' + #define ROUND_OFF 50000 + extern "C" __global__ void kernel_Correlation_updateGradOne( + const int n, + const int intSample, + const float* rbot0, + const float* rbot1, + const float* gradOutput, + float* gradOne, + float* gradTwo + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + int n = intIndex % SIZE_1(gradOne); // channels + int l = (intIndex / SIZE_1(gradOne)) % SIZE_3(gradOne) + 3*{{intStride}}; // w-pos + int m = (intIndex / SIZE_1(gradOne) / SIZE_3(gradOne)) % SIZE_2(gradOne) + 3*{{intStride}}; // h-pos + + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = {{intStride}} * round_off; + + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 3*{{intStride}} + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}}) / {{intStride}} + int ymin = (m - 3*{{intStride}} + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}}) / {{intStride}} + + // Same here: + int xmax = (l - 3*{{intStride}} + round_off_s1) / {{intStride}} - round_off; // floor (l - 3*{{intStride}}) / {{intStride}} + int ymax = (m - 3*{{intStride}} + round_off_s1) / {{intStride}} - round_off; // floor (m - 3*{{intStride}}) / {{intStride}} + + float sum = 0; + if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { + xmin = max(0,xmin); + xmax = min(SIZE_3(gradOutput)-1,xmax); + + ymin = max(0,ymin); + ymax = min(SIZE_2(gradOutput)-1,ymax); + + for (int p = -3; p <= 3; p++) { + for (int o = -3; o <= 3; o++) { + // Get rbot1 data: + int s2o = {{intStride}} * o; + int s2p = {{intStride}} * p; + int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n; + float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n] + + // Index offset for gradOutput in following loops: + int op = (p+3) * 7 + (o+3); // index[o,p] + int idxopoffset = (intSample * SIZE_1(gradOutput) + op); + + for (int y = ymin; y <= ymax; y++) { + for (int x = xmin; x <= xmax; x++) { + int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] + sum += gradOutput[idxgradOutput] * bot1tmp; + } + } + } + } + } + const int sumelems = SIZE_1(gradOne); + const int bot0index = ((n * SIZE_2(gradOne)) + (m-3*{{intStride}})) * SIZE_3(gradOne) + (l-3*{{intStride}}); + gradOne[bot0index + intSample*SIZE_1(gradOne)*SIZE_2(gradOne)*SIZE_3(gradOne)] = sum / (float)sumelems; + } } +''' + +kernel_Correlation_updateGradTwo = ''' + #define ROUND_OFF 50000 + extern "C" __global__ void kernel_Correlation_updateGradTwo( + const int n, + const int intSample, + const float* rbot0, + const float* rbot1, + const float* gradOutput, + float* gradOne, + float* gradTwo + ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { + int n = intIndex % SIZE_1(gradTwo); // channels + int l = (intIndex / SIZE_1(gradTwo)) % SIZE_3(gradTwo) + 3*{{intStride}}; // w-pos + int m = (intIndex / SIZE_1(gradTwo) / SIZE_3(gradTwo)) % SIZE_2(gradTwo) + 3*{{intStride}}; // h-pos + + // round_off is a trick to enable integer division with ceil, even for negative numbers + // We use a large offset, for the inner part not to become negative. + const int round_off = ROUND_OFF; + const int round_off_s1 = {{intStride}} * round_off; + + float sum = 0; + for (int p = -3; p <= 3; p++) { + for (int o = -3; o <= 3; o++) { + int s2o = {{intStride}} * o; + int s2p = {{intStride}} * p; + + //Get X,Y ranges and clamp + // We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior: + int xmin = (l - 3*{{intStride}} - s2o + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}} - s2o) / {{intStride}} + int ymin = (m - 3*{{intStride}} - s2p + round_off_s1 - 1) / {{intStride}} + 1 - round_off; // ceil (l - 3*{{intStride}} - s2o) / {{intStride}} + + // Same here: + int xmax = (l - 3*{{intStride}} - s2o + round_off_s1) / {{intStride}} - round_off; // floor (l - 3*{{intStride}} - s2o) / {{intStride}} + int ymax = (m - 3*{{intStride}} - s2p + round_off_s1) / {{intStride}} - round_off; // floor (m - 3*{{intStride}} - s2p) / {{intStride}} + + if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) { + xmin = max(0,xmin); + xmax = min(SIZE_3(gradOutput)-1,xmax); + + ymin = max(0,ymin); + ymax = min(SIZE_2(gradOutput)-1,ymax); + + // Get rbot0 data: + int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n; + float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n] + + // Index offset for gradOutput in following loops: + int op = (p+3) * 7 + (o+3); // index[o,p] + int idxopoffset = (intSample * SIZE_1(gradOutput) + op); + + for (int y = ymin; y <= ymax; y++) { + for (int x = xmin; x <= xmax; x++) { + int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p] + sum += gradOutput[idxgradOutput] * bot0tmp; + } + } + } + } + } + const int sumelems = SIZE_1(gradTwo); + const int bot1index = ((n * SIZE_2(gradTwo)) + (m-3*{{intStride}})) * SIZE_3(gradTwo) + (l-3*{{intStride}}); + gradTwo[bot1index + intSample*SIZE_1(gradTwo)*SIZE_2(gradTwo)*SIZE_3(gradTwo)] = sum / (float)sumelems; + } } +''' + +def cupy_kernel(strFunction, objVariables): + strKernel = globals()[strFunction].replace('{{intStride}}', str(objVariables['intStride'])) + + while True: + objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) + + if objMatch is None: + break + # end + + intArg = int(objMatch.group(2)) + + strTensor = objMatch.group(4) + intSizes = objVariables[strTensor].size() + + strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item())) + # end + + while True: + objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) + + if objMatch is None: + break + # end + + intArgs = int(objMatch.group(2)) + strArgs = objMatch.group(4).split(',') + + strTensor = strArgs[0] + intStrides = objVariables[strTensor].stride() + strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')' for intArg in range(intArgs) ] + + strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') + # end + + return strKernel +# end + +@cupy.memoize(for_each_device=True) +def cupy_launch(strFunction, strKernel): + return cupy.cuda.compile_with_cache(strKernel).get_function(strFunction) +# end + +class _FunctionCorrelation(torch.autograd.Function): + @staticmethod + def forward(self, one, two, intStride): + rbot0 = one.new_zeros([ one.shape[0], one.shape[2] + (6 * intStride), one.shape[3] + (6 * intStride), one.shape[1] ]) + rbot1 = one.new_zeros([ one.shape[0], one.shape[2] + (6 * intStride), one.shape[3] + (6 * intStride), one.shape[1] ]) + + self.intStride = intStride + + one = one.contiguous(); assert(one.is_cuda == True) + two = two.contiguous(); assert(two.is_cuda == True) + + output = one.new_zeros([ one.shape[0], 49, int(math.ceil(one.shape[2] / intStride)), int(math.ceil(one.shape[3] / intStride)) ]) + + if one.is_cuda == True: + n = one.shape[2] * one.shape[3] + cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { + 'intStride': self.intStride, + 'input': one, + 'output': rbot0 + }))( + grid=tuple([ int((n + 16 - 1) / 16), one.shape[1], one.shape[0] ]), + block=tuple([ 16, 1, 1 ]), + args=[ cupy.int32(n), one.data_ptr(), rbot0.data_ptr() ] + ) + + n = two.shape[2] * two.shape[3] + cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', { + 'intStride': self.intStride, + 'input': two, + 'output': rbot1 + }))( + grid=tuple([ int((n + 16 - 1) / 16), two.shape[1], two.shape[0] ]), + block=tuple([ 16, 1, 1 ]), + args=[ cupy.int32(n), two.data_ptr(), rbot1.data_ptr() ] + ) + + n = output.shape[1] * output.shape[2] * output.shape[3] + cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', { + 'intStride': self.intStride, + 'rbot0': rbot0, + 'rbot1': rbot1, + 'top': output + }))( + grid=tuple([ output.shape[3], output.shape[2], output.shape[0] ]), + block=tuple([ 32, 1, 1 ]), + shared_mem=one.shape[1] * 4, + args=[ cupy.int32(n), rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr() ] + ) + + elif one.is_cuda == False: + raise NotImplementedError() + + # end + + self.save_for_backward(one, two, rbot0, rbot1) + + return output + # end + + @staticmethod + def backward(self, gradOutput): + one, two, rbot0, rbot1 = self.saved_tensors + + gradOutput = gradOutput.contiguous(); assert(gradOutput.is_cuda == True) + + gradOne = one.new_zeros([ one.shape[0], one.shape[1], one.shape[2], one.shape[3] ]) if self.needs_input_grad[0] == True else None + gradTwo = one.new_zeros([ one.shape[0], one.shape[1], one.shape[2], one.shape[3] ]) if self.needs_input_grad[1] == True else None + + if one.is_cuda == True: + if gradOne is not None: + for intSample in range(one.shape[0]): + n = one.shape[1] * one.shape[2] * one.shape[3] + cupy_launch('kernel_Correlation_updateGradOne', cupy_kernel('kernel_Correlation_updateGradOne', { + 'intStride': self.intStride, + 'rbot0': rbot0, + 'rbot1': rbot1, + 'gradOutput': gradOutput, + 'gradOne': gradOne, + 'gradTwo': None + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ cupy.int32(n), intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), gradOne.data_ptr(), None ] + ) + # end + # end + + if gradTwo is not None: + for intSample in range(one.shape[0]): + n = one.shape[1] * one.shape[2] * one.shape[3] + cupy_launch('kernel_Correlation_updateGradTwo', cupy_kernel('kernel_Correlation_updateGradTwo', { + 'intStride': self.intStride, + 'rbot0': rbot0, + 'rbot1': rbot1, + 'gradOutput': gradOutput, + 'gradOne': None, + 'gradTwo': gradTwo + }))( + grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), + block=tuple([ 512, 1, 1 ]), + args=[ cupy.int32(n), intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None, gradTwo.data_ptr() ] + ) + # end + # end + + elif one.is_cuda == False: + raise NotImplementedError() + + # end + + return gradOne, gradTwo, None + # end +# end + +def FunctionCorrelation(tenOne, tenTwo, intStride): + return _FunctionCorrelation.apply(tenOne, tenTwo, intStride) +# end + +class ModuleCorrelation(torch.nn.Module): + def __init__(self): + super().__init__() + # end + + def forward(self, tenOne, tenTwo, intStride): + return _FunctionCorrelation.apply(tenOne, tenTwo, intStride) + # end +# end \ No newline at end of file diff --git a/utils/flow_generation/liteflownet/run.py b/utils/flow_generation/liteflownet/run.py new file mode 100644 index 0000000000000000000000000000000000000000..1957621f3bd9cae2651f8767466f5c1542df3299 --- /dev/null +++ b/utils/flow_generation/liteflownet/run.py @@ -0,0 +1,385 @@ +#!/usr/bin/env python + +import getopt +import math +import numpy +import PIL +import PIL.Image +import sys +import torch + +try: + from .correlation import correlation # the custom cost volume layer +except: + sys.path.insert(0, './correlation'); import correlation # you should consider upgrading python +# end + +########################################################## + +assert(int(str('').join(torch.__version__.split('.')[0:2])) >= 13) # requires at least pytorch version 1.3.0 + +torch.set_grad_enabled(False) # make sure to not compute gradients for computational performance + +torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance + +########################################################## + +arguments_strModel = 'default' # 'default', or 'kitti', or 'sintel' +arguments_strOne = './images/one.png' +arguments_strTwo = './images/two.png' +arguments_strOut = './out.flo' + +for strOption, strArgument in getopt.getopt(sys.argv[1:], '', [ strParameter[2:] + '=' for strParameter in sys.argv[1::2] ])[0]: + if strOption == '--model' and strArgument != '': arguments_strModel = strArgument # which model to use + if strOption == '--one' and strArgument != '': arguments_strOne = strArgument # path to the first frame + if strOption == '--two' and strArgument != '': arguments_strTwo = strArgument # path to the second frame + if strOption == '--out' and strArgument != '': arguments_strOut = strArgument # path to where the output should be stored +# end + +########################################################## + +backwarp_tenGrid = {} + +def backwarp(tenInput, tenFlow): + if str(tenFlow.shape) not in backwarp_tenGrid: + tenHor = torch.linspace(-1.0 + (1.0 / tenFlow.shape[3]), 1.0 - (1.0 / tenFlow.shape[3]), tenFlow.shape[3]).view(1, 1, 1, -1).repeat(1, 1, tenFlow.shape[2], 1) + tenVer = torch.linspace(-1.0 + (1.0 / tenFlow.shape[2]), 1.0 - (1.0 / tenFlow.shape[2]), tenFlow.shape[2]).view(1, 1, -1, 1).repeat(1, 1, 1, tenFlow.shape[3]) + + backwarp_tenGrid[str(tenFlow.shape)] = torch.cat([ tenHor, tenVer ], 1).cuda() + # end + + tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1) + + return torch.nn.functional.grid_sample(input=tenInput, grid=(backwarp_tenGrid[str(tenFlow.shape)] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=False) +# end + +########################################################## + +class Network(torch.nn.Module): + def __init__(self): + super().__init__() + + class Features(torch.nn.Module): + def __init__(self): + super().__init__() + + self.netOne = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=3, out_channels=32, kernel_size=7, stride=1, padding=3), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netTwo = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netThr = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netFou = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=2, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netFiv = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=96, out_channels=128, kernel_size=3, stride=2, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + self.netSix = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=128, out_channels=192, kernel_size=3, stride=2, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + # end + + def forward(self, tenInput): + tenOne = self.netOne(tenInput) + tenTwo = self.netTwo(tenOne) + tenThr = self.netThr(tenTwo) + tenFou = self.netFou(tenThr) + tenFiv = self.netFiv(tenFou) + tenSix = self.netSix(tenFiv) + + return [ tenOne, tenTwo, tenThr, tenFou, tenFiv, tenSix ] + # end + # end + + class Matching(torch.nn.Module): + def __init__(self, intLevel): + super().__init__() + + self.fltBackwarp = [ 0.0, 0.0, 10.0, 5.0, 2.5, 1.25, 0.625 ][intLevel] + + if intLevel != 2: + self.netFeat = torch.nn.Sequential() + + elif intLevel == 2: + self.netFeat = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=1, stride=1, padding=0), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + # end + + if intLevel == 6: + self.netUpflow = None + + elif intLevel != 6: + self.netUpflow = torch.nn.ConvTranspose2d(in_channels=2, out_channels=2, kernel_size=4, stride=2, padding=1, bias=False, groups=2) + + # end + + if intLevel >= 4: + self.netUpcorr = None + + elif intLevel < 4: + self.netUpcorr = torch.nn.ConvTranspose2d(in_channels=49, out_channels=49, kernel_size=4, stride=2, padding=1, bias=False, groups=49) + + # end + + self.netMain = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=49, out_channels=128, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=[ 0, 0, 7, 5, 5, 3, 3 ][intLevel], stride=1, padding=[ 0, 0, 3, 2, 2, 1, 1 ][intLevel]) + ) + # end + + def forward(self, tenOne, tenTwo, tenFeaturesOne, tenFeaturesTwo, tenFlow): + tenFeaturesOne = self.netFeat(tenFeaturesOne) + tenFeaturesTwo = self.netFeat(tenFeaturesTwo) + + if tenFlow is not None: + tenFlow = self.netUpflow(tenFlow) + # end + + if tenFlow is not None: + tenFeaturesTwo = backwarp(tenInput=tenFeaturesTwo, tenFlow=tenFlow * self.fltBackwarp) + # end + + if self.netUpcorr is None: + tenCorrelation = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tenOne=tenFeaturesOne, tenTwo=tenFeaturesTwo, intStride=1), negative_slope=0.1, inplace=False) + + elif self.netUpcorr is not None: + tenCorrelation = self.netUpcorr(torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tenOne=tenFeaturesOne, tenTwo=tenFeaturesTwo, intStride=2), negative_slope=0.1, inplace=False)) + + # end + + return (tenFlow if tenFlow is not None else 0.0) + self.netMain(tenCorrelation) + # end + # end + + class Subpixel(torch.nn.Module): + def __init__(self, intLevel): + super().__init__() + + self.fltBackward = [ 0.0, 0.0, 10.0, 5.0, 2.5, 1.25, 0.625 ][intLevel] + + if intLevel != 2: + self.netFeat = torch.nn.Sequential() + + elif intLevel == 2: + self.netFeat = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=1, stride=1, padding=0), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + # end + + self.netMain = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=[ 0, 0, 130, 130, 194, 258, 386 ][intLevel], out_channels=128, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=[ 0, 0, 7, 5, 5, 3, 3 ][intLevel], stride=1, padding=[ 0, 0, 3, 2, 2, 1, 1 ][intLevel]) + ) + # end + + def forward(self, tenOne, tenTwo, tenFeaturesOne, tenFeaturesTwo, tenFlow): + tenFeaturesOne = self.netFeat(tenFeaturesOne) + tenFeaturesTwo = self.netFeat(tenFeaturesTwo) + + if tenFlow is not None: + tenFeaturesTwo = backwarp(tenInput=tenFeaturesTwo, tenFlow=tenFlow * self.fltBackward) + # end + + return (tenFlow if tenFlow is not None else 0.0) + self.netMain(torch.cat([ tenFeaturesOne, tenFeaturesTwo, tenFlow ], 1)) + # end + # end + + class Regularization(torch.nn.Module): + def __init__(self, intLevel): + super().__init__() + + self.fltBackward = [ 0.0, 0.0, 10.0, 5.0, 2.5, 1.25, 0.625 ][intLevel] + + self.intUnfold = [ 0, 0, 7, 5, 5, 3, 3 ][intLevel] + + if intLevel >= 5: + self.netFeat = torch.nn.Sequential() + + elif intLevel < 5: + self.netFeat = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=[ 0, 0, 32, 64, 96, 128, 192 ][intLevel], out_channels=128, kernel_size=1, stride=1, padding=0), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + # end + + self.netMain = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=[ 0, 0, 131, 131, 131, 131, 195 ][intLevel], out_channels=128, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), + torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1), + torch.nn.LeakyReLU(inplace=False, negative_slope=0.1) + ) + + if intLevel >= 5: + self.netDist = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=32, out_channels=[ 0, 0, 49, 25, 25, 9, 9 ][intLevel], kernel_size=[ 0, 0, 7, 5, 5, 3, 3 ][intLevel], stride=1, padding=[ 0, 0, 3, 2, 2, 1, 1 ][intLevel]) + ) + + elif intLevel < 5: + self.netDist = torch.nn.Sequential( + torch.nn.Conv2d(in_channels=32, out_channels=[ 0, 0, 49, 25, 25, 9, 9 ][intLevel], kernel_size=([ 0, 0, 7, 5, 5, 3, 3 ][intLevel], 1), stride=1, padding=([ 0, 0, 3, 2, 2, 1, 1 ][intLevel], 0)), + torch.nn.Conv2d(in_channels=[ 0, 0, 49, 25, 25, 9, 9 ][intLevel], out_channels=[ 0, 0, 49, 25, 25, 9, 9 ][intLevel], kernel_size=(1, [ 0, 0, 7, 5, 5, 3, 3 ][intLevel]), stride=1, padding=(0, [ 0, 0, 3, 2, 2, 1, 1 ][intLevel])) + ) + + # end + + self.netScaleX = torch.nn.Conv2d(in_channels=[ 0, 0, 49, 25, 25, 9, 9 ][intLevel], out_channels=1, kernel_size=1, stride=1, padding=0) + self.netScaleY = torch.nn.Conv2d(in_channels=[ 0, 0, 49, 25, 25, 9, 9 ][intLevel], out_channels=1, kernel_size=1, stride=1, padding=0) + # eny + + def forward(self, tenOne, tenTwo, tenFeaturesOne, tenFeaturesTwo, tenFlow): + tenDifference = ((tenOne - backwarp(tenInput=tenTwo, tenFlow=tenFlow * self.fltBackward)) ** 2).sum(1, True).sqrt().detach() + + tenDist = self.netDist(self.netMain(torch.cat([ tenDifference, tenFlow - tenFlow.view(tenFlow.shape[0], 2, -1).mean(2, True).view(tenFlow.shape[0], 2, 1, 1), self.netFeat(tenFeaturesOne) ], 1))) + tenDist = (tenDist ** 2).neg() + tenDist = (tenDist - tenDist.max(1, True)[0]).exp() + + tenDivisor = tenDist.sum(1, True).reciprocal() + + tenScaleX = self.netScaleX(tenDist * torch.nn.functional.unfold(input=tenFlow[:, 0:1, :, :], kernel_size=self.intUnfold, stride=1, padding=int((self.intUnfold - 1) / 2)).view_as(tenDist)) * tenDivisor + tenScaleY = self.netScaleY(tenDist * torch.nn.functional.unfold(input=tenFlow[:, 1:2, :, :], kernel_size=self.intUnfold, stride=1, padding=int((self.intUnfold - 1) / 2)).view_as(tenDist)) * tenDivisor + + return torch.cat([ tenScaleX, tenScaleY ], 1) + # end + # end + + self.netFeatures = Features() + self.netMatching = torch.nn.ModuleList([ Matching(intLevel) for intLevel in [ 2, 3, 4, 5, 6 ] ]) + self.netSubpixel = torch.nn.ModuleList([ Subpixel(intLevel) for intLevel in [ 2, 3, 4, 5, 6 ] ]) + self.netRegularization = torch.nn.ModuleList([ Regularization(intLevel) for intLevel in [ 2, 3, 4, 5, 6 ] ]) + + self.load_state_dict({ strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.hub.load_state_dict_from_url(url='http://content.sniklaus.com/github/pytorch-liteflownet/network-' + arguments_strModel + '.pytorch').items() }) + # self.load_state_dict(torch.load('./liteflownet/network-default.pth')) + # end + + def forward(self, tenOne, tenTwo): + tenOne[:, 0, :, :] = tenOne[:, 0, :, :] - 0.411618 + tenOne[:, 1, :, :] = tenOne[:, 1, :, :] - 0.434631 + tenOne[:, 2, :, :] = tenOne[:, 2, :, :] - 0.454253 + + tenTwo[:, 0, :, :] = tenTwo[:, 0, :, :] - 0.410782 + tenTwo[:, 1, :, :] = tenTwo[:, 1, :, :] - 0.433645 + tenTwo[:, 2, :, :] = tenTwo[:, 2, :, :] - 0.452793 + + tenFeaturesOne = self.netFeatures(tenOne) + tenFeaturesTwo = self.netFeatures(tenTwo) + + tenOne = [ tenOne ] + tenTwo = [ tenTwo ] + + for intLevel in [ 1, 2, 3, 4, 5 ]: + tenOne.append(torch.nn.functional.interpolate(input=tenOne[-1], size=(tenFeaturesOne[intLevel].shape[2], tenFeaturesOne[intLevel].shape[3]), mode='bilinear', align_corners=False)) + tenTwo.append(torch.nn.functional.interpolate(input=tenTwo[-1], size=(tenFeaturesTwo[intLevel].shape[2], tenFeaturesTwo[intLevel].shape[3]), mode='bilinear', align_corners=False)) + # end + + tenFlow = None + + for intLevel in [ -1, -2, -3, -4, -5 ]: + tenFlow = self.netMatching[intLevel](tenOne[intLevel], tenTwo[intLevel], tenFeaturesOne[intLevel], tenFeaturesTwo[intLevel], tenFlow) + tenFlow = self.netSubpixel[intLevel](tenOne[intLevel], tenTwo[intLevel], tenFeaturesOne[intLevel], tenFeaturesTwo[intLevel], tenFlow) + tenFlow = self.netRegularization[intLevel](tenOne[intLevel], tenTwo[intLevel], tenFeaturesOne[intLevel], tenFeaturesTwo[intLevel], tenFlow) + # end + + return tenFlow * 20.0 + # end +# end + +netNetwork = None + +########################################################## + +def estimate(tenOne, tenTwo): + global netNetwork + + if netNetwork is None: + netNetwork = Network().cuda().eval() + # end + + assert(tenOne.shape[1] == tenTwo.shape[1]) + assert(tenOne.shape[2] == tenTwo.shape[2]) + + intWidth = tenOne.shape[2] + intHeight = tenOne.shape[1] + + # assert(intWidth == 1024) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue + # assert(intHeight == 436) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue + + tenPreprocessedOne = tenOne.cuda().view(1, 3, intHeight, intWidth) + tenPreprocessedTwo = tenTwo.cuda().view(1, 3, intHeight, intWidth) + + intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 32.0) * 32.0)) + intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 32.0) * 32.0)) + + tenPreprocessedOne = torch.nn.functional.interpolate(input=tenPreprocessedOne, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False) + tenPreprocessedTwo = torch.nn.functional.interpolate(input=tenPreprocessedTwo, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False) + + tenFlow = torch.nn.functional.interpolate(input=netNetwork(tenPreprocessedOne, tenPreprocessedTwo), size=(intHeight, intWidth), mode='bilinear', align_corners=False) + + tenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth) + tenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight) + + return tenFlow[0, :, :, :].cpu() +# end + +########################################################## + +if __name__ == '__main__': + tenOne = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(PIL.Image.open(arguments_strOne))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0))) + tenTwo = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(PIL.Image.open(arguments_strTwo))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0))) + + tenOutput = estimate(tenOne, tenTwo) + + objOutput = open(arguments_strOut, 'wb') + + numpy.array([ 80, 73, 69, 72 ], numpy.uint8).tofile(objOutput) + numpy.array([ tenOutput.shape[2], tenOutput.shape[1] ], numpy.int32).tofile(objOutput) + numpy.array(tenOutput.numpy().transpose(1, 2, 0), numpy.float32).tofile(objOutput) + + objOutput.close() +# end \ No newline at end of file diff --git a/utils/flow_visualization.py b/utils/flow_visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..a02266417c51da767c5d9a670bb6d9728220cba4 --- /dev/null +++ b/utils/flow_visualization.py @@ -0,0 +1,275 @@ +# MIT License +# +# Copyright (c) 2018 Tom Runia +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to conditions. +# +# Author: Tom Runia +# Date Created: 2018-08-03 + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from PIL import Image + + +def make_colorwheel(): + ''' + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + ''' + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = np.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) + col = col + RY + # YG + colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) + colorwheel[col:col + YG, 1] = 255 + col = col + YG + # GC + colorwheel[col:col + GC, 1] = 255 + colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) + col = col + GC + # CB + colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) + colorwheel[col:col + CB, 2] = 255 + col = col + CB + # BM + colorwheel[col:col + BM, 2] = 255 + colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) + col = col + BM + # MR + colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) + colorwheel[col:col + MR, 0] = 255 + return colorwheel + + +def flow_compute_color(u, v, convert_to_bgr=False): + ''' + Applies the flow color wheel to (possibly clipped) flow components u and v. + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + :param u: np.ndarray, input horizontal flow + :param v: np.ndarray, input vertical flow + :param convert_to_bgr: bool, whether to change ordering and output BGR instead of RGB + :return: + ''' + + flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) + + colorwheel = make_colorwheel() # shape [55x3] + ncols = colorwheel.shape[0] + + rad = np.sqrt(np.square(u) + np.square(v)) + a = np.arctan2(-v, -u) / np.pi + + fk = (a + 1) / 2 * (ncols - 1) + 1 + k0 = np.floor(fk).astype(np.int32) + k1 = k0 + 1 + k1[k1 == ncols] = 1 + f = fk - k0 + + for i in range(colorwheel.shape[1]): + tmp = colorwheel[:, i] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1 - f) * col0 + f * col1 + + idx = (rad <= 1) + col[idx] = 1 - rad[idx] * (1 - col[idx]) + col[~idx] = col[~idx] * 0.75 # out of range? + + # Note the 2-i => BGR instead of RGB + ch_idx = 2 - i if convert_to_bgr else i + flow_image[:, :, ch_idx] = np.floor(255 * col) + + return flow_image + + +def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False): + ''' + Expects a two dimensional flow image of shape [H,W,2] + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + :param flow_uv: np.ndarray of shape [H,W,2] + :param clip_flow: float, maximum clipping value for flow + :return: + ''' + + assert flow_uv.ndim == 3, 'input flow must have three dimensions' + assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' + + if clip_flow is not None: + flow_uv = np.clip(flow_uv, 0, clip_flow) + + u = flow_uv[:, :, 0] + v = flow_uv[:, :, 1] + + rad = np.sqrt(np.square(u) + np.square(v)) + rad_max = np.max(rad) + + epsilon = 1e-5 + u = u / (rad_max + epsilon) + v = v / (rad_max + epsilon) + + return flow_compute_color(u, v, convert_to_bgr) + + +UNKNOWN_FLOW_THRESH = 1e7 +SMALLFLOW = 0.0 +LARGEFLOW = 1e8 + + +def make_color_wheel(): + """ + Generate color wheel according Middlebury color code + :return: Color wheel + """ + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + + colorwheel = np.zeros([ncols, 3]) + + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY)) + col += RY + + # YG + colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG)) + colorwheel[col:col + YG, 1] = 255 + col += YG + + # GC + colorwheel[col:col + GC, 1] = 255 + colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC)) + col += GC + + # CB + colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB)) + colorwheel[col:col + CB, 2] = 255 + col += CB + + # BM + colorwheel[col:col + BM, 2] = 255 + colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM)) + col += + BM + + # MR + colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR)) + colorwheel[col:col + MR, 0] = 255 + + return colorwheel + + +def compute_color(u, v): + """ + compute optical flow color map + :param u: optical flow horizontal map + :param v: optical flow vertical map + :return: optical flow in color code + """ + [h, w] = u.shape + img = np.zeros([h, w, 3]) + nanIdx = np.isnan(u) | np.isnan(v) + u[nanIdx] = 0 + v[nanIdx] = 0 + + colorwheel = make_color_wheel() + ncols = np.size(colorwheel, 0) + + rad = np.sqrt(u ** 2 + v ** 2) + + a = np.arctan2(-v, -u) / np.pi + + fk = (a + 1) / 2 * (ncols - 1) + 1 + + k0 = np.floor(fk).astype(int) + + k1 = k0 + 1 + k1[k1 == ncols + 1] = 1 + f = fk - k0 + + for i in range(0, np.size(colorwheel, 1)): + tmp = colorwheel[:, i] + col0 = tmp[k0 - 1] / 255 + col1 = tmp[k1 - 1] / 255 + col = (1 - f) * col0 + f * col1 + + idx = rad <= 1 + col[idx] = 1 - rad[idx] * (1 - col[idx]) + notidx = np.logical_not(idx) + + col[notidx] *= 0.75 + img[:, :, i] = np.uint8(np.floor(255 * col * (1 - nanIdx))) + + return img + + +# from https://github.com/gengshan-y/VCN +def flow_to_image(flow): + """ + Convert flow into middlebury color code image + :param flow: optical flow map + :return: optical flow image in middlebury color + """ + u = flow[:, :, 0] + v = flow[:, :, 1] + + maxu = -999. + maxv = -999. + minu = 999. + minv = 999. + + idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH) + u[idxUnknow] = 0 + v[idxUnknow] = 0 + + maxu = max(maxu, np.max(u)) + minu = min(minu, np.min(u)) + + maxv = max(maxv, np.max(v)) + minv = min(minv, np.min(v)) + + rad = np.sqrt(u ** 2 + v ** 2) + maxrad = max(-1, np.max(rad)) + + u = u / (maxrad + np.finfo(float).eps) + v = v / (maxrad + np.finfo(float).eps) + + img = compute_color(u, v) + + idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2) + img[idx] = 0 + + return np.uint8(img) \ No newline at end of file diff --git a/utils/flowvis.py b/utils/flowvis.py new file mode 100644 index 0000000000000000000000000000000000000000..2b044abbb0f93492ecf02f0e60af9a51926b54e0 --- /dev/null +++ b/utils/flowvis.py @@ -0,0 +1,129 @@ +import torch + + +def make_color_wheel(): + """ + Generate color wheel according Middlebury color code + :return: Color wheel + """ + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + + colorwheel = torch.zeros([3, ncols]) + + col = 0 + + # RY + colorwheel[0, 0:RY] = 255 + colorwheel[1, 0:RY] = torch.floor(255 * torch.arange(0, RY) / RY) + col += RY + + # YG + colorwheel[0, col:col + YG] = 255 - torch.floor(255 * torch.arange(0, YG) / YG) + colorwheel[1, col:col + YG] = 255 + col += YG + + # GC + colorwheel[1, col:col + GC] = 255 + colorwheel[2, col:col + GC] = torch.floor(255 * torch.arange(0, GC) / GC) + col += GC + + # CB + colorwheel[1, col:col + CB] = 255 - torch.floor(255 * torch.arange(0, CB) / CB) + colorwheel[2, col:col + CB] = 255 + col += CB + + # BM + colorwheel[2, col:col + BM] = 255 + colorwheel[0, col:col + BM] = torch.floor(255 * torch.arange(0, BM) / BM) + col += + BM + + # MR + colorwheel[2, col:col + MR] = 255 - torch.floor(255 * torch.arange(0, MR) / MR) + colorwheel[0, col:col + MR] = 255 + + return colorwheel + + +colorwheel = make_color_wheel().cuda() + + +def flow2img(flow_data: torch.Tensor): + """ + convert optical flow into color image + :param flow_data: + :return: color image + """ + # print(flow_data.shape) + # print(type(flow_data)) + u = flow_data[:, 0:1, :, :] + v = flow_data[:, 1:2, :, :] + + UNKNOW_FLOW_THRESHOLD = 1e7 + pr1 = torch.abs(u) > UNKNOW_FLOW_THRESHOLD + pr2 = torch.abs(v) > UNKNOW_FLOW_THRESHOLD + idx_unknown = (pr1 | pr2) + u[idx_unknown] = 0 + v[idx_unknown] = 0 + idx_unknown = idx_unknown.repeat(1, 3, 1, 1) + + rad = torch.sqrt(u ** 2 + v ** 2) + maxrad = max(-1, torch.max(rad).item()) + u = u / maxrad + torch.finfo(float).eps + v = v / maxrad + torch.finfo(float).eps + + img = compute_color(u, v) + + img[idx_unknown] = 0 + + return img / 255. + + +def compute_color(u, v): + """ + compute optical flow color map + :param u: horizontal optical flow + :param v: vertical optical flow + :return: + """ + + B, _, H, W = u.shape + img = torch.zeros((B, 3, H, W), device=torch.device('cuda')) + + NAN_idx = torch.isnan(u) | torch.isnan(v) + u[NAN_idx] = v[NAN_idx] = 0 + ncols = colorwheel.shape[1] + + rad = torch.sqrt(u ** 2 + v ** 2) + + a = torch.arctan2(-v, -u) / torch.pi + + fk = (a + 1) / 2 * (ncols - 1) + 1 + + k0 = torch.floor(fk).to(int) + + k1 = k0 + 1 + k1[k1 == ncols + 1] = 1 + f = fk - k0 + + for i in range(0, colorwheel.shape[0]): + tmp = colorwheel[i, :] + col0 = tmp[k0 - 1] / 255 + col1 = tmp[k1 - 1] / 255 + col = (1 - f) * col0 + f * col1 + + idx = rad <= 1 + col[idx] = 1 - rad[idx] * (1 - col[idx]) + notidx = torch.logical_not(idx) + + col[notidx] *= 0.75 + img[:, i:i+1, :, :] = torch.floor(255 * col * (~NAN_idx)).to(torch.uint8) + + return img + diff --git a/utils/metrics.py b/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..6ee8c4b4bb7a6403c8a9f69236892ca677ac30b5 --- /dev/null +++ b/utils/metrics.py @@ -0,0 +1,49 @@ +import numpy as np + +import torch + +from skimage.metrics import peak_signal_noise_ratio, structural_similarity + + +def calculate_batch_psnr(gt_tensor, output_tensor, mode='avg'): + # both parameters are in the form of tensor of size: BS, C, H, W + + if mode == 'avg': + gt_np = gt_tensor.cpu().numpy().astype(np.float32) + output_np = output_tensor.cpu().numpy().astype(np.float32) + + bs = gt_np.shape[0] + psnr_list = [] + psnr = 0 + for i in range(bs): + gt_im = gt_np[i, :, :, :] + output_im = output_np[i, :, :, :] + + gt_im = gt_im.transpose((1, 2, 0)) + output_im = output_im.transpose((1, 2, 0)) + + psnr_list.append(peak_signal_noise_ratio(gt_im, output_im, data_range=1.)) + psnr += peak_signal_noise_ratio(gt_im, output_im, data_range=1.) + return float(psnr / bs), psnr_list + else: + raise NotImplementedError + + +def calculate_batch_ssim(gt_tensor, output_tensor, mode='avg'): + if mode == 'avg': + gt_np = gt_tensor.cpu().numpy().astype(np.float32) + output_np = output_tensor.cpu().numpy().astype(np.float32) + + bs = gt_np.shape[0] + ssim = 0 + for i in range(bs): + gt_im = gt_np[i, :, :, :] + output_im = output_np[i, :, :, :] + gt_im = gt_im.transpose((1, 2, 0)) + output_im = output_im.transpose((1, 2, 0)) + + ssim += structural_similarity(gt_im, output_im, data_range=1., multichannel=True, channel_axis=2) + + return float(ssim / bs), bs + else: + raise NotImplementedError diff --git a/utils/misc.py b/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..7eb813a4273e0c5c8e8061d4ce0bfd7a2231aec7 --- /dev/null +++ b/utils/misc.py @@ -0,0 +1,357 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Misc functions, including distributed helpers. + +Mostly copy-paste from torchvision references. +""" +import logging +import sys +import os +import time +import datetime +from subprocess import call +import subprocess +from collections import defaultdict, deque +import pickle +from packaging import version +from typing import Optional, List + +import torch +import torchvision +import torch.distributed as dist +from torch import Tensor + +if version.parse(torchvision.__version__) < version.parse('0.7'): + from torchvision.ops import _new_empty_tensor + from torchvision.ops.misc import _output_size + + +def print_cuda_statistics(): + logger = logging.getLogger("System") + logger.info('__Python VERSION: {}'.format(sys.version)) + logger.info('__pyTorch VERSION: {}'.format(torch.__version__)) + logger.info('__CUDA VERSION') + # call(["nvcc", "--version"]) + logger.info('__CUDNN VERSION: {}'.format(torch.backends.cudnn.version())) + logger.info('__Number CUDA Devices: {}'.format(torch.cuda.device_count())) + logger.info('__Devices') + call(["nvidia-smi", "--format=csv", + "--query-gpu=index,name,driver_version,memory.total,memory.used,memory.free"]) + logger.info('Active CUDA Device: GPU {}'.format(torch.cuda.current_device())) + logger.info('Available devices {}'.format(torch.cuda.device_count())) + logger.info('Current cuda device {}'.format(torch.cuda.current_device())) + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device="cuda") + size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) + if local_size != max_size: + padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.detach().cpu().item() + if isinstance(v, float): + v = {'value': v, 'n': 1} + # assert isinstance(v, dict) and isinstance(v['value'], float) and isinstance(v['n'], int) + self.meters[k].update(**v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def print_avg(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter.global_avg)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + if torch.cuda.is_available(): + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}', + 'max mem: {memory:.0f}' + ]) + else: + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ]) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() + sha = 'N/A' + diff = "clean" + branch = 'N/A' + try: + sha = _run(['git', 'rev-parse', 'HEAD']) + subprocess.check_output(['git', 'diff'], cwd=cwd) + diff = _run(['git', 'diff-index', 'HEAD']) + diff = "has uncommited changes" if diff else "clean" + branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + + +def set_logger(file_path): + logger = logging.getLogger() + logger.setLevel('INFO') + stream_handler = logging.StreamHandler() + file_handler = logging.FileHandler(file_path, 'w') + formatter = logging.Formatter('[%(asctime)s] %(message)s', '%m-%d %H:%M:%S') + for handler in [stream_handler, file_handler]: + handler.setFormatter(formatter) + handler.setLevel('INFO') + logger.addHandler(handler) + return logger + + +def set_save_dir(save_dir, log_name, replace=True): + logger = set_logger(os.path.join(save_dir, f'log_{log_name}.txt')) + return logger diff --git a/utils/padder.py b/utils/padder.py new file mode 100644 index 0000000000000000000000000000000000000000..bd41cb633eb5f500d12d5122c6f1b0738099a15b --- /dev/null +++ b/utils/padder.py @@ -0,0 +1,28 @@ +import torch.nn.functional as F + + +class InputPadder: + """ Pads images such that dimensions are divisible by divisor """ + + def __init__(self, dims, divisor=16): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // divisor) + 1) * divisor - self.ht) % divisor + pad_wd = (((self.wd // divisor) + 1) * divisor - self.wd) % divisor + self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2] + + def pad(self, *inputs): + if len(inputs) == 1: + return F.pad(inputs[0], self._pad, mode='replicate') + else: + return [F.pad(x, self._pad, mode='replicate') for x in inputs] + + def unpad(self, *inputs): + if len(inputs) == 1: + return self._unpad(inputs[0]) + else: + return [self._unpad(x) for x in inputs] + + def _unpad(self, x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] + return x[..., c[0]:c[1], c[2]:c[3]] diff --git a/utils/plot.py b/utils/plot.py new file mode 100644 index 0000000000000000000000000000000000000000..b908af10d78c942485b682664981548fc15ae8d2 --- /dev/null +++ b/utils/plot.py @@ -0,0 +1,42 @@ +""" +Plotting utilities to visualize training logs. +""" + +import imageio +import os +import torchvision.utils as v_utils + + +def plot_samples_per_epoch(gen_batch, output_dir, epoch, iteration, nsample): + """ + Plot and save output samples per epoch + """ + fname = "samples_epoch_{:d}_{:d}.jpg".format(epoch, iteration) + fpath = os.path.join(output_dir, fname) + nrow = gen_batch.shape[0] // nsample + + image = v_utils.make_grid(gen_batch, nrow=nrow, padding=2, normalize=True) + v_utils.save_image(image, fpath) + return image + + +def plot_val_samples(gen_batch, output_dir, fname, nrow): + """ + Plot and dsave output samples for validations + """ + fpath = os.path.join(output_dir, fname) + image = v_utils.make_grid(gen_batch, nrow=nrow, padding=2, normalize=True) + v_utils.save_image(image, fpath) + return image + + +def plot_image(img, output_dir, fname): + """ + img in tensor format + """ + + fpath = os.path.join(output_dir, fname) + + v_utils.save_image(img, fpath, nrow=4, padding=2, normalize=True) + return imageio.imread(fpath) + diff --git a/utils/vos/__init__.py b/utils/vos/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/vos/__pycache__/__init__.cpython-310.pyc b/utils/vos/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91465f2e829bbf8f241ad9432cdefeab0eb0a6cd Binary files /dev/null and b/utils/vos/__pycache__/__init__.cpython-310.pyc differ diff --git a/utils/vos/__pycache__/__init__.cpython-38.pyc b/utils/vos/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2d781139b5031c8f44a89cedf0cd07884dc87034 Binary files /dev/null and b/utils/vos/__pycache__/__init__.cpython-38.pyc differ diff --git a/utils/vos/__pycache__/__init__.cpython-39.pyc b/utils/vos/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a11f56d828896fafb68e6cbc85a99ae0987afa08 Binary files /dev/null and b/utils/vos/__pycache__/__init__.cpython-39.pyc differ diff --git a/utils/vos/inference_core.py b/utils/vos/inference_core.py new file mode 100644 index 0000000000000000000000000000000000000000..4ab617b9fcaaa78fbaca81eb3a9025e262551f57 --- /dev/null +++ b/utils/vos/inference_core.py @@ -0,0 +1,80 @@ +import torch + +from inference_memory_bank import MemoryBank +from model.eval_network import STCN +from model.aggregate import aggregate + +from util.tensor_util import pad_divide_by + + +class InferenceCore: + def __init__(self, prop_net:STCN, images, num_objects, top_k=20, mem_every=5, include_last=False): + self.prop_net = prop_net + self.mem_every = mem_every + self.include_last = include_last + + # True dimensions + t = images.shape[1] + h, w = images.shape[-2:] + + # Pad each side to multiple of 16 + images, self.pad = pad_divide_by(images, 16) + # Padded dimensions + nh, nw = images.shape[-2:] + + self.images = images + self.device = 'cuda' + + self.k = num_objects + + # Background included, not always consistent (i.e. sum up to 1) + self.prob = torch.zeros((self.k+1, t, 1, nh, nw), dtype=torch.float32, device=self.device) + self.prob[0] = 1e-7 + + self.t, self.h, self.w = t, h, w + self.nh, self.nw = nh, nw + self.kh = self.nh//16 + self.kw = self.nw//16 + + self.mem_bank = MemoryBank(k=self.k, top_k=top_k) + + def encode_key(self, idx): + result = self.prop_net.encode_key(self.images[:,idx].cuda()) + return result + + def do_pass(self, key_k, key_v, idx, end_idx): + self.mem_bank.add_memory(key_k, key_v) + closest_ti = end_idx + + # Note that we never reach closest_ti, just the frame before it + this_range = range(idx+1, closest_ti) + end = closest_ti - 1 + + for ti in this_range: + k16, qv16, qf16, qf8, qf4 = self.encode_key(ti) + out_mask = self.prop_net.segment_with_query(self.mem_bank, qf8, qf4, k16, qv16) + + out_mask = aggregate(out_mask, keep_bg=True) + self.prob[:,ti] = out_mask + + if ti != end: + is_mem_frame = ((ti % self.mem_every) == 0) + if self.include_last or is_mem_frame: + prev_value = self.prop_net.encode_value(self.images[:,ti].cuda(), qf16, out_mask[1:]) + prev_key = k16.unsqueeze(2) + self.mem_bank.add_memory(prev_key, prev_value, is_temp=not is_mem_frame) + + return closest_ti + + def interact(self, mask, frame_idx, end_idx): + mask, _ = pad_divide_by(mask.cuda(), 16) + + self.prob[:, frame_idx] = aggregate(mask, keep_bg=True) + + # KV pair for the interacting frame + key_k, _, qf16, _, _ = self.encode_key(frame_idx) + key_v = self.prop_net.encode_value(self.images[:,frame_idx].cuda(), qf16, self.prob[1:,frame_idx].cuda()) + key_k = key_k.unsqueeze(2) + + # Propagate + self.do_pass(key_k, key_v, frame_idx, end_idx) diff --git a/utils/vos/inference_memory_bank.py b/utils/vos/inference_memory_bank.py new file mode 100644 index 0000000000000000000000000000000000000000..fe872673d9ddd0f0324f4702164879c2254c9851 --- /dev/null +++ b/utils/vos/inference_memory_bank.py @@ -0,0 +1,86 @@ +import math +import torch + + +def softmax_w_top(x, top): + values, indices = torch.topk(x, k=top, dim=1) + x_exp = values.exp_() + + x_exp /= torch.sum(x_exp, dim=1, keepdim=True) + # The types should be the same already + # some people report an error here so an additional guard is added + x.zero_().scatter_(1, indices, x_exp.type(x.dtype)) # B * THW * HW + + return x + + +class MemoryBank: + def __init__(self, k, top_k=20): + self.top_k = top_k + + self.CK = None + self.CV = None + + self.mem_k = None + self.mem_v = None + + self.num_objects = k + + def _global_matching(self, mk, qk): + # NE means number of elements -- typically T*H*W + B, CK, NE = mk.shape + + # See supplementary material + a_sq = mk.pow(2).sum(1).unsqueeze(2) + ab = mk.transpose(1, 2) @ qk + + affinity = (2*ab-a_sq) / math.sqrt(CK) # B, NE, HW + affinity = softmax_w_top(affinity, top=self.top_k) # B, NE, HW + + return affinity + + def _readout(self, affinity, mv): + return torch.bmm(mv, affinity) + + def match_memory(self, qk): + k = self.num_objects + _, _, h, w = qk.shape + + qk = qk.flatten(start_dim=2) + + if self.temp_k is not None: + mk = torch.cat([self.mem_k, self.temp_k], 2) + mv = torch.cat([self.mem_v, self.temp_v], 2) + else: + mk = self.mem_k + mv = self.mem_v + + affinity = self._global_matching(mk, qk) + + # One affinity for all + readout_mem = self._readout(affinity.expand(k,-1,-1), mv) + + return readout_mem.view(k, self.CV, h, w) + + def add_memory(self, key, value, is_temp=False): + # Temp is for "last frame" + # Not always used + # But can always be flushed + self.temp_k = None + self.temp_v = None + key = key.flatten(start_dim=2) + value = value.flatten(start_dim=2) + + if self.mem_k is None: + # First frame, just shove it in + self.mem_k = key + self.mem_v = value + self.CK = key.shape[1] + self.CV = value.shape[1] + else: + if is_temp: + self.temp_k = key + self.temp_v = value + else: + self.mem_k = torch.cat([self.mem_k, key], 2) + self.mem_v = torch.cat([self.mem_v, value], 2) \ No newline at end of file diff --git a/utils/vos/model/__init__.py b/utils/vos/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..67c15ad18fcc446c933ab62cbd304f02257c80ce --- /dev/null +++ b/utils/vos/model/__init__.py @@ -0,0 +1,8 @@ +import utils.vos.model.aggregate +import utils.vos.model.cbam +import utils.vos.model.eval_network +import utils.vos.model.losses +import utils.vos.model.model +import utils.vos.model.mod_resnet +import utils.vos.model.modules +import utils.vos.model.network diff --git a/utils/vos/model/__pycache__/__init__.cpython-310.pyc b/utils/vos/model/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6466dba10301bf9fa8f1a1922643aee1fbe87b6 Binary files /dev/null and b/utils/vos/model/__pycache__/__init__.cpython-310.pyc differ diff --git a/utils/vos/model/__pycache__/__init__.cpython-38.pyc b/utils/vos/model/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0a76d8f749baeddbb7eed5967a05d513041054f Binary files /dev/null and b/utils/vos/model/__pycache__/__init__.cpython-38.pyc differ diff --git a/utils/vos/model/__pycache__/__init__.cpython-39.pyc b/utils/vos/model/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ed8a9e667009c7a6c6a1f384cbd5ca329d4f538 Binary files /dev/null and b/utils/vos/model/__pycache__/__init__.cpython-39.pyc differ diff --git a/utils/vos/model/__pycache__/aggregate.cpython-310.pyc b/utils/vos/model/__pycache__/aggregate.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31b5a18c18d62482306f36c225ba6afb8cc5e821 Binary files /dev/null and b/utils/vos/model/__pycache__/aggregate.cpython-310.pyc differ diff --git a/utils/vos/model/__pycache__/aggregate.cpython-38.pyc b/utils/vos/model/__pycache__/aggregate.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9551a0d6a531afd96d193cb791667be5cee92a5 Binary files /dev/null and b/utils/vos/model/__pycache__/aggregate.cpython-38.pyc differ diff --git a/utils/vos/model/__pycache__/aggregate.cpython-39.pyc b/utils/vos/model/__pycache__/aggregate.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f82c0c436068e7bacadeea7010cec6a19f16f006 Binary files /dev/null and b/utils/vos/model/__pycache__/aggregate.cpython-39.pyc differ diff --git a/utils/vos/model/__pycache__/cbam.cpython-310.pyc b/utils/vos/model/__pycache__/cbam.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a57c01dd066ce05594cf61a01c97ceab67c0ca0f Binary files /dev/null and b/utils/vos/model/__pycache__/cbam.cpython-310.pyc differ diff --git a/utils/vos/model/__pycache__/cbam.cpython-38.pyc b/utils/vos/model/__pycache__/cbam.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72bd3511cd6c6155b9b0f661791774a813365a3b Binary files /dev/null and b/utils/vos/model/__pycache__/cbam.cpython-38.pyc differ diff --git a/utils/vos/model/__pycache__/cbam.cpython-39.pyc b/utils/vos/model/__pycache__/cbam.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8cc756ae56d0cb4a39b2182285dbca9772bd0575 Binary files /dev/null and b/utils/vos/model/__pycache__/cbam.cpython-39.pyc differ diff --git a/utils/vos/model/__pycache__/eval_network.cpython-310.pyc b/utils/vos/model/__pycache__/eval_network.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e74443a900a96d5212bfa3001469ea3bb13d6702 Binary files /dev/null and b/utils/vos/model/__pycache__/eval_network.cpython-310.pyc differ diff --git a/utils/vos/model/__pycache__/eval_network.cpython-38.pyc b/utils/vos/model/__pycache__/eval_network.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71a393b480713cd9045cb866654b4fb89ffa5b84 Binary files /dev/null and b/utils/vos/model/__pycache__/eval_network.cpython-38.pyc differ diff --git a/utils/vos/model/__pycache__/eval_network.cpython-39.pyc b/utils/vos/model/__pycache__/eval_network.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca5c8245f6a28ea4e4f17f36c2b05fbb36c0e45a Binary files /dev/null and b/utils/vos/model/__pycache__/eval_network.cpython-39.pyc differ diff --git a/utils/vos/model/__pycache__/inference_core.cpython-310.pyc b/utils/vos/model/__pycache__/inference_core.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e2d1ab5e402ff181cb01a5c06fcf1ada8e3ac8e Binary files /dev/null and b/utils/vos/model/__pycache__/inference_core.cpython-310.pyc differ diff --git a/utils/vos/model/__pycache__/inference_core.cpython-38.pyc b/utils/vos/model/__pycache__/inference_core.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..397730c1e20f7462924b3d9ca38c3878c76442a5 Binary files /dev/null and b/utils/vos/model/__pycache__/inference_core.cpython-38.pyc differ diff --git a/utils/vos/model/__pycache__/inference_core.cpython-39.pyc b/utils/vos/model/__pycache__/inference_core.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec9c320d784d4ab2b94b522be12eeba00c672223 Binary files /dev/null and b/utils/vos/model/__pycache__/inference_core.cpython-39.pyc differ diff --git a/utils/vos/model/__pycache__/inference_memory_bank.cpython-310.pyc b/utils/vos/model/__pycache__/inference_memory_bank.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e52ddae0d21a06b75665d42e9c72b451ab1cb82 Binary files /dev/null and b/utils/vos/model/__pycache__/inference_memory_bank.cpython-310.pyc differ diff --git a/utils/vos/model/__pycache__/inference_memory_bank.cpython-38.pyc b/utils/vos/model/__pycache__/inference_memory_bank.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbf8e510790315377bd6179edb31ce14674f5a5c Binary files /dev/null and b/utils/vos/model/__pycache__/inference_memory_bank.cpython-38.pyc differ diff --git a/utils/vos/model/__pycache__/inference_memory_bank.cpython-39.pyc b/utils/vos/model/__pycache__/inference_memory_bank.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8427311c86f016b488cd63b96b495fc69f698b0c Binary files /dev/null and b/utils/vos/model/__pycache__/inference_memory_bank.cpython-39.pyc differ diff --git a/utils/vos/model/__pycache__/losses.cpython-310.pyc b/utils/vos/model/__pycache__/losses.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f40b7ba8c1fd0630988e4550631b28411da7f943 Binary files /dev/null and b/utils/vos/model/__pycache__/losses.cpython-310.pyc differ diff --git a/utils/vos/model/__pycache__/losses.cpython-38.pyc b/utils/vos/model/__pycache__/losses.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eadba9bf96803249552a41ad44ba4aac6d9faabe Binary files /dev/null and b/utils/vos/model/__pycache__/losses.cpython-38.pyc differ diff --git a/utils/vos/model/__pycache__/losses.cpython-39.pyc b/utils/vos/model/__pycache__/losses.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0f408f79615c75701b8728928dbce15f27fd440 Binary files /dev/null and b/utils/vos/model/__pycache__/losses.cpython-39.pyc differ diff --git a/utils/vos/model/__pycache__/mod_resnet.cpython-310.pyc b/utils/vos/model/__pycache__/mod_resnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb66a3acd15759212822bc1401a7b4eb387e849c Binary files /dev/null and b/utils/vos/model/__pycache__/mod_resnet.cpython-310.pyc differ diff --git a/utils/vos/model/__pycache__/mod_resnet.cpython-38.pyc b/utils/vos/model/__pycache__/mod_resnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88baa250f5f295d9ff84efe70c1dfc9e4bb2c9d8 Binary files /dev/null and b/utils/vos/model/__pycache__/mod_resnet.cpython-38.pyc differ diff --git a/utils/vos/model/__pycache__/mod_resnet.cpython-39.pyc b/utils/vos/model/__pycache__/mod_resnet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..427fce41d7b48ca35e9c13fd690c7e5948e515a4 Binary files /dev/null and b/utils/vos/model/__pycache__/mod_resnet.cpython-39.pyc differ diff --git a/utils/vos/model/__pycache__/model.cpython-310.pyc b/utils/vos/model/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5202bc4f306fbcd3d8763f32a24a47aaf793487 Binary files /dev/null and b/utils/vos/model/__pycache__/model.cpython-310.pyc differ diff --git a/utils/vos/model/__pycache__/model.cpython-38.pyc b/utils/vos/model/__pycache__/model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d66a5e396213a146fb60dd5a3a8e2ea8afe9b2d1 Binary files /dev/null and b/utils/vos/model/__pycache__/model.cpython-38.pyc differ diff --git a/utils/vos/model/__pycache__/model.cpython-39.pyc b/utils/vos/model/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..234558f47929d3cc80a387a681e5498bb1ee0fcb Binary files /dev/null and b/utils/vos/model/__pycache__/model.cpython-39.pyc differ diff --git a/utils/vos/model/__pycache__/modules.cpython-310.pyc b/utils/vos/model/__pycache__/modules.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0d91fbda5fe59e92ad7ec491789dc0b8216bd92 Binary files /dev/null and b/utils/vos/model/__pycache__/modules.cpython-310.pyc differ diff --git a/utils/vos/model/__pycache__/modules.cpython-38.pyc b/utils/vos/model/__pycache__/modules.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65245c123cb229f512e210572c6feacaaaf727b7 Binary files /dev/null and b/utils/vos/model/__pycache__/modules.cpython-38.pyc differ diff --git a/utils/vos/model/__pycache__/modules.cpython-39.pyc b/utils/vos/model/__pycache__/modules.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0917845a1c79bea8b35169417654a53f5d9a5d81 Binary files /dev/null and b/utils/vos/model/__pycache__/modules.cpython-39.pyc differ diff --git a/utils/vos/model/__pycache__/network.cpython-310.pyc b/utils/vos/model/__pycache__/network.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5220924b9886564c8316375cb7e63d7490b3981 Binary files /dev/null and b/utils/vos/model/__pycache__/network.cpython-310.pyc differ diff --git a/utils/vos/model/__pycache__/network.cpython-38.pyc b/utils/vos/model/__pycache__/network.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9847e197419489f0f2061ba3b964d11d661a11fb Binary files /dev/null and b/utils/vos/model/__pycache__/network.cpython-38.pyc differ diff --git a/utils/vos/model/__pycache__/network.cpython-39.pyc b/utils/vos/model/__pycache__/network.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9e4e8884c8549d8f683e4d9726effff604c2e7f Binary files /dev/null and b/utils/vos/model/__pycache__/network.cpython-39.pyc differ diff --git a/utils/vos/model/aggregate.py b/utils/vos/model/aggregate.py new file mode 100644 index 0000000000000000000000000000000000000000..aeef98f229e8450c3478162d3ca51067bd88a1b9 --- /dev/null +++ b/utils/vos/model/aggregate.py @@ -0,0 +1,17 @@ +import torch +import torch.nn.functional as F + + +# Soft aggregation from STM +def aggregate(prob, keep_bg=False): + k = prob.shape + new_prob = torch.cat([ + torch.prod(1-prob, dim=0, keepdim=True), + prob + ], 0).clamp(1e-7, 1-1e-7) + logits = torch.log((new_prob /(1-new_prob))) + + if keep_bg: + return F.softmax(logits, dim=0) + else: + return F.softmax(logits, dim=0)[1:] \ No newline at end of file diff --git a/utils/vos/model/cbam.py b/utils/vos/model/cbam.py new file mode 100644 index 0000000000000000000000000000000000000000..6423358429e2843b1f36ceb2bc1a485ea72b8eb4 --- /dev/null +++ b/utils/vos/model/cbam.py @@ -0,0 +1,77 @@ +# Modified from https://github.com/Jongchan/attention-module/blob/master/MODELS/cbam.py + +import torch +import torch.nn as nn +import torch.nn.functional as F + +class BasicConv(nn.Module): + def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True): + super(BasicConv, self).__init__() + self.out_channels = out_planes + self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) + + def forward(self, x): + x = self.conv(x) + return x + +class Flatten(nn.Module): + def forward(self, x): + return x.view(x.size(0), -1) + +class ChannelGate(nn.Module): + def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']): + super(ChannelGate, self).__init__() + self.gate_channels = gate_channels + self.mlp = nn.Sequential( + Flatten(), + nn.Linear(gate_channels, gate_channels // reduction_ratio), + nn.ReLU(), + nn.Linear(gate_channels // reduction_ratio, gate_channels) + ) + self.pool_types = pool_types + def forward(self, x): + channel_att_sum = None + for pool_type in self.pool_types: + if pool_type=='avg': + avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp( avg_pool ) + elif pool_type=='max': + max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))) + channel_att_raw = self.mlp( max_pool ) + + if channel_att_sum is None: + channel_att_sum = channel_att_raw + else: + channel_att_sum = channel_att_sum + channel_att_raw + + scale = torch.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x) + return x * scale + +class ChannelPool(nn.Module): + def forward(self, x): + return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 ) + +class SpatialGate(nn.Module): + def __init__(self): + super(SpatialGate, self).__init__() + kernel_size = 7 + self.compress = ChannelPool() + self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2) + def forward(self, x): + x_compress = self.compress(x) + x_out = self.spatial(x_compress) + scale = torch.sigmoid(x_out) # broadcasting + return x * scale + +class CBAM(nn.Module): + def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False): + super(CBAM, self).__init__() + self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) + self.no_spatial=no_spatial + if not no_spatial: + self.SpatialGate = SpatialGate() + def forward(self, x): + x_out = self.ChannelGate(x) + if not self.no_spatial: + x_out = self.SpatialGate(x_out) + return x_out diff --git a/utils/vos/model/eval_network.py b/utils/vos/model/eval_network.py new file mode 100644 index 0000000000000000000000000000000000000000..0b4429c3b19a0c5b6a975e7fccc09aa66c0fba1c --- /dev/null +++ b/utils/vos/model/eval_network.py @@ -0,0 +1,65 @@ +""" +eval_network.py - Evaluation version of the network +The logic is basically the same +but with top-k and some implementation optimization + +The trailing number of a variable usually denote the stride +e.g. f16 -> encoded features with stride 16 +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .modules import * +from .network import Decoder + + +class STCN(nn.Module): + def __init__(self): + super().__init__() + self.key_encoder = KeyEncoder() + self.value_encoder = ValueEncoder() + + # Projection from f16 feature space to key space + self.key_proj = KeyProjection(1024, keydim=64) + + # Compress f16 a bit to use in decoding later on + self.key_comp = nn.Conv2d(1024, 512, kernel_size=3, padding=1) + + self.decoder = Decoder() + + def encode_value(self, frame, kf16, masks): + k, _, h, w = masks.shape + + # Extract memory key/value for a frame with multiple masks + frame = frame.view(1, 3, h, w).repeat(k, 1, 1, 1) + # Compute the "others" mask + if k != 1: + others = torch.cat([ + torch.sum( + masks[[j for j in range(k) if i!=j]] + , dim=0, keepdim=True) + for i in range(k)], 0) + else: + others = torch.zeros_like(masks) + + f16 = self.value_encoder(frame, kf16.repeat(k,1,1,1), masks, others) + + return f16.unsqueeze(2) + + def encode_key(self, frame): + f16, f8, f4 = self.key_encoder(frame) + k16 = self.key_proj(f16) + f16_thin = self.key_comp(f16) + + return k16, f16_thin, f16, f8, f4 + + def segment_with_query(self, mem_bank, qf8, qf4, qk16, qv16): + k = mem_bank.num_objects + + readout_mem = mem_bank.match_memory(qk16) + qv16 = qv16.expand(k, -1, -1, -1) + qv16 = torch.cat([readout_mem, qv16], 1) + + return torch.sigmoid(self.decoder(qv16, qf8, qf4)) diff --git a/utils/vos/model/inference_core.py b/utils/vos/model/inference_core.py new file mode 100644 index 0000000000000000000000000000000000000000..3e2ddffcf54b959ee0166ecf5b099fbc4939d26c --- /dev/null +++ b/utils/vos/model/inference_core.py @@ -0,0 +1,79 @@ +import torch + +from .inference_memory_bank import MemoryBank +from .eval_network import STCN +from .aggregate import aggregate + +from ..util.tensor_util import pad_divide_by + +class InferenceCore: + def __init__(self, prop_net:STCN, images, num_objects, top_k=20, mem_every=5, include_last=False): + self.prop_net = prop_net + self.mem_every = mem_every + self.include_last = include_last + + # True dimensions + t = images.shape[1] + h, w = images.shape[-2:] + + # Pad each side to multiple of 16 + images, self.pad = pad_divide_by(images, 16) + # Padded dimensions + nh, nw = images.shape[-2:] + + self.images = images + self.device = 'cuda' + + self.k = num_objects + + # Background included, not always consistent (i.e. sum up to 1) + self.prob = torch.zeros((self.k+1, t, 1, nh, nw), dtype=torch.float32, device=self.device) + self.prob[0] = 1e-7 + + self.t, self.h, self.w = t, h, w + self.nh, self.nw = nh, nw + self.kh = self.nh//16 + self.kw = self.nw//16 + + self.mem_bank = MemoryBank(k=self.k, top_k=top_k) + + def encode_key(self, idx): + result = self.prop_net.encode_key(self.images[:,idx].cuda()) + return result + + def do_pass(self, key_k, key_v, idx, end_idx): + self.mem_bank.add_memory(key_k, key_v) + closest_ti = end_idx + + # Note that we never reach closest_ti, just the frame before it + this_range = range(idx+1, closest_ti) + end = closest_ti - 1 + + for ti in this_range: + k16, qv16, qf16, qf8, qf4 = self.encode_key(ti) + out_mask = self.prop_net.segment_with_query(self.mem_bank, qf8, qf4, k16, qv16) + + out_mask = aggregate(out_mask, keep_bg=True) + self.prob[:,ti] = out_mask + + if ti != end: + is_mem_frame = ((ti % self.mem_every) == 0) + if self.include_last or is_mem_frame: + prev_value = self.prop_net.encode_value(self.images[:,ti].cuda(), qf16, out_mask[1:]) + prev_key = k16.unsqueeze(2) + self.mem_bank.add_memory(prev_key, prev_value, is_temp=not is_mem_frame) + + return closest_ti + + def interact(self, mask, frame_idx, end_idx): + mask, _ = pad_divide_by(mask.cuda(), 16) + + self.prob[:, frame_idx] = aggregate(mask, keep_bg=True) + + # KV pair for the interacting frame + key_k, _, qf16, _, _ = self.encode_key(frame_idx) + key_v = self.prop_net.encode_value(self.images[:,frame_idx].cuda(), qf16, self.prob[1:,frame_idx].cuda()) + key_k = key_k.unsqueeze(2) + + # Propagate + self.do_pass(key_k, key_v, frame_idx, end_idx) diff --git a/utils/vos/model/inference_memory_bank.py b/utils/vos/model/inference_memory_bank.py new file mode 100644 index 0000000000000000000000000000000000000000..fe872673d9ddd0f0324f4702164879c2254c9851 --- /dev/null +++ b/utils/vos/model/inference_memory_bank.py @@ -0,0 +1,86 @@ +import math +import torch + + +def softmax_w_top(x, top): + values, indices = torch.topk(x, k=top, dim=1) + x_exp = values.exp_() + + x_exp /= torch.sum(x_exp, dim=1, keepdim=True) + # The types should be the same already + # some people report an error here so an additional guard is added + x.zero_().scatter_(1, indices, x_exp.type(x.dtype)) # B * THW * HW + + return x + + +class MemoryBank: + def __init__(self, k, top_k=20): + self.top_k = top_k + + self.CK = None + self.CV = None + + self.mem_k = None + self.mem_v = None + + self.num_objects = k + + def _global_matching(self, mk, qk): + # NE means number of elements -- typically T*H*W + B, CK, NE = mk.shape + + # See supplementary material + a_sq = mk.pow(2).sum(1).unsqueeze(2) + ab = mk.transpose(1, 2) @ qk + + affinity = (2*ab-a_sq) / math.sqrt(CK) # B, NE, HW + affinity = softmax_w_top(affinity, top=self.top_k) # B, NE, HW + + return affinity + + def _readout(self, affinity, mv): + return torch.bmm(mv, affinity) + + def match_memory(self, qk): + k = self.num_objects + _, _, h, w = qk.shape + + qk = qk.flatten(start_dim=2) + + if self.temp_k is not None: + mk = torch.cat([self.mem_k, self.temp_k], 2) + mv = torch.cat([self.mem_v, self.temp_v], 2) + else: + mk = self.mem_k + mv = self.mem_v + + affinity = self._global_matching(mk, qk) + + # One affinity for all + readout_mem = self._readout(affinity.expand(k,-1,-1), mv) + + return readout_mem.view(k, self.CV, h, w) + + def add_memory(self, key, value, is_temp=False): + # Temp is for "last frame" + # Not always used + # But can always be flushed + self.temp_k = None + self.temp_v = None + key = key.flatten(start_dim=2) + value = value.flatten(start_dim=2) + + if self.mem_k is None: + # First frame, just shove it in + self.mem_k = key + self.mem_v = value + self.CK = key.shape[1] + self.CV = value.shape[1] + else: + if is_temp: + self.temp_k = key + self.temp_v = value + else: + self.mem_k = torch.cat([self.mem_k, key], 2) + self.mem_v = torch.cat([self.mem_v, value], 2) \ No newline at end of file diff --git a/utils/vos/model/losses.py b/utils/vos/model/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..0b24249e95d46cd6807b0ecfc5963913e8bdd5d0 --- /dev/null +++ b/utils/vos/model/losses.py @@ -0,0 +1,84 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from ..util.tensor_util import compute_tensor_iu + +from collections import defaultdict + + +def get_iou_hook(values): + return 'iou/iou', (values['hide_iou/i']+1)/(values['hide_iou/u']+1) + +def get_sec_iou_hook(values): + return 'iou/sec_iou', (values['hide_iou/sec_i']+1)/(values['hide_iou/sec_u']+1) + +iou_hooks_so = [ + get_iou_hook, +] + +iou_hooks_mo = [ + get_iou_hook, + get_sec_iou_hook, +] + + +# https://stackoverflow.com/questions/63735255/how-do-i-compute-bootstrapped-cross-entropy-loss-in-pytorch +class BootstrappedCE(nn.Module): + def __init__(self, start_warm=20000, end_warm=70000, top_p=0.15): + super().__init__() + + self.start_warm = start_warm + self.end_warm = end_warm + self.top_p = top_p + + def forward(self, input, target, it): + if it < self.start_warm: + return F.cross_entropy(input, target), 1.0 + + raw_loss = F.cross_entropy(input, target, reduction='none').view(-1) + num_pixels = raw_loss.numel() + + if it > self.end_warm: + this_p = self.top_p + else: + this_p = self.top_p + (1-self.top_p)*((self.end_warm-it)/(self.end_warm-self.start_warm)) + loss, _ = torch.topk(raw_loss, int(num_pixels * this_p), sorted=False) + return loss.mean(), this_p + + +class LossComputer: + def __init__(self, para): + super().__init__() + self.para = para + self.bce = BootstrappedCE() + + def compute(self, data, it): + losses = defaultdict(int) + + b, s, _, _, _ = data['gt'].shape + selector = data.get('selector', None) + + for i in range(1, s): + # Have to do it in a for-loop like this since not every entry has the second object + # Well it's not a lot of iterations anyway + for j in range(b): + if selector is not None and selector[j][1] > 0.5: + loss, p = self.bce(data['logits_%d'%i][j:j+1], data['cls_gt'][j:j+1,i], it) + else: + loss, p = self.bce(data['logits_%d'%i][j:j+1,:2], data['cls_gt'][j:j+1,i], it) + + losses['loss_%d'%i] += loss / b + losses['p'] += p / b / (s-1) + + losses['total_loss'] += losses['loss_%d'%i] + + new_total_i, new_total_u = compute_tensor_iu(data['mask_%d'%i]>0.5, data['gt'][:,i]>0.5) + losses['hide_iou/i'] += new_total_i + losses['hide_iou/u'] += new_total_u + + if selector is not None: + new_total_i, new_total_u = compute_tensor_iu(data['sec_mask_%d'%i]>0.5, data['sec_gt'][:,i]>0.5) + losses['hide_iou/sec_i'] += new_total_i + losses['hide_iou/sec_u'] += new_total_u + + return losses diff --git a/utils/vos/model/mod_resnet.py b/utils/vos/model/mod_resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..dde0a1f7295fd58a90db97c100c8cd263f438e9f --- /dev/null +++ b/utils/vos/model/mod_resnet.py @@ -0,0 +1,167 @@ +""" +mod_resnet.py - A modified ResNet structure +We append extra channels to the first conv by some network surgery +""" + +from collections import OrderedDict +import math + +import torch +import torch.nn as nn +from torch.utils import model_zoo + + +def load_weights_sequential(target, source_state, extra_chan=1): + + new_dict = OrderedDict() + + for k1, v1 in target.state_dict().items(): + if not 'num_batches_tracked' in k1: + if k1 in source_state: + tar_v = source_state[k1] + + if v1.shape != tar_v.shape: + # Init the new segmentation channel with zeros + # print(v1.shape, tar_v.shape) + c, _, w, h = v1.shape + pads = torch.zeros((c,extra_chan,w,h), device=tar_v.device) + nn.init.orthogonal_(pads) + tar_v = torch.cat([tar_v, pads], 1) + + new_dict[k1] = tar_v + + target.load_state_dict(new_dict, strict=False) + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1, dilation=1): + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, dilation=dilation) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride=stride, dilation=dilation) + self.bn1 = nn.BatchNorm2d(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes, stride=1, dilation=dilation) + self.bn2 = nn.BatchNorm2d(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, dilation=dilation, + padding=dilation) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1) + self.bn3 = nn.BatchNorm2d(planes * 4) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + def __init__(self, block, layers=(3, 4, 23, 3), extra_chan=1): + self.inplanes = 64 + super(ResNet, self).__init__() + self.conv1 = nn.Conv2d(3+extra_chan, 64, kernel_size=7, stride=2, padding=3) + self.bn1 = nn.BatchNorm2d(64) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1, dilation=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride), + nn.BatchNorm2d(planes * block.expansion), + ) + + layers = [block(self.inplanes, planes, stride, downsample)] + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, dilation=dilation)) + + return nn.Sequential(*layers) + +def resnet18(pretrained=True, extra_chan=0): + model = ResNet(BasicBlock, [2, 2, 2, 2], extra_chan) + if pretrained: + load_weights_sequential(model, model_zoo.load_url(model_urls['resnet18']), extra_chan) + return model + +def resnet50(pretrained=True, extra_chan=0): + model = ResNet(Bottleneck, [3, 4, 6, 3], extra_chan) + if pretrained: + load_weights_sequential(model, model_zoo.load_url(model_urls['resnet50']), extra_chan) + return model + diff --git a/utils/vos/model/model.py b/utils/vos/model/model.py new file mode 100644 index 0000000000000000000000000000000000000000..129df26fa1d17d48311c19f3a6a14a5b979ea530 --- /dev/null +++ b/utils/vos/model/model.py @@ -0,0 +1,246 @@ +""" +model.py - warpper and utility functions for network training +Compute loss, back-prop, update parameters, logging, etc. +""" + + +import os +import time +import torch +import torch.nn as nn +import torch.optim as optim + +from .network import STCN +from .losses import LossComputer, iou_hooks_mo, iou_hooks_so +from ..util.log_integrator import Integrator + + +class STCNModel: + def __init__(self, para, logger=None, save_path=None, local_rank=0, world_size=1): + self.para = para + self.single_object = para['single_object'] + self.local_rank = local_rank + + self.STCN = nn.parallel.DistributedDataParallel( + STCN(self.single_object).cuda(), + device_ids=[local_rank], output_device=local_rank, broadcast_buffers=False) + + # Setup logger when local_rank=0 + self.logger = logger + self.save_path = save_path + if logger is not None: + self.last_time = time.time() + self.train_integrator = Integrator(self.logger, distributed=True, local_rank=local_rank, world_size=world_size) + if self.single_object: + self.train_integrator.add_hook(iou_hooks_so) + else: + self.train_integrator.add_hook(iou_hooks_mo) + self.loss_computer = LossComputer(para) + + self.train() + self.optimizer = optim.Adam(filter( + lambda p: p.requires_grad, self.STCN.parameters()), lr=para['lr'], weight_decay=1e-7) + self.scheduler = optim.lr_scheduler.MultiStepLR(self.optimizer, para['steps'], para['gamma']) + if para['amp']: + self.scaler = torch.cuda.amp.GradScaler() + + # Logging info + self.report_interval = 100 + self.save_im_interval = 800 + self.save_model_interval = 50000 + if para['debug']: + self.report_interval = self.save_im_interval = 1 + + def do_pass(self, data, it=0): + # No need to store the gradient outside training + torch.set_grad_enabled(self._is_train) + + for k, v in data.items(): + if type(v) != list and type(v) != dict and type(v) != int: + data[k] = v.cuda(non_blocking=True) + + out = {} + Fs = data['rgb'] + Ms = data['gt'] + + with torch.cuda.amp.autocast(enabled=self.para['amp']): + # key features never change, compute once + k16, kf16_thin, kf16, kf8, kf4 = self.STCN('encode_key', Fs) + + if self.single_object: + ref_v = self.STCN('encode_value', Fs[:,0], kf16[:,0], Ms[:,0]) + + # Segment frame 1 with frame 0 + prev_logits, prev_mask = self.STCN('segment', + k16[:,:,1], kf16_thin[:,1], kf8[:,1], kf4[:,1], + k16[:,:,0:1], ref_v) + prev_v = self.STCN('encode_value', Fs[:,1], kf16[:,1], prev_mask) + + values = torch.cat([ref_v, prev_v], 2) + + del ref_v + + # Segment frame 2 with frame 0 and 1 + this_logits, this_mask = self.STCN('segment', + k16[:,:,2], kf16_thin[:,2], kf8[:,2], kf4[:,2], + k16[:,:,0:2], values) + + out['mask_1'] = prev_mask + out['mask_2'] = this_mask + out['logits_1'] = prev_logits + out['logits_2'] = this_logits + else: + sec_Ms = data['sec_gt'] + selector = data['selector'] + + ref_v1 = self.STCN('encode_value', Fs[:,0], kf16[:,0], Ms[:,0], sec_Ms[:,0]) + ref_v2 = self.STCN('encode_value', Fs[:,0], kf16[:,0], sec_Ms[:,0], Ms[:,0]) + ref_v = torch.stack([ref_v1, ref_v2], 1) + + # Segment frame 1 with frame 0 + prev_logits, prev_mask = self.STCN('segment', + k16[:,:,1], kf16_thin[:,1], kf8[:,1], kf4[:,1], + k16[:,:,0:1], ref_v, selector) + + prev_v1 = self.STCN('encode_value', Fs[:,1], kf16[:,1], prev_mask[:,0:1], prev_mask[:,1:2]) + prev_v2 = self.STCN('encode_value', Fs[:,1], kf16[:,1], prev_mask[:,1:2], prev_mask[:,0:1]) + prev_v = torch.stack([prev_v1, prev_v2], 1) + values = torch.cat([ref_v, prev_v], 3) + + del ref_v + + # Segment frame 2 with frame 0 and 1 + this_logits, this_mask = self.STCN('segment', + k16[:,:,2], kf16_thin[:,2], kf8[:,2], kf4[:,2], + k16[:,:,0:2], values, selector) + + out['mask_1'] = prev_mask[:,0:1] + out['mask_2'] = this_mask[:,0:1] + out['sec_mask_1'] = prev_mask[:,1:2] + out['sec_mask_2'] = this_mask[:,1:2] + + out['logits_1'] = prev_logits + out['logits_2'] = this_logits + + if self._do_log or self._is_train: + losses = self.loss_computer.compute({**data, **out}, it) + + # Logging + if self._do_log: + self.integrator.add_dict(losses) + if self._is_train: + if it % self.save_im_interval == 0 and it != 0: + if self.logger is not None: + images = {**data, **out} + size = (384, 384) + + if self._is_train: + if (it) % self.report_interval == 0 and it != 0: + if self.logger is not None: + self.logger.log_scalar('train/lr', self.scheduler.get_last_lr()[0], it) + self.logger.log_metrics('train', 'time', (time.time()-self.last_time)/self.report_interval, it) + self.last_time = time.time() + self.train_integrator.finalize('train', it) + self.train_integrator.reset_except_hooks() + + if it % self.save_model_interval == 0 and it != 0: + if self.logger is not None: + self.save(it) + + # Backward pass + # This should be done outside autocast + # but I trained it like this and it worked fine + # so I am keeping it this way for reference + self.optimizer.zero_grad(set_to_none=True) + if self.para['amp']: + self.scaler.scale(losses['total_loss']).backward() + self.scaler.step(self.optimizer) + self.scaler.update() + else: + losses['total_loss'].backward() + self.optimizer.step() + self.scheduler.step() + + def save(self, it): + if self.save_path is None: + print('Saving has been disabled.') + return + + os.makedirs(os.path.dirname(self.save_path), exist_ok=True) + model_path = self.save_path + ('_%s.pth' % it) + torch.save(self.STCN.module.state_dict(), model_path) + print('Model saved to %s.' % model_path) + + self.save_checkpoint(it) + + def save_checkpoint(self, it): + if self.save_path is None: + print('Saving has been disabled.') + return + + os.makedirs(os.path.dirname(self.save_path), exist_ok=True) + checkpoint_path = self.save_path + '_checkpoint.pth' + checkpoint = { + 'it': it, + 'network': self.STCN.module.state_dict(), + 'optimizer': self.optimizer.state_dict(), + 'scheduler': self.scheduler.state_dict()} + torch.save(checkpoint, checkpoint_path) + + print('Checkpoint saved to %s.' % checkpoint_path) + + def load_model(self, path): + # This method loads everything and should be used to resume training + map_location = 'cuda:%d' % self.local_rank + checkpoint = torch.load(path, map_location={'cuda:0': map_location}) + + it = checkpoint['it'] + network = checkpoint['network'] + optimizer = checkpoint['optimizer'] + scheduler = checkpoint['scheduler'] + + map_location = 'cuda:%d' % self.local_rank + self.STCN.module.load_state_dict(network) + self.optimizer.load_state_dict(optimizer) + self.scheduler.load_state_dict(scheduler) + + print('Model loaded.') + + return it + + def load_network(self, path): + # This method loads only the network weight and should be used to load a pretrained model + map_location = 'cuda:%d' % self.local_rank + src_dict = torch.load(path, map_location={'cuda:0': map_location}) + + # Maps SO weight (without other_mask) to MO weight (with other_mask) + for k in list(src_dict.keys()): + if k == 'value_encoder.conv1.weight': + if src_dict[k].shape[1] == 4: + pads = torch.zeros((64,1,7,7), device=src_dict[k].device) + nn.init.orthogonal_(pads) + src_dict[k] = torch.cat([src_dict[k], pads], 1) + + self.STCN.module.load_state_dict(src_dict) + print('Network weight loaded:', path) + + def train(self): + self._is_train = True + self._do_log = True + self.integrator = self.train_integrator + # Shall be in eval() mode to freeze BN parameters + self.STCN.eval() + return self + + def val(self): + self._is_train = False + self._do_log = True + self.STCN.eval() + return self + + def test(self): + self._is_train = False + self._do_log = False + self.STCN.eval() + return self + diff --git a/utils/vos/model/modules.py b/utils/vos/model/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..1fb2a668802fde8b234cfef602a30d129154fde3 --- /dev/null +++ b/utils/vos/model/modules.py @@ -0,0 +1,174 @@ +""" +modules.py - This file stores the rathering boring network blocks. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import models + +import utils.vos.model.mod_resnet as mod_resnet +import utils.vos.model.cbam as cbam + + +class ResBlock(nn.Module): + def __init__(self, indim, outdim=None): + super(ResBlock, self).__init__() + if outdim == None: + outdim = indim + if indim == outdim: + self.downsample = None + else: + self.downsample = nn.Conv2d(indim, outdim, kernel_size=3, padding=1) + + self.conv1 = nn.Conv2d(indim, outdim, kernel_size=3, padding=1) + self.conv2 = nn.Conv2d(outdim, outdim, kernel_size=3, padding=1) + + def forward(self, x): + r = self.conv1(F.relu(x)) + r = self.conv2(F.relu(r)) + + if self.downsample is not None: + x = self.downsample(x) + + return x + r + + +class FeatureFusionBlock(nn.Module): + def __init__(self, indim, outdim): + super().__init__() + + self.block1 = ResBlock(indim, outdim) + self.attention = cbam.CBAM(outdim) + self.block2 = ResBlock(outdim, outdim) + + def forward(self, x, f16): + x = torch.cat([x, f16], 1) + x = self.block1(x) + r = self.attention(x) + x = self.block2(x + r) + + return x + + +# Single object version, used only in static image pretraining +# This will be loaded and modified into the multiple objects version later (in stage 1/2/3) +# See model.py (load_network) for the modification procedure +class ValueEncoderSO(nn.Module): + def __init__(self): + super().__init__() + + resnet = mod_resnet.resnet18(pretrained=True, extra_chan=1) + self.conv1 = resnet.conv1 + self.bn1 = resnet.bn1 + self.relu = resnet.relu # 1/2, 64 + self.maxpool = resnet.maxpool + + self.layer1 = resnet.layer1 # 1/4, 64 + self.layer2 = resnet.layer2 # 1/8, 128 + self.layer3 = resnet.layer3 # 1/16, 256 + + self.fuser = FeatureFusionBlock(1024 + 256, 512) + + def forward(self, image, key_f16, mask): + # key_f16 is the feature from the key encoder + + f = torch.cat([image, mask], 1) + + x = self.conv1(f) + x = self.bn1(x) + x = self.relu(x) # 1/2, 64 + x = self.maxpool(x) # 1/4, 64 + x = self.layer1(x) # 1/4, 64 + x = self.layer2(x) # 1/8, 128 + x = self.layer3(x) # 1/16, 256 + + x = self.fuser(x, key_f16) + + return x + + +# Multiple objects version, used in other times +class ValueEncoder(nn.Module): + def __init__(self): + super().__init__() + + resnet = mod_resnet.resnet18(pretrained=True, extra_chan=2) + self.conv1 = resnet.conv1 + self.bn1 = resnet.bn1 + self.relu = resnet.relu # 1/2, 64 + self.maxpool = resnet.maxpool + + self.layer1 = resnet.layer1 # 1/4, 64 + self.layer2 = resnet.layer2 # 1/8, 128 + self.layer3 = resnet.layer3 # 1/16, 256 + + self.fuser = FeatureFusionBlock(1024 + 256, 512) + + def forward(self, image, key_f16, mask, other_masks): + # key_f16 is the feature from the key encoder + + f = torch.cat([image, mask, other_masks], 1) + + x = self.conv1(f) + x = self.bn1(x) + x = self.relu(x) # 1/2, 64 + x = self.maxpool(x) # 1/4, 64 + x = self.layer1(x) # 1/4, 64 + x = self.layer2(x) # 1/8, 128 + x = self.layer3(x) # 1/16, 256 + + x = self.fuser(x, key_f16) + + return x + + +class KeyEncoder(nn.Module): + def __init__(self): + super().__init__() + resnet = models.resnet50(pretrained=True) + self.conv1 = resnet.conv1 + self.bn1 = resnet.bn1 + self.relu = resnet.relu # 1/2, 64 + self.maxpool = resnet.maxpool + + self.res2 = resnet.layer1 # 1/4, 256 + self.layer2 = resnet.layer2 # 1/8, 512 + self.layer3 = resnet.layer3 # 1/16, 1024 + + def forward(self, f): + x = self.conv1(f) + x = self.bn1(x) + x = self.relu(x) # 1/2, 64 + x = self.maxpool(x) # 1/4, 64 + f4 = self.res2(x) # 1/4, 256 + f8 = self.layer2(f4) # 1/8, 512 + f16 = self.layer3(f8) # 1/16, 1024 + + return f16, f8, f4 + + +class UpsampleBlock(nn.Module): + def __init__(self, skip_c, up_c, out_c, scale_factor=2): + super().__init__() + self.skip_conv = nn.Conv2d(skip_c, up_c, kernel_size=3, padding=1) + self.out_conv = ResBlock(up_c, out_c) + self.scale_factor = scale_factor + + def forward(self, skip_f, up_f): + x = self.skip_conv(skip_f) + x = x + F.interpolate(up_f, scale_factor=self.scale_factor, mode='bilinear', align_corners=False) + x = self.out_conv(x) + return x + + +class KeyProjection(nn.Module): + def __init__(self, indim, keydim): + super().__init__() + self.key_proj = nn.Conv2d(indim, keydim, kernel_size=3, padding=1) + + nn.init.orthogonal_(self.key_proj.weight.data) + nn.init.zeros_(self.key_proj.bias.data) + + def forward(self, x): + return self.key_proj(x) diff --git a/utils/vos/model/network.py b/utils/vos/model/network.py new file mode 100644 index 0000000000000000000000000000000000000000..820703030c40e244863c2f5131e6c4de59da166a --- /dev/null +++ b/utils/vos/model/network.py @@ -0,0 +1,156 @@ +""" +network.py - The core of the neural network +Defines the structure and memory operations +Modifed from STM: https://github.com/seoungwugoh/STM + +The trailing number of a variable usually denote the stride +e.g. f16 -> encoded features with stride 16 +""" + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .modules import * + + +class Decoder(nn.Module): + def __init__(self): + super().__init__() + self.compress = ResBlock(1024, 512) + self.up_16_8 = UpsampleBlock(512, 512, 256) # 1/16 -> 1/8 + self.up_8_4 = UpsampleBlock(256, 256, 256) # 1/8 -> 1/4 + + self.pred = nn.Conv2d(256, 1, kernel_size=(3,3), padding=(1,1), stride=1) + + def forward(self, f16, f8, f4): + x = self.compress(f16) + x = self.up_16_8(f8, x) + x = self.up_8_4(f4, x) + + x = self.pred(F.relu(x)) + + x = F.interpolate(x, scale_factor=4, mode='bilinear', align_corners=False) + return x + + +class MemoryReader(nn.Module): + def __init__(self): + super().__init__() + + def get_affinity(self, mk, qk): + B, CK, T, H, W = mk.shape + mk = mk.flatten(start_dim=2) + qk = qk.flatten(start_dim=2) + + # See supplementary material + a_sq = mk.pow(2).sum(1).unsqueeze(2) + ab = mk.transpose(1, 2) @ qk + + affinity = (2*ab-a_sq) / math.sqrt(CK) # B, THW, HW + + # softmax operation; aligned the evaluation style + maxes = torch.max(affinity, dim=1, keepdim=True)[0] + x_exp = torch.exp(affinity - maxes) + x_exp_sum = torch.sum(x_exp, dim=1, keepdim=True) + affinity = x_exp / x_exp_sum + + return affinity + + def readout(self, affinity, mv, qv): + B, CV, T, H, W = mv.shape + + mo = mv.view(B, CV, T*H*W) + mem = torch.bmm(mo, affinity) # Weighted-sum B, CV, HW + mem = mem.view(B, CV, H, W) + + mem_out = torch.cat([mem, qv], dim=1) + + return mem_out + + +class STCN(nn.Module): + def __init__(self, single_object): + super(STCN, self).__init__() + self.single_object = single_object + + self.key_encoder = KeyEncoder() + if single_object: + self.value_encoder = ValueEncoderSO() + else: + self.value_encoder = ValueEncoder() + + # Projection from f16 feature space to key space + self.key_proj = KeyProjection(1024, keydim=64) + + # Compress f16 a bit to use in decoding later on + self.key_comp = nn.Conv2d(1024, 512, kernel_size=3, padding=1) + + self.memory = MemoryReader() + self.decoder = Decoder() + + def aggregate(self, prob): + new_prob = torch.cat([ + torch.prod(1-prob, dim=1, keepdim=True), + prob + ], 1).clamp(1e-7, 1-1e-7) + logits = torch.log((new_prob /(1-new_prob))) + return logits + + def encode_key(self, frame): + # input: b*t*c*h*w + b, t = frame.shape[:2] + + f16, f8, f4 = self.key_encoder(frame.flatten(start_dim=0, end_dim=1)) + k16 = self.key_proj(f16) + f16_thin = self.key_comp(f16) + + # B*C*T*H*W + k16 = k16.view(b, t, *k16.shape[-3:]).transpose(1, 2).contiguous() + + # B*T*C*H*W + f16_thin = f16_thin.view(b, t, *f16_thin.shape[-3:]) + f16 = f16.view(b, t, *f16.shape[-3:]) + f8 = f8.view(b, t, *f8.shape[-3:]) + f4 = f4.view(b, t, *f4.shape[-3:]) + + return k16, f16_thin, f16, f8, f4 + + def encode_value(self, frame, kf16, mask, other_mask=None): + # Extract memory key/value for a frame + if self.single_object: + f16 = self.value_encoder(frame, kf16, mask) + else: + f16 = self.value_encoder(frame, kf16, mask, other_mask) + return f16.unsqueeze(2) # B*512*T*H*W + + def segment(self, qk16, qv16, qf8, qf4, mk16, mv16, selector=None): + # q - query, m - memory + # qv16 is f16_thin above + affinity = self.memory.get_affinity(mk16, qk16) + + if self.single_object: + logits = self.decoder(self.memory.readout(affinity, mv16, qv16), qf8, qf4) + prob = torch.sigmoid(logits) + else: + logits = self.decoder(self.memory.readout(affinity, mv16, qv16), qf8, qf4) + prob = torch.sigmoid(logits) + + logits = self.aggregate(prob) + prob = F.softmax(logits, dim=1)[:, 1:] + + return logits, prob + + def forward(self, mode, *args, **kwargs): + if mode == 'encode_key': + return self.encode_key(*args, **kwargs) + elif mode == 'encode_value': + return self.encode_value(*args, **kwargs) + elif mode == 'segment': + return self.segment(*args, **kwargs) + else: + raise NotImplementedError + + diff --git a/utils/vos/util/__init__.py b/utils/vos/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/vos/util/__pycache__/__init__.cpython-310.pyc b/utils/vos/util/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dcfaf5b9df47dbfdb57c90a828db6f8def1a8855 Binary files /dev/null and b/utils/vos/util/__pycache__/__init__.cpython-310.pyc differ diff --git a/utils/vos/util/__pycache__/__init__.cpython-38.pyc b/utils/vos/util/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..521cc4dc7c1c9fc4340f2c47d59b3fa9b3a2f39d Binary files /dev/null and b/utils/vos/util/__pycache__/__init__.cpython-38.pyc differ diff --git a/utils/vos/util/__pycache__/__init__.cpython-39.pyc b/utils/vos/util/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b68a772fc54d3d262fda78bcd4d4bd8375026fe2 Binary files /dev/null and b/utils/vos/util/__pycache__/__init__.cpython-39.pyc differ diff --git a/utils/vos/util/__pycache__/log_integrator.cpython-310.pyc b/utils/vos/util/__pycache__/log_integrator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71e235f4c85edc23238b6da114d281ef2fde23b7 Binary files /dev/null and b/utils/vos/util/__pycache__/log_integrator.cpython-310.pyc differ diff --git a/utils/vos/util/__pycache__/log_integrator.cpython-38.pyc b/utils/vos/util/__pycache__/log_integrator.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51023fcd3cff416eba59913b38d99f0fb02b05e9 Binary files /dev/null and b/utils/vos/util/__pycache__/log_integrator.cpython-38.pyc differ diff --git a/utils/vos/util/__pycache__/log_integrator.cpython-39.pyc b/utils/vos/util/__pycache__/log_integrator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c1b1b8f75ae64e56f32a526603183c363d1fcad Binary files /dev/null and b/utils/vos/util/__pycache__/log_integrator.cpython-39.pyc differ diff --git a/utils/vos/util/__pycache__/tensor_util.cpython-310.pyc b/utils/vos/util/__pycache__/tensor_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ca165a071b624672f33418fa89f996baa6d632b Binary files /dev/null and b/utils/vos/util/__pycache__/tensor_util.cpython-310.pyc differ diff --git a/utils/vos/util/__pycache__/tensor_util.cpython-38.pyc b/utils/vos/util/__pycache__/tensor_util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e5d135ab8bfd99a9ef60981839115231cd2a9da Binary files /dev/null and b/utils/vos/util/__pycache__/tensor_util.cpython-38.pyc differ diff --git a/utils/vos/util/__pycache__/tensor_util.cpython-39.pyc b/utils/vos/util/__pycache__/tensor_util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e55bcfe2c673dc0827ee4c981185181d30ab7d9 Binary files /dev/null and b/utils/vos/util/__pycache__/tensor_util.cpython-39.pyc differ diff --git a/utils/vos/util/davis_subset.txt b/utils/vos/util/davis_subset.txt new file mode 100644 index 0000000000000000000000000000000000000000..875c2409d2cc4cfc4491ebf7703cb432b26678d8 --- /dev/null +++ b/utils/vos/util/davis_subset.txt @@ -0,0 +1,60 @@ +bear +bmx-bumps +boat +boxing-fisheye +breakdance-flare +bus +car-turn +cat-girl +classic-car +color-run +crossing +dance-jump +dancing +disc-jockey +dog-agility +dog-gooses +dogs-scale +drift-turn +drone +elephant +flamingo +hike +hockey +horsejump-low +kid-football +kite-walk +koala +lady-running +lindy-hop +longboard +lucia +mallard-fly +mallard-water +miami-surf +motocross-bumps +motorbike +night-race +paragliding +planes-water +rallye +rhino +rollerblade +schoolgirls +scooter-board +scooter-gray +sheep +skate-park +snowboard +soccerball +stroller +stunt +surf +swing +tennis +tractor-sand +train +tuk-tuk +upside-down +varanus-cage +walking \ No newline at end of file diff --git a/utils/vos/util/hyper_para.py b/utils/vos/util/hyper_para.py new file mode 100644 index 0000000000000000000000000000000000000000..dc21a63bf6b9c98cedc24acb392222dbaa844b31 --- /dev/null +++ b/utils/vos/util/hyper_para.py @@ -0,0 +1,91 @@ +from argparse import ArgumentParser + + +def none_or_default(x, default): + return x if x is not None else default + +class HyperParameters(): + def parse(self, unknown_arg_ok=False): + parser = ArgumentParser() + + # Enable torch.backends.cudnn.benchmark -- Faster in some cases, test in your own environment + parser.add_argument('--benchmark', action='store_true') + parser.add_argument('--no_amp', action='store_true') + + # Data parameters + parser.add_argument('--static_root', help='Static training data root', default='../static') + parser.add_argument('--bl_root', help='Blender training data root', default='../BL30K') + parser.add_argument('--yv_root', help='YouTubeVOS data root', default='../YouTube') + parser.add_argument('--davis_root', help='DAVIS data root', default='../DAVIS') + + parser.add_argument('--stage', help='Training stage (0-static images, 1-Blender dataset, 2-DAVIS+YouTubeVOS (300K), 3-DAVIS+YouTubeVOS (150K))', type=int, default=0) + parser.add_argument('--num_workers', help='Number of datalaoder workers per process', type=int, default=8) + + # Generic learning parameters + parser.add_argument('-b', '--batch_size', help='Default is dependent on the training stage, see below', default=None, type=int) + parser.add_argument('-i', '--iterations', help='Default is dependent on the training stage, see below', default=None, type=int) + parser.add_argument('--steps', help='Default is dependent on the training stage, see below', nargs="*", default=None, type=int) + + parser.add_argument('--lr', help='Initial learning rate', type=float) + parser.add_argument('--gamma', help='LR := LR*gamma at every decay step', default=0.1, type=float) + + # Loading + parser.add_argument('--load_network', help='Path to pretrained network weight only') + parser.add_argument('--load_model', help='Path to the model file, including network, optimizer and such') + + # Logging information + parser.add_argument('--id', help='Experiment UNIQUE id, use NULL to disable logging to tensorboard', default='NULL') + parser.add_argument('--debug', help='Debug mode which logs information more often', action='store_true') + + # Multiprocessing parameters, not set by users + parser.add_argument('--local_rank', default=0, type=int, help='Local rank of this process') + + if unknown_arg_ok: + args, _ = parser.parse_known_args() + self.args = vars(args) + else: + self.args = vars(parser.parse_args()) + + self.args['amp'] = not self.args['no_amp'] + + # Stage-dependent hyperparameters + # Assign default if not given + if self.args['stage'] == 0: + # Static image pretraining + self.args['lr'] = none_or_default(self.args['lr'], 1e-5) + self.args['batch_size'] = none_or_default(self.args['batch_size'], 8) + self.args['iterations'] = none_or_default(self.args['iterations'], 300000) + self.args['steps'] = none_or_default(self.args['steps'], [150000]) + self.args['single_object'] = True + elif self.args['stage'] == 1: + # BL30K pretraining + self.args['lr'] = none_or_default(self.args['lr'], 1e-5) + self.args['batch_size'] = none_or_default(self.args['batch_size'], 4) + self.args['iterations'] = none_or_default(self.args['iterations'], 500000) + self.args['steps'] = none_or_default(self.args['steps'], [400000]) + self.args['single_object'] = False + elif self.args['stage'] == 2: + # 300K main training for after BL30K + self.args['lr'] = none_or_default(self.args['lr'], 1e-5) + self.args['batch_size'] = none_or_default(self.args['batch_size'], 4) + self.args['iterations'] = none_or_default(self.args['iterations'], 300000) + self.args['steps'] = none_or_default(self.args['steps'], [250000]) + self.args['single_object'] = False + elif self.args['stage'] == 3: + # 150K main training for after static image pretraining + self.args['lr'] = none_or_default(self.args['lr'], 1e-5) + self.args['batch_size'] = none_or_default(self.args['batch_size'], 4) + self.args['iterations'] = none_or_default(self.args['iterations'], 150000) + self.args['steps'] = none_or_default(self.args['steps'], [125000]) + self.args['single_object'] = False + else: + raise NotImplementedError + + def __getitem__(self, key): + return self.args[key] + + def __setitem__(self, key, value): + self.args[key] = value + + def __str__(self): + return str(self.args) diff --git a/utils/vos/util/image_saver.py b/utils/vos/util/image_saver.py new file mode 100644 index 0000000000000000000000000000000000000000..57276878b6381d311466d31293e71ac86712c829 --- /dev/null +++ b/utils/vos/util/image_saver.py @@ -0,0 +1,137 @@ +import cv2 +import numpy as np + +import torch +from dataset.range_transform import inv_im_trans +from collections import defaultdict + +def tensor_to_numpy(image): + image_np = (image.numpy() * 255).astype('uint8') + return image_np + +def tensor_to_np_float(image): + image_np = image.numpy().astype('float32') + return image_np + +def detach_to_cpu(x): + return x.detach().cpu() + +def transpose_np(x): + return np.transpose(x, [1,2,0]) + +def tensor_to_gray_im(x): + x = detach_to_cpu(x) + x = tensor_to_numpy(x) + x = transpose_np(x) + return x + +def tensor_to_im(x): + x = detach_to_cpu(x) + x = inv_im_trans(x).clamp(0, 1) + x = tensor_to_numpy(x) + x = transpose_np(x) + return x + +def tensor_to_seg(x): + x = detach_to_cpu(x) + x = inv_seg_trans(x).clamp(0, 1) + x = tensor_to_numpy(x) + x = transpose_np(x) + return x + +# Predefined key <-> caption dict +key_captions = { + 'im': 'Image', + 'gt': 'GT', +} + +""" +Return an image array with captions +keys in dictionary will be used as caption if not provided +values should contain lists of cv2 images +""" +def get_image_array(images, grid_shape, captions={}): + h, w = grid_shape + cate_counts = len(images) + rows_counts = len(next(iter(images.values()))) + + font = cv2.FONT_HERSHEY_SIMPLEX + + output_image = np.zeros([w*cate_counts, h*(rows_counts+1), 3], dtype=np.uint8) + col_cnt = 0 + for k, v in images.items(): + + # Default as key value itself + caption = captions.get(k, k) + + # Handles new line character + dy = 40 + for i, line in enumerate(caption.split('\n')): + cv2.putText(output_image, line, (10, col_cnt*w+100+i*dy), + font, 0.8, (255,255,255), 2, cv2.LINE_AA) + + # Put images + for row_cnt, img in enumerate(v): + im_shape = img.shape + if len(im_shape) == 2: + img = img[..., np.newaxis] + + img = (img * 255).astype('uint8') + + output_image[(col_cnt+0)*w:(col_cnt+1)*w, + (row_cnt+1)*h:(row_cnt+2)*h, :] = img + + col_cnt += 1 + + return output_image + +def base_transform(im, size): + im = tensor_to_np_float(im) + if len(im.shape) == 3: + im = im.transpose((1, 2, 0)) + else: + im = im[:, :, None] + + # Resize + if im.shape[1] != size: + im = cv2.resize(im, size, interpolation=cv2.INTER_NEAREST) + + return im.clip(0, 1) + +def im_transform(im, size): + return base_transform(inv_im_trans(detach_to_cpu(im)), size=size) + +def mask_transform(mask, size): + return base_transform(detach_to_cpu(mask), size=size) + +def out_transform(mask, size): + return base_transform(detach_to_cpu(torch.sigmoid(mask)), size=size) + +def pool_pairs(images, size, so): + req_images = defaultdict(list) + + b, s, _, _, _ = images['gt'].shape + + # limit number of images to save disk space + b = max(2, b) + + GT_name = 'GT' + for b_idx in range(b): + GT_name += ' %s\n' % images['info']['name'][b_idx] + + for b_idx in range(b): + for s_idx in range(s): + req_images['RGB'].append(im_transform(images['rgb'][b_idx,s_idx], size)) + if s_idx == 0: + req_images['Mask'].append(np.zeros((size[1], size[0], 3))) + if not so: + req_images['Mask 2'].append(np.zeros((size[1], size[0], 3))) + else: + req_images['Mask'].append(mask_transform(images['mask_%d'%s_idx][b_idx], size)) + if not so: + req_images['Mask 2'].append(mask_transform(images['sec_mask_%d'%s_idx][b_idx], size)) + req_images[GT_name].append(mask_transform(images['gt'][b_idx,s_idx], size)) + if not so: + req_images[GT_name + '_2'].append(mask_transform(images['sec_gt'][b_idx,s_idx], size)) + + return get_image_array(req_images, size, key_captions) \ No newline at end of file diff --git a/utils/vos/util/load_subset.py b/utils/vos/util/load_subset.py new file mode 100644 index 0000000000000000000000000000000000000000..3191f4fef05cec04a11eafdfa42b34b98a35549e --- /dev/null +++ b/utils/vos/util/load_subset.py @@ -0,0 +1,16 @@ +""" +load_subset.py - Presents a subset of data +DAVIS - only the training set +YouTubeVOS - I manually filtered some erroneous ones out but I haven't checked all +""" + + +def load_sub_davis(path='util/davis_subset.txt'): + with open(path, mode='r') as f: + subset = set(f.read().splitlines()) + return subset + +def load_sub_yv(path='util/yv_subset.txt'): + with open(path, mode='r') as f: + subset = set(f.read().splitlines()) + return subset diff --git a/utils/vos/util/log_integrator.py b/utils/vos/util/log_integrator.py new file mode 100644 index 0000000000000000000000000000000000000000..e4b26d53de98b16e145090bcddf2041a3f2d1394 --- /dev/null +++ b/utils/vos/util/log_integrator.py @@ -0,0 +1,80 @@ +""" +Integrate numerical values for some iterations +Typically used for loss computation / logging to tensorboard +Call finalize and create a new Integrator when you want to display/log +""" + +import torch + + +class Integrator: + def __init__(self, logger, distributed=True, local_rank=0, world_size=1): + self.values = {} + self.counts = {} + self.hooks = [] # List is used here to maintain insertion order + + self.logger = logger + + self.distributed = distributed + self.local_rank = local_rank + self.world_size = world_size + + def add_tensor(self, key, tensor): + if key not in self.values: + self.counts[key] = 1 + if type(tensor) == float or type(tensor) == int: + self.values[key] = tensor + else: + self.values[key] = tensor.mean().item() + else: + self.counts[key] += 1 + if type(tensor) == float or type(tensor) == int: + self.values[key] += tensor + else: + self.values[key] += tensor.mean().item() + + def add_dict(self, tensor_dict): + for k, v in tensor_dict.items(): + self.add_tensor(k, v) + + def add_hook(self, hook): + """ + Adds a custom hook, i.e. compute new metrics using values in the dict + The hook takes the dict as argument, and returns a (k, v) tuple + e.g. for computing IoU + """ + if type(hook) == list: + self.hooks.extend(hook) + else: + self.hooks.append(hook) + + def reset_except_hooks(self): + self.values = {} + self.counts = {} + + # Average and output the metrics + def finalize(self, prefix, it, f=None): + + for hook in self.hooks: + k, v = hook(self.values) + self.add_tensor(k, v) + + for k, v in self.values.items(): + + if k[:4] == 'hide': + continue + + avg = v / self.counts[k] + + if self.distributed: + # Inplace operation + avg = torch.tensor(avg).cuda() + torch.distributed.reduce(avg, dst=0) + + if self.local_rank == 0: + avg = (avg/self.world_size).cpu().item() + self.logger.log_metrics(prefix, k, avg, it, f) + else: + # Simple does it + self.logger.log_metrics(prefix, k, avg, it, f) + diff --git a/utils/vos/util/logger.py b/utils/vos/util/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..08e409df0e2e850aa984161f174029f6b396330f --- /dev/null +++ b/utils/vos/util/logger.py @@ -0,0 +1,103 @@ +""" +Dumps things to tensorboard and console +""" + +import os +import warnings +import git + +import torchvision.transforms as transforms +from torch.utils.tensorboard import SummaryWriter + + +def tensor_to_numpy(image): + image_np = (image.numpy() * 255).astype('uint8') + return image_np + +def detach_to_cpu(x): + return x.detach().cpu() + +def fix_width_trunc(x): + return ('{:.9s}'.format('{:0.9f}'.format(x))) + +class TensorboardLogger: + def __init__(self, short_id, id): + self.short_id = short_id + if self.short_id == 'NULL': + self.short_id = 'DEBUG' + + if id is None: + self.no_log = True + warnings.warn('Logging has been disbaled.') + else: + self.no_log = False + + self.inv_im_trans = transforms.Normalize( + mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], + std=[1/0.229, 1/0.224, 1/0.225]) + + self.inv_seg_trans = transforms.Normalize( + mean=[-0.5/0.5], + std=[1/0.5]) + + log_path = os.path.join('.', 'log', '%s' % id) + self.logger = SummaryWriter(log_path) + + repo = git.Repo(".") + self.log_string('git', str(repo.active_branch) + ' ' + str(repo.head.commit.hexsha)) + + def log_scalar(self, tag, x, step): + if self.no_log: + warnings.warn('Logging has been disabled.') + return + self.logger.add_scalar(tag, x, step) + + def log_metrics(self, l1_tag, l2_tag, val, step, f=None): + tag = l1_tag + '/' + l2_tag + text = '{:s} - It {:6d} [{:5s}] [{:13}]: {:s}'.format(self.short_id, step, l1_tag.upper(), l2_tag, fix_width_trunc(val)) + print(text) + if f is not None: + f.write(text + '\n') + f.flush() + self.log_scalar(tag, val, step) + + def log_im(self, tag, x, step): + if self.no_log: + warnings.warn('Logging has been disabled.') + return + x = detach_to_cpu(x) + x = self.inv_im_trans(x) + x = tensor_to_numpy(x) + self.logger.add_image(tag, x, step) + + def log_cv2(self, tag, x, step): + if self.no_log: + warnings.warn('Logging has been disabled.') + return + x = x.transpose((2, 0, 1)) + self.logger.add_image(tag, x, step) + + def log_seg(self, tag, x, step): + if self.no_log: + warnings.warn('Logging has been disabled.') + return + x = detach_to_cpu(x) + x = self.inv_seg_trans(x) + x = tensor_to_numpy(x) + self.logger.add_image(tag, x, step) + + def log_gray(self, tag, x, step): + if self.no_log: + warnings.warn('Logging has been disabled.') + return + x = detach_to_cpu(x) + x = tensor_to_numpy(x) + self.logger.add_image(tag, x, step) + + def log_string(self, tag, x): + print(tag, x) + if self.no_log: + warnings.warn('Logging has been disabled.') + return + self.logger.add_text(tag, x) + \ No newline at end of file diff --git a/utils/vos/util/tensor_util.py b/utils/vos/util/tensor_util.py new file mode 100644 index 0000000000000000000000000000000000000000..30d3acbae98c4cba7defa40a86c6ca112db837c1 --- /dev/null +++ b/utils/vos/util/tensor_util.py @@ -0,0 +1,41 @@ +import torch.nn.functional as F + +def compute_tensor_iu(seg, gt): + intersection = (seg & gt).float().sum() + union = (seg | gt).float().sum() + + return intersection, union + +def compute_tensor_iou(seg, gt): + intersection, union = compute_tensor_iu(seg, gt) + iou = (intersection + 1e-6) / (union + 1e-6) + + return iou + +# STM +def pad_divide_by(in_img, d, in_size=None): + if in_size is None: + h, w = in_img.shape[-2:] + else: + h, w = in_size + + if h % d > 0: + new_h = h + d - h % d + else: + new_h = h + if w % d > 0: + new_w = w + d - w % d + else: + new_w = w + lh, uh = int((new_h-h) / 2), int(new_h-h) - int((new_h-h) / 2) + lw, uw = int((new_w-w) / 2), int(new_w-w) - int((new_w-w) / 2) + pad_array = (int(lw), int(uw), int(lh), int(uh)) + out = F.pad(in_img, pad_array) + return out, pad_array + +def unpad(img, pad): + if pad[2]+pad[3] > 0: + img = img[:,:,pad[2]:-pad[3],:] + if pad[0]+pad[1] > 0: + img = img[:,:,:,pad[0]:-pad[1]] + return img \ No newline at end of file diff --git a/utils/vos/util/yv_subset.txt b/utils/vos/util/yv_subset.txt new file mode 100644 index 0000000000000000000000000000000000000000..a26e50a7b8e6233bf17c542b540765cd8a1c5716 --- /dev/null +++ b/utils/vos/util/yv_subset.txt @@ -0,0 +1,3464 @@ +003234408d +0043f083b5 +0044fa5fba +005a527edd +0065b171f9 +00917dcfc4 +00a23ccf53 +00ad5016a4 +01082ae388 +011ac0a06f +013099c098 +0155498c85 +01694ad9c8 +017ac35701 +01b80e8e1a +01baa5a4e1 +01c3111683 +01c4cb5ffe +01c76f0a82 +01c783268c +01ed275c6e +01ff60d1fa +020cd28cd2 +02264db755 +0248626d9a +02668dbffa +0274193026 +02d28375aa +02f3a5c4df +031ccc99b1 +0321b18c10 +0348a45bca +0355e92655 +0358b938c1 +0368107cf1 +0379ddf557 +038b2cc71d +038c15a5dd +03a06cc98a +03a63e187f +03c95b4dae +03e2b57b0e +04194e1248 +0444918a5f +04460a7a52 +04474174a4 +0450095513 +045f00aed2 +04667fabaa +04735c5030 +04990d1915 +04d62d9d98 +04f21da964 +04fbad476e +04fe256562 +0503bf89c9 +0536c9eed0 +054acb238f +05579ca250 +056c200404 +05774f3a2c +058a7592c8 +05a0a513df +05a569d8aa +05aa652648 +05d7715782 +05e0b0f28f +05fdbbdd7a +05ffcfed85 +0630391881 +06840b2bbe +068f7dce6f +0693719753 +06ce2b51fb +06e224798e +06ee361788 +06fbb3fa2c +0700264286 +070c918ca7 +07129e14a4 +07177017e9 +07238ffc58 +07353b2a89 +0738493cbf +075926c651 +075c701292 +0762ea9a30 +07652ee4af +076f206928 +077d32af19 +079049275c +07913cdda7 +07a11a35e8 +07ac33b6df +07b6e8fda8 +07c62c3d11 +07cc1c7d74 +080196ef01 +081207976e +081ae4fa44 +081d8250cb +082900c5d4 +0860df21e2 +0866d4c5e3 +0891ac2eb6 +08931bc458 +08aa2705d5 +08c8450db7 +08d50b926c +08e1e4de15 +08e48c1a48 +08f561c65e +08feb87790 +09049f6fe3 +092e4ff450 +09338adea8 +093c335ccc +0970d28339 +0974a213dc +097b471ed8 +0990941758 +09a348f4fa +09a6841288 +09c5bad17b +09c9ce80c7 +09ff54fef4 +0a23765d15 +0a275e7f12 +0a2f2bd294 +0a7a2514aa +0a7b27fde9 +0a8c467cc3 +0ac8c560ae +0b1627e896 +0b285c47f6 +0b34ec1d55 +0b5b5e8e5a +0b68535614 +0b6f9105fc +0b7dbfa3cb +0b9cea51ca +0b9d012be8 +0bcfc4177d +0bd37b23c1 +0bd864064c +0c11c6bf7b +0c26bc77ac +0c3a04798c +0c44a9d545 +0c817cc390 +0ca839ee9a +0cd7ac0ac0 +0ce06e0121 +0cfe974a89 +0d2fcc0dcd +0d3aad05d2 +0d40b015f4 +0d97fba242 +0d9cc80d7e +0dab85b6d3 +0db5c427a5 +0dbaf284f1 +0de4923598 +0df28a9101 +0e04f636c4 +0e05f0e232 +0e0930474b +0e27472bea +0e30020549 +0e621feb6c +0e803c7d73 +0e9ebe4e3c +0e9f2785ec +0ea68d418b +0eb403a222 +0ee92053d6 +0eefca067f +0f17fa6fcb +0f1ac8e9a3 +0f202e9852 +0f2ab8b1ff +0f51a78756 +0f5fbe16b0 +0f6072077b +0f6b69b2f4 +0f6c2163de +0f74ec5599 +0f9683715b +0fa7b59356 +0fb173695b +0fc958cde2 +0fe7b1a621 +0ffcdb491c +101caff7d4 +1022fe8417 +1032e80b37 +103f501680 +104e64565f +104f1ab997 +106242403f +10b31f5431 +10eced835e +110d26fa3a +1122c1d16a +1145b49a5f +11485838c2 +114e7676ec +1157472b95 +115ee1072c +1171141012 +117757b4b8 +1178932d2f +117cc76bda +1180cbf814 +1187bbd0e3 +1197e44b26 +119cf20728 +119dd54871 +11a0c3b724 +11a6ba8c94 +11c722a456 +11cbcb0b4d +11ccf5e99d +11ce6f452e +11e53de6f2 +11feabe596 +120cb9514d +12156b25b3 +122896672d +1232b2f1d4 +1233ac8596 +1239c87234 +1250423f7c +1257a1bc67 +125d1b19dd +126d203967 +1295e19071 +12ad198c54 +12bddb2bcb +12ec9b93ee +12eebedc35 +132852e094 +1329409f2a +13325cfa14 +134d06dbf9 +135625b53d +13870016f9 +13960b3c84 +13adaad9d9 +13ae097e20 +13e3070469 +13f6a8c20d +1416925cf2 +142d2621f5 +145d5d7c03 +145fdc3ac5 +1471274fa7 +14a6b5a139 +14c21cea0d +14dae0dc93 +14f9bd22b5 +14fd28ae99 +15097d5d4e +150ea711f2 +1514e3563f +152aaa3a9e +152b7d3bd7 +15617297cc +15abbe0c52 +15d1fb3de5 +15f67b0fab +161eb59aad +16288ea47f +164410ce62 +165c3c8cd4 +165c42b41b +165ec9e22b +1669502269 +16763cccbb +16adde065e +16af445362 +16afd538ad +16c3fa4d5d +16d1d65c27 +16e8599e94 +16fe9fb444 +1705796b02 +1724db7671 +17418e81ea +175169edbb +17622326fd +17656bae77 +17b0d94172 +17c220e4f6 +17c7bcd146 +17cb4afe89 +17cd79a434 +17d18604c3 +17d8ca1a37 +17e33f4330 +17f7a6d805 +180abc8378 +183ba3d652 +185bf64702 +18913cc690 +1892651815 +189ac8208a +189b44e92c +18ac264b76 +18b245ab49 +18b5cebc34 +18bad52083 +18bb5144d5 +18c6f205c5 +1903f9ea15 +1917b209f2 +191e74c01d +19367bb94e +193ffaa217 +19696b67d3 +197f3ab6f3 +1981e763cc +198afe39ae +19a6e62b9b +19b60d5335 +19c00c11f9 +19e061eb88 +19e8bc6178 +19ee80dac6 +1a25a9170a +1a359a6c1a +1a3e87c566 +1a5fe06b00 +1a6c0fbd1e +1a6f3b5a4b +1a8afbad92 +1a8bdc5842 +1a95752aca +1a9c131cb7 +1aa3da3ee3 +1ab27ec7ea +1abf16d21d +1acd0f993b +1ad202e499 +1af8d2395d +1afd39a1fa +1b2d31306f +1b3fa67f0e +1b43fa74b4 +1b73ea9fc2 +1b7e8bb255 +1b8680f8cd +1b883843c0 +1b8898785b +1b88ba1aa4 +1b96a498e5 +1bbc4c274f +1bd87fe9ab +1c4090c75b +1c41934f84 +1c72b04b56 +1c87955a3a +1c9f9eb792 +1ca240fede +1ca5673803 +1cada35274 +1cb44b920d +1cd10e62be +1d3087d5e5 +1d3685150a +1d6ff083aa +1d746352a6 +1da256d146 +1da4e956b1 +1daf812218 +1dba687bce +1dce57d05d +1de4a9e537 +1dec5446c8 +1dfbe6f586 +1e1a18c45a +1e1e42529d +1e4be70796 +1eb60959c8 +1ec8b2566b +1ecdc2941c +1ee0ac70ff +1ef8e17def +1f1a2a9fc0 +1f1beb8daa +1f2609ee13 +1f3876f8d0 +1f4ec0563d +1f64955634 +1f7d31b5b2 +1f8014b7fd +1f9c7d10f1 +1fa350df76 +1fc9538993 +1fe2f0ec59 +2000c02f9d +20142b2f05 +201a8d75e5 +2023b3ee4f +202b767bbc +203594a418 +2038987336 +2039c3aecb +204a90d81f +207bc6cf01 +208833d1d1 +20c6d8b362 +20e3e52e0a +2117fa0c14 +211bc5d102 +2120d9c3c3 +2125235a49 +21386f5978 +2142af8795 +215dfc0f73 +217bae91e5 +217c0d44e4 +219057c87b +21d0edbf81 +21df87ad76 +21f1d089f5 +21f4019116 +222597030f +222904eb5b +223a0e0657 +223bd973ab +22472f7395 +224e7c833e +225aba51d9 +2261d421ea +2263a8782b +2268cb1ffd +2268e93b0a +2293c99f3f +22a1141970 +22b13084b2 +22d9f5ab0c +22f02efe3a +232c09b75b +2350d71b4b +2376440551 +2383d8aafd +238b84e67f +238d4b86f6 +238d947c6b +23993ce90d +23b0c8a9ab +23b3beafcc +23d80299fe +23f404a9fc +240118e58a +2431dec2fd +24440e0ac7 +2457274dbc +2465bf515d +246b142c4d +247d729e36 +2481ceafeb +24866b4e6a +2489d78320 +24ab0b83e8 +24b0868d92 +24b5207cd9 +24ddf05c03 +250116161c +256ad2e3fc +256bd83d5e +256dcc8ab8 +2589956baa +258b3b33c6 +25ad437e29 +25ae395636 +25c750c6db +25d2c3fe5d +25dc80db7c +25f97e926f +26011bc28b +260846ffbe +260dd9ad33 +267964ee57 +2680861931 +268ac7d3fc +26b895d91e +26bc786d4f +26ddd2ef12 +26de3d18ca +26f7784762 +2703e52a6a +270ed80c12 +2719b742ab +272f4163d0 +27303333e1 +27659fa7d6 +279214115d +27a5f92a9c +27cf2af1f3 +27f0d5f8a2 +28075f33c1 +281629cb41 +282b0d51f5 +282fcab00b +28449fa0dc +28475208ca +285580b7c4 +285b69e223 +288c117201 +28a8eb9623 +28bf9c3cf3 +28c6b8f86a +28c972dacd +28d9fa6016 +28e392de91 +28f4a45190 +298c844fc9 +29a0356a2b +29d779f9e3 +29dde5f12b +29de7b6579 +29e630bdd0 +29f2332d30 +2a18873352 +2a3824ff31 +2a559dd27f +2a5c09acbd +2a63eb1524 +2a6a30a4ea +2a6d9099d1 +2a821394e3 +2a8c5b1342 +2abc8d66d2 +2ac9ef904a +2b08f37364 +2b351bfd7d +2b659a49d7 +2b69ee5c26 +2b6c30bbbd +2b88561cf2 +2b8b14954e +2ba621c750 +2bab50f9a7 +2bb00c2434 +2bbde474ef +2bdd82fb86 +2be06fb855 +2bf545c2f5 +2bffe4cf9a +2c04b887b7 +2c05209105 +2c0ad8cf39 +2c11fedca8 +2c1a94ebfb +2c1e8c8e2f +2c29fabcf1 +2c2c076c01 +2c3ea7ee7d +2c41fa0648 +2c44bb6d1c +2c54cfbb78 +2c5537eddf +2c6e63b7de +2cb10c6a7e +2cbcd5ccd1 +2cc5d9c5f6 +2cd01cf915 +2cdbf5f0a7 +2ce660f123 +2cf114677e +2d01eef98e +2d03593bdc +2d183ac8c4 +2d33ad3935 +2d3991d83e +2d4333577b +2d4d015c64 +2d8f5e5025 +2d900bdb8e +2d9a1a1d49 +2db0576a5c +2dc0838721 +2dcc417f82 +2df005b843 +2df356de14 +2e00393d96 +2e03b8127a +2e0f886168 +2e2bf37e6d +2e42410932 +2ea78f46e4 +2ebb017a26 +2ee2edba2a +2efb07554a +2f17e4fc1e +2f2c65c2f3 +2f2d9b33be +2f309c206b +2f53822e88 +2f53998171 +2f5b0c89b1 +2f680909e6 +2f710f66bd +2f724132b9 +2f7e3517ae +2f96f5fc6f +2f97d9fecb +2fbfa431ec +2fc9520b53 +2fcd9f4c62 +2feb30f208 +2ff7f5744f +30085a2cc6 +30176e3615 +301f72ee11 +3026bb2f61 +30318465dc +3054ca937d +306121e726 +3064ad91e8 +307444a47f +307bbb7409 +30a20194ab +30c35c64a4 +30dbdb2cd6 +30fc77d72f +310021b58b +3113140ee8 +3150b2ee57 +31539918c4 +318dfe2ce2 +3193da4835 +319f725ad9 +31bbd0d793 +322505c47f +322b237865 +322da43910 +3245e049fb +324c4c38f6 +324e35111a +3252398f09 +327dc4cabf +328d918c7d +3290c0de97 +3299ae3116 +32a7cd687b +33098cedb4 +3332334ac4 +334cb835ac +3355e056eb +33639a2847 +3373891cdc +337975816b +33e29d7e91 +34046fe4f2 +3424f58959 +34370a710f +343bc6a65a +3450382ef7 +3454303a08 +346aacf439 +346e92ff37 +34a5ece7dd +34b109755a +34d1b37101 +34dd2c70a7 +34efa703df +34fbee00a6 +3504df2fda +35195a56a1 +351c822748 +351cfd6bc5 +3543d8334c +35573455c7 +35637a827f +357a710863 +358bf16f9e +35ab34cc34 +35c6235b8d +35d01a438a +3605019d3b +3609bc3f88 +360e25da17 +36299c687c +362c5bc56e +3649228783 +365b0501ea +365f459863 +369893f3ad +369c9977e1 +369dde050a +36c7dac02f +36d5b1493b +36f5cc68fd +3735480d18 +374b479880 +375a49d38f +375a5c0e09 +376bda9651 +377db65f60 +37c19d1087 +37d4ae24fc +37ddce7f8b +37e10d33af +37e45c6247 +37fa0001e8 +3802d458c0 +382caa3cb4 +383bb93111 +388843df90 +38924f4a7f +38b00f93d7 +38c197c10e +38c9c3d801 +38eb2bf67f +38fe9b3ed1 +390352cced +390c51b987 +390ca6f1d6 +392bc0f8a1 +392ecb43bd +3935291688 +3935e63b41 +394454fa9c +394638fc8b +39545e20b7 +397abeae8f +3988074b88 +398f5d5f19 +39bc49a28c +39befd99fb +39c3c7bf55 +39d584b09f +39f6f6ffb1 +3a079fb484 +3a0d3a81b7 +3a1d55d22b +3a20a7583e +3a2c1f66e5 +3a33f4d225 +3a3bf84b13 +3a4565e5ec +3a4e32ed5e +3a7ad86ce0 +3a7bdde9b8 +3a98867cbe +3aa3f1c9e8 +3aa7fce8b6 +3aa876887d +3ab807ded6 +3ab9b1a85a +3adac8d7da +3ae1a4016f +3ae2deaec2 +3ae81609d6 +3af847e62f +3b23792b84 +3b3b0af2ee +3b512dad74 +3b6c7988f6 +3b6e983b5b +3b74a0fc20 +3b7a50b80d +3b96d3492f +3b9ad0c5a9 +3b9ba0894a +3bb4e10ed7 +3bd9a9b515 +3beef45388 +3c019c0a24 +3c090704aa +3c2784fc0d +3c47ab95f8 +3c4db32d74 +3c5ff93faf +3c700f073e +3c713cbf2f +3c8320669c +3c90d225ee +3cadbcc404 +3cb9be84a5 +3cc37fd487 +3cc6f90cb2 +3cd5e035ef +3cdf03531b +3cdf828f59 +3d254b0bca +3d5aeac5ba +3d690473e1 +3d69fed2fb +3d8997aeb6 +3db0d6b07e +3db1ddb8cf +3db907ac77 +3dcbc0635b +3dd48ed55f +3de4ac4ec4 +3decd63d88 +3e04a6be11 +3e108fb65a +3e1448b01c +3e16c19634 +3e2845307e +3e38336da5 +3e3a819865 +3e3e4be915 +3e680622d7 +3e7d2aeb07 +3e7d8f363d +3e91f10205 +3ea4c49bbe +3eb39d11ab +3ec273c8d5 +3ed3f91271 +3ee062a2fd +3eede9782c +3ef2fa99cb +3efc6e9892 +3f0b0dfddd +3f0c860359 +3f18728586 +3f3b15f083 +3f45a470ad +3f4f3bc803 +3fd96c5267 +3fea675fab +3fee8cbc9f +3fff16d112 +401888b36c +4019231330 +402316532d +402680df52 +404d02e0c0 +40709263a8 +4083cfbe15 +40a96c5cb1 +40b8e50f82 +40f4026bf5 +4100b57a3a +41059fdd0b +41124e36de +4122aba5f9 +413bab0f0d +4164faee0b +418035eec9 +4182d51532 +418bb97e10 +41a34c20e7 +41dab05200 +41ff6d5e2a +420caf0859 +42264230ba +425a0c96e0 +42da96b87c +42eb5a5b0f +42f17cd14d +42f5c61c49 +42ffdcdee9 +432f9884f9 +43326d9940 +4350f3ab60 +4399ffade3 +43a6c21f37 +43b5555faa +43d63b752a +4416bdd6ac +4444753edd +444aa274e7 +444d4e0596 +446b8b5f7a +4478f694bb +44b1da0d87 +44b4dad8c9 +44b5ece1b9 +44d239b24e +44eaf8f51e +44f4f57099 +44f7422af2 +450787ac97 +4523656564 +4536c882e5 +453b65daa4 +454f227427 +45636d806a +456fb9362e +457e717a14 +45a89f35e1 +45bf0e947d +45c36a9eab +45d9fc1357 +45f8128b97 +4607f6c03c +46146dfd39 +4620e66b1e +4625f3f2d3 +462b22f263 +4634736113 +463c0f4fdd +46565a75f8 +46630b55ae +466839cb37 +466ba4ae0c +4680236c9d +46bf4e8709 +46e18e42f1 +46f5093c59 +47269e0499 +472da1c484 +47354fab09 +4743bb84a7 +474a796272 +4783d2ab87 +479cad5da3 +479f5d7ef6 +47a05fbd1d +4804ee2767 +4810c3fbca +482fb439c2 +48375af288 +484ab44de4 +485f3944cd +4867b84887 +486a8ac57e +486e69c5bd +48812cf33e +4894b3b9ea +48bd66517d +48d83b48a4 +49058178b8 +4918d10ff0 +4932911f80 +49405b7900 +49972c2d14 +499bf07002 +49b16e9377 +49c104258e +49c879f82d +49e7326789 +49ec3e406a +49fbf0c98a +4a0255c865 +4a088fe99a +4a341402d0 +4a3471bdf5 +4a4b50571c +4a50f3d2e9 +4a6e3faaa1 +4a7191f08a +4a86fcfc30 +4a885fa3ef +4a8af115de +4aa2e0f865 +4aa9d6527f +4abb74bb52 +4ae13de1cd +4af8cb323f +4b02c272b3 +4b19c529fb +4b2974eff4 +4b3154c159 +4b54d2587f +4b556740ff +4b67aa9ef6 +4b97cc7b8d +4baa1ed4aa +4bc8c676bb +4beaea4dbe +4bf5763d24 +4bffa92b67 +4c25dfa8ec +4c397b6fd4 +4c51e75d66 +4c7710908f +4c9b5017be +4ca2ffc361 +4cad2e93bc +4cd427b535 +4cd9a4b1ef +4cdfe3c2b2 +4cef87b649 +4cf208e9b3 +4cf5bc3e60 +4cfdd73249 +4cff5c9e42 +4d26d41091 +4d5c23c554 +4d67c59727 +4d983cad9f +4da0d00b55 +4daa179861 +4dadd57153 +4db117e6c5 +4de4ce4dea +4dfaee19e5 +4dfdd7fab0 +4e3f346aa5 +4e49c2a9c7 +4e4e06a749 +4e70279712 +4e72856cc7 +4e752f8075 +4e7a28907f +4e824b9247 +4e82b1df57 +4e87a639bc +4ea77bfd15 +4eb6fc23a2 +4ec9da329e +4efb9a0720 +4f062fbc63 +4f35be0e0b +4f37e86797 +4f414dd6e7 +4f424abded +4f470cc3ae +4f601d255a +4f7386a1ab +4f824d3dcd +4f827b0751 +4f8db33a13 +4fa160f8a3 +4fa9c30a45 +4facd8f0e8 +4fca07ad01 +4fded94004 +4fdfef4dea +4feb3ac01f +4fffec8479 +500c835a86 +50168342bf +50243cffdc +5031d5a036 +504dd9c0fd +50568fbcfb +5069c7c5b3 +508189ac91 +50b6b3d4b7 +50c6f4fe3e +50cce40173 +50efbe152f +50f290b95d +5104aa1fea +5110dc72c0 +511e8ecd7f +513aada14e +5158d6e985 +5161e1fa57 +51794ddd58 +517d276725 +51a597ee04 +51b37b6d97 +51b5dc30a0 +51e85b347b +51eea1fdac +51eef778af +51f384721c +521cfadcb4 +52355da42f +5247d4b160 +524b470fd0 +524cee1534 +5252195e8a +5255c9ca97 +525928f46f +526df007a7 +529b12de78 +52c7a3d653 +52c8ec0373 +52d225ed52 +52ee406d9e +52ff1ccd4a +53143511e8 +5316d11eb7 +53253f2362 +534a560609 +5352c4a70e +536096501f +536b17bcea +5380eaabff +5390a43a54 +53af427bb2 +53bf5964ce +53c30110b5 +53cad8e44a +53d9c45013 +53e274f1b5 +53e32d21ea +540850e1c7 +540cb31cfe +541c4da30f +541d7935d7 +545468262b +5458647306 +54657855cd +547b3fb23b +5497dc3712 +549c56f1d4 +54a4260bb1 +54b98b8d5e +54e1054b0f +54e8867b83 +54ebe34f6e +5519b4ad13 +551acbffd5 +55341f42da +5566ab97e1 +556c79bbf2 +5589637cc4 +558aa072f0 +559824b6f6 +55c1764e90 +55eda6c77e +562d173565 +5665c024cb +566cef4959 +5675d78833 +5678a91bd8 +567a2b4bd0 +569c282890 +56cc449917 +56e71f3e07 +56f09b9d92 +56fc0e8cf9 +571ca79c71 +57243657cf +57246af7d1 +57427393e9 +574b682c19 +578f211b86 +5790ac295d +579393912d +57a344ab1a +57bd3bcda4 +57bfb7fa4c +57c010175e +57c457cc75 +57c7fc2183 +57d5289a01 +58045fde85 +58163c37cd +582d463e5c +5851739c15 +585dd0f208 +587250f3c3 +589e4cc1de +589f65f5d5 +58a07c17d5 +58adc6d8b6 +58b9bcf656 +58c374917e +58fc75fd42 +5914c30f05 +59323787d5 +5937b08d69 +594065ddd7 +595a0ceea6 +59623ec40b +597ff7ef78 +598935ef05 +598c2ad3b2 +59a6459751 +59b175e138 +59bf0a149f +59d53d1649 +59e3e6fae7 +59fe33e560 +5a13a73fe5 +5a25c22770 +5a4a785006 +5a50640995 +5a75f7a1cf +5a841e59ad +5a91c5ab6d +5ab49d9de0 +5aba1057fe +5abe46ba6d +5ac7c88d0c +5aeb95cc7d +5af15e4fc3 +5afe381ae4 +5b07b4229d +5b1001cc4f +5b1df237d2 +5b263013bf +5b27d19f0b +5b48ae16c5 +5b5babc719 +5baaebdf00 +5bab55cdbe +5bafef6e79 +5bd1f84545 +5bddc3ba25 +5bdf7c20d2 +5bf23bc9d3 +5c01f6171a +5c021681b7 +5c185cff1d +5c42aba280 +5c44bf8ab6 +5c4c574894 +5c52fa4662 +5c6ea7dac3 +5c74315dc2 +5c7668855e +5c83e96778 +5ca36173e4 +5cac477371 +5cb0cb1b2f +5cb0cfb98f +5cb49a19cf +5cbf7dc388 +5d0e07d126 +5d1e24b6e3 +5d663000ff +5da6b2dc5d +5de9b90f24 +5e08de0ed7 +5e1011df9a +5e1ce354fd +5e35512dd7 +5e418b25f9 +5e4849935a +5e4ee19663 +5e886ef78f +5e8d00b974 +5e8d59dc31 +5ed838bd5c +5edda6ee5a +5ede4d2f7a +5ede9767da +5eec4d9fe5 +5eecf07824 +5eef7ed4f4 +5ef5860ac6 +5ef6573a99 +5f1193e72b +5f29ced797 +5f32cf521e +5f51876986 +5f6ebe94a9 +5f6f14977c +5f808d0d2d +5fb8aded6a +5fba90767d +5fd1c7a3df +5fd3da9f68 +5fee2570ae +5ff66140d6 +5ff8b85b53 +600803c0f6 +600be7f53e +6024888af8 +603189a03c +6057307f6e +6061ddbb65 +606c86c455 +60c61cc2e5 +60e51ff1ae +610e38b751 +61344be2f6 +6135e27185 +614afe7975 +614e571886 +614e7078db +619812a1a7 +61b481a78b +61c7172650 +61cf7e40d2 +61d08ef5a1 +61da008958 +61ed178ecb +61f5d1282c +61fd977e49 +621584cffe +625817a927 +625892cf0b +625b89d28a +629995af95 +62a0840bb5 +62ad6e121c +62d6ece152 +62ede7b2da +62f025e1bc +6316faaebc +63281534dc +634058dda0 +6353f09384 +6363c87314 +636e4872e0 +637681cd6b +6376d49f31 +6377809ec2 +63936d7de5 +639bddef11 +63d37e9fd3 +63d90c2bae +63e544a5d6 +63ebbcf874 +63fff40b31 +6406c72e4d +64148128be +6419386729 +643092bc41 +644081b88d +64453cf61d +644bad9729 +6454f548fd +645913b63a +64750b825f +64a43876b7 +64dd6c83e3 +64e05bf46e +64f55f1478 +650b0165e4 +651066ed39 +652b67d960 +653821d680 +6538d00d73 +65866dce22 +6589565c8c +659832db64 +65ab7e1d98 +65b7dda462 +65bd5eb4f5 +65dcf115ab +65e9825801 +65f9afe51c +65ff12bcb5 +666b660284 +6671643f31 +668364b372 +66852243cb +6693a52081 +669b572898 +66e98e78f5 +670f12e88f +674c12c92d +675c27208a +675ed3e1ca +67741db50a +678a2357eb +67b0f4d562 +67cfbff9b1 +67e717d6bd +67ea169a3b +67ea809e0e +681249baa3 +683de643d9 +6846ac20df +6848e012ef +684bcd8812 +684dc1c40c +685a1fa9cf +686dafaac9 +68807d8601 +6893778c77 +6899d2dabe +68a2fad4ab +68cb45fda3 +68cc4a1970 +68dcb40675 +68ea4a8c3d +68f6e7fbf0 +68fa8300b4 +69023db81f +6908ccf557 +691a111e7c +6927723ba5 +692ca0e1a2 +692eb57b63 +69340faa52 +693cbf0c9d +6942f684ad +6944fc833b +69491c0ebf +695b61a2b0 +6979b4d83f +697d4fdb02 +69910460a4 +6997636670 +69a436750b +69aebf7669 +69b8c17047 +69c67f109f +69e0e7b868 +69ea9c09d1 +69f0af42a6 +6a078cdcc7 +6a37a91708 +6a42176f2e +6a48e4aea8 +6a5977be3a +6a5de0535f +6a80d2e2e5 +6a96c8815d +6a986084e2 +6aa8e50445 +6ab9dce449 +6abf0ba6b2 +6acc6049d9 +6adb31756c +6ade215eb0 +6afb7d50e4 +6afd692f1a +6b0b1044fe +6b17c67633 +6b1b6ef28b +6b1e04d00d +6b2261888d +6b25d6528a +6b3a24395c +6b685eb75b +6b79be238c +6b928b7ba6 +6b9c43c25a +6ba99cc41f +6bdab62bcd +6bf2e853b1 +6bf584200f +6bf95df2b9 +6c0949c51c +6c11a5f11f +6c23d89189 +6c4387daf5 +6c4ce479a4 +6c5123e4bc +6c54265f16 +6c56848429 +6c623fac5f +6c81b014e9 +6c99ea7c31 +6c9d29d509 +6c9e3b7d1a +6ca006e283 +6caeb928d6 +6cb2ee722a +6cbfd32c5e +6cc791250b +6cccc985e0 +6d12e30c48 +6d4bf200ad +6d6d2b8843 +6d6eea5682 +6d7a3d0c21 +6d7efa9b9e +6da21f5c91 +6da6adabc0 +6dd2827fbb +6dd36705b9 +6df3637557 +6dfe55e9e5 +6e1a21ba55 +6e2f834767 +6e36e4929a +6e4f460caf +6e618d26b6 +6ead4670f7 +6eaff19b9f +6eb2e1cd9e +6eb30b3b5a +6eca26c202 +6ecad29e52 +6ef0b44654 +6efcfe9275 +6f4789045c +6f49f522ef +6f67d7c4c4 +6f96e91d81 +6fc6fce380 +6fc9b44c00 +6fce7f3226 +6fdf1ca888 +702fd8b729 +70405185d2 +7053e4f41e +707bf4ce41 +7082544248 +708535b72a +7094ac0f60 +70a6b875fa +70c3e97e41 +7106b020ab +711dce6fe2 +7136a4453f +7143fb084f +714d902095 +7151c53b32 +715357be94 +7163b8085f +716df1aa59 +71caded286 +71d2665f35 +71d67b9e19 +71e06dda39 +720b398b9c +720e3fa04c +720e7a5f1e +721bb6f2cb +722803f4f2 +72552a07c9 +726243a205 +72690ef572 +728cda9b65 +728e81c319 +72a810a799 +72acb8cdf6 +72b01281f9 +72cac683e4 +72cadebbce +72cae058a5 +72d8dba870 +72e8d1c1ff +72edc08285 +72f04f1a38 +731b825695 +7320b49b13 +732626383b +732df1eb05 +73329902ab +733798921e +733824d431 +734ea0d7fb +735a7cf7b9 +7367a42892 +7368d5c053 +73c6ae7711 +73e1852735 +73e4e5cc74 +73eac9156b +73f8441a88 +7419e2ab3f +74267f68b9 +7435690c8c +747c44785c +747f1b1f2f +748b2d5c01 +74d4cee0a4 +74ec2b3073 +74ef677020 +750be4c4d8 +75172d4ac8 +75285a7eb1 +75504539c3 +7550949b1d +7551cbd537 +75595b453d +7559b4b0ec +755bd1fbeb +756f76f74d +7570ca7f3c +757a69746e +757cac96c6 +7584129dc3 +75a058dbcd +75b09ce005 +75cae39a8f +75cee6caf0 +75cf58fb2c +75d5c2f32a +75eaf5669d +75f7937438 +75f99bd3b3 +75fa586876 +7613df1f84 +762e1b3487 +76379a3e69 +764271f0f3 +764503c499 +7660005554 +7666351b84 +76693db153 +767856368b +768671f652 +768802b80d +76962c7ed2 +76a75f4eee +76b90809f7 +770a441457 +772a0fa402 +772f2ffc3e +774f6c2175 +77610860e0 +777e58ff3d +77920f1708 +7799df28e7 +779e847a9a +77ba4edc72 +77c834dc43 +77d8aa8691 +77e7f38f4d +77eea6845e +7806308f33 +78254660ea +7828af8bff +784398620a +784d201b12 +78613981ed +78896c6baf +78aff3ebc0 +78c7c03716 +78d3676361 +78e29dd4c3 +78f1a1a54f +79208585cd +792218456c +7923bad550 +794e6fc49f +796e6762ce +797cd21f71 +79921b21c2 +79a5778027 +79bc006280 +79bf95e624 +79d9e00c55 +79e20fc008 +79e9db913e +79f014085e +79fcbb433a +7a13a5dfaa +7a14bc9a36 +7a3c535f70 +7a446a51e9 +7a56e759c5 +7a5f46198d +7a626ec98d +7a802264c4 +7a8b5456ca +7abdff3086 +7aecf9f7ac +7b0fd09c28 +7b18b3db87 +7b39fe7371 +7b49e03d4c +7b5388c9f1 +7b5cf7837f +7b733d31d8 +7b74fd7b98 +7b918ccb8a +7ba3ce3485 +7bb0abc031 +7bb5bb25cd +7bb7dac673 +7bc7761b8c +7bf3820566 +7c03a18ec1 +7c078f211b +7c37d7991a +7c4ec17eff +7c649c2aaf +7c73340ab7 +7c78a2266d +7c88ce3c5b +7ca6843a72 +7cc9258dee +7cec7296ae +7d0ffa68a4 +7d11b4450f +7d1333fcbe +7d18074fef +7d18c8c716 +7d508fb027 +7d55f791f0 +7d74e3c2f6 +7d783f67a9 +7d83a5d854 +7dd409947e +7de45f75e5 +7e0cd25696 +7e1922575c +7e1e3bbcc1 +7e24023274 +7e2f212fd3 +7e6d1cc1f4 +7e7cdcb284 +7e9b6bef69 +7ea5b49283 +7eb2605d96 +7eb26b8485 +7ecd1f0c69 +7f02b3cfe2 +7f1723f0d5 +7f21063c3a +7f3658460e +7f54132e48 +7f559f9d4a +7f5faedf8b +7f838baf2b +7fa5f527e3 +7ff84d66dd +802b45c8c4 +804382b1ad +804c558adb +804f6338a4 +8056117b89 +806b6223ab +8088bda461 +80b790703b +80c4a94706 +80ce2e351b +80db581acd +80e12193df +80e41b608f +80f16b016d +81541b3725 +8175486e6a +8179095000 +8193671178 +81a58d2c6b +81aa1286fb +81dffd30fb +8200245704 +823e7a86e8 +824973babb +824ca5538f +827171a845 +8273a03530 +827cf4f886 +82b865c7dd +82c1517708 +82d15514d6 +82e117b900 +82fec06574 +832b5ef379 +83424c9fbf +8345358fb8 +834b50b31b +835e3b67d7 +836ea92b15 +837c618777 +838eb3bd89 +839381063f +839bc71489 +83a8151377 +83ae88d217 +83ca8bcad0 +83ce590d7f +83d3130ba0 +83d40bcba5 +83daba503a +83de906ec0 +84044f37f3 +84696b5a5e +84752191a3 +847eeeb2e0 +848e7835a0 +84a4b29286 +84a4bf147d +84be115c09 +84d95c4350 +84e0922cf7 +84f0cfc665 +8515f6db22 +851f2f32c1 +852a4d6067 +854c48b02a +857a387c86 +859633d56a +85a4f4a639 +85ab85510c +85b1eda0d9 +85dc1041c6 +85e081f3c7 +85f75187ad +8604bb2b75 +860745b042 +863b4049d7 +8643de22d0 +8647d06439 +864ffce4fe +8662d9441a +8666521b13 +868d6a0685 +869fa45998 +86a40b655d +86a8ae4223 +86b2180703 +86c85d27df +86d3755680 +86e61829a1 +871015806c +871e409c5c +8744b861ce +8749369ba0 +878a299541 +8792c193a0 +8799ab0118 +87d1f7d741 +882b9e4500 +885673ea17 +8859dedf41 +8873ab2806 +887a93b198 +8883e991a9 +8891aa6dfa +8899d8cbcd +88b8274d67 +88d3b80af6 +88ede83da2 +88f345941b +890976d6da +8909bde9ab +8929c7d5d9 +89363acf76 +89379487e0 +8939db6354 +893f658345 +8953138465 +895c96d671 +895cbf96f9 +895e8b29a7 +898fa256c8 +89986c60be +89b874547b +89bdb021d5 +89c802ff9c +89d6336c2b +89ebb27334 +8a27e2407c +8a31f7bca5 +8a4a2fc105 +8a5d6c619c +8a75ad7924 +8aa817e4ed +8aad0591eb +8aca214360 +8ae168c71b +8b0cfbab97 +8b3645d826 +8b3805dbd4 +8b473f0f5d +8b4f6d1186 +8b4fb018b7 +8b518ee936 +8b523bdfd6 +8b52fb5fba +8b91036e5c +8b99a77ac5 +8ba04b1e7b +8ba782192f +8bbeaad78b +8bd1b45776 +8bd7a2dda6 +8bdb091ccf +8be56f165d +8be950d00f +8bf84e7d45 +8bffc4374b +8bfff50747 +8c09867481 +8c0a3251c3 +8c3015cccb +8c469815cf +8c9ccfedc7 +8ca1af9f3c +8ca3f6e6c1 +8ca6a4f60f +8cac6900fe +8cba221a1e +8cbbe62ccd +8d064b29e2 +8d167e7c08 +8d4ab94e1c +8d81f6f899 +8d87897d66 +8dcccd2bd2 +8dcfb878a8 +8dd3ab71b9 +8dda6bf10f +8ddd51ca94 +8dea22c533 +8def5bd3bf +8e1848197c +8e3a83cf2d +8e478e73f3 +8e98ae3c84 +8ea6687ab0 +8eb0d315c1 +8ec10891f9 +8ec3065ec2 +8ecf51a971 +8eddbab9f7 +8ee198467a +8ee2368f40 +8ef595ce82 +8f0a653ad7 +8f1204a732 +8f1600f7f6 +8f16366707 +8f1ce0a411 +8f2e05e814 +8f320d0e09 +8f3b4a84ad +8f3fdad3da +8f5d3622d8 +8f62a2c633 +8f81c9405a +8f8c974d53 +8f918598b6 +8ff61619f6 +9002761b41 +90107941f3 +90118a42ee +902bc16b37 +903e87e0d6 +9041a0f489 +9047bf3222 +9057bfa502 +90617b0954 +9076f4b6db +9077e69b08 +909655b4a6 +909c2eca88 +909dbd1b76 +90bc4a319a +90c7a87887 +90cc785ddd +90d300f09b +9101ea9b1b +9108130458 +911ac9979b +9151cad9b5 +9153762797 +91634ee0c9 +916942666f +9198cfb4ea +919ac864d6 +91b67d58d4 +91bb8df281 +91be106477 +91c33b4290 +91ca7dd9f3 +91d095f869 +91f107082e +920329dd5e +920c959958 +92128fbf4b +9223dacb40 +923137bb7f +9268e1f88a +927647fe08 +9276f5ba47 +92a28cd233 +92b5c1fc6d +92c46be756 +92dabbe3a0 +92e3159361 +92ebab216a +934bdc2893 +9359174efc +935d97dd2f +935feaba1b +93901858ee +939378f6d6 +939bdf742e +93a22bee7e +93da9aeddf +93e2feacce +93e6f1fdf9 +93e811e393 +93e85d8fd3 +93f623d716 +93ff35e801 +94031f12f2 +94091a4873 +94125907e3 +9418653742 +941c870569 +94209c86f0 +9437c715eb +9445c3eca2 +9467c8617c +946d71fb5d +948f3ae6fb +9498baa359 +94a33abeab +94bf1af5e3 +94cf3a8025 +94db712ac8 +94e4b66cff +94e76cbaf6 +950be91db1 +952058e2d0 +952633c37f +952ec313fe +9533fc037c +9574b81269 +9579b73761 +957f7bc48b +958073d2b0 +9582e0eb33 +9584092d0b +95b58b8004 +95bd88da55 +95f74a9959 +962781c601 +962f045bf5 +964ad23b44 +967b90590e +967bffe201 +96825c4714 +968492136a +9684ef9d64 +968c41829e +96a856ef9a +96dfc49961 +96e1a5b4f8 +96e6ff0917 +96fb88e9d7 +96fbe5fc23 +96fc924050 +9715cc83dc +9720eff40f +972c187c0d +97476eb38d +97659ed431 +9773492949 +97756b264f +977bff0d10 +97ab569ff3 +97ba838008 +97d9d008c7 +97e59f09fa +97eb642e56 +98043e2d14 +981ff580cf +983e66cbfc +984f0f1c36 +98595f2bb4 +985c3be474 +9869a12362 +986b5a5e18 +9877af5063 +98911292da +9893a3cf77 +9893d9202d +98a8b06e7f +98ac6f93d9 +98b6974d12 +98ba3c9417 +98c7c00a19 +98d044f206 +98e909f9d1 +98fe7f0410 +990f2742c7 +992bd0779a +994b9b47ba +9955b76bf5 +9966f3adac +997117a654 +999d53d841 +99c04108d3 +99c4277aee +99c6b1acf2 +99dc8bb20b +99fcba71e5 +99fecd4efb +9a02c70ba2 +9a08e7a6f8 +9a2f2c0f86 +9a3254a76e +9a3570a020 +9a39112493 +9a4e9fd399 +9a50af4bfb +9a68631d24 +9a72318dbf +9a767493b7 +9a7fc1548b +9a84ccf6a7 +9a9c0e15b7 +9adf06d89b +9b22b54ee4 +9b473fc8fe +9b4f081782 +9b997664ba +9bc454e109 +9bccfd04de +9bce4583a2 +9bebf1b87f +9bfc50d261 +9c166c86ff +9c293ef4d7 +9c29c047b0 +9c3bc2e2a7 +9c3ce23bd1 +9c404cac0c +9c5180d23a +9c7feca6e4 +9caa49d3ff +9cb2f1b646 +9ce6f765c3 +9cfee34031 +9d01f08ec6 +9d04c280b8 +9d12ceaddc +9d15f8cb3c +9d2101e9bf +9d407c3aeb +9ddefc6165 +9df0b1e298 +9e16f115d8 +9e249b4982 +9e29b1982c +9e493e4773 +9e4c752cd0 +9e4de40671 +9e6319faeb +9e6ddbb52d +9eadcea74f +9ecec5f8ea +9efb47b595 +9f30bfe61e +9f3734c3a4 +9f5b858101 +9f66640cda +9f913803e9 +9f97bc74c8 +9fbad86e20 +9fc2bad316 +9fc5c3af78 +9fcb310255 +9fcc256871 +9fd2fd4d47 +a0071ae316 +a023141022 +a046399a74 +a066e739c1 +a06722ba82 +a07a15dd64 +a07b47f694 +a09c39472e +a0b208fe2e +a0b61c959e +a0bc6c611d +a0e6da5ba2 +a1193d6490 +a14ef483ff +a14f709908 +a15ccc5658 +a16062456f +a174e8d989 +a177c2733c +a17c62e764 +a18ad065fc +a1aaf63216 +a1bb65fb91 +a1bd8e5349 +a1dfdd0cac +a2052e4f6c +a20fd34693 +a21ffe4d81 +a22349e647 +a235d01ec1 +a24f63e8a2 +a2554c9f6d +a263ce8a87 +a29bfc29ec +a2a80072d4 +a2a800ab63 +a2bcd10a33 +a2bdaff3b0 +a2c146ab0d +a2c996e429 +a2dc51ebe8 +a2e6608bfa +a2f2a55f01 +a301869dea +a31fccd2cc +a34f440f33 +a35e0206da +a36bdc4cab +a36e8c79d8 +a378053b20 +a37db3a2b3 +a38950ebc2 +a39a0eb433 +a39c9bca52 +a3a945dc8c +a3b40a0c1e +a3b8588550 +a3c502bec3 +a3f2878017 +a3f4d58010 +a3f51855c3 +a402dc0dfe +a4065a7eda +a412bb2fef +a416b56b53 +a41ec95906 +a43299e362 +a4757bd7af +a48c53c454 +a49dcf9ad5 +a4a506521f +a4ba7753d9 +a4bac06849 +a4f05d681c +a50c10060f +a50eb5a0ea +a5122c6ec6 +a522b1aa79 +a590915345 +a5b5b59139 +a5b77abe43 +a5c2b2c3e1 +a5cd17bb11 +a5da03aef1 +a5dd11de0d +a5ea2b93b6 +a5eaeac80b +a5ec5b0265 +a5f350a87e +a5f472caf4 +a6027a53cf +a61715bb1b +a61cf4389d +a61d9bbd9b +a6470dbbf5 +a64a40f3eb +a653d5c23b +a65bd23cb5 +a66e0b7ad4 +a66fc5053c +a68259572b +a6a810a92c +a6bc36937f +a6c3a374e9 +a6d8a4228d +a6f4e0817f +a71e0481f5 +a7203deb2d +a7392d4438 +a73d3c3902 +a7491f1578 +a74b9ca19c +a77b7a91df +a78195a5f5 +a78758d4ce +a7e6d6c29a +a800d85e88 +a832fa8790 +a83d06410d +a8999af004 +a8f78125b9 +a907b18df1 +a919392446 +a965504e88 +a96b84b8d2 +a973f239cd +a977126596 +a9804f2a08 +a984e56893 +a99738f24c +a99bdd0079 +a9c9c1517e +a9cbf9c41b +a9e42e3c0c +aa07b7c1c0 +aa175e5ec7 +aa1a338630 +aa27d7b868 +aa45f1caaf +aa49e46432 +aa51934e1b +aa6287bb6c +aa6d999971 +aa85278334 +aab33f0e2a +aaba004362 +aade4cf385 +aae78feda4 +aaed233bf3 +aaff16c2db +ab199e8dfb +ab23b78715 +ab2e1b5577 +ab33a18ded +ab45078265 +ab56201494 +ab90f0d24b +abab2e6c20 +abb50c8697 +abbe2d15a0 +abbe73cd21 +abe61a11bb +abeae8ce21 +ac2b431d5f +ac2cb1b9eb +ac31fcd6d0 +ac3d3a126d +ac46bd8087 +ac783ef388 +acb73e4297 +acbf581760 +accafc3531 +acf2c4b745 +acf44293a2 +acf736a27b +acff336758 +ad1fe56886 +ad28f9b9d9 +ad2de9f80e +ad397527b2 +ad3d1cfbcb +ad3fada9d9 +ad4108ee8e +ad54468654 +ad573f7d31 +ad6255bc29 +ad65ebaa07 +ad97cc064a +adabbd1cc4 +adb0b5a270 +adc648f890 +add21ee467 +adfd15ceef +adfdd52eac +ae01cdab63 +ae0b50ff4f +ae13ee3d70 +ae1bcbd423 +ae20d09dea +ae2cecf5f6 +ae3bc4a0ef +ae499c7514 +ae628f2cd4 +ae8545d581 +ae93214fe6 +ae9cd16dbf +aeba9ac967 +aebb242b5c +aed4e0b4c4 +aedd71f125 +aef3e2cb0e +af0b54cee3 +af3de54c7a +af5fd24a36 +af8826d084 +af8ad72057 +afb71e22c5 +afcb331e1f +afe1a35c1e +b01080b5d3 +b05ad0d345 +b0623a6232 +b064dbd4b7 +b06ed37831 +b06f5888e6 +b08dcc490e +b0a68228dc +b0aece727f +b0b0731606 +b0c7f11f9f +b0cca8b830 +b0dd580a89 +b0de66ca08 +b0df7c5c5c +b0f5295608 +b11099eb09 +b132a53086 +b1399fac64 +b13abc0c69 +b1457e3b5e +b15bf4453b +b179c4a82d +b17ee70e8c +b190b1aa65 +b19b3e22c0 +b19c561fab +b1d1cd2e6e +b1d7c03927 +b1d7fe2753 +b1f540a4bd +b1fc9c64e1 +b1fcbb3ced +b220939e93 +b22099b419 +b241e95235 +b2432ae86d +b2456267df +b247940d01 +b24af1c35c +b24f600420 +b24fe36b2a +b258fb0b7d +b26b219919 +b26d9904de +b274456ce1 +b27b28d581 +b2a26bc912 +b2a9c51e1b +b2b0baf470 +b2b2756fe7 +b2ce7699e3 +b2edc76bd2 +b2f6b52100 +b30bf47bcd +b34105a4e9 +b372a82edf +b3779a1962 +b379ab4ff5 +b37a1d69e3 +b37c01396e +b382b09e25 +b3996e4ba5 +b3d9ca2aee +b3dde1e1e9 +b3eb7f05eb +b40b25055c +b41e0f1f19 +b44e32a42b +b4805ae9cd +b4807569a5 +b48efceb3e +b493c25c7f +b4b565aba1 +b4b715a15b +b4d0c90bf4 +b4d84bc371 +b4e5ad97aa +b4eaea9e6b +b50f4b90d5 +b53f675641 +b54278cd43 +b554843889 +b573c0677a +b58d853734 +b5943b18ab +b5a09a83f3 +b5aae1fe25 +b5b9da5364 +b5eb64d419 +b5ebb1d000 +b5f1c0c96a +b5f7fece90 +b6070de1bb +b60a76fe73 +b61f998772 +b62c943664 +b63094ba0c +b64fca8100 +b673e7dcfb +b678b7db00 +b68fc1b217 +b69926d9fa +b6a1df3764 +b6a4859528 +b6b4738b78 +b6b4f847b7 +b6b8d502d4 +b6bb00e366 +b6d65a9eef +b6d79a0845 +b6e9ec577f +b6ec609f7b +b6f92a308d +b70a2c0ab1 +b70a5a0d50 +b70c052f2f +b70d231781 +b72ac6e10b +b7302d8226 +b73867d769 +b751e767f2 +b76df6e059 +b77e5eddef +b7a2c2c83c +b7bcbe6466 +b7c2a469c4 +b7d69da8f0 +b7f31b7c36 +b7f675fb98 +b7fb871660 +b82e5ad1c9 +b841cfb932 +b84b8ae665 +b85b78ac2b +b86c17caa6 +b86e50d82d +b871db031a +b87d56925a +b8aaa59b75 +b8c03d1091 +b8c3210036 +b8e16df00b +b8f34cf72e +b8fb75864e +b9004db86c +b9166cbae9 +b920b256a6 +b938d79dff +b93963f214 +b941aef1a0 +b94d34d14e +b964c57da4 +b96a95bc7a +b96c57d2c7 +b9b6bdde0c +b9bcb3e0f2 +b9d3b92169 +b9dd4b306c +b9f43ef41e +ba1f03c811 +ba3a775d7b +ba3c7f2a31 +ba3fcd417d +ba5e1f4faa +ba795f3089 +ba8a291e6a +ba98512f97 +bac9db04f5 +baedae3442 +baff40d29d +bb04e28695 +bb1b0ee89f +bb1c770fe7 +bb1fc34f99 +bb2d220506 +bb334e5cdb +bb337f9830 +bb721eb9aa +bb87ff58bd +bb89a6b18a +bbaa9a036a +bbb4302dda +bbd31510cf +bbe0256a75 +bc141b9ad5 +bc17ab8a99 +bc318160de +bc3b9ee033 +bc4240b43c +bc4ce49105 +bc4f71372d +bc6b8d6371 +bcaad44ad7 +bcc241b081 +bcc5d8095e +bcd1d39afb +bd0d849da4 +bd0e9ed437 +bd2c94730f +bd321d2be6 +bd3ec46511 +bd5b2e2848 +bd7e02b139 +bd96f9943a +bda224cb25 +bda4a82837 +bdb74e333f +bdccd69dde +bddcc15521 +be116aab29 +be15e18f1e +be1a284edb +be2a367a7b +be376082d0 +be3e3cffbd +be5d1d89a0 +be8b72fe37 +be9b29e08e +bea1f6e62c +bea83281b5 +beb921a4c9 +bec5e9edcd +beeb8a3f92 +bf2232b58d +bf28751739 +bf443804e8 +bf461df850 +bf5374f122 +bf551a6f60 +bf8d0f5ada +bf961167a6 +bfab1ad8f9 +bfcb05d88d +bfd8f6e6c9 +bfd91d0742 +bfe262322f +c013f42ed7 +c01878083f +c01faff1ed +c046fd0edb +c053e35f97 +c079a6482d +c0847b521a +c0a1e06710 +c0e8d4635c +c0e973ad85 +c0f49c6579 +c0f5b222d7 +c10d07c90d +c1268d998c +c130c3fc0c +c14826ad5e +c15b922281 +c16f09cb63 +c18e19d922 +c1c830a735 +c1e8aeea45 +c20a5ccc99 +c20fd5e597 +c219d6f8dc +c2406ae462 +c26f7b5824 +c279e641ee +c27adaeac5 +c2a35c1cda +c2a9903b8b +c2b62567c1 +c2b974ec8c +c2baaff7bf +c2be6900f2 +c304dd44d5 +c307f33da2 +c30a7b62c9 +c3128733ee +c31fa6c598 +c325c8201e +c32d4aa5d1 +c33f28249a +c34365e2d7 +c3457af795 +c34d120a88 +c3509e728d +c35e4fa6c4 +c36240d96f +c3641dfc5a +c37b17a4a9 +c39559ddf6 +c3b0c6e180 +c3b3d82e6c +c3be369fdb +c3bf1e40c2 +c3c760b015 +c3dd38bf98 +c3e4274614 +c3edc48cbd +c41e6587f5 +c4272227b0 +c42917fe82 +c438858117 +c44676563f +c44beb7472 +c45411dacb +c4571bedc8 +c46deb2956 +c479ee052e +c47d551843 +c49f07d46d +c4cc40c1fc +c4f256f5d5 +c4f5b1ddcc +c4ff9b4885 +c52bce43db +c544da6854 +c55784c766 +c557b69fbf +c593a3f7ab +c598faa682 +c5ab1f09c8 +c5b6da8602 +c5b9128d94 +c5e845c6b7 +c5fba7b341 +c60897f093 +c61fe6ed7c +c62188c536 +c64035b2e2 +c69689f177 +c6a12c131f +c6bb6d2d5c +c6c18e860f +c6d9526e0d +c6e55c33f0 +c7030b28bd +c70682c7cc +c70f9be8c5 +c71f30d7b6 +c73c8e747f +c760eeb8b3 +c7637cab0a +c7a1a17308 +c7bf937af5 +c7c2860db3 +c7cef4aee2 +c7ebfc5d57 +c813dcf13c +c82235a49a +c82a7619a1 +c82ecb90cb +c844f03dc7 +c8557963f3 +c89147e6e8 +c8a46ff0c8 +c8ab107dd5 +c8b869a04a +c8c7b306a6 +c8c8b28781 +c8d79e3163 +c8edab0415 +c8f494f416 +c8f6cba9fd +c909ceea97 +c9188f4980 +c922365dd4 +c92c8c3c75 +c937eb0b83 +c94b31b5e5 +c95cd17749 +c96379c03c +c96465ee65 +c965afa713 +c9734b451f +c9862d82dc +c98b6fe013 +c9999b7c48 +c99e92aaf0 +c9b3a8fbda +c9bf64e965 +c9c3cb3797 +c9d1c60cd0 +c9de9c22c4 +ca1828fa54 +ca346f17eb +ca3787d3d3 +ca4b99cbac +ca91c69e3b +ca91e99105 +caa8e97f81 +caac5807f8 +cabba242c2 +cad5a656a9 +cad673e375 +cad8a85930 +cae7b0a02b +cae7ef3184 +caeb6b6cbb +caecf0a5db +cb15312003 +cb2e35d610 +cb35a87504 +cb3f22b0cf +cbb410da64 +cc8728052e +cc892997b8 +cce03c2a9b +cd47a23e31 +cd4dc03dc0 +cd5ae611da +cd603bb9d1 +cd8f49734c +cdc6b1c032 +cdcfe008ad +cdd57027c2 +ce1af99b4b +ce1bc5743a +ce25872021 +ce2776f78f +ce49b1f474 +ce4f0a266f +ce5641b195 +ce6866aa19 +ce712ed3c9 +ce7d1c8117 +ce7dbeaa88 +ce9b015a5e +cea7697b25 +cebbd826cf +cec3415361 +cec41ad4f4 +ced49d26df +ced7705ab2 +cef824a1e1 +cf13f5c95a +cf4376a52d +cf85ab28b5 +cfc2e50b9d +cfcd571fff +cfd9d4ae47 +cfda2dcce5 +cff035928b +cff8191891 +d01608c2a5 +d01a8f1f83 +d021d68bca +d04258ca14 +d0483573dc +d04a90aaff +d05279c0bd +d0696bd5fc +d072fda75b +d0a83bcd9f +d0ab39112e +d0acde820f +d0b4442c71 +d0c65e9e95 +d0fb600c73 +d107a1457c +d123d674c1 +d14d1e9289 +d154e3388e +d177e9878a +d1802f69f8 +d182c4483a +d195d31128 +d200838929 +d205e3cff5 +d247420c4c +d2484bff33 +d26f6ed9b0 +d280fcd1cb +d2857f0faa +d292a50c7f +d295ea2dc7 +d2a58b4fa6 +d2b026739a +d2ebe0890f +d2ede5d862 +d301ca58cc +d3069da8bb +d343d4a77d +d355e634ef +d367fb5253 +d36d16358e +d38bc77e2c +d38d1679e2 +d3932ad4bd +d3987b2930 +d39934abe3 +d3ae1c3f4c +d3b088e593 +d3e6e05e16 +d3eefae7c5 +d3f55f5ab8 +d3f5c309cc +d4034a7fdf +d4193011f3 +d429c67630 +d42c0ff975 +d44a764409 +d44e6acd1d +d45158c175 +d454e8444f +d45f62717e +d48ebdcf74 +d49ab52a25 +d4a607ad81 +d4b063c7db +d4da13e9ba +d4dd1a7d00 +d4f4f7c9c3 +d521aba02e +d535bb1b97 +d53b955f78 +d55cb7a205 +d55f247a45 +d5695544d8 +d5853d9b8b +d5b6c6d94a +d5cae12834 +d5df027f0c +d5ee40e5d0 +d600046f73 +d632fd3510 +d6476cad55 +d65a7bae86 +d664c89912 +d689658f06 +d6917db4be +d69967143e +d699d3d798 +d69f757a3f +d6ac0e065c +d6c02bfda5 +d6c1b5749e +d6e12ef6cc +d6eed152c4 +d6faaaf726 +d704766646 +d708e1350c +d7135cf104 +d7157a9f44 +d719cf9316 +d724134cfd +d73a60a244 +d7411662da +d74875ea7c +d756f5a694 +d7572b7d8a +d763bd6d96 +d7697c8b13 +d7797196b4 +d79c834768 +d7b34e5d73 +d7bb6b37a7 +d7c7e064a6 +d7fbf545b3 +d82a0aa15b +d847e24abd +d8596701b7 +d86101499c +d87069ba86 +d87160957b +d874654b52 +d88a403092 +d8aee40f3f +d8e77a222d +d8eb07c381 +d9010348a1 +d90e3cf281 +d92532c7b2 +d927fae122 +d95707bca8 +d973b31c00 +d991cb471d +d992c69d37 +d99d770820 +d9b63abc11 +d9db6f1983 +d9e52be2d2 +d9edc82650 +da01070697 +da070ea4b7 +da080507b9 +da0e944cc4 +da28d94ff4 +da5d78b9d1 +da6003fc72 +da690fee9f +da6c68708f +da7a816676 +dac361e828 +dac71659b8 +dad980385d +daebc12b77 +db0968cdd3 +db231a7100 +db59282ace +db7f267c3f +dba35b87fd +dbba735a50 +dbca076acd +dbd66dc3ac +dbdc3c292b +dbf4a5b32b +dbfc417d28 +dc1745e0a2 +dc32a44804 +dc34b35e30 +dc504a4f79 +dc704dd647 +dc71bc6918 +dc7771b3be +dcf8c93617 +dd0f4c9fb9 +dd415df125 +dd601f9a3f +dd61d903df +dd77583736 +dd8636bd8b +dd9fe6c6ac +ddb2da4c14 +ddcd450d47 +dde8e67fb4 +ddfc3f04d3 +de2ab79dfa +de2f35b2fd +de30990a51 +de36b216da +de37403340 +de46e4943b +de4ddbccb1 +de5e480f05 +de6a9382ca +de74a601d3 +de827c510d +ded6069f7b +defb71c741 +df01f277f1 +df05214b82 +df0638b0a0 +df11931ffe +df1b0e4620 +df20a8650d +df2bc56d7c +df365282c6 +df39a0d9df +df3c430c24 +df5536cfb9 +df59cfd91d +df5e2152b3 +df741313c9 +df7626172f +df8ad5deb9 +df96aa609a +df9705605c +df9c91c4da +dfc0d3d27a +dfdbf91a99 +e00baaae9b +e0a938c6e7 +e0b2ceee6f +e0bdb5dfae +e0be1f6e17 +e0c478f775 +e0de82caa7 +e0f217dd59 +e0f7208874 +e0fb58395e +e1194c2e9d +e11adcd05d +e128124b9d +e1495354e4 +e1561d6d4b +e158805399 +e16945b951 +e19edcd34b +e1a1544285 +e1ab7957f4 +e1d26d35be +e1e957085b +e1f14510fa +e214b160f4 +e2167379b8 +e21acb20ab +e221105579 +e22ddf8a1b +e22de45950 +e22ffc469b +e23cca5244 +e252f46f0b +e25fa6cf39 +e26e486026 +e275760245 +e27bbedbfe +e29e9868a8 +e2b37ff8af +e2b608d309 +e2bef4da9a +e2c87a6421 +e2ea25542c +e2fb1d6497 +e2fcc99117 +e33c18412a +e348377191 +e352cb59c8 +e36ac982f0 +e391bc981e +e39e3e0a06 +e3bf38265f +e3d5b2cd21 +e3d60e82d5 +e3e3245492 +e3e4134877 +e3f4635e03 +e4004ee048 +e402d1afa5 +e415093d27 +e41ceb5d81 +e424653b78 +e42b6d3dbb +e42d60f0d4 +e436d0ff1e +e43d7ae2c5 +e4428801bc +e44e0b4917 +e470345ede +e48e8b4263 +e4922e3726 +e4936852bb +e495f32c60 +e499228f26 +e4af66e163 +e4b2095f58 +e4d19c8283 +e4d4872dab +e4e2983570 +e4eaa63aab +e4ef0a3a34 +e4f8e5f46e +e4ffb6d0dd +e53e21aa02 +e57f4f668b +e588433c1e +e597442c99 +e5abc0e96b +e5be628030 +e5ce96a55d +e5d6b70a9f +e5fde1574c +e625e1d27b +e6261d2348 +e6267d46bc +e6295f223f +e63463d8c6 +e6387bd1e0 +e653883384 +e65f134e0b +e668ef5664 +e672ccd250 +e674510b20 +e676107765 +e699da0cdf +e6be243065 +e6deab5e0b +e6f065f2b9 +e71629e7b5 +e72a7d7b0b +e72f6104e1 +e75a466eea +e76c55933f +e7784ec8ad +e78922e5e6 +e78d450a9c +e7c6354e77 +e7c8de1fce +e7ea10db28 +e803918710 +e8073a140b +e828dd02db +e845994987 +e8485a2615 +e85c5118a7 +e88b6736e4 +e8962324e3 +e8b3018d36 +e8cee8bf0b +e8d97ebece +e8da49ea6a +e8ed1a3ccf +e8f7904326 +e8f8341dec +e8fa21eb13 +e90c10fc4c +e914b8cac8 +e92b6bfea4 +e92e1b7623 +e93f83e512 +e9422ad240 +e9460b55f9 +e9502628f6 +e950befd5f +e9582bdd1b +e95e5afe0f +e97cfac475 +e98d57d99c +e98eda8978 +e99706b555 +e9bc0760ba +e9d3c78bf3 +e9ec1b7ea8 +ea065cc205 +ea138b6617 +ea16d3fd48 +ea2545d64b +ea286a581c +ea320da917 +ea345f3627 +ea3b94a591 +ea444a37eb +ea4a01216b +ea5672ffa8 +eaa99191cb +eaab4d746c +eac7a59bc1 +ead5d3835a +eaec65cfa7 +eaed1a87be +eb2f821c6f +eb383cb82e +eb6992fe02 +eb6ac20a01 +eb6d7ab39e +eb7921facd +eb8fce51a6 +ebbb90e9f9 +ebbf5c9ee1 +ebc4ec32e6 +ebe56e5ef8 +ec1299aee4 +ec139ff675 +ec193e1a01 +ec28252938 +ec387be051 +ec3d4fac00 +ec4186ce12 +ec579c2f96 +ecae59b782 +ecb33a0448 +ece6bc9e92 +ecfedd4035 +ecfff22fd6 +ed3291c3d6 +ed3cd5308d +ed3e6fc1a5 +ed72ae8825 +ed7455da68 +ed844e879f +ed8f814b2b +ed911a1f63 +ed9ff4f649 +eda8ab984b +edb8878849 +edbfdfe1b4 +edd22c46a2 +edd663afa3 +ede3552eae +edeab61ee0 +ee07583fc0 +ee316eaed6 +ee3f509537 +ee40a1e491 +ee4bf100f1 +ee6f9b01f9 +ee947ed771 +ee9706ac7f +ee9a7840ae +eeb90cb569 +eebf45e5c5 +eeed0c7d73 +ef0061a309 +ef07f1a655 +ef0a8e8f35 +ef232a2aed +ef308ad2e9 +ef44945428 +ef45ce3035 +ef5dde449d +ef5e770988 +ef6359cea3 +ef65268834 +ef6cb5eae0 +ef78972bc2 +ef8cfcfc4f +ef96501dd0 +ef9a2e976b +efb24f950f +efce0c1868 +efe5ac6901 +efe828affa +efea4e0523 +f0268aa627 +f0483250c8 +f04cf99ee6 +f05b189097 +f08928c6d3 +f09d74856f +f0a7607d63 +f0ad38da27 +f0c34e1213 +f0c7f86c29 +f0dfa18ba7 +f0eb3179f7 +f119bab27d +f14409b6a3 +f1489baff4 +f14c18cf6a +f15c607b92 +f1af214222 +f1b77bd309 +f1ba9e1a3e +f1d99239eb +f1dc710cf4 +f1ec5c08fa +f22648fe12 +f22d21f1f1 +f233257395 +f23e95dbe5 +f2445b1572 +f253b3486d +f277c7a6a4 +f2ab2b84d6 +f2b7c9b1f3 +f2b83d5ce5 +f2c276018f +f2cfd94d64 +f2dd6e3add +f2e7653f16 +f2f333ad06 +f2f55d6713 +f2fdb6abec +f305a56d9f +f3085d6570 +f3325c3338 +f3400f1204 +f34497c932 +f34a56525e +f36483c824 +f3704d5663 +f3734c4913 +f38e5aa5b4 +f3986fba44 +f3a0ffc7d9 +f3b24a7d28 +f3e6c35ec3 +f3fc0ea80b +f40a683fbe +f4207ca554 +f4377499c2 +f46184f393 +f46c2d0a6d +f46c364dca +f46f7a0b63 +f46fe141b0 +f470b9aeb0 +f47eb7437f +f48b535719 +f49e4866ac +f4aa882cfd +f4daa3dbd5 +f4dd51ac35 +f507a1b9dc +f51c5ac84b +f52104164b +f54c67b9bb +f5966cadd2 +f5bddf5598 +f5d85cfd17 +f5e2e7d6a0 +f5f051e9b4 +f5f8a93a76 +f6283e8af5 +f635e9568b +f6474735be +f659251be2 +f66981af4e +f6708fa398 +f697fe8e8f +f6adb12c42 +f6c7906ca4 +f6cd0a8016 +f6d6f15ae7 +f6e501892c +f6f59d986f +f6fe8c90a5 +f714160545 +f74c3888d7 +f7782c430e +f7783ae5f2 +f77ab47923 +f788a98327 +f7961ac1f0 +f7a71e7574 +f7a8521432 +f7afbf4947 +f7b7cd5f44 +f7cf4b4a39 +f7d49799ad +f7e0c9bb83 +f7e5b84928 +f7e6bd58be +f7f2a38ac6 +f7f6cb2d6d +f83f19e796 +f85796a921 +f8603c26b2 +f8819b42ec +f891f8eaa1 +f89288d10c +f895ae8cc1 +f8b4ac12f1 +f8c3fb2b01 +f8c8de2764 +f8db369b40 +f8fcb6a78c +f94aafdeef +f95d217b70 +f9681d5103 +f9750192a4 +f9823a32c2 +f991ddb4c2 +f99d535567 +f9ae3d98b7 +f9b6217959 +f9bd1fabf5 +f9c68eaa64 +f9d3e04c4f +f9daf64494 +f9e4cc5a0a +f9ea6b7f31 +f9f3852526 +fa04c615cf +fa08e00a56 +fa4370d74d +fa67744af3 +fa88d48a92 +fa8b904cc9 +fa9526bdf1 +fa9b9d2426 +fad633fbe1 +faf5222dc3 +faff0e15f1 +fb08c64e8c +fb23455a7f +fb2e19fa6e +fb34dfbb77 +fb47fcea1e +fb49738155 +fb4cbc514b +fb4e6062f7 +fb5ba7ad6e +fb63cd1236 +fb81157a07 +fb92abdaeb +fba22a6848 +fbaca0c9df +fbc645f602 +fbd77444cd +fbe53dc8e8 +fbe541dd73 +fbe8488798 +fbfd25174f +fc28cb305e +fc33b1ffd6 +fc6186f0bb +fc918e3a40 +fc96cda9d8 +fc9832eea4 +fcb10d0f81 +fcd20a2509 +fcf637e3ab +fcfd81727f +fd31890379 +fd33551c28 +fd542da05e +fd6789b3fe +fd77828200 +fd7af75f4d +fdb28d0fbb +fdb3d1fb1e +fdb8b04124 +fdc6e3d581 +fdfce7e6fc +fe0f76d41b +fe24b0677d +fe3c02699d +fe58b48235 +fe6a5596b8 +fe6c244f63 +fe7afec086 +fe985d510a +fe9db35d15 +fea8ffcd36 +feb1080388 +fed208bfca +feda5ad1c2 +feec95b386 +ff15a5eff6 +ff204daf4b +ff25f55852 +ff2ada194f +ff2ce142e8 +ff49d36d20 +ff5a1ec4f3 +ff66152b25 +ff692fdc56 +ff773b1a1e +ff97129478 +ffb904207d +ffc43fc345 +fffe5f8df6 diff --git a/vfi-test/.gitattributes b/vfi-test/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..a6344aac8c09253b3b630fb776ae94478aa0275b --- /dev/null +++ b/vfi-test/.gitattributes @@ -0,0 +1,35 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/vfi-test/README.md b/vfi-test/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d4fbf87f8b99d174285a69b47dcaaacdda31e386 --- /dev/null +++ b/vfi-test/README.md @@ -0,0 +1,13 @@ +--- +title: VfiTest +emoji: ๐Ÿ‘ +colorFrom: yellow +colorTo: blue +sdk: gradio +sdk_version: 4.19.2 +app_file: app.py +pinned: false +license: unknown +--- + +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/vfi_inference.py b/vfi_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..5a03e448b76f53c081c2b5098b4c4be51896d614 --- /dev/null +++ b/vfi_inference.py @@ -0,0 +1,158 @@ +# [START_COMMAND] +# python3 -m vfi_inference --cuda_index 0 \ +# --use_video --root ../VFI_Inference/thistest/test_video.mp4 --save_root ../VFI_Inference/thistest/results --source_frame_ext png \ +# --pretrain_path ./pretrained/upr_freq002.pth \ +# --pyr_level 7 --nr_lvl_skipped 0 --splat_mode average --down_scale 1 \ +# --make_video --fps 0 --new_video_name test_video_vfi.mp4 + +# python3 -m vfi_inference --cuda_index 0 \ +# --root ../VFI_Inference/thistest/frames --save_root ../VFI_Inference/thistest/results --source_frame_ext png \ +# --pretrain_path ./pretrained/upr_freq002.pth \ +# --pyr_level 7 --nr_lvl_skipped 0 --splat_mode average --down_scale 1 \ +# --make_video --fps 0 --new_video_name test_video_vfi.mp4 + +# [FILE SYSTEM] ํ”„๋ ˆ์ž„ ์‹œํ€€์Šค +# args.root ํด๋”์— ํ”„๋ ˆ์ž„ ์‹œํ€€์Šค ๋ชจ์—ฌ ์žˆ์–ด์•ผ ํ•จ +# args.save_root ํด๋”๋Š” args.root ํด๋”์™€ ์ƒ์œ„ํด๋”๊ฐ€ ๋™์ผํ•ด์•ผ ํ•˜๊ณ , args.save_root ํด๋”์— ๊ฒฐ๊ณผ๋ฌผ ํ”„๋ ˆ์ž„ ์‹œํ€€์Šค ์ €์žฅ๋จ + +# [FILE SYSTEM] ๋น„๋””์˜ค +# args.root๋Š” ๋น„๋””์˜ค ๊ฒฝ๋กœ +# ํ”„๋ ˆ์ž„ ์‹œํ€€์Šค๋Š” ์ž๋™์œผ๋กœ ์ €์žฅ +# args.save_root ํด๋”๋Š” args.root ํŒŒ์ผ๊ณผ ์ƒ์œ„ํด๋”๊ฐ€ ๋™์ผํ•ด์•ผ ํ•˜๊ณ , args.save_root ํด๋”์— ๊ฒฐ๊ณผ๋ฌผ ํ”„๋ ˆ์ž„ ์‹œํ€€์Šค ์ €์žฅ๋จ + +import argparse + +import os +import cv2 +import glob +import torch +import numpy as np +from PIL import Image +from tqdm import tqdm +from torchvision.transforms import functional as TF + +from modules.components.upr_net_freq import upr_freq as upr_freq002 +from modules.components.upr_basic import upr as upr_basic + +def parent_folder(path): + return os.path.split(path)[0] + +print('์ธํผ๋Ÿฐ์Šค ์‹œ\n1. utils.pad.py replicate->constant๋กœ ๋ณ€๊ฒฝํ•˜๊ณ \n2. components upr Model ์ตœ์ดˆ์ธํ’‹์—์„œ normalization๊ณผ padding ์œ„์น˜ ๋ฐ”๊ฟจ๋Š”์ง€ ํ™•์ธํ•  ๊ฒƒ (padding์ด ์œ„์— ์žˆ์–ด์•ผ๋จ)') +def main(): + parser = argparse.ArgumentParser('Video Frame Interpolation Inference',add_help=True) + parser.add_argument('--cuda_index', default=0, type=int, help='CUDA GPU index') + + parser.add_argument('--use_video', action='store_true', help='whether using video file') + parser.add_argument('--root', default='', type=str, help='root containing frames [./videoname/frames] (or video [./videoname/videoname.mp4])') + parser.add_argument('--save_root', default='', type=str, help='root to save result frames [./videoname/results_expname]') + parser.add_argument('--source_frame_ext', default='png', type=str, help='source frames extension name') + + parser.add_argument('--pretrain_path', default='', type=str, help='path containing pretrained model') + + parser.add_argument('--pyr_level', default=5, type=int, help='UPR-Net pyramid level') + parser.add_argument('--nr_lvl_skipped', default=0, type=int, help='UPR-Net pyramid skip number') + parser.add_argument('--splat_mode', default='average', type=str, help='UPR-Net warping splat mode') + parser.add_argument('--down_scale', default=1, type=int, help='frame down-scaling factor (due to GPU memory issue)') + + parser.add_argument('--make_video', action='store_true', help='whether merging frames and making video file') + parser.add_argument('--fps', default=0, type=int, help='FPS before VFI') + parser.add_argument('--new_video_name', default='newvideo', type=str, help='new video name [new_video_name.mp4]') + + args = parser.parse_args() + assert parent_folder(args.root)==parent_folder(args.save_root),\ + f"the parents of 'root' ({parent_folder(args.root)}) and save_root ({parent_folder(args.save_root)}) should be same!!" + if args.make_video: + assert os.path.splitext(args.new_video_name)[1]!='', f"'new_video_name' ({args.new_video_name}) should have extension name!!" + assert parent_folder(args.new_video_name)=='', f"'new_video_name' should not contain directory path" + if args.use_video: + temp1 = cv2.VideoCapture(args.root) + temp2 = int(temp1.get(cv2.CAP_PROP_FRAME_COUNT)) + assert temp2>0, f"number of frames in video ({args.root}) must be larger than 0!! !!check file name!!" + temp1.release() + del temp1, temp2 + + DEVICE = args.cuda_index + torch.cuda.set_device(DEVICE) + VIDEO_ROOT = args.root if args.use_video else None + FRAME_ROOT = args.root if VIDEO_ROOT is None else parent_folder(VIDEO_ROOT)+'/frames' + SAVE_ROOT = args.save_root + EXT = args.source_frame_ext + SCALE = args.down_scale + + if VIDEO_ROOT is not None: + print('@@@@@@@@@@@@@@@@@@@@Extracting frames from video@@@@@@@@@@@@@@@@@@@@') + os.makedirs(FRAME_ROOT, exist_ok=True) + video = cv2.VideoCapture(VIDEO_ROOT) + this_fps = video.get(cv2.CAP_PROP_FPS) + for index in tqdm(range(int(video.get(cv2.CAP_PROP_FRAME_COUNT)))): + _, frame = video.read() + newfile = os.path.join(FRAME_ROOT, str(index).zfill(4)+f'.{EXT}') + cv2.imwrite(newfile, frame) + video.release() + + model = upr_freq002.Model(pyr_level=args.pyr_level, + nr_lvl_skipped=args.nr_lvl_skipped, + splat_mode=args.splat_mode) + sd = torch.load(args.pretrain_path, map_location='cpu') + sd = sd['model'] if 'model' in sd.keys() else sd + print(model.load_state_dict(sd)) + model = model.to(DEVICE) + + file_list = sorted(glob.glob(os.path.join(FRAME_ROOT, f'*.{EXT}'))) + for i, file in enumerate(file_list): + newfile = os.path.join(FRAME_ROOT, str(i).zfill(4)+f'.{EXT}') + os.rename(file, newfile) + + if args.make_video: + num_frame_before = len(file_list) + fps_before = args.fps if not args.use_video else this_fps + num_frame_after = 2*num_frame_before-1 + fps_after = fps_before*num_frame_after/num_frame_before + print(f'num_frame_before: {num_frame_before}, fps_before: {fps_before:.6f}, time_before: {num_frame_before/fps_before:.6f}') + print(f'num_frame_after: {num_frame_after}, fps_after: {fps_after:.6f}, time_after: {num_frame_after/fps_after:.6f}') + print() + + print('@@@@@@@@@@@@@@@@@@@@Staring VFI@@@@@@@@@@@@@@@@@@@@') + os.makedirs(SAVE_ROOT, exist_ok=True) + for frame_num, file in enumerate(tqdm(file_list)): + img0 = img1 if frame_num!=0 else None + aaa = os.path.join(SAVE_ROOT, str(frame_num*2).zfill(4)+f'.{EXT}') + if EXT not in ['tga', 'TGA']: + img1 = cv2.imread(file) + cv2.imwrite(aaa, img1) + else: + img1 = Image.open(file) + img1.save(aaa) + img1 = np.array(img1)[:,:,[2,1,0]] + H,W,_ = img1.shape + + if SCALE==1: + img1 = (torch.from_numpy(img1[:,:,[2,1,0]])/255).permute(2,0,1).unsqueeze(0).to(DEVICE) + else: + img1 = (torch.from_numpy(cv2.resize(img1, (W//SCALE,H//SCALE), interpolation=cv2.INTER_CUBIC)[:,:,[2,1,0]])/255).permute(2,0,1).unsqueeze(0).to(DEVICE) + if img0 is None: continue + + with torch.no_grad(): + result_dict, extra_dict = model(img0, img1, pyr_level=args.pyr_level, nr_lvl_skipped=args.nr_lvl_skipped, time_step=0.5) + out = result_dict['imgt_pred'] + + bbb = os.path.join(SAVE_ROOT, str(2*frame_num-1).zfill(4)+f'.{EXT}') + if EXT not in ['tga', 'TGA']: + if SCALE==1: + out = (out[0].cpu().permute(1,2,0).clamp(0,1).numpy()*255).astype(np.uint8)[:,:,[2,1,0]] + else: + out = cv2.resize((out[0].cpu().permute(1,2,0).clamp(0,1).numpy()*255).astype(np.uint8)[:,:,[2,1,0]], (W,H), interpolation=cv2.INTER_CUBIC) + cv2.imwrite(bbb, out) + else: + if SCALE==1: + out = TF.to_pil_image(out[0].clamp(0,1).cpu()) + else: + out = TF.to_pil_image(TF.resize(out[0].clamp(0,1).cpu(), (H,W), interpolation=TF.InterpolationMode.BICUBIC)) + out.save(bbb) + + if args.make_video: + cmd = f'ffmpeg -framerate {fps_after} -i {SAVE_ROOT}/%04d.{EXT} -c:v libx264 -preset veryslow -crf 10 {parent_folder(SAVE_ROOT)}/{args.new_video_name}' + os.system(cmd) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/vfi_inference_triplet.py b/vfi_inference_triplet.py new file mode 100644 index 0000000000000000000000000000000000000000..9c03edc0e8a5e933cff82611ad99e3cb413971fc --- /dev/null +++ b/vfi_inference_triplet.py @@ -0,0 +1,138 @@ +# [START_COMMAND] +# python3 -m vfi_inference --cuda_index 0 \ +# --use_video --root ../VFI_Inference/thistest/test_video.mp4 --save_root ../VFI_Inference/thistest/results --source_frame_ext png \ +# --pretrain_path ./pretrained/upr_freq002.pth \ +# --pyr_level 7 --nr_lvl_skipped 0 --splat_mode average --down_scale 1 \ +# --make_video --fps 0 --new_video_name test_video_vfi.mp4 + +# python3 -m vfi_inference_triplet --cuda_index 0 \ +# --root ../VFI_Inference/thistriplet_notarget --pretrain_path ./pretrained/upr_freq002.pth \ +# --pyr_level 7 --nr_lvl_skipped 0 --splat_mode average --down_scale 1 + +# [FILE SYSTEM] +# args.root ํด๋” ์•„๋ž˜์— +# ํ•˜์œ„ ํด๋” (๊นŠ์ด๋Š” ์ตœ๋Œ€ 10๊ฐœ) ์•„๋ž˜์— +# triplet 3๊ฐœ ์ด๋ฏธ์ง€ (with GT) ๋˜๋Š” triplet 2๊ฐœ ์ด๋ฏธ์ง€ (without GT) + +import argparse + +import os +import cv2 +import glob +import torch +import datetime +import numpy as np +from PIL import Image +from tqdm import tqdm +from torch.nn import functional as F +from torchvision.transforms import functional as TF +from utils.metrics import calculate_batch_psnr, calculate_batch_ssim + +from modules.components.upr_net_freq import upr_freq as upr_freq002 +from modules.components.upr_basic import upr as upr_basic + +def multiple_pad(image, multiple): + _,_,H,W = image.size() + pad1 = multiple-(H%multiple) if H%multiple!=0 else 0 + pad2 = multiple-(W%multiple) if W%multiple!=0 else 0 + return TF.pad(image, (0,0,pad2,pad1)) + +print('์ธํผ๋Ÿฐ์Šค ์‹œ\n1. utils.pad.py replicate->constant๋กœ ๋ณ€๊ฒฝํ•˜๊ณ \n2. components upr Model ์ตœ์ดˆ์ธํ’‹์—์„œ normalization๊ณผ padding ์œ„์น˜ ๋ฐ”๊ฟจ๋Š”์ง€ ํ™•์ธํ•  ๊ฒƒ (padding์ด ์œ„์— ์žˆ์–ด์•ผ๋จ)') +def main(): + NOW = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') + + parser = argparse.ArgumentParser('Video Frame Interpolation Inference',add_help=True) + parser.add_argument('--cuda_index', default=0, type=int, help='CUDA GPU index') + + parser.add_argument('--exist_gt', action='store_true', help='whether ground-truth existing') + parser.add_argument('--root', default='', type=str, help='root containing frames [./triplet_name]') + parser.add_argument('--pretrain_path', default='', type=str, help='path containing pretrained model') + + parser.add_argument('--pyr_level', default=5, type=int, help='UPR-Net pyramid level') + parser.add_argument('--nr_lvl_skipped', default=0, type=int, help='UPR-Net pyramid skip number') + parser.add_argument('--splat_mode', default='average', type=str, help='UPR-Net warping splat mode') + parser.add_argument('--down_scale', default=1, type=int, help='frame down-scaling factor (due to GPU memory issue)') + + args = parser.parse_args() + assert not args.root.endswith('/'), f"'root' ({args.root}) must not end with '/'" + + DEVICE = args.cuda_index + torch.cuda.set_device(DEVICE) + ROOT = args.root + SAVE_ROOT = f'{ROOT}_{NOW}' + os.makedirs(SAVE_ROOT, exist_ok=True) + SCALE = args.down_scale + + model = upr_freq002.Model(pyr_level=args.pyr_level, + nr_lvl_skipped=args.nr_lvl_skipped, + splat_mode=args.splat_mode) + sd = torch.load(args.pretrain_path, map_location='cpu') + sd = sd['model'] if 'model' in sd.keys() else sd + print(model.load_state_dict(sd)) + model = model.to(DEVICE) + + star = '/*' + temp = [x for i in range(10) for x in glob.glob(f'{ROOT}{star*i}') if os.path.isfile(x)] + folder_list = sorted(set([os.path.split(x)[0] for x in temp])) + if args.exist_gt: + with open(os.path.join(SAVE_ROOT, f'record.txt'), 'w', encoding='utf8') as f: + f.writelines('') + psnr_list = [] + ssim_list = [] + + print('@@@@@@@@@@@@@@@@@@@@Staring VFI@@@@@@@@@@@@@@@@@@@@') + for folder in tqdm(folder_list): + file_list = [] + for ext in ['tif', 'TIF', 'jpg', 'png', 'tga', 'TGA']: + file_list += sorted(glob.glob(os.path.join(folder, f'*.{ext}'))) + cur_ext = os.path.splitext(file_list[0])[1][1:] + if cur_ext in ['tga', 'TGA']: + img_list = [TF.to_tensor(Image.open(file))[:3].unsqueeze(0).to(DEVICE) for file in file_list] + else: + img_list = [(torch.from_numpy(cv2.imread(file)[:,:,[2,1,0]])/255).permute(2,0,1).unsqueeze(0).to(DEVICE) for file in file_list] + + _,_,Hori,Wori = img_list[0].size() +# if Hori*Wori<=2100000: +# SCALE = 1 +# elif Hori*Wori<=2100000*4: +# SCALE = 2 +# else: +# SCALE = 4 + if args.exist_gt: + img_list = [multiple_pad(img, SCALE) if k!=1 else img for k, img in enumerate(img_list)] + img_list = [F.interpolate(img, scale_factor=1/SCALE, mode='bicubic') if k!=1 else img for k, img in enumerate(img_list)] + img0,imgt,img1 = img_list + else: + img_list = [multiple_pad(img, SCALE) for k, img in enumerate(img_list)] + img_list = [F.interpolate(img, scale_factor=1/SCALE, mode='bicubic') for k, img in enumerate(img_list)] + img0,img1 = img_list + + with torch.no_grad(): + result_dict, extra_dict = model(img0, img1, pyr_level=args.pyr_level, nr_lvl_skipped=args.nr_lvl_skipped, time_step=0.5) + out = F.interpolate(result_dict['imgt_pred'], scale_factor=SCALE, mode='bicubic')[:,:,:Hori,:Wori].clamp(0,1) + + if args.exist_gt: + psnr, _ = calculate_batch_psnr(imgt, out) + ssim, _ = calculate_batch_ssim(imgt, out) + psnr_list.append(psnr) + ssim_list.append(ssim) + + filepath, ext = os.path.splitext(file_list[1]) + newfilename = filepath.replace(ROOT, SAVE_ROOT) + newfile = newfilename+'_pred'+ext if args.exist_gt else os.path.join(os.path.split(newfilename)[0], 'im_pred'+ext) + newfolder = os.path.split(newfile)[0] + os.makedirs(newfolder, exist_ok=True) + if cur_ext in ['tga', 'TGA']: + TF.to_pil_image(out[0].cpu()).save(newfile) + else: + cv2.imwrite(newfile, (out[0].cpu().permute(1,2,0)*255).numpy().astype(np.uint8)[:,:,[2,1,0]]) + + if args.exist_gt: + with open(os.path.join(SAVE_ROOT, f'record.txt'), 'a', encoding='utf8') as f: + foldername = '/'.join(folder.split('/')[2:]) + f.writelines(f'{foldername:45}PSNR: {psnr:.4f} SSIM: {ssim:.4f}\n') + if args.exist_gt: + print(f'PSNR: {np.mean(psnr_list):.4f}, SSIM: {np.mean(ssim_list):.6f}') + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/wandb.yaml b/wandb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bece2eceb5349ce87373da887edbfbf54c11383a --- /dev/null +++ b/wandb.yaml @@ -0,0 +1,3 @@ +entity: inshorts00 +project: spocklab +api_key: 06354aa12a73cf3c1c93f0d14d53169818b7b599