File size: 2,278 Bytes
65bd8af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
import einops
import torch
import torchvision
from PIL import Image
class DummyVisionTokenizer:
def __init__(self, vocab_size, image_size,
add_mask_token=True,
add_special_tokens=True):
self.pad_token_id = None
self.pad_token = None
if add_mask_token:
self.mask_token = vocab_size
self.mask_token_id = vocab_size
self.vocab_size = vocab_size + 1 # mask token
else:
self.vocab_size = vocab_size
if add_special_tokens:
self.bos_token_id = vocab_size
self.bos_token = vocab_size
self.eos_token_id = vocab_size + 1
self.eos_token = vocab_size + 1
self.vocab_size = self.vocab_size + 2 # mask token, bos_token, eos_token
else:
self.vocab_size = self.vocab_size
self.image_size = image_size
def __call__(self, x):
return x
def batch_decode(self, x):
return einops.rearrange(x, "b (c h w) -> b c h w", c=3,
h=self.image_size)
def decode(self, x):
return einops.rearrange(x, "(c h w) -> c h w", c=3,
h=self.image_size)
class DiscreteCIFAR10(torchvision.datasets.CIFAR10):
def __init__(self, root, train, **kwargs):
super().__init__(root=root, train=train,
**kwargs)
self.transform = torchvision.transforms.Compose(
[
torchvision.transforms.Resize(32),
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
torchvision.transforms.Lambda(
lambda x: einops.rearrange(x, "c h w -> (c h w)")),
]
)
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.targets[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
img = (img * 255).to(torch.long)
if self.target_transform is not None:
target = self.target_transform(target)
attention_mask = torch.ones_like(img)
return {'input_ids': img, 'labels': target,
'attention_mask': attention_mask}
|