tectopia / routers /mistral.py
kevinkal's picture
Update routers/mistral.py
0edec00 verified
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing import Annotated, List
from mistralai import Mistral
from auth import verify_token
import os
from schemas.common import APIResponse, ChatMessage, ChatMemoryRequest # Import common schemas
import asyncio
import time
router = APIRouter(prefix="/mistral", tags=["mistral"])
mistral_key = os.environ.get('MISTRAL_KEY', '')
if not mistral_key:
raise RuntimeError("MISTRAL_KEY environment variable not set.")
mistral_client = Mistral(api_key=mistral_key)
# Rate limiting variables
last_mistral_call_time = 0
mistral_call_lock = asyncio.Lock()
MIN_INTERVAL = 1.0 # Minimum 1 second between calls
class LLMRequest(BaseModel):
model: str
prompt: str
@router.post("/chat-stream")
async def mistral_chat_stream(request: LLMRequest, token: Annotated[str, Depends(verify_token)]):
async def generate():
global last_mistral_call_time
async with mistral_call_lock:
current_time = time.monotonic()
elapsed = current_time - last_mistral_call_time
if elapsed < MIN_INTERVAL:
await asyncio.sleep(MIN_INTERVAL - elapsed)
last_mistral_call_time = time.monotonic()
try:
response = await mistral_client.chat.stream_async(
model=request.model,
messages=[
{
"role": "user",
"content": request.prompt,
}
],
)
async for chunk in response:
# Ensure chunk.data exists and has the expected structure
if hasattr(chunk, 'choices') and chunk.choices:
if chunk.choices[0].delta.content is not None:
yield chunk.choices[0].delta.content
# Handle potential variations if needed, e.g., logging unexpected structures
# else:
# print(f"Unexpected chunk structure: {chunk}")
except Exception as e:
# Log the error and yield an error message if streaming fails
print(f"Error during Mistral stream: {e}")
yield f"Error: {str(e)}"
return StreamingResponse(generate(), media_type="text/plain")
@router.post("/chat", response_model=APIResponse)
async def mistral_chat(request: LLMRequest, token: Annotated[str, Depends(verify_token)]):
global last_mistral_call_time
async with mistral_call_lock:
current_time = time.monotonic()
elapsed = current_time - last_mistral_call_time
if elapsed < MIN_INTERVAL:
await asyncio.sleep(MIN_INTERVAL - elapsed)
last_mistral_call_time = time.monotonic()
try:
response = await mistral_client.chat.complete_async(
model=request.model,
messages=[
{
"role": "user",
"content": request.prompt,
}
],
)
if response.choices and response.choices[0].message:
content = response.choices[0].message.content
return APIResponse(success=True, data={"response": content})
else:
return APIResponse(success=False, error="No response content received from Mistral.", data=response.dict())
except Exception as e:
print(f"Error calling Mistral chat completion: {e}")
raise HTTPException(status_code=500, detail=f"Mistral API error: {str(e)}")
@router.post("/chat-with-memory", response_model=APIResponse)
async def mistral_chat_with_memory(request: ChatMemoryRequest, token: Annotated[str, Depends(verify_token)]):
global last_mistral_call_time
async with mistral_call_lock:
current_time = time.monotonic()
elapsed = current_time - last_mistral_call_time
if elapsed < MIN_INTERVAL:
await asyncio.sleep(MIN_INTERVAL - elapsed)
last_mistral_call_time = time.monotonic()
try:
# Convert Pydantic models to dicts for the Mistral API call
messages_dict = [msg.dict() for msg in request.messages]
response = await mistral_client.chat.complete_async(
model=request.model,
messages=messages_dict,
)
if response.choices and response.choices[0].message:
content = response.choices[0].message.content
return APIResponse(success=True, data={"response": content})
else:
return APIResponse(success=False, error="No response content received from Mistral.", data=response.dict())
except Exception as e:
print(f"Error calling Mistral chat completion with memory: {e}")
raise HTTPException(status_code=500, detail=f"Mistral API error: {str(e)}")