|
|
|
|
|
|
|
|
|
|
|
import os |
|
from pathlib import Path |
|
|
|
import tiktoken |
|
from tiktoken.load import load_tiktoken_bpe |
|
|
|
|
|
class Llama3Tokenizer: |
|
"""Thin wrapper around tiktoken that keeps track of Llama-3 special IDs.""" |
|
def __init__(self, model_path): |
|
if not os.path.isfile(model_path): |
|
raise FileNotFoundError(model_path) |
|
|
|
mergeable = load_tiktoken_bpe(model_path) |
|
|
|
|
|
self.special = { |
|
"<|begin_of_text|>": 128000, |
|
"<|end_of_text|>": 128001, |
|
"<|start_header_id|>": 128006, |
|
"<|end_header_id|>": 128007, |
|
"<|eot_id|>": 128009, |
|
} |
|
self.special.update({f"<|reserved_{i}|>": 128002 + i |
|
for i in range(256) |
|
if 128002 + i not in self.special.values()}) |
|
|
|
self.model = tiktoken.Encoding( |
|
name=Path(model_path).name, |
|
pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)" |
|
r"|[^\r\n\p{L}\p{N}]?\p{L}+" |
|
r"|\p{N}{1,3}" |
|
r"| ?[^\s\p{L}\p{N}]+[\r\n]*" |
|
r"|\s*[\r\n]+" |
|
r"|\s+(?!\S)" |
|
r"|\s+", |
|
mergeable_ranks=mergeable, |
|
special_tokens=self.special, |
|
) |
|
|
|
def encode(self, text, bos=False, eos=False, allowed_special=set()): |
|
ids: list[int] = [] |
|
|
|
if bos: |
|
ids.append(self.special_tokens["<|begin_of_text|>"]) |
|
|
|
|
|
ids.extend( |
|
self.model.encode( |
|
text, |
|
allowed_special=allowed_special, |
|
) |
|
) |
|
if eos: |
|
ids.append(self.special_tokens["<|end_of_text|>"]) |
|
|
|
return ids |
|
|
|
def decode(self, ids): |
|
return self.model.decode(ids) |
|
|
|
|
|
class ChatFormat: |
|
|
|
def __init__(self, tokenizer: Llama3Tokenizer, *, |
|
default_system="You are a helpful assistant."): |
|
self.tok = tokenizer |
|
self.default_system = default_system |
|
|
|
def _header(self, role): |
|
"""Encode <|start_header_id|>role<|end_header_id|>\n\n""" |
|
return ( |
|
[self.tok.special["<|start_header_id|>"]] |
|
+ self.tok.encode(role) |
|
+ [self.tok.special["<|end_header_id|>"]] |
|
+ self.tok.encode("\n\n") |
|
) |
|
|
|
def encode(self, user_message, system_message=None, allowed_special=None): |
|
sys_msg = system_message if system_message is not None else self.default_system |
|
|
|
ids = [self.tok.special["<|begin_of_text|>"]] |
|
|
|
|
|
ids += self._header("system") |
|
ids += self.tok.encode(sys_msg, allowed_special=allowed_special) |
|
ids += [self.tok.special["<|eot_id|>"]] |
|
|
|
|
|
ids += self._header("user") |
|
ids += self.tok.encode(user_message) |
|
ids += [self.tok.special["<|eot_id|>"]] |
|
|
|
|
|
ids += self._header("assistant") |
|
|
|
return ids |
|
|
|
def decode(self, ids): |
|
return self.tok.decode(ids) |
|
|
|
|
|
def clean_text(text, header_end="assistant<|end_header_id|>\n\n"): |
|
|
|
index = text.find(header_end) |
|
|
|
if index != -1: |
|
|
|
return text[index + len(header_end):].strip() |
|
else: |
|
|
|
return text |
|
|