import torch from transformers import LogitsProcessor BOI_TOKEN = '' EOI_TOKEN = '' IMG_TOKEN = '' class AutoImageTokenGenerationProcessor(LogitsProcessor): def __init__(self, tokenizer, num_img_gen_tokens=64) -> None: super().__init__() # self.boi_token_id = tokenizer.encode(BOI_TOKEN)[0] # self.eoi_token_id = tokenizer.encode(EOI_TOKEN)[0] img_all_token_str = ''.join([BOI_TOKEN] + [IMG_TOKEN.format(int(item)) for item in range(num_img_gen_tokens)] + [EOI_TOKEN]) self.img_ids_list = tokenizer.encode(img_all_token_str, add_special_tokens=False) def __call__(self, input_ids, scores): bz = input_ids.shape[0] for i in range(bz): cur_input_id = input_ids[i, -1].item() if cur_input_id in self.img_ids_list[:-1]: output_id = self.img_ids_list[self.img_ids_list.index(cur_input_id) + 1] scores[i, ..., output_id] = scores[i, ...].max() + 10. else: scores[i, ..., torch.tensor(self.img_ids_list[1:]).to(dtype=torch.long)] = 0.0 return scores