File size: 1,181 Bytes
674d663
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch
from transformers import LogitsProcessor

BOI_TOKEN = '<img>'
EOI_TOKEN = '</img>'
IMG_TOKEN = '<img_{:05d}>'


class AutoImageTokenGenerationProcessor(LogitsProcessor):

    def __init__(self, tokenizer, num_img_gen_tokens=64) -> None:
        super().__init__()
        # self.boi_token_id = tokenizer.encode(BOI_TOKEN)[0]
        # self.eoi_token_id = tokenizer.encode(EOI_TOKEN)[0]
        img_all_token_str = ''.join([BOI_TOKEN] + [IMG_TOKEN.format(int(item))
                                                   for item in range(num_img_gen_tokens)] + [EOI_TOKEN])
        self.img_ids_list = tokenizer.encode(img_all_token_str, add_special_tokens=False)

    def __call__(self, input_ids, scores):
        bz = input_ids.shape[0]
        for i in range(bz):
            cur_input_id = input_ids[i, -1].item()
            if cur_input_id in self.img_ids_list[:-1]:

                output_id = self.img_ids_list[self.img_ids_list.index(cur_input_id) + 1]
                scores[i, ..., output_id] = scores[i, ...].max() + 10.
            else:

                scores[i, ..., torch.tensor(self.img_ids_list[1:]).to(dtype=torch.long)] = 0.0

        return scores