from torch.utils.data import Dataset from PIL import Image import torch import json import h5py import bisect CAPTION_LENGTH = 25 SIMPLE_PREFIX = "This image shows " def prep_strings(text, tokenizer, template=None, retrieved_caps=None, k=None, is_test=False, max_length=None): if is_test: padding = False truncation = False else: padding = True truncation = True if retrieved_caps is not None: infix = '\n\n'.join(retrieved_caps[:k]) + '.' prefix = template.replace('||', infix) else: prefix = SIMPLE_PREFIX prefix_ids = tokenizer.encode(prefix) len_prefix = len(prefix_ids) text_ids = tokenizer.encode(text, add_special_tokens=False) if truncation: text_ids = text_ids[:CAPTION_LENGTH] input_ids = prefix_ids + text_ids if not is_test else prefix_ids # we ignore the prefix (minus one as the first subtoken in the prefix is not predicted) label_ids = [-100] * (len_prefix - 1) + text_ids + [tokenizer.eos_token_id] if padding: input_ids += [tokenizer.pad_token_id] * (max_length - len(input_ids)) label_ids += [-100] * (max_length - len(label_ids)) if is_test: return input_ids else: return input_ids, label_ids def postprocess_preds(pred, tokenizer): pred = pred.split(SIMPLE_PREFIX)[-1] pred = pred.replace(tokenizer.pad_token, '') if pred.startswith(tokenizer.bos_token): pred = pred[len(tokenizer.bos_token):] if pred.endswith(tokenizer.eos_token): pred = pred[:-len(tokenizer.eos_token)] return pred class TrainDataset(Dataset): def __init__(self, df, features_path, tokenizer, rag=False, template_path=None, k=None, max_caption_length=25): self.df = df self.tokenizer = tokenizer self.features = h5py.File(features_path, 'r') if rag: self.template = open(template_path).read().strip() + ' ' self.max_target_length = (max_caption_length # target caption + max_caption_length * k # retrieved captions + len(tokenizer.encode(self.template)) # template + len(tokenizer.encode('\n\n')) * (k-1) # separator between captions ) assert k is not None self.k = k self.rag = rag def __len__(self): return len(self.df) def __getitem__(self, idx): text = self.df['text'][idx] if self.rag: caps = self.df['caps'][idx] decoder_input_ids, labels = prep_strings(text, self.tokenizer, template=self.template, retrieved_caps=caps, k=self.k, max_length=self.max_target_length) else: decoder_input_ids, labels = prep_strings(text, self.tokenizer, max_length=self.max_target_length) # load precomputed features encoder_outputs = self.features[self.df['cocoid'][idx]][()] encoding = {"encoder_outputs": torch.tensor(encoder_outputs), "decoder_input_ids": torch.tensor(decoder_input_ids), "labels": torch.tensor(labels)} return encoding def load_data_for_training(annot_path, caps_path=None): annotations = json.load(open(annot_path))['images'] if caps_path is not None: retrieved_caps = json.load(open(caps_path)) data = {'train': [], 'val': []} for item in annotations: file_name = item['filename'].split('_')[-1] caps = retrieved_caps[str(item['cocoid'])] samples = [] for sentence in item['sentences']: print("how are the retrieved caps", caps + ' '.join(sentence['tokens'])) samples.append({'file_name': file_name, 'cocoid': str(item['cocoid']), 'caps': None, 'text': " ".join(caps) + ' '.join(sentence['tokens'])}) if item['split'] == 'train' or item['split'] == 'restval': data['train'] += samples elif item['split'] == 'val': data['val'] += samples return data def load_data_for_inference(annot_path, caps_path=None): annotations = json.load(open(annot_path))['images'] if caps_path is not None: retrieved_caps = json.load(open(caps_path)) data = {'test': [], 'val': []} for item in annotations: file_name = item['filename'].split('_')[-1] if caps_path is not None: caps = retrieved_caps[str(item['cocoid'])] else: caps = None image = {'file_name': file_name, 'caps': caps, 'image_id': str(item['cocoid'])} if item['split'] == 'test': data['test'].append(image) elif item['split'] == 'val': data['val'].append(image) return data