emanuelaboros commited on
Commit
137cb44
·
1 Parent(s): 8e0fd47

lets try to change the pipeline

Browse files
Files changed (2) hide show
  1. lang_detect.py +2 -682
  2. push_to_hf.py +0 -145
lang_detect.py CHANGED
@@ -13,680 +13,6 @@ from nltk.tree import Tree
13
  import torch.nn.functional as F
14
  import re, string
15
 
16
- stop_words = set(nltk.corpus.stopwords.words("english"))
17
- DEBUG = False
18
- punctuation = (
19
- string.punctuation
20
- + "«»—…“”"
21
- + "—."
22
- + "–"
23
- + "’"
24
- + "‘"
25
- + "´"
26
- + "•"
27
- + "°"
28
- + "»"
29
- + "“"
30
- + "”"
31
- + "–"
32
- + "—"
33
- + "‘’“”„«»•–—―‣◦…§¶†‡‰′″〈〉"
34
- )
35
-
36
- # List of additional "strange" punctuation marks
37
- # additional_punctuation = "‘’“”„«»•–—―‣◦…§¶†‡‰′″〈〉"
38
-
39
-
40
- WHITESPACE_RULES = {
41
- "fr": {
42
- "pct_no_ws_before": [".", ",", ")", "]", "}", "°", "...", ".-", "%"],
43
- "pct_no_ws_after": ["(", "[", "{"],
44
- "pct_no_ws_before_after": ["'", "-"],
45
- "pct_number": [".", ","],
46
- },
47
- "de": {
48
- "pct_no_ws_before": [
49
- ".",
50
- ",",
51
- ")",
52
- "]",
53
- "}",
54
- "°",
55
- "...",
56
- "?",
57
- "!",
58
- ":",
59
- ";",
60
- ".-",
61
- "%",
62
- ],
63
- "pct_no_ws_after": ["(", "[", "{"],
64
- "pct_no_ws_before_after": ["'", "-"],
65
- "pct_number": [".", ","],
66
- },
67
- "other": {
68
- "pct_no_ws_before": [
69
- ".",
70
- ",",
71
- ")",
72
- "]",
73
- "}",
74
- "°",
75
- "...",
76
- "?",
77
- "!",
78
- ":",
79
- ";",
80
- ".-",
81
- "%",
82
- ],
83
- "pct_no_ws_after": ["(", "[", "{"],
84
- "pct_no_ws_before_after": ["'", "-"],
85
- "pct_number": [".", ","],
86
- },
87
- }
88
-
89
-
90
- def tokenize(text: str, language: str = "other") -> list[str]:
91
- """Apply whitespace rules to the given text and language, separating it into tokens.
92
-
93
- Args:
94
- text (str): The input text to separate into a list of tokens.
95
- language (str): Language of the text.
96
-
97
- Returns:
98
- list[str]: List of tokens with punctuation as separate tokens.
99
- """
100
- # text = add_spaces_around_punctuation(text)
101
- if not text:
102
- return []
103
-
104
- if language not in WHITESPACE_RULES:
105
- # Default behavior for languages without specific rules:
106
- # tokenize using standard whitespace splitting
107
- language = "other"
108
-
109
- wsrules = WHITESPACE_RULES[language]
110
- tokenized_text = []
111
- current_token = ""
112
-
113
- for char in text:
114
- if char in wsrules["pct_no_ws_before_after"]:
115
- if current_token:
116
- tokenized_text.append(current_token)
117
- tokenized_text.append(char)
118
- current_token = ""
119
- elif char in wsrules["pct_no_ws_before"] or char in wsrules["pct_no_ws_after"]:
120
- if current_token:
121
- tokenized_text.append(current_token)
122
- tokenized_text.append(char)
123
- current_token = ""
124
- elif char.isspace():
125
- if current_token:
126
- tokenized_text.append(current_token)
127
- current_token = ""
128
- else:
129
- current_token += char
130
-
131
- if current_token:
132
- tokenized_text.append(current_token)
133
-
134
- return tokenized_text
135
-
136
-
137
- def normalize_text(text):
138
- # Remove spaces and tabs for the search but keep newline characters
139
- return re.sub(r"[ \t]+", "", text)
140
-
141
-
142
- def find_entity_indices(article_text, search_text):
143
- # Normalize texts by removing spaces and tabs
144
- normalized_article = normalize_text(article_text)
145
- normalized_search = normalize_text(search_text)
146
-
147
- # Initialize a list to hold all start and end indices
148
- indices = []
149
-
150
- # Find all occurrences of the search text in the normalized article text
151
- start_index = 0
152
- while True:
153
- start_index = normalized_article.find(normalized_search, start_index)
154
- if start_index == -1:
155
- break
156
-
157
- # Calculate the actual start and end indices in the original article text
158
- original_chars = 0
159
- original_start_index = 0
160
- for i in range(start_index):
161
- while article_text[original_start_index] in (" ", "\t"):
162
- original_start_index += 1
163
- if article_text[original_start_index] not in (" ", "\t", "\n"):
164
- original_chars += 1
165
- original_start_index += 1
166
-
167
- original_end_index = original_start_index
168
- search_chars = 0
169
- while search_chars < len(normalized_search):
170
- if article_text[original_end_index] not in (" ", "\t", "\n"):
171
- search_chars += 1
172
- original_end_index += 1 # Increment to include the last character
173
-
174
- # Append the found indices to the list
175
- if article_text[original_start_index] == " ":
176
- original_start_index += 1
177
- indices.append((original_start_index, original_end_index))
178
-
179
- # Move start_index to the next position to continue searching
180
- start_index += 1
181
-
182
- return indices
183
-
184
-
185
- def get_entities(tokens, tags, confidences, text):
186
-
187
- tags = [tag.replace("S-", "B-").replace("E-", "I-") for tag in tags]
188
- pos_tags = [pos for token, pos in pos_tag(tokens)]
189
-
190
- for i in range(1, len(tags)):
191
- # If a 'B-' tag is followed by another 'B-' without an 'O' in between, change the second to 'I-'
192
- if tags[i].startswith("B-") and tags[i - 1].startswith("I-"):
193
- tags[i] = "I-" + tags[i][2:] # Change 'B-' to 'I-' for the same entity type
194
-
195
- conlltags = [(token, pos, tg) for token, pos, tg in zip(tokens, pos_tags, tags)]
196
- ne_tree = conlltags2tree(conlltags)
197
-
198
- entities = []
199
- idx: int = 0
200
- already_done = []
201
- for subtree in ne_tree:
202
- # skipping 'O' tags
203
- if isinstance(subtree, Tree):
204
- original_label = subtree.label()
205
- original_string = " ".join([token for token, pos in subtree.leaves()])
206
-
207
- for indices in find_entity_indices(text, original_string):
208
- entity_start_position = indices[0]
209
- entity_end_position = indices[1]
210
- if (
211
- "_".join(
212
- [original_label, original_string, str(entity_start_position)]
213
- )
214
- in already_done
215
- ):
216
- continue
217
- else:
218
- already_done.append(
219
- "_".join(
220
- [
221
- original_label,
222
- original_string,
223
- str(entity_start_position),
224
- ]
225
- )
226
- )
227
- if len(text[entity_start_position:entity_end_position].strip()) < len(
228
- text[entity_start_position:entity_end_position]
229
- ):
230
- entity_start_position = (
231
- entity_start_position
232
- + len(text[entity_start_position:entity_end_position])
233
- - len(text[entity_start_position:entity_end_position].strip())
234
- )
235
-
236
- entities.append(
237
- {
238
- "type": original_label,
239
- "confidence_ner": round(
240
- np.average(confidences[idx : idx + len(subtree)]) * 100, 2
241
- ),
242
- "index": (idx, idx + len(subtree)),
243
- "surface": text[
244
- entity_start_position:entity_end_position
245
- ], # original_string,
246
- "lOffset": entity_start_position,
247
- "rOffset": entity_end_position,
248
- }
249
- )
250
-
251
- idx += len(subtree)
252
-
253
- # Update the current character position
254
- # We add the length of the original string + 1 (for the space)
255
- else:
256
- token, pos = subtree
257
- # If it's not a named entity, we still need to update the character
258
- # position
259
- idx += 1
260
-
261
- return entities
262
-
263
-
264
- def realign(
265
- text_sentence, out_label_preds, softmax_scores, tokenizer, reverted_label_map
266
- ):
267
- preds_list, words_list, confidence_list = [], [], []
268
- word_ids = tokenizer(text_sentence, is_split_into_words=True).word_ids()
269
- for idx, word in enumerate(text_sentence):
270
- beginning_index = word_ids.index(idx)
271
- try:
272
- preds_list.append(reverted_label_map[out_label_preds[beginning_index]])
273
- confidence_list.append(max(softmax_scores[beginning_index]))
274
- except Exception as ex: # the sentence was longer then max_length
275
- preds_list.append("O")
276
- confidence_list.append(0.0)
277
- words_list.append(word)
278
-
279
- return words_list, preds_list, confidence_list
280
-
281
-
282
- def add_spaces_around_punctuation(text):
283
- # Add a space before and after all punctuation
284
- all_punctuation = string.punctuation + punctuation
285
- return re.sub(r"([{}])".format(re.escape(all_punctuation)), r" \1 ", text)
286
-
287
-
288
- def attach_comp_to_closest(entities):
289
- # Define valid entity types that can receive a "comp.function" or "comp.name" attachment
290
- valid_entity_types = {"org", "pers", "org.ent", "pers.ind"}
291
-
292
- # Separate "comp.function" and "comp.name" entities from other entities
293
- comp_entities = [ent for ent in entities if ent["type"].startswith("comp")]
294
- other_entities = [ent for ent in entities if not ent["type"].startswith("comp")]
295
-
296
- for comp_entity in comp_entities:
297
- closest_entity = None
298
- min_distance = float("inf")
299
-
300
- # Find the closest non-"comp" entity that is valid for attaching
301
- for other_entity in other_entities:
302
- # Calculate distance between the comp entity and the other entity
303
- if comp_entity["lOffset"] > other_entity["rOffset"]:
304
- distance = comp_entity["lOffset"] - other_entity["rOffset"]
305
- elif comp_entity["rOffset"] < other_entity["lOffset"]:
306
- distance = other_entity["lOffset"] - comp_entity["rOffset"]
307
- else:
308
- distance = 0 # They overlap or touch
309
-
310
- # Ensure the entity type is valid and check for minimal distance
311
- if (
312
- distance < min_distance
313
- and other_entity["type"].split(".")[0] in valid_entity_types
314
- ):
315
- min_distance = distance
316
- closest_entity = other_entity
317
-
318
- # Attach the "comp.function" or "comp.name" if a valid entity is found
319
- if closest_entity:
320
- suffix = comp_entity["type"].split(".")[
321
- -1
322
- ] # Extract the suffix (e.g., 'name', 'function')
323
- closest_entity[suffix] = comp_entity["surface"] # Attach the text
324
-
325
- return other_entities
326
-
327
-
328
- def conflicting_context(comp_entity, target_entity):
329
- """
330
- Determines if there is a conflict between the comp_entity and the target entity.
331
- Prevents incorrect name and function attachments by using a rule-based approach.
332
- """
333
- # Case 1: Check for correct function attachment to person or organization entities
334
- if comp_entity["type"].startswith("comp.function"):
335
- if not ("pers" in target_entity["type"] or "org" in target_entity["type"]):
336
- return True # Conflict: Function should only attach to persons or organizations
337
-
338
- # Case 2: Avoid attaching comp.* entities to non-person, non-organization types (like locations)
339
- if "loc" in target_entity["type"]:
340
- return True # Conflict: comp.* entities should not attach to locations or similar types
341
-
342
- return False # No conflict
343
-
344
-
345
- def extract_name_from_text(text, partial_name):
346
- """
347
- Extracts the full name from the entity's text based on the partial name.
348
- This function assumes that the full name starts with capitalized letters and does not
349
- include any words that come after the partial name.
350
- """
351
- # Split the text and partial name into words
352
- words = tokenize(text)
353
- partial_words = partial_name.split()
354
-
355
- if DEBUG:
356
- print("text:", text)
357
- if DEBUG:
358
- print("partial_name:", partial_name)
359
-
360
- # Find the position of the partial name in the word list
361
- for i, word in enumerate(words):
362
- if DEBUG:
363
- print(words, "---", words[i : i + len(partial_words)])
364
- if words[i : i + len(partial_words)] == partial_words:
365
- # Initialize full name with the partial name
366
- full_name = partial_words[:]
367
-
368
- if DEBUG:
369
- print("full_name:", full_name)
370
-
371
- # Check previous words and only add capitalized words (skip lowercase words)
372
- j = i - 1
373
- while j >= 0 and words[j][0].isupper():
374
- full_name.insert(0, words[j])
375
- j -= 1
376
- if DEBUG:
377
- print("full_name:", full_name)
378
-
379
- # Return only the full name up to the partial name (ignore words after the name)
380
- return " ".join(full_name).strip() # Join the words to form the full name
381
-
382
- # If not found, return the original text (as a fallback)
383
- return text.strip()
384
-
385
-
386
- def repair_names_in_entities(entities):
387
- """
388
- This function repairs the names in the entities by extracting the full name
389
- from the text of the entity if a partial name (e.g., 'Washington') is incorrectly attached.
390
- """
391
- for entity in entities:
392
- if "name" in entity and "pers" in entity["type"]:
393
- name = entity["name"]
394
- text = entity["surface"]
395
-
396
- # Check if the attached name is part of the entity's text
397
- if name in text:
398
- # Extract the full name from the text by splitting around the attached name
399
- full_name = extract_name_from_text(entity["surface"], name)
400
- entity["name"] = (
401
- full_name # Replace the partial name with the full name
402
- )
403
- # if "name" not in entity:
404
- # entity["name"] = entity["surface"]
405
-
406
- return entities
407
-
408
-
409
- def clean_coarse_entities(entities):
410
- """
411
- This function removes entities that are not useful for the NEL process.
412
- """
413
- # Define a set of entity types that are considered useful for NEL
414
- useful_types = {
415
- "pers", # Person
416
- "loc", # Location
417
- "org", # Organization
418
- "date", # Product
419
- "time", # Time
420
- }
421
-
422
- # Filter out entities that are not in the useful_types set unless they are comp.* entities
423
- cleaned_entities = [
424
- entity
425
- for entity in entities
426
- if entity["type"] in useful_types or "comp" in entity["type"]
427
- ]
428
-
429
- return cleaned_entities
430
-
431
-
432
- def postprocess_entities(entities):
433
- # Step 1: Filter entities with the same text, keeping the one with the most dots in the 'entity' field
434
- entity_map = {}
435
-
436
- # Loop over the entities and prioritize the one with the most dots
437
- for entity in entities:
438
- entity_text = entity["surface"]
439
- num_dots = entity["type"].count(".")
440
-
441
- # If the entity text is new, or this entity has more dots, update the map
442
- if (
443
- entity_text not in entity_map
444
- or entity_map[entity_text]["type"].count(".") < num_dots
445
- ):
446
- entity_map[entity_text] = entity
447
-
448
- # Collect the filtered entities from the map
449
- filtered_entities = list(entity_map.values())
450
-
451
- # Step 2: Attach "comp.function" entities to the closest other entities
452
- filtered_entities = attach_comp_to_closest(filtered_entities)
453
- if DEBUG:
454
- print("After attach_comp_to_closest:", filtered_entities, "\n")
455
- filtered_entities = repair_names_in_entities(filtered_entities)
456
- if DEBUG:
457
- print("After repair_names_in_entities:", filtered_entities, "\n")
458
-
459
- # Step 3: Remove entities that are not useful for NEL
460
- # filtered_entities = clean_coarse_entities(filtered_entities)
461
-
462
- # filtered_entities = remove_blacklisted_entities(filtered_entities)
463
-
464
- return filtered_entities
465
-
466
-
467
- def remove_included_entities(entities):
468
- # Loop through entities and remove those whose text is included in another with the same label
469
- final_entities = []
470
- for i, entity in enumerate(entities):
471
- is_included = False
472
- for other_entity in entities:
473
- if entity["surface"] != other_entity["surface"]:
474
- if "comp" in other_entity["type"]:
475
- # Check if entity's text is a substring of another entity's text
476
- if entity["surface"] in other_entity["surface"]:
477
- is_included = True
478
- break
479
- elif (
480
- entity["type"].split(".")[0] in other_entity["type"].split(".")[0]
481
- or other_entity["type"].split(".")[0]
482
- in entity["type"].split(".")[0]
483
- ):
484
- if entity["surface"] in other_entity["surface"]:
485
- is_included = True
486
- if not is_included:
487
- final_entities.append(entity)
488
- return final_entities
489
-
490
-
491
- def refine_entities_with_coarse(all_entities, coarse_entities):
492
- """
493
- Looks through all entities and refines them based on the coarse entities.
494
- If a surface match is found in the coarse entities and the types match,
495
- the entity's confidence_ner and type are updated based on the coarse entity.
496
- """
497
- # Create a dictionary for coarse entities based on surface and type for quick lookup
498
- coarse_lookup = {}
499
- for coarse_entity in coarse_entities:
500
- key = (coarse_entity["surface"], coarse_entity["type"].split(".")[0])
501
- coarse_lookup[key] = coarse_entity
502
-
503
- # Iterate through all entities and compare with the coarse entities
504
- for entity in all_entities:
505
- key = (
506
- entity["surface"],
507
- entity["type"].split(".")[0],
508
- ) # Use the coarse type for comparison
509
-
510
- if key in coarse_lookup:
511
- coarse_entity = coarse_lookup[key]
512
- # If a match is found, update the confidence_ner and type in the entity
513
- if entity["confidence_ner"] < coarse_entity["confidence_ner"]:
514
- entity["confidence_ner"] = coarse_entity["confidence_ner"]
515
- entity["type"] = coarse_entity[
516
- "type"
517
- ] # Update the type if the confidence is higher
518
-
519
- # No need to append to refined_entities, we're modifying in place
520
- for entity in all_entities:
521
- entity["type"] = entity["type"].split(".")[0]
522
- return all_entities
523
-
524
-
525
- def remove_trailing_stopwords(entities):
526
- """
527
- This function removes stopwords and punctuation from both the beginning and end of each entity's text
528
- and repairs the lOffset and rOffset accordingly.
529
- """
530
- if DEBUG:
531
- print(f"Initial entities: {len(entities)}")
532
- new_entities = []
533
- for entity in entities:
534
- if "comp" not in entity["type"]:
535
- entity_text = entity["surface"]
536
- original_len = len(entity_text)
537
-
538
- # Initial offsets
539
- lOffset = entity.get("lOffset", 0)
540
- rOffset = entity.get("rOffset", original_len)
541
-
542
- # Remove stopwords and punctuation from the beginning
543
- i = 0
544
- while entity_text and (
545
- entity_text.split()[0].lower() in stop_words
546
- or entity_text[0] in punctuation
547
- ):
548
- if entity_text.split()[0].lower() in stop_words:
549
- stopword_len = (
550
- len(entity_text.split()[0]) + 1
551
- ) # Adjust length for stopword and following space
552
- entity_text = entity_text[stopword_len:] # Remove leading stopword
553
- lOffset += stopword_len # Adjust the left offset
554
- if DEBUG:
555
- print(
556
- f"Removed leading stopword from entity: {entity['surface']} --> {entity_text} ({entity['type']}"
557
- )
558
- elif entity_text[0] in punctuation:
559
- entity_text = entity_text[1:] # Remove leading punctuation
560
- lOffset += 1 # Adjust the left offset
561
- if DEBUG:
562
- print(
563
- f"Removed leading punctuation from entity: {entity['surface']} --> {entity_text} ({entity['type']}"
564
- )
565
- i += 1
566
-
567
- i = 0
568
- # Remove stopwords and punctuation from the end
569
- iteration = 0
570
- max_iterations = len(entity_text) # Prevent infinite loops
571
-
572
- while entity_text and iteration < max_iterations:
573
- # Check if the last word is a stopword or the last character is punctuation
574
- last_word = entity_text.split()[-1] if entity_text.split() else ""
575
- last_char = entity_text[-1]
576
-
577
- if last_word.lower() in stop_words:
578
- # Remove trailing stopword and adjust rOffset
579
- stopword_len = len(last_word) + 1 # Include space before stopword
580
- entity_text = entity_text[:-stopword_len].rstrip()
581
- rOffset -= stopword_len
582
- if DEBUG:
583
- print(
584
- f"Removed trailing stopword from entity: {entity_text} (rOffset={rOffset})"
585
- )
586
-
587
- elif last_char in punctuation:
588
- # Remove trailing punctuation and adjust rOffset
589
- entity_text = entity_text[:-1].rstrip()
590
- rOffset -= 1
591
- if DEBUG:
592
- print(
593
- f"Removed trailing punctuation from entity: {entity_text} (rOffset={rOffset})"
594
- )
595
- else:
596
- # Exit loop if neither stopwords nor punctuation are found
597
- break
598
-
599
- iteration += 1
600
- # print(f"ITERATION: {iteration} [{entity['surface']}] for {entity_text}")
601
-
602
- if len(entity_text.strip()) == 1:
603
- entities.remove(entity)
604
- if DEBUG:
605
- print(f"Skipping entity: {entity_text}")
606
- continue
607
- # Skip certain entities based on rules
608
- if entity_text in string.punctuation:
609
- if DEBUG:
610
- print(f"Skipping entity: {entity_text}")
611
- entities.remove(entity)
612
- continue
613
- # check now if its in stopwords
614
- if entity_text.lower() in stop_words:
615
- if DEBUG:
616
- print(f"Skipping entity: {entity_text}")
617
- entities.remove(entity)
618
- continue
619
- # check now if the entire entity is a list of stopwords:
620
- if all([word.lower() in stop_words for word in entity_text.split()]):
621
- if DEBUG:
622
- print(f"Skipping entity: {entity_text}")
623
- entities.remove(entity)
624
- continue
625
- # Check if the entire entity is made up of stopwords characters
626
- if all(
627
- [char.lower() in stop_words for char in entity_text if char.isalpha()]
628
- ):
629
- if DEBUG:
630
- print(
631
- f"Skipping entity: {entity_text} (all characters are stopwords)"
632
- )
633
- entities.remove(entity)
634
- continue
635
- # check now if all entity is in a list of punctuation
636
- if all([word in string.punctuation for word in entity_text.split()]):
637
- if DEBUG:
638
- print(
639
- f"Skipping entity: {entity_text} (all characters are punctuation)"
640
- )
641
- entities.remove(entity)
642
- continue
643
- if all(
644
- [
645
- char.lower() in string.punctuation
646
- for char in entity_text
647
- if char.isalpha()
648
- ]
649
- ):
650
- if DEBUG:
651
- print(
652
- f"Skipping entity: {entity_text} (all characters are punctuation)"
653
- )
654
- entities.remove(entity)
655
- continue
656
-
657
- # if it's a number and "time" no in it, then continue
658
- if entity_text.isdigit() and "time" not in entity["type"]:
659
- if DEBUG:
660
- print(f"Skipping entity: {entity_text}")
661
- entities.remove(entity)
662
- continue
663
-
664
- if entity_text.startswith(" "):
665
- entity_text = entity_text[1:]
666
- # update lOffset, rOffset
667
- lOffset += 1
668
- if entity_text.endswith(" "):
669
- entity_text = entity_text[:-1]
670
- # update lOffset, rOffset
671
- rOffset -= 1
672
-
673
- # Update the entity surface and offsets
674
- entity["surface"] = entity_text
675
- entity["lOffset"] = lOffset
676
- entity["rOffset"] = rOffset
677
-
678
- # Remove the entity if the surface is empty after cleaning
679
- if len(entity["surface"].strip()) == 0:
680
- if DEBUG:
681
- print(f"Deleted entity: {entity['surface']}")
682
- entities.remove(entity)
683
- else:
684
- new_entities.append(entity)
685
-
686
- if DEBUG:
687
- print(f"Remained entities: {len(new_entities)}")
688
- return new_entities
689
-
690
 
