File size: 3,599 Bytes
63df3f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List
import json
import os
import logging
from txtai.embeddings import Embeddings

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI()

# Enable CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allows all origins
    allow_credentials=True,
    allow_methods=["*"],  # Allows all methods
    allow_headers=["*"],  # Allows all headers
)

embeddings = Embeddings({"path": "avsolatorio/GIST-all-MiniLM-L6-v2"})

class DocumentRequest(BaseModel):
    index_id: str
    documents: List[str]

class QueryRequest(BaseModel):
    index_id: str
    query: str
    num_results: int

def save_embeddings(index_id, document_list):
    try:
        folder_path = f"indexes/{index_id}"
        os.makedirs(folder_path, exist_ok=True)
        
        # Save embeddings
        embeddings.save(f"{folder_path}/embeddings")
        
        # Save document_list
        with open(f"{folder_path}/document_list.json", "w") as f:
            json.dump(document_list, f)
        logger.info(f"Embeddings and document list saved for index_id: {index_id}")
    except Exception as e:
        logger.error(f"Error saving embeddings for index_id {index_id}: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Error saving embeddings: {str(e)}")

def load_embeddings(index_id):
    try:
        folder_path = f"indexes/{index_id}"
        
        if not os.path.exists(folder_path):
            logger.error(f"Index not found for index_id: {index_id}")
            raise HTTPException(status_code=404, detail="Index not found")
        
        # Load embeddings
        embeddings.load(f"{folder_path}/embeddings")
        
        # Load document_list
        with open(f"{folder_path}/document_list.json", "r") as f:
            document_list = json.load(f)
        
        logger.info(f"Embeddings and document list loaded for index_id: {index_id}")
        return document_list
    except Exception as e:
        logger.error(f"Error loading embeddings for index_id {index_id}: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Error loading embeddings: {str(e)}")

@app.post("/create_index/")
async def create_index(request: DocumentRequest):
    try:
        document_list = [(i, text, None) for i, text in enumerate(request.documents)]
        embeddings.index(document_list)
        save_embeddings(request.index_id, request.documents)  # Save the original documents
        logger.info(f"Index created successfully for index_id: {request.index_id}")
        return {"message": "Index created successfully"}
    except Exception as e:
        logger.error(f"Error creating index: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Error creating index: {str(e)}")

@app.post("/query_index/")
async def query_index(request: QueryRequest):
    try:
        document_list = load_embeddings(request.index_id)
        results = embeddings.search(request.query, request.num_results)
        queried_texts = [document_list[idx[0]] for idx in results]
        logger.info(f"Query executed successfully for index_id: {request.index_id}")
        return {"queried_texts": queried_texts}
    except Exception as e:
        logger.error(f"Error querying index: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Error querying index: {str(e)}")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)