File size: 4,569 Bytes
e931b70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from typing import List

import pandas as pd
import streamlit as st
from langchain.schema import Document
from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever

from backend.chains.retrieval_qa_with_sources import CustomRetrievalQAWithSourcesChain
from backend.constants.myscale_tables import MYSCALE_TABLES
from backend.constants.variables import CHAINS_RETRIEVERS_MAPPING, DIVIDER_HTML, RetrieverButtons
from backend.callbacks.vector_sql_callbacks import VectorSQLSearchDBCallBackHandler, VectorSQLSearchLLMCallBackHandler
from ui.utils import display
from logger import logger


def process_sql_query(selected_table: str, query_type: str):
    place_holder = st.empty()
    logger.info(
        f"button-1: {st.session_state[RetrieverButtons.vector_sql_query_from_db]}, "
        f"button-2: {st.session_state[RetrieverButtons.vector_sql_query_with_llm]}, "
        f"table: {selected_table}, "
        f"content: {st.session_state.query_sql}"
    )
    with place_holder.expander('πŸͺ΅ Query Log', expanded=True):
        try:
            if query_type == RetrieverButtons.vector_sql_query_from_db:
                callback = VectorSQLSearchDBCallBackHandler()
                vector_sql_retriever: VectorSQLDatabaseChainRetriever = \
                    st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["sql_retriever"]
                relevant_docs: List[Document] = vector_sql_retriever.get_relevant_documents(
                    query=st.session_state.query_sql,
                    callbacks=[callback]
                )

                callback.progress_bar.progress(
                    value=1.0,
                    text="[Question -> LLM -> SQL Statement -> MyScaleDB -> Results] Done! βœ…"
                )

                st.markdown(f"### Vector Search Results from `{selected_table}` \n"
                            f"> Here we get documents from MyScaleDB with given sql statement \n\n")
                display(
                    pd.DataFrame(
                        [{**d.metadata, 'abstract': d.page_content} for d in relevant_docs]
                    )
                )
            elif query_type == RetrieverButtons.vector_sql_query_with_llm:
                callback = VectorSQLSearchLLMCallBackHandler(table=selected_table)
                vector_sql_chain: CustomRetrievalQAWithSourcesChain = \
                    st.session_state[CHAINS_RETRIEVERS_MAPPING][selected_table]["sql_chain"]
                chain_results = vector_sql_chain(
                    inputs=st.session_state.query_sql,
                    callbacks=[callback]
                )

                callback.progress_bar.progress(
                    value=1.0,
                    text="[Question -> LLM -> SQL Statement -> MyScaleDB -> "
                         "(Question,Results) -> LLM -> Results] Done! βœ…"
                )

                documents_reference: List[Document] = chain_results["source_documents"]
                st.markdown(f"### Vector Search Results from `{selected_table}` \n"
                            f"> Here we get documents from MyScaleDB with given sql statement \n\n")
                display(
                    pd.DataFrame(
                        [{**d.metadata, 'abstract': d.page_content} for d in documents_reference]
                    )
                )
                st.markdown(
                    f"### Answer from LLM \n"
                    f"> The response of the LLM when given the vector search results. \n\n"
                )
                st.write(chain_results['answer'])
                st.markdown(
                    f"### References from `{selected_table}`\n"
                    f"> Here shows that which documents used by LLM \n\n"
                )
                if len(chain_results['sources']) == 0:
                    st.write("No documents is used by LLM.")
                else:
                    display(
                        dataframe=pd.DataFrame(
                            [{**d.metadata, 'abstract': d.page_content} for d in chain_results['sources']]
                        ),
                        columns_=['ref_id'] + MYSCALE_TABLES[selected_table].must_have_col_names,
                        index='ref_id'
                    )
            else:
                raise NotImplementedError(f"Unsupported query type: {query_type}")
            st.markdown(DIVIDER_HTML, unsafe_allow_html=True)
        except Exception as e:
            st.write('Oops 😡 Something bad happened...')
            raise e