ChatData / chat.py
lqhl's picture
Synced repo using 'sync_with_huggingface' Github Action
0e573d0 verified
import json
import pandas as pd
from os import environ
from time import sleep
import datetime
import streamlit as st
from lib.sessions import SessionManager
from lib.private_kb import PrivateKnowledgeBase
from langchain.schema import HumanMessage, FunctionMessage
from callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler
from lib.json_conv import CustomJSONDecoder
from lib.helper import (
build_agents,
MYSCALE_HOST,
MYSCALE_PASSWORD,
MYSCALE_PORT,
MYSCALE_USER,
DEFAULT_SYSTEM_PROMPT,
UNSTRUCTURED_API,
)
from login import back_to_main
environ["OPENAI_API_BASE"] = st.secrets["OPENAI_API_BASE"]
TOOL_NAMES = {
"langchain_retriever_tool": "Self-querying retriever",
"vecsql_retriever_tool": "Vector SQL",
}
def on_chat_submit():
with st.session_state.next_round.container():
with st.chat_message("user"):
st.write(st.session_state.chat_input)
with st.chat_message("assistant"):
container = st.container()
st_callback = ChatDataAgentCallBackHandler(
container, collapse_completed_thoughts=False
)
ret = st.session_state.agent(
{"input": st.session_state.chat_input}, callbacks=[st_callback]
)
print(ret)
def clear_history():
if "agent" in st.session_state:
st.session_state.agent.memory.clear()
def back_to_main():
if "user_info" in st.session_state:
del st.session_state.user_info
if "user_name" in st.session_state:
del st.session_state.user_name
if "jump_query_ask" in st.session_state:
del st.session_state.jump_query_ask
if "sel_sess" in st.session_state:
del st.session_state.sel_sess
if "current_sessions" in st.session_state:
del st.session_state.current_sessions
def on_session_change_submit():
if "session_manager" in st.session_state and "session_editor" in st.session_state:
print(st.session_state.session_editor)
try:
for elem in st.session_state.session_editor["added_rows"]:
if len(elem) > 0 and "system_prompt" in elem and "session_id" in elem:
if elem["session_id"] != "" and "?" not in elem["session_id"]:
st.session_state.session_manager.add_session(
user_id=st.session_state.user_name,
session_id=f"{st.session_state.user_name}?{elem['session_id']}",
system_prompt=elem["system_prompt"],
)
else:
raise KeyError(
"`session_id` should NOT be neither empty nor contain question marks."
)
else:
raise KeyError(
"You should fill both `session_id` and `system_prompt` to add a column!"
)
for elem in st.session_state.session_editor["deleted_rows"]:
st.session_state.session_manager.remove_session(
session_id=f"{st.session_state.user_name}?{st.session_state.current_sessions[elem]['session_id']}",
)
refresh_sessions()
except Exception as e:
sleep(2)
st.error(f"{type(e)}: {str(e)}")
finally:
st.session_state.session_editor["added_rows"] = []
st.session_state.session_editor["deleted_rows"] = []
refresh_agent()
def build_session_manager():
return SessionManager(
st.session_state,
host=MYSCALE_HOST,
port=MYSCALE_PORT,
username=MYSCALE_USER,
password=MYSCALE_PASSWORD,
)
def refresh_sessions():
st.session_state[
"current_sessions"
] = st.session_state.session_manager.list_sessions(st.session_state.user_name)
if (
type(st.session_state.current_sessions) is not dict
and len(st.session_state.current_sessions) <= 0
):
st.session_state.session_manager.add_session(
st.session_state.user_name,
f"{st.session_state.user_name}?default",
DEFAULT_SYSTEM_PROMPT,
)
st.session_state[
"current_sessions"
] = st.session_state.session_manager.list_sessions(st.session_state.user_name)
st.session_state["user_files"] = st.session_state.private_kb.list_files(
st.session_state.user_name
)
st.session_state["user_tools"] = st.session_state.private_kb.list_tools(
st.session_state.user_name
)
st.session_state["tools_with_users"] = {
**st.session_state.tools,
**st.session_state.private_kb.as_tools(st.session_state.user_name),
}
try:
dfl_indx = [x["session_id"] for x in st.session_state.current_sessions].index(
"default"
if "" not in st.session_state
else st.session_state.sel_session["session_id"]
)
except ValueError:
dfl_indx = 0
st.session_state.sel_sess = st.session_state.current_sessions[dfl_indx]
def build_kb_as_tool():
if (
"b_tool_name" in st.session_state
and "b_tool_desc" in st.session_state
and "b_tool_files" in st.session_state
and len(st.session_state.b_tool_name) > 0
and len(st.session_state.b_tool_desc) > 0
and len(st.session_state.b_tool_files) > 0
):
st.session_state.private_kb.create_tool(
st.session_state.user_name,
st.session_state.b_tool_name,
st.session_state.b_tool_desc,
[f["file_name"] for f in st.session_state.b_tool_files],
)
refresh_sessions()
else:
st.session_state.tool_status.error(
"You should fill all fields to build up a tool!"
)
sleep(2)
def remove_kb():
if "r_tool_names" in st.session_state and len(st.session_state.r_tool_names) > 0:
st.session_state.private_kb.remove_tools(
st.session_state.user_name,
[f["tool_name"] for f in st.session_state.r_tool_names],
)
refresh_sessions()
else:
st.session_state.tool_status.error(
"You should specify at least one tool to delete!"
)
sleep(2)
def refresh_agent():
with st.spinner("Initializing session..."):
print(
f"??? Changed to ",
f"{st.session_state.user_name}?{st.session_state.sel_sess['session_id']}",
)
st.session_state["agent"] = build_agents(
f"{st.session_state.user_name}?{st.session_state.sel_sess['session_id']}",
["LangChain Self Query Retriever For Wikipedia"]
if "selected_tools" not in st.session_state
else st.session_state.selected_tools,
system_prompt=DEFAULT_SYSTEM_PROMPT
if "sel_sess" not in st.session_state
else st.session_state.sel_sess["system_prompt"],
)
def add_file():
if (
"uploaded_files" not in st.session_state
or len(st.session_state.uploaded_files) == 0
):
st.session_state.tool_status.error("Please upload files!", icon="⚠️")
sleep(2)
return
try:
st.session_state.tool_status.info("Uploading...")
st.session_state.private_kb.add_by_file(
st.session_state.user_name, st.session_state.uploaded_files
)
refresh_sessions()
except ValueError as e:
st.session_state.tool_status.error("Failed to upload! " + str(e))
sleep(2)
def clear_files():
st.session_state.private_kb.clear(st.session_state.user_name)
refresh_sessions()
def chat_page():
if "sel_sess" not in st.session_state:
st.session_state["sel_sess"] = {
"session_id": "default",
"system_prompt": DEFAULT_SYSTEM_PROMPT,
}
if "private_kb" not in st.session_state:
st.session_state["private_kb"] = PrivateKnowledgeBase(
host=MYSCALE_HOST,
port=MYSCALE_PORT,
username=MYSCALE_USER,
password=MYSCALE_PASSWORD,
embedding=st.session_state.embeddings["Wikipedia"],
parser_api_key=UNSTRUCTURED_API,
)
if "session_manager" not in st.session_state:
st.session_state["session_manager"] = build_session_manager()
with st.sidebar:
with st.expander("Session Management"):
if "current_sessions" not in st.session_state:
refresh_sessions()
st.info(
"Here you can set up your session! \n\nYou can **change your prompt** here!",
icon="πŸ€–",
)
st.info(
(
"**Add columns by clicking the empty row**.\n"
"And **delete columns by selecting rows with a press on `DEL` Key**"
),
icon="πŸ’‘",
)
st.info(
"Don't forget to **click `Submit Change` to save your change**!",
icon="πŸ“’",
)
st.data_editor(
st.session_state.current_sessions,
num_rows="dynamic",
key="session_editor",
use_container_width=True,
)
st.button("Submit Change!", on_click=on_session_change_submit)
with st.expander("Session Selection", expanded=True):
st.info(
"If no session is attach to your account, then we will add a default session to you!",
icon="❀️",
)
try:
dfl_indx = [
x["session_id"] for x in st.session_state.current_sessions
].index(
"default"
if "" not in st.session_state
else st.session_state.sel_session["session_id"]
)
except Exception as e:
print("*** ", str(e))
dfl_indx = 0
st.selectbox(
"Choose a session to chat:",
options=st.session_state.current_sessions,
index=dfl_indx,
key="sel_sess",
format_func=lambda x: x["session_id"],
on_change=refresh_agent,
)
print(st.session_state.sel_sess)
with st.expander("Tool Settings", expanded=True):
st.info(
"We provides you several knowledge base tools for you. We are building more tools!",
icon="πŸ”§",
)
st.session_state["tool_status"] = st.empty()
tab_kb, tab_file = st.tabs(
[
"Knowledge Bases",
"File Upload",
]
)
with tab_kb:
st.markdown("#### Build You Own Knowledge")
st.multiselect(
"Select Files to Build up",
st.session_state.user_files,
placeholder="You should upload files first",
key="b_tool_files",
format_func=lambda x: x["file_name"],
)
st.text_input(
"Tool Name", "get_relevant_documents", key="b_tool_name")
st.text_input(
"Tool Description",
"Searches among user's private files and returns related documents",
key="b_tool_desc",
)
st.button("Build!", on_click=build_kb_as_tool)
st.markdown("### Knowledge Base Selection")
if (
"user_tools" in st.session_state
and len(st.session_state.user_tools) > 0
):
st.markdown("***User Created Knowledge Bases***")
st.dataframe(st.session_state.user_tools)
st.multiselect(
"Select a Knowledge Base Tool",
st.session_state.tools.keys()
if "tools_with_users" not in st.session_state
else st.session_state.tools_with_users,
default=["Wikipedia + Self Querying"],
key="selected_tools",
on_change=refresh_agent,
)
st.markdown("### Delete Knowledge Base")
st.multiselect(
"Choose Knowledge Base to Remove",
st.session_state.user_tools,
format_func=lambda x: x["tool_name"],
key="r_tool_names",
)
st.button("Delete", on_click=remove_kb)
with tab_file:
st.info(
(
"We adopted [Unstructured API](https://unstructured.io/api-key) "
"here and we only store the processed texts from your documents. "
"For privacy concerns, please refer to "
"[our policy issue](https://myscale.com/privacy/)."
),
icon="πŸ“ƒ",
)
st.file_uploader(
"Upload files", key="uploaded_files", accept_multiple_files=True
)
st.markdown("### Uploaded Files")
st.dataframe(
st.session_state.private_kb.list_files(
st.session_state.user_name),
use_container_width=True,
)
col_1, col_2 = st.columns(2)
with col_1:
st.button("Add Files", on_click=add_file)
with col_2:
st.button("Clear Files and All Tools",
on_click=clear_files)
st.button("Clear Chat History", on_click=clear_history)
st.button("Logout", on_click=back_to_main)
if "agent" not in st.session_state:
refresh_agent()
print("!!! ", st.session_state.agent.memory.chat_memory.session_id)
for msg in st.session_state.agent.memory.chat_memory.messages:
speaker = "user" if isinstance(msg, HumanMessage) else "assistant"
if isinstance(msg, FunctionMessage):
with st.chat_message("Knowledge Base", avatar="πŸ“–"):
st.write(
f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*"
)
st.write("Retrieved from knowledge base:")
try:
st.dataframe(
pd.DataFrame.from_records(
json.loads(msg.content, cls=CustomJSONDecoder)
),
use_container_width=True,
)
except:
st.write(msg.content)
else:
if len(msg.content) > 0:
with st.chat_message(speaker):
print(type(msg), msg.dict())
st.write(
f"*{datetime.datetime.fromtimestamp(msg.additional_kwargs['timestamp']).isoformat()}*"
)
st.write(f"{msg.content}")
st.session_state["next_round"] = st.empty()
st.chat_input("Input Message", on_submit=on_chat_submit, key="chat_input")