Spaces:
Running
Running
File size: 1,641 Bytes
e931b70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
from langchain.chains import LLMChain
from langchain.chat_models import ChatOpenAI
from langchain.prompts import ChatPromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate
from langchain.schema import BaseRetriever
import streamlit as st
from backend.chains.retrieval_qa_with_sources import CustomRetrievalQAWithSourcesChain
from backend.chains.stuff_documents import CustomStuffDocumentChain
from backend.constants.myscale_tables import MYSCALE_TABLES
from backend.constants.prompts import COMBINE_PROMPT
from backend.constants.variables import GLOBAL_CONFIG
def build_retrieval_qa_with_sources_chain(
table_name: str,
retriever: BaseRetriever,
chain_name: str = "<chain_name>"
) -> CustomRetrievalQAWithSourcesChain:
with st.spinner(f'Building QA source chain named `{chain_name}` for MyScaleDB/{table_name} ...'):
# Assign ref_id for documents
custom_stuff_document_chain = CustomStuffDocumentChain(
llm_chain=LLMChain(
prompt=COMBINE_PROMPT,
llm=ChatOpenAI(
model_name=GLOBAL_CONFIG.chat_model,
openai_api_key=GLOBAL_CONFIG.openai_api_key,
temperature=0.6
),
),
document_prompt=MYSCALE_TABLES[table_name].doc_prompt,
document_variable_name="summaries",
)
chain = CustomRetrievalQAWithSourcesChain(
retriever=retriever,
combine_documents_chain=custom_stuff_document_chain,
return_source_documents=True,
max_tokens_limit=12000,
)
return chain
|