Spaces:
Running
Running
File size: 3,480 Bytes
e931b70 a796108 e931b70 fab8405 e931b70 fab8405 e931b70 e4853cf a796108 e931b70 45180a0 0e573d0 e931b70 9061790 e931b70 9061790 e931b70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import os
import time
import streamlit as st
from backend.constants.streamlit_keys import DATA_INITIALIZE_NOT_STATED, DATA_INITIALIZE_COMPLETED, \
DATA_INITIALIZE_STARTED
from backend.constants.variables import DATA_INITIALIZE_STATUS, JUMP_QUERY_ASK, CHAINS_RETRIEVERS_MAPPING, \
TABLE_EMBEDDINGS_MAPPING, RETRIEVER_TOOLS, USER_NAME, GLOBAL_CONFIG, update_global_config
from backend.construct.build_all import build_chains_and_retrievers, load_embedding_models, update_retriever_tools
from backend.types.global_config import GlobalConfig
from logger import logger
from ui.chat_page import chat_page
from ui.home import render_home
from ui.retrievers import render_retrievers
# warnings.filterwarnings("ignore", category=UserWarning)
def prepare_environment():
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
os.environ["LANGCHAIN_TRACING_V2"] = "false"
# os.environ["LANGCHAIN_API_KEY"] = ""
os.environ["OPENAI_API_BASE"] = st.secrets['OPENAI_API_BASE']
os.environ["OPENAI_API_KEY"] = st.secrets['OPENAI_API_KEY']
os.environ["AUTH0_CLIENT_ID"] = st.secrets['AUTH0_CLIENT_ID']
os.environ["AUTH0_DOMAIN"] = st.secrets['AUTH0_DOMAIN']
update_global_config(GlobalConfig(
openai_api_base=st.secrets['OPENAI_API_BASE'],
openai_api_key=st.secrets['OPENAI_API_KEY'],
auth0_client_id=st.secrets['AUTH0_CLIENT_ID'],
auth0_domain=st.secrets['AUTH0_DOMAIN'],
myscale_user=st.secrets['MYSCALE_USER'],
myscale_password=st.secrets['MYSCALE_PASSWORD'],
myscale_host=st.secrets['MYSCALE_HOST'],
myscale_port=st.secrets['MYSCALE_PORT'],
query_model="gpt-3.5-turbo-0125",
chat_model="gpt-3.5-turbo-0125",
untrusted_api=st.secrets['UNSTRUCTURED_API'],
myscale_enable_https=st.secrets.get('MYSCALE_ENABLE_HTTPS', True),
))
# when refresh browser, all session keys will be cleaned.
def initialize_session_state():
if DATA_INITIALIZE_STATUS not in st.session_state:
st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_NOT_STATED
logger.info(f"Initialize session state key: {DATA_INITIALIZE_STATUS}")
if JUMP_QUERY_ASK not in st.session_state:
st.session_state[JUMP_QUERY_ASK] = False
logger.info(f"Initialize session state key: {JUMP_QUERY_ASK}")
def initialize_chat_data():
if st.session_state[DATA_INITIALIZE_STATUS] != DATA_INITIALIZE_COMPLETED:
start_time = time.time()
st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_STARTED
st.session_state[TABLE_EMBEDDINGS_MAPPING] = load_embedding_models()
st.session_state[CHAINS_RETRIEVERS_MAPPING] = build_chains_and_retrievers()
st.session_state[RETRIEVER_TOOLS] = update_retriever_tools()
# mark data initialization finished.
st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_COMPLETED
end_time = time.time()
logger.info(f"ChatData initialized finished in {round(end_time - start_time, 3)} seconds, "
f"session state keys: {list(st.session_state.keys())}")
st.set_page_config(
page_title="ChatData",
page_icon="https://myscale.com/favicon.ico",
initial_sidebar_state="expanded",
layout="wide",
)
prepare_environment()
initialize_session_state()
initialize_chat_data()
if USER_NAME in st.session_state:
chat_page()
else:
if st.session_state[JUMP_QUERY_ASK]:
render_retrievers()
else:
render_home()
|