jla25 commited on
Commit
dac55f6
verified
1 Parent(s): a8b73a7

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +44 -45
handler.py CHANGED
@@ -1,69 +1,68 @@
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
  import torch
3
- import json
4
-
5
-
6
- model_name = "jla25/squareV4"
7
-
8
- tokenizer = AutoTokenizer.from_pretrained(model_name)
9
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
10
-
11
 
12
  class EndpointHandler:
13
  def __init__(self, model_dir):
 
 
 
 
14
  self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
15
  self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
16
- self.model.eval()
17
 
18
  def preprocess(self, data):
19
- if not isinstance(data, dict) or "inputs" not in data or data["inputs"] is None:
20
- raise ValueError("La entrada debe ser un diccionario con la clave 'inputs' y un valor v谩lido.")
 
 
 
 
21
 
22
- # Prompt personalizado para guiar al modelo
23
- input_text = f"Generate a valid JSON capturing data from this text:{data['inputs']}"
24
- print(f"Prompt generado para el modelo: {input_text}")
25
- input_text = input_text.encode("utf-8").decode("utf-8")
26
- tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding="max_length", max_length=1024)
 
 
 
 
27
  return tokens
28
 
29
- def inference(self, tokens):
 
 
 
30
  generate_kwargs = {
31
  "max_length": 512,
32
  "num_beams": 5,
33
  "do_sample": False,
34
- "temperature": 0.3,
35
  "top_k": 50,
36
- "top_p": 0.8,
37
- "repetition_penalty": 2.5
 
38
  }
39
  with torch.no_grad():
40
- outputs = self.model.generate(**tokens, **generate_kwargs)
41
  return outputs
42
 
43
- def clean_output(self, output):
44
- try:
45
- start_index = output.index("{")
46
- end_index = output.rindex("}") + 1
47
- return output[start_index:end_index]
48
- except ValueError:
49
- return output
50
-
51
- def postprocess(self, outputs):
52
- decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
53
- cleaned_output = self.clean_output(decoded_output)
54
-
55
- # Imprimir siempre el texto generado para depuraci贸n
56
- print(f"Texto generado por el modelo: {decoded_output}")
57
- print(f"JSON limpiado: {cleaned_output}")
58
-
59
- return {"response": cleaned_output}
60
 
61
  def __call__(self, data):
 
 
 
 
62
  tokens = self.preprocess(data)
63
- outputs = self.inference(tokens)
64
- result = self.postprocess(outputs)
65
- return result
66
-
67
-
68
- # Crear una instancia del handler
69
- handler = EndpointHandler(model_name)
 
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
  import torch
 
 
 
 
 
 
 
 
3
 
4
  class EndpointHandler:
5
  def __init__(self, model_dir):
6
+ """
7
+ Inicializa el handler con el modelo y tokenizador.
8
+ """
9
+ # Cargar el tokenizador y el modelo desde el directorio proporcionado
10
  self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
11
  self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
12
+ self.model.eval() # Poner el modelo en modo evaluaci贸n
13
 
14
  def preprocess(self, data):
15
+ """
16
+ Preprocesa los datos de entrada para el modelo.
17
+ """
18
+ # Validar entrada
19
+ if not isinstance(data, dict) or "inputs" not in data:
20
+ raise ValueError("Entrada inv谩lida. Debe ser un diccionario con la clave 'inputs'.")
21
 
22
+ input_text = f"Generate a valid JSON capturing data from this text: {data['inputs']}"
23
+ # Tokenizar entrada
24
+ tokens = self.tokenizer(
25
+ input_text,
26
+ return_tensors="pt",
27
+ truncation=True,
28
+ padding="max_length",
29
+ max_length=512
30
+ )
31
  return tokens
32
 
33
+ def inference(self, inputs):
34
+ """
35
+ Realiza la inferencia con el modelo.
36
+ """
37
  generate_kwargs = {
38
  "max_length": 512,
39
  "num_beams": 5,
40
  "do_sample": False,
41
+ "temperature": 0.7,
42
  "top_k": 50,
43
+ "top_p": 0.9,
44
+ "repetition_penalty": 2.0,
45
+ "early_stopping": True # Asegurar que no sea None
46
  }
47
  with torch.no_grad():
48
+ outputs = self.model.generate(**inputs, **generate_kwargs)
49
  return outputs
50
 
51
+ def postprocess(self, model_outputs):
52
+ """
53
+ Procesa las salidas del modelo para devolver resultados.
54
+ """
55
+ # Decodificar la salida generada por el modelo
56
+ decoded_output = self.tokenizer.decode(model_outputs[0], skip_special_tokens=True)
57
+ return {"response": decoded_output}
 
 
 
 
 
 
 
 
 
 
58
 
59
  def __call__(self, data):
60
+ """
61
+ Ejecuta el pipeline de preprocesamiento, inferencia y postprocesamiento.
62
+ """
63
+ # Preprocesar entrada
64
  tokens = self.preprocess(data)
65
+ # Realizar inferencia
66
+ model_outputs = self.inference(tokens)
67
+ # Postprocesar y devolver resultados
68
+ return self.postprocess(model_outputs)