File size: 5,303 Bytes
c587c68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import os
import uvicorn
from fastapi import FastAPI, HTTPException, Depends
from typing import List, Tuple, Optional, Union, Dict, Any
import torch

from config import ServerConfig
from guard_chat import SafeModelPipeline
from pydantic import BaseModel
from dataclasses import dataclass, field
import logging

# Initialize FastAPI app
app = FastAPI(title="Safe Chat Model API")


# Global variables for model and config
pipeline: Optional[SafeModelPipeline] = None
config: Optional[ServerConfig] = None

class ChatRequest(BaseModel):
    messages: List[Dict[str, str]]
    conversation_id: Optional[str] = None  # To track different conversations
    max_new_tokens: Optional[int] = 80
    temperature: Optional[float] = 0.7
    do_sample: Optional[bool] = True
    top_p: Optional[float] = 0.9
    top_k: Optional[int] = 50

class ChatResponse(BaseModel):
    response: str
    input_safety: str
    output_safety: str
    filtered: bool
    conversation_id: str





def get_config():
    """Dependency to get config"""
    if config is None:
        raise HTTPException(status_code=500, detail="Server not initialized")
    return config



@app.post("/update_config")
async def update_config(
    temperature: Optional[float] = None,
    top_k: Optional[int] = None,
    top_p: Optional[float] = None,
    max_new_tokens: Optional[int] = None,
    config: ServerConfig = Depends(get_config)
):
    """Update generation parameters"""
    try:
        if pipeline is None:
            raise HTTPException(status_code=500, detail="Model not initialized")
            
        # Update the configuration
        if temperature is not None:
            pipeline.args.temperature = temperature
        if top_k is not None:
            pipeline.args.top_k = top_k
        if top_p is not None:
            pipeline.args.top_p = top_p
        if max_new_tokens is not None:
            pipeline.args.max_new_tokens = max_new_tokens
        
        return {"message": "Configuration updated successfully"}

        
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
async def health_check():
    """Health check endpoint"""
    if pipeline is None or config is None:
        raise HTTPException(status_code=503, detail="Server not fully initialized")
    return {"status": "healthy"}


@app.on_event("startup")
async def startup_event():
    """Initialize the server on startup"""
    global pipeline, config
    
    try:
        # Load config from YAML
        config_path = os.getenv("CONFIG_PATH", "config.yaml")
        config = ServerConfig.from_yaml(config_path)
        
        # Initialize pipeline with config
        pipeline = SafeModelPipeline(
            model_args=config.to_chat_arguments(),
            system_prompt=config.system_prompt,
            max_history_tokens=config.max_history_tokens
        )
        
    except Exception as e:
        print(f"Error initializing server: {str(e)}")
        raise

@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(request: ChatRequest):
    """Chat endpoint that handles conversation and calls the safe model pipeline"""
    logging.info(f"Received request: {request}")  # Log the incoming request
    
    if pipeline is None:
        raise HTTPException(status_code=500, detail="Model not initialized")
    
    try:
        # Validate request
        if not request.messages:
            raise HTTPException(status_code=400, detail="No messages provided")
        
        logging.info(f"Processing messages: {request.messages}")  # Log the messages
        
        # Call the safe model pipeline's generate_response method
        try:
            response = await pipeline.generate_response(request)
            logging.info(f"Generated response: {response}")  # Log the response
        except Exception as e:
            logging.error(f"Pipeline error: {str(e)}", exc_info=True)
            raise HTTPException(status_code=500, detail=f"Pipeline error: {str(e)}")
        
        # If response is None or invalid, raise an error
        if not response:
            raise HTTPException(status_code=500, detail="Failed to generate response")
            
        return ChatResponse(
            response=response.response,
            conversation_id=response.conversation_id,
            input_safety=response.input_safety,
            output_safety=response.output_safety,
            filtered=response.filtered
        )
        
    except Exception as e:
        logging.error(f"Error in chat endpoint: {str(e)}", exc_info=True)
        raise HTTPException(status_code=500, detail=str(e))

if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="0.0.0.0", 
                       help="Host to run the server on")
    parser.add_argument("--port", type=int, default=8000, 
                       help="Port to run the server on")
    parser.add_argument("--config", type=str, default="config.yaml", 
                       help="Path to config file")
    
    args = parser.parse_args()
    
    # Set config path environment variable
    os.environ["CONFIG_PATH"] = args.config
    
    # Run the server
    uvicorn.run(app, host=args.host, port=args.port)