File size: 4,067 Bytes
deb2e68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7c571a
deb2e68
 
 
 
 
 
 
 
 
 
 
9fd337d
deb2e68
 
 
9fd337d
deb2e68
 
9fd337d
deb2e68
 
9fd337d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
deb2e68
9fd337d
deb2e68
9fd337d
 
 
 
5ab33b7
9fd337d
 
 
 
 
 
deb2e68
9fd337d
 
deb2e68
9fd337d
 
deb2e68
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import os
from typing import Dict, List, Any

class EndpointHandler():
    def __init__(self, path=""):
        """
        Load the model and tokenizer. The 'path' argument is automatically
        passed by the Inference Endpoint infrastructure, pointing to the
        directory where your model files are located.
        """
        # Use trust_remote_code=True because the base model might rely on it
        # or if Unsloth modifications require it, even after merging.
        self.tokenizer = AutoTokenizer.from_pretrained(
            path,
            trust_remote_code=True
        )
        self.model = AutoModelForCausalLM.from_pretrained(
            path,
            torch_dtype=torch.bfloat16, # Use bfloat16 for better stability
            device_map="auto", # Use GPU if available on the endpoint
            trust_remote_code=True
        )
        # Optional: Explicitly set pad token if needed
        # if self.tokenizer.pad_token is None:
        #    self.tokenizer.pad_token = self.tokenizer.eos_token
        print("Handler initialized: Model and tokenizer loaded.")


    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Handles the inference request using manual generation.
        """
        try:
            # Extract inputs and parameters
            inputs_text = data.pop("inputs", None)
            parameters = data.pop("parameters", {})

            if inputs_text is None:
                return [{"error": "Missing 'inputs' key in request data."}]

            # Basic input validation
            if not isinstance(inputs_text, str):
                 return [{"error": "Invalid 'inputs' format. Must be a single string for this handler."}]

            # Set generation parameters
            params = {
                "max_new_tokens": 64,
                "temperature": 1.0,
                "top_p": 0.95,
                "top_k": 64,
                "do_sample": True, # Explicitly enable sampling
                "pad_token_id": self.tokenizer.eos_token_id # Use EOS for padding
            }
            # Update with user-provided parameters
            params.update(parameters)

            print(f"Received input: '{inputs_text}'")
            print(f"Using parameters: {params}")

            # Manually tokenize
            # Important: Add generation prompt structure if needed by the model/tokenizer chat template!
            # Assuming the tokenizer's chat template handles adding the prompt correctly when needed.
            # If not, you might need manual formatting here before tokenizing.
            # Let's try applying the chat template explicitly for robustness:
            messages = [{"role": "user", "content": inputs_text}]
            prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

            print(f"Formatted prompt: '{prompt}'")

            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)

            # Manually generate
            # Use torch.no_grad() for efficiency during inference
            with torch.no_grad():
                outputs = self.model.generate(**inputs, **params)

            # Decode the output
            # outputs[0] contains the full sequence (prompt + generation)
            # We need to decode only the generated part
            input_length = inputs.input_ids.shape[1]
            generated_ids = outputs[0][input_length:]
            generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)

            print(f"Generated IDs length: {len(generated_ids)}")
            print(f"Decoded generated text: '{generated_text}'")

            # Return the results
            return [{"generated_text": generated_text}]

        except Exception as e:
            import traceback
            print(f"Error during inference: {e}")
            print(traceback.format_exc())
            return [{"error": f"Inference failed: {str(e)}"}]