AlzbetaStrompova commited on
Commit
75a65be
1 Parent(s): 19e9ab7

minor changes

Browse files
app.py CHANGED
@@ -1,32 +1,36 @@
1
- import json
2
  import gradio as gr
3
  from website_script import load, run
4
 
5
  tokenizer, model, gazetteers_for_matching = load()
6
 
7
  examples = [
8
- ["Masarykova univerzita se nachází v Brně .", None],
9
- ["Barack Obama navštívil Prahu minulý týden .", None],
10
- ["Angela Merkelová se setkala s francouzským prezidentem v Paříži .", None],
11
- ["Nobelova cena za fyziku byla udělena týmu vědců z MIT .", None]
12
- ]
 
 
 
 
 
13
 
14
 
15
  def ner(text, file_names):
 
16
  result = run(tokenizer, model, gazetteers_for_matching, text, file_names)
17
  return {"text": text, "entities": result}
18
 
19
  with gr.Blocks(css="./style.css", theme=gr.themes.Default(primary_hue="blue", secondary_hue="sky")) as demo:
20
  gr.Interface(ner,
21
- gr.Textbox(lines=10, placeholder="Enter sentence here..."),
22
- # gr.HighlightedText(show_legend=True, color_map={"PER": "red", "ORG": "green", "LOC": "blue"}),
23
- gr.HighlightedText(show_legend=True, color_map={"PER": "#f57d7d", "ORG": "#2cf562", "LOC": "#86aafc"}, elem_id="highlighted_text"),
24
  examples=examples,
25
  title="NerROB-czech",
26
- description="This is an implementation of a Named Entity Recognition model for the Czech language using gazetteers.",
27
  allow_flagging="never",
28
  additional_inputs=gr.File(label="Upload a JSON file containing gazetteers", file_count="multiple", file_types=[".json"]),
29
  )
30
 
31
  if __name__ == "__main__":
32
- demo.launch()
 
 
1
  import gradio as gr
2
  from website_script import load, run
3
 
4
  tokenizer, model, gazetteers_for_matching = load()
5
 
6
  examples = [
7
+ ["Masarykova univerzita se nachází v Brně.", None],
8
+ ["Barack Obama navštívil Prahu minulý týden.", None],
9
+ ["Angela Merkelová se setkala s francouzským prezidentem v Paříži.", None],
10
+ ["Nobelova cena za fyziku byla udělena týmu vědců z MIT.", None],
11
+ ["Eiffelova věž je ikonickou památkou v Paříži.", None],
12
+ ["Bill Gates, spoluzakladatel společnosti Microsoft, oznámil nový grant pro výzkum umělé inteligence.", None],
13
+ ["Britská královna Alžběta II. navštívila Kanadu v rámci svého posledního zahraničního turné, během kterého zdůraznila důležitost spolupráce a přátelství mezi oběma národy.", None],
14
+ ["Francouzský prezident Emmanuel Macron oznámil nový plán na podporu start-upů a inovací ve Francii, který zahrnuje investice ve výši několika miliard eur.", None],
15
+ ["Světová zdravotnická organizace spustila nový program na boj proti malárii v subsaharské Africe, který zahrnuje rozdělování sítí proti komárům a očkování milionů lidí.", None]
16
+ ]
17
 
18
 
19
  def ner(text, file_names):
20
+ text = text.replace(".", " .")
21
  result = run(tokenizer, model, gazetteers_for_matching, text, file_names)
22
  return {"text": text, "entities": result}
23
 
24
  with gr.Blocks(css="./style.css", theme=gr.themes.Default(primary_hue="blue", secondary_hue="sky")) as demo:
25
  gr.Interface(ner,
26
+ gr.Textbox(lines=5, placeholder="Enter sentence here..."),
27
+ gr.HighlightedText(show_legend=True, color_map={"PER": "#f7a7a3", "ORG": "#77fc6a", "LOC": "#87CEFF"}),
 
28
  examples=examples,
29
  title="NerROB-czech",
30
+ description="This is an implementation of a Named Entity Recognition model for the Czech language using gazetteers.",
31
  allow_flagging="never",
32
  additional_inputs=gr.File(label="Upload a JSON file containing gazetteers", file_count="multiple", file_types=[".json"]),
33
  )
34
 
35
  if __name__ == "__main__":
36
+ demo.launch()
data_manipulation/creation_gazetteers.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import itertools
5
+
6
+ import pandas as pd
7
+ from simplemma import lemmatize
8
+ from names_dataset import NameDataset
9
+
10
+
11
+ def load_json(path):
12
+ """
13
+ Load gazetteers from a file
14
+ :param path: path to the gazetteer file
15
+ :return: a dict of gazetteers
16
+ """
17
+ with open(path, 'r') as file:
18
+ data = json.load(file)
19
+ return data
20
+
21
+
22
+ def save_json(data, path):
23
+ """
24
+ Save gazetteers to a file
25
+ :param path: path to the gazetteer file
26
+ :param gazetteers: a dict of gazetteers
27
+ """
28
+ with open(path, 'w') as file:
29
+ json.dump(data, file, indent=4)
30
+
31
+ def merge_gazetteers(*gazetteers):
32
+ """
33
+ Merge multiple gazetteer dictionaries into a single gazetteer dictionary.
34
+
35
+ Returns:
36
+ dict: A merged gazetteer dictionary containing all the keys and values from the input gazetteers.
37
+ """
38
+ # Initialize a new dictionary to store merged results
39
+ merged_gazetteers = {}
40
+ # Iterate over each dictionary provided
41
+ for gaz in gazetteers:
42
+ # Iterate over each key and set in the current dictionary
43
+ for key, value_set in gaz.items():
44
+ if key in merged_gazetteers:
45
+ # If the key already exists in the result, union the sets
46
+ merged_gazetteers[key] |= value_set
47
+ else:
48
+ # Otherwise, initialize the key with the set from the current dictionary
49
+ merged_gazetteers[key] = value_set.copy() # Use copy to avoid mutating the original sets
50
+ return merged_gazetteers
51
+
52
+
53
+ ####################################################################################################
54
+ ### PREPROCESSING OF GAZETTEERS ###################################################################
55
+ ####################################################################################################
56
+
57
+ def remove_all_brackets(text):
58
+ return re.sub(r'[\(\{\[].*?[\)\}\]]', '', text)
59
+
60
+
61
+ def lemmatizing(x):
62
+ if x == "":
63
+ return ""
64
+ return lemmatize(x, lang="cs")
65
+
66
+
67
+ def multi_lemmatizing(x):
68
+ words = x.split(" ")
69
+ phrase = ""
70
+ for word in words:
71
+ phrase += lemmatizing(word) + " "
72
+ return phrase.strip()
73
+
74
+
75
+ def build_reverse_dictionary(dictionary, apply_lemmatizing=False):
76
+ reverse_dictionary = {}
77
+ for key, values in dictionary.items():
78
+ for value in values:
79
+ reverse_dictionary[value] = key
80
+ if apply_lemmatizing:
81
+ temp = lemmatizing(value)
82
+ if temp != value:
83
+ reverse_dictionary[temp] = key
84
+ return reverse_dictionary
85
+
86
+
87
+ def split_gazetteers_for_single_token_match(gazetteers):
88
+ result = {}
89
+ for k, v in gazetteers.items():
90
+ result[k] = set([x for xs in [vv.split(" ") for vv in v] for x in xs])
91
+ result[k] = {x for x in result[k] if len(x) > 2}
92
+ return result
93
+
94
+
95
+ def preprocess_gazetteers(gazetteers, config):
96
+ if config["remove_brackets"]:
97
+ for k, values in gazetteers.items():
98
+ gazetteers[k] = {remove_all_brackets(vv).strip() for vv in values if len(remove_all_brackets(vv).strip()) > 2}
99
+ if config["split_person"]:
100
+ gazetteers["per"].update(set([x for x in list(itertools.chain(*[v.split(" ") for v in gazetteers["per"]])) if len(x) > 2]))
101
+ if config["techniq_for_matching"] == "single":
102
+ gazetteers = split_gazetteers_for_single_token_match(gazetteers)
103
+ if config["lemmatize"]:
104
+ for k, values in gazetteers.items():
105
+ gazetteers[k] = set(list(itertools.chain(*[(vv, lemmatizing(vv)) for vv in values if len(vv) > 2])))
106
+ elif config["lemmatize"]:
107
+ for k, values in gazetteers.items():
108
+ gazetteers[k] = set(list(itertools.chain(*[(value, multi_lemmatizing(value)) for value in values if len(value) > 2])))
109
+
110
+ if config["remove_numeric"]:
111
+ for k, values in gazetteers.items():
112
+ gazetteers[k] = {vv for vv in values if not vv.isnumeric()}
113
+ for k, values in gazetteers.items():
114
+ gazetteers[k] = list(values)
115
+ return gazetteers
data_manipulation/dataset_funcions.py CHANGED
@@ -1,27 +1,10 @@
1
  import os
