|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import torch |
|
import json |
|
|
|
class EndpointHandler: |
|
def __init__(self, model_dir): |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir) |
|
self.model.eval() |
|
|
|
def preprocess(self, data): |
|
|
|
if isinstance(data, dict) and "inputs" in data: |
|
input_text = "Generate a valid JSON capturing data from this text: " + data["inputs"] |
|
else: |
|
raise ValueError("Esperando un diccionario con la clave 'inputs'") |
|
|
|
|
|
tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding=True) |
|
return tokens |
|
|
|
def inference(self, tokens): |
|
|
|
with torch.no_grad(): |
|
outputs = self.model.generate(**tokens) |
|
return outputs |
|
|
|
def postprocess(self, outputs): |
|
|
|
decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return {"generated_text": decoded_output} |
|
|
|
def __call__(self, data): |
|
|
|
tokens = self.preprocess(data) |
|
outputs = self.inference(tokens) |
|
result = self.postprocess(outputs) |
|
return result |
|
|