File size: 5,564 Bytes
e931b70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import streamlit as st
from langchain.chat_models import ChatOpenAI
from langchain.prompts.prompt import PromptTemplate
from langchain.retrievers.self_query.base import SelfQueryRetriever
from langchain.retrievers.self_query.myscale import MyScaleTranslator
from langchain.utilities.sql_database import SQLDatabase
from langchain.vectorstores import MyScaleSettings
from langchain_experimental.retrievers.vector_sql_database import VectorSQLDatabaseChainRetriever
from langchain_experimental.sql.vector_sql import VectorSQLDatabaseChain
from sqlalchemy import create_engine, MetaData

from backend.constants.myscale_tables import MYSCALE_TABLES
from backend.constants.prompts import MYSCALE_PROMPT
from backend.constants.variables import TABLE_EMBEDDINGS_MAPPING, GLOBAL_CONFIG
from backend.retrievers.vector_sql_output_parser import VectorSQLRetrieveOutputParser
from backend.vector_store.myscale_without_metadata import MyScaleWithoutMetadataJson
from logger import logger


@st.cache_resource
def build_self_query_retriever(table_name: str) -> SelfQueryRetriever:
    with st.spinner(f"Building VectorStore for MyScaleDB/{table_name} ..."):
        myscale_connection = {
            "host": GLOBAL_CONFIG.myscale_host,
            "port": GLOBAL_CONFIG.myscale_port,
            "username": GLOBAL_CONFIG.myscale_user,
            "password": GLOBAL_CONFIG.myscale_password,
        }
        myscale_settings = MyScaleSettings(
            **myscale_connection,
            database=MYSCALE_TABLES[table_name].database,
            table=MYSCALE_TABLES[table_name].table,
            column_map={
                "id": "id",
                "text": MYSCALE_TABLES[table_name].text_col_name,
                "vector": MYSCALE_TABLES[table_name].vector_col_name,
                # TODO refine MyScaleDB metadata in langchain.
                "metadata": MYSCALE_TABLES[table_name].metadata_col_name
            }
        )
        myscale_vector_store = MyScaleWithoutMetadataJson(
            embedding=st.session_state[TABLE_EMBEDDINGS_MAPPING][table_name],
            config=myscale_settings,
            must_have_cols=MYSCALE_TABLES[table_name].must_have_col_names
        )

    with st.spinner(f"Building SelfQueryRetriever for MyScaleDB/{table_name} ..."):
        retriever: SelfQueryRetriever = SelfQueryRetriever.from_llm(
            llm=ChatOpenAI(
                model_name=GLOBAL_CONFIG.query_model,
                base_url=GLOBAL_CONFIG.openai_api_base,
                api_key=GLOBAL_CONFIG.openai_api_key,
                temperature=0
            ),
            vectorstore=myscale_vector_store,
            document_contents=MYSCALE_TABLES[table_name].table_contents,
            metadata_field_info=MYSCALE_TABLES[table_name].metadata_col_attributes,
            use_original_query=False,
            structured_query_translator=MyScaleTranslator()
        )
    return retriever


@st.cache_resource
def build_vector_sql_db_chain_retriever(table_name: str) -> VectorSQLDatabaseChainRetriever:
    """Get a group of relative docs from MyScaleDB"""
    with st.spinner(f'Building Vector SQL Database Retriever for MyScaleDB/{table_name}...'):
        if GLOBAL_CONFIG.myscale_enable_https == False:
            engine = create_engine(
                f'clickhouse://{GLOBAL_CONFIG.myscale_user}:{GLOBAL_CONFIG.myscale_password}@'
                f'{GLOBAL_CONFIG.myscale_host}:{GLOBAL_CONFIG.myscale_port}'
                f'/{MYSCALE_TABLES[table_name].database}?protocol=http'
            )
        else:
            engine = create_engine(
                f'clickhouse://{GLOBAL_CONFIG.myscale_user}:{GLOBAL_CONFIG.myscale_password}@'
                f'{GLOBAL_CONFIG.myscale_host}:{GLOBAL_CONFIG.myscale_port}'
                f'/{MYSCALE_TABLES[table_name].database}?protocol=https'
            )
        metadata = MetaData(bind=engine)
        logger.info(f"{table_name} metadata is : {metadata}")
        prompt = PromptTemplate(
            input_variables=["input", "table_info", "top_k"],
            template=MYSCALE_PROMPT,
        )
        # Custom `out_put_parser` rewrite search SQL, make it's possible to query custom column.
        output_parser = VectorSQLRetrieveOutputParser.from_embeddings(
            model=st.session_state[TABLE_EMBEDDINGS_MAPPING][table_name],
            # rewrite columns needs be searched.
            must_have_columns=MYSCALE_TABLES[table_name].must_have_col_names
        )

        # `db_chain` will generate a SQL
        vector_sql_db_chain: VectorSQLDatabaseChain = VectorSQLDatabaseChain.from_llm(
            llm=ChatOpenAI(
                model_name=GLOBAL_CONFIG.query_model,
                base_url=GLOBAL_CONFIG.openai_api_base,
                api_key=GLOBAL_CONFIG.openai_api_key,
                temperature=0
            ),
            prompt=prompt,
            top_k=10,
            return_direct=True,
            db=SQLDatabase(
                engine,
                None,
                metadata,
                include_tables=[MYSCALE_TABLES[table_name].table],
                max_string_length=1024
            ),
            sql_cmd_parser=output_parser,  # TODO needs update `langchain`, fix return type.
            native_format=True
        )

        # `retriever` can search a group of documents with `db_chain`
        vector_sql_db_chain_retriever = VectorSQLDatabaseChainRetriever(
            sql_db_chain=vector_sql_db_chain,
            page_content_key=MYSCALE_TABLES[table_name].text_col_name
        )
    return vector_sql_db_chain_retriever