test_api / fastapi_app.py
API-Handler's picture
Upload 10 files
501c69f verified
raw
history blame
4.43 kB
from fastapi import FastAPI, Request, Response
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import json
from typegpt_api import generate, model_mapping, simplified_models
from api_info import developer_info, model_providers
app = FastAPI()
# Set up CORS middleware if needed
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/health_check")
async def health_check():
return {"status": "OK"}
@app.get("/models")
async def get_models():
try:
response = {
"object": "list",
"data": []
}
for provider, info in model_providers.items():
for model in info["models"]:
response["data"].append({
"id": model,
"object": "model",
"provider": provider,
"description": info["description"]
})
return JSONResponse(content=response)
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
@app.post("/chat/completions")
async def chat_completions(request: Request):
# Receive the JSON payload
try:
body = await request.json()
except Exception as e:
return JSONResponse(content={"error": "Invalid JSON payload"}, status_code=400)
# Extract parameters
model = body.get("model")
messages = body.get("messages")
temperature = body.get("temperature", 0.7)
top_p = body.get("top_p", 1.0)
n = body.get("n", 1)
stream = body.get("stream", False)
stop = body.get("stop")
max_tokens = body.get("max_tokens")
presence_penalty = body.get("presence_penalty", 0.0)
frequency_penalty = body.get("frequency_penalty", 0.0)
logit_bias = body.get("logit_bias")
user = body.get("user")
timeout = 30 # or set based on your preference
# Validate required parameters
if not model:
return JSONResponse(content={"error": "The 'model' parameter is required."}, status_code=400)
if not messages:
return JSONResponse(content={"error": "The 'messages' parameter is required."}, status_code=400)
# Call the generate function
try:
if stream:
async def generate_stream():
response = generate(
model=model,
messages=messages,
temperature=temperature,
top_p=top_p,
n=n,
stream=True,
stop=stop,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
user=user,
timeout=timeout,
)
for chunk in response:
yield f"data: {json.dumps(chunk)}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(
generate_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked"
}
)
else:
response = generate(
model=model,
messages=messages,
temperature=temperature,
top_p=top_p,
n=n,
stream=False,
stop=stop,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
user=user,
timeout=timeout,
)
return JSONResponse(content=response)
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
@app.get("/developer_info")
async def get_developer_info():
return JSONResponse(content=developer_info)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)