lqhl's picture
Synced repo using 'sync_with_huggingface' Github Action
e931b70 verified
import time
from os import environ
from time import sleep
import streamlit as st
from backend.constants.prompts import DEFAULT_SYSTEM_PROMPT
from backend.constants.streamlit_keys import CHAT_KNOWLEDGE_TABLE, CHAT_SESSION_MANAGER, \
CHAT_CURRENT_USER_SESSIONS, EL_SESSION_SELECTOR, USER_PRIVATE_FILES, \
EL_BUILD_KB_WITH_FILES, \
EL_PERSONAL_KB_NAME, EL_PERSONAL_KB_DESCRIPTION, \
USER_PERSONAL_KNOWLEDGE_BASES, AVAILABLE_RETRIEVAL_TOOLS, EL_PERSONAL_KB_NEEDS_REMOVE, \
EL_UPLOAD_FILES_STATUS, EL_SELECTED_KBS, EL_UPLOAD_FILES
from backend.constants.variables import USER_INFO, USER_NAME, JUMP_QUERY_ASK, RETRIEVER_TOOLS
from backend.construct.build_agents import build_agents
from backend.chat_bot.session_manager import SessionManager
from backend.callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler
from logger import logger
environ["OPENAI_API_BASE"] = st.secrets["OPENAI_API_BASE"]
TOOL_NAMES = {
"langchain_retriever_tool": "Self-querying retriever",
"vecsql_retriever_tool": "Vector SQL",
}
def on_chat_submit():
with st.session_state.next_round.container():
with st.chat_message("user"):
st.write(st.session_state.chat_input)
with st.chat_message("assistant"):
container = st.container()
st_callback = ChatDataAgentCallBackHandler(
container, collapse_completed_thoughts=False
)
ret = st.session_state.agent(
{"input": st.session_state.chat_input}, callbacks=[st_callback]
)
logger.info(f"ret:{ret}")
def clear_history():
if "agent" in st.session_state:
st.session_state.agent.memory.clear()
def back_to_main():
if USER_INFO in st.session_state:
del st.session_state[USER_INFO]
if USER_NAME in st.session_state:
del st.session_state[USER_NAME]
if JUMP_QUERY_ASK in st.session_state:
del st.session_state[JUMP_QUERY_ASK]
if EL_SESSION_SELECTOR in st.session_state:
del st.session_state[EL_SESSION_SELECTOR]
if CHAT_CURRENT_USER_SESSIONS in st.session_state:
del st.session_state[CHAT_CURRENT_USER_SESSIONS]
def refresh_sessions():
chat_session_manager: SessionManager = st.session_state[CHAT_SESSION_MANAGER]
current_user_name = st.session_state[USER_NAME]
current_user_sessions = chat_session_manager.list_sessions(current_user_name)
if not isinstance(current_user_sessions, dict) or not current_user_sessions:
# generate a default session for current user.
chat_session_manager.add_session(
user_id=current_user_name,
session_id=f"{current_user_name}?default",
system_prompt=DEFAULT_SYSTEM_PROMPT,
)
st.session_state[CHAT_CURRENT_USER_SESSIONS] = chat_session_manager.list_sessions(current_user_name)
current_user_sessions = st.session_state[CHAT_CURRENT_USER_SESSIONS]
else:
st.session_state[CHAT_CURRENT_USER_SESSIONS] = current_user_sessions
# load current user files.
st.session_state[USER_PRIVATE_FILES] = st.session_state[CHAT_KNOWLEDGE_TABLE].list_files(
current_user_name
)
# load current user private knowledge bases.
st.session_state[USER_PERSONAL_KNOWLEDGE_BASES] = \
st.session_state[CHAT_KNOWLEDGE_TABLE].list_private_knowledge_bases(current_user_name)
logger.info(f"current user name: {current_user_name}, "
f"user private knowledge bases: {st.session_state[USER_PERSONAL_KNOWLEDGE_BASES]}, "
f"user private files: {st.session_state[USER_PRIVATE_FILES]}")
st.session_state[AVAILABLE_RETRIEVAL_TOOLS] = {
# public retrieval tools
**st.session_state[RETRIEVER_TOOLS],
# private retrieval tools
**st.session_state[CHAT_KNOWLEDGE_TABLE].as_retrieval_tools(current_user_name),
}
# print(f"sel_session is {st.session_state.sel_session}, current_user_sessions is {current_user_sessions}")
print(f"current_user_sessions is {current_user_sessions}")
st.session_state[EL_SESSION_SELECTOR] = current_user_sessions[0]
# process for session add and delete.
def on_session_change_submit():
if "session_manager" in st.session_state and "session_editor" in st.session_state:
try:
for elem in st.session_state.session_editor["added_rows"]:
if len(elem) > 0 and "system_prompt" in elem and "session_id" in elem:
if elem["session_id"] != "" and "?" not in elem["session_id"]:
st.session_state.session_manager.add_session(
user_id=st.session_state.user_name,
session_id=f"{st.session_state.user_name}?{elem['session_id']}",
system_prompt=elem["system_prompt"],
)
else:
st.toast("`session_id` shouldn't be neither empty nor contain char `?`.", icon="❌")
raise KeyError(
"`session_id` shouldn't be neither empty nor contain char `?`."
)
else:
st.toast("`You should fill both `session_id` and `system_prompt` to add a column!", icon="❌")
raise KeyError(
"You should fill both `session_id` and `system_prompt` to add a column!"
)
for elem in st.session_state.session_editor["deleted_rows"]:
user_name = st.session_state[USER_NAME]
session_id = st.session_state[CHAT_CURRENT_USER_SESSIONS][elem]['session_id']
user_with_session_id = f"{user_name}?{session_id}"
st.session_state.session_manager.remove_session(session_id=user_with_session_id)
st.toast(f"session `{user_with_session_id}` removed.", icon="βœ…")
refresh_sessions()
except Exception as e:
sleep(2)
st.error(f"{type(e)}: {str(e)}")
finally:
st.session_state.session_editor["added_rows"] = []
st.session_state.session_editor["deleted_rows"] = []
refresh_agent()
def create_private_knowledge_base_as_tool():
current_user_name = st.session_state[USER_NAME]
if (
EL_PERSONAL_KB_NAME in st.session_state
and EL_PERSONAL_KB_DESCRIPTION in st.session_state
and EL_BUILD_KB_WITH_FILES in st.session_state
and len(st.session_state[EL_PERSONAL_KB_NAME]) > 0
and len(st.session_state[EL_PERSONAL_KB_DESCRIPTION]) > 0
and len(st.session_state[EL_BUILD_KB_WITH_FILES]) > 0
):
st.session_state[CHAT_KNOWLEDGE_TABLE].create_private_knowledge_base(
user_id=current_user_name,
tool_name=st.session_state[EL_PERSONAL_KB_NAME],
tool_description=st.session_state[EL_PERSONAL_KB_DESCRIPTION],
files=[f["file_name"] for f in st.session_state[EL_BUILD_KB_WITH_FILES]],
)
refresh_sessions()
else:
st.session_state[EL_UPLOAD_FILES_STATUS].error(
"You should fill all fields to build up a tool!"
)
sleep(2)
def remove_private_knowledge_bases():
if EL_PERSONAL_KB_NEEDS_REMOVE in st.session_state and st.session_state[EL_PERSONAL_KB_NEEDS_REMOVE]:
private_knowledge_bases_needs_remove = st.session_state[EL_PERSONAL_KB_NEEDS_REMOVE]
private_knowledge_base_names = [item["tool_name"] for item in private_knowledge_bases_needs_remove]
# remove these private knowledge bases.
st.session_state[CHAT_KNOWLEDGE_TABLE].remove_private_knowledge_bases(
user_id=st.session_state[USER_NAME],
private_knowledge_bases=private_knowledge_base_names
)
refresh_sessions()
else:
st.session_state[EL_UPLOAD_FILES_STATUS].error(
"You should specify at least one private knowledge base to delete!"
)
time.sleep(2)
def refresh_agent():
with st.spinner("Initializing session..."):
user_name = st.session_state[USER_NAME]
session_id = st.session_state[EL_SESSION_SELECTOR]['session_id']
user_with_session_id = f"{user_name}?{session_id}"
if EL_SELECTED_KBS in st.session_state:
selected_knowledge_bases = st.session_state[EL_SELECTED_KBS]
else:
selected_knowledge_bases = ["Wikipedia + Vector SQL"]
logger.info(f"selected_knowledge_bases: {selected_knowledge_bases}")
if EL_SESSION_SELECTOR in st.session_state:
system_prompt = st.session_state[EL_SESSION_SELECTOR]["system_prompt"]
else:
system_prompt = DEFAULT_SYSTEM_PROMPT
st.session_state["agent"] = build_agents(
session_id=user_with_session_id,
tool_names=selected_knowledge_bases,
system_prompt=system_prompt
)
def add_file():
user_name = st.session_state[USER_NAME]
if EL_UPLOAD_FILES not in st.session_state or len(st.session_state[EL_UPLOAD_FILES]) == 0:
st.session_state[EL_UPLOAD_FILES_STATUS].error("Please upload files!", icon="⚠️")
sleep(2)
return
try:
st.session_state[EL_UPLOAD_FILES_STATUS].info("Uploading...")
st.session_state[CHAT_KNOWLEDGE_TABLE].add_by_file(
user_id=user_name,
files=st.session_state[EL_UPLOAD_FILES]
)
refresh_sessions()
except ValueError as e:
st.session_state[EL_UPLOAD_FILES_STATUS].error("Failed to upload! " + str(e))
sleep(2)
def clear_files():
st.session_state[CHAT_KNOWLEDGE_TABLE].clear(user_id=st.session_state[USER_NAME])
refresh_sessions()