File size: 10,990 Bytes
19bd5a9 401cf68 19bd5a9 17c6622 19bd5a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
import json
import time
import pandas as pd
from os import environ
import datetime
import streamlit as st
from langchain.schema import Document
from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \
from langchain.schema import BaseMessage, HumanMessage, AIMessage, FunctionMessage, SystemMessage
from auth0_component import login_button
from helper import build_tools, build_agents, build_all, sel_map, display
environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']
st.set_page_config(page_title="ChatData", page_icon="")
if 'retriever' not in st.session_state:
st.session_state["sel_map_obj"] = build_all()
st.session_state["tools"] = build_tools()
def on_chat_submit():
ret = st.session_state.agents[st.session_state.sel][st.session_state.ret_type]({"input": st.session_state.chat_input})
def clear_history():
AUTH0_DOMAIN = st.secrets['AUTH0_DOMAIN']
def login():
if "user_name" in st.session_state or ("jump_query_ask" in st.session_state and st.session_state.jump_query_ask):
return True
st.subheader("π€ Welcom to [MyScale]('s [ChatData](! π€ ")
st.write("You can now chat with ArXiv and Wikipedia! You can also try to build your RAG system with those knowledge base via [our public read-only credentials!]( π\n")
st.write("Built purely with streamlit π , LangChain π¦π and love for AI!")
st.write("Follow us on [Twitter]( and [Discord](!")
st.warning("To use chat, please jump to [](")"We used [Auth0]( as our identity provider. "
"We will **NOT** collect any of your conversation in any form for any purpose.")
col1, col2 = st.columns(2, gap='large')
with col1.container():
st.write("Try out MyScale's Self-query and Vector SQL retrievers!")
st.write("In this demo, you will be able to see how those retrievers "
"**digest** -> **translate** -> **retrieve** -> **answer** to your question!")
st.write("It is a step-by-step tour to understand RAG pipeline.")
st.session_state["jump_query_ask"] = st.button("Query / Ask")
with col2.container():
st.write("Now with the power of LangChain's Conversantional Agents, we are able to build "
"conversational chatbot with RAG! The agent will decide when and what to retrieve "
"based on your question!")
st.write("All those conversation history management and retrievers are provided within one MyScale instance!")
st.write("Log in to Chat with RAG!")
login_button(AUTH0_CLIENT_ID, AUTH0_DOMAIN, "auth0")
if st.session_state.auth0 is not None:
st.session_state.user_info = dict(st.session_state.auth0)
if 'email' in st.session_state.user_info:
email = st.session_state.user_info["email"]
email = f"{st.session_state.user_info['nickname']}@{st.session_state.user_info['sub']}"
st.session_state["user_name"] = email
del st.session_state.auth0
if st.session_state.jump_query_ask:
def back_to_main():
if "user_info" in st.session_state:
del st.session_state.user_info
if "user_name" in st.session_state:
del st.session_state.user_name
if "jump_query_ask" in st.session_state:
del st.session_state.jump_query_ask
if login():
if "user_name" in st.session_state:
st.session_state["agents"] = build_agents(st.session_state.user_name)
with st.sidebar:"Retriever Type", ["Self-querying retriever", "Vector SQL"], key="ret_type")
st.selectbox("Knowledge Base", ["ArXiv Papers", "Wikipedia", "ArXiv + Wikipedia"], key="sel")
st.button("Clear Chat History", on_click=clear_history)
st.button("Logout", on_click=back_to_main)
for msg in st.session_state.agents[st.session_state.sel][st.session_state.ret_type].memory.chat_memory.messages:
speaker = "user" if isinstance(msg, HumanMessage) else "assistant"
if isinstance(msg, FunctionMessage):
with st.chat_message("Knowledge Base", avatar="π"):
st.write("Retrieved from knowledge base:")
st.dataframe(pd.DataFrame.from_records(map(dict, eval(msg.content))))
if len(msg.content) > 0:
with st.chat_message(speaker):
print(type(msg), msg.dict())
st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input")
elif "jump_query_ask" in st.session_state and st.session_state.jump_query_ask:
sel = st.selectbox('Choose the knowledge base you want to ask with:',
options=['ArXiv Papers', 'Wikipedia'])
tab_sql, tab_self_query = st.tabs(['Vector SQL', 'Self-Query Retrievers'])
with tab_sql:
st.text_input("Ask a question:", key='query_sql')
cols = st.columns([1, 1, 1, 4])
cols[0].button("Query", key='search_sql')
cols[1].button("Ask", key='ask_sql')
cols[2].button("Back", key='back_sql', on_click=back_to_main)
plc_hldr = st.empty()
if st.session_state.search_sql:
plc_hldr = st.empty()
with plc_hldr.expander('Query Log', expanded=True):
callback = ChatDataSQLSearchCallBackHandler()
docs = st.session_state.sel_map_obj[sel]["sql_retriever"].get_relevant_documents(
st.session_state.query_sql, callbacks=[callback])
callback.progress_bar.progress(value=1.0, text="Done!")
docs = pd.DataFrame(
[{**d.metadata, 'abstract': d.page_content} for d in docs])
except Exception as e:
st.write('Oops π΅ Something bad happened...')
raise e
if st.session_state.ask_sql:
plc_hldr = st.empty()
with plc_hldr.expander('Chat Log', expanded=True):
callback = ChatDataSQLAskCallBackHandler()
ret = st.session_state.sel_map_obj[sel]["sql_chain"](
st.session_state.query_sql, callbacks=[callback])
callback.progress_bar.progress(value=1.0, text="Done!")
f"### Answer from LLM\n{ret['answer']}\n### References")
docs = ret['sources']
docs = pd.DataFrame(
[{**d.metadata, 'abstract': d.page_content} for d in docs])
docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
except Exception as e:
st.write('Oops π΅ Something bad happened...')
raise e
with tab_self_query:"You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='π‘')
st.text_input("Ask a question:", key='query_self')
cols = st.columns([1, 1, 1, 4])
cols[0].button("Query", key='search_self')
cols[1].button("Ask", key='ask_self')
cols[2].button("Back", key='back_self', on_click=back_to_main)
plc_hldr = st.empty()
if st.session_state.search_self:
plc_hldr = st.empty()
with plc_hldr.expander('Query Log', expanded=True):
call_back = None
callback = ChatDataSelfSearchCallBackHandler()
docs = st.session_state.sel_map_obj[sel]["retriever"].get_relevant_documents(
st.session_state.query_self, callbacks=[callback])
callback.progress_bar.progress(value=1.0, text="Done!")
docs = pd.DataFrame(
[{**d.metadata, 'abstract': d.page_content} for d in docs])
display(docs, sel_map[sel]["must_have_cols"])
except Exception as e:
st.write('Oops π΅ Something bad happened...')
raise e
if st.session_state.ask_self:
plc_hldr = st.empty()
with plc_hldr.expander('Chat Log', expanded=True):
call_back = None
callback = ChatDataSelfAskCallBackHandler()
ret = st.session_state.sel_map_obj[sel]["chain"](
st.session_state.query_self, callbacks=[callback])
callback.progress_bar.progress(value=1.0, text="Done!")
f"### Answer from LLM\n{ret['answer']}\n### References")
docs = ret['sources']
docs = pd.DataFrame(
[{**d.metadata, 'abstract': d.page_content} for d in docs])
docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
except Exception as e:
st.write('Oops π΅ Something bad happened...')
raise e |