MARS5-TTS / mars5 /minbpe /codebook.py
arnavmehta7's picture
Add files (#1)
8520a55 verified
raw
history blame
9.2 kB
"""
Minimal (byte-level) Byte Pair Encoding tokenizer.
Unlike RegexTokenizer:
- Operates on integer codes from an encodec codebook.
"""
import regex as re
from .base import Tokenizer, get_stats, merge
class CodebookTokenizer(Tokenizer):
def __init__(self, pattern=None, codebook_size=1024):
"""
- pattern: optional string to override the default (GPT-4 split pattern)
- special_tokens: str -> int dictionary of special tokens
example: {'<|endoftext|>': 100257}
"""
self.merges = {} # (int, int) -> int
self.pattern = pattern
self.compiled_pattern = re.compile(self.pattern)
self.special_tokens = {} # str -> int, e.g. {'<|endoftext|>': 100257}
self.inverse_special_tokens = {}
self.codebook_size = codebook_size
self.vocab = self._build_vocab() # int -> bytes
def train(self, text, vocab_size, verbose=False):
assert vocab_size >= self.codebook_size
num_merges = vocab_size - self.codebook_size
# split the text up into text chunks
# text is a continuous signal, there is no splitting it up.
text_chunks = [text,] # re.findall(self.compiled_pattern, text)
# input text preprocessing
ids = [[int(idx) for idx in ch.split(' ')] for ch in text_chunks]
# iteratively merge the most common pairs to create new tokens
merges = {} # (int, int) -> int
# vocab = {idx: bytes([idx]) for idx in range(self.codebook_size)} # idx -> bytes
vocab = {idx: f" {idx:04d}".encode('utf-8') for idx in range(self.codebook_size)} # idx -> bytes
for i in range(num_merges):
# count the number of times every consecutive pair appears
stats = {}
for chunk_ids in ids:
# passing in stats will update it in place, adding up counts
get_stats(chunk_ids, stats)
# find the pair with the highest count
pair = max(stats, key=stats.get)
# mint a new token: assign it the next available id
idx = self.codebook_size + i
# replace all occurrences of pair in ids with idx
ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids]
# save the merge
merges[pair] = idx
vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
# prints
if verbose:
print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
# save class variables
self.merges = merges # used in encode()
self.vocab = vocab # used in decode()
def register_special_tokens(self, special_tokens):
# special_tokens is a dictionary of str -> int
# example: {"<|endoftext|>": 100257}
self.special_tokens = special_tokens
self.inverse_special_tokens = {v: k for k, v in special_tokens.items()}
def decode(self, ids):
# given ids (list of integers), return Python string
part_bytes = []
for idx in ids:
if idx in self.vocab:
part_bytes.append(self.vocab[idx])
elif idx in self.inverse_special_tokens:
part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8"))
else:
raise ValueError(f"invalid token id: {idx}")
text_bytes = b"".join(part_bytes)
text = text_bytes.decode("utf-8", errors="replace")
return text
def decode_int(self, ids) -> list[int]:
ret: str = self.decode(ids)
for s in self.special_tokens:
ret = ret.replace(s, ' ' + s + ' ')
ret = ret.strip()
ret = [int(t) if t[0].isnumeric() else t for t in ret.split(' ') if len(t) > 0]
return ret
def _encode_chunk(self, text_bytes):
# return the token ids
# let's begin. first, convert all bytes to integers in range 0..255
ids = list(text_bytes)
while len(ids) >= 2:
# find the pair with the lowest merge index
stats = get_stats(ids)
pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
# subtle: if there are no more merges available, the key will
# result in an inf for every single pair, and the min will be
# just the first pair in the list, arbitrarily
# we can detect this terminating case by a membership check
if pair not in self.merges:
break # nothing else can be merged anymore
# otherwise let's merge the best pair (lowest merge index)
idx = self.merges[pair]
ids = merge(ids, pair, idx)
return ids
def encode_ordinary(self, text):
"""Encoding that ignores any special tokens."""
# split text into chunks of text by categories defined in regex pattern
text_chunks = [text,] #re.findall(self.compiled_pattern, text)
# all chunks of text are encoded separately, then results are joined
ids = []
for chunk in text_chunks:
# chunk_bytes = chunk.encode("utf-8") # raw bytes
chunk_ids = [int(idx) for idx in chunk.split(' ')]
chunk_ids = self._encode_chunk(chunk_ids)
ids.extend(chunk_ids)
return ids
def encode(self, text, allowed_special="none_raise"):
"""
Unlike encode_ordinary, this function handles special tokens.
allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens
if none_raise, then an error is raised if any special token is encountered in text
this is the default tiktoken behavior right now as well
any other behavior is either annoying, or a major footgun
"""
# decode the user desire w.r.t. handling of special tokens
special = None
if allowed_special == "all":
special = self.special_tokens
elif allowed_special == "none":
special = {}
elif allowed_special == "none_raise":
special = {}
assert all(token not in text for token in self.special_tokens)
elif isinstance(allowed_special, set):
special = {k: v for k, v in self.special_tokens.items() if k in allowed_special}
else:
raise ValueError(f"allowed_special={allowed_special} not understood")
if not special:
# shortcut: if no special tokens, just use the ordinary encoding
return self.encode_ordinary(text)
# otherwise, we have to be careful with potential special tokens in text
# we handle special tokens by splitting the text
# based on the occurrence of any exact match with any of the special tokens
# we can use re.split for this. note that surrounding the pattern with ()
# makes it into a capturing group, so the special tokens will be included
special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")"
special_chunks = re.split(special_pattern, text)
# now all the special characters are separated from the rest of the text
# all chunks of text are encoded separately, then results are joined
ids = []
for part in special_chunks:
part = part.strip()
if len(part) == 0: continue
if part in special:
# this is a special token, encode it separately as a special case
ids.append(special[part])
else:
# this is an ordinary sequence, encode it normally
ids.extend(self.encode_ordinary(part))
return ids
def load(self, model_file):
"""Inverse of save() but only for the model file"""
model_file = str(model_file)
assert model_file.endswith(".model")
# read the model file
merges = {}
special_tokens = {}
idx = self.codebook_size
with open(model_file, 'r', encoding="utf-8") as f:
# read the version
version = f.readline().strip()
assert version == "minbpe v1"
# read the pattern
self.pattern = f.readline().strip()
# read the special tokens
num_special = int(f.readline().strip())
for _ in range(num_special):
special, special_idx = f.readline().strip().split()
special_tokens[special] = int(special_idx)
# read the merges
for line in f:
# print(line)
idx1, idx2 = map(int, line.split())
merges[(idx1, idx2)] = idx
idx += 1
self.merges = merges
self.special_tokens = special_tokens
self.vocab = self._build_vocab()
def _build_vocab(self):
# vocab is simply and deterministically derived from merges
vocab = {idx: f" {idx:04d}".encode('utf-8') for idx in range(self.codebook_size)}
for (p0, p1), idx in self.merges.items():
vocab[idx] = vocab[p0] + vocab[p1]
for special, idx in self.special_tokens.items():
vocab[idx] = special.encode("utf-8")
return vocab