alexpantex's picture
Upload scripts/api.py with huggingface_hub
f4e126c verified
raw
history blame
2.19 kB
import sys
sys.path.append(sys.path[0].replace('scripts', ''))
import os
from typing import List
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from scripts.run import config, search_engine
from scripts.preprocess import preprocess_text
app = FastAPI(
title="Prompt Search API",
description="A RESTful API to find top-n most similar prompts.",
version="1.0.0"
)
class QueryRequest(BaseModel):
query: str
n_results: int = 5
class SimilarQuery(BaseModel):
prompt: str
score: float
class QueryResponse(BaseModel):
query: str
similar_queries: List[SimilarQuery]
@app.get("/")
def root():
return {"message": "Welcome to the Prompt Search API. Use '/search' endpoint to find similar prompts."}
@app.post("/search", response_model=QueryResponse)
async def search_prompts(query_request: QueryRequest):
"""
Accepts a query prompt and returns the top n similar prompts.
Args:
query_request: JSON input with query prompt and number of results to return.
Returns:
A list of top-n similar prompts with similarity scores.
"""
query = query_request.query
n_results = query_request.n_results
if not query.strip():
raise HTTPException(status_code=400, detail="Query prompt cannot be empty.")
if n_results <= 0:
raise HTTPException(status_code=400, detail="Number of results must be greater than zero.")
try:
q = preprocess_text(query)
print(q)
results = search_engine.most_similar(q, n=n_results)
print("Results:", results) # Check if results have expected structure
result_dict = [{"prompt": r['prompt'], "score": float(r['score'])} for r in results]
return QueryResponse(query=query, similar_queries=result_dict)
# return [{"query": query, "similar_queries": results}]
except Exception as e:
print(e)
raise HTTPException(status_code=500, detail=str(e))
# Entry point for configuring parameters and running the app
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host = config["server"]["host"], port = int(os.getenv("PORT", config["server"]["port"])))