Spaces:
Running
Running
from logger import logger | |
from typing import Dict, Any, Union | |
import streamlit as st | |
from backend.constants.myscale_tables import MYSCALE_TABLES | |
from backend.constants.variables import CHAINS_RETRIEVERS_MAPPING | |
from backend.construct.build_chains import build_retrieval_qa_with_sources_chain | |
from backend.construct.build_retriever_tool import create_retriever_tool | |
from backend.construct.build_retrievers import build_self_query_retriever, build_vector_sql_db_chain_retriever | |
from backend.types.chains_and_retrievers import ChainsAndRetrievers, MetadataColumn | |
from langchain_community.embeddings import HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings, \ | |
SentenceTransformerEmbeddings | |
def load_embedding_model_for_table(table_name: str) -> \ | |
Union[SentenceTransformerEmbeddings, HuggingFaceInstructEmbeddings]: | |
with st.spinner(f"Loading embedding models for [{table_name}] ..."): | |
embeddings = MYSCALE_TABLES[table_name].emb_model() | |
return embeddings | |
def load_embedding_models() -> Dict[str, Union[HuggingFaceEmbeddings, HuggingFaceInstructEmbeddings]]: | |
embedding_models = {} | |
for table in MYSCALE_TABLES: | |
embedding_models[table] = load_embedding_model_for_table(table) | |
return embedding_models | |
def update_retriever_tools(): | |
retrievers_tools = {} | |
for table in MYSCALE_TABLES: | |
logger.info(f"Updating retriever tools [<retriever>, <sql_retriever>] for table {table}") | |
retrievers_tools.update( | |
{ | |
f"{table} + Self Querying": create_retriever_tool( | |
st.session_state[CHAINS_RETRIEVERS_MAPPING][table]["retriever"], | |
*MYSCALE_TABLES[table].tool_desc | |
), | |
f"{table} + Vector SQL": create_retriever_tool( | |
st.session_state[CHAINS_RETRIEVERS_MAPPING][table]["sql_retriever"], | |
*MYSCALE_TABLES[table].tool_desc | |
), | |
}) | |
return retrievers_tools | |
def build_chains_retriever_for_table(table_name: str) -> ChainsAndRetrievers: | |
metadata_col_attributes = MYSCALE_TABLES[table_name].metadata_col_attributes | |
self_query_retriever = build_self_query_retriever(table_name) | |
self_query_chain = build_retrieval_qa_with_sources_chain( | |
table_name=table_name, | |
retriever=self_query_retriever, | |
chain_name="Self Query Retriever" | |
) | |
vector_sql_retriever = build_vector_sql_db_chain_retriever(table_name) | |
vector_sql_chain = build_retrieval_qa_with_sources_chain( | |
table_name=table_name, | |
retriever=vector_sql_retriever, | |
chain_name="Vector SQL DB Retriever" | |
) | |
metadata_columns = [ | |
MetadataColumn( | |
name=attribute.name, | |
desc=attribute.description, | |
type=attribute.type | |
) | |
for attribute in metadata_col_attributes | |
] | |
return ChainsAndRetrievers( | |
metadata_columns=metadata_columns, | |
# for self query | |
retriever=self_query_retriever, | |
chain=self_query_chain, | |
# for vector sql | |
sql_retriever=vector_sql_retriever, | |
sql_chain=vector_sql_chain | |
) | |
def build_chains_and_retrievers() -> Dict[str, Dict[str, Any]]: | |
chains_and_retrievers = {} | |
for table in MYSCALE_TABLES: | |
logger.info(f"Building chains, retrievers for table {table}") | |
chains_and_retrievers[table] = build_chains_retriever_for_table(table).to_dict() | |
return chains_and_retrievers | |