File size: 9,135 Bytes
345ee20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
# Base on M-MAE
# https://github.com/EPFL-VILAB/MultiMAE/blob/main/multimae/input_adapters.py
import os
import pdb
from typing import Dict, List, Optional, Tuple, Union
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from einops import rearrange, repeat
from core.utils import NestedTensor
from dict_recursive_update import recursive_update
from timm.models.layers import drop_path, to_2tuple, trunc_normal_

class TextAdapter(nn.Module):
    """

    Text Adapter for text embedding. This one is now a placeholder, only support bert embedding.

    :param: patch_size: placeholder, achieve similar kwargs with other adapters, which is defined in the solver_mae_devnew.py

    :param: in_chans: placeholder, achieve similar kwargs with other adapters

    :param: stride_level: placeholder, achieve similar kwargs with other adapters

    :param: task_sp_list: List of task specific list for DDP communication. Default: ()

    :param: description_dict_name: the name of the dictionary of all description names in datasets

    :param: one_way_semantics: whether only consider the positive semantics, which is the 1 in the label.

    :param: skeleton_action: whether the label is skeleton action, such as semantics in ntu60.

                            in skeleton action, the label of each sample contains M(2) people,

                            each person has N(25) joints, so the batch size is B*M.

    """

    def __init__(self,

                 patch_size=[None, None],  # placeholder, achieve similar kwargs with other adapters

                 in_chans=None,  # placeholder, achieve similar kwargs with other adapters

                 stride_level=None,  # placeholder, achieve similar kwargs with other adapters

                 task_sp_list=(),

                 modality_share_list=(),

                 embed_dim=768,

                 description_dict_name='rap2_attr_name',

                 one_way_semantics=True,

                 skeleton_action=False,

                 skeleton_action_one_hot_label=False,

                 people_cnt=2,

                 image_caption=False,

                 max_tokens=30,

                 ):
        super(TextAdapter, self).__init__()

        self.patch_size = patch_size
        self.in_chans = in_chans
        self.stride_level = stride_level
        self.task_sp_list = task_sp_list
        self.modality_share_list = modality_share_list
        self.one_way_semantics = one_way_semantics
        self.skeleton_action = skeleton_action
        self.skeleton_action_one_hot_label = skeleton_action_one_hot_label
        self.people_cnt = people_cnt
        self.image_caption = image_caption
        self.embed_dim = embed_dim


        if not self.image_caption:
            if isinstance(description_dict_name, list):
                text_vectors = []
                for this_dict_name in description_dict_name:
                    text_vectors.append(torch.load(f'./{this_dict_name}.pth'))
                text_vectors = torch.cat(text_vectors, dim=0)
            else:
                import os
                # print(os.getcwd())
                text_vectors = torch.load(f'./{description_dict_name}.pth')
            self.text_vectors = nn.Parameter(torch.zeros_like(text_vectors), requires_grad=False)
            self.text_vectors.data.copy_(text_vectors)
            num_tokens = self.text_vectors.shape[0]
            self.patch_shape = [1, num_tokens]
            if self.skeleton_action and not self.skeleton_action_one_hot_label:
                self.patch_shape = [1, 1]
        elif self.image_caption:
            self.patch_shape = [1, max_tokens]

    def forward(self, input_var):
        """

        :param input_var:

        :return: input_var: adapter_output_text: {'tokens': text_embedding,

                                                    'Bs': B,

                                                    'N_H': 1,

                                                    'N_W': text_embedding.shape[1],

                                                    'attn_mask':None}

        """
        output = {}

        if not self.image_caption:
            label = input_var['label'] # bs, num_cls of attribute  #64,120
            B = label.shape[0]
            if len(label.shape)>1:
                if self.one_way_semantics:
                    #  only consider the positive semantics, which is the 1 in the label,
                    #  generate fake all-one label to extract the text embedding
                    label = torch.ones_like(label, device=label.device)
                    text_embedding = self.text_vectors[range(len(self.text_vectors)), label.long()] # bs, num_cls of attr, bert_dim
                else:
                    text_embedding = self.text_vectors[None,:,:].repeat(B,1,1) #bs,cls-num,dim
                if self.skeleton_action:
                    assert self.skeleton_action_one_hot_label
                    # for skeleton action, the label of each sample contains M(people_cnt) people,
                    text_embedding = text_embedding[:,None,...]
                    text_embedding = text_embedding.repeat(1,self.people_cnt, 1, 1)
                    text_embedding = text_embedding.reshape(B*self.people_cnt, text_embedding.shape[-2], text_embedding.shape[-1])
            else:
                #  label example:  tensor([ 8, 33, 39, 34,  6, 14, 1 ]), which is a single classification label
                assert self.one_way_semantics
                text_embedding = self.text_vectors[:, 1]  # num_cls, bert_dim
                text_embedding = text_embedding[label.long()][:, None, :] # bs, 1, bert_dim
                if self.skeleton_action:
                    text_embedding = text_embedding.repeat(1, 2, 1) # bs, 2, bert_dim
                    text_embedding = text_embedding.reshape(B * 2, 1, text_embedding.shape[-1]) # bs*2, 1, bert_dim
        elif self.image_caption:
            token_id, token_type_id, token_padding_mask = input_var['input_id'], input_var['token_type_id'], input_var['padding_mask']
            if self.training:
                text_embedding = torch.zeros(token_id.shape[0], token_id[:, :-1].shape[1], self.embed_dim).cuda()
            else:
                text_embedding = torch.zeros(token_id.shape[0], token_id.shape[1], self.embed_dim).cuda()

            B = token_id.shape[0]

        if text_embedding.shape[-1]!=self.embed_dim:
            text_embedding = torch.zeros(text_embedding.shape[0], text_embedding.shape[1], self.embed_dim).cuda()
        output['adapter_output_text'] = {'tokens': text_embedding,
                                         'Bs': B if not self.skeleton_action else B*self.people_cnt,
                                         'N_H': 1, # assume the text is a two-dim sequence, which has the shape of [1, num_tokens]
                                         'N_W': text_embedding.shape[1],
                                         'attn_mask':None}

        input_var.update(output)

        return input_var

