import os import uvicorn from fastapi import FastAPI, HTTPException, Depends from typing import List, Tuple, Optional, Union, Dict, Any import torch from config import ServerConfig from guard_chat import SafeModelPipeline from pydantic import BaseModel from dataclasses import dataclass, field import logging # Initialize FastAPI app app = FastAPI(title="Safe Chat Model API") # Global variables for model and config pipeline: Optional[SafeModelPipeline] = None config: Optional[ServerConfig] = None class ChatRequest(BaseModel): messages: List[Dict[str, str]] conversation_id: Optional[str] = None # To track different conversations max_new_tokens: Optional[int] = 80 temperature: Optional[float] = 0.7 do_sample: Optional[bool] = True top_p: Optional[float] = 0.9 top_k: Optional[int] = 50 class ChatResponse(BaseModel): response: str input_safety: str output_safety: str filtered: bool conversation_id: str def get_config(): """Dependency to get config""" if config is None: raise HTTPException(status_code=500, detail="Server not initialized") return config @app.post("/update_config") async def update_config( temperature: Optional[float] = None, top_k: Optional[int] = None, top_p: Optional[float] = None, max_new_tokens: Optional[int] = None, config: ServerConfig = Depends(get_config) ): """Update generation parameters""" try: if pipeline is None: raise HTTPException(status_code=500, detail="Model not initialized") # Update the configuration if temperature is not None: pipeline.args.temperature = temperature if top_k is not None: pipeline.args.top_k = top_k if top_p is not None: pipeline.args.top_p = top_p if max_new_tokens is not None: pipeline.args.max_new_tokens = max_new_tokens return {"message": "Configuration updated successfully"} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/health") async def health_check(): """Health check endpoint""" if pipeline is None or config is None: raise HTTPException(status_code=503, detail="Server not fully initialized") return {"status": "healthy"} @app.on_event("startup") async def startup_event(): """Initialize the server on startup""" global pipeline, config try: # Load config from YAML config_path = os.getenv("CONFIG_PATH", "config.yaml") config = ServerConfig.from_yaml(config_path) # Initialize pipeline with config pipeline = SafeModelPipeline( model_args=config.to_chat_arguments(), system_prompt=config.system_prompt, max_history_tokens=config.max_history_tokens ) except Exception as e: print(f"Error initializing server: {str(e)}") raise @app.post("/chat", response_model=ChatResponse) async def chat_endpoint(request: ChatRequest): """Chat endpoint that handles conversation and calls the safe model pipeline""" logging.info(f"Received request: {request}") # Log the incoming request if pipeline is None: raise HTTPException(status_code=500, detail="Model not initialized") try: # Validate request if not request.messages: raise HTTPException(status_code=400, detail="No messages provided") logging.info(f"Processing messages: {request.messages}") # Log the messages # Call the safe model pipeline's generate_response method try: response = await pipeline.generate_response(request) logging.info(f"Generated response: {response}") # Log the response except Exception as e: logging.error(f"Pipeline error: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=f"Pipeline error: {str(e)}") # If response is None or invalid, raise an error if not response: raise HTTPException(status_code=500, detail="Failed to generate response") return ChatResponse( response=response.response, conversation_id=response.conversation_id, input_safety=response.input_safety, output_safety=response.output_safety, filtered=response.filtered ) except Exception as e: logging.error(f"Error in chat endpoint: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to run the server on") parser.add_argument("--port", type=int, default=8000, help="Port to run the server on") parser.add_argument("--config", type=str, default="config.yaml", help="Path to config file") args = parser.parse_args() # Set config path environment variable os.environ["CONFIG_PATH"] = args.config # Run the server uvicorn.run(app, host=args.host, port=args.port)