File size: 10,990 Bytes
19bd5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401cf68
19bd5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17c6622
 
 
 
19bd5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import json
import time
import pandas as pd
from os import environ
import datetime
import streamlit as st
from langchain.schema import Document

from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
    ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \
    ChatDataSQLAskCallBackHandler

from langchain.schema import BaseMessage, HumanMessage, AIMessage, FunctionMessage, SystemMessage
from auth0_component import login_button


from helper import build_tools, build_agents, build_all, sel_map, display

environ['OPENAI_API_BASE'] = st.secrets['OPENAI_API_BASE']

st.set_page_config(page_title="ChatData", page_icon="https://myscale.com/favicon.ico")
st.header("ChatData")


if 'retriever' not in st.session_state:
    st.session_state["sel_map_obj"] = build_all()
    st.session_state["tools"] = build_tools()

def on_chat_submit():
    ret = st.session_state.agents[st.session_state.sel][st.session_state.ret_type]({"input": st.session_state.chat_input})
    print(ret)
    
def clear_history():
    st.session_state.agents[st.session_state.sel][st.session_state.ret_type].memory.clear()

AUTH0_CLIENT_ID = st.secrets['AUTH0_CLIENT_ID']
AUTH0_DOMAIN = st.secrets['AUTH0_DOMAIN']

def login():
    if "user_name" in st.session_state or ("jump_query_ask" in st.session_state and st.session_state.jump_query_ask):
        return True
    st.subheader("πŸ€— Welcom to [MyScale](https://myscale.com)'s [ChatData](https://github.com/myscale/ChatData)! πŸ€— ")
    st.write("You can now chat with ArXiv and Wikipedia! You can also try to build your RAG system with those knowledge base via [our public read-only credentials!](https://github.com/myscale/ChatData#data-schema) 🌟\n")
    st.write("Built purely with streamlit πŸ‘‘ , LangChain πŸ¦œπŸ”— and love for AI!")
    st.write("Follow us on [Twitter](https://x.com/myscaledb) and [Discord](https://discord.gg/D2qpkqc4Jq)!")
    st.warning("To use chat, please jump to [https://myscale-chatdata.hf.space](https://myscale-chatdata.hf.space)")
    st.info("We used [Auth0](https://auth0.com) as our identity provider. "
            "We will **NOT** collect any of your conversation in any form for any purpose.")
    st.divider()
    col1, col2 = st.columns(2, gap='large')
    with col1.container():
        st.write("Try out MyScale's Self-query and Vector SQL retrievers!")
        st.write("In this demo, you will be able to see how those retrievers "
                 "**digest** -> **translate** -> **retrieve** -> **answer** to your question!")
        st.write("It is a step-by-step tour to understand RAG pipeline.")
        st.session_state["jump_query_ask"] = st.button("Query / Ask")
    with col2.container():
        st.write("Now with the power of LangChain's Conversantional Agents, we are able to build "
                 "conversational chatbot with RAG! The agent will decide when and what to retrieve "
                 "based on your question!")
        st.write("All those conversation history management and retrievers are provided within one MyScale instance!")
        st.write("Log in to Chat with RAG!")
        login_button(AUTH0_CLIENT_ID, AUTH0_DOMAIN, "auth0")
    if st.session_state.auth0 is not None:
        st.session_state.user_info = dict(st.session_state.auth0) 
        if 'email' in st.session_state.user_info:
            email = st.session_state.user_info["email"]
        else:
            email = f"{st.session_state.user_info['nickname']}@{st.session_state.user_info['sub']}"
        st.session_state["user_name"] = email
        del st.session_state.auth0
        st.experimental_rerun()
    if st.session_state.jump_query_ask:
        st.experimental_rerun()
        
def back_to_main():
    if "user_info" in st.session_state:
        del st.session_state.user_info
    if "user_name" in st.session_state:
        del st.session_state.user_name
    if "jump_query_ask" in st.session_state:
        del st.session_state.jump_query_ask

