Spaces:
Running
Running
import os | |
import uvicorn | |
from fastapi import FastAPI, HTTPException, Depends | |
from typing import List, Tuple, Optional, Union, Dict, Any | |
import torch | |
from config import ServerConfig | |
from guard_chat import SafeModelPipeline | |
from pydantic import BaseModel | |
from dataclasses import dataclass, field | |
import logging | |
# Initialize FastAPI app | |
app = FastAPI(title="Safe Chat Model API") | |
# Global variables for model and config | |
pipeline: Optional[SafeModelPipeline] = None | |
config: Optional[ServerConfig] = None | |
class ChatRequest(BaseModel): | |
messages: List[Dict[str, str]] | |
conversation_id: Optional[str] = None # To track different conversations | |
max_new_tokens: Optional[int] = 80 | |
temperature: Optional[float] = 0.7 | |
do_sample: Optional[bool] = True | |
top_p: Optional[float] = 0.9 | |
top_k: Optional[int] = 50 | |
class ChatResponse(BaseModel): | |
response: str | |
input_safety: str | |
output_safety: str | |
filtered: bool | |
conversation_id: str | |
def get_config(): | |
"""Dependency to get config""" | |
if config is None: | |
raise HTTPException(status_code=500, detail="Server not initialized") | |
return config | |
async def update_config( | |
temperature: Optional[float] = None, | |
top_k: Optional[int] = None, | |
top_p: Optional[float] = None, | |
max_new_tokens: Optional[int] = None, | |
config: ServerConfig = Depends(get_config) | |
): | |
"""Update generation parameters""" | |
try: | |
if pipeline is None: | |
raise HTTPException(status_code=500, detail="Model not initialized") | |
# Update the configuration | |
if temperature is not None: | |
pipeline.args.temperature = temperature | |
if top_k is not None: | |
pipeline.args.top_k = top_k | |
if top_p is not None: | |
pipeline.args.top_p = top_p | |
if max_new_tokens is not None: | |
pipeline.args.max_new_tokens = max_new_tokens | |
return {"message": "Configuration updated successfully"} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def health_check(): | |
"""Health check endpoint""" | |
if pipeline is None or config is None: | |
raise HTTPException(status_code=503, detail="Server not fully initialized") | |
return {"status": "healthy"} | |
async def startup_event(): | |
"""Initialize the server on startup""" | |
global pipeline, config | |
try: | |
# Load config from YAML | |
config_path = os.getenv("CONFIG_PATH", "config.yaml") | |
config = ServerConfig.from_yaml(config_path) | |
# Initialize pipeline with config | |
pipeline = SafeModelPipeline( | |
model_args=config.to_chat_arguments(), | |
system_prompt=config.system_prompt, | |
max_history_tokens=config.max_history_tokens | |
) | |
except Exception as e: | |
print(f"Error initializing server: {str(e)}") | |
raise | |
async def chat_endpoint(request: ChatRequest): | |
"""Chat endpoint that handles conversation and calls the safe model pipeline""" | |
logging.info(f"Received request: {request}") # Log the incoming request | |
if pipeline is None: | |
raise HTTPException(status_code=500, detail="Model not initialized") | |
try: | |
# Validate request | |
if not request.messages: | |
raise HTTPException(status_code=400, detail="No messages provided") | |
logging.info(f"Processing messages: {request.messages}") # Log the messages | |
# Call the safe model pipeline's generate_response method | |
try: | |
response = await pipeline.generate_response(request) | |
logging.info(f"Generated response: {response}") # Log the response | |
except Exception as e: | |
logging.error(f"Pipeline error: {str(e)}", exc_info=True) | |
raise HTTPException(status_code=500, detail=f"Pipeline error: {str(e)}") | |
# If response is None or invalid, raise an error | |
if not response: | |
raise HTTPException(status_code=500, detail="Failed to generate response") | |
return ChatResponse( | |
response=response.response, | |
conversation_id=response.conversation_id, | |
input_safety=response.input_safety, | |
output_safety=response.output_safety, | |
filtered=response.filtered | |
) | |
except Exception as e: | |
logging.error(f"Error in chat endpoint: {str(e)}", exc_info=True) | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == "__main__": | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--host", type=str, default="0.0.0.0", | |
help="Host to run the server on") | |
parser.add_argument("--port", type=int, default=8000, | |
help="Port to run the server on") | |
parser.add_argument("--config", type=str, default="config.yaml", | |
help="Path to config file") | |
args = parser.parse_args() | |
# Set config path environment variable | |
os.environ["CONFIG_PATH"] = args.config | |
# Run the server | |
uvicorn.run(app, host=args.host, port=args.port) |