f5_model_final / app /controllers /plan_chat_controller.py
EL GHAFRAOUI AYOUB
C'
6f14d8b
raw
history blame
2.35 kB
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse, PlainTextResponse
from pydantic import BaseModel, Field
from typing import List, Literal
from app.helpers.plan_chat import ask_plan_question
from app.helpers.token_auth import get_token
from app.helpers.get_current_uesr import get_user_from_token
from app.models.project import Project
from app.helpers.vectorization import search_similar
router = APIRouter()
class HistoryItem(BaseModel):
message: str
from_: Literal["user", "ai"]
class PlanChatPayload(BaseModel):
query: str
history: List[HistoryItem]
project_id: str
@router.post("/plan-chat")
async def plan_chat(data: PlanChatPayload, token: str = Depends(get_token)):
"""
Handle chat messages for plan generation with context from scraped content
"""
try:
# Validate user
user = await get_user_from_token(token=token)
if not user:
raise HTTPException(status_code=401, detail="Invalid token")
# Get project context
project = await Project.get_or_none(id=data.project_id)
if not project:
raise HTTPException(status_code=404, detail="Project not found")
# Get relevant context from vectorstore
context = await search_similar(data.query)
# Prepare system prompt
system_prompt = (
"You are a solution architect assistant specialized in cloud architecture. "
"Use the following context to help answer questions about the project plan. "
"Focus on providing specific, actionable advice based on the project requirements "
"and scraped documentation.\n\n"
f"Project Context: {context}\n"
f"Project Requirements: {project.requirements}\n"
f"Project Features: {project.features}\n"
f"Solution Stack: {project.solution_stack}\n"
)
async def response_stream():
async for chunk in ask_plan_question(
question=data.query,
history=data.history,
project_context=system_prompt
):
yield chunk
return StreamingResponse(response_stream(), media_type="text/plain")
except Exception as e:
return PlainTextResponse(str(e), status_code=500)