Spaces:
Sleeping
Sleeping
import torch | |
from transformers import AutoTokenizer | |
from extended_embeddings.token_classification import ExtendedEmbeddigsRobertaForTokenClassification | |
from data_manipulation.dataset_funcions import load_gazetteers, gazetteer_matching, align_gazetteers_with_tokens | |
from data_manipulation.preprocess_gazetteers import build_reverse_dictionary | |
def load(): | |
model_name = "ufal/robeczech-base" | |
model_path = "bettystr/NerRoB-czech" | |
model = ExtendedEmbeddigsRobertaForTokenClassification.from_pretrained(model_path).to("cpu") | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model.eval() | |
gazetteers_path = "gazz2.json" | |
gazetteers_for_matching = load_gazetteers(gazetteers_path) | |
temp = [] | |
for i in gazetteers_for_matching.keys(): | |
temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]})) | |
gazetteers_for_matching = temp | |
return tokenizer, model, gazetteers_for_matching | |
def run(tokenizer, model, gazetteers_for_matching, text): | |
tokenized_inputs = tokenizer( | |
text, truncation=True, is_split_into_words=False | |
) | |
matches = gazetteer_matching(text, gazetteers_for_matching) | |
new_g = [] | |
word_ids = tokenized_inputs.word_ids() | |
new_g.append(align_gazetteers_with_tokens(matches, word_ids)) | |
p, o, l = [], [], [] | |
for i in new_g: | |
p.append([x[0] for x in i]) | |
o.append([x[1] for x in i]) | |
l.append([x[2] for x in i]) | |
input_ids = torch.tensor(tokenized_inputs["input_ids"], device="cpu").unsqueeze(0) | |
attention_mask = torch.tensor(tokenized_inputs["attention_mask"], device="cpu").unsqueeze(0) | |
per = torch.tensor(p, device="cpu") | |
org = torch.tensor(o, device="cpu") | |
loc = torch.tensor(l, device="cpu") | |
output = model(input_ids=input_ids, attention_mask=attention_mask, per=per, org=org, loc=loc).logits | |
predictions = torch.argmax(output, dim=2).tolist() | |
predicted_tags = [[model.config.id2label[idx] for idx in sentence] for sentence in predictions] | |
return " ".join(predicted_tags[0]) | |