|
from transformers import BertTokenizerFast |
|
import os |
|
|
|
class MiniSunTokenizer: |
|
def __init__(self, vocab_file=None): |
|
|
|
if vocab_file: |
|
self.tokenizer = BertTokenizerFast(vocab_file=vocab_file, do_lower_case=False) |
|
else: |
|
|
|
self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') |
|
|
|
|
|
self.pad_token = '[PAD]' |
|
self.unk_token = '[UNK]' |
|
self.cls_token = '[CLS]' |
|
self.sep_token = '[SEP]' |
|
self.mask_token = '[MASK]' |
|
|
|
def tokenize(self, text): |
|
|
|
return self.tokenizer.tokenize(text) |
|
|
|
def encode(self, text, max_length=512, padding=True, truncation=True): |
|
|
|
encoded = self.tokenizer.encode_plus( |
|
text, |
|
add_special_tokens=True, |
|
max_length=max_length, |
|
padding='max_length' if padding else False, |
|
truncation=truncation, |
|
return_attention_mask=True, |
|
return_tensors='tf' |
|
) |
|
return encoded['input_ids'], encoded['attention_mask'] |
|
|
|
def decode(self, token_ids): |
|
|
|
return self.tokenizer.decode(token_ids, skip_special_tokens=True) |
|
|
|
def save_pretrained(self, save_directory): |
|
|
|
os.makedirs(save_directory, exist_ok=True) |
|
self.tokenizer.save_pretrained(save_directory) |
|
|
|
|
|
tokenizer = MiniSunTokenizer() |
|
|
|
text = "Hello, this is a test sentence for MiniSun model." |
|
input_ids, attention_mask = tokenizer.encode(text, max_length=20) |
|
|
|
print("Input IDs:", input_ids) |
|
print("Attention Mask:", attention_mask) |