Spaces:
Sleeping
Sleeping
import os | |
import re | |
from tqdm import tqdm | |
from datasets import Dataset, DatasetDict | |
from data_manipulation.creation_gazetteers import build_reverse_dictionary, lemmatizing, load_json | |
#################################################################################################### | |
### GAZETTEERS EMBEDDINGS ########################################################################## | |
#################################################################################################### | |
def find_multi_token_matches(tokens, looking_tokens, gazetteers, matches): | |
i = 0 | |
n = len(tokens) | |
assert n == len(looking_tokens) | |
while i < n: | |
for length in range(min(5, n-i), 0, -1): # Assuming maximum entity length is 5 | |
phrase = ' '.join(looking_tokens[i:i+length]) | |
for gazetteer in gazetteers: | |
if phrase in gazetteer: | |
match_type = gazetteer[phrase] | |
for index in range(i, i+length): | |
matches.setdefault(tokens[index], []).append((phrase, match_type)) | |
i += 1 | |
return matches | |
def find_single_token_matches(tokens, looking_tokens, gazetteers, matches): | |
n = len(tokens) | |
assert n == len(looking_tokens) | |
for index in range(n): | |
word = looking_tokens[index] | |
if len(word) < 3: | |
continue | |
for gazetteer in gazetteers: | |
if word in gazetteer: | |
match_type = gazetteer[word] | |
matches.setdefault(tokens[index], []).append((word, match_type)) | |
return matches | |
def gazetteer_matching(words, gazetteers_for_matching, args=None): | |
ending_ova = True | |
method_for_gazetteers_matching = "single" | |
apply_lemmatizing = True | |
if method_for_gazetteers_matching == "single": | |
matches = find_single_token_matches(words, words, gazetteers_for_matching, {}) | |
if apply_lemmatizing: | |
lemmatize_tokens = [lemmatizing(t) for t in words] | |
matches = find_single_token_matches(words, lemmatize_tokens, gazetteers_for_matching, matches) | |
else: # multi_token_match | |
matches = find_multi_token_matches(words, words, gazetteers_for_matching, {}) | |
if apply_lemmatizing: | |
lemmatize_tokens = [lemmatizing(t) for t in words] | |
matches = find_multi_token_matches(words, lemmatize_tokens, gazetteers_for_matching, matches) | |
result = [] | |
for word in words: | |
mid_res = sorted(matches.get(word, []), key=lambda x: x[0].count(" "), reverse=True) | |
per, org, loc = 0, 0, 0 | |
for res in mid_res: | |
if mid_res[0][0].count(" ") == res[0].count(" "): | |
if res[1] == "PER": | |
per = 5 | |
elif res[1] == "ORG": | |
org = 5 | |
elif res[1] == "LOC": | |
loc = 5 | |
if ending_ova and word.endswith("ová") and word[0].isupper(): | |
per = 5 | |
result.append([per, org, loc]) | |
return result | |
#################################################################################################### | |
### CNEC DATASET ################################################################################### | |
#################################################################################################### | |
def get_dataset_from_cnec(label_mapper:dict, xml_file_path, args): | |
""" | |
label_mapper: cnec labels to int | |
""" | |
# Open and read the XML file as plain text | |
id_ = 0 | |
with open(xml_file_path, "r", encoding="utf-8") as xml_file: | |
plain_text = xml_file.read() | |
plain_text = plain_text[5:-5] # remove unnessery characters | |
plain_text = re.sub(r'([a-zA-Z.])<ne', r'\1 <ne', plain_text) | |
plain_text = re.sub(r'</ne>([a-zA-Z.])', r'</ne> \1', plain_text) | |
plain_text = re.sub(r'[ ]+', ' ', plain_text) | |
sentences = plain_text.split("\n") | |
ne_pattern = r'<ne type="([a-zA-Z?_-]{1,5})">([^<]+)</ne>' | |
data = [] | |
if args.apply_extended_embeddings: | |
gazetteers_for_matching = load_json(args.extended_embeddings_gazetteers_path) | |
temp = [] | |
for i in gazetteers_for_matching.keys(): | |
temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]})) | |
gazetteers_for_matching = temp | |
for sentence in tqdm(sentences): | |
entity_mapping = [] | |
while "<ne type=" in sentence: # while because there are nested entities | |
nes = re.findall(ne_pattern, sentence) | |
for label, entity in nes: | |
pattern = f'<ne type="{label}">{entity}</ne>' | |
index = sentence.index(pattern) | |
temp_index = index | |
sentence = sentence.replace(pattern, entity, 1) | |
temp_index -= sum([len(f'<ne type="{tag}">') for tag in re.findall(r'<ne type="([a-zA-Z?_-]{1,5})">', sentence[:index])]) | |
temp_index -= sentence[:index].count("</ne>") * len("</ne>") | |
temp_index -= (re.sub(r'<ne type="([a-zA-Z?_-]{1,5})">', "", sentence[:index]).replace("</ne>", "")).count(" ") | |
index = temp_index | |
entity_mapping.append((entity, label, index, index + len(entity))) | |
entities = [] | |
for entity, label, start, end in entity_mapping: | |
for tag in label_mapper.keys(): | |
if label.lower().startswith(tag): | |
entities.append((label_mapper[tag], entity, start, end)) | |
break | |
entities.sort(key=lambda x: len(x[1]), reverse=True) | |
words = re.split(r'\s+', sentence) | |
tags_per_word = [] | |
sentence_counter = -1 | |
for word in words: | |
sentence_counter += len(word) + 1 | |
if len(entities) == 0: | |
tags_per_word.append(0) # tag representing no label for no word | |
for index_entity in range(len(entities)): | |
if not(sentence_counter - len(word) >= entities[index_entity][2] and | |
sentence_counter <= entities[index_entity][3] and | |
word in entities[index_entity][1]): | |
if index_entity == len(entities) - 1: | |
tags_per_word.append(0) # tag representing no label for word | |
continue | |
if args.division_to_BI_tags: | |
if sentence_counter - len(word) == entities[index_entity][2]: | |
tags_per_word.append(entities[index_entity][0] * 2 - 1) # beggining of entity | |
else: | |
tags_per_word.append(entities[index_entity][0] * 2) # inside of entity | |
else: | |
tags_per_word.append(entities[index_entity][0]) | |
break | |
if args.contain_only_label_sentences and tags_per_word.count(0) == len(tags_per_word): | |
continue | |
if tags_per_word == [] or tags_per_word == [0]: | |
continue | |
if args.apply_extended_embeddings: | |
matching = gazetteer_matching(words, gazetteers_for_matching, args) | |
data.append({"id": id_, 'tokens': words, 'ner_tags': tags_per_word, | |
"sentence": " ".join(words), "gazetteers": matching}) | |
else: | |
data.append({"id": id_, 'tokens': words, 'ner_tags': tags_per_word, "sentence": " ".join(words)}) | |
id_ += 1 | |
return data | |
def get_default_dataset_from_cnec(label_mapper:dict, xml_file_path): | |
""" | |
label_mapper: cnec labels to int | |
""" | |
# Open and read the XML file as plain text | |
id_ = 0 | |
with open(xml_file_path, "r", encoding="utf-8") as xml_file: | |
plain_text = xml_file.read() | |
plain_text = plain_text[5:-5] # remove unnessery characters | |
plain_text = re.sub(r'([a-zA-Z.])<ne', r'\1 <ne', plain_text) | |
plain_text = re.sub(r'</ne>([a-zA-Z.])', r'</ne> \1', plain_text) | |
plain_text = re.sub(r'[ ]+', ' ', plain_text) | |
sentences = plain_text.split("\n") | |
ne_pattern = r'<ne type="([a-zA-Z?_-]{1,5})">([^<]+)</ne>' | |
data = [] | |
for sentence in tqdm(sentences): | |
entity_mapping = [] | |
while "<ne type=" in sentence: # while because there are nested entities | |
nes = re.findall(ne_pattern, sentence) | |
for label, entity in nes: | |
pattern = f'<ne type="{label}">{entity}</ne>' | |
index = sentence.index(pattern) | |
temp_index = index | |
sentence = sentence.replace(pattern, entity, 1) | |
temp_index -= sum([len(f'<ne type="{tag}">') for tag in re.findall(r'<ne type="([a-zA-Z?_-]{1,5})">', sentence[:index])]) | |
temp_index -= sentence[:index].count("</ne>") * len("</ne>") | |
temp_index -= (re.sub(r'<ne type="([a-zA-Z?_-]{1,5})">', "", sentence[:index]).replace("</ne>", "")).count(" ") | |
index = temp_index | |
entity_mapping.append((entity, label, index, index + len(entity))) | |
entities = [] | |
for entity, label, start, end in entity_mapping: | |
for tag in label_mapper.keys(): | |
if label.lower().startswith(tag): | |
entities.append((label_mapper[tag], entity, start, end)) | |
break | |
entities.sort(key=lambda x: len(x[1]), reverse=True) | |
words = re.split(r'\s+', sentence) | |
tags_per_word = [] | |
sentence_counter = -1 | |
for word in words: | |
sentence_counter += len(word) + 1 | |
if len(entities) == 0: | |
tags_per_word.append(0) # tag representing no label for no word | |
for index_entity in range(len(entities)): | |
if not(sentence_counter - len(word) >= entities[index_entity][2] and | |
sentence_counter <= entities[index_entity][3] and | |
word in entities[index_entity][1]): | |
if index_entity == len(entities) - 1: | |
tags_per_word.append(0) # tag representing no label for word | |
continue | |
if sentence_counter - len(word) == entities[index_entity][2]: | |
tags_per_word.append(entities[index_entity][0] * 2 - 1) # beggining of entity | |
else: | |
tags_per_word.append(entities[index_entity][0] * 2) # inside of entity | |
if tags_per_word == [] or tags_per_word == [0]: | |
continue | |
data.append({"id": id_, 'tokens': words, 'ner_tags': tags_per_word, "sentence": " ".join(words)}) | |
id_ += 1 | |
return data | |
def create_cnec_dataset(label_mapper:dict, args): | |
dataset = DatasetDict() | |
for part, file_name in zip(["train", "validation", "test"],["named_ent_train.xml", "named_ent_etest.xml", "named_ent_dtest.xml"]): | |
file_path = os.path.join(args.cnec_dataset_dir_path, file_name) | |
temp_dataset = get_dataset_from_cnec(label_mapper, file_path, args) | |
dataset[part] = Dataset.from_list(temp_dataset) | |
return dataset | |
#################################################################################################### | |
### WIKIANN DATASET ################################################################################ | |
#################################################################################################### | |
def load_wikiann_testing_dataset(args): | |
if args.apply_extended_embeddings: | |
gazetteers_for_matching = load_json(args.extended_embeddings_gazetteers_path) | |
temp = [] | |
for i in gazetteers_for_matching.keys(): | |
temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]})) | |
gazetteers_for_matching = temp | |
dataset = [] | |
index = 0 | |
sentences = load_tagged_sentences(args.wikiann_dataset_path) | |
for sentence in sentences: | |
words = [word for word, _ in sentence] | |
tags = [tag for _, tag in sentence] | |
if args.apply_extended_embeddings: | |
matching = gazetteer_matching(words, gazetteers_for_matching, args) | |
dataset.append({"id": index, 'tokens': words, 'ner_tags': tags, "gazetteers": matching}) | |
else: | |
dataset.append({"id": index, 'tokens': words, 'ner_tags': tags}) | |
index += 1 | |
test = Dataset.from_list(dataset) | |
dataset = DatasetDict({"train": Dataset.from_list([{"id": 1, 'tokens': [], 'ner_tags': [], "gazetteers": []}]), | |
"validation": Dataset.from_list([{"id": 1, 'tokens': [], 'ner_tags': [], "gazetteers": []}]), | |
"test": test}) | |
# dataset = DatasetDict({"test": test}) | |
return dataset | |
def load_tagged_sentences(file_path): | |
sentences = [] # List to hold all sentences | |
current_sentence = [] # List to hold current sentence tokens and tags | |
with open(file_path, 'r', encoding='utf-8') as file: | |
for line in file: | |
line = line.strip() # Remove any extra whitespace from the line | |
if line: | |
# Split the line into token and tag | |
token_tag_pair = line.split() | |
if len(token_tag_pair) == 2: | |
# Add the token and tag tuple to the current sentence | |
current_sentence.append((token_tag_pair[0].split(':')[1], token_tag_pair[1])) | |
else: | |
# If line is empty and current sentence is not, add it to sentences | |
if current_sentence: | |
sentences.append(current_sentence) | |
current_sentence = [] # Reset for the next sentence | |
# Add the last sentence if the file doesn't end with a blank line | |
if current_sentence: | |
sentences.append(current_sentence) | |
return sentences | |
#################################################################################################### | |
### TOKENIZE DATASET ############################################################################### | |
#################################################################################################### | |
def align_labels_with_tokens(labels, word_ids): | |
new_labels = [] | |
current_word = None | |
for word_id in word_ids: | |
if word_id != current_word: | |
# Start of a new word! | |
current_word = word_id | |
label = -100 if word_id is None else labels[word_id] | |
new_labels.append(label) | |
elif word_id is None: | |
# Special token | |
new_labels.append(-100) | |
else: | |
# Same word as previous token | |
label = labels[word_id] | |
# If the label is B-XXX we change it to I-XXX | |
if label % 2 == 1: | |
label += 1 | |
new_labels.append(label) | |
return new_labels | |
def align_gazetteers_with_tokens(gazetteers, word_ids): | |
aligned_gazetteers = [] | |
current_word = None | |
for word_id in word_ids: | |
if word_id != current_word: | |
# Start of a new word! | |
current_word = word_id | |
gazetteer = [0,0,0] if word_id is None else gazetteers[word_id] | |
aligned_gazetteers.append(gazetteer) | |
elif word_id is None: | |
# Special token | |
aligned_gazetteers.append([0,0,0]) | |
else: | |
# Same word as previous token | |
gazetteer = gazetteers[word_id] | |
aligned_gazetteers.append(gazetteer) | |
return aligned_gazetteers | |
def create_tokenized_dataset(raw_dataset, tokenizer, apply_extended_embeddings=True): | |
def tokenize_and_align_labels(examples): | |
tokenized_inputs = tokenizer( | |
examples["tokens"], truncation=True, is_split_into_words=True | |
) | |
all_labels = examples["ner_tags"] | |
new_labels = [] | |
for i, labels in enumerate(all_labels): | |
word_ids = tokenized_inputs.word_ids(i) | |
new_labels.append(align_labels_with_tokens(labels, word_ids)) | |
tokenized_inputs["labels"] = new_labels | |
if apply_extended_embeddings: | |
matches = examples["gazetteers"] | |
aligned_matches = [] | |
for i, match in enumerate(matches): | |
word_ids = tokenized_inputs.word_ids(i) | |
aligned_matches.append(align_gazetteers_with_tokens(match, word_ids)) | |
per, org, loc = [], [], [] | |
for i in aligned_matches: | |
per.append([x[0] for x in i]) | |
org.append([x[1] for x in i]) | |
loc.append([x[2] for x in i]) | |
tokenized_inputs["per"] = per | |
tokenized_inputs["org"] = org | |
tokenized_inputs["loc"] = loc | |
return tokenized_inputs | |
dataset = raw_dataset.map( | |
tokenize_and_align_labels, | |
batched=True, | |
# remove_columns=raw_dataset["train"].column_names | |
) | |
return dataset |