|
import torch |
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
from typing import Any |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path=""): |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(path).to(device) |
|
self.tokenizer = AutoTokenizer.from_pretrained(path) |
|
|
|
def __call__(self, data: dict[str, Any]) -> dict[str, Any]: |
|
inputs = data.get("inputs") |
|
parameters = data.get("parameters") |
|
|
|
if inputs is None: |
|
raise ValueError(f"'inputs' is missing from the request body") |
|
|
|
if not isinstance(inputs, str): |
|
raise ValueError(f"Expected 'inputs' to be a str, but found {type(inputs)}") |
|
|
|
if parameters is not None and not isinstance(parameters, dict): |
|
raise ValueError(f"Expected 'parameters' to be a dict, but found {type(parameters)}") |
|
|
|
|
|
tokens = self.tokenizer( |
|
inputs, |
|
max_length=1024, |
|
truncation=True, |
|
return_tensors="pt", |
|
return_attention_mask=False, |
|
) |
|
|
|
|
|
input_ids = tokens.input_ids.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
if parameters is None: |
|
output = self.model.generate(input_ids) |
|
else: |
|
output = self.model.generate(input_ids, **parameters) |
|
|
|
generated_text = self.tokenizer.decode(output[0], skip_special_tokens=True) |
|
return {"generated_text": generated_text} |
|
|