File size: 5,063 Bytes
21f80b5
 
 
1a157b4
21f80b5
 
 
0edec00
1a157b4
 
21f80b5
 
 
1a157b4
21f80b5
 
 
 
1a157b4
 
 
 
 
21f80b5
 
 
 
 
 
 
1a157b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21f80b5
1a157b4
21f80b5
 
 
 
 
 
 
 
1a157b4
 
 
 
 
21f80b5
1a157b4
 
21f80b5
 
 
1a157b4
 
 
 
 
 
 
21f80b5
1a157b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0edec00
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from typing import Annotated, List
from mistralai import Mistral
from auth import verify_token
import os
from schemas.common import APIResponse, ChatMessage, ChatMemoryRequest # Import common schemas
import asyncio
import time

router = APIRouter(prefix="/mistral", tags=["mistral"])

mistral_key = os.environ.get('MISTRAL_KEY', '')
if not mistral_key:
    raise RuntimeError("MISTRAL_KEY environment variable not set.")
mistral_client = Mistral(api_key=mistral_key)

# Rate limiting variables
last_mistral_call_time = 0
mistral_call_lock = asyncio.Lock()
MIN_INTERVAL = 1.0  # Minimum 1 second between calls

class LLMRequest(BaseModel):
    model: str
    prompt: str

@router.post("/chat-stream")
async def mistral_chat_stream(request: LLMRequest, token: Annotated[str, Depends(verify_token)]):
    async def generate():
        global last_mistral_call_time
        async with mistral_call_lock:
            current_time = time.monotonic()
            elapsed = current_time - last_mistral_call_time
            if elapsed < MIN_INTERVAL:
                await asyncio.sleep(MIN_INTERVAL - elapsed)
            last_mistral_call_time = time.monotonic()

            try:
                response = await mistral_client.chat.stream_async(
                    model=request.model,
                    messages=[
                        {
                            "role": "user",
                            "content": request.prompt,
                        }
                    ],
                )
                async for chunk in response:
                    # Ensure chunk.data exists and has the expected structure
                    if hasattr(chunk, 'choices') and chunk.choices:
                         if chunk.choices[0].delta.content is not None:
                            yield chunk.choices[0].delta.content
                    # Handle potential variations if needed, e.g., logging unexpected structures
                    # else:
                    #     print(f"Unexpected chunk structure: {chunk}")
            except Exception as e:
                # Log the error and yield an error message if streaming fails
                print(f"Error during Mistral stream: {e}")
                yield f"Error: {str(e)}"

    return StreamingResponse(generate(), media_type="text/plain")

@router.post("/chat", response_model=APIResponse)
async def mistral_chat(request: LLMRequest, token: Annotated[str, Depends(verify_token)]):
    global last_mistral_call_time
    async with mistral_call_lock:
        current_time = time.monotonic()
        elapsed = current_time - last_mistral_call_time
        if elapsed < MIN_INTERVAL:
            await asyncio.sleep(MIN_INTERVAL - elapsed)
        last_mistral_call_time = time.monotonic()

        try:
            response = await mistral_client.chat.complete_async(
                model=request.model,
                messages=[
                    {
                        "role": "user",
                        "content": request.prompt,
                    }
                ],
            )
            if response.choices and response.choices[0].message:
                content = response.choices[0].message.content
                return APIResponse(success=True, data={"response": content})
            else:
                return APIResponse(success=False, error="No response content received from Mistral.", data=response.dict())
        except Exception as e:
            print(f"Error calling Mistral chat completion: {e}")
            raise HTTPException(status_code=500, detail=f"Mistral API error: {str(e)}")

@router.post("/chat-with-memory", response_model=APIResponse)
async def mistral_chat_with_memory(request: ChatMemoryRequest, token: Annotated[str, Depends(verify_token)]):
    global last_mistral_call_time
    async with mistral_call_lock:
        current_time = time.monotonic()
        elapsed = current_time - last_mistral_call_time
        if elapsed < MIN_INTERVAL:
            await asyncio.sleep(MIN_INTERVAL - elapsed)
        last_mistral_call_time = time.monotonic()

        try:
            # Convert Pydantic models to dicts for the Mistral API call
            messages_dict = [msg.dict() for msg in request.messages]

            response = await mistral_client.chat.complete_async(
                model=request.model,
                messages=messages_dict,
            )
            if response.choices and response.choices[0].message:
                content = response.choices[0].message.content
                return APIResponse(success=True, data={"response": content})
            else:
                return APIResponse(success=False, error="No response content received from Mistral.", data=response.dict())
        except Exception as e:
            print(f"Error calling Mistral chat completion with memory: {e}")
            raise HTTPException(status_code=500, detail=f"Mistral API error: {str(e)}")