Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Optional, Union | |
import mmengine | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmengine.model import BaseModule | |
from mmpretrain.registry import MODELS | |
class VQAGenerationHead(BaseModule): | |
"""Generation head for multi-modal pre-trained task, adapted by BLIP. | |
Normally used for qa generation task (open-set) | |
Args: | |
decoder (dict): Decoder for decoding answers. | |
inference_method (str): Inference method. One of 'rank', 'generate'. | |
- If 'rank', the model will return answers with the highest | |
probability from the answer list. | |
- If 'generate', the model will generate answers. | |
- Only for test, not for train / val. | |
num_beams (int): Number of beams for beam search. 1 means no beam | |
search. Only support when inference_method=='generate'. | |
Defaults to 3. | |
num_ans_candidates (int): Number of answer candidates, used to filter | |
out answers with low probability. Only support when | |
inference_method=='rank'. Defaults to 128. | |
loss (dict or nn.Module): Config of loss or module of loss. Defaults to | |
``nn.CrossEntropyLoss(reduction='none', ignore_index=-100)``. | |
init_cfg (dict, optional): the config to control the initialization. | |
Defaults to None. | |
answer_list_path (str, optional): Path to `answer_list.json` | |
(json file of a answer list). Required when | |
inference_method=='rank'. | |
TODO: `mmcls.LabelSmoothLoss` has not support `ignore_index` param. | |
Now using `nn.CrossEntropyLoss`, without label_smoothing, in order to | |
maintain compatibility with torch < 1.10.0 | |
""" | |
def __init__( | |
self, | |
decoder: dict, | |
inference_method: str = 'generate', | |
num_beams: int = 3, | |
num_ans_candidates: int = 128, | |
loss: Union[dict, nn.Module] = nn.CrossEntropyLoss( | |
reduction='none', ignore_index=-100), | |
init_cfg: Optional[dict] = None, | |
answer_list_path: Optional[str] = None, | |
) -> None: | |
super(VQAGenerationHead, self).__init__(init_cfg=init_cfg) | |
self.decoder = MODELS.build(decoder) | |
if inference_method == 'generate': | |
assert isinstance(num_beams, int), \ | |
'for VQA `generate` mode, `num_beams` must be a int.' | |
self.num_beams = num_beams | |
self.num_ans_candidates = None | |
self.answer_list = None | |
elif inference_method == 'rank': | |
assert isinstance(num_ans_candidates, int), \ | |
'for VQA `rank` mode, `num_ans_candidates` must be a int.' | |
assert isinstance(answer_list_path, str), \ | |
'for VQA `rank` mode, `answer_list_path` must be set as ' \ | |
'the path to `answer_list.json`.' | |
self.num_beams = None | |
self.answer_list = mmengine.load(answer_list_path) | |
if isinstance(self.answer_list, dict): | |
self.answer_list = list(self.answer_list.keys()) | |
assert isinstance(self.answer_list, list) and all( | |
isinstance(item, str) for item in self.answer_list), \ | |
'for VQA `rank` mode, `answer_list.json` must be a list of str' | |
self.num_ans_candidates = min(num_ans_candidates, | |
len(self.answer_list)) | |
else: | |
raise AssertionError( | |
'for VQA, `inference_method` must be "generate" or "rank", ' | |
'got {}.'.format(inference_method)) | |
self.inference_method = inference_method | |
if not isinstance(loss, nn.Module): | |
loss = MODELS.build(loss) | |
self.loss_module = loss | |
def forward(self, feats: dict): | |
prediction_logits = self.decoder( | |
feats['answer_input_ids'], | |
attention_mask=feats['answer_attention_mask'], | |
encoder_hidden_states=feats['question_states'], | |
encoder_attention_mask=feats['question_atts'], | |
labels=feats['answer_targets'], | |
return_dict=True, | |
return_logits=True, # directly return logits, not computing loss | |
reduction='none', | |
) | |
return prediction_logits | |
def loss(self, feats: dict, data_samples=None): | |
"""Calculate losses from the extracted features. | |
Args: | |
feats (dict): The features extracted from the backbone. | |
data_samples (List[BaseDataElement]): The annotation data of | |
every samples. | |
Returns: | |
dict[str, Tensor]: a dictionary of loss components | |
""" | |
shifted_prediction_scores = self(feats) | |
labels = feats['answer_targets'] | |
lm_loss = None | |
# we are doing next-token prediction; | |
# shift prediction scores and input ids by one | |
labels = labels[:, 1:].contiguous() | |
lm_loss = self.loss_module( | |
shifted_prediction_scores.view(-1, | |
self.decoder.med_config.vocab_size), | |
labels.view(-1)) | |
lm_loss = lm_loss.view(shifted_prediction_scores.size(0), -1).sum(1) | |
# compute weighted loss | |
losses = dict() | |
loss = feats['answer_weight'] * lm_loss | |
loss = loss.sum() / feats['batch_size'] | |
losses['vqa_loss'] = loss | |
return losses | |
def predict_rank(self, feats: dict, data_samples=None): | |
"""Predict rank in a close-set answer list.""" | |
question_states = feats['multimodal_embeds'] | |
question_atts = feats['question_atts'] | |
answer_candidates = feats['answer_candidates'] | |
assert answer_candidates is not None | |
answer_ids = answer_candidates.input_ids | |
answer_atts = answer_candidates.attention_mask | |
num_ques = question_states.size(0) | |
start_ids = answer_ids[0, 0].repeat(num_ques, 1) # bos token | |
start_output = self.decoder( | |
start_ids, | |
encoder_hidden_states=question_states, | |
encoder_attention_mask=question_atts, | |
return_dict=True, | |
reduction='none', | |
) | |
logits = start_output.logits[:, 0, :] # first token's logit | |
# topk_probs: top-k probability | |
# topk_ids: [num_question, k] | |
answer_first_token = answer_ids[:, 1] | |
prob_first_token = F.softmax( | |
logits, dim=1).index_select( | |
dim=1, index=answer_first_token) | |
topk_probs, topk_ids = prob_first_token.topk( | |
self.num_ans_candidates, dim=1) | |
# answer input: [num_question*k, answer_len] | |
input_ids = [] | |
input_atts = [] | |
for b, topk_id in enumerate(topk_ids): | |
input_ids.append(answer_ids.index_select(dim=0, index=topk_id)) | |
input_atts.append(answer_atts.index_select(dim=0, index=topk_id)) | |
input_ids = torch.cat(input_ids, dim=0) | |
input_atts = torch.cat(input_atts, dim=0) | |
targets_ids = input_ids.masked_fill(input_ids == feats['pad_token_id'], | |
-100) | |
def tile(x, dim, n_tile): | |
init_dim = x.size(dim) | |
repeat_idx = [1] * x.dim() | |
repeat_idx[dim] = n_tile | |
x = x.repeat(*(repeat_idx)) | |
order_index = torch.LongTensor( | |
np.concatenate([ | |
init_dim * np.arange(n_tile) + i for i in range(init_dim) | |
])) | |
return torch.index_select(x, dim, order_index.to(x.device)) | |
# repeat encoder's output for top-k answers | |
question_states = tile(question_states, 0, self.num_ans_candidates) | |
question_atts = tile(question_atts, 0, self.num_ans_candidates) | |
output = self.decoder( | |
input_ids, | |
attention_mask=input_atts, | |
encoder_hidden_states=question_states, | |
encoder_attention_mask=question_atts, | |
labels=targets_ids, | |
return_dict=True, | |
reduction='none', | |
) | |
log_probs_sum = -output.loss | |
log_probs_sum = log_probs_sum.view(num_ques, self.num_ans_candidates) | |
max_topk_ids = log_probs_sum.argmax(dim=1) | |
max_ids = topk_ids[max_topk_ids >= 0, max_topk_ids] | |
answers = [self.answer_list[max_id] for max_id in max_ids] | |
return answers | |
def predict_generate(self, feats: dict, data_samples=None): | |
"""Predict answers in a generation manner.""" | |
device = feats['multimodal_embeds'].device | |
question_states = feats['multimodal_embeds'] | |
question_atts = torch.ones( | |
question_states.size()[:-1], dtype=torch.long).to(device) | |
model_kwargs = { | |
'encoder_hidden_states': question_states, | |
'encoder_attention_mask': question_atts | |
} | |
bos_ids = torch.full((feats['multimodal_embeds'].shape[0], 1), | |
fill_value=feats['bos_token_id'], | |
device=device) | |
outputs = self.decoder.generate( | |
input_ids=bos_ids, | |
max_length=10, | |
min_length=1, | |
num_beams=self.num_beams, | |
eos_token_id=feats['sep_token_id'], | |
pad_token_id=feats['pad_token_id'], | |
**model_kwargs) | |
return outputs | |
def predict(self, feats: dict, data_samples=None): | |
"""Predict results from the extracted features.""" | |
if self.inference_method == 'generate': | |
return self.predict_generate(feats, data_samples) | |
elif self.inference_method == 'rank': | |
return self.predict_rank(feats, data_samples) | |