File size: 6,180 Bytes
2cb9dec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b85ea78
 
 
 
 
 
2cb9dec
 
 
b9c19b4
2cb9dec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a106258
 
2cb9dec
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import os
from fastapi import FastAPI, Depends, HTTPException
from fastapi.responses import JSONResponse, RedirectResponse
from pydantic import BaseModel
from typing import List, Dict
from src.api.models.embedding_models import (
    CreateEmbeddingRequest,
    UpdateEmbeddingRequest,
    DeleteEmbeddingRequest,
)
from src.api.database import get_db, Database, QueryExecutionError, HealthCheckError
from src.api.services.embedding_service import EmbeddingService
from src.api.services.huggingface_service import HuggingFaceService
from src.api.exceptions import DatasetNotFoundError, DatasetPushError, OpenAIError
import pandas as pd
import logging
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Set up structured logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

description = """A FastAPI application for similarity search with PostgreSQL and OpenAI embeddings.

Direct/API URL:
https://re-mind-similarity-search.hf.space
"""

# Initialize FastAPI app
app = FastAPI(
    title="Similarity Search API",
    description=description,
    version="1.0.0",
)


# Root endpoint redirects to /docs
@app.get("/")
async def root():
    return RedirectResponse(url="/docs")


# Health check endpoint
@app.get("/health")
async def health_check(db: Database = Depends(get_db)):
    try:
        is_healthy = await db.health_check()
        if not is_healthy:
            raise HTTPException(status_code=500, detail="Database is unhealthy")
        return {"status": "healthy"}
    except HealthCheckError as e:
        raise HTTPException(status_code=500, detail=str(e))


# Dependency to get EmbeddingService
def get_embedding_service() -> EmbeddingService:
    return EmbeddingService(openai_api_key=os.getenv("OPENAI_API_KEY"))


# Dependency to get HuggingFaceService
def get_huggingface_service() -> HuggingFaceService:
    return HuggingFaceService()


# Endpoint to create embeddings
@app.post("/create_embedding")
async def create_embedding(
    request: CreateEmbeddingRequest,
    db: Database = Depends(get_db),
    embedding_service: EmbeddingService = Depends(get_embedding_service),
    huggingface_service: HuggingFaceService = Depends(get_huggingface_service),
):
    """
    Create embeddings for the target column in the dataset.
    """
    try:
        # Step 1: Query the database
        logger.info("Fetching data from the database...")
        result = await db.fetch(request.query)
        df = pd.DataFrame(result)

        # Step 2: Generate embeddings
        df = await embedding_service.create_embeddings(
            df, request.target_column, request.output_column
        )

        # Step 3: Push to Hugging Face Hub
        await huggingface_service.push_to_hub(df, request.dataset_name)

        return JSONResponse(
            content={
                "message": "Embeddings created and pushed to Hugging Face Hub.",
                "dataset_name": request.dataset_name,
                "num_rows": len(df),
            }
        )
    except QueryExecutionError as e:
        logger.error(f"Database query failed: {e}")
        raise HTTPException(status_code=500, detail=f"Database query failed: {e}")
    except OpenAIError as e:
        logger.error(f"OpenAI API error: {e}")
        raise HTTPException(status_code=500, detail=f"OpenAI API error: {e}")
    except DatasetPushError as e:
        logger.error(f"Failed to push dataset: {e}")
        raise HTTPException(status_code=500, detail=f"Failed to push dataset: {e}")
    except Exception as e:
        logger.error(f"An error occurred: {e}")
        raise HTTPException(status_code=500, detail=f"An error occurred: {e}")


# Endpoint to read embeddings
@app.get("/read_embeddings/{dataset_name}")
async def read_embeddings(
    dataset_name: str,
    huggingface_service: HuggingFaceService = Depends(get_huggingface_service),
):
    """
    Read embeddings from a Hugging Face dataset.
    """
    try:
        df = await huggingface_service.read_dataset(dataset_name)
        return df.to_dict(orient="records")
    except DatasetNotFoundError as e:
        logger.error(f"Dataset not found: {e}")
        raise HTTPException(status_code=404, detail=f"Dataset not found: {e}")
    except Exception as e:
        logger.error(f"An error occurred: {e}")
        raise HTTPException(status_code=500, detail=f"An error occurred: {e}")


# Endpoint to update embeddings
@app.post("/update_embeddings")
async def update_embeddings(
    request: UpdateEmbeddingRequest,
    huggingface_service: HuggingFaceService = Depends(get_huggingface_service),
):
    """
    Update embeddings in a Hugging Face dataset.
    """
    try:
        df = await huggingface_service.update_dataset(
            request.dataset_name, request.updates
        )
        return {
            "message": "Embeddings updated successfully.",
            "dataset_name": request.dataset_name,
        }
    except DatasetPushError as e:
        logger.error(f"Failed to update dataset: {e}")
        raise HTTPException(status_code=500, detail=f"Failed to update dataset: {e}")
    except Exception as e:
        logger.error(f"An error occurred: {e}")
        raise HTTPException(status_code=500, detail=f"An error occurred: {e}")


# Endpoint to delete embeddings
@app.post("/delete_embeddings")
async def delete_embeddings(
    request: DeleteEmbeddingRequest,
    huggingface_service: HuggingFaceService = Depends(get_huggingface_service),
):
    """
    Delete embeddings from a Hugging Face dataset.
    """
    try:
        await huggingface_service.delete_dataset(
            request.dataset_name
        )
        return {
            "message": "Embeddings deleted successfully.",
            "dataset_name": request.dataset_name,
        }
    except DatasetPushError as e:
        logger.error(f"Failed to delete columns: {e}")
        raise HTTPException(status_code=500, detail=f"Failed to delete columns: {e}")
    except Exception as e:
        logger.error(f"An error occurred: {e}")
        raise HTTPException(status_code=500, detail=f"An error occurred: {e}")