ynhe
init
16dc4f2
raw
history blame
63.9 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
CG-DETR model and criterion classes.
"""
import torch
import torch.nn.functional as F
from torch import nn
from third_party.cgdetr.cg_detr.span_utils import generalized_temporal_iou, span_cxw_to_xx
from third_party.cgdetr.cg_detr.matcher import build_matcher
from third_party.cgdetr.cg_detr.transformer import build_transformer, TransformerEncoderLayer, TransformerEncoder
from third_party.cgdetr.cg_detr.position_encoding import build_position_encoding
from third_party.cgdetr.cg_detr.misc import accuracy
import numpy as np
import copy
def inverse_sigmoid(x, eps=1e-3):
x = x.clamp(min=0, max=1)
x1 = x.clamp(min=eps)
x2 = (1 - x).clamp(min=eps)
return torch.log(x1/x2)
def init_weights(module):
if isinstance(module, (nn.Linear, nn.Embedding)):
module.weight.data.normal_(mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def find_nth(vid, underline, n):
max_len = len(vid)
start = vid.find(underline)
while start >= 0 and n > 1:
start = vid.find(underline, start+len(underline))
n -= 1
if start == -1:
start = max_len
return start
def element_wise_list_equal(listA, listB):
res = []
for a, b in zip(listA, listB):
if a==b:
res.append(True)
else:
res.append(False)
return res
class CGDETR(nn.Module):
""" CG DETR. """
def __init__(self, transformer, position_embed, txt_position_embed, txt_dim, vid_dim,
num_queries, input_dropout, aux_loss=False,
contrastive_align_loss=False, contrastive_hdim=64,
max_v_l=75, span_loss_type="l1", use_txt_pos=False, n_input_proj=2, aud_dim=0, args=None):
""" Initializes the model.
Parameters:
transformer: torch module of the transformer architecture. See transformer.py
position_embed: torch module of the position_embedding, See position_encoding.py
txt_position_embed: position_embedding for text
txt_dim: int, text query input dimension
vid_dim: int, video feature input dimension
num_queries: number of object queries, ie detection slot. This is the maximal number of objects
CG-DETR can detect in a single video.
aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
contrastive_align_loss: If true, perform span - tokens contrastive learning
contrastive_hdim: dimension used for projecting the embeddings before computing contrastive loss
max_v_l: int, maximum #clips in videos
span_loss_type: str, one of [l1, ce]
l1: (center-x, width) regression.
ce: (st_idx, ed_idx) classification.
# foreground_thd: float, intersection over prediction >= foreground_thd: labeled as foreground
# background_thd: float, intersection over prediction <= background_thd: labeled background
"""
super().__init__()
self.args=args
self.num_queries = num_queries
self.transformer = transformer
self.position_embed = position_embed
self.txt_position_embed = txt_position_embed
hidden_dim = transformer.d_model
self.span_loss_type = span_loss_type
self.max_v_l = max_v_l
span_pred_dim = 2 if span_loss_type == "l1" else max_v_l * 2
self.span_embed = MLP(hidden_dim, hidden_dim, span_pred_dim, 3)
self.class_embed = nn.Linear(hidden_dim, 2) # 0: background, 1: foreground
self.token_type_embeddings = nn.Embedding(2, hidden_dim)
self.token_type_embeddings.apply(init_weights)
self.use_txt_pos = use_txt_pos
self.n_input_proj = n_input_proj
self.query_embed = nn.Embedding(num_queries, 2)
relu_args = [True] * 3
relu_args[n_input_proj-1] = False
self.input_txt_proj = nn.Sequential(*[
LinearLayer(txt_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
][:n_input_proj])
self.input_vid_proj = nn.Sequential(*[
LinearLayer(vid_dim + aud_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[0]),
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[1]),
LinearLayer(hidden_dim, hidden_dim, layer_norm=True, dropout=input_dropout, relu=relu_args[2])
][:n_input_proj])
self.contrastive_align_loss = contrastive_align_loss
if contrastive_align_loss:
self.contrastive_align_projection_query = nn.Linear(hidden_dim, contrastive_hdim)
self.contrastive_align_projection_txt = nn.Linear(hidden_dim, contrastive_hdim)
self.contrastive_align_projection_vid = nn.Linear(hidden_dim, contrastive_hdim)
self.saliency_proj1 = nn.Linear(hidden_dim, hidden_dim)
self.saliency_proj2 = nn.Linear(hidden_dim, hidden_dim)
self.aux_loss = aux_loss
self.hidden_dim = hidden_dim
self.global_rep_token = torch.nn.Parameter(torch.randn(args.total_prompts, hidden_dim))
self.global_rep_pos = torch.nn.Parameter(torch.randn(1, hidden_dim))
self.moment_rep_token = torch.nn.Parameter(torch.randn(hidden_dim))
self.moment_rep_pos = torch.nn.Parameter(torch.randn(hidden_dim))
self.dummy_rep_token = torch.nn.Parameter(torch.randn(args.num_dummies, hidden_dim))
self.dummy_rep_pos = torch.nn.Parameter(torch.randn(args.num_dummies, hidden_dim))
normalize_before = False
self.sent_rep_token = torch.nn.Parameter(torch.randn(hidden_dim))
self.sent_rep_pos = torch.nn.Parameter(torch.randn(hidden_dim))
self.txt_proj_linear = LinearLayer(txt_dim, hidden_dim, layer_norm=True)
input_txt_sa_proj = TransformerEncoderLayer(hidden_dim, 8, self.args.dim_feedforward, 0.1, "prelu", normalize_before)
txtproj_encoder_norm = nn.LayerNorm(hidden_dim) if normalize_before else None
self.txtproj_encoder = TransformerEncoder(input_txt_sa_proj, args.dummy_layers, txtproj_encoder_norm)
scls_encoder_layer = TransformerEncoderLayer(hidden_dim, 8, self.args.dim_feedforward, 0.1, "prelu", normalize_before)
scls_encoder_norm = nn.LayerNorm(hidden_dim) if normalize_before else None
self.scls_encoder = TransformerEncoder(scls_encoder_layer, args.sent_layers, scls_encoder_norm)
def forward(self, src_txt, src_txt_mask, src_vid, src_vid_mask, vid=None, qid=None, src_aud=None, src_aud_mask=None, targets=None, prompt_token=None):
"""The forward expects two tensors:
- src_txt: [batch_size, L_txt, D_txt]
- src_txt_mask: [batch_size, L_txt], containing 0 on padded pixels,
will convert to 1 as padding later for transformer
- src_vid: [batch_size, L_vid, D_vid]
- src_vid_mask: [batch_size, L_vid], containing 0 on padded pixels,
will convert to 1 as padding later for transformer
It returns a dict with the following elements:
- "pred_spans": The normalized boxes coordinates for all queries, represented as
(center_x, width). These values are normalized in [0, 1],
relative to the size of each individual image (disregarding possible padding).
See PostProcess for information on how to retrieve the unnormalized bounding box.
- "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
dictionnaries containing the two above keys for each decoder layer.
"""
## For discovering real negative samples
device = src_txt_mask.device
# import pdb; pdb.set_trace()
# if vid is not None: ## for demo (run_on_video/run.py)
# _count = [v.count('_') for v in vid]
# if self.args.dset_name == 'hl':
# _position_to_cut = [find_nth(v, '_', _count[i]-1) for i, v in enumerate(vid)]
# ori_vid = [v[:_position_to_cut[i]] for i, v in enumerate(vid)]
# else:
if vid is not None:
ori_vid = [v for v in vid]
if src_aud is not None:
src_vid = torch.cat([src_vid, src_aud], dim=2)
# --------------------------------
src_txt_list = []
src_txt_mask_list = []
for bs in range(src_txt.shape[0]):
idx = int(src_txt_mask[bs].sum().item())
src_txt_list.append(torch.cat((src_txt[bs, :idx, :], prompt_token[bs], src_txt[bs, idx:, :]), dim=0))
src_txt_mask_list.append(torch.cat((src_txt_mask[bs, :idx], torch.ones(1, dtype=torch.bfloat16).to(device), src_txt_mask[bs, idx:]), dim=0))
src_txt = torch.stack(src_txt_list, dim=0)
src_txt_mask = torch.stack(src_txt_mask_list, dim=0)
# --------------------------------
# src_txt = torch.cat((src_txt, prompt_token), dim=1)
# src_txt_mask = torch.cat((src_txt_mask, torch.zeros_like(prompt_token)), dim=1)
src_vid = self.input_vid_proj(src_vid) # [bsz,vlen,770] -> [bsz,vlen,256]
src_txt = self.input_txt_proj(src_txt) # [bsz,qlen,4096] -> [bsz,qlen, 256]
src_vid = src_vid + self.token_type_embeddings(torch.full_like(src_vid_mask.long(), 1)) # TODO
src_txt = src_txt + self.token_type_embeddings(torch.zeros_like(src_txt_mask.long()))
#
pos_vid = self.position_embed(src_vid, src_vid_mask).type(torch.bfloat16) # (bsz, L_vid, d)
pos_txt = self.txt_position_embed(src_txt) if self.use_txt_pos else torch.zeros_like(src_txt).type(torch.bfloat16) # (bsz, L_txt, d)
### insert dummy token in front of txt
txt_dummy = self.dummy_rep_token.reshape([1, self.args.num_dummies, self.hidden_dim]).repeat(src_txt.shape[0], 1, 1) # [bsz, 45, 256]
src_txt_dummy = torch.cat([txt_dummy, src_txt], dim=1) # [bsz, L_txt+45, 256]
mask_txt = torch.tensor([[True] * self.args.num_dummies]).to(src_txt_mask.device).repeat(src_txt_mask.shape[0], 1)
src_txt_mask_dummy = torch.cat([mask_txt, src_txt_mask], dim=1) # [bsz, L_txt+45]
pos_dummy = self.dummy_rep_pos.reshape([1, self.args.num_dummies, self.hidden_dim]).repeat(pos_txt.shape[0], 1, 1).type(torch.bfloat16)
pos_txt_dummy = torch.cat([pos_dummy, pos_txt], dim=1)
src_txt_dummy = src_txt_dummy.permute(1, 0, 2) # (L, batch_size, d)
pos_txt_dummy = pos_txt_dummy.permute(1, 0, 2) # (L, batch_size, d)
memory = self.txtproj_encoder(src_txt_dummy, src_key_padding_mask=~(src_txt_mask_dummy.bool()), pos=pos_txt_dummy) # (L, batch_size, d)
dummy_token = memory[:self.args.num_dummies].permute(1, 0, 2)
pos_txt_dummy = pos_txt_dummy.permute(1, 0, 2) # (L, batch_size, d)
src_txt_dummy = torch.cat([dummy_token, src_txt], dim=1)
mask_txt_dummy = torch.tensor([[True]*self.args.num_dummies]).to(src_txt_mask.device).repeat(src_txt_mask.shape[0], 1)
src_txt_mask_dummy = torch.cat([mask_txt_dummy, src_txt_mask], dim=1)
# Input : Concat video, dummy, txt
src = torch.cat([src_vid, src_txt_dummy], dim=1) # (bsz, L_vid+L_txt, d)
mask = torch.cat([src_vid_mask, src_txt_mask_dummy], dim=1).bool() # (bsz, L_vid+L_txt)
pos = torch.cat([pos_vid, pos_txt_dummy], dim=1)
### sentence token
smask_ = torch.tensor([[True]]).to(mask.device).repeat(src_txt_mask.shape[0], 1)
smask = torch.cat([smask_, src_txt_mask.bool()], dim=1)
ssrc_ = self.sent_rep_token.reshape([1, 1, self.hidden_dim]).repeat(src_txt.shape[0], 1, 1)
ssrc = torch.cat([ssrc_, src_txt], dim=1)
spos_ = self.sent_rep_pos.reshape([1, 1, self.hidden_dim]).repeat(pos_txt.shape[0], 1, 1)
spos = torch.cat([spos_, pos_txt], dim=1)
### dummy sentence token
smaskd = torch.cat([smask_, mask_txt_dummy.bool()], dim=1)
ssrcd = torch.cat([ssrc_, dummy_token], dim=1)
sposd = torch.cat([spos_, pos_dummy], dim=1)
if targets is not None: # train
mmask_ = torch.tensor([[True]]).to(mask.device).repeat(src_vid_mask.shape[0], 1)
mmask = torch.cat([mmask_, src_vid_mask.bool()], dim=1) # [bsz, L_vid+1]
moment_mask_ = torch.clamp(targets["relevant_clips"], 0, 1).bool()
moment_mask = torch.cat([mmask_, moment_mask_], dim=1) # [bsz, L_vid+1]
# if moment_mask.shape[1] != 76:
# import pdb; pdb.set_trace()
mmask = mmask * moment_mask
msrc_ = self.moment_rep_token.reshape([1, 1, self.hidden_dim]).repeat(src_vid.shape[0], 1, 1)
msrc = torch.cat([msrc_, src_vid], dim=1)
mpos_ = self.moment_rep_pos.reshape([1, 1, self.hidden_dim]).repeat(pos_vid.shape[0], 1, 1)
mpos = torch.cat([mpos_, pos_vid], dim=1)
### for Not moment token ####
nmmask_ = torch.tensor([[True]]).to(mask.device).repeat(src_vid_mask.shape[0], 1)
nmmask = torch.cat([nmmask_, src_vid_mask.bool()], dim=1)
nmoment_mask_ = ~(torch.clamp(targets["relevant_clips"], 0, 1).bool())
nmoment_mask = torch.cat([nmmask_, nmoment_mask_], dim=1)
nmmask = nmmask * nmoment_mask
nmsrc_ = self.moment_rep_token.reshape([1, 1, self.hidden_dim]).repeat(src_vid.shape[0], 1, 1)
nmsrc = torch.cat([nmsrc_, src_vid], dim=1)
nmpos_ = self.moment_rep_pos.reshape([1, 1, self.hidden_dim]).repeat(pos_vid.shape[0], 1, 1)
nmpos = torch.cat([nmpos_, pos_vid], dim=1)
###########
else:
moment_mask_ = None
# for t2vidavg sal token
# import pdb; pdb.set_trace()
vidsrc_ = torch.zeros((len(src_vid), 1, self.hidden_dim), dtype=torch.bfloat16).to(device)
for i in range(len(src_vid)):
vidsrc_[i] = src_vid[i][:src_vid_mask.sum(1)[i].long()].mean(0).clone().detach()
video_length = src_vid.shape[1]
if targets is not None: ## train
ssrc = ssrc.permute(1, 0, 2) # (L, batch_size, d)
spos = spos.permute(1, 0, 2) # (L, batch_size, d)
smemory = self.scls_encoder(ssrc, src_key_padding_mask=~smask, pos=spos) # (L, batch_size, d)
sentence_txt, smemory_words = smemory[0], smemory[1:] # sentence_txt : (batch_size, d)
ssrcd = ssrcd.permute(1, 0, 2) # (L, batch_size, d)
sposd = sposd.permute(1, 0, 2) # (L, batch_size, d)
smemoryd = self.scls_encoder(ssrcd, src_key_padding_mask=~smaskd, pos=sposd) # (L, batch_size, d)
sentence_dummy, smemory_words_dummy = smemoryd[0], smemoryd[1:]
txt_dummy_proj = torch.cat([smemory_words_dummy, smemory_words], dim=0)
# import pdb; pdb.set_trace()
# print(src.dtype)
hs, reference, memory, memory_global, attn_weights, memory_moment, nmmemory_moment, mmemory_frames, nmmemory_frames = self.transformer(src, ~mask, self.query_embed.weight, pos, video_length=video_length, moment_idx=targets["relevant_clips"], msrc=msrc, mpos=mpos, mmask=~mmask, nmsrc=nmsrc, nmpos=nmpos, nmmask=~nmmask,
ctxtoken=vidsrc_, gtoken=self.global_rep_token, gpos=self.global_rep_pos, vlen=src_vid_mask.sum(1).long())
moment2txt_similarity = torch.matmul(mmemory_frames.permute(1, 0, 2), txt_dummy_proj.permute(1, 2, 0))
nmoment2txt_similarity = torch.matmul(nmmemory_frames.permute(1, 0, 2), txt_dummy_proj.permute(1, 2, 0))
else: ## inference
sentence_dummy, sentence_txt, moment2txt_similarity, nmoment2txt_similarity = None, None, None, None
hs, reference, memory, memory_global, attn_weights, memory_moment, nmmemory_moment, mmemory_frames, nmmemory_frames = self.transformer(src, ~mask, self.query_embed.weight, pos, video_length=video_length,
ctxtoken=vidsrc_, gtoken=self.global_rep_token, gpos=self.global_rep_pos, vlen=src_vid_mask.sum(1).long())
outputs_class = self.class_embed(hs) # (#layers, batch_size, #queries, #classes)
reference_before_sigmoid = inverse_sigmoid(reference)
tmp = self.span_embed(hs)
outputs_coord = tmp + reference_before_sigmoid
if self.span_loss_type == "l1":
outputs_coord = outputs_coord.sigmoid()
out = {'pred_logits': outputs_class[-1], 'pred_spans': outputs_coord[-1]}
txt_mem = memory[:, src_vid.shape[1]:] # (bsz, L_txt, d)
vid_mem = memory[:, :src_vid.shape[1]] # (bsz, L_vid, d)
if self.contrastive_align_loss:
proj_queries = F.normalize(self.contrastive_align_projection_query(hs), p=2, dim=-1)
proj_txt_mem = F.normalize(self.contrastive_align_projection_txt(txt_mem), p=2, dim=-1)
proj_vid_mem = F.normalize(self.contrastive_align_projection_vid(vid_mem), p=2, dim=-1)
out.update(dict(
proj_queries=proj_queries[-1],
proj_txt_mem=proj_txt_mem,
proj_vid_mem=proj_vid_mem
))
if vid is not None: ## for demo (run_on_video/run.py)
### Neg Pairs ###
neg_vid = ori_vid[1:] + ori_vid[:1]
real_neg_mask = torch.Tensor(element_wise_list_equal(ori_vid, neg_vid)).to(src_txt_dummy.device)
real_neg_mask = real_neg_mask.type(torch.bfloat16)
real_neg_mask = real_neg_mask == False
# import pdb; pdb.set_trace()
if real_neg_mask.sum() != 0:
src_txt_dummy_neg = torch.cat([src_txt_dummy[1:], src_txt_dummy[0:1]], dim=0)
src_txt_mask_dummy_neg = torch.cat([src_txt_mask_dummy[1:], src_txt_mask_dummy[0:1]], dim=0)
src_dummy_neg = torch.cat([src_vid, src_txt_dummy_neg], dim=1)
mask_dummy_neg = torch.cat([src_vid_mask, src_txt_mask_dummy_neg], dim=1).bool()
pos_neg = pos.clone() # since it does not use actual content
mask_dummy_neg = mask_dummy_neg[real_neg_mask]
src_dummy_neg = src_dummy_neg[real_neg_mask]
pos_neg = pos_neg[real_neg_mask]
src_txt_mask_dummy_neg = src_txt_mask_dummy_neg[real_neg_mask]
# import pdb; pdb.set_trace()
_, _, memory_neg, memory_global_neg, attn_weights_neg, _, _, _, _ = self.transformer(src_dummy_neg, ~mask_dummy_neg, self.query_embed.weight, pos_neg, video_length=video_length,
ctxtoken=vidsrc_[real_neg_mask], gtoken=self.global_rep_token, gpos=self.global_rep_pos, vlen=src_vid_mask[real_neg_mask].sum(1).long())
vid_mem_neg = memory_neg[:, :src_vid.shape[1]]
out["saliency_scores_neg"] = (torch.sum(self.saliency_proj1(vid_mem_neg) * self.saliency_proj2(memory_global_neg).unsqueeze(1), dim=-1) / np.sqrt(self.hidden_dim))
out["src_txt_mask_neg"] = src_txt_mask_dummy_neg
out["t2vattnvalues_neg"] = (attn_weights_neg[:, :, self.args.num_dummies:] * (src_txt_mask_dummy_neg[:, self.args.num_dummies:].unsqueeze(1).repeat(1, video_length, 1))).sum(2)
out["t2vattnvalues_neg"] = torch.clamp(out["t2vattnvalues_neg"], 0, 1)
else:
out["saliency_scores_neg"] = None
out["t2vattnvalues_neg"] = None
out["real_neg_mask"] = real_neg_mask
else:
out["saliency_scores_neg"] = None
out["t2vattnvalues_neg"] = None
out["real_neg_mask"] = None
out["saliency_scores"] = (torch.sum(self.saliency_proj1(vid_mem) * self.saliency_proj2(memory_global).unsqueeze(1), dim=-1) / np.sqrt(self.hidden_dim))
out["memory_moment"] = memory_moment
out["nmmemory_moment"] = nmmemory_moment
## sentence token embeeded with text / dummy
out["sentence_txt"] = sentence_txt
out["sentence_dummy"] = sentence_dummy
out["moment2txt_similarity"] = moment2txt_similarity
out["nmoment2txt_similarity"] = nmoment2txt_similarity
out["cate_attn_weights"] = attn_weights
out["moment_mask"] = moment_mask_
out["txt_mask"] = src_txt_mask_dummy
out["t2vattnvalues"] = (attn_weights[:,:,self.args.num_dummies:] * (src_txt_mask.unsqueeze(1).repeat(1, video_length, 1))).sum(2) # (batch_size, L_vid, L_txt) / (batch_size, L_txt)
out["t2vattnvalues"] = torch.clamp(out["t2vattnvalues"], 0, 1)
out["dummy_tokens"] = dummy_token
out["global_rep_tokens"] = self.global_rep_token
# import pdb; pdb.set_trace()
if targets is not None:
out["src_vid"] = mmemory_frames.permute(1, 0, 2) * moment_mask_.unsqueeze(2) + nmmemory_frames.permute(1, 0, 2) * (~(moment_mask_.unsqueeze(2).bool())).bfloat16()
else:
out["src_vid"] = None
out["video_mask"] = src_vid_mask
if self.aux_loss:
# assert proj_queries and proj_txt_mem
out['aux_outputs'] = [
{'pred_logits': a, 'pred_spans': b} for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
if self.contrastive_align_loss:
assert proj_queries is not None
for idx, d in enumerate(proj_queries[:-1]):
out['aux_outputs'][idx].update(dict(proj_queries=d, proj_txt_mem=proj_txt_mem))
return out
class SetCriterion(nn.Module):
""" This class computes the loss for DETR.
The process happens in two steps:
1) we compute hungarian assignment between ground truth boxes and the outputs of the model
2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
"""
def __init__(self, matcher, weight_dict, eos_coef, losses, temperature, span_loss_type, max_v_l,
saliency_margin=1, use_matcher=True, args=None):
""" Create the criterion.
Parameters:
matcher: module able to compute a matching between targets and proposals
weight_dict: dict containing as key the names of the losses and as values their relative weight.
eos_coef: relative classification weight applied to the no-object category
losses: list of all the losses to be applied. See get_loss for list of available losses.
temperature: float, temperature for NCE loss
span_loss_type: str, [l1, ce]
max_v_l: int,
saliency_margin: float
"""
super().__init__()
self.args=args
self.matcher = matcher
self.weight_dict = weight_dict
self.losses = losses
self.temperature = temperature
self.span_loss_type = span_loss_type
self.max_v_l = max_v_l
self.saliency_margin = saliency_margin
# foreground and background classification
self.foreground_label = 0
self.background_label = 1
self.eos_coef = eos_coef
empty_weight = torch.ones(2)
empty_weight[-1] = self.eos_coef # lower weight for background (index 1, foreground index 0)
self.register_buffer('empty_weight', empty_weight)
# for tvsum,
self.use_matcher = use_matcher
# moment sentence contrastive
self.criterion = torch.nn.CrossEntropyLoss()#.to(self.args.device)
self.l2_criterion = torch.nn.MSELoss()#.to(self.args.device)
self.kld_criterion = torch.nn.KLDivLoss(reduction='none')#.to(self.args.device)
self.bce_criterion = nn.BCELoss(reduction='none')
def loss_spans(self, outputs, targets, indices):
"""Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
targets dicts must contain the key "spans" containing a tensor of dim [nb_tgt_spans, 2]
The target spans are expected in format (center_x, w), normalized by the image size.
"""
assert 'pred_spans' in outputs
targets = targets["span_labels"]
idx = self._get_src_permutation_idx(indices)
src_spans = outputs['pred_spans'][idx] # (#spans, max_v_l * 2)
tgt_spans = torch.cat([t['spans'][i] for t, (_, i) in zip(targets, indices)], dim=0) # (#spans, 2)
if self.span_loss_type == "l1":
loss_span = F.l1_loss(src_spans, tgt_spans, reduction='none')
loss_giou = 1 - torch.diag(generalized_temporal_iou(span_cxw_to_xx(src_spans), span_cxw_to_xx(tgt_spans)))
else: # ce
n_spans = src_spans.shape[0]
src_spans = src_spans.view(n_spans, 2, self.max_v_l).transpose(1, 2)
loss_span = F.cross_entropy(src_spans, tgt_spans, reduction='none')
loss_giou = loss_span.new_zeros([1])
losses = {}
losses['loss_span'] = loss_span.mean()
losses['loss_giou'] = loss_giou.mean()
return losses
def loss_labels(self, outputs, targets, indices, log=True):
"""Classification loss (NLL)
targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
"""
# TODO add foreground and background classifier. use all non-matched as background.
assert 'pred_logits' in outputs
src_logits = outputs['pred_logits'] # (batch_size, #queries, #classes=2)
# idx is a tuple of two 1D tensors (batch_idx, src_idx), of the same length == #objects in batch
idx = self._get_src_permutation_idx(indices)
target_classes = torch.full(src_logits.shape[:2], self.background_label,
dtype=torch.int64, device=src_logits.device) # (batch_size, #queries)
target_classes[idx] = self.foreground_label
loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight, reduction="none")
losses = {'loss_label': loss_ce.mean()}
if log:
# TODO this should probably be a separate loss, not hacked in this one here
losses['class_error'] = 100 - accuracy(src_logits[idx], self.foreground_label)[0]
return losses
def loss_saliency(self, outputs, targets, indices, log=True):
"""higher scores for positive clips"""
if "saliency_pos_labels" not in targets:
return {"loss_saliency": 0}
# Neg pair loss
if outputs["saliency_scores_neg"] is not None: ## When batch size is not 1 (negative pair exists)
vid_token_mask = outputs["video_mask"]
real_neg_mask = outputs["real_neg_mask"]
saliency_scores_neg = outputs["saliency_scores_neg"].clone() # (N, L)
loss_neg_pair = (- torch.log(1. - torch.sigmoid(saliency_scores_neg)) * (vid_token_mask[real_neg_mask])).sum(dim=1).mean()
saliency_scores = outputs["saliency_scores"].clone() # (N, L)
saliency_contrast_label = targets["saliency_all_labels"]
# real neg
realneg_saliency_scores = torch.cat([saliency_scores[real_neg_mask], saliency_scores_neg], dim=1)
realneg_saliency_contrast_label = torch.cat([saliency_contrast_label[real_neg_mask], torch.zeros_like(saliency_contrast_label)[real_neg_mask]], dim=1)
realneg_vid_token_mask = vid_token_mask[real_neg_mask].repeat([1, 2])
realneg_saliency_scores = realneg_vid_token_mask * realneg_saliency_scores + (1. - realneg_vid_token_mask) * -1e+3
tau = 0.5
loss_rank_contrastive = 0.
for rand_idx in range(1, 12):
drop_mask = ~(realneg_saliency_contrast_label > 100) # no drop
pos_mask = (realneg_saliency_contrast_label >= rand_idx) # positive when equal or higher than rand_idx
if torch.sum(pos_mask) == 0: # no positive sample
continue
else:
batch_drop_mask = torch.sum(pos_mask, dim=1) > 0 # negative sample indicator
# drop higher ranks
cur_saliency_scores = realneg_saliency_scores * drop_mask / tau + ~drop_mask * -1e+3
# numerical stability
logits = cur_saliency_scores - torch.max(cur_saliency_scores, dim=1, keepdim=True)[0]
# softmax
exp_logits = torch.exp(logits)
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6)
mean_log_prob_pos = (pos_mask * log_prob * realneg_vid_token_mask).sum(1) / (pos_mask.sum(1) + 1e-6)
loss = - mean_log_prob_pos * batch_drop_mask
loss_rank_contrastive = loss_rank_contrastive + loss.mean()
loss_rank_contrastive = loss_rank_contrastive / 12
false_neg_mask = ~(real_neg_mask)
if false_neg_mask.sum() != 0:
if false_neg_mask.sum() == 1:
falseneg_saliency_scores = saliency_scores[false_neg_mask].unsqueeze(0)
falseneg_saliency_contrast_label = saliency_contrast_label[false_neg_mask].unsqueeze(0)
falseneg_vid_token_mask = vid_token_mask[false_neg_mask].unsqueeze(0)
falseneg_saliency_scores = falseneg_vid_token_mask * falseneg_saliency_scores + (1. - falseneg_vid_token_mask) * -1e+3
else:
falseneg_saliency_scores = saliency_scores[false_neg_mask]
falseneg_saliency_contrast_label = saliency_contrast_label[false_neg_mask]
falseneg_vid_token_mask = vid_token_mask[false_neg_mask]
falseneg_saliency_scores = falseneg_vid_token_mask * falseneg_saliency_scores + (1. - falseneg_vid_token_mask) * -1e+3
tau = 0.5
falseneg_loss_rank_contrastive = 0.
for rand_idx in range(1, 12):
drop_mask = ~(falseneg_saliency_contrast_label > 100) # no drop
pos_mask = (falseneg_saliency_contrast_label >= rand_idx) # positive when equal or higher than rand_idx
if torch.sum(pos_mask) == 0: # no positive sample
continue
else:
batch_drop_mask = torch.sum(pos_mask, dim=1) > 0 # negative sample indicator
# drop higher ranks
cur_saliency_scores = falseneg_saliency_scores * drop_mask / tau + ~drop_mask * -1e+3
# numerical stability
logits = cur_saliency_scores - torch.max(cur_saliency_scores, dim=1, keepdim=True)[0]
# softmax
exp_logits = torch.exp(logits)
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6)
mean_log_prob_pos = (pos_mask * log_prob * falseneg_vid_token_mask).sum(1) / (pos_mask.sum(1) + 1e-6)
loss = - mean_log_prob_pos * batch_drop_mask
falseneg_loss_rank_contrastive = falseneg_loss_rank_contrastive + loss.mean()
falseneg_loss_rank_contrastive = falseneg_loss_rank_contrastive / 12
loss_rank_contrastive += falseneg_loss_rank_contrastive
saliency_scores = outputs["saliency_scores"] # (N, L)
pos_indices = targets["saliency_pos_labels"] # (N, #pairs)
neg_indices = targets["saliency_neg_labels"] # (N, #pairs)
num_pairs = pos_indices.shape[1] # typically 2 or 4
batch_indices = torch.arange(len(saliency_scores)).to(saliency_scores.device)
pos_scores = torch.stack(
[saliency_scores[batch_indices, pos_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1)
neg_scores = torch.stack(
[saliency_scores[batch_indices, neg_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1)
loss_saliency = torch.clamp(self.saliency_margin + neg_scores - pos_scores, min=0).sum() \
/ (len(pos_scores) * num_pairs) * 2 # * 2 to keep the loss the same scale
# if self.args.dset_name in ['youtube_uni']:
# loss_saliency = loss_saliency + loss_rank_contrastive + loss_neg_pair * 0.
# else:
loss_saliency = loss_saliency + loss_rank_contrastive + loss_neg_pair
########### Saliency loss to t2v attn weights ##############
"""higher scores for positive clips"""
vid_token_mask = outputs["video_mask"]
# Neg pair loss
if outputs["t2vattnvalues_neg"] is not None:
saliency_scores_neg = outputs["t2vattnvalues_neg"].clone() # (N, L)
loss_neg_pair_attn = (- torch.log(1. - saliency_scores_neg) * (vid_token_mask[real_neg_mask])).sum(dim=1).mean()
saliency_scores = outputs["t2vattnvalues"].clone() # (N, L)
saliency_contrast_label = targets["saliency_all_labels"]
# real neg
realneg_saliency_scores = torch.cat([saliency_scores[real_neg_mask], saliency_scores_neg], dim=1)
realneg_saliency_contrast_label = torch.cat(
[saliency_contrast_label[real_neg_mask], torch.zeros_like(saliency_contrast_label)[real_neg_mask]], dim=1)
realneg_vid_token_mask = vid_token_mask[real_neg_mask].repeat([1, 2])
realneg_saliency_scores = realneg_vid_token_mask * realneg_saliency_scores + (
1. - realneg_vid_token_mask) * -1e+3
tau = 0.5
loss_rank_contrastive_attn = 0.
for rand_idx in range(1, 12):
drop_mask = ~(realneg_saliency_contrast_label > 100) # no drop
pos_mask = (realneg_saliency_contrast_label >= rand_idx) # positive when equal or higher than rand_idx
if torch.sum(pos_mask) == 0: # no positive sample
continue
else:
batch_drop_mask = torch.sum(pos_mask, dim=1) > 0 # negative sample indicator
# drop higher ranks
cur_saliency_scores = realneg_saliency_scores * drop_mask / tau + ~drop_mask * -1e+3
# numerical stability
logits = cur_saliency_scores - torch.max(cur_saliency_scores, dim=1, keepdim=True)[0]
# softmax
exp_logits = torch.exp(logits)
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6)
mean_log_prob_pos = (pos_mask * log_prob * realneg_vid_token_mask).sum(1) / (pos_mask.sum(1) + 1e-6)
loss = - mean_log_prob_pos * batch_drop_mask
loss_rank_contrastive_attn = loss_rank_contrastive_attn + loss.mean()
loss_rank_contrastive_attn = loss_rank_contrastive_attn / 12
false_neg_mask = ~(real_neg_mask)
if false_neg_mask.sum() != 0:
if false_neg_mask.sum() == 1:
falseneg_saliency_scores = saliency_scores[false_neg_mask].unsqueeze(0)
falseneg_saliency_contrast_label = saliency_contrast_label[false_neg_mask].unsqueeze(0)
falseneg_vid_token_mask = vid_token_mask[false_neg_mask].unsqueeze(0)
falseneg_saliency_scores = falseneg_vid_token_mask * falseneg_saliency_scores + (1. - falseneg_vid_token_mask) * -1e+3
else:
falseneg_saliency_scores = saliency_scores[false_neg_mask]
falseneg_saliency_contrast_label = saliency_contrast_label[false_neg_mask]
falseneg_vid_token_mask = vid_token_mask[false_neg_mask]
falseneg_saliency_scores = falseneg_vid_token_mask * falseneg_saliency_scores + (1. - falseneg_vid_token_mask) * -1e+3
tau = 0.5
falseneg_loss_rank_contrastive = 0.
for rand_idx in range(1, 12):
drop_mask = ~(falseneg_saliency_contrast_label > 100) # no drop
pos_mask = (falseneg_saliency_contrast_label >= rand_idx) # positive when equal or higher than rand_idx
if torch.sum(pos_mask) == 0: # no positive sample
continue
else:
batch_drop_mask = torch.sum(pos_mask, dim=1) > 0 # negative sample indicator
# drop higher ranks
cur_saliency_scores = falseneg_saliency_scores * drop_mask / tau + ~drop_mask * -1e+3
# numerical stability
logits = cur_saliency_scores - torch.max(cur_saliency_scores, dim=1, keepdim=True)[0]
# softmax
exp_logits = torch.exp(logits)
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6)
mean_log_prob_pos = (pos_mask * log_prob * falseneg_vid_token_mask).sum(1) / (pos_mask.sum(1) + 1e-6)
loss = - mean_log_prob_pos * batch_drop_mask
falseneg_loss_rank_contrastive = falseneg_loss_rank_contrastive + loss.mean()
falseneg_loss_rank_contrastive = falseneg_loss_rank_contrastive / 12
loss_rank_contrastive += falseneg_loss_rank_contrastive
saliency_scores = outputs["t2vattnvalues"] # (N, L)
pos_indices = targets["saliency_pos_labels"] # (N, #pairs)
neg_indices = targets["saliency_neg_labels"] # (N, #pairs)
num_pairs = pos_indices.shape[1] # typically 2 or 4
batch_indices = torch.arange(len(saliency_scores)).to(saliency_scores.device)
pos_scores = torch.stack(
[saliency_scores[batch_indices, pos_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1)
neg_scores = torch.stack(
[saliency_scores[batch_indices, neg_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1)
loss_saliency_attn = torch.clamp(self.saliency_margin + neg_scores - pos_scores, min=0).sum() \
/ (len(pos_scores) * num_pairs) * 2 # * 2 to keep the loss the same scale
saliency_binary_label = torch.clamp(targets["saliency_all_labels"], 0, 1)
logits = saliency_scores.reshape(-1)
labels_x = saliency_binary_label.reshape(-1)
BCEcriterion = nn.BCELoss()
bceloss = BCEcriterion(logits, labels_x)
# if self.args.dset_name in ['youtube_uni']:
# loss_saliency_attn = loss_rank_contrastive_attn + bceloss + loss_neg_pair_attn * 0 + loss_saliency_attn
# else:
loss_saliency_attn = loss_rank_contrastive_attn + bceloss + loss_neg_pair_attn + loss_saliency_attn
loss_saliency += (loss_saliency_attn * self.args.lw_wattn)
else: ## when batch size == 1
vid_token_mask = outputs["video_mask"]
saliency_scores = outputs["saliency_scores"].clone() # (N, L)
saliency_contrast_label = targets["saliency_all_labels"]
saliency_scores = vid_token_mask * saliency_scores + (1. - vid_token_mask) * -1e+3
tau = 0.5
loss_rank_contrastive = 0.
for rand_idx in range(1, 12):
drop_mask = ~(saliency_contrast_label > 100) # no drop
pos_mask = (saliency_contrast_label >= rand_idx) # positive when equal or higher than rand_idx
if torch.sum(pos_mask) == 0: # no positive sample
continue
else:
batch_drop_mask = torch.sum(pos_mask, dim=1) > 0 # negative sample indicator
# drop higher ranks
cur_saliency_scores = saliency_scores * drop_mask / tau + ~drop_mask * -1e+3
# numerical stability
logits = cur_saliency_scores - torch.max(cur_saliency_scores, dim=1, keepdim=True)[0]
# softmax
exp_logits = torch.exp(logits)
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6)
mean_log_prob_pos = (pos_mask * log_prob * vid_token_mask).sum(1) / (pos_mask.sum(1) + 1e-6)
loss = - mean_log_prob_pos * batch_drop_mask
loss_rank_contrastive = loss_rank_contrastive + loss.mean()
loss_rank_contrastive = loss_rank_contrastive / 12
saliency_scores = outputs["saliency_scores"] # (N, L)
pos_indices = targets["saliency_pos_labels"] # (N, #pairs)
neg_indices = targets["saliency_neg_labels"] # (N, #pairs)
num_pairs = pos_indices.shape[1] # typically 2 or 4
batch_indices = torch.arange(len(saliency_scores)).to(saliency_scores.device)
pos_scores = torch.stack(
[saliency_scores[batch_indices, pos_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1)
neg_scores = torch.stack(
[saliency_scores[batch_indices, neg_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1)
loss_saliency = torch.clamp(self.saliency_margin + neg_scores - pos_scores, min=0).sum() \
/ (len(pos_scores) * num_pairs) * 2 # * 2 to keep the loss the same scale
loss_saliency = loss_saliency + loss_rank_contrastive
########### Saliency loss to t2v attn weights ##############
"""higher scores for positive clips"""
vid_token_mask = outputs["video_mask"]
saliency_scores = outputs["t2vattnvalues"].clone() # (N, L)
saliency_contrast_label = targets["saliency_all_labels"]
saliency_scores = vid_token_mask * saliency_scores + (1. - vid_token_mask) * -1e+3
tau = 0.5
loss_rank_contrastive = 0.
for rand_idx in range(1, 12):
drop_mask = ~(saliency_contrast_label > 100) # no drop
pos_mask = (saliency_contrast_label >= rand_idx) # positive when equal or higher than rand_idx
if torch.sum(pos_mask) == 0: # no positive sample
continue
else:
batch_drop_mask = torch.sum(pos_mask, dim=1) > 0 # negative sample indicator
# drop higher ranks
cur_saliency_scores = saliency_scores * drop_mask / tau + ~drop_mask * -1e+3
# numerical stability
logits = cur_saliency_scores - torch.max(cur_saliency_scores, dim=1, keepdim=True)[0]
# softmax
exp_logits = torch.exp(logits)
log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-6)
mean_log_prob_pos = (pos_mask * log_prob * vid_token_mask).sum(1) / (pos_mask.sum(1) + 1e-6)
loss = - mean_log_prob_pos * batch_drop_mask
loss_rank_contrastive = loss_rank_contrastive + loss.mean()
loss_rank_contrastive_attn = loss_rank_contrastive / 12
saliency_scores = outputs["t2vattnvalues"] # (N, L)
pos_indices = targets["saliency_pos_labels"] # (N, #pairs)
neg_indices = targets["saliency_neg_labels"] # (N, #pairs)
num_pairs = pos_indices.shape[1] # typically 2 or 4
batch_indices = torch.arange(len(saliency_scores)).to(saliency_scores.device)
pos_scores = torch.stack(
[saliency_scores[batch_indices, pos_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1)
neg_scores = torch.stack(
[saliency_scores[batch_indices, neg_indices[:, col_idx]] for col_idx in range(num_pairs)], dim=1)
loss_saliency_attn = torch.clamp(self.saliency_margin + neg_scores - pos_scores, min=0).sum() \
/ (len(pos_scores) * num_pairs) * 2 # * 2 to keep the loss the same scale
saliency_binary_label = torch.clamp(targets["saliency_all_labels"], 0, 1)
logits = saliency_scores.reshape(-1)
labels_x = saliency_binary_label.reshape(-1)
BCEcriterion = nn.BCELoss()
bceloss = BCEcriterion(logits, labels_x)
loss_saliency_attn = loss_rank_contrastive_attn + bceloss + loss_saliency_attn
loss_saliency += (loss_saliency_attn * self.args.lw_wattn)
return {"loss_saliency": loss_saliency}
def loss_contrastive_moment_sentence(self, outputs, targets, indices, log=True):
if outputs["memory_moment"] is not None:
moment_token = outputs["memory_moment"]
nmmemory_moment = outputs["nmmemory_moment"]
sentence_token = outputs["sentence_txt"].squeeze(1)
sentence_dummy = outputs["sentence_dummy"].squeeze(1) # b, 1, d
moment_logits = F.normalize(moment_token, dim=1)
nmoment_logits = F.normalize(nmmemory_moment, dim=1)
sentence_logits = F.normalize(sentence_token, dim=1)
dummy_logits = F.normalize(sentence_dummy, dim=1)
# import pdb; pdb.set_trace()
similarity_matrix = torch.matmul(moment_logits, sentence_logits.T) # B B
nsimilarity_matrix = torch.matmul(nmoment_logits, sentence_logits.T) # B B
similarity_matrix = torch.cat([similarity_matrix, nsimilarity_matrix], dim=1)
labels = torch.eye(similarity_matrix.shape[0]).to(sentence_logits.device)
nlabels = torch.zeros_like(nsimilarity_matrix).to(sentence_logits.device)
labels = torch.cat([labels, nlabels], dim=1).max(dim=1)[1]
loss_ms_align = self.criterion(similarity_matrix, labels)
dummy_similarity_matrix = torch.matmul(moment_logits, dummy_logits.T)
dummy_nsimilarity_matrix = torch.matmul(nmoment_logits, dummy_logits.T)
dummy_similarity_matrix = torch.cat([dummy_similarity_matrix, dummy_nsimilarity_matrix], dim=1)
dummy_labels = (~(torch.eye(similarity_matrix.shape[0]).to(sentence_logits.device).bool())).float()
dummy_nlabels = torch.ones_like(nsimilarity_matrix).to(sentence_logits.device)
dummy_labels = torch.cat([dummy_labels, dummy_nlabels], dim=1).max(dim=1)[1]
dummy_loss_ms_align = self.criterion(dummy_similarity_matrix, dummy_labels)
loss_ms_align += dummy_loss_ms_align
video_mask = outputs['video_mask']
src_vid = outputs['src_vid'] # [bsz, L_vid, D_vid]
moment_mask_ = torch.clamp(targets["relevant_clips"], 0, 1)
momtokcls_pred = torch.matmul(moment_token.unsqueeze(1), src_vid.permute(0, 2, 1)) # bsz 1 L_vid
momtokcls_label = moment_mask_
momtokcls_logit = torch.sigmoid(momtokcls_pred)
loss_ms_align += (self.bce_criterion(momtokcls_logit.reshape(-1), momtokcls_label.reshape(-1)) * video_mask.reshape(-1)).mean()
else:
loss_ms_align = 0.
return {"loss_ms_align": loss_ms_align}
#
def loss_moment2txt_sim_distill(self, outputs, targets, indices, log=True):
if outputs["moment2txt_similarity"] is not None:
moment2txt_similarity = outputs["moment2txt_similarity"] # bsz L_clip 22
moment_mask = outputs["moment_mask"].int() # bsz L_clip 1
txt_mask = outputs["txt_mask"].unsqueeze(1).repeat(1, outputs["cate_attn_weights"].size(1), 1) # bsz l_t
attn_weights = outputs["cate_attn_weights"] # bsz L_clip 22
b, L_vid, L_txt = attn_weights.size()
loss_distill = self.kld_criterion(
torch.log(attn_weights + 1e-6).reshape(b * L_vid, -1),
torch.softmax(moment2txt_similarity, dim=-1).clone().detach().reshape(b * L_vid, -1)).mean(1) * moment_mask.reshape(-1)
loss_distill = loss_distill.sum() / moment_mask.sum()
else:
loss_distill = 0.
return {"loss_distill": loss_distill}
def loss_orthogonal_dummy(self, outputs, targets, indices, log=True):
dummy_tokens = outputs["dummy_tokens"] # (n_dum, dim)
if dummy_tokens.size(1) != 1:
dummy_tokens_norm = dummy_tokens / dummy_tokens.norm(dim=2)[:, :, None]
dummy_tokens_sim = torch.matmul(dummy_tokens_norm, dummy_tokens_norm.permute(0, 2, 1).detach())
for i in range(len(dummy_tokens_sim)):
dummy_tokens_sim[i].fill_diagonal_(0)
loss_dummy_ortho = dummy_tokens_sim.abs().mean()
else:
loss_dummy_ortho=0.
global_tokens = outputs["global_rep_tokens"]
global_tokens_norm = global_tokens / global_tokens.norm(dim=1)[:, None]
global_tokens_sim = torch.matmul(global_tokens_norm, global_tokens_norm.permute(1, 0).detach())
for i in range(len(global_tokens_sim)):
global_tokens_sim.fill_diagonal_(0)
loss_dummy_ortho += global_tokens_sim.abs().mean()
return {"loss_orthogonal_dummy": loss_dummy_ortho}
def loss_contrastive_align(self, outputs, targets, indices, log=True):
"""encourage higher scores between matched query span and input text"""
normalized_text_embed = outputs["proj_txt_mem"] # (bsz, #tokens, d) text tokens
normalized_img_embed = outputs["proj_queries"] # (bsz, #queries, d)
logits = torch.einsum(
"bmd,bnd->bmn", normalized_img_embed, normalized_text_embed) # (bsz, #queries, #tokens)
logits = logits.sum(2) / self.temperature # (bsz, #queries)
idx = self._get_src_permutation_idx(indices)
positive_map = torch.zeros_like(logits, dtype=torch.bool)
positive_map[idx] = True
positive_logits = logits.masked_fill(~positive_map, 0)
pos_term = positive_logits.sum(1) # (bsz, )
num_pos = positive_map.sum(1) # (bsz, )
neg_term = logits.logsumexp(1) # (bsz, )
loss_nce = - pos_term / num_pos + neg_term # (bsz, )
losses = {"loss_contrastive_align": loss_nce.mean()}
return losses
def loss_contrastive_align_vid_txt(self, outputs, targets, indices, log=True):
"""encourage higher scores between matched query span and input text"""
normalized_text_embed = outputs["proj_txt_mem"] # (bsz, #tokens, d) text tokens
normalized_img_embed = outputs["proj_queries"] # (bsz, #queries, d)
logits = torch.einsum(
"bmd,bnd->bmn", normalized_img_embed, normalized_text_embed) # (bsz, #queries, #tokens)
logits = logits.sum(2) / self.temperature # (bsz, #queries)
idx = self._get_src_permutation_idx(indices)
positive_map = torch.zeros_like(logits, dtype=torch.bool)
positive_map[idx] = True
positive_logits = logits.masked_fill(~positive_map, 0)
pos_term = positive_logits.sum(1) # (bsz, )
num_pos = positive_map.sum(1) # (bsz, )
neg_term = logits.logsumexp(1) # (bsz, )
loss_nce = - pos_term / num_pos + neg_term # (bsz, )
losses = {"loss_contrastive_align": loss_nce.mean()}
return losses
def _get_src_permutation_idx(self, indices):
# permute predictions following indices
batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
src_idx = torch.cat([src for (src, _) in indices])
return batch_idx, src_idx # two 1D tensors of the same length
def _get_tgt_permutation_idx(self, indices):
# permute targets following indices
batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
tgt_idx = torch.cat([tgt for (_, tgt) in indices])
return batch_idx, tgt_idx
def get_loss(self, loss, outputs, targets, indices, **kwargs):
loss_map = {
"spans": self.loss_spans,
"labels": self.loss_labels,
"contrastive_align": self.loss_contrastive_align,
"saliency": self.loss_saliency,
"ms_align": self.loss_contrastive_moment_sentence,
"distill": self.loss_moment2txt_sim_distill,
"orthogonal_dummy":self.loss_orthogonal_dummy
}
assert loss in loss_map, f'do you really want to compute {loss} loss?'
return loss_map[loss](outputs, targets, indices, **kwargs)
def forward(self, outputs, targets):
""" This performs the loss computation.
Parameters:
outputs: dict of tensors, see the output specification of the model for the format
targets: list of dicts, such that len(targets) == batch_size.
The expected keys in each dict depends on the losses applied, see each loss' doc
"""
outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
# Retrieve the matching between the outputs of the last layer and the targets
# list(tuples), each tuple is (pred_span_indices, tgt_span_indices)
# only for HL, do not use matcher
if self.use_matcher:
# import pdb; pdb.set_trace()
indices = self.matcher(outputs_without_aux, targets)
losses_target = self.losses
else:
indices = None
losses_target = ["saliency"]
# Compute all the requested losses
losses = {}
for loss in losses_target:
losses.update(self.get_loss(loss, outputs, targets, indices))
# In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
if 'aux_outputs' in outputs:
for i, aux_outputs in enumerate(outputs['aux_outputs']):
# indices = self.matcher(aux_outputs, targets)
if self.use_matcher:
indices = self.matcher(aux_outputs, targets)
losses_target = self.losses
else:
indices = None
losses_target = ["saliency", "ms_align", "distill", "orthogonal_dummy"]
for loss in losses_target:
if "saliency" == loss: # skip as it is only in the top layer
continue
if "ms_align" == loss:
continue
if "distill" == loss:
continue
if "orthogonal_dummy" == loss:
continue
kwargs = {}
l_dict = self.get_loss(loss, aux_outputs, targets, indices, **kwargs)
l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
losses.update(l_dict)
return losses
class MLP(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
def forward(self, x):
for i, layer in enumerate(self.layers):
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
return x
class LinearLayer(nn.Module):
"""linear layer configurable with layer normalization, dropout, ReLU."""
def __init__(self, input_dim, output_dim, layer_norm=True, dropout=0.1, relu=True):
super(LinearLayer, self).__init__()
self.relu = relu
self.layer_norm = layer_norm
if layer_norm:
self.LayerNorm = nn.LayerNorm(input_dim)
layers = [
nn.Dropout(dropout),
nn.Linear(input_dim, output_dim)
]
self.net = nn.Sequential(*layers)
def forward(self, x):
"""(N, L, D)"""
if self.layer_norm:
x = self.LayerNorm(x)
x = self.net(x)
if self.relu:
x = F.relu(x, inplace=True)
return x # (N, L, D)
class CGDETRConfig:
def __init__(self, dset_name='charadesSTA', eval_split_name='val', data_ratio=1.0,
results_root='results', exp_id=None, max_es_cnt=200, eval_epoch=5,
grad_clip=0.1, eval_untrained=False, resume_all=False, start_epoch=None,
max_q_l=-1, max_v_l=-1, clip_length=1, max_windows=5, train_path=None,
eval_path=None, no_norm_vfeat=False, no_norm_tfeat=False, v_feat_dirs=None,
t_feat_dir=None, v_feat_dim=770, t_feat_dim=4096, ctx_mode='video_tef',
position_embedding='sine', enc_layers=3, dec_layers=3, t2v_layers=2,
sent_layers=1, moment_layers=1, dummy_layers=2, dim_feedforward=1024,
hidden_dim=256, input_dropout=0.5, dropout=0.1, txt_drop_ratio=0,
use_txt_pos=False, nheads=8, num_queries=10, num_dummies=45,
total_prompts=10, num_prompts=1, pre_norm=False, n_input_proj=2,
contrastive_hdim=64, temperature=0.07, saliency_margin=0.2, aux_loss=True,
span_loss_type='l1', contrastive_align_loss=False, set_cost_span=10,
set_cost_giou=1, set_cost_class=4, lw_saliency=4, lw_wattn=1.0,
lw_ms_align=1.0, lw_distill=1.0, span_loss_coef=10, giou_loss_coef=1,
label_loss_coef=4, eos_coef=0.1, contrastive_align_loss_coef=0.02,
no_sort_results=False, max_before_nms=10, max_after_nms=10,
conf_thd=0.0, nms_thd=-1):
self.dset_name = dset_name
self.eval_split_name = eval_split_name
self.data_ratio = data_ratio
self.results_root = results_root
self.exp_id = exp_id
self.max_es_cnt = max_es_cnt
self.eval_epoch = eval_epoch
self.grad_clip = grad_clip
self.eval_untrained = eval_untrained
self.resume_all = resume_all
self.start_epoch = start_epoch
self.max_q_l = max_q_l
self.max_v_l = max_v_l
self.clip_length = clip_length
self.max_windows = max_windows
self.train_path = train_path
self.eval_path = eval_path
self.no_norm_vfeat = no_norm_vfeat
self.no_norm_tfeat = no_norm_tfeat
self.v_feat_dirs = v_feat_dirs
self.t_feat_dir = t_feat_dir
self.v_feat_dim = v_feat_dim
self.t_feat_dim = t_feat_dim
self.ctx_mode = ctx_mode
self.position_embedding = position_embedding
self.enc_layers = enc_layers
self.dec_layers = dec_layers
self.t2v_layers = t2v_layers
self.sent_layers = sent_layers
self.moment_layers = moment_layers
self.dummy_layers = dummy_layers
self.dim_feedforward = dim_feedforward
self.hidden_dim = hidden_dim
self.input_dropout = input_dropout
self.dropout = dropout
self.txt_drop_ratio = txt_drop_ratio
self.use_txt_pos = use_txt_pos
self.nheads = nheads
self.num_queries = num_queries
self.num_dummies = num_dummies
self.total_prompts = total_prompts
self.num_prompts = num_prompts
self.pre_norm = pre_norm
self.n_input_proj = n_input_proj
self.contrastive_hdim = contrastive_hdim
self.temperature = temperature
self.saliency_margin = saliency_margin
self.aux_loss = aux_loss
self.span_loss_type = span_loss_type
self.contrastive_align_loss = contrastive_align_loss
self.set_cost_span = set_cost_span
self.set_cost_giou = set_cost_giou
self.set_cost_class = set_cost_class
self.lw_saliency = lw_saliency
self.lw_wattn = lw_wattn
self.lw_ms_align = lw_ms_align
self.lw_distill = lw_distill
self.span_loss_coef = span_loss_coef
self.giou_loss_coef = giou_loss_coef
self.label_loss_coef = label_loss_coef
self.eos_coef = eos_coef
self.contrastive_align_loss_coef = contrastive_align_loss_coef
self.no_sort_results = no_sort_results
self.max_before_nms = max_before_nms
self.max_after_nms = max_after_nms
self.conf_thd = conf_thd
self.nms_thd = nms_thd
def build_cgdetr_model():
# device = torch.device(args.device)
# import pdb; pdb.set_trace()
args = CGDETRConfig()
transformer = build_transformer(args)
position_embedding, txt_position_embedding = build_position_encoding(args)
# if args.a_feat_dir is None:
model = CGDETR(
transformer,
position_embedding,
txt_position_embedding,
txt_dim=args.t_feat_dim,
vid_dim=args.v_feat_dim,
num_queries=args.num_queries,
input_dropout=args.input_dropout,
aux_loss=args.aux_loss,
contrastive_align_loss=args.contrastive_align_loss,
contrastive_hdim=args.contrastive_hdim,
span_loss_type=args.span_loss_type,
use_txt_pos=args.use_txt_pos,
n_input_proj=args.n_input_proj,
args=args
)
# else:
# model = CGDETR(
# transformer,
# position_embedding,
# txt_position_embedding,
# txt_dim=args.t_feat_dim,
# vid_dim=args.v_feat_dim,
# aud_dim=args.a_feat_dim,
# num_queries=args.num_queries,
# input_dropout=args.input_dropout,
# aux_loss=args.aux_loss,
# contrastive_align_loss=args.contrastive_align_loss,
# contrastive_hdim=args.contrastive_hdim,
# span_loss_type=args.span_loss_type,
# use_txt_pos=args.use_txt_pos,
# n_input_proj=args.n_input_proj,
# args=args
# )
matcher = build_matcher(args)
weight_dict = {"loss_span": args.span_loss_coef,
"loss_giou": args.giou_loss_coef,
"loss_label": args.label_loss_coef,
"loss_saliency": args.lw_saliency,
"loss_ms_align": args.lw_ms_align,
"loss_distill": args.lw_distill,
"loss_orthogonal_dummy":args.lw_distill}
if args.contrastive_align_loss:
weight_dict["loss_contrastive_align"] = args.contrastive_align_loss_coef
if args.aux_loss:
aux_weight_dict = {}
for i in range(args.dec_layers - 1):
aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items() if k != "loss_saliency"})
weight_dict.update(aux_weight_dict)
losses = ['spans', 'labels', 'saliency', 'ms_align', 'distill', 'orthogonal_dummy']
if args.contrastive_align_loss:
losses += ["contrastive_align"]
# For highlight detection datasets
# use_matcher = not (args.dset_name in ['youtube_uni', 'tvsum'])
use_matcher = True
criterion = SetCriterion(
matcher=matcher, weight_dict=weight_dict, losses=losses,
eos_coef=args.eos_coef, temperature=args.temperature,
span_loss_type=args.span_loss_type, max_v_l=args.max_v_l,
saliency_margin=args.saliency_margin, use_matcher=use_matcher, args=args
)
# criterion.to(device)
return model, criterion