amaye15
Feat - Delet Embeddings - Updated
a106258
raw
history blame
6.18 kB
import os
from fastapi import FastAPI, Depends, HTTPException
from fastapi.responses import JSONResponse, RedirectResponse
from pydantic import BaseModel
from typing import List, Dict
from src.api.models.embedding_models import (
CreateEmbeddingRequest,
UpdateEmbeddingRequest,
DeleteEmbeddingRequest,
)
from src.api.database import get_db, Database, QueryExecutionError, HealthCheckError
from src.api.services.embedding_service import EmbeddingService
from src.api.services.huggingface_service import HuggingFaceService
from src.api.exceptions import DatasetNotFoundError, DatasetPushError, OpenAIError
import pandas as pd
import logging
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Set up structured logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)
description = """A FastAPI application for similarity search with PostgreSQL and OpenAI embeddings.
Direct/API URL:
https://re-mind-similarity-search.hf.space
"""
# Initialize FastAPI app
app = FastAPI(
title="Similarity Search API",
description=description,
version="1.0.0",
)
# Root endpoint redirects to /docs
@app.get("/")
async def root():
return RedirectResponse(url="/docs")
# Health check endpoint
@app.get("/health")
async def health_check(db: Database = Depends(get_db)):
try:
is_healthy = await db.health_check()
if not is_healthy:
raise HTTPException(status_code=500, detail="Database is unhealthy")
return {"status": "healthy"}
except HealthCheckError as e:
raise HTTPException(status_code=500, detail=str(e))
# Dependency to get EmbeddingService
def get_embedding_service() -> EmbeddingService:
return EmbeddingService(openai_api_key=os.getenv("OPENAI_API_KEY"))
# Dependency to get HuggingFaceService
def get_huggingface_service() -> HuggingFaceService:
return HuggingFaceService()
# Endpoint to create embeddings
@app.post("/create_embedding")
async def create_embedding(
request: CreateEmbeddingRequest,
db: Database = Depends(get_db),
embedding_service: EmbeddingService = Depends(get_embedding_service),
huggingface_service: HuggingFaceService = Depends(get_huggingface_service),
):
"""
Create embeddings for the target column in the dataset.
"""
try:
# Step 1: Query the database
logger.info("Fetching data from the database...")
result = await db.fetch(request.query)
df = pd.DataFrame(result)
# Step 2: Generate embeddings
df = await embedding_service.create_embeddings(
df, request.target_column, request.output_column
)
# Step 3: Push to Hugging Face Hub
await huggingface_service.push_to_hub(df, request.dataset_name)
return JSONResponse(
content={
"message": "Embeddings created and pushed to Hugging Face Hub.",
"dataset_name": request.dataset_name,
"num_rows": len(df),
}
)
except QueryExecutionError as e:
logger.error(f"Database query failed: {e}")
raise HTTPException(status_code=500, detail=f"Database query failed: {e}")
except OpenAIError as e:
logger.error(f"OpenAI API error: {e}")
raise HTTPException(status_code=500, detail=f"OpenAI API error: {e}")
except DatasetPushError as e:
logger.error(f"Failed to push dataset: {e}")
raise HTTPException(status_code=500, detail=f"Failed to push dataset: {e}")
except Exception as e:
logger.error(f"An error occurred: {e}")
raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
# Endpoint to read embeddings
@app.get("/read_embeddings/{dataset_name}")
async def read_embeddings(
dataset_name: str,
huggingface_service: HuggingFaceService = Depends(get_huggingface_service),
):
"""
Read embeddings from a Hugging Face dataset.
"""
try:
df = await huggingface_service.read_dataset(dataset_name)
return df.to_dict(orient="records")
except DatasetNotFoundError as e:
logger.error(f"Dataset not found: {e}")
raise HTTPException(status_code=404, detail=f"Dataset not found: {e}")
except Exception as e:
logger.error(f"An error occurred: {e}")
raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
# Endpoint to update embeddings
@app.post("/update_embeddings")
async def update_embeddings(
request: UpdateEmbeddingRequest,
huggingface_service: HuggingFaceService = Depends(get_huggingface_service),
):
"""
Update embeddings in a Hugging Face dataset.
"""
try:
df = await huggingface_service.update_dataset(
request.dataset_name, request.updates
)
return {
"message": "Embeddings updated successfully.",
"dataset_name": request.dataset_name,
}
except DatasetPushError as e:
logger.error(f"Failed to update dataset: {e}")
raise HTTPException(status_code=500, detail=f"Failed to update dataset: {e}")
except Exception as e:
logger.error(f"An error occurred: {e}")
raise HTTPException(status_code=500, detail=f"An error occurred: {e}")
# Endpoint to delete embeddings
@app.post("/delete_embeddings")
async def delete_embeddings(
request: DeleteEmbeddingRequest,
huggingface_service: HuggingFaceService = Depends(get_huggingface_service),
):
"""
Delete embeddings from a Hugging Face dataset.
"""
try:
await huggingface_service.delete_dataset(
request.dataset_name
)
return {
"message": "Embeddings deleted successfully.",
"dataset_name": request.dataset_name,
}
except DatasetPushError as e:
logger.error(f"Failed to delete columns: {e}")
raise HTTPException(status_code=500, detail=f"Failed to delete columns: {e}")
except Exception as e:
logger.error(f"An error occurred: {e}")
raise HTTPException(status_code=500, detail=f"An error occurred: {e}")