if login():
    if "user_name" in st.session_state:
        st.session_state["agents"] = build_agents(st.session_state.user_name)
        with st.sidebar:
            st.radio("Retriever Type", ["Self-querying retriever", "Vector SQL"], key="ret_type")
            st.selectbox("Knowledge Base", ["ArXiv Papers", "Wikipedia", "ArXiv + Wikipedia"], key="sel")
            st.button("Clear Chat History", on_click=clear_history)
            st.button("Logout", on_click=back_to_main)
        for msg in st.session_state.agents[st.session_state.sel][st.session_state.ret_type].memory.chat_memory.messages:
            speaker = "user" if isinstance(msg, HumanMessage) else "assistant"
            if isinstance(msg, FunctionMessage):
                with st.chat_message("Knowledge Base", avatar="πŸ“–"):
                    print(type(msg.content))
                    st.write(f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*")
                    st.write("Retrieved from knowledge base:")
                    try:
                        st.dataframe(pd.DataFrame.from_records(map(dict, eval(msg.content))))
                    except:
                        st.write(msg.content)
            else:
                if len(msg.content) > 0:
                    with st.chat_message(speaker):
                        print(type(msg), msg.dict())
                        st.write(f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*")
                        st.write(f"{msg.content}")
        st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input")
    elif "jump_query_ask" in st.session_state and st.session_state.jump_query_ask:
        
        sel = st.selectbox('Choose the knowledge base you want to ask with:',
                        options=['ArXiv Papers', 'Wikipedia'])
        sel_map[sel]['hint']()
        tab_sql, tab_self_query = st.tabs(['Vector SQL', 'Self-Query Retrievers'])
        with tab_sql:
            sel_map[sel]['hint_sql']()
            st.text_input("Ask a question:", key='query_sql')
            cols = st.columns([1, 1, 1, 4])
            cols[0].button("Query", key='search_sql')
            cols[1].button("Ask", key='ask_sql')
            cols[2].button("Back", key='back_sql', on_click=back_to_main)
            plc_hldr = st.empty()
            if st.session_state.search_sql:
                plc_hldr = st.empty()
                print(st.session_state.query_sql)
                with plc_hldr.expander('Query Log', expanded=True):
                    callback = ChatDataSQLSearchCallBackHandler()
                    try:
                        docs = st.session_state.sel_map_obj[sel]["sql_retriever"].get_relevant_documents(
                            st.session_state.query_sql, callbacks=[callback])
                        callback.progress_bar.progress(value=1.0, text="Done!")
                        docs = pd.DataFrame(
                            [{**d.metadata, 'abstract': d.page_content} for d in docs])
                        display(docs)
                    except Exception as e:
                        st.write('Oops 😡 Something bad happened...')
                        raise e

            if st.session_state.ask_sql:
                plc_hldr = st.empty()
                print(st.session_state.query_sql)
                with plc_hldr.expander('Chat Log', expanded=True):
                    callback = ChatDataSQLAskCallBackHandler()
                    try:
                        ret = st.session_state.sel_map_obj[sel]["sql_chain"](
                            st.session_state.query_sql, callbacks=[callback])
                        callback.progress_bar.progress(value=1.0, text="Done!")
                        st.markdown(
                            f"### Answer from LLM\n{ret['answer']}\n### References")
                        docs = ret['sources']
                        docs = pd.DataFrame(
                            [{**d.metadata, 'abstract': d.page_content} for d in docs])
                        display(
                            docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
                    except Exception as e:
                        st.write('Oops 😡 Something bad happened...')
                        raise e


        with tab_self_query:
            st.info("You can retrieve papers with button `Query` or ask questions based on retrieved papers with button `Ask`.", icon='πŸ’‘')
            st.dataframe(st.session_state.sel_map_obj[sel]["metadata_columns"])
            st.text_input("Ask a question:", key='query_self')
            cols = st.columns([1, 1, 1, 4])
            cols[0].button("Query", key='search_self')
            cols[1].button("Ask", key='ask_self')
            cols[2].button("Back", key='back_self', on_click=back_to_main)
            plc_hldr = st.empty()
            if st.session_state.search_self:
                plc_hldr = st.empty()
                print(st.session_state.query_self)
                with plc_hldr.expander('Query Log', expanded=True):
                    call_back = None
                    callback = ChatDataSelfSearchCallBackHandler()
                    try:
                        docs = st.session_state.sel_map_obj[sel]["retriever"].get_relevant_documents(
                            st.session_state.query_self, callbacks=[callback])
                        print(docs)
                        callback.progress_bar.progress(value=1.0, text="Done!")
                        docs = pd.DataFrame(
                            [{**d.metadata, 'abstract': d.page_content} for d in docs])
                        display(docs, sel_map[sel]["must_have_cols"])
                    except Exception as e:
                        st.write('Oops 😡 Something bad happened...')
                        raise e

            if st.session_state.ask_self:
                plc_hldr = st.empty()
                print(st.session_state.query_self)
                with plc_hldr.expander('Chat Log', expanded=True):
                    call_back = None
                    callback = ChatDataSelfAskCallBackHandler()
                    try:
                        ret = st.session_state.sel_map_obj[sel]["chain"](
                            st.session_state.query_self, callbacks=[callback])
                        callback.progress_bar.progress(value=1.0, text="Done!")
                        st.markdown(
                            f"### Answer from LLM\n{ret['answer']}\n### References")
                        docs = ret['sources']
                        docs = pd.DataFrame(
                            [{**d.metadata, 'abstract': d.page_content} for d in docs])
                        display(
                            docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
                    except Exception as e:
                        st.write('Oops 😡 Something bad happened...')
                        raise e