Spaces:
Runtime error
Runtime error
File size: 9,757 Bytes
3b96cb1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Union
import mmengine
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule
from mmpretrain.registry import MODELS
@MODELS.register_module()
class VQAGenerationHead(BaseModule):
"""Generation head for multi-modal pre-trained task, adapted by BLIP.
Normally used for qa generation task (open-set)
Args:
decoder (dict): Decoder for decoding answers.
inference_method (str): Inference method. One of 'rank', 'generate'.
- If 'rank', the model will return answers with the highest
probability from the answer list.
- If 'generate', the model will generate answers.
- Only for test, not for train / val.
num_beams (int): Number of beams for beam search. 1 means no beam
search. Only support when inference_method=='generate'.
Defaults to 3.
num_ans_candidates (int): Number of answer candidates, used to filter
out answers with low probability. Only support when
inference_method=='rank'. Defaults to 128.
loss (dict or nn.Module): Config of loss or module of loss. Defaults to
``nn.CrossEntropyLoss(reduction='none', ignore_index=-100)``.
init_cfg (dict, optional): the config to control the initialization.
Defaults to None.
answer_list_path (str, optional): Path to `answer_list.json`
(json file of a answer list). Required when
inference_method=='rank'.
TODO: `mmcls.LabelSmoothLoss` has not support `ignore_index` param.
Now using `nn.CrossEntropyLoss`, without label_smoothing, in order to
maintain compatibility with torch < 1.10.0
"""
def __init__(
self,
decoder: dict,
inference_method: str = 'generate',
num_beams: int = 3,
num_ans_candidates: int = 128,
loss: Union[dict, nn.Module] = nn.CrossEntropyLoss(
reduction='none', ignore_index=-100),
init_cfg: Optional[dict] = None,
answer_list_path: Optional[str] = None,
) -> None:
super(VQAGenerationHead, self).__init__(init_cfg=init_cfg)
self.decoder = MODELS.build(decoder)
if inference_method == 'generate':
assert isinstance(num_beams, int), \
'for VQA `generate` mode, `num_beams` must be a int.'
self.num_beams = num_beams
self.num_ans_candidates = None
self.answer_list = None
elif inference_method == 'rank':
assert isinstance(num_ans_candidates, int), \
'for VQA `rank` mode, `num_ans_candidates` must be a int.'
assert isinstance(answer_list_path, str), \
'for VQA `rank` mode, `answer_list_path` must be set as ' \
'the path to `answer_list.json`.'
self.num_beams = None
self.answer_list = mmengine.load(answer_list_path)
if isinstance(self.answer_list, dict):
self.answer_list = list(self.answer_list.keys())
assert isinstance(self.answer_list, list) and all(
isinstance(item, str) for item in self.answer_list), \
'for VQA `rank` mode, `answer_list.json` must be a list of str'
self.num_ans_candidates = min(num_ans_candidates,
len(self.answer_list))
else:
raise AssertionError(
'for VQA, `inference_method` must be "generate" or "rank", '
'got {}.'.format(inference_method))
self.inference_method = inference_method
if not isinstance(loss, nn.Module):
loss = MODELS.build(loss)
self.loss_module = loss
def forward(self, feats: dict):
prediction_logits = self.decoder(
feats['answer_input_ids'],
attention_mask=feats['answer_attention_mask'],
encoder_hidden_states=feats['question_states'],
encoder_attention_mask=feats['question_atts'],
labels=feats['answer_targets'],
return_dict=True,
return_logits=True, # directly return logits, not computing loss
reduction='none',
)
return prediction_logits
def loss(self, feats: dict, data_samples=None):
"""Calculate losses from the extracted features.
Args:
feats (dict): The features extracted from the backbone.
data_samples (List[BaseDataElement]): The annotation data of
every samples.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
shifted_prediction_scores = self(feats)
labels = feats['answer_targets']
lm_loss = None
# we are doing next-token prediction;
# shift prediction scores and input ids by one
labels = labels[:, 1:].contiguous()
lm_loss = self.loss_module(
shifted_prediction_scores.view(-1,
self.decoder.med_config.vocab_size),
labels.view(-1))
lm_loss = lm_loss.view(shifted_prediction_scores.size(0), -1).sum(1)
# compute weighted loss
losses = dict()
loss = feats['answer_weight'] * lm_loss
loss = loss.sum() / feats['batch_size']
losses['vqa_loss'] = loss
return losses
def predict_rank(self, feats: dict, data_samples=None):
"""Predict rank in a close-set answer list."""
question_states = feats['multimodal_embeds']
question_atts = feats['question_atts']
answer_candidates = feats['answer_candidates']
assert answer_candidates is not None
answer_ids = answer_candidates.input_ids
answer_atts = answer_candidates.attention_mask
num_ques = question_states.size(0)
start_ids = answer_ids[0, 0].repeat(num_ques, 1) # bos token
start_output = self.decoder(
start_ids,
encoder_hidden_states=question_states,
encoder_attention_mask=question_atts,
return_dict=True,
reduction='none',
)
logits = start_output.logits[:, 0, :] # first token's logit
# topk_probs: top-k probability
# topk_ids: [num_question, k]
answer_first_token = answer_ids[:, 1]
prob_first_token = F.softmax(
logits, dim=1).index_select(
dim=1, index=answer_first_token)
topk_probs, topk_ids = prob_first_token.topk(
self.num_ans_candidates, dim=1)
# answer input: [num_question*k, answer_len]
input_ids = []
input_atts = []
for b, topk_id in enumerate(topk_ids):
input_ids.append(answer_ids.index_select(dim=0, index=topk_id))
input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
input_ids = torch.cat(input_ids, dim=0)
input_atts = torch.cat(input_atts, dim=0)
targets_ids = input_ids.masked_fill(input_ids == feats['pad_token_id'],
-100)
def tile(x, dim, n_tile):
init_dim = x.size(dim)
repeat_idx = [1] * x.dim()
repeat_idx[dim] = n_tile
x = x.repeat(*(repeat_idx))
order_index = torch.LongTensor(
np.concatenate([
init_dim * np.arange(n_tile) + i for i in range(init_dim)
]))
return torch.index_select(x, dim, order_index.to(x.device))
# repeat encoder's output for top-k answers
question_states = tile(question_states, 0, self.num_ans_candidates)
question_atts = tile(question_atts, 0, self.num_ans_candidates)
output = self.decoder(
input_ids,
attention_mask=input_atts,
encoder_hidden_states=question_states,
encoder_attention_mask=question_atts,
labels=targets_ids,
return_dict=True,
reduction='none',
)
log_probs_sum = -output.loss
log_probs_sum = log_probs_sum.view(num_ques, self.num_ans_candidates)
max_topk_ids = log_probs_sum.argmax(dim=1)
max_ids = topk_ids[max_topk_ids >= 0, max_topk_ids]
answers = [self.answer_list[max_id] for max_id in max_ids]
return answers
def predict_generate(self, feats: dict, data_samples=None):
"""Predict answers in a generation manner."""
device = feats['multimodal_embeds'].device
question_states = feats['multimodal_embeds']
question_atts = torch.ones(
question_states.size()[:-1], dtype=torch.long).to(device)
model_kwargs = {
'encoder_hidden_states': question_states,
'encoder_attention_mask': question_atts
}
bos_ids = torch.full((feats['multimodal_embeds'].shape[0], 1),
fill_value=feats['bos_token_id'],
device=device)
outputs = self.decoder.generate(
input_ids=bos_ids,
max_length=10,
min_length=1,
num_beams=self.num_beams,
eos_token_id=feats['sep_token_id'],
pad_token_id=feats['pad_token_id'],
**model_kwargs)
return outputs
def predict(self, feats: dict, data_samples=None):
"""Predict results from the extracted features."""
if self.inference_method == 'generate':
return self.predict_generate(feats, data_samples)
elif self.inference_method == 'rank':
return self.predict_rank(feats, data_samples)
|