2
  import re
3
- import json
4
  from tqdm import tqdm
5
 
6
  from datasets import Dataset, DatasetDict
7
 
8
- def load_gazetteers(path):
9
- """
10
- Load gazetteers from a file
11
- :param path: path to the gazetteer file
12
- :return: a dict of gazetteers
13
- """
14
- with open(path, 'r') as f:
15
- gazetteers = json.load(f)
16
- for k, v in gazetteers.items():
17
- gazetteers[k] = set(v)
18
- return gazetteers
19
-
20
- def create_dataset(label_mapper:dict, args):
21
- if args.dataset == "cnec":
22
- return create_cnec_dataset(label_mapper, args)
23
- return load_wikiann_testing_dataset(args)
24
-
25
 
26
  ####################################################################################################
27
  ### GAZETTEERS EMBEDDINGS ##########################################################################
@@ -43,26 +26,36 @@ def find_multi_token_matches(tokens, looking_tokens, gazetteers, matches):
43
  i += 1
44
  return matches
45
 
46
- def find_single_token_matches(tokens, looking_tokens, gazetteers, matches):
47
- return matches
48
 
49
- def find_combination_single_multi_token_matches(tokens, looking_tokens, gazetteers, matches):
 
 
 
 
 
 
 
 
 
 
50
  return matches
51
 
52
- def gazetteer_matching(words, gazetteers_for_matching):
53
- single_token_match = False
54
- ending_ova = False
55
- apply_lemmatizing = False
56
-
57
 
58
- if single_token_match:
59
- matches = {}
 
 
60
 
 
 
 
 
 
61
  else: # multi_token_match
62
  matches = find_multi_token_matches(words, words, gazetteers_for_matching, {})
63
- # if apply_lemmatizing: TODO
64
- # lemmatize_tokens = [lemmatizing(t) for t in words]
65
- # matches = find_multi_token_matches(words, lemmatize_tokens, gazetteers_for_matching, matches)
66
 
67
  result = []
68
  for word in words:
@@ -70,72 +63,18 @@ def gazetteer_matching(words, gazetteers_for_matching):
70
  per, org, loc = 0, 0, 0
71
  for res in mid_res:
72
  if mid_res[0][0].count(" ") == res[0].count(" "):
73
- if res[1] == "per":
74
- per = 1
75
- elif res[1] == "org":
76
- org = 1
77
- elif res[1] == "loc":
78
- loc = 1
79
  if ending_ova and word.endswith("ová") and word[0].isupper():
80
- per = 1
81
  result.append([per, org, loc])
82
  return result
83
 
84
 
85
- ####################################################################################################
86
- ### GAZETTEERS EXPANSION TRAIN DATASET #############################################################
87
- ####################################################################################################
88
-
89
- def expand_train_dataset_with_gazetteers(train, args):
90
- if args.apply_extended_embeddings:
91
- gazetteers_for_matching = load_gazetteers(args.extended_embeddings_gazetteers_path)
92
- gazetteers = load_gazetteers(args.train_gazetteers_path)
93
- count_gazetteers = {}
94
- id_ = train[-1]["id"]
95
- dataset = []
96
- for row in train:
97
- dataset.append({"id": row['id'], 'tokens': row['tokens'].copy(),
98
- 'ner_tags': row['ner_tags'].copy(), 'gazetteers': row['gazetteers'].copy()})
99
- for k in gazetteers.keys():
100
- count_gazetteers[k] = 0
101
- for index in range(args.gazetteers_counter):
102
- for row in tqdm(train, desc=f"loop {index} from {args.gazetteers_counter}"):
103
- i = 0
104
- temp_1 = row["ner_tags"].copy()
105
- temp_2 = row["tokens"].copy()
106
- if temp_1.count(0) == len(temp_1):
107
- continue
108
- while i < len(temp_1):
109
- tag = temp_1[i]
110
- if tag % 2 == 1:
111
- tags = temp_1[:i]
112
- tokens = temp_2[:i]
113
- i += 1
114
- assert len(gazetteers[tag]) > count_gazetteers[tag]
115
- new = gazetteers[tag][count_gazetteers[tag]].split(" ")
116
- count_gazetteers[tag] += 1
117
- while i < len(temp_1):
118
- if temp_1[i] != tag + 1:
119
- break
120
- i += 1
121
- tags.append(tag)
122
- tags.extend([tag + 1] * (len(new) - 1))
123
- tags.extend(temp_1[i:])
124
-
125
- tokens.extend(new)
126
- tokens.extend(temp_2[i:])
127
- temp_1 = tags
128
- temp_2 = tokens
129
- else:
130
- i += 1
131
- id_ += 1
132
- if args.apply_extended_embeddings:
133
- matching = gazetteer_matching(temp_2, gazetteers_for_matching, args)
134
- dataset.append({"id": id_, 'tokens': temp_2, 'ner_tags': temp_1, "gazetteers": matching})
135
- dataset.append({"id": id_, 'tokens': temp_2, 'ner_tags': temp_1})
136
- return dataset
137
-
138
-
139
  ####################################################################################################
