File size: 15,267 Bytes
155a79d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415

import unicodedata
import regex as re
from datasets import load_dataset
import time
import os

def get_stats(ids, stats=None):
    """
    Calcule la fréquence des paires d'ids consécutifs.
    Conserve la même logique que la version originale car cette fonction est indépendante 
    des spécificités de la langue.
    """
    stats = {} if stats is None else stats
    for pair in zip(ids, ids[1:]): 
        stats[pair] = stats.get(pair, 0) + 1
    return stats

def merge(ids, pair, idx):
    """
    Fusionne les paires d'ids identifiées.
    Conserve la même logique que la version originale car cette fonction gère
    uniquement la fusion des tokens déjà identifiés.
    """
    newids = []
    i = 0
    while i < len(ids):
        if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
            newids.append(idx)
            i += 2
        else:
            newids.append(ids[i])
            i += 1
    return newids

def replace_control_characters(s: str) -> str:
    """
    Remplace les caractères de contrôle, avec une attention particulière aux 
    caractères spéciaux du baoulé.
    """
    chars = []
    for ch in s:
        # Gestion spéciale des caractères baoulé
        if ch in ['ɛ', 'ɔ', 'ŋ', 'ɲ']:  # Caractères spéciaux baoulé
            chars.append(ch)
        # Gestion standard des caractères de contrôle
        elif unicodedata.category(ch)[0] != "C":
            chars.append(ch)
        else:
            chars.append(f"\u{ord(ch):04x}")
    return "".join(chars)

def render_token(t: bytes) -> str:
    """
    Décode les tokens en gérant les caractères spéciaux du baoulé.
    """
    try:
        # Tentative de décodage standard
        s = t.decode('utf-8', errors='replace')
        # Gestion des caractères spéciaux baoulé
        s = replace_control_characters(s)
        return s
    except UnicodeDecodeError:
        # En cas d'échec, retourne le caractère de remplacement
        return '�'



class BaouleTokenizer:
    def __init__(self):
        # Initialisation des attributs obligatoires
        self.special_chars = {
            'ɛ': 256,
            'ɔ': 257,
            'ŋ': 258,
            'ɲ': 259
        }
        
        self.pattern = r"(?i:'n|gb|kp|ny|[ɛɔ]n)|(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^

