Update handler.py
Browse files- handler.py +44 -45
handler.py
CHANGED
@@ -1,69 +1,68 @@
|
|
1 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
2 |
import torch
|
3 |
-
import json
|
4 |
-
|
5 |
-
|
6 |
-
model_name = "jla25/squareV4"
|
7 |
-
|
8 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
9 |
-
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
10 |
-
|
11 |
|
12 |
class EndpointHandler:
|
13 |
def __init__(self, model_dir):
|
|
|
|
|
|
|
|
|
14 |
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
15 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
|
16 |
-
self.model.eval()
|
17 |
|
18 |
def preprocess(self, data):
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
21 |
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
27 |
return tokens
|
28 |
|
29 |
-
def inference(self,
|
|
|
|
|
|
|
30 |
generate_kwargs = {
|
31 |
"max_length": 512,
|
32 |
"num_beams": 5,
|
33 |
"do_sample": False,
|
34 |
-
"temperature": 0.
|
35 |
"top_k": 50,
|
36 |
-
"top_p": 0.
|
37 |
-
"repetition_penalty": 2.
|
|
|
38 |
}
|
39 |
with torch.no_grad():
|
40 |
-
outputs = self.model.generate(**
|
41 |
return outputs
|
42 |
|
43 |
-
def
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
def postprocess(self, outputs):
|
52 |
-
decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
53 |
-
cleaned_output = self.clean_output(decoded_output)
|
54 |
-
|
55 |
-
# Imprimir siempre el texto generado para depuraci贸n
|
56 |
-
print(f"Texto generado por el modelo: {decoded_output}")
|
57 |
-
print(f"JSON limpiado: {cleaned_output}")
|
58 |
-
|
59 |
-
return {"response": cleaned_output}
|
60 |
|
61 |
def __call__(self, data):
|
|
|
|
|
|
|
|
|
62 |
tokens = self.preprocess(data)
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
# Crear una instancia del handler
|
69 |
-
handler = EndpointHandler(model_name)
|
|
|
1 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
2 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
class EndpointHandler:
|
5 |
def __init__(self, model_dir):
|
6 |
+
"""
|
7 |
+
Inicializa el handler con el modelo y tokenizador.
|
8 |
+
"""
|
9 |
+
# Cargar el tokenizador y el modelo desde el directorio proporcionado
|
10 |
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
11 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
|
12 |
+
self.model.eval() # Poner el modelo en modo evaluaci贸n
|
13 |
|
14 |
def preprocess(self, data):
|
15 |
+
"""
|
16 |
+
Preprocesa los datos de entrada para el modelo.
|
17 |
+
"""
|
18 |
+
# Validar entrada
|
19 |
+
if not isinstance(data, dict) or "inputs" not in data:
|
20 |
+
raise ValueError("Entrada inv谩lida. Debe ser un diccionario con la clave 'inputs'.")
|
21 |
|
22 |
+
input_text = f"Generate a valid JSON capturing data from this text: {data['inputs']}"
|
23 |
+
# Tokenizar entrada
|
24 |
+
tokens = self.tokenizer(
|
25 |
+
input_text,
|
26 |
+
return_tensors="pt",
|
27 |
+
truncation=True,
|
28 |
+
padding="max_length",
|
29 |
+
max_length=512
|
30 |
+
)
|
31 |
return tokens
|
32 |
|
33 |
+
def inference(self, inputs):
|
34 |
+
"""
|
35 |
+
Realiza la inferencia con el modelo.
|
36 |
+
"""
|
37 |
generate_kwargs = {
|
38 |
"max_length": 512,
|
39 |
"num_beams": 5,
|
40 |
"do_sample": False,
|
41 |
+
"temperature": 0.7,
|
42 |
"top_k": 50,
|
43 |
+
"top_p": 0.9,
|
44 |
+
"repetition_penalty": 2.0,
|
45 |
+
"early_stopping": True # Asegurar que no sea None
|
46 |
}
|
47 |
with torch.no_grad():
|
48 |
+
outputs = self.model.generate(**inputs, **generate_kwargs)
|
49 |
return outputs
|
50 |
|
51 |
+
def postprocess(self, model_outputs):
|
52 |
+
"""
|
53 |
+
Procesa las salidas del modelo para devolver resultados.
|
54 |
+
"""
|
55 |
+
# Decodificar la salida generada por el modelo
|
56 |
+
decoded_output = self.tokenizer.decode(model_outputs[0], skip_special_tokens=True)
|
57 |
+
return {"response": decoded_output}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
def __call__(self, data):
|
60 |
+
"""
|
61 |
+
Ejecuta el pipeline de preprocesamiento, inferencia y postprocesamiento.
|
62 |
+
"""
|
63 |
+
# Preprocesar entrada
|
64 |
tokens = self.preprocess(data)
|
65 |
+
# Realizar inferencia
|
66 |
+
model_outputs = self.inference(tokens)
|
67 |
+
# Postprocesar y devolver resultados
|
68 |
+
return self.postprocess(model_outputs)
|
|
|
|
|
|