Files changed (1) hide show
  1. tasks/mm_tasks/snli_ve.py +197 -0
tasks/mm_tasks/snli_ve.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 The OFA-Sys Team.
2
+ # All rights reserved.
3
+ # This source code is licensed under the Apache 2.0 license
4
+ # found in the LICENSE file in the root directory.
5
+
6
+ import json
7
+ import logging
8
+ import math
9
+ from dataclasses import dataclass, field
10
+ from typing import Optional
11
+
12
+ import torch
13
+ from fairseq import metrics
14
+ from fairseq.tasks import register_task
15
+
16
+ from tasks.ofa_task import OFAConfig, OFATask
17
+ from data.mm_data.snli_ve_dataset import SnliVeDataset
18
+ from data.file_dataset import FileDataset
19
+ from data import data_utils
20
+ from utils.trie import Trie
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ @dataclass
26
+ class SnliVeConfig(OFAConfig):
27
+ ans2label_dict: Optional[str] = field(
28
+ default='{"no": 0, "yes":1, "maybe": 2}',
29
+ metadata={"help": 'answer to label dict'},
30
+ )
31
+ add_caption: bool = field(
32
+ default=False,
33
+ metadata={"help": "add caption to encoder"},
34
+ )
35
+ valid_batch_size: int = field(
36
+ default=20,
37
+ metadata={"help": "valid batch size per step"},
38
+ )
39
+ prompt_type: Optional[str] = field(
40
+ default=None,
41
+ metadata={"help": "prompt_type"},
42
+ )
43
+
44
+
45
+ @register_task("snli_ve", dataclass=SnliVeConfig)
46
+ class SnliVeTask(OFATask):
47
+ def __init__(self, cfg: SnliVeConfig, src_dict, tgt_dict):
48
+ super().__init__(cfg, src_dict, tgt_dict)
49
+ self.ans2label_dict = json.loads(self.cfg.ans2label_dict)
50
+
51
+ def load_dataset(self, split, epoch=1, combine=False, **kwargs):
52
+ paths = self.cfg.data.split(',')
53
+ assert len(paths) > 0
54
+
55
+ if split == 'train':
56
+ file_path = paths[(epoch - 1) % (len(paths) - 1)]
57
+ else:
58
+ file_path = paths[-1]
59
+ dataset = FileDataset(file_path, self.cfg.selected_cols)
60
+
61
+ self.datasets[split] = SnliVeDataset(
62
+ split,
63
+ dataset,
64
+ self.bpe,
65
+ self.src_dict,
66
+ self.tgt_dict,
67
+ max_src_length=self.cfg.max_src_length,
68
+ max_tgt_length=self.cfg.max_tgt_length,
69
+ patch_image_size=self.cfg.patch_image_size,
70
+ add_caption=self.cfg.add_caption,
71
+ constraint_trie=self.constraint_trie,
72
+ imagenet_default_mean_and_std=self.cfg.imagenet_default_mean_and_std,
73
+ prompt_type=self.cfg.prompt_type
74
+ )
75
+
76
+ def build_model(self, cfg):
77
+ model = super().build_model(cfg)
78
+ answer_item_list = []
79
+ self.index2ans = {}
80
+ self.constraint_trie = Trie(self.tgt_dict.eos())
81
+ for i, answer in enumerate(self.ans2label_dict.keys()):
82
+ answer_item = self.tgt_dict.encode_line(
83
+ line=self.bpe.encode(' ' + answer),
84
+ add_if_not_exist=False,
85
+ append_eos=False
86
+ ).long()
87
+ answer_item_list.append(answer_item)
88
+ self.index2ans[i] = answer
89
+ self.constraint_trie.insert([self.tgt_dict.bos()] + answer_item.tolist() + [self.tgt_dict.eos()])
90
+
91
+ constraint_mask_list = []
92
+ for answer_item in answer_item_list:
93
+ constraint_mask = torch.zeros((len(answer_item)+1, len(self.tgt_dict))).bool()
94
+ for i in range(len(answer_item)+1):
95
+ constraint_prefix_token = [self.src_dict.bos()] + answer_item[:i].tolist()
96
+ constraint_nodes = self.constraint_trie.get_next_layer(constraint_prefix_token)
97
+ constraint_mask[i][constraint_nodes] = True
98
+ constraint_mask_list.append(constraint_mask)
99
+
100
+ self.valid_answers_list = []
101
+ self.valid_constraint_masks_list = []
102
+ for i in range(0, len(answer_item_list), self.cfg.valid_batch_size):
103
+ self.valid_answers_list += [answer_item_list[i:i+self.cfg.valid_batch_size]]
104
+ self.valid_constraint_masks_list += [constraint_mask_list[i:i+self.cfg.valid_batch_size]]
105
+
106
+ return model
107
+
108
+ def build_generator(
109
+ self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, prefix_allowed_tokens_fn=None,
110
+ ):
111
+ seq_generator = super().build_generator(models, args, seq_gen_cls, extra_gen_cls_kwargs, prefix_allowed_tokens_fn)
112
+ seq_generator.constraint_trie = self.constraint_trie
113
+
114
+ return seq_generator
115
+
116
+ def valid_step(self, sample, model, criterion, **extra_kwargs):
117
+ loss, sample_size, logging_output = super().valid_step(sample, model, criterion)
118
+
119
+ model.eval()
120
+ with torch.no_grad():
121
+ encoder_out = model.encoder(
122
+ sample["net_input"]["src_tokens"],
123
+ src_lengths=sample["net_input"]["src_lengths"],
124
+ patch_images=sample["net_input"]["patch_images"],
125
+ patch_masks=sample["net_input"]["patch_masks"]
126
+ )
127
+ device = sample["net_input"]["src_tokens"].device
128
+ eos_item = torch.tensor([self.src_dict.eos()])
129
+ pad = self.src_dict.pad()
130
+ valid_result = []
131
+ for valid_answers, valid_constraint_masks in zip(self.valid_answers_list, self.valid_constraint_masks_list):
132
+ valid_size = len(valid_answers)
133
+ valid_tgt_items = [
134
+ torch.cat([torch.tensor(decoder_prompt[1:]), valid_answer, eos_item])
135
+ for decoder_prompt in sample["decoder_prompts"] for valid_answer in valid_answers
136
+ ]
137
+ valid_prev_items = [
138
+ torch.cat([torch.tensor(decoder_prompt), valid_answer])
139
+ for decoder_prompt in sample["decoder_prompts"] for valid_answer in valid_answers
140
+ ]
141
+ valid_constraint_mask_items = [
142
+ torch.cat([torch.zeros(len(decoder_prompt)-1, valid_constraint_mask.size(1)).bool(), valid_constraint_mask], dim=0)
143
+ for decoder_prompt in sample["decoder_prompts"] for valid_constraint_mask in valid_constraint_masks
144
+ ]
145
+ valid_tgt = data_utils.collate_tokens(valid_tgt_items, pad_idx=pad, left_pad=False).to(device)
146
+ valid_prev_output = data_utils.collate_tokens(valid_prev_items, pad_idx=pad, left_pad=False).to(device)
147
+ valid_constraint_masks = data_utils.collate_tokens(valid_constraint_mask_items, pad_idx=pad, left_pad=False).to(device)
148
+
149
+ new_encoder_out = {}
150
+ new_encoder_out["encoder_out"] = [
151
+ encoder_out["encoder_out"][0].repeat_interleave(valid_size, dim=1)
152
+ ]
153
+ new_encoder_out["encoder_padding_mask"] = [
154
+ encoder_out["encoder_padding_mask"][0].repeat_interleave(valid_size, dim=0)
155
+ ]
156
+ new_encoder_out["position_embeddings"] = [
157
+ encoder_out["position_embeddings"][0].repeat_interleave(valid_size, dim=0)
158
+ ]
159
+
160
+ decoder_out = model.decoder(valid_prev_output, encoder_out=new_encoder_out)
161
+ decoder_out[0].masked_fill_(~valid_constraint_masks, -math.inf)
162
+ lprobs = model.get_normalized_probs(decoder_out, log_probs=True)
163
+ scores = lprobs.gather(dim=-1, index=valid_tgt.unsqueeze(-1)).squeeze(-1)
164
+ scores = scores.masked_fill(valid_tgt.eq(self.tgt_dict.pad()), 0)
165
+ scores = scores.masked_fill((~valid_constraint_masks).all(2), 0)
166
+ scores = scores.sum(1)
167
+ scores = scores.view(-1, valid_size)
168
+ valid_result.append(scores)
169
+
170
+ valid_result = torch.cat(valid_result, dim=-1)
171
+ predicts = valid_result.argmax(1).tolist()
172
+ hyps = [self.index2ans[predict_index] for predict_index in predicts]
173
+ scores = [ref_dict.get(hyp, 0) for ref_dict, hyp in zip(sample['ref_dict'], hyps)]
174
+ logging_output["_snli_score_sum"] = sum(scores)
175
+ logging_output["_snli_cnt"] = len(scores)
176
+
177
+ return loss, sample_size, logging_output
178
+
179
+ def reduce_metrics(self, logging_outputs, criterion):
180
+ super().reduce_metrics(logging_outputs, criterion)
181
+
182
+ def sum_logs(key):
183
+ import torch
184
+ result = sum(log.get(key, 0) for log in logging_outputs)
185
+ if torch.is_tensor(result):
186
+ result = result.cpu()
187
+ return result
188
+
189
+ def compute_score(meters):
190
+ score = meters["_snli_score_sum"].sum / meters["_snli_cnt"].sum
191
+ score = score if isinstance(score, float) else score.item()
192
+ return round(score, 4)
193
+
194
+ if sum_logs("_snli_cnt") > 0:
195
+ metrics.log_scalar("_snli_score_sum", sum_logs("_snli_score_sum"))
196
+ metrics.log_scalar("_snli_cnt", sum_logs("_snli_cnt"))
197
+ metrics.log_derived("snli_score", compute_score)