import torch from transformers import AutoTokenizer class DaedalusTokenizer(AutoTokenizer): def __init__(self, config): super(DaedalusTokenizer, self).__init__(config) self.config = config def encode(self, text): return self.encode_plus(text, max_length=self.config.max_seq_length, padding='max_length', truncation=True) def decode(self, ids): return self.decode(ids, skip_special_tokens=True)