from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse import uvicorn import duckdb from query_engine import set_query_engine from llama_index.core.indices.struct_store import NLSQLTableQueryEngine import os from huggingface_hub import hf_hub_download from models import SQL, Prompt import json app = FastAPI() origins = ["*"] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) query_engine:NLSQLTableQueryEngine query_file_path = "" llm_file_path = "" @app.on_event("startup") def startup(): dataset_name = "pdrMottaS/afabd-duckdb" global query_file_path global llm_file_path global query_engine llm_file_path = hf_hub_download(repo_id=dataset_name, filename='llm_afabd.db', repo_type="dataset") query_file_path = hf_hub_download(repo_id=dataset_name, filename='afabd.db', repo_type="dataset") query_engine = set_query_engine(llm_file_path) @app.post("/sql") async def query_database(query_data: SQL): global query_file_path conn = duckdb.connect(query_file_path,read_only=True) df = conn.execute(query_data.query).fetch_df() return JSONResponse(json.loads(df.to_json(orient = "records"))) @app.post("/llm") async def llm(prompt_data: Prompt): global query_engine response = query_engine.query(prompt_data.promt) return JSONResponse({"promt":prompt_data.promt,"response":response}) uvicorn.run(app,host='0.0.0.0',port=7860)