Joshua Sundance Bailey commited on
Commit
bfaa0c3
1 Parent(s): dd9bfbd
langchain-streamlit-demo/app.py CHANGED
@@ -409,6 +409,13 @@ if st.session_state.llm:
409
  if st.session_state.ls_tracer:
410
  callbacks.append(st.session_state.ls_tracer)
411
 
 
 
 
 
 
 
 
412
  use_document_chat = all(
413
  [
414
  document_chat,
@@ -420,46 +427,18 @@ if st.session_state.llm:
420
  try:
421
  full_response: Union[str, None]
422
  if use_document_chat:
423
- if document_chat_chain_type == "Summarization":
424
- st.session_state.doc_chain = get_rag_summarization_chain(
425
- prompt,
426
- st.session_state.retriever,
427
- st.session_state.llm,
428
- )
 
429
  full_response = st.session_state.doc_chain.invoke(
430
  prompt,
431
- dict(
432
- callbacks=callbacks,
433
- tags=["Streamlit Chat"],
434
- ),
435
- )
436
-
437
- st.markdown(full_response)
438
- elif document_chat_chain_type == "Q&A Generation":
439
- config: Dict[str, Any] = dict(
440
- callbacks=callbacks,
441
- tags=["Streamlit Chat"],
442
- )
443
- if st.session_state.provider == "Anthropic":
444
- config["max_concurrency"] = 5
445
- raw_results = st.session_state.doc_chain.invoke(prompt, config)
446
- results = raw_results.QuestionAnswerPairs
447
-
448
- def _to_str(idx, qap):
449
- question_piece = f"{idx}. **Q:** {qap.question}"
450
- whitespace = " " * (len(str(idx)) + 2)
451
- answer_piece = f"{whitespace}**A:** {qap.answer}"
452
- return f"{question_piece}\n\n{answer_piece}"
453
-
454
- full_response = "\n\n".join(
455
- [
456
- _to_str(idx, qap)
457
- for idx, qap in enumerate(results, start=1)
458
- ],
459
  )
460
 
461
- st.markdown(full_response)
462
-
463
  else:
464
  st_handler = StreamlitCallbackHandler(st.container())
465
  callbacks.append(st_handler)
@@ -470,7 +449,9 @@ if st.session_state.llm:
470
  return_only_outputs=True,
471
  )[st.session_state.doc_chain.output_key]
472
  st_handler._complete_current_thought()
473
- st.markdown(full_response)
 
 
474
  else:
475
  message_placeholder = st.empty()
476
  stream_handler = StreamHandler(message_placeholder)
 
409
  if st.session_state.ls_tracer:
410
  callbacks.append(st.session_state.ls_tracer)
411
 
412
+ config: Dict[str, Any] = dict(
413
+ callbacks=callbacks,
414
+ tags=["Streamlit Chat"],
415
+ )
416
+ if st.session_state.provider == "Anthropic":
417
+ config["max_concurrency"] = 5
418
+
419
  use_document_chat = all(
420
  [
421
  document_chat,
 
427
  try:
428
  full_response: Union[str, None]
429
  if use_document_chat:
430
+ if document_chat_chain_type in ("Summarization", "Q&A Generation"):
431
+ if document_chat_chain_type == "Summarization":
432
+ st.session_state.doc_chain = get_rag_summarization_chain(
433
+ prompt,
434
+ st.session_state.retriever,
435
+ st.session_state.llm,
436
+ )
437
  full_response = st.session_state.doc_chain.invoke(
438
  prompt,
439
+ config,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
440
  )
441
 
 
 
442
  else:
443
  st_handler = StreamlitCallbackHandler(st.container())
444
  callbacks.append(st_handler)
 
449
  return_only_outputs=True,
450
  )[st.session_state.doc_chain.output_key]
451
  st_handler._complete_current_thought()
452
+
453
+ st.markdown(full_response)
454
+
455
  else:
456
  message_placeholder = st.empty()
457
  stream_handler = StreamHandler(message_placeholder)
langchain-streamlit-demo/qagen.py CHANGED
@@ -15,10 +15,24 @@ class QuestionAnswerPair(BaseModel):
15
  question: str = Field(..., description="The question that will be answered.")
16
  answer: str = Field(..., description="The answer to the question that was asked.")
17
 
 
 
 
 
 
 
18
 
19
  class QuestionAnswerPairList(BaseModel):
20
  QuestionAnswerPairs: List[QuestionAnswerPair]
21
 
 
 
 
 
 
 
 
 
22
 
23
  PYDANTIC_PARSER: PydanticOutputParser = PydanticOutputParser(
24
  pydantic_object=QuestionAnswerPairList,
@@ -74,4 +88,5 @@ def get_rag_qa_gen_chain(
74
  | CHAT_PROMPT
75
  | llm
76
  | OutputFixingParser.from_llm(llm=llm, parser=PYDANTIC_PARSER)
 
77
  )
 
15
  question: str = Field(..., description="The question that will be answered.")
16
  answer: str = Field(..., description="The answer to the question that was asked.")
17
 
18
+ def to_str(self, idx: int) -> str:
19
+ question_piece = f"{idx}. **Q:** {self.question}"
20
+ whitespace = " " * (len(str(idx)) + 2)
21
+ answer_piece = f"{whitespace}**A:** {self.answer}"
22
+ return f"{question_piece}\n\n{answer_piece}"
23
+
24
 
25
  class QuestionAnswerPairList(BaseModel):
26
  QuestionAnswerPairs: List[QuestionAnswerPair]
27
 
28
+ def to_str(self) -> str:
29
+ return "\n\n".join(
30
+ [
31
+ qap.to_str(idx)
32
+ for idx, qap in enumerate(self.QuestionAnswerPairs, start=1)
33
+ ],
34
+ )
35
+
36
 
37
  PYDANTIC_PARSER: PydanticOutputParser = PydanticOutputParser(
38
  pydantic_object=QuestionAnswerPairList,
 
88
  | CHAT_PROMPT
89
  | llm
90
  | OutputFixingParser.from_llm(llm=llm, parser=PYDANTIC_PARSER)
91
+ | (lambda parsed_output: combine_qa_pair_lists(parsed_output).to_str())
92
  )