691
  class MultitaskTokenClassificationPipeline(Pipeline):
692
 
@@ -703,15 +29,9 @@ class MultitaskTokenClassificationPipeline(Pipeline):
703
 
704
  def preprocess(self, text, **kwargs):
705
 
706
- tokenized_inputs = self.tokenizer(
707
- text, padding="max_length", truncation=True, max_length=512
708
- )
709
-
710
- text_sentence = tokenize(add_spaces_around_punctuation(text))
711
- return tokenized_inputs, text_sentence, text
712
 
713
- def _forward(self, inputs):
714
- inputs, text_sentences, text = inputs
715
 
716
  return text
717
 
 
13
  import torch.nn.functional as F
14
  import re, string
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  class MultitaskTokenClassificationPipeline(Pipeline):
18
 
 
29
 
30
  def preprocess(self, text, **kwargs):
31
 
32
+ return text
 
 
 
 
 
33
 
34
+ def _forward(self, text):
 
35
 
36
  return text
37
 
push_to_hf.py DELETED
@@ -1,145 +0,0 @@
1
- import os
2
- import shutil
3
- import argparse
4
- from transformers import (
5
- AutoTokenizer,
6
- AutoConfig,
7
- AutoModelForTokenClassification,
8
- BertConfig,
9
- )
10
- from huggingface_hub import HfApi, Repository
11
-
12
- # import json
13
- from .configuration_stacked import ImpressoConfig
14
- from .modeling_stacked import ExtendedMultitaskModelForTokenClassification
15
- import subprocess
16
-
17
-
18
- def get_latest_checkpoint(checkpoint_dir):
19
- checkpoints = [
20
- d
21
- for d in os.listdir(checkpoint_dir)
22
- if os.path.isdir(os.path.join(checkpoint_dir, d))
23
- and d.startswith("checkpoint-")
24
- ]
25
- checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[-1]), reverse=True)
26
- return os.path.join(checkpoint_dir, checkpoints[0])
27
-
28
-
29
- def get_info(label_map):
30
- num_token_labels_dict = {task: len(labels) for task, labels in label_map.items()}
31
- return num_token_labels_dict
32
-
33
-
34
- def push_model_to_hub(checkpoint_dir, repo_name, script_path):
35
- checkpoint_path = get_latest_checkpoint(checkpoint_dir)
36
- config = ImpressoConfig.from_pretrained(checkpoint_path)
37
- config.pretrained_config = AutoConfig.from_pretrained(config.name_or_path)
38
- config.save_pretrained("stacked_bert")
39
- config = ImpressoConfig.from_pretrained("stacked_bert")
40
-
41
- model = ExtendedMultitaskModelForTokenClassification.from_pretrained(
42
- checkpoint_path, config=config
43
- )
44
- tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
45
- local_repo_path = "./repo"
46
- repo_url = HfApi().create_repo(repo_id=repo_name, exist_ok=True)
47
- repo = Repository(local_dir=local_repo_path, clone_from=repo_url)
48
-
49
- try:
50
- # Try to pull the latest changes from the remote repository using subprocess
51
- subprocess.run(["git", "pull"], check=True, cwd=local_repo_path)
52
- except subprocess.CalledProcessError as e:
53
- # If fast-forward is not possible, reset the local branch to match the remote branch
54
- subprocess.run(
55
- ["git", "reset", "--hard", "origin/main"],
56
- check=True,
57
- cwd=local_repo_path,
58
- )
59
-
60
- # Copy all Python files to the local repository directory
61
- current_dir = os.path.dirname(os.path.abspath(__file__))
62
- for filename in os.listdir(current_dir):
63
- if filename.endswith(".py") or filename.endswith(".json"):
64
- shutil.copy(
65
- os.path.join(current_dir, filename),
66
- os.path.join(local_repo_path, filename),
67
- )
68
-
69
- ImpressoConfig.register_for_auto_class()
70
- AutoConfig.register("stacked_bert", ImpressoConfig)
71
- AutoModelForTokenClassification.register(
72
- ImpressoConfig, ExtendedMultitaskModelForTokenClassification
73
- )
74
- ExtendedMultitaskModelForTokenClassification.register_for_auto_class(
75
- "AutoModelForTokenClassification"
76
- )
77
-
78
- model.save_pretrained(local_repo_path)
79
- tokenizer.save_pretrained(local_repo_path)
80
-
81
- # Add, commit and push the changes to the repository
82
- subprocess.run(["git", "add", "."], check=True, cwd=local_repo_path)
83
- subprocess.run(
84
- ["git", "commit", "-m", "Initial commit including model and configuration"],
85
- check=True,
86
- cwd=local_repo_path,
87
- )
88
- subprocess.run(["git", "push"], check=True, cwd=local_repo_path)
89
-
90
- # Push the model to the hub (this includes the README template)
91
- model.push_to_hub(repo_name)
92
- tokenizer.push_to_hub(repo_name)
93
-
94
- print(f"Model and repo pushed to: {repo_url}")
95
-
96
-
97
- if __name__ == "__main__":
98
- parser = argparse.ArgumentParser(description="Push NER model to Hugging Face Hub")
99
- parser.add_argument(
100
- "--model_type",
101
- type=str,
102
- required=True,
103
- help="Type of the model (e.g., stacked-bert)",
104
- )
105
- parser.add_argument(
106
- "--language",
107
- type=str,
108
- required=True,
109
- help="Language of the model (e.g., multilingual)",
110
- )
111
- parser.add_argument(
112
- "--checkpoint_dir",
113
- type=str,
114
- required=True,
115
- help="Directory containing checkpoint folders",
116
- )
117
- parser.add_argument(
118
- "--script_path", type=str, required=True, help="Path to the models.py script"
119
- )
120
- args = parser.parse_args()
121
- repo_name = f"impresso-project/ner-{args.model_type}-{args.language}"
122
- push_model_to_hub(args.checkpoint_dir, repo_name, args.script_path)
123
- # PIPELINE_REGISTRY.register_pipeline(
124
- # "generic-ner",
125
- # pipeline_class=MultitaskTokenClassificationPipeline,
126
- # pt_model=ExtendedMultitaskModelForTokenClassification,
127
- # )
128
- # model.config.custom_pipelines = {
129
- # "generic-ner": {
130
- # "impl": "generic_ner.MultitaskTokenClassificationPipeline",
131
- # "pt": ["ExtendedMultitaskModelForTokenClassification"],
132
- # "tf": [],
133
- # }
134
- # }
135
- # classifier = pipeline(
136
- # "generic-ner", model=model, tokenizer=tokenizer, label_map=label_map
137
- # )
138
- # from pprint import pprint
139
- #
140
- # pprint(
141
- # classifier(
142
- # "1. Le public est averti que Charlotte née Bourgoin, femme-de Joseph Digiez, et Maurice Bourgoin, enfant mineur représenté par le sieur Jaques Charles Gicot son curateur, ont été admis par arrêt du Conseil d'Etat du 5 décembre 1797, à solliciter une renonciation générale et absolue aux biens et aux dettes présentes et futures de Jean-Baptiste Bourgoin leur père."
143
- # )
144
- # )
145
- # repo.push_to_hub(commit_message="Initial commit of the trained NER model with code")