Spaces:
Runtime error
Runtime error
# app.py | |
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from typing import List, Dict, Optional, Union | |
import os | |
from llama_index import VectorStoreIndex, SimpleDirectoryReader, ServiceContext | |
from llama_index.llms import OpenAI | |
import logging | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI() | |
# Configure CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Adjust this in production | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
class ChatMessage(BaseModel): | |
role: str | |
content: str | |
class ChatRequest(BaseModel): | |
messages: List[ChatMessage] | |
temperature: Optional[float] = 0.7 | |
max_tokens: Optional[int] = 1024 | |
parameters: Optional[Dict] = {} | |
class ChatResponse(BaseModel): | |
content: str | |
stop: Optional[List[str]] = None | |
usage: Dict[str, int] | |
# Initialize LlamaIndex components | |
async def startup_event(): | |
try: | |
# Initialize OpenAI client (can be replaced with other LLM providers) | |
llm = OpenAI( | |
model="gpt-3.5-turbo", | |
temperature=0.7, | |
api_key=os.getenv("OPENAI_API_KEY") | |
) | |
# Create service context | |
service_context = ServiceContext.from_defaults(llm=llm) | |
# Load documents (adjust path as needed) | |
if os.path.exists("data"): | |
documents = SimpleDirectoryReader("data").load_data() | |
app.state.index = VectorStoreIndex.from_documents( | |
documents, | |
service_context=service_context | |
) | |
else: | |
# Create empty index if no documents | |
app.state.index = VectorStoreIndex([]) | |
app.state.query_engine = app.state.index.as_query_engine() | |
logger.info("LlamaIndex initialization completed successfully") | |
except Exception as e: | |
logger.error(f"Error during startup: {str(e)}") | |
raise | |
async def health_check(): | |
return {"status": "healthy"} | |
async def chat_endpoint(request: ChatRequest): | |
try: | |
# Extract the last user message | |
last_message = next( | |
(msg.content for msg in reversed(request.messages) if msg.role == "user"), | |
None | |
) | |
if not last_message: | |
raise HTTPException( | |
status_code=400, | |
detail="No user message found in the conversation" | |
) | |
# Get response from LlamaIndex | |
response = app.state.query_engine.query( | |
last_message, | |
similarity_top_k=3, # Adjust as needed | |
) | |
# Format response | |
return ChatResponse( | |
content=str(response), | |
stop=None, | |
usage={ | |
"prompt_tokens": 0, # Add actual token counting if needed | |
"completion_tokens": 0, | |
"total_tokens": 0 | |
} | |
) | |
except Exception as e: | |
logger.error(f"Error processing chat request: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) |