|
|
|
from transformers import BertTokenizerFast |
|
import os |
|
import tensorflow as tf |
|
|
|
class MiniSunTokenizer: |
|
def __init__(self, vocab_file=None): |
|
if vocab_file: |
|
self.tokenizer = BertTokenizerFast(vocab_file=vocab_file, do_lower_case=False) |
|
else: |
|
self.tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') |
|
|
|
|
|
self.pad_token = '[PAD]' |
|
self.unk_token = '[UNK]' |
|
self.cls_token = '[CLS]' |
|
self.sep_token = '[SEP]' |
|
self.mask_token = '[MASK]' |
|
self.eos_token = '[EOS]' |
|
|
|
def encode(self, text, max_length=512, padding=True, truncation=True): |
|
""" |
|
Encodes the input text (string or batch of strings). |
|
It automatically detects if the input is a batch or a single sentence. |
|
""" |
|
if isinstance(text, list): |
|
return self._encode_batch(text, max_length, padding, truncation) |
|
else: |
|
return self._encode_single(text, max_length, padding, truncation) |
|
|
|
def _encode_single(self, text, max_length=512, padding=True, truncation=True): |
|
|
|
encoded = self.tokenizer.encode_plus( |
|
text, |
|
add_special_tokens=True, |
|
max_length=max_length, |
|
padding='max_length' if padding else False, |
|
truncation=truncation, |
|
return_attention_mask=True, |
|
return_tensors='np' |
|
) |
|
return { |
|
'input_ids': encoded['input_ids'], |
|
'attention_mask': encoded['attention_mask'] |
|
} |
|
|
|
def _encode_batch(self, texts, max_length=512, padding=True, truncation=True): |
|
|
|
encoded_batch = self.tokenizer.batch_encode_plus( |
|
texts, |
|
add_special_tokens=True, |
|
max_length=max_length, |
|
padding='max_length' if padding else False, |
|
truncation=truncation, |
|
return_attention_mask=True, |
|
return_tensors='np' |
|
) |
|
return { |
|
'input_ids': encoded_batch['input_ids'], |
|
'attention_mask': encoded_batch['attention_mask'] |
|
} |
|
|
|
def decode(self, token_ids): |
|
|
|
return self.tokenizer.decode(token_ids, skip_special_tokens=True) |
|
|
|
def save_pretrained(self, save_directory): |
|
|
|
os.makedirs(save_directory, exist_ok=True) |
|
self.tokenizer.save_pretrained(save_directory) |
|
|
|
def __call__(self, text, *args, **kwargs): |
|
""" |
|
This allows the tokenizer object to be called directly like `tokenizer(text)`. |
|
It will automatically detect if the input is a batch or a single sentence. |
|
""" |
|
return self.encode(text, *args, **kwargs) |
|
|
|
|
|
|
|
tokenizer = MiniSunTokenizer() |