Spaces:
Runtime error
Runtime error
import json | |
import os | |
from collections import Counter | |
from collections import OrderedDict | |
from typing import List | |
import torch | |
from ordered_set import OrderedSet | |
from transformers import AutoTokenizer | |
from common.utils import download, unzip_file | |
def get_tokenizer(tokenizer_name:str): | |
"""auto get tokenizer | |
Args: | |
tokenizer_name (str): support "word_tokenizer" and other pretrained tokenizer in hugging face. | |
Returns: | |
Any: Tokenizer Object | |
""" | |
if tokenizer_name == "word_tokenizer": | |
return WordTokenizer(tokenizer_name) | |
else: | |
return AutoTokenizer.from_pretrained(tokenizer_name) | |
def get_tokenizer_class(tokenizer_name:str): | |
"""auto get tokenizer class | |
Args: | |
tokenizer_name (str): support "word_tokenizer" and other pretrained tokenizer in hugging face. | |
Returns: | |
Any: Tokenizer Class | |
""" | |
if tokenizer_name == "word_tokenizer": | |
return WordTokenizer | |
else: | |
return AutoTokenizer.from_pretrained | |
BATCH_STATE = 1 | |
INSTANCE_STATE = 2 | |
class WordTokenizer(object): | |
def __init__(self, name): | |
self.__name = name | |
self.index2instance = OrderedSet() | |
self.instance2index = OrderedDict() | |
# Counter Object record the frequency | |
# of element occurs in raw text. | |
self.counter = Counter() | |
self.__sign_pad = "[PAD]" | |
self.add_instance(self.__sign_pad) | |
self.__sign_unk = "[UNK]" | |
self.add_instance(self.__sign_unk) | |
def padding_side(self): | |
return "right" | |
def all_special_ids(self): | |
return [self.unk_token_id, self.pad_token_id] | |
def name_or_path(self): | |
return self.__name | |
def vocab_size(self): | |
return len(self.instance2index) | |
def pad_token_id(self): | |
return self.instance2index[self.__sign_pad] | |
def unk_token_id(self): | |
return self.instance2index[self.__sign_unk] | |
def add_instance(self, instance): | |
""" Add instances to alphabet. | |
1, We support any iterative data structure which | |
contains elements of str type. | |
2, We will count added instances that will influence | |
the serialization of unknown instance. | |
Args: | |
instance: is given instance or a list of it. | |
""" | |
if isinstance(instance, (list, tuple)): | |
for element in instance: | |
self.add_instance(element) | |
return | |
# We only support elements of str type. | |
assert isinstance(instance, str) | |
# count the frequency of instances. | |
# self.counter[instance] += 1 | |
if instance not in self.index2instance: | |
self.instance2index[instance] = len(self.index2instance) | |
self.index2instance.append(instance) | |
def __call__(self, instance, | |
return_tensors="pt", | |
is_split_into_words=True, | |
padding=True, | |
add_special_tokens=False, | |
truncation=True, | |
max_length=512, | |
**config): | |
if isinstance(instance, (list, tuple)) and isinstance(instance[0], (str)) and is_split_into_words: | |
res = self.get_index(instance) | |
state = INSTANCE_STATE | |
elif isinstance(instance, str) and not is_split_into_words: | |
res = self.get_index(instance.split(" ")) | |
state = INSTANCE_STATE | |
elif not is_split_into_words and isinstance(instance, (list, tuple)): | |
res = [self.get_index(ins.split(" ")) for ins in instance] | |
state = BATCH_STATE | |
else: | |
res = [self.get_index(ins) for ins in instance] | |
state = BATCH_STATE | |
res = [r[:max_length] if len(r) >= max_length else r for r in res] | |
pad_id = self.get_index(self.__sign_pad) | |
if padding and state == BATCH_STATE: | |
max_len = max([len(x) for x in instance]) | |
for i in range(len(res)): | |
res[i] = res[i] + [pad_id] * (max_len - len(res[i])) | |
if return_tensors == "pt": | |
input_ids = torch.Tensor(res).long() | |
attention_mask = (input_ids != pad_id).long() | |
elif state == BATCH_STATE: | |
input_ids = res | |
attention_mask = [1 if r != pad_id else 0 for batch in res for r in batch] | |
else: | |
input_ids = res | |
attention_mask = [1 if r != pad_id else 0 for r in res] | |
return TokenizedData(input_ids, token_type_ids=attention_mask, attention_mask=attention_mask) | |
def get_index(self, instance): | |
""" Serialize given instance and return. | |
For unknown words, the return index of alphabet | |
depends on variable self.__use_unk: | |
1, If True, then return the index of "<UNK>"; | |
2, If False, then return the index of the | |
element that hold max frequency in training data. | |
Args: | |
instance (Any): is given instance or a list of it. | |
Return: | |
Any: the serialization of query instance. | |
""" | |
if isinstance(instance, (list, tuple)): | |
return [self.get_index(elem) for elem in instance] | |
assert isinstance(instance, str) | |
try: | |
return self.instance2index[instance] | |
except KeyError: | |
return self.instance2index[self.__sign_unk] | |
def decode(self, index): | |
""" Get corresponding instance of query index. | |
if index is invalid, then throws exception. | |
Args: | |
index (int): is query index, possibly iterable. | |
Returns: | |
is corresponding instance. | |
""" | |
if isinstance(index, list): | |
return [self.decode(elem) for elem in index] | |
if isinstance(index, torch.Tensor): | |
index = index.tolist() | |
return self.decode(index) | |
return self.index2instance[index] | |
def decode_batch(self, index, **kargs): | |
""" Get corresponding instance of query index. | |
if index is invalid, then throws exception. | |
Args: | |
index (int): is query index, possibly iterable. | |
Returns: | |
is corresponding instance. | |
""" | |
return self.decode(index) | |
def save(self, path): | |
""" Save the content of alphabet to files. | |
There are two kinds of saved files: | |
1, The first is a list file, elements are | |
sorted by the frequency of occurrence. | |
2, The second is a dictionary file, elements | |
are sorted by it serialized index. | |
Args: | |
path (str): is the path to save object. | |
""" | |
with open(path, 'w', encoding="utf8") as fw: | |
fw.write(json.dumps({"name": self.__name, "token_map": self.instance2index})) | |
def from_file(path): | |
with open(path, 'r', encoding="utf8") as fw: | |
obj = json.load(fw) | |
tokenizer = WordTokenizer(obj["name"]) | |
tokenizer.instance2index = OrderedDict(obj["token_map"]) | |
# tokenizer.counter = len(tokenizer.instance2index) | |
tokenizer.index2instance = OrderedSet(tokenizer.instance2index.keys()) | |
return tokenizer | |
def __len__(self): | |
return len(self.index2instance) | |
def __str__(self): | |
return 'Alphabet {} contains about {} words: \n\t{}'.format(self.name_or_path, len(self), self.index2instance) | |
def convert_tokens_to_ids(self, tokens): | |
"""convert token sequence to intput ids sequence | |
Args: | |
tokens (Any): token sequence | |
Returns: | |
Any: intput ids sequence | |
""" | |
try: | |
if isinstance(tokens, (list, tuple)): | |
return [self.instance2index[x] for x in tokens] | |
return self.instance2index[tokens] | |
except KeyError: | |
return self.instance2index[self.__sign_unk] | |
class TokenizedData(): | |
"""tokenized output data with input_ids, token_type_ids, attention_mask | |
""" | |
def __init__(self, input_ids, token_type_ids, attention_mask): | |
self.input_ids = input_ids | |
self.token_type_ids = token_type_ids | |
self.attention_mask = attention_mask | |
def word_ids(self, index: int) -> List[int or None]: | |
""" get word id list | |
Args: | |
index (int): word index in sequence | |
Returns: | |
List[int or None]: word id list | |
""" | |
return [j if self.attention_mask[index][j] != 0 else None for j, x in enumerate(self.input_ids[index])] | |
def word_to_tokens(self, index, word_id, **kwargs): | |
"""map word and tokens | |
Args: | |
index (int): unused | |
word_id (int): word index in sequence | |
""" | |
return (word_id, word_id + 1) | |
def to(self, device): | |
"""set device | |
Args: | |
device (str): support ["cpu", "cuda"] | |
""" | |
self.input_ids = self.input_ids.to(device) | |
self.token_type_ids = self.token_type_ids.to(device) | |
self.attention_mask = self.attention_mask.to(device) | |
return self | |
def load_embedding(tokenizer: WordTokenizer, glove_name:str): | |
""" load embedding from standford server or local cache. | |
Args: | |
tokenizer (WordTokenizer): non-pretrained tokenizer | |
glove_name (str): _description_ | |
Returns: | |
Any: word embedding | |
""" | |
save_path = "save/" + glove_name + ".zip" | |
if not os.path.exists(save_path): | |
download("http://downloads.cs.stanford.edu/nlp/data/glove.6B.zip#" + glove_name, save_path) | |
unzip_file(save_path, "save/" + glove_name) | |
dim = int(glove_name.split(".")[-2][:-1]) | |
embedding_list = torch.rand((tokenizer.vocab_size, dim)) | |
embedding_list[tokenizer.pad_token_id] = torch.zeros((1, dim)) | |
with open("save/" + glove_name + "/" + glove_name, "r", encoding="utf8") as f: | |
for line in f.readlines(): | |
datas = line.split(" ") | |
word = datas[0] | |
embedding = torch.Tensor([float(datas[i + 1]) for i in range(len(datas) - 1)]) | |
tokenized = tokenizer.convert_tokens_to_ids(word) | |
if isinstance(tokenized, int) and tokenized != tokenizer.unk_token_id: | |
embedding_list[tokenized] = embedding | |
return embedding_list | |