File size: 8,029 Bytes
f1c672a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 |
from typing import List, Dict, Optional
from tqdm import tqdm
from collections import Counter
from matplotlib import pyplot as plt
import json
from pathlib import Path
class TrieNode:
"""Node in the prefix tree (trie) for fast token matching"""
def __init__(self):
self.children = {}
self.is_token = False
self.token = None
class BytePairEncoder:
def __init__(self, text: str):
# Initialize vocabulary from characters
self.chars = sorted(list(set(text)))
self.stoi = {ch: i for i, ch in enumerate(self.chars)}
self.itos = {i: ch for i, ch in enumerate(self.chars)}
# Initial encoding of text
self.data = [self.stoi[c] for c in text]
# Statistics tracking
self.stats = {
"vocab_sizes": [len(self.chars)],
"data_sizes": [len(self.data)],
"compression_ratios": [1.0],
"merge_counts": [],
"tokens_created": [],
"max_token_lengths": [1],
}
# Store original length for compression ratio
self.original_length = len(self.data)
self.max_token_length = 1
def get_digram_stats(self) -> Counter:
"""Get digram counts"""
counts = Counter()
for pair in zip(self.data, self.data[1:]):
pair = (int(pair[0]), int(pair[1]))
counts[pair] += 1
return counts
def encode_to_vocab_size(self, target_vocab_size: int, plot_interval: Optional[int] = None,
print_interval: int = 100) -> None:
"""Train until reaching target vocabulary size"""
pbar = tqdm(total=target_vocab_size, desc="Training BPE", initial=len(self.chars))
iteration = 0
while len(self.itos) < target_vocab_size:
result = self._merge_step()
if result is None:
break
iteration += 1
pbar.update(1)
if print_interval and iteration % print_interval == 0:
self._print_progress(iteration)
if plot_interval and iteration % plot_interval == 0:
self.plot_statistics(iteration=iteration)
pbar.close()
def _merge_step(self):
"""Perform one merge operation"""
stats = self.get_digram_stats()
if not stats:
return None
top_pair, count = max(stats.items(), key=lambda x: x[1])
new_token = self._add_token(top_pair)
self.data = self._replace_pairs(top_pair, new_token)
self._update_stats(count)
return new_token, count
def _add_token(self, pair: tuple) -> int:
"""Add new token to vocabulary"""
token_str = self.itos[pair[0]] + self.itos[pair[1]]
token_id = len(self.itos)
self.stoi[token_str] = token_id
self.itos[token_id] = token_str
self.max_token_length = max(self.max_token_length, len(token_str))
return token_id
def _replace_pairs(self, pair: tuple, new_token: int) -> List[int]:
"""Replace all occurrences of pair with new token"""
result = []
i = 0
while i < len(self.data):
if i < len(self.data) - 1 and self.data[i] == pair[0] and self.data[i + 1] == pair[1]:
result.append(new_token)
i += 2
else:
result.append(self.data[i])
i += 1
return result
def _update_stats(self, merge_count: int):
"""Update training statistics"""
self.stats["vocab_sizes"].append(len(self.itos))
self.stats["data_sizes"].append(len(self.data))
compression = self.original_length / len(self.data)
self.stats["compression_ratios"].append(compression)
self.stats["merge_counts"].append(merge_count)
self.stats["tokens_created"].append(self.itos[len(self.itos)-1])
self.stats["max_token_lengths"].append(self.max_token_length)
def plot_statistics(self, iteration: Optional[int] = None):
"""Plot training statistics"""
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
# Plot training metrics
ax1.plot(self.stats["vocab_sizes"], self.stats["data_sizes"])
ax1.set_title("Vocabulary vs Dataset Size")
ax2.plot(self.stats["vocab_sizes"], self.stats["compression_ratios"])
ax2.set_title("Compression Ratio Progress")
if self.stats["merge_counts"]:
ax3.hist(self.stats["merge_counts"], bins=30)
ax3.set_title("Merge Counts Distribution")
if self.stats["tokens_created"]:
lengths = [len(t) for t in self.stats["tokens_created"]]
ax4.plot(range(len(lengths)), lengths)
ax4.set_title("Token Length Evolution")
plt.tight_layout()
plt.show()
def save_to_file(self, filepath: Path):
"""Save encoder state"""
state = {
"chars": self.chars,
"stoi": self.stoi,
"max_token_length": self.max_token_length,
"stats": self.stats
}
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(state, f, ensure_ascii=False, indent=2)
@classmethod
def load_from_file(cls, filepath: Path):
"""Load encoder state"""
with open(filepath, 'r', encoding='utf-8') as f:
state = json.load(f)
instance = cls("") # Create empty instance
instance.chars = state["chars"]
instance.stoi = state["stoi"]
instance.itos = {int(i): s for s, i in state["stoi"].items()}
instance.max_token_length = state["max_token_length"]
instance.stats = state["stats"]
return instance
def _print_progress(self, iteration: int):
"""Print training progress"""
print(f"\nIteration {iteration}:")
print(f"Vocabulary size: {len(self.itos):,}")
print(f"Data size: {len(self.data):,}")
print(f"Compression ratio: {self.stats['compression_ratios'][-1]:.2f}")
if self.stats["merge_counts"]:
last_merge = self.stats["merge_counts"][-1]
last_token = self.stats["tokens_created"][-1]
print(f"Last merge count: {last_merge:,}")
print(f"Last token created: '{last_token}'")
print(f"Max token length: {self.max_token_length}")
class TokenizerInternal:
"""Tokenizer using trained BPE model"""
def __init__(self, encoder: BytePairEncoder):
self.stoi = encoder.stoi
self.max_token_length = encoder.max_token_length
self._trie = self._build_trie()
def _build_trie(self) -> TrieNode:
"""Build trie for efficient tokenization"""
root = TrieNode()
for token in self.stoi:
node = root
for char in token:
if char not in node.children:
node.children[char] = TrieNode()
node = node.children[char]
node.is_token = True
node.token = token
return root
def tokenize(self, text: str) -> List[str]:
"""Tokenize text using trie-based matching"""
tokens = []
pos = 0
while pos < len(text):
token = self._find_longest_token(text[pos:])
tokens.append(token)
pos += len(token)
return tokens
def _find_longest_token(self, text: str) -> str:
"""Find longest matching token starting at current position"""
node = self._trie
longest = text[0]
current = ""
for char in text[:self.max_token_length]:
if char not in node.children:
break
current += char
node = node.children[char]
if node.is_token:
longest = node.token
return longest |