Spaces:
Running
Running
"""Tweaked version of corresponding AllenNLP file""" | |
import logging | |
from collections import defaultdict | |
from typing import Dict, List, Callable | |
from allennlp.common.util import pad_sequence_to_length | |
from allennlp.data.token_indexers.token_indexer import TokenIndexer | |
from allennlp.data.tokenizers.token import Token | |
from allennlp.data.vocabulary import Vocabulary | |
from overrides import overrides | |
from transformers import AutoTokenizer | |
from utils.helpers import START_TOKEN | |
from gector.tokenization import tokenize_batch | |
import copy | |
logger = logging.getLogger(__name__) | |
# TODO(joelgrus): Figure out how to generate token_type_ids out of this token indexer. | |
class TokenizerIndexer(TokenIndexer[int]): | |
""" | |
A token indexer that does the wordpiece-tokenization (e.g. for BERT embeddings). | |
If you are using one of the pretrained BERT models, you'll want to use the ``PretrainedBertIndexer`` | |
subclass rather than this base class. | |
Parameters | |
---------- | |
tokenizer : ``Callable[[str], List[str]]`` | |
A function that does the actual tokenization. | |
max_pieces : int, optional (default: 512) | |
The BERT embedder uses positional embeddings and so has a corresponding | |
maximum length for its input ids. Any inputs longer than this will | |
either be truncated (default), or be split apart and batched using a | |
sliding window. | |
token_min_padding_length : ``int``, optional (default=``0``) | |
See :class:`TokenIndexer`. | |
""" | |
def __init__(self, | |
tokenizer: Callable[[str], List[str]], | |
max_pieces: int = 512, | |
max_pieces_per_token: int = 3, | |
token_min_padding_length: int = 0) -> None: | |
super().__init__(token_min_padding_length) | |
# The BERT code itself does a two-step tokenization: | |
# sentence -> [words], and then word -> [wordpieces] | |
# In AllenNLP, the first step is implemented as the ``BertBasicWordSplitter``, | |
# and this token indexer handles the second. | |
self.tokenizer = tokenizer | |
self.max_pieces_per_token = max_pieces_per_token | |
self.max_pieces = max_pieces | |
self.max_pieces_per_sentence = 80 | |
def tokens_to_indices(self, tokens: List[Token], | |
vocabulary: Vocabulary, | |
index_name: str) -> Dict[str, List[int]]: | |
text = [token.text for token in tokens] | |
batch_tokens = [text] | |
output_fast = tokenize_batch(self.tokenizer, | |
batch_tokens, | |
max_bpe_length=self.max_pieces, | |
max_bpe_pieces=self.max_pieces_per_token) | |
output_fast = {k: v[0] for k, v in output_fast.items()} | |
return output_fast | |
def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]): | |
# If we only use pretrained models, we don't need to do anything here. | |
pass | |
def get_padding_token(self) -> int: | |
return 0 | |
def get_padding_lengths(self, token: int) -> Dict[str, int]: # pylint: disable=unused-argument | |
return {} | |
def pad_token_sequence(self, | |
tokens: Dict[str, List[int]], | |
desired_num_tokens: Dict[str, int], | |
padding_lengths: Dict[str, int]) -> Dict[str, List[int]]: # pylint: disable=unused-argument | |
return {key: pad_sequence_to_length(val, desired_num_tokens[key]) | |
for key, val in tokens.items()} | |
def get_keys(self, index_name: str) -> List[str]: | |
""" | |
We need to override this because the indexer generates multiple keys. | |
""" | |
# pylint: disable=no-self-use | |
return [index_name, f"{index_name}-offsets", f"{index_name}-type-ids", "mask"] | |
class PretrainedBertIndexer(TokenizerIndexer): | |
# pylint: disable=line-too-long | |
""" | |
A ``TokenIndexer`` corresponding to a pretrained BERT model. | |
Parameters | |
---------- | |
pretrained_model: ``str`` | |
Either the name of the pretrained model to use (e.g. 'bert-base-uncased'), | |
or the path to the .txt file with its vocabulary. | |
If the name is a key in the list of pretrained models at | |
https://github.com/huggingface/pytorch-pretrained-BERT/blob/master/pytorch_pretrained_bert/tokenization.py#L33 | |
the corresponding path will be used; otherwise it will be interpreted as a path or URL. | |
do_lowercase: ``bool``, optional (default = True) | |
Whether to lowercase the tokens before converting to wordpiece ids. | |
max_pieces: int, optional (default: 512) | |
The BERT embedder uses positional embeddings and so has a corresponding | |
maximum length for its input ids. Any inputs longer than this will | |
either be truncated (default), or be split apart and batched using a | |
sliding window. | |
""" | |
def __init__(self, | |
pretrained_model: str, | |
do_lowercase: bool = True, | |
max_pieces: int = 512, | |
max_pieces_per_token: int = 5, | |
special_tokens_fix: int = 0) -> None: | |
if pretrained_model.endswith("-cased") and do_lowercase: | |
logger.warning("Your BERT model appears to be cased, " | |
"but your indexer is lowercasing tokens.") | |
elif pretrained_model.endswith("-uncased") and not do_lowercase: | |
logger.warning("Your BERT model appears to be uncased, " | |
"but your indexer is not lowercasing tokens.") | |
model_name = copy.deepcopy(pretrained_model) | |
model_tokenizer = AutoTokenizer.from_pretrained( | |
model_name, do_lower_case=do_lowercase, do_basic_tokenize=False, use_fast=True) | |
# to adjust all tokenizers | |
if hasattr(model_tokenizer, 'encoder'): | |
model_tokenizer.vocab = model_tokenizer.encoder | |
if hasattr(model_tokenizer, 'sp_model'): | |
model_tokenizer.vocab = defaultdict(lambda: 1) | |
for i in range(model_tokenizer.sp_model.get_piece_size()): | |
model_tokenizer.vocab[model_tokenizer.sp_model.id_to_piece(i)] = i | |
if special_tokens_fix: | |
model_tokenizer.add_tokens([START_TOKEN]) | |
model_tokenizer.vocab[START_TOKEN] = len(model_tokenizer) - 1 | |
super().__init__(tokenizer=model_tokenizer, | |
max_pieces=max_pieces, | |
max_pieces_per_token=max_pieces_per_token | |
) | |