|
|
|
import unicodedata |
|
import regex as re |
|
from datasets import load_dataset |
|
import time |
|
import os |
|
|
|
def get_stats(ids, stats=None): |
|
""" |
|
Calcule la fréquence des paires d'ids consécutifs. |
|
Conserve la même logique que la version originale car cette fonction est indépendante |
|
des spécificités de la langue. |
|
""" |
|
stats = {} if stats is None else stats |
|
for pair in zip(ids, ids[1:]): |
|
stats[pair] = stats.get(pair, 0) + 1 |
|
return stats |
|
|
|
def merge(ids, pair, idx): |
|
""" |
|
Fusionne les paires d'ids identifiées. |
|
Conserve la même logique que la version originale car cette fonction gère |
|
uniquement la fusion des tokens déjà identifiés. |
|
""" |
|
newids = [] |
|
i = 0 |
|
while i < len(ids): |
|
if ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]: |
|
newids.append(idx) |
|
i += 2 |
|
else: |
|
newids.append(ids[i]) |
|
i += 1 |
|
return newids |
|
|
|
def replace_control_characters(s: str) -> str: |
|
""" |
|
Remplace les caractères de contrôle, avec une attention particulière aux |
|
caractères spéciaux du baoulé. |
|
""" |
|
chars = [] |
|
for ch in s: |
|
|
|
if ch in ['ɛ', 'ɔ', 'ŋ', 'ɲ']: |
|
chars.append(ch) |
|
|
|
elif unicodedata.category(ch)[0] != "C": |
|
chars.append(ch) |
|
else: |
|
chars.append(f"\u{ord(ch):04x}") |
|
return "".join(chars) |
|
|
|
def render_token(t: bytes) -> str: |
|
""" |
|
Décode les tokens en gérant les caractères spéciaux du baoulé. |
|
""" |
|
try: |
|
|
|
s = t.decode('utf-8', errors='replace') |
|
|
|
s = replace_control_characters(s) |
|
return s |
|
except UnicodeDecodeError: |
|
|
|
return '�' |
|
|
|
|
|
|
|
class BaouleTokenizer: |
|
def __init__(self): |
|
|
|
self.special_chars = { |
|
'ɛ': 256, |
|
'ɔ': 257, |
|
'ŋ': 258, |
|
'ɲ': 259 |
|
} |
|
|
|
self.pattern = r"(?i:'n|gb|kp|ny|[ɛɔ]n)|(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^
|
|
\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[
|
|
]*|\s*[
|
|
]+|\s+(?!\S)|\s+" |
|
self.compiled_pattern = re.compile(self.pattern) |
|
|
|
self.special_tokens = { |
|
|
|
|
|
|
|
|
|
|
|
|
|
'<|begin_of_text|>': 1101, |
|
'<|end_of_text|>': 1102, |
|
'<|start_header_id|>': 1103, |
|
'<|end_header_id|>': 1104, |
|
'<|eot_id|>': 1105 |
|
} |
|
|
|
self.merges = {} |
|
self.vocab = self._build_vocab() |
|
|
|
def train(self, dataset, vocab_size): |
|
assert vocab_size >= 260 |
|
|
|
|
|
text_chunks = [] |
|
for item in dataset['train']: |
|
chunks = re.findall(self.compiled_pattern, item['baoule']) |
|
text_chunks.extend(chunks) |
|
|
|
|
|
ids = [] |
|
for chunk in text_chunks: |
|
chunk_ids = [] |
|
i = 0 |
|
while i < len(chunk): |
|
|
|
if i < len(chunk) - 1: |
|
digraph = chunk[i:i+2] |
|
if digraph in ['gb', 'kp', 'ny']: |
|
chunk_ids.append(ord(digraph[0])) |
|
chunk_ids.append(ord(digraph[1])) |
|
i += 2 |
|
continue |
|
|
|
|
|
char = chunk[i] |
|
if char in self.special_chars: |
|
chunk_ids.append(self.special_chars[char]) |
|
else: |
|
|
|
chunk_ids.extend(list(char.encode("utf-8"))) |
|
i += 1 |
|
ids.append(chunk_ids) |
|
|
|
|
|
num_merges = vocab_size - (260 + len(self.special_tokens)) |
|
merges = {} |
|
vocab = {idx: bytes([idx]) for idx in range(256)} |
|
vocab.update({idx: char.encode('utf-8') for char, idx in self.special_chars.items()}) |
|
|
|
for i in range(num_merges): |
|
stats = {} |
|
for chunk_ids in ids: |
|
get_stats(chunk_ids, stats) |
|
if not stats: |
|
break |
|
pair = max(stats, key=stats.get) |
|
idx = 260 + len(self.special_tokens) + i |
|
ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids] |
|
merges[pair] = idx |
|
vocab[idx] = vocab[pair[0]] + vocab[pair[1]] |
|
|
|
self.merges = merges |
|
self.vocab = vocab |
|
|
|
def _build_vocab(self): |
|
|
|
vocab = {idx: bytes([idx]) for idx in range(256)} |
|
vocab.update({idx: char.encode('utf-8') for char, idx in self.special_chars.items()}) |
|
|
|
|
|
for (p0, p1), idx in self.merges.items(): |
|
vocab[idx] = vocab[p0] + vocab[p1] |
|
|
|
|
|
for special, idx in self.special_tokens.items(): |
|
vocab[idx] = special.encode("utf-8") |
|
|
|
return vocab |
|
|
|
def save(self, file_prefix): |
|
|
|
model_file = file_prefix + ".model" |
|
with open(model_file, 'w') as f: |
|
f.write("baoule tokenizer v1.0 |
|
") |
|
f.write(f"{self.pattern} |
|
") |
|
|
|
|
|
f.write(f"{len(self.special_chars)} |
|
") |
|
for char, idx in self.special_chars.items(): |
|
f.write(f"{char} {idx} |
|
") |
|
|
|
|
|
f.write(f"{len(self.special_tokens)} |
|
") |
|
for token, idx in self.special_tokens.items(): |
|
f.write(f"{token} {idx} |
|
") |
|
|
|
|
|
for idx1, idx2 in self.merges: |
|
f.write(f"{idx1} {idx2} |
|
") |
|
|
|
|
|
vocab_file = file_prefix + ".vocab" |
|
inverted_merges = {idx: pair for pair, idx in self.merges.items()} |
|
with open(vocab_file, "w", encoding="utf-8") as f: |
|
for idx, token in self.vocab.items(): |
|
s = render_token(token) |
|
if idx in inverted_merges: |
|
idx0, idx1 = inverted_merges[idx] |
|
s0 = render_token(self.vocab[idx0]) |
|
s1 = render_token(self.vocab[idx1]) |
|
f.write(f"[{s0}][{s1}] -> [{s}] {idx} |
|
") |
|
else: |
|
f.write(f"[{s}] {idx} |
|
") |
|
def load(self, model_file): |
|
merges = {} |
|
special_tokens = {} |
|
special_chars = {} |
|
|
|
with open(model_file, 'r', encoding="utf-8") as f: |
|
version = f.readline().strip() |
|
self.pattern = f.readline().strip() |
|
self.compiled_pattern = re.compile(self.pattern) |
|
|
|
|
|
num_special_chars = int(f.readline().strip()) |
|
for _ in range(num_special_chars): |
|
char, char_idx = f.readline().strip().split() |
|
special_chars[char] = int(char_idx) |
|
|
|
|
|
num_special = int(f.readline().strip()) |
|
for _ in range(num_special): |
|
special, special_idx = f.readline().strip().split() |
|
special_tokens[special] = int(special_idx) |
|
|
|
|
|
base_vocab = {} |
|
|
|
for i in range(256): |
|
base_vocab[i] = bytes([i]) |
|
|
|
for char, idx in special_chars.items(): |
|
base_vocab[idx] = char.encode('utf-8') |
|
|
|
for token, idx in special_tokens.items(): |
|
base_vocab[idx] = token.encode('utf-8') |
|
|
|
|
|
for line in f: |
|
try: |
|
idx1, idx2 = map(int, line.strip().split()) |
|
if idx1 not in base_vocab or idx2 not in base_vocab: |
|
print(f"Warning: skipping fusion for indices {idx1}, {idx2} - not found in vocabulary") |
|
continue |
|
next_idx = len(base_vocab) |
|
merges[(idx1, idx2)] = next_idx |
|
base_vocab[next_idx] = base_vocab[idx1] + base_vocab[idx2] |
|
except Exception as e: |
|
print(f"Error processing line: {line.strip()}") |
|
print(f"Current vocabulary keys: {sorted(base_vocab.keys())}") |
|
raise e |
|
|
|
self.merges = merges |
|
self.special_tokens = special_tokens |
|
self.special_chars = special_chars |
|
self.vocab = base_vocab |
|
|
|
return self |
|
|
|
def encode(self, text): |
|
""" |
|
Encode le texte baoulé en liste d'identifiants entiers. |
|
Gère les caractères spéciaux baoulé et les digraphes. |
|
""" |
|
|
|
special_pattern = "(" + "|".join(re.escape(k) for k in self.special_tokens) + ")" |
|
special_chunks = re.split(special_pattern, text) |
|
|
|
ids = [] |
|
|
|
for part in special_chunks: |
|
|
|
if part in self.special_tokens: |
|
ids.append(self.special_tokens[part]) |
|
elif part: |
|
|
|
text_chunks = re.findall(self.compiled_pattern, part) |
|
|
|
for chunk in text_chunks: |
|
chunk_ids = [] |
|
i = 0 |
|
|
|
|
|
while i < len(chunk): |
|
|
|
if i < len(chunk) - 1: |
|
digraph = chunk[i:i+2] |
|
if digraph.lower() in ['gb', 'kp', 'ny']: |
|
chunk_ids.extend([ord(digraph[0]), ord(digraph[1])]) |
|
i += 2 |
|
continue |
|
|
|
|
|
if i < len(chunk) - 1 and chunk[i+1] == 'n': |
|
current_char = chunk[i] |
|
if current_char in 'aɛiɔu': |
|
|
|
nasal_vowel = chunk[i:i+2] |
|
chunk_ids.extend(list(nasal_vowel.encode('utf-8'))) |
|
i += 2 |
|
continue |
|
|
|
|
|
current_char = chunk[i] |
|
if current_char in self.special_chars: |
|
chunk_ids.append(self.special_chars[current_char]) |
|
else: |
|
|
|
chunk_ids.extend(list(current_char.encode('utf-8'))) |
|
i += 1 |
|
|
|
|
|
while len(chunk_ids) >= 2: |
|
stats = get_stats(chunk_ids) |
|
pair = min(stats, key=lambda p: self.merges.get(p, float('inf'))) |
|
|
|
if pair not in self.merges: |
|
break |
|
|
|
idx = self.merges[pair] |
|
chunk_ids = merge(chunk_ids, pair, idx) |
|
|
|
ids.extend(chunk_ids) |
|
|
|
return ids |
|
|
|
def decode(self, ids): |
|
""" |
|
Décode une liste d'identifiants en texte baoulé. |
|
Gère la reconstruction des caractères spéciaux et des digraphes. |
|
""" |
|
part_bytes = [] |
|
inverse_special_tokens = {v: k for k, v in self.special_tokens.items()} |
|
inverse_special_chars = {v: k for k, v in self.special_chars.items()} |
|
|
|
i = 0 |
|
while i < len(ids): |
|
current_id = ids[i] |
|
|
|
|
|
if current_id in inverse_special_tokens: |
|
part_bytes.append(inverse_special_tokens[current_id].encode('utf-8')) |
|
i += 1 |
|
continue |
|
|
|
|
|
if current_id in inverse_special_chars: |
|
part_bytes.append(inverse_special_chars[current_id].encode('utf-8')) |
|
i += 1 |
|
continue |
|
|
|
|
|
if current_id in self.vocab: |
|
|
|
if i < len(ids) - 1: |
|
next_id = ids[i + 1] |
|
current_bytes = self.vocab[current_id] |
|
if next_id in self.vocab: |
|
next_bytes = self.vocab[next_id] |
|
combined = current_bytes + next_bytes |
|
|
|
try: |
|
combined_str = combined.decode('utf-8') |
|
if combined_str.lower() in ['gb', 'kp', 'ny']: |
|
part_bytes.append(combined) |
|
i += 2 |
|
continue |
|
except UnicodeDecodeError: |
|
pass |
|
|
|
part_bytes.append(self.vocab[current_id]) |
|
i += 1 |
|
else: |
|
raise ValueError(f"ID de token invalide: {current_id}") |
|
|
|
|
|
text_bytes = b''.join(part_bytes) |
|
text = text_bytes.decode('utf-8', errors='replace') |
|
|
|
return text |
|
|
|
|
|
|
|
|
|
|
|
dataset = load_dataset("Adjoumani/translations_french_baoule_V1") |
|
|
|
|
|
vocab_size = 512 |
|
output_dir = "./models" |
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
tokenizer = BaouleTokenizer() |
|
start_time = time.time() |
|
tokenizer.train(dataset, vocab_size) |
|
end_time = time.time() |
|
|
|
|
|
tokenizer.save(f"{output_dir}/baoule_tokenizer") |
|
|
|
print(f"Durée d'entraînement: {end_time-start_time:.2f} secondes") |
|
print(f"Modèle sauvegardé dans: {output_dir}/baoule_tokenizer") |
|
|