Create callback.py
Browse files- callback.py +33 -0
callback.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Callback handlers used in the app."""
|
2 |
+
from typing import Any, Dict, List
|
3 |
+
|
4 |
+
from langchain.callbacks.base import AsyncCallbackHandler
|
5 |
+
|
6 |
+
from schemas import ChatResponse
|
7 |
+
|
8 |
+
|
9 |
+
class StreamingLLMCallbackHandler(AsyncCallbackHandler):
|
10 |
+
"""Callback handler for streaming LLM responses."""
|
11 |
+
|
12 |
+
def __init__(self, websocket):
|
13 |
+
self.websocket = websocket
|
14 |
+
|
15 |
+
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
16 |
+
resp = ChatResponse(sender="bot", message=token, type="stream")
|
17 |
+
await self.websocket.send_json(resp.dict())
|
18 |
+
|
19 |
+
|
20 |
+
class QuestionGenCallbackHandler(AsyncCallbackHandler):
|
21 |
+
"""Callback handler for question generation."""
|
22 |
+
|
23 |
+
def __init__(self, websocket):
|
24 |
+
self.websocket = websocket
|
25 |
+
|
26 |
+
async def on_llm_start(
|
27 |
+
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
28 |
+
) -> None:
|
29 |
+
"""Run when LLM starts running."""
|
30 |
+
resp = ChatResponse(
|
31 |
+
sender="bot", message="Synthesizing question...", type="info"
|
32 |
+
)
|
33 |
+
await self.websocket.send_json(resp.dict())
|