Liangrj5
commited on
Commit
·
ebf5d87
1
Parent(s):
6dd9459
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- baselines/__init__.py +0 -0
- baselines/__pycache__/__init__.cpython-311.pyc +0 -0
- baselines/clip_alignment_with_language/README.md +25 -0
- baselines/clip_alignment_with_language/__init__.py +0 -0
- baselines/clip_alignment_with_language/__pycache__/__init__.cpython-311.pyc +0 -0
- baselines/clip_alignment_with_language/__pycache__/config.cpython-311.pyc +0 -0
- baselines/clip_alignment_with_language/__pycache__/inference.cpython-311.pyc +0 -0
- baselines/clip_alignment_with_language/__pycache__/model.cpython-311.pyc +0 -0
- baselines/clip_alignment_with_language/__pycache__/proposal_retrieval_dataset.cpython-311.pyc +0 -0
- baselines/clip_alignment_with_language/config.py +207 -0
- baselines/clip_alignment_with_language/inference.py +672 -0
- baselines/clip_alignment_with_language/local_utils/__init__.py +0 -0
- baselines/clip_alignment_with_language/local_utils/__pycache__/__init__.cpython-311.pyc +0 -0
- baselines/clip_alignment_with_language/local_utils/__pycache__/compute_proposal_upper_bound.cpython-311.pyc +0 -0
- baselines/clip_alignment_with_language/local_utils/__pycache__/proposal.cpython-311.pyc +0 -0
- baselines/clip_alignment_with_language/local_utils/compute_proposal_upper_bound.py +117 -0
- baselines/clip_alignment_with_language/local_utils/proposal.py +181 -0
- baselines/clip_alignment_with_language/local_utils/tvr_proposal_test_log.txt +61 -0
- baselines/clip_alignment_with_language/mix_model_prediction.py +86 -0
- baselines/clip_alignment_with_language/model.py +299 -0
- baselines/clip_alignment_with_language/proposal_retrieval_dataset.py +587 -0
- baselines/clip_alignment_with_language/scripts/compute_upper_bound.sh +23 -0
- baselines/clip_alignment_with_language/scripts/inference.sh +17 -0
- baselines/clip_alignment_with_language/scripts/inference_mix.sh +27 -0
- baselines/clip_alignment_with_language/scripts/inference_with_external.sh +54 -0
- baselines/clip_alignment_with_language/scripts/re_train_cal.sh +21 -0
- baselines/clip_alignment_with_language/scripts/re_train_mcn.sh +21 -0
- baselines/clip_alignment_with_language/scripts/train.sh +80 -0
- baselines/clip_alignment_with_language/train.py +310 -0
- baselines/crossmodal_moment_localization/README.md +2 -0
- baselines/crossmodal_moment_localization/__init__.py +0 -0
- baselines/crossmodal_moment_localization/__pycache__/__init__.cpython-311.pyc +0 -0
- baselines/crossmodal_moment_localization/__pycache__/config.cpython-311.pyc +0 -0
- baselines/crossmodal_moment_localization/__pycache__/inference.cpython-311.pyc +0 -0
- baselines/crossmodal_moment_localization/__pycache__/model_components.cpython-311.pyc +0 -0
- baselines/crossmodal_moment_localization/__pycache__/model_xml.cpython-311.pyc +0 -0
- baselines/crossmodal_moment_localization/__pycache__/ndcg_iou_topk.cpython-311.pyc +0 -0
- baselines/crossmodal_moment_localization/__pycache__/optimization.cpython-311.pyc +0 -0
- baselines/crossmodal_moment_localization/__pycache__/start_end_dataset.cpython-311.pyc +0 -0
- baselines/crossmodal_moment_localization/config.py +276 -0
- baselines/crossmodal_moment_localization/inference.py +414 -0
- baselines/crossmodal_moment_localization/model_components.py +317 -0
- baselines/crossmodal_moment_localization/model_xml.py +642 -0
- baselines/crossmodal_moment_localization/ndcg_iou_topk.py +68 -0
- baselines/crossmodal_moment_localization/optimization.py +338 -0
- baselines/crossmodal_moment_localization/scripts/eval.sh +14 -0
- baselines/crossmodal_moment_localization/scripts/inference.sh +18 -0
- baselines/crossmodal_moment_localization/scripts/inference_with_external.sh +40 -0
- baselines/crossmodal_moment_localization/scripts/train.sh +70 -0
- baselines/crossmodal_moment_localization/start_end_dataset.py +393 -0
baselines/__init__.py
ADDED
File without changes
|
baselines/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (176 Bytes). View file
|
|
baselines/clip_alignment_with_language/README.md
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Clip Alignment With Language
|
2 |
+
This folder contains the CAL model described in the paper
|
3 |
+
```
|
4 |
+
@article{Escorcia2019TemporalLO,
|
5 |
+
title={Temporal Localization of Moments in Video Collections with Natural Language},
|
6 |
+
author={Victor Escorcia and Mattia Soldan and Josef Sivic and Bernard Ghanem and Bryan Russell},
|
7 |
+
journal={ArXiv},
|
8 |
+
year={2019},
|
9 |
+
volume={abs/1907.12763}
|
10 |
+
}
|
11 |
+
```
|
12 |
+
|
13 |
+
It also resembles the MCN model in
|
14 |
+
```
|
15 |
+
@article{Hendricks2017LocalizingMI,
|
16 |
+
title={Localizing Moments in Video with Natural Language},
|
17 |
+
author={Lisa Anne Hendricks and Oliver Wang and Eli Shechtman and Josef Sivic and Trevor Darrell and Bryan C. Russell},
|
18 |
+
journal={2017 IEEE International Conference on Computer Vision (ICCV)},
|
19 |
+
year={2017},
|
20 |
+
pages={5804-5813}
|
21 |
+
}
|
22 |
+
```
|
23 |
+
|
24 |
+
Disclaimer: This code is implemented by [Jie Lei](http://www.cs.unc.edu/~jielei/) for the TVR dataset,
|
25 |
+
it does not guarantee the reproducibility of the original authors' results.
|
baselines/clip_alignment_with_language/__init__.py
ADDED
File without changes
|
baselines/clip_alignment_with_language/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (205 Bytes). View file
|
|
baselines/clip_alignment_with_language/__pycache__/config.cpython-311.pyc
ADDED
Binary file (17.8 kB). View file
|
|
baselines/clip_alignment_with_language/__pycache__/inference.cpython-311.pyc
ADDED
Binary file (43 kB). View file
|
|
baselines/clip_alignment_with_language/__pycache__/model.cpython-311.pyc
ADDED
Binary file (15.8 kB). View file
|
|
baselines/clip_alignment_with_language/__pycache__/proposal_retrieval_dataset.cpython-311.pyc
ADDED
Binary file (37 kB). View file
|
|
baselines/clip_alignment_with_language/config.py
ADDED
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import torch
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
from utils.basic_utils import mkdirp, load_json, save_json, make_zipfile
|
7 |
+
from baselines.clip_alignment_with_language.local_utils.proposal import ProposalConfigs
|
8 |
+
|
9 |
+
|
10 |
+
class BaseOptions(object):
|
11 |
+
saved_option_filename = "opt.json"
|
12 |
+
ckpt_filename = "model.ckpt"
|
13 |
+
tensorboard_log_dir = "tensorboard_log"
|
14 |
+
train_log_filename = "train.log.txt"
|
15 |
+
eval_log_filename = "eval.log.txt"
|
16 |
+
|
17 |
+
def __init__(self):
|
18 |
+
self.parser = argparse.ArgumentParser()
|
19 |
+
self.initialized = False
|
20 |
+
self.opt = None
|
21 |
+
|
22 |
+
def initialize(self):
|
23 |
+
self.initialized = True
|
24 |
+
self.parser.add_argument("--dset_name", type=str, choices=["tvr"])
|
25 |
+
self.parser.add_argument("--eval_split_name", type=str, default="val",
|
26 |
+
help="should match keys in corpus_path, must set for VCMR")
|
27 |
+
self.parser.add_argument("--debug", action="store_true",
|
28 |
+
help="debug (fast) mode, break all loops, do not load all data into memory.")
|
29 |
+
self.parser.add_argument("--data_ratio", type=float, default=1.0,
|
30 |
+
help="how many training and eval data to use. 1.0: use all, 0.1: use 10%."
|
31 |
+
"Use small portion for debug purposes. Note this is different from --debug, "
|
32 |
+
"which works by breaking the loops, typically they are not used together.")
|
33 |
+
self.parser.add_argument("--results_root", type=str, default="results")
|
34 |
+
self.parser.add_argument("--exp_id", type=str, default="res", help="id of the current run")
|
35 |
+
self.parser.add_argument("--seed", type=int, default=2018, help="random seed")
|
36 |
+
self.parser.add_argument("--device", type=int, default=0, help="0 cuda, -1 cpu")
|
37 |
+
self.parser.add_argument("--device_ids", type=int, nargs="+", default=[0], help="GPU ids to run the job")
|
38 |
+
self.parser.add_argument("--num_workers", type=int, default=8,
|
39 |
+
help="num subprocesses used to load the data, 0: use main process")
|
40 |
+
self.parser.add_argument("--no_core_driver", action="store_true",
|
41 |
+
help="hdf5 driver, default use `core` (load into RAM), if specified, use `None`")
|
42 |
+
self.parser.add_argument("--no_pin_memory", action="store_true",
|
43 |
+
help="Don't use pin_memory=True for dataloader. "
|
44 |
+
"ref: https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/4")
|
45 |
+
|
46 |
+
# training config
|
47 |
+
self.parser.add_argument("--lr", type=float, default=0.05, help="learning rate")
|
48 |
+
self.parser.add_argument("--wd", type=float, default=0, help="weight decay")
|
49 |
+
self.parser.add_argument("--momentum", type=float, default=0.95, help="momentum for SGD")
|
50 |
+
self.parser.add_argument("--n_epoch", type=int, default=108, help="number of epochs to run")
|
51 |
+
self.parser.add_argument("--max_es_cnt", type=int, default=108, help="number of epochs to early stop")
|
52 |
+
self.parser.add_argument("--bsz", type=int, default=128, help="mini-batch size")
|
53 |
+
self.parser.add_argument("--eval_query_bsz", type=int, default=1000,
|
54 |
+
help="mini-batch size at inference, for query")
|
55 |
+
self.parser.add_argument("--eval_proposal_bsz", type=int, default=200,
|
56 |
+
help="mini-batch size at inference, for proposals")
|
57 |
+
self.parser.add_argument("--eval_untrained", action="store_true", help="Evaluate on un-trained model")
|
58 |
+
self.parser.add_argument("--grad_clip", type=float, default=-1, help="perform gradient clip, -1: disable")
|
59 |
+
self.parser.add_argument("--margin", type=float, default=0.1, help="margin for hinge loss")
|
60 |
+
self.parser.add_argument("--inter_loss_weight", type=float, default=0.4, help="margin for ranking loss")
|
61 |
+
self.parser.add_argument("--loss_type", type=str, default="hinge", choices=["hinge", "lse"],
|
62 |
+
help="att loss type, can be hinge loss or its smooth approximation LogSumExp")
|
63 |
+
|
64 |
+
# Model and Data config
|
65 |
+
self.parser.add_argument("--max_sub_l", type=int, default=50,
|
66 |
+
help="max length of all sub sentence 97.71 under 50 for 3 sentences")
|
67 |
+
self.parser.add_argument("--max_desc_l", type=int, default=30, help="max length of descriptions")
|
68 |
+
self.parser.add_argument("--pos_iou_thd", type=float, default=0.7, help="moments with IoU >= as positive")
|
69 |
+
self.parser.add_argument("--neg_iou_thd", type=float, default=0.35, help="moments with IoU < as negative")
|
70 |
+
|
71 |
+
self.parser.add_argument("--train_path", type=str, default=None)
|
72 |
+
self.parser.add_argument("--eval_path", type=str, default=None,
|
73 |
+
help="Evaluating during training, for Dev set. If None, will only do training, "
|
74 |
+
"anet_cap and charades_sta has no dev set, so None")
|
75 |
+
self.parser.add_argument("--external_train_vr_res_path", type=str, default=None,
|
76 |
+
help="if set, use external video retrieval results to guide "
|
77 |
+
"inter-nvideo negative sampling. ")
|
78 |
+
self.parser.add_argument("--init_ckpt_path", type=str, default=None,
|
79 |
+
help="init model parameters from checkpoint. Use absolute path")
|
80 |
+
self.parser.add_argument("--external_inference_vr_res_path", type=str, default=None,
|
81 |
+
help="if set, use external video retrieval results to guide evaluation. ")
|
82 |
+
self.parser.add_argument("--use_glove", action="store_true", help="Use GloVe instead of BERT features")
|
83 |
+
self.parser.add_argument("--word2idx_path", type=str,
|
84 |
+
help="a dict, {word: word_idx, ...}, "
|
85 |
+
"special tokens are {<pad>: 0, <unk>: 1, <eos>: 2}")
|
86 |
+
self.parser.add_argument("--vocab_size", type=int, default=-1,
|
87 |
+
help="Set automatically to len(word2idx)")
|
88 |
+
self.parser.add_argument("--glove_path", type=str,
|
89 |
+
help="path to file containing the GloVe embeddings for words in word2idx")
|
90 |
+
self.parser.add_argument("--desc_bert_path", type=str, default=None)
|
91 |
+
self.parser.add_argument("--sub_bert_path", type=str, default=None)
|
92 |
+
self.parser.add_argument("--sub_feat_size", type=int, default=768, help="feature dim for sub feature")
|
93 |
+
self.parser.add_argument("--desc_feat_size", type=int, default=768)
|
94 |
+
self.parser.add_argument("--ctx_mode", type=str,
|
95 |
+
choices=["video", "sub", "tef", "video_sub", "video_tef", "sub_tef", "video_sub_tef"],
|
96 |
+
help="which context to use. a combination of [video, sub, tef]")
|
97 |
+
self.parser.add_argument("--corpus_path", type=str, default=None)
|
98 |
+
self.parser.add_argument("--vid_feat_path", type=str, default="")
|
99 |
+
self.parser.add_argument("--no_norm_vfeat", action="store_true",
|
100 |
+
help="Do not do normalization on video feat, use it when using i3d_resnet concat feat")
|
101 |
+
self.parser.add_argument("--no_norm_tfeat", action="store_true", help="Do not do normalization on text feat")
|
102 |
+
self.parser.add_argument("--clip_length", type=float, default=None,
|
103 |
+
help="each video will be uniformly segmented into small clips, "
|
104 |
+
"will automatically loaded from ProposalConfigs if None")
|
105 |
+
self.parser.add_argument("--vid_feat_size", type=int, help="feature dim for video feature")
|
106 |
+
|
107 |
+
self.parser.add_argument("--model_type", default="cal", choices=["cal", "mcn"])
|
108 |
+
self.parser.add_argument("--embedding_size", type=int, default=768)
|
109 |
+
self.parser.add_argument("--lstm_hidden_size", type=int, default=256)
|
110 |
+
self.parser.add_argument("--visual_hidden_size", type=int, default=256)
|
111 |
+
self.parser.add_argument("--output_size", type=int, default=256)
|
112 |
+
|
113 |
+
# post processing
|
114 |
+
self.parser.add_argument("--nms_thd", type=float, default=-1,
|
115 |
+
help="additionally use non-maximum suppression "
|
116 |
+
"(or non-minimum suppression for distance)"
|
117 |
+
"to post-processing the predictions. "
|
118 |
+
"-1: do not use nms. 0.6 for charades_sta, 0.5 for anet_cap,")
|
119 |
+
self.parser.add_argument("--max_after_nms", type=int, default=100, help="Stores at max_after_nms for eval")
|
120 |
+
self.parser.add_argument("--max_before_nms", type=int, default=300, help="Max before nms")
|
121 |
+
self.parser.add_argument("--use_intermediate", action="store_true",
|
122 |
+
help="Whether to use/save intermediate results to results directory."
|
123 |
+
"Might want use this if we are going to ")
|
124 |
+
|
125 |
+
def save_args(self, opt):
|
126 |
+
args = vars(opt)
|
127 |
+
# Save settings
|
128 |
+
if not isinstance(self, TestOptions):
|
129 |
+
option_file_path = os.path.join(opt.results_dir, self.saved_option_filename) # not yaml file indeed
|
130 |
+
save_json(args, option_file_path, save_pretty=True)
|
131 |
+
|
132 |
+
def parse(self):
|
133 |
+
if not self.initialized:
|
134 |
+
self.initialize()
|
135 |
+
opt = self.parser.parse_args()
|
136 |
+
|
137 |
+
if opt.debug:
|
138 |
+
opt.results_root = os.path.sep.join(opt.results_root.split(os.path.sep)[:-1] + ["debug_results", ])
|
139 |
+
opt.no_core_driver = True
|
140 |
+
opt.num_workers = 0
|
141 |
+
|
142 |
+
if isinstance(self, TestOptions):
|
143 |
+
# modify model_dir to absolute path
|
144 |
+
opt.model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results", opt.model_dir)
|
145 |
+
saved_options = load_json(os.path.join(opt.model_dir, self.saved_option_filename))
|
146 |
+
for arg in saved_options: # use saved options to overwrite all BaseOptions args.
|
147 |
+
if arg not in ["results_root", "num_workers", "nms_thd", "debug", "eval_split_name", "eval_path",
|
148 |
+
"use_intermediate", "external_inference_vr_res_path"]:
|
149 |
+
setattr(opt, arg, saved_options[arg])
|
150 |
+
# opt.no_core_driver = True
|
151 |
+
else:
|
152 |
+
if opt.exp_id is None:
|
153 |
+
raise ValueError("--exp_id is required for at a training option!")
|
154 |
+
|
155 |
+
if opt.clip_length is None:
|
156 |
+
opt.clip_length = ProposalConfigs[opt.dset_name]["clip_length"]
|
157 |
+
opt.results_dir = os.path.join(opt.results_root,
|
158 |
+
"-".join([opt.dset_name, opt.model_type, opt.ctx_mode, opt.exp_id,
|
159 |
+
time.strftime("%Y_%m_%d_%H_%M_%S")]))
|
160 |
+
mkdirp(opt.results_dir)
|
161 |
+
# save a copy of current code
|
162 |
+
code_dir = os.path.dirname(os.path.realpath(__file__))
|
163 |
+
code_zip_filename = os.path.join(opt.results_dir, "code.zip")
|
164 |
+
make_zipfile(code_dir, code_zip_filename,
|
165 |
+
enclosing_dir="code",
|
166 |
+
exclude_dirs_substring="results",
|
167 |
+
exclude_dirs=["results", "debug_results", "__pycache__"],
|
168 |
+
exclude_extensions=[".pyc", ".ipynb", ".swap"])
|
169 |
+
|
170 |
+
self.save_args(opt)
|
171 |
+
|
172 |
+
if "sub" in opt.ctx_mode:
|
173 |
+
assert opt.dset_name == "tvr", "sub is only supported for tvr dataset"
|
174 |
+
|
175 |
+
if "video" in opt.ctx_mode and opt.vid_feat_size > 3000: # 3072, the normalized concatenation of resnet+i3d
|
176 |
+
assert opt.no_norm_vfeat
|
177 |
+
|
178 |
+
opt.ckpt_filepath = os.path.join(opt.results_dir, self.ckpt_filename)
|
179 |
+
opt.train_log_filepath = os.path.join(opt.results_dir, self.train_log_filename)
|
180 |
+
opt.eval_log_filepath = os.path.join(opt.results_dir, self.eval_log_filename)
|
181 |
+
opt.tensorboard_log_dir = os.path.join(opt.results_dir, self.tensorboard_log_dir)
|
182 |
+
opt.device = torch.device("cuda:%d" % opt.device_ids[0] if opt.device >= 0 else "cpu")
|
183 |
+
opt.h5driver = None if opt.no_core_driver else "core"
|
184 |
+
# num_workers > 1 will only work with "core" mode, i.e., memory-mapped hdf5
|
185 |
+
opt.pin_memory = not opt.no_pin_memory
|
186 |
+
opt.num_workers = 1 if opt.no_core_driver else opt.num_workers
|
187 |
+
|
188 |
+
# Display settings
|
189 |
+
print("------------ Options -------------\n{}\n-------------------"
|
190 |
+
.format({str(k): str(v) for k, v in sorted(vars(opt).items())}))
|
191 |
+
self.opt = opt
|
192 |
+
return opt
|
193 |
+
|
194 |
+
|
195 |
+
class TestOptions(BaseOptions):
|
196 |
+
"""add additional options for evaluating"""
|
197 |
+
def initialize(self):
|
198 |
+
BaseOptions.initialize(self)
|
199 |
+
# also need to specify --eval_split_name
|
200 |
+
self.parser.add_argument("--eval_id", type=str, help="evaluation id")
|
201 |
+
self.parser.add_argument("--model_dir", type=str,
|
202 |
+
help="dir contains the model file, will be converted to absolute path afterwards")
|
203 |
+
self.parser.add_argument("--tasks", type=str, nargs="+", choices=["VCMR", "SVMR", "VR"], default="SVMR",
|
204 |
+
help="Which tasks to run."
|
205 |
+
"VCMR: Video Corpus Moment Retrieval;"
|
206 |
+
"SVMR: Single Video Moment Retrieval;"
|
207 |
+
"VR: regular Video Retrieval.")
|
baselines/clip_alignment_with_language/inference.py
ADDED
@@ -0,0 +1,672 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import math
|
4 |
+
import pprint
|
5 |
+
import numpy as np
|
6 |
+
from tqdm import tqdm, trange
|
7 |
+
from collections import defaultdict, OrderedDict
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.backends.cudnn as cudnn
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
|
13 |
+
from baselines.clip_alignment_with_language.config import TestOptions
|
14 |
+
from baselines.clip_alignment_with_language.model import CALWithSub
|
15 |
+
from baselines.clip_alignment_with_language.proposal_retrieval_dataset import \
|
16 |
+
proposal_retrieval_collate, ProposalRetrievalEvalDataset, prepare_batch_inputs
|
17 |
+
from utils.basic_utils import save_jsonl, save_json, load_json
|
18 |
+
from utils.temporal_nms import temporal_non_maximum_suppression
|
19 |
+
from utils.tensor_utils import pad_sequences_1d
|
20 |
+
from standalone_eval.eval import eval_retrieval
|
21 |
+
|
22 |
+
import logging
|
23 |
+
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
|
26 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
27 |
+
level=logging.INFO)
|
28 |
+
|
29 |
+
|
30 |
+
def combine_single_video_proposal_embeddings(proposals_embedding_list, proposals_mask_list):
|
31 |
+
"""
|
32 |
+
Args:
|
33 |
+
proposals_embedding_list: list(torch.Tensor), bsz * (N_prop, N_clips, D_o)
|
34 |
+
proposals_mask_list: list(torch.Tensor), bsz * (N_prop, N_clips)
|
35 |
+
"""
|
36 |
+
if len(proposals_embedding_list) == 1:
|
37 |
+
return proposals_embedding_list[0], proposals_mask_list[0]
|
38 |
+
else: # > 1
|
39 |
+
max_n_clips = max([e.shape[1] for e in proposals_embedding_list])
|
40 |
+
n_proposals = sum([len(e) for e in proposals_embedding_list])
|
41 |
+
d = proposals_embedding_list[0].shape[2]
|
42 |
+
proposals_embedding = proposals_embedding_list[0].new_zeros((n_proposals, max_n_clips, d))
|
43 |
+
proposals_mask = proposals_mask_list[0].new_zeros((n_proposals, max_n_clips))
|
44 |
+
mask_lengths = [0, ] + [len(m) for m in proposals_mask_list]
|
45 |
+
mask_cumsum_lengths = np.cumsum(mask_lengths)
|
46 |
+
for idx, (e, m) in enumerate(zip(proposals_embedding_list, proposals_mask_list)):
|
47 |
+
proposals_embedding[mask_cumsum_lengths[idx]:mask_cumsum_lengths[idx + 1], :e.shape[1]] = e
|
48 |
+
proposals_mask[mask_cumsum_lengths[idx]:mask_cumsum_lengths[idx + 1], :m.shape[1]] = m
|
49 |
+
return proposals_embedding, proposals_mask
|
50 |
+
|
51 |
+
|
52 |
+
def compute_query_embeddings(model, eval_dataset, opt, load_gt_vid_name):
|
53 |
+
"""Use val set to do evaluation, remember to run with torch.no_grad().
|
54 |
+
estimated size 20,000 (query) * 100 (hsz) * 4 / (1024**2) = 7.63 MB
|
55 |
+
"""
|
56 |
+
model.eval()
|
57 |
+
eval_dataset.set_data_mode("query")
|
58 |
+
eval_dataset.load_gt_vid_name_for_query(load_gt_vid_name)
|
59 |
+
query_eval_loader = DataLoader(eval_dataset,
|
60 |
+
collate_fn=proposal_retrieval_collate,
|
61 |
+
batch_size=opt.eval_query_bsz,
|
62 |
+
num_workers=opt.num_workers,
|
63 |
+
shuffle=False,
|
64 |
+
pin_memory=opt.pin_memory)
|
65 |
+
global_meta_list = [] # list(dicts)
|
66 |
+
# n_query = min(len(eval_dataset), opt.eval_query_bsz) if opt.debug else len(eval_dataset)
|
67 |
+
n_query = len(eval_dataset)
|
68 |
+
global_query_embedding = torch.empty((n_query,
|
69 |
+
model.config.output_size),
|
70 |
+
dtype=torch.float32, device=opt.device) # (N_q, D_o)
|
71 |
+
for idx, batch in tqdm(enumerate(query_eval_loader),
|
72 |
+
desc="Computing q embedding",
|
73 |
+
total=len(query_eval_loader)):
|
74 |
+
global_meta_list.extend(batch[0])
|
75 |
+
model_inputs = prepare_batch_inputs(batch[1], device=opt.device, non_blocking=opt.pin_memory)
|
76 |
+
global_query_embedding[idx * opt.eval_query_bsz: (idx + 1) * opt.eval_query_bsz] = \
|
77 |
+
model.query_encoder(**model_inputs)
|
78 |
+
|
79 |
+
if opt.debug:
|
80 |
+
break
|
81 |
+
return global_meta_list, global_query_embedding
|
82 |
+
|
83 |
+
|
84 |
+
def compute_proposal_embeddings(model, eval_dataset, opt):
|
85 |
+
"""Use val set to do evaluation, remember to run with torch.no_grad().
|
86 |
+
estimated 1000 (videos) * 300 (proposals) * 20 (clips) * 100 (hsz) * 4 / (1024 ** 3) = 2.24 GB
|
87 |
+
"""
|
88 |
+
model.eval()
|
89 |
+
eval_dataset.set_data_mode("context")
|
90 |
+
global_meta_list = [] # list(dicts)
|
91 |
+
global_proposal_video_embedding_list = [] # list(torch.tensor), N_videos * [N_prop, N_clips, D_o]
|
92 |
+
global_proposal_sub_embedding_list = [] # list(torch.tensor), N_videos * [N_prop, N_clips, D_o]
|
93 |
+
global_proposal_video_mask_list = [] # list(torch.tensor), N_videos * [N_prop, N_clips]
|
94 |
+
global_proposal_sub_mask_list = [] # list(torch.tensor), N_videos * [N_prop, N_clips]
|
95 |
+
for idx, single_video_info in tqdm(enumerate(eval_dataset),
|
96 |
+
desc="Computing prop embedding for videos",
|
97 |
+
total=len(eval_dataset)):
|
98 |
+
global_meta_list.append(single_video_info["meta"])
|
99 |
+
if model.use_video or model.tef_only:
|
100 |
+
proposals_features_list = single_video_info["model_inputs"]["video_moment_features_list"]
|
101 |
+
proposals_mask_list = single_video_info["model_inputs"]["video_moment_mask_list"]
|
102 |
+
proposals_mask_list = [e.to(opt.device, non_blocking=opt.pin_memory) for e in proposals_mask_list]
|
103 |
+
proposals_embedding_list = [] # (N_prop, D_o)
|
104 |
+
for feat in proposals_features_list:
|
105 |
+
proposals_embedding_list.append(
|
106 |
+
model.moment_encoder(feat.to(opt.device, non_blocking=opt.pin_memory), module_name="video"))
|
107 |
+
p, m = combine_single_video_proposal_embeddings(proposals_embedding_list, proposals_mask_list)
|
108 |
+
global_proposal_video_embedding_list.append(p)
|
109 |
+
global_proposal_video_mask_list.append(m)
|
110 |
+
else:
|
111 |
+
global_proposal_video_embedding_list.append(None)
|
112 |
+
|
113 |
+
if model.use_sub:
|
114 |
+
proposals_features_list = single_video_info["model_inputs"]["sub_moment_features_list"]
|
115 |
+
proposals_mask_list = single_video_info["model_inputs"]["sub_moment_mask_list"]
|
116 |
+
proposals_mask_list = [e.to(opt.device, non_blocking=opt.pin_memory) for e in proposals_mask_list]
|
117 |
+
proposals_embedding_list = [] # (N_prop, D_o)
|
118 |
+
for feat in proposals_features_list:
|
119 |
+
proposals_embedding_list.append(
|
120 |
+
model.moment_encoder(feat.to(opt.device, non_blocking=opt.pin_memory), module_name="sub"))
|
121 |
+
p, m = combine_single_video_proposal_embeddings(proposals_embedding_list, proposals_mask_list)
|
122 |
+
global_proposal_sub_embedding_list.append(p)
|
123 |
+
global_proposal_sub_mask_list.append(m)
|
124 |
+
else:
|
125 |
+
global_proposal_sub_embedding_list.append(None)
|
126 |
+
|
127 |
+
if opt.debug and idx == 100:
|
128 |
+
break
|
129 |
+
global_proposal_mask_list = global_proposal_sub_mask_list if model.use_sub else global_proposal_video_mask_list
|
130 |
+
return global_meta_list, global_proposal_video_embedding_list, \
|
131 |
+
global_proposal_sub_embedding_list, global_proposal_mask_list
|
132 |
+
|
133 |
+
|
134 |
+
def compute_query_proposal_distance(model, eval_dataset, opt, tasks=("SVMR",)):
|
135 |
+
"""compute and save query and video proposal embeddings,
|
136 |
+
tasks: SVMR (single video moment retrieval), VCMR (video corpus moment retrieval)
|
137 |
+
"""
|
138 |
+
is_svmr = "SVMR" in tasks
|
139 |
+
is_vcmr = "VCMR" in tasks
|
140 |
+
query_meta_list, query_embed = compute_query_embeddings(model, eval_dataset, opt,
|
141 |
+
load_gt_vid_name=is_svmr)
|
142 |
+
video_meta_list, video_prop_embed_list, sub_prop_embed_list, prop_mask_list = \
|
143 |
+
compute_proposal_embeddings(model, eval_dataset, opt)
|
144 |
+
|
145 |
+
eval_res = dict(
|
146 |
+
query_meta=query_meta_list, # N_q * dict()
|
147 |
+
video_meta=video_meta_list, # N_videos * dict()
|
148 |
+
video2idx=eval_dataset.video2idx, # dict {vid_name: index}
|
149 |
+
query_prop_dist_vcmr=[], # N_videos * (N_q, N_prop), note N_prop is changing for each video.
|
150 |
+
query_prop_dist_svmr=[], # N_q * (N_prop, ), each query has a GT video, no need to calc. for all.
|
151 |
+
)
|
152 |
+
if is_vcmr:
|
153 |
+
for v_prop_embed, s_prop_embed, prop_mask in tqdm(
|
154 |
+
zip(video_prop_embed_list, sub_prop_embed_list, prop_mask_list),
|
155 |
+
desc="Computing VCMR q to prop dist for videos",
|
156 |
+
total=len(video_prop_embed_list)):
|
157 |
+
query_prop_dist = model.compute_cdist_inference(
|
158 |
+
query_embed, v_prop_embed, s_prop_embed, prop_mask) # (N_q, N_prop)
|
159 |
+
eval_res["query_prop_dist_vcmr"].append(query_prop_dist.cpu())
|
160 |
+
if opt.debug:
|
161 |
+
break
|
162 |
+
|
163 |
+
if is_svmr:
|
164 |
+
if opt.debug:
|
165 |
+
debug_query_meta = []
|
166 |
+
# this is different from video2idx
|
167 |
+
svmr_video2meta_idx = {e["vid_name"]: idx for idx, e in enumerate(video_meta_list)}
|
168 |
+
# logger.info("svmr_video2idx {}".format(list(svmr_video2idx.keys())[:3]))
|
169 |
+
for single_q_embed, single_q_meta in tqdm(zip(query_embed, query_meta_list),
|
170 |
+
desc="Computing SVMR q to prop dist for videos",
|
171 |
+
total=len(query_embed)):
|
172 |
+
# logger.info("single_q_meta[vid_name] {}".format(single_q_meta["vid_name"]))
|
173 |
+
if opt.debug:
|
174 |
+
if single_q_meta["vid_name"] not in svmr_video2meta_idx:
|
175 |
+
continue
|
176 |
+
debug_query_meta.append(single_q_meta)
|
177 |
+
q_gt_vid_meta_idx = svmr_video2meta_idx[single_q_meta["vid_name"]]
|
178 |
+
v_prop_embed = video_prop_embed_list[q_gt_vid_meta_idx] # [N_prop, N_clips, D_o]
|
179 |
+
s_prop_embed = sub_prop_embed_list[q_gt_vid_meta_idx] # [N_prop, N_clips, D_o]
|
180 |
+
prop_mask = prop_mask_list[q_gt_vid_meta_idx] # [N_prop, N_clips]
|
181 |
+
query_prop_dist = model.compute_cdist_inference(
|
182 |
+
single_q_embed.unsqueeze(0), v_prop_embed, s_prop_embed, prop_mask) # (1, N_prop)
|
183 |
+
eval_res["query_prop_dist_svmr"].append(query_prop_dist.squeeze(0).cpu().numpy())
|
184 |
+
if opt.debug:
|
185 |
+
eval_res["query_meta"] = debug_query_meta
|
186 |
+
return eval_res
|
187 |
+
|
188 |
+
|
189 |
+
def filter_vcmr_by_nms(all_video_predictions, nms_threshold=0.6,
|
190 |
+
max_before_nms=1000, max_after_nms=100, score_col_idx=3):
|
191 |
+
""" Apply non-maximum suppression for all the predictions for each video.
|
192 |
+
1) group predictions by video index
|
193 |
+
2) apply nms individually for each video index group
|
194 |
+
3) combine and sort the predictions
|
195 |
+
Args:
|
196 |
+
all_video_predictions: list(sublist),
|
197 |
+
Each sublist is [video_idx (int), st (float), ed(float), score (float)]
|
198 |
+
Note the scores are negative distances.
|
199 |
+
nms_threshold: float
|
200 |
+
max_before_nms: int
|
201 |
+
max_after_nms: int
|
202 |
+
score_col_idx: int
|
203 |
+
Returns:
|
204 |
+
|
205 |
+
"""
|
206 |
+
predictions_neg_by_video_group = defaultdict(list)
|
207 |
+
for pred in all_video_predictions[:max_before_nms]:
|
208 |
+
predictions_neg_by_video_group[pred[0]].append(pred[1:]) # [st (float), ed(float), score (float)]
|
209 |
+
|
210 |
+
predictions_by_video_group_neg_after_nms = dict()
|
211 |
+
for video_idx, grouped_preds in predictions_neg_by_video_group.items():
|
212 |
+
predictions_by_video_group_neg_after_nms[video_idx] = \
|
213 |
+
temporal_non_maximum_suppression(grouped_preds, nms_threshold=nms_threshold)
|
214 |
+
|
215 |
+
predictions_after_nms = []
|
216 |
+
for video_idx, grouped_preds in predictions_by_video_group_neg_after_nms.items():
|
217 |
+
for pred in grouped_preds:
|
218 |
+
pred = [video_idx] + pred # [video_idx (int), st (float), ed(float), score (float)]
|
219 |
+
predictions_after_nms.append(pred)
|
220 |
+
|
221 |
+
# ranking happens across videos
|
222 |
+
predictions_after_nms = sorted(predictions_after_nms,
|
223 |
+
key=lambda x: x[score_col_idx],
|
224 |
+
reverse=True)[:max_after_nms] # descending order
|
225 |
+
return predictions_after_nms
|
226 |
+
|
227 |
+
|
228 |
+
def post_processing_vcmr_nms(vcmr_res, nms_thd=0.6, max_before_nms=1000, max_after_nms=100):
|
229 |
+
"""
|
230 |
+
vcmr_res: list(dict), each dict is{
|
231 |
+
"desc": str,
|
232 |
+
"desc_id": int,
|
233 |
+
"predictions": list(sublist) # each sublist is
|
234 |
+
[video_idx (int), st (float), ed(float), score (float)], video_idx could be different
|
235 |
+
}
|
236 |
+
"""
|
237 |
+
processed_vcmr_res = []
|
238 |
+
for e in vcmr_res:
|
239 |
+
e["predictions"] = filter_vcmr_by_nms(e["predictions"],
|
240 |
+
nms_threshold=nms_thd,
|
241 |
+
max_before_nms=max_before_nms,
|
242 |
+
max_after_nms=max_after_nms)
|
243 |
+
processed_vcmr_res.append(e)
|
244 |
+
return processed_vcmr_res
|
245 |
+
|
246 |
+
|
247 |
+
def post_processing_svmr_nms(svmr_res, nms_thd=0.6, max_before_nms=1000, max_after_nms=100):
|
248 |
+
"""
|
249 |
+
svmr_res: list(dict), each dict is
|
250 |
+
{"desc": str,
|
251 |
+
"desc_id": int,
|
252 |
+
"predictions": list(sublist) # each sublist is
|
253 |
+
[video_idx (int), st (float), ed(float), score (float)], video_idx is the same.
|
254 |
+
}
|
255 |
+
"""
|
256 |
+
processed_svmr_res = []
|
257 |
+
for e in svmr_res:
|
258 |
+
# the predictions are sorted inside the nms func.
|
259 |
+
_predictions = [d[1:] for d in e["predictions"][:max_before_nms]]
|
260 |
+
_predictions = temporal_non_maximum_suppression(
|
261 |
+
_predictions, nms_threshold=nms_thd)[:max_after_nms]
|
262 |
+
_video_id = e["predictions"][0][0] # video_id is the same for all predictions
|
263 |
+
e["predictions"] = [[_video_id, ] + d for d in _predictions]
|
264 |
+
processed_svmr_res.append(e)
|
265 |
+
return processed_svmr_res
|
266 |
+
|
267 |
+
|
268 |
+
def generate_vcmr_predictions_from_res_with_external(eval_res, max_prop_per_query=300, query_bsz_in_sort=1000):
|
269 |
+
""" This function is for Video Corpus Moment Retrieval (VCMR).
|
270 |
+
Generate prediction file which could be evaluated using standalone_eval.eval.
|
271 |
+
Args:
|
272 |
+
eval_res: dict(
|
273 |
+
query_meta=query_meta_list, # N_q * dict(), each dict is {"desc_id": int, "desc": str}
|
274 |
+
video_meta=video_meta_list, # N_videos * dict(), {"vid_name": str, "duration": float, "proposals": ndarray}
|
275 |
+
video2idx=eval_dataset.video2idx, # dict {vid_name: index}
|
276 |
+
video_bsz_in_sort=[], # N_videos * (N_q, N_prop)
|
277 |
+
)
|
278 |
+
max_prop_per_query: int or None. If None, generate ranking for all possible moments, else generate top {}.
|
279 |
+
query_bsz_in_sort: int, only sort a subset of queries at a time, it will be too large to sort all queries.
|
280 |
+
return:
|
281 |
+
list(dicts): each dict is dict(desc=str, desc_id=int, predictions=list(sublist)),
|
282 |
+
each sublist is [vid_name (str), st (float), ed (float), score (float)], score is negative distance.
|
283 |
+
"""
|
284 |
+
# video2idx
|
285 |
+
video2idx = eval_res["video2idx"]
|
286 |
+
video_meta = eval_res["video_meta"]
|
287 |
+
query_meta = eval_res["query_meta"]
|
288 |
+
video_idx2meta_idx = {video2idx[m["vid_name"]]: i for i, m in enumerate(video_meta)}
|
289 |
+
external_query2video = eval_res["external_query2video"] if "external_query2video" in eval_res else None
|
290 |
+
# 「query idx: [video meta idx]」
|
291 |
+
external_query2video_meta_idx = {k: [video_idx2meta_idx[e] for e in v] for k, v in external_query2video.items()}
|
292 |
+
|
293 |
+
external_ordered_video_meta_indices = torch.LongTensor(
|
294 |
+
[external_query2video_meta_idx[e["desc_id"]] for e in query_meta]) # (Nq, 5)
|
295 |
+
top_n_retrieved = external_ordered_video_meta_indices.shape[1]
|
296 |
+
|
297 |
+
# (N_videos, N_prop, N_q), (N_videos, N_prop)
|
298 |
+
padded_dist, padded_mask = pad_sequences_1d([e.transpose(0, 1) for e in eval_res["query_prop_dist_vcmr"]],
|
299 |
+
dtype=eval_res["query_prop_dist_vcmr"][0].dtype,
|
300 |
+
device=eval_res["query_prop_dist_vcmr"][0].device)
|
301 |
+
# putting 'NaN' into the invalid bits, torch.sort considers 'NaN' as larger than any number!!!
|
302 |
+
padded_dist += (padded_mask.unsqueeze(2) == 0).float() * 1e10
|
303 |
+
n_videos, n_prop, n_q = padded_dist.shape
|
304 |
+
padded_dist = padded_dist.permute(2, 0, 1) # (N_q, N_videos, N_prop)
|
305 |
+
|
306 |
+
# get only top retrieved, N_videos now decreased to top_n_retrieved
|
307 |
+
row_indices = torch.arange(n_q, device=padded_dist.device)
|
308 |
+
padded_dist = torch.stack([
|
309 |
+
padded_dist[row_indices, external_ordered_video_meta_indices[:, col_idx]]
|
310 |
+
for col_idx in range(top_n_retrieved)], dim=1) # (N_q, 5, N_prop)
|
311 |
+
n_videos = top_n_retrieved
|
312 |
+
|
313 |
+
padded_dist = padded_dist.view(n_q, -1).contiguous() # (N_q, N_video*N_prop)
|
314 |
+
print("n_videos, n_prop, n_q {}".format((n_videos, n_prop, n_q)))
|
315 |
+
print("padded_dist, {}".format(padded_dist.shape))
|
316 |
+
|
317 |
+
sorted_distances, sorted_indices = torch.topk(padded_dist.to(torch.device("cuda:0"), non_blocking=True),
|
318 |
+
k=min(max_prop_per_query, n_videos * n_prop),
|
319 |
+
dim=1, largest=False, sorted=True) # (N_q, max_prop_per_query) * 2
|
320 |
+
print("orted_distances {}, sorted_indices {}".format(sorted_distances.shape, sorted_indices.shape))
|
321 |
+
sorted_distances = - sorted_distances.cpu().numpy()
|
322 |
+
|
323 |
+
# (N_q, max_prop_per_query) * 2, prop_indices: inside video indices.
|
324 |
+
video_meta_indices_retrieved = torch.floor(sorted_indices.float() / n_prop).long().cpu().numpy()
|
325 |
+
# map back to original video idx (not video meta idx, but real video idx)
|
326 |
+
video_indices = np.array([[external_query2video[query_meta[i]["desc_id"]][j] for j in r]
|
327 |
+
for i, r in enumerate(video_meta_indices_retrieved)]) # (N_q, max_prop_per_query)
|
328 |
+
prop_indices = torch.remainder(sorted_indices, n_prop).cpu().numpy() # (N_q, max_prop_per_query)
|
329 |
+
print("video_indices {}, prop_indices {}".format(video_indices.shape, prop_indices.shape))
|
330 |
+
|
331 |
+
vr_res = []
|
332 |
+
for i in trange(n_q, desc="[VR] Loop over queries to generate predictions"):
|
333 |
+
row = video_indices[i]
|
334 |
+
score_row = - sorted_distances[i]
|
335 |
+
cur_vr_redictions = []
|
336 |
+
for j, video_idx in enumerate(row):
|
337 |
+
cur_vr_redictions.append([int(video_idx), 0, 0, float(score_row[j])])
|
338 |
+
cur_query_pred = dict(
|
339 |
+
desc_id=query_meta[i]["desc_id"],
|
340 |
+
desc=query_meta[i]["desc"],
|
341 |
+
predictions=cur_vr_redictions
|
342 |
+
)
|
343 |
+
vr_res.append(cur_query_pred)
|
344 |
+
|
345 |
+
vcmr_res = []
|
346 |
+
logger.debug("sorted_indices {}".format(sorted_indices.shape))
|
347 |
+
logger.debug("sorted_distances {}".format(sorted_distances.shape))
|
348 |
+
out_bounds_cnt = 0
|
349 |
+
for idx, (v_row_indices, p_row_indices) in tqdm(enumerate(zip(video_indices, prop_indices)),
|
350 |
+
desc="[VCMR] Loop over queries to generate predictions",
|
351 |
+
total=n_q): # query
|
352 |
+
sorted_distances_row = - sorted_distances[idx] # converted to negative distance
|
353 |
+
# [video_idx(int), st(float), ed(float), score(float)]
|
354 |
+
cur_ranked_predictions = []
|
355 |
+
for col_idx, (v_col_idx, p_col_idx) in enumerate(zip(v_row_indices, p_row_indices)):
|
356 |
+
cur_proposals = eval_res["video_meta"][video_idx2meta_idx[v_col_idx]]["proposals"]
|
357 |
+
cur_pred = []
|
358 |
+
cur_pred += [int(v_col_idx), ]
|
359 |
+
# what is wrong with the indexing below??? (out of bounds), but results seems fine???
|
360 |
+
# Not a bug. Since there might be less than max_before_nms proposals from the top retrieved videos
|
361 |
+
if p_col_idx >= len(cur_proposals):
|
362 |
+
out_bounds_cnt += 1
|
363 |
+
p_col_idx = len(cur_proposals)-1
|
364 |
+
cur_pred += cur_proposals[p_col_idx].tolist()
|
365 |
+
cur_pred += [float(sorted_distances_row[col_idx])]
|
366 |
+
cur_ranked_predictions.append(cur_pred)
|
367 |
+
cur_query_pred = dict(
|
368 |
+
desc_id=eval_res["query_meta"][idx]["desc_id"],
|
369 |
+
desc=eval_res["query_meta"][idx]["desc"],
|
370 |
+
predictions=cur_ranked_predictions
|
371 |
+
)
|
372 |
+
vcmr_res.append(cur_query_pred)
|
373 |
+
logger.info("[DEBUG] out_bounds_cnt {}".format(out_bounds_cnt))
|
374 |
+
return vcmr_res, vr_res
|
375 |
+
|
376 |
+
|
377 |
+
def generate_vcmr_predictions_from_res(eval_res, max_prop_per_query=300, query_bsz_in_sort=1000):
|
378 |
+
""" This function is for Video Corpus Moment Retrieval (VCMR).
|
379 |
+
Generate prediction file which could be evaluated using standalone_eval.eval.
|
380 |
+
Args:
|
381 |
+
eval_res: dict(
|
382 |
+
query_meta=query_meta_list, # N_q * dict(), each dict is {"desc_id": int, "desc": str}
|
383 |
+
video_meta=video_meta_list, # N_videos * dict(), {"vid_name": str, "duration": float, "proposals": ndarray}
|
384 |
+
video2idx=eval_dataset.video2idx, # dict {vid_name: index}
|
385 |
+
video_bsz_in_sort=[], # N_videos * (N_q, N_prop)
|
386 |
+
)
|
387 |
+
max_prop_per_query: int or None. If None, generate ranking for all possible moments, else generate top {}.
|
388 |
+
query_bsz_in_sort: int, only sort a subset of queries at a time, it will be too large to sort all queries.
|
389 |
+
return:
|
390 |
+
list(dicts): each dict is dict(desc=str, desc_id=int, predictions=list(sublist)),
|
391 |
+
each sublist is [vid_name (str), st (float), ed (float), score (float)], score is negative distance.
|
392 |
+
"""
|
393 |
+
# video2idx
|
394 |
+
video2idx = eval_res["video2idx"]
|
395 |
+
|
396 |
+
# (N_videos, N_prop, N_q), (N_videos, N_prop)
|
397 |
+
padded_dist, padded_mask = pad_sequences_1d([e.transpose(0, 1) for e in eval_res["query_prop_dist_vcmr"]],
|
398 |
+
dtype=eval_res["query_prop_dist_vcmr"][0].dtype,
|
399 |
+
device=eval_res["query_prop_dist_vcmr"][0].device)
|
400 |
+
# putting 'NaN' into the invalid bits, torch.sort considers 'NaN' as larger than any number!!!
|
401 |
+
padded_dist += (padded_mask.unsqueeze(2) == 0).float() * 1e10
|
402 |
+
n_videos, n_prop, n_q = padded_dist.shape
|
403 |
+
print("n_videos, n_prop, n_q {}".format((n_videos, n_prop, n_q)))
|
404 |
+
padded_dist = padded_dist.view(n_videos * n_prop, n_q).transpose(0, 1).contiguous() # (N_q, N_video*N_prop)
|
405 |
+
print("padded_dist, {}".format(padded_dist.shape))
|
406 |
+
|
407 |
+
sorted_distances, sorted_indices = torch.topk(padded_dist.to(torch.device("cuda:0"), non_blocking=True),
|
408 |
+
k=min(max_prop_per_query, n_videos * n_prop),
|
409 |
+
dim=1, largest=False, sorted=True) # (N_q, max_prop_per_query) * 2
|
410 |
+
sorted_distances = - sorted_distances.cpu().numpy()
|
411 |
+
|
412 |
+
# (N_q, max_prop_per_query) * 2, prop_indices: inside video indices.
|
413 |
+
video_meta_indices = torch.floor(sorted_indices.float() / n_prop).long().cpu().numpy()
|
414 |
+
prop_indices = torch.remainder(sorted_indices, n_prop).cpu().numpy()
|
415 |
+
|
416 |
+
vr_res = []
|
417 |
+
query_meta = eval_res["query_meta"]
|
418 |
+
for i in trange(n_q, desc="[VR] Loop over queries to generate predictions"):
|
419 |
+
row = video_meta_indices[i]
|
420 |
+
score_row = - sorted_distances[i]
|
421 |
+
cur_vr_redictions = []
|
422 |
+
for j, meta_idx in enumerate(row):
|
423 |
+
video_idx = video2idx[eval_res["video_meta"][meta_idx]["vid_name"]]
|
424 |
+
cur_vr_redictions.append([video_idx, 0, 0, float(score_row[j])])
|
425 |
+
cur_query_pred = dict(
|
426 |
+
desc_id=query_meta[i]["desc_id"],
|
427 |
+
desc=query_meta[i]["desc"],
|
428 |
+
predictions=cur_vr_redictions
|
429 |
+
)
|
430 |
+
vr_res.append(cur_query_pred)
|
431 |
+
|
432 |
+
vcmr_res = []
|
433 |
+
logger.debug("sorted_indices {}".format(sorted_indices.shape))
|
434 |
+
logger.debug("sorted_distances {}".format(sorted_distances.shape))
|
435 |
+
for idx, (vm_row_indices, p_row_indices) in tqdm(enumerate(zip(video_meta_indices, prop_indices)),
|
436 |
+
desc="[VCMR] Loop over queries to generate predictions",
|
437 |
+
total=n_q): # query
|
438 |
+
sorted_distances_row = - sorted_distances[idx] # converted to negative distance
|
439 |
+
# [video_idx(int), st(float), ed(float), score(float)]
|
440 |
+
cur_ranked_predictions = []
|
441 |
+
for col_idx, (v_col_idx, p_col_idx) in enumerate(zip(vm_row_indices, p_row_indices)):
|
442 |
+
cur_pred = []
|
443 |
+
cur_pred += [video2idx[eval_res["video_meta"][v_col_idx]["vid_name"]], ]
|
444 |
+
cur_pred += eval_res["video_meta"][v_col_idx]["proposals"][p_col_idx].tolist()
|
445 |
+
cur_pred += [float(sorted_distances_row[col_idx])]
|
446 |
+
cur_ranked_predictions.append(cur_pred)
|
447 |
+
cur_query_pred = dict(
|
448 |
+
desc_id=eval_res["query_meta"][idx]["desc_id"],
|
449 |
+
desc=eval_res["query_meta"][idx]["desc"],
|
450 |
+
predictions=cur_ranked_predictions
|
451 |
+
)
|
452 |
+
vcmr_res.append(cur_query_pred)
|
453 |
+
return vcmr_res, vr_res
|
454 |
+
|
455 |
+
|
456 |
+
def generate_svmr_predictions_from_res(eval_res, max_prop_per_query=None):
|
457 |
+
""" This function is for Video Corpus Moment Retrieval (VCMR).
|
458 |
+
Generate prediction file which could be evaluated using standalone_eval.eval.
|
459 |
+
Args:
|
460 |
+
eval_res: dict(
|
461 |
+
query_meta=query_meta_list, # N_q * dict(), each dict is {"desc_id": int, "desc": str}
|
462 |
+
video_meta=video_meta_list, # N_videos * dict(), {"vid_name": str, "duration": float, "proposals": ndarray}
|
463 |
+
video2idx=eval_dataset.video2idx, # dict {vid_name: index}
|
464 |
+
query_prop_dist_svmr=[], # N_q * (N_prop, )
|
465 |
+
)
|
466 |
+
max_prop_per_query: not used
|
467 |
+
return:
|
468 |
+
list(dicts): each dict is dict(desc=str, desc_id=int, predictions=list(sublist)),
|
469 |
+
each sublist is [vid_name (str), st (float), ed (float), score (float)], score is negative distance.
|
470 |
+
"""
|
471 |
+
video2idx = eval_res["video2idx"]
|
472 |
+
|
473 |
+
svmr_res = []
|
474 |
+
svmr_video2meta_idx = {e["vid_name"]: idx for idx, e in enumerate(eval_res["video_meta"])}
|
475 |
+
for idx, (q_p_dist, q_m) in tqdm(enumerate(zip(eval_res["query_prop_dist_svmr"], eval_res["query_meta"])),
|
476 |
+
desc="Loop over queries to generate predictions",
|
477 |
+
total=len(eval_res["query_prop_dist_svmr"])): # query
|
478 |
+
sorted_indices = np.argsort(q_p_dist) # (N_prop, ) # ascending order, distance
|
479 |
+
if max_prop_per_query is not None:
|
480 |
+
sorted_indices = sorted_indices[:max_prop_per_query]
|
481 |
+
v_eval_idx = video2idx[q_m["vid_name"]]
|
482 |
+
v_meta_idx = svmr_video2meta_idx[q_m["vid_name"]]
|
483 |
+
proposals = eval_res["video_meta"][v_meta_idx]["proposals"] # (N_p, 2)
|
484 |
+
# [video_idx(int), st(float), ed(float), score(float)]
|
485 |
+
cur_ranked_predictions = [
|
486 |
+
[v_eval_idx, ] + proposals[sort_idx].tolist() + [- round(float(q_p_dist[sort_idx]), 4), ]
|
487 |
+
for sort_idx in sorted_indices]
|
488 |
+
cur_query_pred = dict(
|
489 |
+
desc_id=q_m["desc_id"],
|
490 |
+
desc=q_m["desc"],
|
491 |
+
predictions=cur_ranked_predictions
|
492 |
+
)
|
493 |
+
svmr_res.append(cur_query_pred)
|
494 |
+
return svmr_res
|
495 |
+
|
496 |
+
|
497 |
+
POST_PROCESSING_MMS_FUNC = {
|
498 |
+
"SVMR": post_processing_svmr_nms,
|
499 |
+
"VCMR": post_processing_vcmr_nms
|
500 |
+
}
|
501 |
+
|
502 |
+
|
503 |
+
def get_submission_top_n(submission, top_n=100):
|
504 |
+
def get_prediction_top_n(list_dict_predictions, top_n):
|
505 |
+
top_n_res = []
|
506 |
+
for e in list_dict_predictions:
|
507 |
+
e["predictions"] = e["predictions"][:top_n]
|
508 |
+
top_n_res.append(e)
|
509 |
+
return top_n_res
|
510 |
+
|
511 |
+
top_n_submission = dict(video2idx=submission["video2idx"], )
|
512 |
+
for k in submission:
|
513 |
+
if k != "video2idx":
|
514 |
+
top_n_submission[k] = get_prediction_top_n(submission[k], top_n)
|
515 |
+
return top_n_submission
|
516 |
+
|
517 |
+
|
518 |
+
def load_external_vr_res(external_vr_res_path, top_n_vr_videos=5):
|
519 |
+
"""return a mapping from desc_id to top retrieved video id"""
|
520 |
+
external_vr_res = load_json(external_vr_res_path)
|
521 |
+
external_vr_res = get_submission_top_n(external_vr_res, top_n=top_n_vr_videos)["VR"]
|
522 |
+
query2video = {e["desc_id"]: [sub_e[0] for sub_e in e["predictions"]] for e in external_vr_res}
|
523 |
+
return query2video
|
524 |
+
|
525 |
+
|
526 |
+
def eval_epoch(model, eval_dataset, opt, save_submission_filename,
|
527 |
+
tasks=("SVMR",), max_before_nms=1000, max_after_nms=100):
|
528 |
+
model.eval()
|
529 |
+
logger.info("Computing scores")
|
530 |
+
logger.info("Start timing")
|
531 |
+
# times = [] # do not use
|
532 |
+
# for _ in range(3):
|
533 |
+
# st_time = time.time()
|
534 |
+
if opt.use_intermediate:
|
535 |
+
intermediate_cache_path = os.path.join(opt.results_dir, "{}_eval_res.pt".format(opt.eval_split_name))
|
536 |
+
if not os.path.exists(intermediate_cache_path):
|
537 |
+
logger.info("Saving intermediate results {}.".format(intermediate_cache_path))
|
538 |
+
eval_res = compute_query_proposal_distance(model, eval_dataset, opt, tasks=tasks)
|
539 |
+
torch.save(eval_res, intermediate_cache_path)
|
540 |
+
else:
|
541 |
+
logger.info("Loading intermediate results {}.".format(intermediate_cache_path))
|
542 |
+
eval_res = torch.load(intermediate_cache_path)
|
543 |
+
else:
|
544 |
+
logger.info("Running without saving intermediate results, you might want to turn on --use_intermediate.")
|
545 |
+
eval_res = compute_query_proposal_distance(model, eval_dataset, opt, tasks=tasks)
|
546 |
+
# del model # We dont need model anymore
|
547 |
+
|
548 |
+
# eval_res = compute_query_proposal_distance(model, eval_dataset, opt, tasks=tasks)
|
549 |
+
|
550 |
+
logger.info("Generating predictions from scores")
|
551 |
+
eval_submission_raw = dict(video2idx=eval_res["video2idx"])
|
552 |
+
if "SVMR" in tasks:
|
553 |
+
eval_submission_raw["SVMR"] = generate_svmr_predictions_from_res(
|
554 |
+
eval_res, max_prop_per_query=max_before_nms)
|
555 |
+
# vcmr_loading_time = 0
|
556 |
+
if "VCMR" in tasks:
|
557 |
+
if opt.external_inference_vr_res_path is not None:
|
558 |
+
logger.info("Using external VR results from {}".format(opt.external_inference_vr_res_path))
|
559 |
+
# vcmr_loading_time = time.time()
|
560 |
+
eval_res["external_query2video"] = load_external_vr_res(
|
561 |
+
opt.external_inference_vr_res_path, top_n_vr_videos=5)
|
562 |
+
# vcmr_loading_time = time.time() - vcmr_loading_time
|
563 |
+
vcmr_res, vr_res = generate_vcmr_predictions_from_res_with_external(
|
564 |
+
eval_res, max_prop_per_query=max_before_nms)
|
565 |
+
else:
|
566 |
+
vcmr_res, vr_res = generate_vcmr_predictions_from_res(
|
567 |
+
eval_res, max_prop_per_query=max_before_nms)
|
568 |
+
eval_submission_raw["VCMR"] = vcmr_res
|
569 |
+
eval_submission_raw["VR"] = vr_res
|
570 |
+
# times += [time.time() - st_time - vcmr_loading_time]
|
571 |
+
# times = torch.FloatTensor(times)
|
572 |
+
IOU_THDS = (0.5, 0.7)
|
573 |
+
|
574 |
+
logger.info("Saving/Evaluating before nms results")
|
575 |
+
submission_path = os.path.join(opt.results_dir, save_submission_filename)
|
576 |
+
eval_submission = get_submission_top_n(eval_submission_raw, top_n=max_after_nms)
|
577 |
+
if max_after_nms < 1000:
|
578 |
+
save_json(eval_submission, submission_path)
|
579 |
+
else:
|
580 |
+
torch.save(eval_submission, submission_path.replace(".json", ".pt"))
|
581 |
+
|
582 |
+
metrics = eval_retrieval(eval_submission, eval_dataset.query_data,
|
583 |
+
iou_thds=IOU_THDS, match_number=not opt.debug, verbose=opt.debug,
|
584 |
+
use_desc_type=opt.dset_name == "tvr")
|
585 |
+
# metrics["time_avg"] = float(times.mean())
|
586 |
+
# metrics["time_std"] = float(times.std())
|
587 |
+
save_metrics_path = submission_path.replace(".json", "_metrics.json")
|
588 |
+
save_json(metrics, save_metrics_path, save_pretty=True, sort_keys=False)
|
589 |
+
latest_file_paths = [submission_path, save_metrics_path]
|
590 |
+
|
591 |
+
if opt.nms_thd != -1:
|
592 |
+
logger.info("Performing nms with nms_thd {}".format(opt.nms_thd))
|
593 |
+
eval_submission_after_nms = dict(video2idx=eval_submission_raw["video2idx"])
|
594 |
+
for k, nms_func in POST_PROCESSING_MMS_FUNC.items():
|
595 |
+
if k in eval_submission_raw:
|
596 |
+
eval_submission_after_nms[k] = nms_func(eval_submission_raw[k],
|
597 |
+
nms_thd=opt.nms_thd,
|
598 |
+
max_before_nms=max_before_nms,
|
599 |
+
max_after_nms=max_after_nms)
|
600 |
+
|
601 |
+
logger.info("Saving/Evaluating nms results")
|
602 |
+
submission_nms_path = submission_path.replace(".json", "_nms_thd_{}.json".format(opt.nms_thd))
|
603 |
+
save_json(eval_submission_after_nms, submission_nms_path)
|
604 |
+
metrics_nms = eval_retrieval(eval_submission_after_nms, eval_dataset.query_data,
|
605 |
+
iou_thds=IOU_THDS, match_number=not opt.debug, verbose=opt.debug)
|
606 |
+
save_metrics_nms_path = submission_nms_path.replace(".json", "_metrics.json")
|
607 |
+
save_json(metrics_nms, save_metrics_nms_path, save_pretty=True, sort_keys=False)
|
608 |
+
latest_file_paths += [submission_nms_path, save_metrics_nms_path]
|
609 |
+
else:
|
610 |
+
metrics_nms = None
|
611 |
+
return metrics, metrics_nms, latest_file_paths
|
612 |
+
|
613 |
+
|
614 |
+
def setup_model(opt):
|
615 |
+
"""Load model from checkpoint and move to specified device"""
|
616 |
+
checkpoint = torch.load(opt.ckpt_filepath)
|
617 |
+
model = CALWithSub(checkpoint["model_cfg"])
|
618 |
+
model.load_state_dict(checkpoint["model"])
|
619 |
+
logger.info("Loaded model saved at epoch {} from checkpoint: {}"
|
620 |
+
.format(checkpoint["epoch"], opt.ckpt_filepath))
|
621 |
+
|
622 |
+
if opt.device.type == "cuda":
|
623 |
+
logger.info("CUDA enabled.")
|
624 |
+
model.to(opt.device)
|
625 |
+
if len(opt.device_ids) > 1:
|
626 |
+
logger.info("Use multi GPU", opt.device_ids)
|
627 |
+
model = torch.nn.DataParallel(model, device_ids=opt.device_ids) # use multi GPU
|
628 |
+
return model
|
629 |
+
|
630 |
+
|
631 |
+
def start_inference():
|
632 |
+
logger.info("Setup config, data and model...")
|
633 |
+
opt = TestOptions().parse()
|
634 |
+
cudnn.benchmark = False
|
635 |
+
cudnn.deterministic = True
|
636 |
+
|
637 |
+
assert opt.eval_path is not None
|
638 |
+
eval_dataset = ProposalRetrievalEvalDataset(
|
639 |
+
dset_name=opt.dset_name,
|
640 |
+
model_type=opt.model_type,
|
641 |
+
eval_split_name=opt.eval_split_name, # should only be val set
|
642 |
+
data_path=opt.eval_path,
|
643 |
+
desc_bert_path_or_handler=opt.desc_bert_path,
|
644 |
+
sub_bert_path_or_handler=opt.sub_bert_path,
|
645 |
+
max_desc_len=opt.max_desc_l,
|
646 |
+
corpus_path=opt.corpus_path,
|
647 |
+
vid_feat_path_or_handler=opt.vid_feat_path,
|
648 |
+
clip_length=opt.clip_length,
|
649 |
+
eval_proposal_bsz=opt.eval_proposal_bsz,
|
650 |
+
ctx_mode=opt.ctx_mode,
|
651 |
+
data_mode="query",
|
652 |
+
h5driver=opt.h5driver,
|
653 |
+
data_ratio=opt.data_ratio,
|
654 |
+
normalize_vfeat=not opt.no_norm_vfeat,
|
655 |
+
normalize_tfeat=not opt.no_norm_tfeat,
|
656 |
+
)
|
657 |
+
|
658 |
+
model = setup_model(opt)
|
659 |
+
save_submission_filename = \
|
660 |
+
"inference_{}_{}_{}_predictions_{}.json".format(
|
661 |
+
opt.dset_name, opt.eval_split_name, opt.eval_id, "_".join(opt.tasks))
|
662 |
+
logger.info("Starting inference...")
|
663 |
+
with torch.no_grad():
|
664 |
+
metrics_no_nms, metrics_nms, latest_file_paths = \
|
665 |
+
eval_epoch(model, eval_dataset, opt, save_submission_filename, tasks=opt.tasks,
|
666 |
+
max_before_nms=opt.max_before_nms, max_after_nms=opt.max_after_nms)
|
667 |
+
logger.info("metrics_no_nms \n{}".format(pprint.pformat(metrics_no_nms, indent=4)))
|
668 |
+
logger.info("metrics_nms \n{}".format(pprint.pformat(metrics_nms, indent=4)))
|
669 |
+
|
670 |
+
|
671 |
+
if __name__ == '__main__':
|
672 |
+
start_inference()
|
baselines/clip_alignment_with_language/local_utils/__init__.py
ADDED
File without changes
|
baselines/clip_alignment_with_language/local_utils/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (217 Bytes). View file
|
|
baselines/clip_alignment_with_language/local_utils/__pycache__/compute_proposal_upper_bound.cpython-311.pyc
ADDED
Binary file (8.16 kB). View file
|
|
baselines/clip_alignment_with_language/local_utils/__pycache__/proposal.cpython-311.pyc
ADDED
Binary file (7.9 kB). View file
|
|
baselines/clip_alignment_with_language/local_utils/compute_proposal_upper_bound.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Compute oracle upper bound for a given proposal method, which acts like
|
3 |
+
a reversed recall, where we recall the GT timestamp pairs in the set of
|
4 |
+
generated proposals.
|
5 |
+
"""
|
6 |
+
import pprint
|
7 |
+
import numpy as np
|
8 |
+
from tqdm import tqdm
|
9 |
+
from collections import Counter
|
10 |
+
from utils.basic_utils import load_jsonl, save_json
|
11 |
+
from standalone_eval.eval import compute_temporal_iou_batch
|
12 |
+
from baselines.clip_alignment_with_language.local_utils.proposal import get_proposal_interface, ProposalConfigs
|
13 |
+
|
14 |
+
|
15 |
+
def get_didemo_agreed_ts(times_list):
|
16 |
+
"""
|
17 |
+
input example: [[1, 1], [1, 1], [1, 1], [0, 0]],
|
18 |
+
return: [1, 1]"""
|
19 |
+
times_str_list = [tuple(e) for e in times_list]
|
20 |
+
times_str_list_counter = Counter(times_str_list)
|
21 |
+
most_frequent_times = times_str_list_counter.most_common(1)[0][0]
|
22 |
+
return most_frequent_times
|
23 |
+
|
24 |
+
|
25 |
+
def get_proposals_for_single_desc_video_pair(single_data, proposal_fn, dset_name):
|
26 |
+
proposal_info = dict(
|
27 |
+
vid_name=single_data["vid_name"],
|
28 |
+
desc_id=single_data["desc_id"],
|
29 |
+
gt_ts=single_data["ts"] if dset_name != "didemo" else get_didemo_agreed_ts(single_data["ts"]),
|
30 |
+
proposals=proposal_fn(video_id="", metadata={"duration": single_data["duration"]}),
|
31 |
+
)
|
32 |
+
proposal_info["proposal_ious"] = compute_temporal_iou_batch(
|
33 |
+
proposal_info["proposals"], proposal_info["gt_ts"])
|
34 |
+
return proposal_info
|
35 |
+
|
36 |
+
|
37 |
+
def get_proposals_for_videos(datalist, dset_name):
|
38 |
+
"""datalist list(dict): each dict is
|
39 |
+
{"desc_id": str/int, "duration": float, "ts": [st (float), ed (float)], ...}
|
40 |
+
Note for Didemo dataset, "ts" entry is a list of [st (float), ed (float)] from different annotators,
|
41 |
+
here we use the most frequent ts, we break ties by randomly sample one
|
42 |
+
"""
|
43 |
+
proposal_interface = get_proposal_interface(dset_name)
|
44 |
+
video_proposals_list = []
|
45 |
+
for e in tqdm(datalist, desc="Computing video proposals"):
|
46 |
+
video_proposals_list.append(
|
47 |
+
get_proposals_for_single_desc_video_pair(e, proposal_interface, dset_name))
|
48 |
+
return video_proposals_list
|
49 |
+
|
50 |
+
|
51 |
+
def is_recalled_single_moment(proposal_ious, iou_thds=(0.5, 0.7)):
|
52 |
+
"""
|
53 |
+
Args:
|
54 |
+
proposal_ious: np.ndarray, shape (N_proposal, )
|
55 |
+
iou_thds: set, temporal IoU thresholds
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
list(bool), len == len(iou_thds), indicates whether recall under a iou_thd is found.
|
59 |
+
"""
|
60 |
+
recalled = [False, ] * len(iou_thds)
|
61 |
+
for idx, iou_thd in enumerate(iou_thds):
|
62 |
+
recalled[idx] = np.sum(proposal_ious >= iou_thd) >= 1 # at least one
|
63 |
+
return recalled
|
64 |
+
|
65 |
+
|
66 |
+
def compute_proposal_recall_upper_bound(video_proposals_list, iou_thds=(0.5, 0.7)):
|
67 |
+
"""video_proposals_list from get_proposals_for_videos()"""
|
68 |
+
iou_corrects = np.empty((len(video_proposals_list), 2), dtype=np.float32)
|
69 |
+
for idx, d in tqdm(enumerate(video_proposals_list),
|
70 |
+
desc="Computing recall for videos",
|
71 |
+
total=len(video_proposals_list)):
|
72 |
+
iou_corrects[idx] = is_recalled_single_moment(d["proposal_ious"],
|
73 |
+
iou_thds=iou_thds)
|
74 |
+
recall_by_iou = {iou_thd: float(np.mean(iou_corrects[:, idx]))
|
75 |
+
for idx, iou_thd in enumerate(iou_thds)}
|
76 |
+
return recall_by_iou
|
77 |
+
|
78 |
+
|
79 |
+
def main_compute_upper_bound():
|
80 |
+
import argparse
|
81 |
+
parser = argparse.ArgumentParser()
|
82 |
+
parser.add_argument("-dset_name", type=str, choices=["tvr"])
|
83 |
+
parser.add_argument("-eval_file_path", type=str, help="path to the file containing data to be evaluated")
|
84 |
+
parser.add_argument("-save_path", type=str, help="path to save the results")
|
85 |
+
parser.add_argument("-verbose", action="store_true")
|
86 |
+
args = parser.parse_args()
|
87 |
+
|
88 |
+
eval_datalist = load_jsonl(args.eval_file_path)
|
89 |
+
video_proposals_list = get_proposals_for_videos(eval_datalist, args.dset_name)
|
90 |
+
recall_metrics = compute_proposal_recall_upper_bound(video_proposals_list, iou_thds=(0.5, 0.7))
|
91 |
+
|
92 |
+
video_proposals_list_by_video = {}
|
93 |
+
for p in video_proposals_list:
|
94 |
+
if p["vid_name"] in video_proposals_list_by_video:
|
95 |
+
continue
|
96 |
+
else:
|
97 |
+
video_proposals_list_by_video[p["vid_name"]] = p
|
98 |
+
video_proposals_list_by_video = list(video_proposals_list_by_video.values())
|
99 |
+
total_n_clips_in_proposals = \
|
100 |
+
np.sum([np.sum(e["proposals"][:, 1] - e["proposals"][:, 0]) for e in video_proposals_list_by_video])
|
101 |
+
|
102 |
+
results = dict(
|
103 |
+
avg_num_proposals=float(np.mean([len(e["proposals"]) for e in video_proposals_list_by_video])),
|
104 |
+
total_num_proposals=int(np.sum([len(e["proposals"]) for e in video_proposals_list_by_video])),
|
105 |
+
recall_metrics=recall_metrics,
|
106 |
+
dset_name=args.dset_name,
|
107 |
+
filename=args.eval_file_path,
|
108 |
+
proposal_config=ProposalConfigs[args.dset_name]
|
109 |
+
)
|
110 |
+
results["avg_clip_per_proposal"] = total_n_clips_in_proposals / results["total_num_proposals"]
|
111 |
+
save_json(results, args.save_path, save_pretty=True)
|
112 |
+
if args.verbose:
|
113 |
+
pprint.pprint(results)
|
114 |
+
|
115 |
+
|
116 |
+
if __name__ == '__main__':
|
117 |
+
main_compute_upper_bound()
|
baselines/clip_alignment_with_language/local_utils/proposal.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
#
|
3 |
+
# Copyright (c) 2018 Victor Escorcia Castillo
|
4 |
+
#
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
#
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
#
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE.
|
22 |
+
# ==============================================================================
|
23 |
+
"""
|
24 |
+
Group multiple methods to generate salient temporal windows in a video"""
|
25 |
+
import itertools
|
26 |
+
import numpy as np
|
27 |
+
|
28 |
+
PROPOSAL_SCHEMES = ['DidemoICCV17SS', 'SlidingWindowMSRSS']
|
29 |
+
|
30 |
+
|
31 |
+
class TemporalProposalsBase:
|
32 |
+
"""Base class (signature) to generate temporal candidate in video"""
|
33 |
+
def __call__(self, video_id, metadata=None, feature_collection=None):
|
34 |
+
raise NotImplementedError('Implement with the signature above')
|
35 |
+
|
36 |
+
|
37 |
+
class DidemoICCV17SS(TemporalProposalsBase):
|
38 |
+
"""Original search space of moments proposed in ICCV-2017
|
39 |
+
|
40 |
+
Attributes:
|
41 |
+
clip_length_min (float) : minimum length, in seconds, of a video clip.
|
42 |
+
proposals (numpy array) : of shape [21, 2] representing all the
|
43 |
+
possible temporal segments of valid annotations of DiDeMo dataset.
|
44 |
+
It represents the search space of a temporal localization
|
45 |
+
algorithm.
|
46 |
+
|
47 |
+
Reference: Hendricks et al. Localizing Moments in Video with Natural
|
48 |
+
Language. ICCV 2017.
|
49 |
+
"""
|
50 |
+
clip_length_min = 5.0
|
51 |
+
|
52 |
+
def __init__(self, *args, dtype=np.float32, **kwargs):
|
53 |
+
clips_indices = [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5)]
|
54 |
+
for i in itertools.combinations(range(len(clips_indices)), 2):
|
55 |
+
clips_indices.append(i)
|
56 |
+
self.proposals = np.array(clips_indices, dtype=dtype)
|
57 |
+
self.proposals *= self.clip_length_min
|
58 |
+
self.proposals[:, 1] += self.clip_length_min
|
59 |
+
|
60 |
+
def __call__(self, *args, **kwargs):
|
61 |
+
return self.proposals
|
62 |
+
|
63 |
+
|
64 |
+
class SlidingWindowMSRSS(TemporalProposalsBase):
|
65 |
+
"""Multi-scale sliding window with relative stride within the same scale
|
66 |
+
|
67 |
+
Attributes:
|
68 |
+
length (float) : length of smallest window.
|
69 |
+
scales (sequence of int) : duration of moments relative to
|
70 |
+
`length`.
|
71 |
+
stride (float) : relative stride between two windows with the same
|
72 |
+
duration. We used different strides for each scale rounding it
|
73 |
+
towards a multiple of `length`. Note that the minimum stride is
|
74 |
+
`length` for any window will be the `length` itself.
|
75 |
+
dtype (numpy.dtype) :
|
76 |
+
"""
|
77 |
+
|
78 |
+
def __init__(self, length, scales, stride=0.5, round_base=0.5, dtype=np.float32):
|
79 |
+
self.length = length
|
80 |
+
self.scales = scales
|
81 |
+
self.round_base = round_base
|
82 |
+
self.relative_stride = stride
|
83 |
+
# pick strides per scale that are multiples of length
|
84 |
+
self.strides = [max(round(s * stride / round_base) * round_base, round_base)
|
85 |
+
* length for s in scales]
|
86 |
+
self.dtype = dtype
|
87 |
+
assert len(scales) > 0
|
88 |
+
|
89 |
+
def sliding_windows(self, t_end, t_start=0):
|
90 |
+
"""sliding canonical windows over a given time interval"""
|
91 |
+
windows_ = []
|
92 |
+
for i, stride in enumerate(self.strides):
|
93 |
+
num_i = np.ceil((t_end - t_start) / stride)
|
94 |
+
windows_i = np.empty((int(num_i), 2), dtype=np.float32)
|
95 |
+
windows_i[:, 0] = np.arange(t_start, t_end, stride)
|
96 |
+
windows_i[:, 1] = windows_i[:, 0] + self.length * self.scales[i]
|
97 |
+
windows_i[windows_i[:, 1] > t_end, 1] = t_end
|
98 |
+
windows_.append(windows_i)
|
99 |
+
# print("--------------------------------{}".format(i))
|
100 |
+
# print(windows_i)
|
101 |
+
# import sys
|
102 |
+
# sys.exit(1)
|
103 |
+
windows = np.concatenate(windows_, axis=0)
|
104 |
+
# Hacky way to make windows fit inside video
|
105 |
+
# It implies windows at the end may not belong to the set spanned by
|
106 |
+
# length and scales.
|
107 |
+
return np.unique(windows, axis=0)
|
108 |
+
|
109 |
+
def __call__(self, video_id, metadata=None, feature_collection=None):
|
110 |
+
"""return: (N_window, 2), each row contains (start, end)"""
|
111 |
+
duration = metadata.get('duration')
|
112 |
+
assert duration is not None
|
113 |
+
return self.sliding_windows(duration)
|
114 |
+
|
115 |
+
|
116 |
+
ProposalConfigs = {
|
117 |
+
"didemo": {
|
118 |
+
"proposal_interface": "DidemoICCV17SS",
|
119 |
+
"clip_length": 2.5,
|
120 |
+
},
|
121 |
+
"tvr": {
|
122 |
+
"length": 3, # min proposal length
|
123 |
+
"scales": [1, 2, 4, 8],
|
124 |
+
"stride": 0.3,
|
125 |
+
"round_base": 1,
|
126 |
+
"min_proposal_length": 3, # length * min(scales)
|
127 |
+
"clip_length": 1.5, # length should be divisible by clip_length
|
128 |
+
"proposal_interface": "SlidingWindowMSRSS",
|
129 |
+
},
|
130 |
+
"anet_cap": {
|
131 |
+
"length": 5,
|
132 |
+
"scales": [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26],
|
133 |
+
"stride": 0.3,
|
134 |
+
"round_base": 1,
|
135 |
+
"min_proposal_length": 10, # length * min(scales)
|
136 |
+
"clip_length": 5, # length * min(scales) / 2
|
137 |
+
"proposal_interface": "SlidingWindowMSRSS",
|
138 |
+
},
|
139 |
+
"charades_sta": {
|
140 |
+
"length": 3,
|
141 |
+
"scales": [2, 3, 4, 5, 6, 7, 8],
|
142 |
+
"stride": 0.3,
|
143 |
+
"round_base": 1,
|
144 |
+
"min_proposal_length": 6, # length * min(scales)
|
145 |
+
"clip_length": 3, # length * min(scales) / 2
|
146 |
+
"proposal_interface": "SlidingWindowMSRSS",
|
147 |
+
},
|
148 |
+
"profiling": {
|
149 |
+
"length": 5,
|
150 |
+
"scales": [2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
|
151 |
+
"stride": 0.3,
|
152 |
+
"round_base": 1,
|
153 |
+
"clip_length": 5, # length * min(scales) / 2
|
154 |
+
"proposal_interface": "SlidingWindowMSRSS",
|
155 |
+
},
|
156 |
+
}
|
157 |
+
"""
|
158 |
+
'min_clip_length' is used to uniformly segment the video into smaller clips, it is a half of
|
159 |
+
the 'min_proposal_length'. Thus we can enforce each moment has at least 2 clips.
|
160 |
+
"""
|
161 |
+
|
162 |
+
|
163 |
+
def get_proposal_interface(dset_name):
|
164 |
+
""" dset_name (str): one of ["tvr"] """
|
165 |
+
assert dset_name in ProposalConfigs
|
166 |
+
if dset_name == "didemo":
|
167 |
+
return DidemoICCV17SS()
|
168 |
+
else:
|
169 |
+
arg_names = ["length", "scales", "stride", "round_base"]
|
170 |
+
func_args = {k: ProposalConfigs[dset_name][k] for k in arg_names}
|
171 |
+
return SlidingWindowMSRSS(**func_args)
|
172 |
+
|
173 |
+
|
174 |
+
if __name__ == '__main__':
|
175 |
+
test_fns_args = [(DidemoICCV17SS, (),),
|
176 |
+
(SlidingWindowMSRSS, (1.5, [2, 4, 6, 12]))]
|
177 |
+
for fn_i, args_i in test_fns_args:
|
178 |
+
proposal_fn = fn_i(*args_i)
|
179 |
+
x = proposal_fn('hola', {'duration': 15})
|
180 |
+
if fn_i == DidemoICCV17SS:
|
181 |
+
assert len(x) == 21
|
baselines/clip_alignment_with_language/local_utils/tvr_proposal_test_log.txt
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
"""
|
3 |
+
{'avg_num_proposals': 158.30197338228544,
|
4 |
+
'dset_name': 'tvr',
|
5 |
+
'filename': 'data/retrieval_release_data_with_ids/tvr_val_release.jsonl',
|
6 |
+
'proposal_config': {'length': 3,
|
7 |
+
'proposal_interface': 'SlidingWindowMSRSS',
|
8 |
+
'round_base': 1,
|
9 |
+
'scales': [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16],
|
10 |
+
'stride': 0.3},
|
11 |
+
'recall_metrics': {0.5: 0.8927030563354492, 0.7: 0.6690225005149841},
|
12 |
+
'total_num_proposals': 344940}
|
13 |
+
|
14 |
+
|
15 |
+
{'avg_num_proposals': 213.3295089490592,
|
16 |
+
'dset_name': 'tvr',
|
17 |
+
'filename': 'data/retrieval_release_data_with_ids/tvr_val_release.jsonl',
|
18 |
+
'proposal_config': {'length': 3,
|
19 |
+
'min_clip_length': 1.5,
|
20 |
+
'min_proposal_length': 3,
|
21 |
+
'proposal_interface': 'SlidingWindowMSRSS',
|
22 |
+
'round_base': 0.5,
|
23 |
+
'scales': [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16],
|
24 |
+
'stride': 0.3},
|
25 |
+
'recall_metrics': {0.5: 0.9612666368484497, 0.7: 0.8215695023536682},
|
26 |
+
'total_num_proposals': 464845}
|
27 |
+
--
|
28 |
+
|
29 |
+
|
30 |
+
{'avg_num_proposals': 213.3295089490592,
|
31 |
+
'dset_name': 'tvr',
|
32 |
+
'filename': '../../data/retrieval_release_data_with_ids/tvr_val_release.jsonl',
|
33 |
+
'proposal_config': {'length': 3,
|
34 |
+
'proposal_interface': 'SlidingWindowMSRSS',
|
35 |
+
'round_base': 0.5,
|
36 |
+
'scales': [1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16],
|
37 |
+
'stride': 0.3},
|
38 |
+
'recall_metrics': {0.5: 0.9612666368484497, 0.7: 0.8215695023536682}}
|
39 |
+
|
40 |
+
|
41 |
+
{'avg_num_proposals': 263.3845800826067,
|
42 |
+
'dset_name': 'tvr',
|
43 |
+
'filename': '../../data/retrieval_release_data_with_ids/tvr_val_release.jsonl',
|
44 |
+
'proposal_config': {'length': 3,
|
45 |
+
'proposal_interface': 'SlidingWindowMSRSS',
|
46 |
+
'round_base': 0.5,
|
47 |
+
'scales': [0.5, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16],
|
48 |
+
'stride': 0.3},
|
49 |
+
'recall_metrics': {0.5: 0.9841211438179016, 0.7: 0.8567232489585876}}
|
50 |
+
|
51 |
+
|
52 |
+
{'avg_num_proposals': 242.97246443322626,
|
53 |
+
'dset_name': 'tvr',
|
54 |
+
'filename': '../../data/retrieval_release_data_with_ids/tvr_val_release.jsonl',
|
55 |
+
'proposal_config': {'length': 3,
|
56 |
+
'proposal_interface': 'SlidingWindowMSRSS',
|
57 |
+
'round_base': 0.5,
|
58 |
+
'scales': [0.5, 1, 2, 3, 4, 5, 6, 7, 8],
|
59 |
+
'stride': 0.3},
|
60 |
+
'recall_metrics': {0.5: 0.9608076810836792, 0.7: 0.8212941884994507}}
|
61 |
+
"""
|
baselines/clip_alignment_with_language/mix_model_prediction.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Implement the CAL + CAL (TEF) model mentioned in
|
3 |
+
```
|
4 |
+
@article{Escorcia2019TemporalLO,
|
5 |
+
title={Temporal Localization of Moments in Video Collections with Natural Language},
|
6 |
+
author={Victor Escorcia and Mattia Soldan and Josef Sivic and Bernard Ghanem and Bryan Russell},
|
7 |
+
journal={ArXiv},
|
8 |
+
year={2019},
|
9 |
+
volume={abs/1907.12763}
|
10 |
+
}
|
11 |
+
```
|
12 |
+
|
13 |
+
Methods:
|
14 |
+
1, Give top200 predictions for each query in CAL then using CAL (TEF) to re-rank.
|
15 |
+
2, This is approximated by re-ranking the top200 CAL using top1000 CAL(TEF) -- we assume they will be all covered.
|
16 |
+
"""
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import subprocess
|
20 |
+
import numpy as np
|
21 |
+
from tqdm import tqdm
|
22 |
+
from utils.basic_utils import load_json, save_json
|
23 |
+
|
24 |
+
|
25 |
+
def load_saved_res(pred_path):
|
26 |
+
if pred_path.endswith(".json"):
|
27 |
+
pred = load_json(pred_path)
|
28 |
+
else:
|
29 |
+
pred = torch.load(pred_path)
|
30 |
+
vcmr_res = {e["desc_id"]: e for e in pred["VCMR"]}
|
31 |
+
video2idx = pred["video2idx"]
|
32 |
+
return vcmr_res, video2idx
|
33 |
+
|
34 |
+
|
35 |
+
def main_mix_results(pred_path, tef_pred_path, save_path, max_after_nms=100):
|
36 |
+
"""
|
37 |
+
Args:
|
38 |
+
pred_path: contains top-200 VCMR predictions
|
39 |
+
tef_pred_path: contains top-1000 VCMR predictions
|
40 |
+
save_path:
|
41 |
+
max_after_nms: int,
|
42 |
+
Returns:
|
43 |
+
save
|
44 |
+
"""
|
45 |
+
vcmr_res, video2idx = load_saved_res(pred_path)
|
46 |
+
tef_vcmr_res, video2idx = load_saved_res(tef_pred_path)
|
47 |
+
|
48 |
+
reranked_vcmr_res = {}
|
49 |
+
num_valid = []
|
50 |
+
for desc_id, preds in tqdm(vcmr_res.items(), desc="Loop over the predictions"):
|
51 |
+
tef_preds = tef_vcmr_res[desc_id]["predictions"]
|
52 |
+
pred_moments = set([tuple(e[:3]) for e in preds["predictions"]])
|
53 |
+
reranked_moments = [e for e in tef_preds if tuple(e[:3]) in pred_moments][:max_after_nms]
|
54 |
+
num_valid += [len(reranked_moments)]
|
55 |
+
if len(reranked_moments) != 100:
|
56 |
+
reranked_moments += reranked_moments[:100 - len(reranked_moments)]
|
57 |
+
reranked_vcmr_res[desc_id] = dict(
|
58 |
+
predictions=reranked_moments,
|
59 |
+
desc_id=desc_id,
|
60 |
+
desc=preds["desc"]
|
61 |
+
)
|
62 |
+
|
63 |
+
print("There are {} moments founded on average".format(np.mean(num_valid)))
|
64 |
+
reranked_predictions = dict(
|
65 |
+
VCMR=list(reranked_vcmr_res.values()),
|
66 |
+
video2idx=video2idx
|
67 |
+
)
|
68 |
+
|
69 |
+
save_json(reranked_predictions, save_path)
|
70 |
+
|
71 |
+
|
72 |
+
if __name__ == '__main__':
|
73 |
+
import argparse
|
74 |
+
parser = argparse.ArgumentParser()
|
75 |
+
parser.add_argument("--pred_path", type=str, help="path to prediction res")
|
76 |
+
parser.add_argument("--tef_pred_path", type=str, help="path to TEF prediction res")
|
77 |
+
parser.add_argument("--save_path", type=str, help="path to save the re-ranked predictions, same dir as --pred_path")
|
78 |
+
parser.add_argument("--gt_path", type=str, help="path to ground truth file")
|
79 |
+
args = parser.parse_args()
|
80 |
+
|
81 |
+
main_mix_results(args.pred_path, args.tef_pred_path, args.save_path)
|
82 |
+
|
83 |
+
metrics_path = args.save_path.replace(".json", "_metrics.json")
|
84 |
+
eval_cmd = "python standalone_eval/eval.py --submission_path " + args.save_path + " --gt_path " + args.gt_path + \
|
85 |
+
" --save_path " + metrics_path
|
86 |
+
results = subprocess.run(eval_cmd, shell=True)
|
baselines/clip_alignment_with_language/model.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from utils.model_utils import RNNEncoder
|
5 |
+
from easydict import EasyDict as edict
|
6 |
+
|
7 |
+
|
8 |
+
cal_base_cfg = edict(
|
9 |
+
visual_input_size=2048, # changes based on visual input type
|
10 |
+
textual_input_size=768,
|
11 |
+
query_feat_size=768,
|
12 |
+
visual_hidden_size=500, #
|
13 |
+
output_size=100,
|
14 |
+
embedding_size=768,
|
15 |
+
lstm_hidden_size=1000,
|
16 |
+
margin=0.1, # margin for ranking loss
|
17 |
+
loss_type="hinge", # loss type, 'hinge' or 'lse'
|
18 |
+
inter_loss_weight=0.4, # weight for inter negatives
|
19 |
+
ctx_mode="video"
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
class CAL(nn.Module):
|
24 |
+
def __init__(self, config):
|
25 |
+
super(CAL, self).__init__()
|
26 |
+
self.config = config
|
27 |
+
|
28 |
+
self.moment_mlp = nn.Sequential(
|
29 |
+
nn.Linear(config.visual_input_size, config.visual_hidden_size),
|
30 |
+
nn.ReLU(True),
|
31 |
+
nn.Linear(config.visual_hidden_size, config.output_size),
|
32 |
+
)
|
33 |
+
|
34 |
+
self.query_lstm = RNNEncoder(word_embedding_size=config.embedding_size,
|
35 |
+
hidden_size=config.lstm_hidden_size,
|
36 |
+
bidirectional=False,
|
37 |
+
rnn_type="lstm",
|
38 |
+
dropout_p=0,
|
39 |
+
n_layers=1,
|
40 |
+
return_outputs=False)
|
41 |
+
|
42 |
+
self.query_linear = nn.Linear(config.lstm_hidden_size, config.output_size)
|
43 |
+
|
44 |
+
def moment_encoder(self, moment_feat):
|
45 |
+
"""moment_feat: (N, L_clip, D_v)"""
|
46 |
+
return F.normalize(self.moment_mlp(moment_feat), p=2, dim=-1) # (N, L_clip, D_o)
|
47 |
+
|
48 |
+
def query_encoder(self, query_feat, query_mask):
|
49 |
+
"""
|
50 |
+
Args:
|
51 |
+
query_feat: (N, L_q, D_q), torch.float32
|
52 |
+
query_mask: (N, L_q), torch.float32, with 1 indicates valid query, 0 indicates mask
|
53 |
+
"""
|
54 |
+
_, hidden = self.query_lstm(query_feat, torch.sum(query_mask, dim=1).long())
|
55 |
+
return F.normalize(self.query_linear(hidden), p=2, dim=-1) # (N, D_o)
|
56 |
+
|
57 |
+
def compute_pdist(self, query_embedding, moment_feat, moment_mask):
|
58 |
+
""" pairwise L2 distance
|
59 |
+
Args:
|
60 |
+
query_embedding: (N, D_o)
|
61 |
+
moment_feat: (N, L_clip, D_v)
|
62 |
+
moment_mask: (N, L_clip), torch.float32, where 1 indicates valid, 0 indicates padding
|
63 |
+
"""
|
64 |
+
moment_embedding = self.moment_encoder(moment_feat) # (N, L_clip, D_o)
|
65 |
+
moment_clip_dist = torch.sum((moment_embedding - query_embedding.unsqueeze(1)) ** 2, dim=2) # (N, L_clip)
|
66 |
+
moment_dist = torch.sum(moment_clip_dist * moment_mask, dim=1) / moment_mask.sum(1) # (N, )
|
67 |
+
return moment_dist # (N, )
|
68 |
+
|
69 |
+
@classmethod
|
70 |
+
def compute_cdist_inference(cls, query_embeddings, moment_embeddings, moment_mask):
|
71 |
+
""" Compute L2 distance for every possible pair of queries and proposals. This is different from
|
72 |
+
compute_pdist as the latter computes only pairs at each row.
|
73 |
+
Args:
|
74 |
+
query_embeddings: (N_q, D_o)
|
75 |
+
moment_embeddings: (N_prop, N_clips, D_o)
|
76 |
+
moment_mask: (N_prop, N_clips)
|
77 |
+
return:
|
78 |
+
query_moment_scores: (N_q, N_prop)
|
79 |
+
"""
|
80 |
+
# sync device
|
81 |
+
query_device = query_embeddings.device # convert to cuda if we want to use GPU
|
82 |
+
if moment_embeddings.device != query_device:
|
83 |
+
moment_embeddings = moment_embeddings.to(query_device)
|
84 |
+
moment_mask = moment_mask.to(query_device)
|
85 |
+
|
86 |
+
# compute
|
87 |
+
n_query = query_embeddings.shape[0]
|
88 |
+
n_prop, n_clips, d = moment_embeddings.shape
|
89 |
+
query_clip_dist = torch.cdist(
|
90 |
+
query_embeddings, moment_embeddings.reshape(-1, d), p=2) ** 2 # (N_q, N_prop * N_clips)
|
91 |
+
query_clip_dist = query_clip_dist.reshape(n_query, n_prop, n_clips)
|
92 |
+
query_moment_dist = torch.sum(
|
93 |
+
query_clip_dist * moment_mask.unsqueeze(0), dim=2) / moment_mask.sum(1).unsqueeze(0)
|
94 |
+
return query_moment_dist # (N_q, N_prop)
|
95 |
+
|
96 |
+
def forward(self, query_feat, query_mask, pos_moment_feat, pos_moment_mask,
|
97 |
+
intra_neg_moment_feat, intra_neg_moment_mask,
|
98 |
+
inter_neg_moment_feat, inter_neg_moment_mask):
|
99 |
+
"""
|
100 |
+
Args:
|
101 |
+
query_feat: (N, L, D_q)
|
102 |
+
query_mask: (N, L)
|
103 |
+
pos_moment_feat: (N, L_clip_1, D_v)
|
104 |
+
pos_moment_mask: (N, L_clip_1)
|
105 |
+
intra_neg_moment_feat: (N, L_clip_2, D_v)
|
106 |
+
intra_neg_moment_mask: (N, L_clip_2)
|
107 |
+
inter_neg_moment_feat: (N, L_clip_3, D_v)
|
108 |
+
inter_neg_moment_mask: (N, L_clip_2)
|
109 |
+
"""
|
110 |
+
query_embed = self.query_encoder(query_feat, query_mask) # (N, D_o)
|
111 |
+
pos_dist = self.compute_pdist(query_embed, pos_moment_feat, pos_moment_mask) # (N, )
|
112 |
+
intra_neg_dist = self.compute_pdist(query_embed, intra_neg_moment_feat, intra_neg_moment_mask) # (N, )
|
113 |
+
if self.config.inter_loss_weight == 0: # should be zero for tef_only method.
|
114 |
+
loss_inter = 0.
|
115 |
+
else:
|
116 |
+
inter_neg_dist = self.compute_pdist(query_embed, inter_neg_moment_feat, inter_neg_moment_mask) # (N, )
|
117 |
+
loss_inter = self.calc_loss(pos_dist, inter_neg_dist)
|
118 |
+
|
119 |
+
loss = self.calc_loss(pos_dist, intra_neg_dist) + self.config.inter_loss_weight * loss_inter
|
120 |
+
return loss
|
121 |
+
|
122 |
+
def calc_loss(self, pos_dist, neg_dist):
|
123 |
+
""" Note here we encourage positive distance to be smaller than negative distance.
|
124 |
+
Args:
|
125 |
+
pos_dist: (N, ), torch.float32
|
126 |
+
neg_dist: (N, ), torch.float32
|
127 |
+
"""
|
128 |
+
if self.config.loss_type == "hinge": # max(0, m + S_pos - S_neg)
|
129 |
+
return torch.clamp(self.config.margin + pos_dist - neg_dist, min=0).sum() / len(pos_dist)
|
130 |
+
elif self.config.loss_type == "lse": # log[1 + exp(S_pos - S_neg)]
|
131 |
+
return torch.log1p(torch.exp(pos_dist - neg_dist)).sum() / len(pos_dist)
|
132 |
+
else:
|
133 |
+
raise NotImplementedError("Only support 'hinge' and 'lse'")
|
134 |
+
|
135 |
+
|
136 |
+
class CALWithSub(nn.Module):
|
137 |
+
def __init__(self, config):
|
138 |
+
super(CALWithSub, self).__init__()
|
139 |
+
self.config = config
|
140 |
+
self.use_video = "video" in config.ctx_mode
|
141 |
+
self.use_sub = "sub" in config.ctx_mode
|
142 |
+
self.use_tef = "tef" in config.ctx_mode
|
143 |
+
self.tef_only = self.use_tef and not self.use_video and not self.use_sub
|
144 |
+
|
145 |
+
if self.use_video or self.tef_only:
|
146 |
+
self.video_moment_mlp = nn.Sequential(
|
147 |
+
nn.Linear(config.visual_input_size, config.visual_hidden_size),
|
148 |
+
nn.ReLU(True),
|
149 |
+
nn.Linear(config.visual_hidden_size, config.output_size),
|
150 |
+
)
|
151 |
+
|
152 |
+
if self.use_sub:
|
153 |
+
self.sub_moment_mlp = nn.Sequential(
|
154 |
+
nn.Linear(config.textual_input_size, config.visual_hidden_size),
|
155 |
+
nn.ReLU(True),
|
156 |
+
nn.Linear(config.visual_hidden_size, config.output_size),
|
157 |
+
)
|
158 |
+
|
159 |
+
self.query_lstm = RNNEncoder(word_embedding_size=config.query_feat_size,
|
160 |
+
hidden_size=config.lstm_hidden_size,
|
161 |
+
bidirectional=False,
|
162 |
+
rnn_type="lstm",
|
163 |
+
dropout_p=0,
|
164 |
+
n_layers=1,
|
165 |
+
return_outputs=False)
|
166 |
+
|
167 |
+
self.query_linear = nn.Linear(config.lstm_hidden_size, config.output_size)
|
168 |
+
|
169 |
+
def moment_encoder(self, moment_feat, module_name="video"):
|
170 |
+
"""moment_feat: (N, L_clip, D_v)"""
|
171 |
+
if moment_feat is not None:
|
172 |
+
encoder = getattr(self, module_name + "_moment_mlp")
|
173 |
+
return F.normalize(encoder(moment_feat), p=2, dim=-1) # (N, L_clip, D_o)
|
174 |
+
else:
|
175 |
+
return None
|
176 |
+
|
177 |
+
def query_encoder(self, query_feat, query_mask):
|
178 |
+
"""
|
179 |
+
Args:
|
180 |
+
query_feat: (N, L_q, D_q), torch.float32
|
181 |
+
query_mask: (N, L_q), torch.float32, with 1 indicates valid query, 0 indicates mask
|
182 |
+
"""
|
183 |
+
_, hidden = self.query_lstm(query_feat, torch.sum(query_mask, dim=1).long())
|
184 |
+
return F.normalize(self.query_linear(hidden), p=2, dim=-1) # (N, D_o)
|
185 |
+
|
186 |
+
def _compute_pdist(self, query_embedding, moment_feat, moment_mask, module_name="video"):
|
187 |
+
""" pairwise L2 distance
|
188 |
+
Args:
|
189 |
+
query_embedding: (N, D_o)
|
190 |
+
moment_feat: (N, L_clip, D_v)
|
191 |
+
moment_mask: (N, L_clip), torch.float32, where 1 indicates valid, 0 indicates padding
|
192 |
+
"""
|
193 |
+
moment_embedding = self.moment_encoder(moment_feat, module_name=module_name) # (N, L_clip, D_o)
|
194 |
+
moment_clip_dist = torch.sum((moment_embedding - query_embedding.unsqueeze(1)) ** 2, dim=2) # (N, L_clip)
|
195 |
+
moment_dist = torch.sum(moment_clip_dist * moment_mask, dim=1) / moment_mask.sum(1) # (N, )
|
196 |
+
return moment_dist # (N, )
|
197 |
+
|
198 |
+
def compute_pdist(self, query_embedding, moment_video_feat, moment_sub_feat, moment_mask):
|
199 |
+
""" pairwise L2 distance
|
200 |
+
Args:
|
201 |
+
query_embedding: (N, D_o)
|
202 |
+
moment_video_feat: (N, L_clip, D_v)
|
203 |
+
moment_sub_feat: (N, L_clip, D_t)
|
204 |
+
moment_mask: (N, L_clip), torch.float32, where 1 indicates valid, 0 indicates padding
|
205 |
+
"""
|
206 |
+
divisor = (self.use_video or self.tef_only) + self.use_sub
|
207 |
+
video_moment_dist = self._compute_pdist(query_embedding, moment_video_feat, moment_mask, module_name="video") \
|
208 |
+
if self.use_video or self.tef_only else 0
|
209 |
+
sub_moment_dist = self._compute_pdist(query_embedding, moment_sub_feat, moment_mask, module_name="sub") \
|
210 |
+
if self.use_sub else 0
|
211 |
+
return (video_moment_dist + sub_moment_dist) / divisor # (N, )
|
212 |
+
|
213 |
+
def _compute_cdist_inference(self, query_embeddings, moment_embeddings, moment_mask):
|
214 |
+
""" Compute L2 distance for every possible pair of queries and proposals. This is different from
|
215 |
+
compute_pdist as the latter computes only pairs at each row.
|
216 |
+
Args:
|
217 |
+
query_embeddings: (N_q, D_o)
|
218 |
+
moment_embeddings: (N_prop, N_clips, D_o)
|
219 |
+
moment_mask: (N_prop, N_clips)
|
220 |
+
return:
|
221 |
+
query_moment_scores: (N_q, N_prop)
|
222 |
+
"""
|
223 |
+
# sync device
|
224 |
+
query_device = query_embeddings.device # convert to cuda if we want to use GPU
|
225 |
+
if moment_embeddings.device != query_device:
|
226 |
+
moment_embeddings = moment_embeddings.to(query_device)
|
227 |
+
moment_mask = moment_mask.to(query_device)
|
228 |
+
|
229 |
+
# compute
|
230 |
+
n_query = query_embeddings.shape[0]
|
231 |
+
n_prop, n_clips, d = moment_embeddings.shape
|
232 |
+
query_clip_dist = torch.cdist(
|
233 |
+
query_embeddings, moment_embeddings.reshape(-1, d), p=2) ** 2 # (N_q, N_prop * N_clips)
|
234 |
+
query_clip_dist = query_clip_dist.reshape(n_query, n_prop, n_clips)
|
235 |
+
query_moment_dist = torch.sum(
|
236 |
+
query_clip_dist * moment_mask.unsqueeze(0), dim=2) / moment_mask.sum(1).unsqueeze(0)
|
237 |
+
return query_moment_dist # (N_q, N_prop)
|
238 |
+
|
239 |
+
def compute_cdist_inference(self, query_embeddings, video_moment_embeddings, sub_moment_embeddings, moment_mask):
|
240 |
+
divisor = (self.use_video or self.tef_only) + self.use_sub
|
241 |
+
video_moment_dist = self._compute_cdist_inference(query_embeddings, video_moment_embeddings, moment_mask) \
|
242 |
+
if self.use_video or self.tef_only else 0
|
243 |
+
sub_moment_dist = self._compute_cdist_inference(query_embeddings, sub_moment_embeddings, moment_mask) \
|
244 |
+
if self.use_sub else 0
|
245 |
+
return (video_moment_dist + sub_moment_dist) / divisor # (N_q, N_prop)
|
246 |
+
|
247 |
+
def forward(self, query_feat, query_mask, pos_moment_video_feat, pos_moment_video_mask,
|
248 |
+
intra_neg_moment_video_feat, intra_neg_moment_video_mask,
|
249 |
+
inter_neg_moment_video_feat, inter_neg_moment_video_mask,
|
250 |
+
pos_moment_sub_feat, pos_moment_sub_mask,
|
251 |
+
intra_neg_moment_sub_feat, intra_neg_moment_sub_mask,
|
252 |
+
inter_neg_moment_sub_feat, inter_neg_moment_sub_mask):
|
253 |
+
"""
|
254 |
+
Args:
|
255 |
+
query_feat: (N, L, D_q)
|
256 |
+
query_mask: (N, L)
|
257 |
+
pos_moment_video_feat: (N, L_clip_1, D_v)
|
258 |
+
pos_moment_video_mask: (N, L_clip_1)
|
259 |
+
intra_neg_moment_video_feat: (N, L_clip_2, D_v)
|
260 |
+
intra_neg_moment_video_mask: (N, L_clip_2)
|
261 |
+
inter_neg_moment_video_feat: (N, L_clip_3, D_v)
|
262 |
+
inter_neg_moment_video_mask: (N, L_clip_2)
|
263 |
+
pos_moment_sub_feat:
|
264 |
+
pos_moment_sub_mask:
|
265 |
+
intra_neg_moment_sub_feat:
|
266 |
+
intra_neg_moment_sub_mask:
|
267 |
+
inter_neg_moment_sub_feat:
|
268 |
+
inter_neg_moment_sub_mask:
|
269 |
+
"""
|
270 |
+
query_embed = self.query_encoder(query_feat, query_mask) # (N, D_o)
|
271 |
+
pos_dist = self.compute_pdist(
|
272 |
+
query_embed, pos_moment_video_feat, pos_moment_sub_feat,
|
273 |
+
moment_mask=pos_moment_sub_mask if self.use_sub else pos_moment_video_mask) # (N, )
|
274 |
+
intra_neg_dist = self.compute_pdist(
|
275 |
+
query_embed, intra_neg_moment_video_feat, intra_neg_moment_sub_feat,
|
276 |
+
moment_mask=intra_neg_moment_sub_mask if self.use_sub else intra_neg_moment_video_mask) # (N, )
|
277 |
+
if self.config.inter_loss_weight == 0: # should be zero for tef_only method.
|
278 |
+
loss_inter = 0.
|
279 |
+
else:
|
280 |
+
inter_neg_dist = self.compute_pdist(
|
281 |
+
query_embed, inter_neg_moment_video_feat, inter_neg_moment_sub_feat,
|
282 |
+
moment_mask=inter_neg_moment_sub_mask if self.use_sub else inter_neg_moment_video_mask) # (N, )
|
283 |
+
loss_inter = self.calc_loss(pos_dist, inter_neg_dist)
|
284 |
+
|
285 |
+
loss = self.calc_loss(pos_dist, intra_neg_dist) + self.config.inter_loss_weight * loss_inter
|
286 |
+
return loss
|
287 |
+
|
288 |
+
def calc_loss(self, pos_dist, neg_dist):
|
289 |
+
""" Note here we encourage positive distance to be smaller than negative distance.
|
290 |
+
Args:
|
291 |
+
pos_dist: (N, ), torch.float32
|
292 |
+
neg_dist: (N, ), torch.float32
|
293 |
+
"""
|
294 |
+
if self.config.loss_type == "hinge": # max(0, m + S_pos - S_neg)
|
295 |
+
return torch.clamp(self.config.margin + pos_dist - neg_dist, min=0).sum() / len(pos_dist)
|
296 |
+
elif self.config.loss_type == "lse": # log[1 + exp(S_pos - S_neg)]
|
297 |
+
return torch.log1p(torch.exp(pos_dist - neg_dist)).sum() / len(pos_dist)
|
298 |
+
else:
|
299 |
+
raise NotImplementedError("Only support 'hinge' and 'lse'")
|
baselines/clip_alignment_with_language/proposal_retrieval_dataset.py
ADDED
@@ -0,0 +1,587 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Dataset for clip model
|
3 |
+
"""
|
4 |
+
import logging
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
import numpy as np
|
8 |
+
import h5py
|
9 |
+
import math
|
10 |
+
import random
|
11 |
+
from utils.basic_utils import load_jsonl, load_json, l2_normalize_np_array
|
12 |
+
from utils.tensor_utils import pad_sequences_1d
|
13 |
+
from baselines.clip_alignment_with_language.local_utils.proposal import get_proposal_interface
|
14 |
+
from baselines.clip_alignment_with_language.local_utils.compute_proposal_upper_bound import \
|
15 |
+
get_didemo_agreed_ts
|
16 |
+
from standalone_eval.eval import compute_temporal_iou_batch
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
|
20 |
+
|
21 |
+
class ProposalRetrievalDataset(Dataset):
|
22 |
+
"""
|
23 |
+
Args:
|
24 |
+
dset_name, str, ["tvr"]
|
25 |
+
ctx_mode: str,
|
26 |
+
pos_iou_thd: float, in [0, 1], >= pos_iou_thd are defined as positive
|
27 |
+
neg_iou_thd: float, in [0, 1], < neg_iou_thd are defined as negative
|
28 |
+
Return:
|
29 |
+
a dict: {
|
30 |
+
"meta": {
|
31 |
+
"desc_id": int,
|
32 |
+
"desc": str,
|
33 |
+
"vid_name": str,
|
34 |
+
"duration": float,
|
35 |
+
"ts": [st (float), ed (float)], seconds, ground_truth timestamps
|
36 |
+
"pos_moment": [st (float), ed (float)], seconds, IoU with "ts" >= pos_iou_thd
|
37 |
+
"intra_neg_moment": [st (float), ed (float)], seconds, IoU with "ts" < neg_iou_thd
|
38 |
+
"inter_neg_vid_name": str,
|
39 |
+
"inter_neg_duration": float,
|
40 |
+
"inter_neg_moment": [st (float), ed (float)], seconds, IoU with "ts" < neg_iou_thd
|
41 |
+
}
|
42 |
+
"model_inputs": {
|
43 |
+
"desc_feat": torch.tensor, (L, D_t)
|
44 |
+
"pos_moment_feat": torch.tensor, (n_clip_in_moment, D)
|
45 |
+
"intra_neg_moment_feat": torch.tensor, (n_clip_in_moment, D)
|
46 |
+
"inter_neg_moment_feat": torch.tensor, (n_clip_in_moment, D)
|
47 |
+
}
|
48 |
+
}
|
49 |
+
"""
|
50 |
+
def __init__(self, dset_name, data_path, desc_bert_path, sub_bert_path, max_desc_len,
|
51 |
+
vid_feat_path, clip_length, vid_feat_size, sub_feat_size=0, ctx_mode="video_tef",
|
52 |
+
pos_iou_thd=0.7, neg_iou_thd=0.3, h5driver=None, data_ratio=1.0,
|
53 |
+
normalize_vfeat=True, normalize_tfeat=True, model_type="cal",
|
54 |
+
external_train_vr_res_path=None, corpus_path=None):
|
55 |
+
self.dset_name = dset_name
|
56 |
+
self.model_type = model_type
|
57 |
+
self.pool_local = model_type == "mcn" # pool local feature
|
58 |
+
self.data_path = data_path
|
59 |
+
self.data_ratio = data_ratio
|
60 |
+
|
61 |
+
self.desc_bert_path = desc_bert_path
|
62 |
+
self.max_desc_len = max_desc_len
|
63 |
+
self.sub_bert_path = sub_bert_path
|
64 |
+
|
65 |
+
self.vid_feat_path = vid_feat_path
|
66 |
+
self.clip_length = clip_length
|
67 |
+
self.ctx_mode = ctx_mode
|
68 |
+
|
69 |
+
self.pos_iou_thd = pos_iou_thd
|
70 |
+
self.neg_iou_thd = neg_iou_thd
|
71 |
+
|
72 |
+
self.vid_feat_output_size = 2 * vid_feat_size * ("video" in ctx_mode) + 2 * ("tef" in ctx_mode)
|
73 |
+
self.sub_feat_output_size = 2 * sub_feat_size * ("sub" in ctx_mode) + 2 * ("tef" in ctx_mode)
|
74 |
+
|
75 |
+
# prepare desc data
|
76 |
+
self.data = load_jsonl(data_path)
|
77 |
+
if self.data_ratio != 1:
|
78 |
+
n_examples = int(len(self.data) * data_ratio)
|
79 |
+
self.data = self.data[:n_examples]
|
80 |
+
logger.info("Using {}% of the data: {} examples".format(data_ratio * 100, n_examples))
|
81 |
+
|
82 |
+
self.proposal_fn = get_proposal_interface(dset_name)
|
83 |
+
if self.ctx_mode != "tef":
|
84 |
+
self.vid_feat_h5 = h5py.File(self.vid_feat_path, "r", driver=h5driver)
|
85 |
+
self.desc_bert_h5 = h5py.File(self.desc_bert_path, "r", driver=h5driver)
|
86 |
+
if "sub" in self.ctx_mode:
|
87 |
+
self.sub_bert_h5 = h5py.File(self.sub_bert_path, "r", driver=h5driver)
|
88 |
+
self.normalize_vfeat = normalize_vfeat
|
89 |
+
self.normalize_tfeat = normalize_tfeat
|
90 |
+
self.use_video = "video" in self.ctx_mode
|
91 |
+
self.use_sub = "sub" in self.ctx_mode
|
92 |
+
self.use_tef = "tef" in self.ctx_mode
|
93 |
+
|
94 |
+
if external_train_vr_res_path is not None:
|
95 |
+
video_data = load_json(corpus_path)["train"]
|
96 |
+
# {video_idx: [vid_name, vid_duration]}
|
97 |
+
video_idx2name_dur_pair = {v[1]: [k, v[0]] for k, v in video_data.items()}
|
98 |
+
external_vr_res = load_json(external_train_vr_res_path)
|
99 |
+
# {desc_id: [(vid_name, vid_duration), ...]}
|
100 |
+
self.desc_id2video_names_dur_pairs = \
|
101 |
+
{e["desc_id"]: [video_idx2name_dur_pair[int(sub_e[0])] for sub_e in e["predictions"]]
|
102 |
+
for e in external_vr_res["VR"]} # ordered
|
103 |
+
|
104 |
+
def __len__(self):
|
105 |
+
return len(self.data)
|
106 |
+
|
107 |
+
def __getitem__(self, index):
|
108 |
+
raw_data = self.data[index]
|
109 |
+
|
110 |
+
# initialize with basic data
|
111 |
+
meta = dict(
|
112 |
+
desc_id=raw_data["desc_id"],
|
113 |
+
desc=raw_data["desc"],
|
114 |
+
vid_name=raw_data["vid_name"],
|
115 |
+
duration=raw_data["duration"],
|
116 |
+
ts=raw_data["ts"] if self.dset_name != "didemo" else get_didemo_agreed_ts(raw_data["ts"]),
|
117 |
+
)
|
118 |
+
model_inputs = dict()
|
119 |
+
query_feat = self.desc_bert_h5[str(raw_data["desc_id"])][:self.max_desc_len]
|
120 |
+
if self.normalize_tfeat:
|
121 |
+
query_feat = l2_normalize_np_array(query_feat)
|
122 |
+
model_inputs["query_feat"] = torch.from_numpy(query_feat)
|
123 |
+
|
124 |
+
# sample positive and negative moments
|
125 |
+
meta["pos_moment"] = self.align_ts_to_clip_boundaries(meta["duration"], meta["ts"])
|
126 |
+
meta["intra_neg_moment"] = self.sample_intra_neg_moment(meta["duration"], meta["ts"])
|
127 |
+
meta["inter_neg_moment"], meta["inter_neg_vid_name"], meta["inter_neg_duration"] = \
|
128 |
+
self.sample_inter_video_negative(meta["vid_name"], meta["pos_moment"] / meta["duration"],
|
129 |
+
desc_id=meta["desc_id"])
|
130 |
+
|
131 |
+
pos_tef, intra_neg_tef, inter_neg_tef = (None,) * 3
|
132 |
+
if self.use_tef:
|
133 |
+
pos_tef = meta["pos_moment"] / meta["duration"] # temporal endpoint feature, (2, )
|
134 |
+
intra_neg_tef = meta["intra_neg_moment"] / meta["duration"]
|
135 |
+
inter_neg_tef = meta["inter_neg_moment"] / meta["inter_neg_duration"]
|
136 |
+
|
137 |
+
if self.use_video:
|
138 |
+
pos_v_feat = self.vid_feat_h5[meta["vid_name"]] # (N_frm, D)
|
139 |
+
neg_v_feat = self.vid_feat_h5[meta["inter_neg_vid_name"]]
|
140 |
+
pos_v_ctx_feat = np.mean(pos_v_feat, axis=0)
|
141 |
+
neg_v_ctx_feat = np.mean(neg_v_feat, axis=0)
|
142 |
+
if self.normalize_vfeat:
|
143 |
+
pos_v_ctx_feat = l2_normalize_np_array(pos_v_ctx_feat)
|
144 |
+
neg_v_ctx_feat = l2_normalize_np_array(neg_v_ctx_feat)
|
145 |
+
pos_moment_v_feat = self.get_moment_feat(pos_v_feat, meta["pos_moment"],
|
146 |
+
normalize=self.normalize_vfeat,
|
147 |
+
fix_outbound=True, pool_local=self.pool_local)
|
148 |
+
intra_neg_moment_v_feat = self.get_moment_feat(pos_v_feat, meta["intra_neg_moment"],
|
149 |
+
normalize=self.normalize_vfeat,
|
150 |
+
fix_outbound=True, pool_local=self.pool_local)
|
151 |
+
inter_neg_moment_v_feat = self.get_moment_feat(neg_v_feat, meta["inter_neg_moment"],
|
152 |
+
normalize=self.normalize_vfeat,
|
153 |
+
fix_outbound=True, pool_local=self.pool_local)
|
154 |
+
|
155 |
+
# concat features, [video_clip_feat; video_context_feat; temporal_endpoint_feat]
|
156 |
+
model_inputs["pos_moment_video_feat"] = self.concat_feat_adv(
|
157 |
+
moment_feats=[pos_moment_v_feat, pos_v_ctx_feat], tef=pos_tef, ctx_mode=self.ctx_mode)
|
158 |
+
model_inputs["intra_neg_moment_video_feat"] = self.concat_feat_adv(
|
159 |
+
moment_feats=[intra_neg_moment_v_feat, pos_v_ctx_feat], tef=intra_neg_tef, ctx_mode=self.ctx_mode)
|
160 |
+
model_inputs["inter_neg_moment_video_feat"] = self.concat_feat_adv(
|
161 |
+
moment_feats=[inter_neg_moment_v_feat, neg_v_ctx_feat], tef=inter_neg_tef, ctx_mode=self.ctx_mode)
|
162 |
+
else:
|
163 |
+
for k in ["pos_moment_video_feat", "intra_neg_moment_video_feat", "inter_neg_moment_video_feat"]:
|
164 |
+
model_inputs[k] = torch.zeros((2, 2))
|
165 |
+
|
166 |
+
if self.use_sub: # no need for ctx feature, as the features are already contextulized
|
167 |
+
pos_s_feat = self.sub_bert_h5[meta["vid_name"]] # (N_words, D_t)
|
168 |
+
neg_s_feat = self.sub_bert_h5[meta["inter_neg_vid_name"]]
|
169 |
+
pos_s_ctx_feat = np.mean(pos_s_feat, axis=0)
|
170 |
+
neg_s_ctx_feat = np.mean(neg_s_feat, axis=0)
|
171 |
+
if self.normalize_tfeat:
|
172 |
+
pos_s_ctx_feat = l2_normalize_np_array(pos_s_ctx_feat)
|
173 |
+
neg_s_ctx_feat = l2_normalize_np_array(neg_s_ctx_feat)
|
174 |
+
pos_moment_s_feat = self.get_moment_feat(pos_s_feat, meta["pos_moment"],
|
175 |
+
normalize=self.normalize_tfeat,
|
176 |
+
fix_outbound=True, pool_local=self.pool_local)
|
177 |
+
intra_neg_moment_s_feat = self.get_moment_feat(pos_s_feat, meta["intra_neg_moment"],
|
178 |
+
normalize=self.normalize_tfeat,
|
179 |
+
fix_outbound=True, pool_local=self.pool_local)
|
180 |
+
inter_neg_moment_s_feat = self.get_moment_feat(neg_s_feat, meta["inter_neg_moment"],
|
181 |
+
normalize=self.normalize_tfeat,
|
182 |
+
fix_outbound=True, pool_local=self.pool_local)
|
183 |
+
|
184 |
+
# concat features, [sub_clip_feat; sub_context_feat; temporal_endpoint_feat]
|
185 |
+
model_inputs["pos_moment_sub_feat"] = self.concat_feat_adv(
|
186 |
+
moment_feats=[pos_moment_s_feat, pos_s_ctx_feat], tef=pos_tef, ctx_mode=self.ctx_mode)
|
187 |
+
model_inputs["intra_neg_moment_sub_feat"] = self.concat_feat_adv(
|
188 |
+
moment_feats=[intra_neg_moment_s_feat, pos_s_ctx_feat], tef=intra_neg_tef, ctx_mode=self.ctx_mode)
|
189 |
+
model_inputs["inter_neg_moment_sub_feat"] = self.concat_feat_adv(
|
190 |
+
moment_feats=[inter_neg_moment_s_feat, neg_s_ctx_feat], tef=inter_neg_tef, ctx_mode=self.ctx_mode)
|
191 |
+
else:
|
192 |
+
for k in ["pos_moment_sub_feat", "intra_neg_moment_sub_feat", "inter_neg_moment_sub_feat"]:
|
193 |
+
model_inputs[k] = torch.zeros((2, 2))
|
194 |
+
|
195 |
+
if not self.use_sub and not self.use_video and self.use_tef: # use video stream
|
196 |
+
model_inputs["pos_moment_video_feat"] = \
|
197 |
+
self.concat_feat_adv(tef=pos_tef, ctx_mode=self.ctx_mode)
|
198 |
+
model_inputs["intra_neg_moment_video_feat"] = \
|
199 |
+
self.concat_feat_adv(tef=intra_neg_tef, ctx_mode=self.ctx_mode)
|
200 |
+
model_inputs["inter_neg_moment_video_feat"] = \
|
201 |
+
self.concat_feat_adv(tef=inter_neg_tef, ctx_mode=self.ctx_mode)
|
202 |
+
return dict(meta=meta, model_inputs=model_inputs)
|
203 |
+
|
204 |
+
def align_ts_to_clip_boundaries(self, duration, ts):
|
205 |
+
""" # TODO Do we really need this???
|
206 |
+
Generate a moment [st, ed] that is most close to a clip boundary,
|
207 |
+
st and ed must be a multiple of self.clip_length, and ed <= duration
|
208 |
+
duration: float,
|
209 |
+
ts: [st (float), ed (float)], ground_truth ts
|
210 |
+
"""
|
211 |
+
clip_aligned_ts = np.array([math.floor(ts[0] / self.clip_length),
|
212 |
+
math.ceil(ts[1] / self.clip_length)]) * self.clip_length
|
213 |
+
clip_aligned_ts[1] = min(clip_aligned_ts[1], duration)
|
214 |
+
return clip_aligned_ts
|
215 |
+
|
216 |
+
def sample_intra_neg_moment(self, duration, ts):
|
217 |
+
""" Generate a intra negative moment given the video duration and the GT ts.
|
218 |
+
The returned moment will be aligned to clip boundaries.
|
219 |
+
1) neg_moment has at least 2 clips
|
220 |
+
2) its iou with ts should be < self.neg_iou_thd
|
221 |
+
Args:
|
222 |
+
duration: float
|
223 |
+
ts: [st (float), ed (float)], ground_truth ts
|
224 |
+
|
225 |
+
Returns:
|
226 |
+
|
227 |
+
"""
|
228 |
+
max_n_search = 5 # search at most max_n_search times, so the program will not be stuck in infinite loops.
|
229 |
+
sampled_moments = self.sample_ts_at_clip_boundaries(duration, n_pairs=max_n_search) # (n_pairs, 2)
|
230 |
+
sampled_moments_ious = compute_temporal_iou_batch(sampled_moments, ts) # (n_pairs, )
|
231 |
+
smallest_iou_idx = np.argmin(sampled_moments_ious)
|
232 |
+
sampled_moment = sampled_moments[smallest_iou_idx]
|
233 |
+
# only a small number (<20 with max_n_search==10) of samples are wrong,
|
234 |
+
# usually when the video_duration is too short.
|
235 |
+
# if sampled_moments_ious[smallest_iou_idx] >= self.neg_iou_thd:
|
236 |
+
# logger.warning("the sampled intra-neg might be wrong. "
|
237 |
+
# "v_dur {}, ts {}, sampled neg moment {}, iou {}"
|
238 |
+
# .format(duration, ts, sampled_moment, sampled_moments_ious[smallest_iou_idx]))
|
239 |
+
return sampled_moment
|
240 |
+
|
241 |
+
def sample_ts_at_clip_boundaries(self, duration, n_pairs=1):
|
242 |
+
"""sample n_pairs moment at clip boundaries, each has at least two clips."""
|
243 |
+
# '+ self.clip_length' since we assume indexing using [clip_st_idx, clip_ed_idx),
|
244 |
+
moments = np.random.randint(0, np.ceil(duration / self.clip_length), size=(n_pairs, 2))
|
245 |
+
moments = np.sort(moments, axis=1) * self.clip_length
|
246 |
+
less_equal = moments[:, 1] - moments[:, 0] <= self.clip_length
|
247 |
+
start_zero = moments[:, 0] == 0
|
248 |
+
moments[:, 1][less_equal * start_zero] += self.clip_length
|
249 |
+
moments[:, 0][less_equal * (start_zero == False)] -= self.clip_length # keep as bool!!!
|
250 |
+
return moments
|
251 |
+
|
252 |
+
def sample_inter_video_negative(self, pos_vid_name, normalized_pos_moment, desc_id=None):
|
253 |
+
"""Sample a negative moment --> negative video + similar normalized moment.
|
254 |
+
1) they are not from the same video
|
255 |
+
Args:
|
256 |
+
pos_vid_name: str,
|
257 |
+
normalized_pos_moment: np.ndarray, (2, ), value in [0, 1], normalized by duration.
|
258 |
+
desc_id: str
|
259 |
+
Returns:
|
260 |
+
moment: np.ndarray, (2, ), ts aligned to clip boundaries.
|
261 |
+
|
262 |
+
"""
|
263 |
+
use_guided_negative = hasattr(self, "desc_id2video_names_dur_pairs")
|
264 |
+
if use_guided_negative:
|
265 |
+
top_videos = self.desc_id2video_names_dur_pairs[desc_id]
|
266 |
+
max_idx = len(top_videos) - 1
|
267 |
+
|
268 |
+
while True: # usually only run once.
|
269 |
+
if use_guided_negative:
|
270 |
+
sampled_idx = min(max_idx, int(random.expovariate(0.1)))
|
271 |
+
sampled_video_name, sampled_video_dur = top_videos[sampled_idx]
|
272 |
+
else:
|
273 |
+
neg_vid_data = self.data[int(random.random() * len(self))]
|
274 |
+
sampled_video_name, sampled_video_dur = neg_vid_data["vid_name"], neg_vid_data["duration"]
|
275 |
+
if sampled_video_name != pos_vid_name:
|
276 |
+
inter_neg_moment = self.align_ts_to_clip_boundaries(
|
277 |
+
sampled_video_dur, sampled_video_dur * normalized_pos_moment)
|
278 |
+
break
|
279 |
+
|
280 |
+
return inter_neg_moment, sampled_video_name, sampled_video_dur
|
281 |
+
|
282 |
+
@classmethod
|
283 |
+
def get_clip_indices_from_moments(cls, moment, clip_length):
|
284 |
+
clip_st_ed_indices = moment / clip_length
|
285 |
+
return math.floor(clip_st_ed_indices[0]), math.ceil(clip_st_ed_indices[1])
|
286 |
+
|
287 |
+
def get_moment_feat(self, vid_feat, moment, normalize=True, fix_outbound=False, pool_local=False):
|
288 |
+
"""Each moment contains multiple clips.
|
289 |
+
Inside means [moment[0], moment[1]] (seconds)
|
290 |
+
Args:
|
291 |
+
vid_feat: np.ndarray, (N_clips, D)
|
292 |
+
moment: [st (float), ed (float)], np.ndarray
|
293 |
+
normalize: L2 normalize features
|
294 |
+
fix_outbound: bool,
|
295 |
+
pool_local: whether to mean pool the features
|
296 |
+
Returns:
|
297 |
+
moment_feature: np.ndarray, ((moment[1] - moment[0]) / clip_length, D) or (D, )
|
298 |
+
"""
|
299 |
+
clip_st_idx, clip_ed_idx = self.get_clip_indices_from_moments(moment, self.clip_length)
|
300 |
+
if fix_outbound:
|
301 |
+
vid_feat_len = len(vid_feat)
|
302 |
+
if clip_st_idx >= vid_feat_len:
|
303 |
+
clip_st_idx = vid_feat_len - 2
|
304 |
+
moment_feat = vid_feat[clip_st_idx:clip_ed_idx] # indexed as [st, ed)
|
305 |
+
if pool_local:
|
306 |
+
moment_feat = np.mean(moment_feat, axis=0, keepdims=True)
|
307 |
+
if normalize:
|
308 |
+
moment_feat = l2_normalize_np_array(moment_feat)
|
309 |
+
return moment_feat # (n_clip_in_moment, D) or (D, )
|
310 |
+
|
311 |
+
@classmethod
|
312 |
+
def concat_feat_adv(cls, moment_feats=None, tef=None, to_torch=True, ctx_mode="tef"):
|
313 |
+
""" Concat moment_feat with other_feats and tef. All the features should be L2 normalized before concatenating
|
314 |
+
Args:
|
315 |
+
moment_feats: list of feats, one of them might be None. Other possible values are
|
316 |
+
ctx_feat (D, ) or sub(vid)_moment_feat (N_p, N_clips, D_t) or (N_clips, D_t).
|
317 |
+
The first non-None feature array is used as base for the rest to concatenate with.
|
318 |
+
tef: (N_p, 2) or (2, ), np.ndarray
|
319 |
+
to_torch: convert resulting np.ndarray to torch.tensor
|
320 |
+
ctx_mode:
|
321 |
+
"""
|
322 |
+
if ctx_mode == "tef":
|
323 |
+
assembled_feat = np.expand_dims(tef, axis=-2)
|
324 |
+
else: # concat moment_feat with all other_feats
|
325 |
+
moment_feats = [e for e in moment_feats if e is not None] # remove possible None (placeholder)
|
326 |
+
extra_dims = moment_feats[0].shape[:-1] # all others will need to broadcast to match it.
|
327 |
+
if isinstance(extra_dims, int): # happens when len(moment_feat.shape) == 2
|
328 |
+
extra_dims = (extra_dims, )
|
329 |
+
last_dim_lengths = [0, ] + [e.shape[-1] for e in moment_feats]
|
330 |
+
if "tef" in ctx_mode: # add tef
|
331 |
+
last_dim_lengths += [2, ]
|
332 |
+
moment_feats += [np.expand_dims(tef, axis=-2), ]
|
333 |
+
|
334 |
+
if len(moment_feats) > 1:
|
335 |
+
assembled_feat = np.empty(extra_dims + (sum(last_dim_lengths), ), dtype=np.float32)
|
336 |
+
last_dim_lengths_cumsum = [sum(last_dim_lengths[0:idx+1]) for idx in range(len(last_dim_lengths))]
|
337 |
+
for idx, feat in enumerate(moment_feats):
|
338 |
+
assembled_feat[..., last_dim_lengths_cumsum[idx]:last_dim_lengths_cumsum[idx+1]] = feat
|
339 |
+
else:
|
340 |
+
assembled_feat = moment_feats[0]
|
341 |
+
|
342 |
+
if to_torch:
|
343 |
+
return torch.from_numpy(assembled_feat)
|
344 |
+
else:
|
345 |
+
return assembled_feat # (N_prop, N_clips, D_concat) or (N_clips, D_concat)
|
346 |
+
|
347 |
+
|
348 |
+
class ProposalRetrievalEvalDataset(Dataset):
|
349 |
+
"""
|
350 |
+
init_data_mode: `video_query` or `video_only` or `query_only`,
|
351 |
+
it indicates which data to load when initialize the Dataset object.
|
352 |
+
data_mode: `context` or `query`, it indicates which data to return for self.__get_item__()
|
353 |
+
desc_bert_path_or_handler: h5py.File object or str path
|
354 |
+
vid_feat_path_or_handler: h5py.File object or str path
|
355 |
+
eval_proposal_bsz: the proposals for a single video will be sorted in length and batched here with
|
356 |
+
max batch size to be eval_proposal_bsz. A single video might have multiple batches of proposals.
|
357 |
+
load_gt_video: load GroundTruth Video, useful when evaluating single video moment retrieval.
|
358 |
+
data_ratio: percentage of query data to use.
|
359 |
+
"""
|
360 |
+
def __init__(self, dset_name, eval_split_name, data_path=None,
|
361 |
+
desc_bert_path_or_handler=None, max_desc_len=None,
|
362 |
+
sub_bert_path_or_handler=None, vid_feat_path_or_handler=None,
|
363 |
+
corpus_path=None, clip_length=None,
|
364 |
+
eval_proposal_bsz=None, ctx_mode="tef", data_mode="context",
|
365 |
+
h5driver=None, data_ratio=1.0, normalize_vfeat=True,
|
366 |
+
normalize_tfeat=True, max_n_proposals=90, model_type="cal"):
|
367 |
+
self.dset_name = dset_name
|
368 |
+
self.model_type = model_type
|
369 |
+
self.pool_local = model_type == "mcn" # pool local feature
|
370 |
+
self.eval_split_name = eval_split_name
|
371 |
+
self.ctx_mode = ctx_mode
|
372 |
+
self.load_gt_video = False
|
373 |
+
self.data_ratio = data_ratio # only affect query data
|
374 |
+
self.normalize_vfeat = normalize_vfeat
|
375 |
+
self.normalize_tfeat = normalize_tfeat
|
376 |
+
self.max_n_proposals = max_n_proposals
|
377 |
+
|
378 |
+
self.data_mode = None
|
379 |
+
self.set_data_mode(data_mode)
|
380 |
+
|
381 |
+
self.max_desc_len = max_desc_len
|
382 |
+
self.data_path = data_path
|
383 |
+
self.query_data = load_jsonl(data_path)
|
384 |
+
if data_ratio != 1:
|
385 |
+
n_examples = int(len(self.query_data) * data_ratio)
|
386 |
+
self.query_data = self.query_data[:n_examples]
|
387 |
+
logger.info("Using {}% of the data: {} examples".format(data_ratio * 100, n_examples))
|
388 |
+
if isinstance(desc_bert_path_or_handler, h5py.File):
|
389 |
+
self.desc_bert_h5 = desc_bert_path_or_handler
|
390 |
+
else:
|
391 |
+
self.desc_bert_h5 = h5py.File(desc_bert_path_or_handler, "r", driver=h5driver)
|
392 |
+
|
393 |
+
video_data = load_json(corpus_path)[self.eval_split_name]
|
394 |
+
self.video_data = [{"vid_name": k, "duration": v[0]} for k, v in video_data.items()]
|
395 |
+
self.video2idx = {k: v[1] for k, v in video_data.items()}
|
396 |
+
self.eval_proposal_bsz = eval_proposal_bsz
|
397 |
+
self.clip_length = clip_length
|
398 |
+
self.proposal_fn = get_proposal_interface(dset_name)
|
399 |
+
|
400 |
+
self.use_video = "video" in self.ctx_mode
|
401 |
+
self.use_sub = "sub" in self.ctx_mode
|
402 |
+
self.use_tef = "tef" in self.ctx_mode
|
403 |
+
|
404 |
+
if self.use_video:
|
405 |
+
if isinstance(vid_feat_path_or_handler, h5py.File):
|
406 |
+
self.vid_feat_h5 = vid_feat_path_or_handler
|
407 |
+
else: # str path
|
408 |
+
self.vid_feat_h5 = h5py.File(vid_feat_path_or_handler, "r", driver=h5driver)
|
409 |
+
|
410 |
+
if self.use_sub:
|
411 |
+
if isinstance(sub_bert_path_or_handler, h5py.File):
|
412 |
+
self.sub_bert_h5 = sub_bert_path_or_handler
|
413 |
+
else: # str path
|
414 |
+
self.sub_bert_h5 = h5py.File(sub_bert_path_or_handler, "r", driver=h5driver)
|
415 |
+
|
416 |
+
def set_data_mode(self, data_mode):
|
417 |
+
"""context or query"""
|
418 |
+
assert data_mode in ["context", "query"]
|
419 |
+
self.data_mode = data_mode
|
420 |
+
|
421 |
+
def load_gt_vid_name_for_query(self, load_gt_video):
|
422 |
+
"""load_gt_video: bool, affect the returned value of self._get_item_query"""
|
423 |
+
assert "vid_name" in self.query_data[0]
|
424 |
+
self.load_gt_video = load_gt_video
|
425 |
+
|
426 |
+
def __len__(self):
|
427 |
+
if self.data_mode == "context":
|
428 |
+
return len(self.video_data)
|
429 |
+
else:
|
430 |
+
return len(self.query_data)
|
431 |
+
|
432 |
+
def __getitem__(self, index):
|
433 |
+
if self.data_mode == "context":
|
434 |
+
return self._get_item_context(index)
|
435 |
+
else:
|
436 |
+
return self._get_item_query(index)
|
437 |
+
|
438 |
+
def _get_item_query(self, index):
|
439 |
+
"""Need to batch"""
|
440 |
+
raw_data = self.query_data[index]
|
441 |
+
|
442 |
+
meta = dict(
|
443 |
+
desc_id=raw_data["desc_id"],
|
444 |
+
desc=raw_data["desc"],
|
445 |
+
vid_name=raw_data["vid_name"] if self.load_gt_video else None
|
446 |
+
)
|
447 |
+
|
448 |
+
model_inputs = dict()
|
449 |
+
query_feat = self.desc_bert_h5[str(raw_data["desc_id"])][:self.max_desc_len]
|
450 |
+
if self.normalize_tfeat:
|
451 |
+
query_feat = l2_normalize_np_array(query_feat)
|
452 |
+
model_inputs["query_feat"] = torch.from_numpy(query_feat)
|
453 |
+
return dict(meta=meta, model_inputs=model_inputs)
|
454 |
+
|
455 |
+
def _get_item_context(self, index):
|
456 |
+
"""No need to batch, since it has already been batched here"""
|
457 |
+
raw_data = self.video_data[index]
|
458 |
+
|
459 |
+
# get proposals and sort in ascending order, to get more efficient batching
|
460 |
+
proposals = self.proposal_fn(
|
461 |
+
video_id="", metadata={"duration": raw_data["duration"]}) # np.ndarray (N_p, 2)
|
462 |
+
proposals_lengths = proposals[:, 1] - proposals[:, 0] # seconds
|
463 |
+
sorted_proposal_indices = np.argsort(proposals_lengths)[:self.max_n_proposals]
|
464 |
+
sorted_proposals = proposals[sorted_proposal_indices]
|
465 |
+
|
466 |
+
# initialize with basic data
|
467 |
+
meta = dict(
|
468 |
+
vid_name=raw_data["vid_name"],
|
469 |
+
duration=raw_data["duration"],
|
470 |
+
proposals=sorted_proposals
|
471 |
+
)
|
472 |
+
model_inputs = dict()
|
473 |
+
|
474 |
+
n_proposal_batches = math.ceil(1.0 * len(sorted_proposals) / self.eval_proposal_bsz)
|
475 |
+
|
476 |
+
tef_batched_list = [None, ] * n_proposal_batches
|
477 |
+
t_moments_mask_list = [None, ] * n_proposal_batches
|
478 |
+
if self.use_tef:
|
479 |
+
tef_array = sorted_proposals / meta["duration"] # (N_p, 2)
|
480 |
+
for batch_idx in range(n_proposal_batches):
|
481 |
+
st_m_idx = batch_idx * self.eval_proposal_bsz
|
482 |
+
ed_m_idx = (batch_idx + 1) * self.eval_proposal_bsz
|
483 |
+
tef_batched_list[batch_idx] = tef_array[st_m_idx:ed_m_idx]
|
484 |
+
t_moments_mask_list[batch_idx] = \
|
485 |
+
np.ones((len(tef_batched_list[batch_idx]), 1), dtype=np.float32)
|
486 |
+
if not self.use_video and not self.use_sub: # use video stream
|
487 |
+
model_inputs["video_moment_features_list"] = [
|
488 |
+
ProposalRetrievalDataset.concat_feat_adv(tef=t, ctx_mode=self.ctx_mode) for t in tef_batched_list]
|
489 |
+
model_inputs["video_moment_mask_list"] = [torch.from_numpy(e) for e in t_moments_mask_list]
|
490 |
+
|
491 |
+
# extract/group/pad
|
492 |
+
if self.use_video:
|
493 |
+
v_feat = self.vid_feat_h5[meta["vid_name"]] # (N_frm, D)
|
494 |
+
v_ctx_feat = np.mean(v_feat, axis=0) # (D, )
|
495 |
+
if self.normalize_vfeat:
|
496 |
+
v_ctx_feat = l2_normalize_np_array(v_ctx_feat)
|
497 |
+
v_padded_moments_features_list, v_moments_mask_list = \
|
498 |
+
self.get_batched_moment_feat_for_all_proposals(v_feat, sorted_proposals,
|
499 |
+
pool_local=self.pool_local,
|
500 |
+
normalize=self.normalize_vfeat)
|
501 |
+
|
502 |
+
model_inputs["video_moment_features_list"] = [ProposalRetrievalDataset.concat_feat_adv(
|
503 |
+
moment_feats=[v, v_ctx_feat], tef=t, ctx_mode=self.ctx_mode)
|
504 |
+
for v, t in zip(v_padded_moments_features_list, tef_batched_list)]
|
505 |
+
model_inputs["video_moment_mask_list"] = [torch.from_numpy(e) for e in v_moments_mask_list]
|
506 |
+
|
507 |
+
if self.use_sub:
|
508 |
+
s_feat = self.sub_bert_h5[meta["vid_name"]] # (N_frm, D)
|
509 |
+
s_ctx_feat = np.mean(s_feat, axis=0) # (D, )
|
510 |
+
if self.normalize_tfeat:
|
511 |
+
s_ctx_feat = l2_normalize_np_array(s_ctx_feat)
|
512 |
+
s_padded_moments_features_list, s_moments_mask_list = \
|
513 |
+
self.get_batched_moment_feat_for_all_proposals(s_feat, sorted_proposals,
|
514 |
+
pool_local=self.pool_local,
|
515 |
+
normalize=self.normalize_tfeat)
|
516 |
+
model_inputs["sub_moment_features_list"] = [ProposalRetrievalDataset.concat_feat_adv(
|
517 |
+
moment_feats=[s, s_ctx_feat], tef=t, ctx_mode=self.ctx_mode)
|
518 |
+
for s, t in zip(s_padded_moments_features_list, tef_batched_list)]
|
519 |
+
model_inputs["sub_moment_mask_list"] = [torch.from_numpy(e) for e in s_moments_mask_list]
|
520 |
+
return dict(meta=meta, model_inputs=model_inputs)
|
521 |
+
|
522 |
+
def get_batched_moment_feat_for_all_proposals(self, feature, moments, pool_local=False, normalize=True):
|
523 |
+
"""proposals of the same video wil be segmented into multiple batches to accomodate GPU memory
|
524 |
+
pool_local: pool local feature into a single vector
|
525 |
+
"""
|
526 |
+
n_proposal_batches = math.ceil(1.0 * len(moments) / self.eval_proposal_bsz)
|
527 |
+
padded_moments_features_list = [None, ] * n_proposal_batches
|
528 |
+
moments_mask_list = [None, ] * n_proposal_batches
|
529 |
+
moments_features = self.get_moment_feat_for_all_proposals(
|
530 |
+
feature, moments, normalize=normalize, pool_local=pool_local) # N_p * [(N_clips, D), ]
|
531 |
+
for batch_idx in range(n_proposal_batches):
|
532 |
+
st_m_idx = batch_idx * self.eval_proposal_bsz
|
533 |
+
ed_m_idx = (batch_idx + 1) * self.eval_proposal_bsz
|
534 |
+
padded_moments_features, moments_mask = \
|
535 |
+
pad_sequences_1d(moments_features[st_m_idx:ed_m_idx], dtype=np.float32)
|
536 |
+
padded_moments_features_list[batch_idx] = padded_moments_features
|
537 |
+
moments_mask_list[batch_idx] = moments_mask
|
538 |
+
assert np.sum(np.sum(moments_mask, axis=1) == 0) == 0, " err {}".format(moments_mask)
|
539 |
+
assert np.sum(np.sum(moments_mask_list[0], axis=1) == 0) == 0, " err {}".format(moments_mask_list)
|
540 |
+
return padded_moments_features_list, moments_mask_list
|
541 |
+
|
542 |
+
def get_moment_feat_for_all_proposals(self, vid_feat, moments, normalize=True, pool_local=False):
|
543 |
+
"""Each moment is comprised of multiple clips
|
544 |
+
Args:
|
545 |
+
vid_feat: np.ndarray, (N_clips, D)
|
546 |
+
moments: np.ndarray, (N_p, 2), each row is [st (float), ed (float)],
|
547 |
+
normalize: L2 normalize
|
548 |
+
pool_local:
|
549 |
+
Returns:
|
550 |
+
moments_features: list(np.ndarray), [(N_clips, D), ] * N_p, N_clips is changing.
|
551 |
+
"""
|
552 |
+
if normalize and not pool_local:
|
553 |
+
vid_feat = l2_normalize_np_array(vid_feat)
|
554 |
+
vid_feat_len = len(vid_feat)
|
555 |
+
moments_st_clip_indices = np.floor(moments[:, 0] / self.clip_length).astype(np.int64).clip(0, vid_feat_len-1)
|
556 |
+
moments_ed_clip_indices = np.ceil(moments[:, 1] / self.clip_length).astype(np.int64).clip(1, vid_feat_len)
|
557 |
+
moments_features = []
|
558 |
+
for st_idx, ed_idx, m in zip(moments_st_clip_indices, moments_ed_clip_indices, moments):
|
559 |
+
feat = vid_feat[st_idx:ed_idx]
|
560 |
+
if pool_local:
|
561 |
+
feat = np.mean(feat, axis=0, keepdims=True)
|
562 |
+
if normalize:
|
563 |
+
feat = l2_normalize_np_array(feat)
|
564 |
+
moments_features.append(feat)
|
565 |
+
return moments_features
|
566 |
+
|
567 |
+
|
568 |
+
def proposal_retrieval_collate(batch):
|
569 |
+
batch_meta = [e["meta"] for e in batch] # seems no need to collate ?
|
570 |
+
|
571 |
+
model_inputs_keys = batch[0]["model_inputs"].keys()
|
572 |
+
batched_data = {k: pad_sequences_1d([e["model_inputs"][k] for e in batch], dtype=torch.float32)
|
573 |
+
for k in model_inputs_keys}
|
574 |
+
return batch_meta, batched_data
|
575 |
+
|
576 |
+
|
577 |
+
def prepare_batch_inputs(batched_model_inputs, device, non_blocking=False):
|
578 |
+
model_inputs = {}
|
579 |
+
for k, v in batched_model_inputs.items():
|
580 |
+
model_inputs[k] = v[0].to(device, non_blocking=non_blocking)
|
581 |
+
model_inputs[k.replace("feat", "mask")] = v[1].to(device, non_blocking=non_blocking)
|
582 |
+
return model_inputs
|
583 |
+
|
584 |
+
|
585 |
+
if __name__ == '__main__':
|
586 |
+
from baselines.clip_alignment_with_language.config import BaseOptions
|
587 |
+
options = BaseOptions().parse()
|
baselines/clip_alignment_with_language/scripts/compute_upper_bound.sh
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
# run at project root dir
|
3 |
+
dset_name=$1 # see case below
|
4 |
+
split_name=$2 # train/val/test, some datasets may not support all the 3 splits
|
5 |
+
result_dir="baselines/clip_alignment_with_language/results"
|
6 |
+
|
7 |
+
echo "Running with dataset ${dset_name} with split ${split_name}"
|
8 |
+
case ${dset_name} in
|
9 |
+
tvr) # only supports train/val
|
10 |
+
eval_file_path=data/tvr_${split_name}_release.jsonl
|
11 |
+
save_path=${result_dir}/tvr_${split_name}_proposal_upper_bound.json
|
12 |
+
;;
|
13 |
+
*)
|
14 |
+
echo -n "Unknown argument"
|
15 |
+
;;
|
16 |
+
esac
|
17 |
+
|
18 |
+
echo "Running evaluation"
|
19 |
+
python baselines/clip_alignment_with_language/local_utils/compute_proposal_upper_bound.py \
|
20 |
+
-dset_name=${dset_name} \
|
21 |
+
-eval_file_path=${eval_file_path} \
|
22 |
+
-save_path=${save_path} \
|
23 |
+
-verbose
|
baselines/clip_alignment_with_language/scripts/inference.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
# run at project root dir
|
3 |
+
# Usage:
|
4 |
+
# bash baselines/clip_alignment_with_language/scripts/inference.sh ANY_OTHER_PYTHON_ARGS
|
5 |
+
model_dir=$1
|
6 |
+
eval_split_name=$2
|
7 |
+
eval_path=data/tvr_${eval_split_name}_release.jsonl
|
8 |
+
tasks=(VR)
|
9 |
+
tasks+=(SVMR)
|
10 |
+
tasks+=(VCMR)
|
11 |
+
echo "tasks ${tasks[@]}"
|
12 |
+
python baselines/clip_alignment_with_language/inference.py \
|
13 |
+
--model_dir ${model_dir} \
|
14 |
+
--tasks ${tasks[@]} \
|
15 |
+
--eval_split_name ${eval_split_name} \
|
16 |
+
--eval_path ${eval_path} \
|
17 |
+
${@:3}
|
baselines/clip_alignment_with_language/scripts/inference_mix.sh
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
# run at project root dir
|
3 |
+
# Usage:
|
4 |
+
# bash baselines/clip_alignment_with_language/scripts/inference_mix.sh
|
5 |
+
eval_model=$1 # [mcn, cal], retrain models should only be paired with mee
|
6 |
+
project_root=/net/bvisionserver14/playpen-ssd/jielei/projects/video_retrieval/baselines/clip_alignment_with_language/results
|
7 |
+
|
8 |
+
# setup eval model
|
9 |
+
if [[ ${eval_model} == mcn ]]; then
|
10 |
+
pred_dir=tvr-mcn-video_sub-res-2019_11_05_14_16_40
|
11 |
+
tef_pred_dir=tvr-mcn-video_sub_tef-res-2019_11_05_14_14_57
|
12 |
+
elif [[ ${eval_model} == cal ]]; then
|
13 |
+
pred_dir=tvr-cal-video_sub-res-2019_11_05_14_32_59
|
14 |
+
tef_pred_dir=tvr-cal-video_sub_tef-res-2019_11_05_14_25_49
|
15 |
+
fi
|
16 |
+
|
17 |
+
pred_path=${project_root}/${pred_dir}/inference_tvr_test_public_max200_predictions_VR_SVMR_VCMR.json
|
18 |
+
save_path=${project_root}/${pred_dir}/inference_tvr_test_public_max200_predictions_VR_SVMR_VCMR_rerank_${tef_pred_dir}.json
|
19 |
+
tef_pred_path=${project_root}/${tef_pred_dir}/inference_tvr_test_public_max10000_predictions_VCMR.pt
|
20 |
+
gt_path=data/tvr_test_public_archive.jsonl
|
21 |
+
|
22 |
+
|
23 |
+
python baselines/clip_alignment_with_language/mix_model_prediction.py \
|
24 |
+
--pred_path=${pred_path} \
|
25 |
+
--tef_pred_path=${tef_pred_path} \
|
26 |
+
--gt_path=${gt_path} \
|
27 |
+
--save_path=${save_path}
|
baselines/clip_alignment_with_language/scripts/inference_with_external.sh
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
# run at project root dir
|
3 |
+
# Usage:
|
4 |
+
# bash baselines/clip_alignment_with_language/scripts/inference_with_external.sh
|
5 |
+
#model_dir=$1
|
6 |
+
# DO not use NMS, since it gives worse results
|
7 |
+
eval_model=$1 # [mcn, mcn_tef, cal, cal_tef, mcn_retrain, cal_retrain], retrain models should only be paired with mee
|
8 |
+
external_model=$2 # [mee, mcn, cal]
|
9 |
+
eval_split_name=$3
|
10 |
+
eval_path=data/tvr_${eval_split_name}_release.jsonl
|
11 |
+
project_root=/net/bvisionserver14/playpen-ssd/jielei/projects/video_retrieval/baselines
|
12 |
+
|
13 |
+
# setup eval model
|
14 |
+
if [[ ${eval_model} == mcn ]]; then
|
15 |
+
eval_model_dir=tvr-mcn-video_sub-res-2019_11_05_14_16_40
|
16 |
+
elif [[ ${eval_model} == mcn_tef ]]; then
|
17 |
+
eval_model_dir=tvr-mcn-video_sub_tef-res-2019_11_05_14_14_57
|
18 |
+
elif [[ ${eval_model} == cal ]]; then
|
19 |
+
eval_model_dir=tvr-cal-video_sub-res-2019_11_05_14_32_59
|
20 |
+
elif [[ ${eval_model} == cal_tef ]]; then
|
21 |
+
eval_model_dir=tvr-cal-video_sub_tef-res-2019_11_05_14_25_49
|
22 |
+
elif [[ ${eval_model} == mcn_tef_retrain ]]; then
|
23 |
+
eval_model_dir=tvr-mcn-video_sub_tef-+ex_vr_mee_tvr-video_sub-res-2019_11_06_00_33_39_tvr-mcn-video_sub_tef-res-2019_11_05_14_14_57+-2019_11_06_02_26_49
|
24 |
+
elif [[ ${eval_model} == cal_tef_retrain ]]; then
|
25 |
+
eval_model_dir=tvr-cal-video_sub_tef-+ex_vr_mee_tvr-video_sub-res-2019_11_06_00_33_39_tvr-cal-video_sub_tef-res-2019_11_05_14_25_49+-2019_11_06_03_12_15
|
26 |
+
fi
|
27 |
+
|
28 |
+
# setup external
|
29 |
+
if [[ ${external_model} == mee ]]; then
|
30 |
+
external_model_dir=tvr-video_sub-res-2019_11_06_00_33_39
|
31 |
+
external_inference_vr_res_path=${project_root}/mixture_embedding_experts/results/${external_model_dir}/inference_tvr_${eval_split_name}_None_predictions_VR.json
|
32 |
+
elif [[ ${external_model} == mcn ]]; then
|
33 |
+
external_model_dir=tvr-mcn-video_sub-res-2019_11_05_14_16_40
|
34 |
+
external_inference_vr_res_path=${project_root}/clip_alignment_with_language/results/${external_model_dir}/inference_tvr_${eval_split_name}_None_predictions_VR_SVMR_VCMR.json
|
35 |
+
elif [[ ${external_model} == cal ]]; then
|
36 |
+
external_model_dir=tvr-cal-video_sub-res-2019_11_05_14_32_59
|
37 |
+
external_inference_vr_res_path=${project_root}/clip_alignment_with_language/results/${external_model_dir}/inference_tvr_${eval_split_name}_None_predictions_VR_SVMR_VCMR.json
|
38 |
+
fi
|
39 |
+
|
40 |
+
tasks=(VR)
|
41 |
+
tasks+=(SVMR)
|
42 |
+
tasks+=(VCMR)
|
43 |
+
echo "tasks ${tasks[@]}"
|
44 |
+
python baselines/clip_alignment_with_language/inference.py \
|
45 |
+
--model_dir ${eval_model_dir} \
|
46 |
+
--tasks ${tasks[@]} \
|
47 |
+
--eval_split_name ${eval_split_name} \
|
48 |
+
--eval_path ${eval_path} \
|
49 |
+
--external_inference_vr_res_path ${external_inference_vr_res_path} \
|
50 |
+
--eval_id ${external_model_dir} \
|
51 |
+
${@:4}
|
52 |
+
|
53 |
+
#--use_intermediate \ # temporary removed
|
54 |
+
|
baselines/clip_alignment_with_language/scripts/re_train_cal.sh
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
lr=0.00005
|
4 |
+
n_epoch=20
|
5 |
+
project_root=/net/bvisionserver14/playpen-ssd/jielei/projects/video_retrieval
|
6 |
+
ckpt_filename="model.ckpt"
|
7 |
+
init_ckpt_path=${project_root}/baselines/clip_alignment_with_language/results/tvr-cal-video_sub_tef-res-2019_11_05_14_25_49/${ckpt_filename}
|
8 |
+
exp_id=+ex_vr_mee_tvr-video_sub-res-2019_11_06_00_33_39_tvr-cal-video_sub_tef-res-2019_11_05_14_25_49+
|
9 |
+
external_train_vr_res_path=${project_root}/baselines/mixture_embedding_experts/results/tvr-video_sub-res-2019_11_06_00_33_39/inference_tvr_train_None_predictions_VR.json
|
10 |
+
model_type=cal
|
11 |
+
|
12 |
+
bash baselines/clip_alignment_with_language/scripts/train.sh tvr video_sub_tef resnet_i3d \
|
13 |
+
--no_norm_vfeat \
|
14 |
+
--model_type ${model_type} \
|
15 |
+
--exp_id ${exp_id} \
|
16 |
+
--init_ckpt_path ${init_ckpt_path} \
|
17 |
+
--external_train_vr_res_path ${external_train_vr_res_path} \
|
18 |
+
--lr ${lr} \
|
19 |
+
--n_epoch ${n_epoch} \
|
20 |
+
--max_es_cnt 5 \
|
21 |
+
${@:1}
|
baselines/clip_alignment_with_language/scripts/re_train_mcn.sh
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
lr=0.00005
|
4 |
+
n_epoch=20
|
5 |
+
project_root=/net/bvisionserver14/playpen-ssd/jielei/projects/video_retrieval
|
6 |
+
ckpt_filename="model.ckpt"
|
7 |
+
init_ckpt_path=${project_root}/baselines/clip_alignment_with_language/results/tvr-mcn-video_sub_tef-res-2019_11_05_14_14_57/${ckpt_filename}
|
8 |
+
exp_id=+ex_vr_mee_tvr-video_sub-res-2019_11_06_00_33_39_tvr-mcn-video_sub_tef-res-2019_11_05_14_14_57+
|
9 |
+
external_train_vr_res_path=${project_root}/baselines/mixture_embedding_experts/results/tvr-video_sub-res-2019_11_06_00_33_39/inference_tvr_train_None_predictions_VR.json
|
10 |
+
model_type=mcn
|
11 |
+
|
12 |
+
bash baselines/clip_alignment_with_language/scripts/train.sh tvr video_sub_tef resnet_i3d \
|
13 |
+
--no_norm_vfeat \
|
14 |
+
--model_type ${model_type} \
|
15 |
+
--exp_id ${exp_id} \
|
16 |
+
--init_ckpt_path ${init_ckpt_path} \
|
17 |
+
--external_train_vr_res_path ${external_train_vr_res_path} \
|
18 |
+
--lr ${lr} \
|
19 |
+
--n_epoch ${n_epoch} \
|
20 |
+
--max_es_cnt 5 \
|
21 |
+
${@:1}
|
baselines/clip_alignment_with_language/scripts/train.sh
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
# run at project root dir
|
3 |
+
# Usage:
|
4 |
+
# bash baselines/clip_alignment_with_language/scripts/train.sh tvr all ANY_OTHER_PYTHON_ARGS
|
5 |
+
# if re-training, please also give --init_ckpt_path and --external_train_vr_res_path, may also use lower lr ?
|
6 |
+
dset_name=$1 # see case below
|
7 |
+
ctx_mode=$2 # ["video", "sub", "tef", "video_sub", "video_tef", "sub_tef", "video_sub_tef"]
|
8 |
+
vid_feat_type=$3 # [resnet, i3d, resnet_i3d, none] , none for subtitles only models
|
9 |
+
feature_root=data/tvr_feature_release
|
10 |
+
results_root=baselines/clip_alignment_with_language/results
|
11 |
+
vid_feat_size=2048
|
12 |
+
extra_args=()
|
13 |
+
|
14 |
+
if [[ ${ctx_mode} == *"sub"* ]] || [[ ${ctx_mode} == "sub" ]]; then
|
15 |
+
if [[ ${dset_name} != "tvr" ]]; then
|
16 |
+
echo "The use of subtitles is only supported in tvr."
|
17 |
+
exit 1
|
18 |
+
fi
|
19 |
+
fi
|
20 |
+
|
21 |
+
|
22 |
+
case ${dset_name} in
|
23 |
+
tvr)
|
24 |
+
train_path=data/tvr_train_release.jsonl
|
25 |
+
corpus_path=data/tvr_video2dur_idx.json
|
26 |
+
desc_bert_path=${feature_root}/bert_feature/query_only/tvr_query_pretrained_w_query.h5
|
27 |
+
vid_feat_path=${feature_root}/video_feature/tvr_resnet152_rgb_max_cl-1.5.h5
|
28 |
+
clip_length=1.5
|
29 |
+
eval_split_name=val
|
30 |
+
nms_thd=-1
|
31 |
+
extra_args+=(--eval_path)
|
32 |
+
extra_args+=(data/tvr_val_release.jsonl)
|
33 |
+
|
34 |
+
if [[ ${vid_feat_type} == "i3d" ]]; then
|
35 |
+
echo "Using I3D feature with shape 1024"
|
36 |
+
vid_feat_path=${feature_root}/video_feature/tvr_i3d_rgb600_avg_cl-1.5.h5
|
37 |
+
vid_feat_size=1024
|
38 |
+
elif [[ ${vid_feat_type} == "resnet" ]]; then
|
39 |
+
echo "Using ResNet feature with shape 2048"
|
40 |
+
vid_feat_path=${feature_root}/video_feature/tvr_resnet152_rgb_max_cl-1.5.h5
|
41 |
+
vid_feat_size=2048
|
42 |
+
elif [[ ${vid_feat_type} == "resnet_i3d" ]]; then
|
43 |
+
echo "Using concatenated ResNet and I3D feature with shape 2048+1024"
|
44 |
+
vid_feat_path=${feature_root}/video_feature/tvr_resnet152_rgb_max_i3d_rgb600_avg_cat_cl-1.5.h5
|
45 |
+
vid_feat_size=3072
|
46 |
+
extra_args+=(--no_norm_vfeat) # since they are already normalized.
|
47 |
+
fi
|
48 |
+
|
49 |
+
if [[ ${ctx_mode} == *"sub"* ]] || [[ ${ctx_mode} == "sub" ]]; then
|
50 |
+
echo "Running with sub."
|
51 |
+
desc_bert_path=${feature_root}/bert_feature/sub_query/tvr_query_pretrained_w_sub_query.h5 # overwrite
|
52 |
+
sub_bert_path=${feature_root}/bert_feature/sub_query/tvr_sub_pretrained_w_sub_query_max_cl-1.5.h5
|
53 |
+
sub_feat_size=768
|
54 |
+
extra_args+=(--sub_feat_size)
|
55 |
+
extra_args+=(${sub_feat_size})
|
56 |
+
extra_args+=(--sub_bert_path)
|
57 |
+
extra_args+=(${sub_bert_path})
|
58 |
+
fi
|
59 |
+
;;
|
60 |
+
*)
|
61 |
+
echo -n "Unknown argument"
|
62 |
+
;;
|
63 |
+
esac
|
64 |
+
|
65 |
+
echo "Start training with dataset [${dset_name}] in Context Mode [${ctx_mode}]"
|
66 |
+
echo "Extra args ${extra_args[@]}"
|
67 |
+
python baselines/clip_alignment_with_language/train.py \
|
68 |
+
--dset_name=${dset_name} \
|
69 |
+
--eval_split_name=${eval_split_name} \
|
70 |
+
--nms_thd=${nms_thd} \
|
71 |
+
--results_root=${results_root} \
|
72 |
+
--train_path=${train_path} \
|
73 |
+
--desc_bert_path=${desc_bert_path} \
|
74 |
+
--corpus_path=${corpus_path} \
|
75 |
+
--vid_feat_path=${vid_feat_path} \
|
76 |
+
--clip_length=${clip_length} \
|
77 |
+
--vid_feat_size=${vid_feat_size} \
|
78 |
+
--ctx_mode=${ctx_mode} \
|
79 |
+
${extra_args[@]} \
|
80 |
+
${@:4}
|
baselines/clip_alignment_with_language/train.py
ADDED
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import json
|
4 |
+
import pprint
|
5 |
+
import random
|
6 |
+
import numpy as np
|
7 |
+
from collections import OrderedDict
|
8 |
+
from easydict import EasyDict as EDict
|
9 |
+
from tqdm import tqdm, trange
|
10 |
+
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
import torch.backends.cudnn as cudnn
|
14 |
+
from torch.utils.data import DataLoader
|
15 |
+
from torch.utils.tensorboard import SummaryWriter
|
16 |
+
|
17 |
+
from baselines.clip_alignment_with_language.config import BaseOptions
|
18 |
+
from baselines.clip_alignment_with_language.model import CALWithSub
|
19 |
+
from baselines.clip_alignment_with_language.proposal_retrieval_dataset import \
|
20 |
+
ProposalRetrievalDataset, proposal_retrieval_collate, ProposalRetrievalEvalDataset, prepare_batch_inputs
|
21 |
+
from baselines.clip_alignment_with_language.inference import eval_epoch, start_inference
|
22 |
+
from utils.basic_utils import save_jsonl, save_json, AverageMeter
|
23 |
+
from utils.model_utils import count_parameters
|
24 |
+
|
25 |
+
|
26 |
+
import logging
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s",
|
29 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
30 |
+
level=logging.INFO)
|
31 |
+
|
32 |
+
|
33 |
+
def set_seed(seed, use_cuda=True):
|
34 |
+
random.seed(seed)
|
35 |
+
np.random.seed(seed)
|
36 |
+
torch.manual_seed(seed)
|
37 |
+
if use_cuda:
|
38 |
+
torch.cuda.manual_seed_all(seed)
|
39 |
+
|
40 |
+
|
41 |
+
def train_epoch(model, train_loader, optimizer, opt, epoch_i):
|
42 |
+
model.train()
|
43 |
+
|
44 |
+
# init meters
|
45 |
+
dataloading_time = AverageMeter()
|
46 |
+
prepare_inputs_time = AverageMeter()
|
47 |
+
model_forward_time = AverageMeter()
|
48 |
+
model_backward_time = AverageMeter()
|
49 |
+
loss_meter = AverageMeter()
|
50 |
+
|
51 |
+
num_training_examples = len(train_loader)
|
52 |
+
timer_dataloading = time.time()
|
53 |
+
for batch_idx, batch in tqdm(enumerate(train_loader),
|
54 |
+
desc="Training Iteration",
|
55 |
+
total=num_training_examples):
|
56 |
+
dataloading_time.update(time.time() - timer_dataloading)
|
57 |
+
|
58 |
+
# continue
|
59 |
+
timer_start = time.time()
|
60 |
+
model_inputs = prepare_batch_inputs(batch[1], opt.device, non_blocking=opt.pin_memory)
|
61 |
+
prepare_inputs_time.update(time.time() - timer_start)
|
62 |
+
# logger.info("model_inputs {}"
|
63 |
+
# .format({k: (type(k), v.shape if isinstance(v, torch.Tensor) else v)
|
64 |
+
# for k, v in model_inputs.items()}))
|
65 |
+
# logger.info("model_inputs \n{}".format({k: (type(v), v.shape, v.dtype) for k, v in model_inputs.items()}))
|
66 |
+
timer_start = time.time()
|
67 |
+
loss = model(**model_inputs)
|
68 |
+
model_forward_time.update(time.time() - timer_start)
|
69 |
+
timer_start = time.time()
|
70 |
+
optimizer.zero_grad()
|
71 |
+
loss.backward()
|
72 |
+
if opt.grad_clip != -1:
|
73 |
+
nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
|
74 |
+
optimizer.step()
|
75 |
+
model_backward_time.update(time.time() - timer_start)
|
76 |
+
|
77 |
+
global_step = epoch_i * num_training_examples + batch_idx
|
78 |
+
opt.writer.add_scalar("Train/LR", float(optimizer.param_groups[0]["lr"]), global_step)
|
79 |
+
opt.writer.add_scalar("Train/Loss", float(loss), global_step)
|
80 |
+
loss_meter.update(float(loss))
|
81 |
+
|
82 |
+
timer_dataloading = time.time()
|
83 |
+
if opt.debug and batch_idx == 3:
|
84 |
+
break
|
85 |
+
to_write = opt.train_log_txt_formatter.format(
|
86 |
+
time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
|
87 |
+
epoch=epoch_i,
|
88 |
+
loss_str=str(loss_meter.avg))
|
89 |
+
with open(opt.train_log_filepath, "a") as f:
|
90 |
+
f.write(to_write)
|
91 |
+
print("Epoch time stats:")
|
92 |
+
print("dataloading_time: max {dataloading_time.max} "
|
93 |
+
"min {dataloading_time.min} avg {dataloading_time.avg}\n"
|
94 |
+
"prepare_inputs_time: max {prepare_inputs_time.max} "
|
95 |
+
"min {prepare_inputs_time.min} avg {prepare_inputs_time.avg}\n"
|
96 |
+
"model_forward_time: max {model_forward_time.max} "
|
97 |
+
"min {model_forward_time.min} avg {model_forward_time.avg}\n"
|
98 |
+
"model_backward_time: max {model_backward_time.max} "
|
99 |
+
"min {model_backward_time.min} avg {model_backward_time.avg}\n"
|
100 |
+
"".format(dataloading_time=dataloading_time, prepare_inputs_time=prepare_inputs_time,
|
101 |
+
model_forward_time=model_forward_time, model_backward_time=model_backward_time))
|
102 |
+
|
103 |
+
|
104 |
+
def train(model, train_dataset, val_dataset, opt):
|
105 |
+
# Prepare optimizer
|
106 |
+
optimizer = torch.optim.SGD(
|
107 |
+
filter(lambda p: p.requires_grad, model.parameters()),
|
108 |
+
lr=opt.lr,
|
109 |
+
weight_decay=opt.wd,
|
110 |
+
momentum=opt.momentum)
|
111 |
+
# reduce the lr by 0.1 every 30 epochs
|
112 |
+
scheduler = torch.optim.lr_scheduler.StepLR(
|
113 |
+
optimizer,
|
114 |
+
step_size=30,
|
115 |
+
gamma=0.1
|
116 |
+
)
|
117 |
+
|
118 |
+
train_loader = DataLoader(train_dataset,
|
119 |
+
collate_fn=proposal_retrieval_collate,
|
120 |
+
batch_size=opt.bsz,
|
121 |
+
num_workers=opt.num_workers,
|
122 |
+
shuffle=True,
|
123 |
+
pin_memory=opt.pin_memory)
|
124 |
+
|
125 |
+
prev_best_score = 0.
|
126 |
+
es_cnt = 0
|
127 |
+
start_epoch = -1 if opt.eval_untrained else 0
|
128 |
+
eval_tasks_at_training = ["SVMR", ]
|
129 |
+
save_submission_filename = \
|
130 |
+
"latest_{}_{}_predictions_{}.json".format(opt.dset_name, opt.eval_split_name, "_".join(eval_tasks_at_training))
|
131 |
+
for epoch_i in trange(start_epoch, opt.n_epoch, desc="Epoch"):
|
132 |
+
if epoch_i > -1:
|
133 |
+
with torch.autograd.detect_anomaly():
|
134 |
+
train_epoch(model, train_loader, optimizer, opt, epoch_i)
|
135 |
+
global_step = (epoch_i + 1) * len(train_loader)
|
136 |
+
scheduler.step()
|
137 |
+
if opt.eval_path is not None:
|
138 |
+
with torch.no_grad():
|
139 |
+
metrics_no_nms, metrics_nms, latest_file_paths = \
|
140 |
+
eval_epoch(model, val_dataset, opt, save_submission_filename, tasks=eval_tasks_at_training,
|
141 |
+
max_before_nms=300, max_after_nms=100)
|
142 |
+
logger.info("metrics_no_nms {}".format(
|
143 |
+
pprint.pformat(rm_key_from_odict(metrics_no_nms, rm_suffix="by_type"), indent=4)))
|
144 |
+
logger.info("metrics_nms \n{}".format(pprint.pformat(metrics_nms, indent=4)))
|
145 |
+
|
146 |
+
to_write = opt.eval_log_txt_formatter.format(
|
147 |
+
time_str=time.strftime("%Y_%m_%d_%H_%M_%S"),
|
148 |
+
epoch=epoch_i,
|
149 |
+
eval_metrics_str=json.dumps(metrics_no_nms))
|
150 |
+
with open(opt.eval_log_filepath, "a") as f:
|
151 |
+
f.write(to_write)
|
152 |
+
|
153 |
+
# metrics = metrics_nms if metrics_nms is not None else metrics_no_nms
|
154 |
+
metrics = metrics_no_nms
|
155 |
+
# early stop/ log / save model
|
156 |
+
for task_type, task_metrics in metrics.items():
|
157 |
+
for iou_thd in [0.5, 0.7]:
|
158 |
+
opt.writer.add_scalars("Eval/{}-{}".format(task_type, iou_thd),
|
159 |
+
{k: v for k, v in task_metrics.items() if str(iou_thd) in k},
|
160 |
+
global_step)
|
161 |
+
|
162 |
+
# use the most strict metric available
|
163 |
+
if metrics["SVMR"]["0.5-r1"] > prev_best_score:
|
164 |
+
es_cnt = 0
|
165 |
+
prev_best_score = metrics["SVMR"]["0.5-r1"]
|
166 |
+
|
167 |
+
checkpoint = {
|
168 |
+
"model": model.state_dict(),
|
169 |
+
"model_cfg": model.config,
|
170 |
+
"epoch": epoch_i}
|
171 |
+
torch.save(checkpoint, opt.ckpt_filepath)
|
172 |
+
|
173 |
+
best_file_paths = [e.replace("latest", "best") for e in latest_file_paths]
|
174 |
+
for src, tgt in zip(latest_file_paths, best_file_paths):
|
175 |
+
os.renames(src, tgt)
|
176 |
+
logger.info("The checkpoint file has been updated.")
|
177 |
+
else:
|
178 |
+
es_cnt += 1
|
179 |
+
if es_cnt > opt.max_es_cnt: # early stop
|
180 |
+
with open(opt.train_log_filepath, "a") as f:
|
181 |
+
f.write("Early Stop at epoch {}".format(epoch_i))
|
182 |
+
logger.info("Early stop at {} with SVMR 0.5-r1 {}".format(epoch_i, prev_best_score))
|
183 |
+
break
|
184 |
+
else:
|
185 |
+
checkpoint = {
|
186 |
+
"model": model.state_dict(),
|
187 |
+
"model_cfg": model.config,
|
188 |
+
"epoch": epoch_i}
|
189 |
+
torch.save(checkpoint, opt.ckpt_filepath)
|
190 |
+
|
191 |
+
if opt.debug:
|
192 |
+
break
|
193 |
+
|
194 |
+
opt.writer.close()
|
195 |
+
|
196 |
+
|
197 |
+
def rm_key_from_odict(odict_obj, rm_suffix):
|
198 |
+
"""remove key entry from the OrderedDict"""
|
199 |
+
return OrderedDict([(k, v) for k, v in odict_obj.items() if rm_suffix not in k])
|
200 |
+
|
201 |
+
|
202 |
+
def start_training():
|
203 |
+
logger.info("Setup config, data and model...")
|
204 |
+
opt = BaseOptions().parse()
|
205 |
+
set_seed(opt.seed)
|
206 |
+
if opt.debug: # keep the model run deterministically
|
207 |
+
# 'cudnn.benchmark = True' enabled auto finding the best algorithm for a specific input/net config.
|
208 |
+
# Enable this only when input size is fixed.
|
209 |
+
cudnn.benchmark = False
|
210 |
+
cudnn.deterministic = True
|
211 |
+
|
212 |
+
opt.writer = SummaryWriter(opt.tensorboard_log_dir)
|
213 |
+
opt.train_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Loss] {loss_str}\n"
|
214 |
+
opt.eval_log_txt_formatter = "{time_str} [Epoch] {epoch:03d} [Metrics] {eval_metrics_str}\n"
|
215 |
+
|
216 |
+
train_dataset = ProposalRetrievalDataset(
|
217 |
+
dset_name=opt.dset_name,
|
218 |
+
model_type=opt.model_type,
|
219 |
+
data_path=opt.train_path,
|
220 |
+
desc_bert_path=opt.desc_bert_path,
|
221 |
+
sub_bert_path=opt.sub_bert_path,
|
222 |
+
max_desc_len=opt.max_desc_l,
|
223 |
+
vid_feat_path=opt.vid_feat_path,
|
224 |
+
clip_length=opt.clip_length,
|
225 |
+
vid_feat_size=opt.vid_feat_size,
|
226 |
+
sub_feat_size=opt.sub_feat_size,
|
227 |
+
ctx_mode=opt.ctx_mode,
|
228 |
+
pos_iou_thd=opt.pos_iou_thd,
|
229 |
+
neg_iou_thd=opt.neg_iou_thd,
|
230 |
+
h5driver=opt.h5driver,
|
231 |
+
data_ratio=opt.data_ratio,
|
232 |
+
normalize_vfeat=not opt.no_norm_vfeat,
|
233 |
+
normalize_tfeat=not opt.no_norm_tfeat,
|
234 |
+
external_train_vr_res_path=opt.external_train_vr_res_path, # If not None, used to guide negative sampling
|
235 |
+
corpus_path=opt.corpus_path,
|
236 |
+
)
|
237 |
+
|
238 |
+
if opt.eval_path is not None:
|
239 |
+
eval_dataset = ProposalRetrievalEvalDataset(
|
240 |
+
dset_name=opt.dset_name,
|
241 |
+
model_type=opt.model_type,
|
242 |
+
eval_split_name=opt.eval_split_name, # should only be val set
|
243 |
+
data_path=opt.eval_path,
|
244 |
+
desc_bert_path_or_handler=train_dataset.desc_bert_h5,
|
245 |
+
sub_bert_path_or_handler=train_dataset.sub_bert_h5 if "sub" in opt.ctx_mode else None,
|
246 |
+
max_desc_len=opt.max_desc_l,
|
247 |
+
corpus_path=opt.corpus_path,
|
248 |
+
vid_feat_path_or_handler=train_dataset.vid_feat_h5 if "video" in opt.ctx_mode else None,
|
249 |
+
clip_length=opt.clip_length,
|
250 |
+
eval_proposal_bsz=opt.eval_proposal_bsz,
|
251 |
+
ctx_mode=opt.ctx_mode,
|
252 |
+
data_mode="query",
|
253 |
+
h5driver=opt.h5driver,
|
254 |
+
data_ratio=opt.data_ratio,
|
255 |
+
normalize_vfeat=not opt.no_norm_vfeat,
|
256 |
+
normalize_tfeat=not opt.no_norm_tfeat,
|
257 |
+
)
|
258 |
+
else:
|
259 |
+
eval_dataset = None
|
260 |
+
|
261 |
+
model_config = EDict(
|
262 |
+
visual_input_size=train_dataset.vid_feat_output_size, # changes based on visual input type
|
263 |
+
textual_input_size=train_dataset.sub_feat_output_size,
|
264 |
+
query_feat_size=opt.desc_feat_size,
|
265 |
+
visual_hidden_size=opt.visual_hidden_size, #
|
266 |
+
output_size=opt.output_size,
|
267 |
+
embedding_size=opt.embedding_size,
|
268 |
+
lstm_hidden_size=opt.lstm_hidden_size,
|
269 |
+
margin=opt.margin, # margin for ranking loss
|
270 |
+
loss_type=opt.loss_type, # loss type, 'hinge' or 'lse'
|
271 |
+
inter_loss_weight=opt.inter_loss_weight * (opt.ctx_mode == "tef"), # weight for inter negatives
|
272 |
+
ctx_mode=opt.ctx_mode
|
273 |
+
)
|
274 |
+
logger.info("model_config {}".format(model_config))
|
275 |
+
|
276 |
+
model = CALWithSub(model_config)
|
277 |
+
if opt.device.type == "cuda":
|
278 |
+
logger.info("CUDA enabled.")
|
279 |
+
model.to(opt.device)
|
280 |
+
if len(opt.device_ids) > 1:
|
281 |
+
logger.info("Use multi GPU", opt.device_ids)
|
282 |
+
model = torch.nn.DataParallel(model, device_ids=opt.device_ids) # use multi GPU
|
283 |
+
|
284 |
+
if opt.init_ckpt_path is not None:
|
285 |
+
checkpoint = torch.load(opt.init_ckpt_path)
|
286 |
+
model.load_state_dict(checkpoint["model"])
|
287 |
+
logger.info("Loaded model saved at epoch {} from checkpoint: {}"
|
288 |
+
.format(checkpoint["epoch"], opt.init_ckpt_path))
|
289 |
+
count_parameters(model)
|
290 |
+
|
291 |
+
logger.info("Start Training...")
|
292 |
+
train(model, train_dataset, eval_dataset, opt)
|
293 |
+
return opt.results_dir, opt.eval_split_name, opt.eval_path, opt.debug
|
294 |
+
|
295 |
+
|
296 |
+
if __name__ == '__main__':
|
297 |
+
model_dir, eval_split_name, eval_path, debug = start_training()
|
298 |
+
if not debug:
|
299 |
+
model_dir = model_dir.split(os.sep)[-1]
|
300 |
+
tasks = ["SVMR", "VCMR"]
|
301 |
+
input_args = ["--model_dir", model_dir,
|
302 |
+
"--eval_split_name", eval_split_name,
|
303 |
+
"--eval_path", eval_path,
|
304 |
+
"--tasks"] + tasks
|
305 |
+
|
306 |
+
import sys
|
307 |
+
sys.argv[1:] = input_args
|
308 |
+
logger.info("\n\n\nFINISHED TRAINING!!!")
|
309 |
+
logger.info("Evaluating model in {}".format(model_dir))
|
310 |
+
start_inference()
|
baselines/crossmodal_moment_localization/README.md
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
Cross-modal Moment Localization (XML)
|
2 |
+
===
|
baselines/crossmodal_moment_localization/__init__.py
ADDED
File without changes
|
baselines/crossmodal_moment_localization/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (207 Bytes). View file
|
|
baselines/crossmodal_moment_localization/__pycache__/config.cpython-311.pyc
ADDED
Binary file (23.3 kB). View file
|
|
baselines/crossmodal_moment_localization/__pycache__/inference.cpython-311.pyc
ADDED
Binary file (24.1 kB). View file
|
|
baselines/crossmodal_moment_localization/__pycache__/model_components.cpython-311.pyc
ADDED
Binary file (19.8 kB). View file
|
|
baselines/crossmodal_moment_localization/__pycache__/model_xml.cpython-311.pyc
ADDED
Binary file (39.8 kB). View file
|
|
baselines/crossmodal_moment_localization/__pycache__/ndcg_iou_topk.cpython-311.pyc
ADDED
Binary file (5.64 kB). View file
|
|
baselines/crossmodal_moment_localization/__pycache__/optimization.cpython-311.pyc
ADDED
Binary file (18.8 kB). View file
|
|
baselines/crossmodal_moment_localization/__pycache__/start_end_dataset.cpython-311.pyc
ADDED
Binary file (19.5 kB). View file
|
|
baselines/crossmodal_moment_localization/config.py
ADDED
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import torch
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
from utils.basic_utils import mkdirp, load_json, save_json, make_zipfile
|
7 |
+
from baselines.clip_alignment_with_language.local_utils.proposal import ProposalConfigs
|
8 |
+
|
9 |
+
|
10 |
+
class BaseOptions(object):
|
11 |
+
saved_option_filename = "opt.json"
|
12 |
+
ckpt_filename = "model.ckpt"
|
13 |
+
tensorboard_log_dir = "tensorboard_log"
|
14 |
+
train_log_filename = "train.log.txt"
|
15 |
+
eval_log_filename = "eval.log.txt"
|
16 |
+
|
17 |
+
def __init__(self):
|
18 |
+
self.parser = argparse.ArgumentParser()
|
19 |
+
self.initialized = False
|
20 |
+
self.opt = None
|
21 |
+
|
22 |
+
def initialize(self):
|
23 |
+
self.initialized = True
|
24 |
+
self.parser.add_argument("--dset_name", type=str, choices=["tvr"])
|
25 |
+
self.parser.add_argument("--model_name", type=str)
|
26 |
+
self.parser.add_argument("--eval_split_name", type=str, default="val",
|
27 |
+
help="should match keys in corpus_path, must set for VCMR")
|
28 |
+
self.parser.add_argument("--debug", action="store_true",
|
29 |
+
help="debug (fast) mode, break all loops, do not load all data into memory.")
|
30 |
+
self.parser.add_argument("--data_ratio", type=float, default=1.0,
|
31 |
+
help="how many training and eval data to use. 1.0: use all, 0.1: use 10%."
|
32 |
+
"Use small portion for debug purposes. Note this is different from --debug, "
|
33 |
+
"which works by breaking the loops, typically they are not used together.")
|
34 |
+
self.parser.add_argument("--results_root", type=str, default="results")
|
35 |
+
self.parser.add_argument("--exp_id", type=str, default=None, help="id of this run, required at training")
|
36 |
+
self.parser.add_argument("--seed", type=int, default=2018, help="random seed")
|
37 |
+
self.parser.add_argument("--device", type=int, default=0, help="0 cuda, -1 cpu")
|
38 |
+
self.parser.add_argument("--device_ids", type=int, nargs="+", default=[0], help="GPU ids to run the job")
|
39 |
+
self.parser.add_argument("--num_workers", type=int, default=4,
|
40 |
+
help="num subprocesses used to load the data, 0: use main process")
|
41 |
+
self.parser.add_argument("--no_core_driver", action="store_true",
|
42 |
+
help="hdf5 driver, default use `core` (load into RAM), if specified, use `None`")
|
43 |
+
self.parser.add_argument("--no_pin_memory", action="store_true",
|
44 |
+
help="Don't use pin_memory=True for dataloader. "
|
45 |
+
"ref: https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/4")
|
46 |
+
|
47 |
+
# training config
|
48 |
+
self.parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
|
49 |
+
self.parser.add_argument("--lr_warmup_proportion", type=float, default=0.01,
|
50 |
+
help="Proportion of training to perform linear learning rate warmup for. "
|
51 |
+
"E.g., 0.1 = 10% of training.")
|
52 |
+
self.parser.add_argument("--wd", type=float, default=0.01, help="weight decay")
|
53 |
+
self.parser.add_argument("--n_epoch", type=int, default=100, help="number of epochs to run")
|
54 |
+
self.parser.add_argument("--max_es_cnt", type=int, default=10,
|
55 |
+
help="number of epochs to early stop, use -1 to disable early stop")
|
56 |
+
self.parser.add_argument("--stop_task", type=str, default="VCMR", choices=["VCMR", "SVMR", "VR"],
|
57 |
+
help="Use metric associated with stop_task for early stop")
|
58 |
+
self.parser.add_argument("--eval_tasks_at_training", type=str, nargs="+",
|
59 |
+
default=["VCMR"], choices=["VCMR", "SVMR", "VR"],
|
60 |
+
help="evaluate and report numbers for tasks specified here.")
|
61 |
+
self.parser.add_argument("--bsz", type=int, default=128, help="mini-batch size")
|
62 |
+
self.parser.add_argument("--eval_query_bsz", type=int, default=50,
|
63 |
+
help="mini-batch size at inference, for query")
|
64 |
+
self.parser.add_argument("--eval_context_bsz", type=int, default=200,
|
65 |
+
help="mini-batch size at inference, for video/sub")
|
66 |
+
self.parser.add_argument("--eval_untrained", action="store_true", help="Evaluate on un-trained model")
|
67 |
+
self.parser.add_argument("--grad_clip", type=float, default=-1, help="perform gradient clip, -1: disable")
|
68 |
+
self.parser.add_argument("--margin", type=float, default=0.1, help="margin for hinge loss")
|
69 |
+
self.parser.add_argument("--lw_neg_q", type=float, default=1,
|
70 |
+
help="weight for ranking loss with negative query and positive context")
|
71 |
+
self.parser.add_argument("--lw_neg_ctx", type=float, default=1,
|
72 |
+
help="weight for ranking loss with positive query and negative context")
|
73 |
+
self.parser.add_argument("--lw_st_ed", type=float, default=0.01, help="weight for st ed prediction loss")
|
74 |
+
self.parser.add_argument("--train_span_start_epoch", type=int, default=0,
|
75 |
+
help="which epoch to start training span prediction, -1 to disable")
|
76 |
+
self.parser.add_argument("--ranking_loss_type", type=str, default="hinge", choices=["hinge", "lse"],
|
77 |
+
help="att loss type, can be hinge loss or its smooth approximation LogSumExp")
|
78 |
+
self.parser.add_argument("--hard_negtiave_start_epoch", type=int, default=20,
|
79 |
+
help="which epoch to start hard negative sampling for video-level ranking loss,"
|
80 |
+
"use -1 to disable")
|
81 |
+
self.parser.add_argument("--hard_pool_size", type=int, default=20,
|
82 |
+
help="hard negatives are still sampled, but from a harder pool.")
|
83 |
+
|
84 |
+
# Model and Data config
|
85 |
+
self.parser.add_argument("--max_sub_l", type=int, default=50,
|
86 |
+
help="max length of all sub sentence 97.71 under 50 for 3 sentences")
|
87 |
+
self.parser.add_argument("--max_desc_l", type=int, default=30, help="max length of descriptions")
|
88 |
+
self.parser.add_argument("--max_ctx_l", type=int, default=100,
|
89 |
+
help="max number of snippets, 100 for tvr clip_length=1.5, oly 109/21825 > 100")
|
90 |
+
|
91 |
+
self.parser.add_argument("--train_path", type=str, default=None)
|
92 |
+
self.parser.add_argument("--val_path", type=str, default=None)
|
93 |
+
self.parser.add_argument("--test_path", type=str, default=None)
|
94 |
+
self.parser.add_argument("--external_inference_vr_res_path", type=str, default=None,
|
95 |
+
help="if set, use external video retrieval results to guide evaluation. ")
|
96 |
+
self.parser.add_argument("--use_glove", action="store_true", help="Use GloVe instead of BERT features")
|
97 |
+
self.parser.add_argument("--word2idx_path", type=str,
|
98 |
+
help="a dict, {word: word_idx, ...}, "
|
99 |
+
"special tokens are {<pad>: 0, <unk>: 1, <eos>: 2}")
|
100 |
+
self.parser.add_argument("--vocab_size", type=int, default=-1,
|
101 |
+
help="Set automatically to len(word2idx)")
|
102 |
+
self.parser.add_argument("--glove_path", type=str,
|
103 |
+
help="path to file containing the GloVe embeddings for words in word2idx")
|
104 |
+
self.parser.add_argument("--desc_bert_path", type=str, default=None)
|
105 |
+
self.parser.add_argument("--sub_bert_path", type=str, default=None)
|
106 |
+
self.parser.add_argument("--sub_feat_size", type=int, default=768, help="feature dim for sub feature")
|
107 |
+
self.parser.add_argument("--q_feat_size", type=int, default=768, help="feature dim for sub feature")
|
108 |
+
self.parser.add_argument("--ctx_mode", type=str, choices=["video", "sub", "video_sub", "tef",
|
109 |
+
"video_tef", "sub_tef", "video_sub_tef"],
|
110 |
+
help="which context to use. a combination of [video, sub, tef]")
|
111 |
+
self.parser.add_argument("--corpus_path", type=str, default=None)
|
112 |
+
self.parser.add_argument("--vid_feat_path", type=str, default="")
|
113 |
+
self.parser.add_argument("--no_norm_vfeat", action="store_true",
|
114 |
+
help="Do not do normalization on video feat, use it only when using resnet_i3d feat")
|
115 |
+
self.parser.add_argument("--no_norm_tfeat", action="store_true", help="Do not do normalization on text feat")
|
116 |
+
self.parser.add_argument("--clip_length", type=float, default=None,
|
117 |
+
help="each video will be uniformly segmented into small clips, "
|
118 |
+
"will automatically loaded from ProposalConfigs if None")
|
119 |
+
self.parser.add_argument("--vid_feat_size", type=int, help="feature dim for video feature")
|
120 |
+
|
121 |
+
self.parser.add_argument("--span_predictor_type", type=str, default="conv", choices=["conv", "cat_linear"],
|
122 |
+
help="how to generate span predictions, "
|
123 |
+
"conv: apply 1D-Conv layer on top of NxL dot product of query and clips"
|
124 |
+
"cat_linear: cat the query and clips then use a linear layer to give output. "
|
125 |
+
"Note cat_linear is implemented as first project query and clips into scores, "
|
126 |
+
"separately, then sum them up, this should be similar to first cat then project.")
|
127 |
+
self.parser.add_argument("--stack_conv_predictor_conv_kernel_sizes", type=int, default=-1, nargs="+",
|
128 |
+
help="combine the results from conv edge detectors of all sizes specified."
|
129 |
+
"-1: disable. If specified, will ignore --conv_kernel_size option."
|
130 |
+
"This flag is only used when --merge_two_stream and --span_predictor_type conv!")
|
131 |
+
self.parser.add_argument("--encoder_type", type=str, default="transformer",
|
132 |
+
choices=["gru", "lstm", "transformer", "cnn"])
|
133 |
+
self.parser.add_argument("--add_pe_rnn", action="store_true",
|
134 |
+
help="Add positional encoding for GRU and LSTM encoder as well")
|
135 |
+
self.parser.add_argument("--no_merge_two_stream", action="store_true", help="do not merge video and subtitles")
|
136 |
+
self.parser.add_argument("--no_cross_att", action="store_true",
|
137 |
+
help="Use cross-attention for modeling video and subtitles")
|
138 |
+
self.parser.add_argument("--no_self_att", action="store_true", help="do not use self attention")
|
139 |
+
self.parser.add_argument("--no_modular", action="store_true", help="do not use modular attention")
|
140 |
+
self.parser.add_argument("--pe_type", type=str, default="cosine", choices=["none", "linear", "cosine"],
|
141 |
+
help="Only for query encoding")
|
142 |
+
self.parser.add_argument("--max_position_embeddings", type=int, default=300)
|
143 |
+
self.parser.add_argument("--hidden_size", type=int, default=256)
|
144 |
+
self.parser.add_argument("--n_heads", type=int, default=4)
|
145 |
+
self.parser.add_argument("--input_drop", type=float, default=0.1, help="Applied to all inputs")
|
146 |
+
self.parser.add_argument("--drop", type=float, default=0.1, help="Applied to all other layers")
|
147 |
+
self.parser.add_argument("--cross_att_drop", type=float, default=0.1, help="Applied to cross-att")
|
148 |
+
self.parser.add_argument("--conv_kernel_size", type=int, default=5)
|
149 |
+
self.parser.add_argument("--conv_stride", type=int, default=1)
|
150 |
+
self.parser.add_argument("--initializer_range", type=float, default=0.02,
|
151 |
+
help="initializer range for linear layer")
|
152 |
+
self.parser.add_argument("--eval_num_per_epoch", type=float)
|
153 |
+
|
154 |
+
# post processing
|
155 |
+
self.parser.add_argument("--min_pred_l", type=int, default=2,
|
156 |
+
help="constrain the [st, ed] with ed - st >= 2"
|
157 |
+
"(2 clips with length 1.5 each, 3 secs in total"
|
158 |
+
"this is the min length for proposal-based method)")
|
159 |
+
self.parser.add_argument("--max_pred_l", type=int, default=16,
|
160 |
+
help="constrain the [st, ed] pairs with ed - st <= 16, 24 secs in total"
|
161 |
+
"(16 clips with length 1.5 each, "
|
162 |
+
"this is the max length for proposal-based method)")
|
163 |
+
self.parser.add_argument("--q2c_alpha", type=float, default=20,
|
164 |
+
help="give more importance to top scored videos' spans, "
|
165 |
+
"the new score will be: s_new = exp(alpha * s), "
|
166 |
+
"higher alpha indicates more importance. Note s in [-1, 1]")
|
167 |
+
|
168 |
+
self.parser.add_argument("--max_before_nms", type=int, default=200)
|
169 |
+
self.parser.add_argument("--max_vcmr_video", type=int, default=100,
|
170 |
+
help="re-ranking in top-max_vcmr_video")
|
171 |
+
self.parser.add_argument("--nms_thd", type=float, default=-1,
|
172 |
+
help="additionally use non-maximum suppression "
|
173 |
+
"(or non-minimum suppression for distance)"
|
174 |
+
"to post-processing the predictions. "
|
175 |
+
"-1: do not use nms. 0.6 for charades_sta, 0.5 for anet_cap,")
|
176 |
+
|
177 |
+
def display_save(self, opt):
|
178 |
+
args = vars(opt)
|
179 |
+
# Display settings
|
180 |
+
print("------------ Options -------------\n{}\n-------------------"
|
181 |
+
.format({str(k): str(v) for k, v in sorted(args.items())}))
|
182 |
+
|
183 |
+
# Save settings
|
184 |
+
if not isinstance(self, TestOptions):
|
185 |
+
option_file_path = os.path.join(opt.results_dir, self.saved_option_filename) # not yaml file indeed
|
186 |
+
save_json(args, option_file_path, save_pretty=True)
|
187 |
+
|
188 |
+
def parse(self):
|
189 |
+
if not self.initialized:
|
190 |
+
self.initialize()
|
191 |
+
opt = self.parser.parse_args()
|
192 |
+
|
193 |
+
if opt.debug:
|
194 |
+
opt.results_root = os.path.sep.join(opt.results_root.split(os.path.sep)[:-1] + ["debug_results", ])
|
195 |
+
opt.no_core_driver = True
|
196 |
+
opt.num_workers = 0
|
197 |
+
opt.eval_query_bsz = 100
|
198 |
+
|
199 |
+
if isinstance(self, TestOptions):
|
200 |
+
# modify model_dir to absolute path
|
201 |
+
opt.model_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "results", opt.model_dir)
|
202 |
+
saved_options = load_json(os.path.join(opt.model_dir, self.saved_option_filename))
|
203 |
+
for arg in saved_options: # use saved options to overwrite all BaseOptions args.
|
204 |
+
if arg not in ["results_root", "num_workers", "nms_thd", "debug",
|
205 |
+
"eval_split_name", "eval_path", "eval_query_bsz", "eval_context_bsz",
|
206 |
+
"max_pred_l", "min_pred_l", "external_inference_vr_res_path"]:
|
207 |
+
setattr(opt, arg, saved_options[arg])
|
208 |
+
# opt.no_core_driver = True
|
209 |
+
else:
|
210 |
+
if opt.exp_id is None:
|
211 |
+
raise ValueError("--exp_id is required for at a training option!")
|
212 |
+
|
213 |
+
if opt.clip_length is None:
|
214 |
+
opt.clip_length = ProposalConfigs[opt.dset_name]["clip_length"]
|
215 |
+
print("Loaded clip_length {} from proposal config file".format(opt.clip_length))
|
216 |
+
opt.results_dir = os.path.join(opt.results_root, "_".join([opt.model_name, opt.exp_id, time.strftime("%Y%m%d_%H%M%S")]))
|
217 |
+
mkdirp(opt.results_dir)
|
218 |
+
# save a copy of current code
|
219 |
+
code_dir = os.path.dirname(os.path.realpath(__file__))
|
220 |
+
code_zip_filename = os.path.join(opt.results_dir, "code.zip")
|
221 |
+
make_zipfile(code_dir, code_zip_filename,
|
222 |
+
enclosing_dir="code",
|
223 |
+
exclude_dirs_substring="results",
|
224 |
+
exclude_dirs=["results", "debug_results", "__pycache__"],
|
225 |
+
exclude_extensions=[".pyc", ".ipynb", ".swap"],)
|
226 |
+
|
227 |
+
self.display_save(opt)
|
228 |
+
|
229 |
+
if "sub" in opt.ctx_mode:
|
230 |
+
assert opt.dset_name == "tvr", "sub is only supported for tvr dataset"
|
231 |
+
|
232 |
+
if opt.hard_negtiave_start_epoch != -1:
|
233 |
+
if opt.hard_pool_size > opt.bsz:
|
234 |
+
print("[WARNING] hard_pool_size is larger than bsz")
|
235 |
+
|
236 |
+
assert opt.stop_task in opt.eval_tasks_at_training
|
237 |
+
opt.ckpt_filepath = os.path.join(opt.results_dir, self.ckpt_filename)
|
238 |
+
opt.train_log_filepath = os.path.join(opt.results_dir, self.train_log_filename)
|
239 |
+
opt.eval_log_filepath = os.path.join(opt.results_dir, self.eval_log_filename)
|
240 |
+
opt.tensorboard_log_dir = os.path.join(opt.results_dir, self.tensorboard_log_dir)
|
241 |
+
opt.device = torch.device("cuda:%d" % opt.device_ids[0] if opt.device >= 0 else "cpu")
|
242 |
+
opt.h5driver = None if opt.no_core_driver else "core"
|
243 |
+
# num_workers > 1 will only work with "core" mode, i.e., memory-mapped hdf5
|
244 |
+
opt.num_workers = 1 if opt.no_core_driver else opt.num_workers
|
245 |
+
opt.pin_memory = not opt.no_pin_memory
|
246 |
+
|
247 |
+
if "video" in opt.ctx_mode and opt.vid_feat_size > 3000: # 3072, the normalized concatenation of resnet+i3d
|
248 |
+
assert opt.no_norm_vfeat
|
249 |
+
|
250 |
+
if "tef" in opt.ctx_mode and "video" in opt.ctx_mode:
|
251 |
+
opt.vid_feat_size += 2
|
252 |
+
if "tef" in opt.ctx_mode and "sub" in opt.ctx_mode:
|
253 |
+
opt.sub_feat_size += 2
|
254 |
+
|
255 |
+
if "video" not in opt.ctx_mode or "sub" not in opt.ctx_mode:
|
256 |
+
opt.no_merge_two_stream = True
|
257 |
+
opt.no_cross_att = True
|
258 |
+
|
259 |
+
self.opt = opt
|
260 |
+
return opt
|
261 |
+
|
262 |
+
|
263 |
+
class TestOptions(BaseOptions):
|
264 |
+
"""add additional options for evaluating"""
|
265 |
+
def initialize(self):
|
266 |
+
BaseOptions.initialize(self)
|
267 |
+
# also need to specify --eval_split_name
|
268 |
+
self.parser.add_argument("--eval_id", type=str, help="evaluation id")
|
269 |
+
self.parser.add_argument("--model_dir", type=str,
|
270 |
+
help="dir contains the model file, will be converted to absolute path afterwards")
|
271 |
+
self.parser.add_argument("--tasks", type=str, nargs="+",
|
272 |
+
choices=["VCMR", "SVMR", "VR"], default=["VCMR", "SVMR", "VR"],
|
273 |
+
help="Which tasks to run."
|
274 |
+
"VCMR: Video Corpus Moment Retrieval;"
|
275 |
+
"SVMR: Single Video Moment Retrieval;"
|
276 |
+
"VR: regular Video Retrieval. (will be performed automatically with VCMR)")
|
baselines/crossmodal_moment_localization/inference.py
ADDED
@@ -0,0 +1,414 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import copy
|
3 |
+
import math
|
4 |
+
import time
|
5 |
+
import pprint
|
6 |
+
from tqdm import tqdm, trange
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torch.backends.cudnn as cudnn
|
12 |
+
from torch.utils.data import DataLoader
|
13 |
+
|
14 |
+
from baselines.crossmodal_moment_localization.config import TestOptions
|
15 |
+
from baselines.crossmodal_moment_localization.model_xml import XML
|
16 |
+
from baselines.crossmodal_moment_localization.start_end_dataset import \
|
17 |
+
start_end_collate, StartEndEvalDataset, prepare_batch_inputs
|
18 |
+
from baselines.clip_alignment_with_language.inference import \
|
19 |
+
get_submission_top_n, post_processing_vcmr_nms, post_processing_svmr_nms
|
20 |
+
from utils.basic_utils import save_json, load_json
|
21 |
+
from utils.tensor_utils import find_max_triples_from_upper_triangle_product
|
22 |
+
from standalone_eval.eval import eval_retrieval
|
23 |
+
|
24 |
+
import logging
|
25 |
+
from ndcg_iou_topk import calculate_ndcg_iou
|
26 |
+
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
+
def compute_context_info(model, eval_dataset, opt):
|
31 |
+
"""Use val set to do evaluation, remember to run with torch.no_grad().
|
32 |
+
estimated 2200 (videos) * 100 (frm) * 500 (hsz) * 4 (B) * 2 (video/sub) * 2 (layers) / (1024 ** 2) = 1.76 GB
|
33 |
+
max_n_videos: only consider max_n_videos videos for each query to return st_ed scores.
|
34 |
+
"""
|
35 |
+
model.eval()
|
36 |
+
# eval_dataset.set_data_mode("context")
|
37 |
+
context_dataloader = DataLoader(eval_dataset,
|
38 |
+
collate_fn=start_end_collate,
|
39 |
+
batch_size=opt.eval_context_bsz,
|
40 |
+
num_workers=opt.num_workers,
|
41 |
+
shuffle=False,
|
42 |
+
pin_memory=opt.pin_memory)
|
43 |
+
|
44 |
+
metas = [] # list(dicts)
|
45 |
+
video_feat1 = []
|
46 |
+
video_feat2 = []
|
47 |
+
video_mask = []
|
48 |
+
sub_feat1 = []
|
49 |
+
sub_feat2 = []
|
50 |
+
sub_mask = []
|
51 |
+
for idx, batch in tqdm(enumerate(context_dataloader),
|
52 |
+
desc="Computing query2video scores",
|
53 |
+
total=len(context_dataloader)):
|
54 |
+
metas.extend(batch[0])
|
55 |
+
model_inputs = prepare_batch_inputs(batch[1], device=opt.device, non_blocking=opt.pin_memory)
|
56 |
+
|
57 |
+
_video_feat1, _video_feat2, _sub_feat1, _sub_feat2 = model.encode_context(
|
58 |
+
model_inputs["video_feat"], model_inputs["video_mask"],
|
59 |
+
model_inputs["sub_feat"], model_inputs["sub_mask"])
|
60 |
+
if "video" in opt.ctx_mode:
|
61 |
+
video_feat1.append(_video_feat1)
|
62 |
+
video_feat2.append(_video_feat2)
|
63 |
+
video_mask.append(model_inputs["video_mask"])
|
64 |
+
if "sub" in opt.ctx_mode:
|
65 |
+
sub_feat1.append(_sub_feat1)
|
66 |
+
sub_feat2.append(_sub_feat2)
|
67 |
+
sub_mask.append(model_inputs["sub_mask"])
|
68 |
+
|
69 |
+
def cat_tensor(tensor_list):
|
70 |
+
if len(tensor_list) == 0:
|
71 |
+
return None
|
72 |
+
else:
|
73 |
+
seq_l = [e.shape[1] for e in tensor_list]
|
74 |
+
b_sizes = [e.shape[0] for e in tensor_list]
|
75 |
+
b_sizes_cumsum = np.cumsum([0] + b_sizes)
|
76 |
+
if len(tensor_list[0].shape) == 3:
|
77 |
+
hsz = tensor_list[0].shape[2]
|
78 |
+
res_tensor = tensor_list[0].new_zeros(sum(b_sizes), max(seq_l), hsz)
|
79 |
+
elif len(tensor_list[0].shape) == 2:
|
80 |
+
res_tensor = tensor_list[0].new_zeros(sum(b_sizes), max(seq_l))
|
81 |
+
else:
|
82 |
+
raise ValueError("Only support 2/3 dimensional tensors")
|
83 |
+
for i, e in enumerate(tensor_list):
|
84 |
+
res_tensor[b_sizes_cumsum[i]:b_sizes_cumsum[i+1], :seq_l[i]] = e
|
85 |
+
return res_tensor
|
86 |
+
|
87 |
+
return metas, dict(
|
88 |
+
video_feat1=cat_tensor(video_feat1), # (N_videos, L, hsz),
|
89 |
+
video_feat2=cat_tensor(video_feat2),
|
90 |
+
video_mask=cat_tensor(video_mask), # (N_videos, L)
|
91 |
+
sub_feat1=cat_tensor(sub_feat1),
|
92 |
+
sub_feat2=cat_tensor(sub_feat2),
|
93 |
+
sub_mask=cat_tensor(sub_mask),
|
94 |
+
)
|
95 |
+
|
96 |
+
|
97 |
+
def index_if_not_none(input_tensor, indices):
|
98 |
+
if input_tensor is None:
|
99 |
+
return input_tensor
|
100 |
+
else:
|
101 |
+
return input_tensor[indices]
|
102 |
+
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
+
def generate_min_max_length_mask(array_shape, min_l, max_l):
|
107 |
+
""" The last two dimension denotes matrix of upper-triangle with upper-right corner masked,
|
108 |
+
below is the case for 4x4.
|
109 |
+
[[0, 1, 1, 0],
|
110 |
+
[0, 0, 1, 1],
|
111 |
+
[0, 0, 0, 1],
|
112 |
+
[0, 0, 0, 0]]
|
113 |
+
|
114 |
+
Args:
|
115 |
+
array_shape: np.shape??? The last two dimensions should be the same
|
116 |
+
min_l: int, minimum length of predicted span
|
117 |
+
max_l: int, maximum length of predicted span
|
118 |
+
|
119 |
+
Returns:
|
120 |
+
|
121 |
+
"""
|
122 |
+
single_dims = (1, ) * (len(array_shape) - 2)
|
123 |
+
mask_shape = single_dims + array_shape[-2:]
|
124 |
+
extra_length_mask_array = np.ones(mask_shape, dtype=np.float32) # (1, ..., 1, L, L)
|
125 |
+
mask_triu = np.triu(extra_length_mask_array, k=min_l)
|
126 |
+
mask_triu_reversed = 1 - np.triu(extra_length_mask_array, k=max_l)
|
127 |
+
final_prob_mask = mask_triu * mask_triu_reversed
|
128 |
+
return final_prob_mask # with valid bit to be 1
|
129 |
+
|
130 |
+
|
131 |
+
def get_svmr_res_from_st_ed_probs(svmr_gt_st_probs, svmr_gt_ed_probs, query_metas, video2idx,
|
132 |
+
clip_length, min_pred_l, max_pred_l, max_before_nms):
|
133 |
+
"""
|
134 |
+
Args:
|
135 |
+
svmr_gt_st_probs: np.ndarray (N_queries, L, L), value range [0, 1]
|
136 |
+
svmr_gt_ed_probs:
|
137 |
+
query_metas:
|
138 |
+
video2idx:
|
139 |
+
clip_length: float, how long each clip is in seconds
|
140 |
+
min_pred_l: int, minimum number of clips
|
141 |
+
max_pred_l: int, maximum number of clips
|
142 |
+
max_before_nms: get top-max_before_nms predictions for each query
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
|
146 |
+
"""
|
147 |
+
svmr_res = []
|
148 |
+
query_vid_names = [e["vid_name"] for e in query_metas]
|
149 |
+
|
150 |
+
# masking very long ones! Since most are relatively short.
|
151 |
+
st_ed_prob_product = np.einsum("bm,bn->bmn", svmr_gt_st_probs, svmr_gt_ed_probs) # (N, L, L)
|
152 |
+
# extra_length_mask_array = np.ones(st_ed_prob_product.shape, dtype=bool) # (N, L, L)
|
153 |
+
# mask_triu = np.triu(extra_length_mask_array, k=min_pred_l)
|
154 |
+
# mask_triu_reversed = np.logical_not(np.triu(extra_length_mask_array, k=max_pred_l))
|
155 |
+
# final_prob_mask = np.logical_and(mask_triu, mask_triu_reversed) # with valid bit to be 1
|
156 |
+
valid_prob_mask = generate_min_max_length_mask(st_ed_prob_product.shape, min_l=min_pred_l, max_l=max_pred_l)
|
157 |
+
st_ed_prob_product *= valid_prob_mask # invalid location will become zero!
|
158 |
+
|
159 |
+
batched_sorted_triples = find_max_triples_from_upper_triangle_product(
|
160 |
+
st_ed_prob_product, top_n=max_before_nms, prob_thd=None)
|
161 |
+
for i, q_vid_name in tqdm(enumerate(query_vid_names),
|
162 |
+
desc="[SVMR] Loop over queries to generate predictions",
|
163 |
+
total=len(query_vid_names)): # i is query_id
|
164 |
+
q_m = query_metas[i]
|
165 |
+
video_idx = video2idx[q_vid_name]
|
166 |
+
_sorted_triples = batched_sorted_triples[i]
|
167 |
+
_sorted_triples[:, 1] += 1 # as we redefined ed_idx, which is inside the moment.
|
168 |
+
_sorted_triples[:, :2] = _sorted_triples[:, :2] * clip_length
|
169 |
+
# [video_idx(int), st(float), ed(float), score(float)]
|
170 |
+
cur_ranked_predictions = [[video_idx, ] + row for row in _sorted_triples.tolist()]
|
171 |
+
cur_query_pred = dict(
|
172 |
+
query_id=q_m["query_id"],
|
173 |
+
desc=q_m["desc"],
|
174 |
+
predictions=cur_ranked_predictions
|
175 |
+
)
|
176 |
+
svmr_res.append(cur_query_pred)
|
177 |
+
return svmr_res
|
178 |
+
|
179 |
+
|
180 |
+
def load_external_vr_res2(external_vr_res_path, top_n_vr_videos=5):
|
181 |
+
"""return a mapping from query_id to top retrieved video info"""
|
182 |
+
external_vr_res = load_json(external_vr_res_path)
|
183 |
+
external_vr_res = get_submission_top_n(external_vr_res, top_n=top_n_vr_videos)["VR"]
|
184 |
+
query2video = {e["query_id"]: e["predictions"] for e in external_vr_res}
|
185 |
+
return query2video
|
186 |
+
|
187 |
+
|
188 |
+
def compute_query2ctx_info(model, eval_dataset, opt, video_metas, ctx_info,
|
189 |
+
max_before_nms=1000, max_n_videos=100, maxtopk=40):
|
190 |
+
"""Use val set to do evaluation, remember to run with torch.no_grad().
|
191 |
+
estimated size 20,000 (query) * 500 (hsz) * 4 / (1024**2) = 38.15 MB
|
192 |
+
max_n_videos: int, use max_n_videos videos for computing VCMR/VR results
|
193 |
+
"""
|
194 |
+
|
195 |
+
video2idx = eval_dataset.video2idx
|
196 |
+
# video_metas = ctx_info["video_metas"]
|
197 |
+
if opt.external_inference_vr_res_path is not None:
|
198 |
+
video_idx2meta_idx = {video2idx[m["vid_name"]]: i for i, m in enumerate(video_metas)}
|
199 |
+
external_query2video = \
|
200 |
+
load_external_vr_res2(opt.external_inference_vr_res_path, top_n_vr_videos=max_n_videos)
|
201 |
+
# 「query idx: [video meta idx]」
|
202 |
+
external_query2video_meta_idx = \
|
203 |
+
{k: [video_idx2meta_idx[e[0]] for e in v] for k, v in external_query2video.items()}
|
204 |
+
else:
|
205 |
+
external_query2video = None
|
206 |
+
external_query2video_meta_idx = None
|
207 |
+
|
208 |
+
model.eval()
|
209 |
+
eval_dataset.set_data_mode("query")
|
210 |
+
# eval_dataset.load_gt_vid_name_for_query(is_svmr)
|
211 |
+
query_eval_loader = DataLoader(eval_dataset,
|
212 |
+
collate_fn=start_end_collate,
|
213 |
+
batch_size=opt.eval_query_bsz,
|
214 |
+
num_workers=opt.num_workers,
|
215 |
+
shuffle=False,
|
216 |
+
pin_memory=opt.pin_memory)
|
217 |
+
n_total_videos = len(video_metas)
|
218 |
+
n_total_query = len(eval_dataset)
|
219 |
+
bsz = opt.eval_query_bsz
|
220 |
+
|
221 |
+
flat_st_ed_scores_sorted_indices = np.empty((n_total_query, max_before_nms), dtype=int)
|
222 |
+
flat_st_ed_sorted_scores = np.zeros((n_total_query, max_before_nms), dtype=np.float32)
|
223 |
+
sorted_q2c_indices = np.empty((n_total_query, max_n_videos), dtype=int)
|
224 |
+
sorted_q2c_scores = np.empty((n_total_query, max_n_videos), dtype=np.float32)
|
225 |
+
|
226 |
+
|
227 |
+
query_metas = []
|
228 |
+
for idx, batch in tqdm(
|
229 |
+
enumerate(query_eval_loader), desc="Computing q embedding", total=len(query_eval_loader)):
|
230 |
+
_query_metas = batch[0]
|
231 |
+
query_metas.extend(batch[0])
|
232 |
+
model_inputs = prepare_batch_inputs(batch[1], device=opt.device, non_blocking=opt.pin_memory)
|
233 |
+
# query_context_scores (_N_q, N_videos), st_prob, ed_prob (_N_q, N_videos, L)
|
234 |
+
_query_context_scores, _st_probs, _ed_probs = \
|
235 |
+
model.get_pred_from_raw_query(model_inputs["query_feat"], model_inputs["query_mask"],
|
236 |
+
ctx_info["video_feat1"], ctx_info["video_feat2"],
|
237 |
+
ctx_info["video_mask"],
|
238 |
+
ctx_info["sub_feat1"], ctx_info["sub_feat2"],
|
239 |
+
ctx_info["sub_mask"],
|
240 |
+
cross=True)
|
241 |
+
# _query_context_scores = _query_context_scores + 1 # move cosine similarity to [0, 2]
|
242 |
+
# To give more importance to top scores, the higher opt.alpha is the more importance will be given
|
243 |
+
_query_context_scores = torch.exp(opt.q2c_alpha * _query_context_scores)
|
244 |
+
|
245 |
+
# normalize to get true probabilities!!!
|
246 |
+
# the probabilities here are already (pad) masked, so only need to do softmax
|
247 |
+
_st_probs = F.softmax(_st_probs, dim=-1) # (_N_q, N_videos, L)
|
248 |
+
_ed_probs = F.softmax(_ed_probs, dim=-1)
|
249 |
+
|
250 |
+
if external_query2video is None:
|
251 |
+
_sorted_q2c_scores, _sorted_q2c_indices = \
|
252 |
+
torch.topk(_query_context_scores, max_n_videos, dim=1, largest=True)
|
253 |
+
else:
|
254 |
+
relevant_video_info = [external_query2video[qm["query_id"]] for qm in _query_metas]
|
255 |
+
_sorted_q2c_indices = _query_context_scores.new(
|
256 |
+
[[video_idx2meta_idx[sub_e[0]] for sub_e in e] for e in relevant_video_info]).long()
|
257 |
+
_sorted_q2c_scores = _query_context_scores.new(
|
258 |
+
[[sub_e[3] for sub_e in e] for e in relevant_video_info])
|
259 |
+
_sorted_q2c_scores = torch.exp(opt.q2c_alpha * _sorted_q2c_scores)
|
260 |
+
# collect data for vr and vcmr
|
261 |
+
sorted_q2c_indices[idx * bsz:(idx + 1) * bsz] = _sorted_q2c_indices.cpu().numpy()
|
262 |
+
sorted_q2c_scores[idx * bsz:(idx + 1) * bsz] = _sorted_q2c_scores.cpu().numpy()
|
263 |
+
|
264 |
+
|
265 |
+
# Get VCMR results
|
266 |
+
# compute combined scores
|
267 |
+
row_indices = torch.arange(0, len(_st_probs), device=opt.device).unsqueeze(1)
|
268 |
+
_st_probs = _st_probs[row_indices, _sorted_q2c_indices] # (_N_q, max_n_videos, L)
|
269 |
+
_ed_probs = _ed_probs[row_indices, _sorted_q2c_indices]
|
270 |
+
|
271 |
+
# (_N_q, max_n_videos, L, L)
|
272 |
+
_st_ed_scores = torch.einsum("qvm,qv,qvn->qvmn", _st_probs, _sorted_q2c_scores, _ed_probs)
|
273 |
+
valid_prob_mask = generate_min_max_length_mask(
|
274 |
+
_st_ed_scores.shape, min_l=opt.min_pred_l, max_l=opt.max_pred_l)
|
275 |
+
_st_ed_scores *= torch.from_numpy(
|
276 |
+
valid_prob_mask).to(_st_ed_scores.device) # invalid location will become zero!
|
277 |
+
|
278 |
+
# sort across the top-max_n_videos videos (by flatten from the 2nd dim)
|
279 |
+
# the indices here are local indices, not global indices
|
280 |
+
_n_q = _st_ed_scores.shape[0]
|
281 |
+
_flat_st_ed_scores = _st_ed_scores.reshape(_n_q, -1) # (N_q, max_n_videos*L*L)
|
282 |
+
_flat_st_ed_sorted_scores, _flat_st_ed_scores_sorted_indices = \
|
283 |
+
torch.sort(_flat_st_ed_scores, dim=1, descending=True)
|
284 |
+
# collect data
|
285 |
+
flat_st_ed_sorted_scores[idx * bsz:(idx + 1) * bsz] = \
|
286 |
+
_flat_st_ed_sorted_scores[:, :max_before_nms].cpu().numpy()
|
287 |
+
flat_st_ed_scores_sorted_indices[idx * bsz:(idx + 1) * bsz] = \
|
288 |
+
_flat_st_ed_scores_sorted_indices[:, :max_before_nms].cpu().numpy()
|
289 |
+
|
290 |
+
if opt.debug:
|
291 |
+
break
|
292 |
+
|
293 |
+
|
294 |
+
vcmr_res = {}
|
295 |
+
for i, (_flat_st_ed_scores_sorted_indices, _flat_st_ed_sorted_scores) in tqdm(
|
296 |
+
enumerate(zip(flat_st_ed_scores_sorted_indices, flat_st_ed_sorted_scores)),
|
297 |
+
desc="[VCMR] Loop over queries to generate predictions", total=n_total_query): # i is query_idx
|
298 |
+
# list([video_idx(int), st(float), ed(float), score(float)])
|
299 |
+
video_meta_indices_local, pred_st_indices, pred_ed_indices = \
|
300 |
+
np.unravel_index(_flat_st_ed_scores_sorted_indices,
|
301 |
+
shape=(max_n_videos, opt.max_ctx_l, opt.max_ctx_l))
|
302 |
+
# video_meta_indices_local refers to the indices among the top-max_n_videos
|
303 |
+
# video_meta_indices refers to the indices in all the videos, which is the True indices
|
304 |
+
video_meta_indices = sorted_q2c_indices[i, video_meta_indices_local]
|
305 |
+
|
306 |
+
pred_st_in_seconds = pred_st_indices.astype(np.float32) * opt.clip_length
|
307 |
+
pred_ed_in_seconds = pred_ed_indices.astype(np.float32) * opt.clip_length + opt.clip_length
|
308 |
+
cur_vcmr_redictions = []
|
309 |
+
for j, (v_meta_idx, v_score) in enumerate(zip(video_meta_indices, _flat_st_ed_sorted_scores)): # videos
|
310 |
+
video_idx = video2idx[video_metas[v_meta_idx]["vid_name"]]
|
311 |
+
cur_vcmr_redictions.append(
|
312 |
+
{
|
313 |
+
"video_name": video_metas[v_meta_idx]["vid_name"],
|
314 |
+
"timestamp": [float(pred_st_in_seconds[j]), float(pred_ed_in_seconds[j])],
|
315 |
+
"model_scores": float(v_score)
|
316 |
+
}
|
317 |
+
)
|
318 |
+
query_id=query_metas[i]["query_id"]
|
319 |
+
vcmr_res[query_id] = cur_vcmr_redictions[:maxtopk]
|
320 |
+
return vcmr_res
|
321 |
+
|
322 |
+
|
323 |
+
def get_eval_res(model, eval_dataset, context_data, opt, maxtopk):
|
324 |
+
"""compute and save query and video proposal embeddings"""
|
325 |
+
|
326 |
+
video_metas, context_info = compute_context_info(model, context_data, opt)
|
327 |
+
eval_res = compute_query2ctx_info(model, eval_dataset, opt, video_metas, context_info,
|
328 |
+
max_before_nms=opt.max_before_nms, max_n_videos=opt.max_vcmr_video, maxtopk=maxtopk)
|
329 |
+
return eval_res
|
330 |
+
|
331 |
+
|
332 |
+
POST_PROCESSING_MMS_FUNC = {
|
333 |
+
"SVMR": post_processing_svmr_nms,
|
334 |
+
"VCMR": post_processing_vcmr_nms
|
335 |
+
}
|
336 |
+
|
337 |
+
# def get_prediction_top_n(list_dict_predictions, top_n):
|
338 |
+
# top_n_res = []
|
339 |
+
# for e in list_dict_predictions:
|
340 |
+
# e["predictions"] = e["predictions"][:top_n]
|
341 |
+
# top_n_res.append(e)
|
342 |
+
# return top_n_res
|
343 |
+
|
344 |
+
|
345 |
+
def eval_epoch(model, eval_dataset, context_data, logger, opt, max_after_nms, iou_thds, topks):
|
346 |
+
"""max_after_nms: always set to 100, since the eval script only evaluate top-100"""
|
347 |
+
# IOU_THDS = (0.3, 0.5, 0.7)
|
348 |
+
|
349 |
+
model.eval()
|
350 |
+
pred_data = get_eval_res(model, eval_dataset, context_data, opt, max(topks))
|
351 |
+
# pred_data = get_prediction_top_n(eval_res, top_n=max_after_nms)
|
352 |
+
gt_data = eval_dataset.ground_truth
|
353 |
+
average_ndcg = calculate_ndcg_iou(gt_data, pred_data, iou_thds, topks)
|
354 |
+
return average_ndcg, pred_data
|
355 |
+
|
356 |
+
def setup_model(opt):
|
357 |
+
"""Load model from checkpoint and move to specified device"""
|
358 |
+
checkpoint = torch.load(opt.ckpt_filepath)
|
359 |
+
loaded_model_cfg = checkpoint["model_cfg"]
|
360 |
+
loaded_model_cfg["stack_conv_predictor_conv_kernel_sizes"] = -1
|
361 |
+
model = XML(loaded_model_cfg)
|
362 |
+
model.load_state_dict(checkpoint["model"])
|
363 |
+
logger.info("Loaded model saved at epoch {} from checkpoint: {}"
|
364 |
+
.format(checkpoint["epoch"], opt.ckpt_filepath))
|
365 |
+
|
366 |
+
if opt.device.type == "cuda":
|
367 |
+
logger.info("CUDA enabled.")
|
368 |
+
model.to(opt.device)
|
369 |
+
if len(opt.device_ids) > 1:
|
370 |
+
logger.info("Use multi GPU", opt.device_ids)
|
371 |
+
model = torch.nn.DataParallel(model, device_ids=opt.device_ids) # use multi GPU
|
372 |
+
return model
|
373 |
+
|
374 |
+
|
375 |
+
def start_inference():
|
376 |
+
logger.info("Setup config, data and model...")
|
377 |
+
opt = TestOptions().parse()
|
378 |
+
cudnn.benchmark = False
|
379 |
+
cudnn.deterministic = True
|
380 |
+
|
381 |
+
assert opt.eval_path is not None
|
382 |
+
eval_dataset = StartEndEvalDataset(
|
383 |
+
dset_name=opt.dset_name,
|
384 |
+
eval_split_name=opt.eval_split_name, # should only be val set
|
385 |
+
data_path=opt.eval_path,
|
386 |
+
desc_bert_path_or_handler=opt.desc_bert_path,
|
387 |
+
sub_bert_path_or_handler=opt.sub_bert_path,
|
388 |
+
max_desc_len=opt.max_desc_l,
|
389 |
+
max_ctx_len=opt.max_ctx_l,
|
390 |
+
corpus_path=opt.corpus_path,
|
391 |
+
vid_feat_path_or_handler=opt.vid_feat_path,
|
392 |
+
clip_length=opt.clip_length,
|
393 |
+
ctx_mode=opt.ctx_mode,
|
394 |
+
data_mode="query",
|
395 |
+
h5driver=opt.h5driver,
|
396 |
+
data_ratio=opt.data_ratio,
|
397 |
+
normalize_vfeat=not opt.no_norm_vfeat,
|
398 |
+
normalize_tfeat=not opt.no_norm_tfeat
|
399 |
+
)
|
400 |
+
|
401 |
+
model = setup_model(opt)
|
402 |
+
save_submission_filename = "inference_{}_{}_{}_predictions_{}.json".format(
|
403 |
+
opt.dset_name, opt.eval_split_name, opt.eval_id, "_".join(opt.tasks))
|
404 |
+
logger.info("Starting inference...")
|
405 |
+
with torch.no_grad():
|
406 |
+
metrics_no_nms, metrics_nms, latest_file_paths = \
|
407 |
+
eval_epoch(model, eval_dataset, opt, save_submission_filename,
|
408 |
+
tasks=opt.tasks, max_after_nms=100)
|
409 |
+
logger.info("metrics_no_nms \n{}".format(pprint.pformat(metrics_no_nms, indent=4)))
|
410 |
+
logger.info("metrics_nms \n{}".format(pprint.pformat(metrics_nms, indent=4)))
|
411 |
+
|
412 |
+
|
413 |
+
if __name__ == '__main__':
|
414 |
+
start_inference()
|
baselines/crossmodal_moment_localization/model_components.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
class DepthwiseSeparableConv(nn.Module):
|
8 |
+
"""
|
9 |
+
Depth-wise separable convolution uses less parameters to generate output by convolution.
|
10 |
+
:Examples:
|
11 |
+
>>> m = DepthwiseSeparableConv(300, 200, 5, dim=1)
|
12 |
+
>>> input_tensor = torch.randn(32, 300, 20)
|
13 |
+
>>> output = m(input_tensor)
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, in_ch, out_ch, k, dim=1, relu=True):
|
17 |
+
"""
|
18 |
+
:param in_ch: input hidden dimension size
|
19 |
+
:param out_ch: output hidden dimension size
|
20 |
+
:param k: kernel size
|
21 |
+
:param dim: default 1. 1D conv or 2D conv
|
22 |
+
"""
|
23 |
+
super(DepthwiseSeparableConv, self).__init__()
|
24 |
+
self.relu = relu
|
25 |
+
if dim == 1:
|
26 |
+
self.depthwise_conv = nn.Conv1d(in_channels=in_ch, out_channels=in_ch,
|
27 |
+
kernel_size=k, groups=in_ch, padding=k//2)
|
28 |
+
self.pointwise_conv = nn.Conv1d(in_channels=in_ch, out_channels=out_ch,
|
29 |
+
kernel_size=1, padding=0)
|
30 |
+
elif dim == 2:
|
31 |
+
self.depthwise_conv = nn.Conv2d(in_channels=in_ch, out_channels=in_ch,
|
32 |
+
kernel_size=k, groups=in_ch, padding=k//2)
|
33 |
+
self.pointwise_conv = nn.Conv2d(in_channels=in_ch, out_channels=out_ch,
|
34 |
+
kernel_size=1, padding=0)
|
35 |
+
else:
|
36 |
+
raise Exception("Incorrect dimension!")
|
37 |
+
|
38 |
+
def forward(self, x):
|
39 |
+
"""
|
40 |
+
:Input: (N, L_in, D)
|
41 |
+
:Output: (N, L_out, D)
|
42 |
+
"""
|
43 |
+
x = x.transpose(1, 2)
|
44 |
+
if self.relu:
|
45 |
+
out = F.relu(self.pointwise_conv(self.depthwise_conv(x)), inplace=True)
|
46 |
+
else:
|
47 |
+
out = self.pointwise_conv(self.depthwise_conv(x))
|
48 |
+
return out.transpose(1, 2) # (N, L, D)
|
49 |
+
|
50 |
+
|
51 |
+
class ConvEncoder(nn.Module):
|
52 |
+
def __init__(self, kernel_size=7, n_filters=128, dropout=0.1):
|
53 |
+
super(ConvEncoder, self).__init__()
|
54 |
+
self.dropout = nn.Dropout(dropout)
|
55 |
+
self.layer_norm = nn.LayerNorm(n_filters)
|
56 |
+
self.conv = DepthwiseSeparableConv(in_ch=n_filters, out_ch=n_filters, k=kernel_size, relu=True)
|
57 |
+
|
58 |
+
def forward(self, x, mask):
|
59 |
+
"""
|
60 |
+
:param x: (N, L, D)
|
61 |
+
:param mask: (N, L), is not used.
|
62 |
+
:return: (N, L, D)
|
63 |
+
"""
|
64 |
+
return self.layer_norm(self.dropout(self.conv(x)) + x) # (N, L, D)
|
65 |
+
|
66 |
+
|
67 |
+
class TrainablePositionalEncoding(nn.Module):
|
68 |
+
"""Construct the embeddings from word, position and token_type embeddings.
|
69 |
+
"""
|
70 |
+
def __init__(self, max_position_embeddings, hidden_size, dropout=0.1):
|
71 |
+
super(TrainablePositionalEncoding, self).__init__()
|
72 |
+
self.position_embeddings = nn.Embedding(max_position_embeddings, hidden_size)
|
73 |
+
self.LayerNorm = nn.LayerNorm(hidden_size)
|
74 |
+
self.dropout = nn.Dropout(dropout)
|
75 |
+
|
76 |
+
def forward(self, input_feat):
|
77 |
+
"""
|
78 |
+
Args:
|
79 |
+
input_feat: (N, L, D)
|
80 |
+
"""
|
81 |
+
bsz, seq_length = input_feat.shape[:2]
|
82 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_feat.device)
|
83 |
+
position_ids = position_ids.unsqueeze(0).repeat(bsz, 1) # (N, L)
|
84 |
+
|
85 |
+
position_embeddings = self.position_embeddings(position_ids)
|
86 |
+
|
87 |
+
embeddings = self.LayerNorm(input_feat + position_embeddings)
|
88 |
+
embeddings = self.dropout(embeddings)
|
89 |
+
return embeddings
|
90 |
+
|
91 |
+
|
92 |
+
class PositionEncoding(nn.Module):
|
93 |
+
"""
|
94 |
+
Add positional information to input tensor.
|
95 |
+
:Examples:
|
96 |
+
>>> model = PositionEncoding(n_filters=6, max_len=10)
|
97 |
+
>>> test_input1 = torch.zeros(3, 10, 6)
|
98 |
+
>>> output1 = model(test_input1)
|
99 |
+
>>> output1.size()
|
100 |
+
>>> test_input2 = torch.zeros(5, 3, 9, 6)
|
101 |
+
>>> output2 = model(test_input2)
|
102 |
+
>>> output2.size()
|
103 |
+
"""
|
104 |
+
|
105 |
+
def __init__(self, n_filters=128, max_len=500, pe_type="cosine"):
|
106 |
+
"""
|
107 |
+
:param n_filters: same with input hidden size
|
108 |
+
:param max_len: maximum sequence length
|
109 |
+
:param pe_type: cosine or linear or None
|
110 |
+
"""
|
111 |
+
super(PositionEncoding, self).__init__()
|
112 |
+
self.pe_type = pe_type
|
113 |
+
if pe_type != "none":
|
114 |
+
position = torch.arange(0, max_len).float().unsqueeze(1)
|
115 |
+
if pe_type == "cosine":
|
116 |
+
# Compute the positional encodings once in log space.
|
117 |
+
pe = torch.zeros(max_len, n_filters) # (L, D)
|
118 |
+
div_term = torch.exp(torch.arange(0, n_filters, 2).float() * - (math.log(10000.0) / n_filters))
|
119 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
120 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
121 |
+
elif pe_type == "linear":
|
122 |
+
pe = position / max_len
|
123 |
+
else:
|
124 |
+
raise ValueError
|
125 |
+
self.register_buffer("pe", pe) # buffer is a tensor, not a variable, (L, D)
|
126 |
+
|
127 |
+
def forward(self, x):
|
128 |
+
"""
|
129 |
+
:Input: (*, L, D)
|
130 |
+
:Output: (*, L, D) the same size as input
|
131 |
+
"""
|
132 |
+
if self.pe_type != "none":
|
133 |
+
pe = self.pe.data[:x.size(-2), :] # (#x.size(-2), n_filters)
|
134 |
+
extra_dim = len(x.size()) - 2
|
135 |
+
for _ in range(extra_dim):
|
136 |
+
pe = pe.unsqueeze(0)
|
137 |
+
x = x + pe
|
138 |
+
return x
|
139 |
+
|
140 |
+
|
141 |
+
class LinearLayer(nn.Module):
|
142 |
+
"""linear layer configurable with layer normalization, dropout, ReLU."""
|
143 |
+
|
144 |
+
def __init__(self, in_hsz, out_hsz, layer_norm=True, dropout=0.1, relu=True):
|
145 |
+
super(LinearLayer, self).__init__()
|
146 |
+
self.relu = relu
|
147 |
+
self.layer_norm = layer_norm
|
148 |
+
if layer_norm:
|
149 |
+
self.LayerNorm = nn.LayerNorm(in_hsz)
|
150 |
+
layers = [
|
151 |
+
nn.Dropout(dropout),
|
152 |
+
nn.Linear(in_hsz, out_hsz)
|
153 |
+
]
|
154 |
+
self.net = nn.Sequential(*layers)
|
155 |
+
|
156 |
+
def forward(self, x):
|
157 |
+
"""(N, L, D)"""
|
158 |
+
if self.layer_norm:
|
159 |
+
x = self.LayerNorm(x)
|
160 |
+
x = self.net(x)
|
161 |
+
if self.relu:
|
162 |
+
x = F.relu(x, inplace=True)
|
163 |
+
return x # (N, L, D)
|
164 |
+
|
165 |
+
|
166 |
+
bert_config = dict(
|
167 |
+
hidden_size=768,
|
168 |
+
intermediate_size=768,
|
169 |
+
hidden_dropout_prob=0.1,
|
170 |
+
attention_probs_dropout_prob=0.1,
|
171 |
+
num_attention_heads=4,
|
172 |
+
)
|
173 |
+
|
174 |
+
|
175 |
+
class BertLayer(nn.Module):
|
176 |
+
def __init__(self, config, use_self_attention=True):
|
177 |
+
super(BertLayer, self).__init__()
|
178 |
+
self.use_self_attention = use_self_attention
|
179 |
+
if use_self_attention:
|
180 |
+
self.attention = BertAttention(config)
|
181 |
+
self.intermediate = BertIntermediate(config)
|
182 |
+
self.output = BertOutput(config)
|
183 |
+
|
184 |
+
def forward(self, hidden_states, attention_mask):
|
185 |
+
"""
|
186 |
+
Args:
|
187 |
+
hidden_states: (N, L, D)
|
188 |
+
attention_mask: (N, L) with 1 indicate valid, 0 indicates invalid
|
189 |
+
Returns:
|
190 |
+
|
191 |
+
"""
|
192 |
+
if self.use_self_attention:
|
193 |
+
attention_output = self.attention(hidden_states, attention_mask)
|
194 |
+
else:
|
195 |
+
attention_output = hidden_states
|
196 |
+
intermediate_output = self.intermediate(attention_output)
|
197 |
+
layer_output = self.output(intermediate_output, attention_output)
|
198 |
+
return layer_output
|
199 |
+
|
200 |
+
|
201 |
+
class BertAttention(nn.Module):
|
202 |
+
def __init__(self, config):
|
203 |
+
super(BertAttention, self).__init__()
|
204 |
+
self.self = BertSelfAttention(config)
|
205 |
+
self.output = BertSelfOutput(config)
|
206 |
+
|
207 |
+
def forward(self, input_tensor, attention_mask):
|
208 |
+
"""
|
209 |
+
Args:
|
210 |
+
input_tensor: (N, L, D)
|
211 |
+
attention_mask: (N, L)
|
212 |
+
Returns:
|
213 |
+
"""
|
214 |
+
self_output = self.self(input_tensor, input_tensor, input_tensor, attention_mask)
|
215 |
+
attention_output = self.output(self_output, input_tensor)
|
216 |
+
return attention_output
|
217 |
+
|
218 |
+
|
219 |
+
class BertIntermediate(nn.Module):
|
220 |
+
def __init__(self, config):
|
221 |
+
super(BertIntermediate, self).__init__()
|
222 |
+
self.dense = nn.Sequential(
|
223 |
+
nn.Linear(config.hidden_size, config.intermediate_size),
|
224 |
+
nn.ReLU(True))
|
225 |
+
|
226 |
+
def forward(self, hidden_states):
|
227 |
+
return self.dense(hidden_states)
|
228 |
+
|
229 |
+
|
230 |
+
class BertOutput(nn.Module):
|
231 |
+
def __init__(self, config):
|
232 |
+
super(BertOutput, self).__init__()
|
233 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
234 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size)
|
235 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
236 |
+
|
237 |
+
def forward(self, hidden_states, input_tensor):
|
238 |
+
hidden_states = self.dense(hidden_states)
|
239 |
+
hidden_states = self.dropout(hidden_states)
|
240 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
241 |
+
return hidden_states
|
242 |
+
|
243 |
+
|
244 |
+
class BertSelfAttention(nn.Module):
|
245 |
+
def __init__(self, config):
|
246 |
+
super(BertSelfAttention, self).__init__()
|
247 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
248 |
+
raise ValueError(
|
249 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
250 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads))
|
251 |
+
self.num_attention_heads = config.num_attention_heads
|
252 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
253 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
254 |
+
|
255 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
256 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
257 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
258 |
+
|
259 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
260 |
+
|
261 |
+
def transpose_for_scores(self, x):
|
262 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) # (N, L, nh, dh)
|
263 |
+
x = x.view(*new_x_shape)
|
264 |
+
return x.permute(0, 2, 1, 3) # (N, nh, L, dh)
|
265 |
+
|
266 |
+
def forward(self, query_states, key_states, value_states, attention_mask):
|
267 |
+
"""
|
268 |
+
Args:
|
269 |
+
query_states: (N, Lq, D)
|
270 |
+
key_states: (N, L, D)
|
271 |
+
value_states: (N, L, D)
|
272 |
+
attention_mask: (N, Lq, L)
|
273 |
+
Returns:
|
274 |
+
"""
|
275 |
+
# only need to mask the dimension where the softmax (last dim) is applied, as another dim (second last)
|
276 |
+
# will be ignored in future computation anyway
|
277 |
+
attention_mask = (1 - attention_mask.unsqueeze(1)) * -10000. # (N, 1, Lq, L)
|
278 |
+
mixed_query_layer = self.query(query_states)
|
279 |
+
mixed_key_layer = self.key(key_states)
|
280 |
+
mixed_value_layer = self.value(value_states)
|
281 |
+
|
282 |
+
query_layer = self.transpose_for_scores(mixed_query_layer) # (N, nh, Lq, dh)
|
283 |
+
key_layer = self.transpose_for_scores(mixed_key_layer) # (N, nh, L, dh)
|
284 |
+
value_layer = self.transpose_for_scores(mixed_value_layer) # (N, nh, L, dh)
|
285 |
+
|
286 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
287 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) # (N, nh, Lq, L)
|
288 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
289 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
290 |
+
attention_scores = attention_scores + attention_mask
|
291 |
+
|
292 |
+
# Normalize the attention scores to probabilities.
|
293 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
294 |
+
|
295 |
+
# This is actually dropping out entire tokens to attend to, which might
|
296 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
297 |
+
attention_probs = self.dropout(attention_probs)
|
298 |
+
|
299 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
300 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
301 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
302 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
303 |
+
return context_layer
|
304 |
+
|
305 |
+
|
306 |
+
class BertSelfOutput(nn.Module):
|
307 |
+
def __init__(self, config):
|
308 |
+
super(BertSelfOutput, self).__init__()
|
309 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
310 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size)
|
311 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
312 |
+
|
313 |
+
def forward(self, hidden_states, input_tensor):
|
314 |
+
hidden_states = self.dense(hidden_states)
|
315 |
+
hidden_states = self.dropout(hidden_states)
|
316 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
317 |
+
return hidden_states
|
baselines/crossmodal_moment_localization/model_xml.py
ADDED
@@ -0,0 +1,642 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import copy
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from easydict import EasyDict as edict
|
7 |
+
from baselines.crossmodal_moment_localization.model_components import \
|
8 |
+
BertAttention, PositionEncoding, LinearLayer, BertSelfAttention, TrainablePositionalEncoding, ConvEncoder
|
9 |
+
from utils.model_utils import RNNEncoder
|
10 |
+
|
11 |
+
base_bert_layer_config = dict(
|
12 |
+
hidden_size=768,
|
13 |
+
intermediate_size=768,
|
14 |
+
hidden_dropout_prob=0.1,
|
15 |
+
attention_probs_dropout_prob=0.1,
|
16 |
+
num_attention_heads=4,
|
17 |
+
)
|
18 |
+
|
19 |
+
xml_base_config = edict(
|
20 |
+
merge_two_stream=True, # merge only the scores
|
21 |
+
cross_att=True, # cross-attention for video and subtitles
|
22 |
+
span_predictor_type="conv",
|
23 |
+
encoder_type="transformer", # cnn, transformer, lstm, gru
|
24 |
+
add_pe_rnn=False, # add positional encoding for RNNs, (LSTM and GRU)
|
25 |
+
visual_input_size=2048, # changes based on visual input type
|
26 |
+
query_input_size=768,
|
27 |
+
sub_input_size=768,
|
28 |
+
hidden_size=500, #
|
29 |
+
conv_kernel_size=5, # conv kernel_size for st_ed predictor
|
30 |
+
stack_conv_predictor_conv_kernel_sizes=-1, # Do not use
|
31 |
+
conv_stride=1, #
|
32 |
+
max_ctx_l=100,
|
33 |
+
max_desc_l=30,
|
34 |
+
input_drop=0.1, # dropout for input
|
35 |
+
drop=0.1, # dropout for other layers
|
36 |
+
n_heads=4, # self attention heads
|
37 |
+
ctx_mode="video_sub", # which context are used. 'video', 'sub' or 'video_sub'
|
38 |
+
margin=0.1, # margin for ranking loss
|
39 |
+
ranking_loss_type="hinge", # loss type, 'hinge' or 'lse'
|
40 |
+
lw_neg_q=1, # loss weight for neg. query and pos. context
|
41 |
+
lw_neg_ctx=1, # loss weight for pos. query and neg. context
|
42 |
+
lw_st_ed=1, # loss weight for st ed prediction
|
43 |
+
use_hard_negative=False, # use hard negative at video level, we may change it during training.
|
44 |
+
hard_pool_size=20,
|
45 |
+
use_self_attention=True,
|
46 |
+
no_modular=False,
|
47 |
+
pe_type="none", # no positional encoding
|
48 |
+
initializer_range=0.02,
|
49 |
+
)
|
50 |
+
|
51 |
+
|
52 |
+
class XML(nn.Module):
|
53 |
+
def __init__(self, config):
|
54 |
+
super(XML, self).__init__()
|
55 |
+
self.config = config
|
56 |
+
# self.position_embeddings = PositionEncoding(n_filters=config.hidden_size,
|
57 |
+
# max_len=config.max_position_embeddings,
|
58 |
+
# pe_type=config.pe_type)
|
59 |
+
self.query_pos_embed = TrainablePositionalEncoding(
|
60 |
+
max_position_embeddings=config.max_desc_l,
|
61 |
+
hidden_size=config.hidden_size, dropout=config.input_drop)
|
62 |
+
self.ctx_pos_embed = TrainablePositionalEncoding(
|
63 |
+
max_position_embeddings=config.max_ctx_l,
|
64 |
+
hidden_size=config.hidden_size, dropout=config.input_drop)
|
65 |
+
self.query_input_proj = LinearLayer(config.query_input_size,
|
66 |
+
config.hidden_size,
|
67 |
+
layer_norm=True,
|
68 |
+
dropout=config.input_drop,
|
69 |
+
relu=True)
|
70 |
+
if config.encoder_type == "transformer": # self-att encoder
|
71 |
+
self.query_encoder = BertAttention(edict(
|
72 |
+
hidden_size=config.hidden_size,
|
73 |
+
intermediate_size=config.hidden_size,
|
74 |
+
hidden_dropout_prob=config.drop,
|
75 |
+
attention_probs_dropout_prob=config.drop,
|
76 |
+
num_attention_heads=config.n_heads,
|
77 |
+
))
|
78 |
+
elif config.encoder_type == "cnn":
|
79 |
+
self.query_encoder = ConvEncoder(
|
80 |
+
kernel_size=5,
|
81 |
+
n_filters=config.hidden_size,
|
82 |
+
dropout=config.drop
|
83 |
+
)
|
84 |
+
elif config.encoder_type in ["gru", "lstm"]:
|
85 |
+
self.query_encoder = RNNEncoder(
|
86 |
+
word_embedding_size=config.hidden_size,
|
87 |
+
hidden_size=config.hidden_size // 2,
|
88 |
+
bidirectional=True,
|
89 |
+
n_layers=1,
|
90 |
+
rnn_type=config.encoder_type,
|
91 |
+
return_outputs=True,
|
92 |
+
return_hidden=False
|
93 |
+
)
|
94 |
+
|
95 |
+
conv_cfg = dict(in_channels=1,
|
96 |
+
out_channels=1,
|
97 |
+
kernel_size=config.conv_kernel_size,
|
98 |
+
stride=config.conv_stride,
|
99 |
+
padding=config.conv_kernel_size // 2,
|
100 |
+
bias=False)
|
101 |
+
|
102 |
+
cross_att_cfg = edict(
|
103 |
+
hidden_size=config.hidden_size,
|
104 |
+
num_attention_heads=config.n_heads,
|
105 |
+
attention_probs_dropout_prob=config.drop
|
106 |
+
)
|
107 |
+
|
108 |
+
self.use_video = "video" in config.ctx_mode
|
109 |
+
if self.use_video:
|
110 |
+
self.video_input_proj = LinearLayer(config.visual_input_size,
|
111 |
+
config.hidden_size,
|
112 |
+
layer_norm=True,
|
113 |
+
dropout=config.input_drop,
|
114 |
+
relu=True)
|
115 |
+
self.video_encoder1 = copy.deepcopy(self.query_encoder)
|
116 |
+
self.video_encoder2 = copy.deepcopy(self.query_encoder)
|
117 |
+
if self.config.cross_att:
|
118 |
+
self.video_cross_att = BertSelfAttention(cross_att_cfg)
|
119 |
+
self.video_cross_layernorm = nn.LayerNorm(config.hidden_size)
|
120 |
+
else:
|
121 |
+
if self.config.encoder_type == "transformer":
|
122 |
+
self.video_encoder3 = copy.deepcopy(self.query_encoder)
|
123 |
+
self.video_query_linear = nn.Linear(config.hidden_size, config.hidden_size)
|
124 |
+
if config.span_predictor_type == "conv":
|
125 |
+
if not config.merge_two_stream:
|
126 |
+
self.video_st_predictor = nn.Conv1d(**conv_cfg)
|
127 |
+
self.video_ed_predictor = nn.Conv1d(**conv_cfg)
|
128 |
+
elif config.span_predictor_type == "cat_linear":
|
129 |
+
self.video_st_predictor = nn.ModuleList([nn.Linear(config.hidden_size, 1) for _ in range(2)])
|
130 |
+
self.video_ed_predictor = nn.ModuleList([nn.Linear(config.hidden_size, 1) for _ in range(2)])
|
131 |
+
|
132 |
+
self.use_sub = "sub" in config.ctx_mode
|
133 |
+
if self.use_sub:
|
134 |
+
self.sub_input_proj = LinearLayer(config.sub_input_size,
|
135 |
+
config.hidden_size,
|
136 |
+
layer_norm=True,
|
137 |
+
dropout=config.input_drop,
|
138 |
+
relu=True)
|
139 |
+
self.sub_encoder1 = copy.deepcopy(self.query_encoder)
|
140 |
+
self.sub_encoder2 = copy.deepcopy(self.query_encoder)
|
141 |
+
if self.config.cross_att:
|
142 |
+
self.sub_cross_att = BertSelfAttention(cross_att_cfg)
|
143 |
+
self.sub_cross_layernorm = nn.LayerNorm(config.hidden_size)
|
144 |
+
else:
|
145 |
+
if self.config.encoder_type == "transformer":
|
146 |
+
self.sub_encoder3 = copy.deepcopy(self.query_encoder)
|
147 |
+
self.sub_query_linear = nn.Linear(config.hidden_size, config.hidden_size)
|
148 |
+
if config.span_predictor_type == "conv":
|
149 |
+
if not config.merge_two_stream:
|
150 |
+
self.sub_st_predictor = nn.Conv1d(**conv_cfg)
|
151 |
+
self.sub_ed_predictor = nn.Conv1d(**conv_cfg)
|
152 |
+
elif config.span_predictor_type == "cat_linear":
|
153 |
+
self.sub_st_predictor = nn.ModuleList([nn.Linear(config.hidden_size, 1) for _ in range(2)])
|
154 |
+
self.sub_ed_predictor = nn.ModuleList([nn.Linear(config.hidden_size, 1) for _ in range(2)])
|
155 |
+
|
156 |
+
self.modular_vector_mapping = nn.Linear(in_features=config.hidden_size,
|
157 |
+
out_features=self.use_sub + self.use_video,
|
158 |
+
bias=False)
|
159 |
+
|
160 |
+
self.temporal_criterion = nn.CrossEntropyLoss(reduction="mean")
|
161 |
+
|
162 |
+
if config.merge_two_stream and config.span_predictor_type == "conv":
|
163 |
+
if self.config.stack_conv_predictor_conv_kernel_sizes == -1:
|
164 |
+
self.merged_st_predictor = nn.Conv1d(**conv_cfg)
|
165 |
+
self.merged_ed_predictor = nn.Conv1d(**conv_cfg)
|
166 |
+
else:
|
167 |
+
print("Will be using multiple Conv layers for prediction.")
|
168 |
+
self.merged_st_predictors = nn.ModuleList()
|
169 |
+
self.merged_ed_predictors = nn.ModuleList()
|
170 |
+
num_convs = len(self.config.stack_conv_predictor_conv_kernel_sizes)
|
171 |
+
for k in self.config.stack_conv_predictor_conv_kernel_sizes:
|
172 |
+
conv_cfg = dict(in_channels=1,
|
173 |
+
out_channels=1,
|
174 |
+
kernel_size=k,
|
175 |
+
stride=config.conv_stride,
|
176 |
+
padding=k // 2,
|
177 |
+
bias=False)
|
178 |
+
self.merged_st_predictors.append(nn.Conv1d(**conv_cfg))
|
179 |
+
self.merged_ed_predictors.append(nn.Conv1d(**conv_cfg))
|
180 |
+
self.combine_st_conv = nn.Linear(num_convs, 1, bias=False)
|
181 |
+
self.combine_ed_conv = nn.Linear(num_convs, 1, bias=False)
|
182 |
+
|
183 |
+
self.reset_parameters()
|
184 |
+
|
185 |
+
def reset_parameters(self):
|
186 |
+
""" Initialize the weights."""
|
187 |
+
|
188 |
+
def re_init(module):
|
189 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
190 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
191 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
192 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
193 |
+
elif isinstance(module, nn.LayerNorm):
|
194 |
+
module.bias.data.zero_()
|
195 |
+
module.weight.data.fill_(1.0)
|
196 |
+
elif isinstance(module, nn.Conv1d):
|
197 |
+
module.reset_parameters()
|
198 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
199 |
+
module.bias.data.zero_()
|
200 |
+
|
201 |
+
self.apply(re_init)
|
202 |
+
|
203 |
+
def set_hard_negative(self, use_hard_negative, hard_pool_size):
|
204 |
+
"""use_hard_negative: bool; hard_pool_size: int, """
|
205 |
+
self.config.use_hard_negative = use_hard_negative
|
206 |
+
self.config.hard_pool_size = hard_pool_size
|
207 |
+
|
208 |
+
def set_train_st_ed(self, lw_st_ed):
|
209 |
+
"""pre-train video retrieval then span prediction"""
|
210 |
+
self.config.lw_st_ed = lw_st_ed
|
211 |
+
|
212 |
+
def forward(self, query_feat, query_mask, video_feat, video_mask, sub_feat, sub_mask,
|
213 |
+
tef_feat, tef_mask, st_ed_indices):
|
214 |
+
"""
|
215 |
+
Args:
|
216 |
+
query_feat: (N, Lq, Dq)
|
217 |
+
query_mask: (N, Lq)
|
218 |
+
video_feat: (N, Lv, Dv) or None
|
219 |
+
video_mask: (N, Lv) or None
|
220 |
+
sub_feat: (N, Lv, Ds) or None
|
221 |
+
sub_mask: (N, Lv) or None
|
222 |
+
tef_feat: (N, Lv, 2) or None,
|
223 |
+
tef_mask: (N, Lv) or None,
|
224 |
+
st_ed_indices: (N, 2), torch.LongTensor, 1st, 2nd columns are st, ed labels respectively.
|
225 |
+
"""
|
226 |
+
video_feat1, video_feat2, sub_feat1, sub_feat2 = \
|
227 |
+
self.encode_context(video_feat, video_mask, sub_feat, sub_mask)
|
228 |
+
|
229 |
+
query_context_scores, st_prob, ed_prob = \
|
230 |
+
self.get_pred_from_raw_query(query_feat, query_mask,
|
231 |
+
video_feat1, video_feat2, video_mask,
|
232 |
+
sub_feat1, sub_feat2, sub_mask, cross=False)
|
233 |
+
|
234 |
+
loss_st_ed = 0
|
235 |
+
if self.config.lw_st_ed != 0:
|
236 |
+
loss_st = self.temporal_criterion(st_prob, st_ed_indices[:, 0])
|
237 |
+
loss_ed = self.temporal_criterion(ed_prob, st_ed_indices[:, 1])
|
238 |
+
loss_st_ed = loss_st + loss_ed
|
239 |
+
|
240 |
+
loss_neg_ctx, loss_neg_q = 0, 0
|
241 |
+
if self.config.lw_neg_ctx != 0 or self.config.lw_neg_q != 0:
|
242 |
+
loss_neg_ctx, loss_neg_q = self.get_video_level_loss(query_context_scores)
|
243 |
+
|
244 |
+
loss_st_ed = self.config.lw_st_ed * loss_st_ed
|
245 |
+
loss_neg_ctx = self.config.lw_neg_ctx * loss_neg_ctx
|
246 |
+
loss_neg_q = self.config.lw_neg_q * loss_neg_q
|
247 |
+
loss = loss_st_ed + loss_neg_ctx + loss_neg_q
|
248 |
+
return loss, {"loss_st_ed": float(loss_st_ed),
|
249 |
+
"loss_neg_ctx": float(loss_neg_ctx),
|
250 |
+
"loss_neg_q": float(loss_neg_q),
|
251 |
+
"loss_overall": float(loss)}
|
252 |
+
|
253 |
+
def get_visualization_data(self, query_feat, query_mask, video_feat, video_mask, sub_feat, sub_mask,
|
254 |
+
tef_feat, tef_mask, st_ed_indices):
|
255 |
+
assert self.config.merge_two_stream and self.use_video and self.use_sub and not self.config.no_modular
|
256 |
+
video_feat1, video_feat2, sub_feat1, sub_feat2 = \
|
257 |
+
self.encode_context(video_feat, video_mask, sub_feat, sub_mask)
|
258 |
+
encoded_query = self.encode_input(query_feat, query_mask,
|
259 |
+
self.query_input_proj, self.query_encoder, self.query_pos_embed) # (N, Lq, D)
|
260 |
+
# (N, D), (N, D), (N, L, 2)
|
261 |
+
video_query, sub_query, modular_att_scores = \
|
262 |
+
self.get_modularized_queries(encoded_query, query_mask, return_modular_att=True)
|
263 |
+
# (N, L), (N, L), (N, L)
|
264 |
+
st_prob, ed_prob, similarity_scores, video_similarity, sub_similarity = self.get_merged_st_ed_prob(
|
265 |
+
video_query, video_feat2, sub_query, sub_feat2, video_mask, cross=False, return_similaity=True)
|
266 |
+
|
267 |
+
# clean up invalid bits
|
268 |
+
data = dict(modular_att_scores=modular_att_scores.cpu().numpy(), # (N, Lq, 2), row 0, 1 are video, sub.
|
269 |
+
st_prob=st_prob.cpu().numpy(), # (N, L)
|
270 |
+
ed_prob=ed_prob.cpu().numpy(), # (N, L)
|
271 |
+
similarity_scores=similarity_scores.cpu().numpy(), # (N, L)
|
272 |
+
video_similarity=video_similarity.cpu().numpy(), # (N, L)
|
273 |
+
sub_similarity=sub_similarity.cpu().numpy(), # (N, L)
|
274 |
+
st_ed_indices=st_ed_indices.cpu().numpy()) # (N, L)
|
275 |
+
query_lengths = query_mask.sum(1).to(torch.long).cpu().tolist() # (N, )
|
276 |
+
ctx_lengths = video_mask.sum(1).to(torch.long).cpu().tolist() # (N, )
|
277 |
+
# print("query_lengths {}".format((type(query_lengths), len(query_lengths), query_lengths[:10])))
|
278 |
+
for k, v in data.items():
|
279 |
+
if k == "modular_att_scores":
|
280 |
+
# print(k, v, v.shape, type(v))
|
281 |
+
data[k] = [e[:l] for l, e in zip(query_lengths, v)] # list(e) where e is (Lq_i, 2)
|
282 |
+
else:
|
283 |
+
data[k] = [e[:l] for l, e in zip(ctx_lengths, v)] # list(e) where e is (Lc_i)
|
284 |
+
|
285 |
+
# aggregate info for each example
|
286 |
+
datalist = []
|
287 |
+
for idx in range(len(data["modular_att_scores"])):
|
288 |
+
datalist.append({k: v[idx] for k, v in data.items()})
|
289 |
+
return datalist # list(dicts) of length N
|
290 |
+
|
291 |
+
def encode_query(self, query_feat, query_mask):
|
292 |
+
encoded_query = self.encode_input(query_feat, query_mask,
|
293 |
+
self.query_input_proj, self.query_encoder, self.query_pos_embed) # (N, Lq, D)
|
294 |
+
video_query, sub_query = self.get_modularized_queries(encoded_query, query_mask) # (N, D) * 2
|
295 |
+
return video_query, sub_query
|
296 |
+
|
297 |
+
def non_cross_encode_context(self, context_feat, context_mask, module_name="video"):
|
298 |
+
encoder_layer3 = getattr(self, module_name + "_encoder3") \
|
299 |
+
if self.config.encoder_type == "transformer" else None
|
300 |
+
return self._non_cross_encode_context(context_feat, context_mask,
|
301 |
+
input_proj_layer=getattr(self, module_name + "_input_proj"),
|
302 |
+
encoder_layer1=getattr(self, module_name + "_encoder1"),
|
303 |
+
encoder_layer2=getattr(self, module_name + "_encoder2"),
|
304 |
+
encoder_layer3=encoder_layer3)
|
305 |
+
|
306 |
+
def _non_cross_encode_context(self, context_feat, context_mask, input_proj_layer,
|
307 |
+
encoder_layer1, encoder_layer2, encoder_layer3=None):
|
308 |
+
"""
|
309 |
+
Args:
|
310 |
+
context_feat: (N, L, D)
|
311 |
+
context_mask: (N, L)
|
312 |
+
input_proj_layer:
|
313 |
+
encoder_layer1:
|
314 |
+
encoder_layer2:
|
315 |
+
encoder_layer3
|
316 |
+
"""
|
317 |
+
context_feat1 = self.encode_input(
|
318 |
+
context_feat, context_mask, input_proj_layer, encoder_layer1, self.ctx_pos_embed) # (N, L, D)
|
319 |
+
if self.config.encoder_type in ["transformer", "cnn"]:
|
320 |
+
context_mask = context_mask.unsqueeze(1) # (N, 1, L), torch.FloatTensor
|
321 |
+
context_feat2 = encoder_layer2(context_feat1, context_mask) # (N, L, D)
|
322 |
+
if self.config.encoder_type == "transformer":
|
323 |
+
context_feat2 = encoder_layer3(context_feat2, context_mask)
|
324 |
+
elif self.config.encoder_type in ["gru", "lstm"]:
|
325 |
+
context_mask = context_mask.sum(1).long() # (N, ), torch.LongTensor
|
326 |
+
context_feat2 = encoder_layer2(context_feat1, context_mask)[0] # (N, L, D)
|
327 |
+
else:
|
328 |
+
raise NotImplementedError
|
329 |
+
return context_feat1, context_feat2
|
330 |
+
|
331 |
+
def encode_context(self, video_feat, video_mask, sub_feat, sub_mask):
|
332 |
+
if self.config.cross_att:
|
333 |
+
assert self.use_video and self.use_sub
|
334 |
+
|
335 |
+
return self.cross_encode_context(video_feat, video_mask, sub_feat, sub_mask)
|
336 |
+
else:
|
337 |
+
video_feat1, video_feat2 = (None,) * 2
|
338 |
+
if self.use_video:
|
339 |
+
video_feat1, video_feat2 = self.non_cross_encode_context(video_feat, video_mask, module_name="video")
|
340 |
+
sub_feat1, sub_feat2 = (None,) * 2
|
341 |
+
if self.use_sub:
|
342 |
+
sub_feat1, sub_feat2 = self.non_cross_encode_context(sub_feat, sub_mask, module_name="sub")
|
343 |
+
return video_feat1, video_feat2, sub_feat1, sub_feat2
|
344 |
+
|
345 |
+
def cross_encode_context(self, video_feat, video_mask, sub_feat, sub_mask):
|
346 |
+
encoded_video_feat = self.encode_input(video_feat, video_mask,
|
347 |
+
self.video_input_proj, self.video_encoder1, self.ctx_pos_embed)
|
348 |
+
encoded_sub_feat = self.encode_input(sub_feat, sub_mask,
|
349 |
+
self.sub_input_proj, self.sub_encoder1, self.ctx_pos_embed)
|
350 |
+
x_encoded_video_feat = self.cross_context_encoder(
|
351 |
+
encoded_video_feat, video_mask, encoded_sub_feat, sub_mask,
|
352 |
+
self.video_cross_att, self.video_cross_layernorm, self.video_encoder2) # (N, L, D)
|
353 |
+
x_encoded_sub_feat = self.cross_context_encoder(
|
354 |
+
encoded_sub_feat, sub_mask, encoded_video_feat, video_mask,
|
355 |
+
self.sub_cross_att, self.sub_cross_layernorm, self.sub_encoder2) # (N, L, D)
|
356 |
+
return encoded_video_feat, x_encoded_video_feat, encoded_sub_feat, x_encoded_sub_feat
|
357 |
+
|
358 |
+
def cross_context_encoder(self, main_context_feat, main_context_mask, side_context_feat, side_context_mask,
|
359 |
+
cross_att_layer, norm_layer, self_att_layer):
|
360 |
+
"""
|
361 |
+
Args:
|
362 |
+
main_context_feat: (N, Lq, D)
|
363 |
+
main_context_mask: (N, Lq)
|
364 |
+
side_context_feat: (N, Lk, D)
|
365 |
+
side_context_mask: (N, Lk)
|
366 |
+
cross_att_layer:
|
367 |
+
norm_layer:
|
368 |
+
self_att_layer:
|
369 |
+
"""
|
370 |
+
cross_mask = torch.einsum("bm,bn->bmn", main_context_mask, side_context_mask) # (N, Lq, Lk)
|
371 |
+
cross_out = cross_att_layer(main_context_feat, side_context_feat, side_context_feat, cross_mask) # (N, Lq, D)
|
372 |
+
residual_out = norm_layer(cross_out + main_context_feat)
|
373 |
+
if self.config.encoder_type in ["cnn", "transformer"]:
|
374 |
+
return self_att_layer(residual_out, main_context_mask.unsqueeze(1))
|
375 |
+
elif self.config.encoder_type in ["gru", "lstm"]:
|
376 |
+
return self_att_layer(residual_out, main_context_mask.sum(1).long())[0]
|
377 |
+
|
378 |
+
def encode_input(self, feat, mask, input_proj_layer, encoder_layer, pos_embed_layer):
|
379 |
+
"""
|
380 |
+
Args:
|
381 |
+
feat: (N, L, D_input), torch.float32
|
382 |
+
mask: (N, L), torch.float32, with 1 indicates valid query, 0 indicates mask
|
383 |
+
input_proj_layer: down project input
|
384 |
+
encoder_layer: encoder layer
|
385 |
+
# add_pe: bool, whether to add positional encoding
|
386 |
+
pos_embed_layer
|
387 |
+
"""
|
388 |
+
feat = input_proj_layer(feat)
|
389 |
+
|
390 |
+
if self.config.encoder_type in ["cnn", "transformer"]:
|
391 |
+
feat = pos_embed_layer(feat)
|
392 |
+
mask = mask.unsqueeze(1) # (N, 1, L), torch.FloatTensor
|
393 |
+
return encoder_layer(feat, mask) # (N, L, D_hidden)
|
394 |
+
elif self.config.encoder_type in ["gru", "lstm"]:
|
395 |
+
if self.config.add_pe_rnn:
|
396 |
+
feat = pos_embed_layer(feat)
|
397 |
+
mask = mask.sum(1).long() # (N, ), torch.LongTensor
|
398 |
+
return encoder_layer(feat, mask)[0] # (N, L, D_hidden)
|
399 |
+
|
400 |
+
def get_modularized_queries(self, encoded_query, query_mask, return_modular_att=False):
|
401 |
+
"""
|
402 |
+
Args:
|
403 |
+
encoded_query: (N, L, D)
|
404 |
+
query_mask: (N, L)
|
405 |
+
return_modular_att: bool
|
406 |
+
"""
|
407 |
+
if self.config.no_modular:
|
408 |
+
modular_query = torch.max(mask_logits(encoded_query, query_mask.unsqueeze(2)), dim=1)[0] # (N, D)
|
409 |
+
return modular_query, modular_query #
|
410 |
+
else:
|
411 |
+
modular_attention_scores = self.modular_vector_mapping(encoded_query) # (N, L, 2 or 1)
|
412 |
+
modular_attention_scores = F.softmax(
|
413 |
+
mask_logits(modular_attention_scores, query_mask.unsqueeze(2)), dim=1)
|
414 |
+
# TODO check whether it is the same
|
415 |
+
modular_queries = torch.einsum("blm,bld->bmd",
|
416 |
+
modular_attention_scores, encoded_query) # (N, 2 or 1, D)
|
417 |
+
if return_modular_att:
|
418 |
+
assert modular_queries.shape[1] == 2
|
419 |
+
return modular_queries[:, 0], modular_queries[:, 1], modular_attention_scores
|
420 |
+
else:
|
421 |
+
if modular_queries.shape[1] == 2:
|
422 |
+
return modular_queries[:, 0], modular_queries[:, 1] # (N, D) * 2
|
423 |
+
else: # 1
|
424 |
+
return modular_queries[:, 0], modular_queries[:, 0] # the same
|
425 |
+
|
426 |
+
def get_modular_weights(self, encoded_query, query_mask):
|
427 |
+
"""
|
428 |
+
Args:
|
429 |
+
encoded_query: (N, L, D)
|
430 |
+
query_mask: (N, L)
|
431 |
+
"""
|
432 |
+
max_encoded_query, _ = torch.max(mask_logits(encoded_query, query_mask.unsqueeze(2)), dim=1) # (N, D)
|
433 |
+
modular_weights = self.modular_weights_calculator(max_encoded_query) # (N, 2)
|
434 |
+
modular_weights = F.softmax(modular_weights, dim=-1)
|
435 |
+
return modular_weights[:, 0:1], modular_weights[:, 1:2] # (N, 1) * 2
|
436 |
+
|
437 |
+
def get_video_level_scores(self, modularied_query, context_feat1, context_mask):
|
438 |
+
""" Calculate video2query scores for each pair of video and query inside the batch.
|
439 |
+
Args:
|
440 |
+
modularied_query: (N, D)
|
441 |
+
context_feat1: (N, L, D), output of the first transformer encoder layer
|
442 |
+
context_mask: (N, L)
|
443 |
+
Returns:
|
444 |
+
context_query_scores: (N, N) score of each query w.r.t. each video inside the batch,
|
445 |
+
diagonal positions are positive. used to get negative samples.
|
446 |
+
"""
|
447 |
+
modularied_query = F.normalize(modularied_query, dim=-1)
|
448 |
+
context_feat1 = F.normalize(context_feat1, dim=-1)
|
449 |
+
query_context_scores = torch.einsum("md,nld->mln", modularied_query, context_feat1) # (N, L, N)
|
450 |
+
context_mask = context_mask.transpose(0, 1).unsqueeze(0) # (1, L, N)
|
451 |
+
query_context_scores = mask_logits(query_context_scores, context_mask) # (N, L, N)
|
452 |
+
query_context_scores, _ = torch.max(query_context_scores,
|
453 |
+
dim=1) # (N, N) diagonal positions are positive pairs.
|
454 |
+
return query_context_scores
|
455 |
+
|
456 |
+
def get_merged_st_ed_prob(self, video_query, video_feat, sub_query, sub_feat, context_mask,
|
457 |
+
cross=False, return_similaity=False):
|
458 |
+
"""context_mask could be either video_mask or sub_mask, since they are the same"""
|
459 |
+
assert self.use_video and self.use_sub and self.config.span_predictor_type == "conv"
|
460 |
+
video_query = self.video_query_linear(video_query)
|
461 |
+
sub_query = self.sub_query_linear(sub_query)
|
462 |
+
stack_conv = self.config.stack_conv_predictor_conv_kernel_sizes != -1
|
463 |
+
num_convs = len(self.config.stack_conv_predictor_conv_kernel_sizes) if stack_conv else None
|
464 |
+
if cross:
|
465 |
+
video_similarity = torch.einsum("md,nld->mnl", video_query, video_feat)
|
466 |
+
sub_similarity = torch.einsum("md,nld->mnl", sub_query, sub_feat)
|
467 |
+
similarity = (video_similarity + sub_similarity) / 2 # (Nq, Nv, L) from query to all videos.
|
468 |
+
n_q, n_c, l = similarity.shape
|
469 |
+
similarity = similarity.view(n_q * n_c, 1, l)
|
470 |
+
if not stack_conv:
|
471 |
+
st_prob = self.merged_st_predictor(similarity).view(n_q, n_c, l) # (Nq, Nv, L)
|
472 |
+
ed_prob = self.merged_ed_predictor(similarity).view(n_q, n_c, l) # (Nq, Nv, L)
|
473 |
+
else:
|
474 |
+
st_prob_list = []
|
475 |
+
ed_prob_list = []
|
476 |
+
for idx in range(num_convs):
|
477 |
+
st_prob_list.append(self.merged_st_predictors[idx](similarity).squeeze().unsqueeze(2))
|
478 |
+
ed_prob_list.append(self.merged_ed_predictors[idx](similarity).squeeze().unsqueeze(2))
|
479 |
+
# (Nq*Nv, L, 3) --> (Nq*Nv, L) -> (Nq, Nv, L)
|
480 |
+
st_prob = self.combine_st_conv(torch.cat(st_prob_list, dim=2)).view(n_q, n_c, l)
|
481 |
+
ed_prob = self.combine_ed_conv(torch.cat(ed_prob_list, dim=2)).view(n_q, n_c, l)
|
482 |
+
else:
|
483 |
+
video_similarity = torch.einsum("bd,bld->bl", video_query, video_feat) # (N, L)
|
484 |
+
sub_similarity = torch.einsum("bd,bld->bl", sub_query, sub_feat) # (N, L)
|
485 |
+
similarity = (video_similarity + sub_similarity) / 2
|
486 |
+
if not stack_conv:
|
487 |
+
st_prob = self.merged_st_predictor(similarity.unsqueeze(1)).squeeze() # (N, L)
|
488 |
+
ed_prob = self.merged_ed_predictor(similarity.unsqueeze(1)).squeeze() # (N, L)
|
489 |
+
else:
|
490 |
+
st_prob_list = []
|
491 |
+
ed_prob_list = []
|
492 |
+
for idx in range(num_convs):
|
493 |
+
st_prob_list.append(self.merged_st_predictors[idx](similarity.unsqueeze(1)).squeeze().unsqueeze(2))
|
494 |
+
ed_prob_list.append(self.merged_ed_predictors[idx](similarity.unsqueeze(1)).squeeze().unsqueeze(2))
|
495 |
+
st_prob = self.combine_st_conv(torch.cat(st_prob_list, dim=2)).squeeze() # (N, L, 3) --> (N, L)
|
496 |
+
ed_prob = self.combine_ed_conv(torch.cat(ed_prob_list, dim=2)).squeeze() # (N, L, 3) --> (N, L)
|
497 |
+
st_prob = mask_logits(st_prob, context_mask) # (N, L)
|
498 |
+
ed_prob = mask_logits(ed_prob, context_mask)
|
499 |
+
if return_similaity:
|
500 |
+
assert not cross
|
501 |
+
return st_prob, ed_prob, similarity, video_similarity, sub_similarity
|
502 |
+
else:
|
503 |
+
return st_prob, ed_prob
|
504 |
+
|
505 |
+
def get_st_ed_prob(self, modularied_query, context_feat2, context_mask,
|
506 |
+
module_name="video", cross=False):
|
507 |
+
return self._get_st_ed_prob(modularied_query, context_feat2, context_mask,
|
508 |
+
module_query_linear=getattr(self, module_name + "_query_linear"),
|
509 |
+
st_predictor=getattr(self, module_name + "_st_predictor"),
|
510 |
+
ed_predictor=getattr(self, module_name + "_ed_predictor"),
|
511 |
+
cross=cross)
|
512 |
+
|
513 |
+
def _get_st_ed_prob(self, modularied_query, context_feat2, context_mask,
|
514 |
+
module_query_linear, st_predictor, ed_predictor, cross=False):
|
515 |
+
"""
|
516 |
+
Args:
|
517 |
+
modularied_query: (N, D)
|
518 |
+
context_feat2: (N, L, D), output of the first transformer encoder layer
|
519 |
+
context_mask: (N, L)
|
520 |
+
module_query_linear:
|
521 |
+
st_predictor:
|
522 |
+
ed_predictor:
|
523 |
+
cross: at inference, calculate prob for each possible pairs of query and context.
|
524 |
+
"""
|
525 |
+
query = module_query_linear(modularied_query) # (N, D) no need to normalize here.
|
526 |
+
if cross:
|
527 |
+
if self.config.span_predictor_type == "conv":
|
528 |
+
similarity = torch.einsum("md,nld->mnl", query, context_feat2) # (Nq, Nv, L) from query to all videos.
|
529 |
+
n_q, n_c, l = similarity.shape
|
530 |
+
similarity = similarity.view(n_q * n_c, 1, l)
|
531 |
+
st_prob = st_predictor(similarity).view(n_q, n_c, l) # (Nq, Nv, L)
|
532 |
+
ed_prob = ed_predictor(similarity).view(n_q, n_c, l) # (Nq, Nv, L)
|
533 |
+
elif self.config.span_predictor_type == "cat_linear":
|
534 |
+
st_prob_q = st_predictor[0](query).unsqueeze(1) # (Nq, 1, 1)
|
535 |
+
st_prob_ctx = st_predictor[1](context_feat2).squeeze().unsqueeze(0) # (1, Nv, L)
|
536 |
+
st_prob = st_prob_q + st_prob_ctx # (Nq, Nv, L)
|
537 |
+
ed_prob_q = ed_predictor[0](query).unsqueeze(1) # (Nq, 1, 1)
|
538 |
+
ed_prob_ctx = ed_predictor[1](context_feat2).squeeze().unsqueeze(0) # (1, Nv, L)
|
539 |
+
ed_prob = ed_prob_q + ed_prob_ctx # (Nq, Nv, L)
|
540 |
+
context_mask = context_mask.unsqueeze(0) # (1, Nv, L)
|
541 |
+
else:
|
542 |
+
if self.config.span_predictor_type == "conv":
|
543 |
+
similarity = torch.einsum("bd,bld->bl", query, context_feat2) # (N, L)
|
544 |
+
st_prob = st_predictor(similarity.unsqueeze(1)).squeeze() # (N, L)
|
545 |
+
ed_prob = ed_predictor(similarity.unsqueeze(1)).squeeze() # (N, L)
|
546 |
+
elif self.config.span_predictor_type == "cat_linear":
|
547 |
+
# avoid concatenation by break into smaller matrix multiplications.
|
548 |
+
st_prob = st_predictor[0](query) + st_predictor[1](context_feat2).squeeze() # (N, L)
|
549 |
+
ed_prob = ed_predictor[0](query) + ed_predictor[1](context_feat2).squeeze() # (N, L)
|
550 |
+
st_prob = mask_logits(st_prob, context_mask) # (N, L)
|
551 |
+
ed_prob = mask_logits(ed_prob, context_mask)
|
552 |
+
return st_prob, ed_prob
|
553 |
+
|
554 |
+
def get_pred_from_raw_query(self, query_feat, query_mask,
|
555 |
+
video_feat1, video_feat2, video_mask,
|
556 |
+
sub_feat1, sub_feat2, sub_mask, cross=False):
|
557 |
+
"""
|
558 |
+
Args:
|
559 |
+
query_feat: (N, Lq, Dq)
|
560 |
+
query_mask: (N, Lq)
|
561 |
+
video_feat1: (N, Lv, D) or None
|
562 |
+
video_feat2:
|
563 |
+
video_mask: (N, Lv)
|
564 |
+
sub_feat1: (N, Lv, D) or None
|
565 |
+
sub_feat2:
|
566 |
+
sub_mask: (N, Lv)
|
567 |
+
cross:
|
568 |
+
"""
|
569 |
+
video_query, sub_query = self.encode_query(query_feat, query_mask)
|
570 |
+
divisor = self.use_sub + self.use_video
|
571 |
+
|
572 |
+
# get video-level retrieval scores
|
573 |
+
video_q2ctx_scores = self.get_video_level_scores(video_query, video_feat1, video_mask) if self.use_video else 0
|
574 |
+
sub_q2ctx_scores = self.get_video_level_scores(sub_query, sub_feat1, sub_mask) if self.use_sub else 0
|
575 |
+
q2ctx_scores = (video_q2ctx_scores + sub_q2ctx_scores) / divisor # (N, N)
|
576 |
+
|
577 |
+
if self.config.merge_two_stream and self.use_video and self.use_sub:
|
578 |
+
st_prob, ed_prob = self.get_merged_st_ed_prob(
|
579 |
+
video_query, video_feat2, sub_query, sub_feat2, video_mask, cross=cross)
|
580 |
+
else:
|
581 |
+
video_st_prob, video_ed_prob = self.get_st_ed_prob(
|
582 |
+
video_query, video_feat2, video_mask, module_name="video", cross=cross) if self.use_video else (0, 0)
|
583 |
+
sub_st_prob, sub_ed_prob = self.get_st_ed_prob(
|
584 |
+
sub_query, sub_feat2, sub_mask, module_name="sub", cross=cross) if self.use_sub else (0, 0)
|
585 |
+
st_prob = (video_st_prob + sub_st_prob) / divisor # (N, Lv)
|
586 |
+
ed_prob = (video_ed_prob + sub_ed_prob) / divisor # (N, Lv)
|
587 |
+
return q2ctx_scores, st_prob, ed_prob # un-normalized masked probabilities!!!!!
|
588 |
+
|
589 |
+
def get_video_level_loss(self, query_context_scores):
|
590 |
+
""" ranking loss between (pos. query + pos. video) and (pos. query + neg. video) or (neg. query + pos. video)
|
591 |
+
Args:
|
592 |
+
query_context_scores: (N, N), cosine similarity [-1, 1],
|
593 |
+
Each row contains the scores between the query to each of the videos inside the batch.
|
594 |
+
"""
|
595 |
+
bsz = len(query_context_scores)
|
596 |
+
diagonal_indices = torch.arange(bsz).to(query_context_scores.device)
|
597 |
+
pos_scores = query_context_scores[diagonal_indices, diagonal_indices] # (N, )
|
598 |
+
query_context_scores_masked = copy.deepcopy(query_context_scores.data)
|
599 |
+
# impossibly large for cosine similarity, the copy is created as modifying the original will cause error
|
600 |
+
query_context_scores_masked[diagonal_indices, diagonal_indices] = 999
|
601 |
+
pos_query_neg_context_scores = self.get_neg_scores(query_context_scores,
|
602 |
+
query_context_scores_masked)
|
603 |
+
neg_query_pos_context_scores = self.get_neg_scores(query_context_scores.transpose(0, 1),
|
604 |
+
query_context_scores_masked.transpose(0, 1))
|
605 |
+
loss_neg_ctx = self.get_ranking_loss(pos_scores, pos_query_neg_context_scores)
|
606 |
+
loss_neg_q = self.get_ranking_loss(pos_scores, neg_query_pos_context_scores)
|
607 |
+
return loss_neg_ctx, loss_neg_q
|
608 |
+
|
609 |
+
def get_neg_scores(self, scores, scores_masked):
|
610 |
+
"""
|
611 |
+
scores: (N, N), cosine similarity [-1, 1],
|
612 |
+
Each row are scores: query --> all videos. Transposed version: video --> all queries.
|
613 |
+
scores_masked: (N, N) the same as scores, except that the diagonal (positive) positions
|
614 |
+
are masked with a large value.
|
615 |
+
"""
|
616 |
+
bsz = len(scores)
|
617 |
+
batch_indices = torch.arange(bsz).to(scores.device)
|
618 |
+
_, sorted_scores_indices = torch.sort(scores_masked, descending=True, dim=1)
|
619 |
+
sample_min_idx = 1 # skip the masked positive
|
620 |
+
sample_max_idx = min(sample_min_idx + self.config.hard_pool_size, bsz) \
|
621 |
+
if self.config.use_hard_negative else bsz
|
622 |
+
sampled_neg_score_indices = sorted_scores_indices[
|
623 |
+
batch_indices, torch.randint(sample_min_idx, sample_max_idx, size=(bsz,)).to(scores.device)] # (N, )
|
624 |
+
sampled_neg_scores = scores[batch_indices, sampled_neg_score_indices] # (N, )
|
625 |
+
return sampled_neg_scores
|
626 |
+
|
627 |
+
def get_ranking_loss(self, pos_score, neg_score):
|
628 |
+
""" Note here we encourage positive scores to be larger than negative scores.
|
629 |
+
Args:
|
630 |
+
pos_score: (N, ), torch.float32
|
631 |
+
neg_score: (N, ), torch.float32
|
632 |
+
"""
|
633 |
+
if self.config.ranking_loss_type == "hinge": # max(0, m + S_neg - S_pos)
|
634 |
+
return torch.clamp(self.config.margin + neg_score - pos_score, min=0).sum() / len(pos_score)
|
635 |
+
elif self.config.ranking_loss_type == "lse": # log[1 + exp(S_neg - S_pos)]
|
636 |
+
return torch.log1p(torch.exp(neg_score - pos_score)).sum() / len(pos_score)
|
637 |
+
else:
|
638 |
+
raise NotImplementedError("Only support 'hinge' and 'lse'")
|
639 |
+
|
640 |
+
|
641 |
+
def mask_logits(target, mask):
|
642 |
+
return target * mask + (1 - mask) * (-1e10)
|
baselines/crossmodal_moment_localization/ndcg_iou_topk.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.basic_utils import load_jsonl, save_jsonl, load_json
|
2 |
+
import pandas as pd
|
3 |
+
from tqdm import tqdm
|
4 |
+
import numpy as np
|
5 |
+
from collections import defaultdict
|
6 |
+
import copy
|
7 |
+
|
8 |
+
def calculate_iou(pred_start: float, pred_end: float, gt_start: float, gt_end: float) -> float:
|
9 |
+
intersection_start = max(pred_start, gt_start)
|
10 |
+
intersection_end = min(pred_end, gt_end)
|
11 |
+
intersection = max(0, intersection_end - intersection_start)
|
12 |
+
union = (pred_end - pred_start) + (gt_end - gt_start) - intersection
|
13 |
+
return intersection / union if union > 0 else 0
|
14 |
+
|
15 |
+
|
16 |
+
# Function to calculate DCG
|
17 |
+
def calculate_dcg(scores):
|
18 |
+
return sum((2**score - 1) / np.log2(idx + 2) for idx, score in enumerate(scores))
|
19 |
+
|
20 |
+
# Function to calculate NDCG
|
21 |
+
def calculate_ndcg(pred_scores, true_scores):
|
22 |
+
dcg = calculate_dcg(pred_scores)
|
23 |
+
idcg = calculate_dcg(sorted(true_scores, reverse=True))
|
24 |
+
return dcg / idcg if idcg > 0 else 0
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
def calculate_ndcg_iou(all_gt, all_pred, TS, KS):
|
29 |
+
performance = defaultdict(lambda: defaultdict(list))
|
30 |
+
performance_avg = defaultdict(lambda: defaultdict(float))
|
31 |
+
for k in tqdm(all_pred.keys(), desc="Calculate NDCG"):
|
32 |
+
one_pred = all_pred[k]
|
33 |
+
one_gt = all_gt[k]
|
34 |
+
|
35 |
+
one_gt.sort(key=lambda x: x["relevance"], reverse=True)
|
36 |
+
for T in TS:
|
37 |
+
one_gt_drop = copy.deepcopy(one_gt)
|
38 |
+
predictions_with_scores = []
|
39 |
+
|
40 |
+
for pred in one_pred:
|
41 |
+
pred_video_name, pred_time = pred["video_name"], pred["timestamp"]
|
42 |
+
matched_rows = [gt for gt in one_gt_drop if gt["video_name"] == pred_video_name]
|
43 |
+
if not matched_rows:
|
44 |
+
pred["pred_relevance"] = 0
|
45 |
+
else:
|
46 |
+
ious = [calculate_iou(pred_time[0], pred_time[1], gt["timestamp"][0], gt["timestamp"][1]) for gt in matched_rows]
|
47 |
+
max_iou_idx = np.argmax(ious)
|
48 |
+
max_iou_row = matched_rows[max_iou_idx]
|
49 |
+
|
50 |
+
if ious[max_iou_idx] > T:
|
51 |
+
pred["pred_relevance"] = max_iou_row["relevance"]
|
52 |
+
# Remove the matched ground truth row
|
53 |
+
original_idx = one_gt_drop.index(max_iou_row)
|
54 |
+
one_gt_drop.pop(original_idx)
|
55 |
+
else:
|
56 |
+
pred["pred_relevance"] = 0
|
57 |
+
predictions_with_scores.append(pred)
|
58 |
+
for K in KS:
|
59 |
+
true_scores = [gt["relevance"] for gt in one_gt][:K]
|
60 |
+
pred_scores = [pred["pred_relevance"] for pred in predictions_with_scores][:K]
|
61 |
+
ndcg_score = calculate_ndcg(pred_scores, true_scores)
|
62 |
+
performance[K][T].append(ndcg_score)
|
63 |
+
for K, vs in performance.items():
|
64 |
+
for T, v in vs.items():
|
65 |
+
performance_avg[K][T] = np.mean(v)
|
66 |
+
return performance_avg
|
67 |
+
|
68 |
+
|
baselines/crossmodal_moment_localization/optimization.py
ADDED
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""PyTorch optimization for BERT model."""
|
16 |
+
|
17 |
+
import math
|
18 |
+
import torch
|
19 |
+
from torch.optim import Optimizer
|
20 |
+
from torch.optim.optimizer import required
|
21 |
+
from torch.nn.utils import clip_grad_norm_
|
22 |
+
import logging
|
23 |
+
import abc
|
24 |
+
import sys
|
25 |
+
|
26 |
+
logger = logging.getLogger(__name__)
|
27 |
+
|
28 |
+
|
29 |
+
if sys.version_info >= (3, 4):
|
30 |
+
ABC = abc.ABC
|
31 |
+
else:
|
32 |
+
ABC = abc.ABCMeta('ABC', (), {})
|
33 |
+
|
34 |
+
|
35 |
+
class _LRSchedule(ABC):
|
36 |
+
""" Parent of all LRSchedules here. """
|
37 |
+
warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense
|
38 |
+
def __init__(self, warmup=0.002, t_total=-1, **kw):
|
39 |
+
"""
|
40 |
+
:param warmup: what fraction of t_total steps will be used for linear warmup
|
41 |
+
:param t_total: how many training steps (updates) are planned
|
42 |
+
:param kw:
|
43 |
+
"""
|
44 |
+
super(_LRSchedule, self).__init__(**kw)
|
45 |
+
if t_total < 0:
|
46 |
+
logger.warning("t_total value of {} results in schedule not being applied".format(t_total))
|
47 |
+
if not 0.0 <= warmup < 1.0 and not warmup == -1:
|
48 |
+
raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
|
49 |
+
warmup = max(warmup, 0.)
|
50 |
+
self.warmup, self.t_total = float(warmup), float(t_total)
|
51 |
+
self.warned_for_t_total_at_progress = -1
|
52 |
+
|
53 |
+
def get_lr(self, step, nowarn=False):
|
54 |
+
"""
|
55 |
+
:param step: which of t_total steps we're on
|
56 |
+
:param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps
|
57 |
+
:return: learning rate multiplier for current update
|
58 |
+
"""
|
59 |
+
if self.t_total < 0:
|
60 |
+
return 1.
|
61 |
+
progress = float(step) / self.t_total
|
62 |
+
ret = self.get_lr_(progress)
|
63 |
+
# warning for exceeding t_total (only active with warmup_linear
|
64 |
+
if not nowarn and self.warn_t_total and progress > 1. and progress > self.warned_for_t_total_at_progress:
|
65 |
+
logger.warning(
|
66 |
+
"Training beyond specified 't_total'. Learning rate multiplier set to {}. Please set 't_total' of {} correctly."
|
67 |
+
.format(ret, self.__class__.__name__))
|
68 |
+
self.warned_for_t_total_at_progress = progress
|
69 |
+
# end warning
|
70 |
+
return ret
|
71 |
+
|
72 |
+
@abc.abstractmethod
|
73 |
+
def get_lr_(self, progress):
|
74 |
+
"""
|
75 |
+
:param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress
|
76 |
+
:return: learning rate multiplier for current update
|
77 |
+
"""
|
78 |
+
return 1.
|
79 |
+
|
80 |
+
|
81 |
+
class ConstantLR(_LRSchedule):
|
82 |
+
def get_lr_(self, progress):
|
83 |
+
return 1.
|
84 |
+
|
85 |
+
|
86 |
+
class WarmupCosineSchedule(_LRSchedule):
|
87 |
+
"""
|
88 |
+
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
|
89 |
+
Decreases learning rate from 1. to 0. over remaining `1 - warmup` steps following a cosine curve.
|
90 |
+
If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup.
|
91 |
+
"""
|
92 |
+
warn_t_total = True
|
93 |
+
def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw):
|
94 |
+
"""
|
95 |
+
:param warmup: see LRSchedule
|
96 |
+
:param t_total: see LRSchedule
|
97 |
+
:param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1.
|
98 |
+
:param kw:
|
99 |
+
"""
|
100 |
+
super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw)
|
101 |
+
self.cycles = cycles
|
102 |
+
|
103 |
+
def get_lr_(self, progress):
|
104 |
+
if progress < self.warmup:
|
105 |
+
return progress / self.warmup
|
106 |
+
else:
|
107 |
+
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
|
108 |
+
return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress))
|
109 |
+
|
110 |
+
|
111 |
+
class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule):
|
112 |
+
"""
|
113 |
+
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
|
114 |
+
If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying
|
115 |
+
learning rate (with hard restarts).
|
116 |
+
"""
|
117 |
+
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
|
118 |
+
super(WarmupCosineWithHardRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
|
119 |
+
assert(cycles >= 1.)
|
120 |
+
|
121 |
+
def get_lr_(self, progress):
|
122 |
+
if progress < self.warmup:
|
123 |
+
return progress / self.warmup
|
124 |
+
else:
|
125 |
+
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
|
126 |
+
ret = 0.5 * (1. + math.cos(math.pi * ((self.cycles * progress) % 1)))
|
127 |
+
return ret
|
128 |
+
|
129 |
+
|
130 |
+
class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedule):
|
131 |
+
"""
|
132 |
+
All training progress is divided in `cycles` (default=1.) parts of equal length.
|
133 |
+
Every part follows a schedule with the first `warmup` fraction of the training steps linearly increasing from 0. to 1.,
|
134 |
+
followed by a learning rate decreasing from 1. to 0. following a cosine curve.
|
135 |
+
"""
|
136 |
+
def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
|
137 |
+
assert(warmup * cycles < 1.)
|
138 |
+
warmup = warmup * cycles if warmup >= 0 else warmup
|
139 |
+
super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
|
140 |
+
|
141 |
+
def get_lr_(self, progress):
|
142 |
+
progress = progress * self.cycles % 1.
|
143 |
+
if progress < self.warmup:
|
144 |
+
return progress / self.warmup
|
145 |
+
else:
|
146 |
+
progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
|
147 |
+
ret = 0.5 * (1. + math.cos(math.pi * progress))
|
148 |
+
return ret
|
149 |
+
|
150 |
+
|
151 |
+
class WarmupConstantSchedule(_LRSchedule):
|
152 |
+
"""
|
153 |
+
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
|
154 |
+
Keeps learning rate equal to 1. after warmup.
|
155 |
+
"""
|
156 |
+
def get_lr_(self, progress):
|
157 |
+
if progress < self.warmup:
|
158 |
+
return progress / self.warmup
|
159 |
+
return 1.
|
160 |
+
|
161 |
+
|
162 |
+
class WarmupLinearSchedule(_LRSchedule):
|
163 |
+
"""
|
164 |
+
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
|
165 |
+
Linearly decreases learning rate from 1. to 0. over remaining `1 - warmup` steps.
|
166 |
+
"""
|
167 |
+
warn_t_total = True
|
168 |
+
def get_lr_(self, progress):
|
169 |
+
if progress < self.warmup:
|
170 |
+
return progress / self.warmup
|
171 |
+
return max((progress - 1.) / (self.warmup - 1.), 0.)
|
172 |
+
|
173 |
+
|
174 |
+
SCHEDULES = {
|
175 |
+
None: ConstantLR,
|
176 |
+
"none": ConstantLR,
|
177 |
+
"warmup_cosine": WarmupCosineSchedule,
|
178 |
+
"warmup_constant": WarmupConstantSchedule,
|
179 |
+
"warmup_linear": WarmupLinearSchedule
|
180 |
+
}
|
181 |
+
|
182 |
+
|
183 |
+
class EMA(object):
|
184 |
+
""" Exponential Moving Average for model parameters.
|
185 |
+
references:
|
186 |
+
[1] https://github.com/BangLiu/QANet-PyTorch/blob/master/model/modules/ema.py
|
187 |
+
[2] https://github.com/hengruo/QANet-pytorch/blob/e2de07cd2c711d525f5ffee35c3764335d4b501d/main.py"""
|
188 |
+
def __init__(self, decay):
|
189 |
+
self.decay = decay
|
190 |
+
self.shadow = {}
|
191 |
+
self.original = {}
|
192 |
+
|
193 |
+
def register(self, name, val):
|
194 |
+
self.shadow[name] = val.clone()
|
195 |
+
|
196 |
+
def __call__(self, model, step):
|
197 |
+
decay = min(self.decay, (1 + step) / (10.0 + step))
|
198 |
+
for name, param in model.named_parameters():
|
199 |
+
if param.requires_grad:
|
200 |
+
assert name in self.shadow
|
201 |
+
new_average = \
|
202 |
+
(1.0 - decay) * param.data + decay * self.shadow[name]
|
203 |
+
self.shadow[name] = new_average.clone()
|
204 |
+
|
205 |
+
def assign(self, model):
|
206 |
+
for name, param in model.named_parameters():
|
207 |
+
if param.requires_grad:
|
208 |
+
assert name in self.shadow
|
209 |
+
self.original[name] = param.data.clone()
|
210 |
+
param.data = self.shadow[name]
|
211 |
+
|
212 |
+
def resume(self, model):
|
213 |
+
for name, param in model.named_parameters():
|
214 |
+
if param.requires_grad:
|
215 |
+
assert name in self.shadow
|
216 |
+
param.data = self.original[name]
|
217 |
+
|
218 |
+
|
219 |
+
class BertAdam(Optimizer):
|
220 |
+
"""Implements BERT version of Adam algorithm with weight decay fix.
|
221 |
+
Params:
|
222 |
+
lr: learning rate
|
223 |
+
warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
|
224 |
+
t_total: total number of training steps for the learning
|
225 |
+
rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1
|
226 |
+
schedule: schedule to use for the warmup (see above).
|
227 |
+
Can be `'warmup_linear'`, `'warmup_constant'`, `'warmup_cosine'`, `'none'`, `None` or a `_LRSchedule` object (see below).
|
228 |
+
If `None` or `'none'`, learning rate is always kept constant.
|
229 |
+
Default : `'warmup_linear'`
|
230 |
+
b1: Adams b1. Default: 0.9
|
231 |
+
b2: Adams b2. Default: 0.999
|
232 |
+
e: Adams epsilon. Default: 1e-6
|
233 |
+
weight_decay: Weight decay. Default: 0.01
|
234 |
+
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
|
235 |
+
"""
|
236 |
+
def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
|
237 |
+
b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs):
|
238 |
+
if lr is not required and lr < 0.0:
|
239 |
+
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
240 |
+
if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES:
|
241 |
+
raise ValueError("Invalid schedule parameter: {}".format(schedule))
|
242 |
+
if not 0.0 <= b1 < 1.0:
|
243 |
+
raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
|
244 |
+
if not 0.0 <= b2 < 1.0:
|
245 |
+
raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
|
246 |
+
if not e >= 0.0:
|
247 |
+
raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
|
248 |
+
# initialize schedule object
|
249 |
+
if not isinstance(schedule, _LRSchedule):
|
250 |
+
schedule_type = SCHEDULES[schedule]
|
251 |
+
schedule = schedule_type(warmup=warmup, t_total=t_total)
|
252 |
+
else:
|
253 |
+
if warmup != -1 or t_total != -1:
|
254 |
+
logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. "
|
255 |
+
"Please specify custom warmup and t_total in _LRSchedule object.")
|
256 |
+
defaults = dict(lr=lr, schedule=schedule,
|
257 |
+
b1=b1, b2=b2, e=e, weight_decay=weight_decay,
|
258 |
+
max_grad_norm=max_grad_norm)
|
259 |
+
super(BertAdam, self).__init__(params, defaults)
|
260 |
+
|
261 |
+
def get_lr(self):
|
262 |
+
lr = []
|
263 |
+
for group in self.param_groups:
|
264 |
+
for p in group['params']:
|
265 |
+
state = self.state[p]
|
266 |
+
if len(state) == 0:
|
267 |
+
return [0]
|
268 |
+
lr_scheduled = group['lr']
|
269 |
+
lr_scheduled *= group['schedule'].get_lr(state['step'])
|
270 |
+
lr.append(lr_scheduled)
|
271 |
+
return lr
|
272 |
+
|
273 |
+
def step(self, closure=None):
|
274 |
+
"""Performs a single optimization step.
|
275 |
+
|
276 |
+
Arguments:
|
277 |
+
closure (callable, optional): A closure that reevaluates the model
|
278 |
+
and returns the loss.
|
279 |
+
"""
|
280 |
+
loss = None
|
281 |
+
if closure is not None:
|
282 |
+
loss = closure()
|
283 |
+
|
284 |
+
for group in self.param_groups:
|
285 |
+
for p in group['params']:
|
286 |
+
if p.grad is None:
|
287 |
+
continue
|
288 |
+
grad = p.grad.data
|
289 |
+
if grad.is_sparse:
|
290 |
+
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
|
291 |
+
|
292 |
+
state = self.state[p]
|
293 |
+
|
294 |
+
# State initialization
|
295 |
+
if len(state) == 0:
|
296 |
+
state['step'] = 0
|
297 |
+
# Exponential moving average of gradient values
|
298 |
+
state['next_m'] = torch.zeros_like(p.data)
|
299 |
+
# Exponential moving average of squared gradient values
|
300 |
+
state['next_v'] = torch.zeros_like(p.data)
|
301 |
+
|
302 |
+
next_m, next_v = state['next_m'], state['next_v']
|
303 |
+
beta1, beta2 = group['b1'], group['b2']
|
304 |
+
|
305 |
+
# Add grad clipping
|
306 |
+
if group['max_grad_norm'] > 0:
|
307 |
+
clip_grad_norm_(p, group['max_grad_norm'])
|
308 |
+
|
309 |
+
# Decay the first and second moment running average coefficient
|
310 |
+
# In-place operations to update the averages at the same time
|
311 |
+
next_m.mul_(beta1).add_(grad, alpha=1 - beta1)
|
312 |
+
next_v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
|
313 |
+
update = next_m / (next_v.sqrt() + group['e'])
|
314 |
+
|
315 |
+
# Just adding the square of the weights to the loss function is *not*
|
316 |
+
# the correct way of using L2 regularization/weight decay with Adam,
|
317 |
+
# since that will interact with the m and v parameters in strange ways.
|
318 |
+
#
|
319 |
+
# Instead we want to decay the weights in a manner that doesn't interact
|
320 |
+
# with the m/v parameters. This is equivalent to adding the square
|
321 |
+
# of the weights to the loss with plain (non-momentum) SGD.
|
322 |
+
if group['weight_decay'] > 0.0:
|
323 |
+
update += group['weight_decay'] * p.data
|
324 |
+
|
325 |
+
lr_scheduled = group['lr']
|
326 |
+
lr_scheduled *= group['schedule'].get_lr(state['step'])
|
327 |
+
|
328 |
+
update_with_lr = lr_scheduled * update
|
329 |
+
p.data.add_(-update_with_lr)
|
330 |
+
|
331 |
+
state['step'] += 1
|
332 |
+
|
333 |
+
# step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
|
334 |
+
# No bias correction
|
335 |
+
# bias_correction1 = 1 - beta1 ** state['step']
|
336 |
+
# bias_correction2 = 1 - beta2 ** state['step']
|
337 |
+
|
338 |
+
return loss
|
baselines/crossmodal_moment_localization/scripts/eval.sh
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
# run at project root dir
|
3 |
+
# Usage:
|
4 |
+
# bash baselines/crossmodal_moment_localization/scripts/eval.sh ANY_OTHER_PYTHON_ARGS
|
5 |
+
eval_split_name=$1
|
6 |
+
submission_path=$2
|
7 |
+
save_path=$3
|
8 |
+
gt_path=data/tvr_${eval_split_name}_release.jsonl
|
9 |
+
|
10 |
+
python standalone_eval/eval.py \
|
11 |
+
--gt_path ${gt_path} \
|
12 |
+
--submission_path ${submission_path} \
|
13 |
+
--save_path ${save_path} \
|
14 |
+
${@:4}
|
baselines/crossmodal_moment_localization/scripts/inference.sh
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
# run at project root dir
|
3 |
+
# Usage:
|
4 |
+
# bash baselines/crossmodal_moment_localization/scripts/inference.sh ANY_OTHER_PYTHON_ARGS
|
5 |
+
model_dir=$1
|
6 |
+
eval_split_name=$2
|
7 |
+
eval_path=data/tvr_${eval_split_name}_release.jsonl
|
8 |
+
tasks=()
|
9 |
+
tasks+=(VCMR)
|
10 |
+
tasks+=(SVMR)
|
11 |
+
tasks+=(VR)
|
12 |
+
echo "tasks ${tasks[@]}"
|
13 |
+
python baselines/crossmodal_moment_localization/inference.py \
|
14 |
+
--model_dir ${model_dir} \
|
15 |
+
--tasks ${tasks[@]} \
|
16 |
+
--eval_split_name ${eval_split_name} \
|
17 |
+
--eval_path ${eval_path} \
|
18 |
+
${@:3}
|
baselines/crossmodal_moment_localization/scripts/inference_with_external.sh
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
# run at project root dir
|
3 |
+
# Usage:
|
4 |
+
# bash baselines/crossmodal_moment_localization/scripts/inference_with_external.sh
|
5 |
+
#model_dir=$1
|
6 |
+
# DO not use NMS, since it gives worse results
|
7 |
+
eval_model=$1 # [xml, xml_tef]
|
8 |
+
eval_split_name=$2
|
9 |
+
external_model=mee # [mee, mcn, cal]
|
10 |
+
eval_path=data/tvr_${eval_split_name}_release.jsonl
|
11 |
+
project_root=./baselines
|
12 |
+
|
13 |
+
# setup eval model
|
14 |
+
if [[ ${eval_model} == xml ]]; then
|
15 |
+
eval_model_dir=tvr-video_sub-resnet_i3d_no_norm_v-2019_11_03_12_22_19
|
16 |
+
elif [[ ${eval_model} == xml_tef ]]; then
|
17 |
+
eval_model_dir=tvr-video_sub_tef-resnet_i3d_no_norm_v-2019_11_03_12_53_01
|
18 |
+
fi
|
19 |
+
|
20 |
+
# setup external
|
21 |
+
if [[ ${external_model} == mee ]]; then
|
22 |
+
external_model_dir=tvr-video_sub-res-2019_11_06_00_33_39
|
23 |
+
external_inference_vr_res_path=${project_root}/mixture_embedding_experts/results/${external_model_dir}/inference_tvr_${eval_split_name}_None_predictions_VR.json
|
24 |
+
fi
|
25 |
+
|
26 |
+
tasks=(VR)
|
27 |
+
tasks+=(SVMR)
|
28 |
+
tasks+=(VCMR)
|
29 |
+
echo "tasks ${tasks[@]}"
|
30 |
+
python baselines/crossmodal_moment_localization/inference.py \
|
31 |
+
--model_dir ${eval_model_dir} \
|
32 |
+
--tasks ${tasks[@]} \
|
33 |
+
--eval_split_name ${eval_split_name} \
|
34 |
+
--eval_path ${eval_path} \
|
35 |
+
--external_inference_vr_res_path ${external_inference_vr_res_path} \
|
36 |
+
--eval_id ${external_model_dir} \
|
37 |
+
${@:3}
|
38 |
+
|
39 |
+
#--use_intermediate \ # temporary removed
|
40 |
+
|
baselines/crossmodal_moment_localization/scripts/train.sh
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
# run at project root dir
|
3 |
+
# Usage:
|
4 |
+
# bash baselines/crossmodal_moment_localization/scripts/train.sh tvr all ANY_OTHER_PYTHON_ARGS
|
5 |
+
# use --eval_tasks_at_training ["VR", "SVMR", "VCMR"] --stop_task ["VR", "SVMR", "VCMR"] for
|
6 |
+
# use --lw_neg_q 0 --lw_neg_ctx 0 for training SVMR/SVMR only
|
7 |
+
# use --lw_st_ed 0 for training with VR only
|
8 |
+
dset_name=$1 # see case below
|
9 |
+
ctx_mode=$2 # [video, sub, tef, video_sub, video_tef, sub_tef, video_sub_tef]
|
10 |
+
vid_feat_type=$3 # [resnet, i3d, resnet_i3d]
|
11 |
+
feature_root=data/tvr_feature_release
|
12 |
+
results_root=baselines/crossmodal_moment_localization/results
|
13 |
+
vid_feat_size=2048
|
14 |
+
extra_args=()
|
15 |
+
|
16 |
+
if [[ ${ctx_mode} == *"sub"* ]] || [[ ${ctx_mode} == "sub" ]]; then
|
17 |
+
if [[ ${dset_name} != "tvr" ]]; then
|
18 |
+
echo "The use of subtitles is only supported in tvr."
|
19 |
+
exit 1
|
20 |
+
fi
|
21 |
+
fi
|
22 |
+
|
23 |
+
|
24 |
+
case ${dset_name} in
|
25 |
+
tvr)
|
26 |
+
train_path=data/tvr_train_release.jsonl
|
27 |
+
corpus_path=data/tvr_video2dur_idx.json
|
28 |
+
desc_bert_path=${feature_root}/bert_feature/query_only/tvr_query_pretrained_w_query.h5
|
29 |
+
if [[ ${vid_feat_type} == "i3d" ]]; then
|
30 |
+
echo "Using I3D feature with shape 1024"
|
31 |
+
vid_feat_path=${feature_root}/video_feature/tvr_i3d_rgb600_avg_cl-1.5.h5
|
32 |
+
vid_feat_size=1024
|
33 |
+
elif [[ ${vid_feat_type} == "resnet" ]]; then
|
34 |
+
echo "Using ResNet feature with shape 2048"
|
35 |
+
vid_feat_path=${feature_root}/video_feature/tvr_resnet152_rgb_max_cl-1.5.h5
|
36 |
+
vid_feat_size=2048
|
37 |
+
elif [[ ${vid_feat_type} == "resnet_i3d" ]]; then
|
38 |
+
echo "Using concatenated ResNet and I3D feature with shape 2048+1024"
|
39 |
+
vid_feat_path=${feature_root}/video_feature/tvr_resnet152_rgb_max_i3d_rgb600_avg_cat_cl-1.5.h5
|
40 |
+
vid_feat_size=3072
|
41 |
+
extra_args+=(--no_norm_vfeat) # since they are already normalized.
|
42 |
+
fi
|
43 |
+
eval_split_name=val
|
44 |
+
nms_thd=-1
|
45 |
+
extra_args+=(--eval_path)
|
46 |
+
extra_args+=(data/tvr_val_release.jsonl)
|
47 |
+
clip_length=1.5
|
48 |
+
extra_args+=(--max_ctx_l)
|
49 |
+
extra_args+=(100) # max_ctx_l = 100 for clip_length = 1.5, only ~109/21825 has more than 100.
|
50 |
+
extra_args+=(--max_pred_l)
|
51 |
+
extra_args+=(16)
|
52 |
+
if [[ ${ctx_mode} == *"sub"* ]] || [[ ${ctx_mode} == "sub" ]]; then
|
53 |
+
echo "Running with sub."
|
54 |
+
desc_bert_path=${feature_root}/bert_feature/sub_query/tvr_query_pretrained_w_sub_query.h5 # overwrite
|
55 |
+
sub_bert_path=${feature_root}/bert_feature/sub_query/tvr_sub_pretrained_w_sub_query_max_cl-1.5.h5
|
56 |
+
sub_feat_size=768
|
57 |
+
extra_args+=(--sub_feat_size)
|
58 |
+
extra_args+=(${sub_feat_size})
|
59 |
+
extra_args+=(--sub_bert_path)
|
60 |
+
extra_args+=(${sub_bert_path})
|
61 |
+
fi
|
62 |
+
;;
|
63 |
+
*)
|
64 |
+
echo -n "Unknown argument"
|
65 |
+
;;
|
66 |
+
esac
|
67 |
+
|
68 |
+
echo "Start training with dataset [${dset_name}] in Context Mode [${ctx_mode}]"
|
69 |
+
echo "Extra args ${extra_args[@]}"
|
70 |
+
echo " python baselines/crossmodal_moment_localization/train.py --dset_name=${dset_name} --eval_split_name=${eval_split_name} --nms_thd=${nms_thd} --results_root=${results_root} --train_path=${train_path} --desc_bert_path=${desc_bert_path} --corpus_path=${corpus_path} --vid_feat_path=${vid_feat_path} --clip_length=${clip_length} --vid_feat_size=${vid_feat_size} --ctx_mode=${ctx_mode} ${extra_args[@]} ${@:4}"
|
baselines/crossmodal_moment_localization/start_end_dataset.py
ADDED
@@ -0,0 +1,393 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Dataset for clip model
|
3 |
+
"""
|
4 |
+
import logging
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
import numpy as np
|
8 |
+
import h5py
|
9 |
+
import time
|
10 |
+
import math
|
11 |
+
import random
|
12 |
+
from tqdm import tqdm
|
13 |
+
from utils.basic_utils import load_json, load_json, l2_normalize_np_array, flat_list_of_lists, merge_dicts
|
14 |
+
from utils.tensor_utils import pad_sequences_1d
|
15 |
+
from baselines.clip_alignment_with_language.local_utils.compute_proposal_upper_bound import \
|
16 |
+
get_didemo_agreed_ts
|
17 |
+
import pandas as pd
|
18 |
+
|
19 |
+
logger = logging.getLogger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
class StartEndDataset(Dataset):
|
23 |
+
"""
|
24 |
+
Args:
|
25 |
+
dset_name, str, ["tvr"]
|
26 |
+
ctx_mode: str,
|
27 |
+
Return:
|
28 |
+
a dict: {
|
29 |
+
"meta": {
|
30 |
+
"query_id": int,
|
31 |
+
"desc": str,
|
32 |
+
"vid_name": str,
|
33 |
+
"duration": float,
|
34 |
+
"ts": [st (float), ed (float)], seconds, ground_truth timestamps
|
35 |
+
}
|
36 |
+
"model_inputs": {
|
37 |
+
"query_feat": torch.tensor, (L, D_q)
|
38 |
+
"video_feat": torch.tensor, (n_clip_in_moment, D_video)
|
39 |
+
"sub_feat": torch.tensor, (n_clip_in_moment, D_sub)
|
40 |
+
"st_ed_indices": torch.LongTensor, (2, )
|
41 |
+
}
|
42 |
+
}
|
43 |
+
"""
|
44 |
+
def __init__(self, dset_name, data_path, desc_bert_path_or_handler, sub_bert_path_or_handler,
|
45 |
+
max_desc_len, max_ctx_len,
|
46 |
+
vid_feat_path_or_handler, clip_length, ctx_mode="video",
|
47 |
+
normalize_vfeat=True, normalize_tfeat=True, h5driver=None, data_ratio=1.0):
|
48 |
+
self.dset_name = dset_name
|
49 |
+
self.data_path = data_path
|
50 |
+
self.data_ratio = data_ratio
|
51 |
+
|
52 |
+
self.desc_bert_path_or_handler = desc_bert_path_or_handler
|
53 |
+
self.max_desc_len = max_desc_len
|
54 |
+
|
55 |
+
self.sub_bert_path_or_handler = sub_bert_path_or_handler
|
56 |
+
self.max_ctx_len = max_ctx_len
|
57 |
+
self.vid_feat_path_or_handler = vid_feat_path_or_handler
|
58 |
+
self.clip_length = clip_length
|
59 |
+
self.ctx_mode = ctx_mode
|
60 |
+
|
61 |
+
# prepare desc data
|
62 |
+
self.data = self.expand_annotations(load_json(data_path))
|
63 |
+
|
64 |
+
if self.data_ratio != 1:
|
65 |
+
n_examples = int(len(self.data) * data_ratio)
|
66 |
+
self.data = self.data[:n_examples]
|
67 |
+
logger.info("Using {}% of the data: {} examples".format(data_ratio * 100, n_examples))
|
68 |
+
|
69 |
+
self.use_video = "video" in self.ctx_mode
|
70 |
+
self.use_sub = "sub" in self.ctx_mode
|
71 |
+
self.use_tef = "tef" in self.ctx_mode
|
72 |
+
|
73 |
+
if self.use_video:
|
74 |
+
if isinstance(vid_feat_path_or_handler, h5py.File):
|
75 |
+
self.vid_feat_h5 = vid_feat_path_or_handler
|
76 |
+
else: # str path
|
77 |
+
self.vid_feat_h5 = h5py.File(vid_feat_path_or_handler, "r", driver=h5driver)
|
78 |
+
|
79 |
+
if isinstance(desc_bert_path_or_handler, h5py.File):
|
80 |
+
self.desc_bert_h5 = desc_bert_path_or_handler
|
81 |
+
else:
|
82 |
+
self.desc_bert_h5 = h5py.File(desc_bert_path_or_handler, "r", driver=h5driver)
|
83 |
+
|
84 |
+
if self.use_sub:
|
85 |
+
if isinstance(sub_bert_path_or_handler, h5py.File):
|
86 |
+
self.sub_bert_h5 = sub_bert_path_or_handler
|
87 |
+
else: # str path
|
88 |
+
self.sub_bert_h5 = h5py.File(sub_bert_path_or_handler, "r", driver=h5driver)
|
89 |
+
|
90 |
+
self.normalize_vfeat = normalize_vfeat
|
91 |
+
self.normalize_tfeat = normalize_tfeat
|
92 |
+
|
93 |
+
def __len__(self):
|
94 |
+
return len(self.data)
|
95 |
+
|
96 |
+
def expand_annotations(self, annotations):
|
97 |
+
new_annotations = []
|
98 |
+
for i in annotations:
|
99 |
+
query = i["query"]
|
100 |
+
query_id = i["query_id"]
|
101 |
+
for moment in i["relevant_moment"]:
|
102 |
+
moment.update({'query': query, 'query_id': query_id})
|
103 |
+
new_annotations.append(moment)
|
104 |
+
return new_annotations
|
105 |
+
|
106 |
+
def __getitem__(self, index):
|
107 |
+
raw_data = self.data[index]
|
108 |
+
|
109 |
+
# initialize with basic data
|
110 |
+
meta = dict(
|
111 |
+
query_id=raw_data["query_id"],
|
112 |
+
desc=raw_data["query"],
|
113 |
+
vid_name=raw_data["video_name"],
|
114 |
+
duration=raw_data["duration"],
|
115 |
+
ts=raw_data["timestamp"] ,
|
116 |
+
)
|
117 |
+
model_inputs = dict()
|
118 |
+
model_inputs["query_feat"] = self.get_query_feat_by_query_id(meta["query_id"])
|
119 |
+
|
120 |
+
ctx_l = 0
|
121 |
+
if self.use_video:
|
122 |
+
video_feat = self.vid_feat_h5[meta["vid_name"]][:self.max_ctx_len] # (N_clip, D)
|
123 |
+
if self.normalize_vfeat:
|
124 |
+
video_feat = l2_normalize_np_array(video_feat)
|
125 |
+
model_inputs["video_feat"] = torch.from_numpy(video_feat)
|
126 |
+
ctx_l = len(video_feat)
|
127 |
+
else:
|
128 |
+
model_inputs["video_feat"] = torch.zeros((2, 2))
|
129 |
+
|
130 |
+
if self.use_sub: # no need for ctx feature, as the features are already contextulized
|
131 |
+
sub_feat = self.sub_bert_h5[meta["vid_name"]][:self.max_ctx_len] # (N_clips, D_t)
|
132 |
+
if self.normalize_tfeat:
|
133 |
+
sub_feat = l2_normalize_np_array(sub_feat)
|
134 |
+
model_inputs["sub_feat"] = torch.from_numpy(sub_feat)
|
135 |
+
ctx_l = len(sub_feat)
|
136 |
+
else:
|
137 |
+
model_inputs["sub_feat"] = torch.zeros((2, 2))
|
138 |
+
|
139 |
+
if self.use_tef:
|
140 |
+
# note the tef features here are normalized clip indices (1.5 secs), instead of the original time (1 sec)
|
141 |
+
ctx_l = meta["duration"] // self.clip_length + 1 if ctx_l == 0 else ctx_l
|
142 |
+
tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
|
143 |
+
tef_ed = tef_st + 1.0 / ctx_l
|
144 |
+
tef = torch.stack([tef_st, tef_ed], dim=1) # (N_clips, 2)
|
145 |
+
model_inputs["tef_feat"] = tef
|
146 |
+
else:
|
147 |
+
model_inputs["tef_feat"] = torch.zeros((2, 2))
|
148 |
+
|
149 |
+
if self.use_video and self.use_tef:
|
150 |
+
model_inputs["video_feat"] = torch.cat(
|
151 |
+
[model_inputs["video_feat"], model_inputs["tef_feat"]], dim=1) # (N_clips, D+2)
|
152 |
+
if self.use_sub and self.use_tef:
|
153 |
+
model_inputs["sub_feat"] = torch.cat(
|
154 |
+
[model_inputs["sub_feat"], model_inputs["tef_feat"]], dim=1) # (N_clips, D_t+2)
|
155 |
+
|
156 |
+
model_inputs["st_ed_indices"] = self.get_st_ed_label(meta["ts"], max_idx=ctx_l-1)
|
157 |
+
return dict(meta=meta, model_inputs=model_inputs)
|
158 |
+
|
159 |
+
def get_st_ed_label(self, ts, max_idx):
|
160 |
+
"""
|
161 |
+
Args:
|
162 |
+
ts: [st (float), ed (float)] in seconds, ed > st
|
163 |
+
max_idx: length of the video
|
164 |
+
|
165 |
+
Returns:
|
166 |
+
[st_idx, ed_idx]: int,
|
167 |
+
|
168 |
+
Given ts = [3.2, 7.6], st_idx = 2, ed_idx = 6,
|
169 |
+
clips should be indexed as [2: 6), the translated back ts should be [3:9].
|
170 |
+
# TODO which one is better, [2: 5] or [2: 6)
|
171 |
+
"""
|
172 |
+
st_idx = min(math.floor(ts[0] / self.clip_length), max_idx)
|
173 |
+
ed_idx = min(math.ceil(ts[1] / self.clip_length), max_idx)
|
174 |
+
return torch.LongTensor([st_idx, ed_idx])
|
175 |
+
|
176 |
+
def get_query_feat_by_query_id(self, query_id):
|
177 |
+
query_feat = self.desc_bert_h5[str(query_id)][:self.max_desc_len]
|
178 |
+
if self.normalize_tfeat:
|
179 |
+
query_feat = l2_normalize_np_array(query_feat)
|
180 |
+
return torch.from_numpy(query_feat)
|
181 |
+
|
182 |
+
|
183 |
+
class StartEndEvalDataset(Dataset):
|
184 |
+
"""
|
185 |
+
init_data_mode: `video_query` or `video_only` or `query_only`,
|
186 |
+
it indicates which data to load when initialize the Dataset object.
|
187 |
+
data_mode: `context` or `query`, it indicates which data to return for self.__get_item__()
|
188 |
+
desc_bert_path_or_handler: h5py.File object or str path
|
189 |
+
vid_feat_path_or_handler: h5py.File object or str path
|
190 |
+
eval_proposal_bsz: the proposals for a single video will be sorted in length and batched here with
|
191 |
+
max batch size to be eval_proposal_bsz. A single video might have multiple batches of proposals.
|
192 |
+
load_gt_video: load GroundTruth Video, useful when evaluating single video moment retrieval.
|
193 |
+
data_ratio: percentage of query data to use.
|
194 |
+
"""
|
195 |
+
def __init__(self, data_path=None,
|
196 |
+
desc_bert_path_or_handler=None, max_desc_len=None, max_ctx_len=None,
|
197 |
+
sub_bert_path_or_handler=None, vid_feat_path_or_handler=None,
|
198 |
+
corpus_path=None, clip_length=None,
|
199 |
+
ctx_mode="video", data_mode="context",
|
200 |
+
h5driver=None, data_ratio=1.0, normalize_vfeat=True, normalize_tfeat=True):
|
201 |
+
self.ctx_mode = ctx_mode
|
202 |
+
self.load_gt_video = False
|
203 |
+
self.data_ratio = data_ratio # only affect query data
|
204 |
+
self.normalize_vfeat = normalize_vfeat
|
205 |
+
self.normalize_tfeat = normalize_tfeat
|
206 |
+
|
207 |
+
self.data_mode = None
|
208 |
+
self.set_data_mode(data_mode)
|
209 |
+
|
210 |
+
self.max_desc_len = max_desc_len
|
211 |
+
self.max_ctx_len = max_ctx_len
|
212 |
+
self.data_path = data_path
|
213 |
+
|
214 |
+
|
215 |
+
self.annotations = load_json(data_path)
|
216 |
+
self.ground_truth = self.get_relevant_moment_gt()
|
217 |
+
|
218 |
+
|
219 |
+
if isinstance(desc_bert_path_or_handler, h5py.File):
|
220 |
+
self.desc_bert_h5 = desc_bert_path_or_handler
|
221 |
+
else:
|
222 |
+
self.desc_bert_h5 = h5py.File(desc_bert_path_or_handler, "r", driver=h5driver)
|
223 |
+
|
224 |
+
video_data = load_json(corpus_path)
|
225 |
+
self.video_data = [{"vid_name": k, "duration": v} for k, v in video_data.items()]
|
226 |
+
self.video2idx = {k: v for k, v in video_data.items()}
|
227 |
+
self.clip_length = clip_length
|
228 |
+
|
229 |
+
self.use_video = "video" in self.ctx_mode
|
230 |
+
self.use_sub = "sub" in self.ctx_mode
|
231 |
+
self.use_tef = "tef" in self.ctx_mode
|
232 |
+
|
233 |
+
if self.use_video:
|
234 |
+
if isinstance(vid_feat_path_or_handler, h5py.File):
|
235 |
+
self.vid_feat_h5 = vid_feat_path_or_handler
|
236 |
+
else: # str path
|
237 |
+
self.vid_feat_h5 = h5py.File(vid_feat_path_or_handler, "r", driver=h5driver)
|
238 |
+
|
239 |
+
if self.use_sub:
|
240 |
+
if isinstance(sub_bert_path_or_handler, h5py.File):
|
241 |
+
self.sub_bert_h5 = sub_bert_path_or_handler
|
242 |
+
else: # str path
|
243 |
+
self.sub_bert_h5 = h5py.File(sub_bert_path_or_handler, "r", driver=h5driver)
|
244 |
+
|
245 |
+
|
246 |
+
def get_relevant_moment_gt(self):
|
247 |
+
gt_all = {}
|
248 |
+
for data in self.annotations:
|
249 |
+
gt_all[data["query_id"]] = data["relevant_moment"]
|
250 |
+
return gt_all
|
251 |
+
|
252 |
+
def set_data_mode(self, data_mode):
|
253 |
+
"""context or query"""
|
254 |
+
assert data_mode in ["context", "query"]
|
255 |
+
self.data_mode = data_mode
|
256 |
+
|
257 |
+
# def load_gt_vid_name_for_query(self, load_gt_video):
|
258 |
+
# """load_gt_video: bool, affect the returned value of self._get_item_query"""
|
259 |
+
# if load_gt_video:
|
260 |
+
# assert "vid_name" in self.query_data[0]
|
261 |
+
# self.load_gt_video = load_gt_video
|
262 |
+
|
263 |
+
def __len__(self):
|
264 |
+
if self.data_mode == "context":
|
265 |
+
return len(self.video_data)
|
266 |
+
else:
|
267 |
+
return len(self.annotations)
|
268 |
+
|
269 |
+
def __getitem__(self, index):
|
270 |
+
if self.data_mode == "context":
|
271 |
+
return self._get_item_context(index)
|
272 |
+
else:
|
273 |
+
return self._get_item_query(index)
|
274 |
+
|
275 |
+
def get_query_feat_by_query_id(self, query_id):
|
276 |
+
query_feat = self.desc_bert_h5[str(query_id)][:self.max_desc_len]
|
277 |
+
if self.normalize_tfeat:
|
278 |
+
query_feat = l2_normalize_np_array(query_feat)
|
279 |
+
return torch.from_numpy(query_feat)
|
280 |
+
|
281 |
+
def _get_item_query(self, index):
|
282 |
+
"""Need to batch"""
|
283 |
+
raw_data = self.annotations[index]
|
284 |
+
|
285 |
+
meta = dict(
|
286 |
+
query_id=raw_data["query_id"],
|
287 |
+
desc=raw_data["query"],
|
288 |
+
vid_name=raw_data["video_name"] if self.load_gt_video else None
|
289 |
+
)
|
290 |
+
|
291 |
+
model_inputs = dict()
|
292 |
+
model_inputs["query_feat"] = self.get_query_feat_by_query_id(meta["query_id"])
|
293 |
+
return dict(meta=meta, model_inputs=model_inputs)
|
294 |
+
|
295 |
+
def get_st_ed_label(self, ts, max_idx):
|
296 |
+
"""
|
297 |
+
Args:
|
298 |
+
ts: [st (float), ed (float)] in seconds, ed > st
|
299 |
+
max_idx: length of the video
|
300 |
+
|
301 |
+
Returns:
|
302 |
+
[st_idx, ed_idx]: int,
|
303 |
+
|
304 |
+
Given ts = [3.2, 7.6], st_idx = 2, ed_idx = 6,
|
305 |
+
clips should be indexed as [2: 6), the translated back ts should be [3:9].
|
306 |
+
Given ts = [5, 9], st_idx = 3, ed_idx = 6,
|
307 |
+
clips should be indexed as [3: 6), the translated back ts should be [4.5:9].
|
308 |
+
# TODO which one is better, [2: 5] or [2: 6)
|
309 |
+
"""
|
310 |
+
# TODO ed_idx -= 1, should also modify relevant code in inference.py
|
311 |
+
st_idx = min(math.floor(ts[0] / self.clip_length), max_idx)
|
312 |
+
ed_idx = min(math.ceil(ts[1] / self.clip_length) - 1, max_idx) # st_idx could be the same as ed_idx
|
313 |
+
return torch.LongTensor([st_idx, ed_idx])
|
314 |
+
|
315 |
+
def _get_item_context(self, index):
|
316 |
+
"""No need to batch, since it has already been batched here"""
|
317 |
+
raw_data = self.video_data[index]
|
318 |
+
|
319 |
+
# initialize with basic data
|
320 |
+
meta = dict(
|
321 |
+
vid_name=raw_data["vid_name"],
|
322 |
+
duration=raw_data["duration"],
|
323 |
+
)
|
324 |
+
|
325 |
+
model_inputs = dict()
|
326 |
+
ctx_l = 0
|
327 |
+
|
328 |
+
if self.use_video:
|
329 |
+
video_feat = self.vid_feat_h5[meta["vid_name"]][:self.max_ctx_len] # (N_clip, D)
|
330 |
+
if self.normalize_vfeat:
|
331 |
+
video_feat = l2_normalize_np_array(video_feat)
|
332 |
+
model_inputs["video_feat"] = torch.from_numpy(video_feat)
|
333 |
+
ctx_l = len(video_feat)
|
334 |
+
else:
|
335 |
+
model_inputs["video_feat"] = torch.zeros((2, 2))
|
336 |
+
|
337 |
+
if self.use_sub: # no need for ctx feature, as the features are already contextulized
|
338 |
+
sub_feat = self.sub_bert_h5[meta["vid_name"]][:self.max_ctx_len] # (N_clips, D_t)
|
339 |
+
if self.normalize_tfeat:
|
340 |
+
sub_feat = l2_normalize_np_array(sub_feat)
|
341 |
+
model_inputs["sub_feat"] = torch.from_numpy(sub_feat)
|
342 |
+
ctx_l = len(sub_feat)
|
343 |
+
else:
|
344 |
+
model_inputs["sub_feat"] = torch.zeros((2, 2))
|
345 |
+
|
346 |
+
if self.use_tef:
|
347 |
+
ctx_l = meta["duration"] // self.clip_length + 1 if ctx_l == 0 else ctx_l
|
348 |
+
tef_st = torch.arange(0, ctx_l, 1.0) / ctx_l
|
349 |
+
tef_ed = tef_st + 1.0 / ctx_l
|
350 |
+
tef = torch.stack([tef_st, tef_ed], dim=1) # (N_clips, 2)
|
351 |
+
model_inputs["tef_feat"] = tef
|
352 |
+
else:
|
353 |
+
model_inputs["tef_feat"] = torch.zeros((2, 2))
|
354 |
+
|
355 |
+
if self.use_video and self.use_tef:
|
356 |
+
model_inputs["video_feat"] = torch.cat(
|
357 |
+
[model_inputs["video_feat"], model_inputs["tef_feat"]], dim=1) # (N_clips, D+2)
|
358 |
+
if self.use_sub and self.use_tef:
|
359 |
+
model_inputs["sub_feat"] = torch.cat(
|
360 |
+
[model_inputs["sub_feat"], model_inputs["tef_feat"]], dim=1) # (N_clips, D_t+2)
|
361 |
+
return dict(meta=meta, model_inputs=model_inputs)
|
362 |
+
|
363 |
+
|
364 |
+
def start_end_collate(batch):
|
365 |
+
batch_meta = [e["meta"] for e in batch] # seems no need to collate ?
|
366 |
+
|
367 |
+
model_inputs_keys = batch[0]["model_inputs"].keys()
|
368 |
+
batched_data = dict()
|
369 |
+
for k in model_inputs_keys:
|
370 |
+
if "feat" in k:
|
371 |
+
batched_data[k] = pad_sequences_1d(
|
372 |
+
[e["model_inputs"][k] for e in batch], dtype=torch.float32, fixed_length=None)
|
373 |
+
|
374 |
+
if "st_ed_indices" in model_inputs_keys:
|
375 |
+
batched_data["st_ed_indices"] = torch.stack(
|
376 |
+
[e["model_inputs"]["st_ed_indices"] for e in batch], dim=0)
|
377 |
+
return batch_meta, batched_data
|
378 |
+
|
379 |
+
|
380 |
+
def prepare_batch_inputs(batched_model_inputs, device, non_blocking=False):
|
381 |
+
model_inputs = {}
|
382 |
+
for k, v in batched_model_inputs.items():
|
383 |
+
if "feat" in k:
|
384 |
+
model_inputs[k] = v[0].to(device, non_blocking=non_blocking)
|
385 |
+
model_inputs[k.replace("feat", "mask")] = v[1].to(device, non_blocking=non_blocking)
|
386 |
+
else:
|
387 |
+
model_inputs[k] = v.to(device, non_blocking=non_blocking)
|
388 |
+
return model_inputs
|
389 |
+
|
390 |
+
|
391 |
+
if __name__ == '__main__':
|
392 |
+
from baselines.crossmodal_moment_localization.config import BaseOptions
|
393 |
+
options = BaseOptions().parse()
|