Jonny Tran
first commit
1171a09
raw
history blame
3.36 kB
# app.py
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Dict, Optional, Union
import os
from llama_index import VectorStoreIndex, SimpleDirectoryReader, ServiceContext
from llama_index.llms import OpenAI
import logging
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
# Configure CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Adjust this in production
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class ChatMessage(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
messages: List[ChatMessage]
temperature: Optional[float] = 0.7
max_tokens: Optional[int] = 1024
parameters: Optional[Dict] = {}
class ChatResponse(BaseModel):
content: str
stop: Optional[List[str]] = None
usage: Dict[str, int]
# Initialize LlamaIndex components
@app.on_event("startup")
async def startup_event():
try:
# Initialize OpenAI client (can be replaced with other LLM providers)
llm = OpenAI(
model="gpt-3.5-turbo",
temperature=0.7,
api_key=os.getenv("OPENAI_API_KEY")
)
# Create service context
service_context = ServiceContext.from_defaults(llm=llm)
# Load documents (adjust path as needed)
if os.path.exists("data"):
documents = SimpleDirectoryReader("data").load_data()
app.state.index = VectorStoreIndex.from_documents(
documents,
service_context=service_context
)
else:
# Create empty index if no documents
app.state.index = VectorStoreIndex([])
app.state.query_engine = app.state.index.as_query_engine()
logger.info("LlamaIndex initialization completed successfully")
except Exception as e:
logger.error(f"Error during startup: {str(e)}")
raise
@app.get("/health")
async def health_check():
return {"status": "healthy"}
@app.post("/chat", response_model=ChatResponse)
async def chat_endpoint(request: ChatRequest):
try:
# Extract the last user message
last_message = next(
(msg.content for msg in reversed(request.messages) if msg.role == "user"),
None
)
if not last_message:
raise HTTPException(
status_code=400,
detail="No user message found in the conversation"
)
# Get response from LlamaIndex
response = app.state.query_engine.query(
last_message,
similarity_top_k=3, # Adjust as needed
)
# Format response
return ChatResponse(
content=str(response),
stop=None,
usage={
"prompt_tokens": 0, # Add actual token counting if needed
"completion_tokens": 0,
"total_tokens": 0
}
)
except Exception as e:
logger.error(f"Error processing chat request: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)