import json import copy import pickle import torch from simplemma import lemmatize from transformers import AutoTokenizer from extended_embeddings.extended_embedding_token_classification import ExtendedEmbeddigsRobertaForTokenClassification from data_manipulation.dataset_funcions import gazetteer_matching, align_gazetteers_with_tokens # code originaly from data_manipulation.creation_gazetteers def lemmatizing(x): if x == "": return "" return lemmatize(x, lang="cs") # code originaly from data_manipulation.creation_gazetteers def build_reverse_dictionary(dictionary, apply_lemmatizing=False): reverse_dictionary = {} for key, values in dictionary.items(): for value in values: reverse_dictionary[value] = key if apply_lemmatizing: temp = lemmatizing(value) if temp != value: reverse_dictionary[temp] = key return reverse_dictionary def load_json(path): """ Load gazetteers from a file :param path: path to the gazetteer file :return: a dict of gazetteers """ with open(path, 'r') as file: data = json.load(file) return data def load_pickle(path): """ Load pickle gazetteers from a file :param path: path to the gazetteer file :return: a dict of gazetteers """ with open(path, 'rb') as file: data = pickle.load(file) return data def load(): """ Load the tokenizer, model, and gazetteers for named entity recognition. Returns: tokenizer (AutoTokenizer): The tokenizer for tokenizing input text. model (ExtendedEmbeddigsRobertaForTokenClassification): The pre-trained model for named entity recognition. gazetteers_for_matching (list): A list of gazetteers for matching named entities. """ model_name = "ufal/robeczech-base" model_path = "bettystr/NerRoB-czech" gazetteers_path = "gazetteers.pkl" model = ExtendedEmbeddigsRobertaForTokenClassification.from_pretrained(model_path).to("cpu") tokenizer = AutoTokenizer.from_pretrained(model_name) model.eval() gazetteers_for_matching = load_pickle(gazetteers_path) temp = [] for i in gazetteers_for_matching.keys(): temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]})) gazetteers_for_matching = temp return tokenizer, model, gazetteers_for_matching def add_additional_gazetteers(gazetteers_for_matching, file_names): """ Adds additional gazetteers to the existing dict. Args: gazetteers_for_matching (dict): The list of gazetteers to be updated. file_names (list): The list of file names containing additional gazetteers. Returns: dict: The updated list of gazetteers. """ if file_names is None or file_names == []: return gazetteers_for_matching temp = [] for l1 in gazetteers_for_matching: d2 = copy.deepcopy(l1) temp.append(d2) for file_name in file_names: with open(file_name, 'r') as file: data = json.load(file) for key, value_lst in data.items(): key = key.upper() for dictionary in temp: if key in dictionary.values(): for value in value_lst: dictionary[value] = key return temp def run(tokenizer, model, gazetteers, text, file_names=None): """ Runs the named entity recognition (NER) model on the given text. Args: tokenizer (Tokenizer): The tokenizer used to tokenize the input text. model (Model): The NER model used for prediction. gazetteers (list): A list of gazetteers used for matching entities in the text. text (str): The input text to perform NER on. file_names (list, optional): A list of file names to be used as additional gazetteers. Returns: list: A list of dictionaries representing the predicted entities in the text. Each dictionary contains the following keys: - "start" (int): The starting position of the entity in the text. - "end" (int): The ending position of the entity in the text. - "entity" (str): The type of the entity. - "score" (float): The confidence score of the entity prediction. - "word" (str): The actual word representing the entity. - "count" (int): The number of tokens in the entity. """ gazetteers_for_matching = add_additional_gazetteers(gazetteers, file_names) tokenized_inputs = tokenizer( text, truncation=True, is_split_into_words=False, return_offsets_mapping=True ) matches = gazetteer_matching(text, gazetteers_for_matching) new_g = [] word_ids = tokenized_inputs.word_ids() new_g.append(align_gazetteers_with_tokens(matches, word_ids)) p, o, l = [], [], [] for i in new_g: p.append([x[0] for x in i]) o.append([x[1] for x in i]) l.append([x[2] for x in i]) input_ids = torch.tensor(tokenized_inputs["input_ids"], device="cpu").unsqueeze(0) attention_mask = torch.tensor(tokenized_inputs["attention_mask"], device="cpu").unsqueeze(0) per = torch.tensor(p, device="cpu") org = torch.tensor(o, device="cpu") loc = torch.tensor(l, device="cpu") output = model(input_ids=input_ids, attention_mask=attention_mask, per=per, org=org, loc=loc).logits predictions = torch.argmax(output, dim=2).tolist() predicted_tags = [[model.config.id2label[idx] for idx in sentence] for sentence in predictions] softmax = torch.nn.Softmax(dim=2) scores = softmax(output).squeeze(0).tolist() result = [] temp = { "start": 0, "end": 0, "entity": "O", "score": 0, "word": "", "count": 0 } for pos, entity, score in zip(tokenized_inputs.offset_mapping, predicted_tags[0], scores): if pos[0] == pos[1] or entity == "O": continue if "I-" + temp["entity"] == entity: # same entity temp["word"] += text[temp["end"]:pos[0]] + text[pos[0]:pos[1]] temp["end"] = pos[1] temp["count"] += 1 temp["score"] += max(score) else: # new entity if temp["count"] > 0: temp["score"] += max(score) temp["score"] /= temp.pop("count") result.append(temp) temp = { "start": pos[0], "end": pos[1], "entity": entity[2:], "score": 0, "word": text[pos[0]:pos[1]], "count": 1 } if temp["count"] > 0: temp["score"] += max(score) temp["score"] /= temp.pop("count") result.append(temp) return result