Spaces:
Running
Running
import time | |
from os import environ | |
from time import sleep | |
import streamlit as st | |
from backend.constants.prompts import DEFAULT_SYSTEM_PROMPT | |
from backend.constants.streamlit_keys import CHAT_KNOWLEDGE_TABLE, CHAT_SESSION_MANAGER, \ | |
CHAT_CURRENT_USER_SESSIONS, EL_SESSION_SELECTOR, USER_PRIVATE_FILES, \ | |
EL_BUILD_KB_WITH_FILES, \ | |
EL_PERSONAL_KB_NAME, EL_PERSONAL_KB_DESCRIPTION, \ | |
USER_PERSONAL_KNOWLEDGE_BASES, AVAILABLE_RETRIEVAL_TOOLS, EL_PERSONAL_KB_NEEDS_REMOVE, \ | |
EL_UPLOAD_FILES_STATUS, EL_SELECTED_KBS, EL_UPLOAD_FILES | |
from backend.constants.variables import USER_INFO, USER_NAME, JUMP_QUERY_ASK, RETRIEVER_TOOLS | |
from backend.construct.build_agents import build_agents | |
from backend.chat_bot.session_manager import SessionManager | |
from backend.callbacks.arxiv_callbacks import ChatDataAgentCallBackHandler | |
from logger import logger | |
environ["OPENAI_API_BASE"] = st.secrets["OPENAI_API_BASE"] | |
TOOL_NAMES = { | |
"langchain_retriever_tool": "Self-querying retriever", | |
"vecsql_retriever_tool": "Vector SQL", | |
} | |
def on_chat_submit(): | |
with st.session_state.next_round.container(): | |
with st.chat_message("user"): | |
st.write(st.session_state.chat_input) | |
with st.chat_message("assistant"): | |
container = st.container() | |
st_callback = ChatDataAgentCallBackHandler( | |
container, collapse_completed_thoughts=False | |
) | |
ret = st.session_state.agent( | |
{"input": st.session_state.chat_input}, callbacks=[st_callback] | |
) | |
logger.info(f"ret:{ret}") | |
def clear_history(): | |
if "agent" in st.session_state: | |
st.session_state.agent.memory.clear() | |
def back_to_main(): | |
if USER_INFO in st.session_state: | |
del st.session_state[USER_INFO] | |
if USER_NAME in st.session_state: | |
del st.session_state[USER_NAME] | |
if JUMP_QUERY_ASK in st.session_state: | |
del st.session_state[JUMP_QUERY_ASK] | |
if EL_SESSION_SELECTOR in st.session_state: | |
del st.session_state[EL_SESSION_SELECTOR] | |
if CHAT_CURRENT_USER_SESSIONS in st.session_state: | |
del st.session_state[CHAT_CURRENT_USER_SESSIONS] | |
def refresh_sessions(): | |
chat_session_manager: SessionManager = st.session_state[CHAT_SESSION_MANAGER] | |
current_user_name = st.session_state[USER_NAME] | |
current_user_sessions = chat_session_manager.list_sessions(current_user_name) | |
if not isinstance(current_user_sessions, dict) or not current_user_sessions: | |
# generate a default session for current user. | |
chat_session_manager.add_session( | |
user_id=current_user_name, | |
session_id=f"{current_user_name}?default", | |
system_prompt=DEFAULT_SYSTEM_PROMPT, | |
) | |
st.session_state[CHAT_CURRENT_USER_SESSIONS] = chat_session_manager.list_sessions(current_user_name) | |
current_user_sessions = st.session_state[CHAT_CURRENT_USER_SESSIONS] | |
else: | |
st.session_state[CHAT_CURRENT_USER_SESSIONS] = current_user_sessions | |
# load current user files. | |
st.session_state[USER_PRIVATE_FILES] = st.session_state[CHAT_KNOWLEDGE_TABLE].list_files( | |
current_user_name | |
) | |
# load current user private knowledge bases. | |
st.session_state[USER_PERSONAL_KNOWLEDGE_BASES] = \ | |
st.session_state[CHAT_KNOWLEDGE_TABLE].list_private_knowledge_bases(current_user_name) | |
logger.info(f"current user name: {current_user_name}, " | |
f"user private knowledge bases: {st.session_state[USER_PERSONAL_KNOWLEDGE_BASES]}, " | |
f"user private files: {st.session_state[USER_PRIVATE_FILES]}") | |
st.session_state[AVAILABLE_RETRIEVAL_TOOLS] = { | |
# public retrieval tools | |
**st.session_state[RETRIEVER_TOOLS], | |
# private retrieval tools | |
**st.session_state[CHAT_KNOWLEDGE_TABLE].as_retrieval_tools(current_user_name), | |
} | |
# print(f"sel_session is {st.session_state.sel_session}, current_user_sessions is {current_user_sessions}") | |
print(f"current_user_sessions is {current_user_sessions}") | |
st.session_state[EL_SESSION_SELECTOR] = current_user_sessions[0] | |
# process for session add and delete. | |
def on_session_change_submit(): | |
if "session_manager" in st.session_state and "session_editor" in st.session_state: | |
try: | |
for elem in st.session_state.session_editor["added_rows"]: | |
if len(elem) > 0 and "system_prompt" in elem and "session_id" in elem: | |
if elem["session_id"] != "" and "?" not in elem["session_id"]: | |
st.session_state.session_manager.add_session( | |
user_id=st.session_state.user_name, | |
session_id=f"{st.session_state.user_name}?{elem['session_id']}", | |
system_prompt=elem["system_prompt"], | |
) | |
else: | |
st.toast("`session_id` shouldn't be neither empty nor contain char `?`.", icon="β") | |
raise KeyError( | |
"`session_id` shouldn't be neither empty nor contain char `?`." | |
) | |
else: | |
st.toast("`You should fill both `session_id` and `system_prompt` to add a column!", icon="β") | |
raise KeyError( | |
"You should fill both `session_id` and `system_prompt` to add a column!" | |
) | |
for elem in st.session_state.session_editor["deleted_rows"]: | |
user_name = st.session_state[USER_NAME] | |
session_id = st.session_state[CHAT_CURRENT_USER_SESSIONS][elem]['session_id'] | |
user_with_session_id = f"{user_name}?{session_id}" | |
st.session_state.session_manager.remove_session(session_id=user_with_session_id) | |
st.toast(f"session `{user_with_session_id}` removed.", icon="β ") | |
refresh_sessions() | |
except Exception as e: | |
sleep(2) | |
st.error(f"{type(e)}: {str(e)}") | |
finally: | |
st.session_state.session_editor["added_rows"] = [] | |
st.session_state.session_editor["deleted_rows"] = [] | |
refresh_agent() | |
def create_private_knowledge_base_as_tool(): | |
current_user_name = st.session_state[USER_NAME] | |
if ( | |
EL_PERSONAL_KB_NAME in st.session_state | |
and EL_PERSONAL_KB_DESCRIPTION in st.session_state | |
and EL_BUILD_KB_WITH_FILES in st.session_state | |
and len(st.session_state[EL_PERSONAL_KB_NAME]) > 0 | |
and len(st.session_state[EL_PERSONAL_KB_DESCRIPTION]) > 0 | |
and len(st.session_state[EL_BUILD_KB_WITH_FILES]) > 0 | |
): | |
st.session_state[CHAT_KNOWLEDGE_TABLE].create_private_knowledge_base( | |
user_id=current_user_name, | |
tool_name=st.session_state[EL_PERSONAL_KB_NAME], | |
tool_description=st.session_state[EL_PERSONAL_KB_DESCRIPTION], | |
files=[f["file_name"] for f in st.session_state[EL_BUILD_KB_WITH_FILES]], | |
) | |
refresh_sessions() | |
else: | |
st.session_state[EL_UPLOAD_FILES_STATUS].error( | |
"You should fill all fields to build up a tool!" | |
) | |
sleep(2) | |
def remove_private_knowledge_bases(): | |
if EL_PERSONAL_KB_NEEDS_REMOVE in st.session_state and st.session_state[EL_PERSONAL_KB_NEEDS_REMOVE]: | |
private_knowledge_bases_needs_remove = st.session_state[EL_PERSONAL_KB_NEEDS_REMOVE] | |
private_knowledge_base_names = [item["tool_name"] for item in private_knowledge_bases_needs_remove] | |
# remove these private knowledge bases. | |
st.session_state[CHAT_KNOWLEDGE_TABLE].remove_private_knowledge_bases( | |
user_id=st.session_state[USER_NAME], | |
private_knowledge_bases=private_knowledge_base_names | |
) | |
refresh_sessions() | |
else: | |
st.session_state[EL_UPLOAD_FILES_STATUS].error( | |
"You should specify at least one private knowledge base to delete!" | |
) | |
time.sleep(2) | |
def refresh_agent(): | |
with st.spinner("Initializing session..."): | |
user_name = st.session_state[USER_NAME] | |
session_id = st.session_state[EL_SESSION_SELECTOR]['session_id'] | |
user_with_session_id = f"{user_name}?{session_id}" | |
if EL_SELECTED_KBS in st.session_state: | |
selected_knowledge_bases = st.session_state[EL_SELECTED_KBS] | |
else: | |
selected_knowledge_bases = ["Wikipedia + Vector SQL"] | |
logger.info(f"selected_knowledge_bases: {selected_knowledge_bases}") | |
if EL_SESSION_SELECTOR in st.session_state: | |
system_prompt = st.session_state[EL_SESSION_SELECTOR]["system_prompt"] | |
else: | |
system_prompt = DEFAULT_SYSTEM_PROMPT | |
st.session_state["agent"] = build_agents( | |
session_id=user_with_session_id, | |
tool_names=selected_knowledge_bases, | |
system_prompt=system_prompt | |
) | |
def add_file(): | |
user_name = st.session_state[USER_NAME] | |
if EL_UPLOAD_FILES not in st.session_state or len(st.session_state[EL_UPLOAD_FILES]) == 0: | |
st.session_state[EL_UPLOAD_FILES_STATUS].error("Please upload files!", icon="β οΈ") | |
sleep(2) | |
return | |
try: | |
st.session_state[EL_UPLOAD_FILES_STATUS].info("Uploading...") | |
st.session_state[CHAT_KNOWLEDGE_TABLE].add_by_file( | |
user_id=user_name, | |
files=st.session_state[EL_UPLOAD_FILES] | |
) | |
refresh_sessions() | |
except ValueError as e: | |
st.session_state[EL_UPLOAD_FILES_STATUS].error("Failed to upload! " + str(e)) | |
sleep(2) | |
def clear_files(): | |
st.session_state[CHAT_KNOWLEDGE_TABLE].clear(user_id=st.session_state[USER_NAME]) | |
refresh_sessions() | |