Joshua Sundance Bailey commited on
Commit
47c2ffc
1 Parent(s): 622ac66

qagen & summarize

Browse files
langchain-streamlit-demo/app.py CHANGED
@@ -7,12 +7,12 @@ import anthropic
7
  import langsmith.utils
8
  import openai
9
  import streamlit as st
10
- from langchain import LLMChain
11
  from langchain.callbacks import StreamlitCallbackHandler
12
  from langchain.callbacks.base import BaseCallbackHandler
13
  from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
14
  from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
15
  from langchain.chains import RetrievalQA
 
16
  from langchain.chat_models import ChatOpenAI, ChatAnyscale, ChatAnthropic
17
  from langchain.document_loaders import PyPDFLoader
18
  from langchain.embeddings import OpenAIEmbeddings
@@ -26,6 +26,7 @@ from langsmith.client import Client
26
  from streamlit_feedback import streamlit_feedback
27
 
28
  from qagen import get_qa_gen_chain, combine_qa_pair_lists
 
29
 
30
  __version__ = "0.0.6"
31
 
@@ -216,7 +217,14 @@ with sidebar:
216
  )
217
  document_chat_chain_type = st.selectbox(
218
  label="Document Chat Chain Type",
219
- options=["stuff", "refine", "map_reduce", "map_rerank", "Q&A Generation"],
 
 
 
 
 
 
 
220
  index=0,
221
  help=chain_type_help,
222
  disabled=not document_chat,
@@ -331,13 +339,7 @@ if st.session_state.llm:
331
  # --- Document Chat ---
332
  if st.session_state.retriever:
333
  if document_chat_chain_type == "Summarization":
334
- raise NotImplementedError
335
- # st.session_state.doc_chain = RetrievalQA.from_chain_type(
336
- # llm=st.session_state.llm,
337
- # chain_type=chain_type,
338
- # retriever=st.session_state.retriever,
339
- # memory=MEMORY,
340
- # )
341
  elif document_chat_chain_type == "Q&A Generation":
342
  st.session_state.doc_chain = get_qa_gen_chain(st.session_state.llm)
343
 
@@ -393,7 +395,17 @@ if st.session_state.llm:
393
  full_response: Union[str, None]
394
  if use_document_chat:
395
  if document_chat_chain_type == "Summarization":
396
- raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
397
  elif document_chat_chain_type == "Q&A Generation":
398
  config: Dict[str, Any] = dict(
399
  callbacks=callbacks,
@@ -409,14 +421,21 @@ if st.session_state.llm:
409
  config,
410
  )
411
  results = combine_qa_pair_lists(raw_results).QuestionAnswerPairs
412
- full_response = "\n".join(
413
- f"**Q:** {result.question}\n**A:** {result.answer}\n"
414
- for result in results
 
 
 
 
 
 
 
 
 
415
  )
416
- for idx, result in enumerate(results, start=1):
417
- st.markdown(f"{idx}. **Q:** {result.question}")
418
- st.markdown(f"{idx}. **A:** {result.answer}")
419
- st.markdown("\n")
420
 
421
  else:
422
  st_handler = StreamlitCallbackHandler(st.container())
 
7
  import langsmith.utils
8
  import openai
9
  import streamlit as st
 
10
  from langchain.callbacks import StreamlitCallbackHandler
11
  from langchain.callbacks.base import BaseCallbackHandler
12
  from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
13
  from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
14
  from langchain.chains import RetrievalQA
15
+ from langchain.chains.llm import LLMChain
16
  from langchain.chat_models import ChatOpenAI, ChatAnyscale, ChatAnthropic
17
  from langchain.document_loaders import PyPDFLoader
18
  from langchain.embeddings import OpenAIEmbeddings
 
26
  from streamlit_feedback import streamlit_feedback
27
 
28
  from qagen import get_qa_gen_chain, combine_qa_pair_lists
29
+ from summarize import get_summarization_chain
30
 
31
  __version__ = "0.0.6"
32
 
 
217
  )
