|
from typing import Any, List, Dict |
|
from llama_cpp import Llama |
|
import numpy as np |
|
import torch |
|
from transformers import AutoTokenizer, LogitsProcessorList |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
""" |
|
Initialize the model handler using llama_cpp. |
|
""" |
|
self.model = Llama.from_pretrained( |
|
repo_id="bartowski/Meta-Llama-3.1-8B-Instruct-GGUF", |
|
filename="Meta-Llama-3.1-8B-Instruct-Q6_K.gguf" |
|
) |
|
self.tokenizer = AutoTokenizer.from_pretrained("taylorj94/Llama-3.2-1B") |
|
|
|
def get_allowed_token_ids(self, vocab_list: List[str]) -> set[int]: |
|
""" |
|
Generate a set of token IDs for a given list of allowed words. |
|
Includes plain, space-prefixed, capitalized, and uppercase forms of each word. |
|
""" |
|
allowed_ids = set() |
|
for word in vocab_list: |
|
|
|
variations = {word, " " + word, word.capitalize(), " " + word.capitalize()} |
|
|
|
|
|
for variation in variations: |
|
for token_id in self.tokenizer.encode(variation, add_special_tokens=False): |
|
allowed_ids.add(token_id) |
|
|
|
return allowed_ids |
|
|
|
def filter_allowed_tokens(self, input_ids: torch.Tensor, scores: np.ndarray, allowed_token_ids: set[int]) -> np.ndarray: |
|
""" |
|
Modify scores to allow only tokens in the allowed_token_ids set. |
|
Handles both 1D and 2D scores arrays. |
|
""" |
|
if scores.ndim == 1: |
|
|
|
mask = np.isin(np.arange(scores.shape[0]), list(allowed_token_ids)) |
|
scores[~mask] = float('-inf') |
|
elif scores.ndim == 2: |
|
|
|
for i in range(scores.shape[0]): |
|
mask = np.isin(np.arange(scores.shape[1]), list(allowed_token_ids)) |
|
scores[i, ~mask] = float('-inf') |
|
else: |
|
raise ValueError(f"Unsupported scores dimension: {scores.ndim}") |
|
return scores |
|
|
|
|
|
def __call__(self, data: Any) -> List[Dict[str, str]]: |
|
""" |
|
Handle the request, performing inference with a restricted vocabulary. |
|
""" |
|
|
|
inputs = data.get("inputs", None) |
|
parameters = data.get("parameters", {}) |
|
vocab_list = data.get("vocab_list", None) |
|
|
|
if not inputs: |
|
raise ValueError("The 'inputs' field is required.") |
|
|
|
|
|
logits_processors = None |
|
allowed_token_ids = [] |
|
|
|
if vocab_list: |
|
|
|
allowed_token_ids = self.get_allowed_token_ids(vocab_list) |
|
|
|
|
|
input_ids = torch.tensor([self.tokenizer.encode(inputs, add_special_tokens=False)]) |
|
|
|
|
|
logits_processors = LogitsProcessorList([ |
|
lambda input_ids, scores: self.filter_allowed_tokens(input_ids, scores, allowed_token_ids) |
|
]) |
|
|
|
|
|
response = self.model.create_chat_completion( |
|
messages=[ |
|
{"role": "user", "content": inputs} |
|
], |
|
max_tokens=parameters.get("max_length", 30), |
|
logits_processor=logits_processors, |
|
temperature=parameters.get("temperature", 1), |
|
repeat_penalty=parameters.get("repeat_penalty", 1.0) |
|
) |
|
|
|
|
|
generated_text = response["choices"][0]["message"]["content"] |
|
|
|
return [{"generated_text": generated_text, "allowed_token_ids": list(allowed_token_ids)}] |