emanuelaboros commited on
Commit
075c00d
·
0 Parent(s):

Initial commit for ner-stacked-bert-multilingual

Browse files
Files changed (2) hide show
  1. .gitattributes +35 -0
  2. generic_ner.py +791 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
generic_ner.py ADDED
@@ -0,0 +1,791 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from transformers import Pipeline
3
+ import numpy as np
4
+ import torch
5
+ import nltk
6
+
7
+ # new test
8
+ nltk.download("averaged_perceptron_tagger")
9
+ nltk.download("averaged_perceptron_tagger_eng")
10
+ nltk.download("stopwords")
11
+ from nltk.chunk import conlltags2tree
12
+ from nltk import pos_tag
13
+ from nltk.tree import Tree
14
+ import torch.nn.functional as F
15
+ import re, string
16
+
17
+ stop_words = set(nltk.corpus.stopwords.words("english"))
18
+ DEBUG = False
19
+ punctuation = (
20
+ string.punctuation
21
+ + "«»—…“”"
22
+ + "—."
23
+ + "–"
24
+ + "’"
25
+ + "‘"
26
+ + "´"
27
+ + "•"
28
+ + "°"
29
+ + "»"
30
+ + "“"
31
+ + "”"
32
+ + "–"
33
+ + "—"
34
+ + "‘’“”„«»•–—―‣◦…§¶†‡‰′″〈〉"
35
+ )
36
+
37
+ # List of additional "strange" punctuation marks
38
+ # additional_punctuation = "‘’“”„«»•–—―‣◦…§¶†‡‰′″〈〉"
39
+
40
+
41
+ WHITESPACE_RULES = {
42
+ "fr": {
43
+ "pct_no_ws_before": [".", ",", ")", "]", "}", "°", "...", ".-", "%"],
44
+ "pct_no_ws_after": ["(", "[", "{"],
45
+ "pct_no_ws_before_after": ["'", "-"],
46
+ "pct_number": [".", ","],
47
+ },
48
+ "de": {
49
+ "pct_no_ws_before": [
50
+ ".",
51
+ ",",
52
+ ")",
53
+ "]",
54
+ "}",
55
+ "°",
56
+ "...",
57
+ "?",
58
+ "!",
59
+ ":",
60
+ ";",
61
+ ".-",
62
+ "%",
63
+ ],
64
+ "pct_no_ws_after": ["(", "[", "{"],
65
+ "pct_no_ws_before_after": ["'", "-"],
66
+ "pct_number": [".", ","],
67
+ },
68
+ "other": {
69
+ "pct_no_ws_before": [
70
+ ".",
71
+ ",",
72
+ ")",
73
+ "]",
74
+ "}",
75
+ "°",
76
+ "...",
77
+ "?",
78
+ "!",
79
+ ":",
80
+ ";",
81
+ ".-",
82
+ "%",
83
+ ],
84
+ "pct_no_ws_after": ["(", "[", "{"],
85
+ "pct_no_ws_before_after": ["'", "-"],
86
+ "pct_number": [".", ","],
87
+ },
88
+ }
89
+
90
+
91
+ def tokenize(text: str, language: str = "other") -> list[str]:
92
+ """Apply whitespace rules to the given text and language, separating it into tokens.
93
+
94
+ Args:
95
+ text (str): The input text to separate into a list of tokens.
96
+ language (str): Language of the text.
97
+
98
+ Returns:
99
+ list[str]: List of tokens with punctuation as separate tokens.
100
+ """
101
+ # text = add_spaces_around_punctuation(text)
102
+ if not text:
103
+ return []
104
+
105
+ if language not in WHITESPACE_RULES:
106
+ # Default behavior for languages without specific rules:
107
+ # tokenize using standard whitespace splitting
108
+ language = "other"
109
+
110
+ wsrules = WHITESPACE_RULES[language]
111
+ tokenized_text = []
112
+ current_token = ""
113
+
114
+ for char in text:
115
+ if char in wsrules["pct_no_ws_before_after"]:
116
+ if current_token:
117
+ tokenized_text.append(current_token)
118
+ tokenized_text.append(char)
119
+ current_token = ""
120
+ elif char in wsrules["pct_no_ws_before"] or char in wsrules["pct_no_ws_after"]:
121
+ if current_token:
122
+ tokenized_text.append(current_token)
123
+ tokenized_text.append(char)
124
+ current_token = ""
125
+ elif char.isspace():
126
+ if current_token:
127
+ tokenized_text.append(current_token)
128
+ current_token = ""
129
+ else:
130
+ current_token += char
131
+
132
+ if current_token:
133
+ tokenized_text.append(current_token)
134
+
135
+ return tokenized_text
136
+
137
+
138
+ def normalize_text(text):
139
+ # Remove spaces and tabs for the search but keep newline characters
140
+ return re.sub(r"[ \t]+", "", text)
141
+
142
+
143
+ def find_entity_indices(article_text, search_text):
144
+ # Normalize texts by removing spaces and tabs
145
+ normalized_article = normalize_text(article_text)
146
+ normalized_search = normalize_text(search_text)
147
+
148
+ # Initialize a list to hold all start and end indices
149
+ indices = []
150
+
151
+ # Find all occurrences of the search text in the normalized article text
152
+ start_index = 0
153
+ while True:
154
+ start_index = normalized_article.find(normalized_search, start_index)
155
+ if start_index == -1:
156
+ break
157
+
158
+ # Calculate the actual start and end indices in the original article text
159
+ original_chars = 0
160
+ original_start_index = 0
161
+ for i in range(start_index):
162
+ while article_text[original_start_index] in (" ", "\t"):
163
+ original_start_index += 1
164
+ if article_text[original_start_index] not in (" ", "\t", "\n"):
165
+ original_chars += 1
166
+ original_start_index += 1
167
+
168
+ original_end_index = original_start_index
169
+ search_chars = 0
170
+ while search_chars < len(normalized_search):
171
+ if article_text[original_end_index] not in (" ", "\t", "\n"):
172
+ search_chars += 1
173
+ original_end_index += 1 # Increment to include the last character
174
+
175
+ # Append the found indices to the list
176
+ if article_text[original_start_index] == " ":
177
+ original_start_index += 1
178
+ indices.append((original_start_index, original_end_index))
179
+
180
+ # Move start_index to the next position to continue searching
181
+ start_index += 1
182
+
183
+ return indices
184
+
185
+
186
+ def get_entities(tokens, tags, confidences, text):
187
+
188
+ tags = [tag.replace("S-", "B-").replace("E-", "I-") for tag in tags]
189
+ pos_tags = [pos for token, pos in pos_tag(tokens)]
190
+
191
+ for i in range(1, len(tags)):
192
+ # If a 'B-' tag is followed by another 'B-' without an 'O' in between, change the second to 'I-'
193
+ if tags[i].startswith("B-") and tags[i - 1].startswith("I-"):
194
+ tags[i] = "I-" + tags[i][2:] # Change 'B-' to 'I-' for the same entity type
195
+
196
+ conlltags = [(token, pos, tg) for token, pos, tg in zip(tokens, pos_tags, tags)]
197
+ ne_tree = conlltags2tree(conlltags)
198
+
199
+ entities = []
200
+ idx: int = 0
201
+ already_done = []
202
+ for subtree in ne_tree:
203
+ # skipping 'O' tags
204
+ if isinstance(subtree, Tree):
205
+ original_label = subtree.label()
206
+ original_string = " ".join([token for token, pos in subtree.leaves()])
207
+
208
+ for indices in find_entity_indices(text, original_string):
209
+ entity_start_position = indices[0]
210
+ entity_end_position = indices[1]
211
+ if (
212
+ "_".join(
213
+ [original_label, original_string, str(entity_start_position)]
214
+ )
215
+ in already_done
216
+ ):
217
+ continue
218
+ else:
219
+ already_done.append(
220
+ "_".join(
221
+ [
222
+ original_label,
223
+ original_string,
224
+ str(entity_start_position),
225
+ ]
226
+ )
227
+ )
228
+ if len(text[entity_start_position:entity_end_position].strip()) < len(
229
+ text[entity_start_position:entity_end_position]
230
+ ):
231
+ entity_start_position = (
232
+ entity_start_position
233
+ + len(text[entity_start_position:entity_end_position])
234
+ - len(text[entity_start_position:entity_end_position].strip())
235
+ )
236
+
237
+ entities.append(
238
+ {
239
+ "type": original_label,
240
+ "confidence_ner": round(
241
+ np.average(confidences[idx : idx + len(subtree)]) * 100, 2
242
+ ),
243
+ "index": (idx, idx + len(subtree)),
244
+ "surface": text[
245
+ entity_start_position:entity_end_position
246
+ ], # original_string,
247
+ "lOffset": entity_start_position,
248
+ "rOffset": entity_end_position,
249
+ }
250
+ )
251
+
252
+ idx += len(subtree)
253
+
254
+ # Update the current character position
255
+ # We add the length of the original string + 1 (for the space)
256
+ else:
257
+ token, pos = subtree
258
+ # If it's not a named entity, we still need to update the character
259
+ # position
260
+ idx += 1
261
+
262
+ return entities
263
+
264
+
265
+ def realign(
266
+ text_sentence, out_label_preds, softmax_scores, tokenizer, reverted_label_map
267
+ ):
268
+ preds_list, words_list, confidence_list = [], [], []
269
+ word_ids = tokenizer(text_sentence, is_split_into_words=True).word_ids()
270
+ for idx, word in enumerate(text_sentence):
271
+ beginning_index = word_ids.index(idx)
272
+ try:
273
+ preds_list.append(reverted_label_map[out_label_preds[beginning_index]])
274
+ confidence_list.append(max(softmax_scores[beginning_index]))
275
+ except Exception as ex: # the sentence was longer then max_length
276
+ preds_list.append("O")
277
+ confidence_list.append(0.0)
278
+ words_list.append(word)
279
+
280
+ return words_list, preds_list, confidence_list
281
+
282
+
283
+ def add_spaces_around_punctuation(text):
284
+ # Add a space before and after all punctuation
285
+ all_punctuation = string.punctuation + punctuation
286
+ return re.sub(r"([{}])".format(re.escape(all_punctuation)), r" \1 ", text)
287
+
288
+
289
+ def attach_comp_to_closest(entities):
290
+ # Define valid entity types that can receive a "comp.function" or "comp.name" attachment
291
+ valid_entity_types = {"org", "pers", "org.ent", "pers.ind"}
292
+
293
+ # Separate "comp.function" and "comp.name" entities from other entities
294
+ comp_entities = [ent for ent in entities if ent["type"].startswith("comp")]
295
+ other_entities = [ent for ent in entities if not ent["type"].startswith("comp")]
296
+
297
+ for comp_entity in comp_entities:
298
+ closest_entity = None
299
+ min_distance = float("inf")
300
+
301
+ # Find the closest non-"comp" entity that is valid for attaching
302
+ for other_entity in other_entities:
303
+ # Calculate distance between the comp entity and the other entity
304
+ if comp_entity["lOffset"] > other_entity["rOffset"]:
305
+ distance = comp_entity["lOffset"] - other_entity["rOffset"]
306
+ elif comp_entity["rOffset"] < other_entity["lOffset"]:
307
+ distance = other_entity["lOffset"] - comp_entity["rOffset"]
308
+ else:
309
+ distance = 0 # They overlap or touch
310
+
311
+ # Ensure the entity type is valid and check for minimal distance
312
+ if (
313
+ distance < min_distance
314
+ and other_entity["type"].split(".")[0] in valid_entity_types
315
+ ):
316
+ min_distance = distance
317
+ closest_entity = other_entity
318
+
319
+ # Attach the "comp.function" or "comp.name" if a valid entity is found
320
+ if closest_entity:
321
+ suffix = comp_entity["type"].split(".")[
322
+ -1
323
+ ] # Extract the suffix (e.g., 'name', 'function')
324
+ closest_entity[suffix] = comp_entity["surface"] # Attach the text
325
+
326
+ return other_entities
327
+
328
+
329
+ def conflicting_context(comp_entity, target_entity):
330
+ """
331
+ Determines if there is a conflict between the comp_entity and the target entity.
332
+ Prevents incorrect name and function attachments by using a rule-based approach.
333
+ """
334
+ # Case 1: Check for correct function attachment to person or organization entities
335
+ if comp_entity["type"].startswith("comp.function"):
336
+ if not ("pers" in target_entity["type"] or "org" in target_entity["type"]):
337
+ return True # Conflict: Function should only attach to persons or organizations
338
+
339
+ # Case 2: Avoid attaching comp.* entities to non-person, non-organization types (like locations)
340
+ if "loc" in target_entity["type"]:
341
+ return True # Conflict: comp.* entities should not attach to locations or similar types
342
+
343
+ return False # No conflict
344
+
345
+
346
+ def extract_name_from_text(text, partial_name):
347
+ """
348
+ Extracts the full name from the entity's text based on the partial name.
349
+ This function assumes that the full name starts with capitalized letters and does not
350
+ include any words that come after the partial name.
351
+ """
352
+ # Split the text and partial name into words
353
+ words = tokenize(text)
354
+ partial_words = partial_name.split()
355
+
356
+ if DEBUG:
357
+ print("text:", text)
358
+ if DEBUG:
359
+ print("partial_name:", partial_name)
360
+
361
+ # Find the position of the partial name in the word list
362
+ for i, word in enumerate(words):
363
+ if DEBUG:
364
+ print(words, "---", words[i : i + len(partial_words)])
365
+ if words[i : i + len(partial_words)] == partial_words:
366
+ # Initialize full name with the partial name
367
+ full_name = partial_words[:]
368
+
369
+ if DEBUG:
370
+ print("full_name:", full_name)
371
+
372
+ # Check previous words and only add capitalized words (skip lowercase words)
373
+ j = i - 1
374
+ while j >= 0 and words[j][0].isupper():
375
+ full_name.insert(0, words[j])
376
+ j -= 1
377
+ if DEBUG:
378
+ print("full_name:", full_name)
379
+
380
+ # Return only the full name up to the partial name (ignore words after the name)
381
+ return " ".join(full_name).strip() # Join the words to form the full name
382
+
383
+ # If not found, return the original text (as a fallback)
384
+ return text.strip()
385
+
386
+
387
+ def repair_names_in_entities(entities):
388
+ """
389
+ This function repairs the names in the entities by extracting the full name
390
+ from the text of the entity if a partial name (e.g., 'Washington') is incorrectly attached.
391
+ """
392
+ for entity in entities:
393
+ if "name" in entity and "pers" in entity["type"]:
394
+ name = entity["name"]
395
+ text = entity["surface"]
396
+
397
+ # Check if the attached name is part of the entity's text
398
+ if name in text:
399
+ # Extract the full name from the text by splitting around the attached name
400
+ full_name = extract_name_from_text(entity["surface"], name)
401
+ entity["name"] = (
402
+ full_name # Replace the partial name with the full name
403
+ )
404
+ # if "name" not in entity:
405
+ # entity["name"] = entity["surface"]
406
+
407
+ return entities
408
+
409
+
410
+ def clean_coarse_entities(entities):
411
+ """
412
+ This function removes entities that are not useful for the NEL process.
413
+ """
414
+ # Define a set of entity types that are considered useful for NEL
415
+ useful_types = {
416
+ "pers", # Person
417
+ "loc", # Location
418
+ "org", # Organization
419
+ "date", # Product
420
+ "time", # Time
421
+ }
422
+
423
+ # Filter out entities that are not in the useful_types set unless they are comp.* entities
424
+ cleaned_entities = [
425
+ entity
426
+ for entity in entities
427
+ if entity["type"] in useful_types or "comp" in entity["type"]
428
+ ]
429
+
430
+ return cleaned_entities
431
+
432
+
433
+ def postprocess_entities(entities):
434
+ # Step 1: Filter entities with the same text, keeping the one with the most dots in the 'entity' field
435
+ entity_map = {}
436
+
437
+ # Loop over the entities and prioritize the one with the most dots
438
+ for entity in entities:
439
+ entity_text = entity["surface"]
440
+ num_dots = entity["type"].count(".")
441
+
442
+ # If the entity text is new, or this entity has more dots, update the map
443
+ if (
444
+ entity_text not in entity_map
445
+ or entity_map[entity_text]["type"].count(".") < num_dots
446
+ ):
447
+ entity_map[entity_text] = entity
448
+
449
+ # Collect the filtered entities from the map
450
+ filtered_entities = list(entity_map.values())
451
+
452
+ # Step 2: Attach "comp.function" entities to the closest other entities
453
+ filtered_entities = attach_comp_to_closest(filtered_entities)
454
+ if DEBUG:
455
+ print("After attach_comp_to_closest:", filtered_entities, "\n")
456
+ filtered_entities = repair_names_in_entities(filtered_entities)
457
+ if DEBUG:
458
+ print("After repair_names_in_entities:", filtered_entities, "\n")
459
+
460
+ # Step 3: Remove entities that are not useful for NEL
461
+ # filtered_entities = clean_coarse_entities(filtered_entities)
462
+
463
+ # filtered_entities = remove_blacklisted_entities(filtered_entities)
464
+
465
+ return filtered_entities
466
+
467
+
468
+ def remove_included_entities(entities):
469
+ # Loop through entities and remove those whose text is included in another with the same label
470
+ final_entities = []
471
+ for i, entity in enumerate(entities):
472
+ is_included = False
473
+ for other_entity in entities:
474
+ if entity["surface"] != other_entity["surface"]:
475
+ if "comp" in other_entity["type"]:
476
+ # Check if entity's text is a substring of another entity's text
477
+ if entity["surface"] in other_entity["surface"]:
478
+ is_included = True
479
+ break
480
+ elif (
481
+ entity["type"].split(".")[0] in other_entity["type"].split(".")[0]
482
+ or other_entity["type"].split(".")[0]
483
+ in entity["type"].split(".")[0]
484
+ ):
485
+ if entity["surface"] in other_entity["surface"]:
486
+ is_included = True
487
+ if not is_included:
488
+ final_entities.append(entity)
489
+ return final_entities
490
+
491
+
492
+ def refine_entities_with_coarse(all_entities, coarse_entities):
493
+ """
494
+ Looks through all entities and refines them based on the coarse entities.
495
+ If a surface match is found in the coarse entities and the types match,
496
+ the entity's confidence_ner and type are updated based on the coarse entity.
497
+ """
498
+ # Create a dictionary for coarse entities based on surface and type for quick lookup
499
+ coarse_lookup = {}
500
+ for coarse_entity in coarse_entities:
501
+ key = (coarse_entity["surface"], coarse_entity["type"].split(".")[0])
502
+ coarse_lookup[key] = coarse_entity
503
+
504
+ # Iterate through all entities and compare with the coarse entities
505
+ for entity in all_entities:
506
+ key = (
507
+ entity["surface"],
508
+ entity["type"].split(".")[0],
509
+ ) # Use the coarse type for comparison
510
+
511
+ if key in coarse_lookup:
512
+ coarse_entity = coarse_lookup[key]
513
+ # If a match is found, update the confidence_ner and type in the entity
514
+ if entity["confidence_ner"] < coarse_entity["confidence_ner"]:
515
+ entity["confidence_ner"] = coarse_entity["confidence_ner"]
516
+ entity["type"] = coarse_entity[
517
+ "type"
518
+ ] # Update the type if the confidence is higher
519
+
520
+ # No need to append to refined_entities, we're modifying in place
521
+ for entity in all_entities:
522
+ entity["type"] = entity["type"].split(".")[0]
523
+ return all_entities
524
+
525
+
526
+ def remove_trailing_stopwords(entities):
527
+ """
528
+ This function removes stopwords and punctuation from both the beginning and end of each entity's text
529
+ and repairs the lOffset and rOffset accordingly.
530
+ """
531
+ if DEBUG:
532
+ print(f"Initial entities: {len(entities)}")
533
+ new_entities = []
534
+ for entity in entities:
535
+ if "comp" not in entity["type"]:
536
+ entity_text = entity["surface"]
537
+ original_len = len(entity_text)
538
+
539
+ # Initial offsets
540
+ lOffset = entity.get("lOffset", 0)
541
+ rOffset = entity.get("rOffset", original_len)
542
+
543
+ # Remove stopwords and punctuation from the beginning
544
+ i = 0
545
+ while entity_text and (
546
+ entity_text.split()[0].lower() in stop_words
547
+ or entity_text[0] in punctuation
548
+ ):
549
+ if entity_text.split()[0].lower() in stop_words:
550
+ stopword_len = (
551
+ len(entity_text.split()[0]) + 1
552
+ ) # Adjust length for stopword and following space
553
+ entity_text = entity_text[stopword_len:] # Remove leading stopword
554
+ lOffset += stopword_len # Adjust the left offset
555
+ if DEBUG:
556
+ print(
557
+ f"Removed leading stopword from entity: {entity['surface']} --> {entity_text} ({entity['type']}"
558
+ )
559
+ elif entity_text[0] in punctuation:
560
+ entity_text = entity_text[1:] # Remove leading punctuation
561
+ lOffset += 1 # Adjust the left offset
562
+ if DEBUG:
563
+ print(
564
+ f"Removed leading punctuation from entity: {entity['surface']} --> {entity_text} ({entity['type']}"
565
+ )
566
+ i += 1
567
+
568
+ i = 0
569
+ # Remove stopwords and punctuation from the end
570
+ iteration = 0
571
+ max_iterations = len(entity_text) # Prevent infinite loops
572
+
573
+ while entity_text and iteration < max_iterations:
574
+ # Check if the last word is a stopword or the last character is punctuation
575
+ last_word = entity_text.split()[-1] if entity_text.split() else ""
576
+ last_char = entity_text[-1]
577
+
578
+ if last_word.lower() in stop_words:
579
+ # Remove trailing stopword and adjust rOffset
580
+ stopword_len = len(last_word) + 1 # Include space before stopword
581
+ entity_text = entity_text[:-stopword_len].rstrip()
582
+ rOffset -= stopword_len
583
+ if DEBUG:
584
+ print(
585
+ f"Removed trailing stopword from entity: {entity_text} (rOffset={rOffset})"
586
+ )
587
+
588
+ elif last_char in punctuation:
589
+ # Remove trailing punctuation and adjust rOffset
590
+ entity_text = entity_text[:-1].rstrip()
591
+ rOffset -= 1
592
+ if DEBUG:
593
+ print(
594
+ f"Removed trailing punctuation from entity: {entity_text} (rOffset={rOffset})"
595
+ )
596
+ else:
597
+ # Exit loop if neither stopwords nor punctuation are found
598
+ break
599
+
600
+ iteration += 1
601
+ # print(f"ITERATION: {iteration} [{entity['surface']}] for {entity_text}")
602
+
603
+ if len(entity_text.strip()) == 1:
604
+ entities.remove(entity)
605
+ if DEBUG:
606
+ print(f"Skipping entity: {entity_text}")
607
+ continue
608
+ # Skip certain entities based on rules
609
+ if entity_text in string.punctuation:
610
+ if DEBUG:
611
+ print(f"Skipping entity: {entity_text}")
612
+ entities.remove(entity)
613
+ continue
614
+ # check now if its in stopwords
615
+ if entity_text.lower() in stop_words:
616
+ if DEBUG:
617
+ print(f"Skipping entity: {entity_text}")
618
+ entities.remove(entity)
619
+ continue
620
+ # check now if the entire entity is a list of stopwords:
621
+ if all([word.lower() in stop_words for word in entity_text.split()]):
622
+ if DEBUG:
623
+ print(f"Skipping entity: {entity_text}")
624
+ entities.remove(entity)
625
+ continue
626
+ # Check if the entire entity is made up of stopwords characters
627
+ if all(
628
+ [char.lower() in stop_words for char in entity_text if char.isalpha()]
629
+ ):
630
+ if DEBUG:
631
+ print(
632
+ f"Skipping entity: {entity_text} (all characters are stopwords)"
633
+ )
634
+ entities.remove(entity)
635
+ continue
636
+ # check now if all entity is in a list of punctuation
637
+ if all([word in string.punctuation for word in entity_text.split()]):
638
+ if DEBUG:
639
+ print(
640
+ f"Skipping entity: {entity_text} (all characters are punctuation)"
641
+ )
642
+ entities.remove(entity)
643
+ continue
644
+ if all(
645
+ [
646
+ char.lower() in string.punctuation
647
+ for char in entity_text
648
+ if char.isalpha()
649
+ ]
650
+ ):
651
+ if DEBUG:
652
+ print(
653
+ f"Skipping entity: {entity_text} (all characters are punctuation)"
654
+ )
655
+ entities.remove(entity)
656
+ continue
657
+
658
+ # if it's a number and "time" no in it, then continue
659
+ if entity_text.isdigit() and "time" not in entity["type"]:
660
+ if DEBUG:
661
+ print(f"Skipping entity: {entity_text}")
662
+ entities.remove(entity)
663
+ continue
664
+
665
+ if entity_text.startswith(" "):
666
+ entity_text = entity_text[1:]
667
+ # update lOffset, rOffset
668
+ lOffset += 1
669
+ if entity_text.endswith(" "):
670
+ entity_text = entity_text[:-1]
671
+ # update lOffset, rOffset
672
+ rOffset -= 1
673
+
674
+ # Update the entity surface and offsets
675
+ entity["surface"] = entity_text
676
+ entity["lOffset"] = lOffset
677
+ entity["rOffset"] = rOffset
678
+
679
+ # Remove the entity if the surface is empty after cleaning
680
+ if len(entity["surface"].strip()) == 0:
681
+ if DEBUG:
682
+ print(f"Deleted entity: {entity['surface']}")
683
+ entities.remove(entity)
684
+ else:
685
+ new_entities.append(entity)
686
+
687
+ if DEBUG:
688
+ print(f"Remained entities: {len(new_entities)}")
689
+ return new_entities
690
+
691
+
692
+ class MultitaskTokenClassificationPipeline(Pipeline):
693
+
694
+ def _sanitize_parameters(self, **kwargs):
695
+ preprocess_kwargs = {}
696
+ if "text" in kwargs:
697
+ preprocess_kwargs["text"] = kwargs["text"]
698
+ self.label_map = self.model.config.label_map
699
+ self.id2label = {
700
+ task: {id_: label for label, id_ in labels.items()}
701
+ for task, labels in self.label_map.items()
702
+ }
703
+ return preprocess_kwargs, {}, {}
704
+
705
+ def preprocess(self, text, **kwargs):
706
+
707
+ tokenized_inputs = self.tokenizer(
708
+ text, padding="max_length", truncation=True, max_length=512
709
+ )
710
+
711
+ text_sentence = tokenize(add_spaces_around_punctuation(text))
712
+ return tokenized_inputs, text_sentence, text
713
+
714
+ def _forward(self, inputs):
715
+ inputs, text_sentences, text = inputs
716
+ input_ids = torch.tensor([inputs["input_ids"]], dtype=torch.long).to(
717
+ self.model.device
718
+ )
719
+ attention_mask = torch.tensor([inputs["attention_mask"]], dtype=torch.long).to(
720
+ self.model.device
721
+ )
722
+ with torch.no_grad():
723
+ outputs = self.model(input_ids, attention_mask)
724
+ return outputs, text_sentences, text
725
+
726
+ def is_within(self, entity1, entity2):
727
+ """Check if entity1 is fully within the bounds of entity2."""
728
+ return (
729
+ entity1["lOffset"] >= entity2["lOffset"]
730
+ and entity1["rOffset"] <= entity2["rOffset"]
731
+ )
732
+
733
+ def postprocess(self, outputs, **kwargs):
734
+ """
735
+ Postprocess the outputs of the model
736
+ :param outputs:
737
+ :param kwargs:
738
+ :return:
739
+ """
740
+ tokens_result, text_sentence, text = outputs
741
+
742
+ predictions = {}
743
+ confidence_scores = {}
744
+ for task, logits in tokens_result.logits.items():
745
+ predictions[task] = torch.argmax(logits, dim=-1).tolist()[0]
746
+ confidence_scores[task] = F.softmax(logits, dim=-1).tolist()[0]
747
+
748
+ entities = {}
749
+ for task in predictions.keys():
750
+ words_list, preds_list, confidence_list = realign(
751
+ text_sentence,
752
+ predictions[task],
753
+ confidence_scores[task],
754
+ self.tokenizer,
755
+ self.id2label[task],
756
+ )
757
+
758
+ entities[task] = get_entities(words_list, preds_list, confidence_list, text)
759
+
760
+ # add titles to comp entities
761
+ # from pprint import pprint
762
+
763
+ # print("Before:")
764
+ # pprint(entities)
765
+
766
+ all_entities = []
767
+ coarse_entities = []
768
+ for key in entities:
769
+ if key in ["NE-COARSE-LIT"]:
770
+ coarse_entities = entities[key]
771
+ all_entities.extend(entities[key])
772
+
773
+ if DEBUG:
774
+ print(all_entities)
775
+ # print("After remove_included_entities:")
776
+ all_entities = remove_included_entities(all_entities)
777
+ if DEBUG:
778
+ print("After remove_included_entities:", all_entities)
779
+ all_entities = remove_trailing_stopwords(all_entities)
780
+ if DEBUG:
781
+ print("After remove_trailing_stopwords:", all_entities)
782
+ all_entities = postprocess_entities(all_entities)
783
+ if DEBUG:
784
+ print("After postprocess_entities:", all_entities)
785
+ all_entities = refine_entities_with_coarse(all_entities, coarse_entities)
786
+ if DEBUG:
787
+ print("After refine_entities_with_coarse:", all_entities)
788
+ # print("After attach_comp_to_closest:")
789
+ # pprint(all_entities)
790
+ # print("\n")
791
+ return all_entities