|
import chainlit as cl |
|
from langchain.callbacks.base import BaseCallbackHandler |
|
from langchain.schema import StrOutputParser |
|
from langchain.schema.runnable import RunnableConfig, RunnablePassthrough |
|
from langchain_openai import ChatOpenAI |
|
|
|
import configs |
|
from prompts import prompt |
|
from utils import format_docs, process_documents |
|
|
|
doc_search = process_documents(configs.DOCS_STORAGE_PATH) |
|
model = ChatOpenAI(name=configs.CHAT_MODEL, streaming=True) |
|
|
|
|
|
@cl.on_chat_start |
|
async def on_chat_start(): |
|
retriever = doc_search.as_retriever() |
|
|
|
runnable = ( |
|
{"context": retriever | format_docs, "question": RunnablePassthrough()} |
|
| prompt |
|
| model |
|
| StrOutputParser() |
|
) |
|
|
|
cl.user_session.set("runnable", runnable) |
|
|
|
|
|
@cl.on_message |
|
async def on_message(message: cl.Message): |
|
runnable = cl.user_session.get("runnable") |
|
msg = cl.Message(content="") |
|
|
|
class PostMessageHandler(BaseCallbackHandler): |
|
""" |
|
Callback handler for handling the retriever and LLM processes. |
|
Used to post the sources of the retrieved documents as a Chainlit element. |
|
""" |
|
|
|
def __init__(self, msg: cl.Message): |
|
BaseCallbackHandler.__init__(self) |
|
self.msg = msg |
|
self.sources = set() |
|
|
|
def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs): |
|
for d in documents: |
|
source_page_pair = d.metadata["source"] |
|
self.sources.add(source_page_pair) |
|
|
|
def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs): |
|
if len(self.sources): |
|
sources_text = "\n".join( |
|
[f"{source}#page={page}" for source, page in self.sources] |
|
) |
|
self.msg.elements.append( |
|
cl.Text(name="Sources", content=sources_text, display="inline") |
|
) |
|
|
|
async with cl.Step(type="run", name="QA Assistant"): |
|
async for chunk in runnable.astream( |
|
message.content, |
|
config=RunnableConfig( |
|
callbacks=[cl.LangchainCallbackHandler(), PostMessageHandler(msg)] |
|
), |
|
): |
|
await msg.stream_token(chunk) |
|
|
|
with open(configs.HIISTORY_FILE, "a") as f: |
|
f.write(f"""{message.content}[SEP]{msg.content}[END]\n\n""") |
|
await msg.send() |
|
|