OFA-Generic_Interface / data /mm_data /refcoco_dataset.py
guidel's picture
Duplicate from OFA-Sys/OFA-Generic_Interface
8c90e7d
raw
history blame
6.11 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from io import BytesIO
import logging
import warnings
import numpy as np
import torch
import base64
import utils.transforms as T
from PIL import Image, ImageFile
from data import data_utils
from data.ofa_dataset import OFADataset
ImageFile.LOAD_TRUNCATED_IMAGES = True
ImageFile.MAX_IMAGE_PIXELS = None
Image.MAX_IMAGE_PIXELS = None
logger = logging.getLogger(__name__)
warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
def collate(samples, pad_idx, eos_idx):
if len(samples) == 0:
return {}
def merge(key):
return data_utils.collate_tokens(
[s[key] for s in samples],
pad_idx,
eos_idx=eos_idx,
)
id = np.array([s["id"] for s in samples])
src_tokens = merge("source")
src_lengths = torch.LongTensor([s["source"].ne(pad_idx).long().sum() for s in samples])
patch_images = torch.stack([sample['patch_image'] for sample in samples], dim=0)
patch_masks = torch.cat([sample['patch_mask'] for sample in samples])
w_resize_ratios = torch.stack([s["w_resize_ratio"] for s in samples], dim=0)
h_resize_ratios = torch.stack([s["h_resize_ratio"] for s in samples], dim=0)
region_coords = torch.stack([s['region_coord'] for s in samples], dim=0)
prev_output_tokens = None
target = None
if samples[0].get("target", None) is not None:
target = merge("target")
tgt_lengths = torch.LongTensor([s["target"].ne(pad_idx).long().sum() for s in samples])
ntokens = tgt_lengths.sum().item()
if samples[0].get("prev_output_tokens", None) is not None:
prev_output_tokens = merge("prev_output_tokens")
else:
ntokens = src_lengths.sum().item()
batch = {
"id": id,
"nsentences": len(samples),
"ntokens": ntokens,
"net_input": {
"src_tokens": src_tokens,
"src_lengths": src_lengths,
"patch_images": patch_images,
"patch_masks": patch_masks,
"prev_output_tokens": prev_output_tokens
},
"target": target,
"w_resize_ratios": w_resize_ratios,
"h_resize_ratios": h_resize_ratios,
"region_coords": region_coords
}
return batch
class RefcocoDataset(OFADataset):
def __init__(
self,
split,
dataset,
bpe,
src_dict,
tgt_dict=None,
max_src_length=80,
max_tgt_length=30,
patch_image_size=512,
imagenet_default_mean_and_std=False,
num_bins=1000,
max_image_size=512
):
super().__init__(split, dataset, bpe, src_dict, tgt_dict)
self.max_src_length = max_src_length
self.max_tgt_length = max_tgt_length
self.patch_image_size = patch_image_size
self.num_bins = num_bins
if imagenet_default_mean_and_std:
mean = IMAGENET_DEFAULT_MEAN
std = IMAGENET_DEFAULT_STD
else:
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]
# for positioning
self.positioning_transform = T.Compose([
T.RandomResize([patch_image_size], max_size=patch_image_size),
T.ToTensor(),
T.Normalize(mean=mean, std=std, max_image_size=max_image_size)
])
def __getitem__(self, index):
uniq_id, base64_str, text, region_coord = self.dataset[index]
image = Image.open(BytesIO(base64.urlsafe_b64decode(base64_str))).convert("RGB")
w, h = image.size
boxes_target = {"boxes": [], "labels": [], "area": [], "size": torch.tensor([h, w])}
x0, y0, x1, y1 = region_coord.strip().split(',')
region = torch.tensor([float(x0), float(y0), float(x1), float(y1)])
boxes_target["boxes"] = torch.tensor([[float(x0), float(y0), float(x1), float(y1)]])
boxes_target["labels"] = np.array([0])
boxes_target["area"] = torch.tensor([(float(x1) - float(x0)) * (float(y1) - float(y0))])
patch_image, patch_boxes = self.positioning_transform(image, boxes_target)
resize_h, resize_w = patch_boxes["size"][0], patch_boxes["size"][1]
patch_mask = torch.tensor([True])
quant_x0 = "<bin_{}>".format(int((patch_boxes["boxes"][0][0] * (self.num_bins - 1)).round()))
quant_y0 = "<bin_{}>".format(int((patch_boxes["boxes"][0][1] * (self.num_bins - 1)).round()))
quant_x1 = "<bin_{}>".format(int((patch_boxes["boxes"][0][2] * (self.num_bins - 1)).round()))
quant_y1 = "<bin_{}>".format(int((patch_boxes["boxes"][0][3] * (self.num_bins - 1)).round()))
region_coord = "{} {} {} {}".format(quant_x0, quant_y0, quant_x1, quant_y1)
src_caption = self.pre_caption(text, self.max_src_length)
src_item = self.encode_text(' which region does the text " {} " describe?'.format(src_caption))
tgt_item = self.encode_text(region_coord, use_bpe=False)
src_item = torch.cat([self.bos_item, src_item, self.eos_item])
target_item = torch.cat([tgt_item, self.eos_item])
prev_output_item = torch.cat([self.bos_item, tgt_item])
example = {
"id": uniq_id,
"source": src_item,
"patch_image": patch_image,
"patch_mask": patch_mask,
"target": target_item,
"prev_output_tokens": prev_output_item,
"w_resize_ratio": resize_w / w,
"h_resize_ratio": resize_h / h,
"region_coord": region
}
return example
def collater(self, samples, pad_to_length=None):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch with the following keys:
"""
return collate(samples, pad_idx=self.pad, eos_idx=self.eos)