GabMartino
Some fixes.
16c2f8d
from types import SimpleNamespace
import numpy as np
import torch
from torch import nn
from transformers import BertTokenizerFast, BertForMaskedLM, BertTokenizer, BertModel
from tensor2tensor.data_generators import text_encoder
import torch.nn.functional as F
class LatinBERT(nn.Module):
def __init__(self, bertPath, tokenizerPath):
super().__init__()
self.tokenizer = LatinTokenizer(tokenizerPath) #BertTokenizer.from_pretrained("bert-base-cased")
self.model = BertModel.from_pretrained(bertPath)#.to("cuda")
self.model.eval()
@torch.no_grad()
def __call__(self, sentences):
if not isinstance(sentences, list):
sentences = [sentences]
tokens_ids, masks, transforms = self.tokenizer.tokenize(sentences, 512)
#tokens_ids = tokens_ids.to("cuda")
#tokens_ids = tokens_ids.squeeze()
if tokens_ids.shape[-1] > 512:
tokens_ids = torch.narrow(tokens_ids, -1, 0, 512)
tokens_ids = tokens_ids.reshape((-1, tokens_ids.shape[-1]))
outputs = self.model.forward(tokens_ids)
embeddings = outputs.pooler_output
embeddings = F.normalize(embeddings, p=2).cpu()
return embeddings
@property
def dim(self):
return 768
class LatinTokenizer:
def __init__(self, model):
self.vocab = dict()
self.reverseVocab = dict()
self.encoder = text_encoder.SubwordTextEncoder(model)
self.vocab["[PAD]"] = 0
self.vocab["[UNK]"] = 1
self.vocab["[CLS]"] = 2
self.vocab["[SEP]"] = 3
self.vocab["[MASK]"] = 4
for key in self.encoder._subtoken_string_to_id:
self.vocab[key] = self.encoder._subtoken_string_to_id[key] + 5
self.reverseVocab[self.encoder._subtoken_string_to_id[key] + 5] = key
def convert_tokens_to_ids(self, tokens):
wp_tokens = list()
for token in tokens:
if token == "[PAD]":
wp_tokens.append(0)
elif token == "[UNK]":
wp_tokens.append(1)
elif token == "[CLS]":
wp_tokens.append(2)
elif token == "[SEP]":
wp_tokens.append(3)
elif token == "[MASK]":
wp_tokens.append(4)
else:
wp_tokens.append(self.vocab[token])
return wp_tokens
def tokenize(self, sentences, max_batch):
#print(len(sentences))
maxLen=0
for sentence in sentences:
length=0
for word in sentence:
toks=self._tokenize(word)
length+=len(toks)
if length> maxLen:
maxLen=length
#print(maxLen)
all_data=[]
all_masks=[]
all_labels=[]
all_transforms=[]
for sentence in sentences:
tok_ids=[]
input_mask=[]
labels=[]
transform=[]
all_toks=[]
n=0
for idx, word in enumerate(sentence):
toks=self._tokenize(word)
all_toks.append(toks)
n+=len(toks)
cur=0
for idx, word in enumerate(sentence):
toks=all_toks[idx]
ind=list(np.zeros(n))
for j in range(cur,cur+len(toks)):
ind[j]=1./len(toks)
cur+=len(toks)
transform.append(ind)
tok_ids.extend(self.convert_tokens_to_ids(toks))
input_mask.extend(np.ones(len(toks)))
labels.append(1)
all_data.append(tok_ids)
all_masks.append(input_mask)
all_labels.append(labels)
all_transforms.append(transform)
lengths = np.array([len(l) for l in all_data])
# Note sequence must be ordered from shortest to longest so current_batch will work
ordering = np.argsort(lengths)
ordered_data = [None for i in range(len(all_data))]
ordered_masks = [None for i in range(len(all_data))]
ordered_labels = [None for i in range(len(all_data))]
ordered_transforms = [None for i in range(len(all_data))]
for i, ind in enumerate(ordering):
ordered_data[i] = all_data[ind]
ordered_masks[i] = all_masks[ind]
ordered_labels[i] = all_labels[ind]
ordered_transforms[i] = all_transforms[ind]
batched_data=[]
batched_mask=[]
batched_labels=[]
batched_transforms=[]
i=0
current_batch=max_batch
while i < len(ordered_data):
batch_data=ordered_data[i:i+current_batch]
batch_mask=ordered_masks[i:i+current_batch]
batch_labels=ordered_labels[i:i+current_batch]
batch_transforms=ordered_transforms[i:i+current_batch]
max_len = max([len(sent) for sent in batch_data])
max_label = max([len(label) for label in batch_labels])
for j in range(len(batch_data)):
blen=len(batch_data[j])
blab=len(batch_labels[j])
for k in range(blen, max_len):
batch_data[j].append(0)
batch_mask[j].append(0)
for z in range(len(batch_transforms[j])):
batch_transforms[j][z].append(0)
for k in range(blab, max_label):
batch_labels[j].append(-100)
for k in range(len(batch_transforms[j]), max_label):
batch_transforms[j].append(np.zeros(max_len))
batched_data.append(batch_data)
batched_mask.append(batch_mask)
batched_labels.append(batch_labels)
batched_transforms.append(batch_transforms)
#bsize=torch.FloatTensor(batch_transforms).shape
i+=current_batch
# adjust batch size; sentences are ordered from shortest to longest so decrease as they get longer
if max_len > 100:
current_batch=12
if max_len > 200:
current_batch=6
#print(len(batch_data), len(batch_mask), len(batch_transforms))
return torch.LongTensor(batched_data).squeeze(), torch.FloatTensor(batched_mask).squeeze(), torch.FloatTensor(batched_transforms).squeeze()
'''
def _tokenize(self, text):
if not isinstance(text, list):
text = [text]
outputs = []
for sentence in text:
tokens = sentence.split(" ")
wp_tokens = []
for token in tokens:
if token in ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"]:
wp_tokens.append(token)
else:
wp_toks = self.encoder.encode(token)
for wp in wp_toks:
wp_tokens.append(self.reverseVocab[wp + 5])
outputs.append(SimpleNamespace(
tokens=wp_tokens,
ids=torch.Tensor(self.convert_tokens_to_ids(wp_tokens))
))
return outputs
'''
def _tokenize(self, text):
tokens = text.split(" ")
wp_tokens = []
for token in tokens:
if token in {"[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"}:
wp_tokens.append(token)
else:
wp_toks = self.encoder.encode(token)
for wp in wp_toks:
wp_tokens.append(self.reverseVocab[wp + 5])
#print(wp_tokens)
return wp_tokens
def main():
model = LatinBERT("../../latinBert/latin_bert/models/latin_bert", tokenizerPath="./tokenizer/latin.subword.encoder")
sents = ["arma virumque cano", "arma gravi numero violentaque bella parabam"]
output = model(sents)
print("end", output.shape)
if __name__ == "__main__":
main()