Spaces:
Running
Running
import json | |
import time | |
import pandas as pd | |
from os import environ | |
import datetime | |
import streamlit as st | |
from langchain.schema import Document | |
from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \ | |
ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \ | |
ChatDataSQLAskCallBackHandler | |
from langchain.schema import BaseMessage, HumanMessage, AIMessage, FunctionMessage, SystemMessage | |
from auth0_component import login_button | |
from helper import build_tools, build_agents, build_all, sel_map, display | |
environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE'] | |
st.set_page_config(page_title="ChatData", page_icon="https://myscale.com/favicon.ico") | |
st.header("ChatData") | |
if 'retriever' not in st.session_state: | |
st.session_state["sel_map_obj"] = build_all() | |
st.session_state["tools"] = build_tools() | |
def on_chat_submit(): | |
ret = st.session_state.agents[st.session_state.sel][st.session_state.ret_type]({"input": st.session_state.chat_input}) | |
print(ret) | |
def clear_history(): | |
st.session_state.agents[st.session_state.sel][st.session_state.ret_type].memory.clear() | |
AUTH0_CLIENT_ID = st.secrets['AUTH0_CLIENT_ID'] | |
AUTH0_DOMAIN = st.secrets['AUTH0_DOMAIN'] | |
def login(): | |
if "user_name" in st.session_state or ("jump_query_ask" in st.session_state and st.session_state.jump_query_ask): | |
return True | |
st.subheader("π€ Welcom to [MyScale](https://myscale.com)'s [ChatData](https://github.com/myscale/ChatData)! π€ ") | |
st.write("You can now chat with ArXiv and Wikipedia! You can also try to build your RAG system with those knowledge base via [our public read-only credentials!](https://github.com/myscale/ChatData#data-schema) π\n") | |
st.write("Built purely with streamlit π , LangChain π¦π and love for AI!") | |
st.write("Follow us on [Twitter](https://x.com/myscaledb) and [Discord](https://discord.gg/D2qpkqc4Jq)!") | |
st.warning("To use chat, please jump to [https://myscale-chatdata.hf.space](https://myscale-chatdata.hf.space)") | |
st.info("We used [Auth0](https://auth0.com) as our identity provider. " | |
"We will **NOT** collect any of your conversation in any form for any purpose.") | |
st.divider() | |
col1, col2 = st.columns(2, gap='large') | |
with col1.container(): | |
st.write("Try out MyScale's Self-query and Vector SQL retrievers!") | |
st.write("In this demo, you will be able to see how those retrievers " | |
"**digest** -> **translate** -> **retrieve** -> **answer** to your question!") | |
st.write("It is a step-by-step tour to understand RAG pipeline.") | |
st.session_state["jump_query_ask"] = st.button("Query / Ask") | |
with col2.container(): | |
st.write("Now with the power of LangChain's Conversantional Agents, we are able to build " | |
"conversational chatbot with RAG! The agent will decide when and what to retrieve " | |
"based on your question!") | |
st.write("All those conversation history management and retrievers are provided within one MyScale instance!") | |
st.write("Log in to Chat with RAG!") | |
login_button(AUTH0_CLIENT_ID, AUTH0_DOMAIN, "auth0") | |
if st.session_state.auth0 is not None: | |
st.session_state.user_info = dict(st.session_state.auth0) | |
if 'email' in st.session_state.user_info: | |
email = st.session_state.user_info["email"] | |
else: | |
email = f"{st.session_state.user_info['nickname']}@{st.session_state.user_info['sub']}" | |
st.session_state["user_name"] = email | |
del st.session_state.auth0 | |
st.experimental_rerun() | |
if st.session_state.jump_query_ask: | |
st.experimental_rerun() | |
def back_to_main(): | |
if "user_info" in st.session_state: | |
del st.session_state.user_info | |
if "user_name" in st.session_state: | |
del st.session_state.user_name | |
if "jump_query_ask" in st.session_state: | |
del st.session_state.jump_query_ask | |
if login(): | |
if "user_name" in st.session_state: | |
st.session_state["agents"] = build_agents(st.session_state.user_name) | |
with st.sidebar: | |
st.radio("Retriever Type", ["Self-querying retriever", "Vector SQL"], key="ret_type") | |
st.selectbox("Knowledge Base", ["ArXiv Papers", "Wikipedia", "ArXiv + Wikipedia"], key="sel") | |
st.button("Clear Chat History", on_click=clear_history) | |
st.button("Logout", on_click=back_to_main) | |
for msg in st.session_state.agents[st.session_state.sel][st.session_state.ret_type].memory.chat_memory.messages: | |
speaker = "user" if isinstance(msg, HumanMessage) else "assistant" | |
if isinstance(msg, FunctionMessage): | |
with st.chat_message("Knowledge Base", avatar="π"): | |
print(type(msg.content)) | |
st.write(f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*") | |
st.write("Retrieved from knowledge base:") | |
st.dataframe(pd.DataFrame.from_records(map(dict, eval(msg.content)))) | |
else: | |
if len(msg.content) > 0: | |
with st.chat_message(speaker): | |
print(type(msg), msg.dict()) | |
st.write(f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*") | |
st.write(f"{msg.content}") | |
st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input") | |
elif "jump_query_ask" in st.session_state and st.session_state.jump_query_ask: | |
sel = st.selectbox('Choose the knowledge base you want to ask with:', | |
options=['ArXiv Papers', 'Wikipedia']) | |
sel_map[sel]['hint']() | |
tab_sql, tab_self_query = st.tabs(['Vector SQL', 'Self-Query Retrievers']) | |
with tab_sql: | |
sel_map[sel]['hint_sql']() | |
st.text_input("Ask a question:", key='query_sql') | |
cols = st.columns([1, 1, 1, 4]) | |
cols[0].button("Query", key='search_sql') | |
cols[1].button("Ask", key='ask_sql') | |
cols[2].button("Back", key='back_sql', on_click=back_to_main) | |
plc_hldr = st.empty() | |
if st.session_state.search_sql: | |
plc_hldr = st.empty() | |
print(st.session_state.query_sql) | |
with plc_hldr.expander('Query Log', expanded=True): | |
callback = ChatDataSQLSearchCallBackHandler() | |
try: | |
docs = st.session_state.sel_map_obj[sel]["sql_retriever"].get_relevant_documents( | |
st.session_state.query_sql, callbacks=[callback]) | |
callback.progress_bar.progress(value=1.0, text="Done!") | |
docs = pd.DataFrame( | |
[{**d.metadata, 'abstract': d.page_content} for d in docs]) | |
display(docs) | |
except Exception as e: | |
st.write('Oops π΅ Something bad happened...') | |
raise e | |
if st.session_state.ask_sql: | |
plc_hldr = st.empty() | |
print(st.session_state.query_sql) | |
with plc_hldr.expander('Chat Log', expanded=True): | |
callback = ChatDataSQLAskCallBackHandler() | |
try: | |
ret = st.session_state.sel_map_obj[sel]["sql_chain"]( | |
st.session_state.query_sql, callbacks=[callback]) | |
callback.progress_bar.progress(value=1.0, text="Done!") | |
st.markdown( | |
f"### Answer from LLM\n{ret['answer']}\n### References") | |
docs = ret['sources'] | |
docs = pd.DataFrame( | |
[{**d.metadata, 'abstract': d.page_content} for d in docs]) | |
display( | |
docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id') | |
except Exception as e: | |
st.write('Oops π΅ Something bad happened...') | |
raise e | |
with tab_self_query: | |
st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='π‘') | |
st.dataframe(st.session_state.sel_map_obj[sel]["metadata_columns"]) | |
st.text_input("Ask a question:", key='query_self') | |
cols = st.columns([1, 1, 1, 4]) | |
cols[0].button("Query", key='search_self') | |
cols[1].button("Ask", key='ask_self') | |
cols[2].button("Back", key='back_self', on_click=back_to_main) | |
plc_hldr = st.empty() | |
if st.session_state.search_self: | |
plc_hldr = st.empty() | |
print(st.session_state.query_self) | |
with plc_hldr.expander('Query Log', expanded=True): | |
call_back = None | |
callback = ChatDataSelfSearchCallBackHandler() | |
try: | |
docs = st.session_state.sel_map_obj[sel]["retriever"].get_relevant_documents( | |
st.session_state.query_self, callbacks=[callback]) | |
print(docs) | |
callback.progress_bar.progress(value=1.0, text="Done!") | |
docs = pd.DataFrame( | |
[{**d.metadata, 'abstract': d.page_content} for d in docs]) | |
display(docs, sel_map[sel]["must_have_cols"]) | |
except Exception as e: | |
st.write('Oops π΅ Something bad happened...') | |
raise e | |
if st.session_state.ask_self: | |
plc_hldr = st.empty() | |
print(st.session_state.query_self) | |
with plc_hldr.expander('Chat Log', expanded=True): | |
call_back = None | |
callback = ChatDataSelfAskCallBackHandler() | |
try: | |
ret = st.session_state.sel_map_obj[sel]["chain"]( | |
st.session_state.query_self, callbacks=[callback]) | |
callback.progress_bar.progress(value=1.0, text="Done!") | |
st.markdown( | |
f"### Answer from LLM\n{ret['answer']}\n### References") | |
docs = ret['sources'] | |
docs = pd.DataFrame( | |
[{**d.metadata, 'abstract': d.page_content} for d in docs]) | |
display( | |
docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id') | |
except Exception as e: | |
st.write('Oops π΅ Something bad happened...') | |
raise e |