NerRoB-czech / data_manipulation /dataset_funcions.py
AlzbetaStrompova
Initial commit
7e6964a
raw
history blame
20.5 kB
import os
import re
import json
from tqdm import tqdm
from datasets import Dataset, DatasetDict
def load_gazetteers(path):
"""
Load gazetteers from a file
:param path: path to the gazetteer file
:return: a dict of gazetteers
"""
with open(path, 'r') as f:
gazetteers = json.load(f)
for k, v in gazetteers.items():
gazetteers[k] = set(v)
return gazetteers
def create_dataset(label_mapper:dict, args):
if args.dataset == "cnec":
return create_cnec_dataset(label_mapper, args)
return load_wikiann_testing_dataset(args)
####################################################################################################
### GAZETTEERS EMBEDDINGS ##########################################################################
####################################################################################################
def find_multi_token_matches(tokens, looking_tokens, gazetteers, matches):
i = 0
n = len(tokens)
assert n == len(looking_tokens)
while i < n:
for length in range(min(5, n-i), 0, -1): # Assuming maximum entity length is 5
phrase = ' '.join(looking_tokens[i:i+length])
for gazetteer in gazetteers:
if phrase in gazetteer:
match_type = gazetteer[phrase]
for index in range(i, i+length):
matches.setdefault(tokens[index], []).append((phrase, match_type))
i += 1
return matches
def find_single_token_matches(tokens, looking_tokens, gazetteers, matches):
return matches
def find_combination_single_multi_token_matches(tokens, looking_tokens, gazetteers, matches):
return matches
def gazetteer_matching(words, gazetteers_for_matching):
single_token_match = False
ending_ova = False
apply_lemmatizing = False
if single_token_match:
matches = {}
else: # multi_token_match
matches = find_multi_token_matches(words, words, gazetteers_for_matching, {})
# if apply_lemmatizing: TODO
# lemmatize_tokens = [lemmatizing(t) for t in words]
# matches = find_multi_token_matches(words, lemmatize_tokens, gazetteers_for_matching, matches)
result = []
for word in words:
mid_res = sorted(matches.get(word, []), key=lambda x: x[0].count(" "), reverse=True)
per, org, loc = 0, 0, 0
for res in mid_res:
if mid_res[0][0].count(" ") == res[0].count(" "):
if res[1] == "per":
per = 1
elif res[1] == "org":
org = 1
elif res[1] == "loc":
loc = 1
if ending_ova and word.endswith("ová") and word[0].isupper():
per = 1
result.append([per, org, loc])
return result
####################################################################################################
### GAZETTEERS EXPANSION TRAIN DATASET #############################################################
####################################################################################################
def expand_train_dataset_with_gazetteers(train, args):
if args.apply_extended_embeddings:
gazetteers_for_matching = load_gazetteers(args.extended_embeddings_gazetteers_path)
gazetteers = load_gazetteers(args.train_gazetteers_path)
count_gazetteers = {}
id_ = train[-1]["id"]
dataset = []
for row in train:
dataset.append({"id": row['id'], 'tokens': row['tokens'].copy(),
'ner_tags': row['ner_tags'].copy(), 'gazetteers': row['gazetteers'].copy()})
for k in gazetteers.keys():
count_gazetteers[k] = 0
for index in range(args.gazetteers_counter):
for row in tqdm(train, desc=f"loop {index} from {args.gazetteers_counter}"):
i = 0
temp_1 = row["ner_tags"].copy()
temp_2 = row["tokens"].copy()
if temp_1.count(0) == len(temp_1):
continue
while i < len(temp_1):
tag = temp_1[i]
if tag % 2 == 1:
tags = temp_1[:i]
tokens = temp_2[:i]
i += 1
assert len(gazetteers[tag]) > count_gazetteers[tag]
new = gazetteers[tag][count_gazetteers[tag]].split(" ")
count_gazetteers[tag] += 1
while i < len(temp_1):
if temp_1[i] != tag + 1:
break
i += 1
tags.append(tag)
tags.extend([tag + 1] * (len(new) - 1))
tags.extend(temp_1[i:])
tokens.extend(new)
tokens.extend(temp_2[i:])
temp_1 = tags
temp_2 = tokens
else:
i += 1
id_ += 1
if args.apply_extended_embeddings:
matching = gazetteer_matching(temp_2, gazetteers_for_matching, args)
dataset.append({"id": id_, 'tokens': temp_2, 'ner_tags': temp_1, "gazetteers": matching})
dataset.append({"id": id_, 'tokens': temp_2, 'ner_tags': temp_1})
return dataset
####################################################################################################
### CNEC DATASET ###################################################################################
####################################################################################################
def get_dataset_from_cnec(label_mapper:dict, xml_file_path, args):
"""
label_mapper: cnec labels to int
"""
# Open and read the XML file as plain text
assert os.path.isfile(xml_file_path)
id_ = 0
with open(xml_file_path, "r", encoding="utf-8") as xml_file:
plain_text = xml_file.read()
plain_text = plain_text[5:-5] # remove unnessery characters
plain_text = re.sub(r'([a-zA-Z.])<ne', r'\1 <ne', plain_text)
plain_text = re.sub(r'</ne>([a-zA-Z.])', r'</ne> \1', plain_text)
plain_text = re.sub(r'[ ]+', ' ', plain_text)
sentences = plain_text.split("\n")
ne_pattern = r'<ne type="([a-zA-Z?_-]{1,5})">([^<]+)</ne>'
data = []
if args.apply_extended_embeddings:
gazetteers_for_matching = load_gazetteers(args.extended_embeddings_gazetteers_path)
from data_manipulation.preprocess_gazetteers import build_reverse_dictionary
temp = []
for i in gazetteers_for_matching.keys():
temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]}))
gazetteers_for_matching = temp
for sentence in tqdm(sentences):
entity_mapping = []
while "<ne type=" in sentence: # while because there are nested entities
nes = re.findall(ne_pattern, sentence)
for label, entity in nes:
pattern = f'<ne type="{label}">{entity}</ne>'
index = sentence.index(pattern)
temp_index = index
sentence = sentence.replace(pattern, entity, 1)
temp_index -= sum([len(f'<ne type="{tag}">') for tag in re.findall(r'<ne type="([a-zA-Z?_-]{1,5})">', sentence[:index])])
temp_index -= sentence[:index].count("</ne>") * len("</ne>")
temp_index -= (re.sub(r'<ne type="([a-zA-Z?_-]{1,5})">', "", sentence[:index]).replace("</ne>", "")).count(" ")
index = temp_index
entity_mapping.append((entity, label, index, index + len(entity)))
entities = []
for entity, label, start, end in entity_mapping:
for tag in label_mapper.keys():
if label.lower().startswith(tag):
entities.append((label_mapper[tag], entity, start, end))
break
entities.sort(key=lambda x: len(x[1]), reverse=True)
words = re.split(r'\s+', sentence)
tags_per_word = []
sentence_counter = -1
for word in words:
sentence_counter += len(word) + 1
if len(entities) == 0:
tags_per_word.append(0) # tag representing no label for no word
for index_entity in range(len(entities)):
if not(sentence_counter - len(word) >= entities[index_entity][2] and
sentence_counter <= entities[index_entity][3] and
word in entities[index_entity][1]):
if index_entity == len(entities) - 1:
tags_per_word.append(0) # tag representing no label for word
continue
if args.division_to_BI_tags:
if sentence_counter - len(word) == entities[index_entity][2]:
tags_per_word.append(entities[index_entity][0] * 2 - 1) # beggining of entity
else:
tags_per_word.append(entities[index_entity][0] * 2) # inside of entity
else:
tags_per_word.append(entities[index_entity][0])
break
if args.contain_only_label_sentences and tags_per_word.count(0) == len(tags_per_word):
continue
if tags_per_word == [] or tags_per_word == [0]:
continue
if args.apply_extended_embeddings:
matching = gazetteer_matching(words, gazetteers_for_matching)
data.append({"id": id_, 'tokens': words, 'ner_tags': tags_per_word,
"sentence": " ".join(words), "gazetteers": matching})
else:
data.append({"id": id_, 'tokens': words, 'ner_tags': tags_per_word, "sentence": " ".join(words)})
id_ += 1
return data
def create_dataset2(label_mapper:dict, gazetteers_path):
path = "/nlp/projekty/gazetteer_ner/cnec2.0/data/xml"
dataset = DatasetDict()
for part, file_name in zip(["train", "validation", "test"],["named_ent_train.xml", "named_ent_etest.xml", "named_ent_dtest.xml"]):
file_path = os.path.join(path, file_name)
##
id_ = 0
with open(file_path, "r", encoding="utf-8") as xml_file:
plain_text = xml_file.read()
plain_text = plain_text[5:-5] # remove unnessery characters
plain_text = re.sub(r'([a-zA-Z.])<ne', r'\1 <ne', plain_text)
plain_text = re.sub(r'</ne>([a-zA-Z.])', r'</ne> \1', plain_text)
plain_text = re.sub(r'[ ]+', ' ', plain_text)
sentences = plain_text.split("\n")
ne_pattern = r'<ne type="([a-zA-Z?_-]{1,5})">([^<]+)</ne>'
data = []
if True:
gazetteers_for_matching = load_gazetteers(gazetteers_path)
from data_manipulation.preprocess_gazetteers import build_reverse_dictionary
temp = []
for i in gazetteers_for_matching.keys():
temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]}))
gazetteers_for_matching = temp
for sentence in tqdm(sentences):
entity_mapping = []
while "<ne type=" in sentence: # while because there are nested entities
nes = re.findall(ne_pattern, sentence)
for label, entity in nes:
pattern = f'<ne type="{label}">{entity}</ne>'
index = sentence.index(pattern)
temp_index = index
sentence = sentence.replace(pattern, entity, 1)
temp_index -= sum([len(f'<ne type="{tag}">') for tag in re.findall(r'<ne type="([a-zA-Z?_-]{1,5})">', sentence[:index])])
temp_index -= sentence[:index].count("</ne>") * len("</ne>")
temp_index -= (re.sub(r'<ne type="([a-zA-Z?_-]{1,5})">', "", sentence[:index]).replace("</ne>", "")).count(" ")
index = temp_index
entity_mapping.append((entity, label, index, index + len(entity)))
entities = []
for entity, label, start, end in entity_mapping:
for tag in label_mapper.keys():
if label.lower().startswith(tag):
entities.append((label_mapper[tag], entity, start, end))
break
entities.sort(key=lambda x: len(x[1]), reverse=True)
words = re.split(r'\s+', sentence)
tags_per_word = []
sentence_counter = -1
for word in words:
sentence_counter += len(word) + 1
if len(entities) == 0:
tags_per_word.append(0) # tag representing no label for no word
for index_entity in range(len(entities)):
if not(sentence_counter - len(word) >= entities[index_entity][2] and
sentence_counter <= entities[index_entity][3] and
word in entities[index_entity][1]):
if index_entity == len(entities) - 1:
tags_per_word.append(0) # tag representing no label for word
continue
if True:
if sentence_counter - len(word) == entities[index_entity][2]:
tags_per_word.append(entities[index_entity][0] * 2 - 1) # beggining of entity
else:
tags_per_word.append(entities[index_entity][0] * 2) # inside of entity
else:
tags_per_word.append(entities[index_entity][0])
break
if tags_per_word == [] or tags_per_word == [0]:
continue
if True:
matching = gazetteer_matching(words, gazetteers_for_matching)
data.append({"id": id_, 'tokens': words, 'ner_tags': tags_per_word,
"sentence": " ".join(words), "gazetteers": matching})
else:
data.append({"id": id_, 'tokens': words, 'ner_tags': tags_per_word, "sentence": " ".join(words)})
id_ += 1
##
dataset[part] = Dataset.from_list(data)
return dataset
def create_cnec_dataset(label_mapper:dict, args):
assert os.path.isdir(args.cnec_dataset_dir_path)
dataset = DatasetDict()
for part, file_name in zip(["train", "validation", "test"],["named_ent_train.xml", "named_ent_etest.xml", "named_ent_dtest.xml"]):
file_path = os.path.join(args.cnec_dataset_dir_path, file_name)
assert os.path.isfile(file_path)
temp_dataset = get_dataset_from_cnec(label_mapper, file_path, args)
if args.expand_train_data:
temp_dataset = expand_train_dataset_with_gazetteers(temp_dataset, args)
dataset[part] = Dataset.from_list(temp_dataset)
return dataset
####################################################################################################
### WIKIANN DATASET ################################################################################
####################################################################################################
def load_wikiann_testing_dataset(args):
if args.apply_gazetteers_info:
gazetteers_for_matching = load_gazetteers(args.extended_embeddings_gazetteers_path)
assert os.path.isfile(args.wikiann_dataset_path)
dataset = []
index = 0
sentences = load_tagged_sentences(args.wikiann_dataset_path)
for sentence in sentences:
words = [word for word, _ in sentence]
tags = [tag for _, tag in sentence]
if args.apply_gazetteers_info:
matching = gazetteer_matching(words, gazetteers_for_matching, args)
dataset.append({"id": index, 'tokens': words, 'ner_tags': tags, "gazetteers": matching})
else:
dataset.append({"id": index, 'tokens': words, 'ner_tags': tags})
index += 1
test = Dataset.from_list(dataset)
# dataset = DatasetDict({"train": Dataset.from_list([{"id": 1, 'tokens': [], 'ner_tags': [], "gazetteers": []}]),
# "validation": Dataset.from_list([{"id": 1, 'tokens': [], 'ner_tags': [], "gazetteers": []}]), "test": test})
dataset = DatasetDict({"test": test})
return dataset
def load_tagged_sentences(file_path):
sentences = [] # List to hold all sentences
current_sentence = [] # List to hold current sentence tokens and tags
with open(file_path, 'r', encoding='utf-8') as file:
for line in file:
line = line.strip() # Remove any extra whitespace from the line
if line:
# Split the line into token and tag
token_tag_pair = line.split()
if len(token_tag_pair) == 2:
# Add the token and tag tuple to the current sentence
current_sentence.append((token_tag_pair[0].split(':')[1], token_tag_pair[1]))
else:
# If line is empty and current sentence is not, add it to sentences
if current_sentence:
sentences.append(current_sentence)
current_sentence = [] # Reset for the next sentence
# Add the last sentence if the file doesn't end with a blank line
if current_sentence:
sentences.append(current_sentence)
return sentences
####################################################################################################
### TOKENIZE DATASET ###############################################################################
####################################################################################################
def align_labels_with_tokens(labels, word_ids):
new_labels = []
current_word = None
for word_id in word_ids:
if word_id != current_word:
# Start of a new word!
current_word = word_id
label = -100 if word_id is None else labels[word_id]
new_labels.append(label)
elif word_id is None:
# Special token
new_labels.append(-100)
else:
# Same word as previous token
label = labels[word_id]
# If the label is B-XXX we change it to I-XXX
if label % 2 == 1:
label += 1
new_labels.append(label)
return new_labels
def align_gazetteers_with_tokens(gazetteers, word_ids):
new_g = []
current_word = None
for word_id in word_ids:
if word_id != current_word:
# Start of a new word!
current_word = word_id
gazetteer = [0,0,0] if word_id is None else gazetteers[word_id]
new_g.append(gazetteer)
elif word_id is None:
# Special token
new_g.append([0,0,0])
else:
# Same word as previous token
gazetteer = gazetteers[word_id]
# # If the label is B-XXX we change it to I-XXX
# if gazetteer % 2 == 1:
# gazetteer += 1
new_g.append(gazetteer)
return new_g
def create_tokenized_dataset(raw_dataset, tokenizer, apply_extended_embeddings=True):
def tokenize_and_align_labels(examples):
tokenized_inputs = tokenizer(
examples["tokens"], truncation=True, is_split_into_words=True
)
all_labels = examples["ner_tags"]
new_labels = []
for i, labels in enumerate(all_labels):
word_ids = tokenized_inputs.word_ids(i)
new_labels.append(align_labels_with_tokens(labels, word_ids))
tokenized_inputs["labels"] = new_labels
if apply_extended_embeddings:
g = examples["gazetteers"]
new_g = []
for i, g in enumerate(g):
word_ids = tokenized_inputs.word_ids(i)
new_g.append(align_gazetteers_with_tokens(g, word_ids))
p, o, l = [], [], []
for i in new_g:
p.append([x[0] for x in i])
o.append([x[1] for x in i])
l.append([x[2] for x in i])
tokenized_inputs["per"] = p
tokenized_inputs["org"] = o
tokenized_inputs["loc"] = l
return tokenized_inputs
dataset = raw_dataset.map(
tokenize_and_align_labels,
batched=True,
remove_columns=raw_dataset["train"].column_names,
)
return dataset