ChatData / ui /retrievers.py
lqhl's picture
Synced repo using 'sync_with_huggingface' Github Action
e931b70 verified
raw
history blame
3.87 kB
import streamlit as st
from streamlit_extras.add_vertical_space import add_vertical_space
from backend.constants.myscale_tables import MYSCALE_TABLES
from backend.constants.variables import CHAINS_RETRIEVERS_MAPPING, RetrieverButtons
from backend.retrievers.self_query import process_self_query
from backend.retrievers.vector_sql_query import process_sql_query
from backend.constants.variables import JUMP_QUERY_ASK, USER_NAME, USER_INFO
def back_to_main():
if USER_INFO in st.session_state:
del st.session_state[USER_INFO]
if USER_NAME in st.session_state:
del st.session_state[USER_NAME]
if JUMP_QUERY_ASK in st.session_state:
del st.session_state[JUMP_QUERY_ASK]
def _render_table_selector() -> str:
col1, col2 = st.columns(2)
with col1:
selected_table = st.selectbox(
label='Each public knowledge base is stored in a MyScaleDB table, which is read-only.',
options=MYSCALE_TABLES.keys(),
)
MYSCALE_TABLES[selected_table].hint()
with col2:
add_vertical_space(1)
st.info(f"Here is your selected public knowledge base schema in MyScaleDB",
icon='📚')
MYSCALE_TABLES[selected_table].hint_sql()
return selected_table
def render_retrievers():
st.button("⬅️ Back", key="back_sql", on_click=back_to_main)
st.subheader('Please choose a public knowledge base to search.')
selected_table = _render_table_selector()
tab_sql, tab_self_query = st.tabs(
tabs=['Vector SQL', 'Self-querying Retriever']
)
with tab_sql:
render_tab_sql(selected_table)
with tab_self_query:
render_tab_self_query(selected_table)
def render_tab_sql(selected_table: str):
st.warning(
"When you input a query with filtering conditions, you need to ensure that your filters are applied only to "
"the metadata we provide. This table allows filters to be established on the following metadata fields:",
icon="⚠️")
st.dataframe(st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["metadata_columns"])
cols = st.columns([8, 3, 3, 2])
cols[0].text_input("Input your question:", key='query_sql')
with cols[1].container():
add_vertical_space(2)
st.button("Retrieve from MyScaleDB ➡️", key=RetrieverButtons.vector_sql_query_from_db)
with cols[2].container():
add_vertical_space(2)
st.button("Retrieve and answer with LLM ➡️", key=RetrieverButtons.vector_sql_query_with_llm)
if st.session_state[RetrieverButtons.vector_sql_query_from_db]:
process_sql_query(selected_table, RetrieverButtons.vector_sql_query_from_db)
if st.session_state[RetrieverButtons.vector_sql_query_with_llm]:
process_sql_query(selected_table, RetrieverButtons.vector_sql_query_with_llm)
def render_tab_self_query(selected_table):
st.warning(
"When you input a query with filtering conditions, you need to ensure that your filters are applied only to "
"the metadata we provide. This table allows filters to be established on the following metadata fields:",
icon="⚠️")
st.dataframe(st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["metadata_columns"])
cols = st.columns([8, 3, 3, 2])
cols[0].text_input("Input your question:", key='query_self')
with cols[1].container():
add_vertical_space(2)
st.button("Retrieve from MyScaleDB ➡️", key='search_self')
with cols[2].container():
add_vertical_space(2)
st.button("Retrieve and answer with LLM ➡️", key='ask_self')
if st.session_state.search_self:
process_self_query(selected_table, RetrieverButtons.self_query_from_db)
if st.session_state.ask_self:
process_self_query(selected_table, RetrieverButtons.self_query_with_llm)