|
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline |
|
import torch |
|
|
|
class EndpointHandler: |
|
def __init__(self, path="krisoei/timgpt"): |
|
if not path: |
|
raise ValueError("A valid model path or name must be provided.") |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(path) |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
path, |
|
torch_dtype=torch.float16, |
|
device_map="auto" |
|
) |
|
|
|
|
|
self.pipe = pipeline( |
|
"text-generation", |
|
model=self.model, |
|
tokenizer=self.tokenizer, |
|
max_new_tokens=512, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_p=0.95, |
|
) |
|
|
|
def __call__(self, data): |
|
|
|
if not isinstance(data, dict): |
|
return {"error": "Input must be a JSON object."} |
|
|
|
prompt = data.get("inputs", "") |
|
if not prompt: |
|
return {"error": "No input provided."} |
|
|
|
try: |
|
|
|
outputs = self.pipe(prompt) |
|
if outputs: |
|
response = outputs[0]['generated_text'] |
|
|
|
response = response[len(prompt):].strip() |
|
return {"generated_text": response} |
|
else: |
|
return {"error": "No output generated."} |
|
except Exception as e: |
|
return {"error": str(e)} |
|
|