|
""" |
|
Minimal (byte-level) Byte Pair Encoding tokenizer. |
|
|
|
Unlike RegexTokenizer: |
|
- Operates on integer codes from an encodec codebook. |
|
""" |
|
|
|
import regex as re |
|
from .base import Tokenizer, get_stats, merge |
|
|
|
|
|
class CodebookTokenizer(Tokenizer): |
|
|
|
def __init__(self, pattern=None, codebook_size=1024): |
|
""" |
|
- pattern: optional string to override the default (GPT-4 split pattern) |
|
- special_tokens: str -> int dictionary of special tokens |
|
example: {'<|endoftext|>': 100257} |
|
""" |
|
self.merges = {} |
|
self.pattern = pattern |
|
self.compiled_pattern = re.compile(self.pattern) |
|
self.special_tokens = {} |
|
self.inverse_special_tokens = {} |
|
self.codebook_size = codebook_size |
|
self.vocab = self._build_vocab() |
|
|
|
def train(self, text, vocab_size, verbose=False): |
|
assert vocab_size >= self.codebook_size |
|
num_merges = vocab_size - self.codebook_size |
|
|
|
|
|
|
|
text_chunks = [text,] |
|
|
|
|
|
ids = [[int(idx) for idx in ch.split(' ')] for ch in text_chunks] |
|
|
|
|
|
merges = {} |
|
|
|
vocab = {idx: f" {idx:04d}".encode('utf-8') for idx in range(self.codebook_size)} |
|
|
|
for i in range(num_merges): |
|
|
|
stats = {} |
|
for chunk_ids in ids: |
|
|
|
get_stats(chunk_ids, stats) |
|
|
|
pair = max(stats, key=stats.get) |
|
|
|
idx = self.codebook_size + i |
|
|
|
ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids] |
|
|
|
merges[pair] = idx |
|
vocab[idx] = vocab[pair[0]] + vocab[pair[1]] |
|
|
|
if verbose: |
|
print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences") |
|
|
|
|
|
self.merges = merges |
|
self.vocab = vocab |
|
|
|
def register_special_tokens(self, special_tokens): |
|
|
|
|
|
self.special_tokens = special_tokens |
|
self.inverse_special_tokens = {v: k for k, v in special_tokens.items()} |
|
|
|
def decode(self, ids): |
|
|
|
part_bytes = [] |
|
for idx in ids: |
|
if idx in self.vocab: |
|
part_bytes.append(self.vocab[idx]) |
|
elif idx in self.inverse_special_tokens: |
|
part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8")) |
|
else: |
|
raise ValueError(f"invalid token id: {idx}") |
|
text_bytes = b"".join(part_bytes) |
|
text = text_bytes.decode("utf-8", errors="replace") |
|
return text |
|
|
|
def decode_int(self, ids) -> list[int]: |
|
ret: str = self.decode(ids) |
|
for s in self.special_tokens: |
|
ret = ret.replace(s, ' ' + s + ' ') |
|
ret = ret.strip() |
|
ret = [int(t) if t[0].isnumeric() else t for t in ret.split(' ') if len(t) > 0] |
|
return ret |
|
|
|
def _encode_chunk(self, text_bytes): |
|
|
|
|
|
ids = list(text_bytes) |
|
while len(ids) >= 2: |
|
|
|
stats = get_stats(ids) |
|
pair = min(stats, key=lambda p: self.merges.get(p, float("inf"))) |
|
|
|
|
|
|
|
|
|
if pair not in self.merges: |
|
break |
|
|
|
idx = self.merges[pair] |
|
ids = merge(ids, pair, idx) |
|
return ids |
|
|
|
def encode_ordinary(self, text): |
|
"""Encoding that ignores any special tokens.""" |
|
|
|
text_chunks = [text,] |
|
|
|
ids = [] |
|
for chunk in text_chunks: |
|
|
|
chunk_ids = [int(idx) for idx in chunk.split(' ')] |
|
chunk_ids = self._encode_chunk(chunk_ids) |
|
ids.extend(chunk_ids) |
|
return ids |
|
|
|
def encode(self, text, allowed_special="none_raise"): |
|
""" |
|
Unlike encode_ordinary, this function handles special tokens. |
|
allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens |
|
if none_raise, then an error is raised if any special token is encountered in text |
|
this is the default tiktoken behavior right now as well |
|
any other behavior is either annoying, or a major footgun |
|
""" |
|
|
|
special = None |
|
if allowed_special == "all": |
|
special = self.special_tokens |
|
elif allowed_special == "none": |
|
special = {} |
|
elif allowed_special == "none_raise": |
|
special = {} |
|
assert all(token not in text for token in self.special_tokens) |
|
elif isinstance(allowed_special, set): |
|
special = {k: v for k, v in self.special_tokens.items() if k in allowed_special} |
|
else: |
|
raise ValueError(f"allowed_special={allowed_special} not understood") |
|
if not special: |
|
|
|
return self.encode_ordinary(text) |
|
|
|
|
|
|
|
|
|
|
|
special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")" |
|
special_chunks = re.split(special_pattern, text) |
|
|
|
|
|
ids = [] |
|
for part in special_chunks: |
|
part = part.strip() |
|
if len(part) == 0: continue |
|
if part in special: |
|
|
|
ids.append(special[part]) |
|
else: |
|
|
|
ids.extend(self.encode_ordinary(part)) |
|
return ids |
|
|
|
|
|
def load(self, model_file): |
|
"""Inverse of save() but only for the model file""" |
|
model_file = str(model_file) |
|
assert model_file.endswith(".model") |
|
|
|
merges = {} |
|
special_tokens = {} |
|
idx = self.codebook_size |
|
with open(model_file, 'r', encoding="utf-8") as f: |
|
|
|
version = f.readline().strip() |
|
assert version == "minbpe v1" |
|
|
|
self.pattern = f.readline().strip() |
|
|
|
num_special = int(f.readline().strip()) |
|
for _ in range(num_special): |
|
special, special_idx = f.readline().strip().split() |
|
special_tokens[special] = int(special_idx) |
|
|
|
for line in f: |
|
|
|
idx1, idx2 = map(int, line.split()) |
|
merges[(idx1, idx2)] = idx |
|
idx += 1 |
|
self.merges = merges |
|
self.special_tokens = special_tokens |
|
self.vocab = self._build_vocab() |
|
|
|
|
|
def _build_vocab(self): |
|
|
|
vocab = {idx: f" {idx:04d}".encode('utf-8') for idx in range(self.codebook_size)} |
|
for (p0, p1), idx in self.merges.items(): |
|
vocab[idx] = vocab[p0] + vocab[p1] |
|
for special, idx in self.special_tokens.items(): |
|
vocab[idx] = special.encode("utf-8") |
|
return vocab |