|
import json |
|
|
|
from PIL import Image |
|
from torch.utils.data import Dataset |
|
from torchvision.datasets import ImageFolder |
|
|
|
|
|
class COCOFlickrDataset(Dataset): |
|
def __init__( |
|
self, |
|
image_dir_path, |
|
annotations_path, |
|
transform=None, |
|
is_flickr=False, |
|
prefix=None, |
|
): |
|
self.image_dir_path = image_dir_path |
|
self.annotations = json.load(open(annotations_path))["annotations"] |
|
self.is_flickr = is_flickr |
|
self.transform = transform |
|
self.prefix = prefix |
|
|
|
def __len__(self): |
|
return len(self.annotations) |
|
|
|
def get_img_path(self, idx): |
|
if self.is_flickr: |
|
return f"{self.image_dir_path}/{self.annotations[idx]['image_id']}.jpg" |
|
else: |
|
return f"{self.image_dir_path}/{self.prefix}{self.annotations[idx]['image_id']:012d}.jpg" |
|
|
|
def __getitem__(self, idx): |
|
image = Image.open(self.get_img_path(idx)) |
|
caption = self.annotations[idx]["caption"] |
|
return self.transform(image), caption |
|
|
|
|
|
class ImageNetDataset(ImageFolder): |
|
"""Class to represent the ImageNet1k dataset.""" |
|
|
|
def __init__(self, root, **kwargs): |
|
super().__init__(root=root, **kwargs) |
|
|
|
def __getitem__(self, idx): |
|
sample, target = super().__getitem__(idx) |
|
|
|
return sample, target |
|
|