Joshua Sundance Bailey commited on
Commit
923e6fa
1 Parent(s): bfaa0c3
langchain-streamlit-demo/app.py CHANGED
@@ -7,7 +7,6 @@ import anthropic
7
  import langsmith.utils
8
  import openai
9
  import streamlit as st
10
- from langchain.callbacks import StreamlitCallbackHandler
11
  from langchain.callbacks.base import BaseCallbackHandler
12
  from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
13
  from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
@@ -360,40 +359,17 @@ for msg in STMEMORY.messages:
360
 
361
  # --- Current Chat ---
362
  if st.session_state.llm:
363
- # --- Document Chat ---
364
- if st.session_state.retriever:
365
- if document_chat_chain_type == "Summarization":
366
- st.session_state.doc_chain = "summarization"
367
- elif document_chat_chain_type == "Q&A Generation":
368
- st.session_state.doc_chain = get_rag_qa_gen_chain(
369
- st.session_state.retriever,
370
- st.session_state.llm,
371
- )
372
- else:
373
- st.session_state.doc_chain = RetrievalQA.from_chain_type(
374
- llm=st.session_state.llm,
375
- chain_type=document_chat_chain_type,
376
- retriever=st.session_state.retriever,
377
- memory=MEMORY,
378
- )
379
-
380
- else:
381
- # --- Regular Chat ---
382
- chat_prompt = ChatPromptTemplate.from_messages(
383
- [
384
- (
385
- "system",
386
- system_prompt + "\nIt's currently {time}.",
387
- ),
388
- MessagesPlaceholder(variable_name="chat_history"),
389
- ("human", "{query}"),
390
- ],
391
- ).partial(time=lambda: str(datetime.now()))
392
- st.session_state.chain = LLMChain(
393
- prompt=chat_prompt,
394
- llm=st.session_state.llm,
395
- memory=MEMORY,
396
- )
397
 
398
  # --- Chat Input ---
399
  prompt = st.chat_input(placeholder="Ask me a question!")
@@ -419,57 +395,60 @@ if st.session_state.llm:
419
  use_document_chat = all(
420
  [
421
  document_chat,
422
- st.session_state.doc_chain,
423
  st.session_state.retriever,
424
  ],
425
  )
426
 
 
 
427
  try:
428
- full_response: Union[str, None]
429
- if use_document_chat:
430
- if document_chat_chain_type in ("Summarization", "Q&A Generation"):
431
- if document_chat_chain_type == "Summarization":
432
- st.session_state.doc_chain = get_rag_summarization_chain(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
433
  prompt,
434
  st.session_state.retriever,
435
  st.session_state.llm,
436
  )
437
- full_response = st.session_state.doc_chain.invoke(
438
- prompt,
439
- config,
440
- )
441
-
442
- else:
443
- st_handler = StreamlitCallbackHandler(st.container())
444
- callbacks.append(st_handler)
445
- full_response = st.session_state.doc_chain(
446
- {"query": prompt},
447
- callbacks=callbacks,
448
- tags=["Streamlit Chat"],
449
- return_only_outputs=True,
450
- )[st.session_state.doc_chain.output_key]
451
- st_handler._complete_current_thought()
452
-
453
  st.markdown(full_response)
454
 
455
- else:
456
- message_placeholder = st.empty()
457
- stream_handler = StreamHandler(message_placeholder)
458
- callbacks.append(stream_handler)
459
- full_response = st.session_state.chain(
460
- {"query": prompt},
461
- callbacks=callbacks,
462
- tags=["Streamlit Chat"],
463
- return_only_outputs=True,
464
- )[st.session_state.chain.output_key]
465
- message_placeholder.markdown(full_response)
466
  except (openai.error.AuthenticationError, anthropic.AuthenticationError):
467
  st.error(
468
  f"Please enter a valid {st.session_state.provider} API key.",
469
  icon="❌",
470
  )
471
- full_response = None
472
- if full_response:
473
  # --- Tracing ---
474
  if st.session_state.client:
475
  st.session_state.run = RUN_COLLECTOR.traced_runs[0]
 
7
  import langsmith.utils
8
  import openai
9
  import streamlit as st
 
10
  from langchain.callbacks.base import BaseCallbackHandler
11
  from langchain.callbacks.tracers.langchain import LangChainTracer, wait_for_all_tracers
12
  from langchain.callbacks.tracers.run_collector import RunCollectorCallbackHandler
 
359
 
360
  # --- Current Chat ---
361
  if st.session_state.llm:
362
+ # --- Regular Chat ---
363
+ chat_prompt = ChatPromptTemplate.from_messages(
364
+ [
365
+ (
366
+ "system",
367
+ system_prompt + "\nIt's currently {time}.",
368
+ ),
369
+ MessagesPlaceholder(variable_name="chat_history"),
370
+ ("human", "{query}"),
371
+ ],
372
+ ).partial(time=lambda: str(datetime.now()))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
 
