|
from types import SimpleNamespace |
|
|
|
import numpy as np |
|
import torch |
|
from torch import nn |
|
from transformers import BertTokenizerFast, BertForMaskedLM, BertTokenizer, BertModel |
|
from tensor2tensor.data_generators import text_encoder |
|
import torch.nn.functional as F |
|
|
|
|
|
class LatinBERT(nn.Module): |
|
|
|
def __init__(self, bertPath, tokenizerPath): |
|
super().__init__() |
|
self.tokenizer = LatinTokenizer(tokenizerPath) |
|
self.model = BertModel.from_pretrained(bertPath) |
|
self.model.eval() |
|
|
|
@torch.no_grad() |
|
def __call__(self, sentences): |
|
if not isinstance(sentences, list): |
|
sentences = [sentences] |
|
|
|
tokens_ids, masks, transforms = self.tokenizer.tokenize(sentences, 512) |
|
|
|
|
|
if tokens_ids.shape[-1] > 512: |
|
tokens_ids = torch.narrow(tokens_ids, -1, 0, 512) |
|
|
|
tokens_ids = tokens_ids.reshape((-1, tokens_ids.shape[-1])) |
|
outputs = self.model.forward(tokens_ids) |
|
embeddings = outputs.pooler_output |
|
embeddings = F.normalize(embeddings, p=2).cpu() |
|
return embeddings |
|
|
|
@property |
|
def dim(self): |
|
return 768 |
|
|
|
|
|
class LatinTokenizer: |
|
def __init__(self, model): |
|
self.vocab = dict() |
|
self.reverseVocab = dict() |
|
self.encoder = text_encoder.SubwordTextEncoder(model) |
|
|
|
self.vocab["[PAD]"] = 0 |
|
self.vocab["[UNK]"] = 1 |
|
self.vocab["[CLS]"] = 2 |
|
self.vocab["[SEP]"] = 3 |
|
self.vocab["[MASK]"] = 4 |
|
|
|
for key in self.encoder._subtoken_string_to_id: |
|
self.vocab[key] = self.encoder._subtoken_string_to_id[key] + 5 |
|
self.reverseVocab[self.encoder._subtoken_string_to_id[key] + 5] = key |
|
|
|
def convert_tokens_to_ids(self, tokens): |
|
wp_tokens = list() |
|
for token in tokens: |
|
if token == "[PAD]": |
|
wp_tokens.append(0) |
|
elif token == "[UNK]": |
|
wp_tokens.append(1) |
|
elif token == "[CLS]": |
|
wp_tokens.append(2) |
|
elif token == "[SEP]": |
|
wp_tokens.append(3) |
|
elif token == "[MASK]": |
|
wp_tokens.append(4) |
|
else: |
|
wp_tokens.append(self.vocab[token]) |
|
|
|
return wp_tokens |
|
|
|
def tokenize(self, sentences, max_batch): |
|
|
|
maxLen=0 |
|
for sentence in sentences: |
|
length=0 |
|
for word in sentence: |
|
toks=self._tokenize(word) |
|
length+=len(toks) |
|
|
|
if length> maxLen: |
|
maxLen=length |
|
|
|
all_data=[] |
|
all_masks=[] |
|
all_labels=[] |
|
all_transforms=[] |
|
|
|
for sentence in sentences: |
|
tok_ids=[] |
|
input_mask=[] |
|
labels=[] |
|
transform=[] |
|
|
|
all_toks=[] |
|
n=0 |
|
for idx, word in enumerate(sentence): |
|
toks=self._tokenize(word) |
|
all_toks.append(toks) |
|
n+=len(toks) |
|
|
|
cur=0 |
|
for idx, word in enumerate(sentence): |
|
toks=all_toks[idx] |
|
ind=list(np.zeros(n)) |
|
for j in range(cur,cur+len(toks)): |
|
ind[j]=1./len(toks) |
|
cur+=len(toks) |
|
transform.append(ind) |
|
|
|
tok_ids.extend(self.convert_tokens_to_ids(toks)) |
|
|
|
input_mask.extend(np.ones(len(toks))) |
|
labels.append(1) |
|
|
|
all_data.append(tok_ids) |
|
all_masks.append(input_mask) |
|
all_labels.append(labels) |
|
all_transforms.append(transform) |
|
|
|
lengths = np.array([len(l) for l in all_data]) |
|
|
|
|
|
ordering = np.argsort(lengths) |
|
|
|
ordered_data = [None for i in range(len(all_data))] |
|
ordered_masks = [None for i in range(len(all_data))] |
|
ordered_labels = [None for i in range(len(all_data))] |
|
ordered_transforms = [None for i in range(len(all_data))] |
|
|
|
|
|
for i, ind in enumerate(ordering): |
|
ordered_data[i] = all_data[ind] |
|
ordered_masks[i] = all_masks[ind] |
|
ordered_labels[i] = all_labels[ind] |
|
ordered_transforms[i] = all_transforms[ind] |
|
|
|
batched_data=[] |
|
batched_mask=[] |
|
batched_labels=[] |
|
batched_transforms=[] |
|
|
|
i=0 |
|
current_batch=max_batch |
|
|
|
while i < len(ordered_data): |
|
|
|
batch_data=ordered_data[i:i+current_batch] |
|
batch_mask=ordered_masks[i:i+current_batch] |
|
batch_labels=ordered_labels[i:i+current_batch] |
|
batch_transforms=ordered_transforms[i:i+current_batch] |
|
|
|
max_len = max([len(sent) for sent in batch_data]) |
|
max_label = max([len(label) for label in batch_labels]) |
|
|
|
for j in range(len(batch_data)): |
|
|
|
blen=len(batch_data[j]) |
|
blab=len(batch_labels[j]) |
|
|
|
for k in range(blen, max_len): |
|
batch_data[j].append(0) |
|
batch_mask[j].append(0) |
|
for z in range(len(batch_transforms[j])): |
|
batch_transforms[j][z].append(0) |
|
|
|
for k in range(blab, max_label): |
|
batch_labels[j].append(-100) |
|
|
|
for k in range(len(batch_transforms[j]), max_label): |
|
batch_transforms[j].append(np.zeros(max_len)) |
|
|
|
batched_data.append(batch_data) |
|
batched_mask.append(batch_mask) |
|
batched_labels.append(batch_labels) |
|
batched_transforms.append(batch_transforms) |
|
|
|
|
|
|
|
i+=current_batch |
|
|
|
|
|
if max_len > 100: |
|
current_batch=12 |
|
if max_len > 200: |
|
current_batch=6 |
|
|
|
|
|
return torch.LongTensor(batched_data).squeeze(), torch.FloatTensor(batched_mask).squeeze(), torch.FloatTensor(batched_transforms).squeeze() |
|
|
|
''' |
|
|
|
def _tokenize(self, text): |
|
if not isinstance(text, list): |
|
text = [text] |
|
|
|
outputs = [] |
|
for sentence in text: |
|
tokens = sentence.split(" ") |
|
wp_tokens = [] |
|
for token in tokens: |
|
if token in ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]: |
|
wp_tokens.append(token) |
|
else: |
|
wp_toks = self.encoder.encode(token) |
|
for wp in wp_toks: |
|
wp_tokens.append(self.reverseVocab[wp + 5]) |
|
|
|
outputs.append(SimpleNamespace( |
|
tokens=wp_tokens, |
|
ids=torch.Tensor(self.convert_tokens_to_ids(wp_tokens)) |
|
)) |
|
return outputs |
|
|
|
''' |
|
|
|
def _tokenize(self, text): |
|
tokens = text.split(" ") |
|
wp_tokens = [] |
|
for token in tokens: |
|
|
|
if token in {"[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"}: |
|
wp_tokens.append(token) |
|
else: |
|
|
|
wp_toks = self.encoder.encode(token) |
|
|
|
for wp in wp_toks: |
|
wp_tokens.append(self.reverseVocab[wp + 5]) |
|
|
|
return wp_tokens |
|
|
|
def main(): |
|
model = LatinBERT("../../latinBert/latin_bert/models/latin_bert", tokenizerPath="./tokenizer/latin.subword.encoder") |
|
|
|
sents = ["arma virumque cano", "arma gravi numero violentaque bella parabam"] |
|
|
|
|
|
output = model(sents) |
|
print("end", output.shape) |
|
|
|
if __name__ == "__main__": |
|
main() |