|
from typing import List, Dict, Optional |
|
from tqdm import tqdm |
|
from collections import Counter |
|
from matplotlib import pyplot as plt |
|
import json |
|
from pathlib import Path |
|
|
|
class TrieNode: |
|
"""Node in the prefix tree (trie) for fast token matching""" |
|
def __init__(self): |
|
self.children = {} |
|
self.is_token = False |
|
self.token = None |
|
|
|
class BytePairEncoder: |
|
def __init__(self, text: str): |
|
|
|
self.chars = sorted(list(set(text))) |
|
self.stoi = {ch: i for i, ch in enumerate(self.chars)} |
|
self.itos = {i: ch for i, ch in enumerate(self.chars)} |
|
|
|
|
|
self.data = [self.stoi[c] for c in text] |
|
|
|
|
|
self.stats = { |
|
"vocab_sizes": [len(self.chars)], |
|
"data_sizes": [len(self.data)], |
|
"compression_ratios": [1.0], |
|
"merge_counts": [], |
|
"tokens_created": [], |
|
"max_token_lengths": [1], |
|
} |
|
|
|
|
|
self.original_length = len(self.data) |
|
self.max_token_length = 1 |
|
|
|
def get_digram_stats(self) -> Counter: |
|
"""Get digram counts""" |
|
counts = Counter() |
|
for pair in zip(self.data, self.data[1:]): |
|
pair = (int(pair[0]), int(pair[1])) |
|
counts[pair] += 1 |
|
return counts |
|
|
|
def encode_to_vocab_size(self, target_vocab_size: int, plot_interval: Optional[int] = None, |
|
print_interval: int = 100) -> None: |
|
"""Train until reaching target vocabulary size""" |
|
pbar = tqdm(total=target_vocab_size, desc="Training BPE", initial=len(self.chars)) |
|
|
|
iteration = 0 |
|
while len(self.itos) < target_vocab_size: |
|
result = self._merge_step() |
|
if result is None: |
|
break |
|
|
|
iteration += 1 |
|
pbar.update(1) |
|
|
|
if print_interval and iteration % print_interval == 0: |
|
self._print_progress(iteration) |
|
|
|
if plot_interval and iteration % plot_interval == 0: |
|
self.plot_statistics(iteration=iteration) |
|
|
|
pbar.close() |
|
|
|
def _merge_step(self): |
|
"""Perform one merge operation""" |
|
stats = self.get_digram_stats() |
|
if not stats: |
|
return None |
|
|
|
top_pair, count = max(stats.items(), key=lambda x: x[1]) |
|
new_token = self._add_token(top_pair) |
|
self.data = self._replace_pairs(top_pair, new_token) |
|
self._update_stats(count) |
|
|
|
return new_token, count |
|
|
|
def _add_token(self, pair: tuple) -> int: |
|
"""Add new token to vocabulary""" |
|
token_str = self.itos[pair[0]] + self.itos[pair[1]] |
|
token_id = len(self.itos) |
|
self.stoi[token_str] = token_id |
|
self.itos[token_id] = token_str |
|
self.max_token_length = max(self.max_token_length, len(token_str)) |
|
return token_id |
|
|
|
def _replace_pairs(self, pair: tuple, new_token: int) -> List[int]: |
|
"""Replace all occurrences of pair with new token""" |
|
result = [] |
|
i = 0 |
|
while i < len(self.data): |
|
if i < len(self.data) - 1 and self.data[i] == pair[0] and self.data[i + 1] == pair[1]: |
|
result.append(new_token) |
|
i += 2 |
|
else: |
|
result.append(self.data[i]) |
|
i += 1 |
|
return result |
|
|
|
def _update_stats(self, merge_count: int): |
|
"""Update training statistics""" |
|
self.stats["vocab_sizes"].append(len(self.itos)) |
|
self.stats["data_sizes"].append(len(self.data)) |
|
compression = self.original_length / len(self.data) |
|
self.stats["compression_ratios"].append(compression) |
|
self.stats["merge_counts"].append(merge_count) |
|
self.stats["tokens_created"].append(self.itos[len(self.itos)-1]) |
|
self.stats["max_token_lengths"].append(self.max_token_length) |
|
|
|
def plot_statistics(self, iteration: Optional[int] = None): |
|
"""Plot training statistics""" |
|
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10)) |
|
|
|
|
|
ax1.plot(self.stats["vocab_sizes"], self.stats["data_sizes"]) |
|
ax1.set_title("Vocabulary vs Dataset Size") |
|
|
|
ax2.plot(self.stats["vocab_sizes"], self.stats["compression_ratios"]) |
|
ax2.set_title("Compression Ratio Progress") |
|
|
|
if self.stats["merge_counts"]: |
|
ax3.hist(self.stats["merge_counts"], bins=30) |
|
ax3.set_title("Merge Counts Distribution") |
|
|
|
if self.stats["tokens_created"]: |
|
lengths = [len(t) for t in self.stats["tokens_created"]] |
|
ax4.plot(range(len(lengths)), lengths) |
|
ax4.set_title("Token Length Evolution") |
|
|
|
plt.tight_layout() |
|
plt.show() |
|
|
|
def save_to_file(self, filepath: Path): |
|
"""Save encoder state""" |
|
state = { |
|
"chars": self.chars, |
|
"stoi": self.stoi, |
|
"max_token_length": self.max_token_length, |
|
"stats": self.stats |
|
} |
|
with open(filepath, 'w', encoding='utf-8') as f: |
|
json.dump(state, f, ensure_ascii=False, indent=2) |
|
|
|
@classmethod |
|
def load_from_file(cls, filepath: Path): |
|
"""Load encoder state""" |
|
with open(filepath, 'r', encoding='utf-8') as f: |
|
state = json.load(f) |
|
|
|
instance = cls("") |
|
instance.chars = state["chars"] |
|
instance.stoi = state["stoi"] |
|
instance.itos = {int(i): s for s, i in state["stoi"].items()} |
|
instance.max_token_length = state["max_token_length"] |
|
instance.stats = state["stats"] |
|
|
|
return instance |
|
|
|
def _print_progress(self, iteration: int): |
|
"""Print training progress""" |
|
print(f"\nIteration {iteration}:") |
|
print(f"Vocabulary size: {len(self.itos):,}") |
|
print(f"Data size: {len(self.data):,}") |
|
print(f"Compression ratio: {self.stats['compression_ratios'][-1]:.2f}") |
|
|
|
if self.stats["merge_counts"]: |
|
last_merge = self.stats["merge_counts"][-1] |
|
last_token = self.stats["tokens_created"][-1] |
|
print(f"Last merge count: {last_merge:,}") |
|
print(f"Last token created: '{last_token}'") |
|
|
|
print(f"Max token length: {self.max_token_length}") |
|
|
|
class TokenizerInternal: |
|
"""Tokenizer using trained BPE model""" |
|
def __init__(self, encoder: BytePairEncoder): |
|
self.stoi = encoder.stoi |
|
self.max_token_length = encoder.max_token_length |
|
self._trie = self._build_trie() |
|
|
|
def _build_trie(self) -> TrieNode: |
|
"""Build trie for efficient tokenization""" |
|
root = TrieNode() |
|
for token in self.stoi: |
|
node = root |
|
for char in token: |
|
if char not in node.children: |
|
node.children[char] = TrieNode() |
|
node = node.children[char] |
|
node.is_token = True |
|
node.token = token |
|
return root |
|
|
|
def tokenize(self, text: str) -> List[str]: |
|
"""Tokenize text using trie-based matching""" |
|
tokens = [] |
|
pos = 0 |
|
while pos < len(text): |
|
token = self._find_longest_token(text[pos:]) |
|
tokens.append(token) |
|
pos += len(token) |
|
return tokens |
|
|
|
def _find_longest_token(self, text: str) -> str: |
|
"""Find longest matching token starting at current position""" |
|
node = self._trie |
|
longest = text[0] |
|
current = "" |
|
|
|
for char in text[:self.max_token_length]: |
|
if char not in node.children: |
|
break |
|
current += char |
|
node = node.children[char] |
|
if node.is_token: |
|
longest = node.token |
|
|
|
return longest |