218
  document_chat_chain_type = st.selectbox(
219
  label="Document Chat Chain Type",
220
+ options=[
221
+ "stuff",
222
+ "refine",
223
+ "map_reduce",
224
+ "map_rerank",
225
+ "Q&A Generation",
226
+ "Summarization",
227
+ ],
228
  index=0,
229
  help=chain_type_help,
230
  disabled=not document_chat,
 
339
  # --- Document Chat ---
340
  if st.session_state.retriever:
341
  if document_chat_chain_type == "Summarization":
342
+ st.session_state.doc_chain = "summarization"
 
 
 
 
 
 
343
  elif document_chat_chain_type == "Q&A Generation":
344
  st.session_state.doc_chain = get_qa_gen_chain(st.session_state.llm)
345
 
 
395
  full_response: Union[str, None]
396
  if use_document_chat:
397
  if document_chat_chain_type == "Summarization":
398
+ st.session_state.doc_chain = get_summarization_chain(
399
+ st.session_state.llm,
400
+ prompt,
401
+ )
402
+ full_response = st.session_state.doc_chain.run(
403
+ st.session_state.texts,
404
+ callbacks=callbacks,
405
+ tags=["Streamlit Chat"],
406
+ )
407
+
408
+ st.markdown(full_response)
409
  elif document_chat_chain_type == "Q&A Generation":
410
  config: Dict[str, Any] = dict(
411
  callbacks=callbacks,
 
421
  config,
422
  )
423
  results = combine_qa_pair_lists(raw_results).QuestionAnswerPairs
424
+
425
+ def _to_str(idx, qap):
426
+ question_piece = f"{idx}. **Q:** {qap.question}"
427
+ whitespace = " " * (len(str(idx)) + 2)
428
+ answer_piece = f"{whitespace}**A:** {qap.answer}"
429
+ return f"{question_piece}\n{answer_piece}"
430
+
431
+ output_text = "\n\n".join(
432
+ [
433
+ _to_str(idx, qap)
434
+ for idx, qap in enumerate(results, start=1)
435
+ ],
436
  )
437
+
438
+ st.markdown(output_text)
 
 
439
 
440
  else:
441
  st_handler = StreamlitCallbackHandler(st.container())
langchain-streamlit-demo/summarize.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.chains.base import Chain
2
+ from langchain.chains.summarize import load_summarize_chain
3
+ from langchain.prompts import PromptTemplate
4
+ from langchain.schema.language_model import BaseLanguageModel
5
+
6
+ prompt_template = """Write a concise summary of the following text, based on the user input.
7
+ User input: {query}
8
+ Text:
9
+ ```
10
+ {text}
11
+ ```
12
+ CONCISE SUMMARY:"""
13
+
14
+ refine_template = (
15
+ "You are iteratively crafting a summary of the text below based on the user input\n"
16
+ "User input: {query}"
17
+ "We have provided an existing summary up to a certain point: {existing_answer}\n"
18
+ "We have the opportunity to refine the existing summary"
19
+ "(only if needed) with some more context below.\n"
20
+ "------------\n"
21
+ "{text}\n"
22
+ "------------\n"
23
+ "Given the new context, refine the original summary.\n"
24
+ "If the context isn't useful, return the original summary.\n"
25
+ "If the context is useful, refine the summary to include the new context.\n"
26
+ "Your contribution is helping to build a comprehensive summary of a large body of knowledge.\n"
27
+ "You do not have the complete context, so do not discard pieces of the original summary."
28
+ )
29
+
30
+
31
+ def get_summarization_chain(
32
+ llm: BaseLanguageModel,
33
+ prompt: str,
34
+ ) -> Chain:
35
+ _prompt = PromptTemplate.from_template(
36
+ prompt_template,
37
+ partial_variables={"query": prompt},
38
+ )
39
+ refine_prompt = PromptTemplate.from_template(
40
+ refine_template,
41
+ partial_variables={"query": prompt},
42
+ )
43
+ return load_summarize_chain(
44
+ llm=llm,
45
+ chain_type="refine",
46
+ question_prompt=_prompt,
47
+ refine_prompt=refine_prompt,
48
+ return_intermediate_steps=False,
49
+ input_key="input_documents",
50
+ output_key="output_text",
51
+ )