AkwabaGPT / app.py
Monsia's picture
app v 0.1.1
667b788
import chainlit as cl
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import StrOutputParser
from langchain.schema.runnable import RunnableConfig, RunnablePassthrough
from langchain_openai import ChatOpenAI
import configs
from prompts import prompt
from utils import format_docs, process_documents
doc_search = process_documents(configs.DOCS_STORAGE_PATH)
model = ChatOpenAI(name=configs.CHAT_MODEL, streaming=True)
@cl.on_chat_start
async def on_chat_start():
retriever = doc_search.as_retriever(search_kwargs={"k": 10})
runnable = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| model
| StrOutputParser()
)
cl.user_session.set("runnable", runnable)
@cl.on_message
async def on_message(message: cl.Message):
runnable = cl.user_session.get("runnable")
msg = cl.Message(content="")
class PostMessageHandler(BaseCallbackHandler):
"""
Callback handler for handling the retriever and LLM processes.
Used to post the sources of the retrieved documents as a Chainlit element.
"""
def __init__(self, msg: cl.Message):
BaseCallbackHandler.__init__(self)
self.msg = msg
self.sources = set() # To store unique pairs
def on_retriever_end(self, documents, *, run_id, parent_run_id, **kwargs):
for d in documents:
source_page_pair = d.metadata["source"]
self.sources.add(source_page_pair) # Add unique pairs to the set
def on_llm_end(self, response, *, run_id, parent_run_id, **kwargs):
if len(self.sources):
sources_text = "\n".join(
[f"{source}#page={page}" for source, page in self.sources]
)
self.msg.elements.append(
cl.Text(name="Sources", content=sources_text, display="inline")
)
async with cl.Step(type="run", name="QA Assistant"):
async for chunk in runnable.astream(
message.content,
config=RunnableConfig(
callbacks=[cl.LangchainCallbackHandler(), PostMessageHandler(msg)]
),
):
await msg.stream_token(chunk)
with open(configs.HIISTORY_FILE, "a") as f:
f.write(f"""{message.content}[SEP]{msg.content}[END]\n\n""")
await msg.send()