Hjgugugjhuhjggg commited on
Commit
d8245fc
verified
1 Parent(s): 49f2c5a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -92
app.py CHANGED
@@ -1,128 +1,123 @@
1
  import os
2
- import json
3
- import requests
4
  from fastapi import FastAPI, HTTPException
5
  from pydantic import BaseModel
6
  from google.cloud import storage
7
- from google.auth import exceptions
8
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
9
- from io import BytesIO
10
- from dotenv import load_dotenv
11
- import uvicorn
12
- import tempfile
13
 
14
- load_dotenv()
15
 
 
16
  API_KEY = os.getenv("API_KEY")
17
  GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
18
  GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
19
  HF_API_TOKEN = os.getenv("HF_API_TOKEN")
20
 
 
 
 
 
21
  try:
22
- credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON)
23
- storage_client = storage.Client.from_service_account_info(credentials_info)
24
- bucket = storage_client.bucket(GCS_BUCKET_NAME)
25
- except (exceptions.DefaultCredentialsError, json.JSONDecodeError, KeyError, ValueError) as e:
26
- raise RuntimeError(f"Error al cargar credenciales o bucket: {e}")
 
 
 
 
 
 
 
27
 
 
28
  app = FastAPI()
29
 
30
- class DownloadModelRequest(BaseModel):
 
 
 
 
31
  model_name: str
32
  pipeline_task: str
33
  input_text: str
34
 
35
- class GCSHandler:
36
- def __init__(self, bucket_name):
37
- self.bucket = storage_client.bucket(bucket_name)
38
-
39
- def file_exists(self, blob_name):
40
- return self.bucket.blob(blob_name).exists()
 
 
 
 
 
41
 
42
- def upload_file(self, blob_name, file_stream):
43
- blob = self.bucket.blob(blob_name)
44
- blob.upload_from_file(file_stream)
 
 
 
 
 
 
45
 
46
- def download_file(self, blob_name):
47
- blob = self.bucket.blob(blob_name)
48
- if not blob.exists():
49
- raise HTTPException(status_code=404, detail=f"File '{blob_name}' not found.")
50
- return BytesIO(blob.download_as_bytes())
51
 
52
- def download_model_from_huggingface(model_name):
53
- url = f"https://huggingface.co/{model_name}/tree/main"
54
- headers = {"Authorization": f"Bearer {HF_API_TOKEN}"}
55
-
 
56
  try:
57
- response = requests.get(url, headers=headers)
58
- if response.status_code == 200:
59
- model_files = [
60
- "pytorch_model.bin",
61
- "config.json",
62
- "tokenizer.json",
63
- "model.safetensors",
64
- ]
65
- for file_name in model_files:
66
- file_url = f"https://huggingface.co/{model_name}/resolve/main/{file_name}"
67
- file_content = requests.get(file_url).content
68
- blob_name = f"{model_name}/{file_name}"
69
- bucket.blob(blob_name).upload_from_file(BytesIO(file_content))
70
- else:
71
- raise HTTPException(status_code=404, detail="Error al acceder al 谩rbol de archivos de Hugging Face.")
72
  except Exception as e:
73
- raise HTTPException(status_code=500, detail=f"Error descargando archivos de Hugging Face: {e}")
 
74
 
75
- @app.post("/predict/")
76
- async def predict(request: DownloadModelRequest):
 
 
 
 
77
  try:
78
- gcs_handler = GCSHandler(GCS_BUCKET_NAME)
79
- model_prefix = request.model_name
80
- model_files = [
81
- "pytorch_model.bin",
82
- "config.json",
83
- "tokenizer.json",
84
- "model.safetensors",
85
- ]
86
-
87
- model_files_exist = all(gcs_handler.file_exists(f"{model_prefix}/{file}") for file in model_files)
88
 
89
- if not model_files_exist:
90
- download_model_from_huggingface(model_prefix)
91
 
92
- model_files_streams = {file: gcs_handler.download_file(f"{model_prefix}/{file}") for file in model_files if gcs_handler.file_exists(f"{model_prefix}/{file}")}
93
-
94
- config_stream = model_files_streams.get("config.json")
95
- tokenizer_stream = model_files_streams.get("tokenizer.json")
96
- model_stream = model_files_streams.get("pytorch_model.bin")
97
-
98
- if not config_stream or not tokenizer_stream or not model_stream:
99
- raise HTTPException(status_code=500, detail="Required model files missing.")
100
-
101
- with tempfile.TemporaryDirectory() as tmp_dir:
102
- config_path = os.path.join(tmp_dir, "config.json")
103
- tokenizer_path = os.path.join(tmp_dir, "tokenizer.json")
104
- model_path = os.path.join(tmp_dir, "pytorch_model.bin")
105
-
106
- with open(config_path, 'wb') as f:
107
- f.write(config_stream.read())
108
- with open(tokenizer_path, 'wb') as f:
109
- f.write(tokenizer_stream.read())
110
- with open(model_path, 'wb') as f:
111
- f.write(model_stream.read())
112
-
113
- model = AutoModelForCausalLM.from_pretrained(tmp_dir, from_tf=True)
114
- tokenizer = AutoTokenizer.from_pretrained(tmp_dir)
115
-
116
- pipeline_ = pipeline(request.pipeline_task, model=model, tokenizer=tokenizer)
117
-
118
- result = pipeline_(request.input_text)
119
 
120
  return {"response": result}
121
-
122
  except HTTPException as e:
 
123
  raise e
124
  except Exception as e:
125
- raise HTTPException(status_code=500, detail=f"Error: {e}")
 
