gemma-3-dnd / handler.py
YFolla's picture
Update handler.py
e7c571a verified
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import os
from typing import Dict, List, Any
class EndpointHandler():
def __init__(self, path=""):
"""
Load the model and tokenizer. The 'path' argument is automatically
passed by the Inference Endpoint infrastructure, pointing to the
directory where your model files are located.
"""
# Use trust_remote_code=True because the base model might rely on it
# or if Unsloth modifications require it, even after merging.
self.tokenizer = AutoTokenizer.from_pretrained(
path,
trust_remote_code=True
)
self.model = AutoModelForCausalLM.from_pretrained(
path,
torch_dtype=torch.bfloat16, # Use bfloat16 for better stability
device_map="auto", # Use GPU if available on the endpoint
trust_remote_code=True
)
# Optional: Explicitly set pad token if needed
# if self.tokenizer.pad_token is None:
# self.tokenizer.pad_token = self.tokenizer.eos_token
print("Handler initialized: Model and tokenizer loaded.")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Handles the inference request using manual generation.
"""
try:
# Extract inputs and parameters
inputs_text = data.pop("inputs", None)
parameters = data.pop("parameters", {})
if inputs_text is None:
return [{"error": "Missing 'inputs' key in request data."}]
# Basic input validation
if not isinstance(inputs_text, str):
return [{"error": "Invalid 'inputs' format. Must be a single string for this handler."}]
# Set generation parameters
params = {
"max_new_tokens": 64,
"temperature": 1.0,
"top_p": 0.95,
"top_k": 64,
"do_sample": True, # Explicitly enable sampling
"pad_token_id": self.tokenizer.eos_token_id # Use EOS for padding
}
# Update with user-provided parameters
params.update(parameters)
print(f"Received input: '{inputs_text}'")
print(f"Using parameters: {params}")
# Manually tokenize
# Important: Add generation prompt structure if needed by the model/tokenizer chat template!
# Assuming the tokenizer's chat template handles adding the prompt correctly when needed.
# If not, you might need manual formatting here before tokenizing.
# Let's try applying the chat template explicitly for robustness:
messages = [{"role": "user", "content": inputs_text}]
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
print(f"Formatted prompt: '{prompt}'")
inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
# Manually generate
# Use torch.no_grad() for efficiency during inference
with torch.no_grad():
outputs = self.model.generate(**inputs, **params)
# Decode the output
# outputs[0] contains the full sequence (prompt + generation)
# We need to decode only the generated part
input_length = inputs.input_ids.shape[1]
generated_ids = outputs[0][input_length:]
generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
print(f"Generated IDs length: {len(generated_ids)}")
print(f"Decoded generated text: '{generated_text}'")
# Return the results
return [{"generated_text": generated_text}]
except Exception as e:
import traceback
print(f"Error during inference: {e}")
print(traceback.format_exc())
return [{"error": f"Inference failed: {str(e)}"}]