File size: 1,641 Bytes
e931b70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
from langchain.schema import BaseRetriever
import streamlit as st

from backend.chains.retrieval_qa_with_sources import CustomRetrievalQAWithSourcesChain
from backend.chains.stuff_documents import CustomStuffDocumentChain
from backend.constants.myscale_tables import MYSCALE_TABLES
from backend.constants.prompts import COMBINE_PROMPT
from backend.constants.variables import GLOBAL_CONFIG


def build_retrieval_qa_with_sources_chain(
        table_name: str,
        retriever: BaseRetriever,
        chain_name: str = "<chain_name>"
) -> CustomRetrievalQAWithSourcesChain:
    with st.spinner(f'Building QA source chain named `{chain_name}` for MyScaleDB/{table_name} ...'):
        # Assign ref_id for documents
        custom_stuff_document_chain = CustomStuffDocumentChain(
            llm_chain=LLMChain(
                prompt=COMBINE_PROMPT,
                llm=ChatOpenAI(
                    model_name=GLOBAL_CONFIG.chat_model,
                    openai_api_key=GLOBAL_CONFIG.openai_api_key,
                    temperature=0.6
                ),
            ),
            document_prompt=MYSCALE_TABLES[table_name].doc_prompt,
            document_variable_name="summaries",
        )
        chain = CustomRetrievalQAWithSourcesChain(
            retriever=retriever,
            combine_documents_chain=custom_stuff_document_chain,
            return_source_documents=True,
            max_tokens_limit=12000,
        )
    return chain