bart-large-cnn / handler.py
booksouls's picture
add handler.py
8fb1919 verified
raw
history blame
1.74 kB
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from typing import Any
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class EndpointHandler():
def __init__(self, path=""):
self.model = AutoModelForSeq2SeqLM.from_pretrained(path).to(device)
self.tokenizer = AutoTokenizer.from_pretrained(path)
def __call__(self, data: dict[str, Any]) -> dict[str, Any]:
inputs = data.get("inputs")
parameters = data.get("parameters")
if inputs is None:
raise ValueError(f"'inputs' is missing from the request body")
if not isinstance(inputs, str):
raise ValueError(f"Expected 'inputs' to be a str, but found {type(inputs)}")
if parameters is not None and not isinstance(parameters, dict):
raise ValueError(f"Expected 'parameters' to be a dict, but found {type(parameters)}")
# Truncate the tokens to 1024 to prevent errors with BART and long text.
tokens = self.tokenizer(
inputs,
max_length=1024,
truncation=True,
return_tensors="pt",
return_attention_mask=False,
)
# Ensure the input_ids and the model are on the same device to prevent errors.
input_ids = tokens.input_ids.to(device)
# Gradient calculation is not needed for inference.
with torch.no_grad():
if parameters is None:
output = self.model.generate(input_ids)
else:
output = self.model.generate(input_ids, **parameters)
generated_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
return {"generated_text": generated_text}