ChatData / chat.py
Fangrui Liu
update sdk_version
401cf68
raw
history blame
10.9 kB
import json
import time
import pandas as pd
from os import environ
import datetime
import streamlit as st
from langchain.schema import Document
from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \
ChatDataSQLAskCallBackHandler
from langchain.schema import BaseMessage, HumanMessage, AIMessage, FunctionMessage, SystemMessage
from auth0_component import login_button
from helper import build_tools, build_agents, build_all, sel_map, display
environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
st.set_page_config(page_title="ChatData", page_icon="https://myscale.com/favicon.ico")
st.header("ChatData")
if 'retriever' not in st.session_state:
st.session_state["sel_map_obj"] = build_all()
st.session_state["tools"] = build_tools()
def on_chat_submit():
ret = st.session_state.agents[st.session_state.sel][st.session_state.ret_type]({"input": st.session_state.chat_input})
print(ret)
def clear_history():
st.session_state.agents[st.session_state.sel][st.session_state.ret_type].memory.clear()
AUTH0_CLIENT_ID = st.secrets['AUTH0_CLIENT_ID']
AUTH0_DOMAIN = st.secrets['AUTH0_DOMAIN']
def login():
if "user_name" in st.session_state or ("jump_query_ask" in st.session_state and st.session_state.jump_query_ask):
return True
st.subheader("πŸ€— Welcom to [MyScale](https://myscale.com)'s [ChatData](https://github.com/myscale/ChatData)! πŸ€— ")
st.write("You can now chat with ArXiv and Wikipedia! You can also try to build your RAG system with those knowledge base via [our public read-only credentials!](https://github.com/myscale/ChatData#data-schema) 🌟\n")
st.write("Built purely with streamlit πŸ‘‘ , LangChain πŸ¦œπŸ”— and love for AI!")
st.write("Follow us on [Twitter](https://x.com/myscaledb) and [Discord](https://discord.gg/D2qpkqc4Jq)!")
st.warning("To use chat, please jump to [https://myscale-chatdata.hf.space](https://myscale-chatdata.hf.space)")
st.info("We used [Auth0](https://auth0.com) as our identity provider. "
"We will **NOT** collect any of your conversation in any form for any purpose.")
st.divider()
col1, col2 = st.columns(2, gap='large')
with col1.container():
st.write("Try out MyScale's Self-query and Vector SQL retrievers!")
st.write("In this demo, you will be able to see how those retrievers "
"**digest** -> **translate** -> **retrieve** -> **answer** to your question!")
st.write("It is a step-by-step tour to understand RAG pipeline.")
st.session_state["jump_query_ask"] = st.button("Query / Ask")
with col2.container():
st.write("Now with the power of LangChain's Conversantional Agents, we are able to build "
"conversational chatbot with RAG! The agent will decide when and what to retrieve "
"based on your question!")
st.write("All those conversation history management and retrievers are provided within one MyScale instance!")
st.write("Log in to Chat with RAG!")
login_button(AUTH0_CLIENT_ID, AUTH0_DOMAIN, "auth0")
if st.session_state.auth0 is not None:
st.session_state.user_info = dict(st.session_state.auth0)
if 'email' in st.session_state.user_info:
email = st.session_state.user_info["email"]
else:
email = f"{st.session_state.user_info['nickname']}@{st.session_state.user_info['sub']}"
st.session_state["user_name"] = email
del st.session_state.auth0
st.experimental_rerun()
if st.session_state.jump_query_ask:
st.experimental_rerun()
def back_to_main():
if "user_info" in st.session_state:
del st.session_state.user_info
if "user_name" in st.session_state:
del st.session_state.user_name
if "jump_query_ask" in st.session_state:
del st.session_state.jump_query_ask
if login():
if "user_name" in st.session_state:
st.session_state["agents"] = build_agents(st.session_state.user_name)
with st.sidebar:
st.radio("Retriever Type", ["Self-querying retriever", "Vector SQL"], key="ret_type")
st.selectbox("Knowledge Base", ["ArXiv Papers", "Wikipedia", "ArXiv + Wikipedia"], key="sel")
st.button("Clear Chat History", on_click=clear_history)
st.button("Logout", on_click=back_to_main)
for msg in st.session_state.agents[st.session_state.sel][st.session_state.ret_type].memory.chat_memory.messages:
speaker = "user" if isinstance(msg, HumanMessage) else "assistant"
if isinstance(msg, FunctionMessage):
with st.chat_message("Knowledge Base", avatar="πŸ“–"):
print(type(msg.content))
st.write(f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*")
st.write("Retrieved from knowledge base:")
st.dataframe(pd.DataFrame.from_records(map(dict, eval(msg.content))))
else:
if len(msg.content) > 0:
with st.chat_message(speaker):
print(type(msg), msg.dict())
st.write(f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*")
st.write(f"{msg.content}")
st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input")
elif "jump_query_ask" in st.session_state and st.session_state.jump_query_ask:
sel = st.selectbox('Choose the knowledge base you want to ask with:',
options=['ArXiv Papers', 'Wikipedia'])
sel_map[sel]['hint']()
tab_sql, tab_self_query = st.tabs(['Vector SQL', 'Self-Query Retrievers'])
with tab_sql:
sel_map[sel]['hint_sql']()
st.text_input("Ask a question:", key='query_sql')
cols = st.columns([1, 1, 1, 4])
cols[0].button("Query", key='search_sql')
cols[1].button("Ask", key='ask_sql')
cols[2].button("Back", key='back_sql', on_click=back_to_main)
plc_hldr = st.empty()
if st.session_state.search_sql:
plc_hldr = st.empty()
print(st.session_state.query_sql)
with plc_hldr.expander('Query Log', expanded=True):
callback = ChatDataSQLSearchCallBackHandler()
try:
docs = st.session_state.sel_map_obj[sel]["sql_retriever"].get_relevant_documents(
st.session_state.query_sql, callbacks=[callback])
callback.progress_bar.progress(value=1.0, text="Done!")
docs = pd.DataFrame(
[{**d.metadata, 'abstract': d.page_content} for d in docs])
display(docs)
except Exception as e:
st.write('Oops 😡 Something bad happened...')
raise e
if st.session_state.ask_sql:
plc_hldr = st.empty()
print(st.session_state.query_sql)
with plc_hldr.expander('Chat Log', expanded=True):
callback = ChatDataSQLAskCallBackHandler()
try:
ret = st.session_state.sel_map_obj[sel]["sql_chain"](
st.session_state.query_sql, callbacks=[callback])
callback.progress_bar.progress(value=1.0, text="Done!")
st.markdown(
f"### Answer from LLM\n{ret['answer']}\n### References")
docs = ret['sources']
docs = pd.DataFrame(
[{**d.metadata, 'abstract': d.page_content} for d in docs])
display(
docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
except Exception as e:
st.write('Oops 😡 Something bad happened...')
raise e
with tab_self_query:
st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='πŸ’‘')
st.dataframe(st.session_state.sel_map_obj[sel]["metadata_columns"])
st.text_input("Ask a question:", key='query_self')
cols = st.columns([1, 1, 1, 4])
cols[0].button("Query", key='search_self')
cols[1].button("Ask", key='ask_self')
cols[2].button("Back", key='back_self', on_click=back_to_main)
plc_hldr = st.empty()
if st.session_state.search_self:
plc_hldr = st.empty()
print(st.session_state.query_self)
with plc_hldr.expander('Query Log', expanded=True):
call_back = None
callback = ChatDataSelfSearchCallBackHandler()
try:
docs = st.session_state.sel_map_obj[sel]["retriever"].get_relevant_documents(
st.session_state.query_self, callbacks=[callback])
print(docs)
callback.progress_bar.progress(value=1.0, text="Done!")
docs = pd.DataFrame(
[{**d.metadata, 'abstract': d.page_content} for d in docs])
display(docs, sel_map[sel]["must_have_cols"])
except Exception as e:
st.write('Oops 😡 Something bad happened...')
raise e
if st.session_state.ask_self:
plc_hldr = st.empty()
print(st.session_state.query_self)
with plc_hldr.expander('Chat Log', expanded=True):
call_back = None
callback = ChatDataSelfAskCallBackHandler()
try:
ret = st.session_state.sel_map_obj[sel]["chain"](
st.session_state.query_self, callbacks=[callback])
callback.progress_bar.progress(value=1.0, text="Done!")
st.markdown(
f"### Answer from LLM\n{ret['answer']}\n### References")
docs = ret['sources']
docs = pd.DataFrame(
[{**d.metadata, 'abstract': d.page_content} for d in docs])
display(
docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
except Exception as e:
st.write('Oops 😡 Something bad happened...')
raise e