File size: 9,757 Bytes
3b96cb1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Union

import mmengine
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule

from mmpretrain.registry import MODELS


@MODELS.register_module()
class VQAGenerationHead(BaseModule):
    """Generation head for multi-modal pre-trained task, adapted by BLIP.
    Normally used for qa generation task (open-set)

    Args:
        decoder (dict): Decoder for decoding answers.
        inference_method (str): Inference method. One of 'rank', 'generate'.
            - If 'rank', the model will return answers with the highest
                probability from the answer list.
            - If 'generate', the model will generate answers.
            - Only for test, not for train / val.
        num_beams (int): Number of beams for beam search. 1 means no beam
            search. Only support when inference_method=='generate'.
            Defaults to 3.
        num_ans_candidates (int): Number of answer candidates, used to filter
            out answers with low probability. Only support when
            inference_method=='rank'. Defaults to 128.
        loss (dict or nn.Module): Config of loss or module of loss. Defaults to
            ``nn.CrossEntropyLoss(reduction='none', ignore_index=-100)``.
        init_cfg (dict, optional): the config to control the initialization.
            Defaults to None.
        answer_list_path (str, optional): Path to `answer_list.json`
            (json file of a answer list). Required when
            inference_method=='rank'.


    TODO: `mmcls.LabelSmoothLoss` has not support `ignore_index` param.
    Now using `nn.CrossEntropyLoss`, without label_smoothing, in order to
    maintain compatibility with torch < 1.10.0
    """

    def __init__(
        self,
        decoder: dict,
        inference_method: str = 'generate',
        num_beams: int = 3,
        num_ans_candidates: int = 128,
        loss: Union[dict, nn.Module] = nn.CrossEntropyLoss(
            reduction='none', ignore_index=-100),
        init_cfg: Optional[dict] = None,
        answer_list_path: Optional[str] = None,
    ) -> None:

        super(VQAGenerationHead, self).__init__(init_cfg=init_cfg)
        self.decoder = MODELS.build(decoder)

        if inference_method == 'generate':
            assert isinstance(num_beams, int), \
                'for VQA `generate` mode, `num_beams` must be a int.'
            self.num_beams = num_beams
            self.num_ans_candidates = None
            self.answer_list = None

        elif inference_method == 'rank':
            assert isinstance(num_ans_candidates, int), \
                'for VQA `rank` mode, `num_ans_candidates` must be a int.'
            assert isinstance(answer_list_path, str), \
                'for VQA `rank` mode, `answer_list_path` must be set as ' \
                'the path to `answer_list.json`.'
            self.num_beams = None
            self.answer_list = mmengine.load(answer_list_path)
            if isinstance(self.answer_list, dict):
                self.answer_list = list(self.answer_list.keys())
            assert isinstance(self.answer_list, list) and all(
                isinstance(item, str) for item in self.answer_list), \
                'for VQA `rank` mode, `answer_list.json` must be a list of str'
            self.num_ans_candidates = min(num_ans_candidates,
                                          len(self.answer_list))

        else:
            raise AssertionError(
                'for VQA, `inference_method` must be "generate" or "rank", '
                'got {}.'.format(inference_method))

        self.inference_method = inference_method
        if not isinstance(loss, nn.Module):
            loss = MODELS.build(loss)
        self.loss_module = loss

    def forward(self, feats: dict):
        prediction_logits = self.decoder(
            feats['answer_input_ids'],
            attention_mask=feats['answer_attention_mask'],
            encoder_hidden_states=feats['question_states'],
            encoder_attention_mask=feats['question_atts'],
            labels=feats['answer_targets'],
            return_dict=True,
            return_logits=True,  # directly return logits, not computing loss
            reduction='none',
        )
        return prediction_logits

    def loss(self, feats: dict, data_samples=None):
        """Calculate losses from the extracted features.

        Args:
            feats (dict): The features extracted from the backbone.
            data_samples (List[BaseDataElement]): The annotation data of
                every samples.

        Returns:
            dict[str, Tensor]: a dictionary of loss components
        """
        shifted_prediction_scores = self(feats)
        labels = feats['answer_targets']
        lm_loss = None

        # we are doing next-token prediction;
        # shift prediction scores and input ids by one
        labels = labels[:, 1:].contiguous()
        lm_loss = self.loss_module(
            shifted_prediction_scores.view(-1,
                                           self.decoder.med_config.vocab_size),
            labels.view(-1))
        lm_loss = lm_loss.view(shifted_prediction_scores.size(0), -1).sum(1)
        # compute weighted loss
        losses = dict()
        loss = feats['answer_weight'] * lm_loss
        loss = loss.sum() / feats['batch_size']
        losses['vqa_loss'] = loss

        return losses

    def predict_rank(self, feats: dict, data_samples=None):
        """Predict rank in a close-set answer list."""
        question_states = feats['multimodal_embeds']
        question_atts = feats['question_atts']
        answer_candidates = feats['answer_candidates']
        assert answer_candidates is not None

        answer_ids = answer_candidates.input_ids
        answer_atts = answer_candidates.attention_mask
        num_ques = question_states.size(0)
        start_ids = answer_ids[0, 0].repeat(num_ques, 1)  # bos token

        start_output = self.decoder(
            start_ids,
            encoder_hidden_states=question_states,
            encoder_attention_mask=question_atts,
            return_dict=True,
            reduction='none',
        )
        logits = start_output.logits[:, 0, :]  # first token's logit

        # topk_probs: top-k probability
        # topk_ids: [num_question, k]
        answer_first_token = answer_ids[:, 1]
        prob_first_token = F.softmax(
            logits, dim=1).index_select(
                dim=1, index=answer_first_token)
        topk_probs, topk_ids = prob_first_token.topk(
            self.num_ans_candidates, dim=1)

        # answer input: [num_question*k, answer_len]
        input_ids = []
        input_atts = []
        for b, topk_id in enumerate(topk_ids):
            input_ids.append(answer_ids.index_select(dim=0, index=topk_id))
            input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
        input_ids = torch.cat(input_ids, dim=0)
        input_atts = torch.cat(input_atts, dim=0)

        targets_ids = input_ids.masked_fill(input_ids == feats['pad_token_id'],
                                            -100)

        def tile(x, dim, n_tile):
            init_dim = x.size(dim)
            repeat_idx = [1] * x.dim()
            repeat_idx[dim] = n_tile
            x = x.repeat(*(repeat_idx))
            order_index = torch.LongTensor(
                np.concatenate([
                    init_dim * np.arange(n_tile) + i for i in range(init_dim)
                ]))
            return torch.index_select(x, dim, order_index.to(x.device))

        # repeat encoder's output for top-k answers
        question_states = tile(question_states, 0, self.num_ans_candidates)
        question_atts = tile(question_atts, 0, self.num_ans_candidates)

        output = self.decoder(
            input_ids,
            attention_mask=input_atts,
            encoder_hidden_states=question_states,
            encoder_attention_mask=question_atts,
            labels=targets_ids,
            return_dict=True,
            reduction='none',
        )

        log_probs_sum = -output.loss
        log_probs_sum = log_probs_sum.view(num_ques, self.num_ans_candidates)

        max_topk_ids = log_probs_sum.argmax(dim=1)
        max_ids = topk_ids[max_topk_ids >= 0, max_topk_ids]

        answers = [self.answer_list[max_id] for max_id in max_ids]

        return answers

    def predict_generate(self, feats: dict, data_samples=None):
        """Predict answers in a generation manner."""
        device = feats['multimodal_embeds'].device
        question_states = feats['multimodal_embeds']
        question_atts = torch.ones(
            question_states.size()[:-1], dtype=torch.long).to(device)
        model_kwargs = {
            'encoder_hidden_states': question_states,
            'encoder_attention_mask': question_atts
        }

        bos_ids = torch.full((feats['multimodal_embeds'].shape[0], 1),
                             fill_value=feats['bos_token_id'],
                             device=device)

        outputs = self.decoder.generate(
            input_ids=bos_ids,
            max_length=10,
            min_length=1,
            num_beams=self.num_beams,
            eos_token_id=feats['sep_token_id'],
            pad_token_id=feats['pad_token_id'],
            **model_kwargs)

        return outputs

    def predict(self, feats: dict, data_samples=None):
        """Predict results from the extracted features."""
        if self.inference_method == 'generate':
            return self.predict_generate(feats, data_samples)
        elif self.inference_method == 'rank':
            return self.predict_rank(feats, data_samples)