LLMBB-Agent / database_server.py
vlff李飞飞
优化
2e2dc41
import multiprocessing
import json
import os
import uvicorn
from fastapi import FastAPI, Request, HTTPException, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from utils import extract_and_cache_document, service, cache_file_popup_url, cache_root, cache_file, code_interpreter_ws, update_pop_url, change_checkbox_state
from starlette.middleware.sessions import SessionMiddleware
# os.environ["TRANSFORMERS_CACHE"] = ".cache/huggingface/"
app = FastAPI()
app.add_middleware(
CORSMiddleware,
# allow_origins=origins,
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)
app.mount('/static', StaticFiles(directory=code_interpreter_ws), name='static')
@app.middleware("http")
async def access_token_auth(request: Request, call_next):
# print(f"Request URL path: {request.url}")
access_token: str = request.headers.get("Authorization") or request.query_params.get("access_token") or request.session.get("access_token")
is_valid = False
if access_token:
account_info = json.loads(service.get(access_token, "info.json", False))
if account_info and account_info["enabled"]:
is_valid = True
if not is_valid:
return Response(status_code=401, content="the token is not valid")
request.session.setdefault("access_token", access_token)
return await call_next(request)
@app.get('/')
@app.get('/healthz')
async def healthz(request: Request):
return JSONResponse({"healthz": True})
@app.post('/token/add')
async def add_token(request: Request):
access_token: str = request.headers.get("Authorization") or request.query_params.get("access_token") or request.session.get("access_token")
account_info = json.loads(service.get(access_token, "info.json", False))
if account_info and account_info["enabled"] and account_info["role"] == 'admin':
return Response(status_code=401, content="the token is not valid")
data = await request.json()
service.upsert(access_token, "info.json", json.dumps(data, ensure_ascii=False), False)
return JSONResponse({"success": True})
@app.get('/cachedata/{file_name}')
async def cache_data(request: Request, file_name: str):
access_token: str = request.headers.get("Authorization") or request.query_params.get("access_token") or request.session.get("access_token")
account_info = json.loads(service.get(access_token, "info.json", False))
if account_info and account_info["enabled"] and account_info["role"] == 'admin':
return Response(status_code=401, content="the token is not valid")
data = service.get(access_token, file_name, False)
content = json.loads(data) if data else ""
return JSONResponse(content)
@app.post('/endpoint')
async def web_listening(request: Request):
data = await request.json()
msg_type = data['task']
access_token = request.session.get("access_token")
if msg_type == 'change_checkbox':
rsp = change_checkbox_state(data['ckid'], cache_file, access_token)
elif msg_type == 'cache':
cache_obj = multiprocessing.Process( target=extract_and_cache_document, args=(data, cache_root, access_token))
cache_obj.start()
# rsp = cache_data(data, cache_file)
rsp = 'caching'
elif msg_type == 'pop_url':
# What a misleading name! pop_url actually means add_url. pop is referring to the pop_up ui.
rsp = update_pop_url(data, cache_file_popup_url, access_token)
else:
raise NotImplementedError
return JSONResponse(content=rsp)
import gradio as gr
from assistant_server import demo as assistant_app
from workstation_server import demo as workstation_app
app = gr.mount_gradio_app(app, assistant_app, path="/assistant")
app = gr.mount_gradio_app(app, workstation_app, path="/workstation")
app.add_middleware(SessionMiddleware, secret_key=os.getenv("SECRET_KEY"), max_age=25200)
if __name__ == '__main__':
uvicorn.run(app='database_server:app', host='0.0.0.0', port=7860, reload=False, workers=1)