File size: 2,257 Bytes
65ec16b 2f4df5a 33d2c2b 2f4df5a 33d2c2b 2f4df5a 33d2c2b 2f4df5a 33d2c2b 2f4df5a 33d2c2b 2f4df5a 33d2c2b 0a82a06 33d2c2b 0b80774 33d2c2b 2f4df5a 36f12a2 33d2c2b 36f12a2 2f4df5a 36f12a2 33d2c2b 2f4df5a 33d2c2b 2f4df5a 33d2c2b 2f4df5a 33d2c2b 2f4df5a 33d2c2b 2f4df5a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import os
import torch
from llama_cpp import Llama # Library for GGUF model handling
from typing import Any, List, Dict
class FixedVocabLogitsProcessor:
"""
A custom logits processor for GGUF-compatible models.
"""
def __init__(self, allowed_ids: set[int], fill_value=float('-inf')):
self.allowed_ids = allowed_ids
self.fill_value = fill_value
def apply(self, logits: torch.FloatTensor):
"""
Modify logits to restrict to allowed token IDs.
"""
for token_id in range(len(logits)):
if token_id not in self.allowed_ids:
logits[token_id] = self.fill_value
return logits
class EndpointHandler:
def __init__(self, path=""):
"""
Initialize the GGUF model handler.
Args:
path (str): Path to the GGUF file.
"""
self.model = Llama(model_path='/repository/model.gguf')
self.tokenizer = self.model.tokenizer # GGUF-specific tokenizer, if available
def __call__(self, data: Any) -> List[Dict[str, str]]:
"""
Handle the request, performing inference with a restricted vocabulary.
Args:
data (Any): Input data.
Returns:
List[Dict[str, str]]: Generated output.
"""
# Extract inputs and parameters
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", {})
vocab_list = data.pop("vocab_list", None)
if not vocab_list:
raise ValueError("You must provide a 'vocab_list' to define allowed tokens.")
# Define allowed tokens dynamically
allowed_ids = set()
for word in vocab_list:
for tid in self.model.tokenize(word):
allowed_ids.add(tid)
# Tokenize input
input_ids = self.model.tokenize(inputs)
# Perform inference
output_ids = self.model.generate(
input_ids,
max_tokens=parameters.get("max_length", 30),
logits_processor=lambda logits: FixedVocabLogitsProcessor(allowed_ids).apply(logits)
)
# Decode the output
generated_text = self.model.detokenize(output_ids)
return [{"generated_text": generated_text}]
|