huamnifierWithSimpleGrammer / gector /
Initial commit for Gradio app with GECToR
history blame
6.64 kB
"""Tweaked version of corresponding AllenNLP file"""
import logging
from collections import defaultdict
from typing import Dict, List, Callable
from allennlp.common.util import pad_sequence_to_length
from import TokenIndexer
from import Token
from import Vocabulary
from overrides import overrides
from transformers import AutoTokenizer
from utils.helpers import START_TOKEN
from gector.tokenization import tokenize_batch
import copy
logger = logging.getLogger(__name__)
# TODO(joelgrus): Figure out how to generate token_type_ids out of this token indexer.
class TokenizerIndexer(TokenIndexer[int]):
A token indexer that does the wordpiece-tokenization (e.g. for BERT embeddings).
If you are using one of the pretrained BERT models, you'll want to use the ``PretrainedBertIndexer``
subclass rather than this base class.
tokenizer : ``Callable[[str], List[str]]``
A function that does the actual tokenization.
max_pieces : int, optional (default: 512)
The BERT embedder uses positional embeddings and so has a corresponding
maximum length for its input ids. Any inputs longer than this will
either be truncated (default), or be split apart and batched using a
sliding window.
token_min_padding_length : ``int``, optional (default=``0``)
See :class:`TokenIndexer`.
def __init__(self,
tokenizer: Callable[[str], List[str]],
max_pieces: int = 512,
max_pieces_per_token: int = 3,
token_min_padding_length: int = 0) -> None:
# The BERT code itself does a two-step tokenization:
# sentence -> [words], and then word -> [wordpieces]
# In AllenNLP, the first step is implemented as the ``BertBasicWordSplitter``,
# and this token indexer handles the second.
self.tokenizer = tokenizer
self.max_pieces_per_token = max_pieces_per_token
self.max_pieces = max_pieces
self.max_pieces_per_sentence = 80
def tokens_to_indices(self, tokens: List[Token],
vocabulary: Vocabulary,
index_name: str) -> Dict[str, List[int]]:
text = [token.text for token in tokens]
batch_tokens = [text]
output_fast = tokenize_batch(self.tokenizer,
output_fast = {k: v[0] for k, v in output_fast.items()}
return output_fast
def count_vocab_items(self, token: Token, counter: Dict[str, Dict[str, int]]):
# If we only use pretrained models, we don't need to do anything here.
def get_padding_token(self) -> int:
return 0
def get_padding_lengths(self, token: int) -> Dict[str, int]: # pylint: disable=unused-argument
return {}
def pad_token_sequence(self,
tokens: Dict[str, List[int]],
desired_num_tokens: Dict[str, int],
padding_lengths: Dict[str, int]) -> Dict[str, List[int]]: # pylint: disable=unused-argument
return {key: pad_sequence_to_length(val, desired_num_tokens[key])
for key, val in tokens.items()}
def get_keys(self, index_name: str) -> List[str]:
We need to override this because the indexer generates multiple keys.
# pylint: disable=no-self-use
return [index_name, f"{index_name}-offsets", f"{index_name}-type-ids", "mask"]
class PretrainedBertIndexer(TokenizerIndexer):
# pylint: disable=line-too-long
A ``TokenIndexer`` corresponding to a pretrained BERT model.
pretrained_model: ``str``
Either the name of the pretrained model to use (e.g. 'bert-base-uncased'),
or the path to the .txt file with its vocabulary.
If the name is a key in the list of pretrained models at
the corresponding path will be used; otherwise it will be interpreted as a path or URL.
do_lowercase: ``bool``, optional (default = True)
Whether to lowercase the tokens before converting to wordpiece ids.
max_pieces: int, optional (default: 512)
The BERT embedder uses positional embeddings and so has a corresponding
maximum length for its input ids. Any inputs longer than this will
either be truncated (default), or be split apart and batched using a
sliding window.
def __init__(self,
pretrained_model: str,
do_lowercase: bool = True,
max_pieces: int = 512,
max_pieces_per_token: int = 5,
special_tokens_fix: int = 0) -> None:
if pretrained_model.endswith("-cased") and do_lowercase:
logger.warning("Your BERT model appears to be cased, "
"but your indexer is lowercasing tokens.")
elif pretrained_model.endswith("-uncased") and not do_lowercase:
logger.warning("Your BERT model appears to be uncased, "
"but your indexer is not lowercasing tokens.")
model_name = copy.deepcopy(pretrained_model)
model_tokenizer = AutoTokenizer.from_pretrained(
model_name, do_lower_case=do_lowercase, do_basic_tokenize=False, use_fast=True)
# to adjust all tokenizers
if hasattr(model_tokenizer, 'encoder'):
model_tokenizer.vocab = model_tokenizer.encoder
if hasattr(model_tokenizer, 'sp_model'):
model_tokenizer.vocab = defaultdict(lambda: 1)
for i in range(model_tokenizer.sp_model.get_piece_size()):
model_tokenizer.vocab[model_tokenizer.sp_model.id_to_piece(i)] = i
if special_tokens_fix:
model_tokenizer.vocab[START_TOKEN] = len(model_tokenizer) - 1