File size: 1,539 Bytes
032fa14
 
 
 
3863700
258e3f1
032fa14
 
 
 
 
 
 
3863700
f460c33
032fa14
3863700
032fa14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258e3f1
 
032fa14
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import json

class EndpointHandler:
    def __init__(self, model_dir):
        # Cargar el modelo y el tokenizador desde el directorio del modelo
        self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
        self.model.eval()  # Configurar el modelo en modo de evaluación

    def preprocess(self, data):
        # Preprocesamiento de la entrada
        if isinstance(data, dict) and "inputs" in data:
            input_text = "Generate a valid JSON capturing data from this text: " + data["inputs"]
        else:
            raise ValueError("Esperando un diccionario con la clave 'inputs'")

        # Tokenización de la entrada
        tokens = self.tokenizer(input_text, return_tensors="pt", truncation=True, padding=True)
        return tokens

    def inference(self, tokens):
        # Realizar la inferencia
        with torch.no_grad():
            outputs = self.model.generate(**tokens)
        return outputs

    def postprocess(self, outputs):
        # Decodificar la salida del modelo
        decoded_output = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        return {"generated_text": decoded_output}

    def __call__(self, data):
        # Llamada principal del handler para procesamiento completo
        tokens = self.preprocess(data)
        outputs = self.inference(tokens)
        result = self.postprocess(outputs)
        return result