File size: 8,029 Bytes
f1c672a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
from typing import List, Dict, Optional
from tqdm import tqdm
from collections import Counter
from matplotlib import pyplot as plt
import json
from pathlib import Path

class TrieNode:
    """Node in the prefix tree (trie) for fast token matching"""
    def __init__(self):
        self.children = {}
        self.is_token = False
        self.token = None

class BytePairEncoder:
    def __init__(self, text: str):
        # Initialize vocabulary from characters
        self.chars = sorted(list(set(text)))
        self.stoi = {ch: i for i, ch in enumerate(self.chars)}
        self.itos = {i: ch for i, ch in enumerate(self.chars)}
        
        # Initial encoding of text
        self.data = [self.stoi[c] for c in text]
        
        # Statistics tracking
        self.stats = {
            "vocab_sizes": [len(self.chars)],
            "data_sizes": [len(self.data)],
            "compression_ratios": [1.0],
            "merge_counts": [],
            "tokens_created": [],
            "max_token_lengths": [1],
        }
        
        # Store original length for compression ratio
        self.original_length = len(self.data)
        self.max_token_length = 1

    def get_digram_stats(self) -> Counter:
        """Get digram counts"""
        counts = Counter()
        for pair in zip(self.data, self.data[1:]):
            pair = (int(pair[0]), int(pair[1]))
            counts[pair] += 1
        return counts

    def encode_to_vocab_size(self, target_vocab_size: int, plot_interval: Optional[int] = None, 
                           print_interval: int = 100) -> None:
        """Train until reaching target vocabulary size"""
        pbar = tqdm(total=target_vocab_size, desc="Training BPE", initial=len(self.chars))
        
        iteration = 0
        while len(self.itos) < target_vocab_size:
            result = self._merge_step()
            if result is None:
                break
                
            iteration += 1
            pbar.update(1)
            
            if print_interval and iteration % print_interval == 0:
                self._print_progress(iteration)
                
            if plot_interval and iteration % plot_interval == 0:
                self.plot_statistics(iteration=iteration)
        
        pbar.close()

    def _merge_step(self):
        """Perform one merge operation"""
        stats = self.get_digram_stats()
        if not stats:
            return None
            
        top_pair, count = max(stats.items(), key=lambda x: x[1])
        new_token = self._add_token(top_pair)
        self.data = self._replace_pairs(top_pair, new_token)
        self._update_stats(count)
        
        return new_token, count

    def _add_token(self, pair: tuple) -> int:
        """Add new token to vocabulary"""
        token_str = self.itos[pair[0]] + self.itos[pair[1]]
        token_id = len(self.itos)
        self.stoi[token_str] = token_id
        self.itos[token_id] = token_str
        self.max_token_length = max(self.max_token_length, len(token_str))
        return token_id

    def _replace_pairs(self, pair: tuple, new_token: int) -> List[int]:
        """Replace all occurrences of pair with new token"""
        result = []
        i = 0
        while i < len(self.data):
            if i < len(self.data) - 1 and self.data[i] == pair[0] and self.data[i + 1] == pair[1]:
                result.append(new_token)
                i += 2
            else:
                result.append(self.data[i])
                i += 1
        return result

    def _update_stats(self, merge_count: int):
        """Update training statistics"""
        self.stats["vocab_sizes"].append(len(self.itos))
        self.stats["data_sizes"].append(len(self.data))
        compression = self.original_length / len(self.data)
        self.stats["compression_ratios"].append(compression)
        self.stats["merge_counts"].append(merge_count)
        self.stats["tokens_created"].append(self.itos[len(self.itos)-1])
        self.stats["max_token_lengths"].append(self.max_token_length)

    def plot_statistics(self, iteration: Optional[int] = None):
        """Plot training statistics"""
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))
        
        # Plot training metrics
        ax1.plot(self.stats["vocab_sizes"], self.stats["data_sizes"])
        ax1.set_title("Vocabulary vs Dataset Size")
        
        ax2.plot(self.stats["vocab_sizes"], self.stats["compression_ratios"])
        ax2.set_title("Compression Ratio Progress")
        
        if self.stats["merge_counts"]:
            ax3.hist(self.stats["merge_counts"], bins=30)
            ax3.set_title("Merge Counts Distribution")
        
        if self.stats["tokens_created"]:
            lengths = [len(t) for t in self.stats["tokens_created"]]
            ax4.plot(range(len(lengths)), lengths)
            ax4.set_title("Token Length Evolution")
        
        plt.tight_layout()
        plt.show()

    def save_to_file(self, filepath: Path):
        """Save encoder state"""
        state = {
            "chars": self.chars,
            "stoi": self.stoi,
            "max_token_length": self.max_token_length,
            "stats": self.stats
        }
        with open(filepath, 'w', encoding='utf-8') as f:
            json.dump(state, f, ensure_ascii=False, indent=2)

    @classmethod
    def load_from_file(cls, filepath: Path):
        """Load encoder state"""
        with open(filepath, 'r', encoding='utf-8') as f:
            state = json.load(f)
            
        instance = cls("")  # Create empty instance
        instance.chars = state["chars"]
        instance.stoi = state["stoi"]
        instance.itos = {int(i): s for s, i in state["stoi"].items()}
        instance.max_token_length = state["max_token_length"]
        instance.stats = state["stats"]
        
        return instance

    def _print_progress(self, iteration: int):
        """Print training progress"""
        print(f"\nIteration {iteration}:")
        print(f"Vocabulary size: {len(self.itos):,}")
        print(f"Data size: {len(self.data):,}")
        print(f"Compression ratio: {self.stats['compression_ratios'][-1]:.2f}")
        
        if self.stats["merge_counts"]:
            last_merge = self.stats["merge_counts"][-1]
            last_token = self.stats["tokens_created"][-1]
            print(f"Last merge count: {last_merge:,}")
            print(f"Last token created: '{last_token}'")
        
        print(f"Max token length: {self.max_token_length}")

class TokenizerInternal:
    """Tokenizer using trained BPE model"""
    def __init__(self, encoder: BytePairEncoder):
        self.stoi = encoder.stoi
        self.max_token_length = encoder.max_token_length
        self._trie = self._build_trie()

    def _build_trie(self) -> TrieNode:
        """Build trie for efficient tokenization"""
        root = TrieNode()
        for token in self.stoi:
            node = root
            for char in token:
                if char not in node.children:
                    node.children[char] = TrieNode()
                node = node.children[char]
            node.is_token = True
            node.token = token
        return root

    def tokenize(self, text: str) -> List[str]:
        """Tokenize text using trie-based matching"""
        tokens = []
        pos = 0
        while pos < len(text):
            token = self._find_longest_token(text[pos:])
            tokens.append(token)
            pos += len(token)
        return tokens

    def _find_longest_token(self, text: str) -> str:
        """Find longest matching token starting at current position"""
        node = self._trie
        longest = text[0]
        current = ""
        
        for char in text[:self.max_token_length]:
            if char not in node.children:
                break
            current += char
            node = node.children[char]
            if node.is_token:
                longest = node.token
                
        return longest