Spaces:
Running
Running
Synced repo using 'sync_with_huggingface' Github Action
Browse files- .streamlit/config.toml +1 -5
- .streamlit/secrets.example.toml +2 -1
- app.py +74 -121
- backend/__init__.py +0 -0
- backend/callbacks/__init__.py +0 -0
- backend/callbacks/arxiv_callbacks.py +46 -0
- backend/callbacks/llm_thought_with_table.py +36 -0
- backend/callbacks/self_query_callbacks.py +57 -0
- backend/callbacks/vector_sql_callbacks.py +53 -0
- backend/chains/__init__.py +0 -0
- backend/chains/retrieval_qa_with_sources.py +70 -0
- backend/chains/stuff_documents.py +65 -0
- backend/chat_bot/__init__.py +0 -0
- backend/chat_bot/chat.py +225 -0
- backend/chat_bot/json_decoder.py +24 -0
- backend/chat_bot/message_converter.py +67 -0
- backend/chat_bot/private_knowledge_base.py +167 -0
- backend/chat_bot/session_manager.py +96 -0
- backend/chat_bot/tools.py +100 -0
- backend/constants/__init__.py +0 -0
- backend/constants/myscale_tables.py +128 -0
- backend/constants/prompts.py +128 -0
- backend/constants/streamlit_keys.py +35 -0
- backend/constants/variables.py +58 -0
- backend/construct/__init__.py +0 -0
- backend/construct/build_agents.py +82 -0
- backend/construct/build_all.py +95 -0
- backend/construct/build_chains.py +39 -0
- backend/construct/build_chat_bot.py +36 -0
- backend/construct/build_retriever_tool.py +45 -0
- backend/construct/build_retrievers.py +120 -0
- backend/retrievers/__init__.py +0 -0
- backend/retrievers/self_query.py +89 -0
- backend/retrievers/vector_sql_output_parser.py +23 -0
- backend/retrievers/vector_sql_query.py +95 -0
- backend/types/__init__.py +0 -0
- backend/types/chains_and_retrievers.py +34 -0
- backend/types/global_config.py +22 -0
- backend/types/table_config.py +25 -0
- backend/vector_store/__init__.py +0 -0
- backend/vector_store/myscale_without_metadata.py +52 -0
- logger.py +18 -0
- requirements.txt +9 -7
- ui/__init__.py +0 -0
- ui/chat_page.py +196 -0
- ui/home.py +156 -0
- ui/retrievers.py +97 -0
- ui/utils.py +18 -0
.streamlit/config.toml
CHANGED
@@ -1,6 +1,2 @@
|
|
1 |
[theme]
|
2 |
-
|
3 |
-
backgroundColor="#FFFFFF"
|
4 |
-
secondaryBackgroundColor="#D4CEFF"
|
5 |
-
textColor="#262730"
|
6 |
-
font="sans serif"
|
|
|
1 |
[theme]
|
2 |
+
base="dark"
|
|
|
|
|
|
|
|
.streamlit/secrets.example.toml
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
-
MYSCALE_HOST = "msc-
|
2 |
MYSCALE_PORT = 443
|
3 |
MYSCALE_USER = "chatdata"
|
4 |
MYSCALE_PASSWORD = "myscale_rocks"
|
|
|
5 |
OPENAI_API_BASE = "https://api.openai.com/v1"
|
6 |
OPENAI_API_KEY = "<your-openai-key>"
|
7 |
UNSTRUCTURED_API = "<your-unstructured-io-api>" # optional if you don't upload documents
|
|
|
1 |
+
MYSCALE_HOST = "msc-950b9f1f.us-east-1.aws.myscale.com" # read-only database provided by MyScale
|
2 |
MYSCALE_PORT = 443
|
3 |
MYSCALE_USER = "chatdata"
|
4 |
MYSCALE_PASSWORD = "myscale_rocks"
|
5 |
+
MYSCALE_ENABLE_HTTPS = true
|
6 |
OPENAI_API_BASE = "https://api.openai.com/v1"
|
7 |
OPENAI_API_KEY = "<your-openai-key>"
|
8 |
UNSTRUCTURED_API = "<your-unstructured-io-api>" # optional if you don't upload documents
|
app.py
CHANGED
@@ -1,133 +1,86 @@
|
|
1 |
-
import
|
2 |
-
|
|
|
3 |
import streamlit as st
|
4 |
|
5 |
-
from
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
-
from chat import chat_page
|
10 |
-
from login import login, back_to_main
|
11 |
-
from lib.helper import build_tools, build_all, sel_map, display
|
12 |
|
|
|
13 |
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
-
st.set_page_config(page_title="ChatData",
|
17 |
-
page_icon="https://myscale.com/favicon.ico")
|
18 |
-
st.markdown(
|
19 |
-
f"""
|
20 |
-
<style>
|
21 |
-
.st-e4 {{
|
22 |
-
max-width: 500px
|
23 |
-
}}
|
24 |
-
</style>""",
|
25 |
-
unsafe_allow_html=True,
|
26 |
-
)
|
27 |
-
st.header("ChatData")
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
st.session_state
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
-
if login():
|
34 |
-
if "user_name" in st.session_state:
|
35 |
-
chat_page()
|
36 |
-
elif "jump_query_ask" in st.session_state and st.session_state.jump_query_ask:
|
37 |
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
plc_hldr = st.empty()
|
51 |
-
if st.session_state.search_sql:
|
52 |
-
plc_hldr = st.empty()
|
53 |
-
print(st.session_state.query_sql)
|
54 |
-
with plc_hldr.expander('Query Log', expanded=True):
|
55 |
-
callback = ChatDataSQLSearchCallBackHandler()
|
56 |
-
try:
|
57 |
-
docs = st.session_state.sel_map_obj[sel]["sql_retriever"].get_relevant_documents(
|
58 |
-
st.session_state.query_sql, callbacks=[callback])
|
59 |
-
callback.progress_bar.progress(value=1.0, text="Done!")
|
60 |
-
docs = pd.DataFrame(
|
61 |
-
[{**d.metadata, 'abstract': d.page_content} for d in docs])
|
62 |
-
display(docs)
|
63 |
-
except Exception as e:
|
64 |
-
st.write('Oops 😵 Something bad happened...')
|
65 |
-
raise e
|
66 |
|
67 |
-
if st.session_state.ask_sql:
|
68 |
-
plc_hldr = st.empty()
|
69 |
-
print(st.session_state.query_sql)
|
70 |
-
with plc_hldr.expander('Chat Log', expanded=True):
|
71 |
-
callback = ChatDataSQLAskCallBackHandler()
|
72 |
-
try:
|
73 |
-
ret = st.session_state.sel_map_obj[sel]["sql_chain"](
|
74 |
-
st.session_state.query_sql, callbacks=[callback])
|
75 |
-
callback.progress_bar.progress(value=1.0, text="Done!")
|
76 |
-
st.markdown(
|
77 |
-
f"### Answer from LLM\n{ret['answer']}\n### References")
|
78 |
-
docs = ret['sources']
|
79 |
-
docs = pd.DataFrame(
|
80 |
-
[{**d.metadata, 'abstract': d.page_content} for d in docs])
|
81 |
-
display(
|
82 |
-
docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
|
83 |
-
except Exception as e:
|
84 |
-
st.write('Oops 😵 Something bad happened...')
|
85 |
-
raise e
|
86 |
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
plc_hldr = st.empty()
|
98 |
-
print(st.session_state.query_self)
|
99 |
-
with plc_hldr.expander('Query Log', expanded=True):
|
100 |
-
call_back = None
|
101 |
-
callback = ChatDataSelfSearchCallBackHandler()
|
102 |
-
try:
|
103 |
-
docs = st.session_state.sel_map_obj[sel]["retriever"].get_relevant_documents(
|
104 |
-
st.session_state.query_self, callbacks=[callback])
|
105 |
-
print(docs)
|
106 |
-
callback.progress_bar.progress(value=1.0, text="Done!")
|
107 |
-
docs = pd.DataFrame(
|
108 |
-
[{**d.metadata, 'abstract': d.page_content} for d in docs])
|
109 |
-
display(docs, sel_map[sel]["must_have_cols"])
|
110 |
-
except Exception as e:
|
111 |
-
st.write('Oops 😵 Something bad happened...')
|
112 |
-
raise e
|
113 |
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
ret = st.session_state.sel_map_obj[sel]["chain"](
|
122 |
-
st.session_state.query_self, callbacks=[callback])
|
123 |
-
callback.progress_bar.progress(value=1.0, text="Done!")
|
124 |
-
st.markdown(
|
125 |
-
f"### Answer from LLM\n{ret['answer']}\n### References")
|
126 |
-
docs = ret['sources']
|
127 |
-
docs = pd.DataFrame(
|
128 |
-
[{**d.metadata, 'abstract': d.page_content} for d in docs])
|
129 |
-
display(
|
130 |
-
docs, ['ref_id'] + sel_map[sel]["must_have_cols"], index='ref_id')
|
131 |
-
except Exception as e:
|
132 |
-
st.write('Oops 😵 Something bad happened...')
|
133 |
-
raise e
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
|
4 |
import streamlit as st
|
5 |
|
6 |
+
from backend.constants.streamlit_keys import DATA_INITIALIZE_NOT_STATED, DATA_INITIALIZE_COMPLETED, \
|
7 |
+
DATA_INITIALIZE_STARTED
|
8 |
+
from backend.constants.variables import DATA_INITIALIZE_STATUS, JUMP_QUERY_ASK, CHAINS_RETRIEVERS_MAPPING, \
|
9 |
+
TABLE_EMBEDDINGS_MAPPING, RETRIEVER_TOOLS, USER_NAME, GLOBAL_CONFIG, update_global_config
|
10 |
+
from backend.construct.build_all import build_chains_and_retrievers, load_embedding_models, update_retriever_tools
|
11 |
+
from backend.types.global_config import GlobalConfig
|
12 |
+
from logger import logger
|
13 |
+
from ui.chat_page import chat_page
|
14 |
+
from ui.home import render_home
|
15 |
+
from ui.retrievers import render_retrievers
|
16 |
|
|
|
|
|
|
|
17 |
|
18 |
+
# warnings.filterwarnings("ignore", category=UserWarning)
|
19 |
|
20 |
+
def prepare_environment():
|
21 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
22 |
+
os.environ["LANGCHAIN_TRACING_V2"] = "false"
|
23 |
+
# os.environ["LANGCHAIN_API_KEY"] = ""
|
24 |
+
os.environ["OPENAI_API_BASE"] = st.secrets['OPENAI_API_BASE']
|
25 |
+
os.environ["OPENAI_API_KEY"] = st.secrets['OPENAI_API_KEY']
|
26 |
+
os.environ["AUTH0_CLIENT_ID"] = st.secrets['AUTH0_CLIENT_ID']
|
27 |
+
os.environ["AUTH0_DOMAIN"] = st.secrets['AUTH0_DOMAIN']
|
28 |
+
|
29 |
+
update_global_config(GlobalConfig(
|
30 |
+
openai_api_base=st.secrets['OPENAI_API_BASE'],
|
31 |
+
openai_api_key=st.secrets['OPENAI_API_KEY'],
|
32 |
+
auth0_client_id=st.secrets['AUTH0_CLIENT_ID'],
|
33 |
+
auth0_domain=st.secrets['AUTH0_DOMAIN'],
|
34 |
+
myscale_user=st.secrets['MYSCALE_USER'],
|
35 |
+
myscale_password=st.secrets['MYSCALE_PASSWORD'],
|
36 |
+
myscale_host=st.secrets['MYSCALE_HOST'],
|
37 |
+
myscale_port=st.secrets['MYSCALE_PORT'],
|
38 |
+
query_model="gpt-3.5-turbo-0125",
|
39 |
+
chat_model="gpt-3.5-turbo-0125",
|
40 |
+
untrusted_api=st.secrets['UNSTRUCTURED_API'],
|
41 |
+
myscale_enable_https=st.secrets.get('MYSCALE_ENABLE_HTTPS', True),
|
42 |
+
))
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
+
# when refresh browser, all session keys will be cleaned.
|
46 |
+
def initialize_session_state():
|
47 |
+
if DATA_INITIALIZE_STATUS not in st.session_state:
|
48 |
+
st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_NOT_STATED
|
49 |
+
logger.info(f"Initialize session state key: {DATA_INITIALIZE_STATUS}")
|
50 |
+
if JUMP_QUERY_ASK not in st.session_state:
|
51 |
+
st.session_state[JUMP_QUERY_ASK] = False
|
52 |
+
logger.info(f"Initialize session state key: {JUMP_QUERY_ASK}")
|
53 |
|
|
|
|
|
|
|
|
|
54 |
|
55 |
+
def initialize_chat_data():
|
56 |
+
if st.session_state[DATA_INITIALIZE_STATUS] != DATA_INITIALIZE_COMPLETED:
|
57 |
+
start_time = time.time()
|
58 |
+
st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_STARTED
|
59 |
+
st.session_state[TABLE_EMBEDDINGS_MAPPING] = load_embedding_models()
|
60 |
+
st.session_state[CHAINS_RETRIEVERS_MAPPING] = build_chains_and_retrievers()
|
61 |
+
st.session_state[RETRIEVER_TOOLS] = update_retriever_tools()
|
62 |
+
# mark data initialization finished.
|
63 |
+
st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_COMPLETED
|
64 |
+
end_time = time.time()
|
65 |
+
logger.info(f"ChatData initialized finished in {round(end_time - start_time, 3)} seconds, "
|
66 |
+
f"session state keys: {list(st.session_state.keys())}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
+
st.set_page_config(
|
70 |
+
page_title="ChatData",
|
71 |
+
page_icon="https://myscale.com/favicon.ico",
|
72 |
+
initial_sidebar_state="expanded",
|
73 |
+
layout="wide",
|
74 |
+
)
|
75 |
+
|
76 |
+
prepare_environment()
|
77 |
+
initialize_session_state()
|
78 |
+
initialize_chat_data()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
80 |
+
if USER_NAME in st.session_state:
|
81 |
+
chat_page()
|
82 |
+
else:
|
83 |
+
if st.session_state[JUMP_QUERY_ASK]:
|
84 |
+
render_retrievers()
|
85 |
+
else:
|
86 |
+
render_home()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
backend/__init__.py
ADDED
File without changes
|
backend/callbacks/__init__.py
ADDED
File without changes
|
backend/callbacks/arxiv_callbacks.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import textwrap
|
3 |
+
from typing import Dict, Any, List
|
4 |
+
|
5 |
+
from langchain.callbacks.streamlit.streamlit_callback_handler import (
|
6 |
+
LLMThought,
|
7 |
+
StreamlitCallbackHandler,
|
8 |
+
)
|
9 |
+
|
10 |
+
|
11 |
+
class LLMThoughtWithKnowledgeBase(LLMThought):
|
12 |
+
def on_tool_end(
|
13 |
+
self,
|
14 |
+
output: str,
|
15 |
+
color=None,
|
16 |
+
observation_prefix=None,
|
17 |
+
llm_prefix=None,
|
18 |
+
**kwargs: Any,
|
19 |
+
) -> None:
|
20 |
+
try:
|
21 |
+
self._container.markdown(
|
22 |
+
"\n\n".join(
|
23 |
+
["### Retrieved Documents:"]
|
24 |
+
+ [
|
25 |
+
f"**{i+1}**: {textwrap.shorten(r['page_content'], width=80)}"
|
26 |
+
for i, r in enumerate(json.loads(output))
|
27 |
+
]
|
28 |
+
)
|
29 |
+
)
|
30 |
+
except Exception as e:
|
31 |
+
super().on_tool_end(output, color, observation_prefix, llm_prefix, **kwargs)
|
32 |
+
|
33 |
+
|
34 |
+
class ChatDataAgentCallBackHandler(StreamlitCallbackHandler):
|
35 |
+
def on_llm_start(
|
36 |
+
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
37 |
+
) -> None:
|
38 |
+
if self._current_thought is None:
|
39 |
+
self._current_thought = LLMThoughtWithKnowledgeBase(
|
40 |
+
parent_container=self._parent_container,
|
41 |
+
expanded=self._expand_new_thoughts,
|
42 |
+
collapse_on_complete=self._collapse_completed_thoughts,
|
43 |
+
labeler=self._thought_labeler,
|
44 |
+
)
|
45 |
+
|
46 |
+
self._current_thought.on_llm_start(serialized, prompts)
|
backend/callbacks/llm_thought_with_table.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List
|
2 |
+
|
3 |
+
import streamlit as st
|
4 |
+
from langchain_core.outputs import LLMResult
|
5 |
+
from streamlit.external.langchain import StreamlitCallbackHandler
|
6 |
+
|
7 |
+
|
8 |
+
class ChatDataSelfQueryCallBack(StreamlitCallbackHandler):
|
9 |
+
def __init__(self):
|
10 |
+
super().__init__(st.container())
|
11 |
+
self._current_thought = None
|
12 |
+
self.progress_bar = st.progress(value=0.0, text="Executing ChatData SelfQuery CallBack...")
|
13 |
+
|
14 |
+
def on_llm_start(
|
15 |
+
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
16 |
+
) -> None:
|
17 |
+
self.progress_bar.progress(value=0.35, text="Communicate with LLM...")
|
18 |
+
pass
|
19 |
+
|
20 |
+
def on_chain_end(self, outputs, **kwargs) -> None:
|
21 |
+
if len(kwargs['tags']) == 0:
|
22 |
+
self.progress_bar.progress(value=0.75, text="Searching in DB...")
|
23 |
+
|
24 |
+
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
|
25 |
+
|
26 |
+
pass
|
27 |
+
|
28 |
+
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
29 |
+
st.markdown("### Generate filter by LLM \n"
|
30 |
+
"> Here we get `query_constructor` results \n\n")
|
31 |
+
|
32 |
+
self.progress_bar.progress(value=0.5, text="Generate filter by LLM...")
|
33 |
+
for item in response.generations:
|
34 |
+
st.markdown(f"{item[0].text}")
|
35 |
+
|
36 |
+
pass
|
backend/callbacks/self_query_callbacks.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Any, List
|
2 |
+
|
3 |
+
import streamlit as st
|
4 |
+
from langchain.callbacks.streamlit.streamlit_callback_handler import (
|
5 |
+
StreamlitCallbackHandler,
|
6 |
+
)
|
7 |
+
from langchain.schema.output import LLMResult
|
8 |
+
|
9 |
+
|
10 |
+
class CustomSelfQueryRetrieverCallBackHandler(StreamlitCallbackHandler):
|
11 |
+
def __init__(self):
|
12 |
+
super().__init__(st.container())
|
13 |
+
self._current_thought = None
|
14 |
+
self.progress_bar = st.progress(value=0.0, text="Executing ChatData SelfQuery...")
|
15 |
+
|
16 |
+
def on_llm_start(
|
17 |
+
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
18 |
+
) -> None:
|
19 |
+
self.progress_bar.progress(value=0.35, text="Communicate with LLM...")
|
20 |
+
pass
|
21 |
+
|
22 |
+
def on_chain_end(self, outputs, **kwargs) -> None:
|
23 |
+
if len(kwargs['tags']) == 0:
|
24 |
+
self.progress_bar.progress(value=0.75, text="Searching in DB...")
|
25 |
+
pass
|
26 |
+
|
27 |
+
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
28 |
+
st.markdown("### Generate filter by LLM \n"
|
29 |
+
"> Here we get `query_constructor` results \n\n")
|
30 |
+
self.progress_bar.progress(value=0.5, text="Generate filter by LLM...")
|
31 |
+
for item in response.generations:
|
32 |
+
st.markdown(f"{item[0].text}")
|
33 |
+
pass
|
34 |
+
|
35 |
+
|
36 |
+
class ChatDataSelfAskCallBackHandler(StreamlitCallbackHandler):
|
37 |
+
def __init__(self) -> None:
|
38 |
+
super().__init__(st.container())
|
39 |
+
self.progress_bar = st.progress(value=0.2, text="Executing ChatData SelfQuery Chain...")
|
40 |
+
|
41 |
+
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
|
42 |
+
pass
|
43 |
+
|
44 |
+
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
45 |
+
|
46 |
+
if len(kwargs['tags']) != 0:
|
47 |
+
self.progress_bar.progress(value=0.5, text="We got filter info from LLM...")
|
48 |
+
st.markdown("### Generate filter by LLM \n"
|
49 |
+
"> Here we get `query_constructor` results \n\n")
|
50 |
+
for item in response.generations:
|
51 |
+
st.markdown(f"{item[0].text}")
|
52 |
+
pass
|
53 |
+
|
54 |
+
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
|
55 |
+
cid = ".".join(serialized["id"])
|
56 |
+
if cid.endswith(".CustomStuffDocumentChain"):
|
57 |
+
self.progress_bar.progress(value=0.7, text="Asking LLM with related documents...")
|
backend/callbacks/vector_sql_callbacks.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from langchain.callbacks.streamlit.streamlit_callback_handler import (
|
3 |
+
StreamlitCallbackHandler,
|
4 |
+
)
|
5 |
+
from langchain.schema.output import LLMResult
|
6 |
+
from sql_formatter.core import format_sql
|
7 |
+
|
8 |
+
|
9 |
+
class VectorSQLSearchDBCallBackHandler(StreamlitCallbackHandler):
|
10 |
+
def __init__(self) -> None:
|
11 |
+
self.progress_bar = st.progress(value=0.0, text="Writing SQL...")
|
12 |
+
self.status_bar = st.empty()
|
13 |
+
self.prog_value = 0
|
14 |
+
self.prog_interval = 0.2
|
15 |
+
|
16 |
+
def on_llm_start(self, serialized, prompts, **kwargs) -> None:
|
17 |
+
pass
|
18 |
+
|
19 |
+
def on_llm_end(
|
20 |
+
self,
|
21 |
+
response: LLMResult,
|
22 |
+
*args,
|
23 |
+
**kwargs,
|
24 |
+
):
|
25 |
+
text = response.generations[0][0].text
|
26 |
+
if text.replace(" ", "").upper().startswith("SELECT"):
|
27 |
+
st.markdown("### Generated Vector Search SQL Statement \n"
|
28 |
+
"> This sql statement is generated by LLM \n\n")
|
29 |
+
st.markdown(f"""```sql\n{format_sql(text, max_len=80)}\n```""")
|
30 |
+
self.prog_value += self.prog_interval
|
31 |
+
self.progress_bar.progress(
|
32 |
+
value=self.prog_value, text="Searching in DB...")
|
33 |
+
|
34 |
+
def on_chain_start(self, serialized, inputs, **kwargs) -> None:
|
35 |
+
cid = ".".join(serialized["id"])
|
36 |
+
self.prog_value += self.prog_interval
|
37 |
+
self.progress_bar.progress(
|
38 |
+
value=self.prog_value, text=f"Running Chain `{cid}`..."
|
39 |
+
)
|
40 |
+
|
41 |
+
def on_chain_end(self, outputs, **kwargs) -> None:
|
42 |
+
pass
|
43 |
+
|
44 |
+
|
45 |
+
class VectorSQLSearchLLMCallBackHandler(VectorSQLSearchDBCallBackHandler):
|
46 |
+
def __init__(self, table: str) -> None:
|
47 |
+
self.progress_bar = st.progress(value=0.0, text="Writing SQL...")
|
48 |
+
self.status_bar = st.empty()
|
49 |
+
self.prog_value = 0
|
50 |
+
self.prog_interval = 0.1
|
51 |
+
self.table = table
|
52 |
+
|
53 |
+
|
backend/chains/__init__.py
ADDED
File without changes
|
backend/chains/retrieval_qa_with_sources.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
from typing import Dict, Any, Optional, List
|
3 |
+
|
4 |
+
from langchain.callbacks.manager import (
|
5 |
+
AsyncCallbackManagerForChainRun,
|
6 |
+
CallbackManagerForChainRun,
|
7 |
+
)
|
8 |
+
from langchain.chains.qa_with_sources.retrieval import RetrievalQAWithSourcesChain
|
9 |
+
from langchain.docstore.document import Document
|
10 |
+
|
11 |
+
from logger import logger
|
12 |
+
|
13 |
+
|
14 |
+
class CustomRetrievalQAWithSourcesChain(RetrievalQAWithSourcesChain):
|
15 |
+
"""QA with source chain for Chat ArXiv app with references
|
16 |
+
|
17 |
+
This chain will automatically assign reference number to the article,
|
18 |
+
Then parse it back to titles or anything else.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def _call(
|
22 |
+
self,
|
23 |
+
inputs: Dict[str, Any],
|
24 |
+
run_manager: Optional[CallbackManagerForChainRun] = None,
|
25 |
+
) -> Dict[str, str]:
|
26 |
+
logger.info(f"\033[91m\033[1m{self._chain_type}\033[0m")
|
27 |
+
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
28 |
+
accepts_run_manager = (
|
29 |
+
"run_manager" in inspect.signature(self._get_docs).parameters
|
30 |
+
)
|
31 |
+
if accepts_run_manager:
|
32 |
+
docs: List[Document] = self._get_docs(inputs, run_manager=_run_manager)
|
33 |
+
else:
|
34 |
+
docs: List[Document] = self._get_docs(inputs) # type: ignore[call-arg]
|
35 |
+
|
36 |
+
answer = self.combine_documents_chain.run(
|
37 |
+
input_documents=docs, callbacks=_run_manager.get_child(), **inputs
|
38 |
+
)
|
39 |
+
# parse source with ref_id
|
40 |
+
sources = []
|
41 |
+
ref_cnt = 1
|
42 |
+
for d in docs:
|
43 |
+
ref_id = d.metadata['ref_id']
|
44 |
+
if f"Doc #{ref_id}" in answer:
|
45 |
+
answer = answer.replace(f"Doc #{ref_id}", f"#{ref_id}")
|
46 |
+
if f"#{ref_id}" in answer:
|
47 |
+
title = d.metadata['title'].replace('\n', '')
|
48 |
+
d.metadata['ref_id'] = ref_cnt
|
49 |
+
answer = answer.replace(f"#{ref_id}", f"{title} [{ref_cnt}]")
|
50 |
+
sources.append(d)
|
51 |
+
ref_cnt += 1
|
52 |
+
|
53 |
+
result: Dict[str, Any] = {
|
54 |
+
self.answer_key: answer,
|
55 |
+
self.sources_answer_key: sources,
|
56 |
+
}
|
57 |
+
if self.return_source_documents:
|
58 |
+
result["source_documents"] = docs
|
59 |
+
return result
|
60 |
+
|
61 |
+
async def _acall(
|
62 |
+
self,
|
63 |
+
inputs: Dict[str, Any],
|
64 |
+
run_manager: Optional[AsyncCallbackManagerForChainRun] = None,
|
65 |
+
) -> Dict[str, Any]:
|
66 |
+
raise NotImplementedError
|
67 |
+
|
68 |
+
@property
|
69 |
+
def _chain_type(self) -> str:
|
70 |
+
return "custom_retrieval_qa_with_sources_chain"
|
backend/chains/stuff_documents.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, List, Tuple
|
2 |
+
|
3 |
+
from langchain.callbacks.manager import Callbacks
|
4 |
+
from langchain.chains.combine_documents.stuff import StuffDocumentsChain
|
5 |
+
from langchain.docstore.document import Document
|
6 |
+
from langchain.schema.prompt_template import format_document
|
7 |
+
|
8 |
+
|
9 |
+
class CustomStuffDocumentChain(StuffDocumentsChain):
|
10 |
+
"""Combine arxiv documents with PDF reference number"""
|
11 |
+
|
12 |
+
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
|
13 |
+
"""Construct inputs from kwargs and docs.
|
14 |
+
|
15 |
+
Format and the join all the documents together into one input with name
|
16 |
+
`self.document_variable_name`. The pluck any additional variables
|
17 |
+
from **kwargs.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
docs: List of documents to format and then join into single input
|
21 |
+
**kwargs: additional inputs to chain, will pluck any other required
|
22 |
+
arguments from here.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
dictionary of inputs to LLMChain
|
26 |
+
"""
|
27 |
+
# Format each document according to the prompt
|
28 |
+
doc_strings = []
|
29 |
+
for doc_id, doc in enumerate(docs):
|
30 |
+
# add temp reference number in metadata
|
31 |
+
doc.metadata.update({'ref_id': doc_id})
|
32 |
+
doc.page_content = doc.page_content.replace('\n', ' ')
|
33 |
+
doc_strings.append(format_document(doc, self.document_prompt))
|
34 |
+
# Join the documents together to put them in the prompt.
|
35 |
+
inputs = {
|
36 |
+
k: v
|
37 |
+
for k, v in kwargs.items()
|
38 |
+
if k in self.llm_chain.prompt.input_variables
|
39 |
+
}
|
40 |
+
inputs[self.document_variable_name] = self.document_separator.join(
|
41 |
+
doc_strings)
|
42 |
+
return inputs
|
43 |
+
|
44 |
+
def combine_docs(
|
45 |
+
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
46 |
+
) -> Tuple[str, dict]:
|
47 |
+
"""Stuff all documents into one prompt and pass to LLM.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
docs: List of documents to join together into one variable
|
51 |
+
callbacks: Optional callbacks to pass along
|
52 |
+
**kwargs: additional parameters to use to get inputs to LLMChain.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
The first element returned is the single string output. The second
|
56 |
+
element returned is a dictionary of other keys to return.
|
57 |
+
"""
|
58 |
+
inputs = self._get_inputs(docs, **kwargs)
|
59 |
+
# Call predict on the LLM.
|
60 |
+
output = self.llm_chain.predict(callbacks=callbacks, **inputs)
|
61 |
+
return output, {}
|
62 |
+
|
63 |
+
@property
|
64 |
+
def _chain_type(self) -> str:
|
65 |
+
return "custom_stuff_document_chain"
|
backend/chat_bot/__init__.py
ADDED
File without changes
|
backend/chat_bot/chat.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
|
3 |
+
from os import environ
|
4 |
+
from time import sleep
|
5 |
+
import streamlit as st
|
6 |
+
|
7 |
+
from backend.constants.prompts import DEFAULT_SYSTEM_PROMPT
|
8 |
+
from backend.constants.streamlit_keys import CHAT_KNOWLEDGE_TABLE, CHAT_SESSION_MANAGER, \
|
9 |
+
CHAT_CURRENT_USER_SESSIONS, EL_SESSION_SELECTOR, USER_PRIVATE_FILES, \
|
10 |
+
EL_BUILD_KB_WITH_FILES, \
|
11 |
+
EL_PERSONAL_KB_NAME, EL_PERSONAL_KB_DESCRIPTION, \
|
12 |
+
USER_PERSONAL_KNOWLEDGE_BASES, AVAILABLE_RETRIEVAL_TOOLS, EL_PERSONAL_KB_NEEDS_REMOVE, \
|
13 |
+
EL_UPLOAD_FILES_STATUS, EL_SELECTED_KBS, EL_UPLOAD_FILES
|
14 |
+
from backend.constants.variables import USER_INFO, USER_NAME, JUMP_QUERY_ASK, RETRIEVER_TOOLS
|
15 |
+
from backend.construct.build_agents import build_agents
|
16 |
+
from backend.chat_bot.session_manager import SessionManager
|
17 |
+
from backend.callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler
|
18 |
+
|
19 |
+
from logger import logger
|
20 |
+
|
21 |
+
environ["OPENAI_API_BASE"] = st.secrets["OPENAI_API_BASE"]
|
22 |
+
|
23 |
+
TOOL_NAMES = {
|
24 |
+
"langchain_retriever_tool": "Self-querying retriever",
|
25 |
+
"vecsql_retriever_tool": "Vector SQL",
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
def on_chat_submit():
|
30 |
+
with st.session_state.next_round.container():
|
31 |
+
with st.chat_message("user"):
|
32 |
+
st.write(st.session_state.chat_input)
|
33 |
+
with st.chat_message("assistant"):
|
34 |
+
container = st.container()
|
35 |
+
st_callback = ChatDataAgentCallBackHandler(
|
36 |
+
container, collapse_completed_thoughts=False
|
37 |
+
)
|
38 |
+
ret = st.session_state.agent(
|
39 |
+
{"input": st.session_state.chat_input}, callbacks=[st_callback]
|
40 |
+
)
|
41 |
+
logger.info(f"ret:{ret}")
|
42 |
+
|
43 |
+
|
44 |
+
def clear_history():
|
45 |
+
if "agent" in st.session_state:
|
46 |
+
st.session_state.agent.memory.clear()
|
47 |
+
|
48 |
+
|
49 |
+
def back_to_main():
|
50 |
+
if USER_INFO in st.session_state:
|
51 |
+
del st.session_state[USER_INFO]
|
52 |
+
if USER_NAME in st.session_state:
|
53 |
+
del st.session_state[USER_NAME]
|
54 |
+
if JUMP_QUERY_ASK in st.session_state:
|
55 |
+
del st.session_state[JUMP_QUERY_ASK]
|
56 |
+
if EL_SESSION_SELECTOR in st.session_state:
|
57 |
+
del st.session_state[EL_SESSION_SELECTOR]
|
58 |
+
if CHAT_CURRENT_USER_SESSIONS in st.session_state:
|
59 |
+
del st.session_state[CHAT_CURRENT_USER_SESSIONS]
|
60 |
+
|
61 |
+
|
62 |
+
def refresh_sessions():
|
63 |
+
chat_session_manager: SessionManager = st.session_state[CHAT_SESSION_MANAGER]
|
64 |
+
current_user_name = st.session_state[USER_NAME]
|
65 |
+
current_user_sessions = chat_session_manager.list_sessions(current_user_name)
|
66 |
+
|
67 |
+
if not isinstance(current_user_sessions, dict) or not current_user_sessions:
|
68 |
+
# generate a default session for current user.
|
69 |
+
chat_session_manager.add_session(
|
70 |
+
user_id=current_user_name,
|
71 |
+
session_id=f"{current_user_name}?default",
|
72 |
+
system_prompt=DEFAULT_SYSTEM_PROMPT,
|
73 |
+
)
|
74 |
+
st.session_state[CHAT_CURRENT_USER_SESSIONS] = chat_session_manager.list_sessions(current_user_name)
|
75 |
+
current_user_sessions = st.session_state[CHAT_CURRENT_USER_SESSIONS]
|
76 |
+
else:
|
77 |
+
st.session_state[CHAT_CURRENT_USER_SESSIONS] = current_user_sessions
|
78 |
+
|
79 |
+
# load current user files.
|
80 |
+
st.session_state[USER_PRIVATE_FILES] = st.session_state[CHAT_KNOWLEDGE_TABLE].list_files(
|
81 |
+
current_user_name
|
82 |
+
)
|
83 |
+
# load current user private knowledge bases.
|
84 |
+
st.session_state[USER_PERSONAL_KNOWLEDGE_BASES] = \
|
85 |
+
st.session_state[CHAT_KNOWLEDGE_TABLE].list_private_knowledge_bases(current_user_name)
|
86 |
+
logger.info(f"current user name: {current_user_name}, "
|
87 |
+
f"user private knowledge bases: {st.session_state[USER_PERSONAL_KNOWLEDGE_BASES]}, "
|
88 |
+
f"user private files: {st.session_state[USER_PRIVATE_FILES]}")
|
89 |
+
st.session_state[AVAILABLE_RETRIEVAL_TOOLS] = {
|
90 |
+
# public retrieval tools
|
91 |
+
**st.session_state[RETRIEVER_TOOLS],
|
92 |
+
# private retrieval tools
|
93 |
+
**st.session_state[CHAT_KNOWLEDGE_TABLE].as_retrieval_tools(current_user_name),
|
94 |
+
}
|
95 |
+
# print(f"sel_session is {st.session_state.sel_session}, current_user_sessions is {current_user_sessions}")
|
96 |
+
print(f"current_user_sessions is {current_user_sessions}")
|
97 |
+
st.session_state[EL_SESSION_SELECTOR] = current_user_sessions[0]
|
98 |
+
|
99 |
+
|
100 |
+
# process for session add and delete.
|
101 |
+
def on_session_change_submit():
|
102 |
+
if "session_manager" in st.session_state and "session_editor" in st.session_state:
|
103 |
+
try:
|
104 |
+
for elem in st.session_state.session_editor["added_rows"]:
|
105 |
+
if len(elem) > 0 and "system_prompt" in elem and "session_id" in elem:
|
106 |
+
if elem["session_id"] != "" and "?" not in elem["session_id"]:
|
107 |
+
st.session_state.session_manager.add_session(
|
108 |
+
user_id=st.session_state.user_name,
|
109 |
+
session_id=f"{st.session_state.user_name}?{elem['session_id']}",
|
110 |
+
system_prompt=elem["system_prompt"],
|
111 |
+
)
|
112 |
+
else:
|
113 |
+
st.toast("`session_id` shouldn't be neither empty nor contain char `?`.", icon="❌")
|
114 |
+
raise KeyError(
|
115 |
+
"`session_id` shouldn't be neither empty nor contain char `?`."
|
116 |
+
)
|
117 |
+
else:
|
118 |
+
st.toast("`You should fill both `session_id` and `system_prompt` to add a column!", icon="❌")
|
119 |
+
raise KeyError(
|
120 |
+
"You should fill both `session_id` and `system_prompt` to add a column!"
|
121 |
+
)
|
122 |
+
for elem in st.session_state.session_editor["deleted_rows"]:
|
123 |
+
user_name = st.session_state[USER_NAME]
|
124 |
+
session_id = st.session_state[CHAT_CURRENT_USER_SESSIONS][elem]['session_id']
|
125 |
+
user_with_session_id = f"{user_name}?{session_id}"
|
126 |
+
st.session_state.session_manager.remove_session(session_id=user_with_session_id)
|
127 |
+
st.toast(f"session `{user_with_session_id}` removed.", icon="✅")
|
128 |
+
|
129 |
+
refresh_sessions()
|
130 |
+
except Exception as e:
|
131 |
+
sleep(2)
|
132 |
+
st.error(f"{type(e)}: {str(e)}")
|
133 |
+
finally:
|
134 |
+
st.session_state.session_editor["added_rows"] = []
|
135 |
+
st.session_state.session_editor["deleted_rows"] = []
|
136 |
+
refresh_agent()
|
137 |
+
|
138 |
+
|
139 |
+
def create_private_knowledge_base_as_tool():
|
140 |
+
current_user_name = st.session_state[USER_NAME]
|
141 |
+
|
142 |
+
if (
|
143 |
+
EL_PERSONAL_KB_NAME in st.session_state
|
144 |
+
and EL_PERSONAL_KB_DESCRIPTION in st.session_state
|
145 |
+
and EL_BUILD_KB_WITH_FILES in st.session_state
|
146 |
+
and len(st.session_state[EL_PERSONAL_KB_NAME]) > 0
|
147 |
+
and len(st.session_state[EL_PERSONAL_KB_DESCRIPTION]) > 0
|
148 |
+
and len(st.session_state[EL_BUILD_KB_WITH_FILES]) > 0
|
149 |
+
):
|
150 |
+
st.session_state[CHAT_KNOWLEDGE_TABLE].create_private_knowledge_base(
|
151 |
+
user_id=current_user_name,
|
152 |
+
tool_name=st.session_state[EL_PERSONAL_KB_NAME],
|
153 |
+
tool_description=st.session_state[EL_PERSONAL_KB_DESCRIPTION],
|
154 |
+
files=[f["file_name"] for f in st.session_state[EL_BUILD_KB_WITH_FILES]],
|
155 |
+
)
|
156 |
+
refresh_sessions()
|
157 |
+
else:
|
158 |
+
st.session_state[EL_UPLOAD_FILES_STATUS].error(
|
159 |
+
"You should fill all fields to build up a tool!"
|
160 |
+
)
|
161 |
+
sleep(2)
|
162 |
+
|
163 |
+
|
164 |
+
def remove_private_knowledge_bases():
|
165 |
+
if EL_PERSONAL_KB_NEEDS_REMOVE in st.session_state and st.session_state[EL_PERSONAL_KB_NEEDS_REMOVE]:
|
166 |
+
private_knowledge_bases_needs_remove = st.session_state[EL_PERSONAL_KB_NEEDS_REMOVE]
|
167 |
+
private_knowledge_base_names = [item["tool_name"] for item in private_knowledge_bases_needs_remove]
|
168 |
+
# remove these private knowledge bases.
|
169 |
+
st.session_state[CHAT_KNOWLEDGE_TABLE].remove_private_knowledge_bases(
|
170 |
+
user_id=st.session_state[USER_NAME],
|
171 |
+
private_knowledge_bases=private_knowledge_base_names
|
172 |
+
)
|
173 |
+
refresh_sessions()
|
174 |
+
else:
|
175 |
+
st.session_state[EL_UPLOAD_FILES_STATUS].error(
|
176 |
+
"You should specify at least one private knowledge base to delete!"
|
177 |
+
)
|
178 |
+
time.sleep(2)
|
179 |
+
|
180 |
+
|
181 |
+
def refresh_agent():
|
182 |
+
with st.spinner("Initializing session..."):
|
183 |
+
user_name = st.session_state[USER_NAME]
|
184 |
+
session_id = st.session_state[EL_SESSION_SELECTOR]['session_id']
|
185 |
+
user_with_session_id = f"{user_name}?{session_id}"
|
186 |
+
|
187 |
+
if EL_SELECTED_KBS in st.session_state:
|
188 |
+
selected_knowledge_bases = st.session_state[EL_SELECTED_KBS]
|
189 |
+
else:
|
190 |
+
selected_knowledge_bases = ["Wikipedia + Vector SQL"]
|
191 |
+
|
192 |
+
logger.info(f"selected_knowledge_bases: {selected_knowledge_bases}")
|
193 |
+
if EL_SESSION_SELECTOR in st.session_state:
|
194 |
+
system_prompt = st.session_state[EL_SESSION_SELECTOR]["system_prompt"]
|
195 |
+
else:
|
196 |
+
system_prompt = DEFAULT_SYSTEM_PROMPT
|
197 |
+
|
198 |
+
st.session_state["agent"] = build_agents(
|
199 |
+
session_id=user_with_session_id,
|
200 |
+
tool_names=selected_knowledge_bases,
|
201 |
+
system_prompt=system_prompt
|
202 |
+
)
|
203 |
+
|
204 |
+
|
205 |
+
def add_file():
|
206 |
+
user_name = st.session_state[USER_NAME]
|
207 |
+
if EL_UPLOAD_FILES not in st.session_state or len(st.session_state[EL_UPLOAD_FILES]) == 0:
|
208 |
+
st.session_state[EL_UPLOAD_FILES_STATUS].error("Please upload files!", icon="⚠️")
|
209 |
+
sleep(2)
|
210 |
+
return
|
211 |
+
try:
|
212 |
+
st.session_state[EL_UPLOAD_FILES_STATUS].info("Uploading...")
|
213 |
+
st.session_state[CHAT_KNOWLEDGE_TABLE].add_by_file(
|
214 |
+
user_id=user_name,
|
215 |
+
files=st.session_state[EL_UPLOAD_FILES]
|
216 |
+
)
|
217 |
+
refresh_sessions()
|
218 |
+
except ValueError as e:
|
219 |
+
st.session_state[EL_UPLOAD_FILES_STATUS].error("Failed to upload! " + str(e))
|
220 |
+
sleep(2)
|
221 |
+
|
222 |
+
|
223 |
+
def clear_files():
|
224 |
+
st.session_state[CHAT_KNOWLEDGE_TABLE].clear(user_id=st.session_state[USER_NAME])
|
225 |
+
refresh_sessions()
|
backend/chat_bot/json_decoder.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import datetime
|
3 |
+
|
4 |
+
|
5 |
+
class CustomJSONEncoder(json.JSONEncoder):
|
6 |
+
def default(self, obj):
|
7 |
+
if isinstance(obj, datetime.datetime):
|
8 |
+
return datetime.datetime.isoformat(obj)
|
9 |
+
return json.JSONEncoder.default(self, obj)
|
10 |
+
|
11 |
+
|
12 |
+
class CustomJSONDecoder(json.JSONDecoder):
|
13 |
+
def __init__(self, *args, **kwargs):
|
14 |
+
json.JSONDecoder.__init__(
|
15 |
+
self, object_hook=self.object_hook, *args, **kwargs)
|
16 |
+
|
17 |
+
def object_hook(self, source):
|
18 |
+
for k, v in source.items():
|
19 |
+
if isinstance(v, str):
|
20 |
+
try:
|
21 |
+
source[k] = datetime.datetime.fromisoformat(str(v))
|
22 |
+
except:
|
23 |
+
pass
|
24 |
+
return source
|
backend/chat_bot/message_converter.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
import json
|
3 |
+
import time
|
4 |
+
from typing import Any
|
5 |
+
|
6 |
+
from langchain.memory.chat_message_histories.sql import DefaultMessageConverter
|
7 |
+
from langchain.schema import BaseMessage, HumanMessage, AIMessage, SystemMessage, ChatMessage, FunctionMessage
|
8 |
+
from langchain.schema.messages import ToolMessage
|
9 |
+
from sqlalchemy.orm import declarative_base
|
10 |
+
|
11 |
+
from backend.chat_bot.tools import create_message_history_table
|
12 |
+
|
13 |
+
|
14 |
+
def _message_from_dict(message: dict) -> BaseMessage:
|
15 |
+
_type = message["type"]
|
16 |
+
if _type == "human":
|
17 |
+
return HumanMessage(**message["data"])
|
18 |
+
elif _type == "ai":
|
19 |
+
return AIMessage(**message["data"])
|
20 |
+
elif _type == "system":
|
21 |
+
return SystemMessage(**message["data"])
|
22 |
+
elif _type == "chat":
|
23 |
+
return ChatMessage(**message["data"])
|
24 |
+
elif _type == "function":
|
25 |
+
return FunctionMessage(**message["data"])
|
26 |
+
elif _type == "tool":
|
27 |
+
return ToolMessage(**message["data"])
|
28 |
+
elif _type == "AIMessageChunk":
|
29 |
+
message["data"]["type"] = "ai"
|
30 |
+
return AIMessage(**message["data"])
|
31 |
+
else:
|
32 |
+
raise ValueError(f"Got unexpected message type: {_type}")
|
33 |
+
|
34 |
+
|
35 |
+
class DefaultClickhouseMessageConverter(DefaultMessageConverter):
|
36 |
+
"""The default message converter for SQLChatMessageHistory."""
|
37 |
+
|
38 |
+
def __init__(self, table_name: str):
|
39 |
+
super().__init__(table_name)
|
40 |
+
self.model_class = create_message_history_table(table_name, declarative_base())
|
41 |
+
|
42 |
+
def to_sql_model(self, message: BaseMessage, session_id: str) -> Any:
|
43 |
+
time_stamp = time.time()
|
44 |
+
msg_id = hashlib.sha256(
|
45 |
+
f"{session_id}_{message}_{time_stamp}".encode('utf-8')).hexdigest()
|
46 |
+
user_id, _ = session_id.split("?")
|
47 |
+
return self.model_class(
|
48 |
+
id=time_stamp,
|
49 |
+
msg_id=msg_id,
|
50 |
+
user_id=user_id,
|
51 |
+
session_id=session_id,
|
52 |
+
type=message.type,
|
53 |
+
addtionals=json.dumps(message.additional_kwargs),
|
54 |
+
message=json.dumps({
|
55 |
+
"type": message.type,
|
56 |
+
"additional_kwargs": {"timestamp": time_stamp},
|
57 |
+
"data": message.dict()})
|
58 |
+
)
|
59 |
+
|
60 |
+
def from_sql_model(self, sql_message: Any) -> BaseMessage:
|
61 |
+
msg_dump = json.loads(sql_message.message)
|
62 |
+
msg = _message_from_dict(msg_dump)
|
63 |
+
msg.additional_kwargs = msg_dump["additional_kwargs"]
|
64 |
+
return msg
|
65 |
+
|
66 |
+
def get_sql_model_class(self) -> Any:
|
67 |
+
return self.model_class
|
backend/chat_bot/private_knowledge_base.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
from datetime import datetime
|
3 |
+
from typing import List, Optional
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
from clickhouse_connect import get_client
|
7 |
+
from langchain.schema.embeddings import Embeddings
|
8 |
+
from langchain.vectorstores.myscale import MyScaleWithoutJSON, MyScaleSettings
|
9 |
+
from streamlit.runtime.uploaded_file_manager import UploadedFile
|
10 |
+
|
11 |
+
from backend.chat_bot.tools import parse_files, extract_embedding
|
12 |
+
from backend.construct.build_retriever_tool import create_retriever_tool
|
13 |
+
from logger import logger
|
14 |
+
|
15 |
+
|
16 |
+
class ChatBotKnowledgeTable:
|
17 |
+
def __init__(self, host, port, username, password,
|
18 |
+
embedding: Embeddings, parser_api_key: str, db="chat",
|
19 |
+
kb_table="private_kb", tool_table="private_tool") -> None:
|
20 |
+
super().__init__()
|
21 |
+
personal_files_schema_ = f"""
|
22 |
+
CREATE TABLE IF NOT EXISTS {db}.{kb_table}(
|
23 |
+
entity_id String,
|
24 |
+
file_name String,
|
25 |
+
text String,
|
26 |
+
user_id String,
|
27 |
+
created_by DateTime,
|
28 |
+
vector Array(Float32),
|
29 |
+
CONSTRAINT cons_vec_len CHECK length(vector) = 768,
|
30 |
+
VECTOR INDEX vidx vector TYPE MSTG('metric_type=Cosine')
|
31 |
+
) ENGINE = ReplacingMergeTree ORDER BY entity_id
|
32 |
+
"""
|
33 |
+
|
34 |
+
# `tool_name` represent private knowledge database name.
|
35 |
+
private_knowledge_base_schema_ = f"""
|
36 |
+
CREATE TABLE IF NOT EXISTS {db}.{tool_table}(
|
37 |
+
tool_id String,
|
38 |
+
tool_name String,
|
39 |
+
file_names Array(String),
|
40 |
+
user_id String,
|
41 |
+
created_by DateTime,
|
42 |
+
tool_description String
|
43 |
+
) ENGINE = ReplacingMergeTree ORDER BY tool_id
|
44 |
+
"""
|
45 |
+
self.personal_files_table = kb_table
|
46 |
+
self.private_knowledge_base_table = tool_table
|
47 |
+
config = MyScaleSettings(
|
48 |
+
host=host,
|
49 |
+
port=port,
|
50 |
+
username=username,
|
51 |
+
password=password,
|
52 |
+
database=db,
|
53 |
+
table=kb_table,
|
54 |
+
)
|
55 |
+
self.client = get_client(
|
56 |
+
host=config.host,
|
57 |
+
port=config.port,
|
58 |
+
username=config.username,
|
59 |
+
password=config.password,
|
60 |
+
)
|
61 |
+
self.client.command("SET allow_experimental_object_type=1")
|
62 |
+
self.client.command(personal_files_schema_)
|
63 |
+
self.client.command(private_knowledge_base_schema_)
|
64 |
+
self.parser_api_key = parser_api_key
|
65 |
+
self.vector_store = MyScaleWithoutJSON(
|
66 |
+
embedding=embedding,
|
67 |
+
config=config,
|
68 |
+
must_have_cols=["file_name", "text", "created_by"],
|
69 |
+
)
|
70 |
+
|
71 |
+
# List all files with given `user_id`
|
72 |
+
def list_files(self, user_id: str):
|
73 |
+
query = f"""
|
74 |
+
SELECT DISTINCT file_name, COUNT(entity_id) AS num_paragraph,
|
75 |
+
arrayMax(arrayMap(x->length(x), groupArray(text))) AS max_chars
|
76 |
+
FROM {self.vector_store.config.database}.{self.personal_files_table}
|
77 |
+
WHERE user_id = '{user_id}' GROUP BY file_name
|
78 |
+
"""
|
79 |
+
return [r for r in self.vector_store.client.query(query).named_results()]
|
80 |
+
|
81 |
+
# Parse and embedding files
|
82 |
+
def add_by_file(self, user_id, files: List[UploadedFile]):
|
83 |
+
data = parse_files(self.parser_api_key, user_id, files)
|
84 |
+
data = extract_embedding(self.vector_store.embeddings, data)
|
85 |
+
self.vector_store.client.insert_df(
|
86 |
+
table=self.personal_files_table,
|
87 |
+
df=pd.DataFrame(data),
|
88 |
+
database=self.vector_store.config.database,
|
89 |
+
)
|
90 |
+
|
91 |
+
# Remove all files and private_knowledge_bases with given `user_id`
|
92 |
+
def clear(self, user_id: str):
|
93 |
+
self.vector_store.client.command(
|
94 |
+
f"DELETE FROM {self.vector_store.config.database}.{self.personal_files_table} "
|
95 |
+
f"WHERE user_id='{user_id}'"
|
96 |
+
)
|
97 |
+
query = f"""DELETE FROM {self.vector_store.config.database}.{self.private_knowledge_base_table}
|
98 |
+
WHERE user_id = '{user_id}'"""
|
99 |
+
self.vector_store.client.command(query)
|
100 |
+
|
101 |
+
def create_private_knowledge_base(
|
102 |
+
self, user_id: str, tool_name: str, tool_description: str, files: Optional[List[str]] = None
|
103 |
+
):
|
104 |
+
self.vector_store.client.insert_df(
|
105 |
+
self.private_knowledge_base_table,
|
106 |
+
pd.DataFrame(
|
107 |
+
[
|
108 |
+
{
|
109 |
+
"tool_id": hashlib.sha256(
|
110 |
+
(user_id + tool_name).encode("utf-8")
|
111 |
+
).hexdigest(),
|
112 |
+
"tool_name": tool_name, # tool_name represent user's private knowledge base.
|
113 |
+
"file_names": files,
|
114 |
+
"user_id": user_id,
|
115 |
+
"created_by": datetime.now(),
|
116 |
+
"tool_description": tool_description,
|
117 |
+
}
|
118 |
+
]
|
119 |
+
),
|
120 |
+
database=self.vector_store.config.database,
|
121 |
+
)
|
122 |
+
|
123 |
+
# Show all private knowledge bases with given `user_id`
|
124 |
+
def list_private_knowledge_bases(self, user_id: str, private_knowledge_base=None):
|
125 |
+
extended_where = f"AND tool_name = '{private_knowledge_base}'" if private_knowledge_base else ""
|
126 |
+
query = f"""
|
127 |
+
SELECT tool_name, tool_description, length(file_names)
|
128 |
+
FROM {self.vector_store.config.database}.{self.private_knowledge_base_table}
|
129 |
+
WHERE user_id = '{user_id}' {extended_where}
|
130 |
+
"""
|
131 |
+
return [r for r in self.vector_store.client.query(query).named_results()]
|
132 |
+
|
133 |
+
def remove_private_knowledge_bases(self, user_id: str, private_knowledge_bases: List[str]):
|
134 |
+
unique_list = list(set(private_knowledge_bases))
|
135 |
+
unique_list = ",".join([f"'{t}'" for t in unique_list])
|
136 |
+
query = f"""DELETE FROM {self.vector_store.config.database}.{self.private_knowledge_base_table}
|
137 |
+
WHERE user_id = '{user_id}' AND tool_name IN [{unique_list}]"""
|
138 |
+
self.vector_store.client.command(query)
|
139 |
+
|
140 |
+
def as_retrieval_tools(self, user_id, tool_name=None):
|
141 |
+
logger.info(f"")
|
142 |
+
private_knowledge_bases = self.list_private_knowledge_bases(user_id=user_id, private_knowledge_base=tool_name)
|
143 |
+
retrievers = {}
|
144 |
+
for private_kb in private_knowledge_bases:
|
145 |
+
file_names_sql = f"""
|
146 |
+
SELECT arrayJoin(file_names) FROM (
|
147 |
+
SELECT file_names
|
148 |
+
FROM chat.private_tool
|
149 |
+
WHERE user_id = '{user_id}' AND tool_name = '{private_kb["tool_name"]}'
|
150 |
+
)
|
151 |
+
"""
|
152 |
+
logger.info(f"user_id is {user_id}, file_names_sql is {file_names_sql}")
|
153 |
+
res = self.client.query(file_names_sql)
|
154 |
+
file_names = []
|
155 |
+
for line in res.result_rows:
|
156 |
+
file_names.append(line[0])
|
157 |
+
file_names = ', '.join(f"'{item}'" for item in file_names)
|
158 |
+
logger.info(f"user_id is {user_id}, file_names is {file_names}")
|
159 |
+
retrievers[private_kb["tool_name"]] = create_retriever_tool(
|
160 |
+
self.vector_store.as_retriever(
|
161 |
+
search_kwargs={"where_str": f"user_id='{user_id}' AND file_name IN ({file_names})"},
|
162 |
+
),
|
163 |
+
tool_name=private_kb["tool_name"],
|
164 |
+
description=private_kb["tool_description"],
|
165 |
+
)
|
166 |
+
return retrievers
|
167 |
+
|
backend/chat_bot/session_manager.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
|
3 |
+
from backend.chat_bot.tools import create_session_table, create_message_history_table
|
4 |
+
from backend.constants.variables import GLOBAL_CONFIG
|
5 |
+
|
6 |
+
try:
|
7 |
+
from sqlalchemy.orm import declarative_base
|
8 |
+
except ImportError:
|
9 |
+
from sqlalchemy.ext.declarative import declarative_base
|
10 |
+
from datetime import datetime
|
11 |
+
from sqlalchemy import orm, create_engine
|
12 |
+
from logger import logger
|
13 |
+
|
14 |
+
|
15 |
+
def get_sessions(engine, model_class, user_id):
|
16 |
+
with orm.sessionmaker(engine)() as session:
|
17 |
+
result = (
|
18 |
+
session.query(model_class)
|
19 |
+
.where(
|
20 |
+
model_class.session_id == user_id
|
21 |
+
)
|
22 |
+
.order_by(model_class.create_by.desc())
|
23 |
+
)
|
24 |
+
return json.loads(result)
|
25 |
+
|
26 |
+
|
27 |
+
class SessionManager:
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
session_state,
|
31 |
+
host,
|
32 |
+
port,
|
33 |
+
username,
|
34 |
+
password,
|
35 |
+
db='chat',
|
36 |
+
session_table='sessions',
|
37 |
+
msg_table='chat_memory'
|
38 |
+
) -> None:
|
39 |
+
if GLOBAL_CONFIG.myscale_enable_https == False:
|
40 |
+
conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=http'
|
41 |
+
else:
|
42 |
+
conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=https'
|
43 |
+
self.engine = create_engine(conn_str, echo=False)
|
44 |
+
self.session_model_class = create_session_table(
|
45 |
+
session_table, declarative_base())
|
46 |
+
self.session_model_class.metadata.create_all(self.engine)
|
47 |
+
self.msg_model_class = create_message_history_table(msg_table, declarative_base())
|
48 |
+
self.msg_model_class.metadata.create_all(self.engine)
|
49 |
+
self.session_orm = orm.sessionmaker(self.engine)
|
50 |
+
self.session_state = session_state
|
51 |
+
|
52 |
+
def list_sessions(self, user_id: str):
|
53 |
+
with self.session_orm() as session:
|
54 |
+
result = (
|
55 |
+
session.query(self.session_model_class)
|
56 |
+
.where(
|
57 |
+
self.session_model_class.user_id == user_id
|
58 |
+
)
|
59 |
+
.order_by(self.session_model_class.create_by.desc())
|
60 |
+
)
|
61 |
+
sessions = []
|
62 |
+
for r in result:
|
63 |
+
sessions.append({
|
64 |
+
"session_id": r.session_id.split("?")[-1],
|
65 |
+
"system_prompt": r.system_prompt,
|
66 |
+
})
|
67 |
+
return sessions
|
68 |
+
|
69 |
+
# Update sys_prompt with given session_id
|
70 |
+
def modify_system_prompt(self, session_id, sys_prompt):
|
71 |
+
with self.session_orm() as session:
|
72 |
+
obj = session.query(self.session_model_class).where(
|
73 |
+
self.session_model_class.session_id == session_id).first()
|
74 |
+
if obj:
|
75 |
+
obj.system_prompt = sys_prompt
|
76 |
+
session.commit()
|
77 |
+
else:
|
78 |
+
logger.warning(f"Session {session_id} not found")
|
79 |
+
|
80 |
+
# Add a session(session_id, sys_prompt)
|
81 |
+
def add_session(self, user_id: str, session_id: str, system_prompt: str, **kwargs):
|
82 |
+
with self.session_orm() as session:
|
83 |
+
elem = self.session_model_class(
|
84 |
+
user_id=user_id, session_id=session_id, system_prompt=system_prompt,
|
85 |
+
create_by=datetime.now(), additionals=json.dumps(kwargs)
|
86 |
+
)
|
87 |
+
session.add(elem)
|
88 |
+
session.commit()
|
89 |
+
|
90 |
+
# Remove a session and related chat history.
|
91 |
+
def remove_session(self, session_id: str):
|
92 |
+
with self.session_orm() as session:
|
93 |
+
# remove session
|
94 |
+
session.query(self.session_model_class).where(self.session_model_class.session_id == session_id).delete()
|
95 |
+
# remove related chat history.
|
96 |
+
session.query(self.msg_model_class).where(self.msg_model_class.session_id == session_id).delete()
|
backend/chat_bot/tools.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import hashlib
|
2 |
+
from datetime import datetime
|
3 |
+
from multiprocessing.pool import ThreadPool
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
import requests
|
7 |
+
from clickhouse_sqlalchemy import types, engines
|
8 |
+
from langchain.schema.embeddings import Embeddings
|
9 |
+
from sqlalchemy import Column, Text
|
10 |
+
from streamlit.runtime.uploaded_file_manager import UploadedFile
|
11 |
+
|
12 |
+
|
13 |
+
def parse_files(api_key, user_id, files: List[UploadedFile]):
|
14 |
+
def parse_file(file: UploadedFile):
|
15 |
+
headers = {
|
16 |
+
"accept": "application/json",
|
17 |
+
"unstructured-api-key": api_key,
|
18 |
+
}
|
19 |
+
data = {"strategy": "auto", "ocr_languages": ["eng"]}
|
20 |
+
file_hash = hashlib.sha256(file.read()).hexdigest()
|
21 |
+
file_data = {"files": (file.name, file.getvalue(), file.type)}
|
22 |
+
response = requests.post(
|
23 |
+
url="https://api.unstructured.io/general/v0/general",
|
24 |
+
headers=headers,
|
25 |
+
data=data,
|
26 |
+
files=file_data
|
27 |
+
)
|
28 |
+
json_response = response.json()
|
29 |
+
if response.status_code != 200:
|
30 |
+
raise ValueError(str(json_response))
|
31 |
+
texts = [
|
32 |
+
{
|
33 |
+
"text": t["text"],
|
34 |
+
"file_name": t["metadata"]["filename"],
|
35 |
+
"entity_id": hashlib.sha256(
|
36 |
+
(file_hash + t["text"]).encode()
|
37 |
+
).hexdigest(),
|
38 |
+
"user_id": user_id,
|
39 |
+
"created_by": datetime.now(),
|
40 |
+
}
|
41 |
+
for t in json_response
|
42 |
+
if t["type"] == "NarrativeText" and len(t["text"].split(" ")) > 10
|
43 |
+
]
|
44 |
+
return texts
|
45 |
+
|
46 |
+
with ThreadPool(8) as p:
|
47 |
+
rows = []
|
48 |
+
for r in p.imap_unordered(parse_file, files):
|
49 |
+
rows.extend(r)
|
50 |
+
return rows
|
51 |
+
|
52 |
+
|
53 |
+
def extract_embedding(embeddings: Embeddings, texts):
|
54 |
+
if len(texts) > 0:
|
55 |
+
embeddings = embeddings.embed_documents(
|
56 |
+
[t["text"] for _, t in enumerate(texts)])
|
57 |
+
for i, _ in enumerate(texts):
|
58 |
+
texts[i]["vector"] = embeddings[i]
|
59 |
+
return texts
|
60 |
+
raise ValueError("No texts extracted!")
|
61 |
+
|
62 |
+
|
63 |
+
def create_message_history_table(table_name: str, base_class):
|
64 |
+
class Message(base_class):
|
65 |
+
__tablename__ = table_name
|
66 |
+
id = Column(types.Float64)
|
67 |
+
session_id = Column(Text)
|
68 |
+
user_id = Column(Text)
|
69 |
+
msg_id = Column(Text, primary_key=True)
|
70 |
+
type = Column(Text)
|
71 |
+
# should be additions, formal developer mistake spell it.
|
72 |
+
addtionals = Column(Text)
|
73 |
+
message = Column(Text)
|
74 |
+
__table_args__ = (
|
75 |
+
engines.MergeTree(
|
76 |
+
partition_by='session_id',
|
77 |
+
order_by=('id', 'msg_id')
|
78 |
+
),
|
79 |
+
{'comment': 'Store Chat History'}
|
80 |
+
)
|
81 |
+
|
82 |
+
return Message
|
83 |
+
|
84 |
+
|
85 |
+
def create_session_table(table_name: str, DynamicBase):
|
86 |
+
class Session(DynamicBase):
|
87 |
+
__tablename__ = table_name
|
88 |
+
user_id = Column(Text)
|
89 |
+
session_id = Column(Text, primary_key=True)
|
90 |
+
system_prompt = Column(Text)
|
91 |
+
# represent create time.
|
92 |
+
create_by = Column(types.DateTime)
|
93 |
+
# should be additions, formal developer mistake spell it.
|
94 |
+
additionals = Column(Text)
|
95 |
+
__table_args__ = (
|
96 |
+
engines.MergeTree(order_by=session_id),
|
97 |
+
{'comment': 'Store Session and Prompts'}
|
98 |
+
)
|
99 |
+
|
100 |
+
return Session
|
backend/constants/__init__.py
ADDED
File without changes
|
backend/constants/myscale_tables.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List
|
2 |
+
import streamlit as st
|
3 |
+
from langchain.chains.query_constructor.schema import AttributeInfo
|
4 |
+
from langchain_community.embeddings import SentenceTransformerEmbeddings, HuggingFaceInstructEmbeddings
|
5 |
+
from langchain.prompts import PromptTemplate
|
6 |
+
|
7 |
+
from backend.types.table_config import TableConfig
|
8 |
+
|
9 |
+
|
10 |
+
def hint_arxiv():
|
11 |
+
st.markdown("Here we provide some query samples.")
|
12 |
+
st.markdown("- If you want to search papers with filters")
|
13 |
+
st.markdown("1. ```What is a Bayesian network? Please use articles published later than Feb 2018 and with more "
|
14 |
+
"than 2 categories and whose title like `computer` and must have `cs.CV` in its category. ```")
|
15 |
+
st.markdown("2. ```What is a Bayesian network? Please use articles published later than Feb 2018```")
|
16 |
+
st.markdown("- If you want to ask questions based on arxiv papers stored in MyScaleDB")
|
17 |
+
st.markdown("1. ```Did Geoffrey Hinton wrote paper about Capsule Neural Networks?```")
|
18 |
+
st.markdown("2. ```Introduce some applications of GANs published around 2019.```")
|
19 |
+
st.markdown("3. ```请根据 2019 年左右的文章介绍一下 GAN 的应用都有哪些```")
|
20 |
+
|
21 |
+
|
22 |
+
def hint_sql_arxiv():
|
23 |
+
st.markdown('''```sql
|
24 |
+
CREATE TABLE default.ChatArXiv (
|
25 |
+
`abstract` String,
|
26 |
+
`id` String,
|
27 |
+
`vector` Array(Float32),
|
28 |
+
`metadata` Object('JSON'),
|
29 |
+
`pubdate` DateTime,
|
30 |
+
`title` String,
|
31 |
+
`categories` Array(String),
|
32 |
+
`authors` Array(String),
|
33 |
+
`comment` String,
|
34 |
+
`primary_category` String,
|
35 |
+
VECTOR INDEX vec_idx vector TYPE MSTG('fp16_storage=1', 'metric_type=Cosine', 'disk_mode=3'),
|
36 |
+
CONSTRAINT vec_len CHECK length(vector) = 768)
|
37 |
+
ENGINE = ReplacingMergeTree ORDER BY id
|
38 |
+
```''')
|
39 |
+
|
40 |
+
|
41 |
+
def hint_wiki():
|
42 |
+
st.markdown("Here we provide some query samples.")
|
43 |
+
st.markdown("1. ```Which company did Elon Musk found?```")
|
44 |
+
st.markdown("2. ```What is Iron Gwazi?```")
|
45 |
+
st.markdown("3. ```苹果的发源地是哪里?```")
|
46 |
+
st.markdown("4. ```What is a Ring in mathematics?```")
|
47 |
+
st.markdown("5. ```The producer of Rick and Morty.```")
|
48 |
+
st.markdown("6. ```How low is the temperature on Pluto?```")
|
49 |
+
|
50 |
+
|
51 |
+
def hint_sql_wiki():
|
52 |
+
st.markdown('''```sql
|
53 |
+
CREATE TABLE wiki.Wikipedia (
|
54 |
+
`id` String,
|
55 |
+
`title` String,
|
56 |
+
`text` String,
|
57 |
+
`url` String,
|
58 |
+
`wiki_id` UInt64,
|
59 |
+
`views` Float32,
|
60 |
+
`paragraph_id` UInt64,
|
61 |
+
`langs` UInt32,
|
62 |
+
`emb` Array(Float32),
|
63 |
+
VECTOR INDEX vec_idx emb TYPE MSTG('fp16_storage=1', 'metric_type=Cosine', 'disk_mode=3'),
|
64 |
+
CONSTRAINT emb_len CHECK length(emb) = 768)
|
65 |
+
ENGINE = ReplacingMergeTree ORDER BY id
|
66 |
+
```''')
|
67 |
+
|
68 |
+
|
69 |
+
MYSCALE_TABLES: Dict[str, TableConfig] = {
|
70 |
+
'Wikipedia': TableConfig(
|
71 |
+
database="wiki",
|
72 |
+
table="Wikipedia",
|
73 |
+
table_contents="Snapshort from Wikipedia for 2022. All in English.",
|
74 |
+
hint=hint_wiki,
|
75 |
+
hint_sql=hint_sql_wiki,
|
76 |
+
# doc_prompt 对 qa source chain 有用
|
77 |
+
doc_prompt=PromptTemplate(
|
78 |
+
input_variables=["page_content", "url", "title", "ref_id", "views"],
|
79 |
+
template="Title for Doc #{ref_id}: {title}\n\tviews: {views}\n\tcontent: {page_content}\nSOURCE: {url}"
|
80 |
+
),
|
81 |
+
metadata_col_attributes=[
|
82 |
+
AttributeInfo(name="title", description="title of the wikipedia page", type="string"),
|
83 |
+
AttributeInfo(name="text", description="paragraph from this wiki page", type="string"),
|
84 |
+
AttributeInfo(name="views", description="number of views", type="float")
|
85 |
+
],
|
86 |
+
must_have_col_names=['id', 'title', 'url', 'text', 'views'],
|
87 |
+
vector_col_name="emb",
|
88 |
+
text_col_name="text",
|
89 |
+
metadata_col_name="metadata",
|
90 |
+
emb_model=lambda: SentenceTransformerEmbeddings(
|
91 |
+
model_name='sentence-transformers/paraphrase-multilingual-mpnet-base-v2'
|
92 |
+
),
|
93 |
+
tool_desc=("search_among_wikipedia", "Searches among Wikipedia and returns related wiki pages")
|
94 |
+
),
|
95 |
+
'ArXiv Papers': TableConfig(
|
96 |
+
database="default",
|
97 |
+
table="ChatArXiv",
|
98 |
+
table_contents="Snapshort from Wikipedia for 2022. All in English.",
|
99 |
+
hint=hint_arxiv,
|
100 |
+
hint_sql=hint_sql_arxiv,
|
101 |
+
doc_prompt=PromptTemplate(
|
102 |
+
input_variables=["page_content", "id", "title", "ref_id", "authors", "pubdate", "categories"],
|
103 |
+
template="Title for Doc #{ref_id}: {title}\n\tAbstract: {page_content}\n\tAuthors: {authors}\n\t"
|
104 |
+
"Date of Publication: {pubdate}\n\tCategories: {categories}\nSOURCE: {id}"
|
105 |
+
),
|
106 |
+
metadata_col_attributes=[
|
107 |
+
AttributeInfo(name="pubdate", description="The year the paper is published", type="timestamp"),
|
108 |
+
AttributeInfo(name="authors", description="List of author names", type="list[string]"),
|
109 |
+
AttributeInfo(name="title", description="Title of the paper", type="string"),
|
110 |
+
AttributeInfo(name="categories", description="arxiv categories to this paper", type="list[string]"),
|
111 |
+
AttributeInfo(name="length(categories)", description="length of arxiv categories to this paper", type="int")
|
112 |
+
],
|
113 |
+
must_have_col_names=['title', 'id', 'categories', 'abstract', 'authors', 'pubdate'],
|
114 |
+
vector_col_name="vector",
|
115 |
+
text_col_name="abstract",
|
116 |
+
metadata_col_name="metadata",
|
117 |
+
emb_model=lambda: HuggingFaceInstructEmbeddings(
|
118 |
+
model_name='hkunlp/instructor-xl',
|
119 |
+
embed_instruction="Represent the question for retrieving supporting scientific papers: "
|
120 |
+
),
|
121 |
+
tool_desc=(
|
122 |
+
"search_among_scientific_papers",
|
123 |
+
"Searches among scientific papers from ArXiv and returns research papers"
|
124 |
+
)
|
125 |
+
)
|
126 |
+
}
|
127 |
+
|
128 |
+
ALL_TABLE_NAME: List[str] = [config.table for config in MYSCALE_TABLES.values()]
|
backend/constants/prompts.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.prompts import ChatPromptTemplate, \
|
2 |
+
SystemMessagePromptTemplate, HumanMessagePromptTemplate
|
3 |
+
|
4 |
+
DEFAULT_SYSTEM_PROMPT = (
|
5 |
+
"Do your best to answer the questions. "
|
6 |
+
"Feel free to use any tools available to look up "
|
7 |
+
"relevant information. Please keep all details in query "
|
8 |
+
"when calling search functions."
|
9 |
+
)
|
10 |
+
|
11 |
+
COMBINE_PROMPT_TEMPLATE = (
|
12 |
+
"You are a helpful document assistant. "
|
13 |
+
"Your task is to provide information and answer any questions related to documents given below. "
|
14 |
+
"You should use the sections, title and abstract of the selected documents as your source of information "
|
15 |
+
"and try to provide concise and accurate answers to any questions asked by the user. "
|
16 |
+
"If you are unable to find relevant information in the given sections, "
|
17 |
+
"you will need to let the user know that the source does not contain relevant information but still try to "
|
18 |
+
"provide an answer based on your general knowledge. You must refer to the corresponding section name and page "
|
19 |
+
"that you refer to when answering. "
|
20 |
+
"The following is the related information about the document that will help you answer users' questions, "
|
21 |
+
"you MUST answer it using question's language:\n\n {summaries} "
|
22 |
+
"Now you should answer user's question. Remember you must use `Doc #` to refer papers:\n\n"
|
23 |
+
)
|
24 |
+
|
25 |
+
COMBINE_PROMPT = ChatPromptTemplate.from_strings(
|
26 |
+
string_messages=[(SystemMessagePromptTemplate, COMBINE_PROMPT_TEMPLATE),
|
27 |
+
(HumanMessagePromptTemplate, '{question}')])
|
28 |
+
|
29 |
+
MYSCALE_PROMPT = """
|
30 |
+
You are a MyScale expert. Given an input question, first create a syntactically correct MyScale query to run, then look at the results of the query and return the answer to the input question.
|
31 |
+
MyScale queries has a vector distance function called `DISTANCE(column, array)` to compute relevance to the user's question and sort the feature array column by the relevance.
|
32 |
+
When the query is asking for {top_k} closest row, you have to use this distance function to calculate distance to entity's array on vector column and order by the distance to retrieve relevant rows.
|
33 |
+
|
34 |
+
*NOTICE*: `DISTANCE(column, array)` only accept an array column as its first argument and a `NeuralArray(entity)` as its second argument. You also need a user defined function called `NeuralArray(entity)` to retrieve the entity's array.
|
35 |
+
|
36 |
+
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MyScale. You should only order according to the distance function.
|
37 |
+
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
|
38 |
+
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
|
39 |
+
Pay attention to use today() function to get the current date, if the question involves "today". `ORDER BY` clause should always be after `WHERE` clause. DO NOT add semicolon to the end of SQL. Pay attention to the comment in table schema.
|
40 |
+
Pay attention to the data type when using functions. Always use `AND` to connect conditions in `WHERE` and never use comma.
|
41 |
+
Make sure you never write an isolated `WHERE` keyword and never use undesired condition to conrtain the query.
|
42 |
+
|
43 |
+
Use the following format:
|
44 |
+
|
45 |
+
======== table info ========
|
46 |
+
<some table infos>
|
47 |
+
|
48 |
+
Question: "Question here"
|
49 |
+
SQLQuery: "SQL Query to run"
|
50 |
+
|
51 |
+
|
52 |
+
Here are some examples:
|
53 |
+
|
54 |
+
======== table info ========
|
55 |
+
CREATE TABLE "ChatPaper" (
|
56 |
+
abstract String,
|
57 |
+
id String,
|
58 |
+
vector Array(Float32),
|
59 |
+
) ENGINE = ReplicatedReplacingMergeTree()
|
60 |
+
ORDER BY id
|
61 |
+
PRIMARY KEY id
|
62 |
+
|
63 |
+
Question: What is Feartue Pyramid Network?
|
64 |
+
SQLQuery: SELECT ChatPaper.abstract, ChatPaper.id FROM ChatPaper ORDER BY DISTANCE(vector, NeuralArray(PaperRank contribution)) LIMIT {top_k}
|
65 |
+
|
66 |
+
|
67 |
+
======== table info ========
|
68 |
+
CREATE TABLE "ChatPaper" (
|
69 |
+
abstract String,
|
70 |
+
id String,
|
71 |
+
vector Array(Float32),
|
72 |
+
categories Array(String),
|
73 |
+
pubdate DateTime,
|
74 |
+
title String,
|
75 |
+
authors Array(String),
|
76 |
+
primary_category String
|
77 |
+
) ENGINE = ReplicatedReplacingMergeTree()
|
78 |
+
ORDER BY id
|
79 |
+
PRIMARY KEY id
|
80 |
+
|
81 |
+
Question: What is PaperRank? What is the contribution of those works? Use paper with more than 2 categories.
|
82 |
+
SQLQuery: SELECT ChatPaper.title, ChatPaper.id, ChatPaper.authors FROM ChatPaper WHERE length(categories) > 2 ORDER BY DISTANCE(vector, NeuralArray(PaperRank contribution)) LIMIT {top_k}
|
83 |
+
|
84 |
+
|
85 |
+
======== table info ========
|
86 |
+
CREATE TABLE "ChatArXiv" (
|
87 |
+
primary_category String
|
88 |
+
categories Array(String),
|
89 |
+
pubdate DateTime,
|
90 |
+
abstract String,
|
91 |
+
title String,
|
92 |
+
paper_id String,
|
93 |
+
vector Array(Float32),
|
94 |
+
authors Array(String),
|
95 |
+
) ENGINE = MergeTree()
|
96 |
+
ORDER BY paper_id
|
97 |
+
PRIMARY KEY paper_id
|
98 |
+
|
99 |
+
Question: Did Geoffrey Hinton wrote about Capsule Neural Networks? Please use articles published later than 2021.
|
100 |
+
SQLQuery: SELECT ChatArXiv.title, ChatArXiv.paper_id, ChatArXiv.authors FROM ChatArXiv WHERE has(authors, 'Geoffrey Hinton') AND pubdate > parseDateTimeBestEffort('2021-01-01') ORDER BY DISTANCE(vector, NeuralArray(Capsule Neural Networks)) LIMIT {top_k}
|
101 |
+
|
102 |
+
|
103 |
+
======== table info ========
|
104 |
+
CREATE TABLE "PaperDatabase" (
|
105 |
+
abstract String,
|
106 |
+
categories Array(String),
|
107 |
+
vector Array(Float32),
|
108 |
+
pubdate DateTime,
|
109 |
+
id String,
|
110 |
+
comments String,
|
111 |
+
title String,
|
112 |
+
authors Array(String),
|
113 |
+
primary_category String
|
114 |
+
) ENGINE = MergeTree()
|
115 |
+
ORDER BY id
|
116 |
+
PRIMARY KEY id
|
117 |
+
|
118 |
+
Question: Find papers whose abstract has Mutual Information in it.
|
119 |
+
SQLQuery: SELECT PaperDatabase.title, PaperDatabase.id FROM PaperDatabase WHERE abstract ILIKE '%Mutual Information%' ORDER BY DISTANCE(vector, NeuralArray(Mutual Information)) LIMIT {top_k}
|
120 |
+
|
121 |
+
|
122 |
+
Let's begin:
|
123 |
+
|
124 |
+
======== table info ========
|
125 |
+
{table_info}
|
126 |
+
|
127 |
+
Question: {input}
|
128 |
+
SQLQuery: """
|
backend/constants/streamlit_keys.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DATA_INITIALIZE_NOT_STATED = "data_initialize_not_started"
|
2 |
+
DATA_INITIALIZE_STARTED = "data_initialize_started"
|
3 |
+
DATA_INITIALIZE_COMPLETED = "data_initialize_completed"
|
4 |
+
|
5 |
+
|
6 |
+
CHAT_SESSION = "sel_sess"
|
7 |
+
CHAT_KNOWLEDGE_TABLE = "private_kb"
|
8 |
+
|
9 |
+
CHAT_SESSION_MANAGER = "session_manager"
|
10 |
+
CHAT_CURRENT_USER_SESSIONS = "current_sessions"
|
11 |
+
|
12 |
+
EL_SESSION_SELECTOR = "el_session_selector"
|
13 |
+
|
14 |
+
# all personal knowledge bases under a specific user.
|
15 |
+
USER_PERSONAL_KNOWLEDGE_BASES = "user_tools"
|
16 |
+
# all personal files under a specific user.
|
17 |
+
USER_PRIVATE_FILES = "user_files"
|
18 |
+
# public and personal knowledge bases.
|
19 |
+
AVAILABLE_RETRIEVAL_TOOLS = "tools_with_users"
|
20 |
+
|
21 |
+
EL_PERSONAL_KB_NEEDS_REMOVE = "el_personal_kb_needs_remove"
|
22 |
+
|
23 |
+
# files needs upload
|
24 |
+
EL_UPLOAD_FILES = "el_upload_files"
|
25 |
+
EL_UPLOAD_FILES_STATUS = "el_upload_files_status"
|
26 |
+
|
27 |
+
# use these files to build private knowledge base
|
28 |
+
EL_BUILD_KB_WITH_FILES = "el_build_kb_with_files"
|
29 |
+
# build a personal kb, given name.
|
30 |
+
EL_PERSONAL_KB_NAME = "el_personal_kb_name"
|
31 |
+
# build a personal kb, given description.
|
32 |
+
EL_PERSONAL_KB_DESCRIPTION = "el_personal_kb_description"
|
33 |
+
|
34 |
+
# knowledge bases selected by user.
|
35 |
+
EL_SELECTED_KBS = "el_selected_kbs"
|
backend/constants/variables.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from backend.types.global_config import GlobalConfig
|
2 |
+
|
3 |
+
# ***** str variables ***** #
|
4 |
+
EMBEDDING_MODEL_PREFIX = "embedding_model"
|
5 |
+
CHAINS_RETRIEVERS_MAPPING = "sel_map_obj"
|
6 |
+
LANGCHAIN_RETRIEVER = "langchain_retriever"
|
7 |
+
VECTOR_SQL_RETRIEVER = "vecsql_retriever"
|
8 |
+
TABLE_EMBEDDINGS_MAPPING = "embeddings"
|
9 |
+
RETRIEVER_TOOLS = "tools"
|
10 |
+
DATA_INITIALIZE_STATUS = "data_initialized"
|
11 |
+
UI_INITIALIZED = "ui_initialized"
|
12 |
+
JUMP_QUERY_ASK = "jump_query_ask"
|
13 |
+
USER_NAME = "user_name"
|
14 |
+
USER_INFO = "user_info"
|
15 |
+
|
16 |
+
DIVIDER_HTML = """
|
17 |
+
<div style="
|
18 |
+
height: 4px;
|
19 |
+
background: linear-gradient(to right, red, orange, yellow, green, blue, indigo, violet);
|
20 |
+
margin-top: 20px;
|
21 |
+
margin-bottom: 20px;
|
22 |
+
"></div>
|
23 |
+
"""
|
24 |
+
|
25 |
+
DIVIDER_THIN_HTML = """
|
26 |
+
<div style="
|
27 |
+
height: 2px;
|
28 |
+
background: linear-gradient(to right, blue, darkslateblue, indigo, violet);
|
29 |
+
margin-top: 20px;
|
30 |
+
margin-bottom: 20px;
|
31 |
+
"></div>
|
32 |
+
"""
|
33 |
+
|
34 |
+
|
35 |
+
class RetrieverButtons:
|
36 |
+
vector_sql_query_from_db = "vector_sql_query_from_db"
|
37 |
+
vector_sql_query_with_llm = "vector_sql_query_with_llm"
|
38 |
+
self_query_from_db = "self_query_from_db"
|
39 |
+
self_query_with_llm = "self_query_with_llm"
|
40 |
+
|
41 |
+
|
42 |
+
GLOBAL_CONFIG = GlobalConfig()
|
43 |
+
|
44 |
+
|
45 |
+
def update_global_config(new_config: GlobalConfig):
|
46 |
+
global GLOBAL_CONFIG
|
47 |
+
GLOBAL_CONFIG.openai_api_base = new_config.openai_api_base
|
48 |
+
GLOBAL_CONFIG.openai_api_key = new_config.openai_api_key
|
49 |
+
GLOBAL_CONFIG.auth0_client_id = new_config.auth0_client_id
|
50 |
+
GLOBAL_CONFIG.auth0_domain = new_config.auth0_domain
|
51 |
+
GLOBAL_CONFIG.myscale_user = new_config.myscale_user
|
52 |
+
GLOBAL_CONFIG.myscale_password = new_config.myscale_password
|
53 |
+
GLOBAL_CONFIG.myscale_host = new_config.myscale_host
|
54 |
+
GLOBAL_CONFIG.myscale_port = new_config.myscale_port
|
55 |
+
GLOBAL_CONFIG.query_model = new_config.query_model
|
56 |
+
GLOBAL_CONFIG.chat_model = new_config.chat_model
|
57 |
+
GLOBAL_CONFIG.untrusted_api = new_config.untrusted_api
|
58 |
+
GLOBAL_CONFIG.myscale_enable_https = new_config.myscale_enable_https
|
backend/construct/__init__.py
ADDED
File without changes
|
backend/construct/build_agents.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Sequence, List
|
3 |
+
|
4 |
+
import streamlit as st
|
5 |
+
from langchain.agents import AgentExecutor
|
6 |
+
from langchain.schema.language_model import BaseLanguageModel
|
7 |
+
from langchain.tools import BaseTool
|
8 |
+
|
9 |
+
from backend.chat_bot.message_converter import DefaultClickhouseMessageConverter
|
10 |
+
from backend.constants.prompts import DEFAULT_SYSTEM_PROMPT
|
11 |
+
from backend.constants.streamlit_keys import AVAILABLE_RETRIEVAL_TOOLS
|
12 |
+
from backend.constants.variables import GLOBAL_CONFIG, RETRIEVER_TOOLS
|
13 |
+
from logger import logger
|
14 |
+
|
15 |
+
try:
|
16 |
+
from sqlalchemy.orm import declarative_base
|
17 |
+
except ImportError:
|
18 |
+
from sqlalchemy.ext.declarative import declarative_base
|
19 |
+
from langchain.chat_models import ChatOpenAI
|
20 |
+
from langchain.prompts.chat import MessagesPlaceholder
|
21 |
+
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import AgentTokenBufferMemory
|
22 |
+
from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
|
23 |
+
from langchain.schema.messages import SystemMessage
|
24 |
+
from langchain.memory import SQLChatMessageHistory
|
25 |
+
|
26 |
+
|
27 |
+
def create_agent_executor(
|
28 |
+
agent_name: str,
|
29 |
+
session_id: str,
|
30 |
+
llm: BaseLanguageModel,
|
31 |
+
tools: Sequence[BaseTool],
|
32 |
+
system_prompt: str,
|
33 |
+
**kwargs
|
34 |
+
) -> AgentExecutor:
|
35 |
+
agent_name = agent_name.replace(" ", "_")
|
36 |
+
conn_str = f'clickhouse://{os.environ["MYSCALE_USER"]}:{os.environ["MYSCALE_PASSWORD"]}@{os.environ["MYSCALE_HOST"]}:{os.environ["MYSCALE_PORT"]}'
|
37 |
+
chat_memory = SQLChatMessageHistory(
|
38 |
+
session_id,
|
39 |
+
connection_string=f'{conn_str}/chat?protocol=http' if GLOBAL_CONFIG.myscale_enable_https == False else f'{conn_str}/chat?protocol=https',
|
40 |
+
custom_message_converter=DefaultClickhouseMessageConverter(agent_name))
|
41 |
+
memory = AgentTokenBufferMemory(llm=llm, chat_memory=chat_memory)
|
42 |
+
|
43 |
+
prompt = OpenAIFunctionsAgent.create_prompt(
|
44 |
+
system_message=SystemMessage(content=system_prompt),
|
45 |
+
extra_prompt_messages=[MessagesPlaceholder(variable_name="history")],
|
46 |
+
)
|
47 |
+
agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
|
48 |
+
return AgentExecutor(
|
49 |
+
agent=agent,
|
50 |
+
tools=tools,
|
51 |
+
memory=memory,
|
52 |
+
verbose=True,
|
53 |
+
return_intermediate_steps=True,
|
54 |
+
**kwargs
|
55 |
+
)
|
56 |
+
|
57 |
+
|
58 |
+
def build_agents(
|
59 |
+
session_id: str,
|
60 |
+
tool_names: List[str],
|
61 |
+
model: str = "gpt-3.5-turbo-0125",
|
62 |
+
temperature: float = 0.6,
|
63 |
+
system_prompt: str = DEFAULT_SYSTEM_PROMPT
|
64 |
+
):
|
65 |
+
chat_llm = ChatOpenAI(
|
66 |
+
model_name=model,
|
67 |
+
temperature=temperature,
|
68 |
+
base_url=GLOBAL_CONFIG.openai_api_base,
|
69 |
+
api_key=GLOBAL_CONFIG.openai_api_key,
|
70 |
+
streaming=True
|
71 |
+
)
|
72 |
+
tools = st.session_state.get(AVAILABLE_RETRIEVAL_TOOLS, st.session_state.get(RETRIEVER_TOOLS))
|
73 |
+
selected_tools = [tools[k] for k in tool_names]
|
74 |
+
logger.info(f"create agent, use tools: {selected_tools}")
|
75 |
+
agent = create_agent_executor(
|
76 |
+
agent_name="chat_memory",
|
77 |
+
session_id=session_id,
|
78 |
+
llm=chat_llm,
|
79 |
+
tools=selected_tools,
|
80 |
+
system_prompt=system_prompt
|
81 |
+
)
|
82 |
+
return agent
|
backend/construct/build_all.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from logger import logger
|
2 |
+
from typing import Dict, Any, Union
|
3 |
+
|
4 |
+
import streamlit as st
|
5 |
+
|
6 |
+
from backend.constants.myscale_tables import MYSCALE_TABLES
|
7 |
+
from backend.constants.variables import CHAINS_RETRIEVERS_MAPPING
|
8 |
+
from backend.construct.build_chains import build_retrieval_qa_with_sources_chain
|
9 |
+
from backend.construct.build_retriever_tool import create_retriever_tool
|
10 |
+
from backend.construct.build_retrievers import build_self_query_retriever, build_vector_sql_db_chain_retriever
|
11 |
+
from backend.types.chains_and_retrievers import ChainsAndRetrievers, MetadataColumn
|
12 |
+
|
13 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings, \
|
14 |
+
SentenceTransformerEmbeddings
|
15 |
+
|
16 |
+
|
17 |
+
@st.cache_resource
|
18 |
+
def load_embedding_model_for_table(table_name: str) -> \
|
19 |
+
Union[SentenceTransformerEmbeddings, HuggingFaceInstructEmbeddings]:
|
20 |
+
with st.spinner(f"Loading embedding models for [{table_name}] ..."):
|
21 |
+
embeddings = MYSCALE_TABLES[table_name].emb_model()
|
22 |
+
return embeddings
|
23 |
+
|
24 |
+
|
25 |
+
@st.cache_resource
|
26 |
+
def load_embedding_models() -> Dict[str, Union[HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings]]:
|
27 |
+
embedding_models = {}
|
28 |
+
for table in MYSCALE_TABLES:
|
29 |
+
embedding_models[table] = load_embedding_model_for_table(table)
|
30 |
+
return embedding_models
|
31 |
+
|
32 |
+
|
33 |
+
@st.cache_resource
|
34 |
+
def update_retriever_tools():
|
35 |
+
retrievers_tools = {}
|
36 |
+
for table in MYSCALE_TABLES:
|
37 |
+
logger.info(f"Updating retriever tools [<retriever>, <sql_retriever>] for table {table}")
|
38 |
+
retrievers_tools.update(
|
39 |
+
{
|
40 |
+
f"{table} + Self Querying": create_retriever_tool(
|
41 |
+
st.session_state[CHAINS_RETRIEVERS_MAPPING][table]["retriever"],
|
42 |
+
*MYSCALE_TABLES[table].tool_desc
|
43 |
+
),
|
44 |
+
f"{table} + Vector SQL": create_retriever_tool(
|
45 |
+
st.session_state[CHAINS_RETRIEVERS_MAPPING][table]["sql_retriever"],
|
46 |
+
*MYSCALE_TABLES[table].tool_desc
|
47 |
+
),
|
48 |
+
})
|
49 |
+
return retrievers_tools
|
50 |
+
|
51 |
+
|
52 |
+
@st.cache_resource
|
53 |
+
def build_chains_retriever_for_table(table_name: str) -> ChainsAndRetrievers:
|
54 |
+
metadata_col_attributes = MYSCALE_TABLES[table_name].metadata_col_attributes
|
55 |
+
|
56 |
+
self_query_retriever = build_self_query_retriever(table_name)
|
57 |
+
self_query_chain = build_retrieval_qa_with_sources_chain(
|
58 |
+
table_name=table_name,
|
59 |
+
retriever=self_query_retriever,
|
60 |
+
chain_name="Self Query Retriever"
|
61 |
+
)
|
62 |
+
|
63 |
+
vector_sql_retriever = build_vector_sql_db_chain_retriever(table_name)
|
64 |
+
vector_sql_chain = build_retrieval_qa_with_sources_chain(
|
65 |
+
table_name=table_name,
|
66 |
+
retriever=vector_sql_retriever,
|
67 |
+
chain_name="Vector SQL DB Retriever"
|
68 |
+
)
|
69 |
+
|
70 |
+
metadata_columns = [
|
71 |
+
MetadataColumn(
|
72 |
+
name=attribute.name,
|
73 |
+
desc=attribute.description,
|
74 |
+
type=attribute.type
|
75 |
+
)
|
76 |
+
for attribute in metadata_col_attributes
|
77 |
+
]
|
78 |
+
return ChainsAndRetrievers(
|
79 |
+
metadata_columns=metadata_columns,
|
80 |
+
# for self query
|
81 |
+
retriever=self_query_retriever,
|
82 |
+
chain=self_query_chain,
|
83 |
+
# for vector sql
|
84 |
+
sql_retriever=vector_sql_retriever,
|
85 |
+
sql_chain=vector_sql_chain
|
86 |
+
)
|
87 |
+
|
88 |
+
|
89 |
+
@st.cache_resource
|
90 |
+
def build_chains_and_retrievers() -> Dict[str, Dict[str, Any]]:
|
91 |
+
chains_and_retrievers = {}
|
92 |
+
for table in MYSCALE_TABLES:
|
93 |
+
logger.info(f"Building chains, retrievers for table {table}")
|
94 |
+
chains_and_retrievers[table] = build_chains_retriever_for_table(table).to_dict()
|
95 |
+
return chains_and_retrievers
|
backend/construct/build_chains.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.chains import LLMChain
|
2 |
+
from langchain.chat_models import ChatOpenAI
|
3 |
+
from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
|
4 |
+
from langchain.schema import BaseRetriever
|
5 |
+
import streamlit as st
|
6 |
+
|
7 |
+
from backend.chains.retrieval_qa_with_sources import CustomRetrievalQAWithSourcesChain
|
8 |
+
from backend.chains.stuff_documents import CustomStuffDocumentChain
|
9 |
+
from backend.constants.myscale_tables import MYSCALE_TABLES
|
10 |
+
from backend.constants.prompts import COMBINE_PROMPT
|
11 |
+
from backend.constants.variables import GLOBAL_CONFIG
|
12 |
+
|
13 |
+
|
14 |
+
def build_retrieval_qa_with_sources_chain(
|
15 |
+
table_name: str,
|
16 |
+
retriever: BaseRetriever,
|
17 |
+
chain_name: str = "<chain_name>"
|
18 |
+
) -> CustomRetrievalQAWithSourcesChain:
|
19 |
+
with st.spinner(f'Building QA source chain named `{chain_name}` for MyScaleDB/{table_name} ...'):
|
20 |
+
# Assign ref_id for documents
|
21 |
+
custom_stuff_document_chain = CustomStuffDocumentChain(
|
22 |
+
llm_chain=LLMChain(
|
23 |
+
prompt=COMBINE_PROMPT,
|
24 |
+
llm=ChatOpenAI(
|
25 |
+
model_name=GLOBAL_CONFIG.chat_model,
|
26 |
+
openai_api_key=GLOBAL_CONFIG.openai_api_key,
|
27 |
+
temperature=0.6
|
28 |
+
),
|
29 |
+
),
|
30 |
+
document_prompt=MYSCALE_TABLES[table_name].doc_prompt,
|
31 |
+
document_variable_name="summaries",
|
32 |
+
)
|
33 |
+
chain = CustomRetrievalQAWithSourcesChain(
|
34 |
+
retriever=retriever,
|
35 |
+
combine_documents_chain=custom_stuff_document_chain,
|
36 |
+
return_source_documents=True,
|
37 |
+
max_tokens_limit=12000,
|
38 |
+
)
|
39 |
+
return chain
|
backend/construct/build_chat_bot.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from backend.chat_bot.private_knowledge_base import ChatBotKnowledgeTable
|
2 |
+
from backend.constants.streamlit_keys import CHAT_KNOWLEDGE_TABLE, CHAT_SESSION, CHAT_SESSION_MANAGER
|
3 |
+
import streamlit as st
|
4 |
+
|
5 |
+
from backend.constants.variables import GLOBAL_CONFIG, TABLE_EMBEDDINGS_MAPPING
|
6 |
+
from backend.constants.prompts import DEFAULT_SYSTEM_PROMPT
|
7 |
+
from backend.chat_bot.session_manager import SessionManager
|
8 |
+
|
9 |
+
|
10 |
+
def build_chat_knowledge_table():
|
11 |
+
if CHAT_KNOWLEDGE_TABLE not in st.session_state:
|
12 |
+
st.session_state[CHAT_KNOWLEDGE_TABLE] = ChatBotKnowledgeTable(
|
13 |
+
host=GLOBAL_CONFIG.myscale_host,
|
14 |
+
port=GLOBAL_CONFIG.myscale_port,
|
15 |
+
username=GLOBAL_CONFIG.myscale_user,
|
16 |
+
password=GLOBAL_CONFIG.myscale_password,
|
17 |
+
# embedding=st.session_state[TABLE_EMBEDDINGS_MAPPING]["Wikipedia"],
|
18 |
+
embedding=st.session_state[TABLE_EMBEDDINGS_MAPPING]["ArXiv Papers"],
|
19 |
+
parser_api_key=GLOBAL_CONFIG.untrusted_api,
|
20 |
+
)
|
21 |
+
|
22 |
+
|
23 |
+
def initialize_session_manager():
|
24 |
+
if CHAT_SESSION not in st.session_state:
|
25 |
+
st.session_state[CHAT_SESSION] = {
|
26 |
+
"session_id": "default",
|
27 |
+
"system_prompt": DEFAULT_SYSTEM_PROMPT,
|
28 |
+
}
|
29 |
+
if CHAT_SESSION_MANAGER not in st.session_state:
|
30 |
+
st.session_state[CHAT_SESSION_MANAGER] = SessionManager(
|
31 |
+
st.session_state,
|
32 |
+
host=GLOBAL_CONFIG.myscale_host,
|
33 |
+
port=GLOBAL_CONFIG.myscale_port,
|
34 |
+
username=GLOBAL_CONFIG.myscale_user,
|
35 |
+
password=GLOBAL_CONFIG.myscale_password,
|
36 |
+
)
|
backend/construct/build_retriever_tool.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
from langchain.pydantic_v1 import BaseModel, Field
|
5 |
+
from langchain.schema import BaseRetriever, Document
|
6 |
+
from langchain.tools import Tool
|
7 |
+
|
8 |
+
from backend.chat_bot.json_decoder import CustomJSONEncoder
|
9 |
+
|
10 |
+
|
11 |
+
class RetrieverInput(BaseModel):
|
12 |
+
query: str = Field(description="query to look up in retriever")
|
13 |
+
|
14 |
+
|
15 |
+
def create_retriever_tool(
|
16 |
+
retriever: BaseRetriever,
|
17 |
+
tool_name: str,
|
18 |
+
description: str
|
19 |
+
) -> Tool:
|
20 |
+
"""Create a tool to do retrieval of documents.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
retriever: The retriever to use for the retrieval
|
24 |
+
tool_name: The name for the tool. This will be passed to the language model,
|
25 |
+
so should be unique and somewhat descriptive.
|
26 |
+
description: The description for the tool. This will be passed to the language
|
27 |
+
model, so should be descriptive.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
Tool class to pass to an agent
|
31 |
+
"""
|
32 |
+
def wrap(func):
|
33 |
+
def wrapped_retrieve(*args, **kwargs):
|
34 |
+
docs: List[Document] = func(*args, **kwargs)
|
35 |
+
return json.dumps([d.dict() for d in docs], cls=CustomJSONEncoder)
|
36 |
+
|
37 |
+
return wrapped_retrieve
|
38 |
+
|
39 |
+
return Tool(
|
40 |
+
name=tool_name,
|
41 |
+
description=description,
|
42 |
+
func=wrap(retriever.get_relevant_documents),
|
43 |
+
coroutine=retriever.aget_relevant_documents,
|
44 |
+
args_schema=RetrieverInput,
|
45 |
+
)
|
backend/construct/build_retrievers.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from langchain.chat_models import ChatOpenAI
|
3 |
+
from langchain.prompts.prompt import PromptTemplate
|
4 |
+
from langchain.retrievers.self_query.base import SelfQueryRetriever
|
5 |
+
from langchain.retrievers.self_query.myscale import MyScaleTranslator
|
6 |
+
from langchain.utilities.sql_database import SQLDatabase
|
7 |
+
from langchain.vectorstores import MyScaleSettings
|
8 |
+
from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever
|
9 |
+
from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain
|
10 |
+
from sqlalchemy import create_engine, MetaData
|
11 |
+
|
12 |
+
from backend.constants.myscale_tables import MYSCALE_TABLES
|
13 |
+
from backend.constants.prompts import MYSCALE_PROMPT
|
14 |
+
from backend.constants.variables import TABLE_EMBEDDINGS_MAPPING, GLOBAL_CONFIG
|
15 |
+
from backend.retrievers.vector_sql_output_parser import VectorSQLRetrieveOutputParser
|
16 |
+
from backend.vector_store.myscale_without_metadata import MyScaleWithoutMetadataJson
|
17 |
+
from logger import logger
|
18 |
+
|
19 |
+
|
20 |
+
@st.cache_resource
|
21 |
+
def build_self_query_retriever(table_name: str) -> SelfQueryRetriever:
|
22 |
+
with st.spinner(f"Building VectorStore for MyScaleDB/{table_name} ..."):
|
23 |
+
myscale_connection = {
|
24 |
+
"host": GLOBAL_CONFIG.myscale_host,
|
25 |
+
"port": GLOBAL_CONFIG.myscale_port,
|
26 |
+
"username": GLOBAL_CONFIG.myscale_user,
|
27 |
+
"password": GLOBAL_CONFIG.myscale_password,
|
28 |
+
}
|
29 |
+
myscale_settings = MyScaleSettings(
|
30 |
+
**myscale_connection,
|
31 |
+
database=MYSCALE_TABLES[table_name].database,
|
32 |
+
table=MYSCALE_TABLES[table_name].table,
|
33 |
+
column_map={
|
34 |
+
"id": "id",
|
35 |
+
"text": MYSCALE_TABLES[table_name].text_col_name,
|
36 |
+
"vector": MYSCALE_TABLES[table_name].vector_col_name,
|
37 |
+
# TODO refine MyScaleDB metadata in langchain.
|
38 |
+
"metadata": MYSCALE_TABLES[table_name].metadata_col_name
|
39 |
+
}
|
40 |
+
)
|
41 |
+
myscale_vector_store = MyScaleWithoutMetadataJson(
|
42 |
+
embedding=st.session_state[TABLE_EMBEDDINGS_MAPPING][table_name],
|
43 |
+
config=myscale_settings,
|
44 |
+
must_have_cols=MYSCALE_TABLES[table_name].must_have_col_names
|
45 |
+
)
|
46 |
+
|
47 |
+
with st.spinner(f"Building SelfQueryRetriever for MyScaleDB/{table_name} ..."):
|
48 |
+
retriever: SelfQueryRetriever = SelfQueryRetriever.from_llm(
|
49 |
+
llm=ChatOpenAI(
|
50 |
+
model_name=GLOBAL_CONFIG.query_model,
|
51 |
+
base_url=GLOBAL_CONFIG.openai_api_base,
|
52 |
+
api_key=GLOBAL_CONFIG.openai_api_key,
|
53 |
+
temperature=0
|
54 |
+
),
|
55 |
+
vectorstore=myscale_vector_store,
|
56 |
+
document_contents=MYSCALE_TABLES[table_name].table_contents,
|
57 |
+
metadata_field_info=MYSCALE_TABLES[table_name].metadata_col_attributes,
|
58 |
+
use_original_query=False,
|
59 |
+
structured_query_translator=MyScaleTranslator()
|
60 |
+
)
|
61 |
+
return retriever
|
62 |
+
|
63 |
+
|
64 |
+
@st.cache_resource
|
65 |
+
def build_vector_sql_db_chain_retriever(table_name: str) -> VectorSQLDatabaseChainRetriever:
|
66 |
+
"""Get a group of relative docs from MyScaleDB"""
|
67 |
+
with st.spinner(f'Building Vector SQL Database Retriever for MyScaleDB/{table_name}...'):
|
68 |
+
if GLOBAL_CONFIG.myscale_enable_https == False:
|
69 |
+
engine = create_engine(
|
70 |
+
f'clickhouse://{GLOBAL_CONFIG.myscale_user}:{GLOBAL_CONFIG.myscale_password}@'
|
71 |
+
f'{GLOBAL_CONFIG.myscale_host}:{GLOBAL_CONFIG.myscale_port}'
|
72 |
+
f'/{MYSCALE_TABLES[table_name].database}?protocol=http'
|
73 |
+
)
|
74 |
+
else:
|
75 |
+
engine = create_engine(
|
76 |
+
f'clickhouse://{GLOBAL_CONFIG.myscale_user}:{GLOBAL_CONFIG.myscale_password}@'
|
77 |
+
f'{GLOBAL_CONFIG.myscale_host}:{GLOBAL_CONFIG.myscale_port}'
|
78 |
+
f'/{MYSCALE_TABLES[table_name].database}?protocol=https'
|
79 |
+
)
|
80 |
+
metadata = MetaData(bind=engine)
|
81 |
+
logger.info(f"{table_name} metadata is : {metadata}")
|
82 |
+
prompt = PromptTemplate(
|
83 |
+
input_variables=["input", "table_info", "top_k"],
|
84 |
+
template=MYSCALE_PROMPT,
|
85 |
+
)
|
86 |
+
# Custom `out_put_parser` rewrite search SQL, make it's possible to query custom column.
|
87 |
+
output_parser = VectorSQLRetrieveOutputParser.from_embeddings(
|
88 |
+
model=st.session_state[TABLE_EMBEDDINGS_MAPPING][table_name],
|
89 |
+
# rewrite columns needs be searched.
|
90 |
+
must_have_columns=MYSCALE_TABLES[table_name].must_have_col_names
|
91 |
+
)
|
92 |
+
|
93 |
+
# `db_chain` will generate a SQL
|
94 |
+
vector_sql_db_chain: VectorSQLDatabaseChain = VectorSQLDatabaseChain.from_llm(
|
95 |
+
llm=ChatOpenAI(
|
96 |
+
model_name=GLOBAL_CONFIG.query_model,
|
97 |
+
base_url=GLOBAL_CONFIG.openai_api_base,
|
98 |
+
api_key=GLOBAL_CONFIG.openai_api_key,
|
99 |
+
temperature=0
|
100 |
+
),
|
101 |
+
prompt=prompt,
|
102 |
+
top_k=10,
|
103 |
+
return_direct=True,
|
104 |
+
db=SQLDatabase(
|
105 |
+
engine,
|
106 |
+
None,
|
107 |
+
metadata,
|
108 |
+
include_tables=[MYSCALE_TABLES[table_name].table],
|
109 |
+
max_string_length=1024
|
110 |
+
),
|
111 |
+
sql_cmd_parser=output_parser, # TODO needs update `langchain`, fix return type.
|
112 |
+
native_format=True
|
113 |
+
)
|
114 |
+
|
115 |
+
# `retriever` can search a group of documents with `db_chain`
|
116 |
+
vector_sql_db_chain_retriever = VectorSQLDatabaseChainRetriever(
|
117 |
+
sql_db_chain=vector_sql_db_chain,
|
118 |
+
page_content_key=MYSCALE_TABLES[table_name].text_col_name
|
119 |
+
)
|
120 |
+
return vector_sql_db_chain_retriever
|
backend/retrievers/__init__.py
ADDED
File without changes
|
backend/retrievers/self_query.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
import streamlit as st
|
5 |
+
from langchain.retrievers import SelfQueryRetriever
|
6 |
+
from langchain_core.documents import Document
|
7 |
+
from langchain_core.runnables import RunnableConfig
|
8 |
+
|
9 |
+
from backend.chains.retrieval_qa_with_sources import CustomRetrievalQAWithSourcesChain
|
10 |
+
from backend.constants.myscale_tables import MYSCALE_TABLES
|
11 |
+
from backend.constants.variables import CHAINS_RETRIEVERS_MAPPING, DIVIDER_HTML, RetrieverButtons
|
12 |
+
from backend.callbacks.self_query_callbacks import ChatDataSelfAskCallBackHandler, CustomSelfQueryRetrieverCallBackHandler
|
13 |
+
from ui.utils import display
|
14 |
+
from logger import logger
|
15 |
+
|
16 |
+
|
17 |
+
def process_self_query(selected_table, query_type):
|
18 |
+
place_holder = st.empty()
|
19 |
+
logger.info(
|
20 |
+
f"button-1: {RetrieverButtons.self_query_from_db}, "
|
21 |
+
f"button-2: {RetrieverButtons.self_query_with_llm}, "
|
22 |
+
f"content: {st.session_state.query_self}"
|
23 |
+
)
|
24 |
+
with place_holder.expander('🪵 Chat Log', expanded=True):
|
25 |
+
try:
|
26 |
+
if query_type == RetrieverButtons.self_query_from_db:
|
27 |
+
callback = CustomSelfQueryRetrieverCallBackHandler()
|
28 |
+
retriever: SelfQueryRetriever = \
|
29 |
+
st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["retriever"]
|
30 |
+
config: RunnableConfig = {"callbacks": [callback]}
|
31 |
+
|
32 |
+
relevant_docs = retriever.invoke(
|
33 |
+
input=st.session_state.query_self,
|
34 |
+
config=config
|
35 |
+
)
|
36 |
+
|
37 |
+
callback.progress_bar.progress(
|
38 |
+
value=1.0, text="[Question -> LLM -> Query filter -> MyScaleDB -> Results] Done!✅")
|
39 |
+
|
40 |
+
st.markdown(f"### Self Query Results from `{selected_table}` \n"
|
41 |
+
f"> Here we get documents from MyScaleDB by `SelfQueryRetriever` \n\n")
|
42 |
+
display(
|
43 |
+
dataframe=pd.DataFrame(
|
44 |
+
[{**d.metadata, 'abstract': d.page_content} for d in relevant_docs]
|
45 |
+
),
|
46 |
+
columns_=MYSCALE_TABLES[selected_table].must_have_col_names
|
47 |
+
)
|
48 |
+
elif query_type == RetrieverButtons.self_query_with_llm:
|
49 |
+
# callback = CustomSelfQueryRetrieverCallBackHandler()
|
50 |
+
callback = ChatDataSelfAskCallBackHandler()
|
51 |
+
chain: CustomRetrievalQAWithSourcesChain = \
|
52 |
+
st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["chain"]
|
53 |
+
chain_results = chain(st.session_state.query_self, callbacks=[callback])
|
54 |
+
callback.progress_bar.progress(
|
55 |
+
value=1.0,
|
56 |
+
text="[Question -> LLM -> Query filter -> MyScaleDB -> Related Results -> LLM -> LLM Answer] Done!✅"
|
57 |
+
)
|
58 |
+
|
59 |
+
documents_reference: List[Document] = chain_results["source_documents"]
|
60 |
+
st.markdown(f"### SelfQueryRetriever Results from `{selected_table}` \n"
|
61 |
+
f"> Here we get documents from MyScaleDB by `SelfQueryRetriever` \n\n")
|
62 |
+
display(
|
63 |
+
pd.DataFrame(
|
64 |
+
[{**d.metadata, 'abstract': d.page_content} for d in documents_reference]
|
65 |
+
)
|
66 |
+
)
|
67 |
+
st.markdown(
|
68 |
+
f"### Answer from LLM \n"
|
69 |
+
f"> The response of the LLM when given the `SelfQueryRetriever` results. \n\n"
|
70 |
+
)
|
71 |
+
st.write(chain_results['answer'])
|
72 |
+
st.markdown(
|
73 |
+
f"### References from `{selected_table}`\n"
|
74 |
+
f"> Here shows that which documents used by LLM \n\n"
|
75 |
+
)
|
76 |
+
if len(chain_results['sources']) == 0:
|
77 |
+
st.write("No documents is used by LLM.")
|
78 |
+
else:
|
79 |
+
display(
|
80 |
+
dataframe=pd.DataFrame(
|
81 |
+
[{**d.metadata, 'abstract': d.page_content} for d in chain_results['sources']]
|
82 |
+
),
|
83 |
+
columns_=['ref_id'] + MYSCALE_TABLES[selected_table].must_have_col_names,
|
84 |
+
index='ref_id'
|
85 |
+
)
|
86 |
+
st.markdown(DIVIDER_HTML, unsafe_allow_html=True)
|
87 |
+
except Exception as e:
|
88 |
+
st.write('Oops 😵 Something bad happened...')
|
89 |
+
raise e
|
backend/retrievers/vector_sql_output_parser.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Any, List
|
2 |
+
|
3 |
+
from langchain_experimental.sql.vector_sql import VectorSQLOutputParser
|
4 |
+
|
5 |
+
|
6 |
+
class VectorSQLRetrieveOutputParser(VectorSQLOutputParser):
|
7 |
+
"""Based on VectorSQLOutputParser
|
8 |
+
It also modify the SQL to get all columns
|
9 |
+
"""
|
10 |
+
must_have_columns: List[str]
|
11 |
+
|
12 |
+
@property
|
13 |
+
def _type(self) -> str:
|
14 |
+
return "vector_sql_retrieve_custom"
|
15 |
+
|
16 |
+
def parse(self, text: str) -> Dict[str, Any]:
|
17 |
+
text = text.strip()
|
18 |
+
start = text.upper().find("SELECT")
|
19 |
+
if start >= 0:
|
20 |
+
end = text.upper().find("FROM")
|
21 |
+
text = text.replace(
|
22 |
+
text[start + len("SELECT") + 1: end - 1], ", ".join(self.must_have_columns))
|
23 |
+
return super().parse(text)
|
backend/retrievers/vector_sql_query.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
import streamlit as st
|
5 |
+
from langchain.schema import Document
|
6 |
+
from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever
|
7 |
+
|
8 |
+
from backend.chains.retrieval_qa_with_sources import CustomRetrievalQAWithSourcesChain
|
9 |
+
from backend.constants.myscale_tables import MYSCALE_TABLES
|
10 |
+
from backend.constants.variables import CHAINS_RETRIEVERS_MAPPING, DIVIDER_HTML, RetrieverButtons
|
11 |
+
from backend.callbacks.vector_sql_callbacks import VectorSQLSearchDBCallBackHandler, VectorSQLSearchLLMCallBackHandler
|
12 |
+
from ui.utils import display
|
13 |
+
from logger import logger
|
14 |
+
|
15 |
+
|
16 |
+
def process_sql_query(selected_table: str, query_type: str):
|
17 |
+
place_holder = st.empty()
|
18 |
+
logger.info(
|
19 |
+
f"button-1: {st.session_state[RetrieverButtons.vector_sql_query_from_db]}, "
|
20 |
+
f"button-2: {st.session_state[RetrieverButtons.vector_sql_query_with_llm]}, "
|
21 |
+
f"table: {selected_table}, "
|
22 |
+
f"content: {st.session_state.query_sql}"
|
23 |
+
)
|
24 |
+
with place_holder.expander('🪵 Query Log', expanded=True):
|
25 |
+
try:
|
26 |
+
if query_type == RetrieverButtons.vector_sql_query_from_db:
|
27 |
+
callback = VectorSQLSearchDBCallBackHandler()
|
28 |
+
vector_sql_retriever: VectorSQLDatabaseChainRetriever = \
|
29 |
+
st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["sql_retriever"]
|
30 |
+
relevant_docs: List[Document] = vector_sql_retriever.get_relevant_documents(
|
31 |
+
query=st.session_state.query_sql,
|
32 |
+
callbacks=[callback]
|
33 |
+
)
|
34 |
+
|
35 |
+
callback.progress_bar.progress(
|
36 |
+
value=1.0,
|
37 |
+
text="[Question -> LLM -> SQL Statement -> MyScaleDB -> Results] Done! ✅"
|
38 |
+
)
|
39 |
+
|
40 |
+
st.markdown(f"### Vector Search Results from `{selected_table}` \n"
|
41 |
+
f"> Here we get documents from MyScaleDB with given sql statement \n\n")
|
42 |
+
display(
|
43 |
+
pd.DataFrame(
|
44 |
+
[{**d.metadata, 'abstract': d.page_content} for d in relevant_docs]
|
45 |
+
)
|
46 |
+
)
|
47 |
+
elif query_type == RetrieverButtons.vector_sql_query_with_llm:
|
48 |
+
callback = VectorSQLSearchLLMCallBackHandler(table=selected_table)
|
49 |
+
vector_sql_chain: CustomRetrievalQAWithSourcesChain = \
|
50 |
+
st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["sql_chain"]
|
51 |
+
chain_results = vector_sql_chain(
|
52 |
+
inputs=st.session_state.query_sql,
|
53 |
+
callbacks=[callback]
|
54 |
+
)
|
55 |
+
|
56 |
+
callback.progress_bar.progress(
|
57 |
+
value=1.0,
|
58 |
+
text="[Question -> LLM -> SQL Statement -> MyScaleDB -> "
|
59 |
+
"(Question,Results) -> LLM -> Results] Done! ✅"
|
60 |
+
)
|
61 |
+
|
62 |
+
documents_reference: List[Document] = chain_results["source_documents"]
|
63 |
+
st.markdown(f"### Vector Search Results from `{selected_table}` \n"
|
64 |
+
f"> Here we get documents from MyScaleDB with given sql statement \n\n")
|
65 |
+
display(
|
66 |
+
pd.DataFrame(
|
67 |
+
[{**d.metadata, 'abstract': d.page_content} for d in documents_reference]
|
68 |
+
)
|
69 |
+
)
|
70 |
+
st.markdown(
|
71 |
+
f"### Answer from LLM \n"
|
72 |
+
f"> The response of the LLM when given the vector search results. \n\n"
|
73 |
+
)
|
74 |
+
st.write(chain_results['answer'])
|
75 |
+
st.markdown(
|
76 |
+
f"### References from `{selected_table}`\n"
|
77 |
+
f"> Here shows that which documents used by LLM \n\n"
|
78 |
+
)
|
79 |
+
if len(chain_results['sources']) == 0:
|
80 |
+
st.write("No documents is used by LLM.")
|
81 |
+
else:
|
82 |
+
display(
|
83 |
+
dataframe=pd.DataFrame(
|
84 |
+
[{**d.metadata, 'abstract': d.page_content} for d in chain_results['sources']]
|
85 |
+
),
|
86 |
+
columns_=['ref_id'] + MYSCALE_TABLES[selected_table].must_have_col_names,
|
87 |
+
index='ref_id'
|
88 |
+
)
|
89 |
+
else:
|
90 |
+
raise NotImplementedError(f"Unsupported query type: {query_type}")
|
91 |
+
st.markdown(DIVIDER_HTML, unsafe_allow_html=True)
|
92 |
+
except Exception as e:
|
93 |
+
st.write('Oops 😵 Something bad happened...')
|
94 |
+
raise e
|
95 |
+
|
backend/types/__init__.py
ADDED
File without changes
|
backend/types/chains_and_retrievers.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import List, Any
|
4 |
+
from langchain.retrievers import SelfQueryRetriever
|
5 |
+
from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever
|
6 |
+
|
7 |
+
from backend.chains.retrieval_qa_with_sources import CustomRetrievalQAWithSourcesChain
|
8 |
+
|
9 |
+
|
10 |
+
@dataclass
|
11 |
+
class MetadataColumn:
|
12 |
+
name: str
|
13 |
+
desc: str
|
14 |
+
type: str
|
15 |
+
|
16 |
+
|
17 |
+
@dataclass
|
18 |
+
class ChainsAndRetrievers:
|
19 |
+
metadata_columns: List[MetadataColumn]
|
20 |
+
retriever: SelfQueryRetriever
|
21 |
+
chain: CustomRetrievalQAWithSourcesChain
|
22 |
+
sql_retriever: VectorSQLDatabaseChainRetriever
|
23 |
+
sql_chain: CustomRetrievalQAWithSourcesChain
|
24 |
+
|
25 |
+
def to_dict(self) -> Dict[str, Any]:
|
26 |
+
return {
|
27 |
+
"metadata_columns": self.metadata_columns,
|
28 |
+
"retriever": self.retriever,
|
29 |
+
"chain": self.chain,
|
30 |
+
"sql_retriever": self.sql_retriever,
|
31 |
+
"sql_chain": self.sql_chain
|
32 |
+
}
|
33 |
+
|
34 |
+
|
backend/types/global_config.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
|
5 |
+
@dataclass
|
6 |
+
class GlobalConfig:
|
7 |
+
openai_api_base: Optional[str] = ""
|
8 |
+
openai_api_key: Optional[str] = ""
|
9 |
+
|
10 |
+
auth0_client_id: Optional[str] = ""
|
11 |
+
auth0_domain: Optional[str] = ""
|
12 |
+
|
13 |
+
myscale_user: Optional[str] = ""
|
14 |
+
myscale_password: Optional[str] = ""
|
15 |
+
myscale_host: Optional[str] = ""
|
16 |
+
myscale_port: Optional[int] = 443
|
17 |
+
|
18 |
+
query_model: Optional[str] = ""
|
19 |
+
chat_model: Optional[str] = ""
|
20 |
+
|
21 |
+
untrusted_api: Optional[str] = ""
|
22 |
+
myscale_enable_https: Optional[bool] = True
|
backend/types/table_config.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable
|
2 |
+
from langchain.chains.query_constructor.schema import AttributeInfo
|
3 |
+
from langchain.prompts import PromptTemplate
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from typing import List
|
6 |
+
|
7 |
+
|
8 |
+
@dataclass
|
9 |
+
class TableConfig:
|
10 |
+
database: str
|
11 |
+
table: str
|
12 |
+
table_contents: str
|
13 |
+
# column names
|
14 |
+
must_have_col_names: List[str]
|
15 |
+
vector_col_name: str
|
16 |
+
text_col_name: str
|
17 |
+
metadata_col_name: str
|
18 |
+
# hint for UI
|
19 |
+
hint: Callable
|
20 |
+
hint_sql: Callable
|
21 |
+
# for langchain
|
22 |
+
doc_prompt: PromptTemplate
|
23 |
+
metadata_col_attributes: List[AttributeInfo]
|
24 |
+
emb_model: Callable
|
25 |
+
tool_desc: tuple
|
backend/vector_store/__init__.py
ADDED
File without changes
|
backend/vector_store/myscale_without_metadata.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Optional, List
|
2 |
+
|
3 |
+
from langchain.docstore.document import Document
|
4 |
+
from langchain.embeddings.base import Embeddings
|
5 |
+
from langchain.vectorstores.myscale import MyScale, MyScaleSettings
|
6 |
+
|
7 |
+
from logger import logger
|
8 |
+
|
9 |
+
|
10 |
+
class MyScaleWithoutMetadataJson(MyScale):
|
11 |
+
def __init__(self, embedding: Embeddings, config: Optional[MyScaleSettings] = None, must_have_cols: List[str] = [],
|
12 |
+
**kwargs: Any) -> None:
|
13 |
+
try:
|
14 |
+
super().__init__(embedding, config, **kwargs)
|
15 |
+
except Exception as e:
|
16 |
+
logger.error(e)
|
17 |
+
self.must_have_cols: List[str] = must_have_cols
|
18 |
+
|
19 |
+
def _build_qstr(
|
20 |
+
self, q_emb: List[float], topk: int, where_str: Optional[str] = None
|
21 |
+
) -> str:
|
22 |
+
q_emb_str = ",".join(map(str, q_emb))
|
23 |
+
if where_str:
|
24 |
+
where_str = f"PREWHERE {where_str}"
|
25 |
+
else:
|
26 |
+
where_str = ""
|
27 |
+
|
28 |
+
q_str = f"""
|
29 |
+
SELECT {self.config.column_map['text']}, dist, {','.join(self.must_have_cols)}
|
30 |
+
FROM {self.config.database}.{self.config.table}
|
31 |
+
{where_str}
|
32 |
+
ORDER BY distance({self.config.column_map['vector']}, [{q_emb_str}])
|
33 |
+
AS dist {self.dist_order}
|
34 |
+
LIMIT {topk}
|
35 |
+
"""
|
36 |
+
return q_str
|
37 |
+
|
38 |
+
def similarity_search_by_vector(self, embedding: List[float], k: int = 4, where_str: Optional[str] = None,
|
39 |
+
**kwargs: Any) -> List[Document]:
|
40 |
+
q_str = self._build_qstr(embedding, k, where_str)
|
41 |
+
try:
|
42 |
+
return [
|
43 |
+
Document(
|
44 |
+
page_content=r[self.config.column_map["text"]],
|
45 |
+
metadata={k: r[k] for k in self.must_have_cols},
|
46 |
+
)
|
47 |
+
for r in self.client.query(q_str).named_results()
|
48 |
+
]
|
49 |
+
except Exception as e:
|
50 |
+
logger.error(
|
51 |
+
f"\033[91m\033[1m{type(e)}\033[0m \033[95m{str(e)}\033[0m")
|
52 |
+
return []
|
logger.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
|
4 |
+
def setup_logger():
|
5 |
+
logger_ = logging.getLogger('chat-data')
|
6 |
+
logger_.setLevel(logging.INFO)
|
7 |
+
if not logger_.handlers:
|
8 |
+
console_handler = logging.StreamHandler()
|
9 |
+
console_handler.setLevel(logging.INFO)
|
10 |
+
formatter = logging.Formatter(
|
11 |
+
'%(asctime)s - %(filename)s - %(funcName)s - %(levelname)s - %(message)s - [Thread ID: %(thread)d]'
|
12 |
+
)
|
13 |
+
console_handler.setFormatter(formatter)
|
14 |
+
logger_.addHandler(console_handler)
|
15 |
+
return logger_
|
16 |
+
|
17 |
+
|
18 |
+
logger = setup_logger()
|
requirements.txt
CHANGED
@@ -1,15 +1,17 @@
|
|
1 |
-
langchain
|
2 |
-
langchain-
|
3 |
-
|
4 |
-
|
|
|
|
|
5 |
InstructorEmbedding
|
6 |
pandas
|
7 |
-
|
8 |
-
streamlit
|
9 |
streamlit-auth0-component
|
10 |
altair==4.2.2
|
11 |
clickhouse-connect
|
12 |
-
openai==
|
13 |
lark
|
14 |
tiktoken
|
15 |
sql-formatter
|
|
|
1 |
+
langchain==0.2.1
|
2 |
+
langchain-community==0.2.1
|
3 |
+
langchain-core==0.2.1
|
4 |
+
langchain-experimental==0.0.59
|
5 |
+
langchain-openai==0.1.7
|
6 |
+
sentence-transformers==2.2.2
|
7 |
InstructorEmbedding
|
8 |
pandas
|
9 |
+
streamlit
|
10 |
+
streamlit-extras
|
11 |
streamlit-auth0-component
|
12 |
altair==4.2.2
|
13 |
clickhouse-connect
|
14 |
+
openai==1.35.3
|
15 |
lark
|
16 |
tiktoken
|
17 |
sql-formatter
|
ui/__init__.py
ADDED
File without changes
|
ui/chat_page.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import json
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
import streamlit as st
|
6 |
+
from langchain_core.messages import HumanMessage, FunctionMessage
|
7 |
+
from streamlit.delta_generator import DeltaGenerator
|
8 |
+
|
9 |
+
from backend.chat_bot.json_decoder import CustomJSONDecoder
|
10 |
+
from backend.constants.streamlit_keys import CHAT_CURRENT_USER_SESSIONS, EL_SESSION_SELECTOR, \
|
11 |
+
EL_UPLOAD_FILES_STATUS, USER_PRIVATE_FILES, EL_BUILD_KB_WITH_FILES, \
|
12 |
+
EL_PERSONAL_KB_NAME, EL_PERSONAL_KB_DESCRIPTION, \
|
13 |
+
USER_PERSONAL_KNOWLEDGE_BASES, AVAILABLE_RETRIEVAL_TOOLS, EL_PERSONAL_KB_NEEDS_REMOVE, \
|
14 |
+
CHAT_KNOWLEDGE_TABLE, EL_UPLOAD_FILES, EL_SELECTED_KBS
|
15 |
+
from backend.constants.variables import DIVIDER_HTML, USER_NAME, RETRIEVER_TOOLS
|
16 |
+
from backend.construct.build_chat_bot import build_chat_knowledge_table, initialize_session_manager
|
17 |
+
from backend.chat_bot.chat import refresh_sessions, on_session_change_submit, refresh_agent, \
|
18 |
+
create_private_knowledge_base_as_tool, \
|
19 |
+
remove_private_knowledge_bases, add_file, clear_files, clear_history, back_to_main, on_chat_submit
|
20 |
+
|
21 |
+
|
22 |
+
def render_session_manager():
|
23 |
+
with st.expander("🤖 Session Management"):
|
24 |
+
if CHAT_CURRENT_USER_SESSIONS not in st.session_state:
|
25 |
+
refresh_sessions()
|
26 |
+
st.markdown("Here you can update `session_id` and `system_prompt`")
|
27 |
+
st.markdown("- Click empty row to add a new item")
|
28 |
+
st.markdown("- If needs to delete an item, just click it and press `DEL` key")
|
29 |
+
st.markdown("- Don't forget to submit your change.")
|
30 |
+
|
31 |
+
st.data_editor(
|
32 |
+
data=st.session_state[CHAT_CURRENT_USER_SESSIONS],
|
33 |
+
num_rows="dynamic",
|
34 |
+
key="session_editor",
|
35 |
+
use_container_width=True,
|
36 |
+
)
|
37 |
+
st.button("⏫ Submit", on_click=on_session_change_submit, type="primary")
|
38 |
+
|
39 |
+
|
40 |
+
def render_session_selection():
|
41 |
+
with st.expander("✅ Session Selection", expanded=True):
|
42 |
+
st.selectbox(
|
43 |
+
"Choose a `session` to chat",
|
44 |
+
options=st.session_state[CHAT_CURRENT_USER_SESSIONS],
|
45 |
+
index=None,
|
46 |
+
key=EL_SESSION_SELECTOR,
|
47 |
+
format_func=lambda x: x["session_id"],
|
48 |
+
on_change=refresh_agent,
|
49 |
+
)
|
50 |
+
|
51 |
+
|
52 |
+
def render_files_manager():
|
53 |
+
with st.expander("📃 **Upload your personal files**", expanded=False):
|
54 |
+
st.markdown("- Files will be parsed by [Unstructured API](https://unstructured.io/api-key).")
|
55 |
+
st.markdown("- All files will be converted into vectors and stored in [MyScaleDB](https://myscale.com/).")
|
56 |
+
st.file_uploader(label="⏫ **Upload files**", key=EL_UPLOAD_FILES, accept_multiple_files=True)
|
57 |
+
# st.markdown("### Uploaded Files")
|
58 |
+
st.dataframe(
|
59 |
+
data=st.session_state[CHAT_KNOWLEDGE_TABLE].list_files(st.session_state[USER_NAME]),
|
60 |
+
use_container_width=True,
|
61 |
+
)
|
62 |
+
st.session_state[EL_UPLOAD_FILES_STATUS] = st.empty()
|
63 |
+
col_1, col_2 = st.columns(2)
|
64 |
+
with col_1:
|
65 |
+
st.button(label="Upload files", on_click=add_file)
|
66 |
+
with col_2:
|
67 |
+
st.button(label="Clear all files and tools", on_click=clear_files)
|
68 |
+
|
69 |
+
|
70 |
+
def _render_create_personal_knowledge_bases(div: DeltaGenerator):
|
71 |
+
with div:
|
72 |
+
st.markdown("- If you haven't upload your personal files, please upload them first.")
|
73 |
+
st.markdown("- Select some **files** to build your `personal knowledge base`.")
|
74 |
+
st.markdown("- Once the your `personal knowledge base` is built, "
|
75 |
+
"it will answer your questions using information from your personal **files**.")
|
76 |
+
st.multiselect(
|
77 |
+
label="⚡️Select some files to build a **personal knowledge base**",
|
78 |
+
options=st.session_state[USER_PRIVATE_FILES],
|
79 |
+
placeholder="You should upload some files first",
|
80 |
+
key=EL_BUILD_KB_WITH_FILES,
|
81 |
+
format_func=lambda x: x["file_name"],
|
82 |
+
)
|
83 |
+
st.text_input(
|
84 |
+
label="⚡️Personal knowledge base name",
|
85 |
+
value="get_relevant_documents",
|
86 |
+
key=EL_PERSONAL_KB_NAME
|
87 |
+
)
|
88 |
+
st.text_input(
|
89 |
+
label="⚡️Personal knowledge base description",
|
90 |
+
value="Searches from some personal files.",
|
91 |
+
key=EL_PERSONAL_KB_DESCRIPTION,
|
92 |
+
)
|
93 |
+
st.button(
|
94 |
+
label="Build 🔧",
|
95 |
+
on_click=create_private_knowledge_base_as_tool
|
96 |
+
)
|
97 |
+
|
98 |
+
|
99 |
+
def _render_remove_personal_knowledge_bases(div: DeltaGenerator):
|
100 |
+
with div:
|
101 |
+
st.markdown("> Here is all your personal knowledge bases.")
|
102 |
+
if USER_PERSONAL_KNOWLEDGE_BASES in st.session_state and len(st.session_state[USER_PERSONAL_KNOWLEDGE_BASES]) > 0:
|
103 |
+
st.dataframe(st.session_state[USER_PERSONAL_KNOWLEDGE_BASES])
|
104 |
+
else:
|
105 |
+
st.warning("You don't have any personal knowledge bases, please create a new one.")
|
106 |
+
st.multiselect(
|
107 |
+
label="Choose a personal knowledge base to delete",
|
108 |
+
placeholder="Choose a personal knowledge base to delete",
|
109 |
+
options=st.session_state[USER_PERSONAL_KNOWLEDGE_BASES],
|
110 |
+
format_func=lambda x: x["tool_name"],
|
111 |
+
key=EL_PERSONAL_KB_NEEDS_REMOVE,
|
112 |
+
)
|
113 |
+
st.button("Delete", on_click=remove_private_knowledge_bases, type="primary")
|
114 |
+
|
115 |
+
|
116 |
+
def render_personal_tools_build():
|
117 |
+
with st.expander("🔨 **Build your personal knowledge base**", expanded=True):
|
118 |
+
create_new_kb, kb_manager = st.tabs(["Create personal knowledge base", "Personal knowledge base management"])
|
119 |
+
_render_create_personal_knowledge_bases(create_new_kb)
|
120 |
+
_render_remove_personal_knowledge_bases(kb_manager)
|
121 |
+
|
122 |
+
|
123 |
+
def render_knowledge_base_selector():
|
124 |
+
with st.expander("🙋 **Select some knowledge bases to query**", expanded=True):
|
125 |
+
st.markdown("- Knowledge bases come in two types: `public` and `private`.")
|
126 |
+
st.markdown("- All users can access our `public` knowledge bases.")
|
127 |
+
st.markdown("- Only you can access your `personal` knowledge bases.")
|
128 |
+
options = st.session_state[RETRIEVER_TOOLS].keys()
|
129 |
+
if AVAILABLE_RETRIEVAL_TOOLS in st.session_state:
|
130 |
+
options = st.session_state[AVAILABLE_RETRIEVAL_TOOLS]
|
131 |
+
st.multiselect(
|
132 |
+
label="Select some knowledge base tool",
|
133 |
+
placeholder="Please select some knowledge bases to query",
|
134 |
+
options=options,
|
135 |
+
default=["Wikipedia + Self Querying"],
|
136 |
+
key=EL_SELECTED_KBS,
|
137 |
+
on_change=refresh_agent,
|
138 |
+
)
|
139 |
+
|
140 |
+
|
141 |
+
def chat_page():
|
142 |
+
# initialize resources
|
143 |
+
build_chat_knowledge_table()
|
144 |
+
initialize_session_manager()
|
145 |
+
|
146 |
+
# render sidebar
|
147 |
+
with st.sidebar:
|
148 |
+
left, middle, right = st.columns([1, 1, 2])
|
149 |
+
with left:
|
150 |
+
st.button(label="↩️ Log Out", help="log out and back to main page", on_click=back_to_main)
|
151 |
+
with right:
|
152 |
+
st.markdown(f"👤 `{st.session_state[USER_NAME]}`")
|
153 |
+
st.markdown(DIVIDER_HTML, unsafe_allow_html=True)
|
154 |
+
render_session_manager()
|
155 |
+
render_session_selection()
|
156 |
+
render_files_manager()
|
157 |
+
render_personal_tools_build()
|
158 |
+
render_knowledge_base_selector()
|
159 |
+
|
160 |
+
# render chat history
|
161 |
+
if "agent" not in st.session_state:
|
162 |
+
refresh_agent()
|
163 |
+
for msg in st.session_state.agent.memory.chat_memory.messages:
|
164 |
+
speaker = "user" if isinstance(msg, HumanMessage) else "assistant"
|
165 |
+
if isinstance(msg, FunctionMessage):
|
166 |
+
with st.chat_message(name="from knowledge base", avatar="📚"):
|
167 |
+
st.write(
|
168 |
+
f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*"
|
169 |
+
)
|
170 |
+
st.write("Retrieved from knowledge base:")
|
171 |
+
try:
|
172 |
+
st.dataframe(
|
173 |
+
pd.DataFrame.from_records(
|
174 |
+
json.loads(msg.content, cls=CustomJSONDecoder)
|
175 |
+
),
|
176 |
+
use_container_width=True,
|
177 |
+
)
|
178 |
+
except Exception as e:
|
179 |
+
st.warning(e)
|
180 |
+
st.write(msg.content)
|
181 |
+
else:
|
182 |
+
if len(msg.content) > 0:
|
183 |
+
with st.chat_message(speaker):
|
184 |
+
# print(type(msg), msg.dict())
|
185 |
+
st.write(
|
186 |
+
f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*"
|
187 |
+
)
|
188 |
+
st.write(f"{msg.content}")
|
189 |
+
st.session_state["next_round"] = st.empty()
|
190 |
+
from streamlit import _bottom
|
191 |
+
with _bottom:
|
192 |
+
col1, col2 = st.columns([1, 16])
|
193 |
+
with col1:
|
194 |
+
st.button("🗑️", help="Clean chat history", on_click=clear_history, type="secondary")
|
195 |
+
with col2:
|
196 |
+
st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input")
|
ui/home.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
|
3 |
+
from streamlit_extras.add_vertical_space import add_vertical_space
|
4 |
+
from streamlit_extras.card import card
|
5 |
+
from streamlit_extras.colored_header import colored_header
|
6 |
+
from streamlit_extras.mention import mention
|
7 |
+
from streamlit_extras.tags import tagger_component
|
8 |
+
|
9 |
+
from logger import logger
|
10 |
+
import os
|
11 |
+
|
12 |
+
import streamlit as st
|
13 |
+
from auth0_component import login_button
|
14 |
+
|
15 |
+
from backend.constants.variables import JUMP_QUERY_ASK, USER_INFO, USER_NAME, DIVIDER_HTML, DIVIDER_THIN_HTML
|
16 |
+
from streamlit_extras.let_it_rain import rain
|
17 |
+
|
18 |
+
|
19 |
+
def render_home():
|
20 |
+
render_home_header()
|
21 |
+
# st.divider()
|
22 |
+
# st.markdown(DIVIDER_THIN_HTML, unsafe_allow_html=True)
|
23 |
+
add_vertical_space(5)
|
24 |
+
render_home_content()
|
25 |
+
# st.divider()
|
26 |
+
st.markdown(DIVIDER_THIN_HTML, unsafe_allow_html=True)
|
27 |
+
render_home_footer()
|
28 |
+
|
29 |
+
|
30 |
+
def render_home_header():
|
31 |
+
logger.info("render home header")
|
32 |
+
st.header("ChatData - Your Intelligent Assistant")
|
33 |
+
st.markdown(DIVIDER_THIN_HTML, unsafe_allow_html=True)
|
34 |
+
st.markdown("> [ChatData](https://github.com/myscale/ChatData) \
|
35 |
+
is developed by [MyScale](https://myscale.com/), \
|
36 |
+
it's an integration of [LangChain](https://www.langchain.com/) \
|
37 |
+
and [MyScaleDB](https://github.com/myscale/myscaledb)")
|
38 |
+
|
39 |
+
tagger_component(
|
40 |
+
"Keywords:",
|
41 |
+
["MyScaleDB", "LangChain", "VectorSearch", "ChatBot", "GPT", "arxiv", "wikipedia", "Personal Knowledge Base 📚"],
|
42 |
+
color_name=["darkslateblue", "green", "orange", "darkslategrey", "red", "crimson", "darkcyan", "darkgrey"],
|
43 |
+
)
|
44 |
+
text, col1, col2, col3, _ = st.columns([1, 1, 1, 1, 4])
|
45 |
+
with text:
|
46 |
+
st.markdown("Related:")
|
47 |
+
with col1.container():
|
48 |
+
mention(
|
49 |
+
label="streamlit",
|
50 |
+
icon="streamlit",
|
51 |
+
url="https://streamlit.io/",
|
52 |
+
write=True
|
53 |
+
)
|
54 |
+
with col2.container():
|
55 |
+
mention(
|
56 |
+
label="langchain",
|
57 |
+
icon="🦜🔗",
|
58 |
+
url="https://www.langchain.com/",
|
59 |
+
write=True
|
60 |
+
)
|
61 |
+
with col3.container():
|
62 |
+
mention(
|
63 |
+
label="streamlit-extras",
|
64 |
+
icon="🪢",
|
65 |
+
url="https://github.com/arnaudmiribel/streamlit-extras",
|
66 |
+
write=True
|
67 |
+
)
|
68 |
+
|
69 |
+
|
70 |
+
def _render_self_query_chain_content():
|
71 |
+
col1, col2 = st.columns([1, 1], gap='large')
|
72 |
+
with col1.container():
|
73 |
+
st.image(image='../assets/home_page_background_1.png',
|
74 |
+
caption=None,
|
75 |
+
width=None,
|
76 |
+
use_column_width=True,
|
77 |
+
clamp=False,
|
78 |
+
channels="RGB",
|
79 |
+
output_format="PNG")
|
80 |
+
with col2.container():
|
81 |
+
st.header("VectorSearch & SelfQuery with Sources")
|
82 |
+
st.info("In this sample, you will learn how **LangChain** integrates with **MyScaleDB**.")
|
83 |
+
st.markdown("""This example demonstrates two methods for integrating MyScale into LangChain: [Vector SQL](https://api.python.langchain.com/en/latest/sql/langchain_experimental.sql.vector_sql.VectorSQLDatabaseChain.html) and [Self-querying retriever](https://python.langchain.com/v0.2/docs/integrations/retrievers/self_query/myscale_self_query/). For each method, you can choose one of the following options:
|
84 |
+
|
85 |
+
1. `Retrieve from MyScaleDB ➡️` - The LLM (GPT) converts user queries into SQL statements with vector search, executes these searches in MyScaleDB, and retrieves relevant content.
|
86 |
+
|
87 |
+
2. `Retrieve and answer with LLM ➡️` - After retrieving relevant content from MyScaleDB, the user query along with the retrieved content is sent to the LLM (GPT), which then provides a comprehensive answer.""")
|
88 |
+
add_vertical_space(3)
|
89 |
+
_, middle, _ = st.columns([2, 1, 2], gap='small')
|
90 |
+
with middle.container():
|
91 |
+
st.session_state[JUMP_QUERY_ASK] = st.button("Try sample", use_container_width=False, type="secondary")
|
92 |
+
|
93 |
+
|
94 |
+
def _render_chat_bot_content():
|
95 |
+
col1, col2 = st.columns(2, gap='large')
|
96 |
+
with col1.container():
|
97 |
+
st.image(image='../assets/home_page_background_2.png',
|
98 |
+
caption=None,
|
99 |
+
width=None,
|
100 |
+
use_column_width=True,
|
101 |
+
clamp=False,
|
102 |
+
channels="RGB",
|
103 |
+
output_format="PNG")
|
104 |
+
with col2.container():
|
105 |
+
st.header("Chat Bot")
|
106 |
+
st.info("Now you can try our chatbot, this chatbot is built with MyScale and LangChain.")
|
107 |
+
st.markdown("- You need to log in. We use `user_name` to identify each customer.")
|
108 |
+
st.markdown("- You can upload your own PDF files and build your own knowledge base. \
|
109 |
+
(This is just a sample application. Please do not upload important or confidential files.)")
|
110 |
+
st.markdown("- A default session will be assigned as your initial chat session. \
|
111 |
+
You can create and switch to other sessions to jump between different chat conversations.")
|
112 |
+
add_vertical_space(1)
|
113 |
+
_, middle, _ = st.columns([1, 2, 1], gap='small')
|
114 |
+
with middle.container():
|
115 |
+
if USER_NAME not in st.session_state:
|
116 |
+
login_button(clientId=os.environ["AUTH0_CLIENT_ID"],
|
117 |
+
domain=os.environ["AUTH0_DOMAIN"],
|
118 |
+
key="auth0")
|
119 |
+
# if user_info:
|
120 |
+
# user_name = user_info.get("nickname", "default") + "_" + user_info.get("email", "null")
|
121 |
+
# st.session_state[USER_NAME] = user_name
|
122 |
+
# print(user_info)
|
123 |
+
|
124 |
+
|
125 |
+
def render_home_content():
|
126 |
+
logger.info("render home content")
|
127 |
+
_render_self_query_chain_content()
|
128 |
+
add_vertical_space(3)
|
129 |
+
_render_chat_bot_content()
|
130 |
+
|
131 |
+
|
132 |
+
def render_home_footer():
|
133 |
+
logger.info("render home footer")
|
134 |
+
st.write(
|
135 |
+
"Please follow us on [Twitter](https://x.com/myscaledb) and [Discord](https://discord.gg/D2qpkqc4Jq)!"
|
136 |
+
)
|
137 |
+
st.write(
|
138 |
+
"For more details, please refer to [our repository on GitHub](https://github.com/myscale/ChatData)!")
|
139 |
+
st.write("Our [privacy policy](https://myscale.com/privacy/), [terms of service](https://myscale.com/terms/)")
|
140 |
+
|
141 |
+
# st.write(
|
142 |
+
# "Recommended to use the standalone version of Chat-Data, "
|
143 |
+
# "available [here](https://myscale-chatdata.hf.space/)."
|
144 |
+
# )
|
145 |
+
|
146 |
+
if st.session_state.auth0 is not None:
|
147 |
+
st.session_state[USER_INFO] = dict(st.session_state.auth0)
|
148 |
+
if 'email' in st.session_state[USER_INFO]:
|
149 |
+
email = st.session_state[USER_INFO]["email"]
|
150 |
+
else:
|
151 |
+
email = f"{st.session_state[USER_INFO]['nickname']}@{st.session_state[USER_INFO]['sub']}"
|
152 |
+
st.session_state["user_name"] = email
|
153 |
+
del st.session_state.auth0
|
154 |
+
st.rerun()
|
155 |
+
if st.session_state.jump_query_ask:
|
156 |
+
st.rerun()
|
ui/retrievers.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from streamlit_extras.add_vertical_space import add_vertical_space
|
3 |
+
|
4 |
+
from backend.constants.myscale_tables import MYSCALE_TABLES
|
5 |
+
from backend.constants.variables import CHAINS_RETRIEVERS_MAPPING, RetrieverButtons
|
6 |
+
from backend.retrievers.self_query import process_self_query
|
7 |
+
from backend.retrievers.vector_sql_query import process_sql_query
|
8 |
+
from backend.constants.variables import JUMP_QUERY_ASK, USER_NAME, USER_INFO
|
9 |
+
|
10 |
+
|
11 |
+
def back_to_main():
|
12 |
+
if USER_INFO in st.session_state:
|
13 |
+
del st.session_state[USER_INFO]
|
14 |
+
if USER_NAME in st.session_state:
|
15 |
+
del st.session_state[USER_NAME]
|
16 |
+
if JUMP_QUERY_ASK in st.session_state:
|
17 |
+
del st.session_state[JUMP_QUERY_ASK]
|
18 |
+
|
19 |
+
|
20 |
+
def _render_table_selector() -> str:
|
21 |
+
col1, col2 = st.columns(2)
|
22 |
+
with col1:
|
23 |
+
selected_table = st.selectbox(
|
24 |
+
label='Each public knowledge base is stored in a MyScaleDB table, which is read-only.',
|
25 |
+
options=MYSCALE_TABLES.keys(),
|
26 |
+
)
|
27 |
+
MYSCALE_TABLES[selected_table].hint()
|
28 |
+
with col2:
|
29 |
+
add_vertical_space(1)
|
30 |
+
st.info(f"Here is your selected public knowledge base schema in MyScaleDB",
|
31 |
+
icon='📚')
|
32 |
+
MYSCALE_TABLES[selected_table].hint_sql()
|
33 |
+
|
34 |
+
return selected_table
|
35 |
+
|
36 |
+
|
37 |
+
def render_retrievers():
|
38 |
+
st.button("⬅️ Back", key="back_sql", on_click=back_to_main)
|
39 |
+
st.subheader('Please choose a public knowledge base to search.')
|
40 |
+
selected_table = _render_table_selector()
|
41 |
+
|
42 |
+
tab_sql, tab_self_query = st.tabs(
|
43 |
+
tabs=['Vector SQL', 'Self-querying Retriever']
|
44 |
+
)
|
45 |
+
|
46 |
+
with tab_sql:
|
47 |
+
render_tab_sql(selected_table)
|
48 |
+
|
49 |
+
with tab_self_query:
|
50 |
+
render_tab_self_query(selected_table)
|
51 |
+
|
52 |
+
|
53 |
+
def render_tab_sql(selected_table: str):
|
54 |
+
st.warning(
|
55 |
+
"When you input a query with filtering conditions, you need to ensure that your filters are applied only to "
|
56 |
+
"the metadata we provide. This table allows filters to be established on the following metadata fields:",
|
57 |
+
icon="⚠️")
|
58 |
+
st.dataframe(st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["metadata_columns"])
|
59 |
+
|
60 |
+
cols = st.columns([8, 3, 3, 2])
|
61 |
+
cols[0].text_input("Input your question:", key='query_sql')
|
62 |
+
with cols[1].container():
|
63 |
+
add_vertical_space(2)
|
64 |
+
st.button("Retrieve from MyScaleDB ➡️", key=RetrieverButtons.vector_sql_query_from_db)
|
65 |
+
with cols[2].container():
|
66 |
+
add_vertical_space(2)
|
67 |
+
st.button("Retrieve and answer with LLM ➡️", key=RetrieverButtons.vector_sql_query_with_llm)
|
68 |
+
|
69 |
+
if st.session_state[RetrieverButtons.vector_sql_query_from_db]:
|
70 |
+
process_sql_query(selected_table, RetrieverButtons.vector_sql_query_from_db)
|
71 |
+
|
72 |
+
if st.session_state[RetrieverButtons.vector_sql_query_with_llm]:
|
73 |
+
process_sql_query(selected_table, RetrieverButtons.vector_sql_query_with_llm)
|
74 |
+
|
75 |
+
|
76 |
+
def render_tab_self_query(selected_table):
|
77 |
+
st.warning(
|
78 |
+
"When you input a query with filtering conditions, you need to ensure that your filters are applied only to "
|
79 |
+
"the metadata we provide. This table allows filters to be established on the following metadata fields:",
|
80 |
+
icon="⚠️")
|
81 |
+
st.dataframe(st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["metadata_columns"])
|
82 |
+
|
83 |
+
cols = st.columns([8, 3, 3, 2])
|
84 |
+
cols[0].text_input("Input your question:", key='query_self')
|
85 |
+
|
86 |
+
with cols[1].container():
|
87 |
+
add_vertical_space(2)
|
88 |
+
st.button("Retrieve from MyScaleDB ➡️", key='search_self')
|
89 |
+
with cols[2].container():
|
90 |
+
add_vertical_space(2)
|
91 |
+
st.button("Retrieve and answer with LLM ➡️", key='ask_self')
|
92 |
+
|
93 |
+
if st.session_state.search_self:
|
94 |
+
process_self_query(selected_table, RetrieverButtons.self_query_from_db)
|
95 |
+
|
96 |
+
if st.session_state.ask_self:
|
97 |
+
process_self_query(selected_table, RetrieverButtons.self_query_with_llm)
|
ui/utils.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
|
4 |
+
def display(dataframe, columns_=None, index=None):
|
5 |
+
if len(dataframe) > 0:
|
6 |
+
if index:
|
7 |
+
dataframe.set_index(index)
|
8 |
+
if columns_:
|
9 |
+
st.dataframe(dataframe[columns_])
|
10 |
+
else:
|
11 |
+
st.dataframe(dataframe)
|
12 |
+
else:
|
13 |
+
st.write(
|
14 |
+
"Sorry 😵 we didn't find any articles related to your query.\n\n"
|
15 |
+
"Maybe the LLM is too naughty that does not follow our instruction... \n\n"
|
16 |
+
"Please try again and use verbs that may match the datatype.",
|
17 |
+
unsafe_allow_html=True
|
18 |
+
)
|