140
  ### CNEC DATASET ###################################################################################
141
  ####################################################################################################
@@ -144,7 +83,6 @@ def get_dataset_from_cnec(label_mapper:dict, xml_file_path, args):
144
  label_mapper: cnec labels to int
145
  """
146
  # Open and read the XML file as plain text
147
- assert os.path.isfile(xml_file_path)
148
  id_ = 0
149
  with open(xml_file_path, "r", encoding="utf-8") as xml_file:
150
  plain_text = xml_file.read()
@@ -156,14 +94,13 @@ def get_dataset_from_cnec(label_mapper:dict, xml_file_path, args):
156
  ne_pattern = r'<ne type="([a-zA-Z?_-]{1,5})">([^<]+)</ne>'
157
  data = []
158
  if args.apply_extended_embeddings:
159
- gazetteers_for_matching = load_gazetteers(args.extended_embeddings_gazetteers_path)
160
- from data_manipulation.preprocess_gazetteers import build_reverse_dictionary
161
  temp = []
162
  for i in gazetteers_for_matching.keys():
163
  temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]}))
164
  gazetteers_for_matching = temp
165
 
166
- for sentence in tqdm(sentences):
167
  entity_mapping = []
168
  while "<ne type=" in sentence: # while because there are nested entities
169
  nes = re.findall(ne_pattern, sentence)
@@ -215,7 +152,7 @@ def get_dataset_from_cnec(label_mapper:dict, xml_file_path, args):
215
  if tags_per_word == [] or tags_per_word == [0]:
216
  continue
217
  if args.apply_extended_embeddings:
218
- matching = gazetteer_matching(words, gazetteers_for_matching)
219
  data.append({"id": id_, 'tokens': words, 'ner_tags': tags_per_word,
220
  "sentence": " ".join(words), "gazetteers": matching})
221
  else:
@@ -223,104 +160,78 @@ def get_dataset_from_cnec(label_mapper:dict, xml_file_path, args):
223
  id_ += 1
224
  return data
225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
- def create_dataset2(label_mapper:dict, gazetteers_path):
228
- path = "/nlp/projekty/gazetteer_ner/cnec2.0/data/xml"
229
- dataset = DatasetDict()
230
- for part, file_name in zip(["train", "validation", "test"],["named_ent_train.xml", "named_ent_etest.xml", "named_ent_dtest.xml"]):
231
- file_path = os.path.join(path, file_name)
232
- ##
233
- id_ = 0
234
- with open(file_path, "r", encoding="utf-8") as xml_file:
235
- plain_text = xml_file.read()
236
- plain_text = plain_text[5:-5] # remove unnessery characters
237
- plain_text = re.sub(r'([a-zA-Z.])<ne', r'\1 <ne', plain_text)
238
- plain_text = re.sub(r'</ne>([a-zA-Z.])', r'</ne> \1', plain_text)
239
- plain_text = re.sub(r'[ ]+', ' ', plain_text)
240
- sentences = plain_text.split("\n")
241
- ne_pattern = r'<ne type="([a-zA-Z?_-]{1,5})">([^<]+)</ne>'
242
- data = []
243
- if True:
244
- gazetteers_for_matching = load_gazetteers(gazetteers_path)
245
- from data_manipulation.preprocess_gazetteers import build_reverse_dictionary
246
- temp = []
247
- for i in gazetteers_for_matching.keys():
248
- temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]}))
249
- gazetteers_for_matching = temp
250
-
251
- for sentence in tqdm(sentences):
252
- entity_mapping = []
253
- while "<ne type=" in sentence: # while because there are nested entities
254
- nes = re.findall(ne_pattern, sentence)
255
- for label, entity in nes:
256
- pattern = f'<ne type="{label}">{entity}</ne>'
257
- index = sentence.index(pattern)
258
- temp_index = index
259
- sentence = sentence.replace(pattern, entity, 1)
260
- temp_index -= sum([len(f'<ne type="{tag}">') for tag in re.findall(r'<ne type="([a-zA-Z?_-]{1,5})">', sentence[:index])])
261
- temp_index -= sentence[:index].count("</ne>") * len("</ne>")
262
- temp_index -= (re.sub(r'<ne type="([a-zA-Z?_-]{1,5})">', "", sentence[:index]).replace("</ne>", "")).count(" ")
263
- index = temp_index
264
- entity_mapping.append((entity, label, index, index + len(entity)))
265
-
266
- entities = []
267
- for entity, label, start, end in entity_mapping:
268
- for tag in label_mapper.keys():
269
- if label.lower().startswith(tag):
270
- entities.append((label_mapper[tag], entity, start, end))
271
- break
272
- entities.sort(key=lambda x: len(x[1]), reverse=True)
273
-
274
- words = re.split(r'\s+', sentence)
275
- tags_per_word = []
276
- sentence_counter = -1
277
- for word in words:
278
- sentence_counter += len(word) + 1
279
- if len(entities) == 0:
280
- tags_per_word.append(0) # tag representing no label for no word
281
- for index_entity in range(len(entities)):
282
- if not(sentence_counter - len(word) >= entities[index_entity][2] and
283
- sentence_counter <= entities[index_entity][3] and
284
- word in entities[index_entity][1]):
285
- if index_entity == len(entities) - 1:
286
- tags_per_word.append(0) # tag representing no label for word
287
- continue
288
-
289
- if True:
290
- if sentence_counter - len(word) == entities[index_entity][2]:
291
- tags_per_word.append(entities[index_entity][0] * 2 - 1) # beggining of entity
292
- else:
293
- tags_per_word.append(entities[index_entity][0] * 2) # inside of entity
294
- else:
295
- tags_per_word.append(entities[index_entity][0])
296
  break
 
297
 
298
- if tags_per_word == [] or tags_per_word == [0]:
299
- continue
300
- if True:
301
- matching = gazetteer_matching(words, gazetteers_for_matching)
302
- data.append({"id": id_, 'tokens': words, 'ner_tags': tags_per_word,
303
- "sentence": " ".join(words), "gazetteers": matching})
304
- else:
305
- data.append({"id": id_, 'tokens': words, 'ner_tags': tags_per_word, "sentence": " ".join(words)})
306
- id_ += 1
 
 
 
 
 
307
 
 
 
 
 
308
 
309
- ##
310
- dataset[part] = Dataset.from_list(data)
311
- return dataset
 
 
 
312
 
313
 
314
  def create_cnec_dataset(label_mapper:dict, args):
315
-
316
- assert os.path.isdir(args.cnec_dataset_dir_path)
317
  dataset = DatasetDict()
318
  for part, file_name in zip(["train", "validation", "test"],["named_ent_train.xml", "named_ent_etest.xml", "named_ent_dtest.xml"]):
319
  file_path = os.path.join(args.cnec_dataset_dir_path, file_name)
320
- assert os.path.isfile(file_path)
321
  temp_dataset = get_dataset_from_cnec(label_mapper, file_path, args)
322
- if args.expand_train_data:
323
- temp_dataset = expand_train_dataset_with_gazetteers(temp_dataset, args)
324
  dataset[part] = Dataset.from_list(temp_dataset)
325
  return dataset
326
 
@@ -328,16 +239,19 @@ def create_cnec_dataset(label_mapper:dict, args):
328
  ### WIKIANN DATASET ################################################################################
329
  ####################################################################################################
330
  def load_wikiann_testing_dataset(args):
331
- if args.apply_gazetteers_info:
332
- gazetteers_for_matching = load_gazetteers(args.extended_embeddings_gazetteers_path)
333
- assert os.path.isfile(args.wikiann_dataset_path)
 
 
 
334
  dataset = []
335
  index = 0
336
  sentences = load_tagged_sentences(args.wikiann_dataset_path)
337
  for sentence in sentences:
338
  words = [word for word, _ in sentence]
339
  tags = [tag for _, tag in sentence]
340
- if args.apply_gazetteers_info:
341
  matching = gazetteer_matching(words, gazetteers_for_matching, args)
342
  dataset.append({"id": index, 'tokens': words, 'ner_tags': tags, "gazetteers": matching})
343
  else:
@@ -345,9 +259,10 @@ def load_wikiann_testing_dataset(args):
345
  index += 1
346
 
347
  test = Dataset.from_list(dataset)
348
- # dataset = DatasetDict({"train": Dataset.from_list([{"id": 1, 'tokens': [], 'ner_tags': [], "gazetteers": []}]),
349
- # "validation": Dataset.from_list([{"id": 1, 'tokens': [], 'ner_tags': [], "gazetteers": []}]), "test": test})
350
- dataset = DatasetDict({"test": test})
 
351
  return dataset
352
 
353
 
@@ -400,26 +315,24 @@ def align_labels_with_tokens(labels, word_ids):
400
  new_labels.append(label)
401
  return new_labels
402
 
 
403
  def align_gazetteers_with_tokens(gazetteers, word_ids):
404
- new_g = []
405
  current_word = None
406
  for word_id in word_ids:
407
  if word_id != current_word:
408
  # Start of a new word!
409
  current_word = word_id
410
  gazetteer = [0,0,0] if word_id is None else gazetteers[word_id]
411
- new_g.append(gazetteer)
412
  elif word_id is None:
413
  # Special token
414
- new_g.append([0,0,0])
415
  else:
416
  # Same word as previous token
417
  gazetteer = gazetteers[word_id]
418
- # # If the label is B-XXX we change it to I-XXX
419
- # if gazetteer % 2 == 1:
420
- # gazetteer += 1
421
- new_g.append(gazetteer)
422
- return new_g
423
 
424
 
425
  def create_tokenized_dataset(raw_dataset, tokenizer, apply_extended_embeddings=True):
@@ -434,25 +347,24 @@ def create_tokenized_dataset(raw_dataset, tokenizer, apply_extended_embeddings=T
434
  new_labels.append(align_labels_with_tokens(labels, word_ids))
435
  tokenized_inputs["labels"] = new_labels
436
  if apply_extended_embeddings:
437
- g = examples["gazetteers"]
438
- new_g = []
439
- for i, g in enumerate(g):
440
  word_ids = tokenized_inputs.word_ids(i)
441
- new_g.append(align_gazetteers_with_tokens(g, word_ids))
442
- p, o, l = [], [], []
443
- for i in new_g:
444
- p.append([x[0] for x in i])
445
- o.append([x[1] for x in i])
446
- l.append([x[2] for x in i])
447
- tokenized_inputs["per"] = p
448
- tokenized_inputs["org"] = o
449
- tokenized_inputs["loc"] = l
450
  return tokenized_inputs
451
 
452
-
453
  dataset = raw_dataset.map(
454
  tokenize_and_align_labels,
455
  batched=True,
456
- remove_columns=raw_dataset["train"].column_names,
457
  )
458
  return dataset
 
1
  import os
2
  import re
 
3
  from tqdm import tqdm
4
 
5
  from datasets import Dataset, DatasetDict
6
 
7
+ from data_manipulation.creation_gazetteers import build_reverse_dictionary, lemmatizing, load_json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  ####################################################################################################
10
  ### GAZETTEERS EMBEDDINGS ##########################################################################
 
26
  i += 1
27
  return matches
28
 
 
 
29
 
30
+ def find_single_token_matches(tokens, looking_tokens, gazetteers, matches):
31
+ n = len(tokens)
32
+ assert n == len(looking_tokens)
33
+ for index in range(n):
34
+ word = looking_tokens[index]
35
+ if len(word) < 3:
36
+ continue
37
+ for gazetteer in gazetteers:
38
+ if word in gazetteer:
39
+ match_type = gazetteer[word]
40
+ matches.setdefault(tokens[index], []).append((word, match_type))
41
  return matches
42
 
 
 
 
 
 
43
 
44
+ def gazetteer_matching(words, gazetteers_for_matching, args=None):
45
+ ending_ova = True
46
+ method_for_gazetteers_matching = "single"
47
+ apply_lemmatizing = True
48
 
49
+ if method_for_gazetteers_matching == "single":
50
+ matches = find_single_token_matches(words, words, gazetteers_for_matching, {})
51
+ if apply_lemmatizing:
52
+ lemmatize_tokens = [lemmatizing(t) for t in words]
53
+ matches = find_single_token_matches(words, lemmatize_tokens, gazetteers_for_matching, matches)
54
  else: # multi_token_match
55
  matches = find_multi_token_matches(words, words, gazetteers_for_matching, {})
56
+ if apply_lemmatizing:
57
+ lemmatize_tokens = [lemmatizing(t) for t in words]
58
+ matches = find_multi_token_matches(words, lemmatize_tokens, gazetteers_for_matching, matches)
59
 
60
  result = []
61
  for word in words:
 
63
  per, org, loc = 0, 0, 0
64
  for res in mid_res:
65
  if mid_res[0][0].count(" ") == res[0].count(" "):
66
+ if res[1] == "PER":
67
+ per = 5
68
+ elif res[1] == "ORG":
69
+ org = 5
70
+ elif res[1] == "LOC":
71
+ loc = 5
72
  if ending_ova and word.endswith("ová") and word[0].isupper():
73
+ per = 5
74
  result.append([per, org, loc])
75
  return result
76
 
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  ####################################################################################################
79
  ### CNEC DATASET ###################################################################################
80
  ####################################################################################################
 
83
  label_mapper: cnec labels to int
84
  """
