Joshua Sundance Bailey commited on
Commit
dd9bfbd
1 Parent(s): 0609980

rag summarization

Browse files
langchain-streamlit-demo/app.py CHANGED
@@ -27,7 +27,7 @@ from langsmith.client import Client
27
  from streamlit_feedback import streamlit_feedback
28
 
29
  from qagen import get_rag_qa_gen_chain
30
- from summarize import get_summarization_chain
31
 
32
  __version__ = "0.0.10"
33
 
@@ -421,14 +421,17 @@ if st.session_state.llm:
421
  full_response: Union[str, None]
422
  if use_document_chat:
423
  if document_chat_chain_type == "Summarization":
424
- st.session_state.doc_chain = get_summarization_chain(
425
- st.session_state.llm,
426
  prompt,
 
 
427
  )
428
- full_response = st.session_state.doc_chain.run(
429
- st.session_state.texts,
430
- callbacks=callbacks,
431
- tags=["Streamlit Chat"],
 
 
432
  )
433
 
434
  st.markdown(full_response)
 
27
  from streamlit_feedback import streamlit_feedback
28
 
29
  from qagen import get_rag_qa_gen_chain
30
+ from summarize import get_rag_summarization_chain
31
 
32
  __version__ = "0.0.10"
33
 
 
421
  full_response: Union[str, None]
422
  if use_document_chat:
423
  if document_chat_chain_type == "Summarization":
424
+ st.session_state.doc_chain = get_rag_summarization_chain(
 
425
  prompt,
426
+ st.session_state.retriever,
427
+ st.session_state.llm,
428
  )
429
+ full_response = st.session_state.doc_chain.invoke(
430
+ prompt,
431
+ dict(
432
+ callbacks=callbacks,
433
+ tags=["Streamlit Chat"],
434
+ ),
435
  )
436
 
437
  st.markdown(full_response)
langchain-streamlit-demo/summarize.py CHANGED
@@ -2,6 +2,8 @@ from langchain.chains.base import Chain
2
  from langchain.chains.summarize import load_summarize_chain
3
  from langchain.prompts import PromptTemplate
4
  from langchain.schema.language_model import BaseLanguageModel
 
 
5
 
6
  prompt_template = """Write a concise summary of the following text, based on the user input.
7
  User input: {query}
@@ -49,3 +51,16 @@ def get_summarization_chain(
49
  input_key="input_documents",
50
  output_key="output_text",
51
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from langchain.chains.summarize import load_summarize_chain
3
  from langchain.prompts import PromptTemplate
4
  from langchain.schema.language_model import BaseLanguageModel
5
+ from langchain.schema.retriever import BaseRetriever
6
+ from langchain.schema.runnable import RunnableSequence, RunnablePassthrough
7
 
8
  prompt_template = """Write a concise summary of the following text, based on the user input.
9
  User input: {query}
 
51
  input_key="input_documents",
52
  output_key="output_text",
53
  )
54
+
55
+
56
+ def get_rag_summarization_chain(
57
+ prompt: str,
58
+ retriever: BaseRetriever,
59
+ llm: BaseLanguageModel,
60
+ input_key: str = "prompt",
61
+ ) -> RunnableSequence:
62
+ return (
63
+ {"input_documents": retriever, input_key: RunnablePassthrough()}
64
+ | get_summarization_chain(llm, prompt)
65
+ | (lambda output: output["output_text"])
66
+ )