alexpantex commited on
Commit
f4e126c
·
verified ·
1 Parent(s): dd047bd

Upload scripts/api.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/api.py +67 -0
scripts/api.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append(sys.path[0].replace('scripts', ''))
3
+ import os
4
+ from typing import List
5
+
6
+ from fastapi import FastAPI, HTTPException
7
+ from pydantic import BaseModel
8
+ from scripts.run import config, search_engine
9
+ from scripts.preprocess import preprocess_text
10
+
11
+ app = FastAPI(
12
+ title="Prompt Search API",
13
+ description="A RESTful API to find top-n most similar prompts.",
14
+ version="1.0.0"
15
+ )
16
+
17
+ class QueryRequest(BaseModel):
18
+ query: str
19
+ n_results: int = 5
20
+
21
+ class SimilarQuery(BaseModel):
22
+ prompt: str
23
+ score: float
24
+
25
+ class QueryResponse(BaseModel):
26
+ query: str
27
+ similar_queries: List[SimilarQuery]
28
+
29
+ @app.get("/")
30
+ def root():
31
+ return {"message": "Welcome to the Prompt Search API. Use '/search' endpoint to find similar prompts."}
32
+
33
+ @app.post("/search", response_model=QueryResponse)
34
+ async def search_prompts(query_request: QueryRequest):
35
+ """
36
+ Accepts a query prompt and returns the top n similar prompts.
37
+ Args:
38
+ query_request: JSON input with query prompt and number of results to return.
39
+ Returns:
40
+ A list of top-n similar prompts with similarity scores.
41
+ """
42
+
43
+ query = query_request.query
44
+ n_results = query_request.n_results
45
+
46
+ if not query.strip():
47
+ raise HTTPException(status_code=400, detail="Query prompt cannot be empty.")
48
+ if n_results <= 0:
49
+ raise HTTPException(status_code=400, detail="Number of results must be greater than zero.")
50
+
51
+ try:
52
+ q = preprocess_text(query)
53
+ print(q)
54
+ results = search_engine.most_similar(q, n=n_results)
55
+
56
+ print("Results:", results) # Check if results have expected structure
57
+ result_dict = [{"prompt": r['prompt'], "score": float(r['score'])} for r in results]
58
+ return QueryResponse(query=query, similar_queries=result_dict)
59
+ # return [{"query": query, "similar_queries": results}]
60
+ except Exception as e:
61
+ print(e)
62
+ raise HTTPException(status_code=500, detail=str(e))
63
+
64
+ # Entry point for configuring parameters and running the app
65
+ if __name__ == "__main__":
66
+ import uvicorn
67
+ uvicorn.run(app, host = config["server"]["host"], port = int(os.getenv("PORT", config["server"]["port"])))