TTP / mmpretrain /models /heads /vqa_head.py
KyanChen's picture
Upload 1861 files
3b96cb1
raw
history blame
9.76 kB
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Union
import mmengine
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModule
from mmpretrain.registry import MODELS
@MODELS.register_module()
class VQAGenerationHead(BaseModule):
"""Generation head for multi-modal pre-trained task, adapted by BLIP.
Normally used for qa generation task (open-set)
Args:
decoder (dict): Decoder for decoding answers.
inference_method (str): Inference method. One of 'rank', 'generate'.
- If 'rank', the model will return answers with the highest
probability from the answer list.
- If 'generate', the model will generate answers.
- Only for test, not for train / val.
num_beams (int): Number of beams for beam search. 1 means no beam
search. Only support when inference_method=='generate'.
Defaults to 3.
num_ans_candidates (int): Number of answer candidates, used to filter
out answers with low probability. Only support when
inference_method=='rank'. Defaults to 128.
loss (dict or nn.Module): Config of loss or module of loss. Defaults to
``nn.CrossEntropyLoss(reduction='none', ignore_index=-100)``.
init_cfg (dict, optional): the config to control the initialization.
Defaults to None.
answer_list_path (str, optional): Path to `answer_list.json`
(json file of a answer list). Required when
inference_method=='rank'.
TODO: `mmcls.LabelSmoothLoss` has not support `ignore_index` param.
Now using `nn.CrossEntropyLoss`, without label_smoothing, in order to
maintain compatibility with torch < 1.10.0
"""
def __init__(
self,
decoder: dict,
inference_method: str = 'generate',
num_beams: int = 3,
num_ans_candidates: int = 128,
loss: Union[dict, nn.Module] = nn.CrossEntropyLoss(
reduction='none', ignore_index=-100),
init_cfg: Optional[dict] = None,
answer_list_path: Optional[str] = None,
) -> None:
super(VQAGenerationHead, self).__init__(init_cfg=init_cfg)
self.decoder = MODELS.build(decoder)
if inference_method == 'generate':
assert isinstance(num_beams, int), \
'for VQA `generate` mode, `num_beams` must be a int.'
self.num_beams = num_beams
self.num_ans_candidates = None
self.answer_list = None
elif inference_method == 'rank':
assert isinstance(num_ans_candidates, int), \
'for VQA `rank` mode, `num_ans_candidates` must be a int.'
assert isinstance(answer_list_path, str), \
'for VQA `rank` mode, `answer_list_path` must be set as ' \
'the path to `answer_list.json`.'
self.num_beams = None
self.answer_list = mmengine.load(answer_list_path)
if isinstance(self.answer_list, dict):
self.answer_list = list(self.answer_list.keys())
assert isinstance(self.answer_list, list) and all(
isinstance(item, str) for item in self.answer_list), \
'for VQA `rank` mode, `answer_list.json` must be a list of str'
self.num_ans_candidates = min(num_ans_candidates,
len(self.answer_list))
else:
raise AssertionError(
'for VQA, `inference_method` must be "generate" or "rank", '
'got {}.'.format(inference_method))
self.inference_method = inference_method
if not isinstance(loss, nn.Module):
loss = MODELS.build(loss)
self.loss_module = loss
def forward(self, feats: dict):
prediction_logits = self.decoder(
feats['answer_input_ids'],
attention_mask=feats['answer_attention_mask'],
encoder_hidden_states=feats['question_states'],
encoder_attention_mask=feats['question_atts'],
labels=feats['answer_targets'],
return_dict=True,
return_logits=True, # directly return logits, not computing loss
reduction='none',
)
return prediction_logits
def loss(self, feats: dict, data_samples=None):
"""Calculate losses from the extracted features.
Args:
feats (dict): The features extracted from the backbone.
data_samples (List[BaseDataElement]): The annotation data of
every samples.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
shifted_prediction_scores = self(feats)
labels = feats['answer_targets']
lm_loss = None
# we are doing next-token prediction;
# shift prediction scores and input ids by one
labels = labels[:, 1:].contiguous()
lm_loss = self.loss_module(
shifted_prediction_scores.view(-1,
self.decoder.med_config.vocab_size),
labels.view(-1))
lm_loss = lm_loss.view(shifted_prediction_scores.size(0), -1).sum(1)
# compute weighted loss
losses = dict()
loss = feats['answer_weight'] * lm_loss
loss = loss.sum() / feats['batch_size']
losses['vqa_loss'] = loss
return losses
def predict_rank(self, feats: dict, data_samples=None):
"""Predict rank in a close-set answer list."""
question_states = feats['multimodal_embeds']
question_atts = feats['question_atts']
answer_candidates = feats['answer_candidates']
assert answer_candidates is not None
answer_ids = answer_candidates.input_ids
answer_atts = answer_candidates.attention_mask
num_ques = question_states.size(0)
start_ids = answer_ids[0, 0].repeat(num_ques, 1) # bos token
start_output = self.decoder(
start_ids,
encoder_hidden_states=question_states,
encoder_attention_mask=question_atts,
return_dict=True,
reduction='none',
)
logits = start_output.logits[:, 0, :] # first token's logit
# topk_probs: top-k probability
# topk_ids: [num_question, k]
answer_first_token = answer_ids[:, 1]
prob_first_token = F.softmax(
logits, dim=1).index_select(
dim=1, index=answer_first_token)
topk_probs, topk_ids = prob_first_token.topk(
self.num_ans_candidates, dim=1)
# answer input: [num_question*k, answer_len]
input_ids = []
input_atts = []
for b, topk_id in enumerate(topk_ids):
input_ids.append(answer_ids.index_select(dim=0, index=topk_id))
input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
input_ids = torch.cat(input_ids, dim=0)
input_atts = torch.cat(input_atts, dim=0)
targets_ids = input_ids.masked_fill(input_ids == feats['pad_token_id'],
-100)
def tile(x, dim, n_tile):
init_dim = x.size(dim)
repeat_idx = [1] * x.dim()
repeat_idx[dim] = n_tile
x = x.repeat(*(repeat_idx))
order_index = torch.LongTensor(
np.concatenate([
init_dim * np.arange(n_tile) + i for i in range(init_dim)
]))
return torch.index_select(x, dim, order_index.to(x.device))
# repeat encoder's output for top-k answers
question_states = tile(question_states, 0, self.num_ans_candidates)
question_atts = tile(question_atts, 0, self.num_ans_candidates)
output = self.decoder(
input_ids,
attention_mask=input_atts,
encoder_hidden_states=question_states,
encoder_attention_mask=question_atts,
labels=targets_ids,
return_dict=True,
reduction='none',
)
log_probs_sum = -output.loss
log_probs_sum = log_probs_sum.view(num_ques, self.num_ans_candidates)
max_topk_ids = log_probs_sum.argmax(dim=1)
max_ids = topk_ids[max_topk_ids >= 0, max_topk_ids]
answers = [self.answer_list[max_id] for max_id in max_ids]
return answers
def predict_generate(self, feats: dict, data_samples=None):
"""Predict answers in a generation manner."""
device = feats['multimodal_embeds'].device
question_states = feats['multimodal_embeds']
question_atts = torch.ones(
question_states.size()[:-1], dtype=torch.long).to(device)
model_kwargs = {
'encoder_hidden_states': question_states,
'encoder_attention_mask': question_atts
}
bos_ids = torch.full((feats['multimodal_embeds'].shape[0], 1),
fill_value=feats['bos_token_id'],
device=device)
outputs = self.decoder.generate(
input_ids=bos_ids,
max_length=10,
min_length=1,
num_beams=self.num_beams,
eos_token_id=feats['sep_token_id'],
pad_token_id=feats['pad_token_id'],
**model_kwargs)
return outputs
def predict(self, feats: dict, data_samples=None):
"""Predict results from the extracted features."""
if self.inference_method == 'generate':
return self.predict_generate(feats, data_samples)
elif self.inference_method == 'rank':
return self.predict_rank(feats, data_samples)