\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[

]*|\s*[

]+|\s+(?!\S)|\s+"
        self.compiled_pattern = re.compile(self.pattern)
        
        self.special_tokens = {
            #'<|baoule|>': 1101,
            #'<|french|>': 1102,
            #'<|end|>': 1103,
            #'<|unknown|>': 1104,
            #'<|pad|>': 1105,
            # or
            '<|begin_of_text|>': 1101,
            '<|end_of_text|>': 1102,
            '<|start_header_id|>': 1103,
            '<|end_header_id|>': 1104,
            '<|eot_id|>': 1105
        }
        
        self.merges = {}
        self.vocab = self._build_vocab()

    def train(self, dataset, vocab_size):
        assert vocab_size >= 260  # 256 ASCII + 4 caractères spéciaux baoulé minimum

        # Extraction des textes baoulé du dataset HuggingFace
        text_chunks = []
        for item in dataset['train']:
            chunks = re.findall(self.compiled_pattern, item['baoule'])
            text_chunks.extend(chunks)

        # Conversion en ids avec gestion spéciale des caractères baoulé
        ids = []
        for chunk in text_chunks:
            chunk_ids = []
            i = 0
            while i < len(chunk):
                # Vérification des digraphes
                if i < len(chunk) - 1:
                    digraph = chunk[i:i+2]
                    if digraph in ['gb', 'kp', 'ny']:
                        chunk_ids.append(ord(digraph[0]))
                        chunk_ids.append(ord(digraph[1]))
                        i += 2
                        continue
                
                # Gestion des caractères spéciaux baoulé
                char = chunk[i]
                if char in self.special_chars:
                    chunk_ids.append(self.special_chars[char])
                else:
                    # Encodage UTF-8 standard pour les autres caractères
                    chunk_ids.extend(list(char.encode("utf-8")))
                i += 1
            ids.append(chunk_ids)

        # Calcul des fusions
        num_merges = vocab_size - (260 + len(self.special_tokens))
        merges = {}
        vocab = {idx: bytes([idx]) for idx in range(256)}
        vocab.update({idx: char.encode('utf-8') for char, idx in self.special_chars.items()})

        for i in range(num_merges):
            stats = {}
            for chunk_ids in ids:
                get_stats(chunk_ids, stats)
            if not stats:
                break
            pair = max(stats, key=stats.get)
            idx = 260 + len(self.special_tokens) + i
            ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids]
            merges[pair] = idx
            vocab[idx] = vocab[pair[0]] + vocab[pair[1]]

        self.merges = merges
        self.vocab = vocab

    def _build_vocab(self):
        # Vocabulaire de base incluant ASCII et caractères spéciaux baoulé
        vocab = {idx: bytes([idx]) for idx in range(256)}
        vocab.update({idx: char.encode('utf-8') for char, idx in self.special_chars.items()})
        
        # Ajout des fusions
        for (p0, p1), idx in self.merges.items():
            vocab[idx] = vocab[p0] + vocab[p1]
        
        # Ajout des tokens spéciaux
        for special, idx in self.special_tokens.items():
            vocab[idx] = special.encode("utf-8")

        return vocab

    def save(self, file_prefix):
        # Sauvegarde du modèle
        model_file = file_prefix + ".model"
        with open(model_file, 'w') as f:
            f.write("baoule tokenizer v1.0
")
            f.write(f"{self.pattern}
")
            
            # Sauvegarde des caractères spéciaux baoulé
            f.write(f"{len(self.special_chars)}
")
            for char, idx in self.special_chars.items():
                f.write(f"{char} {idx}
")
            
            # Sauvegarde des tokens spéciaux
            f.write(f"{len(self.special_tokens)}
")
            for token, idx in self.special_tokens.items():
                f.write(f"{token} {idx}
")
            
            # Sauvegarde des fusions
            for idx1, idx2 in self.merges:
                f.write(f"{idx1} {idx2}
")

        # Sauvegarde du vocabulaire
        vocab_file = file_prefix + ".vocab"
        inverted_merges = {idx: pair for pair, idx in self.merges.items()}
        with open(vocab_file, "w", encoding="utf-8") as f:
            for idx, token in self.vocab.items():
                s = render_token(token)
                if idx in inverted_merges:
                    idx0, idx1 = inverted_merges[idx]
                    s0 = render_token(self.vocab[idx0])
                    s1 = render_token(self.vocab[idx1])
                    f.write(f"[{s0}][{s1}] -> [{s}] {idx}
")
                else:
                    f.write(f"[{s}] {idx}
")
    def load(self, model_file):
      merges = {}
      special_tokens = {}
      special_chars = {}
      
      with open(model_file, 'r', encoding="utf-8") as f:
          version = f.readline().strip()
          self.pattern = f.readline().strip()
          self.compiled_pattern = re.compile(self.pattern)
          
          # Lecture des caractères spéciaux baoulé
          num_special_chars = int(f.readline().strip())
          for _ in range(num_special_chars):
              char, char_idx = f.readline().strip().split()
              special_chars[char] = int(char_idx)
          
          # Lecture des tokens spéciaux
          num_special = int(f.readline().strip())
          for _ in range(num_special):
              special, special_idx = f.readline().strip().split()
              special_tokens[special] = int(special_idx)
          
          # Création du vocabulaire de base
          base_vocab = {}
          # Ajouter les caractères ASCII
          for i in range(256):
              base_vocab[i] = bytes([i])
          # Ajouter les caractères spéciaux
          for char, idx in special_chars.items():
              base_vocab[idx] = char.encode('utf-8')
          # Ajouter les tokens spéciaux
          for token, idx in special_tokens.items():
              base_vocab[idx] = token.encode('utf-8')
              
          # Lecture des fusions
          for line in f:
              try:
                  idx1, idx2 = map(int, line.strip().split())
                  if idx1 not in base_vocab or idx2 not in base_vocab:
                      print(f"Warning: skipping fusion for indices {idx1}, {idx2} - not found in vocabulary")
                      continue
                  next_idx = len(base_vocab)
                  merges[(idx1, idx2)] = next_idx
                  base_vocab[next_idx] = base_vocab[idx1] + base_vocab[idx2]
              except Exception as e:
                  print(f"Error processing line: {line.strip()}")
                  print(f"Current vocabulary keys: {sorted(base_vocab.keys())}")
                  raise e

      self.merges = merges
      self.special_tokens = special_tokens
      self.special_chars = special_chars
      self.vocab = base_vocab
      
      return self

    def encode(self, text):
        """
        Encode le texte baoulé en liste d'identifiants entiers.
        Gère les caractères spéciaux baoulé et les digraphes.
        """
        # Pattern pour identifier les tokens spéciaux
        special_pattern = "(" + "|".join(re.escape(k) for k in self.special_tokens) + ")"
        special_chunks = re.split(special_pattern, text)
        
        ids = []

        for part in special_chunks:
            # Gestion des tokens spéciaux
            if part in self.special_tokens:
                ids.append(self.special_tokens[part])
            elif part:  # Ignorer les parties vides
                # Découpage du texte en chunks selon le pattern
                text_chunks = re.findall(self.compiled_pattern, part)

                for chunk in text_chunks:
                    chunk_ids = []
                    i = 0
                    
                    # Traitement caractère par caractère avec gestion des digraphes
                    while i < len(chunk):
                        # Vérification des digraphes baoulé (gb, kp, ny)
                        if i < len(chunk) - 1:
                            digraph = chunk[i:i+2]
                            if digraph.lower() in ['gb', 'kp', 'ny']:
                                chunk_ids.extend([ord(digraph[0]), ord(digraph[1])])
                                i += 2
                                continue
                        
                        # Vérification des voyelles nasales
                        if i < len(chunk) - 1 and chunk[i+1] == 'n':
                            current_char = chunk[i]
                            if current_char in 'aɛiɔu':
                                # Traiter la voyelle nasale comme une unité
                                nasal_vowel = chunk[i:i+2]
                                chunk_ids.extend(list(nasal_vowel.encode('utf-8')))
                                i += 2
                                continue
                        
                        # Gestion des caractères spéciaux baoulé
                        current_char = chunk[i]
                        if current_char in self.special_chars:
                            chunk_ids.append(self.special_chars[current_char])
                        else:
                            # Encodage UTF-8 standard pour les autres caractères
                            chunk_ids.extend(list(current_char.encode('utf-8')))
                        i += 1

                    # Application des fusions (byte-pair encoding)
                    while len(chunk_ids) >= 2:
                        stats = get_stats(chunk_ids)
                        pair = min(stats, key=lambda p: self.merges.get(p, float('inf')))
                        
                        if pair not in self.merges:
                            break
                            
                        idx = self.merges[pair]
                        chunk_ids = merge(chunk_ids, pair, idx)
                    
                    ids.extend(chunk_ids)
        
        return ids

    def decode(self, ids):
        """
        Décode une liste d'identifiants en texte baoulé.
        Gère la reconstruction des caractères spéciaux et des digraphes.
        """
        part_bytes = []
        inverse_special_tokens = {v: k for k, v in self.special_tokens.items()}
        inverse_special_chars = {v: k for k, v in self.special_chars.items()}

        i = 0
        while i < len(ids):
            current_id = ids[i]
            
            # Gestion des tokens spéciaux
            if current_id in inverse_special_tokens:
                part_bytes.append(inverse_special_tokens[current_id].encode('utf-8'))
                i += 1
                continue
                
            # Gestion des caractères spéciaux baoulé
            if current_id in inverse_special_chars:
                part_bytes.append(inverse_special_chars[current_id].encode('utf-8'))
                i += 1
                continue
                
            # Gestion du vocabulaire standard
            if current_id in self.vocab:
                # Vérification des digraphes potentiels
                if i < len(ids) - 1:
                    next_id = ids[i + 1]
                    current_bytes = self.vocab[current_id]
                    if next_id in self.vocab:
                        next_bytes = self.vocab[next_id]
                        combined = current_bytes + next_bytes
                        # Vérification si c'est un digraphe baoulé
                        try:
                            combined_str = combined.decode('utf-8')
                            if combined_str.lower() in ['gb', 'kp', 'ny']:
                                part_bytes.append(combined)
                                i += 2
                                continue
                        except UnicodeDecodeError:
                            pass
                
                part_bytes.append(self.vocab[current_id])
                i += 1
            else:
                raise ValueError(f"ID de token invalide: {current_id}")

        # Reconstruction du texte final
        text_bytes = b''.join(part_bytes)
        text = text_bytes.decode('utf-8', errors='replace')
        
        return text




# Chargement du dataset depuis HuggingFace
dataset = load_dataset("Adjoumani/translations_french_baoule_V1")

# Configuration
vocab_size = 512
output_dir = "./models"
os.makedirs(output_dir, exist_ok=True)

# Initialisation et entraînement
tokenizer = BaouleTokenizer()
start_time = time.time()
tokenizer.train(dataset, vocab_size)
end_time = time.time()

# Sauvegarde
tokenizer.save(f"{output_dir}/baoule_tokenizer")

print(f"Durée d'entraînement: {end_time-start_time:.2f} secondes")
print(f"Modèle sauvegardé dans: {output_dir}/baoule_tokenizer")