Liangrj5
commited on
Commit
·
dae63ab
1
Parent(s):
8897497
correct ndcg-iou
Browse files- .gitignore +0 -3
- README.md +44 -24
- infer.py +6 -2
- infer_top20.sh +6 -9
- modules/ReLoCLNet.py +2 -1
- modules/dataset_init.py +3 -0
- modules/dataset_tvrr.py +5 -5
- modules/infer_lib.py +23 -19
- modules/ndcg_iou.py +1 -1
- results/ReLoCLNet/top01/20240704_170921_top01.log +0 -0
- results/ReLoCLNet/top01/best_model.pt +3 -0
- results/ReLoCLNet/top20/20240704_170928_top20.log +0 -0
- results/ReLoCLNet/top20/best_model.pt +3 -0
- results/ReLoCLNet/top40/20240704_170937_top40.log +0 -0
- results/ReLoCLNet/top40/best_model.pt +3 -0
- run_top01.sh +17 -0
- run_top20.sh +10 -7
- run_top40.sh +17 -0
- train.py +19 -16
- utils/run_utils.py +37 -11
- utils/setup.py +3 -2
.gitignore
CHANGED
@@ -7,8 +7,6 @@ __pycache__/
|
|
7 |
*.so
|
8 |
|
9 |
unused
|
10 |
-
|
11 |
-
results
|
12 |
# Distribution / packaging
|
13 |
.Python
|
14 |
build/
|
@@ -59,7 +57,6 @@ coverage.xml
|
|
59 |
*.pot
|
60 |
|
61 |
# Django stuff:
|
62 |
-
*.log
|
63 |
local_settings.py
|
64 |
db.sqlite3
|
65 |
db.sqlite3-journal
|
|
|
7 |
*.so
|
8 |
|
9 |
unused
|
|
|
|
|
10 |
# Distribution / packaging
|
11 |
.Python
|
12 |
build/
|
|
|
57 |
*.pot
|
58 |
|
59 |
# Django stuff:
|
|
|
60 |
local_settings.py
|
61 |
db.sqlite3
|
62 |
db.sqlite3-journal
|
README.md
CHANGED
@@ -1,14 +1,9 @@
|
|
1 |
-
---
|
2 |
-
license: cc
|
3 |
-
language:
|
4 |
-
- en
|
5 |
-
---
|
6 |
# Video Moment Retrieval in Practical Setting: A Dataset of Ranked Moments for Imprecise Queries
|
7 |
|
8 |
The benchmark and dataset for the paper "Video Moment Retrieval in Practical Settings: A Dataset of Ranked Moments for Imprecise Queries" is coming soon.
|
9 |
|
10 |
-
We recommend cloning the code, data, and feature files from the Hugging Face repository at [TVR-Ranking](https://huggingface.co/axgroup/TVR-Ranking).
|
11 |
-
|
12 |
![TVR_Ranking_overview](./figures/taskComparisonV.png)
|
13 |
|
14 |
|
@@ -57,30 +52,55 @@ tar -xf tvr_feature_release.tar.gz -C data/TVR_Ranking/feature
|
|
57 |
# modify the data path first
|
58 |
sh run_top20.sh
|
59 |
```
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
-
##
|
62 |
-
|
63 |
-
The baseline performance of $NDGC@
|
64 |
Top $N$ moments were comprised of a pseudo training set by the query-caption similarity.
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
| |
|
69 |
-
|
|
70 |
-
|
|
71 |
-
| |
|
72 |
-
|
|
73 |
-
|
|
74 |
-
| |
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
|
77 |
|
78 |
-
### 4. Inferring
|
79 |
-
[ToDo] The checkpoint can all be accessed from Hugging Face [TVR-Ranking](https://huggingface.co/axgroup/TVR-Ranking).
|
80 |
|
81 |
|
82 |
## Citation
|
83 |
If you feel this project helpful to your research, please cite our work.
|
84 |
```
|
85 |
|
86 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# Video Moment Retrieval in Practical Setting: A Dataset of Ranked Moments for Imprecise Queries
|
2 |
|
3 |
The benchmark and dataset for the paper "Video Moment Retrieval in Practical Settings: A Dataset of Ranked Moments for Imprecise Queries" is coming soon.
|
4 |
|
5 |
+
We recommend cloning the code, data, and feature files from the Hugging Face repository at [TVR-Ranking](https://huggingface.co/axgroup/TVR-Ranking). This repository only includes the code for ReLoCLNet. You can download the other baseline models from [XML](https://huggingface.co/LiangRenjie/XML_RVMR) and [CONQUER](https://huggingface.co/LiangRenjie/CONQUER_RVMR).
|
6 |
+
|
7 |
![TVR_Ranking_overview](./figures/taskComparisonV.png)
|
8 |
|
9 |
|
|
|
52 |
# modify the data path first
|
53 |
sh run_top20.sh
|
54 |
```
|
55 |
+
### 5. Inferring
|
56 |
+
The checkpoint can all be accessed from Hugging Face [TVR-Ranking](https://huggingface.co/axgroup/TVR-Ranking).
|
57 |
+
```shell
|
58 |
+
sh infer_top20.sh
|
59 |
+
```
|
60 |
|
61 |
+
## Experiment Results
|
62 |
+
### Baseline
|
63 |
+
The baseline performance of $NDGC@40$ was shown as follows.
|
64 |
Top $N$ moments were comprised of a pseudo training set by the query-caption similarity.
|
65 |
+
|
66 |
+
| **Model** | **Train Set Top N** | **IoU=0.3** | |**IoU=0.5** | |**IoU=0.7** | |
|
67 |
+
|----------------|---------------------|--------------|--------------|--------------|--------------|--------------|--------------|
|
68 |
+
| | | **Val** | **Test** | **Val** | **Test** | **Val** | **Test** |
|
69 |
+
| **XML** | 1 | 0.1077 | 0.1016 | 0.0775 | 0.0727 | 0.0273 | 0.0294 |
|
70 |
+
| | 20 | 0.2580 | 0.2512 | 0.1874 | 0.1853 | 0.0705 | 0.0753 |
|
71 |
+
| | 40 | 0.2408 | 0.2432 | 0.1740 | 0.1791 | 0.0666 | 0.0720 |
|
72 |
+
| **ReLoCLNet** | 1 | 0.1533 | 0.1489 | 0.1321 | 0.1304 | 0.0878 | 0.0869 |
|
73 |
+
| | 20 | 0.4039 | 0.4031 | 0.3656 | 0.3648 | 0.2542 | 0.2567 |
|
74 |
+
| | 40 | 0.4725 | 0.4735 | 0.4337 | 0.4337 | 0.3015 | 0.3079 |
|
75 |
+
|
76 |
+
|
77 |
+
### ReLoCLNet Performance
|
78 |
+
|
79 |
+
| **Model** | **Train Set Top N** | **IoU=0.3** | |**IoU=0.5** | |**IoU=0.7** | |
|
80 |
+
|------------|---------------------|--------------|--------------|--------------|--------------|--------------|--------------|
|
81 |
+
| | | **Val** | **Test** | **Val** | **Test** | **Val** | **Test** |
|
82 |
+
| **NDCG@10** | | | | | | | |
|
83 |
+
| ReLoCLNet | 1 | 0.1575 | 0.1525 | 0.1358 | 0.1349 | 0.0908 | 0.0916 |
|
84 |
+
| ReLoCLNet | 20 | 0.3751 | 0.3751 | 0.3407 | 0.3397 | 0.2316 | 0.2338 |
|
85 |
+
| ReLoCLNet | 40 | 0.4339 | 0.4353 | 0.3984 | 0.3986 | 0.2693 | 0.2807 |
|
86 |
+
| **NDCG@20** | | | | | | | |
|
87 |
+
| ReLoCLNet | 1 | 0.1504 | 0.1439 | 0.1303 | 0.1269 | 0.0866 | 0.0849 |
|
88 |
+
| ReLoCLNet | 20 | 0.3815 | 0.3792 | 0.3462 | 0.3427 | 0.2381 | 0.2386 |
|
89 |
+
| ReLoCLNet | 40 | 0.4418 | 0.4439 | 0.4060 | 0.4059 | 0.2787 | 0.2877 |
|
90 |
+
| **NDCG@40** | | | | | | | |
|
91 |
+
| ReLoCLNet | 1 | 0.1533 | 0.1489 | 0.1321 | 0.1304 | 0.0878 | 0.0869 |
|
92 |
+
| ReLoCLNet | 20 | 0.4039 | 0.4031 | 0.3656 | 0.3648 | 0.2542 | 0.2567 |
|
93 |
+
| ReLoCLNet | 40 | 0.4725 | 0.4735 | 0.4337 | 0.4337 | 0.3015 | 0.3079 |
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
|
99 |
|
|
|
|
|
100 |
|
101 |
|
102 |
## Citation
|
103 |
If you feel this project helpful to your research, please cite our work.
|
104 |
```
|
105 |
|
106 |
+
```
|
infer.py
CHANGED
@@ -5,9 +5,9 @@ from tqdm import tqdm
|
|
5 |
from modules.dataset_init import prepare_dataset
|
6 |
from modules.infer_lib import grab_corpus_feature, eval_epoch
|
7 |
|
8 |
-
from utils.basic_utils import
|
9 |
from utils.setup import set_seed, get_args
|
10 |
-
from utils.run_utils import prepare_optimizer, prepare_model, logger_ndcg_iou
|
11 |
|
12 |
def main():
|
13 |
opt = get_args()
|
@@ -21,7 +21,11 @@ def main():
|
|
21 |
|
22 |
model = prepare_model(opt, logger)
|
23 |
# optimizer = prepare_optimizer(model, opt, len(train_loader) * opt.n_epoch)
|
|
|
|
|
|
|
24 |
|
|
|
25 |
corpus_feature = grab_corpus_feature(model, corpus_loader, opt.device)
|
26 |
val_ndcg_iou = eval_epoch(model, corpus_feature, val_loader, val_gt, opt, corpus_video_list)
|
27 |
test_ndcg_iou = eval_epoch(model, corpus_feature, test_loader, test_gt, opt, corpus_video_list)
|
|
|
5 |
from modules.dataset_init import prepare_dataset
|
6 |
from modules.infer_lib import grab_corpus_feature, eval_epoch
|
7 |
|
8 |
+
from utils.basic_utils import get_logger
|
9 |
from utils.setup import set_seed, get_args
|
10 |
+
from utils.run_utils import prepare_optimizer, prepare_model, logger_ndcg_iou, resume_model
|
11 |
|
12 |
def main():
|
13 |
opt = get_args()
|
|
|
21 |
|
22 |
model = prepare_model(opt, logger)
|
23 |
# optimizer = prepare_optimizer(model, opt, len(train_loader) * opt.n_epoch)
|
24 |
+
# start_epoch = 0
|
25 |
+
# model, optimizer, start_epoch = resume_model(logger, opt, model, optimizer, start_epoch)
|
26 |
+
model, _, _ = resume_model(logger, opt, model)
|
27 |
|
28 |
+
model.eval()
|
29 |
corpus_feature = grab_corpus_feature(model, corpus_loader, opt.device)
|
30 |
val_ndcg_iou = eval_epoch(model, corpus_feature, val_loader, val_gt, opt, corpus_video_list)
|
31 |
test_ndcg_iou = eval_epoch(model, corpus_feature, test_loader, test_gt, opt, corpus_video_list)
|
infer_top20.sh
CHANGED
@@ -1,17 +1,14 @@
|
|
1 |
python infer.py \
|
2 |
--results_path results/tvr_ranking \
|
3 |
-
--checkpoint results/tvr_ranking/best_model.pt \
|
4 |
--train_path data/TVR_Ranking/train_top20.json \
|
5 |
--val_path data/TVR_Ranking/val.json \
|
6 |
--test_path data/TVR_Ranking/test.json \
|
7 |
--corpus_path data/TVR_Ranking/video_corpus.json \
|
8 |
-
--desc_bert_path /
|
9 |
-
--video_feat_path /
|
10 |
-
--sub_bert_path /
|
11 |
-
--
|
|
|
12 |
|
13 |
# qsub -I -l select=1:ngpus=1 -P gs_slab -q slab_gpu8
|
14 |
-
# cd /home/renjie.liang/11_TVR-Ranking/ReLoCLNet; conda activate py11; sh infer_top20.sh
|
15 |
-
# --hard_negative_start_epoch 0 \
|
16 |
-
# --no_norm_vfeat \
|
17 |
-
# --use_hard_negative
|
|
|
1 |
python infer.py \
|
2 |
--results_path results/tvr_ranking \
|
|
|
3 |
--train_path data/TVR_Ranking/train_top20.json \
|
4 |
--val_path data/TVR_Ranking/val.json \
|
5 |
--test_path data/TVR_Ranking/test.json \
|
6 |
--corpus_path data/TVR_Ranking/video_corpus.json \
|
7 |
+
--desc_bert_path data/features/query_bert.h5 \
|
8 |
+
--video_feat_path data/features/tvr_i3d_rgb600_avg_cl-1.5.h5 \
|
9 |
+
--sub_bert_path data/features/tvr_sub_pretrained_w_sub_query_max_cl-1.5.h5 \
|
10 |
+
--checkpoint results/tvr_ranking/top20/best_model.pt \
|
11 |
+
--exp_id top20_infer
|
12 |
|
13 |
# qsub -I -l select=1:ngpus=1 -P gs_slab -q slab_gpu8
|
14 |
+
# cd /home/renjie.liang/11_TVR-Ranking/ReLoCLNet; conda activate py11; sh infer_top20.sh
|
|
|
|
|
|
modules/ReLoCLNet.py
CHANGED
@@ -201,7 +201,8 @@ class ReLoCLNet(nn.Module):
|
|
201 |
feat = input_proj_layer(feat)
|
202 |
feat = pos_embed_layer(feat)
|
203 |
mask = mask.unsqueeze(1) # (N, 1, L), torch.FloatTensor
|
204 |
-
|
|
|
205 |
|
206 |
def get_modularized_queries(self, encoded_query, query_mask, return_modular_att=False):
|
207 |
"""
|
|
|
201 |
feat = input_proj_layer(feat)
|
202 |
feat = pos_embed_layer(feat)
|
203 |
mask = mask.unsqueeze(1) # (N, 1, L), torch.FloatTensor
|
204 |
+
feat = encoder_layer(feat, mask) # (N, L, D_hidden)
|
205 |
+
return feat
|
206 |
|
207 |
def get_modularized_queries(self, encoded_query, query_mask, return_modular_att=False):
|
208 |
"""
|
modules/dataset_init.py
CHANGED
@@ -41,6 +41,9 @@ def collate_fn(batch, task):
|
|
41 |
batch_data["sub_feat"] = sub_feat_mask[0]
|
42 |
batch_data["sub_mask"] = sub_feat_mask[1]
|
43 |
|
|
|
|
|
|
|
44 |
if task == "eval":
|
45 |
query_feat_mask = pad_sequences_1d([e["query_feat"] for e in batch], dtype=torch.float32, fixed_length=None)
|
46 |
batch_data["query_feat"] = query_feat_mask[0]
|
|
|
41 |
batch_data["sub_feat"] = sub_feat_mask[0]
|
42 |
batch_data["sub_mask"] = sub_feat_mask[1]
|
43 |
|
44 |
+
# batch_data["video_name"] = [e["video_name"] for e in batch]
|
45 |
+
|
46 |
+
|
47 |
if task == "eval":
|
48 |
query_feat_mask = pad_sequences_1d([e["query_feat"] for e in batch], dtype=torch.float32, fixed_length=None)
|
49 |
batch_data["query_feat"] = query_feat_mask[0]
|
modules/dataset_tvrr.py
CHANGED
@@ -23,7 +23,7 @@ class TrainDataset(Dataset):
|
|
23 |
# prepare desc data
|
24 |
self.use_video = "video" in ctx_mode
|
25 |
self.use_sub = "sub" in ctx_mode
|
26 |
-
|
27 |
self.desc_bert_h5 = h5py.File(desc_bert_path, "r")
|
28 |
if self.use_video:
|
29 |
self.vid_feat_h5 = h5py.File(video_feat_path, "r")
|
@@ -56,6 +56,7 @@ class TrainDataset(Dataset):
|
|
56 |
query_id=raw_data["query_id"]
|
57 |
video_name=raw_data["video_name"]
|
58 |
timestamp = raw_data["timestamp"]
|
|
|
59 |
|
60 |
model_inputs = dict()
|
61 |
model_inputs["simi"] = raw_data["similarity"]
|
@@ -80,12 +81,8 @@ class TrainDataset(Dataset):
|
|
80 |
else:
|
81 |
model_inputs["sub_feat"] = torch.zeros((2, 2))
|
82 |
|
83 |
-
# print(ctx_l)
|
84 |
-
# print(timestamp)
|
85 |
model_inputs["st_ed_indices"] = self.get_st_ed_label(timestamp, max_idx=ctx_l - 1)
|
86 |
-
# print(model_inputs["st_ed_indices"])
|
87 |
return model_inputs
|
88 |
-
# return dict(meta=meta, model_inputs=model_inputs)
|
89 |
|
90 |
def get_st_ed_label(self, ts, max_idx):
|
91 |
"""
|
@@ -175,6 +172,7 @@ class CorpusEvalDataset(Dataset):
|
|
175 |
|
176 |
self.use_video = "video" in ctx_mode
|
177 |
self.use_sub = "sub" in ctx_mode
|
|
|
178 |
if self.use_video:
|
179 |
self.vid_feat_h5 = h5py.File(video_feat_path, "r")
|
180 |
if self.use_sub:
|
@@ -187,6 +185,8 @@ class CorpusEvalDataset(Dataset):
|
|
187 |
"""No need to batch, since it has already been batched here"""
|
188 |
raw_data = self.video_data[index]
|
189 |
# initialize with basic data
|
|
|
|
|
190 |
meta = dict(vid_name=raw_data["vid_name"], duration=raw_data["duration"])
|
191 |
model_inputs = dict()
|
192 |
|
|
|
23 |
# prepare desc data
|
24 |
self.use_video = "video" in ctx_mode
|
25 |
self.use_sub = "sub" in ctx_mode
|
26 |
+
|
27 |
self.desc_bert_h5 = h5py.File(desc_bert_path, "r")
|
28 |
if self.use_video:
|
29 |
self.vid_feat_h5 = h5py.File(video_feat_path, "r")
|
|
|
56 |
query_id=raw_data["query_id"]
|
57 |
video_name=raw_data["video_name"]
|
58 |
timestamp = raw_data["timestamp"]
|
59 |
+
duration = raw_data["duration"]
|
60 |
|
61 |
model_inputs = dict()
|
62 |
model_inputs["simi"] = raw_data["similarity"]
|
|
|
81 |
else:
|
82 |
model_inputs["sub_feat"] = torch.zeros((2, 2))
|
83 |
|
|
|
|
|
84 |
model_inputs["st_ed_indices"] = self.get_st_ed_label(timestamp, max_idx=ctx_l - 1)
|
|
|
85 |
return model_inputs
|
|
|
86 |
|
87 |
def get_st_ed_label(self, ts, max_idx):
|
88 |
"""
|
|
|
172 |
|
173 |
self.use_video = "video" in ctx_mode
|
174 |
self.use_sub = "sub" in ctx_mode
|
175 |
+
|
176 |
if self.use_video:
|
177 |
self.vid_feat_h5 = h5py.File(video_feat_path, "r")
|
178 |
if self.use_sub:
|
|
|
185 |
"""No need to batch, since it has already been batched here"""
|
186 |
raw_data = self.video_data[index]
|
187 |
# initialize with basic data
|
188 |
+
duration = raw_data["duration"]
|
189 |
+
video_name = raw_data["vid_name"]
|
190 |
meta = dict(vid_name=raw_data["vid_name"], duration=raw_data["duration"])
|
191 |
model_inputs = dict()
|
192 |
|
modules/infer_lib.py
CHANGED
@@ -10,16 +10,18 @@ def grab_corpus_feature(model, corpus_loader, device):
|
|
10 |
model.eval()
|
11 |
all_video_feat, all_video_mask = [], []
|
12 |
all_sub_feat, all_sub_mask = [], []
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
17 |
batch_input["sub_feat"], batch_input["sub_mask"])
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
|
24 |
all_video_feat = torch.cat(all_video_feat, dim=0)
|
25 |
all_video_mask = torch.cat(all_video_mask, dim=0)
|
@@ -41,7 +43,7 @@ def eval_epoch(model, corpus_feature, eval_loader, eval_gt, opt, corpus_video_li
|
|
41 |
all_video_mask = corpus_feature["all_video_mask"].to(device)
|
42 |
all_sub_feat = corpus_feature["all_sub_feat"].to(device)
|
43 |
all_sub_mask = corpus_feature["all_sub_mask"].to(device)
|
44 |
-
all_query_score, all_end_prob, all_start_prob = [], [], []
|
45 |
for batch_input in tqdm(eval_loader, desc="Compute Query Scores: ", total=len(eval_loader)):
|
46 |
batch_input = {k: v.to(device) for k, v in batch_input.items()}
|
47 |
query_scores, start_probs, end_probs = model.get_pred_from_raw_query(
|
@@ -56,34 +58,36 @@ def eval_epoch(model, corpus_feature, eval_loader, eval_gt, opt, corpus_video_li
|
|
56 |
start_probs = F.softmax(start_probs, dim=-1)
|
57 |
end_probs = F.softmax(end_probs, dim=-1)
|
58 |
|
59 |
-
query_scores, start_probs,
|
60 |
|
61 |
all_query_id.append(batch_input["query_id"].detach().cpu())
|
62 |
all_query_score.append(query_scores.detach().cpu())
|
63 |
all_start_prob.append(start_probs.detach().cpu())
|
64 |
all_end_prob.append(end_probs.detach().cpu())
|
65 |
-
|
|
|
66 |
all_query_id = torch.cat(all_query_id, dim=0)
|
67 |
all_query_id = all_query_id.tolist()
|
68 |
|
69 |
all_query_score = torch.cat(all_query_score, dim=0)
|
70 |
all_start_prob = torch.cat(all_start_prob, dim=0)
|
71 |
all_end_prob = torch.cat(all_end_prob, dim=0)
|
72 |
-
average_ndcg = calculate_average_ndcg(all_query_id, all_start_prob, all_query_score, all_end_prob,
|
73 |
return average_ndcg
|
74 |
|
75 |
-
def calculate_average_ndcg(all_query_id, all_start_prob, all_query_score, all_end_prob,
|
76 |
topn_moment = max(opt.ndcg_topk)
|
77 |
|
78 |
all_2D_map = torch.einsum("qvm,qv,qvn->qvmn", all_start_prob, all_query_score, all_end_prob)
|
79 |
map_mask = generate_min_max_length_mask(all_2D_map.shape, min_l=opt.min_pred_l, max_l=opt.max_pred_l)
|
80 |
all_2D_map = all_2D_map * map_mask
|
81 |
all_pred = {}
|
82 |
-
for
|
83 |
-
query_id = all_query_id[
|
84 |
-
score_map = all_2D_map[
|
85 |
top_score, top_idx = topk_3d(score_map, topn_moment)
|
86 |
-
|
|
|
87 |
pre_start_time = [i[1].item() * opt.clip_length for i in top_idx]
|
88 |
pre_end_time = [i[2].item() * opt.clip_length for i in top_idx]
|
89 |
|
@@ -94,7 +98,7 @@ def calculate_average_ndcg(all_query_id, all_start_prob, all_query_score, all_en
|
|
94 |
"timestamp": [s, e],
|
95 |
"model_scores": score
|
96 |
})
|
97 |
-
print(pred_result)
|
98 |
all_pred[query_id] = pred_result
|
99 |
|
100 |
average_ndcg = calculate_ndcg_iou(eval_gt, all_pred, opt.iou_threshold, opt.ndcg_topk)
|
|
|
10 |
model.eval()
|
11 |
all_video_feat, all_video_mask = [], []
|
12 |
all_sub_feat, all_sub_mask = [], []
|
13 |
+
|
14 |
+
# all_video_name = []
|
15 |
+
with torch.no_grad():
|
16 |
+
for batch_input in tqdm(corpus_loader, desc="Compute Corpus Feature: ", total=len(corpus_loader)):
|
17 |
+
batch_input = {k: v.to(device) for k, v in batch_input.items()}
|
18 |
+
_video_feat, _sub_feat = model.encode_context(batch_input["video_feat"], batch_input["video_mask"],
|
19 |
batch_input["sub_feat"], batch_input["sub_mask"])
|
20 |
+
|
21 |
+
all_video_feat.append(_video_feat.detach().cpu())
|
22 |
+
all_video_mask.append(batch_input["video_mask"].detach().cpu())
|
23 |
+
all_sub_feat.append(_sub_feat.detach().cpu())
|
24 |
+
all_sub_mask.append(batch_input["sub_mask"].detach().cpu())
|
25 |
|
26 |
all_video_feat = torch.cat(all_video_feat, dim=0)
|
27 |
all_video_mask = torch.cat(all_video_mask, dim=0)
|
|
|
43 |
all_video_mask = corpus_feature["all_video_mask"].to(device)
|
44 |
all_sub_feat = corpus_feature["all_sub_feat"].to(device)
|
45 |
all_sub_mask = corpus_feature["all_sub_mask"].to(device)
|
46 |
+
all_query_score, all_end_prob, all_start_prob, all_top_video_name = [], [], [], []
|
47 |
for batch_input in tqdm(eval_loader, desc="Compute Query Scores: ", total=len(eval_loader)):
|
48 |
batch_input = {k: v.to(device) for k, v in batch_input.items()}
|
49 |
query_scores, start_probs, end_probs = model.get_pred_from_raw_query(
|
|
|
58 |
start_probs = F.softmax(start_probs, dim=-1)
|
59 |
end_probs = F.softmax(end_probs, dim=-1)
|
60 |
|
61 |
+
query_scores, start_probs, end_probs, video_name_top = extract_topk_elements(query_scores, start_probs, end_probs, corpus_video_list, topn_video)
|
62 |
|
63 |
all_query_id.append(batch_input["query_id"].detach().cpu())
|
64 |
all_query_score.append(query_scores.detach().cpu())
|
65 |
all_start_prob.append(start_probs.detach().cpu())
|
66 |
all_end_prob.append(end_probs.detach().cpu())
|
67 |
+
all_top_video_name.extend(video_name_top)
|
68 |
+
|
69 |
all_query_id = torch.cat(all_query_id, dim=0)
|
70 |
all_query_id = all_query_id.tolist()
|
71 |
|
72 |
all_query_score = torch.cat(all_query_score, dim=0)
|
73 |
all_start_prob = torch.cat(all_start_prob, dim=0)
|
74 |
all_end_prob = torch.cat(all_end_prob, dim=0)
|
75 |
+
average_ndcg = calculate_average_ndcg(all_query_id, all_start_prob, all_query_score, all_end_prob, all_top_video_name, eval_gt, opt)
|
76 |
return average_ndcg
|
77 |
|
78 |
+
def calculate_average_ndcg(all_query_id, all_start_prob, all_query_score, all_end_prob, all_top_video_name, eval_gt, opt):
|
79 |
topn_moment = max(opt.ndcg_topk)
|
80 |
|
81 |
all_2D_map = torch.einsum("qvm,qv,qvn->qvmn", all_start_prob, all_query_score, all_end_prob)
|
82 |
map_mask = generate_min_max_length_mask(all_2D_map.shape, min_l=opt.min_pred_l, max_l=opt.max_pred_l)
|
83 |
all_2D_map = all_2D_map * map_mask
|
84 |
all_pred = {}
|
85 |
+
for idx in trange(len(all_2D_map), desc="Collect Predictions: "):
|
86 |
+
query_id = all_query_id[idx]
|
87 |
+
score_map = all_2D_map[idx]
|
88 |
top_score, top_idx = topk_3d(score_map, topn_moment)
|
89 |
+
top_video_name = all_top_video_name[idx]
|
90 |
+
pred_videos = [top_video_name[i[0]] for i in top_idx]
|
91 |
pre_start_time = [i[1].item() * opt.clip_length for i in top_idx]
|
92 |
pre_end_time = [i[2].item() * opt.clip_length for i in top_idx]
|
93 |
|
|
|
98 |
"timestamp": [s, e],
|
99 |
"model_scores": score
|
100 |
})
|
101 |
+
# print(pred_result)
|
102 |
all_pred[query_id] = pred_result
|
103 |
|
104 |
average_ndcg = calculate_ndcg_iou(eval_gt, all_pred, opt.iou_threshold, opt.ndcg_topk)
|
modules/ndcg_iou.py
CHANGED
@@ -25,7 +25,7 @@ def calculate_ndcg(pred_scores, true_scores):
|
|
25 |
def calculate_ndcg_iou(all_gt, all_pred, TS, KS):
|
26 |
performance = defaultdict(lambda: defaultdict(list))
|
27 |
performance_avg = defaultdict(lambda: defaultdict(float))
|
28 |
-
for k in all_pred.keys():
|
29 |
one_pred = all_pred[k]
|
30 |
one_gt = all_gt[k]
|
31 |
|
|
|
25 |
def calculate_ndcg_iou(all_gt, all_pred, TS, KS):
|
26 |
performance = defaultdict(lambda: defaultdict(list))
|
27 |
performance_avg = defaultdict(lambda: defaultdict(float))
|
28 |
+
for k in tqdm(all_pred.keys(), desc="Calculate NDCG"):
|
29 |
one_pred = all_pred[k]
|
30 |
one_gt = all_gt[k]
|
31 |
|
results/ReLoCLNet/top01/20240704_170921_top01.log
ADDED
The diff for this file is too large to render.
See raw diff
|
|
results/ReLoCLNet/top01/best_model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:10d055d1d72eac6aef2937d38422ec3ff5760aba541b1f9de1b3a6127925550b
|
3 |
+
size 83802857
|
results/ReLoCLNet/top20/20240704_170928_top20.log
ADDED
The diff for this file is too large to render.
See raw diff
|
|
results/ReLoCLNet/top20/best_model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9c40bf96c586b463cf324a42101c92a2e7b5e5692a0caa25d234b977403dc8bd
|
3 |
+
size 83802857
|
results/ReLoCLNet/top40/20240704_170937_top40.log
ADDED
The diff for this file is too large to render.
See raw diff
|
|
results/ReLoCLNet/top40/best_model.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:05a50f5bd49514b72669755b5d08e39906f9bc2bdedb4b1b53a519ec66b5e980
|
3 |
+
size 83802857
|
run_top01.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python train.py \
|
2 |
+
--results_path results/tvr_ranking \
|
3 |
+
--train_path data/TVR_Ranking/train_top01.json \
|
4 |
+
--val_path data/TVR_Ranking/val.json \
|
5 |
+
--test_path data/TVR_Ranking/test.json \
|
6 |
+
--corpus_path data/TVR_Ranking/video_corpus.json \
|
7 |
+
--desc_bert_path data/features/query_bert.h5 \
|
8 |
+
--video_feat_path data/features/tvr_i3d_rgb600_avg_cl-1.5.h5 \
|
9 |
+
--sub_bert_path data/features/tvr_sub_pretrained_w_sub_query_max_cl-1.5.h5 \
|
10 |
+
--n_epoch 4000 \
|
11 |
+
--eval_num_per_epoch 0.05 \
|
12 |
+
--seed 2024 \
|
13 |
+
--bsz 512 \
|
14 |
+
--exp_id top01
|
15 |
+
|
16 |
+
# qsub -I -l select=1:ngpus=1 -P gs_slab -q gpu8
|
17 |
+
# cd /home/renjie.liang/11_TVR-Ranking/ReLoCLNet; conda activate py11; sh run_top01.sh
|
run_top20.sh
CHANGED
@@ -4,11 +4,14 @@ python train.py \
|
|
4 |
--val_path data/TVR_Ranking/val.json \
|
5 |
--test_path data/TVR_Ranking/test.json \
|
6 |
--corpus_path data/TVR_Ranking/video_corpus.json \
|
7 |
-
--desc_bert_path data/
|
8 |
-
--video_feat_path data/
|
9 |
-
--sub_bert_path data/
|
10 |
-
--n_epoch
|
11 |
-
--eval_num_per_epoch
|
12 |
-
--seed
|
13 |
-
--
|
|
|
14 |
|
|
|
|
|
|
4 |
--val_path data/TVR_Ranking/val.json \
|
5 |
--test_path data/TVR_Ranking/test.json \
|
6 |
--corpus_path data/TVR_Ranking/video_corpus.json \
|
7 |
+
--desc_bert_path data/features/query_bert.h5 \
|
8 |
+
--video_feat_path data/features/tvr_i3d_rgb600_avg_cl-1.5.h5 \
|
9 |
+
--sub_bert_path data/features/tvr_sub_pretrained_w_sub_query_max_cl-1.5.h5 \
|
10 |
+
--n_epoch 200 \
|
11 |
+
--eval_num_per_epoch 1 \
|
12 |
+
--seed 2024 \
|
13 |
+
--bsz 512 \
|
14 |
+
--exp_id top20
|
15 |
|
16 |
+
# qsub -I -l select=1:ngpus=1 -P gs_slab -q gpu8
|
17 |
+
# cd /home/renjie.liang/11_TVR-Ranking/ReLoCLNet; conda activate py11; sh run_top20.sh
|
run_top40.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python train.py \
|
2 |
+
--results_path results/tvr_ranking \
|
3 |
+
--train_path data/TVR_Ranking/train_top40.json \
|
4 |
+
--val_path data/TVR_Ranking/val.json \
|
5 |
+
--test_path data/TVR_Ranking/test.json \
|
6 |
+
--corpus_path data/TVR_Ranking/video_corpus.json \
|
7 |
+
--desc_bert_path data/features/query_bert.h5 \
|
8 |
+
--video_feat_path data/features/tvr_i3d_rgb600_avg_cl-1.5.h5 \
|
9 |
+
--sub_bert_path data/features/tvr_sub_pretrained_w_sub_query_max_cl-1.5.h5 \
|
10 |
+
--n_epoch 100 \
|
11 |
+
--eval_num_per_epoch 2 \
|
12 |
+
--seed 2024 \
|
13 |
+
--bsz 512 \
|
14 |
+
--exp_id top40
|
15 |
+
|
16 |
+
# qsub -I -l select=1:ngpus=1 -P gs_slab -q gpu8
|
17 |
+
# cd /home/renjie.liang/11_TVR-Ranking/ReLoCLNet; conda activate py11; sh run_top40.sh
|
train.py
CHANGED
@@ -7,7 +7,7 @@ from modules.infer_lib import grab_corpus_feature, eval_epoch
|
|
7 |
|
8 |
from utils.basic_utils import AverageMeter, get_logger
|
9 |
from utils.setup import set_seed, get_args
|
10 |
-
from utils.run_utils import prepare_optimizer, prepare_model, logger_ndcg_iou
|
11 |
|
12 |
def main():
|
13 |
opt = get_args()
|
@@ -20,32 +20,39 @@ def main():
|
|
20 |
|
21 |
|
22 |
train_loader, corpus_loader, corpus_video_list, val_loader, test_loader, val_gt, test_gt = prepare_dataset(opt)
|
23 |
-
|
24 |
model = prepare_model(opt, logger)
|
25 |
optimizer = prepare_optimizer(model, opt, len(train_loader) * opt.n_epoch)
|
26 |
|
|
|
|
|
|
|
|
|
27 |
eval_step = len(train_loader) // opt.eval_num_per_epoch
|
28 |
best_val_ndcg = 0
|
29 |
-
for
|
30 |
-
logger.info(f"TRAIN EPOCH: {
|
31 |
model.train()
|
32 |
-
if opt.hard_negative_start_epoch != -1 and
|
33 |
model.set_hard_negative(True, opt.hard_pool_size)
|
34 |
-
|
35 |
model.train()
|
|
|
36 |
for step, batch_input in tqdm(enumerate(train_loader), desc="Training", total=len(train_loader)):
|
37 |
-
step
|
38 |
batch_input = {k: v.to(opt.device) for k, v in batch_input.items()}
|
39 |
loss = model(**batch_input)
|
40 |
optimizer.zero_grad()
|
41 |
loss.backward()
|
42 |
# nn.utils.clip_grad_norm_(model.parameters())
|
43 |
optimizer.step()
|
44 |
-
|
45 |
if step % opt.log_step == 0:
|
46 |
-
logger.info(f"EPOCH {
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
49 |
corpus_feature = grab_corpus_feature(model, corpus_loader, opt.device)
|
50 |
val_ndcg_iou = eval_epoch(model, corpus_feature, val_loader, val_gt, opt, corpus_video_list)
|
51 |
test_ndcg_iou = eval_epoch(model, corpus_feature, test_loader, test_gt, opt, corpus_video_list)
|
@@ -58,12 +65,8 @@ def main():
|
|
58 |
logger_ndcg_iou(val_ndcg_iou, logger, "BEST VAL")
|
59 |
logger_ndcg_iou(test_ndcg_iou, logger, "BEST TEST")
|
60 |
|
61 |
-
checkpoint = {"model": model.state_dict(), "model_cfg": model.config, "epoch": epoch_i}
|
62 |
-
|
63 |
bestmodel_path = os.path.join(opt.results_path, "best_model.pt")
|
64 |
-
|
65 |
-
logger.info(f"Save checkpoint at {bestmodel_path}")
|
66 |
-
logger.info("")
|
67 |
|
68 |
if __name__ == '__main__':
|
69 |
main()
|
|
|
7 |
|
8 |
from utils.basic_utils import AverageMeter, get_logger
|
9 |
from utils.setup import set_seed, get_args
|
10 |
+
from utils.run_utils import prepare_optimizer, prepare_model, logger_ndcg_iou, save_model, resume_model
|
11 |
|
12 |
def main():
|
13 |
opt = get_args()
|
|
|
20 |
|
21 |
|
22 |
train_loader, corpus_loader, corpus_video_list, val_loader, test_loader, val_gt, test_gt = prepare_dataset(opt)
|
23 |
+
|
24 |
model = prepare_model(opt, logger)
|
25 |
optimizer = prepare_optimizer(model, opt, len(train_loader) * opt.n_epoch)
|
26 |
|
27 |
+
start_epoch = 0
|
28 |
+
if opt.checkpoint is not None:
|
29 |
+
model, optimizer, start_epoch = resume_model(logger, opt, model, optimizer, start_epoch)
|
30 |
+
|
31 |
eval_step = len(train_loader) // opt.eval_num_per_epoch
|
32 |
best_val_ndcg = 0
|
33 |
+
for epoch in range(start_epoch, opt.n_epoch):
|
34 |
+
logger.info(f"TRAIN EPOCH: {epoch}|{opt.n_epoch}")
|
35 |
model.train()
|
36 |
+
if opt.hard_negative_start_epoch != -1 and epoch >= opt.hard_negative_start_epoch:
|
37 |
model.set_hard_negative(True, opt.hard_pool_size)
|
|
|
38 |
model.train()
|
39 |
+
|
40 |
for step, batch_input in tqdm(enumerate(train_loader), desc="Training", total=len(train_loader)):
|
41 |
+
global_step = epoch * len(train_loader) + step + 1
|
42 |
batch_input = {k: v.to(opt.device) for k, v in batch_input.items()}
|
43 |
loss = model(**batch_input)
|
44 |
optimizer.zero_grad()
|
45 |
loss.backward()
|
46 |
# nn.utils.clip_grad_norm_(model.parameters())
|
47 |
optimizer.step()
|
48 |
+
|
49 |
if step % opt.log_step == 0:
|
50 |
+
logger.info(f"EPOCH {epoch}/{opt.n_epoch} | STEP: {step}|{len(train_loader)} | Loss: {loss.item():.6f}")
|
51 |
+
for i in range(torch.cuda.device_count()):
|
52 |
+
print(f"Memory Allocated on GPU {i}: {torch.cuda.memory_allocated(i) / 1024**3:.2f} GB")
|
53 |
+
print(f"Memory Cached on GPU {i}: {torch.cuda.memory_reserved(i) / 1024**3:.2f} GB")
|
54 |
+
print("-------------------------")
|
55 |
+
if global_step % eval_step == 0 or step == len(train_loader):
|
56 |
corpus_feature = grab_corpus_feature(model, corpus_loader, opt.device)
|
57 |
val_ndcg_iou = eval_epoch(model, corpus_feature, val_loader, val_gt, opt, corpus_video_list)
|
58 |
test_ndcg_iou = eval_epoch(model, corpus_feature, test_loader, test_gt, opt, corpus_video_list)
|
|
|
65 |
logger_ndcg_iou(val_ndcg_iou, logger, "BEST VAL")
|
66 |
logger_ndcg_iou(test_ndcg_iou, logger, "BEST TEST")
|
67 |
|
|
|
|
|
68 |
bestmodel_path = os.path.join(opt.results_path, "best_model.pt")
|
69 |
+
save_model(model, optimizer, epoch, bestmodel_path, logger)
|
|
|
|
|
70 |
|
71 |
if __name__ == '__main__':
|
72 |
main()
|
utils/run_utils.py
CHANGED
@@ -2,6 +2,7 @@ import torch
|
|
2 |
from modules.ReLoCLNet import ReLoCLNet
|
3 |
from modules.optimization import BertAdam
|
4 |
import numpy as np
|
|
|
5 |
|
6 |
def count_parameters(model, verbose=True):
|
7 |
"""Count number of parameters in PyTorch model,
|
@@ -21,18 +22,27 @@ def count_parameters(model, verbose=True):
|
|
21 |
def prepare_model(opt, logger):
|
22 |
model = ReLoCLNet(opt)
|
23 |
count_parameters(model)
|
24 |
-
|
25 |
-
if opt.checkpoint is not None:
|
26 |
-
checkpoint = torch.load(opt.checkpoint, map_location=opt.device)
|
27 |
-
model.load_state_dict(checkpoint['model'])
|
28 |
-
logger.info(f"Loading checkpoint from {opt.checkpoint}")
|
29 |
-
|
30 |
-
# Prepare optimizer (unchanged)
|
31 |
if opt.device.type == "cuda":
|
32 |
logger.info("CUDA enabled.")
|
33 |
model.to(opt.device)
|
34 |
return model
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
def prepare_optimizer(model, opt, total_train_steps):
|
37 |
|
38 |
param_optimizer = list(model.named_parameters())
|
@@ -43,10 +53,20 @@ def prepare_optimizer(model, opt, total_train_steps):
|
|
43 |
|
44 |
optimizer = BertAdam(optimizer_grouped_parameters, lr=opt.lr, weight_decay=opt.wd, warmup=opt.lr_warmup_proportion,
|
45 |
t_total=total_train_steps, schedule="warmup_linear")
|
46 |
-
|
47 |
return optimizer
|
48 |
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
def topk_3d(tensor, k):
|
51 |
"""
|
52 |
Find the top k values and their corresponding indices in a 3D tensor.
|
@@ -94,7 +114,7 @@ def generate_min_max_length_mask(array_shape, min_l, max_l):
|
|
94 |
return final_prob_mask # with valid bit to be 1
|
95 |
|
96 |
|
97 |
-
def extract_topk_elements(query_scores, start_probs, end_probs, k):
|
98 |
|
99 |
# Step 1: Find the top k values and their indices in query_scores
|
100 |
topk_values, topk_indices = torch.topk(query_scores, k)
|
@@ -102,8 +122,14 @@ def extract_topk_elements(query_scores, start_probs, end_probs, k):
|
|
102 |
# Step 2: Use these indices to select the corresponding elements from start_probs and end_probs
|
103 |
selected_start_probs = torch.stack([start_probs[i, indices] for i, indices in enumerate(topk_indices)], dim=0)
|
104 |
selected_end_probs = torch.stack([end_probs[i, indices] for i, indices in enumerate(topk_indices)], dim=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
-
return topk_values, selected_start_probs, selected_end_probs
|
107 |
|
108 |
def logger_ndcg_iou(val_ndcg_iou, logger, suffix):
|
109 |
for K, vs in val_ndcg_iou.items():
|
|
|
2 |
from modules.ReLoCLNet import ReLoCLNet
|
3 |
from modules.optimization import BertAdam
|
4 |
import numpy as np
|
5 |
+
import copy
|
6 |
|
7 |
def count_parameters(model, verbose=True):
|
8 |
"""Count number of parameters in PyTorch model,
|
|
|
22 |
def prepare_model(opt, logger):
|
23 |
model = ReLoCLNet(opt)
|
24 |
count_parameters(model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
if opt.device.type == "cuda":
|
26 |
logger.info("CUDA enabled.")
|
27 |
model.to(opt.device)
|
28 |
return model
|
29 |
|
30 |
+
def resume_model(logger, opt, model=None, optimizer=None, start_epoch=None):
|
31 |
+
checkpoint = torch.load(opt.checkpoint, map_location=opt.device)
|
32 |
+
if model is not None:
|
33 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
34 |
+
logger.info(f"Loading model from {opt.checkpoint} at epoch {checkpoint['epoch']}")
|
35 |
+
|
36 |
+
if optimizer is not None:
|
37 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
38 |
+
logger.info(f"Loading optimizer from {opt.checkpoint} at epoch {checkpoint['epoch']}")
|
39 |
+
|
40 |
+
if start_epoch is not None:
|
41 |
+
start_epoch = checkpoint['epoch']
|
42 |
+
logger.info(f"Loading start_epoch from {opt.checkpoint} at epoch {checkpoint['epoch']}")
|
43 |
+
|
44 |
+
return model, optimizer, start_epoch,
|
45 |
+
|
46 |
def prepare_optimizer(model, opt, total_train_steps):
|
47 |
|
48 |
param_optimizer = list(model.named_parameters())
|
|
|
53 |
|
54 |
optimizer = BertAdam(optimizer_grouped_parameters, lr=opt.lr, weight_decay=opt.wd, warmup=opt.lr_warmup_proportion,
|
55 |
t_total=total_train_steps, schedule="warmup_linear")
|
|
|
56 |
return optimizer
|
57 |
|
58 |
+
def save_model(model, optimizer, epoch, path, logger):
|
59 |
+
data = {
|
60 |
+
'epoch': epoch,
|
61 |
+
'model_cfg': model.config,
|
62 |
+
'model_state_dict': model.state_dict(),
|
63 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
64 |
+
}
|
65 |
+
torch.save(data, path)
|
66 |
+
logger.info(f"Save checkpoint at {path}")
|
67 |
+
logger.info("")
|
68 |
+
|
69 |
+
|
70 |
def topk_3d(tensor, k):
|
71 |
"""
|
72 |
Find the top k values and their corresponding indices in a 3D tensor.
|
|
|
114 |
return final_prob_mask # with valid bit to be 1
|
115 |
|
116 |
|
117 |
+
def extract_topk_elements(query_scores, start_probs, end_probs, video_names, k):
|
118 |
|
119 |
# Step 1: Find the top k values and their indices in query_scores
|
120 |
topk_values, topk_indices = torch.topk(query_scores, k)
|
|
|
122 |
# Step 2: Use these indices to select the corresponding elements from start_probs and end_probs
|
123 |
selected_start_probs = torch.stack([start_probs[i, indices] for i, indices in enumerate(topk_indices)], dim=0)
|
124 |
selected_end_probs = torch.stack([end_probs[i, indices] for i, indices in enumerate(topk_indices)], dim=0)
|
125 |
+
|
126 |
+
selected_video_name = []
|
127 |
+
for i in range(topk_indices.shape[0]):
|
128 |
+
vn = copy.deepcopy(video_names)
|
129 |
+
tmp = [vn[idx] for idx in topk_indices[i]]
|
130 |
+
selected_video_name.append(tmp)
|
131 |
|
132 |
+
return topk_values, selected_start_probs, selected_end_probs, selected_video_name
|
133 |
|
134 |
def logger_ndcg_iou(val_ndcg_iou, logger, suffix):
|
135 |
for K, vs in val_ndcg_iou.items():
|
utils/setup.py
CHANGED
@@ -84,12 +84,13 @@ def get_args():
|
|
84 |
parser.add_argument("--ndcg_topk", type=int, nargs='+', default=[10, 20, 40], help="List of NDCG top k values")
|
85 |
args = parser.parse_args()
|
86 |
|
87 |
-
|
88 |
os.makedirs(args.results_path, exist_ok=True)
|
|
|
89 |
if args.hard_negative_start_epoch != -1:
|
90 |
if args.hard_pool_size > args.bsz:
|
91 |
print("[WARNING] hard_pool_size is larger than bsz")
|
92 |
-
|
93 |
return args
|
94 |
|
95 |
|
|
|
84 |
parser.add_argument("--ndcg_topk", type=int, nargs='+', default=[10, 20, 40], help="List of NDCG top k values")
|
85 |
args = parser.parse_args()
|
86 |
|
87 |
+
args.results_path = os.path.join(args.results_path, args.exp_id)
|
88 |
os.makedirs(args.results_path, exist_ok=True)
|
89 |
+
|
90 |
if args.hard_negative_start_epoch != -1:
|
91 |
if args.hard_pool_size > args.bsz:
|
92 |
print("[WARNING] hard_pool_size is larger than bsz")
|
93 |
+
|
94 |
return args
|
95 |
|
96 |
|