374
  # --- Chat Input ---
375
  prompt = st.chat_input(placeholder="Ask me a question!")
 
395
  use_document_chat = all(
396
  [
397
  document_chat,
 
398
  st.session_state.retriever,
399
  ],
400
  )
401
 
402
+ full_response = None
403
+
404
  try:
405
+ if not use_document_chat:
406
+ message_placeholder = st.empty()
407
+ stream_handler = StreamHandler(message_placeholder)
408
+ callbacks.append(stream_handler)
409
+ st.session_state.chain = LLMChain(
410
+ prompt=chat_prompt,
411
+ llm=st.session_state.llm,
412
+ memory=MEMORY,
413
+ ) | (lambda output: output["text"])
414
+ config = {"callbacks": callbacks, "tags": ["Streamlit Chat"]}
415
+ full_response = st.session_state.chain.invoke(prompt, config)
416
+ message_placeholder.markdown(full_response)
417
+
418
+ else:
419
+
420
+ def get_rag_runnable():
421
+ if document_chat_chain_type == "Q&A Generation":
422
+ return get_rag_qa_gen_chain(
423
+ st.session_state.retriever,
424
+ st.session_state.llm,
425
+ )
426
+ elif document_chat_chain_type == "Summarization":
427
+ return get_rag_summarization_chain(
428
  prompt,
429
  st.session_state.retriever,
430
  st.session_state.llm,
431
  )
432
+ else:
433
+ return RetrievalQA.from_chain_type(
434
+ llm=st.session_state.llm,
435
+ chain_type=document_chat_chain_type,
436
+ retriever=st.session_state.retriever,
437
+ memory=MEMORY,
438
+ output_key="output_text",
439
+ ) | (lambda output: output["output_text"])
440
+
441
+ st.session_state.doc_chain = get_rag_runnable()
442
+
443
+ full_response = st.session_state.doc_chain.invoke(prompt, config)
 
 
 
 
444
  st.markdown(full_response)
445
 
 
 
 
 
 
 
 
 
 
 
 
446
  except (openai.error.AuthenticationError, anthropic.AuthenticationError):
447
  st.error(
448
  f"Please enter a valid {st.session_state.provider} API key.",
449
  icon="❌",
450
  )
451
+ if full_response is not None:
 
452
  # --- Tracing ---
453
  if st.session_state.client:
454
  st.session_state.run = RUN_COLLECTOR.traced_runs[0]
langchain-streamlit-demo/qagen.py CHANGED
@@ -1,4 +1,3 @@
1
- from functools import reduce
2
  from typing import List
3
 
4
  from langchain.output_parsers import PydanticOutputParser, OutputFixingParser
@@ -59,25 +58,6 @@ CHAT_PROMPT = ChatPromptTemplate.from_messages(
59
  ).partial(format_instructions=PYDANTIC_PARSER.get_format_instructions)
60
 
61
 
62
- def combine_qa_pair_lists(
63
- qa_pair_lists: List[QuestionAnswerPairList],
64
- ) -> QuestionAnswerPairList:
65
- def reducer(
66
- accumulator: QuestionAnswerPairList,
67
- current: QuestionAnswerPairList,
68
- ) -> QuestionAnswerPairList:
69
- return QuestionAnswerPairList(
70
- QuestionAnswerPairs=accumulator.QuestionAnswerPairs
71
- + current.QuestionAnswerPairs,
72
- )
73
-
74
- return reduce(
75
- reducer,
76
- qa_pair_lists,
77
- QuestionAnswerPairList(QuestionAnswerPairs=[]),
78
- )
79
-
80
-
81
  def get_rag_qa_gen_chain(
82
  retriever: BaseRetriever,
83
  llm: BaseLanguageModel,
@@ -88,5 +68,5 @@ def get_rag_qa_gen_chain(
88
  | CHAT_PROMPT
89
  | llm
90
  | OutputFixingParser.from_llm(llm=llm, parser=PYDANTIC_PARSER)
91
- | (lambda parsed_output: combine_qa_pair_lists(parsed_output).to_str())
92
  )
 
 
1
  from typing import List
2
 
3
  from langchain.output_parsers import PydanticOutputParser, OutputFixingParser
 
58
  ).partial(format_instructions=PYDANTIC_PARSER.get_format_instructions)
59
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  def get_rag_qa_gen_chain(
62
  retriever: BaseRetriever,
63
  llm: BaseLanguageModel,
 
68
  | CHAT_PROMPT
69
  | llm
70
  | OutputFixingParser.from_llm(llm=llm, parser=PYDANTIC_PARSER)
71
+ | (lambda parsed_output: parsed_output.to_str())
72
  )