85
  # Open and read the XML file as plain text
 
86
  id_ = 0
87
  with open(xml_file_path, "r", encoding="utf-8") as xml_file:
88
  plain_text = xml_file.read()
 
94
  ne_pattern = r'<ne type="([a-zA-Z?_-]{1,5})">([^<]+)</ne>'
95
  data = []
96
  if args.apply_extended_embeddings:
97
+ gazetteers_for_matching = load_json(args.extended_embeddings_gazetteers_path)
 
98
  temp = []
99
  for i in gazetteers_for_matching.keys():
100
  temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]}))
101
  gazetteers_for_matching = temp
102
 
103
+ for sentence in tqdm(sentences):
104
  entity_mapping = []
105
  while "<ne type=" in sentence: # while because there are nested entities
106
  nes = re.findall(ne_pattern, sentence)
 
152
  if tags_per_word == [] or tags_per_word == [0]:
153
  continue
154
  if args.apply_extended_embeddings:
155
+ matching = gazetteer_matching(words, gazetteers_for_matching, args)
156
  data.append({"id": id_, 'tokens': words, 'ner_tags': tags_per_word,
157
  "sentence": " ".join(words), "gazetteers": matching})
158
  else:
 
160
  id_ += 1
161
  return data
162
 
163
+ def get_default_dataset_from_cnec(label_mapper:dict, xml_file_path):
164
+ """
165
+ label_mapper: cnec labels to int
166
+ """
167
+ # Open and read the XML file as plain text
168
+ id_ = 0
169
+ with open(xml_file_path, "r", encoding="utf-8") as xml_file:
170
+ plain_text = xml_file.read()
171
+ plain_text = plain_text[5:-5] # remove unnessery characters
172
+ plain_text = re.sub(r'([a-zA-Z.])<ne', r'\1 <ne', plain_text)
173
+ plain_text = re.sub(r'</ne>([a-zA-Z.])', r'</ne> \1', plain_text)
174
+ plain_text = re.sub(r'[ ]+', ' ', plain_text)
175
+ sentences = plain_text.split("\n")
176
+ ne_pattern = r'<ne type="([a-zA-Z?_-]{1,5})">([^<]+)</ne>'
177
+ data = []
178
 
