test_vllm / main.py
qhuang0805123's picture
Upload main.py
c587c68 verified
import os
import uvicorn
from fastapi import FastAPI, HTTPException, Depends
from typing import List, Tuple, Optional, Union, Dict, Any
import torch
from config import ServerConfig
from guard_chat import SafeModelPipeline
from pydantic import BaseModel
from dataclasses import dataclass, field
import logging
# Initialize FastAPI app
app = FastAPI(title="Safe Chat Model API")
# Global variables for model and config
pipeline: Optional[SafeModelPipeline] = None
config: Optional[ServerConfig] = None
class ChatRequest(BaseModel):
messages: List[Dict[str, str]]
conversation_id: Optional[str] = None # To track different conversations
max_new_tokens: Optional[int] = 80
temperature: Optional[float] = 0.7
do_sample: Optional[bool] = True
top_p: Optional[float] = 0.9
top_k: Optional[int] = 50
class ChatResponse(BaseModel):
response: str
input_safety: str
output_safety: str
filtered: bool
conversation_id: str
def get_config():
"""Dependency to get config"""
if config is None:
raise HTTPException(status_code=500, detail="Server not initialized")
return config
@app.post("/update_config")
async def update_config(
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
max_new_tokens: Optional[int] = None,
config: ServerConfig = Depends(get_config)
):
"""Update generation parameters"""
try:
if pipeline is None:
raise HTTPException(status_code=500, detail="Model not initialized")
# Update the configuration
if temperature is not None:
pipeline.args.temperature = temperature
if top_k is not None:
pipeline.args.top_k = top_k
if top_p is not None:
pipeline.args.top_p = top_p
if max_new_tokens is not None:
pipeline.args.max_new_tokens = max_new_tokens
return {"message": "Configuration updated successfully"}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
"""Health check endpoint"""
if pipeline is None or config is None:
raise HTTPException(status_code=503, detail="Server not fully initialized")
return {"status": "healthy"}
@app.on_event("startup")
async def startup_event():
"""Initialize the server on startup"""
global pipeline, config
try:
# Load config from YAML
config_path = os.getenv("CONFIG_PATH", "config.yaml")
config = ServerConfig.from_yaml(config_path)
# Initialize pipeline with config
pipeline = SafeModelPipeline(
model_args=config.to_chat_arguments(),
system_prompt=config.system_prompt,
max_history_tokens=config.max_history_tokens
)
except Exception as e:
print(f"Error initializing server: {str(e)}")
raise
@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(request: ChatRequest):
"""Chat endpoint that handles conversation and calls the safe model pipeline"""
logging.info(f"Received request: {request}") # Log the incoming request
if pipeline is None:
raise HTTPException(status_code=500, detail="Model not initialized")
try:
# Validate request
if not request.messages:
raise HTTPException(status_code=400, detail="No messages provided")
logging.info(f"Processing messages: {request.messages}") # Log the messages
# Call the safe model pipeline's generate_response method
try:
response = await pipeline.generate_response(request)
logging.info(f"Generated response: {response}") # Log the response
except Exception as e:
logging.error(f"Pipeline error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Pipeline error: {str(e)}")
# If response is None or invalid, raise an error
if not response:
raise HTTPException(status_code=500, detail="Failed to generate response")
return ChatResponse(
response=response.response,
conversation_id=response.conversation_id,
input_safety=response.input_safety,
output_safety=response.output_safety,
filtered=response.filtered
)
except Exception as e:
logging.error(f"Error in chat endpoint: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0",
help="Host to run the server on")
parser.add_argument("--port", type=int, default=8000,
help="Port to run the server on")
parser.add_argument("--config", type=str, default="config.yaml",
help="Path to config file")
args = parser.parse_args()
# Set config path environment variable
os.environ["CONFIG_PATH"] = args.config
# Run the server
uvicorn.run(app, host=args.host, port=args.port)