File size: 3,687 Bytes
ebf5d87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import torch
import torch.nn as nn
from baselines.mixture_embedding_experts.model_components import NetVLAD, MaxMarginRankingLoss, GatedEmbeddingUnit
from easydict import EasyDict as edict

mee_base_cfg = edict(
    ctx_mode="video",
    text_input_size=768,
    vid_input_size=1024,
    output_size=256,
    margin=0.2
)


class MEE(nn.Module):
    def __init__(self, config):
        super(MEE, self).__init__()
        self.config = config
        self.use_video = "video" in config.ctx_mode
        self.use_sub = "sub" in config.ctx_mode

        self.query_pooling = NetVLAD(feature_size=config.text_input_size, cluster_size=2)

        if self.use_sub:
            self.sub_query_gu = GatedEmbeddingUnit(input_dimension=self.query_pooling.out_dim,
                                                   output_dimension=config.output_size)
            self.sub_gu = GatedEmbeddingUnit(input_dimension=config.text_input_size,
                                             output_dimension=config.output_size)

        if self.use_video:
            self.video_query_gu = GatedEmbeddingUnit(input_dimension=self.query_pooling.out_dim,
                                                     output_dimension=config.output_size)
            self.video_gu = GatedEmbeddingUnit(input_dimension=config.vid_input_size,
                                               output_dimension=config.output_size)

        if self.use_video and self.use_sub:
            self.moe_fc = nn.Linear(self.query_pooling.out_dim, 2)  # weights

        self.max_margin_loss = MaxMarginRankingLoss(margin=config.margin)

    def forward(self, query_feat, query_mask, video_feat, sub_feat):
        """
        Args:
            query_feat: (N, L, D_q)
            query_mask: (N, L)
            video_feat: (N, Dv)
            sub_feat: (N, Dt)
        """
        pooled_query = self.query_pooling(query_feat)  # (N, Dt)
        encoded_video, encoded_sub = self.encode_context(video_feat, sub_feat)
        confusion_matrix = self.get_score_from_pooled_query_with_encoded_ctx(pooled_query, encoded_video, encoded_sub)
        return self.max_margin_loss(confusion_matrix)

    def encode_context(self, video_feat, sub_feat):
        """(N, D)"""
        encoded_video = self.video_gu(video_feat) if self.use_video else None
        encoded_sub = self.sub_gu(sub_feat) if self.use_sub else None
        return encoded_video, encoded_sub

    def compute_single_stream_scores_with_encoded_ctx(self, pooled_query, encoded_ctx, module_name="video"):
        encoded_query = getattr(self, module_name+"_query_gu")(pooled_query)  # (N, D)
        return torch.einsum("md,nd->mn", encoded_query, encoded_ctx)  # (N, N)

    def get_score_from_pooled_query_with_encoded_ctx(self, pooled_query, encoded_video, encoded_sub):
        """Nq may not equal to Nc
        Args:
            pooled_query: (Nq, Dt)
            encoded_video: (Nc, Dc)
            encoded_sub: (Nc, Dc)
        """

        video_confusion_matrix = self.compute_single_stream_scores_with_encoded_ctx(
            pooled_query, encoded_video, module_name="video") if self.use_video else 0
        sub_confusion_matrix = self.compute_single_stream_scores_with_encoded_ctx(
                pooled_query, encoded_sub, module_name="sub") if self.use_sub else 0

        if self.use_video and self.use_sub:
            stream_weights = self.moe_fc(pooled_query)  # (N, 2)
            confusion_matrix = \
                stream_weights[:, 0:1] * video_confusion_matrix + stream_weights[:, 1:2] * sub_confusion_matrix
        else:
            confusion_matrix = video_confusion_matrix + sub_confusion_matrix
        return confusion_matrix  # (Nq, Nc)