179
+ for sentence in tqdm(sentences):
180
+ entity_mapping = []
181
+ while "<ne type=" in sentence: # while because there are nested entities
182
+ nes = re.findall(ne_pattern, sentence)
183
+ for label, entity in nes:
184
+ pattern = f'<ne type="{label}">{entity}</ne>'
185
+ index = sentence.index(pattern)
186
+ temp_index = index
187
+ sentence = sentence.replace(pattern, entity, 1)
188
+ temp_index -= sum([len(f'<ne type="{tag}">') for tag in re.findall(r'<ne type="([a-zA-Z?_-]{1,5})">', sentence[:index])])
189
+ temp_index -= sentence[:index].count("</ne>") * len("</ne>")
190
+ temp_index -= (re.sub(r'<ne type="([a-zA-Z?_-]{1,5})">', "", sentence[:index]).replace("</ne>", "")).count(" ")
191
+ index = temp_index
192
+ entity_mapping.append((entity, label, index, index + len(entity)))
193
+
194
+ entities = []
195
+ for entity, label, start, end in entity_mapping:
196
+ for tag in label_mapper.keys():
197
+ if label.lower().startswith(tag):
198
+ entities.append((label_mapper[tag], entity, start, end))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  break
200
+ entities.sort(key=lambda x: len(x[1]), reverse=True)
201
 
202
+ words = re.split(r'\s+', sentence)
203
+ tags_per_word = []
204
+ sentence_counter = -1
205
+ for word in words:
206
+ sentence_counter += len(word) + 1
207
+ if len(entities) == 0:
208
+ tags_per_word.append(0) # tag representing no label for no word
209
+ for index_entity in range(len(entities)):
210
+ if not(sentence_counter - len(word) >= entities[index_entity][2] and
211
+ sentence_counter <= entities[index_entity][3] and
212
+ word in entities[index_entity][1]):
213
+ if index_entity == len(entities) - 1:
214
+ tags_per_word.append(0) # tag representing no label for word
215
+ continue
216
 
217
+ if sentence_counter - len(word) == entities[index_entity][2]:
218
+ tags_per_word.append(entities[index_entity][0] * 2 - 1) # beggining of entity
219
+ else:
220
+ tags_per_word.append(entities[index_entity][0] * 2) # inside of entity
221
 
222
+ if tags_per_word == [] or tags_per_word == [0]:
223
+ continue
224
+
225
+ data.append({"id": id_, 'tokens': words, 'ner_tags': tags_per_word, "sentence": " ".join(words)})
226
+ id_ += 1
227
+ return data
228
 
229
 
230
  def create_cnec_dataset(label_mapper:dict, args):
 
 
231
  dataset = DatasetDict()
232
  for part, file_name in zip(["train", "validation", "test"],["named_ent_train.xml", "named_ent_etest.xml", "named_ent_dtest.xml"]):
233
  file_path = os.path.join(args.cnec_dataset_dir_path, file_name)
 
234
  temp_dataset = get_dataset_from_cnec(label_mapper, file_path, args)
 
 
235
  dataset[part] = Dataset.from_list(temp_dataset)
236
  return dataset
237
 
 
239
  ### WIKIANN DATASET ################################################################################
240
  ####################################################################################################
241
  def load_wikiann_testing_dataset(args):
242
+ if args.apply_extended_embeddings:
243
+ gazetteers_for_matching = load_json(args.extended_embeddings_gazetteers_path)
244
+ temp = []
245
+ for i in gazetteers_for_matching.keys():
246
+ temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]}))
247
+ gazetteers_for_matching = temp
248
  dataset = []
249
  index = 0
250
  sentences = load_tagged_sentences(args.wikiann_dataset_path)
251
  for sentence in sentences:
252
  words = [word for word, _ in sentence]
253
  tags = [tag for _, tag in sentence]
254
+ if args.apply_extended_embeddings:
255
  matching = gazetteer_matching(words, gazetteers_for_matching, args)
256
  dataset.append({"id": index, 'tokens': words, 'ner_tags': tags, "gazetteers": matching})
257
  else:
 
259
  index += 1
260
 
261
  test = Dataset.from_list(dataset)
262
+ dataset = DatasetDict({"train": Dataset.from_list([{"id": 1, 'tokens': [], 'ner_tags': [], "gazetteers": []}]),
263
+ "validation": Dataset.from_list([{"id": 1, 'tokens': [], 'ner_tags': [], "gazetteers": []}]),
264
+ "test": test})
265
+ # dataset = DatasetDict({"test": test})
266
  return dataset
267
 
268
 
 
315
  new_labels.append(label)
316
  return new_labels
317
 
318
+
319
  def align_gazetteers_with_tokens(gazetteers, word_ids):
320
+ aligned_gazetteers = []
321
  current_word = None
322
  for word_id in word_ids:
323
  if word_id != current_word:
324
  # Start of a new word!
325
  current_word = word_id
326
  gazetteer = [0,0,0] if word_id is None else gazetteers[word_id]
327
+ aligned_gazetteers.append(gazetteer)
328
  elif word_id is None:
329
  # Special token
330
+ aligned_gazetteers.append([0,0,0])
331
  else:
332
  # Same word as previous token
333
  gazetteer = gazetteers[word_id]
334
+ aligned_gazetteers.append(gazetteer)
335
+ return aligned_gazetteers
 
 
 
