|
from langchain.chains.base import Chain |
|
from langchain.chains.summarize import load_summarize_chain |
|
from langchain.prompts import PromptTemplate |
|
from langchain.schema.language_model import BaseLanguageModel |
|
from langchain.schema.retriever import BaseRetriever |
|
from langchain.schema.runnable import RunnableSequence, RunnablePassthrough |
|
|
|
prompt_template = """Write a concise summary of the following text, based on the user input. |
|
User input: {query} |
|
Text: |
|
``` |
|
{text} |
|
``` |
|
CONCISE SUMMARY:""" |
|
|
|
refine_template = ( |
|
"You are iteratively crafting a summary of the text below based on the user input\n" |
|
"User input: {query}\n" |
|
"We have provided an existing summary up to a certain point: {existing_answer}\n" |
|
"We have the opportunity to refine the existing summary" |
|
"(only if needed) with some more context below.\n" |
|
"------------\n" |
|
"{text}\n" |
|
"------------\n" |
|
"Given the new context, refine the original summary.\n" |
|
"If the context isn't useful, return the original summary.\n" |
|
"If the context is useful, refine the summary to include the new context.\n" |
|
"Your contribution is helping to build a comprehensive summary of a large body of knowledge.\n" |
|
"You do not have the complete context, so do not discard pieces of the original summary." |
|
) |
|
|
|
|
|
def get_summarization_chain( |
|
llm: BaseLanguageModel, |
|
prompt: str, |
|
) -> Chain: |
|
_prompt = PromptTemplate.from_template( |
|
prompt_template, |
|
partial_variables={"query": prompt}, |
|
) |
|
refine_prompt = PromptTemplate.from_template( |
|
refine_template, |
|
partial_variables={"query": prompt}, |
|
) |
|
return load_summarize_chain( |
|
llm=llm, |
|
chain_type="refine", |
|
question_prompt=_prompt, |
|
refine_prompt=refine_prompt, |
|
return_intermediate_steps=False, |
|
input_key="input_documents", |
|
output_key="output_text", |
|
) |
|
|
|
|
|
def get_rag_summarization_chain( |
|
prompt: str, |
|
retriever: BaseRetriever, |
|
llm: BaseLanguageModel, |
|
input_key: str = "prompt", |
|
) -> RunnableSequence: |
|
return ( |
|
{"input_documents": retriever, input_key: RunnablePassthrough()} |
|
| get_summarization_chain(llm, prompt) |
|
| (lambda output: output["output_text"]) |
|
) |
|
|