from fastapi import APIRouter, Depends, HTTPException from fastapi.responses import StreamingResponse from pydantic import BaseModel from typing import Annotated, List from mistralai import Mistral from auth import verify_token import os from schemas.common import APIResponse, ChatMessage, ChatMemoryRequest # Import common schemas import asyncio import time router = APIRouter(prefix="/mistral", tags=["mistral"]) mistral_key = os.environ.get('MISTRAL_KEY', '') if not mistral_key: raise RuntimeError("MISTRAL_KEY environment variable not set.") mistral_client = Mistral(api_key=mistral_key) # Rate limiting variables last_mistral_call_time = 0 mistral_call_lock = asyncio.Lock() MIN_INTERVAL = 1.0 # Minimum 1 second between calls class LLMRequest(BaseModel): model: str prompt: str @router.post("/chat-stream") async def mistral_chat_stream(request: LLMRequest, token: Annotated[str, Depends(verify_token)]): async def generate(): global last_mistral_call_time async with mistral_call_lock: current_time = time.monotonic() elapsed = current_time - last_mistral_call_time if elapsed < MIN_INTERVAL: await asyncio.sleep(MIN_INTERVAL - elapsed) last_mistral_call_time = time.monotonic() try: response = await mistral_client.chat.stream_async( model=request.model, messages=[ { "role": "user", "content": request.prompt, } ], ) async for chunk in response: # Ensure chunk.data exists and has the expected structure if hasattr(chunk, 'choices') and chunk.choices: if chunk.choices[0].delta.content is not None: yield chunk.choices[0].delta.content # Handle potential variations if needed, e.g., logging unexpected structures # else: # print(f"Unexpected chunk structure: {chunk}") except Exception as e: # Log the error and yield an error message if streaming fails print(f"Error during Mistral stream: {e}") yield f"Error: {str(e)}" return StreamingResponse(generate(), media_type="text/plain") @router.post("/chat", response_model=APIResponse) async def mistral_chat(request: LLMRequest, token: Annotated[str, Depends(verify_token)]): global last_mistral_call_time async with mistral_call_lock: current_time = time.monotonic() elapsed = current_time - last_mistral_call_time if elapsed < MIN_INTERVAL: await asyncio.sleep(MIN_INTERVAL - elapsed) last_mistral_call_time = time.monotonic() try: response = await mistral_client.chat.complete_async( model=request.model, messages=[ { "role": "user", "content": request.prompt, } ], ) if response.choices and response.choices[0].message: content = response.choices[0].message.content return APIResponse(success=True, data={"response": content}) else: return APIResponse(success=False, error="No response content received from Mistral.", data=response.dict()) except Exception as e: print(f"Error calling Mistral chat completion: {e}") raise HTTPException(status_code=500, detail=f"Mistral API error: {str(e)}") @router.post("/chat-with-memory", response_model=APIResponse) async def mistral_chat_with_memory(request: ChatMemoryRequest, token: Annotated[str, Depends(verify_token)]): global last_mistral_call_time async with mistral_call_lock: current_time = time.monotonic() elapsed = current_time - last_mistral_call_time if elapsed < MIN_INTERVAL: await asyncio.sleep(MIN_INTERVAL - elapsed) last_mistral_call_time = time.monotonic() try: # Convert Pydantic models to dicts for the Mistral API call messages_dict = [msg.dict() for msg in request.messages] response = await mistral_client.chat.complete_async( model=request.model, messages=messages_dict, ) if response.choices and response.choices[0].message: content = response.choices[0].message.content return APIResponse(success=True, data={"response": content}) else: return APIResponse(success=False, error="No response content received from Mistral.", data=response.dict()) except Exception as e: print(f"Error calling Mistral chat completion with memory: {e}") raise HTTPException(status_code=500, detail=f"Mistral API error: {str(e)}")