336
 
337
 
338
  def create_tokenized_dataset(raw_dataset, tokenizer, apply_extended_embeddings=True):
 
347
  new_labels.append(align_labels_with_tokens(labels, word_ids))
348
  tokenized_inputs["labels"] = new_labels
349
  if apply_extended_embeddings:
350
+ matches = examples["gazetteers"]
351
+ aligned_matches = []
352
+ for i, match in enumerate(matches):
353
  word_ids = tokenized_inputs.word_ids(i)
354
+ aligned_matches.append(align_gazetteers_with_tokens(match, word_ids))
355
+ per, org, loc = [], [], []
356
+ for i in aligned_matches:
357
+ per.append([x[0] for x in i])
358
+ org.append([x[1] for x in i])
359
+ loc.append([x[2] for x in i])
360
+ tokenized_inputs["per"] = per
361
+ tokenized_inputs["org"] = org
362
+ tokenized_inputs["loc"] = loc
363
  return tokenized_inputs
364
 
 
365
  dataset = raw_dataset.map(
366
  tokenize_and_align_labels,
367
  batched=True,
368
+ # remove_columns=raw_dataset["train"].column_names
369
  )
370
  return dataset
data_manipulation/preprocess_gazetteers.py DELETED
@@ -1,54 +0,0 @@
1
- import re
2
-
3
- from simplemma import lemmatize
4
-
5
-
6
- def flatten(xss):
7
- return [x for xs in xss for x in xs]
8
-
9
-
10
- def remove_all_brackets(text):
11
- return re.sub(r'[\(\{\[].*?[\)\}\]]', '', text)
12
-
13
-
14
- def lemmatizing(x):
15
- if x == "":
16
- return ""
17
- return lemmatize(x, lang="cs")
18
-
19
-
20
- def build_reverse_dictionary(dictionary, apply_lemmatizing=False):
21
- reverse_dictionary = {}
22
- for key, values in dictionary.items():
23
- for value in values:
24
- reverse_dictionary[value] = key
25
- if apply_lemmatizing:
26
- temp = lemmatizing(value)
27
- if temp != value:
28
- reverse_dictionary[temp] = key
29
- return reverse_dictionary
30
-
31
-
32
- def split_gazetteers_for_single_token_match(gazetteers):
33
- result = {}
34
- for k, v in gazetteers.items():
35
- result[k] = set(flatten([vv.split(" ") for vv in v]))
36
- result[k] = {x for x in result[k] if len(x) > 2}
37
- return result
38
-
39
-
40
- def preprocess_gazetteers(gazetteers, config):
41
- if config["split_person"]:
42
- gazetteers["PER"].update(set([x for x in flatten([v.split(" ") for v in gazetteers["PER"]]) if len(x) > 2]))
43
- if config["lemmatize"]:
44
- for k, v in gazetteers.items():
45
- gazetteers[k] = set(flatten([(vv, lemmatizing(vv)) for vv in v if len(vv) > 2]))
46
- if config["remove_brackets"]:
47
- for k, v in gazetteers.items():
48
- gazetteers[k] = {remove_all_brackets(vv).strip() for vv in v if len(remove_all_brackets(vv).strip()) > 2}
49
- if config["remove_numeric"]:
50
- for k, v in gazetteers.items():
51
- gazetteers[k] = {vv for vv in v if not vv.isnumeric()}
52
- if config["techniq_for_matching"] != "single":
53
- gazetteers = split_gazetteers_for_single_token_match(gazetteers)
54
- return gazetteers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
extended_embeddings/__init__.py DELETED
File without changes
extended_embeddings/{token_classification.py → extended_embedding_token_classification.py} RENAMED
@@ -1,4 +1,4 @@
1
- from typing import List, Optional, Tuple, Union
2
 
3
  import torch
4
  from torch import nn
@@ -12,11 +12,20 @@ _CONFIG_FOR_DOC = "RobertaConfig"
12
 
13
 
14
  class ExtendedEmbeddigsRobertaForTokenClassification(RobertaForTokenClassification):
 
 
 
 
 
 
 
 
 
15
  def __init__(self, config):
16
  super().__init__(config)
17
  self.num_labels = config.num_labels
18
 
19
- self.roberta = ExtendedEmbeddigsRobertaModel(config, add_pooling_layer=False)
20
  classifier_dropout = (
21
  config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
22
  )
@@ -92,4 +101,5 @@ class ExtendedEmbeddigsRobertaForTokenClassification(RobertaForTokenClassificati
92
  logits=logits,
93
  hidden_states=outputs.hidden_states,
94
  attentions=outputs.attentions,
95
- )
 
 
1
+ from typing import Optional, Tuple, Union
2
 
3
  import torch
4
  from torch import nn
 
12
 
13
 
14
  class ExtendedEmbeddigsRobertaForTokenClassification(RobertaForTokenClassification):
15
+ """
16
+ A RobertaForTokenClassification for token classification tasks with extended embeddings.
17
+
18
+ This RobertaForTokenClassification extends the functionality of the `RobertaForTokenClassification` class
19
+ by adding support for additional features such as `per`, `org`, and `loc`.
20
+
21
+ Part of the code copied from: transformers.models.bert.modeling_roberta.RobertaForTokenClassification
22
+
23
+ """
24
  def __init__(self, config):
25
  super().__init__(config)
26
  self.num_labels = config.num_labels
27
 
28
+ self.roberta = ExtendedEmbeddigsRobertaModel(config)
29
  classifier_dropout = (
30
  config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
31
  )
 
101
  logits=logits,
102
  hidden_states=outputs.hidden_states,
103
  attentions=outputs.attentions,
