|
import einops |
|
import torch |
|
import torchvision |
|
from PIL import Image |
|
|
|
|
|
class DummyVisionTokenizer: |
|
def __init__(self, vocab_size, image_size, |
|
add_mask_token=True, |
|
add_special_tokens=True): |
|
self.pad_token_id = None |
|
self.pad_token = None |
|
if add_mask_token: |
|
self.mask_token = vocab_size |
|
self.mask_token_id = vocab_size |
|
self.vocab_size = vocab_size + 1 |
|
else: |
|
self.vocab_size = vocab_size |
|
if add_special_tokens: |
|
self.bos_token_id = vocab_size |
|
self.bos_token = vocab_size |
|
self.eos_token_id = vocab_size + 1 |
|
self.eos_token = vocab_size + 1 |
|
self.vocab_size = self.vocab_size + 2 |
|
else: |
|
self.vocab_size = self.vocab_size |
|
self.image_size = image_size |
|
|
|
def __call__(self, x): |
|
return x |
|
|
|
def batch_decode(self, x): |
|
return einops.rearrange(x, "b (c h w) -> b c h w", c=3, |
|
h=self.image_size) |
|
|
|
def decode(self, x): |
|
return einops.rearrange(x, "(c h w) -> c h w", c=3, |
|
h=self.image_size) |
|
|
|
|
|
class DiscreteCIFAR10(torchvision.datasets.CIFAR10): |
|
def __init__(self, root, train, **kwargs): |
|
super().__init__(root=root, train=train, |
|
**kwargs) |
|
self.transform = torchvision.transforms.Compose( |
|
[ |
|
torchvision.transforms.Resize(32), |
|
torchvision.transforms.RandomHorizontalFlip(), |
|
torchvision.transforms.ToTensor(), |
|
torchvision.transforms.Lambda( |
|
lambda x: einops.rearrange(x, "c h w -> (c h w)")), |
|
] |
|
) |
|
|
|
def __getitem__(self, index): |
|
""" |
|
Args: |
|
index (int): Index |
|
|
|
Returns: |
|
tuple: (image, target) where target is index of the target class. |
|
""" |
|
img, target = self.data[index], self.targets[index] |
|
|
|
|
|
|
|
img = Image.fromarray(img) |
|
|
|
if self.transform is not None: |
|
img = self.transform(img) |
|
img = (img * 255).to(torch.long) |
|
|
|
if self.target_transform is not None: |
|
target = self.target_transform(target) |
|
|
|
attention_mask = torch.ones_like(img) |
|
|
|
return {'input_ids': img, 'labels': target, |
|
'attention_mask': attention_mask} |
|
|