File size: 2,278 Bytes
65bd8af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import einops
import torch
import torchvision
from PIL import Image


class DummyVisionTokenizer:
  def __init__(self, vocab_size, image_size,
               add_mask_token=True,
               add_special_tokens=True):
    self.pad_token_id = None
    self.pad_token = None
    if add_mask_token:
      self.mask_token = vocab_size
      self.mask_token_id = vocab_size
      self.vocab_size = vocab_size + 1  # mask token
    else:
      self.vocab_size = vocab_size
    if add_special_tokens:
      self.bos_token_id = vocab_size
      self.bos_token = vocab_size
      self.eos_token_id = vocab_size + 1
      self.eos_token = vocab_size + 1
      self.vocab_size = self.vocab_size + 2  # mask token, bos_token, eos_token
    else:
      self.vocab_size = self.vocab_size
    self.image_size = image_size

  def __call__(self, x):
    return x

  def batch_decode(self, x):
    return einops.rearrange(x, "b (c h w) -> b c h w", c=3,
                     h=self.image_size)

  def decode(self, x):
    return einops.rearrange(x, "(c h w) -> c h w", c=3,
                     h=self.image_size)


class DiscreteCIFAR10(torchvision.datasets.CIFAR10):
  def __init__(self, root, train, **kwargs):
    super().__init__(root=root, train=train,
                     **kwargs)
    self.transform = torchvision.transforms.Compose(
      [
        torchvision.transforms.Resize(32),
        torchvision.transforms.RandomHorizontalFlip(),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Lambda(
          lambda x: einops.rearrange(x, "c h w -> (c h w)")),
      ]
    )

  def __getitem__(self, index):
    """
    Args:
        index (int): Index

    Returns:
        tuple: (image, target) where target is index of the target class.
    """
    img, target = self.data[index], self.targets[index]

    # doing this so that it is consistent with all other datasets
    # to return a PIL Image
    img = Image.fromarray(img)

    if self.transform is not None:
      img = self.transform(img)
    img = (img * 255).to(torch.long)

    if self.target_transform is not None:
      target = self.target_transform(target)

    attention_mask = torch.ones_like(img)

    return {'input_ids': img, 'labels': target,
            'attention_mask': attention_mask}