Spaces:
Running
Running
import os | |
import re | |
import json | |
from tqdm import tqdm | |
from datasets import Dataset, DatasetDict | |
def load_gazetteers(path): | |
""" | |
Load gazetteers from a file | |
:param path: path to the gazetteer file | |
:return: a dict of gazetteers | |
""" | |
with open(path, 'r') as f: | |
gazetteers = json.load(f) | |
for k, v in gazetteers.items(): | |
gazetteers[k] = set(v) | |
return gazetteers | |
def create_dataset(label_mapper:dict, args): | |
if args.dataset == "cnec": | |
return create_cnec_dataset(label_mapper, args) | |
return load_wikiann_testing_dataset(args) | |
#################################################################################################### | |
### GAZETTEERS EMBEDDINGS ########################################################################## | |
#################################################################################################### | |
def find_multi_token_matches(tokens, looking_tokens, gazetteers, matches): | |
i = 0 | |
n = len(tokens) | |
assert n == len(looking_tokens) | |
while i < n: | |
for length in range(min(5, n-i), 0, -1): # Assuming maximum entity length is 5 | |
phrase = ' '.join(looking_tokens[i:i+length]) | |
for gazetteer in gazetteers: | |
if phrase in gazetteer: | |
match_type = gazetteer[phrase] | |
for index in range(i, i+length): | |
matches.setdefault(tokens[index], []).append((phrase, match_type)) | |
i += 1 | |
return matches | |
def find_single_token_matches(tokens, looking_tokens, gazetteers, matches): | |
return matches | |
def find_combination_single_multi_token_matches(tokens, looking_tokens, gazetteers, matches): | |
return matches | |
def gazetteer_matching(words, gazetteers_for_matching): | |
single_token_match = False | |
ending_ova = False | |
apply_lemmatizing = False | |
if single_token_match: | |
matches = {} | |
else: # multi_token_match | |
matches = find_multi_token_matches(words, words, gazetteers_for_matching, {}) | |
# if apply_lemmatizing: TODO | |
# lemmatize_tokens = [lemmatizing(t) for t in words] | |
# matches = find_multi_token_matches(words, lemmatize_tokens, gazetteers_for_matching, matches) | |
result = [] | |
for word in words: | |
mid_res = sorted(matches.get(word, []), key=lambda x: x[0].count(" "), reverse=True) | |
per, org, loc = 0, 0, 0 | |
for res in mid_res: | |
if mid_res[0][0].count(" ") == res[0].count(" "): | |
if res[1] == "per": | |
per = 1 | |
elif res[1] == "org": | |
org = 1 | |
elif res[1] == "loc": | |
loc = 1 | |
if ending_ova and word.endswith("ová") and word[0].isupper(): | |
per = 1 | |
result.append([per, org, loc]) | |
return result | |
#################################################################################################### | |
### GAZETTEERS EXPANSION TRAIN DATASET ############################################################# | |
#################################################################################################### | |
def expand_train_dataset_with_gazetteers(train, args): | |
if args.apply_extended_embeddings: | |
gazetteers_for_matching = load_gazetteers(args.extended_embeddings_gazetteers_path) | |
gazetteers = load_gazetteers(args.train_gazetteers_path) | |
count_gazetteers = {} | |
id_ = train[-1]["id"] | |
dataset = [] | |
for row in train: | |
dataset.append({"id": row['id'], 'tokens': row['tokens'].copy(), | |
'ner_tags': row['ner_tags'].copy(), 'gazetteers': row['gazetteers'].copy()}) | |
for k in gazetteers.keys(): | |
count_gazetteers[k] = 0 | |
for index in range(args.gazetteers_counter): | |
for row in tqdm(train, desc=f"loop {index} from {args.gazetteers_counter}"): | |
i = 0 | |
temp_1 = row["ner_tags"].copy() | |
temp_2 = row["tokens"].copy() | |
if temp_1.count(0) == len(temp_1): | |
continue | |
while i < len(temp_1): | |
tag = temp_1[i] | |
if tag % 2 == 1: | |
tags = temp_1[:i] | |
tokens = temp_2[:i] | |
i += 1 | |
assert len(gazetteers[tag]) > count_gazetteers[tag] | |
new = gazetteers[tag][count_gazetteers[tag]].split(" ") | |
count_gazetteers[tag] += 1 | |
while i < len(temp_1): | |
if temp_1[i] != tag + 1: | |
break | |
i += 1 | |
tags.append(tag) | |
tags.extend([tag + 1] * (len(new) - 1)) | |
tags.extend(temp_1[i:]) | |
tokens.extend(new) | |
tokens.extend(temp_2[i:]) | |
temp_1 = tags | |
temp_2 = tokens | |
else: | |
i += 1 | |
id_ += 1 | |
if args.apply_extended_embeddings: | |
matching = gazetteer_matching(temp_2, gazetteers_for_matching, args) | |
dataset.append({"id": id_, 'tokens': temp_2, 'ner_tags': temp_1, "gazetteers": matching}) | |
dataset.append({"id": id_, 'tokens': temp_2, 'ner_tags': temp_1}) | |
return dataset | |
#################################################################################################### | |
### CNEC DATASET ################################################################################### | |
#################################################################################################### | |
def get_dataset_from_cnec(label_mapper:dict, xml_file_path, args): | |
""" | |
label_mapper: cnec labels to int | |
""" | |
# Open and read the XML file as plain text | |
assert os.path.isfile(xml_file_path) | |
id_ = 0 | |
with open(xml_file_path, "r", encoding="utf-8") as xml_file: | |
plain_text = xml_file.read() | |
plain_text = plain_text[5:-5] # remove unnessery characters | |
plain_text = re.sub(r'([a-zA-Z.])<ne', r'\1 <ne', plain_text) | |
plain_text = re.sub(r'</ne>([a-zA-Z.])', r'</ne> \1', plain_text) | |
plain_text = re.sub(r'[ ]+', ' ', plain_text) | |
sentences = plain_text.split("\n") | |
ne_pattern = r'<ne type="([a-zA-Z?_-]{1,5})">([^<]+)</ne>' | |
data = [] | |
if args.apply_extended_embeddings: | |
gazetteers_for_matching = load_gazetteers(args.extended_embeddings_gazetteers_path) | |
from data_manipulation.preprocess_gazetteers import build_reverse_dictionary | |
temp = [] | |
for i in gazetteers_for_matching.keys(): | |
temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]})) | |
gazetteers_for_matching = temp | |
for sentence in tqdm(sentences): | |
entity_mapping = [] | |
while "<ne type=" in sentence: # while because there are nested entities | |
nes = re.findall(ne_pattern, sentence) | |
for label, entity in nes: | |
pattern = f'<ne type="{label}">{entity}</ne>' | |
index = sentence.index(pattern) | |
temp_index = index | |
sentence = sentence.replace(pattern, entity, 1) | |
temp_index -= sum([len(f'<ne type="{tag}">') for tag in re.findall(r'<ne type="([a-zA-Z?_-]{1,5})">', sentence[:index])]) | |
temp_index -= sentence[:index].count("</ne>") * len("</ne>") | |
temp_index -= (re.sub(r'<ne type="([a-zA-Z?_-]{1,5})">', "", sentence[:index]).replace("</ne>", "")).count(" ") | |
index = temp_index | |
entity_mapping.append((entity, label, index, index + len(entity))) | |
entities = [] | |
for entity, label, start, end in entity_mapping: | |
for tag in label_mapper.keys(): | |
if label.lower().startswith(tag): | |
entities.append((label_mapper[tag], entity, start, end)) | |
break | |
entities.sort(key=lambda x: len(x[1]), reverse=True) | |
words = re.split(r'\s+', sentence) | |
tags_per_word = [] | |
sentence_counter = -1 | |
for word in words: | |
sentence_counter += len(word) + 1 | |
if len(entities) == 0: | |
tags_per_word.append(0) # tag representing no label for no word | |
for index_entity in range(len(entities)): | |
if not(sentence_counter - len(word) >= entities[index_entity][2] and | |
sentence_counter <= entities[index_entity][3] and | |
word in entities[index_entity][1]): | |
if index_entity == len(entities) - 1: | |
tags_per_word.append(0) # tag representing no label for word | |
continue | |
if args.division_to_BI_tags: | |
if sentence_counter - len(word) == entities[index_entity][2]: | |
tags_per_word.append(entities[index_entity][0] * 2 - 1) # beggining of entity | |
else: | |
tags_per_word.append(entities[index_entity][0] * 2) # inside of entity | |
else: | |
tags_per_word.append(entities[index_entity][0]) | |
break | |
if args.contain_only_label_sentences and tags_per_word.count(0) == len(tags_per_word): | |
continue | |
if tags_per_word == [] or tags_per_word == [0]: | |
continue | |
if args.apply_extended_embeddings: | |
matching = gazetteer_matching(words, gazetteers_for_matching) | |
data.append({"id": id_, 'tokens': words, 'ner_tags': tags_per_word, | |
"sentence": " ".join(words), "gazetteers": matching}) | |
else: | |
data.append({"id": id_, 'tokens': words, 'ner_tags': tags_per_word, "sentence": " ".join(words)}) | |
id_ += 1 | |
return data | |
def create_dataset2(label_mapper:dict, gazetteers_path): | |
path = "/nlp/projekty/gazetteer_ner/cnec2.0/data/xml" | |
dataset = DatasetDict() | |
for part, file_name in zip(["train", "validation", "test"],["named_ent_train.xml", "named_ent_etest.xml", "named_ent_dtest.xml"]): | |
file_path = os.path.join(path, file_name) | |
## | |
id_ = 0 | |
with open(file_path, "r", encoding="utf-8") as xml_file: | |
plain_text = xml_file.read() | |
plain_text = plain_text[5:-5] # remove unnessery characters | |
plain_text = re.sub(r'([a-zA-Z.])<ne', r'\1 <ne', plain_text) | |
plain_text = re.sub(r'</ne>([a-zA-Z.])', r'</ne> \1', plain_text) | |
plain_text = re.sub(r'[ ]+', ' ', plain_text) | |
sentences = plain_text.split("\n") | |
ne_pattern = r'<ne type="([a-zA-Z?_-]{1,5})">([^<]+)</ne>' | |
data = [] | |
if True: | |
gazetteers_for_matching = load_gazetteers(gazetteers_path) | |
from data_manipulation.preprocess_gazetteers import build_reverse_dictionary | |
temp = [] | |
for i in gazetteers_for_matching.keys(): | |
temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]})) | |
gazetteers_for_matching = temp | |
for sentence in tqdm(sentences): | |
entity_mapping = [] | |
while "<ne type=" in sentence: # while because there are nested entities | |
nes = re.findall(ne_pattern, sentence) | |
for label, entity in nes: | |
pattern = f'<ne type="{label}">{entity}</ne>' | |
index = sentence.index(pattern) | |
temp_index = index | |
sentence = sentence.replace(pattern, entity, 1) | |
temp_index -= sum([len(f'<ne type="{tag}">') for tag in re.findall(r'<ne type="([a-zA-Z?_-]{1,5})">', sentence[:index])]) | |
temp_index -= sentence[:index].count("</ne>") * len("</ne>") | |
temp_index -= (re.sub(r'<ne type="([a-zA-Z?_-]{1,5})">', "", sentence[:index]).replace("</ne>", "")).count(" ") | |
index = temp_index | |
entity_mapping.append((entity, label, index, index + len(entity))) | |
entities = [] | |
for entity, label, start, end in entity_mapping: | |
for tag in label_mapper.keys(): | |
if label.lower().startswith(tag): | |
entities.append((label_mapper[tag], entity, start, end)) | |
break | |
entities.sort(key=lambda x: len(x[1]), reverse=True) | |
words = re.split(r'\s+', sentence) | |
tags_per_word = [] | |
sentence_counter = -1 | |
for word in words: | |
sentence_counter += len(word) + 1 | |
if len(entities) == 0: | |
tags_per_word.append(0) # tag representing no label for no word | |
for index_entity in range(len(entities)): | |
if not(sentence_counter - len(word) >= entities[index_entity][2] and | |
sentence_counter <= entities[index_entity][3] and | |
word in entities[index_entity][1]): | |
if index_entity == len(entities) - 1: | |
tags_per_word.append(0) # tag representing no label for word | |
continue | |
if True: | |
if sentence_counter - len(word) == entities[index_entity][2]: | |
tags_per_word.append(entities[index_entity][0] * 2 - 1) # beggining of entity | |
else: | |
tags_per_word.append(entities[index_entity][0] * 2) # inside of entity | |
else: | |
tags_per_word.append(entities[index_entity][0]) | |
break | |
if tags_per_word == [] or tags_per_word == [0]: | |
continue | |
if True: | |
matching = gazetteer_matching(words, gazetteers_for_matching) | |
data.append({"id": id_, 'tokens': words, 'ner_tags': tags_per_word, | |
"sentence": " ".join(words), "gazetteers": matching}) | |
else: | |
data.append({"id": id_, 'tokens': words, 'ner_tags': tags_per_word, "sentence": " ".join(words)}) | |
id_ += 1 | |
## | |
dataset[part] = Dataset.from_list(data) | |
return dataset | |
def create_cnec_dataset(label_mapper:dict, args): | |
assert os.path.isdir(args.cnec_dataset_dir_path) | |
dataset = DatasetDict() | |
for part, file_name in zip(["train", "validation", "test"],["named_ent_train.xml", "named_ent_etest.xml", "named_ent_dtest.xml"]): | |
file_path = os.path.join(args.cnec_dataset_dir_path, file_name) | |
assert os.path.isfile(file_path) | |
temp_dataset = get_dataset_from_cnec(label_mapper, file_path, args) | |
if args.expand_train_data: | |
temp_dataset = expand_train_dataset_with_gazetteers(temp_dataset, args) | |
dataset[part] = Dataset.from_list(temp_dataset) | |
return dataset | |
#################################################################################################### | |
### WIKIANN DATASET ################################################################################ | |
#################################################################################################### | |
def load_wikiann_testing_dataset(args): | |
if args.apply_gazetteers_info: | |
gazetteers_for_matching = load_gazetteers(args.extended_embeddings_gazetteers_path) | |
assert os.path.isfile(args.wikiann_dataset_path) | |
dataset = [] | |
index = 0 | |
sentences = load_tagged_sentences(args.wikiann_dataset_path) | |
for sentence in sentences: | |
words = [word for word, _ in sentence] | |
tags = [tag for _, tag in sentence] | |
if args.apply_gazetteers_info: | |
matching = gazetteer_matching(words, gazetteers_for_matching, args) | |
dataset.append({"id": index, 'tokens': words, 'ner_tags': tags, "gazetteers": matching}) | |
else: | |
dataset.append({"id": index, 'tokens': words, 'ner_tags': tags}) | |
index += 1 | |
test = Dataset.from_list(dataset) | |
# dataset = DatasetDict({"train": Dataset.from_list([{"id": 1, 'tokens': [], 'ner_tags': [], "gazetteers": []}]), | |
# "validation": Dataset.from_list([{"id": 1, 'tokens': [], 'ner_tags': [], "gazetteers": []}]), "test": test}) | |
dataset = DatasetDict({"test": test}) | |
return dataset | |
def load_tagged_sentences(file_path): | |
sentences = [] # List to hold all sentences | |
current_sentence = [] # List to hold current sentence tokens and tags | |
with open(file_path, 'r', encoding='utf-8') as file: | |
for line in file: | |
line = line.strip() # Remove any extra whitespace from the line | |
if line: | |
# Split the line into token and tag | |
token_tag_pair = line.split() | |
if len(token_tag_pair) == 2: | |
# Add the token and tag tuple to the current sentence | |
current_sentence.append((token_tag_pair[0].split(':')[1], token_tag_pair[1])) | |
else: | |
# If line is empty and current sentence is not, add it to sentences | |
if current_sentence: | |
sentences.append(current_sentence) | |
current_sentence = [] # Reset for the next sentence | |
# Add the last sentence if the file doesn't end with a blank line | |
if current_sentence: | |
sentences.append(current_sentence) | |
return sentences | |
#################################################################################################### | |
### TOKENIZE DATASET ############################################################################### | |
#################################################################################################### | |
def align_labels_with_tokens(labels, word_ids): | |
new_labels = [] | |
current_word = None | |
for word_id in word_ids: | |
if word_id != current_word: | |
# Start of a new word! | |
current_word = word_id | |
label = -100 if word_id is None else labels[word_id] | |
new_labels.append(label) | |
elif word_id is None: | |
# Special token | |
new_labels.append(-100) | |
else: | |
# Same word as previous token | |
label = labels[word_id] | |
# If the label is B-XXX we change it to I-XXX | |
if label % 2 == 1: | |
label += 1 | |
new_labels.append(label) | |
return new_labels | |
def align_gazetteers_with_tokens(gazetteers, word_ids): | |
new_g = [] | |
current_word = None | |
for word_id in word_ids: | |
if word_id != current_word: | |
# Start of a new word! | |
current_word = word_id | |
gazetteer = [0,0,0] if word_id is None else gazetteers[word_id] | |
new_g.append(gazetteer) | |
elif word_id is None: | |
# Special token | |
new_g.append([0,0,0]) | |
else: | |
# Same word as previous token | |
gazetteer = gazetteers[word_id] | |
# # If the label is B-XXX we change it to I-XXX | |
# if gazetteer % 2 == 1: | |
# gazetteer += 1 | |
new_g.append(gazetteer) | |
return new_g | |
def create_tokenized_dataset(raw_dataset, tokenizer, apply_extended_embeddings=True): | |
def tokenize_and_align_labels(examples): | |
tokenized_inputs = tokenizer( | |
examples["tokens"], truncation=True, is_split_into_words=True | |
) | |
all_labels = examples["ner_tags"] | |
new_labels = [] | |
for i, labels in enumerate(all_labels): | |
word_ids = tokenized_inputs.word_ids(i) | |
new_labels.append(align_labels_with_tokens(labels, word_ids)) | |
tokenized_inputs["labels"] = new_labels | |
if apply_extended_embeddings: | |
g = examples["gazetteers"] | |
new_g = [] | |
for i, g in enumerate(g): | |
word_ids = tokenized_inputs.word_ids(i) | |
new_g.append(align_gazetteers_with_tokens(g, word_ids)) | |
p, o, l = [], [], [] | |
for i in new_g: | |
p.append([x[0] for x in i]) | |
o.append([x[1] for x in i]) | |
l.append([x[2] for x in i]) | |
tokenized_inputs["per"] = p | |
tokenized_inputs["org"] = o | |
tokenized_inputs["loc"] = l | |
return tokenized_inputs | |
dataset = raw_dataset.map( | |
tokenize_and_align_labels, | |
batched=True, | |
remove_columns=raw_dataset["train"].column_names, | |
) | |
return dataset |