Spaces:
Running
on
A10G
Running
on
A10G
init code
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +131 -0
- LICENSE +14 -0
- RAFT/__init__.py +2 -0
- RAFT/corr.py +111 -0
- RAFT/datasets.py +235 -0
- RAFT/demo.py +79 -0
- RAFT/extractor.py +267 -0
- RAFT/raft.py +146 -0
- RAFT/update.py +139 -0
- RAFT/utils/__init__.py +2 -0
- RAFT/utils/augmentor.py +246 -0
- RAFT/utils/flow_viz.py +132 -0
- RAFT/utils/flow_viz_pt.py +118 -0
- RAFT/utils/frame_utils.py +137 -0
- RAFT/utils/utils.py +82 -0
- configs/train_flowcomp.json +40 -0
- configs/train_propainter.json +48 -0
- core/dataset.py +232 -0
- core/dist.py +47 -0
- core/loss.py +180 -0
- core/lr_scheduler.py +112 -0
- core/metrics.py +569 -0
- core/prefetch_dataloader.py +125 -0
- core/trainer.py +509 -0
- core/trainer_flow_w_edge.py +380 -0
- core/utils.py +371 -0
- datasets/davis/test.json +1 -0
- datasets/davis/train.json +1 -0
- datasets/youtube-vos/test.json +1 -0
- datasets/youtube-vos/train.json +1 -0
- inference_propainter.py +475 -0
- model/__init__.py +1 -0
- model/canny/canny_filter.py +256 -0
- model/canny/filter.py +288 -0
- model/canny/gaussian.py +116 -0
- model/canny/kernels.py +690 -0
- model/canny/sobel.py +263 -0
- model/misc.py +131 -0
- model/modules/base_module.py +131 -0
- model/modules/deformconv.py +54 -0
- model/modules/flow_comp_raft.py +265 -0
- model/modules/flow_loss_utils.py +142 -0
- model/modules/sparse_transformer.py +344 -0
- model/modules/spectral_norm.py +288 -0
- model/propainter.py +532 -0
- model/recurrent_flow_completion.py +347 -0
- model/vgg_arch.py +157 -0
- requirements.txt +33 -0
- scripts/compute_flow.py +108 -0
- scripts/evaluate_flow_completion.py +197 -0
.gitignore
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.vscode
|
2 |
+
|
3 |
+
# ignored files
|
4 |
+
version.py
|
5 |
+
|
6 |
+
# ignored files with suffix
|
7 |
+
*.html
|
8 |
+
# *.png
|
9 |
+
# *.jpeg
|
10 |
+
# *.jpg
|
11 |
+
# *.gif
|
12 |
+
*.pt
|
13 |
+
*.pth
|
14 |
+
*.dat
|
15 |
+
*.zip
|
16 |
+
|
17 |
+
# template
|
18 |
+
|
19 |
+
# Byte-compiled / optimized / DLL files
|
20 |
+
__pycache__/
|
21 |
+
*.py[cod]
|
22 |
+
*$py.class
|
23 |
+
|
24 |
+
# C extensions
|
25 |
+
*.so
|
26 |
+
|
27 |
+
# Distribution / packaging
|
28 |
+
.Python
|
29 |
+
build/
|
30 |
+
develop-eggs/
|
31 |
+
dist/
|
32 |
+
downloads/
|
33 |
+
eggs/
|
34 |
+
.eggs/
|
35 |
+
lib/
|
36 |
+
lib64/
|
37 |
+
parts/
|
38 |
+
sdist/
|
39 |
+
var/
|
40 |
+
wheels/
|
41 |
+
*.egg-info/
|
42 |
+
.installed.cfg
|
43 |
+
*.egg
|
44 |
+
MANIFEST
|
45 |
+
|
46 |
+
# PyInstaller
|
47 |
+
# Usually these files are written by a python script from a template
|
48 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
49 |
+
*.manifest
|
50 |
+
*.spec
|
51 |
+
|
52 |
+
# Installer logs
|
53 |
+
pip-log.txt
|
54 |
+
pip-delete-this-directory.txt
|
55 |
+
|
56 |
+
# Unit test / coverage reports
|
57 |
+
htmlcov/
|
58 |
+
.tox/
|
59 |
+
.coverage
|
60 |
+
.coverage.*
|
61 |
+
.cache
|
62 |
+
nosetests.xml
|
63 |
+
coverage.xml
|
64 |
+
*.cover
|
65 |
+
.hypothesis/
|
66 |
+
.pytest_cache/
|
67 |
+
|
68 |
+
# Translations
|
69 |
+
*.mo
|
70 |
+
*.pot
|
71 |
+
|
72 |
+
# Django stuff:
|
73 |
+
*.log
|
74 |
+
local_settings.py
|
75 |
+
db.sqlite3
|
76 |
+
|
77 |
+
# Flask stuff:
|
78 |
+
instance/
|
79 |
+
.webassets-cache
|
80 |
+
|
81 |
+
# Scrapy stuff:
|
82 |
+
.scrapy
|
83 |
+
|
84 |
+
# Sphinx documentation
|
85 |
+
docs/_build/
|
86 |
+
|
87 |
+
# PyBuilder
|
88 |
+
target/
|
89 |
+
|
90 |
+
# Jupyter Notebook
|
91 |
+
.ipynb_checkpoints
|
92 |
+
|
93 |
+
# pyenv
|
94 |
+
.python-version
|
95 |
+
|
96 |
+
# celery beat schedule file
|
97 |
+
celerybeat-schedule
|
98 |
+
|
99 |
+
# SageMath parsed files
|
100 |
+
*.sage.py
|
101 |
+
|
102 |
+
# Environments
|
103 |
+
.env
|
104 |
+
.venv
|
105 |
+
env/
|
106 |
+
venv/
|
107 |
+
ENV/
|
108 |
+
env.bak/
|
109 |
+
venv.bak/
|
110 |
+
|
111 |
+
# Spyder project settings
|
112 |
+
.spyderproject
|
113 |
+
.spyproject
|
114 |
+
|
115 |
+
# Rope project settings
|
116 |
+
.ropeproject
|
117 |
+
|
118 |
+
# mkdocs documentation
|
119 |
+
/site
|
120 |
+
|
121 |
+
# mypy
|
122 |
+
.mypy_cache/
|
123 |
+
|
124 |
+
# project
|
125 |
+
experiments_model/
|
126 |
+
unreleased/
|
127 |
+
results_eval/
|
128 |
+
results/
|
129 |
+
*debug*
|
130 |
+
*old*
|
131 |
+
*.sh
|
LICENSE
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# S-Lab License 1.0
|
2 |
+
|
3 |
+
Copyright 2023 S-Lab
|
4 |
+
|
5 |
+
Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
6 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
7 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
8 |
+
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.\
|
9 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
10 |
+
4. In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work.
|
11 |
+
|
12 |
+
|
13 |
+
---
|
14 |
+
For the commercial use of the code, please consult Prof. Chen Change Loy ([email protected])
|
RAFT/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# from .demo import RAFT_infer
|
2 |
+
from .raft import RAFT
|
RAFT/corr.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from .utils.utils import bilinear_sampler, coords_grid
|
4 |
+
|
5 |
+
try:
|
6 |
+
import alt_cuda_corr
|
7 |
+
except:
|
8 |
+
# alt_cuda_corr is not compiled
|
9 |
+
pass
|
10 |
+
|
11 |
+
|
12 |
+
class CorrBlock:
|
13 |
+
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
14 |
+
self.num_levels = num_levels
|
15 |
+
self.radius = radius
|
16 |
+
self.corr_pyramid = []
|
17 |
+
|
18 |
+
# all pairs correlation
|
19 |
+
corr = CorrBlock.corr(fmap1, fmap2)
|
20 |
+
|
21 |
+
batch, h1, w1, dim, h2, w2 = corr.shape
|
22 |
+
corr = corr.reshape(batch*h1*w1, dim, h2, w2)
|
23 |
+
|
24 |
+
self.corr_pyramid.append(corr)
|
25 |
+
for i in range(self.num_levels-1):
|
26 |
+
corr = F.avg_pool2d(corr, 2, stride=2)
|
27 |
+
self.corr_pyramid.append(corr)
|
28 |
+
|
29 |
+
def __call__(self, coords):
|
30 |
+
r = self.radius
|
31 |
+
coords = coords.permute(0, 2, 3, 1)
|
32 |
+
batch, h1, w1, _ = coords.shape
|
33 |
+
|
34 |
+
out_pyramid = []
|
35 |
+
for i in range(self.num_levels):
|
36 |
+
corr = self.corr_pyramid[i]
|
37 |
+
dx = torch.linspace(-r, r, 2*r+1)
|
38 |
+
dy = torch.linspace(-r, r, 2*r+1)
|
39 |
+
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
|
40 |
+
|
41 |
+
centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
|
42 |
+
delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
|
43 |
+
coords_lvl = centroid_lvl + delta_lvl
|
44 |
+
|
45 |
+
corr = bilinear_sampler(corr, coords_lvl)
|
46 |
+
corr = corr.view(batch, h1, w1, -1)
|
47 |
+
out_pyramid.append(corr)
|
48 |
+
|
49 |
+
out = torch.cat(out_pyramid, dim=-1)
|
50 |
+
return out.permute(0, 3, 1, 2).contiguous().float()
|
51 |
+
|
52 |
+
@staticmethod
|
53 |
+
def corr(fmap1, fmap2):
|
54 |
+
batch, dim, ht, wd = fmap1.shape
|
55 |
+
fmap1 = fmap1.view(batch, dim, ht*wd)
|
56 |
+
fmap2 = fmap2.view(batch, dim, ht*wd)
|
57 |
+
|
58 |
+
corr = torch.matmul(fmap1.transpose(1,2), fmap2)
|
59 |
+
corr = corr.view(batch, ht, wd, 1, ht, wd)
|
60 |
+
return corr / torch.sqrt(torch.tensor(dim).float())
|
61 |
+
|
62 |
+
|
63 |
+
class CorrLayer(torch.autograd.Function):
|
64 |
+
@staticmethod
|
65 |
+
def forward(ctx, fmap1, fmap2, coords, r):
|
66 |
+
fmap1 = fmap1.contiguous()
|
67 |
+
fmap2 = fmap2.contiguous()
|
68 |
+
coords = coords.contiguous()
|
69 |
+
ctx.save_for_backward(fmap1, fmap2, coords)
|
70 |
+
ctx.r = r
|
71 |
+
corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r)
|
72 |
+
return corr
|
73 |
+
|
74 |
+
@staticmethod
|
75 |
+
def backward(ctx, grad_corr):
|
76 |
+
fmap1, fmap2, coords = ctx.saved_tensors
|
77 |
+
grad_corr = grad_corr.contiguous()
|
78 |
+
fmap1_grad, fmap2_grad, coords_grad = \
|
79 |
+
correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r)
|
80 |
+
return fmap1_grad, fmap2_grad, coords_grad, None
|
81 |
+
|
82 |
+
|
83 |
+
class AlternateCorrBlock:
|
84 |
+
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
85 |
+
self.num_levels = num_levels
|
86 |
+
self.radius = radius
|
87 |
+
|
88 |
+
self.pyramid = [(fmap1, fmap2)]
|
89 |
+
for i in range(self.num_levels):
|
90 |
+
fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
|
91 |
+
fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
|
92 |
+
self.pyramid.append((fmap1, fmap2))
|
93 |
+
|
94 |
+
def __call__(self, coords):
|
95 |
+
|
96 |
+
coords = coords.permute(0, 2, 3, 1)
|
97 |
+
B, H, W, _ = coords.shape
|
98 |
+
|
99 |
+
corr_list = []
|
100 |
+
for i in range(self.num_levels):
|
101 |
+
r = self.radius
|
102 |
+
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1)
|
103 |
+
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1)
|
104 |
+
|
105 |
+
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
|
106 |
+
corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r)
|
107 |
+
corr_list.append(corr.squeeze(1))
|
108 |
+
|
109 |
+
corr = torch.stack(corr_list, dim=1)
|
110 |
+
corr = corr.reshape(B, -1, H, W)
|
111 |
+
return corr / 16.0
|
RAFT/datasets.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Data loading based on https://github.com/NVIDIA/flownet2-pytorch
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.utils.data as data
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
import os
|
9 |
+
import math
|
10 |
+
import random
|
11 |
+
from glob import glob
|
12 |
+
import os.path as osp
|
13 |
+
|
14 |
+
from utils import frame_utils
|
15 |
+
from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
|
16 |
+
|
17 |
+
|
18 |
+
class FlowDataset(data.Dataset):
|
19 |
+
def __init__(self, aug_params=None, sparse=False):
|
20 |
+
self.augmentor = None
|
21 |
+
self.sparse = sparse
|
22 |
+
if aug_params is not None:
|
23 |
+
if sparse:
|
24 |
+
self.augmentor = SparseFlowAugmentor(**aug_params)
|
25 |
+
else:
|
26 |
+
self.augmentor = FlowAugmentor(**aug_params)
|
27 |
+
|
28 |
+
self.is_test = False
|
29 |
+
self.init_seed = False
|
30 |
+
self.flow_list = []
|
31 |
+
self.image_list = []
|
32 |
+
self.extra_info = []
|
33 |
+
|
34 |
+
def __getitem__(self, index):
|
35 |
+
|
36 |
+
if self.is_test:
|
37 |
+
img1 = frame_utils.read_gen(self.image_list[index][0])
|
38 |
+
img2 = frame_utils.read_gen(self.image_list[index][1])
|
39 |
+
img1 = np.array(img1).astype(np.uint8)[..., :3]
|
40 |
+
img2 = np.array(img2).astype(np.uint8)[..., :3]
|
41 |
+
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
|
42 |
+
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
|
43 |
+
return img1, img2, self.extra_info[index]
|
44 |
+
|
45 |
+
if not self.init_seed:
|
46 |
+
worker_info = torch.utils.data.get_worker_info()
|
47 |
+
if worker_info is not None:
|
48 |
+
torch.manual_seed(worker_info.id)
|
49 |
+
np.random.seed(worker_info.id)
|
50 |
+
random.seed(worker_info.id)
|
51 |
+
self.init_seed = True
|
52 |
+
|
53 |
+
index = index % len(self.image_list)
|
54 |
+
valid = None
|
55 |
+
if self.sparse:
|
56 |
+
flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
|
57 |
+
else:
|
58 |
+
flow = frame_utils.read_gen(self.flow_list[index])
|
59 |
+
|
60 |
+
img1 = frame_utils.read_gen(self.image_list[index][0])
|
61 |
+
img2 = frame_utils.read_gen(self.image_list[index][1])
|
62 |
+
|
63 |
+
flow = np.array(flow).astype(np.float32)
|
64 |
+
img1 = np.array(img1).astype(np.uint8)
|
65 |
+
img2 = np.array(img2).astype(np.uint8)
|
66 |
+
|
67 |
+
# grayscale images
|
68 |
+
if len(img1.shape) == 2:
|
69 |
+
img1 = np.tile(img1[...,None], (1, 1, 3))
|
70 |
+
img2 = np.tile(img2[...,None], (1, 1, 3))
|
71 |
+
else:
|
72 |
+
img1 = img1[..., :3]
|
73 |
+
img2 = img2[..., :3]
|
74 |
+
|
75 |
+
if self.augmentor is not None:
|
76 |
+
if self.sparse:
|
77 |
+
img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
|
78 |
+
else:
|
79 |
+
img1, img2, flow = self.augmentor(img1, img2, flow)
|
80 |
+
|
81 |
+
img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
|
82 |
+
img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
|
83 |
+
flow = torch.from_numpy(flow).permute(2, 0, 1).float()
|
84 |
+
|
85 |
+
if valid is not None:
|
86 |
+
valid = torch.from_numpy(valid)
|
87 |
+
else:
|
88 |
+
valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
|
89 |
+
|
90 |
+
return img1, img2, flow, valid.float()
|
91 |
+
|
92 |
+
|
93 |
+
def __rmul__(self, v):
|
94 |
+
self.flow_list = v * self.flow_list
|
95 |
+
self.image_list = v * self.image_list
|
96 |
+
return self
|
97 |
+
|
98 |
+
def __len__(self):
|
99 |
+
return len(self.image_list)
|
100 |
+
|
101 |
+
|
102 |
+
class MpiSintel(FlowDataset):
|
103 |
+
def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
|
104 |
+
super(MpiSintel, self).__init__(aug_params)
|
105 |
+
flow_root = osp.join(root, split, 'flow')
|
106 |
+
image_root = osp.join(root, split, dstype)
|
107 |
+
|
108 |
+
if split == 'test':
|
109 |
+
self.is_test = True
|
110 |
+
|
111 |
+
for scene in os.listdir(image_root):
|
112 |
+
image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
|
113 |
+
for i in range(len(image_list)-1):
|
114 |
+
self.image_list += [ [image_list[i], image_list[i+1]] ]
|
115 |
+
self.extra_info += [ (scene, i) ] # scene and frame_id
|
116 |
+
|
117 |
+
if split != 'test':
|
118 |
+
self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
|
119 |
+
|
120 |
+
|
121 |
+
class FlyingChairs(FlowDataset):
|
122 |
+
def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
|
123 |
+
super(FlyingChairs, self).__init__(aug_params)
|
124 |
+
|
125 |
+
images = sorted(glob(osp.join(root, '*.ppm')))
|
126 |
+
flows = sorted(glob(osp.join(root, '*.flo')))
|
127 |
+
assert (len(images)//2 == len(flows))
|
128 |
+
|
129 |
+
split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
|
130 |
+
for i in range(len(flows)):
|
131 |
+
xid = split_list[i]
|
132 |
+
if (split=='training' and xid==1) or (split=='validation' and xid==2):
|
133 |
+
self.flow_list += [ flows[i] ]
|
134 |
+
self.image_list += [ [images[2*i], images[2*i+1]] ]
|
135 |
+
|
136 |
+
|
137 |
+
class FlyingThings3D(FlowDataset):
|
138 |
+
def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):
|
139 |
+
super(FlyingThings3D, self).__init__(aug_params)
|
140 |
+
|
141 |
+
for cam in ['left']:
|
142 |
+
for direction in ['into_future', 'into_past']:
|
143 |
+
image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
|
144 |
+
image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
|
145 |
+
|
146 |
+
flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
|
147 |
+
flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
|
148 |
+
|
149 |
+
for idir, fdir in zip(image_dirs, flow_dirs):
|
150 |
+
images = sorted(glob(osp.join(idir, '*.png')) )
|
151 |
+
flows = sorted(glob(osp.join(fdir, '*.pfm')) )
|
152 |
+
for i in range(len(flows)-1):
|
153 |
+
if direction == 'into_future':
|
154 |
+
self.image_list += [ [images[i], images[i+1]] ]
|
155 |
+
self.flow_list += [ flows[i] ]
|
156 |
+
elif direction == 'into_past':
|
157 |
+
self.image_list += [ [images[i+1], images[i]] ]
|
158 |
+
self.flow_list += [ flows[i+1] ]
|
159 |
+
|
160 |
+
|
161 |
+
class KITTI(FlowDataset):
|
162 |
+
def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
|
163 |
+
super(KITTI, self).__init__(aug_params, sparse=True)
|
164 |
+
if split == 'testing':
|
165 |
+
self.is_test = True
|
166 |
+
|
167 |
+
root = osp.join(root, split)
|
168 |
+
images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
|
169 |
+
images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
|
170 |
+
|
171 |
+
for img1, img2 in zip(images1, images2):
|
172 |
+
frame_id = img1.split('/')[-1]
|
173 |
+
self.extra_info += [ [frame_id] ]
|
174 |
+
self.image_list += [ [img1, img2] ]
|
175 |
+
|
176 |
+
if split == 'training':
|
177 |
+
self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
|
178 |
+
|
179 |
+
|
180 |
+
class HD1K(FlowDataset):
|
181 |
+
def __init__(self, aug_params=None, root='datasets/HD1k'):
|
182 |
+
super(HD1K, self).__init__(aug_params, sparse=True)
|
183 |
+
|
184 |
+
seq_ix = 0
|
185 |
+
while 1:
|
186 |
+
flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
|
187 |
+
images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
|
188 |
+
|
189 |
+
if len(flows) == 0:
|
190 |
+
break
|
191 |
+
|
192 |
+
for i in range(len(flows)-1):
|
193 |
+
self.flow_list += [flows[i]]
|
194 |
+
self.image_list += [ [images[i], images[i+1]] ]
|
195 |
+
|
196 |
+
seq_ix += 1
|
197 |
+
|
198 |
+
|
199 |
+
def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
|
200 |
+
""" Create the data loader for the corresponding trainign set """
|
201 |
+
|
202 |
+
if args.stage == 'chairs':
|
203 |
+
aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
|
204 |
+
train_dataset = FlyingChairs(aug_params, split='training')
|
205 |
+
|
206 |
+
elif args.stage == 'things':
|
207 |
+
aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
|
208 |
+
clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
|
209 |
+
final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
|
210 |
+
train_dataset = clean_dataset + final_dataset
|
211 |
+
|
212 |
+
elif args.stage == 'sintel':
|
213 |
+
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
|
214 |
+
things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
|
215 |
+
sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
|
216 |
+
sintel_final = MpiSintel(aug_params, split='training', dstype='final')
|
217 |
+
|
218 |
+
if TRAIN_DS == 'C+T+K+S+H':
|
219 |
+
kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
|
220 |
+
hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
|
221 |
+
train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
|
222 |
+
|
223 |
+
elif TRAIN_DS == 'C+T+K/S':
|
224 |
+
train_dataset = 100*sintel_clean + 100*sintel_final + things
|
225 |
+
|
226 |
+
elif args.stage == 'kitti':
|
227 |
+
aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
|
228 |
+
train_dataset = KITTI(aug_params, split='training')
|
229 |
+
|
230 |
+
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
|
231 |
+
pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
|
232 |
+
|
233 |
+
print('Training with %d image pairs' % len(train_dataset))
|
234 |
+
return train_loader
|
235 |
+
|
RAFT/demo.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import argparse
|
3 |
+
import os
|
4 |
+
import cv2
|
5 |
+
import glob
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
from .raft import RAFT
|
11 |
+
from .utils import flow_viz
|
12 |
+
from .utils.utils import InputPadder
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
DEVICE = 'cuda'
|
17 |
+
|
18 |
+
def load_image(imfile):
|
19 |
+
img = np.array(Image.open(imfile)).astype(np.uint8)
|
20 |
+
img = torch.from_numpy(img).permute(2, 0, 1).float()
|
21 |
+
return img
|
22 |
+
|
23 |
+
|
24 |
+
def load_image_list(image_files):
|
25 |
+
images = []
|
26 |
+
for imfile in sorted(image_files):
|
27 |
+
images.append(load_image(imfile))
|
28 |
+
|
29 |
+
images = torch.stack(images, dim=0)
|
30 |
+
images = images.to(DEVICE)
|
31 |
+
|
32 |
+
padder = InputPadder(images.shape)
|
33 |
+
return padder.pad(images)[0]
|
34 |
+
|
35 |
+
|
36 |
+
def viz(img, flo):
|
37 |
+
img = img[0].permute(1,2,0).cpu().numpy()
|
38 |
+
flo = flo[0].permute(1,2,0).cpu().numpy()
|
39 |
+
|
40 |
+
# map flow to rgb image
|
41 |
+
flo = flow_viz.flow_to_image(flo)
|
42 |
+
# img_flo = np.concatenate([img, flo], axis=0)
|
43 |
+
img_flo = flo
|
44 |
+
|
45 |
+
cv2.imwrite('/home/chengao/test/flow.png', img_flo[:, :, [2,1,0]])
|
46 |
+
# cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
|
47 |
+
# cv2.waitKey()
|
48 |
+
|
49 |
+
|
50 |
+
def demo(args):
|
51 |
+
model = torch.nn.DataParallel(RAFT(args))
|
52 |
+
model.load_state_dict(torch.load(args.model))
|
53 |
+
|
54 |
+
model = model.module
|
55 |
+
model.to(DEVICE)
|
56 |
+
model.eval()
|
57 |
+
|
58 |
+
with torch.no_grad():
|
59 |
+
images = glob.glob(os.path.join(args.path, '*.png')) + \
|
60 |
+
glob.glob(os.path.join(args.path, '*.jpg'))
|
61 |
+
|
62 |
+
images = load_image_list(images)
|
63 |
+
for i in range(images.shape[0]-1):
|
64 |
+
image1 = images[i,None]
|
65 |
+
image2 = images[i+1,None]
|
66 |
+
|
67 |
+
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
|
68 |
+
viz(image1, flow_up)
|
69 |
+
|
70 |
+
|
71 |
+
def RAFT_infer(args):
|
72 |
+
model = torch.nn.DataParallel(RAFT(args))
|
73 |
+
model.load_state_dict(torch.load(args.model))
|
74 |
+
|
75 |
+
model = model.module
|
76 |
+
model.to(DEVICE)
|
77 |
+
model.eval()
|
78 |
+
|
79 |
+
return model
|
RAFT/extractor.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class ResidualBlock(nn.Module):
|
7 |
+
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
8 |
+
super(ResidualBlock, self).__init__()
|
9 |
+
|
10 |
+
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
|
11 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
|
12 |
+
self.relu = nn.ReLU(inplace=True)
|
13 |
+
|
14 |
+
num_groups = planes // 8
|
15 |
+
|
16 |
+
if norm_fn == 'group':
|
17 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
18 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
19 |
+
if not stride == 1:
|
20 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
21 |
+
|
22 |
+
elif norm_fn == 'batch':
|
23 |
+
self.norm1 = nn.BatchNorm2d(planes)
|
24 |
+
self.norm2 = nn.BatchNorm2d(planes)
|
25 |
+
if not stride == 1:
|
26 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
27 |
+
|
28 |
+
elif norm_fn == 'instance':
|
29 |
+
self.norm1 = nn.InstanceNorm2d(planes)
|
30 |
+
self.norm2 = nn.InstanceNorm2d(planes)
|
31 |
+
if not stride == 1:
|
32 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
33 |
+
|
34 |
+
elif norm_fn == 'none':
|
35 |
+
self.norm1 = nn.Sequential()
|
36 |
+
self.norm2 = nn.Sequential()
|
37 |
+
if not stride == 1:
|
38 |
+
self.norm3 = nn.Sequential()
|
39 |
+
|
40 |
+
if stride == 1:
|
41 |
+
self.downsample = None
|
42 |
+
|
43 |
+
else:
|
44 |
+
self.downsample = nn.Sequential(
|
45 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
|
46 |
+
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
y = x
|
50 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
51 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
52 |
+
|
53 |
+
if self.downsample is not None:
|
54 |
+
x = self.downsample(x)
|
55 |
+
|
56 |
+
return self.relu(x+y)
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
class BottleneckBlock(nn.Module):
|
61 |
+
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
62 |
+
super(BottleneckBlock, self).__init__()
|
63 |
+
|
64 |
+
self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
|
65 |
+
self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
|
66 |
+
self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
|
67 |
+
self.relu = nn.ReLU(inplace=True)
|
68 |
+
|
69 |
+
num_groups = planes // 8
|
70 |
+
|
71 |
+
if norm_fn == 'group':
|
72 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
|
73 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
|
74 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
75 |
+
if not stride == 1:
|
76 |
+
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
77 |
+
|
78 |
+
elif norm_fn == 'batch':
|
79 |
+
self.norm1 = nn.BatchNorm2d(planes//4)
|
80 |
+
self.norm2 = nn.BatchNorm2d(planes//4)
|
81 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
82 |
+
if not stride == 1:
|
83 |
+
self.norm4 = nn.BatchNorm2d(planes)
|
84 |
+
|
85 |
+
elif norm_fn == 'instance':
|
86 |
+
self.norm1 = nn.InstanceNorm2d(planes//4)
|
87 |
+
self.norm2 = nn.InstanceNorm2d(planes//4)
|
88 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
89 |
+
if not stride == 1:
|
90 |
+
self.norm4 = nn.InstanceNorm2d(planes)
|
91 |
+
|
92 |
+
elif norm_fn == 'none':
|
93 |
+
self.norm1 = nn.Sequential()
|
94 |
+
self.norm2 = nn.Sequential()
|
95 |
+
self.norm3 = nn.Sequential()
|
96 |
+
if not stride == 1:
|
97 |
+
self.norm4 = nn.Sequential()
|
98 |
+
|
99 |
+
if stride == 1:
|
100 |
+
self.downsample = None
|
101 |
+
|
102 |
+
else:
|
103 |
+
self.downsample = nn.Sequential(
|
104 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
|
105 |
+
|
106 |
+
|
107 |
+
def forward(self, x):
|
108 |
+
y = x
|
109 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
110 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
111 |
+
y = self.relu(self.norm3(self.conv3(y)))
|
112 |
+
|
113 |
+
if self.downsample is not None:
|
114 |
+
x = self.downsample(x)
|
115 |
+
|
116 |
+
return self.relu(x+y)
|
117 |
+
|
118 |
+
class BasicEncoder(nn.Module):
|
119 |
+
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
120 |
+
super(BasicEncoder, self).__init__()
|
121 |
+
self.norm_fn = norm_fn
|
122 |
+
|
123 |
+
if self.norm_fn == 'group':
|
124 |
+
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
125 |
+
|
126 |
+
elif self.norm_fn == 'batch':
|
127 |
+
self.norm1 = nn.BatchNorm2d(64)
|
128 |
+
|
129 |
+
elif self.norm_fn == 'instance':
|
130 |
+
self.norm1 = nn.InstanceNorm2d(64)
|
131 |
+
|
132 |
+
elif self.norm_fn == 'none':
|
133 |
+
self.norm1 = nn.Sequential()
|
134 |
+
|
135 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
136 |
+
self.relu1 = nn.ReLU(inplace=True)
|
137 |
+
|
138 |
+
self.in_planes = 64
|
139 |
+
self.layer1 = self._make_layer(64, stride=1)
|
140 |
+
self.layer2 = self._make_layer(96, stride=2)
|
141 |
+
self.layer3 = self._make_layer(128, stride=2)
|
142 |
+
|
143 |
+
# output convolution
|
144 |
+
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
|
145 |
+
|
146 |
+
self.dropout = None
|
147 |
+
if dropout > 0:
|
148 |
+
self.dropout = nn.Dropout2d(p=dropout)
|
149 |
+
|
150 |
+
for m in self.modules():
|
151 |
+
if isinstance(m, nn.Conv2d):
|
152 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
153 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
154 |
+
if m.weight is not None:
|
155 |
+
nn.init.constant_(m.weight, 1)
|
156 |
+
if m.bias is not None:
|
157 |
+
nn.init.constant_(m.bias, 0)
|
158 |
+
|
159 |
+
def _make_layer(self, dim, stride=1):
|
160 |
+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
161 |
+
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
162 |
+
layers = (layer1, layer2)
|
163 |
+
|
164 |
+
self.in_planes = dim
|
165 |
+
return nn.Sequential(*layers)
|
166 |
+
|
167 |
+
|
168 |
+
def forward(self, x):
|
169 |
+
|
170 |
+
# if input is list, combine batch dimension
|
171 |
+
is_list = isinstance(x, tuple) or isinstance(x, list)
|
172 |
+
if is_list:
|
173 |
+
batch_dim = x[0].shape[0]
|
174 |
+
x = torch.cat(x, dim=0)
|
175 |
+
|
176 |
+
x = self.conv1(x)
|
177 |
+
x = self.norm1(x)
|
178 |
+
x = self.relu1(x)
|
179 |
+
|
180 |
+
x = self.layer1(x)
|
181 |
+
x = self.layer2(x)
|
182 |
+
x = self.layer3(x)
|
183 |
+
|
184 |
+
x = self.conv2(x)
|
185 |
+
|
186 |
+
if self.training and self.dropout is not None:
|
187 |
+
x = self.dropout(x)
|
188 |
+
|
189 |
+
if is_list:
|
190 |
+
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
191 |
+
|
192 |
+
return x
|
193 |
+
|
194 |
+
|
195 |
+
class SmallEncoder(nn.Module):
|
196 |
+
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
197 |
+
super(SmallEncoder, self).__init__()
|
198 |
+
self.norm_fn = norm_fn
|
199 |
+
|
200 |
+
if self.norm_fn == 'group':
|
201 |
+
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
|
202 |
+
|
203 |
+
elif self.norm_fn == 'batch':
|
204 |
+
self.norm1 = nn.BatchNorm2d(32)
|
205 |
+
|
206 |
+
elif self.norm_fn == 'instance':
|
207 |
+
self.norm1 = nn.InstanceNorm2d(32)
|
208 |
+
|
209 |
+
elif self.norm_fn == 'none':
|
210 |
+
self.norm1 = nn.Sequential()
|
211 |
+
|
212 |
+
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
|
213 |
+
self.relu1 = nn.ReLU(inplace=True)
|
214 |
+
|
215 |
+
self.in_planes = 32
|
216 |
+
self.layer1 = self._make_layer(32, stride=1)
|
217 |
+
self.layer2 = self._make_layer(64, stride=2)
|
218 |
+
self.layer3 = self._make_layer(96, stride=2)
|
219 |
+
|
220 |
+
self.dropout = None
|
221 |
+
if dropout > 0:
|
222 |
+
self.dropout = nn.Dropout2d(p=dropout)
|
223 |
+
|
224 |
+
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
|
225 |
+
|
226 |
+
for m in self.modules():
|
227 |
+
if isinstance(m, nn.Conv2d):
|
228 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
229 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
230 |
+
if m.weight is not None:
|
231 |
+
nn.init.constant_(m.weight, 1)
|
232 |
+
if m.bias is not None:
|
233 |
+
nn.init.constant_(m.bias, 0)
|
234 |
+
|
235 |
+
def _make_layer(self, dim, stride=1):
|
236 |
+
layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
237 |
+
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
|
238 |
+
layers = (layer1, layer2)
|
239 |
+
|
240 |
+
self.in_planes = dim
|
241 |
+
return nn.Sequential(*layers)
|
242 |
+
|
243 |
+
|
244 |
+
def forward(self, x):
|
245 |
+
|
246 |
+
# if input is list, combine batch dimension
|
247 |
+
is_list = isinstance(x, tuple) or isinstance(x, list)
|
248 |
+
if is_list:
|
249 |
+
batch_dim = x[0].shape[0]
|
250 |
+
x = torch.cat(x, dim=0)
|
251 |
+
|
252 |
+
x = self.conv1(x)
|
253 |
+
x = self.norm1(x)
|
254 |
+
x = self.relu1(x)
|
255 |
+
|
256 |
+
x = self.layer1(x)
|
257 |
+
x = self.layer2(x)
|
258 |
+
x = self.layer3(x)
|
259 |
+
x = self.conv2(x)
|
260 |
+
|
261 |
+
if self.training and self.dropout is not None:
|
262 |
+
x = self.dropout(x)
|
263 |
+
|
264 |
+
if is_list:
|
265 |
+
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
266 |
+
|
267 |
+
return x
|
RAFT/raft.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from .update import BasicUpdateBlock, SmallUpdateBlock
|
7 |
+
from .extractor import BasicEncoder, SmallEncoder
|
8 |
+
from .corr import CorrBlock, AlternateCorrBlock
|
9 |
+
from .utils.utils import bilinear_sampler, coords_grid, upflow8
|
10 |
+
|
11 |
+
try:
|
12 |
+
autocast = torch.cuda.amp.autocast
|
13 |
+
except:
|
14 |
+
# dummy autocast for PyTorch < 1.6
|
15 |
+
class autocast:
|
16 |
+
def __init__(self, enabled):
|
17 |
+
pass
|
18 |
+
def __enter__(self):
|
19 |
+
pass
|
20 |
+
def __exit__(self, *args):
|
21 |
+
pass
|
22 |
+
|
23 |
+
|
24 |
+
class RAFT(nn.Module):
|
25 |
+
def __init__(self, args):
|
26 |
+
super(RAFT, self).__init__()
|
27 |
+
self.args = args
|
28 |
+
|
29 |
+
if args.small:
|
30 |
+
self.hidden_dim = hdim = 96
|
31 |
+
self.context_dim = cdim = 64
|
32 |
+
args.corr_levels = 4
|
33 |
+
args.corr_radius = 3
|
34 |
+
|
35 |
+
else:
|
36 |
+
self.hidden_dim = hdim = 128
|
37 |
+
self.context_dim = cdim = 128
|
38 |
+
args.corr_levels = 4
|
39 |
+
args.corr_radius = 4
|
40 |
+
|
41 |
+
if 'dropout' not in args._get_kwargs():
|
42 |
+
args.dropout = 0
|
43 |
+
|
44 |
+
if 'alternate_corr' not in args._get_kwargs():
|
45 |
+
args.alternate_corr = False
|
46 |
+
|
47 |
+
# feature network, context network, and update block
|
48 |
+
if args.small:
|
49 |
+
self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
|
50 |
+
self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
|
51 |
+
self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
|
52 |
+
|
53 |
+
else:
|
54 |
+
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
|
55 |
+
self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
|
56 |
+
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
|
57 |
+
|
58 |
+
|
59 |
+
def freeze_bn(self):
|
60 |
+
for m in self.modules():
|
61 |
+
if isinstance(m, nn.BatchNorm2d):
|
62 |
+
m.eval()
|
63 |
+
|
64 |
+
def initialize_flow(self, img):
|
65 |
+
""" Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
|
66 |
+
N, C, H, W = img.shape
|
67 |
+
coords0 = coords_grid(N, H//8, W//8).to(img.device)
|
68 |
+
coords1 = coords_grid(N, H//8, W//8).to(img.device)
|
69 |
+
|
70 |
+
# optical flow computed as difference: flow = coords1 - coords0
|
71 |
+
return coords0, coords1
|
72 |
+
|
73 |
+
def upsample_flow(self, flow, mask):
|
74 |
+
""" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
|
75 |
+
N, _, H, W = flow.shape
|
76 |
+
mask = mask.view(N, 1, 9, 8, 8, H, W)
|
77 |
+
mask = torch.softmax(mask, dim=2)
|
78 |
+
|
79 |
+
up_flow = F.unfold(8 * flow, [3,3], padding=1)
|
80 |
+
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
|
81 |
+
|
82 |
+
up_flow = torch.sum(mask * up_flow, dim=2)
|
83 |
+
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
|
84 |
+
return up_flow.reshape(N, 2, 8*H, 8*W)
|
85 |
+
|
86 |
+
|
87 |
+
def forward(self, image1, image2, iters=12, flow_init=None, test_mode=True):
|
88 |
+
""" Estimate optical flow between pair of frames """
|
89 |
+
|
90 |
+
# image1 = 2 * (image1 / 255.0) - 1.0
|
91 |
+
# image2 = 2 * (image2 / 255.0) - 1.0
|
92 |
+
|
93 |
+
image1 = image1.contiguous()
|
94 |
+
image2 = image2.contiguous()
|
95 |
+
|
96 |
+
hdim = self.hidden_dim
|
97 |
+
cdim = self.context_dim
|
98 |
+
|
99 |
+
# run the feature network
|
100 |
+
with autocast(enabled=self.args.mixed_precision):
|
101 |
+
fmap1, fmap2 = self.fnet([image1, image2])
|
102 |
+
|
103 |
+
fmap1 = fmap1.float()
|
104 |
+
fmap2 = fmap2.float()
|
105 |
+
|
106 |
+
if self.args.alternate_corr:
|
107 |
+
corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
108 |
+
else:
|
109 |
+
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
110 |
+
|
111 |
+
# run the context network
|
112 |
+
with autocast(enabled=self.args.mixed_precision):
|
113 |
+
cnet = self.cnet(image1)
|
114 |
+
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
|
115 |
+
net = torch.tanh(net)
|
116 |
+
inp = torch.relu(inp)
|
117 |
+
|
118 |
+
coords0, coords1 = self.initialize_flow(image1)
|
119 |
+
|
120 |
+
if flow_init is not None:
|
121 |
+
coords1 = coords1 + flow_init
|
122 |
+
|
123 |
+
flow_predictions = []
|
124 |
+
for itr in range(iters):
|
125 |
+
coords1 = coords1.detach()
|
126 |
+
corr = corr_fn(coords1) # index correlation volume
|
127 |
+
|
128 |
+
flow = coords1 - coords0
|
129 |
+
with autocast(enabled=self.args.mixed_precision):
|
130 |
+
net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
|
131 |
+
|
132 |
+
# F(t+1) = F(t) + \Delta(t)
|
133 |
+
coords1 = coords1 + delta_flow
|
134 |
+
|
135 |
+
# upsample predictions
|
136 |
+
if up_mask is None:
|
137 |
+
flow_up = upflow8(coords1 - coords0)
|
138 |
+
else:
|
139 |
+
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
|
140 |
+
|
141 |
+
flow_predictions.append(flow_up)
|
142 |
+
|
143 |
+
if test_mode:
|
144 |
+
return coords1 - coords0, flow_up
|
145 |
+
|
146 |
+
return flow_predictions
|
RAFT/update.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class FlowHead(nn.Module):
|
7 |
+
def __init__(self, input_dim=128, hidden_dim=256):
|
8 |
+
super(FlowHead, self).__init__()
|
9 |
+
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
10 |
+
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
|
11 |
+
self.relu = nn.ReLU(inplace=True)
|
12 |
+
|
13 |
+
def forward(self, x):
|
14 |
+
return self.conv2(self.relu(self.conv1(x)))
|
15 |
+
|
16 |
+
class ConvGRU(nn.Module):
|
17 |
+
def __init__(self, hidden_dim=128, input_dim=192+128):
|
18 |
+
super(ConvGRU, self).__init__()
|
19 |
+
self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
|
20 |
+
self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
|
21 |
+
self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
|
22 |
+
|
23 |
+
def forward(self, h, x):
|
24 |
+
hx = torch.cat([h, x], dim=1)
|
25 |
+
|
26 |
+
z = torch.sigmoid(self.convz(hx))
|
27 |
+
r = torch.sigmoid(self.convr(hx))
|
28 |
+
q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
|
29 |
+
|
30 |
+
h = (1-z) * h + z * q
|
31 |
+
return h
|
32 |
+
|
33 |
+
class SepConvGRU(nn.Module):
|
34 |
+
def __init__(self, hidden_dim=128, input_dim=192+128):
|
35 |
+
super(SepConvGRU, self).__init__()
|
36 |
+
self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
37 |
+
self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
38 |
+
self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
39 |
+
|
40 |
+
self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
41 |
+
self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
42 |
+
self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
43 |
+
|
44 |
+
|
45 |
+
def forward(self, h, x):
|
46 |
+
# horizontal
|
47 |
+
hx = torch.cat([h, x], dim=1)
|
48 |
+
z = torch.sigmoid(self.convz1(hx))
|
49 |
+
r = torch.sigmoid(self.convr1(hx))
|
50 |
+
q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
|
51 |
+
h = (1-z) * h + z * q
|
52 |
+
|
53 |
+
# vertical
|
54 |
+
hx = torch.cat([h, x], dim=1)
|
55 |
+
z = torch.sigmoid(self.convz2(hx))
|
56 |
+
r = torch.sigmoid(self.convr2(hx))
|
57 |
+
q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
|
58 |
+
h = (1-z) * h + z * q
|
59 |
+
|
60 |
+
return h
|
61 |
+
|
62 |
+
class SmallMotionEncoder(nn.Module):
|
63 |
+
def __init__(self, args):
|
64 |
+
super(SmallMotionEncoder, self).__init__()
|
65 |
+
cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
|
66 |
+
self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
|
67 |
+
self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
|
68 |
+
self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
|
69 |
+
self.conv = nn.Conv2d(128, 80, 3, padding=1)
|
70 |
+
|
71 |
+
def forward(self, flow, corr):
|
72 |
+
cor = F.relu(self.convc1(corr))
|
73 |
+
flo = F.relu(self.convf1(flow))
|
74 |
+
flo = F.relu(self.convf2(flo))
|
75 |
+
cor_flo = torch.cat([cor, flo], dim=1)
|
76 |
+
out = F.relu(self.conv(cor_flo))
|
77 |
+
return torch.cat([out, flow], dim=1)
|
78 |
+
|
79 |
+
class BasicMotionEncoder(nn.Module):
|
80 |
+
def __init__(self, args):
|
81 |
+
super(BasicMotionEncoder, self).__init__()
|
82 |
+
cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
|
83 |
+
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
|
84 |
+
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
|
85 |
+
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
|
86 |
+
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
|
87 |
+
self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
|
88 |
+
|
89 |
+
def forward(self, flow, corr):
|
90 |
+
cor = F.relu(self.convc1(corr))
|
91 |
+
cor = F.relu(self.convc2(cor))
|
92 |
+
flo = F.relu(self.convf1(flow))
|
93 |
+
flo = F.relu(self.convf2(flo))
|
94 |
+
|
95 |
+
cor_flo = torch.cat([cor, flo], dim=1)
|
96 |
+
out = F.relu(self.conv(cor_flo))
|
97 |
+
return torch.cat([out, flow], dim=1)
|
98 |
+
|
99 |
+
class SmallUpdateBlock(nn.Module):
|
100 |
+
def __init__(self, args, hidden_dim=96):
|
101 |
+
super(SmallUpdateBlock, self).__init__()
|
102 |
+
self.encoder = SmallMotionEncoder(args)
|
103 |
+
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
|
104 |
+
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
|
105 |
+
|
106 |
+
def forward(self, net, inp, corr, flow):
|
107 |
+
motion_features = self.encoder(flow, corr)
|
108 |
+
inp = torch.cat([inp, motion_features], dim=1)
|
109 |
+
net = self.gru(net, inp)
|
110 |
+
delta_flow = self.flow_head(net)
|
111 |
+
|
112 |
+
return net, None, delta_flow
|
113 |
+
|
114 |
+
class BasicUpdateBlock(nn.Module):
|
115 |
+
def __init__(self, args, hidden_dim=128, input_dim=128):
|
116 |
+
super(BasicUpdateBlock, self).__init__()
|
117 |
+
self.args = args
|
118 |
+
self.encoder = BasicMotionEncoder(args)
|
119 |
+
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
|
120 |
+
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
|
121 |
+
|
122 |
+
self.mask = nn.Sequential(
|
123 |
+
nn.Conv2d(128, 256, 3, padding=1),
|
124 |
+
nn.ReLU(inplace=True),
|
125 |
+
nn.Conv2d(256, 64*9, 1, padding=0))
|
126 |
+
|
127 |
+
def forward(self, net, inp, corr, flow, upsample=True):
|
128 |
+
motion_features = self.encoder(flow, corr)
|
129 |
+
inp = torch.cat([inp, motion_features], dim=1)
|
130 |
+
|
131 |
+
net = self.gru(net, inp)
|
132 |
+
delta_flow = self.flow_head(net)
|
133 |
+
|
134 |
+
# scale mask to balence gradients
|
135 |
+
mask = .25 * self.mask(net)
|
136 |
+
return net, mask, delta_flow
|
137 |
+
|
138 |
+
|
139 |
+
|
RAFT/utils/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .flow_viz import flow_to_image
|
2 |
+
from .frame_utils import writeFlow
|
RAFT/utils/augmentor.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
import math
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
cv2.setNumThreads(0)
|
8 |
+
cv2.ocl.setUseOpenCL(False)
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torchvision.transforms import ColorJitter
|
12 |
+
import torch.nn.functional as F
|
13 |
+
|
14 |
+
|
15 |
+
class FlowAugmentor:
|
16 |
+
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
|
17 |
+
|
18 |
+
# spatial augmentation params
|
19 |
+
self.crop_size = crop_size
|
20 |
+
self.min_scale = min_scale
|
21 |
+
self.max_scale = max_scale
|
22 |
+
self.spatial_aug_prob = 0.8
|
23 |
+
self.stretch_prob = 0.8
|
24 |
+
self.max_stretch = 0.2
|
25 |
+
|
26 |
+
# flip augmentation params
|
27 |
+
self.do_flip = do_flip
|
28 |
+
self.h_flip_prob = 0.5
|
29 |
+
self.v_flip_prob = 0.1
|
30 |
+
|
31 |
+
# photometric augmentation params
|
32 |
+
self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
|
33 |
+
self.asymmetric_color_aug_prob = 0.2
|
34 |
+
self.eraser_aug_prob = 0.5
|
35 |
+
|
36 |
+
def color_transform(self, img1, img2):
|
37 |
+
""" Photometric augmentation """
|
38 |
+
|
39 |
+
# asymmetric
|
40 |
+
if np.random.rand() < self.asymmetric_color_aug_prob:
|
41 |
+
img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
|
42 |
+
img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
|
43 |
+
|
44 |
+
# symmetric
|
45 |
+
else:
|
46 |
+
image_stack = np.concatenate([img1, img2], axis=0)
|
47 |
+
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
|
48 |
+
img1, img2 = np.split(image_stack, 2, axis=0)
|
49 |
+
|
50 |
+
return img1, img2
|
51 |
+
|
52 |
+
def eraser_transform(self, img1, img2, bounds=[50, 100]):
|
53 |
+
""" Occlusion augmentation """
|
54 |
+
|
55 |
+
ht, wd = img1.shape[:2]
|
56 |
+
if np.random.rand() < self.eraser_aug_prob:
|
57 |
+
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
|
58 |
+
for _ in range(np.random.randint(1, 3)):
|
59 |
+
x0 = np.random.randint(0, wd)
|
60 |
+
y0 = np.random.randint(0, ht)
|
61 |
+
dx = np.random.randint(bounds[0], bounds[1])
|
62 |
+
dy = np.random.randint(bounds[0], bounds[1])
|
63 |
+
img2[y0:y0+dy, x0:x0+dx, :] = mean_color
|
64 |
+
|
65 |
+
return img1, img2
|
66 |
+
|
67 |
+
def spatial_transform(self, img1, img2, flow):
|
68 |
+
# randomly sample scale
|
69 |
+
ht, wd = img1.shape[:2]
|
70 |
+
min_scale = np.maximum(
|
71 |
+
(self.crop_size[0] + 8) / float(ht),
|
72 |
+
(self.crop_size[1] + 8) / float(wd))
|
73 |
+
|
74 |
+
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
|
75 |
+
scale_x = scale
|
76 |
+
scale_y = scale
|
77 |
+
if np.random.rand() < self.stretch_prob:
|
78 |
+
scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
|
79 |
+
scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
|
80 |
+
|
81 |
+
scale_x = np.clip(scale_x, min_scale, None)
|
82 |
+
scale_y = np.clip(scale_y, min_scale, None)
|
83 |
+
|
84 |
+
if np.random.rand() < self.spatial_aug_prob:
|
85 |
+
# rescale the images
|
86 |
+
img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
87 |
+
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
88 |
+
flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
89 |
+
flow = flow * [scale_x, scale_y]
|
90 |
+
|
91 |
+
if self.do_flip:
|
92 |
+
if np.random.rand() < self.h_flip_prob: # h-flip
|
93 |
+
img1 = img1[:, ::-1]
|
94 |
+
img2 = img2[:, ::-1]
|
95 |
+
flow = flow[:, ::-1] * [-1.0, 1.0]
|
96 |
+
|
97 |
+
if np.random.rand() < self.v_flip_prob: # v-flip
|
98 |
+
img1 = img1[::-1, :]
|
99 |
+
img2 = img2[::-1, :]
|
100 |
+
flow = flow[::-1, :] * [1.0, -1.0]
|
101 |
+
|
102 |
+
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
|
103 |
+
x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
|
104 |
+
|
105 |
+
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
106 |
+
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
107 |
+
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
108 |
+
|
109 |
+
return img1, img2, flow
|
110 |
+
|
111 |
+
def __call__(self, img1, img2, flow):
|
112 |
+
img1, img2 = self.color_transform(img1, img2)
|
113 |
+
img1, img2 = self.eraser_transform(img1, img2)
|
114 |
+
img1, img2, flow = self.spatial_transform(img1, img2, flow)
|
115 |
+
|
116 |
+
img1 = np.ascontiguousarray(img1)
|
117 |
+
img2 = np.ascontiguousarray(img2)
|
118 |
+
flow = np.ascontiguousarray(flow)
|
119 |
+
|
120 |
+
return img1, img2, flow
|
121 |
+
|
122 |
+
class SparseFlowAugmentor:
|
123 |
+
def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
|
124 |
+
# spatial augmentation params
|
125 |
+
self.crop_size = crop_size
|
126 |
+
self.min_scale = min_scale
|
127 |
+
self.max_scale = max_scale
|
128 |
+
self.spatial_aug_prob = 0.8
|
129 |
+
self.stretch_prob = 0.8
|
130 |
+
self.max_stretch = 0.2
|
131 |
+
|
132 |
+
# flip augmentation params
|
133 |
+
self.do_flip = do_flip
|
134 |
+
self.h_flip_prob = 0.5
|
135 |
+
self.v_flip_prob = 0.1
|
136 |
+
|
137 |
+
# photometric augmentation params
|
138 |
+
self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
|
139 |
+
self.asymmetric_color_aug_prob = 0.2
|
140 |
+
self.eraser_aug_prob = 0.5
|
141 |
+
|
142 |
+
def color_transform(self, img1, img2):
|
143 |
+
image_stack = np.concatenate([img1, img2], axis=0)
|
144 |
+
image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
|
145 |
+
img1, img2 = np.split(image_stack, 2, axis=0)
|
146 |
+
return img1, img2
|
147 |
+
|
148 |
+
def eraser_transform(self, img1, img2):
|
149 |
+
ht, wd = img1.shape[:2]
|
150 |
+
if np.random.rand() < self.eraser_aug_prob:
|
151 |
+
mean_color = np.mean(img2.reshape(-1, 3), axis=0)
|
152 |
+
for _ in range(np.random.randint(1, 3)):
|
153 |
+
x0 = np.random.randint(0, wd)
|
154 |
+
y0 = np.random.randint(0, ht)
|
155 |
+
dx = np.random.randint(50, 100)
|
156 |
+
dy = np.random.randint(50, 100)
|
157 |
+
img2[y0:y0+dy, x0:x0+dx, :] = mean_color
|
158 |
+
|
159 |
+
return img1, img2
|
160 |
+
|
161 |
+
def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
|
162 |
+
ht, wd = flow.shape[:2]
|
163 |
+
coords = np.meshgrid(np.arange(wd), np.arange(ht))
|
164 |
+
coords = np.stack(coords, axis=-1)
|
165 |
+
|
166 |
+
coords = coords.reshape(-1, 2).astype(np.float32)
|
167 |
+
flow = flow.reshape(-1, 2).astype(np.float32)
|
168 |
+
valid = valid.reshape(-1).astype(np.float32)
|
169 |
+
|
170 |
+
coords0 = coords[valid>=1]
|
171 |
+
flow0 = flow[valid>=1]
|
172 |
+
|
173 |
+
ht1 = int(round(ht * fy))
|
174 |
+
wd1 = int(round(wd * fx))
|
175 |
+
|
176 |
+
coords1 = coords0 * [fx, fy]
|
177 |
+
flow1 = flow0 * [fx, fy]
|
178 |
+
|
179 |
+
xx = np.round(coords1[:,0]).astype(np.int32)
|
180 |
+
yy = np.round(coords1[:,1]).astype(np.int32)
|
181 |
+
|
182 |
+
v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
|
183 |
+
xx = xx[v]
|
184 |
+
yy = yy[v]
|
185 |
+
flow1 = flow1[v]
|
186 |
+
|
187 |
+
flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
|
188 |
+
valid_img = np.zeros([ht1, wd1], dtype=np.int32)
|
189 |
+
|
190 |
+
flow_img[yy, xx] = flow1
|
191 |
+
valid_img[yy, xx] = 1
|
192 |
+
|
193 |
+
return flow_img, valid_img
|
194 |
+
|
195 |
+
def spatial_transform(self, img1, img2, flow, valid):
|
196 |
+
# randomly sample scale
|
197 |
+
|
198 |
+
ht, wd = img1.shape[:2]
|
199 |
+
min_scale = np.maximum(
|
200 |
+
(self.crop_size[0] + 1) / float(ht),
|
201 |
+
(self.crop_size[1] + 1) / float(wd))
|
202 |
+
|
203 |
+
scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
|
204 |
+
scale_x = np.clip(scale, min_scale, None)
|
205 |
+
scale_y = np.clip(scale, min_scale, None)
|
206 |
+
|
207 |
+
if np.random.rand() < self.spatial_aug_prob:
|
208 |
+
# rescale the images
|
209 |
+
img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
210 |
+
img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
|
211 |
+
flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
|
212 |
+
|
213 |
+
if self.do_flip:
|
214 |
+
if np.random.rand() < 0.5: # h-flip
|
215 |
+
img1 = img1[:, ::-1]
|
216 |
+
img2 = img2[:, ::-1]
|
217 |
+
flow = flow[:, ::-1] * [-1.0, 1.0]
|
218 |
+
valid = valid[:, ::-1]
|
219 |
+
|
220 |
+
margin_y = 20
|
221 |
+
margin_x = 50
|
222 |
+
|
223 |
+
y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
|
224 |
+
x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
|
225 |
+
|
226 |
+
y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
|
227 |
+
x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
|
228 |
+
|
229 |
+
img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
230 |
+
img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
231 |
+
flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
232 |
+
valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
233 |
+
return img1, img2, flow, valid
|
234 |
+
|
235 |
+
|
236 |
+
def __call__(self, img1, img2, flow, valid):
|
237 |
+
img1, img2 = self.color_transform(img1, img2)
|
238 |
+
img1, img2 = self.eraser_transform(img1, img2)
|
239 |
+
img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
|
240 |
+
|
241 |
+
img1 = np.ascontiguousarray(img1)
|
242 |
+
img2 = np.ascontiguousarray(img2)
|
243 |
+
flow = np.ascontiguousarray(flow)
|
244 |
+
valid = np.ascontiguousarray(valid)
|
245 |
+
|
246 |
+
return img1, img2, flow, valid
|
RAFT/utils/flow_viz.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
|
2 |
+
|
3 |
+
|
4 |
+
# MIT License
|
5 |
+
#
|
6 |
+
# Copyright (c) 2018 Tom Runia
|
7 |
+
#
|
8 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
9 |
+
# of this software and associated documentation files (the "Software"), to deal
|
10 |
+
# in the Software without restriction, including without limitation the rights
|
11 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
12 |
+
# copies of the Software, and to permit persons to whom the Software is
|
13 |
+
# furnished to do so, subject to conditions.
|
14 |
+
#
|
15 |
+
# Author: Tom Runia
|
16 |
+
# Date Created: 2018-08-03
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
|
20 |
+
def make_colorwheel():
|
21 |
+
"""
|
22 |
+
Generates a color wheel for optical flow visualization as presented in:
|
23 |
+
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
|
24 |
+
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
|
25 |
+
|
26 |
+
Code follows the original C++ source code of Daniel Scharstein.
|
27 |
+
Code follows the the Matlab source code of Deqing Sun.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
np.ndarray: Color wheel
|
31 |
+
"""
|
32 |
+
|
33 |
+
RY = 15
|
34 |
+
YG = 6
|
35 |
+
GC = 4
|
36 |
+
CB = 11
|
37 |
+
BM = 13
|
38 |
+
MR = 6
|
39 |
+
|
40 |
+
ncols = RY + YG + GC + CB + BM + MR
|
41 |
+
colorwheel = np.zeros((ncols, 3))
|
42 |
+
col = 0
|
43 |
+
|
44 |
+
# RY
|
45 |
+
colorwheel[0:RY, 0] = 255
|
46 |
+
colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
|
47 |
+
col = col+RY
|
48 |
+
# YG
|
49 |
+
colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
|
50 |
+
colorwheel[col:col+YG, 1] = 255
|
51 |
+
col = col+YG
|
52 |
+
# GC
|
53 |
+
colorwheel[col:col+GC, 1] = 255
|
54 |
+
colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
|
55 |
+
col = col+GC
|
56 |
+
# CB
|
57 |
+
colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
|
58 |
+
colorwheel[col:col+CB, 2] = 255
|
59 |
+
col = col+CB
|
60 |
+
# BM
|
61 |
+
colorwheel[col:col+BM, 2] = 255
|
62 |
+
colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
|
63 |
+
col = col+BM
|
64 |
+
# MR
|
65 |
+
colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
|
66 |
+
colorwheel[col:col+MR, 0] = 255
|
67 |
+
return colorwheel
|
68 |
+
|
69 |
+
|
70 |
+
def flow_uv_to_colors(u, v, convert_to_bgr=False):
|
71 |
+
"""
|
72 |
+
Applies the flow color wheel to (possibly clipped) flow components u and v.
|
73 |
+
|
74 |
+
According to the C++ source code of Daniel Scharstein
|
75 |
+
According to the Matlab source code of Deqing Sun
|
76 |
+
|
77 |
+
Args:
|
78 |
+
u (np.ndarray): Input horizontal flow of shape [H,W]
|
79 |
+
v (np.ndarray): Input vertical flow of shape [H,W]
|
80 |
+
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
np.ndarray: Flow visualization image of shape [H,W,3]
|
84 |
+
"""
|
85 |
+
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
|
86 |
+
colorwheel = make_colorwheel() # shape [55x3]
|
87 |
+
ncols = colorwheel.shape[0]
|
88 |
+
rad = np.sqrt(np.square(u) + np.square(v))
|
89 |
+
a = np.arctan2(-v, -u)/np.pi
|
90 |
+
fk = (a+1) / 2*(ncols-1)
|
91 |
+
k0 = np.floor(fk).astype(np.int32)
|
92 |
+
k1 = k0 + 1
|
93 |
+
k1[k1 == ncols] = 0
|
94 |
+
f = fk - k0
|
95 |
+
for i in range(colorwheel.shape[1]):
|
96 |
+
tmp = colorwheel[:,i]
|
97 |
+
col0 = tmp[k0] / 255.0
|
98 |
+
col1 = tmp[k1] / 255.0
|
99 |
+
col = (1-f)*col0 + f*col1
|
100 |
+
idx = (rad <= 1)
|
101 |
+
col[idx] = 1 - rad[idx] * (1-col[idx])
|
102 |
+
col[~idx] = col[~idx] * 0.75 # out of range
|
103 |
+
# Note the 2-i => BGR instead of RGB
|
104 |
+
ch_idx = 2-i if convert_to_bgr else i
|
105 |
+
flow_image[:,:,ch_idx] = np.floor(255 * col)
|
106 |
+
return flow_image
|
107 |
+
|
108 |
+
|
109 |
+
def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
|
110 |
+
"""
|
111 |
+
Expects a two dimensional flow image of shape.
|
112 |
+
|
113 |
+
Args:
|
114 |
+
flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
|
115 |
+
clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
|
116 |
+
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
117 |
+
|
118 |
+
Returns:
|
119 |
+
np.ndarray: Flow visualization image of shape [H,W,3]
|
120 |
+
"""
|
121 |
+
assert flow_uv.ndim == 3, 'input flow must have three dimensions'
|
122 |
+
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
|
123 |
+
if clip_flow is not None:
|
124 |
+
flow_uv = np.clip(flow_uv, 0, clip_flow)
|
125 |
+
u = flow_uv[:,:,0]
|
126 |
+
v = flow_uv[:,:,1]
|
127 |
+
rad = np.sqrt(np.square(u) + np.square(v))
|
128 |
+
rad_max = np.max(rad)
|
129 |
+
epsilon = 1e-5
|
130 |
+
u = u / (rad_max + epsilon)
|
131 |
+
v = v / (rad_max + epsilon)
|
132 |
+
return flow_uv_to_colors(u, v, convert_to_bgr)
|
RAFT/utils/flow_viz_pt.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization
|
2 |
+
import torch
|
3 |
+
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
|
4 |
+
|
5 |
+
@torch.no_grad()
|
6 |
+
def flow_to_image(flow: torch.Tensor) -> torch.Tensor:
|
7 |
+
|
8 |
+
"""
|
9 |
+
Converts a flow to an RGB image.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float.
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
img (Tensor): Image Tensor of dtype uint8 where each color corresponds
|
16 |
+
to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input.
|
17 |
+
"""
|
18 |
+
|
19 |
+
if flow.dtype != torch.float:
|
20 |
+
raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.")
|
21 |
+
|
22 |
+
orig_shape = flow.shape
|
23 |
+
if flow.ndim == 3:
|
24 |
+
flow = flow[None] # Add batch dim
|
25 |
+
|
26 |
+
if flow.ndim != 4 or flow.shape[1] != 2:
|
27 |
+
raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.")
|
28 |
+
|
29 |
+
max_norm = torch.sum(flow**2, dim=1).sqrt().max()
|
30 |
+
epsilon = torch.finfo((flow).dtype).eps
|
31 |
+
normalized_flow = flow / (max_norm + epsilon)
|
32 |
+
img = _normalized_flow_to_image(normalized_flow)
|
33 |
+
|
34 |
+
if len(orig_shape) == 3:
|
35 |
+
img = img[0] # Remove batch dim
|
36 |
+
return img
|
37 |
+
|
38 |
+
@torch.no_grad()
|
39 |
+
def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
|
40 |
+
|
41 |
+
"""
|
42 |
+
Converts a batch of normalized flow to an RGB image.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W)
|
46 |
+
Returns:
|
47 |
+
img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8.
|
48 |
+
"""
|
49 |
+
|
50 |
+
N, _, H, W = normalized_flow.shape
|
51 |
+
device = normalized_flow.device
|
52 |
+
flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device)
|
53 |
+
colorwheel = _make_colorwheel().to(device) # shape [55x3]
|
54 |
+
num_cols = colorwheel.shape[0]
|
55 |
+
norm = torch.sum(normalized_flow**2, dim=1).sqrt()
|
56 |
+
a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi
|
57 |
+
fk = (a + 1) / 2 * (num_cols - 1)
|
58 |
+
k0 = torch.floor(fk).to(torch.long)
|
59 |
+
k1 = k0 + 1
|
60 |
+
k1[k1 == num_cols] = 0
|
61 |
+
f = fk - k0
|
62 |
+
|
63 |
+
for c in range(colorwheel.shape[1]):
|
64 |
+
tmp = colorwheel[:, c]
|
65 |
+
col0 = tmp[k0] / 255.0
|
66 |
+
col1 = tmp[k1] / 255.0
|
67 |
+
col = (1 - f) * col0 + f * col1
|
68 |
+
col = 1 - norm * (1 - col)
|
69 |
+
flow_image[:, c, :, :] = torch.floor(255. * col)
|
70 |
+
return flow_image
|
71 |
+
|
72 |
+
|
73 |
+
@torch.no_grad()
|
74 |
+
def _make_colorwheel() -> torch.Tensor:
|
75 |
+
"""
|
76 |
+
Generates a color wheel for optical flow visualization as presented in:
|
77 |
+
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
|
78 |
+
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf.
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
colorwheel (Tensor[55, 3]): Colorwheel Tensor.
|
82 |
+
"""
|
83 |
+
|
84 |
+
RY = 15
|
85 |
+
YG = 6
|
86 |
+
GC = 4
|
87 |
+
CB = 11
|
88 |
+
BM = 13
|
89 |
+
MR = 6
|
90 |
+
|
91 |
+
ncols = RY + YG + GC + CB + BM + MR
|
92 |
+
colorwheel = torch.zeros((ncols, 3))
|
93 |
+
col = 0
|
94 |
+
|
95 |
+
# RY
|
96 |
+
colorwheel[0:RY, 0] = 255
|
97 |
+
colorwheel[0:RY, 1] = torch.floor(255. * torch.arange(0., RY) / RY)
|
98 |
+
col = col + RY
|
99 |
+
# YG
|
100 |
+
colorwheel[col : col + YG, 0] = 255 - torch.floor(255. * torch.arange(0., YG) / YG)
|
101 |
+
colorwheel[col : col + YG, 1] = 255
|
102 |
+
col = col + YG
|
103 |
+
# GC
|
104 |
+
colorwheel[col : col + GC, 1] = 255
|
105 |
+
colorwheel[col : col + GC, 2] = torch.floor(255. * torch.arange(0., GC) / GC)
|
106 |
+
col = col + GC
|
107 |
+
# CB
|
108 |
+
colorwheel[col : col + CB, 1] = 255 - torch.floor(255. * torch.arange(CB) / CB)
|
109 |
+
colorwheel[col : col + CB, 2] = 255
|
110 |
+
col = col + CB
|
111 |
+
# BM
|
112 |
+
colorwheel[col : col + BM, 2] = 255
|
113 |
+
colorwheel[col : col + BM, 0] = torch.floor(255. * torch.arange(0., BM) / BM)
|
114 |
+
col = col + BM
|
115 |
+
# MR
|
116 |
+
colorwheel[col : col + MR, 2] = 255 - torch.floor(255. * torch.arange(MR) / MR)
|
117 |
+
colorwheel[col : col + MR, 0] = 255
|
118 |
+
return colorwheel
|
RAFT/utils/frame_utils.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from PIL import Image
|
3 |
+
from os.path import *
|
4 |
+
import re
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
cv2.setNumThreads(0)
|
8 |
+
cv2.ocl.setUseOpenCL(False)
|
9 |
+
|
10 |
+
TAG_CHAR = np.array([202021.25], np.float32)
|
11 |
+
|
12 |
+
def readFlow(fn):
|
13 |
+
""" Read .flo file in Middlebury format"""
|
14 |
+
# Code adapted from:
|
15 |
+
# http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
|
16 |
+
|
17 |
+
# WARNING: this will work on little-endian architectures (eg Intel x86) only!
|
18 |
+
# print 'fn = %s'%(fn)
|
19 |
+
with open(fn, 'rb') as f:
|
20 |
+
magic = np.fromfile(f, np.float32, count=1)
|
21 |
+
if 202021.25 != magic:
|
22 |
+
print('Magic number incorrect. Invalid .flo file')
|
23 |
+
return None
|
24 |
+
else:
|
25 |
+
w = np.fromfile(f, np.int32, count=1)
|
26 |
+
h = np.fromfile(f, np.int32, count=1)
|
27 |
+
# print 'Reading %d x %d flo file\n' % (w, h)
|
28 |
+
data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
|
29 |
+
# Reshape data into 3D array (columns, rows, bands)
|
30 |
+
# The reshape here is for visualization, the original code is (w,h,2)
|
31 |
+
return np.resize(data, (int(h), int(w), 2))
|
32 |
+
|
33 |
+
def readPFM(file):
|
34 |
+
file = open(file, 'rb')
|
35 |
+
|
36 |
+
color = None
|
37 |
+
width = None
|
38 |
+
height = None
|
39 |
+
scale = None
|
40 |
+
endian = None
|
41 |
+
|
42 |
+
header = file.readline().rstrip()
|
43 |
+
if header == b'PF':
|
44 |
+
color = True
|
45 |
+
elif header == b'Pf':
|
46 |
+
color = False
|
47 |
+
else:
|
48 |
+
raise Exception('Not a PFM file.')
|
49 |
+
|
50 |
+
dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
|
51 |
+
if dim_match:
|
52 |
+
width, height = map(int, dim_match.groups())
|
53 |
+
else:
|
54 |
+
raise Exception('Malformed PFM header.')
|
55 |
+
|
56 |
+
scale = float(file.readline().rstrip())
|
57 |
+
if scale < 0: # little-endian
|
58 |
+
endian = '<'
|
59 |
+
scale = -scale
|
60 |
+
else:
|
61 |
+
endian = '>' # big-endian
|
62 |
+
|
63 |
+
data = np.fromfile(file, endian + 'f')
|
64 |
+
shape = (height, width, 3) if color else (height, width)
|
65 |
+
|
66 |
+
data = np.reshape(data, shape)
|
67 |
+
data = np.flipud(data)
|
68 |
+
return data
|
69 |
+
|
70 |
+
def writeFlow(filename,uv,v=None):
|
71 |
+
""" Write optical flow to file.
|
72 |
+
|
73 |
+
If v is None, uv is assumed to contain both u and v channels,
|
74 |
+
stacked in depth.
|
75 |
+
Original code by Deqing Sun, adapted from Daniel Scharstein.
|
76 |
+
"""
|
77 |
+
nBands = 2
|
78 |
+
|
79 |
+
if v is None:
|
80 |
+
assert(uv.ndim == 3)
|
81 |
+
assert(uv.shape[2] == 2)
|
82 |
+
u = uv[:,:,0]
|
83 |
+
v = uv[:,:,1]
|
84 |
+
else:
|
85 |
+
u = uv
|
86 |
+
|
87 |
+
assert(u.shape == v.shape)
|
88 |
+
height,width = u.shape
|
89 |
+
f = open(filename,'wb')
|
90 |
+
# write the header
|
91 |
+
f.write(TAG_CHAR)
|
92 |
+
np.array(width).astype(np.int32).tofile(f)
|
93 |
+
np.array(height).astype(np.int32).tofile(f)
|
94 |
+
# arrange into matrix form
|
95 |
+
tmp = np.zeros((height, width*nBands))
|
96 |
+
tmp[:,np.arange(width)*2] = u
|
97 |
+
tmp[:,np.arange(width)*2 + 1] = v
|
98 |
+
tmp.astype(np.float32).tofile(f)
|
99 |
+
f.close()
|
100 |
+
|
101 |
+
|
102 |
+
def readFlowKITTI(filename):
|
103 |
+
flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
|
104 |
+
flow = flow[:,:,::-1].astype(np.float32)
|
105 |
+
flow, valid = flow[:, :, :2], flow[:, :, 2]
|
106 |
+
flow = (flow - 2**15) / 64.0
|
107 |
+
return flow, valid
|
108 |
+
|
109 |
+
def readDispKITTI(filename):
|
110 |
+
disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
|
111 |
+
valid = disp > 0.0
|
112 |
+
flow = np.stack([-disp, np.zeros_like(disp)], -1)
|
113 |
+
return flow, valid
|
114 |
+
|
115 |
+
|
116 |
+
def writeFlowKITTI(filename, uv):
|
117 |
+
uv = 64.0 * uv + 2**15
|
118 |
+
valid = np.ones([uv.shape[0], uv.shape[1], 1])
|
119 |
+
uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
|
120 |
+
cv2.imwrite(filename, uv[..., ::-1])
|
121 |
+
|
122 |
+
|
123 |
+
def read_gen(file_name, pil=False):
|
124 |
+
ext = splitext(file_name)[-1]
|
125 |
+
if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
|
126 |
+
return Image.open(file_name)
|
127 |
+
elif ext == '.bin' or ext == '.raw':
|
128 |
+
return np.load(file_name)
|
129 |
+
elif ext == '.flo':
|
130 |
+
return readFlow(file_name).astype(np.float32)
|
131 |
+
elif ext == '.pfm':
|
132 |
+
flow = readPFM(file_name).astype(np.float32)
|
133 |
+
if len(flow.shape) == 2:
|
134 |
+
return flow
|
135 |
+
else:
|
136 |
+
return flow[:, :, :-1]
|
137 |
+
return []
|
RAFT/utils/utils.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import numpy as np
|
4 |
+
from scipy import interpolate
|
5 |
+
|
6 |
+
|
7 |
+
class InputPadder:
|
8 |
+
""" Pads images such that dimensions are divisible by 8 """
|
9 |
+
def __init__(self, dims, mode='sintel'):
|
10 |
+
self.ht, self.wd = dims[-2:]
|
11 |
+
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
|
12 |
+
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
|
13 |
+
if mode == 'sintel':
|
14 |
+
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
|
15 |
+
else:
|
16 |
+
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
|
17 |
+
|
18 |
+
def pad(self, *inputs):
|
19 |
+
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
|
20 |
+
|
21 |
+
def unpad(self,x):
|
22 |
+
ht, wd = x.shape[-2:]
|
23 |
+
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
|
24 |
+
return x[..., c[0]:c[1], c[2]:c[3]]
|
25 |
+
|
26 |
+
def forward_interpolate(flow):
|
27 |
+
flow = flow.detach().cpu().numpy()
|
28 |
+
dx, dy = flow[0], flow[1]
|
29 |
+
|
30 |
+
ht, wd = dx.shape
|
31 |
+
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
|
32 |
+
|
33 |
+
x1 = x0 + dx
|
34 |
+
y1 = y0 + dy
|
35 |
+
|
36 |
+
x1 = x1.reshape(-1)
|
37 |
+
y1 = y1.reshape(-1)
|
38 |
+
dx = dx.reshape(-1)
|
39 |
+
dy = dy.reshape(-1)
|
40 |
+
|
41 |
+
valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
|
42 |
+
x1 = x1[valid]
|
43 |
+
y1 = y1[valid]
|
44 |
+
dx = dx[valid]
|
45 |
+
dy = dy[valid]
|
46 |
+
|
47 |
+
flow_x = interpolate.griddata(
|
48 |
+
(x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
|
49 |
+
|
50 |
+
flow_y = interpolate.griddata(
|
51 |
+
(x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
|
52 |
+
|
53 |
+
flow = np.stack([flow_x, flow_y], axis=0)
|
54 |
+
return torch.from_numpy(flow).float()
|
55 |
+
|
56 |
+
|
57 |
+
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
|
58 |
+
""" Wrapper for grid_sample, uses pixel coordinates """
|
59 |
+
H, W = img.shape[-2:]
|
60 |
+
xgrid, ygrid = coords.split([1,1], dim=-1)
|
61 |
+
xgrid = 2*xgrid/(W-1) - 1
|
62 |
+
ygrid = 2*ygrid/(H-1) - 1
|
63 |
+
|
64 |
+
grid = torch.cat([xgrid, ygrid], dim=-1)
|
65 |
+
img = F.grid_sample(img, grid, align_corners=True)
|
66 |
+
|
67 |
+
if mask:
|
68 |
+
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
69 |
+
return img, mask.float()
|
70 |
+
|
71 |
+
return img
|
72 |
+
|
73 |
+
|
74 |
+
def coords_grid(batch, ht, wd):
|
75 |
+
coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
|
76 |
+
coords = torch.stack(coords[::-1], dim=0).float()
|
77 |
+
return coords[None].repeat(batch, 1, 1, 1)
|
78 |
+
|
79 |
+
|
80 |
+
def upflow8(flow, mode='bilinear'):
|
81 |
+
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
|
82 |
+
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
|
configs/train_flowcomp.json
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"seed": 2023,
|
3 |
+
"save_dir": "experiments_model/",
|
4 |
+
"train_data_loader": {
|
5 |
+
"name": "youtube-vos",
|
6 |
+
"video_root": "your_video_root",
|
7 |
+
"flow_root": "your_flow_root",
|
8 |
+
"w": 432,
|
9 |
+
"h": 240,
|
10 |
+
"num_local_frames": 10,
|
11 |
+
"num_ref_frames": 1,
|
12 |
+
"load_flow": 0
|
13 |
+
},
|
14 |
+
"losses": {
|
15 |
+
"flow_weight": 0.25
|
16 |
+
},
|
17 |
+
"model": {
|
18 |
+
"net": "recurrent_flow_completion"
|
19 |
+
},
|
20 |
+
"trainer": {
|
21 |
+
"version": "trainer_flow_w_edge",
|
22 |
+
"type": "Adam",
|
23 |
+
"beta1": 0,
|
24 |
+
"beta2": 0.99,
|
25 |
+
"lr": 5e-5,
|
26 |
+
"batch_size": 8,
|
27 |
+
"num_workers": 4,
|
28 |
+
"num_prefetch_queue": 4,
|
29 |
+
"log_freq": 100,
|
30 |
+
"save_freq": 5e3,
|
31 |
+
"iterations": 700e3,
|
32 |
+
"scheduler": {
|
33 |
+
"type": "MultiStepLR",
|
34 |
+
"milestones": [
|
35 |
+
300e3, 400e3, 500e3, 600e3
|
36 |
+
],
|
37 |
+
"gamma": 0.2
|
38 |
+
}
|
39 |
+
}
|
40 |
+
}
|
configs/train_propainter.json
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"seed": 2023,
|
3 |
+
"save_dir": "experiments_model/",
|
4 |
+
"train_data_loader": {
|
5 |
+
"name": "youtube-vos",
|
6 |
+
"video_root": "your_video_root",
|
7 |
+
"flow_root": "your_flow_root",
|
8 |
+
"w": 432,
|
9 |
+
"h": 240,
|
10 |
+
"num_local_frames": 10,
|
11 |
+
"num_ref_frames": 6,
|
12 |
+
"load_flow": 0
|
13 |
+
},
|
14 |
+
"losses": {
|
15 |
+
"hole_weight": 1,
|
16 |
+
"valid_weight": 1,
|
17 |
+
"flow_weight": 1,
|
18 |
+
"adversarial_weight": 0.01,
|
19 |
+
"GAN_LOSS": "hinge",
|
20 |
+
"perceptual_weight": 0
|
21 |
+
},
|
22 |
+
"model": {
|
23 |
+
"net": "propainter",
|
24 |
+
"no_dis": 0,
|
25 |
+
"load_d": 1,
|
26 |
+
"interp_mode": "nearest"
|
27 |
+
},
|
28 |
+
"trainer": {
|
29 |
+
"version": "trainer",
|
30 |
+
"type": "Adam",
|
31 |
+
"beta1": 0,
|
32 |
+
"beta2": 0.99,
|
33 |
+
"lr": 1e-4,
|
34 |
+
"batch_size": 8,
|
35 |
+
"num_workers": 8,
|
36 |
+
"num_prefetch_queue": 8,
|
37 |
+
"log_freq": 100,
|
38 |
+
"save_freq": 1e4,
|
39 |
+
"iterations": 700e3,
|
40 |
+
"scheduler": {
|
41 |
+
"type": "MultiStepLR",
|
42 |
+
"milestones": [
|
43 |
+
400e3
|
44 |
+
],
|
45 |
+
"gamma": 0.1
|
46 |
+
}
|
47 |
+
}
|
48 |
+
}
|
core/dataset.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
from PIL import Image
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torchvision.transforms as transforms
|
11 |
+
|
12 |
+
from utils.file_client import FileClient
|
13 |
+
from utils.img_util import imfrombytes
|
14 |
+
from utils.flow_util import resize_flow, flowread
|
15 |
+
from core.utils import (create_random_shape_with_random_motion, Stack,
|
16 |
+
ToTorchFormatTensor, GroupRandomHorizontalFlip,GroupRandomHorizontalFlowFlip)
|
17 |
+
|
18 |
+
|
19 |
+
class TrainDataset(torch.utils.data.Dataset):
|
20 |
+
def __init__(self, args: dict):
|
21 |
+
self.args = args
|
22 |
+
self.video_root = args['video_root']
|
23 |
+
self.flow_root = args['flow_root']
|
24 |
+
self.num_local_frames = args['num_local_frames']
|
25 |
+
self.num_ref_frames = args['num_ref_frames']
|
26 |
+
self.size = self.w, self.h = (args['w'], args['h'])
|
27 |
+
|
28 |
+
self.load_flow = args['load_flow']
|
29 |
+
if self.load_flow:
|
30 |
+
assert os.path.exists(self.flow_root)
|
31 |
+
|
32 |
+
json_path = os.path.join('./datasets', args['name'], 'train.json')
|
33 |
+
|
34 |
+
with open(json_path, 'r') as f:
|
35 |
+
self.video_train_dict = json.load(f)
|
36 |
+
self.video_names = sorted(list(self.video_train_dict.keys()))
|
37 |
+
|
38 |
+
# self.video_names = sorted(os.listdir(self.video_root))
|
39 |
+
self.video_dict = {}
|
40 |
+
self.frame_dict = {}
|
41 |
+
|
42 |
+
for v in self.video_names:
|
43 |
+
frame_list = sorted(os.listdir(os.path.join(self.video_root, v)))
|
44 |
+
v_len = len(frame_list)
|
45 |
+
if v_len > self.num_local_frames + self.num_ref_frames:
|
46 |
+
self.video_dict[v] = v_len
|
47 |
+
self.frame_dict[v] = frame_list
|
48 |
+
|
49 |
+
|
50 |
+
self.video_names = list(self.video_dict.keys()) # update names
|
51 |
+
|
52 |
+
self._to_tensors = transforms.Compose([
|
53 |
+
Stack(),
|
54 |
+
ToTorchFormatTensor(),
|
55 |
+
])
|
56 |
+
self.file_client = FileClient('disk')
|
57 |
+
|
58 |
+
def __len__(self):
|
59 |
+
return len(self.video_names)
|
60 |
+
|
61 |
+
def _sample_index(self, length, sample_length, num_ref_frame=3):
|
62 |
+
complete_idx_set = list(range(length))
|
63 |
+
pivot = random.randint(0, length - sample_length)
|
64 |
+
local_idx = complete_idx_set[pivot:pivot + sample_length]
|
65 |
+
remain_idx = list(set(complete_idx_set) - set(local_idx))
|
66 |
+
ref_index = sorted(random.sample(remain_idx, num_ref_frame))
|
67 |
+
|
68 |
+
return local_idx + ref_index
|
69 |
+
|
70 |
+
def __getitem__(self, index):
|
71 |
+
video_name = self.video_names[index]
|
72 |
+
# create masks
|
73 |
+
all_masks = create_random_shape_with_random_motion(
|
74 |
+
self.video_dict[video_name], imageHeight=self.h, imageWidth=self.w)
|
75 |
+
|
76 |
+
# create sample index
|
77 |
+
selected_index = self._sample_index(self.video_dict[video_name],
|
78 |
+
self.num_local_frames,
|
79 |
+
self.num_ref_frames)
|
80 |
+
|
81 |
+
# read video frames
|
82 |
+
frames = []
|
83 |
+
masks = []
|
84 |
+
flows_f, flows_b = [], []
|
85 |
+
for idx in selected_index:
|
86 |
+
frame_list = self.frame_dict[video_name]
|
87 |
+
img_path = os.path.join(self.video_root, video_name, frame_list[idx])
|
88 |
+
img_bytes = self.file_client.get(img_path, 'img')
|
89 |
+
img = imfrombytes(img_bytes, float32=False)
|
90 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
91 |
+
img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR)
|
92 |
+
img = Image.fromarray(img)
|
93 |
+
|
94 |
+
frames.append(img)
|
95 |
+
masks.append(all_masks[idx])
|
96 |
+
|
97 |
+
if len(frames) <= self.num_local_frames-1 and self.load_flow:
|
98 |
+
current_n = frame_list[idx][:-4]
|
99 |
+
next_n = frame_list[idx+1][:-4]
|
100 |
+
flow_f_path = os.path.join(self.flow_root, video_name, f'{current_n}_{next_n}_f.flo')
|
101 |
+
flow_b_path = os.path.join(self.flow_root, video_name, f'{next_n}_{current_n}_b.flo')
|
102 |
+
flow_f = flowread(flow_f_path, quantize=False)
|
103 |
+
flow_b = flowread(flow_b_path, quantize=False)
|
104 |
+
flow_f = resize_flow(flow_f, self.h, self.w)
|
105 |
+
flow_b = resize_flow(flow_b, self.h, self.w)
|
106 |
+
flows_f.append(flow_f)
|
107 |
+
flows_b.append(flow_b)
|
108 |
+
|
109 |
+
if len(frames) == self.num_local_frames: # random reverse
|
110 |
+
if random.random() < 0.5:
|
111 |
+
frames.reverse()
|
112 |
+
masks.reverse()
|
113 |
+
if self.load_flow:
|
114 |
+
flows_f.reverse()
|
115 |
+
flows_b.reverse()
|
116 |
+
flows_ = flows_f
|
117 |
+
flows_f = flows_b
|
118 |
+
flows_b = flows_
|
119 |
+
|
120 |
+
if self.load_flow:
|
121 |
+
frames, flows_f, flows_b = GroupRandomHorizontalFlowFlip()(frames, flows_f, flows_b)
|
122 |
+
else:
|
123 |
+
frames = GroupRandomHorizontalFlip()(frames)
|
124 |
+
|
125 |
+
# normalizate, to tensors
|
126 |
+
frame_tensors = self._to_tensors(frames) * 2.0 - 1.0
|
127 |
+
mask_tensors = self._to_tensors(masks)
|
128 |
+
if self.load_flow:
|
129 |
+
flows_f = np.stack(flows_f, axis=-1) # H W 2 T-1
|
130 |
+
flows_b = np.stack(flows_b, axis=-1)
|
131 |
+
flows_f = torch.from_numpy(flows_f).permute(3, 2, 0, 1).contiguous().float()
|
132 |
+
flows_b = torch.from_numpy(flows_b).permute(3, 2, 0, 1).contiguous().float()
|
133 |
+
|
134 |
+
# img [-1,1] mask [0,1]
|
135 |
+
if self.load_flow:
|
136 |
+
return frame_tensors, mask_tensors, flows_f, flows_b, video_name
|
137 |
+
else:
|
138 |
+
return frame_tensors, mask_tensors, 'None', 'None', video_name
|
139 |
+
|
140 |
+
|
141 |
+
class TestDataset(torch.utils.data.Dataset):
|
142 |
+
def __init__(self, args):
|
143 |
+
self.args = args
|
144 |
+
self.size = self.w, self.h = args['size']
|
145 |
+
|
146 |
+
self.video_root = args['video_root']
|
147 |
+
self.mask_root = args['mask_root']
|
148 |
+
self.flow_root = args['flow_root']
|
149 |
+
|
150 |
+
self.load_flow = args['load_flow']
|
151 |
+
if self.load_flow:
|
152 |
+
assert os.path.exists(self.flow_root)
|
153 |
+
self.video_names = sorted(os.listdir(self.mask_root))
|
154 |
+
|
155 |
+
self.video_dict = {}
|
156 |
+
self.frame_dict = {}
|
157 |
+
|
158 |
+
for v in self.video_names:
|
159 |
+
frame_list = sorted(os.listdir(os.path.join(self.video_root, v)))
|
160 |
+
v_len = len(frame_list)
|
161 |
+
self.video_dict[v] = v_len
|
162 |
+
self.frame_dict[v] = frame_list
|
163 |
+
|
164 |
+
self._to_tensors = transforms.Compose([
|
165 |
+
Stack(),
|
166 |
+
ToTorchFormatTensor(),
|
167 |
+
])
|
168 |
+
self.file_client = FileClient('disk')
|
169 |
+
|
170 |
+
def __len__(self):
|
171 |
+
return len(self.video_names)
|
172 |
+
|
173 |
+
def __getitem__(self, index):
|
174 |
+
video_name = self.video_names[index]
|
175 |
+
selected_index = list(range(self.video_dict[video_name]))
|
176 |
+
|
177 |
+
# read video frames
|
178 |
+
frames = []
|
179 |
+
masks = []
|
180 |
+
flows_f, flows_b = [], []
|
181 |
+
for idx in selected_index:
|
182 |
+
frame_list = self.frame_dict[video_name]
|
183 |
+
frame_path = os.path.join(self.video_root, video_name, frame_list[idx])
|
184 |
+
|
185 |
+
img_bytes = self.file_client.get(frame_path, 'input')
|
186 |
+
img = imfrombytes(img_bytes, float32=False)
|
187 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
188 |
+
img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR)
|
189 |
+
img = Image.fromarray(img)
|
190 |
+
|
191 |
+
frames.append(img)
|
192 |
+
|
193 |
+
mask_path = os.path.join(self.mask_root, video_name, str(idx).zfill(5) + '.png')
|
194 |
+
mask = Image.open(mask_path).resize(self.size, Image.NEAREST).convert('L')
|
195 |
+
|
196 |
+
# origin: 0 indicates missing. now: 1 indicates missing
|
197 |
+
mask = np.asarray(mask)
|
198 |
+
m = np.array(mask > 0).astype(np.uint8)
|
199 |
+
|
200 |
+
m = cv2.dilate(m,
|
201 |
+
cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)),
|
202 |
+
iterations=4)
|
203 |
+
mask = Image.fromarray(m * 255)
|
204 |
+
masks.append(mask)
|
205 |
+
|
206 |
+
if len(frames) <= len(selected_index)-1 and self.load_flow:
|
207 |
+
current_n = frame_list[idx][:-4]
|
208 |
+
next_n = frame_list[idx+1][:-4]
|
209 |
+
flow_f_path = os.path.join(self.flow_root, video_name, f'{current_n}_{next_n}_f.flo')
|
210 |
+
flow_b_path = os.path.join(self.flow_root, video_name, f'{next_n}_{current_n}_b.flo')
|
211 |
+
flow_f = flowread(flow_f_path, quantize=False)
|
212 |
+
flow_b = flowread(flow_b_path, quantize=False)
|
213 |
+
flow_f = resize_flow(flow_f, self.h, self.w)
|
214 |
+
flow_b = resize_flow(flow_b, self.h, self.w)
|
215 |
+
flows_f.append(flow_f)
|
216 |
+
flows_b.append(flow_b)
|
217 |
+
|
218 |
+
# normalizate, to tensors
|
219 |
+
frames_PIL = [np.array(f).astype(np.uint8) for f in frames]
|
220 |
+
frame_tensors = self._to_tensors(frames) * 2.0 - 1.0
|
221 |
+
mask_tensors = self._to_tensors(masks)
|
222 |
+
|
223 |
+
if self.load_flow:
|
224 |
+
flows_f = np.stack(flows_f, axis=-1) # H W 2 T-1
|
225 |
+
flows_b = np.stack(flows_b, axis=-1)
|
226 |
+
flows_f = torch.from_numpy(flows_f).permute(3, 2, 0, 1).contiguous().float()
|
227 |
+
flows_b = torch.from_numpy(flows_b).permute(3, 2, 0, 1).contiguous().float()
|
228 |
+
|
229 |
+
if self.load_flow:
|
230 |
+
return frame_tensors, mask_tensors, flows_f, flows_b, video_name, frames_PIL
|
231 |
+
else:
|
232 |
+
return frame_tensors, mask_tensors, 'None', 'None', video_name
|
core/dist.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
def get_world_size():
|
6 |
+
"""Find OMPI world size without calling mpi functions
|
7 |
+
:rtype: int
|
8 |
+
"""
|
9 |
+
if os.environ.get('PMI_SIZE') is not None:
|
10 |
+
return int(os.environ.get('PMI_SIZE') or 1)
|
11 |
+
elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None:
|
12 |
+
return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1)
|
13 |
+
else:
|
14 |
+
return torch.cuda.device_count()
|
15 |
+
|
16 |
+
|
17 |
+
def get_global_rank():
|
18 |
+
"""Find OMPI world rank without calling mpi functions
|
19 |
+
:rtype: int
|
20 |
+
"""
|
21 |
+
if os.environ.get('PMI_RANK') is not None:
|
22 |
+
return int(os.environ.get('PMI_RANK') or 0)
|
23 |
+
elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None:
|
24 |
+
return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0)
|
25 |
+
else:
|
26 |
+
return 0
|
27 |
+
|
28 |
+
|
29 |
+
def get_local_rank():
|
30 |
+
"""Find OMPI local rank without calling mpi functions
|
31 |
+
:rtype: int
|
32 |
+
"""
|
33 |
+
if os.environ.get('MPI_LOCALRANKID') is not None:
|
34 |
+
return int(os.environ.get('MPI_LOCALRANKID') or 0)
|
35 |
+
elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None:
|
36 |
+
return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0)
|
37 |
+
else:
|
38 |
+
return 0
|
39 |
+
|
40 |
+
|
41 |
+
def get_master_ip():
|
42 |
+
if os.environ.get('AZ_BATCH_MASTER_NODE') is not None:
|
43 |
+
return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0]
|
44 |
+
elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None:
|
45 |
+
return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE')
|
46 |
+
else:
|
47 |
+
return "127.0.0.1"
|
core/loss.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import lpips
|
4 |
+
from model.vgg_arch import VGGFeatureExtractor
|
5 |
+
|
6 |
+
class PerceptualLoss(nn.Module):
|
7 |
+
"""Perceptual loss with commonly used style loss.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
layer_weights (dict): The weight for each layer of vgg feature.
|
11 |
+
Here is an example: {'conv5_4': 1.}, which means the conv5_4
|
12 |
+
feature layer (before relu5_4) will be extracted with weight
|
13 |
+
1.0 in calculting losses.
|
14 |
+
vgg_type (str): The type of vgg network used as feature extractor.
|
15 |
+
Default: 'vgg19'.
|
16 |
+
use_input_norm (bool): If True, normalize the input image in vgg.
|
17 |
+
Default: True.
|
18 |
+
range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
|
19 |
+
Default: False.
|
20 |
+
perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
|
21 |
+
loss will be calculated and the loss will multiplied by the
|
22 |
+
weight. Default: 1.0.
|
23 |
+
style_weight (float): If `style_weight > 0`, the style loss will be
|
24 |
+
calculated and the loss will multiplied by the weight.
|
25 |
+
Default: 0.
|
26 |
+
criterion (str): Criterion used for perceptual loss. Default: 'l1'.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(self,
|
30 |
+
layer_weights,
|
31 |
+
vgg_type='vgg19',
|
32 |
+
use_input_norm=True,
|
33 |
+
range_norm=False,
|
34 |
+
perceptual_weight=1.0,
|
35 |
+
style_weight=0.,
|
36 |
+
criterion='l1'):
|
37 |
+
super(PerceptualLoss, self).__init__()
|
38 |
+
self.perceptual_weight = perceptual_weight
|
39 |
+
self.style_weight = style_weight
|
40 |
+
self.layer_weights = layer_weights
|
41 |
+
self.vgg = VGGFeatureExtractor(
|
42 |
+
layer_name_list=list(layer_weights.keys()),
|
43 |
+
vgg_type=vgg_type,
|
44 |
+
use_input_norm=use_input_norm,
|
45 |
+
range_norm=range_norm)
|
46 |
+
|
47 |
+
self.criterion_type = criterion
|
48 |
+
if self.criterion_type == 'l1':
|
49 |
+
self.criterion = torch.nn.L1Loss()
|
50 |
+
elif self.criterion_type == 'l2':
|
51 |
+
self.criterion = torch.nn.L2loss()
|
52 |
+
elif self.criterion_type == 'mse':
|
53 |
+
self.criterion = torch.nn.MSELoss(reduction='mean')
|
54 |
+
elif self.criterion_type == 'fro':
|
55 |
+
self.criterion = None
|
56 |
+
else:
|
57 |
+
raise NotImplementedError(f'{criterion} criterion has not been supported.')
|
58 |
+
|
59 |
+
def forward(self, x, gt):
|
60 |
+
"""Forward function.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
x (Tensor): Input tensor with shape (n, c, h, w).
|
64 |
+
gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
|
65 |
+
|
66 |
+
Returns:
|
67 |
+
Tensor: Forward results.
|
68 |
+
"""
|
69 |
+
# extract vgg features
|
70 |
+
x_features = self.vgg(x)
|
71 |
+
gt_features = self.vgg(gt.detach())
|
72 |
+
|
73 |
+
# calculate perceptual loss
|
74 |
+
if self.perceptual_weight > 0:
|
75 |
+
percep_loss = 0
|
76 |
+
for k in x_features.keys():
|
77 |
+
if self.criterion_type == 'fro':
|
78 |
+
percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
|
79 |
+
else:
|
80 |
+
percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
|
81 |
+
percep_loss *= self.perceptual_weight
|
82 |
+
else:
|
83 |
+
percep_loss = None
|
84 |
+
|
85 |
+
# calculate style loss
|
86 |
+
if self.style_weight > 0:
|
87 |
+
style_loss = 0
|
88 |
+
for k in x_features.keys():
|
89 |
+
if self.criterion_type == 'fro':
|
90 |
+
style_loss += torch.norm(
|
91 |
+
self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
|
92 |
+
else:
|
93 |
+
style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
|
94 |
+
gt_features[k])) * self.layer_weights[k]
|
95 |
+
style_loss *= self.style_weight
|
96 |
+
else:
|
97 |
+
style_loss = None
|
98 |
+
|
99 |
+
return percep_loss, style_loss
|
100 |
+
|
101 |
+
def _gram_mat(self, x):
|
102 |
+
"""Calculate Gram matrix.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
x (torch.Tensor): Tensor with shape of (n, c, h, w).
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
torch.Tensor: Gram matrix.
|
109 |
+
"""
|
110 |
+
n, c, h, w = x.size()
|
111 |
+
features = x.view(n, c, w * h)
|
112 |
+
features_t = features.transpose(1, 2)
|
113 |
+
gram = features.bmm(features_t) / (c * h * w)
|
114 |
+
return gram
|
115 |
+
|
116 |
+
class LPIPSLoss(nn.Module):
|
117 |
+
def __init__(self,
|
118 |
+
loss_weight=1.0,
|
119 |
+
use_input_norm=True,
|
120 |
+
range_norm=False,):
|
121 |
+
super(LPIPSLoss, self).__init__()
|
122 |
+
self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval()
|
123 |
+
self.loss_weight = loss_weight
|
124 |
+
self.use_input_norm = use_input_norm
|
125 |
+
self.range_norm = range_norm
|
126 |
+
|
127 |
+
if self.use_input_norm:
|
128 |
+
# the mean is for image with range [0, 1]
|
129 |
+
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
130 |
+
# the std is for image with range [0, 1]
|
131 |
+
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
132 |
+
|
133 |
+
def forward(self, pred, target):
|
134 |
+
if self.range_norm:
|
135 |
+
pred = (pred + 1) / 2
|
136 |
+
target = (target + 1) / 2
|
137 |
+
if self.use_input_norm:
|
138 |
+
pred = (pred - self.mean) / self.std
|
139 |
+
target = (target - self.mean) / self.std
|
140 |
+
lpips_loss = self.perceptual(target.contiguous(), pred.contiguous())
|
141 |
+
return self.loss_weight * lpips_loss.mean(), None
|
142 |
+
|
143 |
+
|
144 |
+
class AdversarialLoss(nn.Module):
|
145 |
+
r"""
|
146 |
+
Adversarial loss
|
147 |
+
https://arxiv.org/abs/1711.10337
|
148 |
+
"""
|
149 |
+
def __init__(self,
|
150 |
+
type='nsgan',
|
151 |
+
target_real_label=1.0,
|
152 |
+
target_fake_label=0.0):
|
153 |
+
r"""
|
154 |
+
type = nsgan | lsgan | hinge
|
155 |
+
"""
|
156 |
+
super(AdversarialLoss, self).__init__()
|
157 |
+
self.type = type
|
158 |
+
self.register_buffer('real_label', torch.tensor(target_real_label))
|
159 |
+
self.register_buffer('fake_label', torch.tensor(target_fake_label))
|
160 |
+
|
161 |
+
if type == 'nsgan':
|
162 |
+
self.criterion = nn.BCELoss()
|
163 |
+
elif type == 'lsgan':
|
164 |
+
self.criterion = nn.MSELoss()
|
165 |
+
elif type == 'hinge':
|
166 |
+
self.criterion = nn.ReLU()
|
167 |
+
|
168 |
+
def __call__(self, outputs, is_real, is_disc=None):
|
169 |
+
if self.type == 'hinge':
|
170 |
+
if is_disc:
|
171 |
+
if is_real:
|
172 |
+
outputs = -outputs
|
173 |
+
return self.criterion(1 + outputs).mean()
|
174 |
+
else:
|
175 |
+
return (-outputs).mean()
|
176 |
+
else:
|
177 |
+
labels = (self.real_label
|
178 |
+
if is_real else self.fake_label).expand_as(outputs)
|
179 |
+
loss = self.criterion(outputs, labels)
|
180 |
+
return loss
|
core/lr_scheduler.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
LR scheduler from BasicSR https://github.com/xinntao/BasicSR
|
3 |
+
"""
|
4 |
+
import math
|
5 |
+
from collections import Counter
|
6 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
7 |
+
|
8 |
+
|
9 |
+
class MultiStepRestartLR(_LRScheduler):
|
10 |
+
""" MultiStep with restarts learning rate scheme.
|
11 |
+
Args:
|
12 |
+
optimizer (torch.nn.optimizer): Torch optimizer.
|
13 |
+
milestones (list): Iterations that will decrease learning rate.
|
14 |
+
gamma (float): Decrease ratio. Default: 0.1.
|
15 |
+
restarts (list): Restart iterations. Default: [0].
|
16 |
+
restart_weights (list): Restart weights at each restart iteration.
|
17 |
+
Default: [1].
|
18 |
+
last_epoch (int): Used in _LRScheduler. Default: -1.
|
19 |
+
"""
|
20 |
+
def __init__(self,
|
21 |
+
optimizer,
|
22 |
+
milestones,
|
23 |
+
gamma=0.1,
|
24 |
+
restarts=(0, ),
|
25 |
+
restart_weights=(1, ),
|
26 |
+
last_epoch=-1):
|
27 |
+
self.milestones = Counter(milestones)
|
28 |
+
self.gamma = gamma
|
29 |
+
self.restarts = restarts
|
30 |
+
self.restart_weights = restart_weights
|
31 |
+
assert len(self.restarts) == len(
|
32 |
+
self.restart_weights), 'restarts and their weights do not match.'
|
33 |
+
super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
|
34 |
+
|
35 |
+
def get_lr(self):
|
36 |
+
if self.last_epoch in self.restarts:
|
37 |
+
weight = self.restart_weights[self.restarts.index(self.last_epoch)]
|
38 |
+
return [
|
39 |
+
group['initial_lr'] * weight
|
40 |
+
for group in self.optimizer.param_groups
|
41 |
+
]
|
42 |
+
if self.last_epoch not in self.milestones:
|
43 |
+
return [group['lr'] for group in self.optimizer.param_groups]
|
44 |
+
return [
|
45 |
+
group['lr'] * self.gamma**self.milestones[self.last_epoch]
|
46 |
+
for group in self.optimizer.param_groups
|
47 |
+
]
|
48 |
+
|
49 |
+
|
50 |
+
def get_position_from_periods(iteration, cumulative_period):
|
51 |
+
"""Get the position from a period list.
|
52 |
+
It will return the index of the right-closest number in the period list.
|
53 |
+
For example, the cumulative_period = [100, 200, 300, 400],
|
54 |
+
if iteration == 50, return 0;
|
55 |
+
if iteration == 210, return 2;
|
56 |
+
if iteration == 300, return 2.
|
57 |
+
Args:
|
58 |
+
iteration (int): Current iteration.
|
59 |
+
cumulative_period (list[int]): Cumulative period list.
|
60 |
+
Returns:
|
61 |
+
int: The position of the right-closest number in the period list.
|
62 |
+
"""
|
63 |
+
for i, period in enumerate(cumulative_period):
|
64 |
+
if iteration <= period:
|
65 |
+
return i
|
66 |
+
|
67 |
+
|
68 |
+
class CosineAnnealingRestartLR(_LRScheduler):
|
69 |
+
""" Cosine annealing with restarts learning rate scheme.
|
70 |
+
An example of config:
|
71 |
+
periods = [10, 10, 10, 10]
|
72 |
+
restart_weights = [1, 0.5, 0.5, 0.5]
|
73 |
+
eta_min=1e-7
|
74 |
+
It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
|
75 |
+
scheduler will restart with the weights in restart_weights.
|
76 |
+
Args:
|
77 |
+
optimizer (torch.nn.optimizer): Torch optimizer.
|
78 |
+
periods (list): Period for each cosine anneling cycle.
|
79 |
+
restart_weights (list): Restart weights at each restart iteration.
|
80 |
+
Default: [1].
|
81 |
+
eta_min (float): The mimimum lr. Default: 0.
|
82 |
+
last_epoch (int): Used in _LRScheduler. Default: -1.
|
83 |
+
"""
|
84 |
+
def __init__(self,
|
85 |
+
optimizer,
|
86 |
+
periods,
|
87 |
+
restart_weights=(1, ),
|
88 |
+
eta_min=1e-7,
|
89 |
+
last_epoch=-1):
|
90 |
+
self.periods = periods
|
91 |
+
self.restart_weights = restart_weights
|
92 |
+
self.eta_min = eta_min
|
93 |
+
assert (len(self.periods) == len(self.restart_weights)
|
94 |
+
), 'periods and restart_weights should have the same length.'
|
95 |
+
self.cumulative_period = [
|
96 |
+
sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
|
97 |
+
]
|
98 |
+
super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
|
99 |
+
|
100 |
+
def get_lr(self):
|
101 |
+
idx = get_position_from_periods(self.last_epoch,
|
102 |
+
self.cumulative_period)
|
103 |
+
current_weight = self.restart_weights[idx]
|
104 |
+
nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
|
105 |
+
current_period = self.periods[idx]
|
106 |
+
|
107 |
+
return [
|
108 |
+
self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
|
109 |
+
(1 + math.cos(math.pi * (
|
110 |
+
(self.last_epoch - nearest_restart) / current_period)))
|
111 |
+
for base_lr in self.base_lrs
|
112 |
+
]
|
core/metrics.py
ADDED
@@ -0,0 +1,569 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from skimage import measure
|
3 |
+
from scipy import linalg
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from core.utils import to_tensors
|
10 |
+
|
11 |
+
|
12 |
+
def calculate_epe(flow1, flow2):
|
13 |
+
"""Calculate End point errors."""
|
14 |
+
|
15 |
+
epe = torch.sum((flow1 - flow2)**2, dim=1).sqrt()
|
16 |
+
epe = epe.view(-1)
|
17 |
+
return epe.mean().item()
|
18 |
+
|
19 |
+
|
20 |
+
def calculate_psnr(img1, img2):
|
21 |
+
"""Calculate PSNR (Peak Signal-to-Noise Ratio).
|
22 |
+
Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
|
23 |
+
Args:
|
24 |
+
img1 (ndarray): Images with range [0, 255].
|
25 |
+
img2 (ndarray): Images with range [0, 255].
|
26 |
+
Returns:
|
27 |
+
float: psnr result.
|
28 |
+
"""
|
29 |
+
|
30 |
+
assert img1.shape == img2.shape, \
|
31 |
+
(f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
|
32 |
+
|
33 |
+
mse = np.mean((img1 - img2)**2)
|
34 |
+
if mse == 0:
|
35 |
+
return float('inf')
|
36 |
+
return 20. * np.log10(255. / np.sqrt(mse))
|
37 |
+
|
38 |
+
|
39 |
+
def calc_psnr_and_ssim(img1, img2):
|
40 |
+
"""Calculate PSNR and SSIM for images.
|
41 |
+
img1: ndarray, range [0, 255]
|
42 |
+
img2: ndarray, range [0, 255]
|
43 |
+
"""
|
44 |
+
img1 = img1.astype(np.float64)
|
45 |
+
img2 = img2.astype(np.float64)
|
46 |
+
|
47 |
+
psnr = calculate_psnr(img1, img2)
|
48 |
+
ssim = measure.compare_ssim(img1,
|
49 |
+
img2,
|
50 |
+
data_range=255,
|
51 |
+
multichannel=True,
|
52 |
+
win_size=65)
|
53 |
+
|
54 |
+
return psnr, ssim
|
55 |
+
|
56 |
+
|
57 |
+
###########################
|
58 |
+
# I3D models
|
59 |
+
###########################
|
60 |
+
|
61 |
+
|
62 |
+
def init_i3d_model(i3d_model_path):
|
63 |
+
print(f"[Loading I3D model from {i3d_model_path} for FID score ..]")
|
64 |
+
i3d_model = InceptionI3d(400, in_channels=3, final_endpoint='Logits')
|
65 |
+
i3d_model.load_state_dict(torch.load(i3d_model_path))
|
66 |
+
i3d_model.to(torch.device('cuda:0'))
|
67 |
+
return i3d_model
|
68 |
+
|
69 |
+
|
70 |
+
def calculate_i3d_activations(video1, video2, i3d_model, device):
|
71 |
+
"""Calculate VFID metric.
|
72 |
+
video1: list[PIL.Image]
|
73 |
+
video2: list[PIL.Image]
|
74 |
+
"""
|
75 |
+
video1 = to_tensors()(video1).unsqueeze(0).to(device)
|
76 |
+
video2 = to_tensors()(video2).unsqueeze(0).to(device)
|
77 |
+
video1_activations = get_i3d_activations(
|
78 |
+
video1, i3d_model).cpu().numpy().flatten()
|
79 |
+
video2_activations = get_i3d_activations(
|
80 |
+
video2, i3d_model).cpu().numpy().flatten()
|
81 |
+
|
82 |
+
return video1_activations, video2_activations
|
83 |
+
|
84 |
+
|
85 |
+
def calculate_vfid(real_activations, fake_activations):
|
86 |
+
"""
|
87 |
+
Given two distribution of features, compute the FID score between them
|
88 |
+
Params:
|
89 |
+
real_activations: list[ndarray]
|
90 |
+
fake_activations: list[ndarray]
|
91 |
+
"""
|
92 |
+
m1 = np.mean(real_activations, axis=0)
|
93 |
+
m2 = np.mean(fake_activations, axis=0)
|
94 |
+
s1 = np.cov(real_activations, rowvar=False)
|
95 |
+
s2 = np.cov(fake_activations, rowvar=False)
|
96 |
+
return calculate_frechet_distance(m1, s1, m2, s2)
|
97 |
+
|
98 |
+
|
99 |
+
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
100 |
+
"""Numpy implementation of the Frechet Distance.
|
101 |
+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
102 |
+
and X_2 ~ N(mu_2, C_2) is
|
103 |
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
104 |
+
Stable version by Dougal J. Sutherland.
|
105 |
+
Params:
|
106 |
+
-- mu1 : Numpy array containing the activations of a layer of the
|
107 |
+
inception net (like returned by the function 'get_predictions')
|
108 |
+
for generated samples.
|
109 |
+
-- mu2 : The sample mean over activations, precalculated on an
|
110 |
+
representive data set.
|
111 |
+
-- sigma1: The covariance matrix over activations for generated samples.
|
112 |
+
-- sigma2: The covariance matrix over activations, precalculated on an
|
113 |
+
representive data set.
|
114 |
+
Returns:
|
115 |
+
-- : The Frechet Distance.
|
116 |
+
"""
|
117 |
+
|
118 |
+
mu1 = np.atleast_1d(mu1)
|
119 |
+
mu2 = np.atleast_1d(mu2)
|
120 |
+
|
121 |
+
sigma1 = np.atleast_2d(sigma1)
|
122 |
+
sigma2 = np.atleast_2d(sigma2)
|
123 |
+
|
124 |
+
assert mu1.shape == mu2.shape, \
|
125 |
+
'Training and test mean vectors have different lengths'
|
126 |
+
assert sigma1.shape == sigma2.shape, \
|
127 |
+
'Training and test covariances have different dimensions'
|
128 |
+
|
129 |
+
diff = mu1 - mu2
|
130 |
+
|
131 |
+
# Product might be almost singular
|
132 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
133 |
+
if not np.isfinite(covmean).all():
|
134 |
+
msg = ('fid calculation produces singular product; '
|
135 |
+
'adding %s to diagonal of cov estimates') % eps
|
136 |
+
print(msg)
|
137 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
138 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
139 |
+
|
140 |
+
# Numerical error might give slight imaginary component
|
141 |
+
if np.iscomplexobj(covmean):
|
142 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
143 |
+
m = np.max(np.abs(covmean.imag))
|
144 |
+
raise ValueError('Imaginary component {}'.format(m))
|
145 |
+
covmean = covmean.real
|
146 |
+
|
147 |
+
tr_covmean = np.trace(covmean)
|
148 |
+
|
149 |
+
return (diff.dot(diff) + np.trace(sigma1) + # NOQA
|
150 |
+
np.trace(sigma2) - 2 * tr_covmean)
|
151 |
+
|
152 |
+
|
153 |
+
def get_i3d_activations(batched_video,
|
154 |
+
i3d_model,
|
155 |
+
target_endpoint='Logits',
|
156 |
+
flatten=True,
|
157 |
+
grad_enabled=False):
|
158 |
+
"""
|
159 |
+
Get features from i3d model and flatten them to 1d feature,
|
160 |
+
valid target endpoints are defined in InceptionI3d.VALID_ENDPOINTS
|
161 |
+
VALID_ENDPOINTS = (
|
162 |
+
'Conv3d_1a_7x7',
|
163 |
+
'MaxPool3d_2a_3x3',
|
164 |
+
'Conv3d_2b_1x1',
|
165 |
+
'Conv3d_2c_3x3',
|
166 |
+
'MaxPool3d_3a_3x3',
|
167 |
+
'Mixed_3b',
|
168 |
+
'Mixed_3c',
|
169 |
+
'MaxPool3d_4a_3x3',
|
170 |
+
'Mixed_4b',
|
171 |
+
'Mixed_4c',
|
172 |
+
'Mixed_4d',
|
173 |
+
'Mixed_4e',
|
174 |
+
'Mixed_4f',
|
175 |
+
'MaxPool3d_5a_2x2',
|
176 |
+
'Mixed_5b',
|
177 |
+
'Mixed_5c',
|
178 |
+
'Logits',
|
179 |
+
'Predictions',
|
180 |
+
)
|
181 |
+
"""
|
182 |
+
with torch.set_grad_enabled(grad_enabled):
|
183 |
+
feat = i3d_model.extract_features(batched_video.transpose(1, 2),
|
184 |
+
target_endpoint)
|
185 |
+
if flatten:
|
186 |
+
feat = feat.view(feat.size(0), -1)
|
187 |
+
|
188 |
+
return feat
|
189 |
+
|
190 |
+
|
191 |
+
# This code is from https://github.com/piergiaj/pytorch-i3d/blob/master/pytorch_i3d.py
|
192 |
+
# I only fix flake8 errors and do some cleaning here
|
193 |
+
|
194 |
+
|
195 |
+
class MaxPool3dSamePadding(nn.MaxPool3d):
|
196 |
+
def compute_pad(self, dim, s):
|
197 |
+
if s % self.stride[dim] == 0:
|
198 |
+
return max(self.kernel_size[dim] - self.stride[dim], 0)
|
199 |
+
else:
|
200 |
+
return max(self.kernel_size[dim] - (s % self.stride[dim]), 0)
|
201 |
+
|
202 |
+
def forward(self, x):
|
203 |
+
# compute 'same' padding
|
204 |
+
(batch, channel, t, h, w) = x.size()
|
205 |
+
pad_t = self.compute_pad(0, t)
|
206 |
+
pad_h = self.compute_pad(1, h)
|
207 |
+
pad_w = self.compute_pad(2, w)
|
208 |
+
|
209 |
+
pad_t_f = pad_t // 2
|
210 |
+
pad_t_b = pad_t - pad_t_f
|
211 |
+
pad_h_f = pad_h // 2
|
212 |
+
pad_h_b = pad_h - pad_h_f
|
213 |
+
pad_w_f = pad_w // 2
|
214 |
+
pad_w_b = pad_w - pad_w_f
|
215 |
+
|
216 |
+
pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
|
217 |
+
x = F.pad(x, pad)
|
218 |
+
return super(MaxPool3dSamePadding, self).forward(x)
|
219 |
+
|
220 |
+
|
221 |
+
class Unit3D(nn.Module):
|
222 |
+
def __init__(self,
|
223 |
+
in_channels,
|
224 |
+
output_channels,
|
225 |
+
kernel_shape=(1, 1, 1),
|
226 |
+
stride=(1, 1, 1),
|
227 |
+
padding=0,
|
228 |
+
activation_fn=F.relu,
|
229 |
+
use_batch_norm=True,
|
230 |
+
use_bias=False,
|
231 |
+
name='unit_3d'):
|
232 |
+
"""Initializes Unit3D module."""
|
233 |
+
super(Unit3D, self).__init__()
|
234 |
+
|
235 |
+
self._output_channels = output_channels
|
236 |
+
self._kernel_shape = kernel_shape
|
237 |
+
self._stride = stride
|
238 |
+
self._use_batch_norm = use_batch_norm
|
239 |
+
self._activation_fn = activation_fn
|
240 |
+
self._use_bias = use_bias
|
241 |
+
self.name = name
|
242 |
+
self.padding = padding
|
243 |
+
|
244 |
+
self.conv3d = nn.Conv3d(
|
245 |
+
in_channels=in_channels,
|
246 |
+
out_channels=self._output_channels,
|
247 |
+
kernel_size=self._kernel_shape,
|
248 |
+
stride=self._stride,
|
249 |
+
padding=0, # we always want padding to be 0 here. We will
|
250 |
+
# dynamically pad based on input size in forward function
|
251 |
+
bias=self._use_bias)
|
252 |
+
|
253 |
+
if self._use_batch_norm:
|
254 |
+
self.bn = nn.BatchNorm3d(self._output_channels,
|
255 |
+
eps=0.001,
|
256 |
+
momentum=0.01)
|
257 |
+
|
258 |
+
def compute_pad(self, dim, s):
|
259 |
+
if s % self._stride[dim] == 0:
|
260 |
+
return max(self._kernel_shape[dim] - self._stride[dim], 0)
|
261 |
+
else:
|
262 |
+
return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0)
|
263 |
+
|
264 |
+
def forward(self, x):
|
265 |
+
# compute 'same' padding
|
266 |
+
(batch, channel, t, h, w) = x.size()
|
267 |
+
pad_t = self.compute_pad(0, t)
|
268 |
+
pad_h = self.compute_pad(1, h)
|
269 |
+
pad_w = self.compute_pad(2, w)
|
270 |
+
|
271 |
+
pad_t_f = pad_t // 2
|
272 |
+
pad_t_b = pad_t - pad_t_f
|
273 |
+
pad_h_f = pad_h // 2
|
274 |
+
pad_h_b = pad_h - pad_h_f
|
275 |
+
pad_w_f = pad_w // 2
|
276 |
+
pad_w_b = pad_w - pad_w_f
|
277 |
+
|
278 |
+
pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
|
279 |
+
x = F.pad(x, pad)
|
280 |
+
|
281 |
+
x = self.conv3d(x)
|
282 |
+
if self._use_batch_norm:
|
283 |
+
x = self.bn(x)
|
284 |
+
if self._activation_fn is not None:
|
285 |
+
x = self._activation_fn(x)
|
286 |
+
return x
|
287 |
+
|
288 |
+
|
289 |
+
class InceptionModule(nn.Module):
|
290 |
+
def __init__(self, in_channels, out_channels, name):
|
291 |
+
super(InceptionModule, self).__init__()
|
292 |
+
|
293 |
+
self.b0 = Unit3D(in_channels=in_channels,
|
294 |
+
output_channels=out_channels[0],
|
295 |
+
kernel_shape=[1, 1, 1],
|
296 |
+
padding=0,
|
297 |
+
name=name + '/Branch_0/Conv3d_0a_1x1')
|
298 |
+
self.b1a = Unit3D(in_channels=in_channels,
|
299 |
+
output_channels=out_channels[1],
|
300 |
+
kernel_shape=[1, 1, 1],
|
301 |
+
padding=0,
|
302 |
+
name=name + '/Branch_1/Conv3d_0a_1x1')
|
303 |
+
self.b1b = Unit3D(in_channels=out_channels[1],
|
304 |
+
output_channels=out_channels[2],
|
305 |
+
kernel_shape=[3, 3, 3],
|
306 |
+
name=name + '/Branch_1/Conv3d_0b_3x3')
|
307 |
+
self.b2a = Unit3D(in_channels=in_channels,
|
308 |
+
output_channels=out_channels[3],
|
309 |
+
kernel_shape=[1, 1, 1],
|
310 |
+
padding=0,
|
311 |
+
name=name + '/Branch_2/Conv3d_0a_1x1')
|
312 |
+
self.b2b = Unit3D(in_channels=out_channels[3],
|
313 |
+
output_channels=out_channels[4],
|
314 |
+
kernel_shape=[3, 3, 3],
|
315 |
+
name=name + '/Branch_2/Conv3d_0b_3x3')
|
316 |
+
self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3],
|
317 |
+
stride=(1, 1, 1),
|
318 |
+
padding=0)
|
319 |
+
self.b3b = Unit3D(in_channels=in_channels,
|
320 |
+
output_channels=out_channels[5],
|
321 |
+
kernel_shape=[1, 1, 1],
|
322 |
+
padding=0,
|
323 |
+
name=name + '/Branch_3/Conv3d_0b_1x1')
|
324 |
+
self.name = name
|
325 |
+
|
326 |
+
def forward(self, x):
|
327 |
+
b0 = self.b0(x)
|
328 |
+
b1 = self.b1b(self.b1a(x))
|
329 |
+
b2 = self.b2b(self.b2a(x))
|
330 |
+
b3 = self.b3b(self.b3a(x))
|
331 |
+
return torch.cat([b0, b1, b2, b3], dim=1)
|
332 |
+
|
333 |
+
|
334 |
+
class InceptionI3d(nn.Module):
|
335 |
+
"""Inception-v1 I3D architecture.
|
336 |
+
The model is introduced in:
|
337 |
+
Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset
|
338 |
+
Joao Carreira, Andrew Zisserman
|
339 |
+
https://arxiv.org/pdf/1705.07750v1.pdf.
|
340 |
+
See also the Inception architecture, introduced in:
|
341 |
+
Going deeper with convolutions
|
342 |
+
Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
|
343 |
+
Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
|
344 |
+
http://arxiv.org/pdf/1409.4842v1.pdf.
|
345 |
+
"""
|
346 |
+
|
347 |
+
# Endpoints of the model in order. During construction, all the endpoints up
|
348 |
+
# to a designated `final_endpoint` are returned in a dictionary as the
|
349 |
+
# second return value.
|
350 |
+
VALID_ENDPOINTS = (
|
351 |
+
'Conv3d_1a_7x7',
|
352 |
+
'MaxPool3d_2a_3x3',
|
353 |
+
'Conv3d_2b_1x1',
|
354 |
+
'Conv3d_2c_3x3',
|
355 |
+
'MaxPool3d_3a_3x3',
|
356 |
+
'Mixed_3b',
|
357 |
+
'Mixed_3c',
|
358 |
+
'MaxPool3d_4a_3x3',
|
359 |
+
'Mixed_4b',
|
360 |
+
'Mixed_4c',
|
361 |
+
'Mixed_4d',
|
362 |
+
'Mixed_4e',
|
363 |
+
'Mixed_4f',
|
364 |
+
'MaxPool3d_5a_2x2',
|
365 |
+
'Mixed_5b',
|
366 |
+
'Mixed_5c',
|
367 |
+
'Logits',
|
368 |
+
'Predictions',
|
369 |
+
)
|
370 |
+
|
371 |
+
def __init__(self,
|
372 |
+
num_classes=400,
|
373 |
+
spatial_squeeze=True,
|
374 |
+
final_endpoint='Logits',
|
375 |
+
name='inception_i3d',
|
376 |
+
in_channels=3,
|
377 |
+
dropout_keep_prob=0.5):
|
378 |
+
"""Initializes I3D model instance.
|
379 |
+
Args:
|
380 |
+
num_classes: The number of outputs in the logit layer (default 400, which
|
381 |
+
matches the Kinetics dataset).
|
382 |
+
spatial_squeeze: Whether to squeeze the spatial dimensions for the logits
|
383 |
+
before returning (default True).
|
384 |
+
final_endpoint: The model contains many possible endpoints.
|
385 |
+
`final_endpoint` specifies the last endpoint for the model to be built
|
386 |
+
up to. In addition to the output at `final_endpoint`, all the outputs
|
387 |
+
at endpoints up to `final_endpoint` will also be returned, in a
|
388 |
+
dictionary. `final_endpoint` must be one of
|
389 |
+
InceptionI3d.VALID_ENDPOINTS (default 'Logits').
|
390 |
+
name: A string (optional). The name of this module.
|
391 |
+
Raises:
|
392 |
+
ValueError: if `final_endpoint` is not recognized.
|
393 |
+
"""
|
394 |
+
|
395 |
+
if final_endpoint not in self.VALID_ENDPOINTS:
|
396 |
+
raise ValueError('Unknown final endpoint %s' % final_endpoint)
|
397 |
+
|
398 |
+
super(InceptionI3d, self).__init__()
|
399 |
+
self._num_classes = num_classes
|
400 |
+
self._spatial_squeeze = spatial_squeeze
|
401 |
+
self._final_endpoint = final_endpoint
|
402 |
+
self.logits = None
|
403 |
+
|
404 |
+
if self._final_endpoint not in self.VALID_ENDPOINTS:
|
405 |
+
raise ValueError('Unknown final endpoint %s' %
|
406 |
+
self._final_endpoint)
|
407 |
+
|
408 |
+
self.end_points = {}
|
409 |
+
end_point = 'Conv3d_1a_7x7'
|
410 |
+
self.end_points[end_point] = Unit3D(in_channels=in_channels,
|
411 |
+
output_channels=64,
|
412 |
+
kernel_shape=[7, 7, 7],
|
413 |
+
stride=(2, 2, 2),
|
414 |
+
padding=(3, 3, 3),
|
415 |
+
name=name + end_point)
|
416 |
+
if self._final_endpoint == end_point:
|
417 |
+
return
|
418 |
+
|
419 |
+
end_point = 'MaxPool3d_2a_3x3'
|
420 |
+
self.end_points[end_point] = MaxPool3dSamePadding(
|
421 |
+
kernel_size=[1, 3, 3], stride=(1, 2, 2), padding=0)
|
422 |
+
if self._final_endpoint == end_point:
|
423 |
+
return
|
424 |
+
|
425 |
+
end_point = 'Conv3d_2b_1x1'
|
426 |
+
self.end_points[end_point] = Unit3D(in_channels=64,
|
427 |
+
output_channels=64,
|
428 |
+
kernel_shape=[1, 1, 1],
|
429 |
+
padding=0,
|
430 |
+
name=name + end_point)
|
431 |
+
if self._final_endpoint == end_point:
|
432 |
+
return
|
433 |
+
|
434 |
+
end_point = 'Conv3d_2c_3x3'
|
435 |
+
self.end_points[end_point] = Unit3D(in_channels=64,
|
436 |
+
output_channels=192,
|
437 |
+
kernel_shape=[3, 3, 3],
|
438 |
+
padding=1,
|
439 |
+
name=name + end_point)
|
440 |
+
if self._final_endpoint == end_point:
|
441 |
+
return
|
442 |
+
|
443 |
+
end_point = 'MaxPool3d_3a_3x3'
|
444 |
+
self.end_points[end_point] = MaxPool3dSamePadding(
|
445 |
+
kernel_size=[1, 3, 3], stride=(1, 2, 2), padding=0)
|
446 |
+
if self._final_endpoint == end_point:
|
447 |
+
return
|
448 |
+
|
449 |
+
end_point = 'Mixed_3b'
|
450 |
+
self.end_points[end_point] = InceptionModule(192,
|
451 |
+
[64, 96, 128, 16, 32, 32],
|
452 |
+
name + end_point)
|
453 |
+
if self._final_endpoint == end_point:
|
454 |
+
return
|
455 |
+
|
456 |
+
end_point = 'Mixed_3c'
|
457 |
+
self.end_points[end_point] = InceptionModule(
|
458 |
+
256, [128, 128, 192, 32, 96, 64], name + end_point)
|
459 |
+
if self._final_endpoint == end_point:
|
460 |
+
return
|
461 |
+
|
462 |
+
end_point = 'MaxPool3d_4a_3x3'
|
463 |
+
self.end_points[end_point] = MaxPool3dSamePadding(
|
464 |
+
kernel_size=[3, 3, 3], stride=(2, 2, 2), padding=0)
|
465 |
+
if self._final_endpoint == end_point:
|
466 |
+
return
|
467 |
+
|
468 |
+
end_point = 'Mixed_4b'
|
469 |
+
self.end_points[end_point] = InceptionModule(
|
470 |
+
128 + 192 + 96 + 64, [192, 96, 208, 16, 48, 64], name + end_point)
|
471 |
+
if self._final_endpoint == end_point:
|
472 |
+
return
|
473 |
+
|
474 |
+
end_point = 'Mixed_4c'
|
475 |
+
self.end_points[end_point] = InceptionModule(
|
476 |
+
192 + 208 + 48 + 64, [160, 112, 224, 24, 64, 64], name + end_point)
|
477 |
+
if self._final_endpoint == end_point:
|
478 |
+
return
|
479 |
+
|
480 |
+
end_point = 'Mixed_4d'
|
481 |
+
self.end_points[end_point] = InceptionModule(
|
482 |
+
160 + 224 + 64 + 64, [128, 128, 256, 24, 64, 64], name + end_point)
|
483 |
+
if self._final_endpoint == end_point:
|
484 |
+
return
|
485 |
+
|
486 |
+
end_point = 'Mixed_4e'
|
487 |
+
self.end_points[end_point] = InceptionModule(
|
488 |
+
128 + 256 + 64 + 64, [112, 144, 288, 32, 64, 64], name + end_point)
|
489 |
+
if self._final_endpoint == end_point:
|
490 |
+
return
|
491 |
+
|
492 |
+
end_point = 'Mixed_4f'
|
493 |
+
self.end_points[end_point] = InceptionModule(
|
494 |
+
112 + 288 + 64 + 64, [256, 160, 320, 32, 128, 128],
|
495 |
+
name + end_point)
|
496 |
+
if self._final_endpoint == end_point:
|
497 |
+
return
|
498 |
+
|
499 |
+
end_point = 'MaxPool3d_5a_2x2'
|
500 |
+
self.end_points[end_point] = MaxPool3dSamePadding(
|
501 |
+
kernel_size=[2, 2, 2], stride=(2, 2, 2), padding=0)
|
502 |
+
if self._final_endpoint == end_point:
|
503 |
+
return
|
504 |
+
|
505 |
+
end_point = 'Mixed_5b'
|
506 |
+
self.end_points[end_point] = InceptionModule(
|
507 |
+
256 + 320 + 128 + 128, [256, 160, 320, 32, 128, 128],
|
508 |
+
name + end_point)
|
509 |
+
if self._final_endpoint == end_point:
|
510 |
+
return
|
511 |
+
|
512 |
+
end_point = 'Mixed_5c'
|
513 |
+
self.end_points[end_point] = InceptionModule(
|
514 |
+
256 + 320 + 128 + 128, [384, 192, 384, 48, 128, 128],
|
515 |
+
name + end_point)
|
516 |
+
if self._final_endpoint == end_point:
|
517 |
+
return
|
518 |
+
|
519 |
+
end_point = 'Logits'
|
520 |
+
self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7], stride=(1, 1, 1))
|
521 |
+
self.dropout = nn.Dropout(dropout_keep_prob)
|
522 |
+
self.logits = Unit3D(in_channels=384 + 384 + 128 + 128,
|
523 |
+
output_channels=self._num_classes,
|
524 |
+
kernel_shape=[1, 1, 1],
|
525 |
+
padding=0,
|
526 |
+
activation_fn=None,
|
527 |
+
use_batch_norm=False,
|
528 |
+
use_bias=True,
|
529 |
+
name='logits')
|
530 |
+
|
531 |
+
self.build()
|
532 |
+
|
533 |
+
def replace_logits(self, num_classes):
|
534 |
+
self._num_classes = num_classes
|
535 |
+
self.logits = Unit3D(in_channels=384 + 384 + 128 + 128,
|
536 |
+
output_channels=self._num_classes,
|
537 |
+
kernel_shape=[1, 1, 1],
|
538 |
+
padding=0,
|
539 |
+
activation_fn=None,
|
540 |
+
use_batch_norm=False,
|
541 |
+
use_bias=True,
|
542 |
+
name='logits')
|
543 |
+
|
544 |
+
def build(self):
|
545 |
+
for k in self.end_points.keys():
|
546 |
+
self.add_module(k, self.end_points[k])
|
547 |
+
|
548 |
+
def forward(self, x):
|
549 |
+
for end_point in self.VALID_ENDPOINTS:
|
550 |
+
if end_point in self.end_points:
|
551 |
+
x = self._modules[end_point](
|
552 |
+
x) # use _modules to work with dataparallel
|
553 |
+
|
554 |
+
x = self.logits(self.dropout(self.avg_pool(x)))
|
555 |
+
if self._spatial_squeeze:
|
556 |
+
logits = x.squeeze(3).squeeze(3)
|
557 |
+
# logits is batch X time X classes, which is what we want to work with
|
558 |
+
return logits
|
559 |
+
|
560 |
+
def extract_features(self, x, target_endpoint='Logits'):
|
561 |
+
for end_point in self.VALID_ENDPOINTS:
|
562 |
+
if end_point in self.end_points:
|
563 |
+
x = self._modules[end_point](x)
|
564 |
+
if end_point == target_endpoint:
|
565 |
+
break
|
566 |
+
if target_endpoint == 'Logits':
|
567 |
+
return x.mean(4).mean(3).mean(2)
|
568 |
+
else:
|
569 |
+
return x
|
core/prefetch_dataloader.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import queue as Queue
|
2 |
+
import threading
|
3 |
+
import torch
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
|
6 |
+
|
7 |
+
class PrefetchGenerator(threading.Thread):
|
8 |
+
"""A general prefetch generator.
|
9 |
+
|
10 |
+
Ref:
|
11 |
+
https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
|
12 |
+
|
13 |
+
Args:
|
14 |
+
generator: Python generator.
|
15 |
+
num_prefetch_queue (int): Number of prefetch queue.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, generator, num_prefetch_queue):
|
19 |
+
threading.Thread.__init__(self)
|
20 |
+
self.queue = Queue.Queue(num_prefetch_queue)
|
21 |
+
self.generator = generator
|
22 |
+
self.daemon = True
|
23 |
+
self.start()
|
24 |
+
|
25 |
+
def run(self):
|
26 |
+
for item in self.generator:
|
27 |
+
self.queue.put(item)
|
28 |
+
self.queue.put(None)
|
29 |
+
|
30 |
+
def __next__(self):
|
31 |
+
next_item = self.queue.get()
|
32 |
+
if next_item is None:
|
33 |
+
raise StopIteration
|
34 |
+
return next_item
|
35 |
+
|
36 |
+
def __iter__(self):
|
37 |
+
return self
|
38 |
+
|
39 |
+
|
40 |
+
class PrefetchDataLoader(DataLoader):
|
41 |
+
"""Prefetch version of dataloader.
|
42 |
+
|
43 |
+
Ref:
|
44 |
+
https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
|
45 |
+
|
46 |
+
TODO:
|
47 |
+
Need to test on single gpu and ddp (multi-gpu). There is a known issue in
|
48 |
+
ddp.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
num_prefetch_queue (int): Number of prefetch queue.
|
52 |
+
kwargs (dict): Other arguments for dataloader.
|
53 |
+
"""
|
54 |
+
|
55 |
+
def __init__(self, num_prefetch_queue, **kwargs):
|
56 |
+
self.num_prefetch_queue = num_prefetch_queue
|
57 |
+
super(PrefetchDataLoader, self).__init__(**kwargs)
|
58 |
+
|
59 |
+
def __iter__(self):
|
60 |
+
return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
|
61 |
+
|
62 |
+
|
63 |
+
class CPUPrefetcher():
|
64 |
+
"""CPU prefetcher.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
loader: Dataloader.
|
68 |
+
"""
|
69 |
+
|
70 |
+
def __init__(self, loader):
|
71 |
+
self.ori_loader = loader
|
72 |
+
self.loader = iter(loader)
|
73 |
+
|
74 |
+
def next(self):
|
75 |
+
try:
|
76 |
+
return next(self.loader)
|
77 |
+
except StopIteration:
|
78 |
+
return None
|
79 |
+
|
80 |
+
def reset(self):
|
81 |
+
self.loader = iter(self.ori_loader)
|
82 |
+
|
83 |
+
|
84 |
+
class CUDAPrefetcher():
|
85 |
+
"""CUDA prefetcher.
|
86 |
+
|
87 |
+
Ref:
|
88 |
+
https://github.com/NVIDIA/apex/issues/304#
|
89 |
+
|
90 |
+
It may consums more GPU memory.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
loader: Dataloader.
|
94 |
+
opt (dict): Options.
|
95 |
+
"""
|
96 |
+
|
97 |
+
def __init__(self, loader, opt):
|
98 |
+
self.ori_loader = loader
|
99 |
+
self.loader = iter(loader)
|
100 |
+
self.opt = opt
|
101 |
+
self.stream = torch.cuda.Stream()
|
102 |
+
self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
|
103 |
+
self.preload()
|
104 |
+
|
105 |
+
def preload(self):
|
106 |
+
try:
|
107 |
+
self.batch = next(self.loader) # self.batch is a dict
|
108 |
+
except StopIteration:
|
109 |
+
self.batch = None
|
110 |
+
return None
|
111 |
+
# put tensors to gpu
|
112 |
+
with torch.cuda.stream(self.stream):
|
113 |
+
for k, v in self.batch.items():
|
114 |
+
if torch.is_tensor(v):
|
115 |
+
self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
|
116 |
+
|
117 |
+
def next(self):
|
118 |
+
torch.cuda.current_stream().wait_stream(self.stream)
|
119 |
+
batch = self.batch
|
120 |
+
self.preload()
|
121 |
+
return batch
|
122 |
+
|
123 |
+
def reset(self):
|
124 |
+
self.loader = iter(self.ori_loader)
|
125 |
+
self.preload()
|
core/trainer.py
ADDED
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import logging
|
4 |
+
import importlib
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from core.prefetch_dataloader import PrefetchDataLoader, CPUPrefetcher
|
11 |
+
from torch.utils.data.distributed import DistributedSampler
|
12 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
13 |
+
import torchvision
|
14 |
+
from torch.utils.tensorboard import SummaryWriter
|
15 |
+
|
16 |
+
from core.lr_scheduler import MultiStepRestartLR, CosineAnnealingRestartLR
|
17 |
+
from core.loss import AdversarialLoss, PerceptualLoss, LPIPSLoss
|
18 |
+
from core.dataset import TrainDataset
|
19 |
+
|
20 |
+
from model.modules.flow_comp_raft import RAFT_bi, FlowLoss, EdgeLoss
|
21 |
+
from model.recurrent_flow_completion import RecurrentFlowCompleteNet
|
22 |
+
|
23 |
+
from RAFT.utils.flow_viz_pt import flow_to_image
|
24 |
+
|
25 |
+
|
26 |
+
class Trainer:
|
27 |
+
def __init__(self, config):
|
28 |
+
self.config = config
|
29 |
+
self.epoch = 0
|
30 |
+
self.iteration = 0
|
31 |
+
self.num_local_frames = config['train_data_loader']['num_local_frames']
|
32 |
+
self.num_ref_frames = config['train_data_loader']['num_ref_frames']
|
33 |
+
|
34 |
+
# setup data set and data loader
|
35 |
+
self.train_dataset = TrainDataset(config['train_data_loader'])
|
36 |
+
|
37 |
+
self.train_sampler = None
|
38 |
+
self.train_args = config['trainer']
|
39 |
+
if config['distributed']:
|
40 |
+
self.train_sampler = DistributedSampler(
|
41 |
+
self.train_dataset,
|
42 |
+
num_replicas=config['world_size'],
|
43 |
+
rank=config['global_rank'])
|
44 |
+
|
45 |
+
dataloader_args = dict(
|
46 |
+
dataset=self.train_dataset,
|
47 |
+
batch_size=self.train_args['batch_size'] // config['world_size'],
|
48 |
+
shuffle=(self.train_sampler is None),
|
49 |
+
num_workers=self.train_args['num_workers'],
|
50 |
+
sampler=self.train_sampler,
|
51 |
+
drop_last=True)
|
52 |
+
|
53 |
+
self.train_loader = PrefetchDataLoader(self.train_args['num_prefetch_queue'], **dataloader_args)
|
54 |
+
self.prefetcher = CPUPrefetcher(self.train_loader)
|
55 |
+
|
56 |
+
# set loss functions
|
57 |
+
self.adversarial_loss = AdversarialLoss(type=self.config['losses']['GAN_LOSS'])
|
58 |
+
self.adversarial_loss = self.adversarial_loss.to(self.config['device'])
|
59 |
+
self.l1_loss = nn.L1Loss()
|
60 |
+
# self.perc_loss = PerceptualLoss(
|
61 |
+
# layer_weights={'conv3_4': 0.25, 'conv4_4': 0.25, 'conv5_4': 0.5},
|
62 |
+
# use_input_norm=True,
|
63 |
+
# range_norm=True,
|
64 |
+
# criterion='l1'
|
65 |
+
# ).to(self.config['device'])
|
66 |
+
|
67 |
+
if self.config['losses']['perceptual_weight'] > 0:
|
68 |
+
self.perc_loss = LPIPSLoss(use_input_norm=True, range_norm=True).to(self.config['device'])
|
69 |
+
|
70 |
+
# self.flow_comp_loss = FlowCompletionLoss().to(self.config['device'])
|
71 |
+
# self.flow_comp_loss = FlowCompletionLoss(self.config['device'])
|
72 |
+
|
73 |
+
# set raft
|
74 |
+
self.fix_raft = RAFT_bi(device = self.config['device'])
|
75 |
+
self.fix_flow_complete = RecurrentFlowCompleteNet('/mnt/lustre/sczhou/VQGANs/CodeMOVI/experiments_model/recurrent_flow_completion_v5_train_flowcomp_v5/gen_760000.pth')
|
76 |
+
for p in self.fix_flow_complete.parameters():
|
77 |
+
p.requires_grad = False
|
78 |
+
self.fix_flow_complete.to(self.config['device'])
|
79 |
+
self.fix_flow_complete.eval()
|
80 |
+
|
81 |
+
# self.flow_loss = FlowLoss()
|
82 |
+
|
83 |
+
# setup models including generator and discriminator
|
84 |
+
net = importlib.import_module('model.' + config['model']['net'])
|
85 |
+
self.netG = net.InpaintGenerator()
|
86 |
+
# print(self.netG)
|
87 |
+
self.netG = self.netG.to(self.config['device'])
|
88 |
+
if not self.config['model'].get('no_dis', False):
|
89 |
+
if self.config['model'].get('dis_2d', False):
|
90 |
+
self.netD = net.Discriminator_2D(
|
91 |
+
in_channels=3,
|
92 |
+
use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge')
|
93 |
+
else:
|
94 |
+
self.netD = net.Discriminator(
|
95 |
+
in_channels=3,
|
96 |
+
use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge')
|
97 |
+
self.netD = self.netD.to(self.config['device'])
|
98 |
+
|
99 |
+
self.interp_mode = self.config['model']['interp_mode']
|
100 |
+
# setup optimizers and schedulers
|
101 |
+
self.setup_optimizers()
|
102 |
+
self.setup_schedulers()
|
103 |
+
self.load()
|
104 |
+
|
105 |
+
if config['distributed']:
|
106 |
+
self.netG = DDP(self.netG,
|
107 |
+
device_ids=[self.config['local_rank']],
|
108 |
+
output_device=self.config['local_rank'],
|
109 |
+
broadcast_buffers=True,
|
110 |
+
find_unused_parameters=True)
|
111 |
+
if not self.config['model']['no_dis']:
|
112 |
+
self.netD = DDP(self.netD,
|
113 |
+
device_ids=[self.config['local_rank']],
|
114 |
+
output_device=self.config['local_rank'],
|
115 |
+
broadcast_buffers=True,
|
116 |
+
find_unused_parameters=False)
|
117 |
+
|
118 |
+
# set summary writer
|
119 |
+
self.dis_writer = None
|
120 |
+
self.gen_writer = None
|
121 |
+
self.summary = {}
|
122 |
+
if self.config['global_rank'] == 0 or (not config['distributed']):
|
123 |
+
if not self.config['model']['no_dis']:
|
124 |
+
self.dis_writer = SummaryWriter(
|
125 |
+
os.path.join(config['save_dir'], 'dis'))
|
126 |
+
self.gen_writer = SummaryWriter(
|
127 |
+
os.path.join(config['save_dir'], 'gen'))
|
128 |
+
|
129 |
+
def setup_optimizers(self):
|
130 |
+
"""Set up optimizers."""
|
131 |
+
backbone_params = []
|
132 |
+
for name, param in self.netG.named_parameters():
|
133 |
+
if param.requires_grad:
|
134 |
+
backbone_params.append(param)
|
135 |
+
else:
|
136 |
+
print(f'Params {name} will not be optimized.')
|
137 |
+
|
138 |
+
optim_params = [
|
139 |
+
{
|
140 |
+
'params': backbone_params,
|
141 |
+
'lr': self.config['trainer']['lr']
|
142 |
+
},
|
143 |
+
]
|
144 |
+
|
145 |
+
self.optimG = torch.optim.Adam(optim_params,
|
146 |
+
betas=(self.config['trainer']['beta1'],
|
147 |
+
self.config['trainer']['beta2']))
|
148 |
+
|
149 |
+
if not self.config['model']['no_dis']:
|
150 |
+
self.optimD = torch.optim.Adam(
|
151 |
+
self.netD.parameters(),
|
152 |
+
lr=self.config['trainer']['lr'],
|
153 |
+
betas=(self.config['trainer']['beta1'],
|
154 |
+
self.config['trainer']['beta2']))
|
155 |
+
|
156 |
+
def setup_schedulers(self):
|
157 |
+
"""Set up schedulers."""
|
158 |
+
scheduler_opt = self.config['trainer']['scheduler']
|
159 |
+
scheduler_type = scheduler_opt.pop('type')
|
160 |
+
|
161 |
+
if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
|
162 |
+
self.scheG = MultiStepRestartLR(
|
163 |
+
self.optimG,
|
164 |
+
milestones=scheduler_opt['milestones'],
|
165 |
+
gamma=scheduler_opt['gamma'])
|
166 |
+
if not self.config['model']['no_dis']:
|
167 |
+
self.scheD = MultiStepRestartLR(
|
168 |
+
self.optimD,
|
169 |
+
milestones=scheduler_opt['milestones'],
|
170 |
+
gamma=scheduler_opt['gamma'])
|
171 |
+
elif scheduler_type == 'CosineAnnealingRestartLR':
|
172 |
+
self.scheG = CosineAnnealingRestartLR(
|
173 |
+
self.optimG,
|
174 |
+
periods=scheduler_opt['periods'],
|
175 |
+
restart_weights=scheduler_opt['restart_weights'],
|
176 |
+
eta_min=scheduler_opt['eta_min'])
|
177 |
+
if not self.config['model']['no_dis']:
|
178 |
+
self.scheD = CosineAnnealingRestartLR(
|
179 |
+
self.optimD,
|
180 |
+
periods=scheduler_opt['periods'],
|
181 |
+
restart_weights=scheduler_opt['restart_weights'],
|
182 |
+
eta_min=scheduler_opt['eta_min'])
|
183 |
+
else:
|
184 |
+
raise NotImplementedError(
|
185 |
+
f'Scheduler {scheduler_type} is not implemented yet.')
|
186 |
+
|
187 |
+
def update_learning_rate(self):
|
188 |
+
"""Update learning rate."""
|
189 |
+
self.scheG.step()
|
190 |
+
if not self.config['model']['no_dis']:
|
191 |
+
self.scheD.step()
|
192 |
+
|
193 |
+
def get_lr(self):
|
194 |
+
"""Get current learning rate."""
|
195 |
+
return self.optimG.param_groups[0]['lr']
|
196 |
+
|
197 |
+
def add_summary(self, writer, name, val):
|
198 |
+
"""Add tensorboard summary."""
|
199 |
+
if name not in self.summary:
|
200 |
+
self.summary[name] = 0
|
201 |
+
self.summary[name] += val
|
202 |
+
n = self.train_args['log_freq']
|
203 |
+
if writer is not None and self.iteration % n == 0:
|
204 |
+
writer.add_scalar(name, self.summary[name] / n, self.iteration)
|
205 |
+
self.summary[name] = 0
|
206 |
+
|
207 |
+
def load(self):
|
208 |
+
"""Load netG (and netD)."""
|
209 |
+
# get the latest checkpoint
|
210 |
+
model_path = self.config['save_dir']
|
211 |
+
# TODO: add resume name
|
212 |
+
if os.path.isfile(os.path.join(model_path, 'latest.ckpt')):
|
213 |
+
latest_epoch = open(os.path.join(model_path, 'latest.ckpt'),
|
214 |
+
'r').read().splitlines()[-1]
|
215 |
+
else:
|
216 |
+
ckpts = [
|
217 |
+
os.path.basename(i).split('.pth')[0]
|
218 |
+
for i in glob.glob(os.path.join(model_path, '*.pth'))
|
219 |
+
]
|
220 |
+
ckpts.sort()
|
221 |
+
latest_epoch = ckpts[-1][4:] if len(ckpts) > 0 else None
|
222 |
+
|
223 |
+
if latest_epoch is not None:
|
224 |
+
gen_path = os.path.join(model_path,
|
225 |
+
f'gen_{int(latest_epoch):06d}.pth')
|
226 |
+
dis_path = os.path.join(model_path,
|
227 |
+
f'dis_{int(latest_epoch):06d}.pth')
|
228 |
+
opt_path = os.path.join(model_path,
|
229 |
+
f'opt_{int(latest_epoch):06d}.pth')
|
230 |
+
|
231 |
+
if self.config['global_rank'] == 0:
|
232 |
+
print(f'Loading model from {gen_path}...')
|
233 |
+
dataG = torch.load(gen_path, map_location=self.config['device'])
|
234 |
+
self.netG.load_state_dict(dataG)
|
235 |
+
if not self.config['model']['no_dis'] and self.config['model']['load_d']:
|
236 |
+
dataD = torch.load(dis_path, map_location=self.config['device'])
|
237 |
+
self.netD.load_state_dict(dataD)
|
238 |
+
|
239 |
+
data_opt = torch.load(opt_path, map_location=self.config['device'])
|
240 |
+
self.optimG.load_state_dict(data_opt['optimG'])
|
241 |
+
# self.scheG.load_state_dict(data_opt['scheG'])
|
242 |
+
if not self.config['model']['no_dis'] and self.config['model']['load_d']:
|
243 |
+
self.optimD.load_state_dict(data_opt['optimD'])
|
244 |
+
# self.scheD.load_state_dict(data_opt['scheD'])
|
245 |
+
self.epoch = data_opt['epoch']
|
246 |
+
self.iteration = data_opt['iteration']
|
247 |
+
else:
|
248 |
+
gen_path = self.config['trainer'].get('gen_path', None)
|
249 |
+
dis_path = self.config['trainer'].get('dis_path', None)
|
250 |
+
opt_path = self.config['trainer'].get('opt_path', None)
|
251 |
+
if gen_path is not None:
|
252 |
+
if self.config['global_rank'] == 0:
|
253 |
+
print(f'Loading Gen-Net from {gen_path}...')
|
254 |
+
dataG = torch.load(gen_path, map_location=self.config['device'])
|
255 |
+
self.netG.load_state_dict(dataG)
|
256 |
+
|
257 |
+
if dis_path is not None and not self.config['model']['no_dis'] and self.config['model']['load_d']:
|
258 |
+
if self.config['global_rank'] == 0:
|
259 |
+
print(f'Loading Dis-Net from {dis_path}...')
|
260 |
+
dataD = torch.load(dis_path, map_location=self.config['device'])
|
261 |
+
self.netD.load_state_dict(dataD)
|
262 |
+
if opt_path is not None:
|
263 |
+
data_opt = torch.load(opt_path, map_location=self.config['device'])
|
264 |
+
self.optimG.load_state_dict(data_opt['optimG'])
|
265 |
+
self.scheG.load_state_dict(data_opt['scheG'])
|
266 |
+
if not self.config['model']['no_dis'] and self.config['model']['load_d']:
|
267 |
+
self.optimD.load_state_dict(data_opt['optimD'])
|
268 |
+
self.scheD.load_state_dict(data_opt['scheD'])
|
269 |
+
else:
|
270 |
+
if self.config['global_rank'] == 0:
|
271 |
+
print('Warnning: There is no trained model found.'
|
272 |
+
'An initialized model will be used.')
|
273 |
+
|
274 |
+
def save(self, it):
|
275 |
+
"""Save parameters every eval_epoch"""
|
276 |
+
if self.config['global_rank'] == 0:
|
277 |
+
# configure path
|
278 |
+
gen_path = os.path.join(self.config['save_dir'],
|
279 |
+
f'gen_{it:06d}.pth')
|
280 |
+
dis_path = os.path.join(self.config['save_dir'],
|
281 |
+
f'dis_{it:06d}.pth')
|
282 |
+
opt_path = os.path.join(self.config['save_dir'],
|
283 |
+
f'opt_{it:06d}.pth')
|
284 |
+
print(f'\nsaving model to {gen_path} ...')
|
285 |
+
|
286 |
+
# remove .module for saving
|
287 |
+
if isinstance(self.netG, torch.nn.DataParallel) or isinstance(self.netG, DDP):
|
288 |
+
netG = self.netG.module
|
289 |
+
if not self.config['model']['no_dis']:
|
290 |
+
netD = self.netD.module
|
291 |
+
else:
|
292 |
+
netG = self.netG
|
293 |
+
if not self.config['model']['no_dis']:
|
294 |
+
netD = self.netD
|
295 |
+
|
296 |
+
# save checkpoints
|
297 |
+
torch.save(netG.state_dict(), gen_path)
|
298 |
+
if not self.config['model']['no_dis']:
|
299 |
+
torch.save(netD.state_dict(), dis_path)
|
300 |
+
torch.save(
|
301 |
+
{
|
302 |
+
'epoch': self.epoch,
|
303 |
+
'iteration': self.iteration,
|
304 |
+
'optimG': self.optimG.state_dict(),
|
305 |
+
'optimD': self.optimD.state_dict(),
|
306 |
+
'scheG': self.scheG.state_dict(),
|
307 |
+
'scheD': self.scheD.state_dict()
|
308 |
+
}, opt_path)
|
309 |
+
else:
|
310 |
+
torch.save(
|
311 |
+
{
|
312 |
+
'epoch': self.epoch,
|
313 |
+
'iteration': self.iteration,
|
314 |
+
'optimG': self.optimG.state_dict(),
|
315 |
+
'scheG': self.scheG.state_dict()
|
316 |
+
}, opt_path)
|
317 |
+
|
318 |
+
latest_path = os.path.join(self.config['save_dir'], 'latest.ckpt')
|
319 |
+
os.system(f"echo {it:06d} > {latest_path}")
|
320 |
+
|
321 |
+
def train(self):
|
322 |
+
"""training entry"""
|
323 |
+
pbar = range(int(self.train_args['iterations']))
|
324 |
+
if self.config['global_rank'] == 0:
|
325 |
+
pbar = tqdm(pbar,
|
326 |
+
initial=self.iteration,
|
327 |
+
dynamic_ncols=True,
|
328 |
+
smoothing=0.01)
|
329 |
+
|
330 |
+
os.makedirs('logs', exist_ok=True)
|
331 |
+
|
332 |
+
logging.basicConfig(
|
333 |
+
level=logging.INFO,
|
334 |
+
format="%(asctime)s %(filename)s[line:%(lineno)d]"
|
335 |
+
"%(levelname)s %(message)s",
|
336 |
+
datefmt="%a, %d %b %Y %H:%M:%S",
|
337 |
+
filename=f"logs/{self.config['save_dir'].split('/')[-1]}.log",
|
338 |
+
filemode='w')
|
339 |
+
|
340 |
+
while True:
|
341 |
+
self.epoch += 1
|
342 |
+
self.prefetcher.reset()
|
343 |
+
if self.config['distributed']:
|
344 |
+
self.train_sampler.set_epoch(self.epoch)
|
345 |
+
self._train_epoch(pbar)
|
346 |
+
if self.iteration > self.train_args['iterations']:
|
347 |
+
break
|
348 |
+
print('\nEnd training....')
|
349 |
+
|
350 |
+
def _train_epoch(self, pbar):
|
351 |
+
"""Process input and calculate loss every training epoch"""
|
352 |
+
device = self.config['device']
|
353 |
+
train_data = self.prefetcher.next()
|
354 |
+
while train_data is not None:
|
355 |
+
self.iteration += 1
|
356 |
+
frames, masks, flows_f, flows_b, _ = train_data
|
357 |
+
frames, masks = frames.to(device), masks.to(device).float()
|
358 |
+
l_t = self.num_local_frames
|
359 |
+
b, t, c, h, w = frames.size()
|
360 |
+
gt_local_frames = frames[:, :l_t, ...]
|
361 |
+
local_masks = masks[:, :l_t, ...].contiguous()
|
362 |
+
|
363 |
+
masked_frames = frames * (1 - masks)
|
364 |
+
masked_local_frames = masked_frames[:, :l_t, ...]
|
365 |
+
# get gt optical flow
|
366 |
+
if flows_f[0] == 'None' or flows_b[0] == 'None':
|
367 |
+
gt_flows_bi = self.fix_raft(gt_local_frames)
|
368 |
+
else:
|
369 |
+
gt_flows_bi = (flows_f.to(device), flows_b.to(device))
|
370 |
+
|
371 |
+
# ---- complete flow ----
|
372 |
+
pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(gt_flows_bi, local_masks)
|
373 |
+
pred_flows_bi = self.fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, local_masks)
|
374 |
+
# pred_flows_bi = gt_flows_bi
|
375 |
+
|
376 |
+
# ---- image propagation ----
|
377 |
+
prop_imgs, updated_local_masks = self.netG.module.img_propagation(masked_local_frames, pred_flows_bi, local_masks, interpolation=self.interp_mode)
|
378 |
+
updated_masks = masks.clone()
|
379 |
+
updated_masks[:, :l_t, ...] = updated_local_masks.view(b, l_t, 1, h, w)
|
380 |
+
updated_frames = masked_frames.clone()
|
381 |
+
prop_local_frames = gt_local_frames * (1-local_masks) + prop_imgs.view(b, l_t, 3, h, w) * local_masks # merge
|
382 |
+
updated_frames[:, :l_t, ...] = prop_local_frames
|
383 |
+
|
384 |
+
# ---- feature propagation + Transformer ----
|
385 |
+
pred_imgs = self.netG(updated_frames, pred_flows_bi, masks, updated_masks, l_t)
|
386 |
+
pred_imgs = pred_imgs.view(b, -1, c, h, w)
|
387 |
+
|
388 |
+
# get the local frames
|
389 |
+
pred_local_frames = pred_imgs[:, :l_t, ...]
|
390 |
+
comp_local_frames = gt_local_frames * (1. - local_masks) + pred_local_frames * local_masks
|
391 |
+
comp_imgs = frames * (1. - masks) + pred_imgs * masks
|
392 |
+
|
393 |
+
gen_loss = 0
|
394 |
+
dis_loss = 0
|
395 |
+
# optimize net_g
|
396 |
+
if not self.config['model']['no_dis']:
|
397 |
+
for p in self.netD.parameters():
|
398 |
+
p.requires_grad = False
|
399 |
+
|
400 |
+
self.optimG.zero_grad()
|
401 |
+
|
402 |
+
# generator l1 loss
|
403 |
+
hole_loss = self.l1_loss(pred_imgs * masks, frames * masks)
|
404 |
+
hole_loss = hole_loss / torch.mean(masks) * self.config['losses']['hole_weight']
|
405 |
+
gen_loss += hole_loss
|
406 |
+
self.add_summary(self.gen_writer, 'loss/hole_loss', hole_loss.item())
|
407 |
+
|
408 |
+
valid_loss = self.l1_loss(pred_imgs * (1 - masks), frames * (1 - masks))
|
409 |
+
valid_loss = valid_loss / torch.mean(1-masks) * self.config['losses']['valid_weight']
|
410 |
+
gen_loss += valid_loss
|
411 |
+
self.add_summary(self.gen_writer, 'loss/valid_loss', valid_loss.item())
|
412 |
+
|
413 |
+
# perceptual loss
|
414 |
+
if self.config['losses']['perceptual_weight'] > 0:
|
415 |
+
perc_loss = self.perc_loss(pred_imgs.view(-1,3,h,w), frames.view(-1,3,h,w))[0] * self.config['losses']['perceptual_weight']
|
416 |
+
gen_loss += perc_loss
|
417 |
+
self.add_summary(self.gen_writer, 'loss/perc_loss', perc_loss.item())
|
418 |
+
|
419 |
+
# gan loss
|
420 |
+
if not self.config['model']['no_dis']:
|
421 |
+
# generator adversarial loss
|
422 |
+
gen_clip = self.netD(comp_imgs)
|
423 |
+
gan_loss = self.adversarial_loss(gen_clip, True, False)
|
424 |
+
gan_loss = gan_loss * self.config['losses']['adversarial_weight']
|
425 |
+
gen_loss += gan_loss
|
426 |
+
self.add_summary(self.gen_writer, 'loss/gan_loss', gan_loss.item())
|
427 |
+
gen_loss.backward()
|
428 |
+
self.optimG.step()
|
429 |
+
|
430 |
+
if not self.config['model']['no_dis']:
|
431 |
+
# optimize net_d
|
432 |
+
for p in self.netD.parameters():
|
433 |
+
p.requires_grad = True
|
434 |
+
self.optimD.zero_grad()
|
435 |
+
|
436 |
+
# discriminator adversarial loss
|
437 |
+
real_clip = self.netD(frames)
|
438 |
+
fake_clip = self.netD(comp_imgs.detach())
|
439 |
+
dis_real_loss = self.adversarial_loss(real_clip, True, True)
|
440 |
+
dis_fake_loss = self.adversarial_loss(fake_clip, False, True)
|
441 |
+
dis_loss += (dis_real_loss + dis_fake_loss) / 2
|
442 |
+
self.add_summary(self.dis_writer, 'loss/dis_vid_real', dis_real_loss.item())
|
443 |
+
self.add_summary(self.dis_writer, 'loss/dis_vid_fake', dis_fake_loss.item())
|
444 |
+
dis_loss.backward()
|
445 |
+
self.optimD.step()
|
446 |
+
|
447 |
+
self.update_learning_rate()
|
448 |
+
|
449 |
+
# write image to tensorboard
|
450 |
+
if self.iteration % 200 == 0:
|
451 |
+
# img to cpu
|
452 |
+
t = 0
|
453 |
+
gt_local_frames_cpu = ((gt_local_frames.view(b,-1,3,h,w) + 1)/2.0).cpu()
|
454 |
+
masked_local_frames = ((masked_local_frames.view(b,-1,3,h,w) + 1)/2.0).cpu()
|
455 |
+
prop_local_frames_cpu = ((prop_local_frames.view(b,-1,3,h,w) + 1)/2.0).cpu()
|
456 |
+
pred_local_frames_cpu = ((pred_local_frames.view(b,-1,3,h,w) + 1)/2.0).cpu()
|
457 |
+
img_results = torch.cat([masked_local_frames[0][t], gt_local_frames_cpu[0][t],
|
458 |
+
prop_local_frames_cpu[0][t], pred_local_frames_cpu[0][t]], 1)
|
459 |
+
img_results = torchvision.utils.make_grid(img_results, nrow=1, normalize=True)
|
460 |
+
if self.gen_writer is not None:
|
461 |
+
self.gen_writer.add_image(f'img/img:inp-gt-res-{t}', img_results, self.iteration)
|
462 |
+
|
463 |
+
t = 5
|
464 |
+
if masked_local_frames.shape[1] > 5:
|
465 |
+
img_results = torch.cat([masked_local_frames[0][t], gt_local_frames_cpu[0][t],
|
466 |
+
prop_local_frames_cpu[0][t], pred_local_frames_cpu[0][t]], 1)
|
467 |
+
img_results = torchvision.utils.make_grid(img_results, nrow=1, normalize=True)
|
468 |
+
if self.gen_writer is not None:
|
469 |
+
self.gen_writer.add_image(f'img/img:inp-gt-res-{t}', img_results, self.iteration)
|
470 |
+
|
471 |
+
# flow to cpu
|
472 |
+
gt_flows_forward_cpu = flow_to_image(gt_flows_bi[0][0]).cpu()
|
473 |
+
masked_flows_forward_cpu = (gt_flows_forward_cpu[0] * (1-local_masks[0][0].cpu())).to(gt_flows_forward_cpu)
|
474 |
+
pred_flows_forward_cpu = flow_to_image(pred_flows_bi[0][0]).cpu()
|
475 |
+
|
476 |
+
flow_results = torch.cat([gt_flows_forward_cpu[0], masked_flows_forward_cpu, pred_flows_forward_cpu[0]], 1)
|
477 |
+
if self.gen_writer is not None:
|
478 |
+
self.gen_writer.add_image('img/flow:gt-pred', flow_results, self.iteration)
|
479 |
+
|
480 |
+
# console logs
|
481 |
+
if self.config['global_rank'] == 0:
|
482 |
+
pbar.update(1)
|
483 |
+
if not self.config['model']['no_dis']:
|
484 |
+
pbar.set_description((f"d: {dis_loss.item():.3f}; "
|
485 |
+
f"hole: {hole_loss.item():.3f}; "
|
486 |
+
f"valid: {valid_loss.item():.3f}"))
|
487 |
+
else:
|
488 |
+
pbar.set_description((f"hole: {hole_loss.item():.3f}; "
|
489 |
+
f"valid: {valid_loss.item():.3f}"))
|
490 |
+
|
491 |
+
if self.iteration % self.train_args['log_freq'] == 0:
|
492 |
+
if not self.config['model']['no_dis']:
|
493 |
+
logging.info(f"[Iter {self.iteration}] "
|
494 |
+
f"d: {dis_loss.item():.4f}; "
|
495 |
+
f"hole: {hole_loss.item():.4f}; "
|
496 |
+
f"valid: {valid_loss.item():.4f}")
|
497 |
+
else:
|
498 |
+
logging.info(f"[Iter {self.iteration}] "
|
499 |
+
f"hole: {hole_loss.item():.4f}; "
|
500 |
+
f"valid: {valid_loss.item():.4f}")
|
501 |
+
|
502 |
+
# saving models
|
503 |
+
if self.iteration % self.train_args['save_freq'] == 0:
|
504 |
+
self.save(int(self.iteration))
|
505 |
+
|
506 |
+
if self.iteration > self.train_args['iterations']:
|
507 |
+
break
|
508 |
+
|
509 |
+
train_data = self.prefetcher.next()
|
core/trainer_flow_w_edge.py
ADDED
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import logging
|
4 |
+
import importlib
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from core.prefetch_dataloader import PrefetchDataLoader, CPUPrefetcher
|
11 |
+
from torch.utils.data.distributed import DistributedSampler
|
12 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
13 |
+
|
14 |
+
from torch.utils.tensorboard import SummaryWriter
|
15 |
+
|
16 |
+
from core.lr_scheduler import MultiStepRestartLR, CosineAnnealingRestartLR
|
17 |
+
from core.dataset import TrainDataset
|
18 |
+
|
19 |
+
from model.modules.flow_comp_raft import RAFT_bi, FlowLoss, EdgeLoss
|
20 |
+
|
21 |
+
# from skimage.feature import canny
|
22 |
+
from model.canny.canny_filter import Canny
|
23 |
+
from RAFT.utils.flow_viz_pt import flow_to_image
|
24 |
+
|
25 |
+
|
26 |
+
class Trainer:
|
27 |
+
def __init__(self, config):
|
28 |
+
self.config = config
|
29 |
+
self.epoch = 0
|
30 |
+
self.iteration = 0
|
31 |
+
self.num_local_frames = config['train_data_loader']['num_local_frames']
|
32 |
+
self.num_ref_frames = config['train_data_loader']['num_ref_frames']
|
33 |
+
|
34 |
+
# setup data set and data loader
|
35 |
+
self.train_dataset = TrainDataset(config['train_data_loader'])
|
36 |
+
|
37 |
+
self.train_sampler = None
|
38 |
+
self.train_args = config['trainer']
|
39 |
+
if config['distributed']:
|
40 |
+
self.train_sampler = DistributedSampler(
|
41 |
+
self.train_dataset,
|
42 |
+
num_replicas=config['world_size'],
|
43 |
+
rank=config['global_rank'])
|
44 |
+
|
45 |
+
dataloader_args = dict(
|
46 |
+
dataset=self.train_dataset,
|
47 |
+
batch_size=self.train_args['batch_size'] // config['world_size'],
|
48 |
+
shuffle=(self.train_sampler is None),
|
49 |
+
num_workers=self.train_args['num_workers'],
|
50 |
+
sampler=self.train_sampler,
|
51 |
+
drop_last=True)
|
52 |
+
|
53 |
+
self.train_loader = PrefetchDataLoader(self.train_args['num_prefetch_queue'], **dataloader_args)
|
54 |
+
self.prefetcher = CPUPrefetcher(self.train_loader)
|
55 |
+
|
56 |
+
# set raft
|
57 |
+
self.fix_raft = RAFT_bi(device = self.config['device'])
|
58 |
+
self.flow_loss = FlowLoss()
|
59 |
+
self.edge_loss = EdgeLoss()
|
60 |
+
self.canny = Canny(sigma=(2,2), low_threshold=0.1, high_threshold=0.2)
|
61 |
+
|
62 |
+
# setup models including generator and discriminator
|
63 |
+
net = importlib.import_module('model.' + config['model']['net'])
|
64 |
+
self.netG = net.RecurrentFlowCompleteNet()
|
65 |
+
# print(self.netG)
|
66 |
+
self.netG = self.netG.to(self.config['device'])
|
67 |
+
|
68 |
+
# setup optimizers and schedulers
|
69 |
+
self.setup_optimizers()
|
70 |
+
self.setup_schedulers()
|
71 |
+
self.load()
|
72 |
+
|
73 |
+
if config['distributed']:
|
74 |
+
self.netG = DDP(self.netG,
|
75 |
+
device_ids=[self.config['local_rank']],
|
76 |
+
output_device=self.config['local_rank'],
|
77 |
+
broadcast_buffers=True,
|
78 |
+
find_unused_parameters=True)
|
79 |
+
|
80 |
+
# set summary writer
|
81 |
+
self.dis_writer = None
|
82 |
+
self.gen_writer = None
|
83 |
+
self.summary = {}
|
84 |
+
if self.config['global_rank'] == 0 or (not config['distributed']):
|
85 |
+
self.gen_writer = SummaryWriter(
|
86 |
+
os.path.join(config['save_dir'], 'gen'))
|
87 |
+
|
88 |
+
def setup_optimizers(self):
|
89 |
+
"""Set up optimizers."""
|
90 |
+
backbone_params = []
|
91 |
+
for name, param in self.netG.named_parameters():
|
92 |
+
if param.requires_grad:
|
93 |
+
backbone_params.append(param)
|
94 |
+
else:
|
95 |
+
print(f'Params {name} will not be optimized.')
|
96 |
+
|
97 |
+
optim_params = [
|
98 |
+
{
|
99 |
+
'params': backbone_params,
|
100 |
+
'lr': self.config['trainer']['lr']
|
101 |
+
},
|
102 |
+
]
|
103 |
+
|
104 |
+
self.optimG = torch.optim.Adam(optim_params,
|
105 |
+
betas=(self.config['trainer']['beta1'],
|
106 |
+
self.config['trainer']['beta2']))
|
107 |
+
|
108 |
+
|
109 |
+
def setup_schedulers(self):
|
110 |
+
"""Set up schedulers."""
|
111 |
+
scheduler_opt = self.config['trainer']['scheduler']
|
112 |
+
scheduler_type = scheduler_opt.pop('type')
|
113 |
+
|
114 |
+
if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
|
115 |
+
self.scheG = MultiStepRestartLR(
|
116 |
+
self.optimG,
|
117 |
+
milestones=scheduler_opt['milestones'],
|
118 |
+
gamma=scheduler_opt['gamma'])
|
119 |
+
elif scheduler_type == 'CosineAnnealingRestartLR':
|
120 |
+
self.scheG = CosineAnnealingRestartLR(
|
121 |
+
self.optimG,
|
122 |
+
periods=scheduler_opt['periods'],
|
123 |
+
restart_weights=scheduler_opt['restart_weights'])
|
124 |
+
else:
|
125 |
+
raise NotImplementedError(
|
126 |
+
f'Scheduler {scheduler_type} is not implemented yet.')
|
127 |
+
|
128 |
+
def update_learning_rate(self):
|
129 |
+
"""Update learning rate."""
|
130 |
+
self.scheG.step()
|
131 |
+
|
132 |
+
def get_lr(self):
|
133 |
+
"""Get current learning rate."""
|
134 |
+
return self.optimG.param_groups[0]['lr']
|
135 |
+
|
136 |
+
def add_summary(self, writer, name, val):
|
137 |
+
"""Add tensorboard summary."""
|
138 |
+
if name not in self.summary:
|
139 |
+
self.summary[name] = 0
|
140 |
+
self.summary[name] += val
|
141 |
+
n = self.train_args['log_freq']
|
142 |
+
if writer is not None and self.iteration % n == 0:
|
143 |
+
writer.add_scalar(name, self.summary[name] / n, self.iteration)
|
144 |
+
self.summary[name] = 0
|
145 |
+
|
146 |
+
def load(self):
|
147 |
+
"""Load netG."""
|
148 |
+
# get the latest checkpoint
|
149 |
+
model_path = self.config['save_dir']
|
150 |
+
if os.path.isfile(os.path.join(model_path, 'latest.ckpt')):
|
151 |
+
latest_epoch = open(os.path.join(model_path, 'latest.ckpt'),
|
152 |
+
'r').read().splitlines()[-1]
|
153 |
+
else:
|
154 |
+
ckpts = [
|
155 |
+
os.path.basename(i).split('.pth')[0]
|
156 |
+
for i in glob.glob(os.path.join(model_path, '*.pth'))
|
157 |
+
]
|
158 |
+
ckpts.sort()
|
159 |
+
latest_epoch = ckpts[-1][4:] if len(ckpts) > 0 else None
|
160 |
+
|
161 |
+
if latest_epoch is not None:
|
162 |
+
gen_path = os.path.join(model_path, f'gen_{int(latest_epoch):06d}.pth')
|
163 |
+
opt_path = os.path.join(model_path,f'opt_{int(latest_epoch):06d}.pth')
|
164 |
+
|
165 |
+
if self.config['global_rank'] == 0:
|
166 |
+
print(f'Loading model from {gen_path}...')
|
167 |
+
dataG = torch.load(gen_path, map_location=self.config['device'])
|
168 |
+
self.netG.load_state_dict(dataG)
|
169 |
+
|
170 |
+
|
171 |
+
data_opt = torch.load(opt_path, map_location=self.config['device'])
|
172 |
+
self.optimG.load_state_dict(data_opt['optimG'])
|
173 |
+
self.scheG.load_state_dict(data_opt['scheG'])
|
174 |
+
|
175 |
+
self.epoch = data_opt['epoch']
|
176 |
+
self.iteration = data_opt['iteration']
|
177 |
+
|
178 |
+
else:
|
179 |
+
if self.config['global_rank'] == 0:
|
180 |
+
print('Warnning: There is no trained model found.'
|
181 |
+
'An initialized model will be used.')
|
182 |
+
|
183 |
+
def save(self, it):
|
184 |
+
"""Save parameters every eval_epoch"""
|
185 |
+
if self.config['global_rank'] == 0:
|
186 |
+
# configure path
|
187 |
+
gen_path = os.path.join(self.config['save_dir'],
|
188 |
+
f'gen_{it:06d}.pth')
|
189 |
+
opt_path = os.path.join(self.config['save_dir'],
|
190 |
+
f'opt_{it:06d}.pth')
|
191 |
+
print(f'\nsaving model to {gen_path} ...')
|
192 |
+
|
193 |
+
# remove .module for saving
|
194 |
+
if isinstance(self.netG, torch.nn.DataParallel) or isinstance(self.netG, DDP):
|
195 |
+
netG = self.netG.module
|
196 |
+
else:
|
197 |
+
netG = self.netG
|
198 |
+
|
199 |
+
# save checkpoints
|
200 |
+
torch.save(netG.state_dict(), gen_path)
|
201 |
+
torch.save(
|
202 |
+
{
|
203 |
+
'epoch': self.epoch,
|
204 |
+
'iteration': self.iteration,
|
205 |
+
'optimG': self.optimG.state_dict(),
|
206 |
+
'scheG': self.scheG.state_dict()
|
207 |
+
}, opt_path)
|
208 |
+
|
209 |
+
latest_path = os.path.join(self.config['save_dir'], 'latest.ckpt')
|
210 |
+
os.system(f"echo {it:06d} > {latest_path}")
|
211 |
+
|
212 |
+
def train(self):
|
213 |
+
"""training entry"""
|
214 |
+
pbar = range(int(self.train_args['iterations']))
|
215 |
+
if self.config['global_rank'] == 0:
|
216 |
+
pbar = tqdm(pbar,
|
217 |
+
initial=self.iteration,
|
218 |
+
dynamic_ncols=True,
|
219 |
+
smoothing=0.01)
|
220 |
+
|
221 |
+
os.makedirs('logs', exist_ok=True)
|
222 |
+
|
223 |
+
logging.basicConfig(
|
224 |
+
level=logging.INFO,
|
225 |
+
format="%(asctime)s %(filename)s[line:%(lineno)d]"
|
226 |
+
"%(levelname)s %(message)s",
|
227 |
+
datefmt="%a, %d %b %Y %H:%M:%S",
|
228 |
+
filename=f"logs/{self.config['save_dir'].split('/')[-1]}.log",
|
229 |
+
filemode='w')
|
230 |
+
|
231 |
+
while True:
|
232 |
+
self.epoch += 1
|
233 |
+
self.prefetcher.reset()
|
234 |
+
if self.config['distributed']:
|
235 |
+
self.train_sampler.set_epoch(self.epoch)
|
236 |
+
self._train_epoch(pbar)
|
237 |
+
if self.iteration > self.train_args['iterations']:
|
238 |
+
break
|
239 |
+
print('\nEnd training....')
|
240 |
+
|
241 |
+
# def get_edges(self, flows): # fgvc
|
242 |
+
# # (b, t, 2, H, W)
|
243 |
+
# b, t, _, h, w = flows.shape
|
244 |
+
# flows = flows.view(-1, 2, h, w)
|
245 |
+
# flows_list = flows.permute(0, 2, 3, 1).cpu().numpy()
|
246 |
+
# edges = []
|
247 |
+
# for f in list(flows_list):
|
248 |
+
# flows_gray = (f[:, :, 0] ** 2 + f[:, :, 1] ** 2) ** 0.5
|
249 |
+
# if flows_gray.max() < 1:
|
250 |
+
# flows_gray = flows_gray*0
|
251 |
+
# else:
|
252 |
+
# flows_gray = flows_gray / flows_gray.max()
|
253 |
+
|
254 |
+
# edge = canny(flows_gray, sigma=2, low_threshold=0.1, high_threshold=0.2) # fgvc
|
255 |
+
# edge = torch.from_numpy(edge).view(1, 1, h, w).float()
|
256 |
+
# edges.append(edge)
|
257 |
+
# edges = torch.stack(edges, dim=0).to(self.config['device'])
|
258 |
+
# edges = edges.view(b, t, 1, h, w)
|
259 |
+
# return edges
|
260 |
+
|
261 |
+
def get_edges(self, flows):
|
262 |
+
# (b, t, 2, H, W)
|
263 |
+
b, t, _, h, w = flows.shape
|
264 |
+
flows = flows.view(-1, 2, h, w)
|
265 |
+
flows_gray = (flows[:, 0, None] ** 2 + flows[:, 1, None] ** 2) ** 0.5
|
266 |
+
if flows_gray.max() < 1:
|
267 |
+
flows_gray = flows_gray*0
|
268 |
+
else:
|
269 |
+
flows_gray = flows_gray / flows_gray.max()
|
270 |
+
|
271 |
+
magnitude, edges = self.canny(flows_gray.float())
|
272 |
+
edges = edges.view(b, t, 1, h, w)
|
273 |
+
return edges
|
274 |
+
|
275 |
+
def _train_epoch(self, pbar):
|
276 |
+
"""Process input and calculate loss every training epoch"""
|
277 |
+
device = self.config['device']
|
278 |
+
train_data = self.prefetcher.next()
|
279 |
+
while train_data is not None:
|
280 |
+
self.iteration += 1
|
281 |
+
frames, masks, flows_f, flows_b, _ = train_data
|
282 |
+
frames, masks = frames.to(device), masks.to(device)
|
283 |
+
masks = masks.float()
|
284 |
+
|
285 |
+
l_t = self.num_local_frames
|
286 |
+
b, t, c, h, w = frames.size()
|
287 |
+
gt_local_frames = frames[:, :l_t, ...]
|
288 |
+
local_masks = masks[:, :l_t, ...].contiguous()
|
289 |
+
|
290 |
+
# get gt optical flow
|
291 |
+
if flows_f[0] == 'None' or flows_b[0] == 'None':
|
292 |
+
gt_flows_bi = self.fix_raft(gt_local_frames)
|
293 |
+
else:
|
294 |
+
gt_flows_bi = (flows_f.to(device), flows_b.to(device))
|
295 |
+
|
296 |
+
# get gt edge
|
297 |
+
gt_edges_forward = self.get_edges(gt_flows_bi[0])
|
298 |
+
gt_edges_backward = self.get_edges(gt_flows_bi[1])
|
299 |
+
gt_edges_bi = [gt_edges_forward, gt_edges_backward]
|
300 |
+
|
301 |
+
# complete flow
|
302 |
+
pred_flows_bi, pred_edges_bi = self.netG.module.forward_bidirect_flow(gt_flows_bi, local_masks)
|
303 |
+
|
304 |
+
# optimize net_g
|
305 |
+
self.optimG.zero_grad()
|
306 |
+
|
307 |
+
# compulte flow_loss
|
308 |
+
flow_loss, warp_loss = self.flow_loss(pred_flows_bi, gt_flows_bi, local_masks, gt_local_frames)
|
309 |
+
flow_loss = flow_loss * self.config['losses']['flow_weight']
|
310 |
+
warp_loss = warp_loss * 0.01
|
311 |
+
self.add_summary(self.gen_writer, 'loss/flow_loss', flow_loss.item())
|
312 |
+
self.add_summary(self.gen_writer, 'loss/warp_loss', warp_loss.item())
|
313 |
+
|
314 |
+
# compute edge loss
|
315 |
+
edge_loss = self.edge_loss(pred_edges_bi, gt_edges_bi, local_masks)
|
316 |
+
edge_loss = edge_loss*1.0
|
317 |
+
self.add_summary(self.gen_writer, 'loss/edge_loss', edge_loss.item())
|
318 |
+
|
319 |
+
loss = flow_loss + warp_loss + edge_loss
|
320 |
+
loss.backward()
|
321 |
+
self.optimG.step()
|
322 |
+
self.update_learning_rate()
|
323 |
+
|
324 |
+
# write image to tensorboard
|
325 |
+
# if self.iteration % 200 == 0:
|
326 |
+
if self.iteration % 200 == 0 and self.gen_writer is not None:
|
327 |
+
t = 5
|
328 |
+
# forward to cpu
|
329 |
+
gt_flows_forward_cpu = flow_to_image(gt_flows_bi[0][0]).cpu()
|
330 |
+
masked_flows_forward_cpu = (gt_flows_forward_cpu[t] * (1-local_masks[0][t].cpu())).to(gt_flows_forward_cpu)
|
331 |
+
pred_flows_forward_cpu = flow_to_image(pred_flows_bi[0][0]).cpu()
|
332 |
+
|
333 |
+
flow_results = torch.cat([gt_flows_forward_cpu[t], masked_flows_forward_cpu, pred_flows_forward_cpu[t]], 1)
|
334 |
+
self.gen_writer.add_image('img/flow-f:gt-pred', flow_results, self.iteration)
|
335 |
+
|
336 |
+
# backward to cpu
|
337 |
+
gt_flows_backward_cpu = flow_to_image(gt_flows_bi[1][0]).cpu()
|
338 |
+
masked_flows_backward_cpu = (gt_flows_backward_cpu[t] * (1-local_masks[0][t+1].cpu())).to(gt_flows_backward_cpu)
|
339 |
+
pred_flows_backward_cpu = flow_to_image(pred_flows_bi[1][0]).cpu()
|
340 |
+
|
341 |
+
flow_results = torch.cat([gt_flows_backward_cpu[t], masked_flows_backward_cpu, pred_flows_backward_cpu[t]], 1)
|
342 |
+
self.gen_writer.add_image('img/flow-b:gt-pred', flow_results, self.iteration)
|
343 |
+
|
344 |
+
# TODO: show edge
|
345 |
+
# forward
|
346 |
+
gt_edges_forward_cpu = gt_edges_bi[0][0].cpu()
|
347 |
+
masked_edges_forward_cpu = (gt_edges_forward_cpu[t] * (1-local_masks[0][t].cpu())).to(gt_edges_forward_cpu)
|
348 |
+
pred_edges_forward_cpu = pred_edges_bi[0][0].cpu()
|
349 |
+
|
350 |
+
edge_results = torch.cat([gt_edges_forward_cpu[t], masked_edges_forward_cpu, pred_edges_forward_cpu[t]], 1)
|
351 |
+
self.gen_writer.add_image('img/edge-f:gt-pred', edge_results, self.iteration)
|
352 |
+
# backward
|
353 |
+
gt_edges_backward_cpu = gt_edges_bi[1][0].cpu()
|
354 |
+
masked_edges_backward_cpu = (gt_edges_backward_cpu[t] * (1-local_masks[0][t+1].cpu())).to(gt_edges_backward_cpu)
|
355 |
+
pred_edges_backward_cpu = pred_edges_bi[1][0].cpu()
|
356 |
+
|
357 |
+
edge_results = torch.cat([gt_edges_backward_cpu[t], masked_edges_backward_cpu, pred_edges_backward_cpu[t]], 1)
|
358 |
+
self.gen_writer.add_image('img/edge-b:gt-pred', edge_results, self.iteration)
|
359 |
+
|
360 |
+
# console logs
|
361 |
+
if self.config['global_rank'] == 0:
|
362 |
+
pbar.update(1)
|
363 |
+
pbar.set_description((f"flow: {flow_loss.item():.3f}; "
|
364 |
+
f"warp: {warp_loss.item():.3f}; "
|
365 |
+
f"edge: {edge_loss.item():.3f}; "
|
366 |
+
f"lr: {self.get_lr()}"))
|
367 |
+
|
368 |
+
if self.iteration % self.train_args['log_freq'] == 0:
|
369 |
+
logging.info(f"[Iter {self.iteration}] "
|
370 |
+
f"flow: {flow_loss.item():.4f}; "
|
371 |
+
f"warp: {warp_loss.item():.4f}")
|
372 |
+
|
373 |
+
# saving models
|
374 |
+
if self.iteration % self.train_args['save_freq'] == 0:
|
375 |
+
self.save(int(self.iteration))
|
376 |
+
|
377 |
+
if self.iteration > self.train_args['iterations']:
|
378 |
+
break
|
379 |
+
|
380 |
+
train_data = self.prefetcher.next()
|
core/utils.py
ADDED
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import io
|
3 |
+
import cv2
|
4 |
+
import random
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image, ImageOps
|
7 |
+
import zipfile
|
8 |
+
import math
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import matplotlib
|
12 |
+
import matplotlib.patches as patches
|
13 |
+
from matplotlib.path import Path
|
14 |
+
from matplotlib import pyplot as plt
|
15 |
+
from torchvision import transforms
|
16 |
+
|
17 |
+
# matplotlib.use('agg')
|
18 |
+
|
19 |
+
# ###########################################################################
|
20 |
+
# Directory IO
|
21 |
+
# ###########################################################################
|
22 |
+
|
23 |
+
|
24 |
+
def read_dirnames_under_root(root_dir):
|
25 |
+
dirnames = [
|
26 |
+
name for i, name in enumerate(sorted(os.listdir(root_dir)))
|
27 |
+
if os.path.isdir(os.path.join(root_dir, name))
|
28 |
+
]
|
29 |
+
print(f'Reading directories under {root_dir}, num: {len(dirnames)}')
|
30 |
+
return dirnames
|
31 |
+
|
32 |
+
|
33 |
+
class TrainZipReader(object):
|
34 |
+
file_dict = dict()
|
35 |
+
|
36 |
+
def __init__(self):
|
37 |
+
super(TrainZipReader, self).__init__()
|
38 |
+
|
39 |
+
@staticmethod
|
40 |
+
def build_file_dict(path):
|
41 |
+
file_dict = TrainZipReader.file_dict
|
42 |
+
if path in file_dict:
|
43 |
+
return file_dict[path]
|
44 |
+
else:
|
45 |
+
file_handle = zipfile.ZipFile(path, 'r')
|
46 |
+
file_dict[path] = file_handle
|
47 |
+
return file_dict[path]
|
48 |
+
|
49 |
+
@staticmethod
|
50 |
+
def imread(path, idx):
|
51 |
+
zfile = TrainZipReader.build_file_dict(path)
|
52 |
+
filelist = zfile.namelist()
|
53 |
+
filelist.sort()
|
54 |
+
data = zfile.read(filelist[idx])
|
55 |
+
#
|
56 |
+
im = Image.open(io.BytesIO(data))
|
57 |
+
return im
|
58 |
+
|
59 |
+
|
60 |
+
class TestZipReader(object):
|
61 |
+
file_dict = dict()
|
62 |
+
|
63 |
+
def __init__(self):
|
64 |
+
super(TestZipReader, self).__init__()
|
65 |
+
|
66 |
+
@staticmethod
|
67 |
+
def build_file_dict(path):
|
68 |
+
file_dict = TestZipReader.file_dict
|
69 |
+
if path in file_dict:
|
70 |
+
return file_dict[path]
|
71 |
+
else:
|
72 |
+
file_handle = zipfile.ZipFile(path, 'r')
|
73 |
+
file_dict[path] = file_handle
|
74 |
+
return file_dict[path]
|
75 |
+
|
76 |
+
@staticmethod
|
77 |
+
def imread(path, idx):
|
78 |
+
zfile = TestZipReader.build_file_dict(path)
|
79 |
+
filelist = zfile.namelist()
|
80 |
+
filelist.sort()
|
81 |
+
data = zfile.read(filelist[idx])
|
82 |
+
file_bytes = np.asarray(bytearray(data), dtype=np.uint8)
|
83 |
+
im = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
|
84 |
+
im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
|
85 |
+
# im = Image.open(io.BytesIO(data))
|
86 |
+
return im
|
87 |
+
|
88 |
+
|
89 |
+
# ###########################################################################
|
90 |
+
# Data augmentation
|
91 |
+
# ###########################################################################
|
92 |
+
|
93 |
+
|
94 |
+
def to_tensors():
|
95 |
+
return transforms.Compose([Stack(), ToTorchFormatTensor()])
|
96 |
+
|
97 |
+
|
98 |
+
class GroupRandomHorizontalFlowFlip(object):
|
99 |
+
"""Randomly horizontally flips the given PIL.Image with a probability of 0.5
|
100 |
+
"""
|
101 |
+
def __call__(self, img_group, flowF_group, flowB_group):
|
102 |
+
v = random.random()
|
103 |
+
if v < 0.5:
|
104 |
+
ret_img = [
|
105 |
+
img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group
|
106 |
+
]
|
107 |
+
ret_flowF = [ff[:, ::-1] * [-1.0, 1.0] for ff in flowF_group]
|
108 |
+
ret_flowB = [fb[:, ::-1] * [-1.0, 1.0] for fb in flowB_group]
|
109 |
+
return ret_img, ret_flowF, ret_flowB
|
110 |
+
else:
|
111 |
+
return img_group, flowF_group, flowB_group
|
112 |
+
|
113 |
+
|
114 |
+
class GroupRandomHorizontalFlip(object):
|
115 |
+
"""Randomly horizontally flips the given PIL.Image with a probability of 0.5
|
116 |
+
"""
|
117 |
+
def __call__(self, img_group, is_flow=False):
|
118 |
+
v = random.random()
|
119 |
+
if v < 0.5:
|
120 |
+
ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
|
121 |
+
if is_flow:
|
122 |
+
for i in range(0, len(ret), 2):
|
123 |
+
# invert flow pixel values when flipping
|
124 |
+
ret[i] = ImageOps.invert(ret[i])
|
125 |
+
return ret
|
126 |
+
else:
|
127 |
+
return img_group
|
128 |
+
|
129 |
+
|
130 |
+
class Stack(object):
|
131 |
+
def __init__(self, roll=False):
|
132 |
+
self.roll = roll
|
133 |
+
|
134 |
+
def __call__(self, img_group):
|
135 |
+
mode = img_group[0].mode
|
136 |
+
if mode == '1':
|
137 |
+
img_group = [img.convert('L') for img in img_group]
|
138 |
+
mode = 'L'
|
139 |
+
if mode == 'L':
|
140 |
+
return np.stack([np.expand_dims(x, 2) for x in img_group], axis=2)
|
141 |
+
elif mode == 'RGB':
|
142 |
+
if self.roll:
|
143 |
+
return np.stack([np.array(x)[:, :, ::-1] for x in img_group],
|
144 |
+
axis=2)
|
145 |
+
else:
|
146 |
+
return np.stack(img_group, axis=2)
|
147 |
+
else:
|
148 |
+
raise NotImplementedError(f"Image mode {mode}")
|
149 |
+
|
150 |
+
|
151 |
+
class ToTorchFormatTensor(object):
|
152 |
+
""" Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
|
153 |
+
to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
|
154 |
+
def __init__(self, div=True):
|
155 |
+
self.div = div
|
156 |
+
|
157 |
+
def __call__(self, pic):
|
158 |
+
if isinstance(pic, np.ndarray):
|
159 |
+
# numpy img: [L, C, H, W]
|
160 |
+
img = torch.from_numpy(pic).permute(2, 3, 0, 1).contiguous()
|
161 |
+
else:
|
162 |
+
# handle PIL Image
|
163 |
+
img = torch.ByteTensor(torch.ByteStorage.from_buffer(
|
164 |
+
pic.tobytes()))
|
165 |
+
img = img.view(pic.size[1], pic.size[0], len(pic.mode))
|
166 |
+
# put it from HWC to CHW format
|
167 |
+
# yikes, this transpose takes 80% of the loading time/CPU
|
168 |
+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
169 |
+
img = img.float().div(255) if self.div else img.float()
|
170 |
+
return img
|
171 |
+
|
172 |
+
|
173 |
+
# ###########################################################################
|
174 |
+
# Create masks with random shape
|
175 |
+
# ###########################################################################
|
176 |
+
|
177 |
+
|
178 |
+
def create_random_shape_with_random_motion(video_length,
|
179 |
+
imageHeight=240,
|
180 |
+
imageWidth=432):
|
181 |
+
# get a random shape
|
182 |
+
height = random.randint(imageHeight // 3, imageHeight - 1)
|
183 |
+
width = random.randint(imageWidth // 3, imageWidth - 1)
|
184 |
+
edge_num = random.randint(6, 8)
|
185 |
+
ratio = random.randint(6, 8) / 10
|
186 |
+
|
187 |
+
region = get_random_shape(edge_num=edge_num,
|
188 |
+
ratio=ratio,
|
189 |
+
height=height,
|
190 |
+
width=width)
|
191 |
+
region_width, region_height = region.size
|
192 |
+
# get random position
|
193 |
+
x, y = random.randint(0, imageHeight - region_height), random.randint(
|
194 |
+
0, imageWidth - region_width)
|
195 |
+
velocity = get_random_velocity(max_speed=3)
|
196 |
+
m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8))
|
197 |
+
m.paste(region, (y, x, y + region.size[0], x + region.size[1]))
|
198 |
+
masks = [m.convert('L')]
|
199 |
+
# return fixed masks
|
200 |
+
if random.uniform(0, 1) > 0.5:
|
201 |
+
return masks * video_length
|
202 |
+
# return moving masks
|
203 |
+
for _ in range(video_length - 1):
|
204 |
+
x, y, velocity = random_move_control_points(x,
|
205 |
+
y,
|
206 |
+
imageHeight,
|
207 |
+
imageWidth,
|
208 |
+
velocity,
|
209 |
+
region.size,
|
210 |
+
maxLineAcceleration=(3,
|
211 |
+
0.5),
|
212 |
+
maxInitSpeed=3)
|
213 |
+
m = Image.fromarray(
|
214 |
+
np.zeros((imageHeight, imageWidth)).astype(np.uint8))
|
215 |
+
m.paste(region, (y, x, y + region.size[0], x + region.size[1]))
|
216 |
+
masks.append(m.convert('L'))
|
217 |
+
return masks
|
218 |
+
|
219 |
+
|
220 |
+
def create_random_shape_with_random_motion_zoom_rotation(video_length, zoomin=0.9, zoomout=1.1, rotmin=1, rotmax=10, imageHeight=240, imageWidth=432):
|
221 |
+
# get a random shape
|
222 |
+
assert zoomin < 1, "Zoom-in parameter must be smaller than 1"
|
223 |
+
assert zoomout > 1, "Zoom-out parameter must be larger than 1"
|
224 |
+
assert rotmin < rotmax, "Minimum value of rotation must be smaller than maximun value !"
|
225 |
+
height = random.randint(imageHeight//3, imageHeight-1)
|
226 |
+
width = random.randint(imageWidth//3, imageWidth-1)
|
227 |
+
edge_num = random.randint(6, 8)
|
228 |
+
ratio = random.randint(6, 8)/10
|
229 |
+
region = get_random_shape(
|
230 |
+
edge_num=edge_num, ratio=ratio, height=height, width=width)
|
231 |
+
region_width, region_height = region.size
|
232 |
+
# get random position
|
233 |
+
x, y = random.randint(
|
234 |
+
0, imageHeight-region_height), random.randint(0, imageWidth-region_width)
|
235 |
+
velocity = get_random_velocity(max_speed=3)
|
236 |
+
m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8))
|
237 |
+
m.paste(region, (y, x, y+region.size[0], x+region.size[1]))
|
238 |
+
masks = [m.convert('L')]
|
239 |
+
# return fixed masks
|
240 |
+
if random.uniform(0, 1) > 0.5:
|
241 |
+
return masks*video_length # -> directly copy all the base masks
|
242 |
+
# return moving masks
|
243 |
+
for _ in range(video_length-1):
|
244 |
+
x, y, velocity = random_move_control_points(
|
245 |
+
x, y, imageHeight, imageWidth, velocity, region.size, maxLineAcceleration=(3, 0.5), maxInitSpeed=3)
|
246 |
+
m = Image.fromarray(
|
247 |
+
np.zeros((imageHeight, imageWidth)).astype(np.uint8))
|
248 |
+
### add by kaidong, to simulate zoon-in, zoom-out and rotation
|
249 |
+
extra_transform = random.uniform(0, 1)
|
250 |
+
# zoom in and zoom out
|
251 |
+
if extra_transform > 0.75:
|
252 |
+
resize_coefficient = random.uniform(zoomin, zoomout)
|
253 |
+
region = region.resize((math.ceil(region_width * resize_coefficient), math.ceil(region_height * resize_coefficient)), Image.NEAREST)
|
254 |
+
m.paste(region, (y, x, y + region.size[0], x + region.size[1]))
|
255 |
+
region_width, region_height = region.size
|
256 |
+
# rotation
|
257 |
+
elif extra_transform > 0.5:
|
258 |
+
m.paste(region, (y, x, y + region.size[0], x + region.size[1]))
|
259 |
+
m = m.rotate(random.randint(rotmin, rotmax))
|
260 |
+
# region_width, region_height = region.size
|
261 |
+
### end
|
262 |
+
else:
|
263 |
+
m.paste(region, (y, x, y+region.size[0], x+region.size[1]))
|
264 |
+
masks.append(m.convert('L'))
|
265 |
+
return masks
|
266 |
+
|
267 |
+
|
268 |
+
def get_random_shape(edge_num=9, ratio=0.7, width=432, height=240):
|
269 |
+
'''
|
270 |
+
There is the initial point and 3 points per cubic bezier curve.
|
271 |
+
Thus, the curve will only pass though n points, which will be the sharp edges.
|
272 |
+
The other 2 modify the shape of the bezier curve.
|
273 |
+
edge_num, Number of possibly sharp edges
|
274 |
+
points_num, number of points in the Path
|
275 |
+
ratio, (0, 1) magnitude of the perturbation from the unit circle,
|
276 |
+
'''
|
277 |
+
points_num = edge_num*3 + 1
|
278 |
+
angles = np.linspace(0, 2*np.pi, points_num)
|
279 |
+
codes = np.full(points_num, Path.CURVE4)
|
280 |
+
codes[0] = Path.MOVETO
|
281 |
+
# Using this instead of Path.CLOSEPOLY avoids an innecessary straight line
|
282 |
+
verts = np.stack((np.cos(angles), np.sin(angles))).T * \
|
283 |
+
(2*ratio*np.random.random(points_num)+1-ratio)[:, None]
|
284 |
+
verts[-1, :] = verts[0, :]
|
285 |
+
path = Path(verts, codes)
|
286 |
+
# draw paths into images
|
287 |
+
fig = plt.figure()
|
288 |
+
ax = fig.add_subplot(111)
|
289 |
+
patch = patches.PathPatch(path, facecolor='black', lw=2)
|
290 |
+
ax.add_patch(patch)
|
291 |
+
ax.set_xlim(np.min(verts)*1.1, np.max(verts)*1.1)
|
292 |
+
ax.set_ylim(np.min(verts)*1.1, np.max(verts)*1.1)
|
293 |
+
ax.axis('off') # removes the axis to leave only the shape
|
294 |
+
fig.canvas.draw()
|
295 |
+
# convert plt images into numpy images
|
296 |
+
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
297 |
+
data = data.reshape((fig.canvas.get_width_height()[::-1] + (3,)))
|
298 |
+
plt.close(fig)
|
299 |
+
# postprocess
|
300 |
+
data = cv2.resize(data, (width, height))[:, :, 0]
|
301 |
+
data = (1 - np.array(data > 0).astype(np.uint8))*255
|
302 |
+
corrdinates = np.where(data > 0)
|
303 |
+
xmin, xmax, ymin, ymax = np.min(corrdinates[0]), np.max(
|
304 |
+
corrdinates[0]), np.min(corrdinates[1]), np.max(corrdinates[1])
|
305 |
+
region = Image.fromarray(data).crop((ymin, xmin, ymax, xmax))
|
306 |
+
return region
|
307 |
+
|
308 |
+
|
309 |
+
def random_accelerate(velocity, maxAcceleration, dist='uniform'):
|
310 |
+
speed, angle = velocity
|
311 |
+
d_speed, d_angle = maxAcceleration
|
312 |
+
if dist == 'uniform':
|
313 |
+
speed += np.random.uniform(-d_speed, d_speed)
|
314 |
+
angle += np.random.uniform(-d_angle, d_angle)
|
315 |
+
elif dist == 'guassian':
|
316 |
+
speed += np.random.normal(0, d_speed / 2)
|
317 |
+
angle += np.random.normal(0, d_angle / 2)
|
318 |
+
else:
|
319 |
+
raise NotImplementedError(
|
320 |
+
f'Distribution type {dist} is not supported.')
|
321 |
+
return (speed, angle)
|
322 |
+
|
323 |
+
|
324 |
+
def get_random_velocity(max_speed=3, dist='uniform'):
|
325 |
+
if dist == 'uniform':
|
326 |
+
speed = np.random.uniform(max_speed)
|
327 |
+
elif dist == 'guassian':
|
328 |
+
speed = np.abs(np.random.normal(0, max_speed / 2))
|
329 |
+
else:
|
330 |
+
raise NotImplementedError(
|
331 |
+
f'Distribution type {dist} is not supported.')
|
332 |
+
angle = np.random.uniform(0, 2 * np.pi)
|
333 |
+
return (speed, angle)
|
334 |
+
|
335 |
+
|
336 |
+
def random_move_control_points(X,
|
337 |
+
Y,
|
338 |
+
imageHeight,
|
339 |
+
imageWidth,
|
340 |
+
lineVelocity,
|
341 |
+
region_size,
|
342 |
+
maxLineAcceleration=(3, 0.5),
|
343 |
+
maxInitSpeed=3):
|
344 |
+
region_width, region_height = region_size
|
345 |
+
speed, angle = lineVelocity
|
346 |
+
X += int(speed * np.cos(angle))
|
347 |
+
Y += int(speed * np.sin(angle))
|
348 |
+
lineVelocity = random_accelerate(lineVelocity,
|
349 |
+
maxLineAcceleration,
|
350 |
+
dist='guassian')
|
351 |
+
if ((X > imageHeight - region_height) or (X < 0)
|
352 |
+
or (Y > imageWidth - region_width) or (Y < 0)):
|
353 |
+
lineVelocity = get_random_velocity(maxInitSpeed, dist='guassian')
|
354 |
+
new_X = np.clip(X, 0, imageHeight - region_height)
|
355 |
+
new_Y = np.clip(Y, 0, imageWidth - region_width)
|
356 |
+
return new_X, new_Y, lineVelocity
|
357 |
+
|
358 |
+
|
359 |
+
if __name__ == '__main__':
|
360 |
+
|
361 |
+
trials = 10
|
362 |
+
for _ in range(trials):
|
363 |
+
video_length = 10
|
364 |
+
# The returned masks are either stationary (50%) or moving (50%)
|
365 |
+
masks = create_random_shape_with_random_motion(video_length,
|
366 |
+
imageHeight=240,
|
367 |
+
imageWidth=432)
|
368 |
+
|
369 |
+
for m in masks:
|
370 |
+
cv2.imshow('mask', np.array(m))
|
371 |
+
cv2.waitKey(500)
|
datasets/davis/test.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"bear": 82, "blackswan": 50, "bmx-bumps": 90, "bmx-trees": 80, "boat": 75, "breakdance": 84, "breakdance-flare": 71, "bus": 80, "camel": 90, "car-roundabout": 75, "car-shadow": 40, "car-turn": 80, "cows": 104, "dance-jump": 60, "dance-twirl": 90, "dog": 60, "dog-agility": 25, "drift-chicane": 52, "drift-straight": 50, "drift-turn": 64, "elephant": 80, "flamingo": 80, "goat": 90, "hike": 80, "hockey": 75, "horsejump-high": 50, "horsejump-low": 60, "kite-surf": 50, "kite-walk": 80, "libby": 49, "lucia": 70, "mallard-fly": 70, "mallard-water": 80, "motocross-bumps": 60, "motocross-jump": 40, "motorbike": 43, "paragliding": 70, "paragliding-launch": 80, "parkour": 100, "rhino": 90, "rollerblade": 35, "scooter-black": 43, "scooter-gray": 75, "soapbox": 99, "soccerball": 48, "stroller": 91, "surf": 55, "swing": 60, "tennis": 70, "train": 80}
|
datasets/davis/train.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"baseball": 90, "basketball-game": 77, "bears-ball": 78, "bmx-rider": 85, "butterfly": 80, "car-competition": 66, "cat": 52, "chairlift": 99, "circus": 73, "city-ride": 70, "crafting": 45, "curling": 76, "dog-competition": 85, "dolphins-show": 74, "dribbling": 49, "drone-flying": 70, "ducks": 75, "elephant-hyenas": 55, "giraffes": 88, "gym-ball": 69, "helicopter-landing": 77, "horse-race": 80, "horses-kids": 78, "hurdles-race": 55, "ice-hockey": 52, "jet-ski": 83, "juggling-selfie": 78, "kayak-race": 63, "kids-robot": 75, "landing": 35, "luggage": 83, "mantaray": 73, "marbles": 70, "mascot": 78, "mermaid": 78, "monster-trucks": 99, "motorbike-indoors": 79, "motorbike-race": 88, "music-band": 87, "obstacles": 81, "obstacles-race": 48, "peacock": 75, "plane-exhibition": 73, "puppet": 100, "robot-battle": 85, "robotic-arm": 82, "rodeo": 85, "sea-turtle": 90, "skydiving-jumping": 75, "snowboard-race": 75, "snowboard-sand": 55, "surfer": 80, "swimmer": 86, "table-tennis": 70, "tram": 84, "trucks-race": 78, "twist-dance": 83, "volleyball-beach": 73, "water-slide": 88, "weightlifting": 90}
|
datasets/youtube-vos/test.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"0070461469": 91, "00bd64cb00": 180, "00fef116ee": 96, "012257ffcf": 180, "01475d1fe7": 180, "0163b18674": 96, "017fa2adaa": 180, "0232ba85ed": 180, "02b1a46f42": 180, "02caec8ac0": 91, "047436c72c": 96, "0481e165b4": 150, "04f98557e7": 144, "05e73c3ecb": 96, "08f95ce1ff": 144, "0b6db1c6fd": 96, "0bd8c18197": 180, "0c6d13ee2c": 91, "0c7ba00455": 96, "0cba3e52eb": 91, "0d16524447": 150, "0d4827437d": 150, "0d62fa582a": 180, "0e1f91c0d7": 91, "0ef454b3f0": 91, "10e18fcf0c": 96, "11105e147e": 91, "11444b16da": 91, "11a4df37a4": 180, "11b3298d6a": 96, "13006c4c7e": 96, "1345523ba1": 180, "144a16eb12": 180, "15a6536e74": 180, "1616507c9e": 180, "1655f4782a": 92, "16608ccef6": 96, "16bc05b66c": 150, "16f1e1779b": 96, "17caf00e26": 96, "18f1e2f716": 91, "191a0bfcdf": 180, "19d4acf831": 91, "1a1dc21969": 96, "1a72d9fcea": 150, "1a92c81edd": 180, "1b2c2022a3": 96, "1d1601d079": 180, "1db7b25d1c": 180, "1dee5b7b5a": 150, "1e0c2e54f2": 96, "1e458b1539": 92, "1e6ac08c86": 91, "1e790eae99": 56, "1ed0c6ca5b": 96, "1edbdb6d18": 180, "1f2015e056": 96, "215ac56b15": 180, "2233485b49": 96, "224d171af6": 180, "237c6ebaf4": 91, "2462c51412": 96, "24bf968338": 180, "250d5953a0": 150, "25bcf222fb": 180, "25ea8feecf": 150, "25fc493839": 92, "262f69837e": 180, "264ca20298": 180, "26d8d48248": 51, "270f84c5e5": 91, "27889bc0fe": 180, "29b87846e7": 96, "29d2e79171": 180, "2a44411a3d": 180, "2b426fd330": 180, "2c4c4e2d5b": 180, "2c4c718eda": 180, "2c962c1bbe": 180, "2cc841341c": 92, "2cf6c4d17e": 91, "2d7ef0be04": 180, "2e5e52c6c8": 150, "2ef6fce8c6": 144, "3014e769bf": 180, "30d5f163b6": 180, "318df73d6a": 90, "31fbb9df3c": 96, "3255fcad2f": 180, "3303eea8e4": 91, "3447c30052": 150, "362722660c": 180, "37e0b4642b": 91, "383e51ed93": 180, "386b050bd0": 41, "3876ba3136": 180, "388ec2934c": 180, "38b45d9c6b": 96, "396680839c": 150, "39ffa3a4a4": 180, "3b0291b2be": 150, "3b333693f4": 180, "3bde1da2cf": 96, "3c5f4e6672": 91, "3c80682cc6": 92, "3ce634a1c1": 180, "3d6a761295": 96, "3da878c317": 91, "3db571b7ee": 96, "3e2336812c": 180, "3f16b04d6d": 96, "3fbbc75c5e": 180, "4015a1e1cc": 87, "406cd7bd48": 91, "407b87ba26": 91, "40a5628dcc": 91, "41af239f5e": 180, "42c671b285": 180, "42de37f462": 180, "4381c60a2f": 180, "4445dc0af5": 180, "44a3419d24": 180, "4566034eaf": 51, "45877fd086": 180, "4595935b88": 91, "4923010cfe": 96, "49b6d81ee8": 180, "4a39c34139": 180, "4a5a9fde01": 144, "4a90394892": 180, "4af10534e4": 180, "4af307f5bc": 180, "4be0ac97df": 91, "4be9025726": 91, "4c18a7bfab": 91, "4c269afea9": 91, "4c3db058db": 179, "4e1ef26a1e": 96, "50f4c0195b": 150, "50f89963c0": 96, "5105c5e4b8": 180, "51d60e4f93": 46, "51ee638399": 96, "522ea1a892": 180, "528e9f30e7": 91, "532efb206a": 180, "544b1486ac": 91, "5592eb680c": 180, "562fadda3a": 91, "568b30cf93": 150, "575f0e2d8e": 91, "5767fe466c": 150, "581c78d558": 180, "5a0ddcf128": 96, "5adf056317": 144, "5b33c701ce": 180, "5b8f636b33": 150, "5b9d26b1d7": 180, "5c24813a0b": 180, "5d0b35f30f": 46, "5e130392e1": 96, "5e41efe5bc": 180, "5e75de78ae": 91, "5fc34880f7": 180, "60912d6bab": 96, "612c96383d": 180, "61e5fd2205": 144, "620e350d23": 180, "62c27fcaaf": 180, "637c22d967": 91, "63eaebe4a2": 96, "63fd6b311e": 180, "64099f32ab": 180, "65643c4b34": 96, "660a88feb5": 180, "664b8d0c9f": 150, "665a7947b0": 180, "66affc2e86": 180, "673b1c03c9": 96, "67780f49c2": 91, "679a24b7bd": 180, "680d35b75b": 144, "68364a69ef": 180, "683bfaf498": 180, "68e883ff28": 180, "691f63f681": 180, "69f2d3146c": 96, "6c5c018237": 91, "6caa33f43a": 96, "6d2c7cc107": 180, "6d55effbbe": 144, "6d6b09b420": 51, "6d715acc3e": 180, "6e89b7359d": 96, "6e9428d555": 150, "6e9feafa2b": 91, "6eced45fee": 180, "6ef0b3282c": 96, "6f9019f0ea": 91, "6fe0ee9b7c": 180, "6ff74d4995": 180, "712b6ec68e": 96, "71680a627f": 96, "716aad4b56": 180, "721c2cda07": 180, "72218d52ac": 96, "7286b8aac9": 91, "728ba7998d": 91, "73b2b9af5f": 96, "7452941f4f": 180, "759d8249dd": 91, "75a55907dc": 150, "75f3a2a19e": 150, "77e7e4b1a1": 144, "7898e6542c": 180, "78e639c2c4": 91, "79091168f8": 180, "7ad5af3fe6": 180, "7b1a7dec16": 150, "7b36c4c3db": 180, "7b455d07cc": 150, "7bce4cfa48": 180, "7c064444d0": 144, "7c8014406a": 91, "7cb70182e5": 96, "7d04e540f5": 91, "7d5df020bf": 96, "7dfda4322c": 96, "7e6a27cc7c": 96, "7e9e344bf4": 180, "7eb9424a53": 180, "7ec8ea61f4": 91, "7fd2806fb0": 180, "8006501830": 150, "8014aeb412": 180, "80d1d22999": 180, "812f31be15": 144, "81312af68f": 92, "82843a1676": 150, "835aea9584": 36, "8366c67e9b": 180, "8467aa6c5c": 180, "8470ee5f48": 180, "8473ae2c60": 180, "8519765a65": 150, "851f73e4fc": 96, "85621c2c81": 150, "85b045995c": 180, "860c0a7cf8": 92, "861bd4b31e": 180, "8639adb930": 180, "8683e4d414": 150, "8687e892ff": 180, "86c5907811": 180, "870c197c8b": 180, "87de455fb7": 180, "87e1975888": 96, "87f5d4903c": 96, "883ede763d": 150, "88b84fe107": 91, "88ee198ce0": 91, "89d148a39f": 96, "89f3d789c5": 180, "8a22bb6c32": 180, "8a76048654": 180, "8a99d63296": 97, "8b0697f61a": 96, "8b722babfb": 180, "8ba5691030": 180, "8bdd52a66b": 150, "8c427b6a57": 180, "8cb68f36f6": 91, "8cbf0d6194": 180, "8d1ab4a2ed": 91, "8d55a5aebb": 180, "8d8c5906bd": 180, "8eb95e2e56": 150, "8f99788aa7": 180, "8fa5b3778f": 91, "9009ab4811": 91, "90c10e44cf": 91, "90c2c5c336": 96, "9124189275": 91, "91ee8300e7": 144, "9246556dfd": 91, "9323741e3b": 150, "94a33d3d20": 180, "9584210f86": 91, "9637e3b658": 51, "966c4c022e": 180, "9781e083b5": 180, "990d358980": 180, "995c087687": 150, "99a7d42674": 144, "99f056c109": 180, "9a29032b9c": 180, "9b07fc4cf6": 180, "9b5aa49509": 96, "9b5abb8108": 91, "9be210e984": 150, "9c3c28740e": 180, "9cace717c5": 180, "9d3ff7c1c1": 91, "9d8c66d92c": 150, "9eaa2f1fcc": 91, "9f1967f60f": 96, "9fa359e1cb": 150, "9fca469ddd": 96, "9ff11b620a": 180, "9ff655b9a3": 180, "a029b21901": 180, "a0c7eedeb8": 144, "a15e70486b": 180, "a35bef8bbf": 180, "a4309379a2": 91, "a51335af59": 96, "a5690fb3bf": 180, "a5b71f76fb": 86, "a5c8b1f945": 150, "a635426233": 150, "a73cc75b81": 144, "a7863d3903": 180, "a88f1fd4e3": 144, "aa2e90aa98": 144, "aab5ecf878": 91, "aafc5edf08": 96, "ab49400ffe": 180, "acd7b890f6": 91, "ad3ee9b86b": 180, "ad5fda372c": 144, "adb2040e5f": 91, "ae30aed29d": 180, "ae57b941a0": 180, "aeb9de8f66": 41, "af658a277c": 91, "af881cd801": 150, "b016a85236": 180, "b0313efe37": 96, "b19d6e149a": 120, "b19f091836": 180, "b2304e81df": 144, "b2d23dcf3a": 150, "b3cee57f31": 36, "b41a7ebfc6": 180, "b455f801b5": 46, "b47336c07b": 96, "b499ce791f": 180, "b52d26ddf9": 96, "b5c525cb08": 180, "b5d3b9be03": 91, "b6386bc3ce": 96, "b748b0f3be": 180, "b75e9ea782": 180, "b8237af453": 180, "b8a2104720": 96, "b8d6f92a65": 96, "b8f93a4094": 180, "bb0a1708ea": 180, "bb2245ab94": 180, "bb4ae8019f": 180, "bbdc38baa0": 76, "bbfe438d63": 96, "bc2be9fdc8": 96, "bcc00265f4": 96, "bd42cc48e4": 150, "bd43315417": 180, "bd85b04982": 51, "bda3146a46": 96, "be2b40d82a": 150, "c0f856e4de": 96, "c1bfacba4a": 91, "c1dcd30fb2": 96, "c285ede7f3": 180, "c2a6163d39": 150, "c3517ebed5": 86, "c3aabac30c": 180, "c3bb62a2f7": 144, "c454f19e90": 150, "c4c410ccd7": 180, "c5b94822e3": 180, "c64e9d1f7e": 91, "c682d1748f": 150, "c6d04b1ca3": 180, "c6dda81d86": 180, "c71623ab0c": 180, "c7db88a9db": 144, "c80ecb97d6": 150, "c8dd4de705": 180, "c915c8cbba": 150, "cb25a994d8": 144, "cba3e31e88": 91, "cc43a853e2": 180, "cc6c653874": 180, "cc718c7746": 180, "cc7e050f7f": 144, "cd14ed8653": 144, "cd5e4efaad": 46, "cddf78284d": 86, "cde37afe57": 144, "ce358eaf23": 150, "ce45145721": 91, "ce7d4af66d": 180, "ce9fb4bd8e": 91, "cec4db17a0": 180, "cecdd82d3c": 180, "ceea39e735": 180, "cf3e28c92a": 180, "cf8c671dab": 150, "cfd1e8166f": 96, "cfe7d98e50": 150, "cff0bbcba8": 96, "d1219663b7": 180, "d18ea7cd51": 180, "d1ed509b94": 91, "d22c5d5908": 81, "d2c6c7d8f6": 96, "d380084b7c": 91, "d3a2586e34": 180, "d3b1039c67": 180, "d3b25a44b3": 180, "d3f1d615b1": 180, "d7203fdab6": 96, "d76e963754": 96, "d7b3892660": 66, "d8b3e257da": 150, "d8b93e6bb1": 180, "d949468ad6": 180, "da553b619f": 180, "daac20af89": 180, "db8bf2430a": 180, "dbd729449a": 180, "dc0928b157": 91, "dc9aa0b8c0": 180, "dcc0637430": 180, "dcd3e1b53e": 86, "de1854f657": 101, "deb31e46cf": 96, "debccf2743": 150, "decf924833": 150, "e08b241b91": 180, "e0daa3b339": 180, "e1a52251b7": 180, "e1fc6d5237": 91, "e228ce16fd": 96, "e36dbb2ab7": 91, "e3dcf7a45e": 180, "e411e957af": 180, "e412e6a76b": 180, "e45a003b97": 179, "e60826ddf9": 91, "e6295c843b": 96, "e62c23b62b": 150, "e6b7a8fe73": 180, "e6f0e3131c": 180, "e7a3f8884e": 180, "e7c176739c": 180, "e965cd989b": 86, "e989440f7b": 150, "e98d115b9c": 81, "ea5f8c74d6": 180, "ea8a5b5a78": 96, "eaad295e8c": 150, "eaf4947f74": 180, "eb65451f4b": 92, "eb79c39e8e": 180, "eb92c92912": 96, "ebbb88e5f5": 180, "ec9b46eb6c": 180, "eca0be379d": 180, "ed33e8efb7": 66, "eda3a7bbb1": 150, "ee3ff10184": 180, "eec8403cc8": 91, "eee2db8829": 150, "ef22b8a227": 91, "ef8737ca22": 180, "eff7c1c098": 180, "f00dc892b2": 96, "f019c9ff98": 96, "f01edcbffb": 179, "f0866da89c": 180, "f12eb5256e": 180, "f1df2ea2dc": 180, "f29119c644": 180, "f3419f3a62": 150, "f35029f76d": 180, "f39dc2240d": 180, "f3aa63fa74": 150, "f3f3c201bd": 180, "f4865471b4": 96, "f505ae958c": 91, "f7605e73cd": 150, "f7917687d6": 180, "f7d310e219": 180, "f7e25f87b2": 180, "f94cd39525": 91, "f9f9aa431c": 180, "fa666fcc95": 66, "fb10740465": 180, "fb25b14e48": 91, "fb28ec1ba3": 150, "fbdda5ec7b": 96, "fbdf2180ee": 150, "fc0db37221": 91, "fd237cf4fb": 180, "fe36582e18": 180, "fef14bb2f2": 180, "ffe59ed1c1": 150}
|
datasets/youtube-vos/train.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"003234408d": 180, "0043f083b5": 96, "0044fa5fba": 87, "005a527edd": 144, "0065b171f9": 180, "00917dcfc4": 96, "00a23ccf53": 180, "00ad5016a4": 91, "01082ae388": 150, "011ac0a06f": 180, "013099c098": 91, "0155498c85": 180, "01694ad9c8": 91, "017ac35701": 180, "01b80e8e1a": 61, "01baa5a4e1": 150, "01c3111683": 180, "01c4cb5ffe": 180, "01c76f0a82": 96, "01c783268c": 180, "01e64dd36a": 91, "01ed275c6e": 96, "01ff60d1fa": 180, "020cd28cd2": 150, "02264db755": 180, "0248626d9a": 91, "02668dbffa": 150, "0274193026": 144, "02d28375aa": 180, "02f3a5c4df": 46, "031ccc99b1": 91, "0321b18c10": 92, "0348a45bca": 180, "0355e92655": 92, "0358b938c1": 91, "0368107cf1": 96, "0379ddf557": 180, "038b2cc71d": 91, "038c15a5dd": 178, "03a06cc98a": 96, "03a63e187f": 180, "03c95b4dae": 92, "03e2b57b0e": 150, "04194e1248": 180, "04259896e2": 180, "0444918a5f": 96, "04460a7a52": 180, "04474174a4": 180, "0450095513": 150, "045f00aed2": 180, "04667fabaa": 180, "04735c5030": 91, "04990d1915": 92, "04d62d9d98": 96, "04f21da964": 180, "04fbad476e": 180, "04fe256562": 96, "0503bf89c9": 150, "0536c9eed0": 92, "054acb238f": 180, "05579ca250": 150, "056c200404": 96, "05774f3a2c": 180, "058a7592c8": 96, "05a0a513df": 96, "05a569d8aa": 91, "05aa652648": 150, "05d7715782": 96, "05e0b0f28f": 150, "05fdbbdd7a": 66, "05ffcfed85": 180, "0630391881": 150, "06840b2bbe": 91, "068f7dce6f": 180, "0693719753": 150, "06ce2b51fb": 91, "06e224798e": 180, "06ee361788": 91, "06fbb3fa2c": 90, "0700264286": 96, "070c918ca7": 180, "07129e14a4": 180, "07177017e9": 86, "07238ffc58": 180, "07353b2a89": 150, "0738493cbf": 87, "075926c651": 87, "075c701292": 180, "0762ea9a30": 96, "07652ee4af": 150, "076f206928": 96, "077d32af19": 96, "079049275c": 144, "07913cdda7": 92, "07a11a35e8": 180, "07ac33b6df": 150, "07b6e8fda8": 46, "07c62c3d11": 180, "07cc1c7d74": 180, "080196ef01": 180, "081207976e": 96, "081ae4fa44": 150, "081d8250cb": 96, "082900c5d4": 96, "0860df21e2": 180, "0866d4c5e3": 91, "0891ac2eb6": 81, "08931bc458": 180, "08aa2705d5": 180, "08c8450db7": 96, "08d50b926c": 180, "08e1e4de15": 180, "08e48c1a48": 92, "08f561c65e": 180, "08feb87790": 96, "09049f6fe3": 150, "092e4ff450": 180, "09338adea8": 180, "093c335ccc": 144, "0970d28339": 180, "0974a213dc": 96, "097b471ed8": 96, "0990941758": 180, "09a348f4fa": 150, "09a6841288": 96, "09c5bad17b": 96, "09c9ce80c7": 180, "09ff54fef4": 150, "0a23765d15": 91, "0a275e7f12": 96, "0a2f2bd294": 96, "0a7a2514aa": 96, "0a7b27fde9": 180, "0a8c467cc3": 180, "0ac8c560ae": 96, "0b1627e896": 96, "0b285c47f6": 144, "0b34ec1d55": 180, "0b5b5e8e5a": 96, "0b68535614": 180, "0b6f9105fc": 180, "0b7dbfa3cb": 91, "0b9cea51ca": 180, "0b9d012be8": 180, "0bcfc4177d": 96, "0bd37b23c1": 96, "0bd864064c": 158, "0c11c6bf7b": 180, "0c26bc77ac": 180, "0c3a04798c": 96, "0c44a9d545": 180, "0c817cc390": 180, "0ca839ee9a": 180, "0cd7ac0ac0": 150, "0ce06e0121": 180, "0cfe974a89": 180, "0d2fcc0dcd": 96, "0d3aad05d2": 144, "0d40b015f4": 180, "0d97fba242": 91, "0d9cc80d7e": 51, "0dab85b6d3": 144, "0db5c427a5": 96, "0dbaf284f1": 97, "0de4923598": 97, "0df28a9101": 150, "0e04f636c4": 150, "0e05f0e232": 180, "0e0930474b": 91, "0e27472bea": 180, "0e30020549": 144, "0e621feb6c": 180, "0e803c7d73": 91, "0e9ebe4e3c": 92, "0e9f2785ec": 96, "0ea68d418b": 96, "0eb403a222": 96, "0ee92053d6": 97, "0eefca067f": 150, "0f17fa6fcb": 180, "0f1ac8e9a3": 180, "0f202e9852": 91, "0f2ab8b1ff": 180, "0f51a78756": 150, "0f5fbe16b0": 180, "0f6072077b": 91, "0f6b69b2f4": 180, "0f6c2163de": 144, "0f74ec5599": 180, "0f9683715b": 96, "0fa7b59356": 180, "0fb173695b": 96, "0fc958cde2": 150, "0fe7b1a621": 180, "0ffcdb491c": 96, "101caff7d4": 96, "1022fe8417": 96, "1032e80b37": 96, "103f501680": 180, "104e64565f": 96, "104f1ab997": 91, "106242403f": 96, "10b31f5431": 180, "10eced835e": 91, "110d26fa3a": 150, "1122c1d16a": 180, "1145b49a5f": 180, "11485838c2": 96, "114e7676ec": 180, "1157472b95": 180, "115ee1072c": 91, "1171141012": 150, "117757b4b8": 180, "1178932d2f": 180, "117cc76bda": 180, "1180cbf814": 180, "1187bbd0e3": 96, "1197e44b26": 180, "119cf20728": 180, "119dd54871": 180, "11a0c3b724": 91, "11a6ba8c94": 180, "11c722a456": 180, "11cbcb0b4d": 96, "11ccf5e99d": 96, "11ce6f452e": 91, "11e53de6f2": 46, "11feabe596": 150, "120cb9514d": 180, "12156b25b3": 180, "122896672d": 180, "1232b2f1d4": 36, "1233ac8596": 97, "1239c87234": 180, "1250423f7c": 96, "1257a1bc67": 180, "125d1b19dd": 180, "126d203967": 180, "1295e19071": 96, "12ad198c54": 144, "12bddb2bcb": 150, "12ec9b93ee": 180, "12eebedc35": 91, "132852e094": 180, "1329409f2a": 180, "13325cfa14": 96, "1336440745": 180, "134d06dbf9": 97, "135625b53d": 144, "13870016f9": 92, "13960b3c84": 96, "13adaad9d9": 180, "13ae097e20": 180, "13e3070469": 96, "13f6a8c20d": 144, "1416925cf2": 92, "142d2621f5": 91, "145d5d7c03": 180, "145fdc3ac5": 180, "1471274fa7": 76, "14a6b5a139": 180, "14c21cea0d": 180, "14dae0dc93": 96, "14f9bd22b5": 180, "14fd28ae99": 180, "15097d5d4e": 144, "150ea711f2": 180, "1514e3563f": 180, "152aaa3a9e": 180, "152b7d3bd7": 150, "15617297cc": 180, "15abbe0c52": 150, "15d1fb3de5": 180, "15f67b0fab": 180, "161eb59aad": 96, "16288ea47f": 180, "164410ce62": 91, "165c3c8cd4": 96, "165c42b41b": 91, "165ec9e22b": 144, "1669502269": 91, "16763cccbb": 150, "16adde065e": 96, "16af445362": 96, "16afd538ad": 150, "16c3fa4d5d": 96, "16d1d65c27": 180, "16e8599e94": 180, "16fe9fb444": 91, "1705796b02": 96, "1724db7671": 144, "17418e81ea": 180, "175169edbb": 144, "17622326fd": 180, "17656bae77": 91, "17b0d94172": 61, "17c220e4f6": 180, "17c7bcd146": 96, "17cb4afe89": 180, "17cd79a434": 180, "17d18604c3": 96, "17d8ca1a37": 150, "17e33f4330": 180, "17f7a6d805": 150, "180abc8378": 180, "183ba3d652": 96, "185bf64702": 96, "18913cc690": 91, "1892651815": 180, "189ac8208a": 91, "189b44e92c": 97, "18ac264b76": 150, "18b245ab49": 91, "18b5cebc34": 150, "18bad52083": 180, "18bb5144d5": 180, "18c6f205c5": 96, "1903f9ea15": 96, "1917b209f2": 91, "191e74c01d": 150, "19367bb94e": 180, "193ffaa217": 91, "19696b67d3": 96, "197f3ab6f3": 180, "1981e763cc": 180, "198afe39ae": 144, "19a6e62b9b": 150, "19b60d5335": 180, "19c00c11f9": 150, "19e061eb88": 91, "19e8bc6178": 86, "19ee80dac6": 180, "1a25a9170a": 180, "1a359a6c1a": 150, "1a3e87c566": 150, "1a5fe06b00": 91, "1a6c0fbd1e": 144, "1a6f3b5a4b": 96, "1a8afbad92": 92, "1a8bdc5842": 150, "1a95752aca": 150, "1a9c131cb7": 180, "1aa3da3ee3": 150, "1ab27ec7ea": 56, "1abf16d21d": 150, "1acd0f993b": 180, "1ad202e499": 180, "1af8d2395d": 180, "1afd39a1fa": 91, "1b2d31306f": 180, "1b3fa67f0e": 92, "1b43fa74b4": 150, "1b73ea9fc2": 92, "1b7e8bb255": 96, "1b8680f8cd": 180, "1b883843c0": 91, "1b8898785b": 180, "1b88ba1aa4": 180, "1b96a498e5": 150, "1bbc4c274f": 96, "1bd87fe9ab": 66, "1c4090c75b": 180, "1c41934f84": 96, "1c72b04b56": 180, "1c87955a3a": 150, "1c9f9eb792": 180, "1ca240fede": 96, "1ca5673803": 180, "1cada35274": 180, "1cb44b920d": 180, "1cd10e62be": 150, "1d3087d5e5": 180, "1d3685150a": 92, "1d6ff083aa": 96, "1d746352a6": 66, "1da256d146": 91, "1da4e956b1": 180, "1daf812218": 150, "1dba687bce": 180, "1dce57d05d": 86, "1de4a9e537": 97, "1dec5446c8": 180, "1dfbe6f586": 150, "1e1a18c45a": 180, "1e1e42529d": 76, "1e4be70796": 96, "1eb60959c8": 180, "1ec8b2566b": 180, "1ecdc2941c": 180, "1ee0ac70ff": 87, "1ef8e17def": 91, "1f1a2a9fc0": 86, "1f1beb8daa": 150, "1f2609ee13": 180, "1f3876f8d0": 144, "1f4ec0563d": 150, "1f64955634": 96, "1f7d31b5b2": 96, "1f8014b7fd": 96, "1f9c7d10f1": 180, "1fa350df76": 96, "1fc9538993": 180, "1fe2f0ec59": 150, "2000c02f9d": 180, "20142b2f05": 180, "201a8d75e5": 150, "2023b3ee4f": 180, "202b767bbc": 92, "203594a418": 180, "2038987336": 150, "2039c3aecb": 96, "204a90d81f": 150, "207bc6cf01": 144, "208833d1d1": 180, "20c6d8b362": 46, "20e3e52e0a": 96, "2117fa0c14": 180, "211bc5d102": 150, "2120d9c3c3": 150, "2125235a49": 180, "21386f5978": 92, "2142af8795": 150, "215dfc0f73": 96, "217bae91e5": 180, "217c0d44e4": 150, "219057c87b": 150, "21d0edbf81": 96, "21df87ad76": 96, "21f1d089f5": 96, "21f4019116": 180, "222597030f": 91, "222904eb5b": 92, "223a0e0657": 180, "223bd973ab": 92, "22472f7395": 150, "224e7c833e": 96, "225aba51d9": 86, "2261d421ea": 180, "2263a8782b": 180, "2268cb1ffd": 150, "2268e93b0a": 61, "2293c99f3f": 180, "22a1141970": 91, "22b13084b2": 180, "22d9f5ab0c": 180, "22f02efe3a": 144, "232c09b75b": 150, "2350d71b4b": 180, "2376440551": 180, "2383d8aafd": 144, "238b84e67f": 96, "238d4b86f6": 91, "238d947c6b": 46, "23993ce90d": 180, "23b0c8a9ab": 150, "23b3beafcc": 156, "23d80299fe": 92, "23f404a9fc": 96, "240118e58a": 178, "2431dec2fd": 180, "24440e0ac7": 97, "2457274dbc": 180, "2465bf515d": 91, "246b142c4d": 180, "247d729e36": 96, "2481ceafeb": 150, "24866b4e6a": 150, "2489d78320": 180, "24ab0b83e8": 180, "24b0868d92": 180, "24b5207cd9": 96, "24ddf05c03": 92, "250116161c": 71, "256ad2e3fc": 180, "256bd83d5e": 180, "256dcc8ab8": 180, "2589956baa": 150, "258b3b33c6": 91, "25ad437e29": 96, "25ae395636": 180, "25c750c6db": 150, "25d2c3fe5d": 180, "25dc80db7c": 96, "25f97e926f": 180, "26011bc28b": 150, "260846ffbe": 180, "260dd9ad33": 66, "267964ee57": 92, "2680861931": 96, "268ac7d3fc": 180, "26b895d91e": 71, "26bc786d4f": 91, "26ddd2ef12": 180, "26de3d18ca": 150, "26f7784762": 180, "2703e52a6a": 180, "270ed80c12": 180, "2719b742ab": 180, "272f4163d0": 180, "27303333e1": 96, "27659fa7d6": 180, "279214115d": 180, "27a5f92a9c": 97, "27cf2af1f3": 150, "27f0d5f8a2": 86, "28075f33c1": 180, "281629cb41": 96, "282b0d51f5": 96, "282fcab00b": 96, "28449fa0dc": 180, "28475208ca": 96, "285580b7c4": 180, "285b69e223": 150, "288c117201": 150, "28a8eb9623": 180, "28bf9c3cf3": 180, "28c6b8f86a": 180, "28c972dacd": 144, "28d9fa6016": 96, "28e392de91": 144, "28f4a45190": 150, "298c844fc9": 91, "29a0356a2b": 180, "29d779f9e3": 76, "29dde5f12b": 86, "29de7b6579": 150, "29e630bdd0": 144, "29f2332d30": 144, "2a18873352": 92, "2a3824ff31": 91, "2a559dd27f": 96, "2a5c09acbd": 76, "2a63eb1524": 96, "2a6a30a4ea": 150, "2a6d9099d1": 180, "2a821394e3": 81, "2a8c5b1342": 96, "2abc8d66d2": 96, "2ac9ef904a": 46, "2b08f37364": 150, "2b351bfd7d": 180, "2b659a49d7": 66, "2b69ee5c26": 96, "2b6c30bbbd": 180, "2b88561cf2": 144, "2b8b14954e": 180, "2ba621c750": 150, "2bab50f9a7": 180, "2bb00c2434": 91, "2bbde474ef": 92, "2bdd82fb86": 150, "2be06fb855": 96, "2bf545c2f5": 180, "2bffe4cf9a": 96, "2c04b887b7": 144, "2c05209105": 180, "2c0ad8cf39": 180, "2c11fedca8": 56, "2c1a94ebfb": 91, "2c1e8c8e2f": 180, "2c29fabcf1": 96, "2c2c076c01": 180, "2c3ea7ee7d": 92, "2c41fa0648": 87, "2c44bb6d1c": 96, "2c54cfbb78": 180, "2c5537eddf": 180, "2c6e63b7de": 150, "2cb10c6a7e": 180, "2cbcd5ccd1": 180, "2cc5d9c5f6": 180, "2cd01cf915": 180, "2cdbf5f0a7": 91, "2ce660f123": 96, "2cf114677e": 150, "2d01eef98e": 180, "2d03593bdc": 96, "2d183ac8c4": 180, "2d33ad3935": 96, "2d3991d83e": 150, "2d4333577b": 180, "2d4d015c64": 96, "2d8f5e5025": 144, "2d900bdb8e": 180, "2d9a1a1d49": 46, "2db0576a5c": 180, "2dc0838721": 180, "2dcc417f82": 150, "2df005b843": 180, "2df356de14": 180, "2e00393d96": 61, "2e03b8127a": 180, "2e0f886168": 96, "2e2bf37e6d": 180, "2e42410932": 87, "2ea78f46e4": 180, "2ebb017a26": 180, "2ee2edba2a": 96, "2efb07554a": 180, "2f17e4fc1e": 96, "2f2c65c2f3": 144, "2f2d9b33be": 150, "2f309c206b": 180, "2f53822e88": 144, "2f53998171": 96, "2f5b0c89b1": 180, "2f680909e6": 180, "2f710f66bd": 180, "2f724132b9": 91, "2f7e3517ae": 91, "2f96f5fc6f": 180, "2f97d9fecb": 96, "2fbfa431ec": 96, "2fc9520b53": 180, "2fcd9f4c62": 180, "2feb30f208": 87, "2ff7f5744f": 150, "30085a2cc6": 96, "30176e3615": 56, "301f72ee11": 92, "3026bb2f61": 180, "30318465dc": 150, "3054ca937d": 180, "306121e726": 92, "3064ad91e8": 180, "307444a47f": 180, "307bbb7409": 91, "30a20194ab": 144, "30c35c64a4": 150, "30dbdb2cd6": 91, "30fc77d72f": 150, "310021b58b": 96, "3113140ee8": 144, "3150b2ee57": 180, "31539918c4": 180, "318dfe2ce2": 144, "3193da4835": 91, "319f725ad9": 180, "31bbd0d793": 91, "322505c47f": 180, "322b237865": 92, "322da43910": 97, "3245e049fb": 66, "324c4c38f6": 180, "324e35111a": 150, "3252398f09": 150, "327dc4cabf": 180, "328d918c7d": 180, "3290c0de97": 96, "3299ae3116": 180, "32a7cd687b": 150, "33098cedb4": 92, "3332334ac4": 180, "334cb835ac": 180, "3355e056eb": 180, "33639a2847": 180, "3373891cdc": 180, "337975816b": 180, "33e29d7e91": 96, "34046fe4f2": 180, "3424f58959": 180, "34370a710f": 92, "343bc6a65a": 179, "3450382ef7": 144, "3454303a08": 180, "346aacf439": 180, "346e92ff37": 180, "34a5ece7dd": 144, "34b109755a": 180, "34d1b37101": 96, "34dd2c70a7": 180, "34efa703df": 180, "34fbee00a6": 150, "3504df2fda": 96, "35195a56a1": 150, "351c822748": 180, "351cfd6bc5": 180, "3543d8334c": 180, "35573455c7": 96, "35637a827f": 96, "357a710863": 92, "358bf16f9e": 96, "35ab34cc34": 180, "35c6235b8d": 91, "35d01a438a": 180, "3605019d3b": 96, "3609bc3f88": 92, "360e25da17": 97, "36299c687c": 96, "362c5bc56e": 180, "3649228783": 150, "365b0501ea": 92, "365f459863": 180, "369893f3ad": 180, "369c9977e1": 180, "369dde050a": 96, "36c7dac02f": 180, "36d5b1493b": 180, "36f5cc68fd": 91, "3735480d18": 180, "374b479880": 97, "375a49d38f": 180, "375a5c0e09": 180, "376bda9651": 144, "377db65f60": 144, "37c19d1087": 46, "37d4ae24fc": 96, "37ddce7f8b": 180, "37e10d33af": 180, "37e45c6247": 96, "37fa0001e8": 180, "3802d458c0": 150, "382caa3cb4": 91, "383bb93111": 91, "388843df90": 180, "38924f4a7f": 92, "38b00f93d7": 92, "38c197c10e": 96, "38c9c3d801": 180, "38eb2bf67f": 92, "38fe9b3ed1": 180, "390352cced": 180, "390c51b987": 96, "390ca6f1d6": 144, "392bc0f8a1": 96, "392ecb43bd": 92, "3935291688": 150, "3935e63b41": 180, "394454fa9c": 180, "394638fc8b": 96, "39545e20b7": 180, "397abeae8f": 180, "3988074b88": 91, "398f5d5f19": 174, "39bc49a28c": 180, "39befd99fb": 144, "39c3c7bf55": 180, "39d584b09f": 91, "39f6f6ffb1": 180, "3a079fb484": 180, "3a0d3a81b7": 150, "3a1d55d22b": 82, "3a20a7583e": 96, "3a2c1f66e5": 150, "3a33f4d225": 180, "3a3bf84b13": 144, "3a4565e5ec": 144, "3a4e32ed5e": 180, "3a7ad86ce0": 180, "3a7bdde9b8": 180, "3a98867cbe": 91, "3aa3f1c9e8": 150, "3aa7fce8b6": 91, "3aa876887d": 96, "3ab807ded6": 96, "3ab9b1a85a": 96, "3adac8d7da": 180, "3ae1a4016f": 96, "3ae2deaec2": 180, "3ae81609d6": 144, "3af847e62f": 92, "3b23792b84": 144, "3b3b0af2ee": 150, "3b512dad74": 144, "3b6c7988f6": 91, "3b6e983b5b": 180, "3b74a0fc20": 180, "3b7a50b80d": 180, "3b96d3492f": 180, "3b9ad0c5a9": 150, "3b9ba0894a": 180, "3bb4e10ed7": 144, "3bd9a9b515": 150, "3beef45388": 96, "3c019c0a24": 96, "3c090704aa": 96, "3c2784fc0d": 144, "3c47ab95f8": 150, "3c4db32d74": 91, "3c5ff93faf": 180, "3c700f073e": 180, "3c713cbf2f": 91, "3c8320669c": 180, "3c90d225ee": 180, "3cadbcc404": 96, "3cb9be84a5": 150, "3cc37fd487": 91, "3cc6f90cb2": 92, "3cd5e035ef": 180, "3cdf03531b": 178, "3cdf828f59": 180, "3d254b0bca": 180, "3d5aeac5ba": 180, "3d690473e1": 180, "3d69fed2fb": 96, "3d8997aeb6": 96, "3db0d6b07e": 96, "3db1ddb8cf": 180, "3db907ac77": 180, "3dcbc0635b": 150, "3dd48ed55f": 144, "3de4ac4ec4": 92, "3decd63d88": 180, "3e04a6be11": 180, "3e108fb65a": 96, "3e1448b01c": 150, "3e16c19634": 180, "3e2845307e": 61, "3e38336da5": 96, "3e3a819865": 180, "3e3e4be915": 96, "3e680622d7": 91, "3e7d2aeb07": 96, "3e7d8f363d": 180, "3e91f10205": 26, "3ea4c49bbe": 144, "3eb39d11ab": 180, "3ec273c8d5": 96, "3ed3f91271": 76, "3ee062a2fd": 180, "3eede9782c": 180, "3ef2fa99cb": 180, "3efc6e9892": 92, "3f0b0dfddd": 96, "3f0c860359": 91, "3f18728586": 180, "3f3b15f083": 96, "3f45a470ad": 46, "3f4f3bc803": 150, "3fd96c5267": 91, "3fea675fab": 91, "3fee8cbc9f": 96, "3fff16d112": 180, "401888b36c": 144, "4019231330": 150, "402316532d": 180, "402680df52": 180, "404d02e0c0": 150, "40709263a8": 81, "4083cfbe15": 150, "40a96c5cb1": 96, "40b8e50f82": 91, "40f4026bf5": 144, "4100b57a3a": 150, "41059fdd0b": 180, "41124e36de": 144, "4122aba5f9": 180, "413bab0f0d": 96, "4164faee0b": 180, "418035eec9": 180, "4182d51532": 96, "418bb97e10": 144, "41a34c20e7": 96, "41dab05200": 180, "41ff6d5e2a": 77, "420caf0859": 56, "42264230ba": 96, "425a0c96e0": 91, "42da96b87c": 180, "42eb5a5b0f": 180, "42f17cd14d": 91, "42f5c61c49": 180, "42ffdcdee9": 180, "432f9884f9": 91, "43326d9940": 150, "4350f3ab60": 144, "4399ffade3": 96, "43a6c21f37": 150, "43b5555faa": 180, "43d63b752a": 180, "4416bdd6ac": 92, "4444753edd": 76, "444aa274e7": 150, "444d4e0596": 150, "446b8b5f7a": 96, "4478f694bb": 91, "44b1da0d87": 92, "44b4dad8c9": 96, "44b5ece1b9": 180, "44d239b24e": 150, "44eaf8f51e": 180, "44f4f57099": 96, "44f7422af2": 180, "450787ac97": 180, "4523656564": 96, "4536c882e5": 180, "453b65daa4": 180, "454f227427": 91, "45636d806a": 180, "456fb9362e": 91, "457e717a14": 150, "45a89f35e1": 180, "45bf0e947d": 150, "45c36a9eab": 150, "45d9fc1357": 174, "45f8128b97": 180, "4607f6c03c": 91, "46146dfd39": 92, "4620e66b1e": 150, "4625f3f2d3": 96, "462b22f263": 96, "4634736113": 180, "463c0f4fdd": 180, "46565a75f8": 96, "46630b55ae": 56, "466839cb37": 91, "466ba4ae0c": 180, "4680236c9d": 180, "46bf4e8709": 91, "46e18e42f1": 150, "46f5093c59": 180, "47269e0499": 92, "472da1c484": 144, "47354fab09": 180, "4743bb84a7": 92, "474a796272": 180, "4783d2ab87": 96, "479cad5da3": 180, "479f5d7ef6": 96, "47a05fbd1d": 96, "4804ee2767": 97, "4810c3fbca": 180, "482fb439c2": 150, "48375af288": 96, "484ab44de4": 96, "485f3944cd": 96, "4867b84887": 150, "486a8ac57e": 180, "486e69c5bd": 180, "48812cf33e": 150, "4894b3b9ea": 180, "48bd66517d": 180, "48d83b48a4": 91, "49058178b8": 46, "4918d10ff0": 91, "4932911f80": 150, "49405b7900": 180, "49972c2d14": 150, "499bf07002": 96, "49b16e9377": 180, "49c104258e": 144, "49c879f82d": 96, "49e7326789": 180, "49ec3e406a": 91, "49fbf0c98a": 96, "4a0255c865": 180, "4a088fe99a": 96, "4a341402d0": 180, "4a3471bdf5": 96, "4a4b50571c": 144, "4a50f3d2e9": 96, "4a6e3faaa1": 180, "4a7191f08a": 150, "4a86fcfc30": 180, "4a885fa3ef": 144, "4a8af115de": 21, "4aa2e0f865": 180, "4aa9d6527f": 180, "4abb74bb52": 96, "4ae13de1cd": 91, "4af8cb323f": 97, "4b02c272b3": 180, "4b19c529fb": 96, "4b2974eff4": 180, "4b3154c159": 95, "4b54d2587f": 180, "4b556740ff": 144, "4b67aa9ef6": 178, "4b97cc7b8d": 96, "4baa1ed4aa": 91, "4bc8c676bb": 96, "4beaea4dbe": 180, "4bf5763d24": 96, "4bffa92b67": 138, "4c25dfa8ec": 96, "4c397b6fd4": 180, "4c51e75d66": 150, "4c7710908f": 180, "4c9b5017be": 180, "4ca2ffc361": 92, "4cad2e93bc": 150, "4cd427b535": 180, "4cd9a4b1ef": 180, "4cdfe3c2b2": 180, "4cef87b649": 96, "4cf208e9b3": 180, "4cf5bc3e60": 92, "4cfdd73249": 91, "4cff5c9e42": 180, "4d26d41091": 96, "4d5c23c554": 180, "4d67c59727": 150, "4d983cad9f": 180, "4da0d00b55": 144, "4daa179861": 91, "4dadd57153": 92, "4db117e6c5": 91, "4de4ce4dea": 180, "4dfaee19e5": 180, "4dfdd7fab0": 180, "4e3f346aa5": 92, "4e49c2a9c7": 56, "4e4e06a749": 180, "4e70279712": 96, "4e72856cc7": 91, "4e752f8075": 180, "4e7a28907f": 66, "4e824b9247": 180, "4e82b1df57": 180, "4e87a639bc": 180, "4ea77bfd15": 150, "4eb6fc23a2": 180, "4ec9da329e": 96, "4efb9a0720": 180, "4f062fbc63": 96, "4f35be0e0b": 96, "4f37e86797": 91, "4f414dd6e7": 180, "4f424abded": 180, "4f470cc3ae": 144, "4f601d255a": 150, "4f7386a1ab": 144, "4f824d3dcd": 91, "4f827b0751": 144, "4f8db33a13": 180, "4fa160f8a3": 180, "4fa9c30a45": 180, "4facd8f0e8": 96, "4fca07ad01": 91, "4fded94004": 180, "4fdfef4dea": 91, "4feb3ac01f": 92, "4fffec8479": 96, "500c835a86": 180, "50168342bf": 180, "50243cffdc": 180, "5031d5a036": 180, "504dd9c0fd": 96, "50568fbcfb": 180, "5069c7c5b3": 180, "508189ac91": 180, "50b6b3d4b7": 91, "50c6f4fe3e": 86, "50cce40173": 180, "50efbe152f": 180, "50f290b95d": 91, "5104aa1fea": 96, "5110dc72c0": 180, "511e8ecd7f": 150, "513aada14e": 92, "5158d6e985": 180, "5161e1fa57": 180, "51794ddd58": 96, "517d276725": 91, "51a597ee04": 51, "51b37b6d97": 96, "51b5dc30a0": 96, "51e85b347b": 180, "51eea1fdac": 150, "51eef778af": 91, "51f384721c": 76, "521cfadcb4": 180, "52355da42f": 96, "5247d4b160": 180, "524b470fd0": 180, "524cee1534": 96, "5252195e8a": 91, "5255c9ca97": 144, "525928f46f": 96, "526df007a7": 180, "529b12de78": 91, "52c7a3d653": 150, "52c8ec0373": 91, "52d225ed52": 96, "52ee406d9e": 180, "52ff1ccd4a": 96, "53143511e8": 180, "5316d11eb7": 96, "53253f2362": 180, "534a560609": 91, "5352c4a70e": 180, "536096501f": 92, "536b17bcea": 180, "5380eaabff": 144, "5390a43a54": 180, "53af427bb2": 91, "53bf5964ce": 180, "53c30110b5": 96, "53cad8e44a": 150, "53d9c45013": 91, "53e274f1b5": 150, "53e32d21ea": 96, "540850e1c7": 96, "540cb31cfe": 180, "541c4da30f": 91, "541d7935d7": 180, "545468262b": 180, "5458647306": 144, "54657855cd": 96, "547b3fb23b": 180, "5497dc3712": 150, "549c56f1d4": 96, "54a4260bb1": 150, "54b98b8d5e": 180, "54e1054b0f": 91, "54e8867b83": 180, "54ebe34f6e": 180, "5519b4ad13": 86, "551acbffd5": 150, "55341f42da": 180, "5566ab97e1": 91, "556c79bbf2": 144, "5589637cc4": 180, "558aa072f0": 180, "559824b6f6": 91, "55c1764e90": 180, "55eda6c77e": 180, "562d173565": 150, "5665c024cb": 96, "566cef4959": 91, "5675d78833": 144, "5678a91bd8": 180, "567a2b4bd0": 180, "569c282890": 86, "56cc449917": 150, "56e71f3e07": 150, "56f09b9d92": 180, "56fc0e8cf9": 144, "571ca79c71": 91, "57243657cf": 144, "57246af7d1": 91, "57427393e9": 96, "574b682c19": 180, "578f211b86": 180, "5790ac295d": 91, "579393912d": 180, "57a344ab1a": 180, "57bd3bcda4": 180, "57bfb7fa4c": 150, "57c010175e": 180, "57c457cc75": 180, "57c7fc2183": 150, "57d5289a01": 61, "58045fde85": 96, "58163c37cd": 150, "582d463e5c": 180, "5851739c15": 180, "585dd0f208": 66, "587250f3c3": 180, "589e4cc1de": 180, "589f65f5d5": 180, "58a07c17d5": 180, "58adc6d8b6": 76, "58b9bcf656": 96, "58c374917e": 96, "58fc75fd42": 87, "5914c30f05": 96, "59323787d5": 150, "5937b08d69": 96, "594065ddd7": 96, "595a0ceea6": 91, "59623ec40b": 91, "597ff7ef78": 150, "598935ef05": 46, "598c2ad3b2": 180, "59a6459751": 180, "59b175e138": 96, "59bf0a149f": 180, "59d53d1649": 180, "59e3e6fae7": 180, "59fe33e560": 180, "5a13a73fe5": 96, "5a25c22770": 150, "5a4a785006": 96, "5a50640995": 180, "5a75f7a1cf": 96, "5a841e59ad": 180, "5a91c5ab6d": 150, "5ab49d9de0": 96, "5aba1057fe": 180, "5abe46ba6d": 91, "5ac7c88d0c": 180, "5aeb95cc7d": 92, "5af15e4fc3": 91, "5afe381ae4": 96, "5b07b4229d": 51, "5b1001cc4f": 180, "5b1df237d2": 180, "5b263013bf": 91, "5b27d19f0b": 180, "5b48ae16c5": 96, "5b5babc719": 180, "5baaebdf00": 180, "5bab55cdbe": 180, "5bafef6e79": 96, "5bc77844da": 180, "5bd1f84545": 180, "5bddc3ba25": 180, "5bdf7c20d2": 180, "5bf23bc9d3": 180, "5c01f6171a": 144, "5c021681b7": 96, "5c185cff1d": 180, "5c42aba280": 180, "5c44bf8ab6": 180, "5c4c574894": 144, "5c52fa4662": 76, "5c6ea7dac3": 96, "5c74315dc2": 180, "5c7668855e": 92, "5c83e96778": 180, "5ca36173e4": 96, "5cac477371": 97, "5cb0cb1b2f": 96, "5cb0cfb98f": 144, "5cb49a19cf": 180, "5cbf7dc388": 180, "5d0e07d126": 96, "5d1e24b6e3": 81, "5d663000ff": 150, "5da6b2dc5d": 180, "5de9b90f24": 61, "5e08de0ed7": 180, "5e1011df9a": 87, "5e1ce354fd": 150, "5e35512dd7": 180, "5e418b25f9": 96, "5e4849935a": 144, "5e4ee19663": 96, "5e886ef78f": 96, "5e8d00b974": 180, "5e8d59dc31": 180, "5ed838bd5c": 96, "5edda6ee5a": 180, "5ede4d2f7a": 144, "5ede9767da": 144, "5ee23ca60e": 87, "5eec4d9fe5": 96, "5eecf07824": 180, "5eef7ed4f4": 91, "5ef5860ac6": 144, "5ef6573a99": 96, "5f1193e72b": 91, "5f29ced797": 96, "5f32cf521e": 150, "5f51876986": 96, "5f6ebe94a9": 86, "5f6f14977c": 91, "5f808d0d2d": 91, "5fb8aded6a": 180, "5fba90767d": 96, "5fd1c7a3df": 92, "5fd3da9f68": 91, "5fee2570ae": 180, "5ff66140d6": 180, "5ff8b85b53": 180, "600803c0f6": 180, "600be7f53e": 96, "6024888af8": 180, "603189a03c": 96, "6057307f6e": 180, "6061ddbb65": 96, "606c86c455": 180, "60c61cc2e5": 180, "60e51ff1ae": 150, "610e38b751": 150, "61344be2f6": 180, "6135e27185": 96, "614afe7975": 150, "614e571886": 180, "614e7078db": 96, "619812a1a7": 96, "61b481a78b": 96, "61c7172650": 180, "61cf7e40d2": 96, "61d08ef5a1": 46, "61da008958": 96, "61ed178ecb": 61, "61f5d1282c": 92, "61fd977e49": 144, "621584cffe": 180, "625817a927": 180, "625892cf0b": 96, "625b89d28a": 91, "629995af95": 150, "62a0840bb5": 180, "62ad6e121c": 87, "62d6ece152": 91, "62ede7b2da": 91, "62f025e1bc": 180, "6316faaebc": 97, "63281534dc": 150, "634058dda0": 144, "6353f09384": 180, "6363c87314": 180, "636e4872e0": 180, "637681cd6b": 180, "6376d49f31": 180, "6377809ec2": 180, "63936d7de5": 96, "639bddef11": 150, "63d37e9fd3": 180, "63d90c2bae": 96, "63e544a5d6": 180, "63ebbcf874": 96, "63fff40b31": 180, "6406c72e4d": 61, "64148128be": 96, "6419386729": 150, "643092bc41": 96, "644081b88d": 144, "64453cf61d": 180, "644bad9729": 96, "6454f548fd": 180, "645913b63a": 180, "64750b825f": 180, "64a43876b7": 96, "64dd6c83e3": 92, "64e05bf46e": 96, "64f55f1478": 150, "650b0165e4": 180, "651066ed39": 180, "652b67d960": 180, "653821d680": 180, "6538d00d73": 180, "65866dce22": 150, "6589565c8c": 150, "659832db64": 180, "65ab7e1d98": 180, "65b7dda462": 180, "65bd5eb4f5": 180, "65dcf115ab": 91, "65e9825801": 180, "65f9afe51c": 91, "65ff12bcb5": 180, "666b660284": 180, "6671643f31": 180, "668364b372": 96, "66852243cb": 96, "6693a52081": 180, "669b572898": 180, "66e98e78f5": 91, "670f12e88f": 180, "674c12c92d": 91, "675c27208a": 180, "675ed3e1ca": 144, "67741db50a": 96, "678a2357eb": 70, "67b0f4d562": 180, "67cfbff9b1": 180, "67e717d6bd": 91, "67ea169a3b": 92, "67ea809e0e": 180, "681249baa3": 180, "683de643d9": 180, "6846ac20df": 96, "6848e012ef": 96, "684bcd8812": 96, "684dc1c40c": 96, "685a1fa9cf": 91, "686dafaac9": 144, "68807d8601": 96, "6893778c77": 96, "6899d2dabe": 91, "68a2fad4ab": 180, "68cb45fda3": 180, "68cc4a1970": 96, "68dcb40675": 180, "68ea4a8c3d": 180, "68f6e7fbf0": 96, "68fa8300b4": 180, "69023db81f": 96, "6908ccf557": 91, "691a111e7c": 180, "6927723ba5": 180, "692ca0e1a2": 97, "692eb57b63": 180, "69340faa52": 96, "693cbf0c9d": 180, "6942f684ad": 96, "6944fc833b": 180, "69491c0ebf": 91, "695b61a2b0": 96, "6979b4d83f": 180, "697d4fdb02": 144, "69910460a4": 180, "6997636670": 180, "69a436750b": 96, "69aebf7669": 180, "69b8c17047": 180, "69c67f109f": 180, "69e0e7b868": 180, "69ea9c09d1": 180, "69f0af42a6": 97, "6a078cdcc7": 144, "6a37a91708": 71, "6a42176f2e": 180, "6a48e4aea8": 96, "6a5977be3a": 180, "6a5de0535f": 180, "6a80d2e2e5": 96, "6a96c8815d": 180, "6a986084e2": 96, "6aa8e50445": 92, "6ab9dce449": 150, "6abf0ba6b2": 180, "6acc6049d9": 96, "6adb31756c": 180, "6ade215eb0": 96, "6afb7d50e4": 144, "6afd692f1a": 180, "6b0b1044fe": 91, "6b17c67633": 180, "6b1b6ef28b": 92, "6b1e04d00d": 180, "6b2261888d": 96, "6b25d6528a": 144, "6b3a24395c": 150, "6b685eb75b": 96, "6b79be238c": 92, "6b928b7ba6": 96, "6b9c43c25a": 180, "6ba99cc41f": 91, "6bdab62bcd": 86, "6bf2e853b1": 180, "6bf584200f": 180, "6bf95df2b9": 150, "6c0949c51c": 180, "6c11a5f11f": 96, "6c23d89189": 61, "6c4387daf5": 96, "6c4ce479a4": 86, "6c5123e4bc": 96, "6c54265f16": 92, "6c56848429": 96, "6c623fac5f": 36, "6c81b014e9": 96, "6c99ea7c31": 92, "6c9d29d509": 91, "6c9e3b7d1a": 91, "6ca006e283": 96, "6caeb928d6": 180, "6cb2ee722a": 180, "6cbfd32c5e": 180, "6cc791250b": 150, "6cccc985e0": 96, "6d12e30c48": 180, "6d4bf200ad": 180, "6d6d2b8843": 91, "6d6eea5682": 180, "6d7a3d0c21": 96, "6d7efa9b9e": 180, "6da21f5c91": 180, "6da6adabc0": 150, "6dd2827fbb": 96, "6dd36705b9": 131, "6df3637557": 180, "6dfe55e9e5": 150, "6e1a21ba55": 96, "6e2f834767": 180, "6e36e4929a": 96, "6e4f460caf": 96, "6e618d26b6": 56, "6ead4670f7": 180, "6eaff19b9f": 180, "6eb2e1cd9e": 180, "6eb30b3b5a": 96, "6eca26c202": 91, "6ecad29e52": 96, "6ef0b44654": 96, "6efcfe9275": 180, "6f4789045c": 180, "6f49f522ef": 96, "6f67d7c4c4": 180, "6f96e91d81": 144, "6fc6fce380": 180, "6fc9b44c00": 96, "6fce7f3226": 150, "6fdf1ca888": 150, "702fd8b729": 180, "70405185d2": 180, "7053e4f41e": 180, "707bf4ce41": 87, "7082544248": 81, "708535b72a": 96, "7094ac0f60": 180, "70a6b875fa": 180, "70c3e97e41": 180, "7106b020ab": 91, "711dce6fe2": 96, "7136a4453f": 180, "7143fb084f": 180, "714d902095": 150, "7151c53b32": 150, "715357be94": 180, "7163b8085f": 150, "716df1aa59": 150, "71caded286": 150, "71d2665f35": 91, "71d67b9e19": 96, "71e06dda39": 180, "720b398b9c": 91, "720e3fa04c": 150, "720e7a5f1e": 91, "721bb6f2cb": 91, "722803f4f2": 92, "72552a07c9": 91, "726243a205": 96, "72690ef572": 46, "728cda9b65": 86, "728e81c319": 91, "72a810a799": 180, "72acb8cdf6": 180, "72b01281f9": 180, "72cac683e4": 91, "72cadebbce": 180, "72cae058a5": 180, "72d8dba870": 180, "72e8d1c1ff": 96, "72edc08285": 180, "72f04f1a38": 81, "731b825695": 144, "7320b49b13": 180, "732626383b": 87, "732df1eb05": 150, "73329902ab": 150, "733798921e": 150, "733824d431": 150, "734ea0d7fb": 91, "735a7cf7b9": 144, "7367a42892": 91, "7368d5c053": 180, "738e5a0a14": 180, "73c6ae7711": 96, "73e1852735": 150, "73e4e5cc74": 150, "73eac9156b": 180, "73f8441a88": 91, "7419e2ab3f": 91, "74267f68b9": 91, "7435690c8c": 46, "747c44785c": 81, "747f1b1f2f": 144, "748b2d5c01": 96, "74d4cee0a4": 91, "74ec2b3073": 91, "74ef677020": 96, "750be4c4d8": 96, "75172d4ac8": 96, "75285a7eb1": 180, "75504539c3": 91, "7550949b1d": 96, "7551cbd537": 150, "75595b453d": 91, "7559b4b0ec": 91, "755bd1fbeb": 96, "756f76f74d": 180, "7570ca7f3c": 180, "757a69746e": 180, "757cac96c6": 180, "7584129dc3": 144, "75a058dbcd": 91, "75b09ce005": 96, "75cae39a8f": 180, "75cee6caf0": 180, "75cf58fb2c": 91, "75d5c2f32a": 180, "75eaf5669d": 96, "75f7937438": 180, "75f99bd3b3": 96, "75fa586876": 92, "7613df1f84": 150, "762e1b3487": 96, "76379a3e69": 180, "764271f0f3": 92, "764503c499": 86, "7660005554": 46, "7666351b84": 96, "76693db153": 51, "767856368b": 92, "768671f652": 180, "768802b80d": 180, "76962c7ed2": 71, "76a75f4eee": 150, "76b90809f7": 180, "770a441457": 96, "772a0fa402": 180, "772f2ffc3e": 91, "774f6c2175": 180, "77610860e0": 56, "777e58ff3d": 96, "77920f1708": 150, "7799df28e7": 180, "779e847a9a": 81, "77ba4edc72": 96, "77c834dc43": 41, "77d8aa8691": 180, "77e7f38f4d": 144, "77eea6845e": 96, "7806308f33": 91, "78254660ea": 91, "7828af8bff": 180, "784398620a": 71, "784d201b12": 96, "78613981ed": 180, "78896c6baf": 92, "78aff3ebc0": 150, "78c7c03716": 91, "78d3676361": 91, "78e29dd4c3": 150, "78f1a1a54f": 91, "79208585cd": 180, "792218456c": 180, "7923bad550": 150, "794e6fc49f": 96, "796e6762ce": 180, "797cd21f71": 150, "79921b21c2": 150, "79a5778027": 180, "79bc006280": 180, "79bf95e624": 91, "79d9e00c55": 91, "79e20fc008": 96, "79e9db913e": 180, "79f014085e": 91, "79fcbb433a": 150, "7a13a5dfaa": 180, "7a14bc9a36": 96, "7a3c535f70": 96, "7a446a51e9": 91, "7a56e759c5": 91, "7a5f46198d": 86, "7a626ec98d": 92, "7a802264c4": 180, "7a8b5456ca": 180, "7abdff3086": 150, "7aecf9f7ac": 150, "7b0fd09c28": 96, "7b18b3db87": 180, "7b39fe7371": 144, "7b49e03d4c": 180, "7b5388c9f1": 180, "7b5cf7837f": 180, "7b733d31d8": 180, "7b74fd7b98": 180, "7b918ccb8a": 150, "7ba3ce3485": 96, "7bb0abc031": 180, "7bb5bb25cd": 180, "7bb7dac673": 92, "7bc7761b8c": 180, "7bf3820566": 96, "7c03a18ec1": 96, "7c078f211b": 150, "7c37d7991a": 71, "7c4ec17eff": 144, "7c649c2aaf": 180, "7c73340ab7": 91, "7c78a2266d": 180, "7c88ce3c5b": 180, "7ca6843a72": 180, "7cc9258dee": 96, "7cec7296ae": 46, "7d0ffa68a4": 96, "7d11b4450f": 81, "7d1333fcbe": 96, "7d18074fef": 91, "7d18c8c716": 96, "7d508fb027": 180, "7d55f791f0": 180, "7d74e3c2f6": 150, "7d783f67a9": 96, "7d83a5d854": 150, "7dd409947e": 180, "7de45f75e5": 150, "7e0cd25696": 150, "7e1922575c": 96, "7e1e3bbcc1": 180, "7e24023274": 180, "7e2f212fd3": 96, "7e6d1cc1f4": 180, "7e7cdcb284": 144, "7e9b6bef69": 66, "7ea5b49283": 92, "7eb2605d96": 91, "7eb26b8485": 180, "7ecd1f0c69": 96, "7f02b3cfe2": 180, "7f1723f0d5": 97, "7f21063c3a": 81, "7f3658460e": 91, "7f54132e48": 144, "7f559f9d4a": 144, "7f5faedf8b": 96, "7f838baf2b": 180, "7fa5f527e3": 96, "7ff84d66dd": 150, "802b45c8c4": 180, "804382b1ad": 180, "804c558adb": 96, "804f6338a4": 180, "8056117b89": 150, "806b6223ab": 96, "8088bda461": 46, "80b790703b": 180, "80c4a94706": 96, "80ce2e351b": 180, "80db581acd": 96, "80e12193df": 150, "80e41b608f": 180, "80f16b016d": 91, "81541b3725": 91, "8175486e6a": 96, "8179095000": 180, "8193671178": 180, "81a58d2c6b": 150, "81aa1286fb": 96, "81dffd30fb": 96, "8200245704": 41, "823e7a86e8": 46, "824973babb": 144, "824ca5538f": 180, "827171a845": 180, "8273a03530": 180, "827cf4f886": 91, "82b865c7dd": 180, "82c1517708": 91, "82d15514d6": 150, "82e117b900": 179, "82fec06574": 150, "832b5ef379": 97, "83424c9fbf": 180, "8345358fb8": 71, "834b50b31b": 180, "835e3b67d7": 97, "836ea92b15": 90, "837c618777": 144, "838eb3bd89": 180, "839381063f": 91, "839bc71489": 180, "83a8151377": 180, "83ae88d217": 180, "83ca8bcad0": 180, "83ce590d7f": 180, "83d3130ba0": 36, "83d40bcba5": 86, "83daba503a": 144, "83de906ec0": 180, "84044f37f3": 180, "84696b5a5e": 96, "84752191a3": 91, "847eeeb2e0": 180, "848e7835a0": 96, "84a4b29286": 180, "84a4bf147d": 66, "84be115c09": 144, "84d95c4350": 180, "84e0922cf7": 150, "84f0cfc665": 96, "8515f6db22": 180, "851f2f32c1": 91, "852a4d6067": 150, "854c48b02a": 96, "857a387c86": 180, "859633d56a": 96, "85a4f4a639": 144, "85ab85510c": 180, "85b1eda0d9": 92, "85dc1041c6": 96, "85e081f3c7": 150, "85f75187ad": 96, "8604bb2b75": 96, "860745b042": 150, "863b4049d7": 180, "8643de22d0": 180, "8647d06439": 46, "864ffce4fe": 180, "8662d9441a": 180, "8666521b13": 76, "868d6a0685": 91, "869fa45998": 91, "86a40b655d": 150, "86a8ae4223": 92, "86b2180703": 180, "86c85d27df": 180, "86d3755680": 180, "86e61829a1": 180, "871015806c": 91, "871e409c5c": 180, "8744b861ce": 96, "8749369ba0": 180, "878a299541": 144, "8792c193a0": 96, "8799ab0118": 96, "87d1f7d741": 180, "882b9e4500": 180, "885673ea17": 180, "8859dedf41": 96, "8873ab2806": 91, "887a93b198": 180, "8883e991a9": 86, "8891aa6dfa": 91, "8899d8cbcd": 91, "88b8274d67": 180, "88d3b80af6": 91, "88ede83da2": 180, "88f345941b": 180, "890976d6da": 91, "8909bde9ab": 91, "8929c7d5d9": 180, "89363acf76": 150, "89379487e0": 96, "8939db6354": 180, "893f658345": 144, "8953138465": 180, "895c96d671": 180, "895cbf96f9": 180, "895e8b29a7": 91, "898fa256c8": 180, "89986c60be": 180, "89b874547b": 180, "89bdb021d5": 144, "89c802ff9c": 96, "89d6336c2b": 180, "89ebb27334": 91, "8a27e2407c": 96, "8a31f7bca5": 96, "8a4a2fc105": 96, "8a5d6c619c": 96, "8a75ad7924": 180, "8aa817e4ed": 87, "8aad0591eb": 180, "8aca214360": 180, "8ae168c71b": 96, "8b0cfbab97": 21, "8b3645d826": 96, "8b3805dbd4": 180, "8b473f0f5d": 180, "8b4f6d1186": 180, "8b4fb018b7": 66, "8b518ee936": 92, "8b523bdfd6": 150, "8b52fb5fba": 91, "8b91036e5c": 144, "8b99a77ac5": 180, "8ba04b1e7b": 96, "8ba782192f": 180, "8bbeaad78b": 96, "8bd1b45776": 180, "8bd7a2dda6": 150, "8bdb091ccf": 180, "8be56f165d": 96, "8be950d00f": 96, "8bf84e7d45": 180, "8bffc4374b": 66, "8bfff50747": 180, "8c09867481": 144, "8c0a3251c3": 180, "8c3015cccb": 180, "8c469815cf": 96, "8c9ccfedc7": 91, "8ca1af9f3c": 150, "8ca3f6e6c1": 96, "8ca6a4f60f": 96, "8cac6900fe": 96, "8cba221a1e": 180, "8cbbe62ccd": 180, "8d064b29e2": 92, "8d167e7c08": 91, "8d4ab94e1c": 96, "8d81f6f899": 180, "8d87897d66": 91, "8dcccd2bd2": 180, "8dcfb878a8": 150, "8dd3ab71b9": 91, "8dda6bf10f": 96, "8ddd51ca94": 180, "8dea22c533": 180, "8def5bd3bf": 96, "8e1848197c": 91, "8e3a83cf2d": 91, "8e478e73f3": 91, "8e98ae3c84": 96, "8ea6687ab0": 180, "8eb0d315c1": 91, "8ec10891f9": 150, "8ec3065ec2": 180, "8ecf51a971": 150, "8eddbab9f7": 91, "8ee198467a": 180, "8ee2368f40": 180, "8ef595ce82": 150, "8f0a653ad7": 150, "8f1204a732": 150, "8f1600f7f6": 91, "8f16366707": 96, "8f1ce0a411": 92, "8f2e05e814": 91, "8f320d0e09": 96, "8f3b4a84ad": 91, "8f3fdad3da": 96, "8f5d3622d8": 96, "8f62a2c633": 180, "8f81c9405a": 97, "8f8c974d53": 120, "8f918598b6": 96, "8ff61619f6": 96, "9002761b41": 96, "90107941f3": 92, "90118a42ee": 96, "902bc16b37": 91, "903e87e0d6": 144, "9041a0f489": 96, "9047bf3222": 51, "9057bfa502": 150, "90617b0954": 92, "9076f4b6db": 180, "9077e69b08": 144, "909655b4a6": 96, "909c2eca88": 180, "909dbd1b76": 180, "90bc4a319a": 180, "90c7a87887": 96, "90cc785ddd": 96, "90d300f09b": 180, "9101ea9b1b": 96, "9108130458": 150, "911ac9979b": 150, "9151cad9b5": 97, "9153762797": 180, "91634ee0c9": 91, "916942666f": 76, "9198cfb4ea": 180, "919ac864d6": 180, "91b67d58d4": 180, "91bb8df281": 150, "91be106477": 91, "91c33b4290": 180, "91ca7dd9f3": 144, "91d095f869": 180, "91f107082e": 180, "920329dd5e": 180, "920c959958": 150, "92128fbf4b": 144, "9223dacb40": 150, "923137bb7f": 61, "9268e1f88a": 180, "927647fe08": 150, "9276f5ba47": 150, "92a28cd233": 71, "92b5c1fc6d": 144, "92c46be756": 180, "92dabbe3a0": 96, "92e3159361": 180, "92ebab216a": 180, "934bdc2893": 180, "9359174efc": 180, "935d97dd2f": 91, "935feaba1b": 96, "93901858ee": 150, "939378f6d6": 91, "939bdf742e": 96, "93a22bee7e": 96, "93da9aeddf": 91, "93e2feacce": 180, "93e6f1fdf9": 96, "93e811e393": 180, "93e85d8fd3": 180, "93f623d716": 180, "93ff35e801": 46, "94031f12f2": 96, "94091a4873": 180, "94125907e3": 87, "9418653742": 91, "941c870569": 101, "94209c86f0": 180, "9437c715eb": 76, "9445c3eca2": 91, "9467c8617c": 96, "946d71fb5d": 96, "948f3ae6fb": 180, "9498baa359": 96, "94a33abeab": 91, "94bf1af5e3": 144, "94cf3a8025": 96, "94db712ac8": 180, "94e4b66cff": 92, "94e76cbaf6": 180, "950be91db1": 180, "952058e2d0": 92, "952633c37f": 96, "952ec313fe": 87, "9533fc037c": 96, "9574b81269": 92, "9579b73761": 180, "957f7bc48b": 180, "958073d2b0": 150, "9582e0eb33": 71, "9584092d0b": 91, "95b58b8004": 150, "95bd88da55": 180, "95f74a9959": 180, "962781c601": 180, "962f045bf5": 91, "964ad23b44": 91, "967b90590e": 144, "967bffe201": 86, "96825c4714": 81, "968492136a": 96, "9684ef9d64": 86, "968c41829e": 91, "96a856ef9a": 180, "96dfc49961": 180, "96e1a5b4f8": 180, "96e6ff0917": 150, "96fb88e9d7": 96, "96fbe5fc23": 150, "96fc924050": 96, "9715cc83dc": 180, "9720eff40f": 180, "972c187c0d": 180, "97476eb38d": 180, "97659ed431": 180, "9773492949": 96, "97756b264f": 96, "977bff0d10": 96, "97ab569ff3": 96, "97ba838008": 180, "97d9d008c7": 150, "97e59f09fa": 96, "97eb642e56": 96, "98043e2d14": 96, "981ff580cf": 180, "983e66cbfc": 96, "984f0f1c36": 180, "98595f2bb4": 91, "985c3be474": 91, "9869a12362": 180, "986b5a5e18": 180, "9877af5063": 180, "98911292da": 180, "9893a3cf77": 97, "9893d9202d": 91, "98a8b06e7f": 91, "98ac6f93d9": 150, "98b6974d12": 96, "98ba3c9417": 180, "98c7c00a19": 96, "98d044f206": 96, "98e909f9d1": 150, "98fe7f0410": 150, "990f2742c7": 96, "992bd0779a": 180, "994b9b47ba": 150, "9955b76bf5": 91, "9966f3adac": 46, "997117a654": 180, "999d53d841": 150, "99c04108d3": 180, "99c4277aee": 96, "99c6b1acf2": 96, "99dc8bb20b": 180, "99fcba71e5": 150, "99fecd4efb": 92, "9a02c70ba2": 96, "9a08e7a6f8": 180, "9a2f2c0f86": 81, "9a3254a76e": 92, "9a3570a020": 180, "9a39112493": 180, "9a4e9fd399": 180, "9a50af4bfb": 180, "9a68631d24": 150, "9a72318dbf": 92, "9a767493b7": 180, "9a7fc1548b": 96, "9a84ccf6a7": 150, "9a9c0e15b7": 96, "9adf06d89b": 150, "9b22b54ee4": 91, "9b473fc8fe": 96, "9b4f081782": 180, "9b997664ba": 180, "9bc454e109": 180, "9bccfd04de": 96, "9bce4583a2": 96, "9bebf1b87f": 158, "9bfc50d261": 180, "9c166c86ff": 96, "9c293ef4d7": 144, "9c29c047b0": 91, "9c3bc2e2a7": 96, "9c3ce23bd1": 91, "9c404cac0c": 180, "9c5180d23a": 144, "9c7feca6e4": 144, "9caa49d3ff": 180, "9cb2f1b646": 180, "9ce6f765c3": 91, "9cfee34031": 180, "9d01f08ec6": 180, "9d04c280b8": 91, "9d12ceaddc": 180, "9d15f8cb3c": 180, "9d2101e9bf": 180, "9d407c3aeb": 96, "9ddefc6165": 180, "9df0b1e298": 96, "9e16f115d8": 144, "9e249b4982": 96, "9e29b1982c": 92, "9e493e4773": 180, "9e4c752cd0": 91, "9e4de40671": 96, "9e6319faeb": 96, "9e6ddbb52d": 91, "9eadcea74f": 180, "9ecec5f8ea": 46, "9efb47b595": 96, "9f30bfe61e": 72, "9f3734c3a4": 180, "9f5b858101": 180, "9f66640cda": 180, "9f913803e9": 180, "9f97bc74c8": 180, "9fbad86e20": 180, "9fc2bad316": 180, "9fc5c3af78": 150, "9fcb310255": 92, "9fcc256871": 91, "9fd2fd4d47": 180, "a0071ae316": 96, "a023141022": 56, "a046399a74": 96, "a066e739c1": 150, "a06722ba82": 96, "a07a15dd64": 180, "a07b47f694": 180, "a09c39472e": 144, "a0b208fe2e": 91, "a0b61c959e": 96, "a0bc6c611d": 180, "a0e6da5ba2": 91, "a1193d6490": 96, "a14ef483ff": 91, "a14f709908": 180, "a15ccc5658": 96, "a16062456f": 180, "a174e8d989": 91, "a177c2733c": 150, "a17c62e764": 92, "a18ad065fc": 150, "a1aaf63216": 96, "a1bb65fb91": 150, "a1bd8e5349": 91, "a1dfdd0cac": 180, "a2052e4f6c": 96, "a20fd34693": 96, "a21ffe4d81": 150, "a22349e647": 180, "a235d01ec1": 180, "a24f63e8a2": 180, "a2554c9f6d": 46, "a263ce8a87": 180, "a29bfc29ec": 91, "a2a80072d4": 150, "a2a800ab63": 180, "a2bcd10a33": 180, "a2bdaff3b0": 91, "a2c146ab0d": 91, "a2c996e429": 96, "a2dc51ebe8": 180, "a2e6608bfa": 180, "a2f2a55f01": 96, "a301869dea": 180, "a31fccd2cc": 180, "a34f440f33": 180, "a35e0206da": 180, "a36bdc4cab": 180, "a36e8c79d8": 71, "a378053b20": 144, "a37db3a2b3": 91, "a38950ebc2": 180, "a39a0eb433": 91, "a39c9bca52": 180, "a3a945dc8c": 91, "a3b40a0c1e": 150, "a3b8588550": 91, "a3c502bec3": 180, "a3f2878017": 180, "a3f4d58010": 180, "a3f51855c3": 150, "a402dc0dfe": 21, "a4065a7eda": 180, "a412bb2fef": 180, "a416b56b53": 96, "a41ec95906": 91, "a43299e362": 180, "a4757bd7af": 96, "a48c53c454": 180, "a49dcf9ad5": 150, "a4a506521f": 180, "a4ba7753d9": 180, "a4bac06849": 91, "a4f05d681c": 91, "a50c10060f": 150, "a50eb5a0ea": 150, "a5122c6ec6": 150, "a522b1aa79": 96, "a590915345": 180, "a5b5b59139": 96, "a5b77abe43": 180, "a5c2b2c3e1": 96, "a5cd17bb11": 180, "a5da03aef1": 180, "a5dd11de0d": 150, "a5ea2b93b6": 150, "a5eaeac80b": 180, "a5ec5b0265": 144, "a5f350a87e": 180, "a5f472caf4": 96, "a6027a53cf": 180, "a61715bb1b": 180, "a61cf4389d": 150, "a61d9bbd9b": 180, "a6470dbbf5": 150, "a64a40f3eb": 76, "a653d5c23b": 180, "a65bd23cb5": 150, "a66e0b7ad4": 180, "a66fc5053c": 91, "a68259572b": 180, "a6a810a92c": 150, "a6bc36937f": 91, "a6c3a374e9": 180, "a6d8a4228d": 180, "a6f4e0817f": 180, "a71e0481f5": 96, "a7203deb2d": 150, "a7392d4438": 150, "a73d3c3902": 180, "a7491f1578": 150, "a74b9ca19c": 180, "a77b7a91df": 150, "a78195a5f5": 150, "a78758d4ce": 180, "a7e6d6c29a": 96, "a800d85e88": 51, "a832fa8790": 180, "a83d06410d": 150, "a8999af004": 180, "a8f78125b9": 180, "a907b18df1": 150, "a919392446": 150, "a965504e88": 96, "a96b84b8d2": 96, "a973f239cd": 91, "a977126596": 180, "a9804f2a08": 91, "a984e56893": 96, "a99738f24c": 91, "a99bdd0079": 144, "a9c9c1517e": 178, "a9cbf9c41b": 150, "a9e42e3c0c": 150, "aa07b7c1c0": 180, "aa175e5ec7": 96, "aa1a338630": 96, "aa27d7b868": 96, "aa45f1caaf": 91, "aa49e46432": 96, "aa51934e1b": 180, "aa6287bb6c": 96, "aa6d999971": 180, "aa85278334": 96, "aab33f0e2a": 180, "aaba004362": 180, "aade4cf385": 180, "aae78feda4": 91, "aaed233bf3": 180, "aaff16c2db": 96, "ab199e8dfb": 96, "ab23b78715": 96, "ab2e1b5577": 180, "ab33a18ded": 96, "ab45078265": 180, "ab56201494": 180, "ab90f0d24b": 180, "abab2e6c20": 180, "abb50c8697": 92, "abbe2d15a0": 180, "abbe73cd21": 150, "abe61a11bb": 180, "abeae8ce21": 150, "ac2b431d5f": 150, "ac2cb1b9eb": 150, "ac31fcd6d0": 91, "ac3d3a126d": 180, "ac46bd8087": 180, "ac783ef388": 180, "acb73e4297": 150, "acbf581760": 180, "accafc3531": 96, "acf2c4b745": 96, "acf44293a2": 96, "acf736a27b": 90, "acff336758": 180, "ad1fe56886": 92, "ad28f9b9d9": 91, "ad2de9f80e": 180, "ad397527b2": 97, "ad3d1cfbcb": 86, "ad3fada9d9": 180, "ad4108ee8e": 180, "ad54468654": 66, "ad573f7d31": 96, "ad6255bc29": 180, "ad65ebaa07": 144, "ad97cc064a": 96, "adabbd1cc4": 180, "adb0b5a270": 180, "adc648f890": 150, "add21ee467": 180, "adfd15ceef": 180, "adfdd52eac": 96, "ae01cdab63": 180, "ae0b50ff4f": 96, "ae13ee3d70": 180, "ae1bcbd423": 180, "ae20d09dea": 180, "ae2cecf5f6": 56, "ae3bc4a0ef": 180, "ae499c7514": 92, "ae628f2cd4": 150, "ae8545d581": 86, "ae93214fe6": 150, "ae9cd16dbf": 46, "aeba9ac967": 180, "aebb242b5c": 150, "aed4e0b4c4": 86, "aedd71f125": 180, "aef3e2cb0e": 180, "af0b54cee3": 96, "af3de54c7a": 180, "af5fd24a36": 150, "af8826d084": 91, "af8ad72057": 180, "afb71e22c5": 92, "afcb331e1f": 96, "afe1a35c1e": 150, "b01080b5d3": 180, "b05ad0d345": 96, "b0623a6232": 91, "b064dbd4b7": 96, "b06ed37831": 96, "b06f5888e6": 92, "b08dcc490e": 91, "b0a68228dc": 92, "b0aece727f": 144, "b0b0731606": 96, "b0c7f11f9f": 180, "b0cca8b830": 180, "b0dd580a89": 180, "b0de66ca08": 180, "b0df7c5c5c": 96, "b0f5295608": 96, "b11099eb09": 180, "b132a53086": 91, "b1399fac64": 180, "b13abc0c69": 96, "b1457e3b5e": 180, "b15bf4453b": 91, "b179c4a82d": 96, "b17ee70e8c": 180, "b190b1aa65": 96, "b19b3e22c0": 180, "b19c561fab": 180, "b1d1cd2e6e": 92, "b1d7c03927": 91, "b1d7fe2753": 180, "b1f540a4bd": 96, "b1fc9c64e1": 96, "b1fcbb3ced": 180, "b220939e93": 96, "b22099b419": 180, "b241e95235": 96, "b2432ae86d": 180, "b2456267df": 180, "b247940d01": 150, "b24af1c35c": 180, "b24f600420": 97, "b24fe36b2a": 150, "b258fb0b7d": 180, "b26b219919": 96, "b26d9904de": 96, "b274456ce1": 180, "b27b28d581": 72, "b2a26bc912": 180, "b2a9c51e1b": 180, "b2b0baf470": 180, "b2b2756fe7": 96, "b2ce7699e3": 180, "b2edc76bd2": 150, "b2f6b52100": 180, "b30bf47bcd": 180, "b34105a4e9": 91, "b372a82edf": 150, "b3779a1962": 96, "b379ab4ff5": 46, "b37a1d69e3": 150, "b37c01396e": 180, "b382b09e25": 150, "b3996e4ba5": 180, "b3d9ca2aee": 180, "b3dde1e1e9": 180, "b3eb7f05eb": 86, "b40b25055c": 91, "b41e0f1f19": 91, "b44e32a42b": 91, "b4805ae9cd": 46, "b4807569a5": 97, "b48efceb3e": 150, "b493c25c7f": 180, "b4b565aba1": 150, "b4b715a15b": 180, "b4d0c90bf4": 91, "b4d84bc371": 180, "b4e5ad97aa": 180, "b4eaea9e6b": 150, "b50f4b90d5": 180, "b53f675641": 150, "b54278cd43": 180, "b554843889": 150, "b573c0677a": 180, "b58d853734": 180, "b5943b18ab": 180, "b5a09a83f3": 71, "b5aae1fe25": 91, "b5b9da5364": 97, "b5eb64d419": 91, "b5ebb1d000": 96, "b5f1c0c96a": 96, "b5f7fece90": 180, "b6070de1bb": 180, "b60a76fe73": 86, "b61f998772": 96, "b62c943664": 96, "b63094ba0c": 180, "b64fca8100": 96, "b673e7dcfb": 96, "b678b7db00": 180, "b68fc1b217": 180, "b69926d9fa": 96, "b6a1df3764": 180, "b6a4859528": 96, "b6b4738b78": 96, "b6b4f847b7": 150, "b6b8d502d4": 150, "b6bb00e366": 180, "b6d65a9eef": 180, "b6d79a0845": 180, "b6e9ec577f": 91, "b6ec609f7b": 163, "b6f92a308d": 180, "b70a2c0ab1": 46, "b70a5a0d50": 180, "b70c052f2f": 150, "b70d231781": 92, "b72ac6e10b": 180, "b7302d8226": 92, "b73867d769": 150, "b751e767f2": 180, "b76df6e059": 96, "b77e5eddef": 92, "b7a2c2c83c": 96, "b7bcbe6466": 180, "b7c2a469c4": 180, "b7d69da8f0": 144, "b7f31b7c36": 61, "b7f675fb98": 46, "b7fb871660": 51, "b82e5ad1c9": 91, "b841cfb932": 96, "b84b8ae665": 180, "b85b78ac2b": 180, "b86c17caa6": 180, "b86e50d82d": 96, "b871db031a": 66, "b87d56925a": 96, "b8aaa59b75": 92, "b8c03d1091": 180, "b8c3210036": 46, "b8e16df00b": 144, "b8f34cf72e": 91, "b8fb75864e": 150, "b9004db86c": 180, "b9166cbae9": 92, "b920b256a6": 180, "b938d79dff": 20, "b93963f214": 180, "b941aef1a0": 144, "b94d34d14e": 96, "b964c57da4": 96, "b96a95bc7a": 180, "b96c57d2c7": 144, "b9b6bdde0c": 180, "b9bcb3e0f2": 96, "b9d3b92169": 180, "b9dd4b306c": 180, "b9f43ef41e": 92, "ba1f03c811": 96, "ba3a775d7b": 180, "ba3c7f2a31": 150, "ba3fcd417d": 180, "ba5e1f4faa": 150, "ba795f3089": 96, "ba8a291e6a": 150, "ba98512f97": 92, "bac9db04f5": 180, "baedae3442": 180, "baff40d29d": 180, "bb04e28695": 96, "bb1b0ee89f": 96, "bb1c770fe7": 150, "bb1fc34f99": 150, "bb2d220506": 180, "bb334e5cdb": 91, "bb337f9830": 81, "bb721eb9aa": 96, "bb87ff58bd": 96, "bb89a6b18a": 87, "bbaa9a036a": 144, "bbb4302dda": 180, "bbd31510cf": 96, "bbe0256a75": 180, "bc141b9ad5": 91, "bc17ab8a99": 150, "bc318160de": 180, "bc3b9ee033": 91, "bc4240b43c": 96, "bc4ce49105": 91, "bc4f71372d": 96, "bc6b8d6371": 180, "bcaad44ad7": 150, "bcc241b081": 91, "bcc5d8095e": 96, "bcd1d39afb": 96, "bd0d849da4": 180, "bd0e9ed437": 150, "bd2c94730f": 180, "bd321d2be6": 61, "bd3ec46511": 91, "bd5b2e2848": 41, "bd7e02b139": 96, "bd96f9943a": 180, "bda224cb25": 91, "bda4a82837": 96, "bdb74e333f": 180, "bdccd69dde": 96, "bddcc15521": 180, "be116aab29": 150, "be15e18f1e": 150, "be1a284edb": 180, "be2a367a7b": 180, "be376082d0": 150, "be3e3cffbd": 51, "be5d1d89a0": 180, "be8b72fe37": 180, "be9b29e08e": 91, "bea1f6e62c": 97, "bea83281b5": 92, "beb921a4c9": 96, "bec5e9edcd": 180, "beeb8a3f92": 150, "bf2232b58d": 96, "bf28751739": 150, "bf443804e8": 180, "bf461df850": 150, "bf5374f122": 180, "bf551a6f60": 180, "bf8d0f5ada": 96, "bf961167a6": 92, "bfab1ad8f9": 150, "bfcb05d88d": 96, "bfd8f6e6c9": 92, "bfd91d0742": 150, "bfe262322f": 87, "c013f42ed7": 180, "c01878083f": 180, "c01faff1ed": 180, "c046fd0edb": 150, "c053e35f97": 91, "c079a6482d": 96, "c0847b521a": 96, "c0a1e06710": 180, "c0e8d4635c": 96, "c0e973ad85": 96, "c0f49c6579": 92, "c0f5b222d7": 96, "c10d07c90d": 180, "c1268d998c": 96, "c130c3fc0c": 180, "c14826ad5e": 180, "c15b922281": 180, "c16f09cb63": 180, "c18e19d922": 180, "c1c830a735": 96, "c1e8aeea45": 180, "c20a5ccc99": 180, "c20fd5e597": 180, "c219d6f8dc": 150, "c2406ae462": 96, "c26f7b5824": 180, "c279e641ee": 96, "c27adaeac5": 180, "c2a35c1cda": 96, "c2a9903b8b": 180, "c2b62567c1": 96, "c2b974ec8c": 150, "c2baaff7bf": 91, "c2be6900f2": 180, "c304dd44d5": 180, "c307f33da2": 96, "c30a7b62c9": 92, "c3128733ee": 180, "c31fa6c598": 180, "c325c8201e": 96, "c32d4aa5d1": 180, "c33f28249a": 144, "c34365e2d7": 180, "c3457af795": 96, "c34d120a88": 180, "c3509e728d": 96, "c35e4fa6c4": 180, "c36240d96f": 150, "c3641dfc5a": 92, "c37b17a4a9": 180, "c39559ddf6": 180, "c3b0c6e180": 96, "c3b3d82e6c": 180, "c3be369fdb": 91, "c3bf1e40c2": 97, "c3c760b015": 96, "c3dd38bf98": 150, "c3e4274614": 91, "c3edc48cbd": 180, "c41e6587f5": 96, "c4272227b0": 96, "c42917fe82": 86, "c438858117": 180, "c44676563f": 180, "c44beb7472": 180, "c45411dacb": 91, "c4571bedc8": 91, "c46deb2956": 180, "c479ee052e": 180, "c47d551843": 180, "c49f07d46d": 180, "c4cc40c1fc": 97, "c4f256f5d5": 144, "c4f5b1ddcc": 180, "c4ff9b4885": 150, "c52bce43db": 66, "c544da6854": 180, "c55784c766": 180, "c557b69fbf": 180, "c593a3f7ab": 92, "c598faa682": 180, "c5ab1f09c8": 180, "c5b6da8602": 96, "c5b9128d94": 96, "c5e845c6b7": 150, "c5fba7b341": 150, "c60897f093": 96, "c61fe6ed7c": 96, "c62188c536": 96, "c64035b2e2": 150, "c69689f177": 180, "c6a12c131f": 51, "c6bb6d2d5c": 180, "c6c18e860f": 150, "c6d9526e0d": 180, "c6e55c33f0": 96, "c7030b28bd": 96, "c70682c7cc": 180, "c70f9be8c5": 87, "c71f30d7b6": 180, "c73c8e747f": 180, "c760eeb8b3": 144, "c7637cab0a": 150, "c7a1a17308": 87, "c7bf937af5": 91, "c7c2860db3": 180, "c7cef4aee2": 91, "c7ebfc5d57": 180, "c813dcf13c": 91, "c82235a49a": 96, "c82a7619a1": 180, "c82ecb90cb": 180, "c844f03dc7": 96, "c8557963f3": 91, "c89147e6e8": 180, "c8a46ff0c8": 150, "c8ab107dd5": 97, "c8b869a04a": 96, "c8c7b306a6": 91, "c8c8b28781": 180, "c8d79e3163": 180, "c8edab0415": 150, "c8f494f416": 96, "c8f6cba9fd": 150, "c909ceea97": 92, "c9188f4980": 180, "c922365dd4": 96, "c92c8c3c75": 96, "c937eb0b83": 91, "c94b31b5e5": 180, "c95cd17749": 180, "c96379c03c": 180, "c96465ee65": 180, "c965afa713": 144, "c9734b451f": 92, "c9862d82dc": 180, "c98b6fe013": 180, "c9999b7c48": 180, "c99e92aaf0": 97, "c9b3a8fbda": 150, "c9bf64e965": 96, "c9c3cb3797": 91, "c9d1c60cd0": 144, "c9de9c22c4": 96, "ca1828fa54": 96, "ca346f17eb": 180, "ca3787d3d3": 150, "ca4b99cbac": 96, "ca91c69e3b": 71, "ca91e99105": 46, "caa8e97f81": 96, "caac5807f8": 180, "cabba242c2": 96, "cad5a656a9": 180, "cad673e375": 180, "cad8a85930": 150, "cae7b0a02b": 180, "cae7ef3184": 180, "caeb6b6cbb": 150, "caecf0a5db": 91, "cb15312003": 76, "cb2e35d610": 150, "cb35a87504": 150, "cb3f22b0cf": 96, "cbb410da64": 91, "cc8728052e": 150, "cc892997b8": 180, "cce03c2a9b": 144, "cd47a23e31": 92, "cd4dc03dc0": 180, "cd5ae611da": 96, "cd603bb9d1": 144, "cd8f49734c": 180, "cdc6b1c032": 92, "cdcfe008ad": 144, "cdd57027c2": 96, "ce1af99b4b": 150, "ce1bc5743a": 150, "ce25872021": 97, "ce2776f78f": 180, "ce49b1f474": 180, "ce4f0a266f": 180, "ce5641b195": 180, "ce6866aa19": 180, "ce712ed3c9": 91, "ce7d1c8117": 144, "ce7dbeaa88": 180, "ce9b015a5e": 180, "cea7697b25": 96, "cebbd826cf": 150, "cec3415361": 150, "cec41ad4f4": 180, "ced49d26df": 180, "ced7705ab2": 144, "cef824a1e1": 92, "cf13f5c95a": 144, "cf4376a52d": 180, "cf85ab28b5": 180, "cfc2e50b9d": 150, "cfcd571fff": 144, "cfd9d4ae47": 180, "cfda2dcce5": 150, "cff035928b": 91, "cff8191891": 46, "d01608c2a5": 96, "d01a8f1f83": 144, "d021d68bca": 180, "d04258ca14": 150, "d0483573dc": 150, "d04a90aaff": 180, "d05279c0bd": 180, "d0696bd5fc": 91, "d072fda75b": 178, "d0a83bcd9f": 150, "d0ab39112e": 180, "d0acde820f": 96, "d0b4442c71": 144, "d0c65e9e95": 180, "d0fb600c73": 150, "d107a1457c": 61, "d123d674c1": 66, "d14d1e9289": 96, "d154e3388e": 96, "d177e9878a": 96, "d1802f69f8": 150, "d182c4483a": 180, "d195d31128": 180, "d200838929": 180, "d205e3cff5": 180, "d247420c4c": 180, "d2484bff33": 66, "d26f6ed9b0": 150, "d280fcd1cb": 180, "d2857f0faa": 180, "d292a50c7f": 46, "d295ea2dc7": 96, "d2a58b4fa6": 91, "d2b026739a": 150, "d2ebe0890f": 180, "d2ede5d862": 91, "d301ca58cc": 150, "d3069da8bb": 91, "d343d4a77d": 150, "d355e634ef": 86, "d367fb5253": 91, "d36d16358e": 76, "d38bc77e2c": 101, "d38d1679e2": 144, "d3932ad4bd": 97, "d3987b2930": 180, "d39934abe3": 144, "d3ae1c3f4c": 92, "d3b088e593": 87, "d3e6e05e16": 150, "d3eefae7c5": 144, "d3f55f5ab8": 180, "d3f5c309cc": 61, "d4034a7fdf": 180, "d4193011f3": 144, "d429c67630": 180, "d42c0ff975": 180, "d44a764409": 180, "d44e6acd1d": 66, "d45158c175": 150, "d454e8444f": 150, "d45f62717e": 180, "d48ebdcf74": 180, "d49ab52a25": 86, "d4a607ad81": 92, "d4b063c7db": 144, "d4da13e9ba": 96, "d4dd1a7d00": 180, "d4f4f7c9c3": 96, "d521aba02e": 180, "d535bb1b97": 92, "d53b955f78": 96, "d55cb7a205": 92, "d55f247a45": 150, "d5695544d8": 180, "d5853d9b8b": 180, "d5b6c6d94a": 96, "d5cae12834": 150, "d5df027f0c": 144, "d5ee40e5d0": 180, "d600046f73": 144, "d632fd3510": 144, "d6476cad55": 180, "d65a7bae86": 150, "d664c89912": 150, "d689658f06": 180, "d6917db4be": 96, "d69967143e": 96, "d699d3d798": 91, "d69f757a3f": 180, "d6ac0e065c": 91, "d6c02bfda5": 96, "d6c1b5749e": 92, "d6e12ef6cc": 92, "d6eed152c4": 180, "d6faaaf726": 96, "d704766646": 180, "d708e1350c": 180, "d7135cf104": 180, "d7157a9f44": 46, "d719cf9316": 96, "d724134cfd": 144, "d73a60a244": 180, "d7411662da": 144, "d74875ea7c": 96, "d756f5a694": 91, "d7572b7d8a": 180, "d763bd6d96": 180, "d7697c8b13": 96, "d7797196b4": 150, "d79c834768": 180, "d7b34e5d73": 91, "d7bb6b37a7": 150, "d7c7e064a6": 180, "d7fbf545b3": 96, "d82a0aa15b": 180, "d847e24abd": 144, "d8596701b7": 144, "d86101499c": 144, "d87069ba86": 150, "d87160957b": 144, "d874654b52": 91, "d88a403092": 96, "d8aee40f3f": 144, "d8e77a222d": 91, "d8eb07c381": 180, "d9010348a1": 66, "d90e3cf281": 91, "d92532c7b2": 180, "d927fae122": 150, "d95707bca8": 91, "d973b31c00": 144, "d991cb471d": 180, "d992c69d37": 150, "d99d770820": 180, "d9b63abc11": 180, "d9db6f1983": 144, "d9e52be2d2": 96, "d9edc82650": 150, "da01070697": 96, "da070ea4b7": 180, "da080507b9": 150, "da0e944cc4": 180, "da28d94ff4": 96, "da5d78b9d1": 180, "da6003fc72": 150, "da690fee9f": 180, "da6c68708f": 180, "da7a816676": 144, "dac361e828": 180, "dac71659b8": 144, "dad980385d": 96, "daebc12b77": 150, "db0968cdd3": 150, "db231a7100": 92, "db59282ace": 91, "db7f267c3f": 180, "dba35b87fd": 96, "dbba735a50": 86, "dbca076acd": 180, "dbd66dc3ac": 180, "dbdc3c292b": 180, "dbf4a5b32b": 180, "dbfc417d28": 180, "dc1745e0a2": 91, "dc32a44804": 180, "dc34b35e30": 150, "dc504a4f79": 92, "dc704dd647": 180, "dc71bc6918": 92, "dc7771b3be": 180, "dcf8c93617": 96, "dd0f4c9fb9": 180, "dd415df125": 120, "dd601f9a3f": 144, "dd61d903df": 150, "dd77583736": 150, "dd8636bd8b": 180, "dd9fe6c6ac": 92, "ddb2da4c14": 180, "ddcd450d47": 144, "dde8e67fb4": 76, "ddfc3f04d3": 150, "de2ab79dfa": 180, "de2f35b2fd": 91, "de30990a51": 180, "de36b216da": 96, "de37403340": 180, "de46e4943b": 96, "de4ddbccb1": 180, "de5e480f05": 96, "de6a9382ca": 96, "de74a601d3": 180, "de827c510d": 92, "ded6069f7b": 180, "defb71c741": 96, "df01f277f1": 180, "df05214b82": 92, "df0638b0a0": 46, "df11931ffe": 180, "df1b0e4620": 180, "df20a8650d": 92, "df2bc56d7c": 180, "df365282c6": 180, "df39a0d9df": 96, "df3c430c24": 91, "df5536cfb9": 180, "df59cfd91d": 97, "df5e2152b3": 66, "df741313c9": 96, "df7626172f": 92, "df8ad5deb9": 180, "df96aa609a": 180, "df9705605c": 180, "df9c91c4da": 180, "dfc0d3d27a": 180, "dfdbf91a99": 180, "e00baaae9b": 180, "e0a938c6e7": 91, "e0b2ceee6f": 150, "e0bdb5dfae": 36, "e0be1f6e17": 96, "e0c478f775": 150, "e0de82caa7": 180, "e0f217dd59": 91, "e0f7208874": 180, "e0fb58395e": 180, "e1194c2e9d": 150, "e11adcd05d": 180, "e128124b9d": 87, "e1495354e4": 180, "e1561d6d4b": 180, "e158805399": 91, "e16945b951": 46, "e19edcd34b": 180, "e1a1544285": 180, "e1ab7957f4": 150, "e1d26d35be": 96, "e1e957085b": 96, "e1f14510fa": 180, "e214b160f4": 180, "e2167379b8": 150, "e21acb20ab": 180, "e221105579": 180, "e22ddf8a1b": 180, "e22de45950": 96, "e22ffc469b": 180, "e23cca5244": 96, "e252f46f0b": 180, "e25fa6cf39": 180, "e26e486026": 150, "e275760245": 96, "e27bbedbfe": 92, "e29e9868a8": 180, "e2b37ff8af": 96, "e2b608d309": 180, "e2bef4da9a": 96, "e2c87a6421": 96, "e2ea25542c": 144, "e2fb1d6497": 178, "e2fcc99117": 91, "e33c18412a": 71, "e348377191": 91, "e352cb59c8": 180, "e36ac982f0": 91, "e391bc981e": 96, "e39e3e0a06": 96, "e3bf38265f": 51, "e3d5b2cd21": 150, "e3d60e82d5": 46, "e3e3245492": 96, "e3e4134877": 150, "e3f4635e03": 180, "e4004ee048": 180, "e402d1afa5": 180, "e415093d27": 71, "e41ceb5d81": 180, "e424653b78": 96, "e42b6d3dbb": 96, "e42d60f0d4": 180, "e436d0ff1e": 180, "e43d7ae2c5": 92, "e4428801bc": 97, "e44e0b4917": 180, "e470345ede": 180, "e48e8b4263": 180, "e4922e3726": 180, "e4936852bb": 96, "e495f32c60": 41, "e499228f26": 150, "e4af66e163": 180, "e4b2095f58": 180, "e4d19c8283": 180, "e4d4872dab": 96, "e4e2983570": 41, "e4eaa63aab": 91, "e4ef0a3a34": 91, "e4f8e5f46e": 96, "e4ffb6d0dd": 71, "e53e21aa02": 180, "e57f4f668b": 180, "e588433c1e": 96, "e597442c99": 150, "e5abc0e96b": 91, "e5be628030": 180, "e5ce96a55d": 61, "e5d6b70a9f": 81, "e5fde1574c": 92, "e625e1d27b": 180, "e6261d2348": 91, "e6267d46bc": 96, "e6295f223f": 180, "e63463d8c6": 96, "e6387bd1e0": 180, "e653883384": 96, "e65f134e0b": 150, "e668ef5664": 180, "e672ccd250": 92, "e674510b20": 91, "e676107765": 150, "e699da0cdf": 180, "e6be243065": 46, "e6deab5e0b": 76, "e6f065f2b9": 96, "e71629e7b5": 96, "e72a7d7b0b": 150, "e72f6104e1": 92, "e75a466eea": 72, "e76c55933f": 150, "e7784ec8ad": 180, "e78922e5e6": 47, "e78d450a9c": 91, "e7c6354e77": 91, "e7c8de1fce": 150, "e7ea10db28": 150, "e803918710": 180, "e8073a140b": 180, "e828dd02db": 150, "e845994987": 150, "e8485a2615": 96, "e85c5118a7": 180, "e88b6736e4": 180, "e8962324e3": 91, "e8b3018d36": 91, "e8cee8bf0b": 150, "e8d97ebece": 144, "e8da49ea6a": 96, "e8ed1a3ccf": 180, "e8f7904326": 72, "e8f8341dec": 180, "e8fa21eb13": 180, "e90c10fc4c": 150, "e914b8cac8": 180, "e92b6bfea4": 46, "e92e1b7623": 150, "e93f83e512": 92, "e9422ad240": 46, "e9460b55f9": 180, "e9502628f6": 180, "e950befd5f": 180, "e9582bdd1b": 91, "e95e5afe0f": 96, "e97cfac475": 96, "e98d57d99c": 91, "e98eda8978": 92, "e99706b555": 41, "e9bc0760ba": 91, "e9d3c78bf3": 87, "e9ec1b7ea8": 144, "ea065cc205": 180, "ea138b6617": 150, "ea16d3fd48": 180, "ea2545d64b": 180, "ea286a581c": 150, "ea320da917": 96, "ea345f3627": 91, "ea3b94a591": 180, "ea444a37eb": 71, "ea4a01216b": 180, "ea5672ffa8": 81, "eaa99191cb": 150, "eaab4d746c": 91, "eac7a59bc1": 150, "ead5d3835a": 96, "eaec65cfa7": 180, "eaed1a87be": 180, "eb2f821c6f": 180, "eb383cb82e": 91, "eb6992fe02": 150, "eb6ac20a01": 92, "eb6d7ab39e": 96, "eb7921facd": 180, "eb8fce51a6": 180, "ebbb90e9f9": 91, "ebbf5c9ee1": 180, "ebc4ec32e6": 91, "ebe56e5ef8": 180, "ec1299aee4": 97, "ec139ff675": 180, "ec193e1a01": 180, "ec28252938": 150, "ec387be051": 180, "ec3d4fac00": 91, "ec4186ce12": 95, "ec579c2f96": 91, "ecae59b782": 180, "ecb33a0448": 180, "ece6bc9e92": 150, "ecfedd4035": 92, "ecfff22fd6": 180, "ed3291c3d6": 180, "ed3cd5308d": 180, "ed3e6fc1a5": 180, "ed72ae8825": 180, "ed7455da68": 92, "ed844e879f": 150, "ed8f814b2b": 92, "ed911a1f63": 180, "ed9ff4f649": 180, "eda8ab984b": 180, "edb8878849": 96, "edbfdfe1b4": 180, "edd22c46a2": 96, "edd663afa3": 180, "ede3552eae": 96, "edeab61ee0": 174, "ee07583fc0": 150, "ee316eaed6": 91, "ee3f509537": 150, "ee40a1e491": 92, "ee4bf100f1": 180, "ee6f9b01f9": 180, "ee947ed771": 96, "ee9706ac7f": 91, "ee9a7840ae": 180, "eeb90cb569": 180, "eebf45e5c5": 92, "eeed0c7d73": 87, "ef0061a309": 96, "ef07f1a655": 96, "ef0a8e8f35": 56, "ef232a2aed": 150, "ef308ad2e9": 180, "ef44945428": 96, "ef45ce3035": 180, "ef5dde449d": 180, "ef5e770988": 144, "ef6359cea3": 96, "ef65268834": 180, "ef6cb5eae0": 86, "ef78972bc2": 150, "ef8cfcfc4f": 82, "ef96501dd0": 150, "ef9a2e976b": 91, "efb24f950f": 180, "efce0c1868": 180, "efe5ac6901": 91, "efe828affa": 180, "efea4e0523": 144, "f0268aa627": 180, "f0483250c8": 180, "f04cf99ee6": 62, "f05b189097": 96, "f08928c6d3": 96, "f09d74856f": 150, "f0a7607d63": 180, "f0ad38da27": 71, "f0c34e1213": 92, "f0c7f86c29": 180, "f0dfa18ba7": 150, "f0eb3179f7": 180, "f119bab27d": 150, "f14409b6a3": 180, "f1489baff4": 86, "f14c18cf6a": 180, "f15c607b92": 180, "f1af214222": 97, "f1b77bd309": 180, "f1ba9e1a3e": 180, "f1d99239eb": 66, "f1dc710cf4": 180, "f1ec5c08fa": 97, "f22648fe12": 180, "f22d21f1f1": 144, "f233257395": 91, "f23e95dbe5": 96, "f2445b1572": 150, "f253b3486d": 144, "f277c7a6a4": 91, "f2ab2b84d6": 87, "f2b7c9b1f3": 150, "f2b83d5ce5": 180, "f2c276018f": 150, "f2cfd94d64": 150, "f2dd6e3add": 150, "f2e7653f16": 180, "f2f333ad06": 96, "f2f55d6713": 180, "f2fdb6abec": 180, "f305a56d9f": 46, "f3085d6570": 96, "f3325c3338": 180, "f3400f1204": 180, "f34497c932": 97, "f34a56525e": 91, "f36483c824": 96, "f3704d5663": 91, "f3734c4913": 150, "f38e5aa5b4": 86, "f3986fba44": 180, "f3a0ffc7d9": 180, "f3b24a7d28": 96, "f3e6c35ec3": 180, "f3fc0ea80b": 96, "f40a683fbe": 180, "f4207ca554": 180, "f4377499c2": 150, "f46184f393": 144, "f46c2d0a6d": 180, "f46c364dca": 180, "f46f7a0b63": 180, "f46fe141b0": 91, "f470b9aeb0": 180, "f47eb7437f": 96, "f48b535719": 92, "f49e4866ac": 180, "f4aa882cfd": 180, "f4daa3dbd5": 96, "f4dd51ac35": 91, "f507a1b9dc": 96, "f51c5ac84b": 86, "f52104164b": 180, "f54c67b9bb": 96, "f5966cadd2": 180, "f5bddf5598": 91, "f5d85cfd17": 92, "f5e2e7d6a0": 96, "f5f051e9b4": 180, "f5f8a93a76": 150, "f6283e8af5": 96, "f635e9568b": 180, "f6474735be": 144, "f659251be2": 150, "f66981af4e": 96, "f6708fa398": 87, "f697fe8e8f": 96, "f6adb12c42": 76, "f6c7906ca4": 180, "f6cd0a8016": 144, "f6d6f15ae7": 144, "f6e501892c": 96, "f6f59d986f": 180, "f6fe8c90a5": 180, "f714160545": 144, "f74c3888d7": 180, "f7782c430e": 150, "f7783ae5f2": 96, "f77ab47923": 97, "f788a98327": 91, "f7961ac1f0": 96, "f7a71e7574": 150, "f7a8521432": 180, "f7afbf4947": 150, "f7b7cd5f44": 81, "f7cf4b4a39": 92, "f7d49799ad": 150, "f7e0c9bb83": 180, "f7e5b84928": 96, "f7e6bd58be": 96, "f7f2a38ac6": 96, "f7f6cb2d6d": 150, "f83f19e796": 76, "f85796a921": 91, "f8603c26b2": 180, "f8819b42ec": 144, "f891f8eaa1": 96, "f89288d10c": 92, "f895ae8cc1": 180, "f8af30d4b6": 97, "f8b4ac12f1": 180, "f8c3fb2b01": 180, "f8c8de2764": 180, "f8db369b40": 92, "f8fcb6a78c": 180, "f94aafdeef": 180, "f95d217b70": 96, "f9681d5103": 92, "f9750192a4": 91, "f9823a32c2": 96, "f991ddb4c2": 96, "f99d535567": 96, "f9ae3d98b7": 144, "f9b6217959": 91, "f9bd1fabf5": 96, "f9c68eaa64": 180, "f9d3e04c4f": 92, "f9daf64494": 180, "f9e4cc5a0a": 96, "f9ea6b7f31": 96, "f9f3852526": 180, "fa04c615cf": 150, "fa08e00a56": 180, "fa4370d74d": 180, "fa67744af3": 180, "fa88d48a92": 150, "fa8b904cc9": 92, "fa9526bdf1": 150, "fa9b9d2426": 150, "fad633fbe1": 150, "faf5222dc3": 91, "faff0e15f1": 180, "fb08c64e8c": 180, "fb23455a7f": 150, "fb2e19fa6e": 180, "fb34dfbb77": 180, "fb47fcea1e": 96, "fb49738155": 180, "fb4cbc514b": 71, "fb4e6062f7": 180, "fb5ba7ad6e": 96, "fb63cd1236": 96, "fb81157a07": 180, "fb92abdaeb": 180, "fba22a6848": 92, "fbaca0c9df": 180, "fbc645f602": 96, "fbd77444cd": 96, "fbe53dc8e8": 96, "fbe541dd73": 97, "fbe8488798": 91, "fbfd25174f": 96, "fc28cb305e": 97, "fc33b1ffd6": 150, "fc6186f0bb": 180, "fc918e3a40": 150, "fc96cda9d8": 150, "fc9832eea4": 150, "fcb10d0f81": 180, "fcd20a2509": 180, "fcf637e3ab": 92, "fcfd81727f": 96, "fd31890379": 180, "fd33551c28": 144, "fd542da05e": 144, "fd6789b3fe": 180, "fd77828200": 180, "fd7af75f4d": 150, "fdb28d0fbb": 150, "fdb3d1fb1e": 82, "fdb8b04124": 96, "fdc6e3d581": 91, "fdfce7e6fc": 180, "fe0f76d41b": 180, "fe24b0677d": 180, "fe3c02699d": 144, "fe58b48235": 96, "fe6a5596b8": 91, "fe6c244f63": 66, "fe7afec086": 180, "fe985d510a": 144, "fe9db35d15": 96, "fea8ffcd36": 144, "feb1080388": 180, "fed208bfca": 180, "feda5ad1c2": 180, "feec95b386": 91, "ff15a5eff6": 144, "ff204daf4b": 96, "ff25f55852": 180, "ff2ada194f": 180, "ff2ce142e8": 96, "ff49d36d20": 180, "ff5a1ec4f3": 180, "ff66152b25": 180, "ff692fdc56": 180, "ff773b1a1e": 96, "ff97129478": 144, "ffb904207d": 180, "ffc43fc345": 150, "fffe5f8df6": 180}
|
inference_propainter.py
ADDED
@@ -0,0 +1,475 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import os
|
3 |
+
import cv2
|
4 |
+
import argparse
|
5 |
+
import imageio
|
6 |
+
import numpy as np
|
7 |
+
import scipy.ndimage
|
8 |
+
from PIL import Image
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torchvision
|
13 |
+
|
14 |
+
from model.modules.flow_comp_raft import RAFT_bi
|
15 |
+
from model.recurrent_flow_completion import RecurrentFlowCompleteNet
|
16 |
+
from model.propainter import InpaintGenerator
|
17 |
+
from utils.download_util import load_file_from_url
|
18 |
+
from core.utils import to_tensors
|
19 |
+
from model.misc import get_device
|
20 |
+
|
21 |
+
import warnings
|
22 |
+
warnings.filterwarnings("ignore")
|
23 |
+
|
24 |
+
pretrain_model_url = 'https://github.com/sczhou/ProPainter/releases/download/v0.1.0/'
|
25 |
+
|
26 |
+
def imwrite(img, file_path, params=None, auto_mkdir=True):
|
27 |
+
if auto_mkdir:
|
28 |
+
dir_name = os.path.abspath(os.path.dirname(file_path))
|
29 |
+
os.makedirs(dir_name, exist_ok=True)
|
30 |
+
return cv2.imwrite(file_path, img, params)
|
31 |
+
|
32 |
+
|
33 |
+
# resize frames
|
34 |
+
def resize_frames(frames, size=None):
|
35 |
+
if size is not None:
|
36 |
+
out_size = size
|
37 |
+
process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8)
|
38 |
+
frames = [f.resize(process_size) for f in frames]
|
39 |
+
else:
|
40 |
+
out_size = frames[0].size
|
41 |
+
process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8)
|
42 |
+
if not out_size == process_size:
|
43 |
+
frames = [f.resize(process_size) for f in frames]
|
44 |
+
|
45 |
+
return frames, process_size, out_size
|
46 |
+
|
47 |
+
|
48 |
+
# read frames from video
|
49 |
+
def read_frame_from_videos(frame_root):
|
50 |
+
if frame_root.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): # input video path
|
51 |
+
video_name = os.path.basename(frame_root)[:-4]
|
52 |
+
vframes, aframes, info = torchvision.io.read_video(filename=frame_root, pts_unit='sec') # RGB
|
53 |
+
frames = list(vframes.numpy())
|
54 |
+
frames = [Image.fromarray(f) for f in frames]
|
55 |
+
fps = info['video_fps']
|
56 |
+
else:
|
57 |
+
video_name = os.path.basename(frame_root)
|
58 |
+
frames = []
|
59 |
+
fr_lst = sorted(os.listdir(frame_root))
|
60 |
+
for fr in fr_lst:
|
61 |
+
frame = cv2.imread(os.path.join(frame_root, fr))
|
62 |
+
frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
63 |
+
frames.append(frame)
|
64 |
+
fps = None
|
65 |
+
size = frames[0].size
|
66 |
+
|
67 |
+
return frames, fps, size, video_name
|
68 |
+
|
69 |
+
|
70 |
+
def binary_mask(mask, th=0.1):
|
71 |
+
mask[mask>th] = 1
|
72 |
+
mask[mask<=th] = 0
|
73 |
+
return mask
|
74 |
+
|
75 |
+
|
76 |
+
# read frame-wise masks
|
77 |
+
def read_mask(mpath, length, size, flow_mask_dilates=8, mask_dilates=5):
|
78 |
+
masks_img = []
|
79 |
+
masks_dilated = []
|
80 |
+
flow_masks = []
|
81 |
+
|
82 |
+
if mpath.endswith(('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')): # input single img path
|
83 |
+
masks_img = [Image.open(mpath)]
|
84 |
+
else:
|
85 |
+
mnames = sorted(os.listdir(mpath))
|
86 |
+
for mp in mnames:
|
87 |
+
masks_img.append(Image.open(os.path.join(mpath, mp)))
|
88 |
+
|
89 |
+
for mask_img in masks_img:
|
90 |
+
if size is not None:
|
91 |
+
mask_img = mask_img.resize(size, Image.NEAREST)
|
92 |
+
mask_img = np.array(mask_img.convert('L'))
|
93 |
+
|
94 |
+
# Dilate 8 pixel so that all known pixel is trustworthy
|
95 |
+
if flow_mask_dilates > 0:
|
96 |
+
flow_mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=flow_mask_dilates).astype(np.uint8)
|
97 |
+
else:
|
98 |
+
flow_mask_img = binary_mask(mask_img).astype(np.uint8)
|
99 |
+
# Close the small holes inside the foreground objects
|
100 |
+
# flow_mask_img = cv2.morphologyEx(flow_mask_img, cv2.MORPH_CLOSE, np.ones((21, 21),np.uint8)).astype(bool)
|
101 |
+
# flow_mask_img = scipy.ndimage.binary_fill_holes(flow_mask_img).astype(np.uint8)
|
102 |
+
flow_masks.append(Image.fromarray(flow_mask_img * 255))
|
103 |
+
|
104 |
+
if mask_dilates > 0:
|
105 |
+
mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=mask_dilates).astype(np.uint8)
|
106 |
+
else:
|
107 |
+
mask_img = binary_mask(mask_img).astype(np.uint8)
|
108 |
+
masks_dilated.append(Image.fromarray(mask_img * 255))
|
109 |
+
|
110 |
+
if len(masks_img) == 1:
|
111 |
+
flow_masks = flow_masks * length
|
112 |
+
masks_dilated = masks_dilated * length
|
113 |
+
|
114 |
+
return flow_masks, masks_dilated
|
115 |
+
|
116 |
+
|
117 |
+
def extrapolation(video_ori, scale):
|
118 |
+
"""Prepares the data for video outpainting.
|
119 |
+
"""
|
120 |
+
nFrame = len(video_ori)
|
121 |
+
imgW, imgH = video_ori[0].size
|
122 |
+
|
123 |
+
# Defines new FOV.
|
124 |
+
imgH_extr = int(scale[0] * imgH)
|
125 |
+
imgW_extr = int(scale[1] * imgW)
|
126 |
+
imgH_extr = imgH_extr - imgH_extr % 8
|
127 |
+
imgW_extr = imgW_extr - imgW_extr % 8
|
128 |
+
H_start = int((imgH_extr - imgH) / 2)
|
129 |
+
W_start = int((imgW_extr - imgW) / 2)
|
130 |
+
|
131 |
+
# Extrapolates the FOV for video.
|
132 |
+
frames = []
|
133 |
+
for v in video_ori:
|
134 |
+
frame = np.zeros(((imgH_extr, imgW_extr, 3)), dtype=np.uint8)
|
135 |
+
frame[H_start: H_start + imgH, W_start: W_start + imgW, :] = v
|
136 |
+
frames.append(Image.fromarray(frame))
|
137 |
+
|
138 |
+
# Generates the mask for missing region.
|
139 |
+
masks_dilated = []
|
140 |
+
flow_masks = []
|
141 |
+
|
142 |
+
dilate_h = 4 if H_start > 10 else 0
|
143 |
+
dilate_w = 4 if W_start > 10 else 0
|
144 |
+
mask = np.ones(((imgH_extr, imgW_extr)), dtype=np.uint8)
|
145 |
+
|
146 |
+
mask[H_start+dilate_h: H_start+imgH-dilate_h,
|
147 |
+
W_start+dilate_w: W_start+imgW-dilate_w] = 0
|
148 |
+
flow_masks.append(Image.fromarray(mask * 255))
|
149 |
+
|
150 |
+
mask[H_start: H_start+imgH, W_start: W_start+imgW] = 0
|
151 |
+
masks_dilated.append(Image.fromarray(mask * 255))
|
152 |
+
|
153 |
+
flow_masks = flow_masks * nFrame
|
154 |
+
masks_dilated = masks_dilated * nFrame
|
155 |
+
|
156 |
+
return frames, flow_masks, masks_dilated, (imgW_extr, imgH_extr)
|
157 |
+
|
158 |
+
|
159 |
+
def get_ref_index(mid_neighbor_id, neighbor_ids, length, ref_stride=10, ref_num=-1):
|
160 |
+
ref_index = []
|
161 |
+
if ref_num == -1:
|
162 |
+
for i in range(0, length, ref_stride):
|
163 |
+
if i not in neighbor_ids:
|
164 |
+
ref_index.append(i)
|
165 |
+
else:
|
166 |
+
start_idx = max(0, mid_neighbor_id - ref_stride * (ref_num // 2))
|
167 |
+
end_idx = min(length, mid_neighbor_id + ref_stride * (ref_num // 2))
|
168 |
+
for i in range(start_idx, end_idx, ref_stride):
|
169 |
+
if i not in neighbor_ids:
|
170 |
+
if len(ref_index) > ref_num:
|
171 |
+
break
|
172 |
+
ref_index.append(i)
|
173 |
+
return ref_index
|
174 |
+
|
175 |
+
|
176 |
+
|
177 |
+
if __name__ == '__main__':
|
178 |
+
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
179 |
+
device = get_device()
|
180 |
+
|
181 |
+
parser = argparse.ArgumentParser()
|
182 |
+
parser.add_argument(
|
183 |
+
'-i', '--video', type=str, default='inputs/object_removal/bmx-trees', help='Path of the input video or image folder.')
|
184 |
+
parser.add_argument(
|
185 |
+
'-m', '--mask', type=str, default='inputs/object_removal/bmx-trees_mask', help='Path of the mask(s) or mask folder.')
|
186 |
+
parser.add_argument(
|
187 |
+
'-o', '--output', type=str, default='results', help='Output folder. Default: results')
|
188 |
+
parser.add_argument(
|
189 |
+
"--resize_ratio", type=float, default=1.0, help='Resize scale for processing video.')
|
190 |
+
parser.add_argument(
|
191 |
+
'--height', type=int, default=-1, help='Height of the processing video.')
|
192 |
+
parser.add_argument(
|
193 |
+
'--width', type=int, default=-1, help='Width of the processing video.')
|
194 |
+
parser.add_argument(
|
195 |
+
'--mask_dilation', type=int, default=4, help='Mask dilation for video and flow masking.')
|
196 |
+
parser.add_argument(
|
197 |
+
"--ref_stride", type=int, default=10, help='Stride of global reference frames.')
|
198 |
+
parser.add_argument(
|
199 |
+
"--neighbor_length", type=int, default=10, help='Length of local neighboring frames.')
|
200 |
+
parser.add_argument(
|
201 |
+
"--subvideo_length", type=int, default=80, help='Length of sub-video for long video inference.')
|
202 |
+
parser.add_argument(
|
203 |
+
"--raft_iter", type=int, default=20, help='Iterations for RAFT inference.')
|
204 |
+
parser.add_argument(
|
205 |
+
'--mode', default='video_inpainting', choices=['video_inpainting', 'video_outpainting'], help="Modes: video_inpainting / video_outpainting")
|
206 |
+
parser.add_argument(
|
207 |
+
'--scale_h', type=float, default=1.0, help='Outpainting scale of height for video_outpainting mode.')
|
208 |
+
parser.add_argument(
|
209 |
+
'--scale_w', type=float, default=1.2, help='Outpainting scale of width for video_outpainting mode.')
|
210 |
+
parser.add_argument(
|
211 |
+
'--save_fps', type=int, default=24, help='Frame per second. Default: 24')
|
212 |
+
parser.add_argument(
|
213 |
+
'--save_frames', action='store_true', help='Save output frames. Default: False')
|
214 |
+
parser.add_argument(
|
215 |
+
'--fp16', action='store_true', help='Use fp16 (half precision) during inference. Default: fp32 (single precision).')
|
216 |
+
|
217 |
+
args = parser.parse_args()
|
218 |
+
|
219 |
+
# Use fp16 precision during inference to reduce running memory cost
|
220 |
+
use_half = True if args.fp16 else False
|
221 |
+
|
222 |
+
|
223 |
+
frames, fps, size, video_name = read_frame_from_videos(args.video)
|
224 |
+
if not args.width == -1 and not args.height == -1:
|
225 |
+
size = (args.width, args.height)
|
226 |
+
if not args.resize_ratio == 1.0:
|
227 |
+
size = (int(args.resize_ratio * size[0]), int(args.resize_ratio * size[1]))
|
228 |
+
|
229 |
+
frames, size, out_size = resize_frames(frames, size)
|
230 |
+
|
231 |
+
fps = args.save_fps if fps is None else fps
|
232 |
+
save_root = os.path.join(args.output, video_name)
|
233 |
+
if not os.path.exists(save_root):
|
234 |
+
os.makedirs(save_root, exist_ok=True)
|
235 |
+
|
236 |
+
if args.mode == 'video_inpainting':
|
237 |
+
frames_len = len(frames)
|
238 |
+
flow_masks, masks_dilated = read_mask(args.mask, frames_len, size,
|
239 |
+
flow_mask_dilates=args.mask_dilation,
|
240 |
+
mask_dilates=args.mask_dilation)
|
241 |
+
w, h = size
|
242 |
+
elif args.mode == 'video_outpainting':
|
243 |
+
assert args.scale_h is not None and args.scale_w is not None, 'Please provide a outpainting scale (s_h, s_w).'
|
244 |
+
frames, flow_masks, masks_dilated, size = extrapolation(frames, (args.scale_h, args.scale_w))
|
245 |
+
w, h = size
|
246 |
+
else:
|
247 |
+
raise NotImplementedError
|
248 |
+
|
249 |
+
# for saving the masked frames or video
|
250 |
+
masked_frame_for_save = []
|
251 |
+
for i in range(len(frames)):
|
252 |
+
mask_ = np.expand_dims(np.array(masks_dilated[i]),2).repeat(3, axis=2)/255.
|
253 |
+
img = np.array(frames[i])
|
254 |
+
green = np.zeros([h, w, 3])
|
255 |
+
green[:,:,1] = 255
|
256 |
+
alpha = 0.6
|
257 |
+
# alpha = 1.0
|
258 |
+
fuse_img = (1-alpha)*img + alpha*green
|
259 |
+
fuse_img = mask_ * fuse_img + (1-mask_)*img
|
260 |
+
masked_frame_for_save.append(fuse_img.astype(np.uint8))
|
261 |
+
|
262 |
+
frames_inp = [np.array(f).astype(np.uint8) for f in frames]
|
263 |
+
frames = to_tensors()(frames).unsqueeze(0) * 2 - 1
|
264 |
+
flow_masks = to_tensors()(flow_masks).unsqueeze(0)
|
265 |
+
masks_dilated = to_tensors()(masks_dilated).unsqueeze(0)
|
266 |
+
frames, flow_masks, masks_dilated = frames.to(device), flow_masks.to(device), masks_dilated.to(device)
|
267 |
+
|
268 |
+
|
269 |
+
##############################################
|
270 |
+
# set up RAFT and flow competition model
|
271 |
+
##############################################
|
272 |
+
ckpt_path = load_file_from_url(url=os.path.join(pretrain_model_url, 'raft-things.pth'),
|
273 |
+
model_dir='weights', progress=True, file_name=None)
|
274 |
+
fix_raft = RAFT_bi(ckpt_path, device)
|
275 |
+
|
276 |
+
ckpt_path = load_file_from_url(url=os.path.join(pretrain_model_url, 'recurrent_flow_completion.pth'),
|
277 |
+
model_dir='weights', progress=True, file_name=None)
|
278 |
+
fix_flow_complete = RecurrentFlowCompleteNet(ckpt_path)
|
279 |
+
for p in fix_flow_complete.parameters():
|
280 |
+
p.requires_grad = False
|
281 |
+
fix_flow_complete.to(device)
|
282 |
+
fix_flow_complete.eval()
|
283 |
+
|
284 |
+
|
285 |
+
##############################################
|
286 |
+
# set up ProPainter model
|
287 |
+
##############################################
|
288 |
+
ckpt_path = load_file_from_url(url=os.path.join(pretrain_model_url, 'ProPainter.pth'),
|
289 |
+
model_dir='weights', progress=True, file_name=None)
|
290 |
+
model = InpaintGenerator(model_path=ckpt_path).to(device)
|
291 |
+
model.eval()
|
292 |
+
|
293 |
+
|
294 |
+
##############################################
|
295 |
+
# ProPainter inference
|
296 |
+
##############################################
|
297 |
+
video_length = frames.size(1)
|
298 |
+
print(f'\nProcessing: {video_name} [{video_length} frames]...')
|
299 |
+
with torch.no_grad():
|
300 |
+
# ---- compute flow ----
|
301 |
+
if frames.size(-1) <= 640:
|
302 |
+
short_clip_len = 12
|
303 |
+
elif frames.size(-1) <= 720:
|
304 |
+
short_clip_len = 8
|
305 |
+
elif frames.size(-1) <= 1280:
|
306 |
+
short_clip_len = 4
|
307 |
+
else:
|
308 |
+
short_clip_len = 2
|
309 |
+
|
310 |
+
# use fp32 for RAFT
|
311 |
+
if frames.size(1) > short_clip_len:
|
312 |
+
gt_flows_f_list, gt_flows_b_list = [], []
|
313 |
+
for f in range(0, video_length, short_clip_len):
|
314 |
+
end_f = min(video_length, f + short_clip_len)
|
315 |
+
if f == 0:
|
316 |
+
flows_f, flows_b = fix_raft(frames[:,f:end_f], iters=args.raft_iter)
|
317 |
+
else:
|
318 |
+
flows_f, flows_b = fix_raft(frames[:,f-1:end_f], iters=args.raft_iter)
|
319 |
+
|
320 |
+
gt_flows_f_list.append(flows_f)
|
321 |
+
gt_flows_b_list.append(flows_b)
|
322 |
+
torch.cuda.empty_cache()
|
323 |
+
|
324 |
+
gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
|
325 |
+
gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
|
326 |
+
gt_flows_bi = (gt_flows_f, gt_flows_b)
|
327 |
+
else:
|
328 |
+
gt_flows_bi = fix_raft(frames, iters=args.raft_iter)
|
329 |
+
torch.cuda.empty_cache()
|
330 |
+
|
331 |
+
|
332 |
+
if use_half:
|
333 |
+
frames, flow_masks, masks_dilated = frames.half(), flow_masks.half(), masks_dilated.half()
|
334 |
+
gt_flows_bi = (gt_flows_bi[0].half(), gt_flows_bi[1].half())
|
335 |
+
fix_flow_complete = fix_flow_complete.half()
|
336 |
+
model = model.half()
|
337 |
+
|
338 |
+
|
339 |
+
# ---- complete flow ----
|
340 |
+
flow_length = gt_flows_bi[0].size(1)
|
341 |
+
if flow_length > args.subvideo_length:
|
342 |
+
pred_flows_f, pred_flows_b = [], []
|
343 |
+
pad_len = 5
|
344 |
+
for f in range(0, flow_length, args.subvideo_length):
|
345 |
+
s_f = max(0, f - pad_len)
|
346 |
+
e_f = min(flow_length, f + args.subvideo_length + pad_len)
|
347 |
+
pad_len_s = max(0, f) - s_f
|
348 |
+
pad_len_e = e_f - min(flow_length, f + args.subvideo_length)
|
349 |
+
pred_flows_bi_sub, _ = fix_flow_complete.forward_bidirect_flow(
|
350 |
+
(gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]),
|
351 |
+
flow_masks[:, s_f:e_f+1])
|
352 |
+
pred_flows_bi_sub = fix_flow_complete.combine_flow(
|
353 |
+
(gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]),
|
354 |
+
pred_flows_bi_sub,
|
355 |
+
flow_masks[:, s_f:e_f+1])
|
356 |
+
|
357 |
+
pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e])
|
358 |
+
pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e])
|
359 |
+
torch.cuda.empty_cache()
|
360 |
+
|
361 |
+
pred_flows_f = torch.cat(pred_flows_f, dim=1)
|
362 |
+
pred_flows_b = torch.cat(pred_flows_b, dim=1)
|
363 |
+
pred_flows_bi = (pred_flows_f, pred_flows_b)
|
364 |
+
else:
|
365 |
+
pred_flows_bi, _ = fix_flow_complete.forward_bidirect_flow(gt_flows_bi, flow_masks)
|
366 |
+
pred_flows_bi = fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, flow_masks)
|
367 |
+
torch.cuda.empty_cache()
|
368 |
+
|
369 |
+
|
370 |
+
# ---- image propagation ----
|
371 |
+
masked_frames = frames * (1 - masks_dilated)
|
372 |
+
subvideo_length_img_prop = min(100, args.subvideo_length) # ensure a minimum of 100 frames for image propagation
|
373 |
+
if video_length > subvideo_length_img_prop:
|
374 |
+
updated_frames, updated_masks = [], []
|
375 |
+
pad_len = 10
|
376 |
+
for f in range(0, video_length, subvideo_length_img_prop):
|
377 |
+
s_f = max(0, f - pad_len)
|
378 |
+
e_f = min(video_length, f + subvideo_length_img_prop + pad_len)
|
379 |
+
pad_len_s = max(0, f) - s_f
|
380 |
+
pad_len_e = e_f - min(video_length, f + subvideo_length_img_prop)
|
381 |
+
|
382 |
+
b, t, _, _, _ = masks_dilated[:, s_f:e_f].size()
|
383 |
+
pred_flows_bi_sub = (pred_flows_bi[0][:, s_f:e_f-1], pred_flows_bi[1][:, s_f:e_f-1])
|
384 |
+
prop_imgs_sub, updated_local_masks_sub = model.img_propagation(masked_frames[:, s_f:e_f],
|
385 |
+
pred_flows_bi_sub,
|
386 |
+
masks_dilated[:, s_f:e_f],
|
387 |
+
'nearest')
|
388 |
+
updated_frames_sub = frames[:, s_f:e_f] * (1 - masks_dilated[:, s_f:e_f]) + \
|
389 |
+
prop_imgs_sub.view(b, t, 3, h, w) * masks_dilated[:, s_f:e_f]
|
390 |
+
updated_masks_sub = updated_local_masks_sub.view(b, t, 1, h, w)
|
391 |
+
|
392 |
+
updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e])
|
393 |
+
updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e])
|
394 |
+
torch.cuda.empty_cache()
|
395 |
+
|
396 |
+
updated_frames = torch.cat(updated_frames, dim=1)
|
397 |
+
updated_masks = torch.cat(updated_masks, dim=1)
|
398 |
+
else:
|
399 |
+
b, t, _, _, _ = masks_dilated.size()
|
400 |
+
prop_imgs, updated_local_masks = model.img_propagation(masked_frames, pred_flows_bi, masks_dilated, 'nearest')
|
401 |
+
updated_frames = frames * (1 - masks_dilated) + prop_imgs.view(b, t, 3, h, w) * masks_dilated
|
402 |
+
updated_masks = updated_local_masks.view(b, t, 1, h, w)
|
403 |
+
torch.cuda.empty_cache()
|
404 |
+
|
405 |
+
|
406 |
+
ori_frames = frames_inp
|
407 |
+
comp_frames = [None] * video_length
|
408 |
+
|
409 |
+
neighbor_stride = args.neighbor_length // 2
|
410 |
+
if video_length > args.subvideo_length:
|
411 |
+
ref_num = args.subvideo_length // args.ref_stride
|
412 |
+
else:
|
413 |
+
ref_num = -1
|
414 |
+
|
415 |
+
# ---- feature propagation + transformer ----
|
416 |
+
for f in tqdm(range(0, video_length, neighbor_stride)):
|
417 |
+
neighbor_ids = [
|
418 |
+
i for i in range(max(0, f - neighbor_stride),
|
419 |
+
min(video_length, f + neighbor_stride + 1))
|
420 |
+
]
|
421 |
+
ref_ids = get_ref_index(f, neighbor_ids, video_length, args.ref_stride, ref_num)
|
422 |
+
selected_imgs = updated_frames[:, neighbor_ids + ref_ids, :, :, :]
|
423 |
+
selected_masks = masks_dilated[:, neighbor_ids + ref_ids, :, :, :]
|
424 |
+
selected_update_masks = updated_masks[:, neighbor_ids + ref_ids, :, :, :]
|
425 |
+
selected_pred_flows_bi = (pred_flows_bi[0][:, neighbor_ids[:-1], :, :, :], pred_flows_bi[1][:, neighbor_ids[:-1], :, :, :])
|
426 |
+
|
427 |
+
with torch.no_grad():
|
428 |
+
# 1.0 indicates mask
|
429 |
+
l_t = len(neighbor_ids)
|
430 |
+
|
431 |
+
# pred_img = selected_imgs # results of image propagation
|
432 |
+
pred_img = model(selected_imgs, selected_pred_flows_bi, selected_masks, selected_update_masks, l_t)
|
433 |
+
|
434 |
+
pred_img = pred_img.view(-1, 3, h, w)
|
435 |
+
|
436 |
+
pred_img = (pred_img + 1) / 2
|
437 |
+
pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255
|
438 |
+
binary_masks = masks_dilated[0, neighbor_ids, :, :, :].cpu().permute(
|
439 |
+
0, 2, 3, 1).numpy().astype(np.uint8)
|
440 |
+
for i in range(len(neighbor_ids)):
|
441 |
+
idx = neighbor_ids[i]
|
442 |
+
img = np.array(pred_img[i]).astype(np.uint8) * binary_masks[i] \
|
443 |
+
+ ori_frames[idx] * (1 - binary_masks[i])
|
444 |
+
if comp_frames[idx] is None:
|
445 |
+
comp_frames[idx] = img
|
446 |
+
else:
|
447 |
+
comp_frames[idx] = comp_frames[idx].astype(np.float32) * 0.5 + img.astype(np.float32) * 0.5
|
448 |
+
|
449 |
+
comp_frames[idx] = comp_frames[idx].astype(np.uint8)
|
450 |
+
|
451 |
+
torch.cuda.empty_cache()
|
452 |
+
|
453 |
+
# save each frame
|
454 |
+
if args.save_frames:
|
455 |
+
for idx in range(video_length):
|
456 |
+
f = comp_frames[idx]
|
457 |
+
f = cv2.resize(f, out_size, interpolation = cv2.INTER_CUBIC)
|
458 |
+
f = cv2.cvtColor(f, cv2.COLOR_BGR2RGB)
|
459 |
+
img_save_root = os.path.join(save_root, 'frames', str(idx).zfill(4)+'.png')
|
460 |
+
imwrite(f, img_save_root)
|
461 |
+
|
462 |
+
|
463 |
+
# if args.mode == 'video_outpainting':
|
464 |
+
# comp_frames = [i[10:-10,10:-10] for i in comp_frames]
|
465 |
+
# masked_frame_for_save = [i[10:-10,10:-10] for i in masked_frame_for_save]
|
466 |
+
|
467 |
+
# save videos frame
|
468 |
+
masked_frame_for_save = [cv2.resize(f, out_size) for f in masked_frame_for_save]
|
469 |
+
comp_frames = [cv2.resize(f, out_size) for f in comp_frames]
|
470 |
+
imageio.mimwrite(os.path.join(save_root, 'masked_in.mp4'), masked_frame_for_save, fps=fps, quality=7)
|
471 |
+
imageio.mimwrite(os.path.join(save_root, 'inpaint_out.mp4'), comp_frames, fps=fps, quality=7)
|
472 |
+
|
473 |
+
print(f'\nAll results are saved in {save_root}')
|
474 |
+
|
475 |
+
torch.cuda.empty_cache()
|
model/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
model/canny/canny_filter.py
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from .gaussian import gaussian_blur2d
|
9 |
+
from .kernels import get_canny_nms_kernel, get_hysteresis_kernel
|
10 |
+
from .sobel import spatial_gradient
|
11 |
+
|
12 |
+
def rgb_to_grayscale(image, rgb_weights = None):
|
13 |
+
if len(image.shape) < 3 or image.shape[-3] != 3:
|
14 |
+
raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
|
15 |
+
|
16 |
+
if rgb_weights is None:
|
17 |
+
# 8 bit images
|
18 |
+
if image.dtype == torch.uint8:
|
19 |
+
rgb_weights = torch.tensor([76, 150, 29], device=image.device, dtype=torch.uint8)
|
20 |
+
# floating point images
|
21 |
+
elif image.dtype in (torch.float16, torch.float32, torch.float64):
|
22 |
+
rgb_weights = torch.tensor([0.299, 0.587, 0.114], device=image.device, dtype=image.dtype)
|
23 |
+
else:
|
24 |
+
raise TypeError(f"Unknown data type: {image.dtype}")
|
25 |
+
else:
|
26 |
+
# is tensor that we make sure is in the same device/dtype
|
27 |
+
rgb_weights = rgb_weights.to(image)
|
28 |
+
|
29 |
+
# unpack the color image channels with RGB order
|
30 |
+
r = image[..., 0:1, :, :]
|
31 |
+
g = image[..., 1:2, :, :]
|
32 |
+
b = image[..., 2:3, :, :]
|
33 |
+
|
34 |
+
w_r, w_g, w_b = rgb_weights.unbind()
|
35 |
+
return w_r * r + w_g * g + w_b * b
|
36 |
+
|
37 |
+
|
38 |
+
def canny(
|
39 |
+
input: torch.Tensor,
|
40 |
+
low_threshold: float = 0.1,
|
41 |
+
high_threshold: float = 0.2,
|
42 |
+
kernel_size: Tuple[int, int] = (5, 5),
|
43 |
+
sigma: Tuple[float, float] = (1, 1),
|
44 |
+
hysteresis: bool = True,
|
45 |
+
eps: float = 1e-6,
|
46 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
47 |
+
r"""Find edges of the input image and filters them using the Canny algorithm.
|
48 |
+
|
49 |
+
.. image:: _static/img/canny.png
|
50 |
+
|
51 |
+
Args:
|
52 |
+
input: input image tensor with shape :math:`(B,C,H,W)`.
|
53 |
+
low_threshold: lower threshold for the hysteresis procedure.
|
54 |
+
high_threshold: upper threshold for the hysteresis procedure.
|
55 |
+
kernel_size: the size of the kernel for the gaussian blur.
|
56 |
+
sigma: the standard deviation of the kernel for the gaussian blur.
|
57 |
+
hysteresis: if True, applies the hysteresis edge tracking.
|
58 |
+
Otherwise, the edges are divided between weak (0.5) and strong (1) edges.
|
59 |
+
eps: regularization number to avoid NaN during backprop.
|
60 |
+
|
61 |
+
Returns:
|
62 |
+
- the canny edge magnitudes map, shape of :math:`(B,1,H,W)`.
|
63 |
+
- the canny edge detection filtered by thresholds and hysteresis, shape of :math:`(B,1,H,W)`.
|
64 |
+
|
65 |
+
.. note::
|
66 |
+
See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/
|
67 |
+
canny.html>`__.
|
68 |
+
|
69 |
+
Example:
|
70 |
+
>>> input = torch.rand(5, 3, 4, 4)
|
71 |
+
>>> magnitude, edges = canny(input) # 5x3x4x4
|
72 |
+
>>> magnitude.shape
|
73 |
+
torch.Size([5, 1, 4, 4])
|
74 |
+
>>> edges.shape
|
75 |
+
torch.Size([5, 1, 4, 4])
|
76 |
+
"""
|
77 |
+
if not isinstance(input, torch.Tensor):
|
78 |
+
raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
|
79 |
+
|
80 |
+
if not len(input.shape) == 4:
|
81 |
+
raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}")
|
82 |
+
|
83 |
+
if low_threshold > high_threshold:
|
84 |
+
raise ValueError(
|
85 |
+
"Invalid input thresholds. low_threshold should be smaller than the high_threshold. Got: {}>{}".format(
|
86 |
+
low_threshold, high_threshold
|
87 |
+
)
|
88 |
+
)
|
89 |
+
|
90 |
+
if low_threshold < 0 and low_threshold > 1:
|
91 |
+
raise ValueError(f"Invalid input threshold. low_threshold should be in range (0,1). Got: {low_threshold}")
|
92 |
+
|
93 |
+
if high_threshold < 0 and high_threshold > 1:
|
94 |
+
raise ValueError(f"Invalid input threshold. high_threshold should be in range (0,1). Got: {high_threshold}")
|
95 |
+
|
96 |
+
device: torch.device = input.device
|
97 |
+
dtype: torch.dtype = input.dtype
|
98 |
+
|
99 |
+
# To Grayscale
|
100 |
+
if input.shape[1] == 3:
|
101 |
+
input = rgb_to_grayscale(input)
|
102 |
+
|
103 |
+
# Gaussian filter
|
104 |
+
blurred: torch.Tensor = gaussian_blur2d(input, kernel_size, sigma)
|
105 |
+
|
106 |
+
# Compute the gradients
|
107 |
+
gradients: torch.Tensor = spatial_gradient(blurred, normalized=False)
|
108 |
+
|
109 |
+
# Unpack the edges
|
110 |
+
gx: torch.Tensor = gradients[:, :, 0]
|
111 |
+
gy: torch.Tensor = gradients[:, :, 1]
|
112 |
+
|
113 |
+
# Compute gradient magnitude and angle
|
114 |
+
magnitude: torch.Tensor = torch.sqrt(gx * gx + gy * gy + eps)
|
115 |
+
angle: torch.Tensor = torch.atan2(gy, gx)
|
116 |
+
|
117 |
+
# Radians to Degrees
|
118 |
+
angle = 180.0 * angle / math.pi
|
119 |
+
|
120 |
+
# Round angle to the nearest 45 degree
|
121 |
+
angle = torch.round(angle / 45) * 45
|
122 |
+
|
123 |
+
# Non-maximal suppression
|
124 |
+
nms_kernels: torch.Tensor = get_canny_nms_kernel(device, dtype)
|
125 |
+
nms_magnitude: torch.Tensor = F.conv2d(magnitude, nms_kernels, padding=nms_kernels.shape[-1] // 2)
|
126 |
+
|
127 |
+
# Get the indices for both directions
|
128 |
+
positive_idx: torch.Tensor = (angle / 45) % 8
|
129 |
+
positive_idx = positive_idx.long()
|
130 |
+
|
131 |
+
negative_idx: torch.Tensor = ((angle / 45) + 4) % 8
|
132 |
+
negative_idx = negative_idx.long()
|
133 |
+
|
134 |
+
# Apply the non-maximum suppression to the different directions
|
135 |
+
channel_select_filtered_positive: torch.Tensor = torch.gather(nms_magnitude, 1, positive_idx)
|
136 |
+
channel_select_filtered_negative: torch.Tensor = torch.gather(nms_magnitude, 1, negative_idx)
|
137 |
+
|
138 |
+
channel_select_filtered: torch.Tensor = torch.stack(
|
139 |
+
[channel_select_filtered_positive, channel_select_filtered_negative], 1
|
140 |
+
)
|
141 |
+
|
142 |
+
is_max: torch.Tensor = channel_select_filtered.min(dim=1)[0] > 0.0
|
143 |
+
|
144 |
+
magnitude = magnitude * is_max
|
145 |
+
|
146 |
+
# Threshold
|
147 |
+
edges: torch.Tensor = F.threshold(magnitude, low_threshold, 0.0)
|
148 |
+
|
149 |
+
low: torch.Tensor = magnitude > low_threshold
|
150 |
+
high: torch.Tensor = magnitude > high_threshold
|
151 |
+
|
152 |
+
edges = low * 0.5 + high * 0.5
|
153 |
+
edges = edges.to(dtype)
|
154 |
+
|
155 |
+
# Hysteresis
|
156 |
+
if hysteresis:
|
157 |
+
edges_old: torch.Tensor = -torch.ones(edges.shape, device=edges.device, dtype=dtype)
|
158 |
+
hysteresis_kernels: torch.Tensor = get_hysteresis_kernel(device, dtype)
|
159 |
+
|
160 |
+
while ((edges_old - edges).abs() != 0).any():
|
161 |
+
weak: torch.Tensor = (edges == 0.5).float()
|
162 |
+
strong: torch.Tensor = (edges == 1).float()
|
163 |
+
|
164 |
+
hysteresis_magnitude: torch.Tensor = F.conv2d(
|
165 |
+
edges, hysteresis_kernels, padding=hysteresis_kernels.shape[-1] // 2
|
166 |
+
)
|
167 |
+
hysteresis_magnitude = (hysteresis_magnitude == 1).any(1, keepdim=True).to(dtype)
|
168 |
+
hysteresis_magnitude = hysteresis_magnitude * weak + strong
|
169 |
+
|
170 |
+
edges_old = edges.clone()
|
171 |
+
edges = hysteresis_magnitude + (hysteresis_magnitude == 0) * weak * 0.5
|
172 |
+
|
173 |
+
edges = hysteresis_magnitude
|
174 |
+
|
175 |
+
return magnitude, edges
|
176 |
+
|
177 |
+
|
178 |
+
class Canny(nn.Module):
|
179 |
+
r"""Module that finds edges of the input image and filters them using the Canny algorithm.
|
180 |
+
|
181 |
+
Args:
|
182 |
+
input: input image tensor with shape :math:`(B,C,H,W)`.
|
183 |
+
low_threshold: lower threshold for the hysteresis procedure.
|
184 |
+
high_threshold: upper threshold for the hysteresis procedure.
|
185 |
+
kernel_size: the size of the kernel for the gaussian blur.
|
186 |
+
sigma: the standard deviation of the kernel for the gaussian blur.
|
187 |
+
hysteresis: if True, applies the hysteresis edge tracking.
|
188 |
+
Otherwise, the edges are divided between weak (0.5) and strong (1) edges.
|
189 |
+
eps: regularization number to avoid NaN during backprop.
|
190 |
+
|
191 |
+
Returns:
|
192 |
+
- the canny edge magnitudes map, shape of :math:`(B,1,H,W)`.
|
193 |
+
- the canny edge detection filtered by thresholds and hysteresis, shape of :math:`(B,1,H,W)`.
|
194 |
+
|
195 |
+
Example:
|
196 |
+
>>> input = torch.rand(5, 3, 4, 4)
|
197 |
+
>>> magnitude, edges = Canny()(input) # 5x3x4x4
|
198 |
+
>>> magnitude.shape
|
199 |
+
torch.Size([5, 1, 4, 4])
|
200 |
+
>>> edges.shape
|
201 |
+
torch.Size([5, 1, 4, 4])
|
202 |
+
"""
|
203 |
+
|
204 |
+
def __init__(
|
205 |
+
self,
|
206 |
+
low_threshold: float = 0.1,
|
207 |
+
high_threshold: float = 0.2,
|
208 |
+
kernel_size: Tuple[int, int] = (5, 5),
|
209 |
+
sigma: Tuple[float, float] = (1, 1),
|
210 |
+
hysteresis: bool = True,
|
211 |
+
eps: float = 1e-6,
|
212 |
+
) -> None:
|
213 |
+
super().__init__()
|
214 |
+
|
215 |
+
if low_threshold > high_threshold:
|
216 |
+
raise ValueError(
|
217 |
+
"Invalid input thresholds. low_threshold should be\
|
218 |
+
smaller than the high_threshold. Got: {}>{}".format(
|
219 |
+
low_threshold, high_threshold
|
220 |
+
)
|
221 |
+
)
|
222 |
+
|
223 |
+
if low_threshold < 0 or low_threshold > 1:
|
224 |
+
raise ValueError(f"Invalid input threshold. low_threshold should be in range (0,1). Got: {low_threshold}")
|
225 |
+
|
226 |
+
if high_threshold < 0 or high_threshold > 1:
|
227 |
+
raise ValueError(f"Invalid input threshold. high_threshold should be in range (0,1). Got: {high_threshold}")
|
228 |
+
|
229 |
+
# Gaussian blur parameters
|
230 |
+
self.kernel_size = kernel_size
|
231 |
+
self.sigma = sigma
|
232 |
+
|
233 |
+
# Double threshold
|
234 |
+
self.low_threshold = low_threshold
|
235 |
+
self.high_threshold = high_threshold
|
236 |
+
|
237 |
+
# Hysteresis
|
238 |
+
self.hysteresis = hysteresis
|
239 |
+
|
240 |
+
self.eps: float = eps
|
241 |
+
|
242 |
+
def __repr__(self) -> str:
|
243 |
+
return ''.join(
|
244 |
+
(
|
245 |
+
f'{type(self).__name__}(',
|
246 |
+
', '.join(
|
247 |
+
f'{name}={getattr(self, name)}' for name in sorted(self.__dict__) if not name.startswith('_')
|
248 |
+
),
|
249 |
+
')',
|
250 |
+
)
|
251 |
+
)
|
252 |
+
|
253 |
+
def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
254 |
+
return canny(
|
255 |
+
input, self.low_threshold, self.high_threshold, self.kernel_size, self.sigma, self.hysteresis, self.eps
|
256 |
+
)
|
model/canny/filter.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from .kernels import normalize_kernel2d
|
7 |
+
|
8 |
+
|
9 |
+
def _compute_padding(kernel_size: List[int]) -> List[int]:
|
10 |
+
"""Compute padding tuple."""
|
11 |
+
# 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
|
12 |
+
# https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
|
13 |
+
if len(kernel_size) < 2:
|
14 |
+
raise AssertionError(kernel_size)
|
15 |
+
computed = [k - 1 for k in kernel_size]
|
16 |
+
|
17 |
+
# for even kernels we need to do asymmetric padding :(
|
18 |
+
out_padding = 2 * len(kernel_size) * [0]
|
19 |
+
|
20 |
+
for i in range(len(kernel_size)):
|
21 |
+
computed_tmp = computed[-(i + 1)]
|
22 |
+
|
23 |
+
pad_front = computed_tmp // 2
|
24 |
+
pad_rear = computed_tmp - pad_front
|
25 |
+
|
26 |
+
out_padding[2 * i + 0] = pad_front
|
27 |
+
out_padding[2 * i + 1] = pad_rear
|
28 |
+
|
29 |
+
return out_padding
|
30 |
+
|
31 |
+
|
32 |
+
def filter2d(
|
33 |
+
input: torch.Tensor,
|
34 |
+
kernel: torch.Tensor,
|
35 |
+
border_type: str = 'reflect',
|
36 |
+
normalized: bool = False,
|
37 |
+
padding: str = 'same',
|
38 |
+
) -> torch.Tensor:
|
39 |
+
r"""Convolve a tensor with a 2d kernel.
|
40 |
+
|
41 |
+
The function applies a given kernel to a tensor. The kernel is applied
|
42 |
+
independently at each depth channel of the tensor. Before applying the
|
43 |
+
kernel, the function applies padding according to the specified mode so
|
44 |
+
that the output remains in the same shape.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
input: the input tensor with shape of
|
48 |
+
:math:`(B, C, H, W)`.
|
49 |
+
kernel: the kernel to be convolved with the input
|
50 |
+
tensor. The kernel shape must be :math:`(1, kH, kW)` or :math:`(B, kH, kW)`.
|
51 |
+
border_type: the padding mode to be applied before convolving.
|
52 |
+
The expected modes are: ``'constant'``, ``'reflect'``,
|
53 |
+
``'replicate'`` or ``'circular'``.
|
54 |
+
normalized: If True, kernel will be L1 normalized.
|
55 |
+
padding: This defines the type of padding.
|
56 |
+
2 modes available ``'same'`` or ``'valid'``.
|
57 |
+
|
58 |
+
Return:
|
59 |
+
torch.Tensor: the convolved tensor of same size and numbers of channels
|
60 |
+
as the input with shape :math:`(B, C, H, W)`.
|
61 |
+
|
62 |
+
Example:
|
63 |
+
>>> input = torch.tensor([[[
|
64 |
+
... [0., 0., 0., 0., 0.],
|
65 |
+
... [0., 0., 0., 0., 0.],
|
66 |
+
... [0., 0., 5., 0., 0.],
|
67 |
+
... [0., 0., 0., 0., 0.],
|
68 |
+
... [0., 0., 0., 0., 0.],]]])
|
69 |
+
>>> kernel = torch.ones(1, 3, 3)
|
70 |
+
>>> filter2d(input, kernel, padding='same')
|
71 |
+
tensor([[[[0., 0., 0., 0., 0.],
|
72 |
+
[0., 5., 5., 5., 0.],
|
73 |
+
[0., 5., 5., 5., 0.],
|
74 |
+
[0., 5., 5., 5., 0.],
|
75 |
+
[0., 0., 0., 0., 0.]]]])
|
76 |
+
"""
|
77 |
+
if not isinstance(input, torch.Tensor):
|
78 |
+
raise TypeError(f"Input input is not torch.Tensor. Got {type(input)}")
|
79 |
+
|
80 |
+
if not isinstance(kernel, torch.Tensor):
|
81 |
+
raise TypeError(f"Input kernel is not torch.Tensor. Got {type(kernel)}")
|
82 |
+
|
83 |
+
if not isinstance(border_type, str):
|
84 |
+
raise TypeError(f"Input border_type is not string. Got {type(border_type)}")
|
85 |
+
|
86 |
+
if border_type not in ['constant', 'reflect', 'replicate', 'circular']:
|
87 |
+
raise ValueError(
|
88 |
+
f"Invalid border type, we expect 'constant', \
|
89 |
+
'reflect', 'replicate', 'circular'. Got:{border_type}"
|
90 |
+
)
|
91 |
+
|
92 |
+
if not isinstance(padding, str):
|
93 |
+
raise TypeError(f"Input padding is not string. Got {type(padding)}")
|
94 |
+
|
95 |
+
if padding not in ['valid', 'same']:
|
96 |
+
raise ValueError(f"Invalid padding mode, we expect 'valid' or 'same'. Got: {padding}")
|
97 |
+
|
98 |
+
if not len(input.shape) == 4:
|
99 |
+
raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}")
|
100 |
+
|
101 |
+
if (not len(kernel.shape) == 3) and not ((kernel.shape[0] == 0) or (kernel.shape[0] == input.shape[0])):
|
102 |
+
raise ValueError(f"Invalid kernel shape, we expect 1xHxW or BxHxW. Got: {kernel.shape}")
|
103 |
+
|
104 |
+
# prepare kernel
|
105 |
+
b, c, h, w = input.shape
|
106 |
+
tmp_kernel: torch.Tensor = kernel.unsqueeze(1).to(input)
|
107 |
+
|
108 |
+
if normalized:
|
109 |
+
tmp_kernel = normalize_kernel2d(tmp_kernel)
|
110 |
+
|
111 |
+
tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
|
112 |
+
|
113 |
+
height, width = tmp_kernel.shape[-2:]
|
114 |
+
|
115 |
+
# pad the input tensor
|
116 |
+
if padding == 'same':
|
117 |
+
padding_shape: List[int] = _compute_padding([height, width])
|
118 |
+
input = F.pad(input, padding_shape, mode=border_type)
|
119 |
+
|
120 |
+
# kernel and input tensor reshape to align element-wise or batch-wise params
|
121 |
+
tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
|
122 |
+
input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
|
123 |
+
|
124 |
+
# convolve the tensor with the kernel.
|
125 |
+
output = F.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
|
126 |
+
|
127 |
+
if padding == 'same':
|
128 |
+
out = output.view(b, c, h, w)
|
129 |
+
else:
|
130 |
+
out = output.view(b, c, h - height + 1, w - width + 1)
|
131 |
+
|
132 |
+
return out
|
133 |
+
|
134 |
+
|
135 |
+
def filter2d_separable(
|
136 |
+
input: torch.Tensor,
|
137 |
+
kernel_x: torch.Tensor,
|
138 |
+
kernel_y: torch.Tensor,
|
139 |
+
border_type: str = 'reflect',
|
140 |
+
normalized: bool = False,
|
141 |
+
padding: str = 'same',
|
142 |
+
) -> torch.Tensor:
|
143 |
+
r"""Convolve a tensor with two 1d kernels, in x and y directions.
|
144 |
+
|
145 |
+
The function applies a given kernel to a tensor. The kernel is applied
|
146 |
+
independently at each depth channel of the tensor. Before applying the
|
147 |
+
kernel, the function applies padding according to the specified mode so
|
148 |
+
that the output remains in the same shape.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
input: the input tensor with shape of
|
152 |
+
:math:`(B, C, H, W)`.
|
153 |
+
kernel_x: the kernel to be convolved with the input
|
154 |
+
tensor. The kernel shape must be :math:`(1, kW)` or :math:`(B, kW)`.
|
155 |
+
kernel_y: the kernel to be convolved with the input
|
156 |
+
tensor. The kernel shape must be :math:`(1, kH)` or :math:`(B, kH)`.
|
157 |
+
border_type: the padding mode to be applied before convolving.
|
158 |
+
The expected modes are: ``'constant'``, ``'reflect'``,
|
159 |
+
``'replicate'`` or ``'circular'``.
|
160 |
+
normalized: If True, kernel will be L1 normalized.
|
161 |
+
padding: This defines the type of padding.
|
162 |
+
2 modes available ``'same'`` or ``'valid'``.
|
163 |
+
|
164 |
+
Return:
|
165 |
+
torch.Tensor: the convolved tensor of same size and numbers of channels
|
166 |
+
as the input with shape :math:`(B, C, H, W)`.
|
167 |
+
|
168 |
+
Example:
|
169 |
+
>>> input = torch.tensor([[[
|
170 |
+
... [0., 0., 0., 0., 0.],
|
171 |
+
... [0., 0., 0., 0., 0.],
|
172 |
+
... [0., 0., 5., 0., 0.],
|
173 |
+
... [0., 0., 0., 0., 0.],
|
174 |
+
... [0., 0., 0., 0., 0.],]]])
|
175 |
+
>>> kernel = torch.ones(1, 3)
|
176 |
+
|
177 |
+
>>> filter2d_separable(input, kernel, kernel, padding='same')
|
178 |
+
tensor([[[[0., 0., 0., 0., 0.],
|
179 |
+
[0., 5., 5., 5., 0.],
|
180 |
+
[0., 5., 5., 5., 0.],
|
181 |
+
[0., 5., 5., 5., 0.],
|
182 |
+
[0., 0., 0., 0., 0.]]]])
|
183 |
+
"""
|
184 |
+
out_x = filter2d(input, kernel_x.unsqueeze(0), border_type, normalized, padding)
|
185 |
+
out = filter2d(out_x, kernel_y.unsqueeze(-1), border_type, normalized, padding)
|
186 |
+
return out
|
187 |
+
|
188 |
+
|
189 |
+
def filter3d(
|
190 |
+
input: torch.Tensor, kernel: torch.Tensor, border_type: str = 'replicate', normalized: bool = False
|
191 |
+
) -> torch.Tensor:
|
192 |
+
r"""Convolve a tensor with a 3d kernel.
|
193 |
+
|
194 |
+
The function applies a given kernel to a tensor. The kernel is applied
|
195 |
+
independently at each depth channel of the tensor. Before applying the
|
196 |
+
kernel, the function applies padding according to the specified mode so
|
197 |
+
that the output remains in the same shape.
|
198 |
+
|
199 |
+
Args:
|
200 |
+
input: the input tensor with shape of
|
201 |
+
:math:`(B, C, D, H, W)`.
|
202 |
+
kernel: the kernel to be convolved with the input
|
203 |
+
tensor. The kernel shape must be :math:`(1, kD, kH, kW)` or :math:`(B, kD, kH, kW)`.
|
204 |
+
border_type: the padding mode to be applied before convolving.
|
205 |
+
The expected modes are: ``'constant'``,
|
206 |
+
``'replicate'`` or ``'circular'``.
|
207 |
+
normalized: If True, kernel will be L1 normalized.
|
208 |
+
|
209 |
+
Return:
|
210 |
+
the convolved tensor of same size and numbers of channels
|
211 |
+
as the input with shape :math:`(B, C, D, H, W)`.
|
212 |
+
|
213 |
+
Example:
|
214 |
+
>>> input = torch.tensor([[[
|
215 |
+
... [[0., 0., 0., 0., 0.],
|
216 |
+
... [0., 0., 0., 0., 0.],
|
217 |
+
... [0., 0., 0., 0., 0.],
|
218 |
+
... [0., 0., 0., 0., 0.],
|
219 |
+
... [0., 0., 0., 0., 0.]],
|
220 |
+
... [[0., 0., 0., 0., 0.],
|
221 |
+
... [0., 0., 0., 0., 0.],
|
222 |
+
... [0., 0., 5., 0., 0.],
|
223 |
+
... [0., 0., 0., 0., 0.],
|
224 |
+
... [0., 0., 0., 0., 0.]],
|
225 |
+
... [[0., 0., 0., 0., 0.],
|
226 |
+
... [0., 0., 0., 0., 0.],
|
227 |
+
... [0., 0., 0., 0., 0.],
|
228 |
+
... [0., 0., 0., 0., 0.],
|
229 |
+
... [0., 0., 0., 0., 0.]]
|
230 |
+
... ]]])
|
231 |
+
>>> kernel = torch.ones(1, 3, 3, 3)
|
232 |
+
>>> filter3d(input, kernel)
|
233 |
+
tensor([[[[[0., 0., 0., 0., 0.],
|
234 |
+
[0., 5., 5., 5., 0.],
|
235 |
+
[0., 5., 5., 5., 0.],
|
236 |
+
[0., 5., 5., 5., 0.],
|
237 |
+
[0., 0., 0., 0., 0.]],
|
238 |
+
<BLANKLINE>
|
239 |
+
[[0., 0., 0., 0., 0.],
|
240 |
+
[0., 5., 5., 5., 0.],
|
241 |
+
[0., 5., 5., 5., 0.],
|
242 |
+
[0., 5., 5., 5., 0.],
|
243 |
+
[0., 0., 0., 0., 0.]],
|
244 |
+
<BLANKLINE>
|
245 |
+
[[0., 0., 0., 0., 0.],
|
246 |
+
[0., 5., 5., 5., 0.],
|
247 |
+
[0., 5., 5., 5., 0.],
|
248 |
+
[0., 5., 5., 5., 0.],
|
249 |
+
[0., 0., 0., 0., 0.]]]]])
|
250 |
+
"""
|
251 |
+
if not isinstance(input, torch.Tensor):
|
252 |
+
raise TypeError(f"Input border_type is not torch.Tensor. Got {type(input)}")
|
253 |
+
|
254 |
+
if not isinstance(kernel, torch.Tensor):
|
255 |
+
raise TypeError(f"Input border_type is not torch.Tensor. Got {type(kernel)}")
|
256 |
+
|
257 |
+
if not isinstance(border_type, str):
|
258 |
+
raise TypeError(f"Input border_type is not string. Got {type(kernel)}")
|
259 |
+
|
260 |
+
if not len(input.shape) == 5:
|
261 |
+
raise ValueError(f"Invalid input shape, we expect BxCxDxHxW. Got: {input.shape}")
|
262 |
+
|
263 |
+
if not len(kernel.shape) == 4 and kernel.shape[0] != 1:
|
264 |
+
raise ValueError(f"Invalid kernel shape, we expect 1xDxHxW. Got: {kernel.shape}")
|
265 |
+
|
266 |
+
# prepare kernel
|
267 |
+
b, c, d, h, w = input.shape
|
268 |
+
tmp_kernel: torch.Tensor = kernel.unsqueeze(1).to(input)
|
269 |
+
|
270 |
+
if normalized:
|
271 |
+
bk, dk, hk, wk = kernel.shape
|
272 |
+
tmp_kernel = normalize_kernel2d(tmp_kernel.view(bk, dk, hk * wk)).view_as(tmp_kernel)
|
273 |
+
|
274 |
+
tmp_kernel = tmp_kernel.expand(-1, c, -1, -1, -1)
|
275 |
+
|
276 |
+
# pad the input tensor
|
277 |
+
depth, height, width = tmp_kernel.shape[-3:]
|
278 |
+
padding_shape: List[int] = _compute_padding([depth, height, width])
|
279 |
+
input_pad: torch.Tensor = F.pad(input, padding_shape, mode=border_type)
|
280 |
+
|
281 |
+
# kernel and input tensor reshape to align element-wise or batch-wise params
|
282 |
+
tmp_kernel = tmp_kernel.reshape(-1, 1, depth, height, width)
|
283 |
+
input_pad = input_pad.view(-1, tmp_kernel.size(0), input_pad.size(-3), input_pad.size(-2), input_pad.size(-1))
|
284 |
+
|
285 |
+
# convolve the tensor with the kernel.
|
286 |
+
output = F.conv3d(input_pad, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
|
287 |
+
|
288 |
+
return output.view(b, c, d, h, w)
|
model/canny/gaussian.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from .filter import filter2d, filter2d_separable
|
7 |
+
from .kernels import get_gaussian_kernel1d, get_gaussian_kernel2d
|
8 |
+
|
9 |
+
|
10 |
+
def gaussian_blur2d(
|
11 |
+
input: torch.Tensor,
|
12 |
+
kernel_size: Tuple[int, int],
|
13 |
+
sigma: Tuple[float, float],
|
14 |
+
border_type: str = 'reflect',
|
15 |
+
separable: bool = True,
|
16 |
+
) -> torch.Tensor:
|
17 |
+
r"""Create an operator that blurs a tensor using a Gaussian filter.
|
18 |
+
|
19 |
+
.. image:: _static/img/gaussian_blur2d.png
|
20 |
+
|
21 |
+
The operator smooths the given tensor with a gaussian kernel by convolving
|
22 |
+
it to each channel. It supports batched operation.
|
23 |
+
|
24 |
+
Arguments:
|
25 |
+
input: the input tensor with shape :math:`(B,C,H,W)`.
|
26 |
+
kernel_size: the size of the kernel.
|
27 |
+
sigma: the standard deviation of the kernel.
|
28 |
+
border_type: the padding mode to be applied before convolving.
|
29 |
+
The expected modes are: ``'constant'``, ``'reflect'``,
|
30 |
+
``'replicate'`` or ``'circular'``. Default: ``'reflect'``.
|
31 |
+
separable: run as composition of two 1d-convolutions.
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
the blurred tensor with shape :math:`(B, C, H, W)`.
|
35 |
+
|
36 |
+
.. note::
|
37 |
+
See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/
|
38 |
+
gaussian_blur.html>`__.
|
39 |
+
|
40 |
+
Examples:
|
41 |
+
>>> input = torch.rand(2, 4, 5, 5)
|
42 |
+
>>> output = gaussian_blur2d(input, (3, 3), (1.5, 1.5))
|
43 |
+
>>> output.shape
|
44 |
+
torch.Size([2, 4, 5, 5])
|
45 |
+
"""
|
46 |
+
if separable:
|
47 |
+
kernel_x: torch.Tensor = get_gaussian_kernel1d(kernel_size[1], sigma[1])
|
48 |
+
kernel_y: torch.Tensor = get_gaussian_kernel1d(kernel_size[0], sigma[0])
|
49 |
+
out = filter2d_separable(input, kernel_x[None], kernel_y[None], border_type)
|
50 |
+
else:
|
51 |
+
kernel: torch.Tensor = get_gaussian_kernel2d(kernel_size, sigma)
|
52 |
+
out = filter2d(input, kernel[None], border_type)
|
53 |
+
return out
|
54 |
+
|
55 |
+
|
56 |
+
class GaussianBlur2d(nn.Module):
|
57 |
+
r"""Create an operator that blurs a tensor using a Gaussian filter.
|
58 |
+
|
59 |
+
The operator smooths the given tensor with a gaussian kernel by convolving
|
60 |
+
it to each channel. It supports batched operation.
|
61 |
+
|
62 |
+
Arguments:
|
63 |
+
kernel_size: the size of the kernel.
|
64 |
+
sigma: the standard deviation of the kernel.
|
65 |
+
border_type: the padding mode to be applied before convolving.
|
66 |
+
The expected modes are: ``'constant'``, ``'reflect'``,
|
67 |
+
``'replicate'`` or ``'circular'``. Default: ``'reflect'``.
|
68 |
+
separable: run as composition of two 1d-convolutions.
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
the blurred tensor.
|
72 |
+
|
73 |
+
Shape:
|
74 |
+
- Input: :math:`(B, C, H, W)`
|
75 |
+
- Output: :math:`(B, C, H, W)`
|
76 |
+
|
77 |
+
Examples::
|
78 |
+
|
79 |
+
>>> input = torch.rand(2, 4, 5, 5)
|
80 |
+
>>> gauss = GaussianBlur2d((3, 3), (1.5, 1.5))
|
81 |
+
>>> output = gauss(input) # 2x4x5x5
|
82 |
+
>>> output.shape
|
83 |
+
torch.Size([2, 4, 5, 5])
|
84 |
+
"""
|
85 |
+
|
86 |
+
def __init__(
|
87 |
+
self,
|
88 |
+
kernel_size: Tuple[int, int],
|
89 |
+
sigma: Tuple[float, float],
|
90 |
+
border_type: str = 'reflect',
|
91 |
+
separable: bool = True,
|
92 |
+
) -> None:
|
93 |
+
super().__init__()
|
94 |
+
self.kernel_size: Tuple[int, int] = kernel_size
|
95 |
+
self.sigma: Tuple[float, float] = sigma
|
96 |
+
self.border_type = border_type
|
97 |
+
self.separable = separable
|
98 |
+
|
99 |
+
def __repr__(self) -> str:
|
100 |
+
return (
|
101 |
+
self.__class__.__name__
|
102 |
+
+ '(kernel_size='
|
103 |
+
+ str(self.kernel_size)
|
104 |
+
+ ', '
|
105 |
+
+ 'sigma='
|
106 |
+
+ str(self.sigma)
|
107 |
+
+ ', '
|
108 |
+
+ 'border_type='
|
109 |
+
+ self.border_type
|
110 |
+
+ 'separable='
|
111 |
+
+ str(self.separable)
|
112 |
+
+ ')'
|
113 |
+
)
|
114 |
+
|
115 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
116 |
+
return gaussian_blur2d(input, self.kernel_size, self.sigma, self.border_type, self.separable)
|
model/canny/kernels.py
ADDED
@@ -0,0 +1,690 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from math import sqrt
|
3 |
+
from typing import List, Optional, Tuple
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
def normalize_kernel2d(input: torch.Tensor) -> torch.Tensor:
|
9 |
+
r"""Normalize both derivative and smoothing kernel."""
|
10 |
+
if len(input.size()) < 2:
|
11 |
+
raise TypeError(f"input should be at least 2D tensor. Got {input.size()}")
|
12 |
+
norm: torch.Tensor = input.abs().sum(dim=-1).sum(dim=-1)
|
13 |
+
return input / (norm.unsqueeze(-1).unsqueeze(-1))
|
14 |
+
|
15 |
+
|
16 |
+
def gaussian(window_size: int, sigma: float) -> torch.Tensor:
|
17 |
+
device, dtype = None, None
|
18 |
+
if isinstance(sigma, torch.Tensor):
|
19 |
+
device, dtype = sigma.device, sigma.dtype
|
20 |
+
x = torch.arange(window_size, device=device, dtype=dtype) - window_size // 2
|
21 |
+
if window_size % 2 == 0:
|
22 |
+
x = x + 0.5
|
23 |
+
|
24 |
+
gauss = torch.exp((-x.pow(2.0) / (2 * sigma**2)).float())
|
25 |
+
return gauss / gauss.sum()
|
26 |
+
|
27 |
+
|
28 |
+
def gaussian_discrete_erf(window_size: int, sigma) -> torch.Tensor:
|
29 |
+
r"""Discrete Gaussian by interpolating the error function.
|
30 |
+
|
31 |
+
Adapted from:
|
32 |
+
https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py
|
33 |
+
"""
|
34 |
+
device = sigma.device if isinstance(sigma, torch.Tensor) else None
|
35 |
+
sigma = torch.as_tensor(sigma, dtype=torch.float, device=device)
|
36 |
+
x = torch.arange(window_size).float() - window_size // 2
|
37 |
+
t = 0.70710678 / torch.abs(sigma)
|
38 |
+
gauss = 0.5 * ((t * (x + 0.5)).erf() - (t * (x - 0.5)).erf())
|
39 |
+
gauss = gauss.clamp(min=0)
|
40 |
+
return gauss / gauss.sum()
|
41 |
+
|
42 |
+
|
43 |
+
def _modified_bessel_0(x: torch.Tensor) -> torch.Tensor:
|
44 |
+
r"""Adapted from:
|
45 |
+
|
46 |
+
https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py
|
47 |
+
"""
|
48 |
+
if torch.abs(x) < 3.75:
|
49 |
+
y = (x / 3.75) * (x / 3.75)
|
50 |
+
return 1.0 + y * (
|
51 |
+
3.5156229 + y * (3.0899424 + y * (1.2067492 + y * (0.2659732 + y * (0.360768e-1 + y * 0.45813e-2))))
|
52 |
+
)
|
53 |
+
ax = torch.abs(x)
|
54 |
+
y = 3.75 / ax
|
55 |
+
ans = 0.916281e-2 + y * (-0.2057706e-1 + y * (0.2635537e-1 + y * (-0.1647633e-1 + y * 0.392377e-2)))
|
56 |
+
coef = 0.39894228 + y * (0.1328592e-1 + y * (0.225319e-2 + y * (-0.157565e-2 + y * ans)))
|
57 |
+
return (torch.exp(ax) / torch.sqrt(ax)) * coef
|
58 |
+
|
59 |
+
|
60 |
+
def _modified_bessel_1(x: torch.Tensor) -> torch.Tensor:
|
61 |
+
r"""adapted from:
|
62 |
+
|
63 |
+
https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py
|
64 |
+
"""
|
65 |
+
if torch.abs(x) < 3.75:
|
66 |
+
y = (x / 3.75) * (x / 3.75)
|
67 |
+
ans = 0.51498869 + y * (0.15084934 + y * (0.2658733e-1 + y * (0.301532e-2 + y * 0.32411e-3)))
|
68 |
+
return torch.abs(x) * (0.5 + y * (0.87890594 + y * ans))
|
69 |
+
ax = torch.abs(x)
|
70 |
+
y = 3.75 / ax
|
71 |
+
ans = 0.2282967e-1 + y * (-0.2895312e-1 + y * (0.1787654e-1 - y * 0.420059e-2))
|
72 |
+
ans = 0.39894228 + y * (-0.3988024e-1 + y * (-0.362018e-2 + y * (0.163801e-2 + y * (-0.1031555e-1 + y * ans))))
|
73 |
+
ans = ans * torch.exp(ax) / torch.sqrt(ax)
|
74 |
+
return -ans if x < 0.0 else ans
|
75 |
+
|
76 |
+
|
77 |
+
def _modified_bessel_i(n: int, x: torch.Tensor) -> torch.Tensor:
|
78 |
+
r"""adapted from:
|
79 |
+
|
80 |
+
https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py
|
81 |
+
"""
|
82 |
+
if n < 2:
|
83 |
+
raise ValueError("n must be greater than 1.")
|
84 |
+
if x == 0.0:
|
85 |
+
return x
|
86 |
+
device = x.device
|
87 |
+
tox = 2.0 / torch.abs(x)
|
88 |
+
ans = torch.tensor(0.0, device=device)
|
89 |
+
bip = torch.tensor(0.0, device=device)
|
90 |
+
bi = torch.tensor(1.0, device=device)
|
91 |
+
m = int(2 * (n + int(sqrt(40.0 * n))))
|
92 |
+
for j in range(m, 0, -1):
|
93 |
+
bim = bip + float(j) * tox * bi
|
94 |
+
bip = bi
|
95 |
+
bi = bim
|
96 |
+
if abs(bi) > 1.0e10:
|
97 |
+
ans = ans * 1.0e-10
|
98 |
+
bi = bi * 1.0e-10
|
99 |
+
bip = bip * 1.0e-10
|
100 |
+
if j == n:
|
101 |
+
ans = bip
|
102 |
+
ans = ans * _modified_bessel_0(x) / bi
|
103 |
+
return -ans if x < 0.0 and (n % 2) == 1 else ans
|
104 |
+
|
105 |
+
|
106 |
+
def gaussian_discrete(window_size, sigma) -> torch.Tensor:
|
107 |
+
r"""Discrete Gaussian kernel based on the modified Bessel functions.
|
108 |
+
|
109 |
+
Adapted from:
|
110 |
+
https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py
|
111 |
+
"""
|
112 |
+
device = sigma.device if isinstance(sigma, torch.Tensor) else None
|
113 |
+
sigma = torch.as_tensor(sigma, dtype=torch.float, device=device)
|
114 |
+
sigma2 = sigma * sigma
|
115 |
+
tail = int(window_size // 2)
|
116 |
+
out_pos: List[Optional[torch.Tensor]] = [None] * (tail + 1)
|
117 |
+
out_pos[0] = _modified_bessel_0(sigma2)
|
118 |
+
out_pos[1] = _modified_bessel_1(sigma2)
|
119 |
+
for k in range(2, len(out_pos)):
|
120 |
+
out_pos[k] = _modified_bessel_i(k, sigma2)
|
121 |
+
out = out_pos[:0:-1]
|
122 |
+
out.extend(out_pos)
|
123 |
+
out = torch.stack(out) * torch.exp(sigma2) # type: ignore
|
124 |
+
return out / out.sum() # type: ignore
|
125 |
+
|
126 |
+
|
127 |
+
def laplacian_1d(window_size) -> torch.Tensor:
|
128 |
+
r"""One could also use the Laplacian of Gaussian formula to design the filter."""
|
129 |
+
|
130 |
+
filter_1d = torch.ones(window_size)
|
131 |
+
filter_1d[window_size // 2] = 1 - window_size
|
132 |
+
laplacian_1d: torch.Tensor = filter_1d
|
133 |
+
return laplacian_1d
|
134 |
+
|
135 |
+
|
136 |
+
def get_box_kernel2d(kernel_size: Tuple[int, int]) -> torch.Tensor:
|
137 |
+
r"""Utility function that returns a box filter."""
|
138 |
+
kx: float = float(kernel_size[0])
|
139 |
+
ky: float = float(kernel_size[1])
|
140 |
+
scale: torch.Tensor = torch.tensor(1.0) / torch.tensor([kx * ky])
|
141 |
+
tmp_kernel: torch.Tensor = torch.ones(1, kernel_size[0], kernel_size[1])
|
142 |
+
return scale.to(tmp_kernel.dtype) * tmp_kernel
|
143 |
+
|
144 |
+
|
145 |
+
def get_binary_kernel2d(window_size: Tuple[int, int]) -> torch.Tensor:
|
146 |
+
r"""Create a binary kernel to extract the patches.
|
147 |
+
|
148 |
+
If the window size is HxW will create a (H*W)xHxW kernel.
|
149 |
+
"""
|
150 |
+
window_range: int = window_size[0] * window_size[1]
|
151 |
+
kernel: torch.Tensor = torch.zeros(window_range, window_range)
|
152 |
+
for i in range(window_range):
|
153 |
+
kernel[i, i] += 1.0
|
154 |
+
return kernel.view(window_range, 1, window_size[0], window_size[1])
|
155 |
+
|
156 |
+
|
157 |
+
def get_sobel_kernel_3x3() -> torch.Tensor:
|
158 |
+
"""Utility function that returns a sobel kernel of 3x3."""
|
159 |
+
return torch.tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]])
|
160 |
+
|
161 |
+
|
162 |
+
def get_sobel_kernel_5x5_2nd_order() -> torch.Tensor:
|
163 |
+
"""Utility function that returns a 2nd order sobel kernel of 5x5."""
|
164 |
+
return torch.tensor(
|
165 |
+
[
|
166 |
+
[-1.0, 0.0, 2.0, 0.0, -1.0],
|
167 |
+
[-4.0, 0.0, 8.0, 0.0, -4.0],
|
168 |
+
[-6.0, 0.0, 12.0, 0.0, -6.0],
|
169 |
+
[-4.0, 0.0, 8.0, 0.0, -4.0],
|
170 |
+
[-1.0, 0.0, 2.0, 0.0, -1.0],
|
171 |
+
]
|
172 |
+
)
|
173 |
+
|
174 |
+
|
175 |
+
def _get_sobel_kernel_5x5_2nd_order_xy() -> torch.Tensor:
|
176 |
+
"""Utility function that returns a 2nd order sobel kernel of 5x5."""
|
177 |
+
return torch.tensor(
|
178 |
+
[
|
179 |
+
[-1.0, -2.0, 0.0, 2.0, 1.0],
|
180 |
+
[-2.0, -4.0, 0.0, 4.0, 2.0],
|
181 |
+
[0.0, 0.0, 0.0, 0.0, 0.0],
|
182 |
+
[2.0, 4.0, 0.0, -4.0, -2.0],
|
183 |
+
[1.0, 2.0, 0.0, -2.0, -1.0],
|
184 |
+
]
|
185 |
+
)
|
186 |
+
|
187 |
+
|
188 |
+
def get_diff_kernel_3x3() -> torch.Tensor:
|
189 |
+
"""Utility function that returns a first order derivative kernel of 3x3."""
|
190 |
+
return torch.tensor([[-0.0, 0.0, 0.0], [-1.0, 0.0, 1.0], [-0.0, 0.0, 0.0]])
|
191 |
+
|
192 |
+
|
193 |
+
def get_diff_kernel3d(device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor:
|
194 |
+
"""Utility function that returns a first order derivative kernel of 3x3x3."""
|
195 |
+
kernel: torch.Tensor = torch.tensor(
|
196 |
+
[
|
197 |
+
[
|
198 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
199 |
+
[[0.0, 0.0, 0.0], [-0.5, 0.0, 0.5], [0.0, 0.0, 0.0]],
|
200 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
201 |
+
],
|
202 |
+
[
|
203 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
204 |
+
[[0.0, -0.5, 0.0], [0.0, 0.0, 0.0], [0.0, 0.5, 0.0]],
|
205 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
206 |
+
],
|
207 |
+
[
|
208 |
+
[[0.0, 0.0, 0.0], [0.0, -0.5, 0.0], [0.0, 0.0, 0.0]],
|
209 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
210 |
+
[[0.0, 0.0, 0.0], [0.0, 0.5, 0.0], [0.0, 0.0, 0.0]],
|
211 |
+
],
|
212 |
+
],
|
213 |
+
device=device,
|
214 |
+
dtype=dtype,
|
215 |
+
)
|
216 |
+
return kernel.unsqueeze(1)
|
217 |
+
|
218 |
+
|
219 |
+
def get_diff_kernel3d_2nd_order(device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor:
|
220 |
+
"""Utility function that returns a first order derivative kernel of 3x3x3."""
|
221 |
+
kernel: torch.Tensor = torch.tensor(
|
222 |
+
[
|
223 |
+
[
|
224 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
225 |
+
[[0.0, 0.0, 0.0], [1.0, -2.0, 1.0], [0.0, 0.0, 0.0]],
|
226 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
227 |
+
],
|
228 |
+
[
|
229 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
230 |
+
[[0.0, 1.0, 0.0], [0.0, -2.0, 0.0], [0.0, 1.0, 0.0]],
|
231 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
232 |
+
],
|
233 |
+
[
|
234 |
+
[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
|
235 |
+
[[0.0, 0.0, 0.0], [0.0, -2.0, 0.0], [0.0, 0.0, 0.0]],
|
236 |
+
[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
|
237 |
+
],
|
238 |
+
[
|
239 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
240 |
+
[[1.0, 0.0, -1.0], [0.0, 0.0, 0.0], [-1.0, 0.0, 1.0]],
|
241 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
242 |
+
],
|
243 |
+
[
|
244 |
+
[[0.0, 1.0, 0.0], [0.0, 0.0, 0.0], [0.0, -1.0, 0.0]],
|
245 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
246 |
+
[[0.0, -1.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
|
247 |
+
],
|
248 |
+
[
|
249 |
+
[[0.0, 0.0, 0.0], [1.0, 0.0, -1.0], [0.0, 0.0, 0.0]],
|
250 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
251 |
+
[[0.0, 0.0, 0.0], [-1.0, 0.0, 1.0], [0.0, 0.0, 0.0]],
|
252 |
+
],
|
253 |
+
],
|
254 |
+
device=device,
|
255 |
+
dtype=dtype,
|
256 |
+
)
|
257 |
+
return kernel.unsqueeze(1)
|
258 |
+
|
259 |
+
|
260 |
+
def get_sobel_kernel2d() -> torch.Tensor:
|
261 |
+
kernel_x: torch.Tensor = get_sobel_kernel_3x3()
|
262 |
+
kernel_y: torch.Tensor = kernel_x.transpose(0, 1)
|
263 |
+
return torch.stack([kernel_x, kernel_y])
|
264 |
+
|
265 |
+
|
266 |
+
def get_diff_kernel2d() -> torch.Tensor:
|
267 |
+
kernel_x: torch.Tensor = get_diff_kernel_3x3()
|
268 |
+
kernel_y: torch.Tensor = kernel_x.transpose(0, 1)
|
269 |
+
return torch.stack([kernel_x, kernel_y])
|
270 |
+
|
271 |
+
|
272 |
+
def get_sobel_kernel2d_2nd_order() -> torch.Tensor:
|
273 |
+
gxx: torch.Tensor = get_sobel_kernel_5x5_2nd_order()
|
274 |
+
gyy: torch.Tensor = gxx.transpose(0, 1)
|
275 |
+
gxy: torch.Tensor = _get_sobel_kernel_5x5_2nd_order_xy()
|
276 |
+
return torch.stack([gxx, gxy, gyy])
|
277 |
+
|
278 |
+
|
279 |
+
def get_diff_kernel2d_2nd_order() -> torch.Tensor:
|
280 |
+
gxx: torch.Tensor = torch.tensor([[0.0, 0.0, 0.0], [1.0, -2.0, 1.0], [0.0, 0.0, 0.0]])
|
281 |
+
gyy: torch.Tensor = gxx.transpose(0, 1)
|
282 |
+
gxy: torch.Tensor = torch.tensor([[-1.0, 0.0, 1.0], [0.0, 0.0, 0.0], [1.0, 0.0, -1.0]])
|
283 |
+
return torch.stack([gxx, gxy, gyy])
|
284 |
+
|
285 |
+
|
286 |
+
def get_spatial_gradient_kernel2d(mode: str, order: int) -> torch.Tensor:
|
287 |
+
r"""Function that returns kernel for 1st or 2nd order image gradients, using one of the following operators:
|
288 |
+
|
289 |
+
sobel, diff.
|
290 |
+
"""
|
291 |
+
if mode not in ['sobel', 'diff']:
|
292 |
+
raise TypeError(
|
293 |
+
"mode should be either sobel\
|
294 |
+
or diff. Got {}".format(
|
295 |
+
mode
|
296 |
+
)
|
297 |
+
)
|
298 |
+
if order not in [1, 2]:
|
299 |
+
raise TypeError(
|
300 |
+
"order should be either 1 or 2\
|
301 |
+
Got {}".format(
|
302 |
+
order
|
303 |
+
)
|
304 |
+
)
|
305 |
+
if mode == 'sobel' and order == 1:
|
306 |
+
kernel: torch.Tensor = get_sobel_kernel2d()
|
307 |
+
elif mode == 'sobel' and order == 2:
|
308 |
+
kernel = get_sobel_kernel2d_2nd_order()
|
309 |
+
elif mode == 'diff' and order == 1:
|
310 |
+
kernel = get_diff_kernel2d()
|
311 |
+
elif mode == 'diff' and order == 2:
|
312 |
+
kernel = get_diff_kernel2d_2nd_order()
|
313 |
+
else:
|
314 |
+
raise NotImplementedError("")
|
315 |
+
return kernel
|
316 |
+
|
317 |
+
|
318 |
+
def get_spatial_gradient_kernel3d(mode: str, order: int, device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor:
|
319 |
+
r"""Function that returns kernel for 1st or 2nd order scale pyramid gradients, using one of the following
|
320 |
+
operators: sobel, diff."""
|
321 |
+
if mode not in ['sobel', 'diff']:
|
322 |
+
raise TypeError(
|
323 |
+
"mode should be either sobel\
|
324 |
+
or diff. Got {}".format(
|
325 |
+
mode
|
326 |
+
)
|
327 |
+
)
|
328 |
+
if order not in [1, 2]:
|
329 |
+
raise TypeError(
|
330 |
+
"order should be either 1 or 2\
|
331 |
+
Got {}".format(
|
332 |
+
order
|
333 |
+
)
|
334 |
+
)
|
335 |
+
if mode == 'sobel':
|
336 |
+
raise NotImplementedError("Sobel kernel for 3d gradient is not implemented yet")
|
337 |
+
if mode == 'diff' and order == 1:
|
338 |
+
kernel = get_diff_kernel3d(device, dtype)
|
339 |
+
elif mode == 'diff' and order == 2:
|
340 |
+
kernel = get_diff_kernel3d_2nd_order(device, dtype)
|
341 |
+
else:
|
342 |
+
raise NotImplementedError("")
|
343 |
+
return kernel
|
344 |
+
|
345 |
+
|
346 |
+
def get_gaussian_kernel1d(kernel_size: int, sigma: float, force_even: bool = False) -> torch.Tensor:
|
347 |
+
r"""Function that returns Gaussian filter coefficients.
|
348 |
+
|
349 |
+
Args:
|
350 |
+
kernel_size: filter size. It should be odd and positive.
|
351 |
+
sigma: gaussian standard deviation.
|
352 |
+
force_even: overrides requirement for odd kernel size.
|
353 |
+
|
354 |
+
Returns:
|
355 |
+
1D tensor with gaussian filter coefficients.
|
356 |
+
|
357 |
+
Shape:
|
358 |
+
- Output: :math:`(\text{kernel_size})`
|
359 |
+
|
360 |
+
Examples:
|
361 |
+
|
362 |
+
>>> get_gaussian_kernel1d(3, 2.5)
|
363 |
+
tensor([0.3243, 0.3513, 0.3243])
|
364 |
+
|
365 |
+
>>> get_gaussian_kernel1d(5, 1.5)
|
366 |
+
tensor([0.1201, 0.2339, 0.2921, 0.2339, 0.1201])
|
367 |
+
"""
|
368 |
+
if not isinstance(kernel_size, int) or ((kernel_size % 2 == 0) and not force_even) or (kernel_size <= 0):
|
369 |
+
raise TypeError("kernel_size must be an odd positive integer. " "Got {}".format(kernel_size))
|
370 |
+
window_1d: torch.Tensor = gaussian(kernel_size, sigma)
|
371 |
+
return window_1d
|
372 |
+
|
373 |
+
|
374 |
+
def get_gaussian_discrete_kernel1d(kernel_size: int, sigma: float, force_even: bool = False) -> torch.Tensor:
|
375 |
+
r"""Function that returns Gaussian filter coefficients based on the modified Bessel functions. Adapted from:
|
376 |
+
https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py.
|
377 |
+
|
378 |
+
Args:
|
379 |
+
kernel_size: filter size. It should be odd and positive.
|
380 |
+
sigma: gaussian standard deviation.
|
381 |
+
force_even: overrides requirement for odd kernel size.
|
382 |
+
|
383 |
+
Returns:
|
384 |
+
1D tensor with gaussian filter coefficients.
|
385 |
+
|
386 |
+
Shape:
|
387 |
+
- Output: :math:`(\text{kernel_size})`
|
388 |
+
|
389 |
+
Examples:
|
390 |
+
|
391 |
+
>>> get_gaussian_discrete_kernel1d(3, 2.5)
|
392 |
+
tensor([0.3235, 0.3531, 0.3235])
|
393 |
+
|
394 |
+
>>> get_gaussian_discrete_kernel1d(5, 1.5)
|
395 |
+
tensor([0.1096, 0.2323, 0.3161, 0.2323, 0.1096])
|
396 |
+
"""
|
397 |
+
if not isinstance(kernel_size, int) or ((kernel_size % 2 == 0) and not force_even) or (kernel_size <= 0):
|
398 |
+
raise TypeError("kernel_size must be an odd positive integer. " "Got {}".format(kernel_size))
|
399 |
+
window_1d = gaussian_discrete(kernel_size, sigma)
|
400 |
+
return window_1d
|
401 |
+
|
402 |
+
|
403 |
+
def get_gaussian_erf_kernel1d(kernel_size: int, sigma: float, force_even: bool = False) -> torch.Tensor:
|
404 |
+
r"""Function that returns Gaussian filter coefficients by interpolating the error function, adapted from:
|
405 |
+
https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py.
|
406 |
+
|
407 |
+
Args:
|
408 |
+
kernel_size: filter size. It should be odd and positive.
|
409 |
+
sigma: gaussian standard deviation.
|
410 |
+
force_even: overrides requirement for odd kernel size.
|
411 |
+
|
412 |
+
Returns:
|
413 |
+
1D tensor with gaussian filter coefficients.
|
414 |
+
|
415 |
+
Shape:
|
416 |
+
- Output: :math:`(\text{kernel_size})`
|
417 |
+
|
418 |
+
Examples:
|
419 |
+
|
420 |
+
>>> get_gaussian_erf_kernel1d(3, 2.5)
|
421 |
+
tensor([0.3245, 0.3511, 0.3245])
|
422 |
+
|
423 |
+
>>> get_gaussian_erf_kernel1d(5, 1.5)
|
424 |
+
tensor([0.1226, 0.2331, 0.2887, 0.2331, 0.1226])
|
425 |
+
"""
|
426 |
+
if not isinstance(kernel_size, int) or ((kernel_size % 2 == 0) and not force_even) or (kernel_size <= 0):
|
427 |
+
raise TypeError("kernel_size must be an odd positive integer. " "Got {}".format(kernel_size))
|
428 |
+
window_1d = gaussian_discrete_erf(kernel_size, sigma)
|
429 |
+
return window_1d
|
430 |
+
|
431 |
+
|
432 |
+
def get_gaussian_kernel2d(
|
433 |
+
kernel_size: Tuple[int, int], sigma: Tuple[float, float], force_even: bool = False
|
434 |
+
) -> torch.Tensor:
|
435 |
+
r"""Function that returns Gaussian filter matrix coefficients.
|
436 |
+
|
437 |
+
Args:
|
438 |
+
kernel_size: filter sizes in the x and y direction.
|
439 |
+
Sizes should be odd and positive.
|
440 |
+
sigma: gaussian standard deviation in the x and y
|
441 |
+
direction.
|
442 |
+
force_even: overrides requirement for odd kernel size.
|
443 |
+
|
444 |
+
Returns:
|
445 |
+
2D tensor with gaussian filter matrix coefficients.
|
446 |
+
|
447 |
+
Shape:
|
448 |
+
- Output: :math:`(\text{kernel_size}_x, \text{kernel_size}_y)`
|
449 |
+
|
450 |
+
Examples:
|
451 |
+
>>> get_gaussian_kernel2d((3, 3), (1.5, 1.5))
|
452 |
+
tensor([[0.0947, 0.1183, 0.0947],
|
453 |
+
[0.1183, 0.1478, 0.1183],
|
454 |
+
[0.0947, 0.1183, 0.0947]])
|
455 |
+
>>> get_gaussian_kernel2d((3, 5), (1.5, 1.5))
|
456 |
+
tensor([[0.0370, 0.0720, 0.0899, 0.0720, 0.0370],
|
457 |
+
[0.0462, 0.0899, 0.1123, 0.0899, 0.0462],
|
458 |
+
[0.0370, 0.0720, 0.0899, 0.0720, 0.0370]])
|
459 |
+
"""
|
460 |
+
if not isinstance(kernel_size, tuple) or len(kernel_size) != 2:
|
461 |
+
raise TypeError(f"kernel_size must be a tuple of length two. Got {kernel_size}")
|
462 |
+
if not isinstance(sigma, tuple) or len(sigma) != 2:
|
463 |
+
raise TypeError(f"sigma must be a tuple of length two. Got {sigma}")
|
464 |
+
ksize_x, ksize_y = kernel_size
|
465 |
+
sigma_x, sigma_y = sigma
|
466 |
+
kernel_x: torch.Tensor = get_gaussian_kernel1d(ksize_x, sigma_x, force_even)
|
467 |
+
kernel_y: torch.Tensor = get_gaussian_kernel1d(ksize_y, sigma_y, force_even)
|
468 |
+
kernel_2d: torch.Tensor = torch.matmul(kernel_x.unsqueeze(-1), kernel_y.unsqueeze(-1).t())
|
469 |
+
return kernel_2d
|
470 |
+
|
471 |
+
|
472 |
+
def get_laplacian_kernel1d(kernel_size: int) -> torch.Tensor:
|
473 |
+
r"""Function that returns the coefficients of a 1D Laplacian filter.
|
474 |
+
|
475 |
+
Args:
|
476 |
+
kernel_size: filter size. It should be odd and positive.
|
477 |
+
|
478 |
+
Returns:
|
479 |
+
1D tensor with laplacian filter coefficients.
|
480 |
+
|
481 |
+
Shape:
|
482 |
+
- Output: math:`(\text{kernel_size})`
|
483 |
+
|
484 |
+
Examples:
|
485 |
+
>>> get_laplacian_kernel1d(3)
|
486 |
+
tensor([ 1., -2., 1.])
|
487 |
+
>>> get_laplacian_kernel1d(5)
|
488 |
+
tensor([ 1., 1., -4., 1., 1.])
|
489 |
+
"""
|
490 |
+
if not isinstance(kernel_size, int) or kernel_size % 2 == 0 or kernel_size <= 0:
|
491 |
+
raise TypeError(f"ksize must be an odd positive integer. Got {kernel_size}")
|
492 |
+
window_1d: torch.Tensor = laplacian_1d(kernel_size)
|
493 |
+
return window_1d
|
494 |
+
|
495 |
+
|
496 |
+
def get_laplacian_kernel2d(kernel_size: int) -> torch.Tensor:
|
497 |
+
r"""Function that returns Gaussian filter matrix coefficients.
|
498 |
+
|
499 |
+
Args:
|
500 |
+
kernel_size: filter size should be odd.
|
501 |
+
|
502 |
+
Returns:
|
503 |
+
2D tensor with laplacian filter matrix coefficients.
|
504 |
+
|
505 |
+
Shape:
|
506 |
+
- Output: :math:`(\text{kernel_size}_x, \text{kernel_size}_y)`
|
507 |
+
|
508 |
+
Examples:
|
509 |
+
>>> get_laplacian_kernel2d(3)
|
510 |
+
tensor([[ 1., 1., 1.],
|
511 |
+
[ 1., -8., 1.],
|
512 |
+
[ 1., 1., 1.]])
|
513 |
+
>>> get_laplacian_kernel2d(5)
|
514 |
+
tensor([[ 1., 1., 1., 1., 1.],
|
515 |
+
[ 1., 1., 1., 1., 1.],
|
516 |
+
[ 1., 1., -24., 1., 1.],
|
517 |
+
[ 1., 1., 1., 1., 1.],
|
518 |
+
[ 1., 1., 1., 1., 1.]])
|
519 |
+
"""
|
520 |
+
if not isinstance(kernel_size, int) or kernel_size % 2 == 0 or kernel_size <= 0:
|
521 |
+
raise TypeError(f"ksize must be an odd positive integer. Got {kernel_size}")
|
522 |
+
|
523 |
+
kernel = torch.ones((kernel_size, kernel_size))
|
524 |
+
mid = kernel_size // 2
|
525 |
+
kernel[mid, mid] = 1 - kernel_size**2
|
526 |
+
kernel_2d: torch.Tensor = kernel
|
527 |
+
return kernel_2d
|
528 |
+
|
529 |
+
|
530 |
+
def get_pascal_kernel_2d(kernel_size: int, norm: bool = True) -> torch.Tensor:
|
531 |
+
"""Generate pascal filter kernel by kernel size.
|
532 |
+
|
533 |
+
Args:
|
534 |
+
kernel_size: height and width of the kernel.
|
535 |
+
norm: if to normalize the kernel or not. Default: True.
|
536 |
+
|
537 |
+
Returns:
|
538 |
+
kernel shaped as :math:`(kernel_size, kernel_size)`
|
539 |
+
|
540 |
+
Examples:
|
541 |
+
>>> get_pascal_kernel_2d(1)
|
542 |
+
tensor([[1.]])
|
543 |
+
>>> get_pascal_kernel_2d(4)
|
544 |
+
tensor([[0.0156, 0.0469, 0.0469, 0.0156],
|
545 |
+
[0.0469, 0.1406, 0.1406, 0.0469],
|
546 |
+
[0.0469, 0.1406, 0.1406, 0.0469],
|
547 |
+
[0.0156, 0.0469, 0.0469, 0.0156]])
|
548 |
+
>>> get_pascal_kernel_2d(4, norm=False)
|
549 |
+
tensor([[1., 3., 3., 1.],
|
550 |
+
[3., 9., 9., 3.],
|
551 |
+
[3., 9., 9., 3.],
|
552 |
+
[1., 3., 3., 1.]])
|
553 |
+
"""
|
554 |
+
a = get_pascal_kernel_1d(kernel_size)
|
555 |
+
|
556 |
+
filt = a[:, None] * a[None, :]
|
557 |
+
if norm:
|
558 |
+
filt = filt / torch.sum(filt)
|
559 |
+
return filt
|
560 |
+
|
561 |
+
|
562 |
+
def get_pascal_kernel_1d(kernel_size: int, norm: bool = False) -> torch.Tensor:
|
563 |
+
"""Generate Yang Hui triangle (Pascal's triangle) by a given number.
|
564 |
+
|
565 |
+
Args:
|
566 |
+
kernel_size: height and width of the kernel.
|
567 |
+
norm: if to normalize the kernel or not. Default: False.
|
568 |
+
|
569 |
+
Returns:
|
570 |
+
kernel shaped as :math:`(kernel_size,)`
|
571 |
+
|
572 |
+
Examples:
|
573 |
+
>>> get_pascal_kernel_1d(1)
|
574 |
+
tensor([1.])
|
575 |
+
>>> get_pascal_kernel_1d(2)
|
576 |
+
tensor([1., 1.])
|
577 |
+
>>> get_pascal_kernel_1d(3)
|
578 |
+
tensor([1., 2., 1.])
|
579 |
+
>>> get_pascal_kernel_1d(4)
|
580 |
+
tensor([1., 3., 3., 1.])
|
581 |
+
>>> get_pascal_kernel_1d(5)
|
582 |
+
tensor([1., 4., 6., 4., 1.])
|
583 |
+
>>> get_pascal_kernel_1d(6)
|
584 |
+
tensor([ 1., 5., 10., 10., 5., 1.])
|
585 |
+
"""
|
586 |
+
pre: List[float] = []
|
587 |
+
cur: List[float] = []
|
588 |
+
for i in range(kernel_size):
|
589 |
+
cur = [1.0] * (i + 1)
|
590 |
+
|
591 |
+
for j in range(1, i // 2 + 1):
|
592 |
+
value = pre[j - 1] + pre[j]
|
593 |
+
cur[j] = value
|
594 |
+
if i != 2 * j:
|
595 |
+
cur[-j - 1] = value
|
596 |
+
pre = cur
|
597 |
+
|
598 |
+
out = torch.as_tensor(cur)
|
599 |
+
if norm:
|
600 |
+
out = out / torch.sum(out)
|
601 |
+
return out
|
602 |
+
|
603 |
+
|
604 |
+
def get_canny_nms_kernel(device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor:
|
605 |
+
"""Utility function that returns 3x3 kernels for the Canny Non-maximal suppression."""
|
606 |
+
kernel: torch.Tensor = torch.tensor(
|
607 |
+
[
|
608 |
+
[[0.0, 0.0, 0.0], [0.0, 1.0, -1.0], [0.0, 0.0, 0.0]],
|
609 |
+
[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]],
|
610 |
+
[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, -1.0, 0.0]],
|
611 |
+
[[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]],
|
612 |
+
[[0.0, 0.0, 0.0], [-1.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
|
613 |
+
[[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
|
614 |
+
[[0.0, -1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
|
615 |
+
[[0.0, 0.0, -1.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
|
616 |
+
],
|
617 |
+
device=device,
|
618 |
+
dtype=dtype,
|
619 |
+
)
|
620 |
+
return kernel.unsqueeze(1)
|
621 |
+
|
622 |
+
|
623 |
+
def get_hysteresis_kernel(device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor:
|
624 |
+
"""Utility function that returns the 3x3 kernels for the Canny hysteresis."""
|
625 |
+
kernel: torch.Tensor = torch.tensor(
|
626 |
+
[
|
627 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 0.0]],
|
628 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 1.0]],
|
629 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
|
630 |
+
[[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [1.0, 0.0, 0.0]],
|
631 |
+
[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
632 |
+
[[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
633 |
+
[[0.0, 1.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
634 |
+
[[0.0, 0.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
|
635 |
+
],
|
636 |
+
device=device,
|
637 |
+
dtype=dtype,
|
638 |
+
)
|
639 |
+
return kernel.unsqueeze(1)
|
640 |
+
|
641 |
+
|
642 |
+
def get_hanning_kernel1d(kernel_size: int, device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor:
|
643 |
+
r"""Returns Hanning (also known as Hann) kernel, used in signal processing and KCF tracker.
|
644 |
+
|
645 |
+
.. math:: w(n) = 0.5 - 0.5cos\\left(\\frac{2\\pi{n}}{M-1}\\right)
|
646 |
+
\\qquad 0 \\leq n \\leq M-1
|
647 |
+
|
648 |
+
See further in numpy docs https://numpy.org/doc/stable/reference/generated/numpy.hanning.html
|
649 |
+
|
650 |
+
Args:
|
651 |
+
kernel_size: The size the of the kernel. It should be positive.
|
652 |
+
|
653 |
+
Returns:
|
654 |
+
1D tensor with Hanning filter coefficients.
|
655 |
+
.. math:: w(n) = 0.5 - 0.5cos\\left(\\frac{2\\pi{n}}{M-1}\\right)
|
656 |
+
|
657 |
+
Shape:
|
658 |
+
- Output: math:`(\text{kernel_size})`
|
659 |
+
|
660 |
+
Examples:
|
661 |
+
>>> get_hanning_kernel1d(4)
|
662 |
+
tensor([0.0000, 0.7500, 0.7500, 0.0000])
|
663 |
+
"""
|
664 |
+
if not isinstance(kernel_size, int) or kernel_size <= 2:
|
665 |
+
raise TypeError(f"ksize must be an positive integer > 2. Got {kernel_size}")
|
666 |
+
|
667 |
+
x: torch.Tensor = torch.arange(kernel_size, device=device, dtype=dtype)
|
668 |
+
x = 0.5 - 0.5 * torch.cos(2.0 * math.pi * x / float(kernel_size - 1))
|
669 |
+
return x
|
670 |
+
|
671 |
+
|
672 |
+
def get_hanning_kernel2d(kernel_size: Tuple[int, int], device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor:
|
673 |
+
r"""Returns 2d Hanning kernel, used in signal processing and KCF tracker.
|
674 |
+
|
675 |
+
Args:
|
676 |
+
kernel_size: The size of the kernel for the filter. It should be positive.
|
677 |
+
|
678 |
+
Returns:
|
679 |
+
2D tensor with Hanning filter coefficients.
|
680 |
+
.. math:: w(n) = 0.5 - 0.5cos\\left(\\frac{2\\pi{n}}{M-1}\\right)
|
681 |
+
|
682 |
+
Shape:
|
683 |
+
- Output: math:`(\text{kernel_size[0], kernel_size[1]})`
|
684 |
+
"""
|
685 |
+
if kernel_size[0] <= 2 or kernel_size[1] <= 2:
|
686 |
+
raise TypeError(f"ksize must be an tuple of positive integers > 2. Got {kernel_size}")
|
687 |
+
ky: torch.Tensor = get_hanning_kernel1d(kernel_size[0], device, dtype)[None].T
|
688 |
+
kx: torch.Tensor = get_hanning_kernel1d(kernel_size[1], device, dtype)[None]
|
689 |
+
kernel2d = ky @ kx
|
690 |
+
return kernel2d
|
model/canny/sobel.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .kernels import get_spatial_gradient_kernel2d, get_spatial_gradient_kernel3d, normalize_kernel2d
|
6 |
+
|
7 |
+
|
8 |
+
def spatial_gradient(input: torch.Tensor, mode: str = 'sobel', order: int = 1, normalized: bool = True) -> torch.Tensor:
|
9 |
+
r"""Compute the first order image derivative in both x and y using a Sobel operator.
|
10 |
+
|
11 |
+
.. image:: _static/img/spatial_gradient.png
|
12 |
+
|
13 |
+
Args:
|
14 |
+
input: input image tensor with shape :math:`(B, C, H, W)`.
|
15 |
+
mode: derivatives modality, can be: `sobel` or `diff`.
|
16 |
+
order: the order of the derivatives.
|
17 |
+
normalized: whether the output is normalized.
|
18 |
+
|
19 |
+
Return:
|
20 |
+
the derivatives of the input feature map. with shape :math:`(B, C, 2, H, W)`.
|
21 |
+
|
22 |
+
.. note::
|
23 |
+
See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/
|
24 |
+
filtering_edges.html>`__.
|
25 |
+
|
26 |
+
Examples:
|
27 |
+
>>> input = torch.rand(1, 3, 4, 4)
|
28 |
+
>>> output = spatial_gradient(input) # 1x3x2x4x4
|
29 |
+
>>> output.shape
|
30 |
+
torch.Size([1, 3, 2, 4, 4])
|
31 |
+
"""
|
32 |
+
if not isinstance(input, torch.Tensor):
|
33 |
+
raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
|
34 |
+
|
35 |
+
if not len(input.shape) == 4:
|
36 |
+
raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}")
|
37 |
+
# allocate kernel
|
38 |
+
kernel: torch.Tensor = get_spatial_gradient_kernel2d(mode, order)
|
39 |
+
if normalized:
|
40 |
+
kernel = normalize_kernel2d(kernel)
|
41 |
+
|
42 |
+
# prepare kernel
|
43 |
+
b, c, h, w = input.shape
|
44 |
+
tmp_kernel: torch.Tensor = kernel.to(input).detach()
|
45 |
+
tmp_kernel = tmp_kernel.unsqueeze(1).unsqueeze(1)
|
46 |
+
|
47 |
+
# convolve input tensor with sobel kernel
|
48 |
+
kernel_flip: torch.Tensor = tmp_kernel.flip(-3)
|
49 |
+
|
50 |
+
# Pad with "replicate for spatial dims, but with zeros for channel
|
51 |
+
spatial_pad = [kernel.size(1) // 2, kernel.size(1) // 2, kernel.size(2) // 2, kernel.size(2) // 2]
|
52 |
+
out_channels: int = 3 if order == 2 else 2
|
53 |
+
padded_inp: torch.Tensor = F.pad(input.reshape(b * c, 1, h, w), spatial_pad, 'replicate')[:, :, None]
|
54 |
+
|
55 |
+
return F.conv3d(padded_inp, kernel_flip, padding=0).view(b, c, out_channels, h, w)
|
56 |
+
|
57 |
+
|
58 |
+
def spatial_gradient3d(input: torch.Tensor, mode: str = 'diff', order: int = 1) -> torch.Tensor:
|
59 |
+
r"""Compute the first and second order volume derivative in x, y and d using a diff operator.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
input: input features tensor with shape :math:`(B, C, D, H, W)`.
|
63 |
+
mode: derivatives modality, can be: `sobel` or `diff`.
|
64 |
+
order: the order of the derivatives.
|
65 |
+
|
66 |
+
Return:
|
67 |
+
the spatial gradients of the input feature map with shape math:`(B, C, 3, D, H, W)`
|
68 |
+
or :math:`(B, C, 6, D, H, W)`.
|
69 |
+
|
70 |
+
Examples:
|
71 |
+
>>> input = torch.rand(1, 4, 2, 4, 4)
|
72 |
+
>>> output = spatial_gradient3d(input)
|
73 |
+
>>> output.shape
|
74 |
+
torch.Size([1, 4, 3, 2, 4, 4])
|
75 |
+
"""
|
76 |
+
if not isinstance(input, torch.Tensor):
|
77 |
+
raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
|
78 |
+
|
79 |
+
if not len(input.shape) == 5:
|
80 |
+
raise ValueError(f"Invalid input shape, we expect BxCxDxHxW. Got: {input.shape}")
|
81 |
+
b, c, d, h, w = input.shape
|
82 |
+
dev = input.device
|
83 |
+
dtype = input.dtype
|
84 |
+
if (mode == 'diff') and (order == 1):
|
85 |
+
# we go for the special case implementation due to conv3d bad speed
|
86 |
+
x: torch.Tensor = F.pad(input, 6 * [1], 'replicate')
|
87 |
+
center = slice(1, -1)
|
88 |
+
left = slice(0, -2)
|
89 |
+
right = slice(2, None)
|
90 |
+
out = torch.empty(b, c, 3, d, h, w, device=dev, dtype=dtype)
|
91 |
+
out[..., 0, :, :, :] = x[..., center, center, right] - x[..., center, center, left]
|
92 |
+
out[..., 1, :, :, :] = x[..., center, right, center] - x[..., center, left, center]
|
93 |
+
out[..., 2, :, :, :] = x[..., right, center, center] - x[..., left, center, center]
|
94 |
+
out = 0.5 * out
|
95 |
+
else:
|
96 |
+
# prepare kernel
|
97 |
+
# allocate kernel
|
98 |
+
kernel: torch.Tensor = get_spatial_gradient_kernel3d(mode, order)
|
99 |
+
|
100 |
+
tmp_kernel: torch.Tensor = kernel.to(input).detach()
|
101 |
+
tmp_kernel = tmp_kernel.repeat(c, 1, 1, 1, 1)
|
102 |
+
|
103 |
+
# convolve input tensor with grad kernel
|
104 |
+
kernel_flip: torch.Tensor = tmp_kernel.flip(-3)
|
105 |
+
|
106 |
+
# Pad with "replicate for spatial dims, but with zeros for channel
|
107 |
+
spatial_pad = [
|
108 |
+
kernel.size(2) // 2,
|
109 |
+
kernel.size(2) // 2,
|
110 |
+
kernel.size(3) // 2,
|
111 |
+
kernel.size(3) // 2,
|
112 |
+
kernel.size(4) // 2,
|
113 |
+
kernel.size(4) // 2,
|
114 |
+
]
|
115 |
+
out_ch: int = 6 if order == 2 else 3
|
116 |
+
out = F.conv3d(F.pad(input, spatial_pad, 'replicate'), kernel_flip, padding=0, groups=c).view(
|
117 |
+
b, c, out_ch, d, h, w
|
118 |
+
)
|
119 |
+
return out
|
120 |
+
|
121 |
+
|
122 |
+
def sobel(input: torch.Tensor, normalized: bool = True, eps: float = 1e-6) -> torch.Tensor:
|
123 |
+
r"""Compute the Sobel operator and returns the magnitude per channel.
|
124 |
+
|
125 |
+
.. image:: _static/img/sobel.png
|
126 |
+
|
127 |
+
Args:
|
128 |
+
input: the input image with shape :math:`(B,C,H,W)`.
|
129 |
+
normalized: if True, L1 norm of the kernel is set to 1.
|
130 |
+
eps: regularization number to avoid NaN during backprop.
|
131 |
+
|
132 |
+
Return:
|
133 |
+
the sobel edge gradient magnitudes map with shape :math:`(B,C,H,W)`.
|
134 |
+
|
135 |
+
.. note::
|
136 |
+
See a working example `here <https://kornia-tutorials.readthedocs.io/en/latest/
|
137 |
+
filtering_edges.html>`__.
|
138 |
+
|
139 |
+
Example:
|
140 |
+
>>> input = torch.rand(1, 3, 4, 4)
|
141 |
+
>>> output = sobel(input) # 1x3x4x4
|
142 |
+
>>> output.shape
|
143 |
+
torch.Size([1, 3, 4, 4])
|
144 |
+
"""
|
145 |
+
if not isinstance(input, torch.Tensor):
|
146 |
+
raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
|
147 |
+
|
148 |
+
if not len(input.shape) == 4:
|
149 |
+
raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}")
|
150 |
+
|
151 |
+
# comput the x/y gradients
|
152 |
+
edges: torch.Tensor = spatial_gradient(input, normalized=normalized)
|
153 |
+
|
154 |
+
# unpack the edges
|
155 |
+
gx: torch.Tensor = edges[:, :, 0]
|
156 |
+
gy: torch.Tensor = edges[:, :, 1]
|
157 |
+
|
158 |
+
# compute gradient maginitude
|
159 |
+
magnitude: torch.Tensor = torch.sqrt(gx * gx + gy * gy + eps)
|
160 |
+
|
161 |
+
return magnitude
|
162 |
+
|
163 |
+
|
164 |
+
class SpatialGradient(nn.Module):
|
165 |
+
r"""Compute the first order image derivative in both x and y using a Sobel operator.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
mode: derivatives modality, can be: `sobel` or `diff`.
|
169 |
+
order: the order of the derivatives.
|
170 |
+
normalized: whether the output is normalized.
|
171 |
+
|
172 |
+
Return:
|
173 |
+
the sobel edges of the input feature map.
|
174 |
+
|
175 |
+
Shape:
|
176 |
+
- Input: :math:`(B, C, H, W)`
|
177 |
+
- Output: :math:`(B, C, 2, H, W)`
|
178 |
+
|
179 |
+
Examples:
|
180 |
+
>>> input = torch.rand(1, 3, 4, 4)
|
181 |
+
>>> output = SpatialGradient()(input) # 1x3x2x4x4
|
182 |
+
"""
|
183 |
+
|
184 |
+
def __init__(self, mode: str = 'sobel', order: int = 1, normalized: bool = True) -> None:
|
185 |
+
super().__init__()
|
186 |
+
self.normalized: bool = normalized
|
187 |
+
self.order: int = order
|
188 |
+
self.mode: str = mode
|
189 |
+
|
190 |
+
def __repr__(self) -> str:
|
191 |
+
return (
|
192 |
+
self.__class__.__name__ + '('
|
193 |
+
'order=' + str(self.order) + ', ' + 'normalized=' + str(self.normalized) + ', ' + 'mode=' + self.mode + ')'
|
194 |
+
)
|
195 |
+
|
196 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
197 |
+
return spatial_gradient(input, self.mode, self.order, self.normalized)
|
198 |
+
|
199 |
+
|
200 |
+
class SpatialGradient3d(nn.Module):
|
201 |
+
r"""Compute the first and second order volume derivative in x, y and d using a diff operator.
|
202 |
+
|
203 |
+
Args:
|
204 |
+
mode: derivatives modality, can be: `sobel` or `diff`.
|
205 |
+
order: the order of the derivatives.
|
206 |
+
|
207 |
+
Return:
|
208 |
+
the spatial gradients of the input feature map.
|
209 |
+
|
210 |
+
Shape:
|
211 |
+
- Input: :math:`(B, C, D, H, W)`. D, H, W are spatial dimensions, gradient is calculated w.r.t to them.
|
212 |
+
- Output: :math:`(B, C, 3, D, H, W)` or :math:`(B, C, 6, D, H, W)`
|
213 |
+
|
214 |
+
Examples:
|
215 |
+
>>> input = torch.rand(1, 4, 2, 4, 4)
|
216 |
+
>>> output = SpatialGradient3d()(input)
|
217 |
+
>>> output.shape
|
218 |
+
torch.Size([1, 4, 3, 2, 4, 4])
|
219 |
+
"""
|
220 |
+
|
221 |
+
def __init__(self, mode: str = 'diff', order: int = 1) -> None:
|
222 |
+
super().__init__()
|
223 |
+
self.order: int = order
|
224 |
+
self.mode: str = mode
|
225 |
+
self.kernel = get_spatial_gradient_kernel3d(mode, order)
|
226 |
+
return
|
227 |
+
|
228 |
+
def __repr__(self) -> str:
|
229 |
+
return self.__class__.__name__ + '(' 'order=' + str(self.order) + ', ' + 'mode=' + self.mode + ')'
|
230 |
+
|
231 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore
|
232 |
+
return spatial_gradient3d(input, self.mode, self.order)
|
233 |
+
|
234 |
+
|
235 |
+
class Sobel(nn.Module):
|
236 |
+
r"""Compute the Sobel operator and returns the magnitude per channel.
|
237 |
+
|
238 |
+
Args:
|
239 |
+
normalized: if True, L1 norm of the kernel is set to 1.
|
240 |
+
eps: regularization number to avoid NaN during backprop.
|
241 |
+
|
242 |
+
Return:
|
243 |
+
the sobel edge gradient magnitudes map.
|
244 |
+
|
245 |
+
Shape:
|
246 |
+
- Input: :math:`(B, C, H, W)`
|
247 |
+
- Output: :math:`(B, C, H, W)`
|
248 |
+
|
249 |
+
Examples:
|
250 |
+
>>> input = torch.rand(1, 3, 4, 4)
|
251 |
+
>>> output = Sobel()(input) # 1x3x4x4
|
252 |
+
"""
|
253 |
+
|
254 |
+
def __init__(self, normalized: bool = True, eps: float = 1e-6) -> None:
|
255 |
+
super().__init__()
|
256 |
+
self.normalized: bool = normalized
|
257 |
+
self.eps: float = eps
|
258 |
+
|
259 |
+
def __repr__(self) -> str:
|
260 |
+
return self.__class__.__name__ + '(' 'normalized=' + str(self.normalized) + ')'
|
261 |
+
|
262 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
263 |
+
return sobel(input, self.normalized, self.eps)
|
model/misc.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import random
|
4 |
+
import time
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import logging
|
8 |
+
import numpy as np
|
9 |
+
from os import path as osp
|
10 |
+
|
11 |
+
def constant_init(module, val, bias=0):
|
12 |
+
if hasattr(module, 'weight') and module.weight is not None:
|
13 |
+
nn.init.constant_(module.weight, val)
|
14 |
+
if hasattr(module, 'bias') and module.bias is not None:
|
15 |
+
nn.init.constant_(module.bias, bias)
|
16 |
+
|
17 |
+
initialized_logger = {}
|
18 |
+
def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
|
19 |
+
"""Get the root logger.
|
20 |
+
The logger will be initialized if it has not been initialized. By default a
|
21 |
+
StreamHandler will be added. If `log_file` is specified, a FileHandler will
|
22 |
+
also be added.
|
23 |
+
Args:
|
24 |
+
logger_name (str): root logger name. Default: 'basicsr'.
|
25 |
+
log_file (str | None): The log filename. If specified, a FileHandler
|
26 |
+
will be added to the root logger.
|
27 |
+
log_level (int): The root logger level. Note that only the process of
|
28 |
+
rank 0 is affected, while other processes will set the level to
|
29 |
+
"Error" and be silent most of the time.
|
30 |
+
Returns:
|
31 |
+
logging.Logger: The root logger.
|
32 |
+
"""
|
33 |
+
logger = logging.getLogger(logger_name)
|
34 |
+
# if the logger has been initialized, just return it
|
35 |
+
if logger_name in initialized_logger:
|
36 |
+
return logger
|
37 |
+
|
38 |
+
format_str = '%(asctime)s %(levelname)s: %(message)s'
|
39 |
+
stream_handler = logging.StreamHandler()
|
40 |
+
stream_handler.setFormatter(logging.Formatter(format_str))
|
41 |
+
logger.addHandler(stream_handler)
|
42 |
+
logger.propagate = False
|
43 |
+
|
44 |
+
if log_file is not None:
|
45 |
+
logger.setLevel(log_level)
|
46 |
+
# add file handler
|
47 |
+
# file_handler = logging.FileHandler(log_file, 'w')
|
48 |
+
file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log
|
49 |
+
file_handler.setFormatter(logging.Formatter(format_str))
|
50 |
+
file_handler.setLevel(log_level)
|
51 |
+
logger.addHandler(file_handler)
|
52 |
+
initialized_logger[logger_name] = True
|
53 |
+
return logger
|
54 |
+
|
55 |
+
|
56 |
+
IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
|
57 |
+
torch.__version__)[0][:3])] >= [1, 12, 0]
|
58 |
+
|
59 |
+
def gpu_is_available():
|
60 |
+
if IS_HIGH_VERSION:
|
61 |
+
if torch.backends.mps.is_available():
|
62 |
+
return True
|
63 |
+
return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False
|
64 |
+
|
65 |
+
def get_device(gpu_id=None):
|
66 |
+
if gpu_id is None:
|
67 |
+
gpu_str = ''
|
68 |
+
elif isinstance(gpu_id, int):
|
69 |
+
gpu_str = f':{gpu_id}'
|
70 |
+
else:
|
71 |
+
raise TypeError('Input should be int value.')
|
72 |
+
|
73 |
+
if IS_HIGH_VERSION:
|
74 |
+
if torch.backends.mps.is_available():
|
75 |
+
return torch.device('mps'+gpu_str)
|
76 |
+
return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu')
|
77 |
+
|
78 |
+
|
79 |
+
def set_random_seed(seed):
|
80 |
+
"""Set random seeds."""
|
81 |
+
random.seed(seed)
|
82 |
+
np.random.seed(seed)
|
83 |
+
torch.manual_seed(seed)
|
84 |
+
torch.cuda.manual_seed(seed)
|
85 |
+
torch.cuda.manual_seed_all(seed)
|
86 |
+
|
87 |
+
|
88 |
+
def get_time_str():
|
89 |
+
return time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
90 |
+
|
91 |
+
|
92 |
+
def scandir(dir_path, suffix=None, recursive=False, full_path=False):
|
93 |
+
"""Scan a directory to find the interested files.
|
94 |
+
|
95 |
+
Args:
|
96 |
+
dir_path (str): Path of the directory.
|
97 |
+
suffix (str | tuple(str), optional): File suffix that we are
|
98 |
+
interested in. Default: None.
|
99 |
+
recursive (bool, optional): If set to True, recursively scan the
|
100 |
+
directory. Default: False.
|
101 |
+
full_path (bool, optional): If set to True, include the dir_path.
|
102 |
+
Default: False.
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
A generator for all the interested files with relative pathes.
|
106 |
+
"""
|
107 |
+
|
108 |
+
if (suffix is not None) and not isinstance(suffix, (str, tuple)):
|
109 |
+
raise TypeError('"suffix" must be a string or tuple of strings')
|
110 |
+
|
111 |
+
root = dir_path
|
112 |
+
|
113 |
+
def _scandir(dir_path, suffix, recursive):
|
114 |
+
for entry in os.scandir(dir_path):
|
115 |
+
if not entry.name.startswith('.') and entry.is_file():
|
116 |
+
if full_path:
|
117 |
+
return_path = entry.path
|
118 |
+
else:
|
119 |
+
return_path = osp.relpath(entry.path, root)
|
120 |
+
|
121 |
+
if suffix is None:
|
122 |
+
yield return_path
|
123 |
+
elif return_path.endswith(suffix):
|
124 |
+
yield return_path
|
125 |
+
else:
|
126 |
+
if recursive:
|
127 |
+
yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
|
128 |
+
else:
|
129 |
+
continue
|
130 |
+
|
131 |
+
return _scandir(dir_path, suffix=suffix, recursive=recursive)
|
model/modules/base_module.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from functools import reduce
|
6 |
+
|
7 |
+
class BaseNetwork(nn.Module):
|
8 |
+
def __init__(self):
|
9 |
+
super(BaseNetwork, self).__init__()
|
10 |
+
|
11 |
+
def print_network(self):
|
12 |
+
if isinstance(self, list):
|
13 |
+
self = self[0]
|
14 |
+
num_params = 0
|
15 |
+
for param in self.parameters():
|
16 |
+
num_params += param.numel()
|
17 |
+
print(
|
18 |
+
'Network [%s] was created. Total number of parameters: %.1f million. '
|
19 |
+
'To see the architecture, do print(network).' %
|
20 |
+
(type(self).__name__, num_params / 1000000))
|
21 |
+
|
22 |
+
def init_weights(self, init_type='normal', gain=0.02):
|
23 |
+
'''
|
24 |
+
initialize network's weights
|
25 |
+
init_type: normal | xavier | kaiming | orthogonal
|
26 |
+
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
|
27 |
+
'''
|
28 |
+
def init_func(m):
|
29 |
+
classname = m.__class__.__name__
|
30 |
+
if classname.find('InstanceNorm2d') != -1:
|
31 |
+
if hasattr(m, 'weight') and m.weight is not None:
|
32 |
+
nn.init.constant_(m.weight.data, 1.0)
|
33 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
34 |
+
nn.init.constant_(m.bias.data, 0.0)
|
35 |
+
elif hasattr(m, 'weight') and (classname.find('Conv') != -1
|
36 |
+
or classname.find('Linear') != -1):
|
37 |
+
if init_type == 'normal':
|
38 |
+
nn.init.normal_(m.weight.data, 0.0, gain)
|
39 |
+
elif init_type == 'xavier':
|
40 |
+
nn.init.xavier_normal_(m.weight.data, gain=gain)
|
41 |
+
elif init_type == 'xavier_uniform':
|
42 |
+
nn.init.xavier_uniform_(m.weight.data, gain=1.0)
|
43 |
+
elif init_type == 'kaiming':
|
44 |
+
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
45 |
+
elif init_type == 'orthogonal':
|
46 |
+
nn.init.orthogonal_(m.weight.data, gain=gain)
|
47 |
+
elif init_type == 'none': # uses pytorch's default init method
|
48 |
+
m.reset_parameters()
|
49 |
+
else:
|
50 |
+
raise NotImplementedError(
|
51 |
+
'initialization method [%s] is not implemented' %
|
52 |
+
init_type)
|
53 |
+
if hasattr(m, 'bias') and m.bias is not None:
|
54 |
+
nn.init.constant_(m.bias.data, 0.0)
|
55 |
+
|
56 |
+
self.apply(init_func)
|
57 |
+
|
58 |
+
# propagate to children
|
59 |
+
for m in self.children():
|
60 |
+
if hasattr(m, 'init_weights'):
|
61 |
+
m.init_weights(init_type, gain)
|
62 |
+
|
63 |
+
|
64 |
+
class Vec2Feat(nn.Module):
|
65 |
+
def __init__(self, channel, hidden, kernel_size, stride, padding):
|
66 |
+
super(Vec2Feat, self).__init__()
|
67 |
+
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
68 |
+
c_out = reduce((lambda x, y: x * y), kernel_size) * channel
|
69 |
+
self.embedding = nn.Linear(hidden, c_out)
|
70 |
+
self.kernel_size = kernel_size
|
71 |
+
self.stride = stride
|
72 |
+
self.padding = padding
|
73 |
+
self.bias_conv = nn.Conv2d(channel,
|
74 |
+
channel,
|
75 |
+
kernel_size=3,
|
76 |
+
stride=1,
|
77 |
+
padding=1)
|
78 |
+
|
79 |
+
def forward(self, x, t, output_size):
|
80 |
+
b_, _, _, _, c_ = x.shape
|
81 |
+
x = x.view(b_, -1, c_)
|
82 |
+
feat = self.embedding(x)
|
83 |
+
b, _, c = feat.size()
|
84 |
+
feat = feat.view(b * t, -1, c).permute(0, 2, 1)
|
85 |
+
feat = F.fold(feat,
|
86 |
+
output_size=output_size,
|
87 |
+
kernel_size=self.kernel_size,
|
88 |
+
stride=self.stride,
|
89 |
+
padding=self.padding)
|
90 |
+
feat = self.bias_conv(feat)
|
91 |
+
return feat
|
92 |
+
|
93 |
+
|
94 |
+
class FusionFeedForward(nn.Module):
|
95 |
+
def __init__(self, dim, hidden_dim=1960, t2t_params=None):
|
96 |
+
super(FusionFeedForward, self).__init__()
|
97 |
+
# We set hidden_dim as a default to 1960
|
98 |
+
self.fc1 = nn.Sequential(nn.Linear(dim, hidden_dim))
|
99 |
+
self.fc2 = nn.Sequential(nn.GELU(), nn.Linear(hidden_dim, dim))
|
100 |
+
assert t2t_params is not None
|
101 |
+
self.t2t_params = t2t_params
|
102 |
+
self.kernel_shape = reduce((lambda x, y: x * y), t2t_params['kernel_size']) # 49
|
103 |
+
|
104 |
+
def forward(self, x, output_size):
|
105 |
+
n_vecs = 1
|
106 |
+
for i, d in enumerate(self.t2t_params['kernel_size']):
|
107 |
+
n_vecs *= int((output_size[i] + 2 * self.t2t_params['padding'][i] -
|
108 |
+
(d - 1) - 1) / self.t2t_params['stride'][i] + 1)
|
109 |
+
|
110 |
+
x = self.fc1(x)
|
111 |
+
b, n, c = x.size()
|
112 |
+
normalizer = x.new_ones(b, n, self.kernel_shape).view(-1, n_vecs, self.kernel_shape).permute(0, 2, 1)
|
113 |
+
normalizer = F.fold(normalizer,
|
114 |
+
output_size=output_size,
|
115 |
+
kernel_size=self.t2t_params['kernel_size'],
|
116 |
+
padding=self.t2t_params['padding'],
|
117 |
+
stride=self.t2t_params['stride'])
|
118 |
+
|
119 |
+
x = F.fold(x.view(-1, n_vecs, c).permute(0, 2, 1),
|
120 |
+
output_size=output_size,
|
121 |
+
kernel_size=self.t2t_params['kernel_size'],
|
122 |
+
padding=self.t2t_params['padding'],
|
123 |
+
stride=self.t2t_params['stride'])
|
124 |
+
|
125 |
+
x = F.unfold(x / normalizer,
|
126 |
+
kernel_size=self.t2t_params['kernel_size'],
|
127 |
+
padding=self.t2t_params['padding'],
|
128 |
+
stride=self.t2t_params['stride']).permute(
|
129 |
+
0, 2, 1).contiguous().view(b, n, c)
|
130 |
+
x = self.fc2(x)
|
131 |
+
return x
|
model/modules/deformconv.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch.nn import init as init
|
4 |
+
from torch.nn.modules.utils import _pair, _single
|
5 |
+
import math
|
6 |
+
|
7 |
+
class ModulatedDeformConv2d(nn.Module):
|
8 |
+
def __init__(self,
|
9 |
+
in_channels,
|
10 |
+
out_channels,
|
11 |
+
kernel_size,
|
12 |
+
stride=1,
|
13 |
+
padding=0,
|
14 |
+
dilation=1,
|
15 |
+
groups=1,
|
16 |
+
deform_groups=1,
|
17 |
+
bias=True):
|
18 |
+
super(ModulatedDeformConv2d, self).__init__()
|
19 |
+
|
20 |
+
self.in_channels = in_channels
|
21 |
+
self.out_channels = out_channels
|
22 |
+
self.kernel_size = _pair(kernel_size)
|
23 |
+
self.stride = stride
|
24 |
+
self.padding = padding
|
25 |
+
self.dilation = dilation
|
26 |
+
self.groups = groups
|
27 |
+
self.deform_groups = deform_groups
|
28 |
+
self.with_bias = bias
|
29 |
+
# enable compatibility with nn.Conv2d
|
30 |
+
self.transposed = False
|
31 |
+
self.output_padding = _single(0)
|
32 |
+
|
33 |
+
self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
|
34 |
+
if bias:
|
35 |
+
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
36 |
+
else:
|
37 |
+
self.register_parameter('bias', None)
|
38 |
+
self.init_weights()
|
39 |
+
|
40 |
+
def init_weights(self):
|
41 |
+
n = self.in_channels
|
42 |
+
for k in self.kernel_size:
|
43 |
+
n *= k
|
44 |
+
stdv = 1. / math.sqrt(n)
|
45 |
+
self.weight.data.uniform_(-stdv, stdv)
|
46 |
+
if self.bias is not None:
|
47 |
+
self.bias.data.zero_()
|
48 |
+
|
49 |
+
if hasattr(self, 'conv_offset'):
|
50 |
+
self.conv_offset.weight.data.zero_()
|
51 |
+
self.conv_offset.bias.data.zero_()
|
52 |
+
|
53 |
+
def forward(self, x, offset, mask):
|
54 |
+
pass
|
model/modules/flow_comp_raft.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from RAFT import RAFT
|
7 |
+
from model.modules.flow_loss_utils import flow_warp, ternary_loss2
|
8 |
+
|
9 |
+
|
10 |
+
def initialize_RAFT(model_path='weights/raft-things.pth', device='cuda'):
|
11 |
+
"""Initializes the RAFT model.
|
12 |
+
"""
|
13 |
+
args = argparse.ArgumentParser()
|
14 |
+
args.raft_model = model_path
|
15 |
+
args.small = False
|
16 |
+
args.mixed_precision = False
|
17 |
+
args.alternate_corr = False
|
18 |
+
model = torch.nn.DataParallel(RAFT(args))
|
19 |
+
model.load_state_dict(torch.load(args.raft_model, map_location='cpu'))
|
20 |
+
model = model.module
|
21 |
+
|
22 |
+
model.to(device)
|
23 |
+
|
24 |
+
return model
|
25 |
+
|
26 |
+
|
27 |
+
class RAFT_bi(nn.Module):
|
28 |
+
"""Flow completion loss"""
|
29 |
+
def __init__(self, model_path='weights/raft-things.pth', device='cuda'):
|
30 |
+
super().__init__()
|
31 |
+
self.fix_raft = initialize_RAFT(model_path, device=device)
|
32 |
+
|
33 |
+
for p in self.fix_raft.parameters():
|
34 |
+
p.requires_grad = False
|
35 |
+
|
36 |
+
self.l1_criterion = nn.L1Loss()
|
37 |
+
self.eval()
|
38 |
+
|
39 |
+
def forward(self, gt_local_frames, iters=20):
|
40 |
+
b, l_t, c, h, w = gt_local_frames.size()
|
41 |
+
# print(gt_local_frames.shape)
|
42 |
+
|
43 |
+
with torch.no_grad():
|
44 |
+
gtlf_1 = gt_local_frames[:, :-1, :, :, :].reshape(-1, c, h, w)
|
45 |
+
gtlf_2 = gt_local_frames[:, 1:, :, :, :].reshape(-1, c, h, w)
|
46 |
+
# print(gtlf_1.shape)
|
47 |
+
|
48 |
+
_, gt_flows_forward = self.fix_raft(gtlf_1, gtlf_2, iters=iters, test_mode=True)
|
49 |
+
_, gt_flows_backward = self.fix_raft(gtlf_2, gtlf_1, iters=iters, test_mode=True)
|
50 |
+
|
51 |
+
|
52 |
+
gt_flows_forward = gt_flows_forward.view(b, l_t-1, 2, h, w)
|
53 |
+
gt_flows_backward = gt_flows_backward.view(b, l_t-1, 2, h, w)
|
54 |
+
|
55 |
+
return gt_flows_forward, gt_flows_backward
|
56 |
+
|
57 |
+
|
58 |
+
##################################################################################
|
59 |
+
def smoothness_loss(flow, cmask):
|
60 |
+
delta_u, delta_v, mask = smoothness_deltas(flow)
|
61 |
+
loss_u = charbonnier_loss(delta_u, cmask)
|
62 |
+
loss_v = charbonnier_loss(delta_v, cmask)
|
63 |
+
return loss_u + loss_v
|
64 |
+
|
65 |
+
|
66 |
+
def smoothness_deltas(flow):
|
67 |
+
"""
|
68 |
+
flow: [b, c, h, w]
|
69 |
+
"""
|
70 |
+
mask_x = create_mask(flow, [[0, 0], [0, 1]])
|
71 |
+
mask_y = create_mask(flow, [[0, 1], [0, 0]])
|
72 |
+
mask = torch.cat((mask_x, mask_y), dim=1)
|
73 |
+
mask = mask.to(flow.device)
|
74 |
+
filter_x = torch.tensor([[0, 0, 0.], [0, 1, -1], [0, 0, 0]])
|
75 |
+
filter_y = torch.tensor([[0, 0, 0.], [0, 1, 0], [0, -1, 0]])
|
76 |
+
weights = torch.ones([2, 1, 3, 3])
|
77 |
+
weights[0, 0] = filter_x
|
78 |
+
weights[1, 0] = filter_y
|
79 |
+
weights = weights.to(flow.device)
|
80 |
+
|
81 |
+
flow_u, flow_v = torch.split(flow, split_size_or_sections=1, dim=1)
|
82 |
+
delta_u = F.conv2d(flow_u, weights, stride=1, padding=1)
|
83 |
+
delta_v = F.conv2d(flow_v, weights, stride=1, padding=1)
|
84 |
+
return delta_u, delta_v, mask
|
85 |
+
|
86 |
+
|
87 |
+
def second_order_loss(flow, cmask):
|
88 |
+
delta_u, delta_v, mask = second_order_deltas(flow)
|
89 |
+
loss_u = charbonnier_loss(delta_u, cmask)
|
90 |
+
loss_v = charbonnier_loss(delta_v, cmask)
|
91 |
+
return loss_u + loss_v
|
92 |
+
|
93 |
+
|
94 |
+
def charbonnier_loss(x, mask=None, truncate=None, alpha=0.45, beta=1.0, epsilon=0.001):
|
95 |
+
"""
|
96 |
+
Compute the generalized charbonnier loss of the difference tensor x
|
97 |
+
All positions where mask == 0 are not taken into account
|
98 |
+
x: a tensor of shape [b, c, h, w]
|
99 |
+
mask: a mask of shape [b, mc, h, w], where mask channels must be either 1 or the same as
|
100 |
+
the number of channels of x. Entries should be 0 or 1
|
101 |
+
return: loss
|
102 |
+
"""
|
103 |
+
b, c, h, w = x.shape
|
104 |
+
norm = b * c * h * w
|
105 |
+
error = torch.pow(torch.square(x * beta) + torch.square(torch.tensor(epsilon)), alpha)
|
106 |
+
if mask is not None:
|
107 |
+
error = mask * error
|
108 |
+
if truncate is not None:
|
109 |
+
error = torch.min(error, truncate)
|
110 |
+
return torch.sum(error) / norm
|
111 |
+
|
112 |
+
|
113 |
+
def second_order_deltas(flow):
|
114 |
+
"""
|
115 |
+
consider the single flow first
|
116 |
+
flow shape: [b, c, h, w]
|
117 |
+
"""
|
118 |
+
# create mask
|
119 |
+
mask_x = create_mask(flow, [[0, 0], [1, 1]])
|
120 |
+
mask_y = create_mask(flow, [[1, 1], [0, 0]])
|
121 |
+
mask_diag = create_mask(flow, [[1, 1], [1, 1]])
|
122 |
+
mask = torch.cat((mask_x, mask_y, mask_diag, mask_diag), dim=1)
|
123 |
+
mask = mask.to(flow.device)
|
124 |
+
|
125 |
+
filter_x = torch.tensor([[0, 0, 0.], [1, -2, 1], [0, 0, 0]])
|
126 |
+
filter_y = torch.tensor([[0, 1, 0.], [0, -2, 0], [0, 1, 0]])
|
127 |
+
filter_diag1 = torch.tensor([[1, 0, 0.], [0, -2, 0], [0, 0, 1]])
|
128 |
+
filter_diag2 = torch.tensor([[0, 0, 1.], [0, -2, 0], [1, 0, 0]])
|
129 |
+
weights = torch.ones([4, 1, 3, 3])
|
130 |
+
weights[0] = filter_x
|
131 |
+
weights[1] = filter_y
|
132 |
+
weights[2] = filter_diag1
|
133 |
+
weights[3] = filter_diag2
|
134 |
+
weights = weights.to(flow.device)
|
135 |
+
|
136 |
+
# split the flow into flow_u and flow_v, conv them with the weights
|
137 |
+
flow_u, flow_v = torch.split(flow, split_size_or_sections=1, dim=1)
|
138 |
+
delta_u = F.conv2d(flow_u, weights, stride=1, padding=1)
|
139 |
+
delta_v = F.conv2d(flow_v, weights, stride=1, padding=1)
|
140 |
+
return delta_u, delta_v, mask
|
141 |
+
|
142 |
+
def create_mask(tensor, paddings):
|
143 |
+
"""
|
144 |
+
tensor shape: [b, c, h, w]
|
145 |
+
paddings: [2 x 2] shape list, the first row indicates up and down paddings
|
146 |
+
the second row indicates left and right paddings
|
147 |
+
| |
|
148 |
+
| x |
|
149 |
+
| x * x |
|
150 |
+
| x |
|
151 |
+
| |
|
152 |
+
"""
|
153 |
+
shape = tensor.shape
|
154 |
+
inner_height = shape[2] - (paddings[0][0] + paddings[0][1])
|
155 |
+
inner_width = shape[3] - (paddings[1][0] + paddings[1][1])
|
156 |
+
inner = torch.ones([inner_height, inner_width])
|
157 |
+
torch_paddings = [paddings[1][0], paddings[1][1], paddings[0][0], paddings[0][1]] # left, right, up and down
|
158 |
+
mask2d = F.pad(inner, pad=torch_paddings)
|
159 |
+
mask3d = mask2d.unsqueeze(0).repeat(shape[0], 1, 1)
|
160 |
+
mask4d = mask3d.unsqueeze(1)
|
161 |
+
return mask4d.detach()
|
162 |
+
|
163 |
+
def ternary_loss(flow_comp, flow_gt, mask, current_frame, shift_frame, scale_factor=1):
|
164 |
+
if scale_factor != 1:
|
165 |
+
current_frame = F.interpolate(current_frame, scale_factor=1 / scale_factor, mode='bilinear')
|
166 |
+
shift_frame = F.interpolate(shift_frame, scale_factor=1 / scale_factor, mode='bilinear')
|
167 |
+
warped_sc = flow_warp(shift_frame, flow_gt.permute(0, 2, 3, 1))
|
168 |
+
noc_mask = torch.exp(-50. * torch.sum(torch.abs(current_frame - warped_sc), dim=1).pow(2)).unsqueeze(1)
|
169 |
+
warped_comp_sc = flow_warp(shift_frame, flow_comp.permute(0, 2, 3, 1))
|
170 |
+
loss = ternary_loss2(current_frame, warped_comp_sc, noc_mask, mask)
|
171 |
+
return loss
|
172 |
+
|
173 |
+
class FlowLoss(nn.Module):
|
174 |
+
def __init__(self):
|
175 |
+
super().__init__()
|
176 |
+
self.l1_criterion = nn.L1Loss()
|
177 |
+
|
178 |
+
def forward(self, pred_flows, gt_flows, masks, frames):
|
179 |
+
# pred_flows: b t-1 2 h w
|
180 |
+
loss = 0
|
181 |
+
warp_loss = 0
|
182 |
+
h, w = pred_flows[0].shape[-2:]
|
183 |
+
masks = [masks[:,:-1,...].contiguous(), masks[:, 1:, ...].contiguous()]
|
184 |
+
frames0 = frames[:,:-1,...]
|
185 |
+
frames1 = frames[:,1:,...]
|
186 |
+
current_frames = [frames0, frames1]
|
187 |
+
next_frames = [frames1, frames0]
|
188 |
+
for i in range(len(pred_flows)):
|
189 |
+
# print(pred_flows[i].shape)
|
190 |
+
combined_flow = pred_flows[i] * masks[i] + gt_flows[i] * (1-masks[i])
|
191 |
+
l1_loss = self.l1_criterion(pred_flows[i] * masks[i], gt_flows[i] * masks[i]) / torch.mean(masks[i])
|
192 |
+
l1_loss += self.l1_criterion(pred_flows[i] * (1-masks[i]), gt_flows[i] * (1-masks[i])) / torch.mean((1-masks[i]))
|
193 |
+
|
194 |
+
smooth_loss = smoothness_loss(combined_flow.reshape(-1,2,h,w), masks[i].reshape(-1,1,h,w))
|
195 |
+
smooth_loss2 = second_order_loss(combined_flow.reshape(-1,2,h,w), masks[i].reshape(-1,1,h,w))
|
196 |
+
|
197 |
+
warp_loss_i = ternary_loss(combined_flow.reshape(-1,2,h,w), gt_flows[i].reshape(-1,2,h,w),
|
198 |
+
masks[i].reshape(-1,1,h,w), current_frames[i].reshape(-1,3,h,w), next_frames[i].reshape(-1,3,h,w))
|
199 |
+
|
200 |
+
loss += l1_loss + smooth_loss + smooth_loss2
|
201 |
+
|
202 |
+
warp_loss += warp_loss_i
|
203 |
+
|
204 |
+
return loss, warp_loss
|
205 |
+
|
206 |
+
|
207 |
+
def edgeLoss(preds_edges, edges):
|
208 |
+
"""
|
209 |
+
|
210 |
+
Args:
|
211 |
+
preds_edges: with shape [b, c, h , w]
|
212 |
+
edges: with shape [b, c, h, w]
|
213 |
+
|
214 |
+
Returns: Edge losses
|
215 |
+
|
216 |
+
"""
|
217 |
+
mask = (edges > 0.5).float()
|
218 |
+
b, c, h, w = mask.shape
|
219 |
+
num_pos = torch.sum(mask, dim=[1, 2, 3]).float() # Shape: [b,].
|
220 |
+
num_neg = c * h * w - num_pos # Shape: [b,].
|
221 |
+
neg_weights = (num_neg / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2).unsqueeze(3)
|
222 |
+
pos_weights = (num_pos / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2).unsqueeze(3)
|
223 |
+
weight = neg_weights * mask + pos_weights * (1 - mask) # weight for debug
|
224 |
+
losses = F.binary_cross_entropy_with_logits(preds_edges.float(), edges.float(), weight=weight, reduction='none')
|
225 |
+
loss = torch.mean(losses)
|
226 |
+
return loss
|
227 |
+
|
228 |
+
class EdgeLoss(nn.Module):
|
229 |
+
def __init__(self):
|
230 |
+
super().__init__()
|
231 |
+
|
232 |
+
def forward(self, pred_edges, gt_edges, masks):
|
233 |
+
# pred_flows: b t-1 1 h w
|
234 |
+
loss = 0
|
235 |
+
h, w = pred_edges[0].shape[-2:]
|
236 |
+
masks = [masks[:,:-1,...].contiguous(), masks[:, 1:, ...].contiguous()]
|
237 |
+
for i in range(len(pred_edges)):
|
238 |
+
# print(f'edges_{i}', torch.sum(gt_edges[i])) # debug
|
239 |
+
combined_edge = pred_edges[i] * masks[i] + gt_edges[i] * (1-masks[i])
|
240 |
+
edge_loss = (edgeLoss(pred_edges[i].reshape(-1,1,h,w), gt_edges[i].reshape(-1,1,h,w)) \
|
241 |
+
+ 5 * edgeLoss(combined_edge.reshape(-1,1,h,w), gt_edges[i].reshape(-1,1,h,w)))
|
242 |
+
loss += edge_loss
|
243 |
+
|
244 |
+
return loss
|
245 |
+
|
246 |
+
|
247 |
+
class FlowSimpleLoss(nn.Module):
|
248 |
+
def __init__(self):
|
249 |
+
super().__init__()
|
250 |
+
self.l1_criterion = nn.L1Loss()
|
251 |
+
|
252 |
+
def forward(self, pred_flows, gt_flows):
|
253 |
+
# pred_flows: b t-1 2 h w
|
254 |
+
loss = 0
|
255 |
+
h, w = pred_flows[0].shape[-2:]
|
256 |
+
h_orig, w_orig = gt_flows[0].shape[-2:]
|
257 |
+
pred_flows = [f.view(-1, 2, h, w) for f in pred_flows]
|
258 |
+
gt_flows = [f.view(-1, 2, h_orig, w_orig) for f in gt_flows]
|
259 |
+
|
260 |
+
ds_factor = 1.0*h/h_orig
|
261 |
+
gt_flows = [F.interpolate(f, scale_factor=ds_factor, mode='area') * ds_factor for f in gt_flows]
|
262 |
+
for i in range(len(pred_flows)):
|
263 |
+
loss += self.l1_criterion(pred_flows[i], gt_flows[i])
|
264 |
+
|
265 |
+
return loss
|
model/modules/flow_loss_utils.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
def flow_warp(x,
|
7 |
+
flow,
|
8 |
+
interpolation='bilinear',
|
9 |
+
padding_mode='zeros',
|
10 |
+
align_corners=True):
|
11 |
+
"""Warp an image or a feature map with optical flow.
|
12 |
+
Args:
|
13 |
+
x (Tensor): Tensor with size (n, c, h, w).
|
14 |
+
flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is
|
15 |
+
a two-channel, denoting the width and height relative offsets.
|
16 |
+
Note that the values are not normalized to [-1, 1].
|
17 |
+
interpolation (str): Interpolation mode: 'nearest' or 'bilinear'.
|
18 |
+
Default: 'bilinear'.
|
19 |
+
padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'.
|
20 |
+
Default: 'zeros'.
|
21 |
+
align_corners (bool): Whether align corners. Default: True.
|
22 |
+
Returns:
|
23 |
+
Tensor: Warped image or feature map.
|
24 |
+
"""
|
25 |
+
if x.size()[-2:] != flow.size()[1:3]:
|
26 |
+
raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and '
|
27 |
+
f'flow ({flow.size()[1:3]}) are not the same.')
|
28 |
+
_, _, h, w = x.size()
|
29 |
+
# create mesh grid
|
30 |
+
device = flow.device
|
31 |
+
grid_y, grid_x = torch.meshgrid(torch.arange(0, h, device=device), torch.arange(0, w, device=device))
|
32 |
+
grid = torch.stack((grid_x, grid_y), 2).type_as(x) # (w, h, 2)
|
33 |
+
grid.requires_grad = False
|
34 |
+
|
35 |
+
grid_flow = grid + flow
|
36 |
+
# scale grid_flow to [-1,1]
|
37 |
+
grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0
|
38 |
+
grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0
|
39 |
+
grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3)
|
40 |
+
output = F.grid_sample(x,
|
41 |
+
grid_flow,
|
42 |
+
mode=interpolation,
|
43 |
+
padding_mode=padding_mode,
|
44 |
+
align_corners=align_corners)
|
45 |
+
return output
|
46 |
+
|
47 |
+
|
48 |
+
# def image_warp(image, flow):
|
49 |
+
# b, c, h, w = image.size()
|
50 |
+
# device = image.device
|
51 |
+
# flow = torch.cat([flow[:, 0:1, :, :] / ((w - 1.0) / 2.0), flow[:, 1:2, :, :] / ((h - 1.0) / 2.0)], dim=1) # normalize to [-1~1](from upper left to lower right
|
52 |
+
# flow = flow.permute(0, 2, 3, 1) # if you wanna use grid_sample function, the channel(band) shape of show must be in the last dimension
|
53 |
+
# x = np.linspace(-1, 1, w)
|
54 |
+
# y = np.linspace(-1, 1, h)
|
55 |
+
# X, Y = np.meshgrid(x, y)
|
56 |
+
# grid = torch.cat((torch.from_numpy(X.astype('float32')).unsqueeze(0).unsqueeze(3),
|
57 |
+
# torch.from_numpy(Y.astype('float32')).unsqueeze(0).unsqueeze(3)), 3).to(device)
|
58 |
+
# output = torch.nn.functional.grid_sample(image, grid + flow, mode='bilinear', padding_mode='zeros')
|
59 |
+
# return output
|
60 |
+
|
61 |
+
|
62 |
+
def length_sq(x):
|
63 |
+
return torch.sum(torch.square(x), dim=1, keepdim=True)
|
64 |
+
|
65 |
+
|
66 |
+
def fbConsistencyCheck(flow_fw, flow_bw, alpha1=0.01, alpha2=0.5):
|
67 |
+
flow_bw_warped = flow_warp(flow_bw, flow_fw.permute(0, 2, 3, 1)) # wb(wf(x))
|
68 |
+
flow_fw_warped = flow_warp(flow_fw, flow_bw.permute(0, 2, 3, 1)) # wf(wb(x))
|
69 |
+
flow_diff_fw = flow_fw + flow_bw_warped # wf + wb(wf(x))
|
70 |
+
flow_diff_bw = flow_bw + flow_fw_warped # wb + wf(wb(x))
|
71 |
+
|
72 |
+
mag_sq_fw = length_sq(flow_fw) + length_sq(flow_bw_warped) # |wf| + |wb(wf(x))|
|
73 |
+
mag_sq_bw = length_sq(flow_bw) + length_sq(flow_fw_warped) # |wb| + |wf(wb(x))|
|
74 |
+
occ_thresh_fw = alpha1 * mag_sq_fw + alpha2
|
75 |
+
occ_thresh_bw = alpha1 * mag_sq_bw + alpha2
|
76 |
+
|
77 |
+
fb_occ_fw = (length_sq(flow_diff_fw) > occ_thresh_fw).float()
|
78 |
+
fb_occ_bw = (length_sq(flow_diff_bw) > occ_thresh_bw).float()
|
79 |
+
|
80 |
+
return fb_occ_fw, fb_occ_bw # fb_occ_fw -> frame2 area occluded by frame1, fb_occ_bw -> frame1 area occluded by frame2
|
81 |
+
|
82 |
+
|
83 |
+
def rgb2gray(image):
|
84 |
+
gray_image = image[:, 0] * 0.299 + image[:, 1] * 0.587 + 0.110 * image[:, 2]
|
85 |
+
gray_image = gray_image.unsqueeze(1)
|
86 |
+
return gray_image
|
87 |
+
|
88 |
+
|
89 |
+
def ternary_transform(image, max_distance=1):
|
90 |
+
device = image.device
|
91 |
+
patch_size = 2 * max_distance + 1
|
92 |
+
intensities = rgb2gray(image) * 255
|
93 |
+
out_channels = patch_size * patch_size
|
94 |
+
w = np.eye(out_channels).reshape(out_channels, 1, patch_size, patch_size)
|
95 |
+
weights = torch.from_numpy(w).float().to(device)
|
96 |
+
patches = F.conv2d(intensities, weights, stride=1, padding=1)
|
97 |
+
transf = patches - intensities
|
98 |
+
transf_norm = transf / torch.sqrt(0.81 + torch.square(transf))
|
99 |
+
return transf_norm
|
100 |
+
|
101 |
+
|
102 |
+
def hamming_distance(t1, t2):
|
103 |
+
dist = torch.square(t1 - t2)
|
104 |
+
dist_norm = dist / (0.1 + dist)
|
105 |
+
dist_sum = torch.sum(dist_norm, dim=1, keepdim=True)
|
106 |
+
return dist_sum
|
107 |
+
|
108 |
+
|
109 |
+
def create_mask(mask, paddings):
|
110 |
+
"""
|
111 |
+
padding: [[top, bottom], [left, right]]
|
112 |
+
"""
|
113 |
+
shape = mask.shape
|
114 |
+
inner_height = shape[2] - (paddings[0][0] + paddings[0][1])
|
115 |
+
inner_width = shape[3] - (paddings[1][0] + paddings[1][1])
|
116 |
+
inner = torch.ones([inner_height, inner_width])
|
117 |
+
|
118 |
+
mask2d = F.pad(inner, pad=[paddings[1][0], paddings[1][1], paddings[0][0], paddings[0][1]])
|
119 |
+
mask3d = mask2d.unsqueeze(0)
|
120 |
+
mask4d = mask3d.unsqueeze(0).repeat(shape[0], 1, 1, 1)
|
121 |
+
return mask4d.detach()
|
122 |
+
|
123 |
+
|
124 |
+
def ternary_loss2(frame1, warp_frame21, confMask, masks, max_distance=1):
|
125 |
+
"""
|
126 |
+
|
127 |
+
Args:
|
128 |
+
frame1: torch tensor, with shape [b * t, c, h, w]
|
129 |
+
warp_frame21: torch tensor, with shape [b * t, c, h, w]
|
130 |
+
confMask: confidence mask, with shape [b * t, c, h, w]
|
131 |
+
masks: torch tensor, with shape [b * t, c, h, w]
|
132 |
+
max_distance: maximum distance.
|
133 |
+
|
134 |
+
Returns: ternary loss
|
135 |
+
|
136 |
+
"""
|
137 |
+
t1 = ternary_transform(frame1)
|
138 |
+
t21 = ternary_transform(warp_frame21)
|
139 |
+
dist = hamming_distance(t1, t21)
|
140 |
+
loss = torch.mean(dist * confMask * masks) / torch.mean(masks)
|
141 |
+
return loss
|
142 |
+
|
model/modules/sparse_transformer.py
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from functools import reduce
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
class SoftSplit(nn.Module):
|
8 |
+
def __init__(self, channel, hidden, kernel_size, stride, padding):
|
9 |
+
super(SoftSplit, self).__init__()
|
10 |
+
self.kernel_size = kernel_size
|
11 |
+
self.stride = stride
|
12 |
+
self.padding = padding
|
13 |
+
self.t2t = nn.Unfold(kernel_size=kernel_size,
|
14 |
+
stride=stride,
|
15 |
+
padding=padding)
|
16 |
+
c_in = reduce((lambda x, y: x * y), kernel_size) * channel
|
17 |
+
self.embedding = nn.Linear(c_in, hidden)
|
18 |
+
|
19 |
+
def forward(self, x, b, output_size):
|
20 |
+
f_h = int((output_size[0] + 2 * self.padding[0] -
|
21 |
+
(self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
|
22 |
+
f_w = int((output_size[1] + 2 * self.padding[1] -
|
23 |
+
(self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
|
24 |
+
|
25 |
+
feat = self.t2t(x)
|
26 |
+
feat = feat.permute(0, 2, 1)
|
27 |
+
# feat shape [b*t, num_vec, ks*ks*c]
|
28 |
+
feat = self.embedding(feat)
|
29 |
+
# feat shape after embedding [b, t*num_vec, hidden]
|
30 |
+
feat = feat.view(b, -1, f_h, f_w, feat.size(2))
|
31 |
+
return feat
|
32 |
+
|
33 |
+
|
34 |
+
class SoftComp(nn.Module):
|
35 |
+
def __init__(self, channel, hidden, kernel_size, stride, padding):
|
36 |
+
super(SoftComp, self).__init__()
|
37 |
+
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
38 |
+
c_out = reduce((lambda x, y: x * y), kernel_size) * channel
|
39 |
+
self.embedding = nn.Linear(hidden, c_out)
|
40 |
+
self.kernel_size = kernel_size
|
41 |
+
self.stride = stride
|
42 |
+
self.padding = padding
|
43 |
+
self.bias_conv = nn.Conv2d(channel,
|
44 |
+
channel,
|
45 |
+
kernel_size=3,
|
46 |
+
stride=1,
|
47 |
+
padding=1)
|
48 |
+
|
49 |
+
def forward(self, x, t, output_size):
|
50 |
+
b_, _, _, _, c_ = x.shape
|
51 |
+
x = x.view(b_, -1, c_)
|
52 |
+
feat = self.embedding(x)
|
53 |
+
b, _, c = feat.size()
|
54 |
+
feat = feat.view(b * t, -1, c).permute(0, 2, 1)
|
55 |
+
feat = F.fold(feat,
|
56 |
+
output_size=output_size,
|
57 |
+
kernel_size=self.kernel_size,
|
58 |
+
stride=self.stride,
|
59 |
+
padding=self.padding)
|
60 |
+
feat = self.bias_conv(feat)
|
61 |
+
return feat
|
62 |
+
|
63 |
+
|
64 |
+
class FusionFeedForward(nn.Module):
|
65 |
+
def __init__(self, dim, hidden_dim=1960, t2t_params=None):
|
66 |
+
super(FusionFeedForward, self).__init__()
|
67 |
+
# We set hidden_dim as a default to 1960
|
68 |
+
self.fc1 = nn.Sequential(nn.Linear(dim, hidden_dim))
|
69 |
+
self.fc2 = nn.Sequential(nn.GELU(), nn.Linear(hidden_dim, dim))
|
70 |
+
assert t2t_params is not None
|
71 |
+
self.t2t_params = t2t_params
|
72 |
+
self.kernel_shape = reduce((lambda x, y: x * y), t2t_params['kernel_size']) # 49
|
73 |
+
|
74 |
+
def forward(self, x, output_size):
|
75 |
+
n_vecs = 1
|
76 |
+
for i, d in enumerate(self.t2t_params['kernel_size']):
|
77 |
+
n_vecs *= int((output_size[i] + 2 * self.t2t_params['padding'][i] -
|
78 |
+
(d - 1) - 1) / self.t2t_params['stride'][i] + 1)
|
79 |
+
|
80 |
+
x = self.fc1(x)
|
81 |
+
b, n, c = x.size()
|
82 |
+
normalizer = x.new_ones(b, n, self.kernel_shape).view(-1, n_vecs, self.kernel_shape).permute(0, 2, 1)
|
83 |
+
normalizer = F.fold(normalizer,
|
84 |
+
output_size=output_size,
|
85 |
+
kernel_size=self.t2t_params['kernel_size'],
|
86 |
+
padding=self.t2t_params['padding'],
|
87 |
+
stride=self.t2t_params['stride'])
|
88 |
+
|
89 |
+
x = F.fold(x.view(-1, n_vecs, c).permute(0, 2, 1),
|
90 |
+
output_size=output_size,
|
91 |
+
kernel_size=self.t2t_params['kernel_size'],
|
92 |
+
padding=self.t2t_params['padding'],
|
93 |
+
stride=self.t2t_params['stride'])
|
94 |
+
|
95 |
+
x = F.unfold(x / normalizer,
|
96 |
+
kernel_size=self.t2t_params['kernel_size'],
|
97 |
+
padding=self.t2t_params['padding'],
|
98 |
+
stride=self.t2t_params['stride']).permute(
|
99 |
+
0, 2, 1).contiguous().view(b, n, c)
|
100 |
+
x = self.fc2(x)
|
101 |
+
return x
|
102 |
+
|
103 |
+
|
104 |
+
def window_partition(x, window_size, n_head):
|
105 |
+
"""
|
106 |
+
Args:
|
107 |
+
x: shape is (B, T, H, W, C)
|
108 |
+
window_size (tuple[int]): window size
|
109 |
+
Returns:
|
110 |
+
windows: (B, num_windows_h, num_windows_w, n_head, T, window_size, window_size, C//n_head)
|
111 |
+
"""
|
112 |
+
B, T, H, W, C = x.shape
|
113 |
+
x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1], window_size[1], n_head, C//n_head)
|
114 |
+
windows = x.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous()
|
115 |
+
return windows
|
116 |
+
|
117 |
+
class SparseWindowAttention(nn.Module):
|
118 |
+
def __init__(self, dim, n_head, window_size, pool_size=(4,4), qkv_bias=True, attn_drop=0., proj_drop=0.,
|
119 |
+
pooling_token=True):
|
120 |
+
super().__init__()
|
121 |
+
assert dim % n_head == 0
|
122 |
+
# key, query, value projections for all heads
|
123 |
+
self.key = nn.Linear(dim, dim, qkv_bias)
|
124 |
+
self.query = nn.Linear(dim, dim, qkv_bias)
|
125 |
+
self.value = nn.Linear(dim, dim, qkv_bias)
|
126 |
+
# regularization
|
127 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
128 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
129 |
+
# output projection
|
130 |
+
self.proj = nn.Linear(dim, dim)
|
131 |
+
self.n_head = n_head
|
132 |
+
self.window_size = window_size
|
133 |
+
self.pooling_token = pooling_token
|
134 |
+
if self.pooling_token:
|
135 |
+
ks, stride = pool_size, pool_size
|
136 |
+
self.pool_layer = nn.Conv2d(dim, dim, kernel_size=ks, stride=stride, padding=(0, 0), groups=dim)
|
137 |
+
self.pool_layer.weight.data.fill_(1. / (pool_size[0] * pool_size[1]))
|
138 |
+
self.pool_layer.bias.data.fill_(0)
|
139 |
+
# self.expand_size = tuple(i // 2 for i in window_size)
|
140 |
+
self.expand_size = tuple((i + 1) // 2 for i in window_size)
|
141 |
+
|
142 |
+
if any(i > 0 for i in self.expand_size):
|
143 |
+
# get mask for rolled k and rolled v
|
144 |
+
mask_tl = torch.ones(self.window_size[0], self.window_size[1])
|
145 |
+
mask_tl[:-self.expand_size[0], :-self.expand_size[1]] = 0
|
146 |
+
mask_tr = torch.ones(self.window_size[0], self.window_size[1])
|
147 |
+
mask_tr[:-self.expand_size[0], self.expand_size[1]:] = 0
|
148 |
+
mask_bl = torch.ones(self.window_size[0], self.window_size[1])
|
149 |
+
mask_bl[self.expand_size[0]:, :-self.expand_size[1]] = 0
|
150 |
+
mask_br = torch.ones(self.window_size[0], self.window_size[1])
|
151 |
+
mask_br[self.expand_size[0]:, self.expand_size[1]:] = 0
|
152 |
+
masrool_k = torch.stack((mask_tl, mask_tr, mask_bl, mask_br), 0).flatten(0)
|
153 |
+
self.register_buffer("valid_ind_rolled", masrool_k.nonzero(as_tuple=False).view(-1))
|
154 |
+
|
155 |
+
self.max_pool = nn.MaxPool2d(window_size, window_size, (0, 0))
|
156 |
+
|
157 |
+
|
158 |
+
def forward(self, x, mask=None, T_ind=None, attn_mask=None):
|
159 |
+
b, t, h, w, c = x.shape # 20 36
|
160 |
+
w_h, w_w = self.window_size[0], self.window_size[1]
|
161 |
+
c_head = c // self.n_head
|
162 |
+
n_wh = math.ceil(h / self.window_size[0])
|
163 |
+
n_ww = math.ceil(w / self.window_size[1])
|
164 |
+
new_h = n_wh * self.window_size[0] # 20
|
165 |
+
new_w = n_ww * self.window_size[1] # 36
|
166 |
+
pad_r = new_w - w
|
167 |
+
pad_b = new_h - h
|
168 |
+
# reverse order
|
169 |
+
if pad_r > 0 or pad_b > 0:
|
170 |
+
x = F.pad(x,(0, 0, 0, pad_r, 0, pad_b, 0, 0), mode='constant', value=0)
|
171 |
+
mask = F.pad(mask,(0, 0, 0, pad_r, 0, pad_b, 0, 0), mode='constant', value=0)
|
172 |
+
|
173 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
174 |
+
q = self.query(x)
|
175 |
+
k = self.key(x)
|
176 |
+
v = self.value(x)
|
177 |
+
win_q = window_partition(q.contiguous(), self.window_size, self.n_head).view(b, n_wh*n_ww, self.n_head, t, w_h*w_w, c_head)
|
178 |
+
win_k = window_partition(k.contiguous(), self.window_size, self.n_head).view(b, n_wh*n_ww, self.n_head, t, w_h*w_w, c_head)
|
179 |
+
win_v = window_partition(v.contiguous(), self.window_size, self.n_head).view(b, n_wh*n_ww, self.n_head, t, w_h*w_w, c_head)
|
180 |
+
# roll_k and roll_v
|
181 |
+
if any(i > 0 for i in self.expand_size):
|
182 |
+
(k_tl, v_tl) = map(lambda a: torch.roll(a, shifts=(-self.expand_size[0], -self.expand_size[1]), dims=(2, 3)), (k, v))
|
183 |
+
(k_tr, v_tr) = map(lambda a: torch.roll(a, shifts=(-self.expand_size[0], self.expand_size[1]), dims=(2, 3)), (k, v))
|
184 |
+
(k_bl, v_bl) = map(lambda a: torch.roll(a, shifts=(self.expand_size[0], -self.expand_size[1]), dims=(2, 3)), (k, v))
|
185 |
+
(k_br, v_br) = map(lambda a: torch.roll(a, shifts=(self.expand_size[0], self.expand_size[1]), dims=(2, 3)), (k, v))
|
186 |
+
|
187 |
+
(k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows) = map(
|
188 |
+
lambda a: window_partition(a, self.window_size, self.n_head).view(b, n_wh*n_ww, self.n_head, t, w_h*w_w, c_head),
|
189 |
+
(k_tl, k_tr, k_bl, k_br))
|
190 |
+
(v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows) = map(
|
191 |
+
lambda a: window_partition(a, self.window_size, self.n_head).view(b, n_wh*n_ww, self.n_head, t, w_h*w_w, c_head),
|
192 |
+
(v_tl, v_tr, v_bl, v_br))
|
193 |
+
rool_k = torch.cat((k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows), 4).contiguous()
|
194 |
+
rool_v = torch.cat((v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows), 4).contiguous() # [b, n_wh*n_ww, n_head, t, w_h*w_w, c_head]
|
195 |
+
# mask out tokens in current window
|
196 |
+
rool_k = rool_k[:, :, :, :, self.valid_ind_rolled]
|
197 |
+
rool_v = rool_v[:, :, :, :, self.valid_ind_rolled]
|
198 |
+
roll_N = rool_k.shape[4]
|
199 |
+
rool_k = rool_k.view(b, n_wh*n_ww, self.n_head, t, roll_N, c // self.n_head)
|
200 |
+
rool_v = rool_v.view(b, n_wh*n_ww, self.n_head, t, roll_N, c // self.n_head)
|
201 |
+
win_k = torch.cat((win_k, rool_k), dim=4)
|
202 |
+
win_v = torch.cat((win_v, rool_v), dim=4)
|
203 |
+
else:
|
204 |
+
win_k = win_k
|
205 |
+
win_v = win_v
|
206 |
+
|
207 |
+
# pool_k and pool_v
|
208 |
+
if self.pooling_token:
|
209 |
+
pool_x = self.pool_layer(x.view(b*t, new_h, new_w, c).permute(0,3,1,2))
|
210 |
+
_, _, p_h, p_w = pool_x.shape
|
211 |
+
pool_x = pool_x.permute(0,2,3,1).view(b, t, p_h, p_w, c)
|
212 |
+
# pool_k
|
213 |
+
pool_k = self.key(pool_x).unsqueeze(1).repeat(1, n_wh*n_ww, 1, 1, 1, 1) # [b, n_wh*n_ww, t, p_h, p_w, c]
|
214 |
+
pool_k = pool_k.view(b, n_wh*n_ww, t, p_h, p_w, self.n_head, c_head).permute(0,1,5,2,3,4,6)
|
215 |
+
pool_k = pool_k.contiguous().view(b, n_wh*n_ww, self.n_head, t, p_h*p_w, c_head)
|
216 |
+
win_k = torch.cat((win_k, pool_k), dim=4)
|
217 |
+
# pool_v
|
218 |
+
pool_v = self.value(pool_x).unsqueeze(1).repeat(1, n_wh*n_ww, 1, 1, 1, 1) # [b, n_wh*n_ww, t, p_h, p_w, c]
|
219 |
+
pool_v = pool_v.view(b, n_wh*n_ww, t, p_h, p_w, self.n_head, c_head).permute(0,1,5,2,3,4,6)
|
220 |
+
pool_v = pool_v.contiguous().view(b, n_wh*n_ww, self.n_head, t, p_h*p_w, c_head)
|
221 |
+
win_v = torch.cat((win_v, pool_v), dim=4)
|
222 |
+
|
223 |
+
# [b, n_wh*n_ww, n_head, t, w_h*w_w, c_head]
|
224 |
+
out = torch.zeros_like(win_q)
|
225 |
+
l_t = mask.size(1)
|
226 |
+
|
227 |
+
mask = self.max_pool(mask.view(b * l_t, new_h, new_w))
|
228 |
+
mask = mask.view(b, l_t, n_wh*n_ww)
|
229 |
+
mask = torch.sum(mask, dim=1) # [b, n_wh*n_ww]
|
230 |
+
for i in range(win_q.shape[0]):
|
231 |
+
### For masked windows
|
232 |
+
mask_ind_i = mask[i].nonzero(as_tuple=False).view(-1)
|
233 |
+
# mask out quary in current window
|
234 |
+
# [b, n_wh*n_ww, n_head, t, w_h*w_w, c_head]
|
235 |
+
mask_n = len(mask_ind_i)
|
236 |
+
if mask_n > 0:
|
237 |
+
win_q_t = win_q[i, mask_ind_i].view(mask_n, self.n_head, t*w_h*w_w, c_head)
|
238 |
+
win_k_t = win_k[i, mask_ind_i]
|
239 |
+
win_v_t = win_v[i, mask_ind_i]
|
240 |
+
# mask out key and value
|
241 |
+
if T_ind is not None:
|
242 |
+
# key [n_wh*n_ww, n_head, t, w_h*w_w, c_head]
|
243 |
+
win_k_t = win_k_t[:, :, T_ind.view(-1)].view(mask_n, self.n_head, -1, c_head)
|
244 |
+
# value
|
245 |
+
win_v_t = win_v_t[:, :, T_ind.view(-1)].view(mask_n, self.n_head, -1, c_head)
|
246 |
+
else:
|
247 |
+
win_k_t = win_k_t.view(n_wh*n_ww, self.n_head, t*w_h*w_w, c_head)
|
248 |
+
win_v_t = win_v_t.view(n_wh*n_ww, self.n_head, t*w_h*w_w, c_head)
|
249 |
+
|
250 |
+
att_t = (win_q_t @ win_k_t.transpose(-2, -1)) * (1.0 / math.sqrt(win_q_t.size(-1)))
|
251 |
+
att_t = F.softmax(att_t, dim=-1)
|
252 |
+
att_t = self.attn_drop(att_t)
|
253 |
+
y_t = att_t @ win_v_t
|
254 |
+
|
255 |
+
out[i, mask_ind_i] = y_t.view(-1, self.n_head, t, w_h*w_w, c_head)
|
256 |
+
|
257 |
+
### For unmasked windows
|
258 |
+
unmask_ind_i = (mask[i] == 0).nonzero(as_tuple=False).view(-1)
|
259 |
+
# mask out quary in current window
|
260 |
+
# [b, n_wh*n_ww, n_head, t, w_h*w_w, c_head]
|
261 |
+
win_q_s = win_q[i, unmask_ind_i]
|
262 |
+
win_k_s = win_k[i, unmask_ind_i, :, :, :w_h*w_w]
|
263 |
+
win_v_s = win_v[i, unmask_ind_i, :, :, :w_h*w_w]
|
264 |
+
|
265 |
+
att_s = (win_q_s @ win_k_s.transpose(-2, -1)) * (1.0 / math.sqrt(win_q_s.size(-1)))
|
266 |
+
att_s = F.softmax(att_s, dim=-1)
|
267 |
+
att_s = self.attn_drop(att_s)
|
268 |
+
y_s = att_s @ win_v_s
|
269 |
+
out[i, unmask_ind_i] = y_s
|
270 |
+
|
271 |
+
# re-assemble all head outputs side by side
|
272 |
+
out = out.view(b, n_wh, n_ww, self.n_head, t, w_h, w_w, c_head)
|
273 |
+
out = out.permute(0, 4, 1, 5, 2, 6, 3, 7).contiguous().view(b, t, new_h, new_w, c)
|
274 |
+
|
275 |
+
|
276 |
+
if pad_r > 0 or pad_b > 0:
|
277 |
+
out = out[:, :, :h, :w, :]
|
278 |
+
|
279 |
+
# output projection
|
280 |
+
out = self.proj_drop(self.proj(out))
|
281 |
+
return out
|
282 |
+
|
283 |
+
|
284 |
+
class TemporalSparseTransformer(nn.Module):
|
285 |
+
def __init__(self, dim, n_head, window_size, pool_size,
|
286 |
+
norm_layer=nn.LayerNorm, t2t_params=None):
|
287 |
+
super().__init__()
|
288 |
+
self.window_size = window_size
|
289 |
+
self.attention = SparseWindowAttention(dim, n_head, window_size, pool_size)
|
290 |
+
self.norm1 = norm_layer(dim)
|
291 |
+
self.norm2 = norm_layer(dim)
|
292 |
+
self.mlp = FusionFeedForward(dim, t2t_params=t2t_params)
|
293 |
+
|
294 |
+
def forward(self, x, fold_x_size, mask=None, T_ind=None):
|
295 |
+
"""
|
296 |
+
Args:
|
297 |
+
x: image tokens, shape [B T H W C]
|
298 |
+
fold_x_size: fold feature size, shape [60 108]
|
299 |
+
mask: mask tokens, shape [B T H W 1]
|
300 |
+
Returns:
|
301 |
+
out_tokens: shape [B T H W C]
|
302 |
+
"""
|
303 |
+
B, T, H, W, C = x.shape # 20 36
|
304 |
+
|
305 |
+
shortcut = x
|
306 |
+
x = self.norm1(x)
|
307 |
+
att_x = self.attention(x, mask, T_ind)
|
308 |
+
|
309 |
+
# FFN
|
310 |
+
x = shortcut + att_x
|
311 |
+
y = self.norm2(x)
|
312 |
+
x = x + self.mlp(y.view(B, T * H * W, C), fold_x_size).view(B, T, H, W, C)
|
313 |
+
|
314 |
+
return x
|
315 |
+
|
316 |
+
|
317 |
+
class TemporalSparseTransformerBlock(nn.Module):
|
318 |
+
def __init__(self, dim, n_head, window_size, pool_size, depths, t2t_params=None):
|
319 |
+
super().__init__()
|
320 |
+
blocks = []
|
321 |
+
for i in range(depths):
|
322 |
+
blocks.append(
|
323 |
+
TemporalSparseTransformer(dim, n_head, window_size, pool_size, t2t_params=t2t_params)
|
324 |
+
)
|
325 |
+
self.transformer = nn.Sequential(*blocks)
|
326 |
+
self.depths = depths
|
327 |
+
|
328 |
+
def forward(self, x, fold_x_size, l_mask=None, t_dilation=2):
|
329 |
+
"""
|
330 |
+
Args:
|
331 |
+
x: image tokens, shape [B T H W C]
|
332 |
+
fold_x_size: fold feature size, shape [60 108]
|
333 |
+
l_mask: local mask tokens, shape [B T H W 1]
|
334 |
+
Returns:
|
335 |
+
out_tokens: shape [B T H W C]
|
336 |
+
"""
|
337 |
+
assert self.depths % t_dilation == 0, 'wrong t_dilation input.'
|
338 |
+
T = x.size(1)
|
339 |
+
T_ind = [torch.arange(i, T, t_dilation) for i in range(t_dilation)] * (self.depths // t_dilation)
|
340 |
+
|
341 |
+
for i in range(0, self.depths):
|
342 |
+
x = self.transformer[i](x, fold_x_size, l_mask, T_ind[i])
|
343 |
+
|
344 |
+
return x
|
model/modules/spectral_norm.py
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Spectral Normalization from https://arxiv.org/abs/1802.05957
|
3 |
+
"""
|
4 |
+
import torch
|
5 |
+
from torch.nn.functional import normalize
|
6 |
+
|
7 |
+
|
8 |
+
class SpectralNorm(object):
|
9 |
+
# Invariant before and after each forward call:
|
10 |
+
# u = normalize(W @ v)
|
11 |
+
# NB: At initialization, this invariant is not enforced
|
12 |
+
|
13 |
+
_version = 1
|
14 |
+
|
15 |
+
# At version 1:
|
16 |
+
# made `W` not a buffer,
|
17 |
+
# added `v` as a buffer, and
|
18 |
+
# made eval mode use `W = u @ W_orig @ v` rather than the stored `W`.
|
19 |
+
|
20 |
+
def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
|
21 |
+
self.name = name
|
22 |
+
self.dim = dim
|
23 |
+
if n_power_iterations <= 0:
|
24 |
+
raise ValueError(
|
25 |
+
'Expected n_power_iterations to be positive, but '
|
26 |
+
'got n_power_iterations={}'.format(n_power_iterations))
|
27 |
+
self.n_power_iterations = n_power_iterations
|
28 |
+
self.eps = eps
|
29 |
+
|
30 |
+
def reshape_weight_to_matrix(self, weight):
|
31 |
+
weight_mat = weight
|
32 |
+
if self.dim != 0:
|
33 |
+
# permute dim to front
|
34 |
+
weight_mat = weight_mat.permute(
|
35 |
+
self.dim,
|
36 |
+
*[d for d in range(weight_mat.dim()) if d != self.dim])
|
37 |
+
height = weight_mat.size(0)
|
38 |
+
return weight_mat.reshape(height, -1)
|
39 |
+
|
40 |
+
def compute_weight(self, module, do_power_iteration):
|
41 |
+
# NB: If `do_power_iteration` is set, the `u` and `v` vectors are
|
42 |
+
# updated in power iteration **in-place**. This is very important
|
43 |
+
# because in `DataParallel` forward, the vectors (being buffers) are
|
44 |
+
# broadcast from the parallelized module to each module replica,
|
45 |
+
# which is a new module object created on the fly. And each replica
|
46 |
+
# runs its own spectral norm power iteration. So simply assigning
|
47 |
+
# the updated vectors to the module this function runs on will cause
|
48 |
+
# the update to be lost forever. And the next time the parallelized
|
49 |
+
# module is replicated, the same randomly initialized vectors are
|
50 |
+
# broadcast and used!
|
51 |
+
#
|
52 |
+
# Therefore, to make the change propagate back, we rely on two
|
53 |
+
# important behaviors (also enforced via tests):
|
54 |
+
# 1. `DataParallel` doesn't clone storage if the broadcast tensor
|
55 |
+
# is already on correct device; and it makes sure that the
|
56 |
+
# parallelized module is already on `device[0]`.
|
57 |
+
# 2. If the out tensor in `out=` kwarg has correct shape, it will
|
58 |
+
# just fill in the values.
|
59 |
+
# Therefore, since the same power iteration is performed on all
|
60 |
+
# devices, simply updating the tensors in-place will make sure that
|
61 |
+
# the module replica on `device[0]` will update the _u vector on the
|
62 |
+
# parallized module (by shared storage).
|
63 |
+
#
|
64 |
+
# However, after we update `u` and `v` in-place, we need to **clone**
|
65 |
+
# them before using them to normalize the weight. This is to support
|
66 |
+
# backproping through two forward passes, e.g., the common pattern in
|
67 |
+
# GAN training: loss = D(real) - D(fake). Otherwise, engine will
|
68 |
+
# complain that variables needed to do backward for the first forward
|
69 |
+
# (i.e., the `u` and `v` vectors) are changed in the second forward.
|
70 |
+
weight = getattr(module, self.name + '_orig')
|
71 |
+
u = getattr(module, self.name + '_u')
|
72 |
+
v = getattr(module, self.name + '_v')
|
73 |
+
weight_mat = self.reshape_weight_to_matrix(weight)
|
74 |
+
|
75 |
+
if do_power_iteration:
|
76 |
+
with torch.no_grad():
|
77 |
+
for _ in range(self.n_power_iterations):
|
78 |
+
# Spectral norm of weight equals to `u^T W v`, where `u` and `v`
|
79 |
+
# are the first left and right singular vectors.
|
80 |
+
# This power iteration produces approximations of `u` and `v`.
|
81 |
+
v = normalize(torch.mv(weight_mat.t(), u),
|
82 |
+
dim=0,
|
83 |
+
eps=self.eps,
|
84 |
+
out=v)
|
85 |
+
u = normalize(torch.mv(weight_mat, v),
|
86 |
+
dim=0,
|
87 |
+
eps=self.eps,
|
88 |
+
out=u)
|
89 |
+
if self.n_power_iterations > 0:
|
90 |
+
# See above on why we need to clone
|
91 |
+
u = u.clone()
|
92 |
+
v = v.clone()
|
93 |
+
|
94 |
+
sigma = torch.dot(u, torch.mv(weight_mat, v))
|
95 |
+
weight = weight / sigma
|
96 |
+
return weight
|
97 |
+
|
98 |
+
def remove(self, module):
|
99 |
+
with torch.no_grad():
|
100 |
+
weight = self.compute_weight(module, do_power_iteration=False)
|
101 |
+
delattr(module, self.name)
|
102 |
+
delattr(module, self.name + '_u')
|
103 |
+
delattr(module, self.name + '_v')
|
104 |
+
delattr(module, self.name + '_orig')
|
105 |
+
module.register_parameter(self.name,
|
106 |
+
torch.nn.Parameter(weight.detach()))
|
107 |
+
|
108 |
+
def __call__(self, module, inputs):
|
109 |
+
setattr(
|
110 |
+
module, self.name,
|
111 |
+
self.compute_weight(module, do_power_iteration=module.training))
|
112 |
+
|
113 |
+
def _solve_v_and_rescale(self, weight_mat, u, target_sigma):
|
114 |
+
# Tries to returns a vector `v` s.t. `u = normalize(W @ v)`
|
115 |
+
# (the invariant at top of this class) and `u @ W @ v = sigma`.
|
116 |
+
# This uses pinverse in case W^T W is not invertible.
|
117 |
+
v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(),
|
118 |
+
weight_mat.t(), u.unsqueeze(1)).squeeze(1)
|
119 |
+
return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v)))
|
120 |
+
|
121 |
+
@staticmethod
|
122 |
+
def apply(module, name, n_power_iterations, dim, eps):
|
123 |
+
for k, hook in module._forward_pre_hooks.items():
|
124 |
+
if isinstance(hook, SpectralNorm) and hook.name == name:
|
125 |
+
raise RuntimeError(
|
126 |
+
"Cannot register two spectral_norm hooks on "
|
127 |
+
"the same parameter {}".format(name))
|
128 |
+
|
129 |
+
fn = SpectralNorm(name, n_power_iterations, dim, eps)
|
130 |
+
weight = module._parameters[name]
|
131 |
+
|
132 |
+
with torch.no_grad():
|
133 |
+
weight_mat = fn.reshape_weight_to_matrix(weight)
|
134 |
+
|
135 |
+
h, w = weight_mat.size()
|
136 |
+
# randomly initialize `u` and `v`
|
137 |
+
u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps)
|
138 |
+
v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps)
|
139 |
+
|
140 |
+
delattr(module, fn.name)
|
141 |
+
module.register_parameter(fn.name + "_orig", weight)
|
142 |
+
# We still need to assign weight back as fn.name because all sorts of
|
143 |
+
# things may assume that it exists, e.g., when initializing weights.
|
144 |
+
# However, we can't directly assign as it could be an nn.Parameter and
|
145 |
+
# gets added as a parameter. Instead, we register weight.data as a plain
|
146 |
+
# attribute.
|
147 |
+
setattr(module, fn.name, weight.data)
|
148 |
+
module.register_buffer(fn.name + "_u", u)
|
149 |
+
module.register_buffer(fn.name + "_v", v)
|
150 |
+
|
151 |
+
module.register_forward_pre_hook(fn)
|
152 |
+
|
153 |
+
module._register_state_dict_hook(SpectralNormStateDictHook(fn))
|
154 |
+
module._register_load_state_dict_pre_hook(
|
155 |
+
SpectralNormLoadStateDictPreHook(fn))
|
156 |
+
return fn
|
157 |
+
|
158 |
+
|
159 |
+
# This is a top level class because Py2 pickle doesn't like inner class nor an
|
160 |
+
# instancemethod.
|
161 |
+
class SpectralNormLoadStateDictPreHook(object):
|
162 |
+
# See docstring of SpectralNorm._version on the changes to spectral_norm.
|
163 |
+
def __init__(self, fn):
|
164 |
+
self.fn = fn
|
165 |
+
|
166 |
+
# For state_dict with version None, (assuming that it has gone through at
|
167 |
+
# least one training forward), we have
|
168 |
+
#
|
169 |
+
# u = normalize(W_orig @ v)
|
170 |
+
# W = W_orig / sigma, where sigma = u @ W_orig @ v
|
171 |
+
#
|
172 |
+
# To compute `v`, we solve `W_orig @ x = u`, and let
|
173 |
+
# v = x / (u @ W_orig @ x) * (W / W_orig).
|
174 |
+
def __call__(self, state_dict, prefix, local_metadata, strict,
|
175 |
+
missing_keys, unexpected_keys, error_msgs):
|
176 |
+
fn = self.fn
|
177 |
+
version = local_metadata.get('spectral_norm',
|
178 |
+
{}).get(fn.name + '.version', None)
|
179 |
+
if version is None or version < 1:
|
180 |
+
with torch.no_grad():
|
181 |
+
weight_orig = state_dict[prefix + fn.name + '_orig']
|
182 |
+
# weight = state_dict.pop(prefix + fn.name)
|
183 |
+
# sigma = (weight_orig / weight).mean()
|
184 |
+
weight_mat = fn.reshape_weight_to_matrix(weight_orig)
|
185 |
+
u = state_dict[prefix + fn.name + '_u']
|
186 |
+
# v = fn._solve_v_and_rescale(weight_mat, u, sigma)
|
187 |
+
# state_dict[prefix + fn.name + '_v'] = v
|
188 |
+
|
189 |
+
|
190 |
+
# This is a top level class because Py2 pickle doesn't like inner class nor an
|
191 |
+
# instancemethod.
|
192 |
+
class SpectralNormStateDictHook(object):
|
193 |
+
# See docstring of SpectralNorm._version on the changes to spectral_norm.
|
194 |
+
def __init__(self, fn):
|
195 |
+
self.fn = fn
|
196 |
+
|
197 |
+
def __call__(self, module, state_dict, prefix, local_metadata):
|
198 |
+
if 'spectral_norm' not in local_metadata:
|
199 |
+
local_metadata['spectral_norm'] = {}
|
200 |
+
key = self.fn.name + '.version'
|
201 |
+
if key in local_metadata['spectral_norm']:
|
202 |
+
raise RuntimeError(
|
203 |
+
"Unexpected key in metadata['spectral_norm']: {}".format(key))
|
204 |
+
local_metadata['spectral_norm'][key] = self.fn._version
|
205 |
+
|
206 |
+
|
207 |
+
def spectral_norm(module,
|
208 |
+
name='weight',
|
209 |
+
n_power_iterations=1,
|
210 |
+
eps=1e-12,
|
211 |
+
dim=None):
|
212 |
+
r"""Applies spectral normalization to a parameter in the given module.
|
213 |
+
|
214 |
+
.. math::
|
215 |
+
\mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})},
|
216 |
+
\sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
|
217 |
+
|
218 |
+
Spectral normalization stabilizes the training of discriminators (critics)
|
219 |
+
in Generative Adversarial Networks (GANs) by rescaling the weight tensor
|
220 |
+
with spectral norm :math:`\sigma` of the weight matrix calculated using
|
221 |
+
power iteration method. If the dimension of the weight tensor is greater
|
222 |
+
than 2, it is reshaped to 2D in power iteration method to get spectral
|
223 |
+
norm. This is implemented via a hook that calculates spectral norm and
|
224 |
+
rescales weight before every :meth:`~Module.forward` call.
|
225 |
+
|
226 |
+
See `Spectral Normalization for Generative Adversarial Networks`_ .
|
227 |
+
|
228 |
+
.. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
|
229 |
+
|
230 |
+
Args:
|
231 |
+
module (nn.Module): containing module
|
232 |
+
name (str, optional): name of weight parameter
|
233 |
+
n_power_iterations (int, optional): number of power iterations to
|
234 |
+
calculate spectral norm
|
235 |
+
eps (float, optional): epsilon for numerical stability in
|
236 |
+
calculating norms
|
237 |
+
dim (int, optional): dimension corresponding to number of outputs,
|
238 |
+
the default is ``0``, except for modules that are instances of
|
239 |
+
ConvTranspose{1,2,3}d, when it is ``1``
|
240 |
+
|
241 |
+
Returns:
|
242 |
+
The original module with the spectral norm hook
|
243 |
+
|
244 |
+
Example::
|
245 |
+
|
246 |
+
>>> m = spectral_norm(nn.Linear(20, 40))
|
247 |
+
>>> m
|
248 |
+
Linear(in_features=20, out_features=40, bias=True)
|
249 |
+
>>> m.weight_u.size()
|
250 |
+
torch.Size([40])
|
251 |
+
|
252 |
+
"""
|
253 |
+
if dim is None:
|
254 |
+
if isinstance(module,
|
255 |
+
(torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
|
256 |
+
torch.nn.ConvTranspose3d)):
|
257 |
+
dim = 1
|
258 |
+
else:
|
259 |
+
dim = 0
|
260 |
+
SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
|
261 |
+
return module
|
262 |
+
|
263 |
+
|
264 |
+
def remove_spectral_norm(module, name='weight'):
|
265 |
+
r"""Removes the spectral normalization reparameterization from a module.
|
266 |
+
|
267 |
+
Args:
|
268 |
+
module (Module): containing module
|
269 |
+
name (str, optional): name of weight parameter
|
270 |
+
|
271 |
+
Example:
|
272 |
+
>>> m = spectral_norm(nn.Linear(40, 10))
|
273 |
+
>>> remove_spectral_norm(m)
|
274 |
+
"""
|
275 |
+
for k, hook in module._forward_pre_hooks.items():
|
276 |
+
if isinstance(hook, SpectralNorm) and hook.name == name:
|
277 |
+
hook.remove(module)
|
278 |
+
del module._forward_pre_hooks[k]
|
279 |
+
return module
|
280 |
+
|
281 |
+
raise ValueError("spectral_norm of '{}' not found in {}".format(
|
282 |
+
name, module))
|
283 |
+
|
284 |
+
|
285 |
+
def use_spectral_norm(module, use_sn=False):
|
286 |
+
if use_sn:
|
287 |
+
return spectral_norm(module)
|
288 |
+
return module
|
model/propainter.py
ADDED
@@ -0,0 +1,532 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
''' Towards An End-to-End Framework for Video Inpainting
|
2 |
+
'''
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
import torchvision
|
8 |
+
|
9 |
+
from einops import rearrange
|
10 |
+
|
11 |
+
from model.modules.base_module import BaseNetwork
|
12 |
+
from model.modules.sparse_transformer import TemporalSparseTransformerBlock, SoftSplit, SoftComp
|
13 |
+
from model.modules.spectral_norm import spectral_norm as _spectral_norm
|
14 |
+
from model.modules.flow_loss_utils import flow_warp
|
15 |
+
from model.modules.deformconv import ModulatedDeformConv2d
|
16 |
+
|
17 |
+
from .misc import constant_init
|
18 |
+
|
19 |
+
def length_sq(x):
|
20 |
+
return torch.sum(torch.square(x), dim=1, keepdim=True)
|
21 |
+
|
22 |
+
def fbConsistencyCheck(flow_fw, flow_bw, alpha1=0.01, alpha2=0.5):
|
23 |
+
flow_bw_warped = flow_warp(flow_bw, flow_fw.permute(0, 2, 3, 1)) # wb(wf(x))
|
24 |
+
flow_diff_fw = flow_fw + flow_bw_warped # wf + wb(wf(x))
|
25 |
+
|
26 |
+
mag_sq_fw = length_sq(flow_fw) + length_sq(flow_bw_warped) # |wf| + |wb(wf(x))|
|
27 |
+
occ_thresh_fw = alpha1 * mag_sq_fw + alpha2
|
28 |
+
|
29 |
+
# fb_valid_fw = (length_sq(flow_diff_fw) < occ_thresh_fw).float()
|
30 |
+
fb_valid_fw = (length_sq(flow_diff_fw) < occ_thresh_fw).to(flow_fw)
|
31 |
+
return fb_valid_fw
|
32 |
+
|
33 |
+
|
34 |
+
class DeformableAlignment(ModulatedDeformConv2d):
|
35 |
+
"""Second-order deformable alignment module."""
|
36 |
+
def __init__(self, *args, **kwargs):
|
37 |
+
# self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10)
|
38 |
+
self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 3)
|
39 |
+
|
40 |
+
super(DeformableAlignment, self).__init__(*args, **kwargs)
|
41 |
+
|
42 |
+
self.conv_offset = nn.Sequential(
|
43 |
+
nn.Conv2d(2*self.out_channels + 2 + 1 + 2, self.out_channels, 3, 1, 1),
|
44 |
+
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
45 |
+
nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
|
46 |
+
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
47 |
+
nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
|
48 |
+
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
49 |
+
nn.Conv2d(self.out_channels, 27 * self.deform_groups, 3, 1, 1),
|
50 |
+
)
|
51 |
+
self.init_offset()
|
52 |
+
|
53 |
+
def init_offset(self):
|
54 |
+
constant_init(self.conv_offset[-1], val=0, bias=0)
|
55 |
+
|
56 |
+
def forward(self, x, cond_feat, flow):
|
57 |
+
out = self.conv_offset(cond_feat)
|
58 |
+
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
59 |
+
|
60 |
+
# offset
|
61 |
+
offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1))
|
62 |
+
offset = offset + flow.flip(1).repeat(1, offset.size(1) // 2, 1, 1)
|
63 |
+
|
64 |
+
# mask
|
65 |
+
mask = torch.sigmoid(mask)
|
66 |
+
|
67 |
+
return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias,
|
68 |
+
self.stride, self.padding,
|
69 |
+
self.dilation, mask)
|
70 |
+
|
71 |
+
|
72 |
+
class BidirectionalPropagation(nn.Module):
|
73 |
+
def __init__(self, channel, learnable=True):
|
74 |
+
super(BidirectionalPropagation, self).__init__()
|
75 |
+
self.deform_align = nn.ModuleDict()
|
76 |
+
self.backbone = nn.ModuleDict()
|
77 |
+
self.channel = channel
|
78 |
+
self.prop_list = ['backward_1', 'forward_1']
|
79 |
+
self.learnable = learnable
|
80 |
+
|
81 |
+
if self.learnable:
|
82 |
+
for i, module in enumerate(self.prop_list):
|
83 |
+
self.deform_align[module] = DeformableAlignment(
|
84 |
+
channel, channel, 3, padding=1, deform_groups=16)
|
85 |
+
|
86 |
+
self.backbone[module] = nn.Sequential(
|
87 |
+
nn.Conv2d(2*channel+2, channel, 3, 1, 1),
|
88 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
89 |
+
nn.Conv2d(channel, channel, 3, 1, 1),
|
90 |
+
)
|
91 |
+
|
92 |
+
self.fuse = nn.Sequential(
|
93 |
+
nn.Conv2d(2*channel+2, channel, 3, 1, 1),
|
94 |
+
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
95 |
+
nn.Conv2d(channel, channel, 3, 1, 1),
|
96 |
+
)
|
97 |
+
|
98 |
+
def binary_mask(self, mask, th=0.1):
|
99 |
+
mask[mask>th] = 1
|
100 |
+
mask[mask<=th] = 0
|
101 |
+
# return mask.float()
|
102 |
+
return mask.to(mask)
|
103 |
+
|
104 |
+
def forward(self, x, flows_forward, flows_backward, mask, interpolation='bilinear'):
|
105 |
+
"""
|
106 |
+
x shape : [b, t, c, h, w]
|
107 |
+
return [b, t, c, h, w]
|
108 |
+
"""
|
109 |
+
|
110 |
+
# For backward warping
|
111 |
+
# pred_flows_forward for backward feature propagation
|
112 |
+
# pred_flows_backward for forward feature propagation
|
113 |
+
b, t, c, h, w = x.shape
|
114 |
+
feats, masks = {}, {}
|
115 |
+
feats['input'] = [x[:, i, :, :, :] for i in range(0, t)]
|
116 |
+
masks['input'] = [mask[:, i, :, :, :] for i in range(0, t)]
|
117 |
+
|
118 |
+
prop_list = ['backward_1', 'forward_1']
|
119 |
+
cache_list = ['input'] + prop_list
|
120 |
+
|
121 |
+
for p_i, module_name in enumerate(prop_list):
|
122 |
+
feats[module_name] = []
|
123 |
+
masks[module_name] = []
|
124 |
+
|
125 |
+
if 'backward' in module_name:
|
126 |
+
frame_idx = range(0, t)
|
127 |
+
frame_idx = frame_idx[::-1]
|
128 |
+
flow_idx = frame_idx
|
129 |
+
flows_for_prop = flows_forward
|
130 |
+
flows_for_check = flows_backward
|
131 |
+
else:
|
132 |
+
frame_idx = range(0, t)
|
133 |
+
flow_idx = range(-1, t - 1)
|
134 |
+
flows_for_prop = flows_backward
|
135 |
+
flows_for_check = flows_forward
|
136 |
+
|
137 |
+
for i, idx in enumerate(frame_idx):
|
138 |
+
feat_current = feats[cache_list[p_i]][idx]
|
139 |
+
mask_current = masks[cache_list[p_i]][idx]
|
140 |
+
|
141 |
+
if i == 0:
|
142 |
+
feat_prop = feat_current
|
143 |
+
mask_prop = mask_current
|
144 |
+
else:
|
145 |
+
flow_prop = flows_for_prop[:, flow_idx[i], :, :, :]
|
146 |
+
flow_check = flows_for_check[:, flow_idx[i], :, :, :]
|
147 |
+
flow_vaild_mask = fbConsistencyCheck(flow_prop, flow_check)
|
148 |
+
feat_warped = flow_warp(feat_prop, flow_prop.permute(0, 2, 3, 1), interpolation)
|
149 |
+
|
150 |
+
if self.learnable:
|
151 |
+
cond = torch.cat([feat_current, feat_warped, flow_prop, flow_vaild_mask, mask_current], dim=1)
|
152 |
+
feat_prop = self.deform_align[module_name](feat_prop, cond, flow_prop)
|
153 |
+
mask_prop = mask_current
|
154 |
+
else:
|
155 |
+
mask_prop_valid = flow_warp(mask_prop, flow_prop.permute(0, 2, 3, 1))
|
156 |
+
mask_prop_valid = self.binary_mask(mask_prop_valid)
|
157 |
+
|
158 |
+
union_vaild_mask = self.binary_mask(mask_current*flow_vaild_mask*(1-mask_prop_valid))
|
159 |
+
feat_prop = union_vaild_mask * feat_warped + (1-union_vaild_mask) * feat_current
|
160 |
+
# update mask
|
161 |
+
mask_prop = self.binary_mask(mask_current*(1-(flow_vaild_mask*(1-mask_prop_valid))))
|
162 |
+
|
163 |
+
# refine
|
164 |
+
if self.learnable:
|
165 |
+
feat = torch.cat([feat_current, feat_prop, mask_current], dim=1)
|
166 |
+
feat_prop = feat_prop + self.backbone[module_name](feat)
|
167 |
+
# feat_prop = self.backbone[module_name](feat_prop)
|
168 |
+
|
169 |
+
feats[module_name].append(feat_prop)
|
170 |
+
masks[module_name].append(mask_prop)
|
171 |
+
|
172 |
+
# end for
|
173 |
+
if 'backward' in module_name:
|
174 |
+
feats[module_name] = feats[module_name][::-1]
|
175 |
+
masks[module_name] = masks[module_name][::-1]
|
176 |
+
|
177 |
+
outputs_b = torch.stack(feats['backward_1'], dim=1).view(-1, c, h, w)
|
178 |
+
outputs_f = torch.stack(feats['forward_1'], dim=1).view(-1, c, h, w)
|
179 |
+
|
180 |
+
if self.learnable:
|
181 |
+
mask_in = mask.view(-1, 2, h, w)
|
182 |
+
masks_b, masks_f = None, None
|
183 |
+
outputs = self.fuse(torch.cat([outputs_b, outputs_f, mask_in], dim=1)) + x.view(-1, c, h, w)
|
184 |
+
else:
|
185 |
+
masks_b = torch.stack(masks['backward_1'], dim=1)
|
186 |
+
masks_f = torch.stack(masks['forward_1'], dim=1)
|
187 |
+
outputs = outputs_f
|
188 |
+
|
189 |
+
return outputs_b.view(b, -1, c, h, w), outputs_f.view(b, -1, c, h, w), \
|
190 |
+
outputs.view(b, -1, c, h, w), masks_f
|
191 |
+
|
192 |
+
|
193 |
+
class Encoder(nn.Module):
|
194 |
+
def __init__(self):
|
195 |
+
super(Encoder, self).__init__()
|
196 |
+
self.group = [1, 2, 4, 8, 1]
|
197 |
+
self.layers = nn.ModuleList([
|
198 |
+
nn.Conv2d(5, 64, kernel_size=3, stride=2, padding=1),
|
199 |
+
nn.LeakyReLU(0.2, inplace=True),
|
200 |
+
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
|
201 |
+
nn.LeakyReLU(0.2, inplace=True),
|
202 |
+
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
203 |
+
nn.LeakyReLU(0.2, inplace=True),
|
204 |
+
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
|
205 |
+
nn.LeakyReLU(0.2, inplace=True),
|
206 |
+
nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1),
|
207 |
+
nn.LeakyReLU(0.2, inplace=True),
|
208 |
+
nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2),
|
209 |
+
nn.LeakyReLU(0.2, inplace=True),
|
210 |
+
nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4),
|
211 |
+
nn.LeakyReLU(0.2, inplace=True),
|
212 |
+
nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8),
|
213 |
+
nn.LeakyReLU(0.2, inplace=True),
|
214 |
+
nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1),
|
215 |
+
nn.LeakyReLU(0.2, inplace=True)
|
216 |
+
])
|
217 |
+
|
218 |
+
def forward(self, x):
|
219 |
+
bt, c, _, _ = x.size()
|
220 |
+
# h, w = h//4, w//4
|
221 |
+
out = x
|
222 |
+
for i, layer in enumerate(self.layers):
|
223 |
+
if i == 8:
|
224 |
+
x0 = out
|
225 |
+
_, _, h, w = x0.size()
|
226 |
+
if i > 8 and i % 2 == 0:
|
227 |
+
g = self.group[(i - 8) // 2]
|
228 |
+
x = x0.view(bt, g, -1, h, w)
|
229 |
+
o = out.view(bt, g, -1, h, w)
|
230 |
+
out = torch.cat([x, o], 2).view(bt, -1, h, w)
|
231 |
+
out = layer(out)
|
232 |
+
return out
|
233 |
+
|
234 |
+
|
235 |
+
class deconv(nn.Module):
|
236 |
+
def __init__(self,
|
237 |
+
input_channel,
|
238 |
+
output_channel,
|
239 |
+
kernel_size=3,
|
240 |
+
padding=0):
|
241 |
+
super().__init__()
|
242 |
+
self.conv = nn.Conv2d(input_channel,
|
243 |
+
output_channel,
|
244 |
+
kernel_size=kernel_size,
|
245 |
+
stride=1,
|
246 |
+
padding=padding)
|
247 |
+
|
248 |
+
def forward(self, x):
|
249 |
+
x = F.interpolate(x,
|
250 |
+
scale_factor=2,
|
251 |
+
mode='bilinear',
|
252 |
+
align_corners=True)
|
253 |
+
return self.conv(x)
|
254 |
+
|
255 |
+
|
256 |
+
class InpaintGenerator(BaseNetwork):
|
257 |
+
def __init__(self, init_weights=True, model_path=None):
|
258 |
+
super(InpaintGenerator, self).__init__()
|
259 |
+
channel = 128
|
260 |
+
hidden = 512
|
261 |
+
|
262 |
+
# encoder
|
263 |
+
self.encoder = Encoder()
|
264 |
+
|
265 |
+
# decoder
|
266 |
+
self.decoder = nn.Sequential(
|
267 |
+
deconv(channel, 128, kernel_size=3, padding=1),
|
268 |
+
nn.LeakyReLU(0.2, inplace=True),
|
269 |
+
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
|
270 |
+
nn.LeakyReLU(0.2, inplace=True),
|
271 |
+
deconv(64, 64, kernel_size=3, padding=1),
|
272 |
+
nn.LeakyReLU(0.2, inplace=True),
|
273 |
+
nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1))
|
274 |
+
|
275 |
+
# soft split and soft composition
|
276 |
+
kernel_size = (7, 7)
|
277 |
+
padding = (3, 3)
|
278 |
+
stride = (3, 3)
|
279 |
+
t2t_params = {
|
280 |
+
'kernel_size': kernel_size,
|
281 |
+
'stride': stride,
|
282 |
+
'padding': padding
|
283 |
+
}
|
284 |
+
self.ss = SoftSplit(channel, hidden, kernel_size, stride, padding)
|
285 |
+
self.sc = SoftComp(channel, hidden, kernel_size, stride, padding)
|
286 |
+
self.max_pool = nn.MaxPool2d(kernel_size, stride, padding)
|
287 |
+
|
288 |
+
# feature propagation module
|
289 |
+
self.img_prop_module = BidirectionalPropagation(3, learnable=False)
|
290 |
+
self.feat_prop_module = BidirectionalPropagation(128, learnable=True)
|
291 |
+
|
292 |
+
|
293 |
+
depths = 8
|
294 |
+
num_heads = 4
|
295 |
+
window_size = (5, 9)
|
296 |
+
pool_size = (4, 4)
|
297 |
+
self.transformers = TemporalSparseTransformerBlock(dim=hidden,
|
298 |
+
n_head=num_heads,
|
299 |
+
window_size=window_size,
|
300 |
+
pool_size=pool_size,
|
301 |
+
depths=depths,
|
302 |
+
t2t_params=t2t_params)
|
303 |
+
if init_weights:
|
304 |
+
self.init_weights()
|
305 |
+
|
306 |
+
|
307 |
+
if model_path is not None:
|
308 |
+
print('Pretrained ProPainter has loaded...')
|
309 |
+
ckpt = torch.load(model_path, map_location='cpu')
|
310 |
+
self.load_state_dict(ckpt, strict=True)
|
311 |
+
|
312 |
+
# print network parameter number
|
313 |
+
self.print_network()
|
314 |
+
|
315 |
+
def img_propagation(self, masked_frames, completed_flows, masks, interpolation='nearest'):
|
316 |
+
_, _, prop_frames, updated_masks = self.img_prop_module(masked_frames, completed_flows[0], completed_flows[1], masks, interpolation)
|
317 |
+
return prop_frames, updated_masks
|
318 |
+
|
319 |
+
def forward(self, masked_frames, completed_flows, masks_in, masks_updated, num_local_frames, interpolation='bilinear', t_dilation=2):
|
320 |
+
"""
|
321 |
+
Args:
|
322 |
+
masks_in: original mask
|
323 |
+
masks_updated: updated mask after image propagation
|
324 |
+
"""
|
325 |
+
|
326 |
+
l_t = num_local_frames
|
327 |
+
b, t, _, ori_h, ori_w = masked_frames.size()
|
328 |
+
|
329 |
+
# extracting features
|
330 |
+
enc_feat = self.encoder(torch.cat([masked_frames.view(b * t, 3, ori_h, ori_w),
|
331 |
+
masks_in.view(b * t, 1, ori_h, ori_w),
|
332 |
+
masks_updated.view(b * t, 1, ori_h, ori_w)], dim=1))
|
333 |
+
_, c, h, w = enc_feat.size()
|
334 |
+
local_feat = enc_feat.view(b, t, c, h, w)[:, :l_t, ...]
|
335 |
+
ref_feat = enc_feat.view(b, t, c, h, w)[:, l_t:, ...]
|
336 |
+
fold_feat_size = (h, w)
|
337 |
+
|
338 |
+
ds_flows_f = F.interpolate(completed_flows[0].view(-1, 2, ori_h, ori_w), scale_factor=1/4, mode='bilinear', align_corners=False).view(b, l_t-1, 2, h, w)/4.0
|
339 |
+
ds_flows_b = F.interpolate(completed_flows[1].view(-1, 2, ori_h, ori_w), scale_factor=1/4, mode='bilinear', align_corners=False).view(b, l_t-1, 2, h, w)/4.0
|
340 |
+
ds_mask_in = F.interpolate(masks_in.reshape(-1, 1, ori_h, ori_w), scale_factor=1/4, mode='nearest').view(b, t, 1, h, w)
|
341 |
+
ds_mask_in_local = ds_mask_in[:, :l_t]
|
342 |
+
ds_mask_updated_local = F.interpolate(masks_updated[:,:l_t].reshape(-1, 1, ori_h, ori_w), scale_factor=1/4, mode='nearest').view(b, l_t, 1, h, w)
|
343 |
+
|
344 |
+
|
345 |
+
if self.training:
|
346 |
+
mask_pool_l = self.max_pool(ds_mask_in.view(-1, 1, h, w))
|
347 |
+
mask_pool_l = mask_pool_l.view(b, t, 1, mask_pool_l.size(-2), mask_pool_l.size(-1))
|
348 |
+
else:
|
349 |
+
mask_pool_l = self.max_pool(ds_mask_in_local.view(-1, 1, h, w))
|
350 |
+
mask_pool_l = mask_pool_l.view(b, l_t, 1, mask_pool_l.size(-2), mask_pool_l.size(-1))
|
351 |
+
|
352 |
+
|
353 |
+
prop_mask_in = torch.cat([ds_mask_in_local, ds_mask_updated_local], dim=2)
|
354 |
+
_, _, local_feat, _ = self.feat_prop_module(local_feat, ds_flows_f, ds_flows_b, prop_mask_in, interpolation)
|
355 |
+
enc_feat = torch.cat((local_feat, ref_feat), dim=1)
|
356 |
+
|
357 |
+
trans_feat = self.ss(enc_feat.view(-1, c, h, w), b, fold_feat_size)
|
358 |
+
mask_pool_l = rearrange(mask_pool_l, 'b t c h w -> b t h w c').contiguous()
|
359 |
+
trans_feat = self.transformers(trans_feat, fold_feat_size, mask_pool_l, t_dilation=t_dilation)
|
360 |
+
trans_feat = self.sc(trans_feat, t, fold_feat_size)
|
361 |
+
trans_feat = trans_feat.view(b, t, -1, h, w)
|
362 |
+
|
363 |
+
enc_feat = enc_feat + trans_feat
|
364 |
+
|
365 |
+
if self.training:
|
366 |
+
output = self.decoder(enc_feat.view(-1, c, h, w))
|
367 |
+
output = torch.tanh(output).view(b, t, 3, ori_h, ori_w)
|
368 |
+
else:
|
369 |
+
output = self.decoder(enc_feat[:, :l_t].view(-1, c, h, w))
|
370 |
+
output = torch.tanh(output).view(b, l_t, 3, ori_h, ori_w)
|
371 |
+
|
372 |
+
return output
|
373 |
+
|
374 |
+
|
375 |
+
# ######################################################################
|
376 |
+
# Discriminator for Temporal Patch GAN
|
377 |
+
# ######################################################################
|
378 |
+
class Discriminator(BaseNetwork):
|
379 |
+
def __init__(self,
|
380 |
+
in_channels=3,
|
381 |
+
use_sigmoid=False,
|
382 |
+
use_spectral_norm=True,
|
383 |
+
init_weights=True):
|
384 |
+
super(Discriminator, self).__init__()
|
385 |
+
self.use_sigmoid = use_sigmoid
|
386 |
+
nf = 32
|
387 |
+
|
388 |
+
self.conv = nn.Sequential(
|
389 |
+
spectral_norm(
|
390 |
+
nn.Conv3d(in_channels=in_channels,
|
391 |
+
out_channels=nf * 1,
|
392 |
+
kernel_size=(3, 5, 5),
|
393 |
+
stride=(1, 2, 2),
|
394 |
+
padding=1,
|
395 |
+
bias=not use_spectral_norm), use_spectral_norm),
|
396 |
+
# nn.InstanceNorm2d(64, track_running_stats=False),
|
397 |
+
nn.LeakyReLU(0.2, inplace=True),
|
398 |
+
spectral_norm(
|
399 |
+
nn.Conv3d(nf * 1,
|
400 |
+
nf * 2,
|
401 |
+
kernel_size=(3, 5, 5),
|
402 |
+
stride=(1, 2, 2),
|
403 |
+
padding=(1, 2, 2),
|
404 |
+
bias=not use_spectral_norm), use_spectral_norm),
|
405 |
+
# nn.InstanceNorm2d(128, track_running_stats=False),
|
406 |
+
nn.LeakyReLU(0.2, inplace=True),
|
407 |
+
spectral_norm(
|
408 |
+
nn.Conv3d(nf * 2,
|
409 |
+
nf * 4,
|
410 |
+
kernel_size=(3, 5, 5),
|
411 |
+
stride=(1, 2, 2),
|
412 |
+
padding=(1, 2, 2),
|
413 |
+
bias=not use_spectral_norm), use_spectral_norm),
|
414 |
+
# nn.InstanceNorm2d(256, track_running_stats=False),
|
415 |
+
nn.LeakyReLU(0.2, inplace=True),
|
416 |
+
spectral_norm(
|
417 |
+
nn.Conv3d(nf * 4,
|
418 |
+
nf * 4,
|
419 |
+
kernel_size=(3, 5, 5),
|
420 |
+
stride=(1, 2, 2),
|
421 |
+
padding=(1, 2, 2),
|
422 |
+
bias=not use_spectral_norm), use_spectral_norm),
|
423 |
+
# nn.InstanceNorm2d(256, track_running_stats=False),
|
424 |
+
nn.LeakyReLU(0.2, inplace=True),
|
425 |
+
spectral_norm(
|
426 |
+
nn.Conv3d(nf * 4,
|
427 |
+
nf * 4,
|
428 |
+
kernel_size=(3, 5, 5),
|
429 |
+
stride=(1, 2, 2),
|
430 |
+
padding=(1, 2, 2),
|
431 |
+
bias=not use_spectral_norm), use_spectral_norm),
|
432 |
+
# nn.InstanceNorm2d(256, track_running_stats=False),
|
433 |
+
nn.LeakyReLU(0.2, inplace=True),
|
434 |
+
nn.Conv3d(nf * 4,
|
435 |
+
nf * 4,
|
436 |
+
kernel_size=(3, 5, 5),
|
437 |
+
stride=(1, 2, 2),
|
438 |
+
padding=(1, 2, 2)))
|
439 |
+
|
440 |
+
if init_weights:
|
441 |
+
self.init_weights()
|
442 |
+
|
443 |
+
def forward(self, xs):
|
444 |
+
# T, C, H, W = xs.shape (old)
|
445 |
+
# B, T, C, H, W (new)
|
446 |
+
xs_t = torch.transpose(xs, 1, 2)
|
447 |
+
feat = self.conv(xs_t)
|
448 |
+
if self.use_sigmoid:
|
449 |
+
feat = torch.sigmoid(feat)
|
450 |
+
out = torch.transpose(feat, 1, 2) # B, T, C, H, W
|
451 |
+
return out
|
452 |
+
|
453 |
+
|
454 |
+
class Discriminator_2D(BaseNetwork):
|
455 |
+
def __init__(self,
|
456 |
+
in_channels=3,
|
457 |
+
use_sigmoid=False,
|
458 |
+
use_spectral_norm=True,
|
459 |
+
init_weights=True):
|
460 |
+
super(Discriminator_2D, self).__init__()
|
461 |
+
self.use_sigmoid = use_sigmoid
|
462 |
+
nf = 32
|
463 |
+
|
464 |
+
self.conv = nn.Sequential(
|
465 |
+
spectral_norm(
|
466 |
+
nn.Conv3d(in_channels=in_channels,
|
467 |
+
out_channels=nf * 1,
|
468 |
+
kernel_size=(1, 5, 5),
|
469 |
+
stride=(1, 2, 2),
|
470 |
+
padding=(0, 2, 2),
|
471 |
+
bias=not use_spectral_norm), use_spectral_norm),
|
472 |
+
# nn.InstanceNorm2d(64, track_running_stats=False),
|
473 |
+
nn.LeakyReLU(0.2, inplace=True),
|
474 |
+
spectral_norm(
|
475 |
+
nn.Conv3d(nf * 1,
|
476 |
+
nf * 2,
|
477 |
+
kernel_size=(1, 5, 5),
|
478 |
+
stride=(1, 2, 2),
|
479 |
+
padding=(0, 2, 2),
|
480 |
+
bias=not use_spectral_norm), use_spectral_norm),
|
481 |
+
# nn.InstanceNorm2d(128, track_running_stats=False),
|
482 |
+
nn.LeakyReLU(0.2, inplace=True),
|
483 |
+
spectral_norm(
|
484 |
+
nn.Conv3d(nf * 2,
|
485 |
+
nf * 4,
|
486 |
+
kernel_size=(1, 5, 5),
|
487 |
+
stride=(1, 2, 2),
|
488 |
+
padding=(0, 2, 2),
|
489 |
+
bias=not use_spectral_norm), use_spectral_norm),
|
490 |
+
# nn.InstanceNorm2d(256, track_running_stats=False),
|
491 |
+
nn.LeakyReLU(0.2, inplace=True),
|
492 |
+
spectral_norm(
|
493 |
+
nn.Conv3d(nf * 4,
|
494 |
+
nf * 4,
|
495 |
+
kernel_size=(1, 5, 5),
|
496 |
+
stride=(1, 2, 2),
|
497 |
+
padding=(0, 2, 2),
|
498 |
+
bias=not use_spectral_norm), use_spectral_norm),
|
499 |
+
# nn.InstanceNorm2d(256, track_running_stats=False),
|
500 |
+
nn.LeakyReLU(0.2, inplace=True),
|
501 |
+
spectral_norm(
|
502 |
+
nn.Conv3d(nf * 4,
|
503 |
+
nf * 4,
|
504 |
+
kernel_size=(1, 5, 5),
|
505 |
+
stride=(1, 2, 2),
|
506 |
+
padding=(0, 2, 2),
|
507 |
+
bias=not use_spectral_norm), use_spectral_norm),
|
508 |
+
# nn.InstanceNorm2d(256, track_running_stats=False),
|
509 |
+
nn.LeakyReLU(0.2, inplace=True),
|
510 |
+
nn.Conv3d(nf * 4,
|
511 |
+
nf * 4,
|
512 |
+
kernel_size=(1, 5, 5),
|
513 |
+
stride=(1, 2, 2),
|
514 |
+
padding=(0, 2, 2)))
|
515 |
+
|
516 |
+
if init_weights:
|
517 |
+
self.init_weights()
|
518 |
+
|
519 |
+
def forward(self, xs):
|
520 |
+
# T, C, H, W = xs.shape (old)
|
521 |
+
# B, T, C, H, W (new)
|
522 |
+
xs_t = torch.transpose(xs, 1, 2)
|
523 |
+
feat = self.conv(xs_t)
|
524 |
+
if self.use_sigmoid:
|
525 |
+
feat = torch.sigmoid(feat)
|
526 |
+
out = torch.transpose(feat, 1, 2) # B, T, C, H, W
|
527 |
+
return out
|
528 |
+
|
529 |
+
def spectral_norm(module, mode=True):
|
530 |
+
if mode:
|
531 |
+
return _spectral_norm(module)
|
532 |
+
return module
|
model/recurrent_flow_completion.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import torchvision
|
5 |
+
|
6 |
+
from model.modules.deformconv import ModulatedDeformConv2d
|
7 |
+
from .misc import constant_init
|
8 |
+
|
9 |
+
class SecondOrderDeformableAlignment(ModulatedDeformConv2d):
|
10 |
+
"""Second-order deformable alignment module."""
|
11 |
+
def __init__(self, *args, **kwargs):
|
12 |
+
self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 5)
|
13 |
+
|
14 |
+
super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs)
|
15 |
+
|
16 |
+
self.conv_offset = nn.Sequential(
|
17 |
+
nn.Conv2d(3 * self.out_channels, self.out_channels, 3, 1, 1),
|
18 |
+
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
19 |
+
nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
|
20 |
+
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
21 |
+
nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
|
22 |
+
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
23 |
+
nn.Conv2d(self.out_channels, 27 * self.deform_groups, 3, 1, 1),
|
24 |
+
)
|
25 |
+
self.init_offset()
|
26 |
+
|
27 |
+
def init_offset(self):
|
28 |
+
constant_init(self.conv_offset[-1], val=0, bias=0)
|
29 |
+
|
30 |
+
def forward(self, x, extra_feat):
|
31 |
+
out = self.conv_offset(extra_feat)
|
32 |
+
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
33 |
+
|
34 |
+
# offset
|
35 |
+
offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1))
|
36 |
+
offset_1, offset_2 = torch.chunk(offset, 2, dim=1)
|
37 |
+
offset = torch.cat([offset_1, offset_2], dim=1)
|
38 |
+
|
39 |
+
# mask
|
40 |
+
mask = torch.sigmoid(mask)
|
41 |
+
|
42 |
+
return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias,
|
43 |
+
self.stride, self.padding,
|
44 |
+
self.dilation, mask)
|
45 |
+
|
46 |
+
class BidirectionalPropagation(nn.Module):
|
47 |
+
def __init__(self, channel):
|
48 |
+
super(BidirectionalPropagation, self).__init__()
|
49 |
+
modules = ['backward_', 'forward_']
|
50 |
+
self.deform_align = nn.ModuleDict()
|
51 |
+
self.backbone = nn.ModuleDict()
|
52 |
+
self.channel = channel
|
53 |
+
|
54 |
+
for i, module in enumerate(modules):
|
55 |
+
self.deform_align[module] = SecondOrderDeformableAlignment(
|
56 |
+
2 * channel, channel, 3, padding=1, deform_groups=16)
|
57 |
+
|
58 |
+
self.backbone[module] = nn.Sequential(
|
59 |
+
nn.Conv2d((2 + i) * channel, channel, 3, 1, 1),
|
60 |
+
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
61 |
+
nn.Conv2d(channel, channel, 3, 1, 1),
|
62 |
+
)
|
63 |
+
|
64 |
+
self.fusion = nn.Conv2d(2 * channel, channel, 1, 1, 0)
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
"""
|
68 |
+
x shape : [b, t, c, h, w]
|
69 |
+
return [b, t, c, h, w]
|
70 |
+
"""
|
71 |
+
b, t, c, h, w = x.shape
|
72 |
+
feats = {}
|
73 |
+
feats['spatial'] = [x[:, i, :, :, :] for i in range(0, t)]
|
74 |
+
|
75 |
+
for module_name in ['backward_', 'forward_']:
|
76 |
+
|
77 |
+
feats[module_name] = []
|
78 |
+
|
79 |
+
frame_idx = range(0, t)
|
80 |
+
mapping_idx = list(range(0, len(feats['spatial'])))
|
81 |
+
mapping_idx += mapping_idx[::-1]
|
82 |
+
|
83 |
+
if 'backward' in module_name:
|
84 |
+
frame_idx = frame_idx[::-1]
|
85 |
+
|
86 |
+
feat_prop = x.new_zeros(b, self.channel, h, w)
|
87 |
+
for i, idx in enumerate(frame_idx):
|
88 |
+
feat_current = feats['spatial'][mapping_idx[idx]]
|
89 |
+
if i > 0:
|
90 |
+
cond_n1 = feat_prop
|
91 |
+
|
92 |
+
# initialize second-order features
|
93 |
+
feat_n2 = torch.zeros_like(feat_prop)
|
94 |
+
cond_n2 = torch.zeros_like(cond_n1)
|
95 |
+
if i > 1: # second-order features
|
96 |
+
feat_n2 = feats[module_name][-2]
|
97 |
+
cond_n2 = feat_n2
|
98 |
+
|
99 |
+
cond = torch.cat([cond_n1, feat_current, cond_n2], dim=1) # condition information, cond(flow warped 1st/2nd feature)
|
100 |
+
feat_prop = torch.cat([feat_prop, feat_n2], dim=1) # two order feat_prop -1 & -2
|
101 |
+
feat_prop = self.deform_align[module_name](feat_prop, cond)
|
102 |
+
|
103 |
+
# fuse current features
|
104 |
+
feat = [feat_current] + \
|
105 |
+
[feats[k][idx] for k in feats if k not in ['spatial', module_name]] \
|
106 |
+
+ [feat_prop]
|
107 |
+
|
108 |
+
feat = torch.cat(feat, dim=1)
|
109 |
+
# embed current features
|
110 |
+
feat_prop = feat_prop + self.backbone[module_name](feat)
|
111 |
+
|
112 |
+
feats[module_name].append(feat_prop)
|
113 |
+
|
114 |
+
# end for
|
115 |
+
if 'backward' in module_name:
|
116 |
+
feats[module_name] = feats[module_name][::-1]
|
117 |
+
|
118 |
+
outputs = []
|
119 |
+
for i in range(0, t):
|
120 |
+
align_feats = [feats[k].pop(0) for k in feats if k != 'spatial']
|
121 |
+
align_feats = torch.cat(align_feats, dim=1)
|
122 |
+
outputs.append(self.fusion(align_feats))
|
123 |
+
|
124 |
+
return torch.stack(outputs, dim=1) + x
|
125 |
+
|
126 |
+
|
127 |
+
class deconv(nn.Module):
|
128 |
+
def __init__(self,
|
129 |
+
input_channel,
|
130 |
+
output_channel,
|
131 |
+
kernel_size=3,
|
132 |
+
padding=0):
|
133 |
+
super().__init__()
|
134 |
+
self.conv = nn.Conv2d(input_channel,
|
135 |
+
output_channel,
|
136 |
+
kernel_size=kernel_size,
|
137 |
+
stride=1,
|
138 |
+
padding=padding)
|
139 |
+
|
140 |
+
def forward(self, x):
|
141 |
+
x = F.interpolate(x,
|
142 |
+
scale_factor=2,
|
143 |
+
mode='bilinear',
|
144 |
+
align_corners=True)
|
145 |
+
return self.conv(x)
|
146 |
+
|
147 |
+
|
148 |
+
class P3DBlock(nn.Module):
|
149 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, use_residual=0, bias=True):
|
150 |
+
super().__init__()
|
151 |
+
self.conv1 = nn.Sequential(
|
152 |
+
nn.Conv3d(in_channels, out_channels, kernel_size=(1, kernel_size, kernel_size),
|
153 |
+
stride=(1, stride, stride), padding=(0, padding, padding), bias=bias),
|
154 |
+
nn.LeakyReLU(0.2, inplace=True)
|
155 |
+
)
|
156 |
+
self.conv2 = nn.Sequential(
|
157 |
+
nn.Conv3d(out_channels, out_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1),
|
158 |
+
padding=(2, 0, 0), dilation=(2, 1, 1), bias=bias)
|
159 |
+
)
|
160 |
+
self.use_residual = use_residual
|
161 |
+
|
162 |
+
def forward(self, feats):
|
163 |
+
feat1 = self.conv1(feats)
|
164 |
+
feat2 = self.conv2(feat1)
|
165 |
+
if self.use_residual:
|
166 |
+
output = feats + feat2
|
167 |
+
else:
|
168 |
+
output = feat2
|
169 |
+
return output
|
170 |
+
|
171 |
+
|
172 |
+
class EdgeDetection(nn.Module):
|
173 |
+
def __init__(self, in_ch=2, out_ch=1, mid_ch=16):
|
174 |
+
super().__init__()
|
175 |
+
self.projection = nn.Sequential(
|
176 |
+
nn.Conv2d(in_ch, mid_ch, 3, 1, 1),
|
177 |
+
nn.LeakyReLU(0.2, inplace=True)
|
178 |
+
)
|
179 |
+
|
180 |
+
self.mid_layer_1 = nn.Sequential(
|
181 |
+
nn.Conv2d(mid_ch, mid_ch, 3, 1, 1),
|
182 |
+
nn.LeakyReLU(0.2, inplace=True)
|
183 |
+
)
|
184 |
+
|
185 |
+
self.mid_layer_2 = nn.Sequential(
|
186 |
+
nn.Conv2d(mid_ch, mid_ch, 3, 1, 1)
|
187 |
+
)
|
188 |
+
|
189 |
+
self.l_relu = nn.LeakyReLU(0.01, inplace=True)
|
190 |
+
|
191 |
+
self.out_layer = nn.Conv2d(mid_ch, out_ch, 1, 1, 0)
|
192 |
+
|
193 |
+
def forward(self, flow):
|
194 |
+
flow = self.projection(flow)
|
195 |
+
edge = self.mid_layer_1(flow)
|
196 |
+
edge = self.mid_layer_2(edge)
|
197 |
+
edge = self.l_relu(flow + edge)
|
198 |
+
edge = self.out_layer(edge)
|
199 |
+
edge = torch.sigmoid(edge)
|
200 |
+
return edge
|
201 |
+
|
202 |
+
|
203 |
+
class RecurrentFlowCompleteNet(nn.Module):
|
204 |
+
def __init__(self, model_path=None):
|
205 |
+
super().__init__()
|
206 |
+
self.downsample = nn.Sequential(
|
207 |
+
nn.Conv3d(3, 32, kernel_size=(1, 5, 5), stride=(1, 2, 2),
|
208 |
+
padding=(0, 2, 2), padding_mode='replicate'),
|
209 |
+
nn.LeakyReLU(0.2, inplace=True)
|
210 |
+
)
|
211 |
+
|
212 |
+
self.encoder1 = nn.Sequential(
|
213 |
+
P3DBlock(32, 32, 3, 1, 1),
|
214 |
+
nn.LeakyReLU(0.2, inplace=True),
|
215 |
+
P3DBlock(32, 64, 3, 2, 1),
|
216 |
+
nn.LeakyReLU(0.2, inplace=True)
|
217 |
+
) # 4x
|
218 |
+
|
219 |
+
self.encoder2 = nn.Sequential(
|
220 |
+
P3DBlock(64, 64, 3, 1, 1),
|
221 |
+
nn.LeakyReLU(0.2, inplace=True),
|
222 |
+
P3DBlock(64, 128, 3, 2, 1),
|
223 |
+
nn.LeakyReLU(0.2, inplace=True)
|
224 |
+
) # 8x
|
225 |
+
|
226 |
+
self.mid_dilation = nn.Sequential(
|
227 |
+
nn.Conv3d(128, 128, (1, 3, 3), (1, 1, 1), padding=(0, 3, 3), dilation=(1, 3, 3)), # p = d*(k-1)/2
|
228 |
+
nn.LeakyReLU(0.2, inplace=True),
|
229 |
+
nn.Conv3d(128, 128, (1, 3, 3), (1, 1, 1), padding=(0, 2, 2), dilation=(1, 2, 2)),
|
230 |
+
nn.LeakyReLU(0.2, inplace=True),
|
231 |
+
nn.Conv3d(128, 128, (1, 3, 3), (1, 1, 1), padding=(0, 1, 1), dilation=(1, 1, 1)),
|
232 |
+
nn.LeakyReLU(0.2, inplace=True)
|
233 |
+
)
|
234 |
+
|
235 |
+
# feature propagation module
|
236 |
+
self.feat_prop_module = BidirectionalPropagation(128)
|
237 |
+
|
238 |
+
self.decoder2 = nn.Sequential(
|
239 |
+
nn.Conv2d(128, 128, 3, 1, 1),
|
240 |
+
nn.LeakyReLU(0.2, inplace=True),
|
241 |
+
deconv(128, 64, 3, 1),
|
242 |
+
nn.LeakyReLU(0.2, inplace=True)
|
243 |
+
) # 4x
|
244 |
+
|
245 |
+
self.decoder1 = nn.Sequential(
|
246 |
+
nn.Conv2d(64, 64, 3, 1, 1),
|
247 |
+
nn.LeakyReLU(0.2, inplace=True),
|
248 |
+
deconv(64, 32, 3, 1),
|
249 |
+
nn.LeakyReLU(0.2, inplace=True)
|
250 |
+
) # 2x
|
251 |
+
|
252 |
+
self.upsample = nn.Sequential(
|
253 |
+
nn.Conv2d(32, 32, 3, padding=1),
|
254 |
+
nn.LeakyReLU(0.2, inplace=True),
|
255 |
+
deconv(32, 2, 3, 1)
|
256 |
+
)
|
257 |
+
|
258 |
+
# edge loss
|
259 |
+
self.edgeDetector = EdgeDetection(in_ch=2, out_ch=1, mid_ch=16)
|
260 |
+
|
261 |
+
# Need to initial the weights of MSDeformAttn specifically
|
262 |
+
for m in self.modules():
|
263 |
+
if isinstance(m, SecondOrderDeformableAlignment):
|
264 |
+
m.init_offset()
|
265 |
+
|
266 |
+
if model_path is not None:
|
267 |
+
print('Pretrained flow completion model has loaded...')
|
268 |
+
ckpt = torch.load(model_path, map_location='cpu')
|
269 |
+
self.load_state_dict(ckpt, strict=True)
|
270 |
+
|
271 |
+
|
272 |
+
def forward(self, masked_flows, masks):
|
273 |
+
# masked_flows: b t-1 2 h w
|
274 |
+
# masks: b t-1 2 h w
|
275 |
+
b, t, _, h, w = masked_flows.size()
|
276 |
+
masked_flows = masked_flows.permute(0,2,1,3,4)
|
277 |
+
masks = masks.permute(0,2,1,3,4)
|
278 |
+
|
279 |
+
inputs = torch.cat((masked_flows, masks), dim=1)
|
280 |
+
|
281 |
+
x = self.downsample(inputs)
|
282 |
+
|
283 |
+
feat_e1 = self.encoder1(x)
|
284 |
+
feat_e2 = self.encoder2(feat_e1) # b c t h w
|
285 |
+
feat_mid = self.mid_dilation(feat_e2) # b c t h w
|
286 |
+
feat_mid = feat_mid.permute(0,2,1,3,4) # b t c h w
|
287 |
+
|
288 |
+
feat_prop = self.feat_prop_module(feat_mid)
|
289 |
+
feat_prop = feat_prop.view(-1, 128, h//8, w//8) # b*t c h w
|
290 |
+
|
291 |
+
_, c, _, h_f, w_f = feat_e1.shape
|
292 |
+
feat_e1 = feat_e1.permute(0,2,1,3,4).contiguous().view(-1, c, h_f, w_f) # b*t c h w
|
293 |
+
feat_d2 = self.decoder2(feat_prop) + feat_e1
|
294 |
+
|
295 |
+
_, c, _, h_f, w_f = x.shape
|
296 |
+
x = x.permute(0,2,1,3,4).contiguous().view(-1, c, h_f, w_f) # b*t c h w
|
297 |
+
|
298 |
+
feat_d1 = self.decoder1(feat_d2)
|
299 |
+
|
300 |
+
flow = self.upsample(feat_d1)
|
301 |
+
if self.training:
|
302 |
+
edge = self.edgeDetector(flow)
|
303 |
+
edge = edge.view(b, t, 1, h, w)
|
304 |
+
else:
|
305 |
+
edge = None
|
306 |
+
|
307 |
+
flow = flow.view(b, t, 2, h, w)
|
308 |
+
|
309 |
+
return flow, edge
|
310 |
+
|
311 |
+
|
312 |
+
def forward_bidirect_flow(self, masked_flows_bi, masks):
|
313 |
+
"""
|
314 |
+
Args:
|
315 |
+
masked_flows_bi: [masked_flows_f, masked_flows_b] | (b t-1 2 h w), (b t-1 2 h w)
|
316 |
+
masks: b t 1 h w
|
317 |
+
"""
|
318 |
+
masks_forward = masks[:, :-1, ...].contiguous()
|
319 |
+
masks_backward = masks[:, 1:, ...].contiguous()
|
320 |
+
|
321 |
+
# mask flow
|
322 |
+
masked_flows_forward = masked_flows_bi[0] * (1-masks_forward)
|
323 |
+
masked_flows_backward = masked_flows_bi[1] * (1-masks_backward)
|
324 |
+
|
325 |
+
# -- completion --
|
326 |
+
# forward
|
327 |
+
pred_flows_forward, pred_edges_forward = self.forward(masked_flows_forward, masks_forward)
|
328 |
+
|
329 |
+
# backward
|
330 |
+
masked_flows_backward = torch.flip(masked_flows_backward, dims=[1])
|
331 |
+
masks_backward = torch.flip(masks_backward, dims=[1])
|
332 |
+
pred_flows_backward, pred_edges_backward = self.forward(masked_flows_backward, masks_backward)
|
333 |
+
pred_flows_backward = torch.flip(pred_flows_backward, dims=[1])
|
334 |
+
if self.training:
|
335 |
+
pred_edges_backward = torch.flip(pred_edges_backward, dims=[1])
|
336 |
+
|
337 |
+
return [pred_flows_forward, pred_flows_backward], [pred_edges_forward, pred_edges_backward]
|
338 |
+
|
339 |
+
|
340 |
+
def combine_flow(self, masked_flows_bi, pred_flows_bi, masks):
|
341 |
+
masks_forward = masks[:, :-1, ...].contiguous()
|
342 |
+
masks_backward = masks[:, 1:, ...].contiguous()
|
343 |
+
|
344 |
+
pred_flows_forward = pred_flows_bi[0] * masks_forward + masked_flows_bi[0] * (1-masks_forward)
|
345 |
+
pred_flows_backward = pred_flows_bi[1] * masks_backward + masked_flows_bi[1] * (1-masks_backward)
|
346 |
+
|
347 |
+
return pred_flows_forward, pred_flows_backward
|
model/vgg_arch.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
from collections import OrderedDict
|
4 |
+
from torch import nn as nn
|
5 |
+
from torchvision.models import vgg as vgg
|
6 |
+
|
7 |
+
VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
|
8 |
+
NAMES = {
|
9 |
+
'vgg11': [
|
10 |
+
'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
|
11 |
+
'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
|
12 |
+
'pool5'
|
13 |
+
],
|
14 |
+
'vgg13': [
|
15 |
+
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
|
16 |
+
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
|
17 |
+
'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
|
18 |
+
],
|
19 |
+
'vgg16': [
|
20 |
+
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
|
21 |
+
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
|
22 |
+
'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
|
23 |
+
'pool5'
|
24 |
+
],
|
25 |
+
'vgg19': [
|
26 |
+
'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
|
27 |
+
'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
|
28 |
+
'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
|
29 |
+
'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
|
30 |
+
]
|
31 |
+
}
|
32 |
+
|
33 |
+
|
34 |
+
def insert_bn(names):
|
35 |
+
"""Insert bn layer after each conv.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
names (list): The list of layer names.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
list: The list of layer names with bn layers.
|
42 |
+
"""
|
43 |
+
names_bn = []
|
44 |
+
for name in names:
|
45 |
+
names_bn.append(name)
|
46 |
+
if 'conv' in name:
|
47 |
+
position = name.replace('conv', '')
|
48 |
+
names_bn.append('bn' + position)
|
49 |
+
return names_bn
|
50 |
+
|
51 |
+
class VGGFeatureExtractor(nn.Module):
|
52 |
+
"""VGG network for feature extraction.
|
53 |
+
|
54 |
+
In this implementation, we allow users to choose whether use normalization
|
55 |
+
in the input feature and the type of vgg network. Note that the pretrained
|
56 |
+
path must fit the vgg type.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
layer_name_list (list[str]): Forward function returns the corresponding
|
60 |
+
features according to the layer_name_list.
|
61 |
+
Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
|
62 |
+
vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
|
63 |
+
use_input_norm (bool): If True, normalize the input image. Importantly,
|
64 |
+
the input feature must in the range [0, 1]. Default: True.
|
65 |
+
range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
|
66 |
+
Default: False.
|
67 |
+
requires_grad (bool): If true, the parameters of VGG network will be
|
68 |
+
optimized. Default: False.
|
69 |
+
remove_pooling (bool): If true, the max pooling operations in VGG net
|
70 |
+
will be removed. Default: False.
|
71 |
+
pooling_stride (int): The stride of max pooling operation. Default: 2.
|
72 |
+
"""
|
73 |
+
|
74 |
+
def __init__(self,
|
75 |
+
layer_name_list,
|
76 |
+
vgg_type='vgg19',
|
77 |
+
use_input_norm=True,
|
78 |
+
range_norm=False,
|
79 |
+
requires_grad=False,
|
80 |
+
remove_pooling=False,
|
81 |
+
pooling_stride=2):
|
82 |
+
super(VGGFeatureExtractor, self).__init__()
|
83 |
+
|
84 |
+
self.layer_name_list = layer_name_list
|
85 |
+
self.use_input_norm = use_input_norm
|
86 |
+
self.range_norm = range_norm
|
87 |
+
|
88 |
+
self.names = NAMES[vgg_type.replace('_bn', '')]
|
89 |
+
if 'bn' in vgg_type:
|
90 |
+
self.names = insert_bn(self.names)
|
91 |
+
|
92 |
+
# only borrow layers that will be used to avoid unused params
|
93 |
+
max_idx = 0
|
94 |
+
for v in layer_name_list:
|
95 |
+
idx = self.names.index(v)
|
96 |
+
if idx > max_idx:
|
97 |
+
max_idx = idx
|
98 |
+
|
99 |
+
if os.path.exists(VGG_PRETRAIN_PATH):
|
100 |
+
vgg_net = getattr(vgg, vgg_type)(pretrained=False)
|
101 |
+
state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
|
102 |
+
vgg_net.load_state_dict(state_dict)
|
103 |
+
else:
|
104 |
+
vgg_net = getattr(vgg, vgg_type)(pretrained=True)
|
105 |
+
|
106 |
+
features = vgg_net.features[:max_idx + 1]
|
107 |
+
|
108 |
+
modified_net = OrderedDict()
|
109 |
+
for k, v in zip(self.names, features):
|
110 |
+
if 'pool' in k:
|
111 |
+
# if remove_pooling is true, pooling operation will be removed
|
112 |
+
if remove_pooling:
|
113 |
+
continue
|
114 |
+
else:
|
115 |
+
# in some cases, we may want to change the default stride
|
116 |
+
modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
|
117 |
+
else:
|
118 |
+
modified_net[k] = v
|
119 |
+
|
120 |
+
self.vgg_net = nn.Sequential(modified_net)
|
121 |
+
|
122 |
+
if not requires_grad:
|
123 |
+
self.vgg_net.eval()
|
124 |
+
for param in self.parameters():
|
125 |
+
param.requires_grad = False
|
126 |
+
else:
|
127 |
+
self.vgg_net.train()
|
128 |
+
for param in self.parameters():
|
129 |
+
param.requires_grad = True
|
130 |
+
|
131 |
+
if self.use_input_norm:
|
132 |
+
# the mean is for image with range [0, 1]
|
133 |
+
self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
134 |
+
# the std is for image with range [0, 1]
|
135 |
+
self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
136 |
+
|
137 |
+
def forward(self, x):
|
138 |
+
"""Forward function.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
x (Tensor): Input tensor with shape (n, c, h, w).
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
Tensor: Forward results.
|
145 |
+
"""
|
146 |
+
if self.range_norm:
|
147 |
+
x = (x + 1) / 2
|
148 |
+
if self.use_input_norm:
|
149 |
+
x = (x - self.mean) / self.std
|
150 |
+
output = {}
|
151 |
+
|
152 |
+
for key, layer in self.vgg_net._modules.items():
|
153 |
+
x = layer(x)
|
154 |
+
if key in self.layer_name_list:
|
155 |
+
output[key] = x.clone()
|
156 |
+
|
157 |
+
return output
|
requirements.txt
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
av
|
2 |
+
addict
|
3 |
+
einops
|
4 |
+
future
|
5 |
+
numpy
|
6 |
+
scipy
|
7 |
+
opencv-python
|
8 |
+
matplotlib
|
9 |
+
scikit-image
|
10 |
+
torch>=1.7.1
|
11 |
+
torchvision>=0.8.2
|
12 |
+
imageio-ffmpeg
|
13 |
+
pyyaml
|
14 |
+
requests
|
15 |
+
timm
|
16 |
+
yapf
|
17 |
+
progressbar2
|
18 |
+
gdown
|
19 |
+
gitpython
|
20 |
+
git+https://github.com/cheind/py-thin-plate-spline
|
21 |
+
hickle
|
22 |
+
tensorboard
|
23 |
+
numpy
|
24 |
+
git+https://github.com/facebookresearch/segment-anything.git
|
25 |
+
gradio
|
26 |
+
opencv-python
|
27 |
+
matplotlib
|
28 |
+
pyyaml
|
29 |
+
av
|
30 |
+
openmim
|
31 |
+
tqdm
|
32 |
+
psutil
|
33 |
+
omegaconf
|
scripts/compute_flow.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import sys
|
3 |
+
sys.path.append(".")
|
4 |
+
|
5 |
+
import os
|
6 |
+
import cv2
|
7 |
+
import argparse
|
8 |
+
from PIL import Image
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torchvision import transforms
|
12 |
+
|
13 |
+
from RAFT import RAFT
|
14 |
+
from utils.flow_util import *
|
15 |
+
|
16 |
+
def imwrite(img, file_path, params=None, auto_mkdir=True):
|
17 |
+
if auto_mkdir:
|
18 |
+
dir_name = os.path.abspath(os.path.dirname(file_path))
|
19 |
+
os.makedirs(dir_name, exist_ok=True)
|
20 |
+
return cv2.imwrite(file_path, img, params)
|
21 |
+
|
22 |
+
def initialize_RAFT(model_path='weights/raft-things.pth', device='cuda'):
|
23 |
+
"""Initializes the RAFT model.
|
24 |
+
"""
|
25 |
+
args = argparse.ArgumentParser()
|
26 |
+
args.raft_model = model_path
|
27 |
+
args.small = False
|
28 |
+
args.mixed_precision = False
|
29 |
+
args.alternate_corr = False
|
30 |
+
|
31 |
+
model = torch.nn.DataParallel(RAFT(args))
|
32 |
+
model.load_state_dict(torch.load(args.raft_model))
|
33 |
+
|
34 |
+
model = model.module
|
35 |
+
model.to(device)
|
36 |
+
model.eval()
|
37 |
+
|
38 |
+
return model
|
39 |
+
|
40 |
+
|
41 |
+
if __name__ == '__main__':
|
42 |
+
device = 'cuda'
|
43 |
+
|
44 |
+
parser = argparse.ArgumentParser()
|
45 |
+
parser.add_argument('-i', '--root_path', type=str, default='your_dataset_root/youtube-vos/JPEGImages')
|
46 |
+
parser.add_argument('-o', '--save_path', type=str, default='your_dataset_root/youtube-vos/Flows_flo')
|
47 |
+
parser.add_argument('--height', type=int, default=240)
|
48 |
+
parser.add_argument('--width', type=int, default=432)
|
49 |
+
|
50 |
+
args = parser.parse_args()
|
51 |
+
|
52 |
+
# Flow model
|
53 |
+
RAFT_model = initialize_RAFT(device=device)
|
54 |
+
|
55 |
+
root_path = args.root_path
|
56 |
+
save_path = args.save_path
|
57 |
+
h_new, w_new = (args.height, args.width)
|
58 |
+
|
59 |
+
file_list = sorted(os.listdir(root_path))
|
60 |
+
for f in file_list:
|
61 |
+
print(f'Processing: {f} ...')
|
62 |
+
m_list = sorted(os.listdir(os.path.join(root_path, f)))
|
63 |
+
len_m = len(m_list)
|
64 |
+
for i in range(len_m-1):
|
65 |
+
img1_path = os.path.join(root_path, f, m_list[i])
|
66 |
+
img2_path = os.path.join(root_path, f, m_list[i+1])
|
67 |
+
img1 = Image.fromarray(cv2.imread(img1_path))
|
68 |
+
img2 = Image.fromarray(cv2.imread(img2_path))
|
69 |
+
|
70 |
+
transform = transforms.Compose([transforms.ToTensor()])
|
71 |
+
|
72 |
+
img1 = transform(img1).unsqueeze(0).to(device)[:,[2,1,0],:,:]
|
73 |
+
img2 = transform(img2).unsqueeze(0).to(device)[:,[2,1,0],:,:]
|
74 |
+
|
75 |
+
# upsize to a multiple of 16
|
76 |
+
# h, w = img1.shape[2:4]
|
77 |
+
# w_new = w if (w % 16) == 0 else 16 * (w // 16 + 1)
|
78 |
+
# h_new = h if (h % 16) == 0 else 16 * (h // 16 + 1)
|
79 |
+
|
80 |
+
|
81 |
+
img1 = F.interpolate(input=img1,
|
82 |
+
size=(h_new, w_new),
|
83 |
+
mode='bilinear',
|
84 |
+
align_corners=False)
|
85 |
+
img2 = F.interpolate(input=img2,
|
86 |
+
size=(h_new, w_new),
|
87 |
+
mode='bilinear',
|
88 |
+
align_corners=False)
|
89 |
+
|
90 |
+
with torch.no_grad():
|
91 |
+
img1 = img1*2 - 1
|
92 |
+
img2 = img2*2 - 1
|
93 |
+
|
94 |
+
_, flow_f = RAFT_model(img1, img2, iters=20, test_mode=True)
|
95 |
+
_, flow_b = RAFT_model(img2, img1, iters=20, test_mode=True)
|
96 |
+
|
97 |
+
|
98 |
+
flow_f = flow_f[0].permute(1,2,0).cpu().numpy()
|
99 |
+
flow_b = flow_b[0].permute(1,2,0).cpu().numpy()
|
100 |
+
|
101 |
+
# flow_f = resize_flow(flow_f, w_new, h_new)
|
102 |
+
# flow_b = resize_flow(flow_b, w_new, h_new)
|
103 |
+
|
104 |
+
save_flow_f = os.path.join(save_path, f, f'{m_list[i][:-4]}_{m_list[i+1][:-4]}_f.flo')
|
105 |
+
save_flow_b = os.path.join(save_path, f, f'{m_list[i+1][:-4]}_{m_list[i][:-4]}_b.flo')
|
106 |
+
|
107 |
+
flowwrite(flow_f, save_flow_f, quantize=False)
|
108 |
+
flowwrite(flow_b, save_flow_b, quantize=False)
|
scripts/evaluate_flow_completion.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import sys
|
3 |
+
sys.path.append(".")
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import os
|
7 |
+
import numpy as np
|
8 |
+
import argparse
|
9 |
+
from PIL import Image
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
|
14 |
+
from core.dataset import TestDataset
|
15 |
+
from model.modules.flow_comp_raft import RAFT_bi
|
16 |
+
from model.recurrent_flow_completion import RecurrentFlowCompleteNet
|
17 |
+
|
18 |
+
from RAFT.utils.flow_viz_pt import flow_to_image
|
19 |
+
|
20 |
+
import cvbase
|
21 |
+
import imageio
|
22 |
+
from time import time
|
23 |
+
|
24 |
+
import warnings
|
25 |
+
warnings.filterwarnings("ignore")
|
26 |
+
|
27 |
+
def create_dir(dir):
|
28 |
+
"""Creates a directory if not exist.
|
29 |
+
"""
|
30 |
+
if not os.path.exists(dir):
|
31 |
+
os.makedirs(dir)
|
32 |
+
|
33 |
+
def save_flows(output, videoFlowF, videoFlowB):
|
34 |
+
# create_dir(os.path.join(output, 'forward_flo'))
|
35 |
+
# create_dir(os.path.join(output, 'backward_flo'))
|
36 |
+
create_dir(os.path.join(output, 'forward_png'))
|
37 |
+
create_dir(os.path.join(output, 'backward_png'))
|
38 |
+
N = videoFlowF.shape[-1]
|
39 |
+
for i in range(N):
|
40 |
+
forward_flow = videoFlowF[..., i]
|
41 |
+
backward_flow = videoFlowB[..., i]
|
42 |
+
forward_flow_vis = cvbase.flow2rgb(forward_flow)
|
43 |
+
backward_flow_vis = cvbase.flow2rgb(backward_flow)
|
44 |
+
# cvbase.write_flow(forward_flow, os.path.join(output, 'forward_flo', '{:05d}.flo'.format(i)))
|
45 |
+
# cvbase.write_flow(backward_flow, os.path.join(output, 'backward_flo', '{:05d}.flo'.format(i)))
|
46 |
+
forward_flow_vis = (forward_flow_vis*255.0).astype(np.uint8)
|
47 |
+
backward_flow_vis = (backward_flow_vis*255.0).astype(np.uint8)
|
48 |
+
imageio.imwrite(os.path.join(output, 'forward_png', '{:05d}.png'.format(i)), forward_flow_vis)
|
49 |
+
imageio.imwrite(os.path.join(output, 'backward_png', '{:05d}.png'.format(i)), backward_flow_vis)
|
50 |
+
|
51 |
+
def tensor2np(array):
|
52 |
+
array = torch.stack(array, dim=-1).squeeze(0).permute(1, 2, 0, 3).cpu().numpy()
|
53 |
+
return array
|
54 |
+
|
55 |
+
def main_worker(args):
|
56 |
+
# set up datasets and data loader
|
57 |
+
args.size = (args.width, args.height)
|
58 |
+
test_dataset = TestDataset(vars(args))
|
59 |
+
|
60 |
+
test_loader = DataLoader(test_dataset,
|
61 |
+
batch_size=1,
|
62 |
+
shuffle=False,
|
63 |
+
num_workers=args.num_workers)
|
64 |
+
|
65 |
+
# set up models
|
66 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
67 |
+
fix_raft = RAFT_bi(args.raft_model_path, device)
|
68 |
+
|
69 |
+
fix_flow_complete = RecurrentFlowCompleteNet(args.fc_model_path)
|
70 |
+
for p in fix_flow_complete.parameters():
|
71 |
+
p.requires_grad = False
|
72 |
+
fix_flow_complete.to(device)
|
73 |
+
fix_flow_complete.eval()
|
74 |
+
|
75 |
+
total_frame_epe = []
|
76 |
+
time_all = []
|
77 |
+
|
78 |
+
print('Start evaluation...')
|
79 |
+
# create results directory
|
80 |
+
result_path = os.path.join('results_flow', f'{args.dataset}')
|
81 |
+
if not os.path.exists(result_path):
|
82 |
+
os.makedirs(result_path)
|
83 |
+
|
84 |
+
eval_summary = open(os.path.join(result_path, f"{args.dataset}_metrics.txt"), "w")
|
85 |
+
|
86 |
+
for index, items in enumerate(test_loader):
|
87 |
+
frames, masks, flows_f, flows_b, video_name, frames_PIL = items
|
88 |
+
local_masks = masks.float().to(device)
|
89 |
+
|
90 |
+
video_length = frames.size(1)
|
91 |
+
|
92 |
+
if args.load_flow:
|
93 |
+
gt_flows_bi = (flows_f.to(device), flows_b.to(device))
|
94 |
+
else:
|
95 |
+
short_len = 60
|
96 |
+
if frames.size(1) > short_len:
|
97 |
+
gt_flows_f_list, gt_flows_b_list = [], []
|
98 |
+
for f in range(0, video_length, short_len):
|
99 |
+
end_f = min(video_length, f + short_len)
|
100 |
+
if f == 0:
|
101 |
+
flows_f, flows_b = fix_raft(frames[:,f:end_f], iters=args.raft_iter)
|
102 |
+
else:
|
103 |
+
flows_f, flows_b = fix_raft(frames[:,f-1:end_f], iters=args.raft_iter)
|
104 |
+
|
105 |
+
gt_flows_f_list.append(flows_f)
|
106 |
+
gt_flows_b_list.append(flows_b)
|
107 |
+
gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
|
108 |
+
gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
|
109 |
+
gt_flows_bi = (gt_flows_f, gt_flows_b)
|
110 |
+
else:
|
111 |
+
gt_flows_bi = fix_raft(frames, iters=20)
|
112 |
+
|
113 |
+
torch.cuda.synchronize()
|
114 |
+
time_start = time()
|
115 |
+
|
116 |
+
# flow_length = flows_f.size(1)
|
117 |
+
# f_stride = 30
|
118 |
+
# pred_flows_f = []
|
119 |
+
# pred_flows_b = []
|
120 |
+
# suffix = flow_length%f_stride
|
121 |
+
# last = flow_length//f_stride
|
122 |
+
# for f in range(0, flow_length, f_stride):
|
123 |
+
# gt_flows_bi_i = (gt_flows_bi[0][:,f:f+f_stride], gt_flows_bi[1][:,f:f+f_stride])
|
124 |
+
# pred_flows_bi, _ = fix_flow_complete.forward_bidirect_flow(gt_flows_bi_i, local_masks[:,f:f+f_stride+1])
|
125 |
+
# pred_flows_f_i, pred_flows_b_i = fix_flow_complete.combine_flow(gt_flows_bi_i, pred_flows_bi, local_masks[:,f:f+f_stride+1])
|
126 |
+
# pred_flows_f.append(pred_flows_f_i)
|
127 |
+
# pred_flows_b.append(pred_flows_b_i)
|
128 |
+
# pred_flows_f = torch.cat(pred_flows_f, dim=1)
|
129 |
+
# pred_flows_b = torch.cat(pred_flows_b, dim=1)
|
130 |
+
# pred_flows_bi = (pred_flows_f, pred_flows_b)
|
131 |
+
|
132 |
+
pred_flows_bi, _ = fix_flow_complete.forward_bidirect_flow(gt_flows_bi, local_masks)
|
133 |
+
pred_flows_bi = fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, local_masks)
|
134 |
+
|
135 |
+
torch.cuda.synchronize()
|
136 |
+
time_i = time() - time_start
|
137 |
+
time_i = time_i*1.0/frames.size(1)
|
138 |
+
|
139 |
+
time_all = time_all+[time_i]*frames.size(1)
|
140 |
+
|
141 |
+
cur_video_epe = []
|
142 |
+
|
143 |
+
epe1 = torch.mean(torch.sum((flows_f - pred_flows_bi[0].cpu())**2, dim=2).sqrt())
|
144 |
+
epe2 = torch.mean(torch.sum((flows_b - pred_flows_bi[1].cpu())**2, dim=2).sqrt())
|
145 |
+
|
146 |
+
cur_video_epe.append(epe1.numpy())
|
147 |
+
cur_video_epe.append(epe2.numpy())
|
148 |
+
|
149 |
+
total_frame_epe = total_frame_epe+[epe1.numpy()]*flows_f.size(1)
|
150 |
+
total_frame_epe = total_frame_epe+[epe2.numpy()]*flows_f.size(1)
|
151 |
+
|
152 |
+
cur_epe = sum(cur_video_epe) / len(cur_video_epe)
|
153 |
+
avg_time = sum(time_all) / len(time_all)
|
154 |
+
print(
|
155 |
+
f'[{index+1:3}/{len(test_loader)}] Name: {str(video_name):25} | EPE: {cur_epe:.4f} | Time: {avg_time:.4f}'
|
156 |
+
)
|
157 |
+
eval_summary.write(
|
158 |
+
f'[{index+1:3}/{len(test_loader)}] Name: {str(video_name):25} | EPE: {cur_epe:.4f} | Time: {avg_time:.4f}\n'
|
159 |
+
)
|
160 |
+
|
161 |
+
# saving images for evaluating warpping errors
|
162 |
+
if args.save_results:
|
163 |
+
forward_flows = pred_flows_bi[0].cpu().permute(1,0,2,3,4)
|
164 |
+
backward_flows = pred_flows_bi[1].cpu().permute(1,0,2,3,4)
|
165 |
+
# forward_flows = flows_f.cpu().permute(1,0,2,3,4)
|
166 |
+
# backward_flows = flows_b.cpu().permute(1,0,2,3,4)
|
167 |
+
videoFlowF = list(forward_flows)
|
168 |
+
videoFlowB = list(backward_flows)
|
169 |
+
|
170 |
+
videoFlowF = tensor2np(videoFlowF)
|
171 |
+
videoFlowB = tensor2np(videoFlowB)
|
172 |
+
|
173 |
+
save_frame_path = os.path.join(result_path, video_name[0])
|
174 |
+
save_flows(save_frame_path, videoFlowF, videoFlowB)
|
175 |
+
|
176 |
+
avg_frame_epe = sum(total_frame_epe) / len(total_frame_epe)
|
177 |
+
|
178 |
+
print(f'Finish evaluation... Average Frame EPE: {avg_frame_epe:.4f} | | Time: {avg_time:.4f}')
|
179 |
+
eval_summary.write(f'Finish evaluation... Average Frame EPE: {avg_frame_epe:.4f} | | Time: {avg_time:.4f}\n')
|
180 |
+
eval_summary.close()
|
181 |
+
|
182 |
+
if __name__ == '__main__':
|
183 |
+
parser = argparse.ArgumentParser()
|
184 |
+
parser.add_argument('--height', type=int, default=240)
|
185 |
+
parser.add_argument('--width', type=int, default=432)
|
186 |
+
parser.add_argument('--raft_model_path', default='weights/raft-things.pth', type=str)
|
187 |
+
parser.add_argument('--fc_model_path', default='weights/recurrent_flow_completion.pth', type=str)
|
188 |
+
parser.add_argument('--dataset', choices=['davis', 'youtube-vos'], type=str)
|
189 |
+
parser.add_argument('--video_root', default='dataset_root', type=str)
|
190 |
+
parser.add_argument('--mask_root', default='mask_root', type=str)
|
191 |
+
parser.add_argument('--flow_root', default='flow_ground_truth_root', type=str)
|
192 |
+
parser.add_argument('--load_flow', default=False, type=bool)
|
193 |
+
parser.add_argument("--raft_iter", type=int, default=20)
|
194 |
+
parser.add_argument('--save_results', action='store_true')
|
195 |
+
parser.add_argument('--num_workers', default=4, type=int)
|
196 |
+
args = parser.parse_args()
|
197 |
+
main_worker(args)
|