import pprint from tqdm import tqdm, trange import numpy as np import os from collections import OrderedDict, defaultdict from utils.basic_utils import AverageMeter import torch import torch.nn.functional as F import torch.backends.cudnn as cudnn from torch.utils.data import DataLoader from cg_detr.config import TestOptions from cg_detr.model import build_model from cg_detr.span_utils import span_cxw_to_xx from cg_detr.start_end_dataset import StartEndDataset, start_end_collate, prepare_batch_inputs from cg_detr.postprocessing_cg_detr import PostProcessorDETR from standalone_eval.eval import eval_submission from utils.basic_utils import save_jsonl, save_json from utils.temporal_nms import temporal_nms import logging logger = logging.getLogger(__name__) logging.basicConfig(format="%(asctime)s.%(msecs)03d:%(levelname)s:%(name)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO) def post_processing_mr_nms(mr_res, nms_thd, max_before_nms, max_after_nms): mr_res_after_nms = [] for e in mr_res: e["pred_relevant_windows"] = temporal_nms( e["pred_relevant_windows"][:max_before_nms], nms_thd=nms_thd, max_after_nms=max_after_nms ) mr_res_after_nms.append(e) return mr_res_after_nms def eval_epoch_post_processing(submission, opt, gt_data, save_submission_filename): # IOU_THDS = (0.5, 0.7) logger.info("Saving/Evaluating before nms results") submission_path = os.path.join(opt.results_dir, save_submission_filename) save_jsonl(submission, submission_path) if opt.eval_split_name in ["val"]: # since test_public has no GT metrics = eval_submission( submission, gt_data, verbose=opt.debug, match_number=not opt.debug ) save_metrics_path = submission_path.replace(".jsonl", "_metrics.json") save_json(metrics, save_metrics_path, save_pretty=True, sort_keys=False) latest_file_paths = [submission_path, save_metrics_path] else: metrics = None latest_file_paths = [submission_path, ] if opt.nms_thd != -1: logger.info("[MR] Performing nms with nms_thd {}".format(opt.nms_thd)) submission_after_nms = post_processing_mr_nms( submission, nms_thd=opt.nms_thd, max_before_nms=opt.max_before_nms, max_after_nms=opt.max_after_nms ) logger.info("Saving/Evaluating nms results") submission_nms_path = submission_path.replace(".jsonl", "_nms_thd_{}.jsonl".format(opt.nms_thd)) save_jsonl(submission_after_nms, submission_nms_path) if opt.eval_split_name == "val": metrics_nms = eval_submission( submission_after_nms, gt_data, verbose=opt.debug, match_number=not opt.debug ) save_metrics_nms_path = submission_nms_path.replace(".jsonl", "_metrics.json") save_json(metrics_nms, save_metrics_nms_path, save_pretty=True, sort_keys=False) latest_file_paths += [submission_nms_path, save_metrics_nms_path] else: metrics_nms = None latest_file_paths = [submission_nms_path, ] else: metrics_nms = None return metrics, metrics_nms, latest_file_paths # for HL @torch.no_grad() def compute_hl_results(model, eval_loader, opt, epoch_i=None, criterion=None, tb_writer=None): model.eval() if criterion: assert eval_loader.dataset.load_labels criterion.eval() loss_meters = defaultdict(AverageMeter) write_tb = tb_writer is not None and epoch_i is not None mr_res = [] topk = 5 # top-5 map video_ap_collected = [] for batch in tqdm(eval_loader, desc="compute st ed scores"): query_meta = batch[0] model_inputs, targets = prepare_batch_inputs(batch[1], opt.device, non_blocking=opt.pin_memory) outputs = model(**model_inputs) # loss meters # if criterion: # loss_dict = criterion(outputs, targets) # weight_dict = criterion.weight_dict # print(loss_dict) # print(weight_dict) # print('#######') # {'loss_saliency': tensor(18.1374, device='cuda:0')} # {'loss_span': 10, 'loss_giou': 1, 'loss_label': 4, 'loss_saliency': 1.0, 'loss_ms_align': 1.0, # 'loss_distill': 1.0, 'loss_span_0': 10, 'loss_giou_0': 1, 'loss_label_0': 4, 'loss_ms_align_0': 1.0, # 'loss_distill_0': 1.0} # losses=0. # print(loss_dict.keys(), weight_dict.keys()) # losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) # loss_dict["loss_overall"] = float(losses) # for logging only # print(loss_dict.items()) # # print(weight_dict.items()) # for k, v in loss_dict.items(): # loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v)) preds = outputs['saliency_scores'].clone().detach() for meta, pred in zip(query_meta, preds): pred = pred label = meta['label'] # raw label video_ap = [] # Follow the UMT code "https://github.com/TencentARC/UMT/blob/main/datasets/tvsum.py" if opt.dset_name in ["tvsum"]: for i in range(20): pred=pred.cpu() cur_pred = pred[:len(label)] inds = torch.argsort(cur_pred, descending=True, dim=-1) # video_id = self.get_video_id(idx) cur_label = torch.Tensor(label)[:, i] cur_label = torch.where(cur_label > cur_label.median(), 1.0, .0) cur_label = cur_label[inds].tolist()[:topk] # if (num_gt := sum(cur_label)) == 0: num_gt = sum(cur_label) if num_gt == 0: video_ap.append(0) continue hits = ap = rec = 0 prc = 1 for j, gt in enumerate(cur_label): hits += gt _rec = hits / num_gt _prc = hits / (j + 1) ap += (_rec - rec) * (prc + _prc) / 2 rec, prc = _rec, _prc video_ap.append(ap) elif opt.dset_name in ["youtube_uni"]: cur_pred = pred[:len(label)] # if opt.dset_name == "tvsum_sfc": cur_pred = cur_pred.cpu() inds = torch.argsort(cur_pred, descending=True, dim=-1) cur_label = torch.Tensor(label).squeeze()[inds].tolist() num_gt = sum(cur_label) if num_gt == 0: video_ap.append(0) continue hits = ap = rec = 0 prc = 1 for j, gt in enumerate(cur_label): hits += gt _rec = hits / num_gt _prc = hits / (j + 1) ap += (_rec - rec) * (prc + _prc) / 2 rec, prc = _rec, _prc video_ap.append(float(ap)) else: print("No such dataset") exit(-1) video_ap_collected.append(video_ap) mean_ap = np.mean(video_ap_collected) submmission = dict(mAP=round(mean_ap, 5)) # tensorboard writer if write_tb and criterion: for k, v in loss_meters.items(): tb_writer.add_scalar("Eval/{}".format(k), v.avg, epoch_i + 1) return submmission, loss_meters @torch.no_grad() def compute_mr_results(model, eval_loader, opt, epoch_i=None, criterion=None, tb_writer=None): model.eval() if criterion: assert eval_loader.dataset.load_labels criterion.eval() loss_meters = defaultdict(AverageMeter) write_tb = tb_writer is not None and epoch_i is not None mr_res = [] for batch in tqdm(eval_loader, desc="compute st ed scores"): query_meta = batch[0] model_inputs, targets = prepare_batch_inputs(batch[1], opt.device, non_blocking=opt.pin_memory) outputs = model(**model_inputs) prob = F.softmax(outputs["pred_logits"], -1) # (batch_size, #queries, #classes=2) if opt.span_loss_type == "l1": scores = prob[..., 0] # * (batch_size, #queries) foreground label is 0, we directly take it pred_spans = outputs["pred_spans"] # (bsz, #queries, 2) _saliency_scores = outputs["saliency_scores"].half() # (bsz, L) saliency_scores = [] valid_vid_lengths = model_inputs["src_vid_mask"].sum(1).cpu().tolist() for j in range(len(valid_vid_lengths)): saliency_scores.append(_saliency_scores[j, :int(valid_vid_lengths[j])].tolist()) else: bsz, n_queries = outputs["pred_spans"].shape[:2] # # (bsz, #queries, max_v_l *2) pred_spans_logits = outputs["pred_spans"].view(bsz, n_queries, 2, opt.max_v_l) pred_span_scores, pred_spans = F.softmax(pred_spans_logits, dim=-1).max(-1) # 2 * (bsz, #queries, 2) scores = torch.prod(pred_span_scores, 2) # (bsz, #queries) pred_spans[:, 1] += 1 pred_spans *= opt.clip_length # compose predictions for idx, (meta, spans, score) in enumerate(zip(query_meta, pred_spans.cpu(), scores.cpu())): if opt.span_loss_type == "l1": spans = span_cxw_to_xx(spans) * meta["duration"] spans = torch.clamp(spans, 0, meta["duration"]) # # (#queries, 3), [st(float), ed(float), score(float)] cur_ranked_preds = torch.cat([spans, score[:, None]], dim=1).tolist() if not opt.no_sort_results: cur_ranked_preds = sorted(cur_ranked_preds, key=lambda x: x[2], reverse=True) cur_ranked_preds = [[float(f"{e:.4f}") for e in row] for row in cur_ranked_preds] cur_query_pred = dict( qid=meta["qid"], query=meta["query"], vid=meta["vid"], pred_relevant_windows=cur_ranked_preds, pred_saliency_scores=saliency_scores[idx] ) mr_res.append(cur_query_pred) if criterion: loss_dict = criterion(outputs, targets) weight_dict = criterion.weight_dict losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) loss_dict["loss_overall"] = float(losses) # for logging only for k, v in loss_dict.items(): loss_meters[k].update(float(v) * weight_dict[k] if k in weight_dict else float(v)) if opt.debug: break if write_tb and criterion: for k, v in loss_meters.items(): tb_writer.add_scalar("Eval/{}".format(k), v.avg, epoch_i + 1) if opt.dset_name in ['hl']: post_processor = PostProcessorDETR( clip_length=opt.clip_length, min_ts_val=0, max_ts_val=150, min_w_l=2, max_w_l=150, move_window_method="left", process_func_names=("clip_ts", "round_multiple") ) elif opt.dset_name in ['charadesSTA']: if opt.v_feat_dim == 4096: # vgg post_processor = PostProcessorDETR( clip_length=opt.clip_length, min_ts_val=0, max_ts_val=360, min_w_l=12, max_w_l=360, move_window_method="left", process_func_names=("clip_ts", "round_multiple") ) else: post_processor = PostProcessorDETR( clip_length=opt.clip_length, min_ts_val=0, max_ts_val=150, min_w_l=2, max_w_l=60, move_window_method="left", process_func_names=("clip_ts", "round_multiple") ) else: post_processor = PostProcessorDETR( clip_length=opt.clip_length, min_ts_val=0, max_ts_val=50000, min_w_l=0, max_w_l=50000, move_window_method="left", process_func_names=(["round_multiple"]) ) mr_res = post_processor(mr_res) return mr_res, loss_meters def get_eval_res(model, eval_loader, opt, epoch_i, criterion, tb_writer): """compute and save query and video proposal embeddings""" eval_res, eval_loss_meters = compute_mr_results(model, eval_loader, opt, epoch_i, criterion, tb_writer) # list(dict) return eval_res, eval_loss_meters def eval_epoch(model, eval_dataset, opt, save_submission_filename, epoch_i=None, criterion=None, tb_writer=None): logger.info("Generate submissions") model.eval() if criterion is not None and eval_dataset.load_labels: criterion.eval() else: criterion = None if opt.dset_name == 'tacos': shuffle = True else: shuffle = False eval_loader = DataLoader( eval_dataset, collate_fn=start_end_collate, batch_size=opt.eval_bsz, num_workers=opt.num_workers, shuffle=shuffle, pin_memory=opt.pin_memory ) # tvsum if opt.dset_name in ['tvsum', 'youtube_uni']: metrics, eval_loss_meters = compute_hl_results(model, eval_loader, opt, epoch_i, criterion, tb_writer) # to match original save format submission = [ {"brief": metrics} ] submission_path = os.path.join(opt.results_dir, "latest_metric.jsonl") save_jsonl(submission, submission_path) return submission[0], submission[0], eval_loss_meters, [submission_path] else: submission, eval_loss_meters = get_eval_res(model, eval_loader, opt, epoch_i, criterion, tb_writer) if opt.dset_name in ['charadesSTA', 'tacos', 'nlq']: new_submission = [] for s in submission: s.pop('pred_saliency_scores', None) new_submission.append(s) submission = new_submission if opt.no_sort_results: save_submission_filename = save_submission_filename.replace(".jsonl", "_unsorted.jsonl") metrics, metrics_nms, latest_file_paths = eval_epoch_post_processing( submission, opt, eval_dataset.data, save_submission_filename) return metrics, metrics_nms, eval_loss_meters, latest_file_paths def setup_model(opt): """setup model/optimizer/scheduler and load checkpoints when needed""" logger.info("setup model/optimizer/scheduler") model, criterion = build_model(opt) if opt.device.type == "cuda": logger.info("CUDA enabled.") model.to(opt.device) criterion.to(opt.device) param_dicts = [{"params": [p for n, p in model.named_parameters() if p.requires_grad]}] optimizer = torch.optim.AdamW(param_dicts, lr=opt.lr, weight_decay=opt.wd) lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, opt.lr_drop) if opt.resume is not None: logger.info(f"Load checkpoint from {opt.resume}") checkpoint = torch.load(opt.resume, map_location="cpu") from collections import OrderedDict new_state_dict = OrderedDict() if 'pt' in opt.resume[:-4]: if 'asr' in opt.resume[:25]: model.load_state_dict(checkpoint["model"]) else: for k, v in checkpoint["model"].items(): name = k[7:] # remove `module.` new_state_dict[name] = v # model.load_state_dict(checkpoint["model"]) model.load_state_dict(new_state_dict) else: model.load_state_dict(checkpoint["model"]) if opt.resume_all: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) opt.start_epoch = checkpoint['epoch'] + 1 logger.info(f"Loaded model saved at epoch {checkpoint['epoch']} from checkpoint: {opt.resume}") else: logger.warning("If you intend to evaluate the model, please specify --resume with ckpt path") return model, criterion, optimizer, lr_scheduler def start_inference(train_opt=None, split=None, splitfile=None): if train_opt is not None: opt = TestOptions().parse(train_opt.a_feat_dir) else: opt = TestOptions().parse() if split is not None: opt.eval_split_name = split if splitfile is not None: opt.eval_path = splitfile print(opt.eval_split_name) print(opt.eval_path) logger.info("Setup config, data and model...") cudnn.benchmark = True cudnn.deterministic = False assert opt.eval_path is not None if opt.eval_split_name == 'val': loadlabel = True else: loadlabel = False eval_dataset = StartEndDataset( dset_name=opt.dset_name, data_path=opt.eval_path, v_feat_dirs=opt.v_feat_dirs, q_feat_dir=opt.t_feat_dir, q_feat_type="last_hidden_state", max_q_l=opt.max_q_l, max_v_l=opt.max_v_l, ctx_mode=opt.ctx_mode, data_ratio=opt.data_ratio, normalize_v=not opt.no_norm_vfeat, normalize_t=not opt.no_norm_tfeat, clip_len=opt.clip_length, max_windows=opt.max_windows, load_labels=loadlabel, # opt.eval_split_name == "val", span_loss_type=opt.span_loss_type, txt_drop_ratio=0, dset_domain=opt.dset_domain, ) model, criterion, _, _ = setup_model(opt) save_submission_filename = "hl_{}_submission.jsonl".format( opt.eval_split_name) # save_submission_filename = "inference_{}_{}_{}_preds.jsonl".format( # opt.dset_name, opt.eval_split_name, opt.eval_id) logger.info("Starting inference...") with torch.no_grad(): metrics_no_nms, metrics_nms, eval_loss_meters, latest_file_paths = \ eval_epoch(model, eval_dataset, opt, save_submission_filename, criterion=criterion) if opt.eval_split_name == 'val': logger.info("metrics_no_nms {}".format(pprint.pformat(metrics_no_nms["brief"], indent=4))) if metrics_nms is not None: logger.info("metrics_nms {}".format(pprint.pformat(metrics_nms["brief"], indent=4))) from sys import argv if __name__ == '__main__': _,_,_,_,split,_,splitfile = argv start_inference(split=split, splitfile=splitfile)