| | import torch |
| | from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
| | import os |
| | from typing import Dict, List, Any |
| |
|
| | class EndpointHandler(): |
| | def __init__(self, path=""): |
| | """ |
| | Load the model and tokenizer. The 'path' argument is automatically |
| | passed by the Inference Endpoint infrastructure, pointing to the |
| | directory where your model files are located. |
| | """ |
| | |
| | |
| | self.tokenizer = AutoTokenizer.from_pretrained( |
| | path, |
| | trust_remote_code=True |
| | ) |
| | self.model = AutoModelForCausalLM.from_pretrained( |
| | path, |
| | torch_dtype=torch.bfloat16, |
| | device_map="auto", |
| | trust_remote_code=True |
| | ) |
| | |
| | |
| | |
| | print("Handler initialized: Model and tokenizer loaded.") |
| |
|
| |
|
| | def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
| | """ |
| | Handles the inference request using manual generation. |
| | """ |
| | try: |
| | |
| | inputs_text = data.pop("inputs", None) |
| | parameters = data.pop("parameters", {}) |
| |
|
| | if inputs_text is None: |
| | return [{"error": "Missing 'inputs' key in request data."}] |
| |
|
| | |
| | if not isinstance(inputs_text, str): |
| | return [{"error": "Invalid 'inputs' format. Must be a single string for this handler."}] |
| |
|
| | |
| | params = { |
| | "max_new_tokens": 64, |
| | "temperature": 1.0, |
| | "top_p": 0.95, |
| | "top_k": 64, |
| | "do_sample": True, |
| | "pad_token_id": self.tokenizer.eos_token_id |
| | } |
| | |
| | params.update(parameters) |
| |
|
| | print(f"Received input: '{inputs_text}'") |
| | print(f"Using parameters: {params}") |
| |
|
| | |
| | |
| | |
| | |
| | |
| | messages = [{"role": "user", "content": inputs_text}] |
| | prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| |
|
| | print(f"Formatted prompt: '{prompt}'") |
| |
|
| | inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) |
| |
|
| | |
| | |
| | with torch.no_grad(): |
| | outputs = self.model.generate(**inputs, **params) |
| |
|
| | |
| | |
| | |
| | input_length = inputs.input_ids.shape[1] |
| | generated_ids = outputs[0][input_length:] |
| | generated_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True) |
| |
|
| | print(f"Generated IDs length: {len(generated_ids)}") |
| | print(f"Decoded generated text: '{generated_text}'") |
| |
|
| | |
| | return [{"generated_text": generated_text}] |
| |
|
| | except Exception as e: |
| | import traceback |
| | print(f"Error during inference: {e}") |
| | print(traceback.format_exc()) |
| | return [{"error": f"Inference failed: {str(e)}"}] |
| |
|