import os import logging from fastapi import FastAPI, HTTPException from pydantic import BaseModel from google.cloud import storage from transformers import pipeline import json from google.auth.exceptions import DefaultCredentialsError # Configuración de GCS # Cargar las variables de entorno API_KEY = os.getenv("API_KEY") GCS_BUCKET_NAME = os.getenv("GCS_BUCKET_NAME") GOOGLE_APPLICATION_CREDENTIALS_JSON = os.getenv("GOOGLE_APPLICATION_CREDENTIALS_JSON") HF_API_TOKEN = os.getenv("HF_API_TOKEN") # Configuración de logs logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) try: # Intentar cargar las credenciales de servicio de GCS desde la variable de entorno credentials_info = json.loads(GOOGLE_APPLICATION_CREDENTIALS_JSON) # Cargar el JSON de credenciales storage_client = storage.Client.from_service_account_info(credentials_info) # Crear cliente de GCS bucket = storage_client.bucket(GCS_BUCKET_NAME) # Acceder al bucket # Verificación exitosa logger.info(f"Conexión con Google Cloud Storage exitosa. Bucket: {GCS_BUCKET_NAME}") except (DefaultCredentialsError, json.JSONDecodeError, KeyError, ValueError) as e: # Manejo de errores en caso de que las credenciales sean incorrectas o faltantes logger.error(f"Error al cargar las credenciales o bucket: {e}") raise RuntimeError(f"Error al cargar las credenciales o bucket: {e}") # Configurar la aplicación FastAPI app = FastAPI() # Configuración de logs logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) class PredictionRequest(BaseModel): model_name: str pipeline_task: str input_text: str # Función para obtener la URL del modelo desde GCS def get_gcs_model_url(bucket_name: str, model_name: str): """ Obtiene la URL del modelo desde GCS. """ try: model_dir = f"models/{model_name}/" # Verificar si la carpeta del modelo existe en GCS bucket = storage_client.get_bucket(bucket_name) blobs = bucket.list_blobs(prefix=model_dir) # Verificar si existen archivos en el directorio del modelo file_list = [blob.name for blob in blobs] if not file_list: raise HTTPException(status_code=404, detail="No se encontraron los archivos del modelo en GCS.") # Construir la URL GCS del modelo (en este caso solo la ruta del directorio) gcs_url = f"gs://{bucket_name}/{model_dir}" return gcs_url except Exception as e: logger.error(f"Error al obtener la URL del modelo desde GCS: {str(e)}") raise HTTPException(status_code=500, detail="Error al obtener la URL del modelo desde GCS.") # Función para cargar el pipeline directamente desde GCS como URL def load_pipeline_from_gcs(model_name: str, pipeline_task: str): """ Carga el pipeline directamente desde la URL del modelo en GCS sin usar RAM ni almacenamiento temporal. """ try: # Obtener la URL del modelo desde GCS model_url = get_gcs_model_url(GCS_BUCKET_NAME, model_name) # Cargar el pipeline directamente desde la URL del modelo nlp_pipeline = pipeline( task=pipeline_task, model=model_url, # Usamos la URL de GCS como modelo ) return nlp_pipeline except Exception as e: logger.error(f"Error al cargar el pipeline desde GCS: {str(e)}") raise HTTPException(status_code=500, detail="Error al cargar el pipeline desde GCS.") # Endpoint para realizar la predicción @app.post("/predict") def predict(request: PredictionRequest): """ Endpoint para recibir solicitudes POST con datos JSON y realizar la predicción. """ try: # Extraer los parámetros de la solicitud JSON model_name = request.model_name pipeline_task = request.pipeline_task input_text = request.input_text # Cargar el pipeline directamente desde GCS sin usar RAM ni almacenamiento temporal nlp_pipeline = load_pipeline_from_gcs(model_name, pipeline_task) # Realizar la predicción result = nlp_pipeline(input_text) return {"response": result} except HTTPException as e: logger.error(f"Error en la predicción: {e.detail}") raise e except Exception as e: logger.error(f"Error en la predicción: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)