baoule_tokenizer / baoule_tokenizer.py
Adjoumani's picture
Upload baoule_tokenizer.py with huggingface_hub
155a79d verified
raw
history blame
15.3 kB
import unicodedata
import regex as re
from datasets import load_dataset
import time
import os
def get_stats(ids, stats=None):
"""
Calcule la fréquence des paires d'ids consécutifs.
Conserve la même logique que la version originale car cette fonction est indépendante
des spécificités de la langue.
"""
stats = {} if stats is None else stats
for pair in zip(ids, ids[1:]):
stats[pair] = stats.get(pair, 0) + 1
return stats
def merge(ids, pair, idx):
"""
Fusionne les paires d'ids identifiées.
Conserve la même logique que la version originale car cette fonction gère
uniquement la fusion des tokens déjà identifiés.
"""
newids = []
i = 0
while i < len(ids):
if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
newids.append(idx)
i += 2
else:
newids.append(ids[i])
i += 1
return newids
def replace_control_characters(s: str) -> str:
"""
Remplace les caractères de contrôle, avec une attention particulière aux
caractères spéciaux du baoulé.
"""
chars = []
for ch in s:
# Gestion spéciale des caractères baoulé
if ch in ['ɛ', 'ɔ', 'ŋ', 'ɲ']: # Caractères spéciaux baoulé
chars.append(ch)
# Gestion standard des caractères de contrôle
elif unicodedata.category(ch)[0] != "C":
chars.append(ch)
else:
chars.append(f"\u{ord(ch):04x}")
return "".join(chars)
def render_token(t: bytes) -> str:
"""
Décode les tokens en gérant les caractères spéciaux du baoulé.
"""
try:
# Tentative de décodage standard
s = t.decode('utf-8', errors='replace')
# Gestion des caractères spéciaux baoulé
s = replace_control_characters(s)
return s
except UnicodeDecodeError:
# En cas d'échec, retourne le caractère de remplacement
return '�'
class BaouleTokenizer:
def __init__(self):
# Initialisation des attributs obligatoires
self.special_chars = {
'ɛ': 256,
'ɔ': 257,
'ŋ': 258,
'ɲ': 259
}
self.pattern = r"(?i:'n|gb|kp|ny|[ɛɔ]n)|(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^
\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[
]*|\s*[
]+|\s+(?!\S)|\s+"
self.compiled_pattern = re.compile(self.pattern)
self.special_tokens = {
#'<|baoule|>': 1101,
#'<|french|>': 1102,
#'<|end|>': 1103,
#'<|unknown|>': 1104,
#'<|pad|>': 1105,
# or
'<|begin_of_text|>': 1101,
'<|end_of_text|>': 1102,
'<|start_header_id|>': 1103,
'<|end_header_id|>': 1104,
'<|eot_id|>': 1105
}
self.merges = {}
self.vocab = self._build_vocab()
def train(self, dataset, vocab_size):
assert vocab_size >= 260 # 256 ASCII + 4 caractères spéciaux baoulé minimum
# Extraction des textes baoulé du dataset HuggingFace
text_chunks = []
for item in dataset['train']:
chunks = re.findall(self.compiled_pattern, item['baoule'])
text_chunks.extend(chunks)
# Conversion en ids avec gestion spéciale des caractères baoulé
ids = []
for chunk in text_chunks:
chunk_ids = []
i = 0
while i < len(chunk):
# Vérification des digraphes
if i < len(chunk) - 1:
digraph = chunk[i:i+2]
if digraph in ['gb', 'kp', 'ny']:
chunk_ids.append(ord(digraph[0]))
chunk_ids.append(ord(digraph[1]))
i += 2
continue
# Gestion des caractères spéciaux baoulé
char = chunk[i]
if char in self.special_chars:
chunk_ids.append(self.special_chars[char])
else:
# Encodage UTF-8 standard pour les autres caractères
chunk_ids.extend(list(char.encode("utf-8")))
i += 1
ids.append(chunk_ids)
# Calcul des fusions
num_merges = vocab_size - (260 + len(self.special_tokens))
merges = {}
vocab = {idx: bytes([idx]) for idx in range(256)}
vocab.update({idx: char.encode('utf-8') for char, idx in self.special_chars.items()})
for i in range(num_merges):
stats = {}
for chunk_ids in ids:
get_stats(chunk_ids, stats)
if not stats:
break
pair = max(stats, key=stats.get)
idx = 260 + len(self.special_tokens) + i
ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids]
merges[pair] = idx
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
self.merges = merges
self.vocab = vocab
def _build_vocab(self):
# Vocabulaire de base incluant ASCII et caractères spéciaux baoulé
vocab = {idx: bytes([idx]) for idx in range(256)}
vocab.update({idx: char.encode('utf-8') for char, idx in self.special_chars.items()})
# Ajout des fusions
for (p0, p1), idx in self.merges.items():
vocab[idx] = vocab[p0] + vocab[p1]
# Ajout des tokens spéciaux
for special, idx in self.special_tokens.items():
vocab[idx] = special.encode("utf-8")
return vocab
def save(self, file_prefix):
# Sauvegarde du modèle
model_file = file_prefix + ".model"
with open(model_file, 'w') as f:
f.write("baoule tokenizer v1.0
")
f.write(f"{self.pattern}
")
# Sauvegarde des caractères spéciaux baoulé
f.write(f"{len(self.special_chars)}
")
for char, idx in self.special_chars.items():
f.write(f"{char} {idx}
")
# Sauvegarde des tokens spéciaux
f.write(f"{len(self.special_tokens)}
")
for token, idx in self.special_tokens.items():
f.write(f"{token} {idx}
")
# Sauvegarde des fusions
for idx1, idx2 in self.merges:
f.write(f"{idx1} {idx2}
")
# Sauvegarde du vocabulaire
vocab_file = file_prefix + ".vocab"
inverted_merges = {idx: pair for pair, idx in self.merges.items()}
with open(vocab_file, "w", encoding="utf-8") as f:
for idx, token in self.vocab.items():
s = render_token(token)
if idx in inverted_merges:
idx0, idx1 = inverted_merges[idx]
s0 = render_token(self.vocab[idx0])
s1 = render_token(self.vocab[idx1])
f.write(f"[{s0}][{s1}] -> [{s}] {idx}
")
else:
f.write(f"[{s}] {idx}
")
def load(self, model_file):
merges = {}
special_tokens = {}
special_chars = {}
with open(model_file, 'r', encoding="utf-8") as f:
version = f.readline().strip()
self.pattern = f.readline().strip()
self.compiled_pattern = re.compile(self.pattern)
# Lecture des caractères spéciaux baoulé
num_special_chars = int(f.readline().strip())
for _ in range(num_special_chars):
char, char_idx = f.readline().strip().split()
special_chars[char] = int(char_idx)
# Lecture des tokens spéciaux
num_special = int(f.readline().strip())
for _ in range(num_special):
special, special_idx = f.readline().strip().split()
special_tokens[special] = int(special_idx)
# Création du vocabulaire de base
base_vocab = {}
# Ajouter les caractères ASCII
for i in range(256):
base_vocab[i] = bytes([i])
# Ajouter les caractères spéciaux
for char, idx in special_chars.items():
base_vocab[idx] = char.encode('utf-8')
# Ajouter les tokens spéciaux
for token, idx in special_tokens.items():
base_vocab[idx] = token.encode('utf-8')
# Lecture des fusions
for line in f:
try:
idx1, idx2 = map(int, line.strip().split())
if idx1 not in base_vocab or idx2 not in base_vocab:
print(f"Warning: skipping fusion for indices {idx1}, {idx2} - not found in vocabulary")
continue
next_idx = len(base_vocab)
merges[(idx1, idx2)] = next_idx
base_vocab[next_idx] = base_vocab[idx1] + base_vocab[idx2]
except Exception as e:
print(f"Error processing line: {line.strip()}")
print(f"Current vocabulary keys: {sorted(base_vocab.keys())}")
raise e
self.merges = merges
self.special_tokens = special_tokens
self.special_chars = special_chars
self.vocab = base_vocab
return self
def encode(self, text):
"""
Encode le texte baoulé en liste d'identifiants entiers.
Gère les caractères spéciaux baoulé et les digraphes.
"""
# Pattern pour identifier les tokens spéciaux
special_pattern = "(" + "|".join(re.escape(k) for k in self.special_tokens) + ")"
special_chunks = re.split(special_pattern, text)
ids = []
for part in special_chunks:
# Gestion des tokens spéciaux
if part in self.special_tokens:
ids.append(self.special_tokens[part])
elif part: # Ignorer les parties vides
# Découpage du texte en chunks selon le pattern
text_chunks = re.findall(self.compiled_pattern, part)
for chunk in text_chunks:
chunk_ids = []
i = 0
# Traitement caractère par caractère avec gestion des digraphes
while i < len(chunk):
# Vérification des digraphes baoulé (gb, kp, ny)
if i < len(chunk) - 1:
digraph = chunk[i:i+2]
if digraph.lower() in ['gb', 'kp', 'ny']:
chunk_ids.extend([ord(digraph[0]), ord(digraph[1])])
i += 2
continue
# Vérification des voyelles nasales
if i < len(chunk) - 1 and chunk[i+1] == 'n':
current_char = chunk[i]
if current_char in 'aɛiɔu':
# Traiter la voyelle nasale comme une unité
nasal_vowel = chunk[i:i+2]
chunk_ids.extend(list(nasal_vowel.encode('utf-8')))
i += 2
continue
# Gestion des caractères spéciaux baoulé
current_char = chunk[i]
if current_char in self.special_chars:
chunk_ids.append(self.special_chars[current_char])
else:
# Encodage UTF-8 standard pour les autres caractères
chunk_ids.extend(list(current_char.encode('utf-8')))
i += 1
# Application des fusions (byte-pair encoding)
while len(chunk_ids) >= 2:
stats = get_stats(chunk_ids)
pair = min(stats, key=lambda p: self.merges.get(p, float('inf')))
if pair not in self.merges:
break
idx = self.merges[pair]
chunk_ids = merge(chunk_ids, pair, idx)
ids.extend(chunk_ids)
return ids
def decode(self, ids):
"""
Décode une liste d'identifiants en texte baoulé.
Gère la reconstruction des caractères spéciaux et des digraphes.
"""
part_bytes = []
inverse_special_tokens = {v: k for k, v in self.special_tokens.items()}
inverse_special_chars = {v: k for k, v in self.special_chars.items()}
i = 0
while i < len(ids):
current_id = ids[i]
# Gestion des tokens spéciaux
if current_id in inverse_special_tokens:
part_bytes.append(inverse_special_tokens[current_id].encode('utf-8'))
i += 1
continue
# Gestion des caractères spéciaux baoulé
if current_id in inverse_special_chars:
part_bytes.append(inverse_special_chars[current_id].encode('utf-8'))
i += 1
continue
# Gestion du vocabulaire standard
if current_id in self.vocab:
# Vérification des digraphes potentiels
if i < len(ids) - 1:
next_id = ids[i + 1]
current_bytes = self.vocab[current_id]
if next_id in self.vocab:
next_bytes = self.vocab[next_id]
combined = current_bytes + next_bytes
# Vérification si c'est un digraphe baoulé
try:
combined_str = combined.decode('utf-8')
if combined_str.lower() in ['gb', 'kp', 'ny']:
part_bytes.append(combined)
i += 2
continue
except UnicodeDecodeError:
pass
part_bytes.append(self.vocab[current_id])
i += 1
else:
raise ValueError(f"ID de token invalide: {current_id}")
# Reconstruction du texte final
text_bytes = b''.join(part_bytes)
text = text_bytes.decode('utf-8', errors='replace')
return text
# Chargement du dataset depuis HuggingFace
dataset = load_dataset("Adjoumani/translations_french_baoule_V1")
# Configuration
vocab_size = 512
output_dir = "./models"
os.makedirs(output_dir, exist_ok=True)
# Initialisation et entraînement
tokenizer = BaouleTokenizer()
start_time = time.time()
tokenizer.train(dataset, vocab_size)
end_time = time.time()
# Sauvegarde
tokenizer.save(f"{output_dir}/baoule_tokenizer")
print(f"Durée d'entraînement: {end_time-start_time:.2f} secondes")
print(f"Modèle sauvegardé dans: {output_dir}/baoule_tokenizer")