|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import json |
|
import random |
|
import torch |
|
import glob |
|
from collections import defaultdict, Counter |
|
from torchvision import transforms |
|
from torchvision.datasets.folder import default_loader |
|
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD |
|
from timm.data.transforms import RandomResizedCropAndInterpolation |
|
from timm.data import create_transform |
|
|
|
import utils |
|
from glossary import normalize_word |
|
from randaug import RandomAugment |
|
|
|
|
|
class BaseDataset(torch.utils.data.Dataset): |
|
def __init__( |
|
self, data_path, split, transform, |
|
tokenizer, num_max_bpe_tokens, task=None, |
|
): |
|
index_files = self.get_index_files(split, task=task) |
|
self.tokenizer = tokenizer |
|
self.num_max_bpe_tokens = num_max_bpe_tokens |
|
self.data_path = data_path |
|
items = [] |
|
self.index_files = index_files |
|
|
|
offset = 0 |
|
for _index_file in index_files: |
|
index_file = os.path.join(data_path, _index_file) |
|
with open(index_file, mode="r", encoding="utf-8") as reader: |
|
for line in reader: |
|
data = json.loads(line) |
|
items.append(data) |
|
print("Load %d image-text pairs from %s. " % (len(items) - offset, index_file)) |
|
offset = len(items) |
|
self.items = items |
|
self.bos_token_id = tokenizer.bos_token_id |
|
self.eos_token_id = tokenizer.eos_token_id |
|
self.pad_token_id = tokenizer.pad_token_id |
|
self.loader = default_loader |
|
self.transform = transform |
|
self.split = split |
|
|
|
@staticmethod |
|
def get_index_files(split): |
|
raise NotImplementedError() |
|
|
|
def _get_image(self, image_path: str): |
|
image_path = os.path.join(self.data_path, image_path) |
|
image = self.loader(image_path) |
|
return self.transform(image) |
|
|
|
def _get_text_segment(self, text_segment, max_len=None): |
|
if isinstance(text_segment, str): |
|
tokens = self.tokenizer.tokenize(text_segment) |
|
else: |
|
tokens = text_segment[:] |
|
if len(tokens) == 0: |
|
raise RuntimeError("The text segment should contains at least one tokens!") |
|
if max_len is None: |
|
max_len = self.num_max_bpe_tokens |
|
|
|
if len(tokens) > max_len - 2: |
|
tokens = tokens[:max_len - 2] |
|
|
|
tokens = [self.bos_token_id] + tokens[:] + [self.eos_token_id] |
|
num_tokens = len(tokens) |
|
padding_mask = [0] * num_tokens + [1] * (max_len - num_tokens) |
|
return tokens + [self.pad_token_id] * (max_len - num_tokens), padding_mask, num_tokens |
|
|
|
def _get_image_text_example(self, index: int, data: dict): |
|
item = self.items[index] |
|
img_path = item["image_path"] |
|
img = self._get_image(img_path) |
|
data["image"] = img |
|
|
|
text_segment = item["text_segment"] |
|
language_tokens, padding_mask, _ = self._get_text_segment(text_segment) |
|
data["language_tokens"] = language_tokens |
|
data["padding_mask"] = padding_mask |
|
|
|
def __getitem__(self, index: int): |
|
data = dict() |
|
self._get_image_text_example(index, data) |
|
return data |
|
|
|
def __len__(self) -> int: |
|
return len(self.items) |
|
|
|
def __repr__(self) -> str: |
|
head = "Dataset " + self.__class__.__name__ |
|
body = '{' + "\n Number of items: %s," % self.__len__() |
|
body += "\n data root = %s," % self.data_path |
|
body += "\n split = %s," % self.split |
|
body += "\n dataset index files = %s" % str(self.index_files) |
|
body += "\n num max bpe tokens = %s" % self.num_max_bpe_tokens |
|
body += "\n transforms = [" |
|
for t in self.transform.transforms: |
|
body += "\n %s" % str(t) |
|
body += "\n ]" |
|
body += "\n}" |
|
|
|
return head + body |
|
|
|
|
|
def _write_data_into_jsonl(items, jsonl_file): |
|
with open(jsonl_file, mode="w", encoding="utf-8") as writer: |
|
for data in items: |
|
writer.write(json.dumps(data, indent=None)) |
|
writer.write('\n') |
|
print("Write %s with %d items !" % (jsonl_file, len(items))) |
|
|
|
|
|
def _make_retrieval_coco_karpathy_dataset_index( |
|
data_path, |
|
tokenizer, |
|
split=("train", "restval"), |
|
split_name="train", |
|
): |
|
coco_karpathy_split_json_file = os.path.join(data_path, "dataset_coco.json") |
|
items = [] |
|
image_counter = set() |
|
print("read %s" % coco_karpathy_split_json_file) |
|
with open(coco_karpathy_split_json_file, mode="r", encoding="utf-8") as reader: |
|
data = json.loads(reader.read()) |
|
for item in data["images"]: |
|
if item["split"] in split: |
|
image_path = os.path.join(item["filepath"], item["filename"]) |
|
for sent in item["sentences"]: |
|
tokens = tokenizer.tokenize(sent["raw"]) |
|
token_ids = tokenizer.convert_tokens_to_ids(tokens) |
|
items.append({ |
|
"image_path": image_path, |
|
"text_segment": token_ids, |
|
"image_id": len(image_counter), |
|
}) |
|
if image_path not in image_counter: |
|
image_counter.add(image_path) |
|
print("Find %d images and %d image-text pairs for karpathy dataset %s split !" % \ |
|
(len(image_counter), len(items), split_name)) |
|
index_file = os.path.join(data_path, "coco_retrieval.%s.jsonl" % split_name) |
|
_write_data_into_jsonl(items, index_file) |
|
pass |
|
|
|
|
|
def _make_captioning_coco_karpathy_dataset_index( |
|
data_path, |
|
tokenizer, |
|
split=("train", "restval"), |
|
split_name="train", |
|
): |
|
coco_karpathy_split_json_file = os.path.join(data_path, "dataset_coco.json") |
|
items = [] |
|
image_counter = set() |
|
print("read %s" % coco_karpathy_split_json_file) |
|
with open(coco_karpathy_split_json_file, mode="r", encoding="utf-8") as reader: |
|
data = json.loads(reader.read()) |
|
for item in data["images"]: |
|
if item["split"] in split: |
|
image_path = os.path.join(item["filepath"], item["filename"]) |
|
if item["split"] in ["train", "restval"]: |
|
for sent in item["sentences"]: |
|
tokens = tokenizer.tokenize(sent["raw"]) |
|
token_ids = tokenizer.convert_tokens_to_ids(tokens) |
|
items.append({ |
|
"image_path": image_path, |
|
"text_segment": token_ids, |
|
"image_id": item["cocoid"], |
|
}) |
|
else: |
|
items.append({ |
|
"image_path": image_path, |
|
"text_segment": None, |
|
"image_id": item["cocoid"], |
|
}) |
|
if image_path not in image_counter: |
|
image_counter.add(image_path) |
|
print("Find %d images and %d image-text pairs for karpathy dataset %s split !" % \ |
|
(len(image_counter), len(items), split_name)) |
|
index_file = os.path.join(data_path, "coco_captioning.%s.jsonl" % split_name) |
|
_write_data_into_jsonl(items, index_file) |
|
pass |
|
|
|
|
|
def _make_nocaps_dataset_index( |
|
data_path, |
|
split="val", |
|
): |
|
if split == "val": |
|
json_file = "nocaps_val_4500_captions.json" |
|
elif split == "test": |
|
json_file = "nocaps_test_image_info.json" |
|
nocaps_split_json_file = os.path.join(data_path, json_file) |
|
items = [] |
|
image_counter = set() |
|
print("read %s" % nocaps_split_json_file) |
|
with open(nocaps_split_json_file, mode="r", encoding="utf-8") as reader: |
|
data = json.loads(reader.read()) |
|
for item in data["images"]: |
|
image_path = os.path.join(split, item["file_name"]) |
|
items.append({ |
|
"image_path": image_path, |
|
"text_segment": None, |
|
"image_id": item["id"], |
|
}) |
|
|
|
if image_path not in image_counter: |
|
image_counter.add(image_path) |
|
|
|
print("Find %d images and %d image-text pairs for nocaps dataset %s split !" % \ |
|
(len(image_counter), len(items), split)) |
|
index_file = os.path.join(data_path, "nocaps.%s.jsonl" % split) |
|
_write_data_into_jsonl(items, index_file) |
|
|
|
|
|
class NLVR2Dataset(BaseDataset): |
|
@staticmethod |
|
def get_index_files(split, task=None): |
|
if split == "train": |
|
return ("nlvr2.train.index.jsonl", ) |
|
elif split == "val": |
|
return ("nlvr2.dev.index.jsonl", ) |
|
elif split == "test": |
|
return ("nlvr2.test-P.index.jsonl", ) |
|
else: |
|
raise RuntimeError("split %s is not found!" % split) |
|
|
|
def __getitem__(self, index: int): |
|
data = super().__getitem__(index) |
|
item = self.items[index] |
|
img_path = item["image2_path"] |
|
img = self._get_image(img_path) |
|
data["image2"] = img |
|
data["label"] = self.items[index]["label"] |
|
return data |
|
|
|
@staticmethod |
|
def __preprocess_json(preifx, json_file, tokenizer, index_file): |
|
items = [] |
|
with open(json_file, mode="r", encoding="utf-8") as reader: |
|
for line in reader: |
|
data = json.loads(line) |
|
path = os.path.join(preifx, str(data["directory"])) if "directory" in data else preifx |
|
path = os.path.join(path, "-".join(data["identifier"].split("-")[:-1])) |
|
tokens = tokenizer.tokenize(data["sentence"]) |
|
token_ids = tokenizer.convert_tokens_to_ids(tokens) |
|
items.append({ |
|
"image_path": path + "-img0.png", |
|
"image2_path": path + "-img1.png", |
|
"text_segment": token_ids, |
|
"label": 1 if data["label"] == "True" else 0, |
|
"identifier": data["identifier"], |
|
}) |
|
_write_data_into_jsonl(items, index_file) |
|
|
|
@classmethod |
|
def make_dataset_index(cls, data_path, tokenizer, nlvr_repo_path): |
|
cls.__preprocess_json( |
|
preifx="images/train", json_file=os.path.join(nlvr_repo_path, "nlvr2/data/train.json"), |
|
tokenizer=tokenizer, index_file=os.path.join(data_path, cls.get_index_files("train")[0]), |
|
) |
|
cls.__preprocess_json( |
|
preifx="dev", json_file=os.path.join(nlvr_repo_path, "nlvr2/data/dev.json"), |
|
tokenizer=tokenizer, index_file=os.path.join(data_path, cls.get_index_files("val")[0]), |
|
) |
|
cls.__preprocess_json( |
|
preifx="test1", json_file=os.path.join(nlvr_repo_path, "nlvr2/data/test1.json"), |
|
tokenizer=tokenizer, index_file=os.path.join(data_path, cls.get_index_files("test")[0]), |
|
) |
|
|
|
|
|
class ImageNetDataset(BaseDataset): |
|
@staticmethod |
|
def get_index_files(split, task=None): |
|
if split == "train": |
|
return ("imagenet.train.index.jsonl", ) |
|
elif split == "val": |
|
return ("imagenet.val.index.jsonl", ) |
|
elif split == "test": |
|
return ("imagenet.val.index.jsonl", ) |
|
else: |
|
raise RuntimeError("split %s is not found!" % split) |
|
|
|
def __getitem__(self, index: int): |
|
data = dict() |
|
item = self.items[index] |
|
img_path = item["image_path"] |
|
img = self._get_image(img_path) |
|
data["image"] = img |
|
data["label"] = item["label"] |
|
return data |
|
|
|
@staticmethod |
|
def _find_classes(dir): |
|
""" |
|
Finds the class folders in a dataset. |
|
Args: |
|
dir (string): Root directory path. |
|
Returns: |
|
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. |
|
Ensures: |
|
No class is a subdirectory of another. |
|
""" |
|
classes = [d.name for d in os.scandir(dir) if d.is_dir()] |
|
classes.sort() |
|
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} |
|
return classes, class_to_idx |
|
|
|
@staticmethod |
|
def _make_imagenet_index(data_path, index_path, data_path_prefix, class_to_idx, split): |
|
items = [] |
|
index_file = os.path.join(index_path, f"imagenet.{split}.index.jsonl") |
|
for target_class in sorted(class_to_idx.keys()): |
|
class_index = class_to_idx[target_class] |
|
target_dir = os.path.join(data_path, target_class) |
|
if not os.path.isdir(target_dir): |
|
continue |
|
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): |
|
for fname in sorted(fnames): |
|
path = os.path.join(root, fname) |
|
path = path.replace(data_path_prefix, "") |
|
items.append({ |
|
"image_path": path, |
|
"label": class_index, |
|
}) |
|
|
|
_write_data_into_jsonl(items, index_file) |
|
|
|
@classmethod |
|
def make_dataset_index(cls, train_data_path, val_data_path, index_path): |
|
data_path_prefix = train_data_path[:[x[0]==x[1] for x in zip(train_data_path, val_data_path)].index(0)] |
|
classes, class_to_idx = cls._find_classes(train_data_path) |
|
cls._make_imagenet_index( |
|
data_path=train_data_path, index_path=index_path, data_path_prefix=data_path_prefix, |
|
class_to_idx=class_to_idx, split="train", |
|
) |
|
cls._make_imagenet_index( |
|
data_path=val_data_path, index_path=index_path, data_path_prefix=data_path_prefix, |
|
class_to_idx=class_to_idx, split="val", |
|
) |
|
|
|
|
|
class VQAv2Dataset(BaseDataset): |
|
def __init__(self, data_path, **kwargs): |
|
super().__init__(data_path=data_path, **kwargs) |
|
ans2label_file = os.path.join(data_path, "answer2label.txt") |
|
ans2label = {} |
|
label2ans = [] |
|
with open(ans2label_file, mode="r", encoding="utf-8") as reader: |
|
for i, line in enumerate(reader): |
|
data = json.loads(line) |
|
ans = data["answer"] |
|
label = data["label"] |
|
label = int(label) |
|
assert label == i |
|
ans2label[ans] = i |
|
label2ans.append(ans) |
|
|
|
self.ans2label = ans2label |
|
self.label2ans = label2ans |
|
|
|
@staticmethod |
|
def get_index_files(split, task=None): |
|
if split == "train": |
|
return ("vqa.train.jsonl", "vqa.trainable_val.jsonl") |
|
elif split == "val": |
|
return ("vqa.rest_val.jsonl", ) |
|
elif split == "test": |
|
return ("vqa.test.jsonl", ) |
|
elif split == "test-dev": |
|
return ("vqa.test-dev.jsonl", ) |
|
else: |
|
raise RuntimeError("split %s is not found!" % split) |
|
|
|
def __getitem__(self, index: int): |
|
data = super().__getitem__(index) |
|
if "labels" in self.items[index] and len(self.items[index]["labels"]) > 0: |
|
labels = [0.] * len(self.label2ans) |
|
for l, s in zip(self.items[index]["labels"], self.items[index]["scores"]): |
|
labels[l] = s |
|
data["labels"] = torch.FloatTensor(labels) |
|
else: |
|
data["qid"] = self.items[index]["qid"] |
|
return data |
|
|
|
@staticmethod |
|
def get_score(occurences): |
|
if occurences == 0: |
|
return 0.0 |
|
elif occurences == 1: |
|
return 0.3 |
|
elif occurences == 2: |
|
return 0.6 |
|
elif occurences == 3: |
|
return 0.9 |
|
else: |
|
return 1.0 |
|
|
|
@classmethod |
|
def make_dataset_index(cls, data_path, tokenizer, annotation_data_path): |
|
with open(os.path.join(annotation_data_path, "v2_OpenEnded_mscoco_train2014_questions.json"), "r") as fp: |
|
questions_train2014 = json.load(fp)["questions"] |
|
with open(os.path.join(annotation_data_path, "v2_OpenEnded_mscoco_val2014_questions.json"), "r") as fp: |
|
questions_val2014 = json.load(fp)["questions"] |
|
with open(os.path.join(annotation_data_path, "v2_OpenEnded_mscoco_test2015_questions.json"), "r") as fp: |
|
questions_test2015 = json.load(fp)["questions"] |
|
with open(os.path.join(annotation_data_path, "v2_OpenEnded_mscoco_test-dev2015_questions.json"), "r") as fp: |
|
questions_test_dev2015 = json.load(fp)["questions"] |
|
|
|
with open(os.path.join(annotation_data_path, "v2_mscoco_train2014_annotations.json"), "r") as fp: |
|
annotations_train2014 = json.load(fp)["annotations"] |
|
with open(os.path.join(annotation_data_path, "v2_mscoco_val2014_annotations.json"), "r") as fp: |
|
annotations_val2014 = json.load(fp)["annotations"] |
|
|
|
annotations = dict() |
|
|
|
for split, questions in zip( |
|
["train", "val", "test", "test-dev"], |
|
[questions_train2014, questions_val2014, questions_test2015, questions_test_dev2015], |
|
): |
|
_annot = defaultdict(dict) |
|
for q in questions: |
|
question_text = q["question"] |
|
tokens = tokenizer.tokenize(question_text) |
|
token_ids = tokenizer.convert_tokens_to_ids(tokens) |
|
|
|
assert q["question_id"] not in _annot[q["image_id"]] |
|
_annot[q["image_id"]][q["question_id"]] = { |
|
"question": question_text, |
|
"token_ids": token_ids, |
|
} |
|
|
|
annotations[split] = _annot |
|
|
|
all_major_answers = list() |
|
|
|
for split, annots in zip( |
|
["train", "val"], [annotations_train2014, annotations_val2014], |
|
): |
|
|
|
for q in annots: |
|
all_major_answers.append(q["multiple_choice_answer"]) |
|
|
|
all_major_answers = [normalize_word(word) for word in all_major_answers] |
|
counter = {k: v for k, v in Counter(all_major_answers).items() if v >= 9} |
|
ans2label = {k: i for i, k in enumerate(counter.keys())} |
|
label2ans = list(counter.keys()) |
|
|
|
for split, annots in zip( |
|
["train", "val"], [annotations_train2014, annotations_val2014], |
|
): |
|
_annot = annotations[split] |
|
for q in annots: |
|
answers = q["answers"] |
|
answer_count = {} |
|
for answer in answers: |
|
answer_ = answer["answer"] |
|
answer_count[answer_] = answer_count.get(answer_, 0) + 1 |
|
|
|
labels = [] |
|
scores = [] |
|
for answer in answer_count: |
|
if answer not in ans2label: |
|
continue |
|
labels.append(ans2label[answer]) |
|
score = cls.get_score(answer_count[answer]) |
|
scores.append(score) |
|
|
|
assert "labels" not in _annot[q["image_id"]][q["question_id"]] |
|
assert "question" in _annot[q["image_id"]][q["question_id"]] |
|
_annot[q["image_id"]][q["question_id"]]["labels"] = labels |
|
_annot[q["image_id"]][q["question_id"]]["scores"] = scores |
|
|
|
for split in ["train", "val"]: |
|
filtered_annot = dict() |
|
for ik, iv in annotations[split].items(): |
|
new_q = dict() |
|
for qk, qv in iv.items(): |
|
if len(qv["labels"]) != 0: |
|
new_q[qk] = qv |
|
if len(new_q) != 0: |
|
filtered_annot[ik] = new_q |
|
annotations[split] = filtered_annot |
|
|
|
split2items = {} |
|
for split in ["train", "val", "test", "test-dev"]: |
|
annot = annotations[split] |
|
split_name = { |
|
"train": "train2014", |
|
"val": "val2014", |
|
"test": "test2015", |
|
"test-dev": "test2015", |
|
}[split] |
|
paths = list(glob.glob(f"{data_path}/{split_name}/*.jpg")) |
|
random.shuffle(paths) |
|
annot_paths = [path for path in paths \ |
|
if int(path.split("/")[-1].split("_")[-1][:-4]) in annot] |
|
|
|
if len(paths) == len(annot_paths): |
|
print("all images have caption annotations") |
|
else: |
|
print("not all images have caption annotations") |
|
print(len(paths), len(annot_paths), len(annot)) |
|
|
|
items = [] |
|
for path in annot_paths: |
|
iid = int(path.split("/")[-1].split("_")[-1][:-4]) |
|
_annot = annotations[split][iid] |
|
for qid in _annot: |
|
q = _annot[qid] |
|
if split in ["train", "val"]: |
|
labels = q["labels"] |
|
scores = q["scores"] |
|
else: |
|
labels, scores = [], [] |
|
|
|
items.append({ |
|
"image_path": os.path.join(split_name, path.split('/')[-1]), |
|
"text_segment": q["token_ids"], |
|
"labels": labels, |
|
"scores": scores, |
|
"qid": qid, |
|
}) |
|
split2items[split] = items |
|
|
|
_write_data_into_jsonl(items=items, jsonl_file=os.path.join(data_path, "vqa.%s.jsonl" % split)) |
|
|
|
|
|
val_image2items = defaultdict(list) |
|
for item in split2items["val"]: |
|
val_image2items[item["image_path"]].append(item) |
|
|
|
print("Contains %d image and %d pairs for val set!" % (len(val_image2items), len(split2items["val"]))) |
|
|
|
val_images = list(val_image2items.keys()) |
|
random.shuffle(val_images) |
|
trainable_val = [] |
|
rest_val = [] |
|
for i, image_id in enumerate(val_images): |
|
if i < 1000: |
|
rest_val += val_image2items[image_id] |
|
else: |
|
trainable_val += val_image2items[image_id] |
|
|
|
_write_data_into_jsonl(items=trainable_val, jsonl_file=os.path.join(data_path, "vqa.trainable_val.jsonl")) |
|
_write_data_into_jsonl(items=rest_val, jsonl_file=os.path.join(data_path, "vqa.rest_val.jsonl")) |
|
|
|
with open(os.path.join(data_path, "answer2label.txt"), mode="w", encoding="utf-8") as writer: |
|
for ans in ans2label: |
|
to_json = { |
|
"answer": ans, |
|
"label": ans2label[ans] |
|
} |
|
writer.write("%s\n" % json.dumps(to_json)) |
|
|
|
|
|
class RetrievalDataset(BaseDataset): |
|
@staticmethod |
|
def get_index_files(split, task=None): |
|
if split == "train": |
|
return (f"{task}.train.jsonl", ) |
|
elif split == "val": |
|
return (f"{task}.val.jsonl", ) |
|
elif split == "test": |
|
return (f"{task}.test.jsonl", ) |
|
else: |
|
raise RuntimeError("split %s is not found!" % split) |
|
|
|
def __getitem__(self, index: int): |
|
data = super().__getitem__(index) |
|
data["image_id"] = self.items[index]["image_id"] |
|
return data |
|
|
|
@staticmethod |
|
def make_flickr30k_dataset_index(data_path, tokenizer, karpathy_path): |
|
|
|
with open(os.path.join(karpathy_path, "dataset_flickr30k.json"), "r") as reader: |
|
captions = json.loads(reader.read()) |
|
|
|
captions = captions["images"] |
|
split2items = defaultdict(list) |
|
split2images = defaultdict(set) |
|
|
|
for each_item in captions: |
|
image_path = os.path.join("flickr30k-images", each_item["filename"]) |
|
split = each_item["split"] |
|
|
|
for text_segment in each_item["sentences"]: |
|
tokens = tokenizer.tokenize(text_segment["raw"]) |
|
token_ids = tokenizer.convert_tokens_to_ids(tokens) |
|
|
|
split2items[split].append({ |
|
"image_path": image_path, |
|
"text_segment": token_ids, |
|
"image_id": len(split2images[split]), |
|
}) |
|
|
|
assert each_item["filename"] not in split2images[split] |
|
split2images[split].add(each_item["filename"]) |
|
|
|
for split in split2items: |
|
print("%d images and %d image-text pairs!" % (len(split2images[split]), len(split2items[split]))) |
|
_write_data_into_jsonl(split2items[split], os.path.join(data_path, "flickr30k.%s.jsonl" % split)) |
|
|
|
@staticmethod |
|
def make_coco_dataset_index(data_path, tokenizer): |
|
_make_retrieval_coco_karpathy_dataset_index(data_path, tokenizer, split=("train", "restval"), split_name="train") |
|
_make_retrieval_coco_karpathy_dataset_index(data_path, tokenizer, split=("val", ), split_name="val") |
|
_make_retrieval_coco_karpathy_dataset_index(data_path, tokenizer, split=("test", ), split_name="test") |
|
|
|
|
|
class CaptioningDataset(BaseDataset): |
|
|
|
def __init__(self, data_path, split, transform, |
|
tokenizer, num_max_bpe_tokens, task, mask_prob): |
|
super().__init__( |
|
data_path=data_path, split=split, |
|
transform=transform, tokenizer=tokenizer, |
|
num_max_bpe_tokens=num_max_bpe_tokens, task=task, |
|
) |
|
self.mask_token_id = tokenizer.mask_token_id |
|
self.language_vocab_size = tokenizer.vocab_size |
|
self.mask_prob = mask_prob |
|
|
|
@staticmethod |
|
def get_index_files(split, task=None): |
|
if split == "train": |
|
return ("coco_captioning.train.jsonl", ) |
|
elif split == "val": |
|
return (f"{task}.val.jsonl", ) |
|
elif split == "test": |
|
return (f"{task}.test.jsonl", ) |
|
else: |
|
raise RuntimeError("split %s is not found!" % split) |
|
|
|
def _get_mask_token(self, token): |
|
p = random.random() |
|
if p < 0.8: |
|
return self.mask_token_id |
|
elif p < 0.9: |
|
return token |
|
else: |
|
return random.randint(3, self.language_vocab_size - 1) |
|
|
|
def _masking_on_text_tokens(self, tokens, num_tokens, mask_prob): |
|
bool_masked_pos = [0] * len(tokens) |
|
to_mask = min(int(num_tokens * mask_prob + 0.5), num_tokens - 1) |
|
to_mask = max(to_mask, 1) |
|
num_masked_tokens = 0 |
|
while num_masked_tokens < to_mask: |
|
i = random.randint(1, num_tokens - 1) |
|
if bool_masked_pos[i] == 0: |
|
bool_masked_pos[i] = 1 |
|
tokens[i] = self._get_mask_token(tokens[i]) |
|
num_masked_tokens += 1 |
|
|
|
return tokens, bool_masked_pos |
|
|
|
def __getitem__(self, index: int): |
|
data = dict() |
|
item = self.items[index] |
|
img_path = item["image_path"] |
|
img = self._get_image(img_path) |
|
data["image"] = img |
|
data["image_id"] = item["image_id"] |
|
|
|
text_segment = item["text_segment"] |
|
if text_segment is not None: |
|
language_tokens, padding_mask, num_tokens = self._get_text_segment(text_segment) |
|
masked_tokens = language_tokens[:] |
|
masked_tokens, language_masked_pos = \ |
|
self._masking_on_text_tokens(masked_tokens, num_tokens, self.mask_prob) |
|
data["language_tokens"] = language_tokens |
|
data["masked_tokens"] = masked_tokens |
|
data["language_masked_pos"] = language_masked_pos |
|
data["padding_mask"] = padding_mask |
|
return data |
|
|
|
@staticmethod |
|
def make_coco_captioning_dataset_index(data_path, tokenizer): |
|
_make_captioning_coco_karpathy_dataset_index(data_path, tokenizer, split=("train", "restval"), split_name="train") |
|
_make_captioning_coco_karpathy_dataset_index(data_path, tokenizer, split=("val", ), split_name="val") |
|
_make_captioning_coco_karpathy_dataset_index(data_path, tokenizer, split=("test", ), split_name="test") |
|
|
|
@staticmethod |
|
def make_nocaps_captioning_dataset_index(data_path): |
|
_make_nocaps_dataset_index(data_path, split="val") |
|
_make_nocaps_dataset_index(data_path, split="test") |
|
|
|
|
|
task2dataset = { |
|
"nlvr2": NLVR2Dataset, |
|
"vqav2": VQAv2Dataset, |
|
"flickr30k": RetrievalDataset, |
|
"coco_retrieval": RetrievalDataset, |
|
"coco_captioning": CaptioningDataset, |
|
"nocaps": CaptioningDataset, |
|
"imagenet": ImageNetDataset, |
|
} |
|
|
|
|
|
def create_dataloader(dataset, is_train, batch_size, num_workers, pin_mem, dist_eval=False): |
|
if is_train or dist_eval: |
|
num_tasks = utils.get_world_size() |
|
global_rank = utils.get_rank() |
|
|
|
if not is_train and dist_eval and len(dataset) % num_tasks != 0: |
|
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' |
|
'This will slightly alter validation results as extra duplicate entries are added to achieve ' |
|
'equal num of samples per-process.') |
|
|
|
sampler = torch.utils.data.DistributedSampler( |
|
dataset, num_replicas=num_tasks, rank=global_rank, shuffle=is_train |
|
) |
|
else: |
|
sampler = torch.utils.data.SequentialSampler(dataset) |
|
|
|
return torch.utils.data.DataLoader( |
|
dataset, sampler=sampler, |
|
batch_size=batch_size, |
|
num_workers=num_workers, |
|
pin_memory=pin_mem, |
|
drop_last=is_train, |
|
collate_fn=utils.merge_batch_tensors_by_dict_key, |
|
) |
|
|
|
|
|
def build_transform(is_train, args): |
|
if args.task in ["imagenet"]: |
|
return build_imagenet_transform(is_train, args) |
|
|
|
if is_train: |
|
t = [ |
|
RandomResizedCropAndInterpolation(args.input_size, scale=(0.5, 1.0), interpolation=args.train_interpolation), |
|
transforms.RandomHorizontalFlip(), |
|
] |
|
if args.randaug: |
|
t.append( |
|
RandomAugment( |
|
2, 7, isPIL=True, |
|
augs=[ |
|
'Identity','AutoContrast','Equalize','Brightness','Sharpness', |
|
'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate', |
|
])) |
|
t += [ |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), |
|
] |
|
t = transforms.Compose(t) |
|
else: |
|
t = transforms.Compose([ |
|
transforms.Resize((args.input_size, args.input_size), interpolation=3), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD) |
|
]) |
|
|
|
return t |
|
|
|
|
|
def build_imagenet_transform(is_train, args): |
|
resize_im = args.input_size > 32 |
|
if is_train: |
|
|
|
transform = create_transform( |
|
input_size=args.input_size, |
|
is_training=True, |
|
color_jitter=args.color_jitter, |
|
auto_augment=args.aa, |
|
interpolation=args.train_interpolation, |
|
re_prob=args.reprob, |
|
re_mode=args.remode, |
|
re_count=args.recount, |
|
mean=IMAGENET_DEFAULT_MEAN, |
|
std=IMAGENET_DEFAULT_STD, |
|
) |
|
if not resize_im: |
|
|
|
|
|
transform.transforms[0] = transforms.RandomCrop( |
|
args.input_size, padding=4) |
|
return transform |
|
|
|
t = [] |
|
if resize_im: |
|
if args.crop_pct is None: |
|
args.crop_pct = 1.0 |
|
size = int(args.input_size / args.crop_pct) |
|
t.append( |
|
transforms.Resize(size, interpolation=3), |
|
) |
|
t.append(transforms.CenterCrop(args.input_size)) |
|
|
|
t.append(transforms.ToTensor()) |
|
t.append(transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)) |
|
return transforms.Compose(t) |
|
|
|
|
|
def get_sentencepiece_model_for_beit3(args): |
|
from transformers import XLMRobertaTokenizer |
|
return XLMRobertaTokenizer(args.sentencepiece_model) |
|
|
|
|
|
def create_dataset_by_split(args, split, is_train=True): |
|
transform = build_transform(is_train=is_train, args=args) |
|
dataset_class = task2dataset[args.task] |
|
tokenizer = get_sentencepiece_model_for_beit3(args) |
|
|
|
opt_kwargs = {} |
|
if args.task in ["coco_captioning", "nocaps"]: |
|
opt_kwargs["mask_prob"] = args.captioning_mask_prob |
|
|
|
dataset = dataset_class( |
|
data_path=args.data_path, split=split, |
|
transform=transform, tokenizer=tokenizer, |
|
num_max_bpe_tokens=args.num_max_bpe_tokens, |
|
task=args.task, **opt_kwargs, |
|
) |
|
if is_train: |
|
batch_size = args.batch_size |
|
elif hasattr(args, "eval_batch_size") and args.eval_batch_size is not None: |
|
batch_size = args.eval_batch_size |
|
else: |
|
batch_size = int(args.batch_size * 1.5) |
|
|
|
return create_dataloader( |
|
dataset, is_train=is_train, batch_size=batch_size, |
|
num_workers=args.num_workers, pin_mem=args.pin_mem, dist_eval=args.dist_eval, |
|
) |
|
|
|
|
|
def create_downstream_dataset(args, is_eval=False): |
|
if is_eval: |
|
return create_dataset_by_split(args, split="test", is_train=False) |
|
else: |
|
return \ |
|
create_dataset_by_split(args, split="train", is_train=True), \ |
|
create_dataset_by_split(args, split="val", is_train=True) |
|
|