pvanand commited on
Commit
63df3f2
1 Parent(s): 823b7be

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +99 -0
main.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from pydantic import BaseModel
4
+ from typing import List
5
+ import json
6
+ import os
7
+ import logging
8
+ from txtai.embeddings import Embeddings
9
+
10
+ # Set up logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ app = FastAPI()
15
+
16
+ # Enable CORS
17
+ app.add_middleware(
18
+ CORSMiddleware,
19
+ allow_origins=["*"], # Allows all origins
20
+ allow_credentials=True,
21
+ allow_methods=["*"], # Allows all methods
22
+ allow_headers=["*"], # Allows all headers
23
+ )
24
+
25
+ embeddings = Embeddings({"path": "avsolatorio/GIST-all-MiniLM-L6-v2"})
26
+
27
+ class DocumentRequest(BaseModel):
28
+ index_id: str
29
+ documents: List[str]
30
+
31
+ class QueryRequest(BaseModel):
32
+ index_id: str
33
+ query: str
34
+ num_results: int
35
+
36
+ def save_embeddings(index_id, document_list):
37
+ try:
38
+ folder_path = f"indexes/{index_id}"
39
+ os.makedirs(folder_path, exist_ok=True)
40
+
41
+ # Save embeddings
42
+ embeddings.save(f"{folder_path}/embeddings")
43
+
44
+ # Save document_list
45
+ with open(f"{folder_path}/document_list.json", "w") as f:
46
+ json.dump(document_list, f)
47
+ logger.info(f"Embeddings and document list saved for index_id: {index_id}")
48
+ except Exception as e:
49
+ logger.error(f"Error saving embeddings for index_id {index_id}: {str(e)}")
50
+ raise HTTPException(status_code=500, detail=f"Error saving embeddings: {str(e)}")
51
+
52
+ def load_embeddings(index_id):
53
+ try:
54
+ folder_path = f"indexes/{index_id}"
55
+
56
+ if not os.path.exists(folder_path):
57
+ logger.error(f"Index not found for index_id: {index_id}")
58
+ raise HTTPException(status_code=404, detail="Index not found")
59
+
60
+ # Load embeddings
61
+ embeddings.load(f"{folder_path}/embeddings")
62
+
63
+ # Load document_list
64
+ with open(f"{folder_path}/document_list.json", "r") as f:
65
+ document_list = json.load(f)
66
+
67
+ logger.info(f"Embeddings and document list loaded for index_id: {index_id}")
68
+ return document_list
69
+ except Exception as e:
70
+ logger.error(f"Error loading embeddings for index_id {index_id}: {str(e)}")
71
+ raise HTTPException(status_code=500, detail=f"Error loading embeddings: {str(e)}")
72
+
73
+ @app.post("/create_index/")
74
+ async def create_index(request: DocumentRequest):
75
+ try:
76
+ document_list = [(i, text, None) for i, text in enumerate(request.documents)]
77
+ embeddings.index(document_list)
78
+ save_embeddings(request.index_id, request.documents) # Save the original documents
79
+ logger.info(f"Index created successfully for index_id: {request.index_id}")
80
+ return {"message": "Index created successfully"}
81
+ except Exception as e:
82
+ logger.error(f"Error creating index: {str(e)}")
83
+ raise HTTPException(status_code=500, detail=f"Error creating index: {str(e)}")
84
+
85
+ @app.post("/query_index/")
86
+ async def query_index(request: QueryRequest):
87
+ try:
88
+ document_list = load_embeddings(request.index_id)
89
+ results = embeddings.search(request.query, request.num_results)
90
+ queried_texts = [document_list[idx[0]] for idx in results]
91
+ logger.info(f"Query executed successfully for index_id: {request.index_id}")
92
+ return {"queried_texts": queried_texts}
93
+ except Exception as e:
94
+ logger.error(f"Error querying index: {str(e)}")
95
+ raise HTTPException(status_code=500, detail=f"Error querying index: {str(e)}")
96
+
97
+ if __name__ == "__main__":
98
+ import uvicorn
99
+ uvicorn.run(app, host="0.0.0.0", port=7860)