|
import torch |
|
import numpy as np |
|
from transformers import RobertaTokenizer, BatchEncoding, RobertaTokenizerFast |
|
import warnings |
|
|
|
|
|
def get_tokenizer(parent_class): |
|
class TokenizerClass(parent_class): |
|
def __init__(self, *args, **kwargs): |
|
""" |
|
JinaTokenizer extends the RobertaTokenizer class to include task_type_ids in |
|
the batch encoding. |
|
The task_type_ids are used to pass instruction information to the model. |
|
A task_type should either be an integer or a sequence of integers with the same |
|
length as the batch size. |
|
""" |
|
super().__init__(*args, **kwargs) |
|
|
|
def __call__(self, *args, task_type=None, **kwargs): |
|
batch_encoding = super().__call__(*args, **kwargs) |
|
if task_type is not None: |
|
batch_encoding = BatchEncoding( |
|
{ |
|
'task_type_ids': self._get_task_type_ids(batch_encoding, task_type), |
|
**batch_encoding, |
|
}, |
|
tensor_type=kwargs.get('return_tensors'), |
|
) |
|
return batch_encoding |
|
|
|
def _batch_encode_plus(self, *args, task_type=None, **kwargs): |
|
batch_encoding = super()._batch_encode_plus(*args, **kwargs) |
|
if task_type is not None: |
|
batch_encoding = BatchEncoding( |
|
{ |
|
'task_type_ids': self._get_task_type_ids(batch_encoding, task_type), |
|
**batch_encoding, |
|
}, |
|
tensor_type=kwargs.get('return_tensors'), |
|
) |
|
return batch_encoding |
|
|
|
def _encode_plus(self, *args, task_type=None, **kwargs): |
|
batch_encoding = super()._encode_plus(*args, **kwargs) |
|
if task_type is not None: |
|
batch_encoding = BatchEncoding( |
|
{ |
|
'task_type_ids': self._get_task_type_ids(batch_encoding, task_type), |
|
**batch_encoding, |
|
}, |
|
tensor_type=kwargs.get('return_tensors'), |
|
) |
|
return batch_encoding |
|
|
|
@staticmethod |
|
def _get_task_type_ids(batch_encoding: BatchEncoding, task_type): |
|
|
|
def apply_task_type(m, x): |
|
x = torch.tensor(x) |
|
assert ( |
|
len(x.shape) == 0 or x.shape[0] == m.shape[0] |
|
), 'The shape of task_type does not match the size of the batch.' |
|
return m * x if len(x.shape) == 0 else m * x[:, None] |
|
|
|
if isinstance(batch_encoding['input_ids'], torch.Tensor): |
|
shape = batch_encoding['input_ids'].shape |
|
return apply_task_type(torch.ones(shape, dtype=torch.long), task_type) |
|
else: |
|
try: |
|
shape = torch.tensor(batch_encoding['input_ids']).shape |
|
except: |
|
raise ValueError( |
|
"Unable to create tensor, you should probably " |
|
"activate truncation and/or padding with " |
|
"'padding=True' 'truncation=True' to have batched " |
|
"tensors with the same length." |
|
) |
|
if isinstance(batch_encoding['input_ids'], list): |
|
return ( |
|
apply_task_type(torch.ones(shape, dtype=torch.long), task_type) |
|
).tolist() |
|
elif isinstance(batch_encoding['input_ids'], np.array): |
|
return ( |
|
apply_task_type(torch.ones(shape, dtype=torch.long), task_type) |
|
).numpy() |
|
else: |
|
warnings.warn( |
|
'input_ids is not a torch tensor, numpy array, or list. Returning torch tensor' |
|
) |
|
return apply_task_type(torch.ones(shape, dtype=torch.long), task_type) |
|
|
|
return TokenizerClass |
|
|
|
|
|
JinaTokenizer = get_tokenizer(RobertaTokenizer) |
|
JinaTokenizerFast = get_tokenizer(RobertaTokenizerFast) |
|
|