Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import re | |
import logging | |
import json | |
import os | |
from pathlib import Path | |
from typing import Any, Callable, Sequence, Mapping, Iterable, Protocol, ClassVar, runtime_checkable | |
from sentencepiece import SentencePieceProcessor | |
import gguf | |
from .gguf_writer import GGUFWriter | |
logger = logging.getLogger(__name__) | |
class SpecialVocab: | |
merges: list[str] | |
add_special_token: dict[str, bool] | |
special_token_ids: dict[str, int] | |
chat_template: str | Sequence[Mapping[str, str]] | None | |
def __init__( | |
self, path: str | os.PathLike[str], load_merges: bool = False, | |
special_token_types: Iterable[str] | None = None, | |
n_vocab: int | None = None, | |
): | |
self.special_token_ids = {} | |
self.add_special_token = {} | |
self.n_vocab = n_vocab | |
self.load_merges = load_merges | |
self.merges = [] | |
self.chat_template = None | |
if special_token_types is not None: | |
self.special_token_types = special_token_types | |
else: | |
self.special_token_types = ('bos', 'eos', 'unk', 'sep', 'pad', 'cls', 'mask') | |
self._load(Path(path)) | |
def __repr__(self) -> str: | |
return '<SpecialVocab with {} merges, special tokens {}, add special tokens {}>'.format( | |
len(self.merges), self.special_token_ids or "unset", self.add_special_token or "unset", | |
) | |
def add_to_gguf(self, gw: GGUFWriter, quiet: bool = False) -> None: | |
if self.merges: | |
if not quiet: | |
logger.info(f'Adding {len(self.merges)} merge(s).') | |
gw.add_token_merges(self.merges) | |
elif self.load_merges: | |
logger.warning('Adding merges requested but no merges found, output may be non-functional.') | |
for typ, tokid in self.special_token_ids.items(): | |
id_handler: Callable[[int], None] | None = getattr(gw, f'add_{typ}_token_id', None) | |
if id_handler is None: | |
logger.warning(f'No handler for special token type {typ} with id {tokid} - skipping') | |
continue | |
if not quiet: | |
logger.info(f'Setting special token type {typ} to {tokid}') | |
id_handler(tokid) | |
for typ, value in self.add_special_token.items(): | |
add_handler: Callable[[bool], None] | None = getattr(gw, f'add_add_{typ}_token', None) | |
if add_handler is None: | |
logger.warning(f'No handler for add_{typ}_token with value {value} - skipping') | |
continue | |
if not quiet: | |
logger.info(f'Setting add_{typ}_token to {value}') | |
add_handler(value) | |
if self.chat_template is not None: | |
if not quiet: | |
logger.info(f'Setting chat_template to {self.chat_template}') | |
gw.add_chat_template(self.chat_template) | |
def _load(self, path: Path) -> None: | |
self._try_load_from_tokenizer_json(path) | |
self._try_load_from_config_json(path) | |
if self.load_merges and not self.merges: | |
self._try_load_merges_txt(path) | |
def _try_load_merges_txt(self, path: Path) -> bool: | |
merges_file = path / 'merges.txt' | |
if not merges_file.is_file(): | |
return False | |
with open(merges_file, 'r', encoding = 'utf-8') as fp: | |
first_line = next(fp, '').strip() | |
if not first_line.startswith('#'): | |
fp.seek(0) | |
line_num = 0 | |
else: | |
line_num = 1 | |
merges = [] | |
for line in fp: | |
line_num += 1 | |
line = line.strip() | |
if not line: | |
continue | |
parts = line.split(None, 3) | |
if len(parts) != 2: | |
logger.warning(f'{merges_file.name}: Line {line_num}: Entry malformed, ignoring') | |
continue | |
merges.append(f'{parts[0]} {parts[1]}') | |
self.merges = merges | |
return True | |
def _set_special_token(self, typ: str, tid: Any) -> None: | |
if not isinstance(tid, int): | |
return | |
if tid < 0: | |
raise ValueError(f'invalid value for special token type {typ}: {tid}') | |
if self.n_vocab is None or tid < self.n_vocab: | |
if typ in self.special_token_ids: | |
return | |
self.special_token_ids[typ] = tid | |
return | |
logger.warning(f'Special token type {typ}, id {tid} out of range, must be under {self.n_vocab} - skipping') | |
def _try_load_from_tokenizer_json(self, path: Path) -> bool: | |
tokenizer_file = path / 'tokenizer.json' | |
if tokenizer_file.is_file(): | |
with open(tokenizer_file, encoding = 'utf-8') as f: | |
tokenizer = json.load(f) | |
if self.load_merges: | |
merges = tokenizer.get('model', {}).get('merges') | |
if isinstance(merges, list) and merges and isinstance(merges[0], str): | |
self.merges = merges | |
added_tokens = tokenizer.get('added_tokens', {}) | |
else: | |
added_tokens = {} | |
tokenizer_config_file = path / 'tokenizer_config.json' | |
if not tokenizer_config_file.is_file(): | |
return True | |
with open(tokenizer_config_file, encoding = 'utf-8') as f: | |
tokenizer_config = json.load(f) | |
chat_template = tokenizer_config.get('chat_template') | |
if chat_template is None or isinstance(chat_template, (str, list)): | |
self.chat_template = chat_template | |
else: | |
logger.warning(f'Bad type for chat_template field in {tokenizer_config_file!r} - ignoring') | |
for typ in self.special_token_types: | |
add_entry = tokenizer_config.get(f'add_{typ}_token') | |
if isinstance(add_entry, bool): | |
self.add_special_token[typ] = add_entry | |
entry = tokenizer_config.get(f'{typ}_token') | |
if isinstance(entry, str): | |
tc_content = entry | |
elif isinstance(entry, dict): | |
entry_content = entry.get('content') | |
if not isinstance(entry_content, str): | |
continue | |
tc_content = entry_content | |
else: | |
continue | |
# We only need the first match here. | |
maybe_token_id = next( | |
(atok.get('id') for atok in added_tokens if atok.get('content') == tc_content), | |
None, | |
) | |
self._set_special_token(typ, maybe_token_id) | |
return True | |
def _try_load_from_config_json(self, path: Path) -> bool: | |
config_file = path / 'config.json' | |
if not config_file.is_file(): | |
return False | |
with open(config_file, encoding = 'utf-8') as f: | |
config = json.load(f) | |
for typ in self.special_token_types: | |
self._set_special_token(typ, config.get(f'{typ}_token_id')) | |
return True | |
class BaseVocab(Protocol): | |
tokenizer_model: ClassVar[str] | |
name: ClassVar[str] | |
class Vocab(BaseVocab, Protocol): | |
vocab_size: int | |
added_tokens_dict: dict[str, int] | |
added_tokens_list: list[str] | |
fname_tokenizer: Path | |
def __init__(self, base_path: Path): ... | |
def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: ... | |
class NoVocab(BaseVocab): | |
tokenizer_model = "no_vocab" | |
name = "no_vocab" | |
def __repr__(self) -> str: | |
return "<NoVocab for a model without integrated vocabulary>" | |
class BpeVocab(Vocab): | |
tokenizer_model = "gpt2" | |
name = "bpe" | |
def __init__(self, base_path: Path): | |
added_tokens: dict[str, int] = {} | |
if (fname_tokenizer := base_path / 'vocab.json').exists(): | |
# "slow" tokenizer | |
with open(fname_tokenizer, encoding="utf-8") as f: | |
self.vocab = json.load(f) | |
try: | |
# FIXME: Verify that added tokens here _cannot_ overlap with the main vocab. | |
with open(base_path / 'added_tokens.json', encoding="utf-8") as f: | |
added_tokens = json.load(f) | |
except FileNotFoundError: | |
pass | |
else: | |
# "fast" tokenizer | |
fname_tokenizer = base_path / 'tokenizer.json' | |
# if this fails, FileNotFoundError propagates to caller | |
with open(fname_tokenizer, encoding="utf-8") as f: | |
tokenizer_json = json.load(f) | |
tokenizer_model: dict[str, Any] = tokenizer_json['model'] | |
if ( | |
tokenizer_model['type'] != 'BPE' or tokenizer_model.get('byte_fallback', False) | |
or tokenizer_json['decoder']['type'] != 'ByteLevel' | |
): | |
raise FileNotFoundError('Cannot find GPT-2 BPE tokenizer') | |
self.vocab = tokenizer_model["vocab"] | |
if (added := tokenizer_json.get('added_tokens')) is not None: | |
# Added tokens here can be duplicates of the main vocabulary. | |
added_tokens = {item['content']: item['id'] | |
for item in added | |
if item['content'] not in self.vocab} | |
vocab_size = len(self.vocab) | |
expected_ids = list(range(vocab_size, vocab_size + len(added_tokens))) | |
actual_ids = sorted(added_tokens.values()) | |
if expected_ids != actual_ids: | |
expected_end_id = vocab_size + len(actual_ids) - 1 | |
raise ValueError(f"Expected the {len(actual_ids)} added token ID(s) to be sequential in the range " | |
f"{vocab_size} - {expected_end_id}; got {actual_ids}") | |
items = sorted(added_tokens.items(), key=lambda text_idx: text_idx[1]) | |
self.added_tokens_dict = added_tokens | |
self.added_tokens_list = [text for (text, idx) in items] | |
self.vocab_size_base = vocab_size | |
self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) | |
self.fname_tokenizer = fname_tokenizer | |
def bpe_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: | |
reverse_vocab = {id: encoded_tok for encoded_tok, id in self.vocab.items()} | |
for i, _ in enumerate(self.vocab): | |
yield reverse_vocab[i], 0.0, gguf.TokenType.NORMAL | |
def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: | |
for text in self.added_tokens_list: | |
score = -1000.0 | |
yield text.encode("utf-8"), score, gguf.TokenType.CONTROL | |
def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: | |
yield from self.bpe_tokens() | |
yield from self.added_tokens() | |
def __repr__(self) -> str: | |
return f"<BpeVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>" | |
class SentencePieceVocab(Vocab): | |
tokenizer_model = "llama" | |
name = "spm" | |
def __init__(self, base_path: Path): | |
added_tokens: dict[str, int] = {} | |
if (fname_tokenizer := base_path / 'tokenizer.model').exists(): | |
# normal location | |
try: | |
with open(base_path / 'added_tokens.json', encoding="utf-8") as f: | |
added_tokens = json.load(f) | |
except FileNotFoundError: | |
pass | |
elif not (fname_tokenizer := base_path.parent / 'tokenizer.model').exists(): | |
# not found in alternate location either | |
raise FileNotFoundError('Cannot find tokenizer.model') | |
self.sentencepiece_tokenizer = SentencePieceProcessor() | |
self.sentencepiece_tokenizer.LoadFromFile(str(fname_tokenizer)) | |
vocab_size = self.sentencepiece_tokenizer.vocab_size() | |
new_tokens = {id: piece for piece, id in added_tokens.items() if id >= vocab_size} | |
expected_new_ids = list(range(vocab_size, vocab_size + len(new_tokens))) | |
actual_new_ids = sorted(new_tokens.keys()) | |
if expected_new_ids != actual_new_ids: | |
raise ValueError(f"Expected new token IDs {expected_new_ids} to be sequential; got {actual_new_ids}") | |
# Token pieces that were added to the base vocabulary. | |
self.added_tokens_dict = added_tokens | |
self.added_tokens_list = [new_tokens[id] for id in actual_new_ids] | |
self.vocab_size_base = vocab_size | |
self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) | |
self.fname_tokenizer = fname_tokenizer | |
def sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: | |
tokenizer = self.sentencepiece_tokenizer | |
for i in range(tokenizer.vocab_size()): | |
piece = tokenizer.IdToPiece(i) | |
text = piece.encode("utf-8") | |
score: float = tokenizer.GetScore(i) | |
toktype = gguf.TokenType.NORMAL | |
if tokenizer.IsUnknown(i): | |
toktype = gguf.TokenType.UNKNOWN | |
if tokenizer.IsControl(i): | |
toktype = gguf.TokenType.CONTROL | |
# NOTE: I think added_tokens are user defined. | |
# ref: https://github.com/google/sentencepiece/blob/master/src/sentencepiece_model.proto | |
# if tokenizer.is_user_defined(i): toktype = gguf.TokenType.USER_DEFINED | |
if tokenizer.IsUnused(i): | |
toktype = gguf.TokenType.UNUSED | |
if tokenizer.IsByte(i): | |
toktype = gguf.TokenType.BYTE | |
yield text, score, toktype | |
def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: | |
for text in self.added_tokens_list: | |
score = -1000.0 | |
yield text.encode("utf-8"), score, gguf.TokenType.USER_DEFINED | |
def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: | |
yield from self.sentencepiece_tokens() | |
yield from self.added_tokens() | |
def __repr__(self) -> str: | |
return f"<SentencePieceVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>" | |
class LlamaHfVocab(Vocab): | |
tokenizer_model = "llama" | |
name = "hfft" | |
def __init__(self, base_path: Path): | |
fname_tokenizer = base_path / 'tokenizer.json' | |
# if this fails, FileNotFoundError propagates to caller | |
with open(fname_tokenizer, encoding='utf-8') as f: | |
tokenizer_json = json.load(f) | |
# pre-check so we know if we need transformers | |
tokenizer_model: dict[str, Any] = tokenizer_json['model'] | |
is_llama3 = ( | |
tokenizer_model['type'] == 'BPE' and tokenizer_model.get('ignore_merges', False) | |
and not tokenizer_model.get('byte_fallback', True) | |
) | |
if is_llama3: | |
raise TypeError('Llama 3 must be converted with BpeVocab') | |
if not is_llama3 and ( | |
tokenizer_model['type'] != 'BPE' or not tokenizer_model.get('byte_fallback', False) | |
or tokenizer_json['decoder']['type'] != 'Sequence' | |
): | |
raise FileNotFoundError('Cannot find Llama BPE tokenizer') | |
try: | |
from transformers import AutoTokenizer | |
except ImportError as e: | |
raise ImportError( | |
"To use LlamaHfVocab, please install the `transformers` package. " | |
"You can install it with `pip install transformers`." | |
) from e | |
# Allow the tokenizer to default to slow or fast versions. | |
# Explicitly set tokenizer to use local paths. | |
self.tokenizer = AutoTokenizer.from_pretrained( | |
base_path, | |
cache_dir=base_path, | |
local_files_only=True, | |
) | |
assert self.tokenizer.is_fast # assume tokenizer.json is used | |
# Initialize lists and dictionaries for added tokens | |
self.added_tokens_list = [] | |
self.added_tokens_dict = dict() | |
self.added_tokens_ids = set() | |
# Process added tokens | |
for tok, tokidx in sorted( | |
self.tokenizer.get_added_vocab().items(), key=lambda x: x[1] | |
): | |
# Only consider added tokens that are not in the base vocabulary | |
if tokidx >= self.tokenizer.vocab_size: | |
self.added_tokens_list.append(tok) | |
self.added_tokens_dict[tok] = tokidx | |
self.added_tokens_ids.add(tokidx) | |
# Store special tokens and their IDs | |
self.specials = { | |
tok: self.tokenizer.get_vocab()[tok] | |
for tok in self.tokenizer.all_special_tokens | |
} | |
self.special_ids = set(self.tokenizer.all_special_ids) | |
# Set vocabulary sizes | |
self.vocab_size_base = self.tokenizer.vocab_size | |
self.vocab_size = self.vocab_size_base + len(self.added_tokens_list) | |
self.fname_tokenizer = fname_tokenizer | |
def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: | |
reverse_vocab = { | |
id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items() | |
} | |
for token_id in range(self.vocab_size_base): | |
# Skip processing added tokens here | |
if token_id in self.added_tokens_ids: | |
continue | |
# Convert token text to bytes | |
token_text = reverse_vocab[token_id].encode("utf-8") | |
# Yield token text, score, and type | |
yield token_text, self.get_token_score(token_id), self.get_token_type( | |
token_id, token_text, self.special_ids # Reuse already stored special IDs | |
) | |
def get_token_type(self, token_id: int, token_text: bytes, special_ids: set[int]) -> gguf.TokenType: | |
# Special case for byte tokens | |
if re.fullmatch(br"<0x[0-9A-Fa-f]{2}>", token_text): | |
return gguf.TokenType.BYTE | |
# Determine token type based on whether it's a special token | |
return gguf.TokenType.CONTROL if token_id in special_ids else gguf.TokenType.NORMAL | |
def get_token_score(self, token_id: int) -> float: | |
# Placeholder for actual logic to determine the token's score | |
# This needs to be implemented based on specific requirements | |
return -1000.0 # Default score | |
def added_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: | |
for text in self.added_tokens_list: | |
if text in self.specials: | |
toktype = self.get_token_type(self.specials[text], b'', self.special_ids) | |
score = self.get_token_score(self.specials[text]) | |
else: | |
toktype = gguf.TokenType.USER_DEFINED | |
score = -1000.0 | |
yield text.encode("utf-8"), score, toktype | |
def has_newline_token(self): | |
return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab | |
def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]: | |
yield from self.hf_tokens() | |
yield from self.added_tokens() | |
def __repr__(self) -> str: | |
return f"<LlamaHfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>" | |