104
+ )
105
+
extended_embeddings/extended_embeddings_data_collator.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import DataCollatorForTokenClassification
3
+ from transformers.data.data_collator import pad_without_fast_tokenizer_warning
4
+
5
+
6
+ class ExtendedEmbeddingsDataCollatorForTokenClassification(DataCollatorForTokenClassification):
7
+ """
8
+ A data collator for token classification tasks with extended embeddings.
9
+
10
+ This data collator extends the functionality of the `DataCollatorForTokenClassification` class
11
+ by adding support for additional features such as `per`, `org`, and `loc`.
12
+
13
+ Part of the code copied from: transformers.data.data_collator.DataCollatorForTokenClassification
14
+ """
15
+
16
+ def torch_call(self, features):
17
+ label_name = "label" if "label" in features[0].keys() else "labels"
18
+ labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
19
+ per = [feature["per"] for feature in features] if "per" in features[0].keys() else None
20
+ org = [feature["org"] for feature in features] if "org" in features[0].keys() else None
21
+ loc = [feature["loc"] for feature in features] if "loc" in features[0].keys() else None
22
+
23
+ no_labels_features = [{k: v for k, v in feature.items() if k not in [label_name, "per", "org", "loc"]} for feature in features]
24
+
25
+ batch = pad_without_fast_tokenizer_warning(
26
+ self.tokenizer,
27
+ no_labels_features,
28
+ padding=self.padding,
29
+ max_length=self.max_length,
30
+ pad_to_multiple_of=self.pad_to_multiple_of,
31
+ return_tensors="pt",
32
+ )
33
+
34
+ if labels is None:
35
+ return batch
36
+
37
+ sequence_length = batch["input_ids"].shape[1]
38
+ padding_side = self.tokenizer.padding_side
39
+
40
+ def to_list(tensor_or_iterable):
41
+ if isinstance(tensor_or_iterable, torch.Tensor):
42
+ return tensor_or_iterable.tolist()
43
+ return list(tensor_or_iterable)
44
+
45
+ if padding_side == "right":
46
+ batch[label_name] = [
47
+ to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
48
+ ]
49
+ batch["per"] = [
50
+ to_list(p) + [0] * (sequence_length - len(p)) for p in per
51
+ ]
52
+ batch["org"] = [
53
+ to_list(o) + [0] * (sequence_length - len(o)) for o in org
54
+ ]
55
+ batch["loc"] = [
56
+ to_list(l) + [0] * (sequence_length - len(l)) for l in loc
57
+ ]
58
+ else:
59
+ batch[label_name] = [
60
+ [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels
61
+ ]
62
+ batch["per"] = [
63
+ [0] * (sequence_length - len(p)) + self.to_list(p) for p in per
64
+ ]
65
+ batch["org"] = [
66
+ [0] * (sequence_length - len(o)) + self.to_list(o) for o in org
67
+ ]
68
+ batch["loc"] = [
69
+ [0] * (sequence_length - len(l)) + self.to_list(l) for l in loc
70
+ ]
71
+
72
+ batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64)
73
+ batch["per"] = torch.tensor(batch["per"], dtype=torch.int64)
74
+ batch["org"] = torch.tensor(batch["org"], dtype=torch.int64)
75
+ batch["loc"] = torch.tensor(batch["loc"], dtype=torch.int64)
76
+ return batch
77
+
extended_embeddings/extended_embeddings_model.py CHANGED
@@ -1,53 +1,27 @@
1
- from transformers.models.roberta.modeling_roberta import RobertaModel, RobertaEncoder, RobertaEmbeddings
2
- from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
3
  from typing import List, Optional, Tuple, Union
 
4
  import torch
5
- from torch.nn import functional as F
6
- from torch import nn
7
 
