|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
import torch |
|
|
|
class EndpointHandler: |
|
def __init__(self, model_dir): |
|
""" |
|
Inicializa el handler con el modelo y tokenizador. |
|
""" |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir) |
|
self.model.eval() |
|
|
|
def preprocess(self, data): |
|
""" |
|
Preprocesa los datos de entrada para el modelo. |
|
""" |
|
|
|
if not isinstance(data, dict) or "inputs" not in data: |
|
raise ValueError("Entrada inválida. Debe ser un diccionario con la clave 'inputs'.") |
|
|
|
input_text = f"Generate a valid JSON capturing data from this text: {data['inputs']}" |
|
|
|
tokens = self.tokenizer( |
|
input_text, |
|
return_tensors="pt", |
|
truncation=True, |
|
padding="max_length", |
|
max_length=512 |
|
) |
|
return tokens |
|
|
|
def inference(self, inputs): |
|
""" |
|
Realiza la inferencia con el modelo. |
|
""" |
|
generate_kwargs = { |
|
"max_length": 512, |
|
"num_beams": 5, |
|
"do_sample": False, |
|
"temperature": 0.7, |
|
"top_k": 50, |
|
"top_p": 0.9, |
|
"repetition_penalty": 2.0, |
|
"early_stopping": True |
|
} |
|
with torch.no_grad(): |
|
outputs = self.model.generate(**inputs, **generate_kwargs) |
|
return outputs |
|
|
|
def postprocess(self, model_outputs): |
|
""" |
|
Procesa las salidas del modelo para devolver resultados. |
|
""" |
|
|
|
decoded_output = self.tokenizer.decode(model_outputs[0], skip_special_tokens=True) |
|
return {"response": decoded_output} |
|
|
|
def __call__(self, data): |
|
""" |
|
Ejecuta el pipeline de preprocesamiento, inferencia y postprocesamiento. |
|
""" |
|
|
|
tokens = self.preprocess(data) |
|
|
|
model_outputs = self.inference(tokens) |
|
|
|
return self.postprocess(model_outputs) |
|
|