Commit
·
137cb44
1
Parent(s):
8e0fd47
lets try to change the pipeline
Browse files- lang_detect.py +2 -682
- push_to_hf.py +0 -145
lang_detect.py
CHANGED
@@ -13,680 +13,6 @@ from nltk.tree import Tree
|
|
13 |
import torch.nn.functional as F
|
14 |
import re, string
|
15 |
|
16 |
-
stop_words = set(nltk.corpus.stopwords.words("english"))
|
17 |
-
DEBUG = False
|
18 |
-
punctuation = (
|
19 |
-
string.punctuation
|
20 |
-
+ "«»—…“”"
|
21 |
-
+ "—."
|
22 |
-
+ "–"
|
23 |
-
+ "’"
|
24 |
-
+ "‘"
|
25 |
-
+ "´"
|
26 |
-
+ "•"
|
27 |
-
+ "°"
|
28 |
-
+ "»"
|
29 |
-
+ "“"
|
30 |
-
+ "”"
|
31 |
-
+ "–"
|
32 |
-
+ "—"
|
33 |
-
+ "‘’“”„«»•–—―‣◦…§¶†‡‰′″〈〉"
|
34 |
-
)
|
35 |
-
|
36 |
-
# List of additional "strange" punctuation marks
|
37 |
-
# additional_punctuation = "‘’“”„«»•–—―‣◦…§¶†‡‰′″〈〉"
|
38 |
-
|
39 |
-
|
40 |
-
WHITESPACE_RULES = {
|
41 |
-
"fr": {
|
42 |
-
"pct_no_ws_before": [".", ",", ")", "]", "}", "°", "...", ".-", "%"],
|
43 |
-
"pct_no_ws_after": ["(", "[", "{"],
|
44 |
-
"pct_no_ws_before_after": ["'", "-"],
|
45 |
-
"pct_number": [".", ","],
|
46 |
-
},
|
47 |
-
"de": {
|
48 |
-
"pct_no_ws_before": [
|
49 |
-
".",
|
50 |
-
",",
|
51 |
-
")",
|
52 |
-
"]",
|
53 |
-
"}",
|
54 |
-
"°",
|
55 |
-
"...",
|
56 |
-
"?",
|
57 |
-
"!",
|
58 |
-
":",
|
59 |
-
";",
|
60 |
-
".-",
|
61 |
-
"%",
|
62 |
-
],
|
63 |
-
"pct_no_ws_after": ["(", "[", "{"],
|
64 |
-
"pct_no_ws_before_after": ["'", "-"],
|
65 |
-
"pct_number": [".", ","],
|
66 |
-
},
|
67 |
-
"other": {
|
68 |
-
"pct_no_ws_before": [
|
69 |
-
".",
|
70 |
-
",",
|
71 |
-
")",
|
72 |
-
"]",
|
73 |
-
"}",
|
74 |
-
"°",
|
75 |
-
"...",
|
76 |
-
"?",
|
77 |
-
"!",
|
78 |
-
":",
|
79 |
-
";",
|
80 |
-
".-",
|
81 |
-
"%",
|
82 |
-
],
|
83 |
-
"pct_no_ws_after": ["(", "[", "{"],
|
84 |
-
"pct_no_ws_before_after": ["'", "-"],
|
85 |
-
"pct_number": [".", ","],
|
86 |
-
},
|
87 |
-
}
|
88 |
-
|
89 |
-
|
90 |
-
def tokenize(text: str, language: str = "other") -> list[str]:
|
91 |
-
"""Apply whitespace rules to the given text and language, separating it into tokens.
|
92 |
-
|
93 |
-
Args:
|
94 |
-
text (str): The input text to separate into a list of tokens.
|
95 |
-
language (str): Language of the text.
|
96 |
-
|
97 |
-
Returns:
|
98 |
-
list[str]: List of tokens with punctuation as separate tokens.
|
99 |
-
"""
|
100 |
-
# text = add_spaces_around_punctuation(text)
|
101 |
-
if not text:
|
102 |
-
return []
|
103 |
-
|
104 |
-
if language not in WHITESPACE_RULES:
|
105 |
-
# Default behavior for languages without specific rules:
|
106 |
-
# tokenize using standard whitespace splitting
|
107 |
-
language = "other"
|
108 |
-
|
109 |
-
wsrules = WHITESPACE_RULES[language]
|
110 |
-
tokenized_text = []
|
111 |
-
current_token = ""
|
112 |
-
|
113 |
-
for char in text:
|
114 |
-
if char in wsrules["pct_no_ws_before_after"]:
|
115 |
-
if current_token:
|
116 |
-
tokenized_text.append(current_token)
|
117 |
-
tokenized_text.append(char)
|
118 |
-
current_token = ""
|
119 |
-
elif char in wsrules["pct_no_ws_before"] or char in wsrules["pct_no_ws_after"]:
|
120 |
-
if current_token:
|
121 |
-
tokenized_text.append(current_token)
|
122 |
-
tokenized_text.append(char)
|
123 |
-
current_token = ""
|
124 |
-
elif char.isspace():
|
125 |
-
if current_token:
|
126 |
-
tokenized_text.append(current_token)
|
127 |
-
current_token = ""
|
128 |
-
else:
|
129 |
-
current_token += char
|
130 |
-
|
131 |
-
if current_token:
|
132 |
-
tokenized_text.append(current_token)
|
133 |
-
|
134 |
-
return tokenized_text
|
135 |
-
|
136 |
-
|
137 |
-
def normalize_text(text):
|
138 |
-
# Remove spaces and tabs for the search but keep newline characters
|
139 |
-
return re.sub(r"[ \t]+", "", text)
|
140 |
-
|
141 |
-
|
142 |
-
def find_entity_indices(article_text, search_text):
|
143 |
-
# Normalize texts by removing spaces and tabs
|
144 |
-
normalized_article = normalize_text(article_text)
|
145 |
-
normalized_search = normalize_text(search_text)
|
146 |
-
|
147 |
-
# Initialize a list to hold all start and end indices
|
148 |
-
indices = []
|
149 |
-
|
150 |
-
# Find all occurrences of the search text in the normalized article text
|
151 |
-
start_index = 0
|
152 |
-
while True:
|
153 |
-
start_index = normalized_article.find(normalized_search, start_index)
|
154 |
-
if start_index == -1:
|
155 |
-
break
|
156 |
-
|
157 |
-
# Calculate the actual start and end indices in the original article text
|
158 |
-
original_chars = 0
|
159 |
-
original_start_index = 0
|
160 |
-
for i in range(start_index):
|
161 |
-
while article_text[original_start_index] in (" ", "\t"):
|
162 |
-
original_start_index += 1
|
163 |
-
if article_text[original_start_index] not in (" ", "\t", "\n"):
|
164 |
-
original_chars += 1
|
165 |
-
original_start_index += 1
|
166 |
-
|
167 |
-
original_end_index = original_start_index
|
168 |
-
search_chars = 0
|
169 |
-
while search_chars < len(normalized_search):
|
170 |
-
if article_text[original_end_index] not in (" ", "\t", "\n"):
|
171 |
-
search_chars += 1
|
172 |
-
original_end_index += 1 # Increment to include the last character
|
173 |
-
|
174 |
-
# Append the found indices to the list
|
175 |
-
if article_text[original_start_index] == " ":
|
176 |
-
original_start_index += 1
|
177 |
-
indices.append((original_start_index, original_end_index))
|
178 |
-
|
179 |
-
# Move start_index to the next position to continue searching
|
180 |
-
start_index += 1
|
181 |
-
|
182 |
-
return indices
|
183 |
-
|
184 |
-
|
185 |
-
def get_entities(tokens, tags, confidences, text):
|
186 |
-
|
187 |
-
tags = [tag.replace("S-", "B-").replace("E-", "I-") for tag in tags]
|
188 |
-
pos_tags = [pos for token, pos in pos_tag(tokens)]
|
189 |
-
|
190 |
-
for i in range(1, len(tags)):
|
191 |
-
# If a 'B-' tag is followed by another 'B-' without an 'O' in between, change the second to 'I-'
|
192 |
-
if tags[i].startswith("B-") and tags[i - 1].startswith("I-"):
|
193 |
-
tags[i] = "I-" + tags[i][2:] # Change 'B-' to 'I-' for the same entity type
|
194 |
-
|
195 |
-
conlltags = [(token, pos, tg) for token, pos, tg in zip(tokens, pos_tags, tags)]
|
196 |
-
ne_tree = conlltags2tree(conlltags)
|
197 |
-
|
198 |
-
entities = []
|
199 |
-
idx: int = 0
|
200 |
-
already_done = []
|
201 |
-
for subtree in ne_tree:
|
202 |
-
# skipping 'O' tags
|
203 |
-
if isinstance(subtree, Tree):
|
204 |
-
original_label = subtree.label()
|
205 |
-
original_string = " ".join([token for token, pos in subtree.leaves()])
|
206 |
-
|
207 |
-
for indices in find_entity_indices(text, original_string):
|
208 |
-
entity_start_position = indices[0]
|
209 |
-
entity_end_position = indices[1]
|
210 |
-
if (
|
211 |
-
"_".join(
|
212 |
-
[original_label, original_string, str(entity_start_position)]
|
213 |
-
)
|
214 |
-
in already_done
|
215 |
-
):
|
216 |
-
continue
|
217 |
-
else:
|
218 |
-
already_done.append(
|
219 |
-
"_".join(
|
220 |
-
[
|
221 |
-
original_label,
|
222 |
-
original_string,
|
223 |
-
str(entity_start_position),
|
224 |
-
]
|
225 |
-
)
|
226 |
-
)
|
227 |
-
if len(text[entity_start_position:entity_end_position].strip()) < len(
|
228 |
-
text[entity_start_position:entity_end_position]
|
229 |
-
):
|
230 |
-
entity_start_position = (
|
231 |
-
entity_start_position
|
232 |
-
+ len(text[entity_start_position:entity_end_position])
|
233 |
-
- len(text[entity_start_position:entity_end_position].strip())
|
234 |
-
)
|
235 |
-
|
236 |
-
entities.append(
|
237 |
-
{
|
238 |
-
"type": original_label,
|
239 |
-
"confidence_ner": round(
|
240 |
-
np.average(confidences[idx : idx + len(subtree)]) * 100, 2
|
241 |
-
),
|
242 |
-
"index": (idx, idx + len(subtree)),
|
243 |
-
"surface": text[
|
244 |
-
entity_start_position:entity_end_position
|
245 |
-
], # original_string,
|
246 |
-
"lOffset": entity_start_position,
|
247 |
-
"rOffset": entity_end_position,
|
248 |
-
}
|
249 |
-
)
|
250 |
-
|
251 |
-
idx += len(subtree)
|
252 |
-
|
253 |
-
# Update the current character position
|
254 |
-
# We add the length of the original string + 1 (for the space)
|
255 |
-
else:
|
256 |
-
token, pos = subtree
|
257 |
-
# If it's not a named entity, we still need to update the character
|
258 |
-
# position
|
259 |
-
idx += 1
|
260 |
-
|
261 |
-
return entities
|
262 |
-
|
263 |
-
|
264 |
-
def realign(
|
265 |
-
text_sentence, out_label_preds, softmax_scores, tokenizer, reverted_label_map
|
266 |
-
):
|
267 |
-
preds_list, words_list, confidence_list = [], [], []
|
268 |
-
word_ids = tokenizer(text_sentence, is_split_into_words=True).word_ids()
|
269 |
-
for idx, word in enumerate(text_sentence):
|
270 |
-
beginning_index = word_ids.index(idx)
|
271 |
-
try:
|
272 |
-
preds_list.append(reverted_label_map[out_label_preds[beginning_index]])
|
273 |
-
confidence_list.append(max(softmax_scores[beginning_index]))
|
274 |
-
except Exception as ex: # the sentence was longer then max_length
|
275 |
-
preds_list.append("O")
|
276 |
-
confidence_list.append(0.0)
|
277 |
-
words_list.append(word)
|
278 |
-
|
279 |
-
return words_list, preds_list, confidence_list
|
280 |
-
|
281 |
-
|
282 |
-
def add_spaces_around_punctuation(text):
|
283 |
-
# Add a space before and after all punctuation
|
284 |
-
all_punctuation = string.punctuation + punctuation
|
285 |
-
return re.sub(r"([{}])".format(re.escape(all_punctuation)), r" \1 ", text)
|
286 |
-
|
287 |
-
|
288 |
-
def attach_comp_to_closest(entities):
|
289 |
-
# Define valid entity types that can receive a "comp.function" or "comp.name" attachment
|
290 |
-
valid_entity_types = {"org", "pers", "org.ent", "pers.ind"}
|
291 |
-
|
292 |
-
# Separate "comp.function" and "comp.name" entities from other entities
|
293 |
-
comp_entities = [ent for ent in entities if ent["type"].startswith("comp")]
|
294 |
-
other_entities = [ent for ent in entities if not ent["type"].startswith("comp")]
|
295 |
-
|
296 |
-
for comp_entity in comp_entities:
|
297 |
-
closest_entity = None
|
298 |
-
min_distance = float("inf")
|
299 |
-
|
300 |
-
# Find the closest non-"comp" entity that is valid for attaching
|
301 |
-
for other_entity in other_entities:
|
302 |
-
# Calculate distance between the comp entity and the other entity
|
303 |
-
if comp_entity["lOffset"] > other_entity["rOffset"]:
|
304 |
-
distance = comp_entity["lOffset"] - other_entity["rOffset"]
|
305 |
-
elif comp_entity["rOffset"] < other_entity["lOffset"]:
|
306 |
-
distance = other_entity["lOffset"] - comp_entity["rOffset"]
|
307 |
-
else:
|
308 |
-
distance = 0 # They overlap or touch
|
309 |
-
|
310 |
-
# Ensure the entity type is valid and check for minimal distance
|
311 |
-
if (
|
312 |
-
distance < min_distance
|
313 |
-
and other_entity["type"].split(".")[0] in valid_entity_types
|
314 |
-
):
|
315 |
-
min_distance = distance
|
316 |
-
closest_entity = other_entity
|
317 |
-
|
318 |
-
# Attach the "comp.function" or "comp.name" if a valid entity is found
|
319 |
-
if closest_entity:
|
320 |
-
suffix = comp_entity["type"].split(".")[
|
321 |
-
-1
|
322 |
-
] # Extract the suffix (e.g., 'name', 'function')
|
323 |
-
closest_entity[suffix] = comp_entity["surface"] # Attach the text
|
324 |
-
|
325 |
-
return other_entities
|
326 |
-
|
327 |
-
|
328 |
-
def conflicting_context(comp_entity, target_entity):
|
329 |
-
"""
|
330 |
-
Determines if there is a conflict between the comp_entity and the target entity.
|
331 |
-
Prevents incorrect name and function attachments by using a rule-based approach.
|
332 |
-
"""
|
333 |
-
# Case 1: Check for correct function attachment to person or organization entities
|
334 |
-
if comp_entity["type"].startswith("comp.function"):
|
335 |
-
if not ("pers" in target_entity["type"] or "org" in target_entity["type"]):
|
336 |
-
return True # Conflict: Function should only attach to persons or organizations
|
337 |
-
|
338 |
-
# Case 2: Avoid attaching comp.* entities to non-person, non-organization types (like locations)
|
339 |
-
if "loc" in target_entity["type"]:
|
340 |
-
return True # Conflict: comp.* entities should not attach to locations or similar types
|
341 |
-
|
342 |
-
return False # No conflict
|
343 |
-
|
344 |
-
|
345 |
-
def extract_name_from_text(text, partial_name):
|
346 |
-
"""
|
347 |
-
Extracts the full name from the entity's text based on the partial name.
|
348 |
-
This function assumes that the full name starts with capitalized letters and does not
|
349 |
-
include any words that come after the partial name.
|
350 |
-
"""
|
351 |
-
# Split the text and partial name into words
|
352 |
-
words = tokenize(text)
|
353 |
-
partial_words = partial_name.split()
|
354 |
-
|
355 |
-
if DEBUG:
|
356 |
-
print("text:", text)
|
357 |
-
if DEBUG:
|
358 |
-
print("partial_name:", partial_name)
|
359 |
-
|
360 |
-
# Find the position of the partial name in the word list
|
361 |
-
for i, word in enumerate(words):
|
362 |
-
if DEBUG:
|
363 |
-
print(words, "---", words[i : i + len(partial_words)])
|
364 |
-
if words[i : i + len(partial_words)] == partial_words:
|
365 |
-
# Initialize full name with the partial name
|
366 |
-
full_name = partial_words[:]
|
367 |
-
|
368 |
-
if DEBUG:
|
369 |
-
print("full_name:", full_name)
|
370 |
-
|
371 |
-
# Check previous words and only add capitalized words (skip lowercase words)
|
372 |
-
j = i - 1
|
373 |
-
while j >= 0 and words[j][0].isupper():
|
374 |
-
full_name.insert(0, words[j])
|
375 |
-
j -= 1
|
376 |
-
if DEBUG:
|
377 |
-
print("full_name:", full_name)
|
378 |
-
|
379 |
-
# Return only the full name up to the partial name (ignore words after the name)
|
380 |
-
return " ".join(full_name).strip() # Join the words to form the full name
|
381 |
-
|
382 |
-
# If not found, return the original text (as a fallback)
|
383 |
-
return text.strip()
|
384 |
-
|
385 |
-
|
386 |
-
def repair_names_in_entities(entities):
|
387 |
-
"""
|
388 |
-
This function repairs the names in the entities by extracting the full name
|
389 |
-
from the text of the entity if a partial name (e.g., 'Washington') is incorrectly attached.
|
390 |
-
"""
|
391 |
-
for entity in entities:
|
392 |
-
if "name" in entity and "pers" in entity["type"]:
|
393 |
-
name = entity["name"]
|
394 |
-
text = entity["surface"]
|
395 |
-
|
396 |
-
# Check if the attached name is part of the entity's text
|
397 |
-
if name in text:
|
398 |
-
# Extract the full name from the text by splitting around the attached name
|
399 |
-
full_name = extract_name_from_text(entity["surface"], name)
|
400 |
-
entity["name"] = (
|
401 |
-
full_name # Replace the partial name with the full name
|
402 |
-
)
|
403 |
-
# if "name" not in entity:
|
404 |
-
# entity["name"] = entity["surface"]
|
405 |
-
|
406 |
-
return entities
|
407 |
-
|
408 |
-
|
409 |
-
def clean_coarse_entities(entities):
|
410 |
-
"""
|
411 |
-
This function removes entities that are not useful for the NEL process.
|
412 |
-
"""
|
413 |
-
# Define a set of entity types that are considered useful for NEL
|
414 |
-
useful_types = {
|
415 |
-
"pers", # Person
|
416 |
-
"loc", # Location
|
417 |
-
"org", # Organization
|
418 |
-
"date", # Product
|
419 |
-
"time", # Time
|
420 |
-
}
|
421 |
-
|
422 |
-
# Filter out entities that are not in the useful_types set unless they are comp.* entities
|
423 |
-
cleaned_entities = [
|
424 |
-
entity
|
425 |
-
for entity in entities
|
426 |
-
if entity["type"] in useful_types or "comp" in entity["type"]
|
427 |
-
]
|
428 |
-
|
429 |
-
return cleaned_entities
|
430 |
-
|
431 |
-
|
432 |
-
def postprocess_entities(entities):
|
433 |
-
# Step 1: Filter entities with the same text, keeping the one with the most dots in the 'entity' field
|
434 |
-
entity_map = {}
|
435 |
-
|
436 |
-
# Loop over the entities and prioritize the one with the most dots
|
437 |
-
for entity in entities:
|
438 |
-
entity_text = entity["surface"]
|
439 |
-
num_dots = entity["type"].count(".")
|
440 |
-
|
441 |
-
# If the entity text is new, or this entity has more dots, update the map
|
442 |
-
if (
|
443 |
-
entity_text not in entity_map
|
444 |
-
or entity_map[entity_text]["type"].count(".") < num_dots
|
445 |
-
):
|
446 |
-
entity_map[entity_text] = entity
|
447 |
-
|
448 |
-
# Collect the filtered entities from the map
|
449 |
-
filtered_entities = list(entity_map.values())
|
450 |
-
|
451 |
-
# Step 2: Attach "comp.function" entities to the closest other entities
|
452 |
-
filtered_entities = attach_comp_to_closest(filtered_entities)
|
453 |
-
if DEBUG:
|
454 |
-
print("After attach_comp_to_closest:", filtered_entities, "\n")
|
455 |
-
filtered_entities = repair_names_in_entities(filtered_entities)
|
456 |
-
if DEBUG:
|
457 |
-
print("After repair_names_in_entities:", filtered_entities, "\n")
|
458 |
-
|
459 |
-
# Step 3: Remove entities that are not useful for NEL
|
460 |
-
# filtered_entities = clean_coarse_entities(filtered_entities)
|
461 |
-
|
462 |
-
# filtered_entities = remove_blacklisted_entities(filtered_entities)
|
463 |
-
|
464 |
-
return filtered_entities
|
465 |
-
|
466 |
-
|
467 |
-
def remove_included_entities(entities):
|
468 |
-
# Loop through entities and remove those whose text is included in another with the same label
|
469 |
-
final_entities = []
|
470 |
-
for i, entity in enumerate(entities):
|
471 |
-
is_included = False
|
472 |
-
for other_entity in entities:
|
473 |
-
if entity["surface"] != other_entity["surface"]:
|
474 |
-
if "comp" in other_entity["type"]:
|
475 |
-
# Check if entity's text is a substring of another entity's text
|
476 |
-
if entity["surface"] in other_entity["surface"]:
|
477 |
-
is_included = True
|
478 |
-
break
|
479 |
-
elif (
|
480 |
-
entity["type"].split(".")[0] in other_entity["type"].split(".")[0]
|
481 |
-
or other_entity["type"].split(".")[0]
|
482 |
-
in entity["type"].split(".")[0]
|
483 |
-
):
|
484 |
-
if entity["surface"] in other_entity["surface"]:
|
485 |
-
is_included = True
|
486 |
-
if not is_included:
|
487 |
-
final_entities.append(entity)
|
488 |
-
return final_entities
|
489 |
-
|
490 |
-
|
491 |
-
def refine_entities_with_coarse(all_entities, coarse_entities):
|
492 |
-
"""
|
493 |
-
Looks through all entities and refines them based on the coarse entities.
|
494 |
-
If a surface match is found in the coarse entities and the types match,
|
495 |
-
the entity's confidence_ner and type are updated based on the coarse entity.
|
496 |
-
"""
|
497 |
-
# Create a dictionary for coarse entities based on surface and type for quick lookup
|
498 |
-
coarse_lookup = {}
|
499 |
-
for coarse_entity in coarse_entities:
|
500 |
-
key = (coarse_entity["surface"], coarse_entity["type"].split(".")[0])
|
501 |
-
coarse_lookup[key] = coarse_entity
|
502 |
-
|
503 |
-
# Iterate through all entities and compare with the coarse entities
|
504 |
-
for entity in all_entities:
|
505 |
-
key = (
|
506 |
-
entity["surface"],
|
507 |
-
entity["type"].split(".")[0],
|
508 |
-
) # Use the coarse type for comparison
|
509 |
-
|
510 |
-
if key in coarse_lookup:
|
511 |
-
coarse_entity = coarse_lookup[key]
|
512 |
-
# If a match is found, update the confidence_ner and type in the entity
|
513 |
-
if entity["confidence_ner"] < coarse_entity["confidence_ner"]:
|
514 |
-
entity["confidence_ner"] = coarse_entity["confidence_ner"]
|
515 |
-
entity["type"] = coarse_entity[
|
516 |
-
"type"
|
517 |
-
] # Update the type if the confidence is higher
|
518 |
-
|
519 |
-
# No need to append to refined_entities, we're modifying in place
|
520 |
-
for entity in all_entities:
|
521 |
-
entity["type"] = entity["type"].split(".")[0]
|
522 |
-
return all_entities
|
523 |
-
|
524 |
-
|
525 |
-
def remove_trailing_stopwords(entities):
|
526 |
-
"""
|
527 |
-
This function removes stopwords and punctuation from both the beginning and end of each entity's text
|
528 |
-
and repairs the lOffset and rOffset accordingly.
|
529 |
-
"""
|
530 |
-
if DEBUG:
|
531 |
-
print(f"Initial entities: {len(entities)}")
|
532 |
-
new_entities = []
|
533 |
-
for entity in entities:
|
534 |
-
if "comp" not in entity["type"]:
|
535 |
-
entity_text = entity["surface"]
|
536 |
-
original_len = len(entity_text)
|
537 |
-
|
538 |
-
# Initial offsets
|
539 |
-
lOffset = entity.get("lOffset", 0)
|
540 |
-
rOffset = entity.get("rOffset", original_len)
|
541 |
-
|
542 |
-
# Remove stopwords and punctuation from the beginning
|
543 |
-
i = 0
|
544 |
-
while entity_text and (
|
545 |
-
entity_text.split()[0].lower() in stop_words
|
546 |
-
or entity_text[0] in punctuation
|
547 |
-
):
|
548 |
-
if entity_text.split()[0].lower() in stop_words:
|
549 |
-
stopword_len = (
|
550 |
-
len(entity_text.split()[0]) + 1
|
551 |
-
) # Adjust length for stopword and following space
|
552 |
-
entity_text = entity_text[stopword_len:] # Remove leading stopword
|
553 |
-
lOffset += stopword_len # Adjust the left offset
|
554 |
-
if DEBUG:
|
555 |
-
print(
|
556 |
-
f"Removed leading stopword from entity: {entity['surface']} --> {entity_text} ({entity['type']}"
|
557 |
-
)
|
558 |
-
elif entity_text[0] in punctuation:
|
559 |
-
entity_text = entity_text[1:] # Remove leading punctuation
|
560 |
-
lOffset += 1 # Adjust the left offset
|
561 |
-
if DEBUG:
|
562 |
-
print(
|
563 |
-
f"Removed leading punctuation from entity: {entity['surface']} --> {entity_text} ({entity['type']}"
|
564 |
-
)
|
565 |
-
i += 1
|
566 |
-
|
567 |
-
i = 0
|
568 |
-
# Remove stopwords and punctuation from the end
|
569 |
-
iteration = 0
|
570 |
-
max_iterations = len(entity_text) # Prevent infinite loops
|
571 |
-
|
572 |
-
while entity_text and iteration < max_iterations:
|
573 |
-
# Check if the last word is a stopword or the last character is punctuation
|
574 |
-
last_word = entity_text.split()[-1] if entity_text.split() else ""
|
575 |
-
last_char = entity_text[-1]
|
576 |
-
|
577 |
-
if last_word.lower() in stop_words:
|
578 |
-
# Remove trailing stopword and adjust rOffset
|
579 |
-
stopword_len = len(last_word) + 1 # Include space before stopword
|
580 |
-
entity_text = entity_text[:-stopword_len].rstrip()
|
581 |
-
rOffset -= stopword_len
|
582 |
-
if DEBUG:
|
583 |
-
print(
|
584 |
-
f"Removed trailing stopword from entity: {entity_text} (rOffset={rOffset})"
|
585 |
-
)
|
586 |
-
|
587 |
-
elif last_char in punctuation:
|
588 |
-
# Remove trailing punctuation and adjust rOffset
|
589 |
-
entity_text = entity_text[:-1].rstrip()
|
590 |
-
rOffset -= 1
|
591 |
-
if DEBUG:
|
592 |
-
print(
|
593 |
-
f"Removed trailing punctuation from entity: {entity_text} (rOffset={rOffset})"
|
594 |
-
)
|
595 |
-
else:
|
596 |
-
# Exit loop if neither stopwords nor punctuation are found
|
597 |
-
break
|
598 |
-
|
599 |
-
iteration += 1
|
600 |
-
# print(f"ITERATION: {iteration} [{entity['surface']}] for {entity_text}")
|
601 |
-
|
602 |
-
if len(entity_text.strip()) == 1:
|
603 |
-
entities.remove(entity)
|
604 |
-
if DEBUG:
|
605 |
-
print(f"Skipping entity: {entity_text}")
|
606 |
-
continue
|
607 |
-
# Skip certain entities based on rules
|
608 |
-
if entity_text in string.punctuation:
|
609 |
-
if DEBUG:
|
610 |
-
print(f"Skipping entity: {entity_text}")
|
611 |
-
entities.remove(entity)
|
612 |
-
continue
|
613 |
-
# check now if its in stopwords
|
614 |
-
if entity_text.lower() in stop_words:
|
615 |
-
if DEBUG:
|
616 |
-
print(f"Skipping entity: {entity_text}")
|
617 |
-
entities.remove(entity)
|
618 |
-
continue
|
619 |
-
# check now if the entire entity is a list of stopwords:
|
620 |
-
if all([word.lower() in stop_words for word in entity_text.split()]):
|
621 |
-
if DEBUG:
|
622 |
-
print(f"Skipping entity: {entity_text}")
|
623 |
-
entities.remove(entity)
|
624 |
-
continue
|
625 |
-
# Check if the entire entity is made up of stopwords characters
|
626 |
-
if all(
|
627 |
-
[char.lower() in stop_words for char in entity_text if char.isalpha()]
|
628 |
-
):
|
629 |
-
if DEBUG:
|
630 |
-
print(
|
631 |
-
f"Skipping entity: {entity_text} (all characters are stopwords)"
|
632 |
-
)
|
633 |
-
entities.remove(entity)
|
634 |
-
continue
|
635 |
-
# check now if all entity is in a list of punctuation
|
636 |
-
if all([word in string.punctuation for word in entity_text.split()]):
|
637 |
-
if DEBUG:
|
638 |
-
print(
|
639 |
-
f"Skipping entity: {entity_text} (all characters are punctuation)"
|
640 |
-
)
|
641 |
-
entities.remove(entity)
|
642 |
-
continue
|
643 |
-
if all(
|
644 |
-
[
|
645 |
-
char.lower() in string.punctuation
|
646 |
-
for char in entity_text
|
647 |
-
if char.isalpha()
|
648 |
-
]
|
649 |
-
):
|
650 |
-
if DEBUG:
|
651 |
-
print(
|
652 |
-
f"Skipping entity: {entity_text} (all characters are punctuation)"
|
653 |
-
)
|
654 |
-
entities.remove(entity)
|
655 |
-
continue
|
656 |
-
|
657 |
-
# if it's a number and "time" no in it, then continue
|
658 |
-
if entity_text.isdigit() and "time" not in entity["type"]:
|
659 |
-
if DEBUG:
|
660 |
-
print(f"Skipping entity: {entity_text}")
|
661 |
-
entities.remove(entity)
|
662 |
-
continue
|
663 |
-
|
664 |
-
if entity_text.startswith(" "):
|
665 |
-
entity_text = entity_text[1:]
|
666 |
-
# update lOffset, rOffset
|
667 |
-
lOffset += 1
|
668 |
-
if entity_text.endswith(" "):
|
669 |
-
entity_text = entity_text[:-1]
|
670 |
-
# update lOffset, rOffset
|
671 |
-
rOffset -= 1
|
672 |
-
|
673 |
-
# Update the entity surface and offsets
|
674 |
-
entity["surface"] = entity_text
|
675 |
-
entity["lOffset"] = lOffset
|
676 |
-
entity["rOffset"] = rOffset
|
677 |
-
|
678 |
-
# Remove the entity if the surface is empty after cleaning
|
679 |
-
if len(entity["surface"].strip()) == 0:
|
680 |
-
if DEBUG:
|
681 |
-
print(f"Deleted entity: {entity['surface']}")
|
682 |
-
entities.remove(entity)
|
683 |
-
else:
|
684 |
-
new_entities.append(entity)
|
685 |
-
|
686 |
-
if DEBUG:
|
687 |
-
print(f"Remained entities: {len(new_entities)}")
|
688 |
-
return new_entities
|
689 |
-
|
690 |
|
691 |
class MultitaskTokenClassificationPipeline(Pipeline):
|
692 |
|
@@ -703,15 +29,9 @@ class MultitaskTokenClassificationPipeline(Pipeline):
|
|
703 |
|
704 |
def preprocess(self, text, **kwargs):
|
705 |
|
706 |
-
|
707 |
-
text, padding="max_length", truncation=True, max_length=512
|
708 |
-
)
|
709 |
-
|
710 |
-
text_sentence = tokenize(add_spaces_around_punctuation(text))
|
711 |
-
return tokenized_inputs, text_sentence, text
|
712 |
|
713 |
-
def _forward(self,
|
714 |
-
inputs, text_sentences, text = inputs
|
715 |
|
716 |
return text
|
717 |
|
|
|
13 |
import torch.nn.functional as F
|
14 |
import re, string
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
class MultitaskTokenClassificationPipeline(Pipeline):
|
18 |
|
|
|
29 |
|
30 |
def preprocess(self, text, **kwargs):
|
31 |
|
32 |
+
return text
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
+
def _forward(self, text):
|
|
|
35 |
|
36 |
return text
|
37 |
|
push_to_hf.py
DELETED
@@ -1,145 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import shutil
|
3 |
-
import argparse
|
4 |
-
from transformers import (
|
5 |
-
AutoTokenizer,
|
6 |
-
AutoConfig,
|
7 |
-
AutoModelForTokenClassification,
|
8 |
-
BertConfig,
|
9 |
-
)
|
10 |
-
from huggingface_hub import HfApi, Repository
|
11 |
-
|
12 |
-
# import json
|
13 |
-
from .configuration_stacked import ImpressoConfig
|
14 |
-
from .modeling_stacked import ExtendedMultitaskModelForTokenClassification
|
15 |
-
import subprocess
|
16 |
-
|
17 |
-
|
18 |
-
def get_latest_checkpoint(checkpoint_dir):
|
19 |
-
checkpoints = [
|
20 |
-
d
|
21 |
-
for d in os.listdir(checkpoint_dir)
|
22 |
-
if os.path.isdir(os.path.join(checkpoint_dir, d))
|
23 |
-
and d.startswith("checkpoint-")
|
24 |
-
]
|
25 |
-
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[-1]), reverse=True)
|
26 |
-
return os.path.join(checkpoint_dir, checkpoints[0])
|
27 |
-
|
28 |
-
|
29 |
-
def get_info(label_map):
|
30 |
-
num_token_labels_dict = {task: len(labels) for task, labels in label_map.items()}
|
31 |
-
return num_token_labels_dict
|
32 |
-
|
33 |
-
|
34 |
-
def push_model_to_hub(checkpoint_dir, repo_name, script_path):
|
35 |
-
checkpoint_path = get_latest_checkpoint(checkpoint_dir)
|
36 |
-
config = ImpressoConfig.from_pretrained(checkpoint_path)
|
37 |
-
config.pretrained_config = AutoConfig.from_pretrained(config.name_or_path)
|
38 |
-
config.save_pretrained("stacked_bert")
|
39 |
-
config = ImpressoConfig.from_pretrained("stacked_bert")
|
40 |
-
|
41 |
-
model = ExtendedMultitaskModelForTokenClassification.from_pretrained(
|
42 |
-
checkpoint_path, config=config
|
43 |
-
)
|
44 |
-
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
|
45 |
-
local_repo_path = "./repo"
|
46 |
-
repo_url = HfApi().create_repo(repo_id=repo_name, exist_ok=True)
|
47 |
-
repo = Repository(local_dir=local_repo_path, clone_from=repo_url)
|
48 |
-
|
49 |
-
try:
|
50 |
-
# Try to pull the latest changes from the remote repository using subprocess
|
51 |
-
subprocess.run(["git", "pull"], check=True, cwd=local_repo_path)
|
52 |
-
except subprocess.CalledProcessError as e:
|
53 |
-
# If fast-forward is not possible, reset the local branch to match the remote branch
|
54 |
-
subprocess.run(
|
55 |
-
["git", "reset", "--hard", "origin/main"],
|
56 |
-
check=True,
|
57 |
-
cwd=local_repo_path,
|
58 |
-
)
|
59 |
-
|
60 |
-
# Copy all Python files to the local repository directory
|
61 |
-
current_dir = os.path.dirname(os.path.abspath(__file__))
|
62 |
-
for filename in os.listdir(current_dir):
|
63 |
-
if filename.endswith(".py") or filename.endswith(".json"):
|
64 |
-
shutil.copy(
|
65 |
-
os.path.join(current_dir, filename),
|
66 |
-
os.path.join(local_repo_path, filename),
|
67 |
-
)
|
68 |
-
|
69 |
-
ImpressoConfig.register_for_auto_class()
|
70 |
-
AutoConfig.register("stacked_bert", ImpressoConfig)
|
71 |
-
AutoModelForTokenClassification.register(
|
72 |
-
ImpressoConfig, ExtendedMultitaskModelForTokenClassification
|
73 |
-
)
|
74 |
-
ExtendedMultitaskModelForTokenClassification.register_for_auto_class(
|
75 |
-
"AutoModelForTokenClassification"
|
76 |
-
)
|
77 |
-
|
78 |
-
model.save_pretrained(local_repo_path)
|
79 |
-
tokenizer.save_pretrained(local_repo_path)
|
80 |
-
|
81 |
-
# Add, commit and push the changes to the repository
|
82 |
-
subprocess.run(["git", "add", "."], check=True, cwd=local_repo_path)
|
83 |
-
subprocess.run(
|
84 |
-
["git", "commit", "-m", "Initial commit including model and configuration"],
|
85 |
-
check=True,
|
86 |
-
cwd=local_repo_path,
|
87 |
-
)
|
88 |
-
subprocess.run(["git", "push"], check=True, cwd=local_repo_path)
|
89 |
-
|
90 |
-
# Push the model to the hub (this includes the README template)
|
91 |
-
model.push_to_hub(repo_name)
|
92 |
-
tokenizer.push_to_hub(repo_name)
|
93 |
-
|
94 |
-
print(f"Model and repo pushed to: {repo_url}")
|
95 |
-
|
96 |
-
|
97 |
-
if __name__ == "__main__":
|
98 |
-
parser = argparse.ArgumentParser(description="Push NER model to Hugging Face Hub")
|
99 |
-
parser.add_argument(
|
100 |
-
"--model_type",
|
101 |
-
type=str,
|
102 |
-
required=True,
|
103 |
-
help="Type of the model (e.g., stacked-bert)",
|
104 |
-
)
|
105 |
-
parser.add_argument(
|
106 |
-
"--language",
|
107 |
-
type=str,
|
108 |
-
required=True,
|
109 |
-
help="Language of the model (e.g., multilingual)",
|
110 |
-
)
|
111 |
-
parser.add_argument(
|
112 |
-
"--checkpoint_dir",
|
113 |
-
type=str,
|
114 |
-
required=True,
|
115 |
-
help="Directory containing checkpoint folders",
|
116 |
-
)
|
117 |
-
parser.add_argument(
|
118 |
-
"--script_path", type=str, required=True, help="Path to the models.py script"
|
119 |
-
)
|
120 |
-
args = parser.parse_args()
|
121 |
-
repo_name = f"impresso-project/ner-{args.model_type}-{args.language}"
|
122 |
-
push_model_to_hub(args.checkpoint_dir, repo_name, args.script_path)
|
123 |
-
# PIPELINE_REGISTRY.register_pipeline(
|
124 |
-
# "generic-ner",
|
125 |
-
# pipeline_class=MultitaskTokenClassificationPipeline,
|
126 |
-
# pt_model=ExtendedMultitaskModelForTokenClassification,
|
127 |
-
# )
|
128 |
-
# model.config.custom_pipelines = {
|
129 |
-
# "generic-ner": {
|
130 |
-
# "impl": "generic_ner.MultitaskTokenClassificationPipeline",
|
131 |
-
# "pt": ["ExtendedMultitaskModelForTokenClassification"],
|
132 |
-
# "tf": [],
|
133 |
-
# }
|
134 |
-
# }
|
135 |
-
# classifier = pipeline(
|
136 |
-
# "generic-ner", model=model, tokenizer=tokenizer, label_map=label_map
|
137 |
-
# )
|
138 |
-
# from pprint import pprint
|
139 |
-
#
|
140 |
-
# pprint(
|
141 |
-
# classifier(
|
142 |
-
# "1. Le public est averti que Charlotte née Bourgoin, femme-de Joseph Digiez, et Maurice Bourgoin, enfant mineur représenté par le sieur Jaques Charles Gicot son curateur, ont été admis par arrêt du Conseil d'Etat du 5 décembre 1797, à solliciter une renonciation générale et absolue aux biens et aux dettes présentes et futures de Jean-Baptiste Bourgoin leur père."
|
143 |
-
# )
|
144 |
-
# )
|
145 |
-
# repo.push_to_hub(commit_message="Initial commit of the trained NER model with code")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|