126
 
127
  if __name__ == "__main__":
 
128
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import os
2
+ import logging
 
3
  from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel
5
  from google.cloud import storage
6
+ from transformers import pipeline
7
+ import json
8
+ from google.auth.exceptions import DefaultCredentialsError
 
 
 
9
 
10
+ # Configuraci贸n de GCS
11
 
12
+ # Cargar las variables de entorno
13
  API_KEY = os.getenv("API_KEY")
14
  GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME")
15
  GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON")
16
  HF_API_TOKEN = os.getenv("HF_API_TOKEN")
17
 
18
+ # Configuraci贸n de logs
19
+ logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
20
+ logger = logging.getLogger(__name__)
21
+
22
  try:
23
+ # Intentar cargar las credenciales de servicio de GCS desde la variable de entorno
24
+ credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON) # Cargar el JSON de credenciales
25
+ storage_client = storage.Client.from_service_account_info(credentials_info) # Crear cliente de GCS
26
+ bucket = storage_client.bucket(GCS_BUCKET_NAME) # Acceder al bucket
27
+
28
+ # Verificaci贸n exitosa
29
+ logger.info(f"Conexi贸n con Google Cloud Storage exitosa. Bucket: {GCS_BUCKET_NAME}")
30
+
31
+ except (DefaultCredentialsError, json.JSONDecodeError, KeyError, ValueError) as e:
32
+ # Manejo de errores en caso de que las credenciales sean incorrectas o faltantes
33
+ logger.error(f"Error al cargar las credenciales o bucket: {e}")
34
+ raise RuntimeError(f"Error al cargar las credenciales o bucket: {e}")
35
 
36
+ # Configurar la aplicaci贸n FastAPI
37
  app = FastAPI()
38
 
39
+ # Configuraci贸n de logs
40
+ logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
41
+ logger = logging.getLogger(__name__)
42
+
43
+ class PredictionRequest(BaseModel):
44
  model_name: str
45
  pipeline_task: str
46
  input_text: str
47
 
48
+ # Funci贸n para obtener la URL del modelo desde GCS
49
+ def get_gcs_model_url(bucket_name: str, model_name: str):
50
+ """
51
+ Obtiene la URL del modelo desde GCS.
52
+ """
53
+ try:
54
+ model_dir = f"models/{model_name}/"
55
+
56
+ # Verificar si la carpeta del modelo existe en GCS
57
+ bucket = storage_client.get_bucket(bucket_name)
58
+ blobs = bucket.list_blobs(prefix=model_dir)
59
 
60
+ # Verificar si existen archivos en el directorio del modelo
61
+ file_list = [blob.name for blob in blobs]
62
+ if not file_list:
63
+ raise HTTPException(status_code=404, detail="No se encontraron los archivos del modelo en GCS.")
64
+
65
+ # Construir la URL GCS del modelo (en este caso solo la ruta del directorio)
66
+ gcs_url = f"gs://{bucket_name}/{model_dir}"
67
+
68
+ return gcs_url
69
 
70
+ except Exception as e:
71
+ logger.error(f"Error al obtener la URL del modelo desde GCS: {str(e)}")
72
+ raise HTTPException(status_code=500, detail="Error al obtener la URL del modelo desde GCS.")
 
 
73
 
74
+ # Funci贸n para cargar el pipeline directamente desde GCS como URL
75
+ def load_pipeline_from_gcs(model_name: str, pipeline_task: str):
76
+ """
77
+ Carga el pipeline directamente desde la URL del modelo en GCS sin usar RAM ni almacenamiento temporal.
78
+ """
79
  try:
80
+ # Obtener la URL del modelo desde GCS
81
+ model_url = get_gcs_model_url(GCS_BUCKET_NAME, model_name)
82
+
83
+ # Cargar el pipeline directamente desde la URL del modelo
84
+ nlp_pipeline = pipeline(
85
+ task=pipeline_task,
86
+ model=model_url, # Usamos la URL de GCS como modelo
87
+ )
88
+
89
+ return nlp_pipeline
 
 
 
 
 
90
  except Exception as e:
91
+ logger.error(f"Error al cargar el pipeline desde GCS: {str(e)}")
92
+ raise HTTPException(status_code=500, detail="Error al cargar el pipeline desde GCS.")
93
 
94
+ # Endpoint para realizar la predicci贸n
95
+ @app.post("/predict")
96
+ def predict(request: PredictionRequest):
97
+ """
98
+ Endpoint para recibir solicitudes POST con datos JSON y realizar la predicci贸n.
99
+ """
100
  try:
101
+ # Extraer los par谩metros de la solicitud JSON
102
+ model_name = request.model_name
103
+ pipeline_task = request.pipeline_task
104
+ input_text = request.input_text
 
 
 
 
 
 
105
 
106
+ # Cargar el pipeline directamente desde GCS sin usar RAM ni almacenamiento temporal
107
+ nlp_pipeline = load_pipeline_from_gcs(model_name, pipeline_task)
108
 
109
+ # Realizar la predicci贸n
110
+ result = nlp_pipeline(input_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
  return {"response": result}
113
+
114
  except HTTPException as e:
115
+ logger.error(f"Error en la predicci贸n: {e.detail}")
116
  raise e
117
  except Exception as e:
118
+ logger.error(f"Error en la predicci贸n: {str(e)}")
119
+ raise HTTPException(status_code=500, detail=str(e))
120
 
121
  if __name__ == "__main__":
122
+ import uvicorn
123
  uvicorn.run(app, host="0.0.0.0", port=7860)