Spaces:
Paused
Paused
Daniel Marques
commited on
Commit
·
abff149
1
Parent(s):
e04f8da
fix: add websocket in handlerToken
Browse files
main.py
CHANGED
@@ -20,8 +20,6 @@ from langchain.memory import ConversationBufferMemory
|
|
20 |
from langchain.callbacks.base import BaseCallbackHandler
|
21 |
from langchain.schema import LLMResult
|
22 |
|
23 |
-
from varstate import State
|
24 |
-
|
25 |
# from langchain.embeddings import HuggingFaceEmbeddings
|
26 |
from load_models import load_model
|
27 |
|
@@ -58,9 +56,9 @@ DB = Chroma(
|
|
58 |
RETRIEVER = DB.as_retriever()
|
59 |
|
60 |
class MyCustomSyncHandler(BaseCallbackHandler):
|
61 |
-
def __init__(self
|
62 |
self.end = False
|
63 |
-
self.
|
64 |
|
65 |
def on_llm_start(
|
66 |
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
@@ -77,14 +75,8 @@ class MyCustomSyncHandler(BaseCallbackHandler):
|
|
77 |
|
78 |
print(token)
|
79 |
|
80 |
-
|
81 |
# Create State
|
82 |
-
|
83 |
-
tokenMessageLLM = State()
|
84 |
-
|
85 |
-
get, update = tokenMessageLLM.create('')
|
86 |
-
|
87 |
-
handlerToken = MyCustomSyncHandler(update)
|
88 |
|
89 |
LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks=[handlerToken])
|
90 |
|
@@ -253,8 +245,8 @@ async def create_upload_file(file: UploadFile):
|
|
253 |
|
254 |
return {"filename": file.filename}
|
255 |
|
256 |
-
@api_app.websocket("/ws")
|
257 |
-
async def websocket_endpoint(websocket: WebSocket):
|
258 |
global QA
|
259 |
|
260 |
await websocket.accept()
|
@@ -265,16 +257,11 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
265 |
while True:
|
266 |
prompt = await websocket.receive_text()
|
267 |
|
268 |
-
|
269 |
|
270 |
-
if (oldReceiveText != prompt)
|
271 |
oldReceiveText = prompt
|
272 |
-
|
273 |
-
|
274 |
-
print(statusProcess);
|
275 |
-
|
276 |
-
tokenState = get()
|
277 |
-
await websocket.send_text(f"token: {tokenState}")
|
278 |
|
279 |
except WebSocketDisconnect:
|
280 |
print('disconnect')
|
|
|
20 |
from langchain.callbacks.base import BaseCallbackHandler
|
21 |
from langchain.schema import LLMResult
|
22 |
|
|
|
|
|
23 |
# from langchain.embeddings import HuggingFaceEmbeddings
|
24 |
from load_models import load_model
|
25 |
|
|
|
56 |
RETRIEVER = DB.as_retriever()
|
57 |
|
58 |
class MyCustomSyncHandler(BaseCallbackHandler):
|
59 |
+
def __init__(self):
|
60 |
self.end = False
|
61 |
+
self.callback = None
|
62 |
|
63 |
def on_llm_start(
|
64 |
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
|
|
75 |
|
76 |
print(token)
|
77 |
|
|
|
78 |
# Create State
|
79 |
+
handlerToken = MyCustomSyncHandler()
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
LLM = load_model(device_type=DEVICE_TYPE, model_id=MODEL_ID, model_basename=MODEL_BASENAME, stream=True, callbacks=[handlerToken])
|
82 |
|
|
|
245 |
|
246 |
return {"filename": file.filename}
|
247 |
|
248 |
+
@api_app.websocket("/ws/{client_id}")
|
249 |
+
async def websocket_endpoint(websocket: WebSocket, client_id: int):
|
250 |
global QA
|
251 |
|
252 |
await websocket.accept()
|
|
|
257 |
while True:
|
258 |
prompt = await websocket.receive_text()
|
259 |
|
260 |
+
handlerToken.callback = websocket.send_text;
|
261 |
|
262 |
+
if (oldReceiveText != prompt):
|
263 |
oldReceiveText = prompt
|
264 |
+
asyncio.run(QA(prompt))
|
|
|
|
|
|
|
|
|
|
|
265 |
|
266 |
except WebSocketDisconnect:
|
267 |
print('disconnect')
|