NerRoB-czech / website_script.py
AlzbetaStrompova
Initial commit
7e6964a
raw
history blame
2.04 kB
import torch
from transformers import AutoTokenizer
from extended_embeddings.token_classification import ExtendedEmbeddigsRobertaForTokenClassification
from data_manipulation.dataset_funcions import load_gazetteers, gazetteer_matching, align_gazetteers_with_tokens
from data_manipulation.preprocess_gazetteers import build_reverse_dictionary
def load():
model_name = "ufal/robeczech-base"
model_path = "bettystr/NerRoB-czech"
model = ExtendedEmbeddigsRobertaForTokenClassification.from_pretrained(model_path).to("cpu")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.eval()
gazetteers_path = "gazz2.json"
gazetteers_for_matching = load_gazetteers(gazetteers_path)
temp = []
for i in gazetteers_for_matching.keys():
temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]}))
gazetteers_for_matching = temp
return tokenizer, model, gazetteers_for_matching
def run(tokenizer, model, gazetteers_for_matching, text):
tokenized_inputs = tokenizer(
text, truncation=True, is_split_into_words=False
)
matches = gazetteer_matching(text, gazetteers_for_matching)
new_g = []
word_ids = tokenized_inputs.word_ids()
new_g.append(align_gazetteers_with_tokens(matches, word_ids))
p, o, l = [], [], []
for i in new_g:
p.append([x[0] for x in i])
o.append([x[1] for x in i])
l.append([x[2] for x in i])
input_ids = torch.tensor(tokenized_inputs["input_ids"], device="cpu").unsqueeze(0)
attention_mask = torch.tensor(tokenized_inputs["attention_mask"], device="cpu").unsqueeze(0)
per = torch.tensor(p, device="cpu")
org = torch.tensor(o, device="cpu")
loc = torch.tensor(l, device="cpu")
output = model(input_ids=input_ids, attention_mask=attention_mask, per=per, org=org, loc=loc).logits
predictions = torch.argmax(output, dim=2).tolist()
predicted_tags = [[model.config.id2label[idx] for idx in sentence] for sentence in predictions]
return " ".join(predicted_tags[0])