8
- # Copied from transformers.models.bert.modeling_bert.BertPooler
9
- class ExtendedEmbeddigsRobertaPooler(nn.Module):
10
- def __init__(self, config):
11
- super().__init__()
12
- size_of_gazetters_part = int((len(config.id2label.keys()) - 1) // 2)
13
- self.dense = nn.Linear(config.hidden_size + size_of_gazetters_part, config.hidden_size + size_of_gazetters_part)
14
- self.activation = nn.Tanh()
15
-
16
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
17
- # We "pool" the model by simply taking the hidden state corresponding
18
- # to the first token.
19
- first_token_tensor = hidden_states[:, 0]
20
- pooled_output = self.dense(first_token_tensor)
21
- pooled_output = self.activation(pooled_output)
22
- return pooled_output
23
 
24
  class ExtendedEmbeddigsRobertaModel(RobertaModel):
25
  """
 
26
 
27
- The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
28
- cross-attention is added between the self-attention layers, following the architecture described in *Attention is
29
- all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
30
- Kaiser and Illia Polosukhin.
31
 
32
- To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
33
- to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
34
- `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
35
-
36
- .. _*Attention is all you need*: https://arxiv.org/abs/1706.03762
37
 
38
  """
39
-
40
- # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta
41
- def __init__(self, config, add_pooling_layer=True):
42
  super().__init__(config)
43
  self.config = config
44
 
45
  self.embeddings = RobertaEmbeddings(config)
46
  self.encoder = RobertaEncoder(config)
47
- # self.gazetteers = GazetteersNetwork() # change
48
-
49
- self.pooler = ExtendedEmbeddigsRobertaPooler(config)
50
-
51
  # Initialize weights and apply final processing
52
  self.post_init()
53
 
@@ -57,10 +31,9 @@ class ExtendedEmbeddigsRobertaModel(RobertaModel):
57
  attention_mask: Optional[torch.Tensor] = None,
58
  token_type_ids: Optional[torch.Tensor] = None,
59
  position_ids: Optional[torch.Tensor] = None,
60
- # gazetteers_ids: Optional[torch.Tensor] = None, # change
61
- per: Optional[torch.Tensor] = None, # change
62
- org: Optional[torch.Tensor] = None, # change
63
- loc: Optional[torch.Tensor] = None, # change
64
  head_mask: Optional[torch.Tensor] = None,
65
  inputs_embeds: Optional[torch.Tensor] = None,
66
  encoder_hidden_states: Optional[torch.Tensor] = None,
 
 
 
1
  from typing import List, Optional, Tuple, Union
2
+
3
  import torch
4
+ from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
5
+ from transformers.models.roberta.modeling_roberta import RobertaModel, RobertaEncoder, RobertaEmbeddings
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  class ExtendedEmbeddigsRobertaModel(RobertaModel):
9
  """
10
+ A RobertaModel for token classification tasks with extended embeddings.
11
 
12
+ This RobertaModel extends the functionality of the `RobertaModel` class
13
+ by adding support for additional features such as `per`, `org`, and `loc`.
 
 
14
 
15
+ Part of the code copied from: transformers.models.bert.modeling_roberta.RobertaModel
 
 
 
 
16
 
17
  """
18
+ def __init__(self, config):
 
 
19
  super().__init__(config)
20
  self.config = config
21
 
22
  self.embeddings = RobertaEmbeddings(config)
23
  self.encoder = RobertaEncoder(config)
24
+ self.pooler = None
 
 
 
25
  # Initialize weights and apply final processing
26
  self.post_init()
27
 
 
31
  attention_mask: Optional[torch.Tensor] = None,
32
  token_type_ids: Optional[torch.Tensor] = None,
33
  position_ids: Optional[torch.Tensor] = None,
34
+ per: Optional[torch.Tensor] = None,
35
+ org: Optional[torch.Tensor] = None,
36
+ loc: Optional[torch.Tensor] = None,
 
37
  head_mask: Optional[torch.Tensor] = None,
38
  inputs_embeds: Optional[torch.Tensor] = None,
39
  encoder_hidden_states: Optional[torch.Tensor] = None,
flagged/log.csv DELETED
@@ -1,8 +0,0 @@
1
- text,output,flag,username,timestamp
2
- Masarykova univerzita se nachází v Brně .,"[{""token"": """", ""class_or_confidence"": null}, {""token"": ""Masarykova univerzita"", ""class_or_confidence"": ""ORG""}, {""token"": "" se nach\u00e1z\u00ed v "", ""class_or_confidence"": null}, {""token"": ""Brn\u011b"", ""class_or_confidence"": ""LOC""}, {""token"": "" ."", ""class_or_confidence"": null}]",,,2024-05-06 02:29:01.157209
3
- Barack Obama navštívil Prahu minulý týden .,"[{""token"": """", ""class_or_confidence"": null}, {""token"": ""Barack Obama"", ""class_or_confidence"": ""OSV""}, {""token"": "" nav\u0161t\u00edvil "", ""class_or_confidence"": null}, {""token"": ""Prahu"", ""class_or_confidence"": ""LOC""}, {""token"": "" minul\u00fd t\u00fdden ."", ""class_or_confidence"": null}]",,,2024-05-06 02:31:57.950478
4
- Masarykova univerzita se nachází v Brně .,"[{""token"": """", ""class_or_confidence"": null}, {""token"": ""Masarykova univerzita"", ""class_or_confidence"": ""ORG""}, {""token"": "" se nach\u00e1z\u00ed v "", ""class_or_confidence"": null}, {""token"": ""Brn\u011b"", ""class_or_confidence"": ""LOC""}, {""token"": "" ."", ""class_or_confidence"": null}]",,,2024-05-06 02:51:30.197653
5
- Barack Obama navštívil Prahu minulý týden .,,,,2024-05-06 10:58:33.085992
6
- Masarykova univerzita se nachází v Brně .,"[{""token"": """", ""class_or_confidence"": null}, {""token"": ""Masarykova univerzita"", ""class_or_confidence"": ""ORG""}, {""token"": "" se nach\u00e1z\u00ed v "", ""class_or_confidence"": null}, {""token"": ""Brn\u011b"", ""class_or_confidence"": ""LOC""}, {""token"": "" ."", ""class_or_confidence"": null}]",,,2024-05-06 11:00:17.762652
7
- Masarykova univerzita se nachází v Brně .,"[{""token"": """", ""class_or_confidence"": null}, {""token"": ""Masarykova univerzita"", ""class_or_confidence"": ""ORG""}, {""token"": "" se nach\u00e1z\u00ed v "", ""class_or_confidence"": null}, {""token"": ""Brn\u011b"", ""class_or_confidence"": ""LOC""}, {""token"": "" ."", ""class_or_confidence"": null}]",,,2024-05-06 11:00:20.057269
8
- ,,,,,2024-05-09 22:59:12.114264
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -5,3 +5,4 @@ torch
5
  simplemma
6
  gradio
7
  pandas
 
 
5
  simplemma
6
  gradio
7
  pandas
8
+ name-datasets
style.css CHANGED
@@ -6,10 +6,6 @@ footer {
6
  color-scheme: light dark;
7
  }
8
 
9
- .container .svelte-ju12zg {
10
- color: light-dark(black, white);
11
- }
12
-
13
  .text.svelte-ju12zg {
14
  padding: 0;
15
  margin: 0;
@@ -23,4 +19,9 @@ footer {
23
  .textspan.svelte-ju12zg.no-cat {
24
  margin: 0;
25
  padding: 0;
26
- }
 
 
 
 
 
 
6
  color-scheme: light dark;
7
  }
8
 
 
 
 
 
9
  .text.svelte-ju12zg {
10
  padding: 0;
11
  margin: 0;
 
19
  .textspan.svelte-ju12zg.no-cat {
20
  margin: 0;
21
  padding: 0;
22
+ }
23
+
24
+ .category-label.svelte-ju12zg {
25
+ background-color: light-dark(white, black,);
26
+
27
+ }
upload_model.ipynb CHANGED
@@ -2,13 +2,13 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 2,
6
  "metadata": {},
7
  "outputs": [
8
  {
9
  "data": {
10
  "application/vnd.jupyter.widget-view+json": {
11
- "model_id": "65fea98bf7924f4fb4947d8e2dda2f4d",
12
  "version_major": 2,
13
  "version_minor": 0
14
  },
@@ -28,7 +28,7 @@
28
  },
29
  {
30
  "cell_type": "code",
31
- "execution_count": 3,
32
  "metadata": {},
33
  "outputs": [
34
  {
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 1,
6
  "metadata": {},
7
  "outputs": [
8
  {
9
  "data": {
10
  "application/vnd.jupyter.widget-view+json": {
11
+ "model_id": "556291d727474e0a82723d6459722b16",
12
  "version_major": 2,
13
  "version_minor": 0
14
  },
 
28
  },
29
  {
30
  "cell_type": "code",
31
+ "execution_count": 2,
32
  "metadata": {},
33
  "outputs": [
34
  {
website_script.py CHANGED
@@ -2,11 +2,39 @@ import json
2
  import copy
3
 
4
  import torch
 
5
  from transformers import AutoTokenizer
6
 
7
- from extended_embeddings.token_classification import ExtendedEmbeddigsRobertaForTokenClassification
8
- from data_manipulation.dataset_funcions import load_gazetteers, gazetteer_matching, align_gazetteers_with_tokens
9
- from data_manipulation.preprocess_gazetteers import build_reverse_dictionary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  def load():
@@ -18,7 +46,7 @@ def load():
18
  tokenizer = AutoTokenizer.from_pretrained(model_name)
19
  model.eval()
20
 
21
- gazetteers_for_matching = load_gazetteers(gazetteers_path)
22
  temp = []
23
  for i in gazetteers_for_matching.keys():
24
  temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]}))
 
2
  import copy
3
 
4
  import torch
5
+ from simplemma import lemmatize
6
  from transformers import AutoTokenizer
7
 
8
+ from extended_embeddings.extended_embedding_token_classification import ExtendedEmbeddigsRobertaForTokenClassification
9
+ from data_manipulation.dataset_funcions import gazetteer_matching, align_gazetteers_with_tokens
10
+
11
+ # code originaly from data_manipulation.creation_gazetteers
12
+ def lemmatizing(x):
13
+ if x == "":
14
+ return ""
15
+ return lemmatize(x, lang="cs")
16
+
17
+ # code originaly from data_manipulation.creation_gazetteers
18
+ def build_reverse_dictionary(dictionary, apply_lemmatizing=False):
19
+ reverse_dictionary = {}
20
+ for key, values in dictionary.items():
21
+ for value in values:
22
+ reverse_dictionary[value] = key
23
+ if apply_lemmatizing:
24
+ temp = lemmatizing(value)
25
+ if temp != value:
26
+ reverse_dictionary[temp] = key
27
+ return reverse_dictionary
28
+
29
+ def load_json(path):
30
+ """
31
+ Load gazetteers from a file
32
+ :param path: path to the gazetteer file
33
+ :return: a dict of gazetteers
34
+ """
35
+ with open(path, 'r') as file:
36
+ data = json.load(file)
37
+ return data
38
 
39
 
40
  def load():
 
46
  tokenizer = AutoTokenizer.from_pretrained(model_name)
47
  model.eval()
48
 
49
+ gazetteers_for_matching = load_json(gazetteers_path)
50
  temp = []
51
  for i in gazetteers_for_matching.keys():
52
  temp.append(build_reverse_dictionary({i: gazetteers_for_matching[i]}))