def extract_bert_features(rap2_attr_name, bert_tokenizer, bert_model, dim=768):
    """

    Extract the bert features for all sentences in rap2_attr_name

    :param rap2_attr_name: the dictionary of all attribute names in rap2

    :param bert_tokenizer: tokenizer of bert

    :param bert_model: bert model

    :return: rap2_attr_vectors: the bert features of all sentences in rap2_attr_name

    """
    # Traverse the dictionary to get all sentences
    sentences = []
    max_val_idx = 0
    for attr_dict in rap2_attr_name.values():
        for attr_desc in attr_dict.values():
            sentences.append(attr_desc)
        max_val_idx = max(max_val_idx, len(attr_dict)-1)
    # Batch code and input all sentences into BERT model
    input_ids = bert_tokenizer(sentences, return_tensors='pt',padding=True,truncation=True)
    import pdb;
    #Input the whole batch and extract features
    with torch.no_grad():
        last_hidden_states = bert_model(**input_ids)
    # Initialize the output matrix
    rap2_attr_vectors = torch.zeros([len(rap2_attr_name), max_val_idx+1, dim])
    # Select the corresponding sentence features from last_hidden_states
    idx = 0
    for attr_idx, attr_dict in rap2_attr_name.items():
        for val_idx in range(max_val_idx+1):
            if val_idx in attr_dict:
                # use the pooler_output as the representation of the sentence
                rap2_attr_vectors[attr_idx,val_idx] = last_hidden_states[1][idx]
                idx += 1
    # pdb.set_trace()
    return rap2_attr_vectors

def text_adapter(pretrained=True, **kwargs):
    """

    Create a text adapter for text embedding

    :param pretrained: no pretrained model for text adapter

    :param kwargs: other parameters

    :return: adapter: text adapter

    """
    default = dict()
    recursive_update(default, kwargs)
    adapter = TextAdapter(**default)

    return adapter