File size: 2,257 Bytes
65ec16b
2f4df5a
33d2c2b
2f4df5a
 
 
33d2c2b
2f4df5a
33d2c2b
2f4df5a
 
 
 
 
 
33d2c2b
2f4df5a
33d2c2b
2f4df5a
33d2c2b
 
 
 
0a82a06
 
 
 
33d2c2b
 
 
 
 
0b80774
33d2c2b
2f4df5a
36f12a2
33d2c2b
 
 
 
 
 
 
36f12a2
 
 
 
 
 
 
 
 
2f4df5a
36f12a2
33d2c2b
2f4df5a
 
33d2c2b
 
2f4df5a
33d2c2b
2f4df5a
33d2c2b
 
 
2f4df5a
 
 
33d2c2b
2f4df5a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import torch
from llama_cpp import Llama  # Library for GGUF model handling
from typing import Any, List, Dict


class FixedVocabLogitsProcessor:
    """
    A custom logits processor for GGUF-compatible models.
    """

    def __init__(self, allowed_ids: set[int], fill_value=float('-inf')):
        self.allowed_ids = allowed_ids
        self.fill_value = fill_value

    def apply(self, logits: torch.FloatTensor):
        """
        Modify logits to restrict to allowed token IDs.
        """
        for token_id in range(len(logits)):
            if token_id not in self.allowed_ids:
                logits[token_id] = self.fill_value
        return logits


class EndpointHandler:
    def __init__(self, path=""):
        """
        Initialize the GGUF model handler.
        Args:
            path (str): Path to the GGUF file.
        """
        self.model = Llama(model_path='/repository/model.gguf')
        self.tokenizer = self.model.tokenizer  # GGUF-specific tokenizer, if available

    def __call__(self, data: Any) -> List[Dict[str, str]]:
        """
        Handle the request, performing inference with a restricted vocabulary.
        Args:
            data (Any): Input data.
        Returns:
            List[Dict[str, str]]: Generated output.
        """
        # Extract inputs and parameters
        inputs = data.pop("inputs", data)
        parameters = data.pop("parameters", {})
        vocab_list = data.pop("vocab_list", None)

        if not vocab_list:
            raise ValueError("You must provide a 'vocab_list' to define allowed tokens.")

        # Define allowed tokens dynamically
        allowed_ids = set()
        for word in vocab_list:
            for tid in self.model.tokenize(word):
                allowed_ids.add(tid)

        # Tokenize input
        input_ids = self.model.tokenize(inputs)

        # Perform inference
        output_ids = self.model.generate(
            input_ids,
            max_tokens=parameters.get("max_length", 30),
            logits_processor=lambda logits: FixedVocabLogitsProcessor(allowed_ids).apply(logits)
        )

        # Decode the output
        generated_text = self.model.detokenize(output_ids)

        return [{"generated_text": generated_text}]