NerRoB-czech / data_manipulation /dataset_funcions.py
AlzbetaStrompova
minor changes
75a65be
raw
history blame
16.7 kB
import os
import re
from tqdm import tqdm
from datasets import Dataset, DatasetDict
from data_manipulation.creation_gazetteers import build_reverse_dictionary, lemmatizing, load_json
####################################################################################################
### GAZETTEERS EMBEDDINGS ##########################################################################
####################################################################################################
def find_multi_token_matches(tokens, looking_tokens, gazetteers, matches):
i = 0
n = len(tokens)
assert n == len(looking_tokens)
while i < n:
for length in range(min(5, n-i), 0, -1): # Assuming maximum entity length is 5
phrase = ' '.join(looking_tokens[i:i+length])
for gazetteer in gazetteers:
if phrase in gazetteer:
match_type = gazetteer[phrase]
for index in range(i, i+length):
matches.setdefault(tokens[index], []).append((phrase, match_type))
i += 1
return matches
def find_single_token_matches(tokens, looking_tokens, gazetteers, matches):
n = len(tokens)
assert n == len(looking_tokens)
for index in range(n):
word = looking_tokens[index]
if len(word) < 3:
continue
for gazetteer in gazetteers:
if word in gazetteer:
match_type = gazetteer[word]
matches.setdefault(tokens[index], []).append((word, match_type))
return matches
def gazetteer_matching(words, gazetteers_for_matching, args=None):
ending_ova = True
method_for_gazetteers_matching = "single"
apply_lemmatizing = True
if method_for_gazetteers_matching == "single":
matches = find_single_token_matches(words, words, gazetteers_for_matching, {})
if apply_lemmatizing:
lemmatize_tokens = [lemmatizing(t) for t in words]
matches = find_single_token_matches(words, lemmatize_tokens, gazetteers_for_matching, matches)
else: # multi_token_match
matches = find_multi_token_matches(words, words, gazetteers_for_matching, {})
if apply_lemmatizing:
lemmatize_tokens = [lemmatizing(t) for t in words]
matches = find_multi_token_matches(words, lemmatize_tokens, gazetteers_for_matching, matches)
result = []
for word in words:
mid_res = sorted(matches.get(word, []), key=lambda x: x[0].count(" "), reverse=True)
per, org, loc = 0, 0, 0
for res in mid_res:
if mid_res[0][0].count(" ") == res[0].count(" "):
if res[1] == "PER":
per = 5
elif res[1] == "ORG":
org = 5
elif res[1] == "LOC":
loc = 5
if ending_ova and word.endswith("ová") and word[0].isupper():
per = 5
result.append([per, org, loc])
return result
####################################################################################################
### CNEC DATASET ###################################################################################
####################################################################################################
def get_dataset_from_cnec(label_mapper:dict, xml_file_path, args):
"""
label_mapper: cnec labels to int
"""
# Open and read the XML file as plain text
id_ = 0
with open(xml_file_path, "r", encoding="utf-8") as xml_file:
plain_text = xml_file.read()
plain_text = plain_text[5:-5] # remove unnessery characters
plain_text = re.sub(r'([a-zA-Z.])<ne', r'\1 <ne', plain_text)
plain_text = re.sub(r'</ne>([a-zA-Z.])', r'</ne> \1', plain_text)
plain_text = re.sub(r'[ ]+', ' ', plain_text)
sentences = plain_text.split("\n")
ne_pattern = r'<ne type="([a-zA-Z?_-]{1,5})">([^<]+)</ne>'
data = []
if args.apply_extended_embeddings:
gazetteers_for_matching = load_json(args.extended_embeddings_gazetteers_path)
temp = []
for i in gazetteers_for_matching.keys():
temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]}))
gazetteers_for_matching = temp
for sentence in tqdm(sentences):
entity_mapping = []
while "<ne type=" in sentence: # while because there are nested entities
nes = re.findall(ne_pattern, sentence)
for label, entity in nes:
pattern = f'<ne type="{label}">{entity}</ne>'
index = sentence.index(pattern)
temp_index = index
sentence = sentence.replace(pattern, entity, 1)
temp_index -= sum([len(f'<ne type="{tag}">') for tag in re.findall(r'<ne type="([a-zA-Z?_-]{1,5})">', sentence[:index])])
temp_index -= sentence[:index].count("</ne>") * len("</ne>")
temp_index -= (re.sub(r'<ne type="([a-zA-Z?_-]{1,5})">', "", sentence[:index]).replace("</ne>", "")).count(" ")
index = temp_index
entity_mapping.append((entity, label, index, index + len(entity)))
entities = []
for entity, label, start, end in entity_mapping:
for tag in label_mapper.keys():
if label.lower().startswith(tag):
entities.append((label_mapper[tag], entity, start, end))
break
entities.sort(key=lambda x: len(x[1]), reverse=True)
words = re.split(r'\s+', sentence)
tags_per_word = []
sentence_counter = -1
for word in words:
sentence_counter += len(word) + 1
if len(entities) == 0:
tags_per_word.append(0) # tag representing no label for no word
for index_entity in range(len(entities)):
if not(sentence_counter - len(word) >= entities[index_entity][2] and
sentence_counter <= entities[index_entity][3] and
word in entities[index_entity][1]):
if index_entity == len(entities) - 1:
tags_per_word.append(0) # tag representing no label for word
continue
if args.division_to_BI_tags:
if sentence_counter - len(word) == entities[index_entity][2]:
tags_per_word.append(entities[index_entity][0] * 2 - 1) # beggining of entity
else:
tags_per_word.append(entities[index_entity][0] * 2) # inside of entity
else:
tags_per_word.append(entities[index_entity][0])
break
if args.contain_only_label_sentences and tags_per_word.count(0) == len(tags_per_word):
continue
if tags_per_word == [] or tags_per_word == [0]:
continue
if args.apply_extended_embeddings:
matching = gazetteer_matching(words, gazetteers_for_matching, args)
data.append({"id": id_, 'tokens': words, 'ner_tags': tags_per_word,
"sentence": " ".join(words), "gazetteers": matching})
else:
data.append({"id": id_, 'tokens': words, 'ner_tags': tags_per_word, "sentence": " ".join(words)})
id_ += 1
return data
def get_default_dataset_from_cnec(label_mapper:dict, xml_file_path):
"""
label_mapper: cnec labels to int
"""
# Open and read the XML file as plain text
id_ = 0
with open(xml_file_path, "r", encoding="utf-8") as xml_file:
plain_text = xml_file.read()
plain_text = plain_text[5:-5] # remove unnessery characters
plain_text = re.sub(r'([a-zA-Z.])<ne', r'\1 <ne', plain_text)
plain_text = re.sub(r'</ne>([a-zA-Z.])', r'</ne> \1', plain_text)
plain_text = re.sub(r'[ ]+', ' ', plain_text)
sentences = plain_text.split("\n")
ne_pattern = r'<ne type="([a-zA-Z?_-]{1,5})">([^<]+)</ne>'
data = []
for sentence in tqdm(sentences):
entity_mapping = []
while "<ne type=" in sentence: # while because there are nested entities
nes = re.findall(ne_pattern, sentence)
for label, entity in nes:
pattern = f'<ne type="{label}">{entity}</ne>'
index = sentence.index(pattern)
temp_index = index
sentence = sentence.replace(pattern, entity, 1)
temp_index -= sum([len(f'<ne type="{tag}">') for tag in re.findall(r'<ne type="([a-zA-Z?_-]{1,5})">', sentence[:index])])
temp_index -= sentence[:index].count("</ne>") * len("</ne>")
temp_index -= (re.sub(r'<ne type="([a-zA-Z?_-]{1,5})">', "", sentence[:index]).replace("</ne>", "")).count(" ")
index = temp_index
entity_mapping.append((entity, label, index, index + len(entity)))
entities = []
for entity, label, start, end in entity_mapping:
for tag in label_mapper.keys():
if label.lower().startswith(tag):
entities.append((label_mapper[tag], entity, start, end))
break
entities.sort(key=lambda x: len(x[1]), reverse=True)
words = re.split(r'\s+', sentence)
tags_per_word = []
sentence_counter = -1
for word in words:
sentence_counter += len(word) + 1
if len(entities) == 0:
tags_per_word.append(0) # tag representing no label for no word
for index_entity in range(len(entities)):
if not(sentence_counter - len(word) >= entities[index_entity][2] and
sentence_counter <= entities[index_entity][3] and
word in entities[index_entity][1]):
if index_entity == len(entities) - 1:
tags_per_word.append(0) # tag representing no label for word
continue
if sentence_counter - len(word) == entities[index_entity][2]:
tags_per_word.append(entities[index_entity][0] * 2 - 1) # beggining of entity
else:
tags_per_word.append(entities[index_entity][0] * 2) # inside of entity
if tags_per_word == [] or tags_per_word == [0]:
continue
data.append({"id": id_, 'tokens': words, 'ner_tags': tags_per_word, "sentence": " ".join(words)})
id_ += 1
return data
def create_cnec_dataset(label_mapper:dict, args):
dataset = DatasetDict()
for part, file_name in zip(["train", "validation", "test"],["named_ent_train.xml", "named_ent_etest.xml", "named_ent_dtest.xml"]):
file_path = os.path.join(args.cnec_dataset_dir_path, file_name)
temp_dataset = get_dataset_from_cnec(label_mapper, file_path, args)
dataset[part] = Dataset.from_list(temp_dataset)
return dataset
####################################################################################################
### WIKIANN DATASET ################################################################################
####################################################################################################
def load_wikiann_testing_dataset(args):
if args.apply_extended_embeddings:
gazetteers_for_matching = load_json(args.extended_embeddings_gazetteers_path)
temp = []
for i in gazetteers_for_matching.keys():
temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]}))
gazetteers_for_matching = temp
dataset = []
index = 0
sentences = load_tagged_sentences(args.wikiann_dataset_path)
for sentence in sentences:
words = [word for word, _ in sentence]
tags = [tag for _, tag in sentence]
if args.apply_extended_embeddings:
matching = gazetteer_matching(words, gazetteers_for_matching, args)
dataset.append({"id": index, 'tokens': words, 'ner_tags': tags, "gazetteers": matching})
else:
dataset.append({"id": index, 'tokens': words, 'ner_tags': tags})
index += 1
test = Dataset.from_list(dataset)
dataset = DatasetDict({"train": Dataset.from_list([{"id": 1, 'tokens': [], 'ner_tags': [], "gazetteers": []}]),
"validation": Dataset.from_list([{"id": 1, 'tokens': [], 'ner_tags': [], "gazetteers": []}]),
"test": test})
# dataset = DatasetDict({"test": test})
return dataset
def load_tagged_sentences(file_path):
sentences = [] # List to hold all sentences
current_sentence = [] # List to hold current sentence tokens and tags
with open(file_path, 'r', encoding='utf-8') as file:
for line in file:
line = line.strip() # Remove any extra whitespace from the line
if line:
# Split the line into token and tag
token_tag_pair = line.split()
if len(token_tag_pair) == 2:
# Add the token and tag tuple to the current sentence
current_sentence.append((token_tag_pair[0].split(':')[1], token_tag_pair[1]))
else:
# If line is empty and current sentence is not, add it to sentences
if current_sentence:
sentences.append(current_sentence)
current_sentence = [] # Reset for the next sentence
# Add the last sentence if the file doesn't end with a blank line
if current_sentence:
sentences.append(current_sentence)
return sentences
####################################################################################################
### TOKENIZE DATASET ###############################################################################
####################################################################################################
def align_labels_with_tokens(labels, word_ids):
new_labels = []
current_word = None
for word_id in word_ids:
if word_id != current_word:
# Start of a new word!
current_word = word_id
label = -100 if word_id is None else labels[word_id]
new_labels.append(label)
elif word_id is None:
# Special token
new_labels.append(-100)
else:
# Same word as previous token
label = labels[word_id]
# If the label is B-XXX we change it to I-XXX
if label % 2 == 1:
label += 1
new_labels.append(label)
return new_labels
def align_gazetteers_with_tokens(gazetteers, word_ids):
aligned_gazetteers = []
current_word = None
for word_id in word_ids:
if word_id != current_word:
# Start of a new word!
current_word = word_id
gazetteer = [0,0,0] if word_id is None else gazetteers[word_id]
aligned_gazetteers.append(gazetteer)
elif word_id is None:
# Special token
aligned_gazetteers.append([0,0,0])
else:
# Same word as previous token
gazetteer = gazetteers[word_id]
aligned_gazetteers.append(gazetteer)
return aligned_gazetteers
def create_tokenized_dataset(raw_dataset, tokenizer, apply_extended_embeddings=True):
def tokenize_and_align_labels(examples):
tokenized_inputs = tokenizer(
examples["tokens"], truncation=True, is_split_into_words=True
)
all_labels = examples["ner_tags"]
new_labels = []
for i, labels in enumerate(all_labels):
word_ids = tokenized_inputs.word_ids(i)
new_labels.append(align_labels_with_tokens(labels, word_ids))
tokenized_inputs["labels"] = new_labels
if apply_extended_embeddings:
matches = examples["gazetteers"]
aligned_matches = []
for i, match in enumerate(matches):
word_ids = tokenized_inputs.word_ids(i)
aligned_matches.append(align_gazetteers_with_tokens(match, word_ids))
per, org, loc = [], [], []
for i in aligned_matches:
per.append([x[0] for x in i])
org.append([x[1] for x in i])
loc.append([x[2] for x in i])
tokenized_inputs["per"] = per
tokenized_inputs["org"] = org
tokenized_inputs["loc"] = loc
return tokenized_inputs
dataset = raw_dataset.map(
tokenize_and_align_labels,
batched=True,
# remove_columns=raw_dataset["train"].column_names
)
return dataset