squareV4 / handler.py
jla25's picture
Update handler.py
dac55f6 verified
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
class EndpointHandler:
def __init__(self, model_dir):
"""
Inicializa el handler con el modelo y tokenizador.
"""
# Cargar el tokenizador y el modelo desde el directorio proporcionado
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
self.model.eval() # Poner el modelo en modo evaluación
def preprocess(self, data):
"""
Preprocesa los datos de entrada para el modelo.
"""
# Validar entrada
if not isinstance(data, dict) or "inputs" not in data:
raise ValueError("Entrada inválida. Debe ser un diccionario con la clave 'inputs'.")
input_text = f"Generate a valid JSON capturing data from this text: {data['inputs']}"
# Tokenizar entrada
tokens = self.tokenizer(
input_text,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=512
)
return tokens
def inference(self, inputs):
"""
Realiza la inferencia con el modelo.
"""
generate_kwargs = {
"max_length": 512,
"num_beams": 5,
"do_sample": False,
"temperature": 0.7,
"top_k": 50,
"top_p": 0.9,
"repetition_penalty": 2.0,
"early_stopping": True # Asegurar que no sea None
}
with torch.no_grad():
outputs = self.model.generate(**inputs, **generate_kwargs)
return outputs
def postprocess(self, model_outputs):
"""
Procesa las salidas del modelo para devolver resultados.
"""
# Decodificar la salida generada por el modelo
decoded_output = self.tokenizer.decode(model_outputs[0], skip_special_tokens=True)
return {"response": decoded_output}
def __call__(self, data):
"""
Ejecuta el pipeline de preprocesamiento, inferencia y postprocesamiento.
"""
# Preprocesar entrada
tokens = self.preprocess(data)
# Realizar inferencia
model_outputs = self.inference(tokens)
# Postprocesar y devolver resultados
return self.postprocess(model_outputs)