Spaces:
Sleeping
Sleeping
AlzbetaStrompova
commited on
Commit
•
75a65be
1
Parent(s):
19e9ab7
minor changes
Browse files- app.py +15 -11
- data_manipulation/creation_gazetteers.py +115 -0
- data_manipulation/dataset_funcions.py +124 -212
- data_manipulation/preprocess_gazetteers.py +0 -54
- extended_embeddings/__init__.py +0 -0
- extended_embeddings/{token_classification.py → extended_embedding_token_classification.py} +13 -3
- extended_embeddings/extended_embeddings_data_collator.py +77 -0
- extended_embeddings/extended_embeddings_model.py +12 -39
- flagged/log.csv +0 -8
- requirements.txt +1 -0
- style.css +6 -5
- upload_model.ipynb +3 -3
- website_script.py +32 -4
app.py
CHANGED
@@ -1,32 +1,36 @@
|
|
1 |
-
import json
|
2 |
import gradio as gr
|
3 |
from website_script import load, run
|
4 |
|
5 |
tokenizer, model, gazetteers_for_matching = load()
|
6 |
|
7 |
examples = [
|
8 |
-
["Masarykova univerzita se nachází v
|
9 |
-
["Barack Obama navštívil Prahu minulý týden
|
10 |
-
["Angela Merkelová se setkala s francouzským prezidentem v Paříži
|
11 |
-
["Nobelova cena za fyziku byla udělena týmu vědců z MIT
|
12 |
-
]
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
|
15 |
def ner(text, file_names):
|
|
|
16 |
result = run(tokenizer, model, gazetteers_for_matching, text, file_names)
|
17 |
return {"text": text, "entities": result}
|
18 |
|
19 |
with gr.Blocks(css="./style.css", theme=gr.themes.Default(primary_hue="blue", secondary_hue="sky")) as demo:
|
20 |
gr.Interface(ner,
|
21 |
-
gr.Textbox(lines=
|
22 |
-
|
23 |
-
gr.HighlightedText(show_legend=True, color_map={"PER": "#f57d7d", "ORG": "#2cf562", "LOC": "#86aafc"}, elem_id="highlighted_text"),
|
24 |
examples=examples,
|
25 |
title="NerROB-czech",
|
26 |
-
description="This is an implementation of a Named Entity Recognition model for the Czech language using gazetteers.",
|
27 |
allow_flagging="never",
|
28 |
additional_inputs=gr.File(label="Upload a JSON file containing gazetteers", file_count="multiple", file_types=[".json"]),
|
29 |
)
|
30 |
|
31 |
if __name__ == "__main__":
|
32 |
-
demo.launch()
|
|
|
|
|
1 |
import gradio as gr
|
2 |
from website_script import load, run
|
3 |
|
4 |
tokenizer, model, gazetteers_for_matching = load()
|
5 |
|
6 |
examples = [
|
7 |
+
["Masarykova univerzita se nachází v Brně.", None],
|
8 |
+
["Barack Obama navštívil Prahu minulý týden.", None],
|
9 |
+
["Angela Merkelová se setkala s francouzským prezidentem v Paříži.", None],
|
10 |
+
["Nobelova cena za fyziku byla udělena týmu vědců z MIT.", None],
|
11 |
+
["Eiffelova věž je ikonickou památkou v Paříži.", None],
|
12 |
+
["Bill Gates, spoluzakladatel společnosti Microsoft, oznámil nový grant pro výzkum umělé inteligence.", None],
|
13 |
+
["Britská královna Alžběta II. navštívila Kanadu v rámci svého posledního zahraničního turné, během kterého zdůraznila důležitost spolupráce a přátelství mezi oběma národy.", None],
|
14 |
+
["Francouzský prezident Emmanuel Macron oznámil nový plán na podporu start-upů a inovací ve Francii, který zahrnuje investice ve výši několika miliard eur.", None],
|
15 |
+
["Světová zdravotnická organizace spustila nový program na boj proti malárii v subsaharské Africe, který zahrnuje rozdělování sítí proti komárům a očkování milionů lidí.", None]
|
16 |
+
]
|
17 |
|
18 |
|
19 |
def ner(text, file_names):
|
20 |
+
text = text.replace(".", " .")
|
21 |
result = run(tokenizer, model, gazetteers_for_matching, text, file_names)
|
22 |
return {"text": text, "entities": result}
|
23 |
|
24 |
with gr.Blocks(css="./style.css", theme=gr.themes.Default(primary_hue="blue", secondary_hue="sky")) as demo:
|
25 |
gr.Interface(ner,
|
26 |
+
gr.Textbox(lines=5, placeholder="Enter sentence here..."),
|
27 |
+
gr.HighlightedText(show_legend=True, color_map={"PER": "#f7a7a3", "ORG": "#77fc6a", "LOC": "#87CEFF"}),
|
|
|
28 |
examples=examples,
|
29 |
title="NerROB-czech",
|
30 |
+
description="This is an implementation of a Named Entity Recognition model for the Czech language using gazetteers.",
|
31 |
allow_flagging="never",
|
32 |
additional_inputs=gr.File(label="Upload a JSON file containing gazetteers", file_count="multiple", file_types=[".json"]),
|
33 |
)
|
34 |
|
35 |
if __name__ == "__main__":
|
36 |
+
demo.launch()
|
data_manipulation/creation_gazetteers.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import json
|
4 |
+
import itertools
|
5 |
+
|
6 |
+
import pandas as pd
|
7 |
+
from simplemma import lemmatize
|
8 |
+
from names_dataset import NameDataset
|
9 |
+
|
10 |
+
|
11 |
+
def load_json(path):
|
12 |
+
"""
|
13 |
+
Load gazetteers from a file
|
14 |
+
:param path: path to the gazetteer file
|
15 |
+
:return: a dict of gazetteers
|
16 |
+
"""
|
17 |
+
with open(path, 'r') as file:
|
18 |
+
data = json.load(file)
|
19 |
+
return data
|
20 |
+
|
21 |
+
|
22 |
+
def save_json(data, path):
|
23 |
+
"""
|
24 |
+
Save gazetteers to a file
|
25 |
+
:param path: path to the gazetteer file
|
26 |
+
:param gazetteers: a dict of gazetteers
|
27 |
+
"""
|
28 |
+
with open(path, 'w') as file:
|
29 |
+
json.dump(data, file, indent=4)
|
30 |
+
|
31 |
+
def merge_gazetteers(*gazetteers):
|
32 |
+
"""
|
33 |
+
Merge multiple gazetteer dictionaries into a single gazetteer dictionary.
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
dict: A merged gazetteer dictionary containing all the keys and values from the input gazetteers.
|
37 |
+
"""
|
38 |
+
# Initialize a new dictionary to store merged results
|
39 |
+
merged_gazetteers = {}
|
40 |
+
# Iterate over each dictionary provided
|
41 |
+
for gaz in gazetteers:
|
42 |
+
# Iterate over each key and set in the current dictionary
|
43 |
+
for key, value_set in gaz.items():
|
44 |
+
if key in merged_gazetteers:
|
45 |
+
# If the key already exists in the result, union the sets
|
46 |
+
merged_gazetteers[key] |= value_set
|
47 |
+
else:
|
48 |
+
# Otherwise, initialize the key with the set from the current dictionary
|
49 |
+
merged_gazetteers[key] = value_set.copy() # Use copy to avoid mutating the original sets
|
50 |
+
return merged_gazetteers
|
51 |
+
|
52 |
+
|
53 |
+
####################################################################################################
|
54 |
+
### PREPROCESSING OF GAZETTEERS ###################################################################
|
55 |
+
####################################################################################################
|
56 |
+
|
57 |
+
def remove_all_brackets(text):
|
58 |
+
return re.sub(r'[\(\{\[].*?[\)\}\]]', '', text)
|
59 |
+
|
60 |
+
|
61 |
+
def lemmatizing(x):
|
62 |
+
if x == "":
|
63 |
+
return ""
|
64 |
+
return lemmatize(x, lang="cs")
|
65 |
+
|
66 |
+
|
67 |
+
def multi_lemmatizing(x):
|
68 |
+
words = x.split(" ")
|
69 |
+
phrase = ""
|
70 |
+
for word in words:
|
71 |
+
phrase += lemmatizing(word) + " "
|
72 |
+
return phrase.strip()
|
73 |
+
|
74 |
+
|
75 |
+
def build_reverse_dictionary(dictionary, apply_lemmatizing=False):
|
76 |
+
reverse_dictionary = {}
|
77 |
+
for key, values in dictionary.items():
|
78 |
+
for value in values:
|
79 |
+
reverse_dictionary[value] = key
|
80 |
+
if apply_lemmatizing:
|
81 |
+
temp = lemmatizing(value)
|
82 |
+
if temp != value:
|
83 |
+
reverse_dictionary[temp] = key
|
84 |
+
return reverse_dictionary
|
85 |
+
|
86 |
+
|
87 |
+
def split_gazetteers_for_single_token_match(gazetteers):
|
88 |
+
result = {}
|
89 |
+
for k, v in gazetteers.items():
|
90 |
+
result[k] = set([x for xs in [vv.split(" ") for vv in v] for x in xs])
|
91 |
+
result[k] = {x for x in result[k] if len(x) > 2}
|
92 |
+
return result
|
93 |
+
|
94 |
+
|
95 |
+
def preprocess_gazetteers(gazetteers, config):
|
96 |
+
if config["remove_brackets"]:
|
97 |
+
for k, values in gazetteers.items():
|
98 |
+
gazetteers[k] = {remove_all_brackets(vv).strip() for vv in values if len(remove_all_brackets(vv).strip()) > 2}
|
99 |
+
if config["split_person"]:
|
100 |
+
gazetteers["per"].update(set([x for x in list(itertools.chain(*[v.split(" ") for v in gazetteers["per"]])) if len(x) > 2]))
|
101 |
+
if config["techniq_for_matching"] == "single":
|
102 |
+
gazetteers = split_gazetteers_for_single_token_match(gazetteers)
|
103 |
+
if config["lemmatize"]:
|
104 |
+
for k, values in gazetteers.items():
|
105 |
+
gazetteers[k] = set(list(itertools.chain(*[(vv, lemmatizing(vv)) for vv in values if len(vv) > 2])))
|
106 |
+
elif config["lemmatize"]:
|
107 |
+
for k, values in gazetteers.items():
|
108 |
+
gazetteers[k] = set(list(itertools.chain(*[(value, multi_lemmatizing(value)) for value in values if len(value) > 2])))
|
109 |
+
|
110 |
+
if config["remove_numeric"]:
|
111 |
+
for k, values in gazetteers.items():
|
112 |
+
gazetteers[k] = {vv for vv in values if not vv.isnumeric()}
|
113 |
+
for k, values in gazetteers.items():
|
114 |
+
gazetteers[k] = list(values)
|
115 |
+
return gazetteers
|
data_manipulation/dataset_funcions.py
CHANGED
@@ -1,27 +1,10 @@
|
|
1 |
import os
|
2 |
import re
|
3 |
-
import json
|
4 |
from tqdm import tqdm
|
5 |
|
6 |
from datasets import Dataset, DatasetDict
|
7 |
|
8 |
-
|
9 |
-
"""
|
10 |
-
Load gazetteers from a file
|
11 |
-
:param path: path to the gazetteer file
|
12 |
-
:return: a dict of gazetteers
|
13 |
-
"""
|
14 |
-
with open(path, 'r') as f:
|
15 |
-
gazetteers = json.load(f)
|
16 |
-
for k, v in gazetteers.items():
|
17 |
-
gazetteers[k] = set(v)
|
18 |
-
return gazetteers
|
19 |
-
|
20 |
-
def create_dataset(label_mapper:dict, args):
|
21 |
-
if args.dataset == "cnec":
|
22 |
-
return create_cnec_dataset(label_mapper, args)
|
23 |
-
return load_wikiann_testing_dataset(args)
|
24 |
-
|
25 |
|
26 |
####################################################################################################
|
27 |
### GAZETTEERS EMBEDDINGS ##########################################################################
|
@@ -43,26 +26,36 @@ def find_multi_token_matches(tokens, looking_tokens, gazetteers, matches):
|
|
43 |
i += 1
|
44 |
return matches
|
45 |
|
46 |
-
def find_single_token_matches(tokens, looking_tokens, gazetteers, matches):
|
47 |
-
return matches
|
48 |
|
49 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
return matches
|
51 |
|
52 |
-
def gazetteer_matching(words, gazetteers_for_matching):
|
53 |
-
single_token_match = False
|
54 |
-
ending_ova = False
|
55 |
-
apply_lemmatizing = False
|
56 |
-
|
57 |
|
58 |
-
|
59 |
-
|
|
|
|
|
60 |
|
|
|
|
|
|
|
|
|
|
|
61 |
else: # multi_token_match
|
62 |
matches = find_multi_token_matches(words, words, gazetteers_for_matching, {})
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
|
67 |
result = []
|
68 |
for word in words:
|
@@ -70,72 +63,18 @@ def gazetteer_matching(words, gazetteers_for_matching):
|
|
70 |
per, org, loc = 0, 0, 0
|
71 |
for res in mid_res:
|
72 |
if mid_res[0][0].count(" ") == res[0].count(" "):
|
73 |
-
if res[1] == "
|
74 |
-
per =
|
75 |
-
elif res[1] == "
|
76 |
-
org =
|
77 |
-
elif res[1] == "
|
78 |
-
loc =
|
79 |
if ending_ova and word.endswith("ová") and word[0].isupper():
|
80 |
-
per =
|
81 |
result.append([per, org, loc])
|
82 |
return result
|
83 |
|
84 |
|
85 |
-
####################################################################################################
|
86 |
-
### GAZETTEERS EXPANSION TRAIN DATASET #############################################################
|
87 |
-
####################################################################################################
|
88 |
-
|
89 |
-
def expand_train_dataset_with_gazetteers(train, args):
|
90 |
-
if args.apply_extended_embeddings:
|
91 |
-
gazetteers_for_matching = load_gazetteers(args.extended_embeddings_gazetteers_path)
|
92 |
-
gazetteers = load_gazetteers(args.train_gazetteers_path)
|
93 |
-
count_gazetteers = {}
|
94 |
-
id_ = train[-1]["id"]
|
95 |
-
dataset = []
|
96 |
-
for row in train:
|
97 |
-
dataset.append({"id": row['id'], 'tokens': row['tokens'].copy(),
|
98 |
-
'ner_tags': row['ner_tags'].copy(), 'gazetteers': row['gazetteers'].copy()})
|
99 |
-
for k in gazetteers.keys():
|
100 |
-
count_gazetteers[k] = 0
|
101 |
-
for index in range(args.gazetteers_counter):
|
102 |
-
for row in tqdm(train, desc=f"loop {index} from {args.gazetteers_counter}"):
|
103 |
-
i = 0
|
104 |
-
temp_1 = row["ner_tags"].copy()
|
105 |
-
temp_2 = row["tokens"].copy()
|
106 |
-
if temp_1.count(0) == len(temp_1):
|
107 |
-
continue
|
108 |
-
while i < len(temp_1):
|
109 |
-
tag = temp_1[i]
|
110 |
-
if tag % 2 == 1:
|
111 |
-
tags = temp_1[:i]
|
112 |
-
tokens = temp_2[:i]
|
113 |
-
i += 1
|
114 |
-
assert len(gazetteers[tag]) > count_gazetteers[tag]
|
115 |
-
new = gazetteers[tag][count_gazetteers[tag]].split(" ")
|
116 |
-
count_gazetteers[tag] += 1
|
117 |
-
while i < len(temp_1):
|
118 |
-
if temp_1[i] != tag + 1:
|
119 |
-
break
|
120 |
-
i += 1
|
121 |
-
tags.append(tag)
|
122 |
-
tags.extend([tag + 1] * (len(new) - 1))
|
123 |
-
tags.extend(temp_1[i:])
|
124 |
-
|
125 |
-
tokens.extend(new)
|
126 |
-
tokens.extend(temp_2[i:])
|
127 |
-
temp_1 = tags
|
128 |
-
temp_2 = tokens
|
129 |
-
else:
|
130 |
-
i += 1
|
131 |
-
id_ += 1
|
132 |
-
if args.apply_extended_embeddings:
|
133 |
-
matching = gazetteer_matching(temp_2, gazetteers_for_matching, args)
|
134 |
-
dataset.append({"id": id_, 'tokens': temp_2, 'ner_tags': temp_1, "gazetteers": matching})
|
135 |
-
dataset.append({"id": id_, 'tokens': temp_2, 'ner_tags': temp_1})
|
136 |
-
return dataset
|
137 |
-
|
138 |
-
|
139 |
####################################################################################################
|
140 |
### CNEC DATASET ###################################################################################
|
141 |
####################################################################################################
|
@@ -144,7 +83,6 @@ def get_dataset_from_cnec(label_mapper:dict, xml_file_path, args):
|
|
144 |
label_mapper: cnec labels to int
|
145 |
"""
|
146 |
# Open and read the XML file as plain text
|
147 |
-
assert os.path.isfile(xml_file_path)
|
148 |
id_ = 0
|
149 |
with open(xml_file_path, "r", encoding="utf-8") as xml_file:
|
150 |
plain_text = xml_file.read()
|
@@ -156,14 +94,13 @@ def get_dataset_from_cnec(label_mapper:dict, xml_file_path, args):
|
|
156 |
ne_pattern = r'<ne type="([a-zA-Z?_-]{1,5})">([^<]+)</ne>'
|
157 |
data = []
|
158 |
if args.apply_extended_embeddings:
|
159 |
-
gazetteers_for_matching =
|
160 |
-
from data_manipulation.preprocess_gazetteers import build_reverse_dictionary
|
161 |
temp = []
|
162 |
for i in gazetteers_for_matching.keys():
|
163 |
temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]}))
|
164 |
gazetteers_for_matching = temp
|
165 |
|
166 |
-
for sentence in tqdm(sentences):
|
167 |
entity_mapping = []
|
168 |
while "<ne type=" in sentence: # while because there are nested entities
|
169 |
nes = re.findall(ne_pattern, sentence)
|
@@ -215,7 +152,7 @@ def get_dataset_from_cnec(label_mapper:dict, xml_file_path, args):
|
|
215 |
if tags_per_word == [] or tags_per_word == [0]:
|
216 |
continue
|
217 |
if args.apply_extended_embeddings:
|
218 |
-
matching = gazetteer_matching(words, gazetteers_for_matching)
|
219 |
data.append({"id": id_, 'tokens': words, 'ner_tags': tags_per_word,
|
220 |
"sentence": " ".join(words), "gazetteers": matching})
|
221 |
else:
|
@@ -223,104 +160,78 @@ def get_dataset_from_cnec(label_mapper:dict, xml_file_path, args):
|
|
223 |
id_ += 1
|
224 |
return data
|
225 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
for i in gazetteers_for_matching.keys():
|
248 |
-
temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]}))
|
249 |
-
gazetteers_for_matching = temp
|
250 |
-
|
251 |
-
for sentence in tqdm(sentences):
|
252 |
-
entity_mapping = []
|
253 |
-
while "<ne type=" in sentence: # while because there are nested entities
|
254 |
-
nes = re.findall(ne_pattern, sentence)
|
255 |
-
for label, entity in nes:
|
256 |
-
pattern = f'<ne type="{label}">{entity}</ne>'
|
257 |
-
index = sentence.index(pattern)
|
258 |
-
temp_index = index
|
259 |
-
sentence = sentence.replace(pattern, entity, 1)
|
260 |
-
temp_index -= sum([len(f'<ne type="{tag}">') for tag in re.findall(r'<ne type="([a-zA-Z?_-]{1,5})">', sentence[:index])])
|
261 |
-
temp_index -= sentence[:index].count("</ne>") * len("</ne>")
|
262 |
-
temp_index -= (re.sub(r'<ne type="([a-zA-Z?_-]{1,5})">', "", sentence[:index]).replace("</ne>", "")).count(" ")
|
263 |
-
index = temp_index
|
264 |
-
entity_mapping.append((entity, label, index, index + len(entity)))
|
265 |
-
|
266 |
-
entities = []
|
267 |
-
for entity, label, start, end in entity_mapping:
|
268 |
-
for tag in label_mapper.keys():
|
269 |
-
if label.lower().startswith(tag):
|
270 |
-
entities.append((label_mapper[tag], entity, start, end))
|
271 |
-
break
|
272 |
-
entities.sort(key=lambda x: len(x[1]), reverse=True)
|
273 |
-
|
274 |
-
words = re.split(r'\s+', sentence)
|
275 |
-
tags_per_word = []
|
276 |
-
sentence_counter = -1
|
277 |
-
for word in words:
|
278 |
-
sentence_counter += len(word) + 1
|
279 |
-
if len(entities) == 0:
|
280 |
-
tags_per_word.append(0) # tag representing no label for no word
|
281 |
-
for index_entity in range(len(entities)):
|
282 |
-
if not(sentence_counter - len(word) >= entities[index_entity][2] and
|
283 |
-
sentence_counter <= entities[index_entity][3] and
|
284 |
-
word in entities[index_entity][1]):
|
285 |
-
if index_entity == len(entities) - 1:
|
286 |
-
tags_per_word.append(0) # tag representing no label for word
|
287 |
-
continue
|
288 |
-
|
289 |
-
if True:
|
290 |
-
if sentence_counter - len(word) == entities[index_entity][2]:
|
291 |
-
tags_per_word.append(entities[index_entity][0] * 2 - 1) # beggining of entity
|
292 |
-
else:
|
293 |
-
tags_per_word.append(entities[index_entity][0] * 2) # inside of entity
|
294 |
-
else:
|
295 |
-
tags_per_word.append(entities[index_entity][0])
|
296 |
break
|
|
|
297 |
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
|
|
|
|
|
|
|
|
|
|
307 |
|
|
|
|
|
|
|
|
|
308 |
|
309 |
-
|
310 |
-
|
311 |
-
|
|
|
|
|
|
|
312 |
|
313 |
|
314 |
def create_cnec_dataset(label_mapper:dict, args):
|
315 |
-
|
316 |
-
assert os.path.isdir(args.cnec_dataset_dir_path)
|
317 |
dataset = DatasetDict()
|
318 |
for part, file_name in zip(["train", "validation", "test"],["named_ent_train.xml", "named_ent_etest.xml", "named_ent_dtest.xml"]):
|
319 |
file_path = os.path.join(args.cnec_dataset_dir_path, file_name)
|
320 |
-
assert os.path.isfile(file_path)
|
321 |
temp_dataset = get_dataset_from_cnec(label_mapper, file_path, args)
|
322 |
-
if args.expand_train_data:
|
323 |
-
temp_dataset = expand_train_dataset_with_gazetteers(temp_dataset, args)
|
324 |
dataset[part] = Dataset.from_list(temp_dataset)
|
325 |
return dataset
|
326 |
|
@@ -328,16 +239,19 @@ def create_cnec_dataset(label_mapper:dict, args):
|
|
328 |
### WIKIANN DATASET ################################################################################
|
329 |
####################################################################################################
|
330 |
def load_wikiann_testing_dataset(args):
|
331 |
-
if args.
|
332 |
-
gazetteers_for_matching =
|
333 |
-
|
|
|
|
|
|
|
334 |
dataset = []
|
335 |
index = 0
|
336 |
sentences = load_tagged_sentences(args.wikiann_dataset_path)
|
337 |
for sentence in sentences:
|
338 |
words = [word for word, _ in sentence]
|
339 |
tags = [tag for _, tag in sentence]
|
340 |
-
if args.
|
341 |
matching = gazetteer_matching(words, gazetteers_for_matching, args)
|
342 |
dataset.append({"id": index, 'tokens': words, 'ner_tags': tags, "gazetteers": matching})
|
343 |
else:
|
@@ -345,9 +259,10 @@ def load_wikiann_testing_dataset(args):
|
|
345 |
index += 1
|
346 |
|
347 |
test = Dataset.from_list(dataset)
|
348 |
-
|
349 |
-
|
350 |
-
|
|
|
351 |
return dataset
|
352 |
|
353 |
|
@@ -400,26 +315,24 @@ def align_labels_with_tokens(labels, word_ids):
|
|
400 |
new_labels.append(label)
|
401 |
return new_labels
|
402 |
|
|
|
403 |
def align_gazetteers_with_tokens(gazetteers, word_ids):
|
404 |
-
|
405 |
current_word = None
|
406 |
for word_id in word_ids:
|
407 |
if word_id != current_word:
|
408 |
# Start of a new word!
|
409 |
current_word = word_id
|
410 |
gazetteer = [0,0,0] if word_id is None else gazetteers[word_id]
|
411 |
-
|
412 |
elif word_id is None:
|
413 |
# Special token
|
414 |
-
|
415 |
else:
|
416 |
# Same word as previous token
|
417 |
gazetteer = gazetteers[word_id]
|
418 |
-
|
419 |
-
|
420 |
-
# gazetteer += 1
|
421 |
-
new_g.append(gazetteer)
|
422 |
-
return new_g
|
423 |
|
424 |
|
425 |
def create_tokenized_dataset(raw_dataset, tokenizer, apply_extended_embeddings=True):
|
@@ -434,25 +347,24 @@ def create_tokenized_dataset(raw_dataset, tokenizer, apply_extended_embeddings=T
|
|
434 |
new_labels.append(align_labels_with_tokens(labels, word_ids))
|
435 |
tokenized_inputs["labels"] = new_labels
|
436 |
if apply_extended_embeddings:
|
437 |
-
|
438 |
-
|
439 |
-
for i,
|
440 |
word_ids = tokenized_inputs.word_ids(i)
|
441 |
-
|
442 |
-
|
443 |
-
for i in
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
tokenized_inputs["per"] =
|
448 |
-
tokenized_inputs["org"] =
|
449 |
-
tokenized_inputs["loc"] =
|
450 |
return tokenized_inputs
|
451 |
|
452 |
-
|
453 |
dataset = raw_dataset.map(
|
454 |
tokenize_and_align_labels,
|
455 |
batched=True,
|
456 |
-
remove_columns=raw_dataset["train"].column_names
|
457 |
)
|
458 |
return dataset
|
|
|
1 |
import os
|
2 |
import re
|
|
|
3 |
from tqdm import tqdm
|
4 |
|
5 |
from datasets import Dataset, DatasetDict
|
6 |
|
7 |
+
from data_manipulation.creation_gazetteers import build_reverse_dictionary, lemmatizing, load_json
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
####################################################################################################
|
10 |
### GAZETTEERS EMBEDDINGS ##########################################################################
|
|
|
26 |
i += 1
|
27 |
return matches
|
28 |
|
|
|
|
|
29 |
|
30 |
+
def find_single_token_matches(tokens, looking_tokens, gazetteers, matches):
|
31 |
+
n = len(tokens)
|
32 |
+
assert n == len(looking_tokens)
|
33 |
+
for index in range(n):
|
34 |
+
word = looking_tokens[index]
|
35 |
+
if len(word) < 3:
|
36 |
+
continue
|
37 |
+
for gazetteer in gazetteers:
|
38 |
+
if word in gazetteer:
|
39 |
+
match_type = gazetteer[word]
|
40 |
+
matches.setdefault(tokens[index], []).append((word, match_type))
|
41 |
return matches
|
42 |
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
+
def gazetteer_matching(words, gazetteers_for_matching, args=None):
|
45 |
+
ending_ova = True
|
46 |
+
method_for_gazetteers_matching = "single"
|
47 |
+
apply_lemmatizing = True
|
48 |
|
49 |
+
if method_for_gazetteers_matching == "single":
|
50 |
+
matches = find_single_token_matches(words, words, gazetteers_for_matching, {})
|
51 |
+
if apply_lemmatizing:
|
52 |
+
lemmatize_tokens = [lemmatizing(t) for t in words]
|
53 |
+
matches = find_single_token_matches(words, lemmatize_tokens, gazetteers_for_matching, matches)
|
54 |
else: # multi_token_match
|
55 |
matches = find_multi_token_matches(words, words, gazetteers_for_matching, {})
|
56 |
+
if apply_lemmatizing:
|
57 |
+
lemmatize_tokens = [lemmatizing(t) for t in words]
|
58 |
+
matches = find_multi_token_matches(words, lemmatize_tokens, gazetteers_for_matching, matches)
|
59 |
|
60 |
result = []
|
61 |
for word in words:
|
|
|
63 |
per, org, loc = 0, 0, 0
|
64 |
for res in mid_res:
|
65 |
if mid_res[0][0].count(" ") == res[0].count(" "):
|
66 |
+
if res[1] == "PER":
|
67 |
+
per = 5
|
68 |
+
elif res[1] == "ORG":
|
69 |
+
org = 5
|
70 |
+
elif res[1] == "LOC":
|
71 |
+
loc = 5
|
72 |
if ending_ova and word.endswith("ová") and word[0].isupper():
|
73 |
+
per = 5
|
74 |
result.append([per, org, loc])
|
75 |
return result
|
76 |
|
77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
####################################################################################################
|
79 |
### CNEC DATASET ###################################################################################
|
80 |
####################################################################################################
|
|
|
83 |
label_mapper: cnec labels to int
|
84 |
"""
|
85 |
# Open and read the XML file as plain text
|
|
|
86 |
id_ = 0
|
87 |
with open(xml_file_path, "r", encoding="utf-8") as xml_file:
|
88 |
plain_text = xml_file.read()
|
|
|
94 |
ne_pattern = r'<ne type="([a-zA-Z?_-]{1,5})">([^<]+)</ne>'
|
95 |
data = []
|
96 |
if args.apply_extended_embeddings:
|
97 |
+
gazetteers_for_matching = load_json(args.extended_embeddings_gazetteers_path)
|
|
|
98 |
temp = []
|
99 |
for i in gazetteers_for_matching.keys():
|
100 |
temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]}))
|
101 |
gazetteers_for_matching = temp
|
102 |
|
103 |
+
for sentence in tqdm(sentences):
|
104 |
entity_mapping = []
|
105 |
while "<ne type=" in sentence: # while because there are nested entities
|
106 |
nes = re.findall(ne_pattern, sentence)
|
|
|
152 |
if tags_per_word == [] or tags_per_word == [0]:
|
153 |
continue
|
154 |
if args.apply_extended_embeddings:
|
155 |
+
matching = gazetteer_matching(words, gazetteers_for_matching, args)
|
156 |
data.append({"id": id_, 'tokens': words, 'ner_tags': tags_per_word,
|
157 |
"sentence": " ".join(words), "gazetteers": matching})
|
158 |
else:
|
|
|
160 |
id_ += 1
|
161 |
return data
|
162 |
|
163 |
+
def get_default_dataset_from_cnec(label_mapper:dict, xml_file_path):
|
164 |
+
"""
|
165 |
+
label_mapper: cnec labels to int
|
166 |
+
"""
|
167 |
+
# Open and read the XML file as plain text
|
168 |
+
id_ = 0
|
169 |
+
with open(xml_file_path, "r", encoding="utf-8") as xml_file:
|
170 |
+
plain_text = xml_file.read()
|
171 |
+
plain_text = plain_text[5:-5] # remove unnessery characters
|
172 |
+
plain_text = re.sub(r'([a-zA-Z.])<ne', r'\1 <ne', plain_text)
|
173 |
+
plain_text = re.sub(r'</ne>([a-zA-Z.])', r'</ne> \1', plain_text)
|
174 |
+
plain_text = re.sub(r'[ ]+', ' ', plain_text)
|
175 |
+
sentences = plain_text.split("\n")
|
176 |
+
ne_pattern = r'<ne type="([a-zA-Z?_-]{1,5})">([^<]+)</ne>'
|
177 |
+
data = []
|
178 |
|
179 |
+
for sentence in tqdm(sentences):
|
180 |
+
entity_mapping = []
|
181 |
+
while "<ne type=" in sentence: # while because there are nested entities
|
182 |
+
nes = re.findall(ne_pattern, sentence)
|
183 |
+
for label, entity in nes:
|
184 |
+
pattern = f'<ne type="{label}">{entity}</ne>'
|
185 |
+
index = sentence.index(pattern)
|
186 |
+
temp_index = index
|
187 |
+
sentence = sentence.replace(pattern, entity, 1)
|
188 |
+
temp_index -= sum([len(f'<ne type="{tag}">') for tag in re.findall(r'<ne type="([a-zA-Z?_-]{1,5})">', sentence[:index])])
|
189 |
+
temp_index -= sentence[:index].count("</ne>") * len("</ne>")
|
190 |
+
temp_index -= (re.sub(r'<ne type="([a-zA-Z?_-]{1,5})">', "", sentence[:index]).replace("</ne>", "")).count(" ")
|
191 |
+
index = temp_index
|
192 |
+
entity_mapping.append((entity, label, index, index + len(entity)))
|
193 |
+
|
194 |
+
entities = []
|
195 |
+
for entity, label, start, end in entity_mapping:
|
196 |
+
for tag in label_mapper.keys():
|
197 |
+
if label.lower().startswith(tag):
|
198 |
+
entities.append((label_mapper[tag], entity, start, end))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
199 |
break
|
200 |
+
entities.sort(key=lambda x: len(x[1]), reverse=True)
|
201 |
|
202 |
+
words = re.split(r'\s+', sentence)
|
203 |
+
tags_per_word = []
|
204 |
+
sentence_counter = -1
|
205 |
+
for word in words:
|
206 |
+
sentence_counter += len(word) + 1
|
207 |
+
if len(entities) == 0:
|
208 |
+
tags_per_word.append(0) # tag representing no label for no word
|
209 |
+
for index_entity in range(len(entities)):
|
210 |
+
if not(sentence_counter - len(word) >= entities[index_entity][2] and
|
211 |
+
sentence_counter <= entities[index_entity][3] and
|
212 |
+
word in entities[index_entity][1]):
|
213 |
+
if index_entity == len(entities) - 1:
|
214 |
+
tags_per_word.append(0) # tag representing no label for word
|
215 |
+
continue
|
216 |
|
217 |
+
if sentence_counter - len(word) == entities[index_entity][2]:
|
218 |
+
tags_per_word.append(entities[index_entity][0] * 2 - 1) # beggining of entity
|
219 |
+
else:
|
220 |
+
tags_per_word.append(entities[index_entity][0] * 2) # inside of entity
|
221 |
|
222 |
+
if tags_per_word == [] or tags_per_word == [0]:
|
223 |
+
continue
|
224 |
+
|
225 |
+
data.append({"id": id_, 'tokens': words, 'ner_tags': tags_per_word, "sentence": " ".join(words)})
|
226 |
+
id_ += 1
|
227 |
+
return data
|
228 |
|
229 |
|
230 |
def create_cnec_dataset(label_mapper:dict, args):
|
|
|
|
|
231 |
dataset = DatasetDict()
|
232 |
for part, file_name in zip(["train", "validation", "test"],["named_ent_train.xml", "named_ent_etest.xml", "named_ent_dtest.xml"]):
|
233 |
file_path = os.path.join(args.cnec_dataset_dir_path, file_name)
|
|
|
234 |
temp_dataset = get_dataset_from_cnec(label_mapper, file_path, args)
|
|
|
|
|
235 |
dataset[part] = Dataset.from_list(temp_dataset)
|
236 |
return dataset
|
237 |
|
|
|
239 |
### WIKIANN DATASET ################################################################################
|
240 |
####################################################################################################
|
241 |
def load_wikiann_testing_dataset(args):
|
242 |
+
if args.apply_extended_embeddings:
|
243 |
+
gazetteers_for_matching = load_json(args.extended_embeddings_gazetteers_path)
|
244 |
+
temp = []
|
245 |
+
for i in gazetteers_for_matching.keys():
|
246 |
+
temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]}))
|
247 |
+
gazetteers_for_matching = temp
|
248 |
dataset = []
|
249 |
index = 0
|
250 |
sentences = load_tagged_sentences(args.wikiann_dataset_path)
|
251 |
for sentence in sentences:
|
252 |
words = [word for word, _ in sentence]
|
253 |
tags = [tag for _, tag in sentence]
|
254 |
+
if args.apply_extended_embeddings:
|
255 |
matching = gazetteer_matching(words, gazetteers_for_matching, args)
|
256 |
dataset.append({"id": index, 'tokens': words, 'ner_tags': tags, "gazetteers": matching})
|
257 |
else:
|
|
|
259 |
index += 1
|
260 |
|
261 |
test = Dataset.from_list(dataset)
|
262 |
+
dataset = DatasetDict({"train": Dataset.from_list([{"id": 1, 'tokens': [], 'ner_tags': [], "gazetteers": []}]),
|
263 |
+
"validation": Dataset.from_list([{"id": 1, 'tokens': [], 'ner_tags': [], "gazetteers": []}]),
|
264 |
+
"test": test})
|
265 |
+
# dataset = DatasetDict({"test": test})
|
266 |
return dataset
|
267 |
|
268 |
|
|
|
315 |
new_labels.append(label)
|
316 |
return new_labels
|
317 |
|
318 |
+
|
319 |
def align_gazetteers_with_tokens(gazetteers, word_ids):
|
320 |
+
aligned_gazetteers = []
|
321 |
current_word = None
|
322 |
for word_id in word_ids:
|
323 |
if word_id != current_word:
|
324 |
# Start of a new word!
|
325 |
current_word = word_id
|
326 |
gazetteer = [0,0,0] if word_id is None else gazetteers[word_id]
|
327 |
+
aligned_gazetteers.append(gazetteer)
|
328 |
elif word_id is None:
|
329 |
# Special token
|
330 |
+
aligned_gazetteers.append([0,0,0])
|
331 |
else:
|
332 |
# Same word as previous token
|
333 |
gazetteer = gazetteers[word_id]
|
334 |
+
aligned_gazetteers.append(gazetteer)
|
335 |
+
return aligned_gazetteers
|
|
|
|
|
|
|
336 |
|
337 |
|
338 |
def create_tokenized_dataset(raw_dataset, tokenizer, apply_extended_embeddings=True):
|
|
|
347 |
new_labels.append(align_labels_with_tokens(labels, word_ids))
|
348 |
tokenized_inputs["labels"] = new_labels
|
349 |
if apply_extended_embeddings:
|
350 |
+
matches = examples["gazetteers"]
|
351 |
+
aligned_matches = []
|
352 |
+
for i, match in enumerate(matches):
|
353 |
word_ids = tokenized_inputs.word_ids(i)
|
354 |
+
aligned_matches.append(align_gazetteers_with_tokens(match, word_ids))
|
355 |
+
per, org, loc = [], [], []
|
356 |
+
for i in aligned_matches:
|
357 |
+
per.append([x[0] for x in i])
|
358 |
+
org.append([x[1] for x in i])
|
359 |
+
loc.append([x[2] for x in i])
|
360 |
+
tokenized_inputs["per"] = per
|
361 |
+
tokenized_inputs["org"] = org
|
362 |
+
tokenized_inputs["loc"] = loc
|
363 |
return tokenized_inputs
|
364 |
|
|
|
365 |
dataset = raw_dataset.map(
|
366 |
tokenize_and_align_labels,
|
367 |
batched=True,
|
368 |
+
# remove_columns=raw_dataset["train"].column_names
|
369 |
)
|
370 |
return dataset
|
data_manipulation/preprocess_gazetteers.py
DELETED
@@ -1,54 +0,0 @@
|
|
1 |
-
import re
|
2 |
-
|
3 |
-
from simplemma import lemmatize
|
4 |
-
|
5 |
-
|
6 |
-
def flatten(xss):
|
7 |
-
return [x for xs in xss for x in xs]
|
8 |
-
|
9 |
-
|
10 |
-
def remove_all_brackets(text):
|
11 |
-
return re.sub(r'[\(\{\[].*?[\)\}\]]', '', text)
|
12 |
-
|
13 |
-
|
14 |
-
def lemmatizing(x):
|
15 |
-
if x == "":
|
16 |
-
return ""
|
17 |
-
return lemmatize(x, lang="cs")
|
18 |
-
|
19 |
-
|
20 |
-
def build_reverse_dictionary(dictionary, apply_lemmatizing=False):
|
21 |
-
reverse_dictionary = {}
|
22 |
-
for key, values in dictionary.items():
|
23 |
-
for value in values:
|
24 |
-
reverse_dictionary[value] = key
|
25 |
-
if apply_lemmatizing:
|
26 |
-
temp = lemmatizing(value)
|
27 |
-
if temp != value:
|
28 |
-
reverse_dictionary[temp] = key
|
29 |
-
return reverse_dictionary
|
30 |
-
|
31 |
-
|
32 |
-
def split_gazetteers_for_single_token_match(gazetteers):
|
33 |
-
result = {}
|
34 |
-
for k, v in gazetteers.items():
|
35 |
-
result[k] = set(flatten([vv.split(" ") for vv in v]))
|
36 |
-
result[k] = {x for x in result[k] if len(x) > 2}
|
37 |
-
return result
|
38 |
-
|
39 |
-
|
40 |
-
def preprocess_gazetteers(gazetteers, config):
|
41 |
-
if config["split_person"]:
|
42 |
-
gazetteers["PER"].update(set([x for x in flatten([v.split(" ") for v in gazetteers["PER"]]) if len(x) > 2]))
|
43 |
-
if config["lemmatize"]:
|
44 |
-
for k, v in gazetteers.items():
|
45 |
-
gazetteers[k] = set(flatten([(vv, lemmatizing(vv)) for vv in v if len(vv) > 2]))
|
46 |
-
if config["remove_brackets"]:
|
47 |
-
for k, v in gazetteers.items():
|
48 |
-
gazetteers[k] = {remove_all_brackets(vv).strip() for vv in v if len(remove_all_brackets(vv).strip()) > 2}
|
49 |
-
if config["remove_numeric"]:
|
50 |
-
for k, v in gazetteers.items():
|
51 |
-
gazetteers[k] = {vv for vv in v if not vv.isnumeric()}
|
52 |
-
if config["techniq_for_matching"] != "single":
|
53 |
-
gazetteers = split_gazetteers_for_single_token_match(gazetteers)
|
54 |
-
return gazetteers
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
extended_embeddings/__init__.py
DELETED
File without changes
|
extended_embeddings/{token_classification.py → extended_embedding_token_classification.py}
RENAMED
@@ -1,4 +1,4 @@
|
|
1 |
-
from typing import
|
2 |
|
3 |
import torch
|
4 |
from torch import nn
|
@@ -12,11 +12,20 @@ _CONFIG_FOR_DOC = "RobertaConfig"
|
|
12 |
|
13 |
|
14 |
class ExtendedEmbeddigsRobertaForTokenClassification(RobertaForTokenClassification):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
def __init__(self, config):
|
16 |
super().__init__(config)
|
17 |
self.num_labels = config.num_labels
|
18 |
|
19 |
-
self.roberta = ExtendedEmbeddigsRobertaModel(config
|
20 |
classifier_dropout = (
|
21 |
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
22 |
)
|
@@ -92,4 +101,5 @@ class ExtendedEmbeddigsRobertaForTokenClassification(RobertaForTokenClassificati
|
|
92 |
logits=logits,
|
93 |
hidden_states=outputs.hidden_states,
|
94 |
attentions=outputs.attentions,
|
95 |
-
)
|
|
|
|
1 |
+
from typing import Optional, Tuple, Union
|
2 |
|
3 |
import torch
|
4 |
from torch import nn
|
|
|
12 |
|
13 |
|
14 |
class ExtendedEmbeddigsRobertaForTokenClassification(RobertaForTokenClassification):
|
15 |
+
"""
|
16 |
+
A RobertaForTokenClassification for token classification tasks with extended embeddings.
|
17 |
+
|
18 |
+
This RobertaForTokenClassification extends the functionality of the `RobertaForTokenClassification` class
|
19 |
+
by adding support for additional features such as `per`, `org`, and `loc`.
|
20 |
+
|
21 |
+
Part of the code copied from: transformers.models.bert.modeling_roberta.RobertaForTokenClassification
|
22 |
+
|
23 |
+
"""
|
24 |
def __init__(self, config):
|
25 |
super().__init__(config)
|
26 |
self.num_labels = config.num_labels
|
27 |
|
28 |
+
self.roberta = ExtendedEmbeddigsRobertaModel(config)
|
29 |
classifier_dropout = (
|
30 |
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
31 |
)
|
|
|
101 |
logits=logits,
|
102 |
hidden_states=outputs.hidden_states,
|
103 |
attentions=outputs.attentions,
|
104 |
+
)
|
105 |
+
|
extended_embeddings/extended_embeddings_data_collator.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import DataCollatorForTokenClassification
|
3 |
+
from transformers.data.data_collator import pad_without_fast_tokenizer_warning
|
4 |
+
|
5 |
+
|
6 |
+
class ExtendedEmbeddingsDataCollatorForTokenClassification(DataCollatorForTokenClassification):
|
7 |
+
"""
|
8 |
+
A data collator for token classification tasks with extended embeddings.
|
9 |
+
|
10 |
+
This data collator extends the functionality of the `DataCollatorForTokenClassification` class
|
11 |
+
by adding support for additional features such as `per`, `org`, and `loc`.
|
12 |
+
|
13 |
+
Part of the code copied from: transformers.data.data_collator.DataCollatorForTokenClassification
|
14 |
+
"""
|
15 |
+
|
16 |
+
def torch_call(self, features):
|
17 |
+
label_name = "label" if "label" in features[0].keys() else "labels"
|
18 |
+
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
|
19 |
+
per = [feature["per"] for feature in features] if "per" in features[0].keys() else None
|
20 |
+
org = [feature["org"] for feature in features] if "org" in features[0].keys() else None
|
21 |
+
loc = [feature["loc"] for feature in features] if "loc" in features[0].keys() else None
|
22 |
+
|
23 |
+
no_labels_features = [{k: v for k, v in feature.items() if k not in [label_name, "per", "org", "loc"]} for feature in features]
|
24 |
+
|
25 |
+
batch = pad_without_fast_tokenizer_warning(
|
26 |
+
self.tokenizer,
|
27 |
+
no_labels_features,
|
28 |
+
padding=self.padding,
|
29 |
+
max_length=self.max_length,
|
30 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
31 |
+
return_tensors="pt",
|
32 |
+
)
|
33 |
+
|
34 |
+
if labels is None:
|
35 |
+
return batch
|
36 |
+
|
37 |
+
sequence_length = batch["input_ids"].shape[1]
|
38 |
+
padding_side = self.tokenizer.padding_side
|
39 |
+
|
40 |
+
def to_list(tensor_or_iterable):
|
41 |
+
if isinstance(tensor_or_iterable, torch.Tensor):
|
42 |
+
return tensor_or_iterable.tolist()
|
43 |
+
return list(tensor_or_iterable)
|
44 |
+
|
45 |
+
if padding_side == "right":
|
46 |
+
batch[label_name] = [
|
47 |
+
to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
|
48 |
+
]
|
49 |
+
batch["per"] = [
|
50 |
+
to_list(p) + [0] * (sequence_length - len(p)) for p in per
|
51 |
+
]
|
52 |
+
batch["org"] = [
|
53 |
+
to_list(o) + [0] * (sequence_length - len(o)) for o in org
|
54 |
+
]
|
55 |
+
batch["loc"] = [
|
56 |
+
to_list(l) + [0] * (sequence_length - len(l)) for l in loc
|
57 |
+
]
|
58 |
+
else:
|
59 |
+
batch[label_name] = [
|
60 |
+
[self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels
|
61 |
+
]
|
62 |
+
batch["per"] = [
|
63 |
+
[0] * (sequence_length - len(p)) + self.to_list(p) for p in per
|
64 |
+
]
|
65 |
+
batch["org"] = [
|
66 |
+
[0] * (sequence_length - len(o)) + self.to_list(o) for o in org
|
67 |
+
]
|
68 |
+
batch["loc"] = [
|
69 |
+
[0] * (sequence_length - len(l)) + self.to_list(l) for l in loc
|
70 |
+
]
|
71 |
+
|
72 |
+
batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64)
|
73 |
+
batch["per"] = torch.tensor(batch["per"], dtype=torch.int64)
|
74 |
+
batch["org"] = torch.tensor(batch["org"], dtype=torch.int64)
|
75 |
+
batch["loc"] = torch.tensor(batch["loc"], dtype=torch.int64)
|
76 |
+
return batch
|
77 |
+
|
extended_embeddings/extended_embeddings_model.py
CHANGED
@@ -1,53 +1,27 @@
|
|
1 |
-
from transformers.models.roberta.modeling_roberta import RobertaModel, RobertaEncoder, RobertaEmbeddings
|
2 |
-
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
|
3 |
from typing import List, Optional, Tuple, Union
|
|
|
4 |
import torch
|
5 |
-
from
|
6 |
-
from
|
7 |
|
8 |
-
# Copied from transformers.models.bert.modeling_bert.BertPooler
|
9 |
-
class ExtendedEmbeddigsRobertaPooler(nn.Module):
|
10 |
-
def __init__(self, config):
|
11 |
-
super().__init__()
|
12 |
-
size_of_gazetters_part = int((len(config.id2label.keys()) - 1) // 2)
|
13 |
-
self.dense = nn.Linear(config.hidden_size + size_of_gazetters_part, config.hidden_size + size_of_gazetters_part)
|
14 |
-
self.activation = nn.Tanh()
|
15 |
-
|
16 |
-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
17 |
-
# We "pool" the model by simply taking the hidden state corresponding
|
18 |
-
# to the first token.
|
19 |
-
first_token_tensor = hidden_states[:, 0]
|
20 |
-
pooled_output = self.dense(first_token_tensor)
|
21 |
-
pooled_output = self.activation(pooled_output)
|
22 |
-
return pooled_output
|
23 |
|
24 |
class ExtendedEmbeddigsRobertaModel(RobertaModel):
|
25 |
"""
|
|
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
|
30 |
-
Kaiser and Illia Polosukhin.
|
31 |
|
32 |
-
|
33 |
-
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
|
34 |
-
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
|
35 |
-
|
36 |
-
.. _*Attention is all you need*: https://arxiv.org/abs/1706.03762
|
37 |
|
38 |
"""
|
39 |
-
|
40 |
-
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta
|
41 |
-
def __init__(self, config, add_pooling_layer=True):
|
42 |
super().__init__(config)
|
43 |
self.config = config
|
44 |
|
45 |
self.embeddings = RobertaEmbeddings(config)
|
46 |
self.encoder = RobertaEncoder(config)
|
47 |
-
|
48 |
-
|
49 |
-
self.pooler = ExtendedEmbeddigsRobertaPooler(config)
|
50 |
-
|
51 |
# Initialize weights and apply final processing
|
52 |
self.post_init()
|
53 |
|
@@ -57,10 +31,9 @@ class ExtendedEmbeddigsRobertaModel(RobertaModel):
|
|
57 |
attention_mask: Optional[torch.Tensor] = None,
|
58 |
token_type_ids: Optional[torch.Tensor] = None,
|
59 |
position_ids: Optional[torch.Tensor] = None,
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
loc: Optional[torch.Tensor] = None, # change
|
64 |
head_mask: Optional[torch.Tensor] = None,
|
65 |
inputs_embeds: Optional[torch.Tensor] = None,
|
66 |
encoder_hidden_states: Optional[torch.Tensor] = None,
|
|
|
|
|
|
|
1 |
from typing import List, Optional, Tuple, Union
|
2 |
+
|
3 |
import torch
|
4 |
+
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
|
5 |
+
from transformers.models.roberta.modeling_roberta import RobertaModel, RobertaEncoder, RobertaEmbeddings
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
class ExtendedEmbeddigsRobertaModel(RobertaModel):
|
9 |
"""
|
10 |
+
A RobertaModel for token classification tasks with extended embeddings.
|
11 |
|
12 |
+
This RobertaModel extends the functionality of the `RobertaModel` class
|
13 |
+
by adding support for additional features such as `per`, `org`, and `loc`.
|
|
|
|
|
14 |
|
15 |
+
Part of the code copied from: transformers.models.bert.modeling_roberta.RobertaModel
|
|
|
|
|
|
|
|
|
16 |
|
17 |
"""
|
18 |
+
def __init__(self, config):
|
|
|
|
|
19 |
super().__init__(config)
|
20 |
self.config = config
|
21 |
|
22 |
self.embeddings = RobertaEmbeddings(config)
|
23 |
self.encoder = RobertaEncoder(config)
|
24 |
+
self.pooler = None
|
|
|
|
|
|
|
25 |
# Initialize weights and apply final processing
|
26 |
self.post_init()
|
27 |
|
|
|
31 |
attention_mask: Optional[torch.Tensor] = None,
|
32 |
token_type_ids: Optional[torch.Tensor] = None,
|
33 |
position_ids: Optional[torch.Tensor] = None,
|
34 |
+
per: Optional[torch.Tensor] = None,
|
35 |
+
org: Optional[torch.Tensor] = None,
|
36 |
+
loc: Optional[torch.Tensor] = None,
|
|
|
37 |
head_mask: Optional[torch.Tensor] = None,
|
38 |
inputs_embeds: Optional[torch.Tensor] = None,
|
39 |
encoder_hidden_states: Optional[torch.Tensor] = None,
|
flagged/log.csv
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
text,output,flag,username,timestamp
|
2 |
-
Masarykova univerzita se nachází v Brně .,"[{""token"": """", ""class_or_confidence"": null}, {""token"": ""Masarykova univerzita"", ""class_or_confidence"": ""ORG""}, {""token"": "" se nach\u00e1z\u00ed v "", ""class_or_confidence"": null}, {""token"": ""Brn\u011b"", ""class_or_confidence"": ""LOC""}, {""token"": "" ."", ""class_or_confidence"": null}]",,,2024-05-06 02:29:01.157209
|
3 |
-
Barack Obama navštívil Prahu minulý týden .,"[{""token"": """", ""class_or_confidence"": null}, {""token"": ""Barack Obama"", ""class_or_confidence"": ""OSV""}, {""token"": "" nav\u0161t\u00edvil "", ""class_or_confidence"": null}, {""token"": ""Prahu"", ""class_or_confidence"": ""LOC""}, {""token"": "" minul\u00fd t\u00fdden ."", ""class_or_confidence"": null}]",,,2024-05-06 02:31:57.950478
|
4 |
-
Masarykova univerzita se nachází v Brně .,"[{""token"": """", ""class_or_confidence"": null}, {""token"": ""Masarykova univerzita"", ""class_or_confidence"": ""ORG""}, {""token"": "" se nach\u00e1z\u00ed v "", ""class_or_confidence"": null}, {""token"": ""Brn\u011b"", ""class_or_confidence"": ""LOC""}, {""token"": "" ."", ""class_or_confidence"": null}]",,,2024-05-06 02:51:30.197653
|
5 |
-
Barack Obama navštívil Prahu minulý týden .,,,,2024-05-06 10:58:33.085992
|
6 |
-
Masarykova univerzita se nachází v Brně .,"[{""token"": """", ""class_or_confidence"": null}, {""token"": ""Masarykova univerzita"", ""class_or_confidence"": ""ORG""}, {""token"": "" se nach\u00e1z\u00ed v "", ""class_or_confidence"": null}, {""token"": ""Brn\u011b"", ""class_or_confidence"": ""LOC""}, {""token"": "" ."", ""class_or_confidence"": null}]",,,2024-05-06 11:00:17.762652
|
7 |
-
Masarykova univerzita se nachází v Brně .,"[{""token"": """", ""class_or_confidence"": null}, {""token"": ""Masarykova univerzita"", ""class_or_confidence"": ""ORG""}, {""token"": "" se nach\u00e1z\u00ed v "", ""class_or_confidence"": null}, {""token"": ""Brn\u011b"", ""class_or_confidence"": ""LOC""}, {""token"": "" ."", ""class_or_confidence"": null}]",,,2024-05-06 11:00:20.057269
|
8 |
-
,,,,,2024-05-09 22:59:12.114264
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -5,3 +5,4 @@ torch
|
|
5 |
simplemma
|
6 |
gradio
|
7 |
pandas
|
|
|
|
5 |
simplemma
|
6 |
gradio
|
7 |
pandas
|
8 |
+
name-datasets
|
style.css
CHANGED
@@ -6,10 +6,6 @@ footer {
|
|
6 |
color-scheme: light dark;
|
7 |
}
|
8 |
|
9 |
-
.container .svelte-ju12zg {
|
10 |
-
color: light-dark(black, white);
|
11 |
-
}
|
12 |
-
|
13 |
.text.svelte-ju12zg {
|
14 |
padding: 0;
|
15 |
margin: 0;
|
@@ -23,4 +19,9 @@ footer {
|
|
23 |
.textspan.svelte-ju12zg.no-cat {
|
24 |
margin: 0;
|
25 |
padding: 0;
|
26 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
6 |
color-scheme: light dark;
|
7 |
}
|
8 |
|
|
|
|
|
|
|
|
|
9 |
.text.svelte-ju12zg {
|
10 |
padding: 0;
|
11 |
margin: 0;
|
|
|
19 |
.textspan.svelte-ju12zg.no-cat {
|
20 |
margin: 0;
|
21 |
padding: 0;
|
22 |
+
}
|
23 |
+
|
24 |
+
.category-label.svelte-ju12zg {
|
25 |
+
background-color: light-dark(white, black,);
|
26 |
+
|
27 |
+
}
|
upload_model.ipynb
CHANGED
@@ -2,13 +2,13 @@
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
-
"execution_count":
|
6 |
"metadata": {},
|
7 |
"outputs": [
|
8 |
{
|
9 |
"data": {
|
10 |
"application/vnd.jupyter.widget-view+json": {
|
11 |
-
"model_id": "
|
12 |
"version_major": 2,
|
13 |
"version_minor": 0
|
14 |
},
|
@@ -28,7 +28,7 @@
|
|
28 |
},
|
29 |
{
|
30 |
"cell_type": "code",
|
31 |
-
"execution_count":
|
32 |
"metadata": {},
|
33 |
"outputs": [
|
34 |
{
|
|
|
2 |
"cells": [
|
3 |
{
|
4 |
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
"metadata": {},
|
7 |
"outputs": [
|
8 |
{
|
9 |
"data": {
|
10 |
"application/vnd.jupyter.widget-view+json": {
|
11 |
+
"model_id": "556291d727474e0a82723d6459722b16",
|
12 |
"version_major": 2,
|
13 |
"version_minor": 0
|
14 |
},
|
|
|
28 |
},
|
29 |
{
|
30 |
"cell_type": "code",
|
31 |
+
"execution_count": 2,
|
32 |
"metadata": {},
|
33 |
"outputs": [
|
34 |
{
|
website_script.py
CHANGED
@@ -2,11 +2,39 @@ import json
|
|
2 |
import copy
|
3 |
|
4 |
import torch
|
|
|
5 |
from transformers import AutoTokenizer
|
6 |
|
7 |
-
from extended_embeddings.
|
8 |
-
from data_manipulation.dataset_funcions import
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
|
12 |
def load():
|
@@ -18,7 +46,7 @@ def load():
|
|
18 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
19 |
model.eval()
|
20 |
|
21 |
-
gazetteers_for_matching =
|
22 |
temp = []
|
23 |
for i in gazetteers_for_matching.keys():
|
24 |
temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]}))
|
|
|
2 |
import copy
|
3 |
|
4 |
import torch
|
5 |
+
from simplemma import lemmatize
|
6 |
from transformers import AutoTokenizer
|
7 |
|
8 |
+
from extended_embeddings.extended_embedding_token_classification import ExtendedEmbeddigsRobertaForTokenClassification
|
9 |
+
from data_manipulation.dataset_funcions import gazetteer_matching, align_gazetteers_with_tokens
|
10 |
+
|
11 |
+
# code originaly from data_manipulation.creation_gazetteers
|
12 |
+
def lemmatizing(x):
|
13 |
+
if x == "":
|
14 |
+
return ""
|
15 |
+
return lemmatize(x, lang="cs")
|
16 |
+
|
17 |
+
# code originaly from data_manipulation.creation_gazetteers
|
18 |
+
def build_reverse_dictionary(dictionary, apply_lemmatizing=False):
|
19 |
+
reverse_dictionary = {}
|
20 |
+
for key, values in dictionary.items():
|
21 |
+
for value in values:
|
22 |
+
reverse_dictionary[value] = key
|
23 |
+
if apply_lemmatizing:
|
24 |
+
temp = lemmatizing(value)
|
25 |
+
if temp != value:
|
26 |
+
reverse_dictionary[temp] = key
|
27 |
+
return reverse_dictionary
|
28 |
+
|
29 |
+
def load_json(path):
|
30 |
+
"""
|
31 |
+
Load gazetteers from a file
|
32 |
+
:param path: path to the gazetteer file
|
33 |
+
:return: a dict of gazetteers
|
34 |
+
"""
|
35 |
+
with open(path, 'r') as file:
|
36 |
+
data = json.load(file)
|
37 |
+
return data
|
38 |
|
39 |
|
40 |
def load():
|
|
|
46 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
47 |
model.eval()
|
48 |
|
49 |
+
gazetteers_for_matching = load_json(gazetteers_path)
|
50 |
temp = []
|
51 |
for i in gazetteers_for_matching.keys():
|
52 |
temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]}))
|