Joshua Sundance Bailey commited on
Commit
622ac66
1 Parent(s): e4344c4
.idea/.name CHANGED
@@ -1 +1 @@
1
- langchain-streamlit-demo
 
1
+ langchain-streamlit-demo
.idea/inspectionProfiles/Project_Default.xml CHANGED
@@ -18,4 +18,4 @@
18
  </inspection_tool>
19
  <inspection_tool class="PyShadowingNamesInspection" enabled="false" level="WEAK WARNING" enabled_by_default="false" />
20
  </profile>
21
- </component>
 
18
  </inspection_tool>
19
  <inspection_tool class="PyShadowingNamesInspection" enabled="false" level="WEAK WARNING" enabled_by_default="false" />
20
  </profile>
21
+ </component>
.idea/inspectionProfiles/profiles_settings.xml CHANGED
@@ -3,4 +3,4 @@
3
  <option name="USE_PROJECT_PROFILE" value="false" />
4
  <version value="1.0" />
5
  </settings>
6
- </component>
 
3
  <option name="USE_PROJECT_PROFILE" value="false" />
4
  <version value="1.0" />
5
  </settings>
6
+ </component>
.idea/kubernetes-settings.xml CHANGED
@@ -3,4 +3,4 @@
3
  <component name="KubernetesSettings">
4
  <option name="contextName" value="swca-aks" />
5
  </component>
6
- </project>
 
3
  <component name="KubernetesSettings">
4
  <option name="contextName" value="swca-aks" />
5
  </component>
6
+ </project>
.idea/langchain-streamlit-demo.iml CHANGED
@@ -5,4 +5,4 @@
5
  <orderEntry type="jdk" jdkName="Remote Python 3.11.4 Docker (&lt;none&gt;:&lt;none&gt;) (5)" jdkType="Python SDK" />
6
  <orderEntry type="sourceFolder" forTests="false" />
7
  </component>
8
- </module>
 
5
  <orderEntry type="jdk" jdkName="Remote Python 3.11.4 Docker (&lt;none&gt;:&lt;none&gt;) (5)" jdkType="Python SDK" />
6
  <orderEntry type="sourceFolder" forTests="false" />
7
  </component>
8
+ </module>
.idea/misc.xml CHANGED
@@ -1,4 +1,4 @@
1
  <?xml version="1.0" encoding="UTF-8"?>
2
  <project version="4">
3
  <component name="ProjectRootManager" version="2" project-jdk-name="Remote Python 3.11.4 Docker (&lt;none&gt;:&lt;none&gt;) (5)" project-jdk-type="Python SDK" />
4
- </project>
 
1
  <?xml version="1.0" encoding="UTF-8"?>
2
  <project version="4">
3
  <component name="ProjectRootManager" version="2" project-jdk-name="Remote Python 3.11.4 Docker (&lt;none&gt;:&lt;none&gt;) (5)" project-jdk-type="Python SDK" />
4
+ </project>
.idea/modules.xml CHANGED
@@ -5,4 +5,4 @@
5
  <module fileurl="file://$PROJECT_DIR$/.idea/langchain-streamlit-demo.iml" filepath="$PROJECT_DIR$/.idea/langchain-streamlit-demo.iml" />
6
  </modules>
7
  </component>
8
- </project>
 
5
  <module fileurl="file://$PROJECT_DIR$/.idea/langchain-streamlit-demo.iml" filepath="$PROJECT_DIR$/.idea/langchain-streamlit-demo.iml" />
6
  </modules>
7
  </component>
8
+ </project>
.idea/vcs.xml CHANGED
@@ -3,4 +3,4 @@
3
  <component name="VcsDirectoryMappings">
4
  <mapping directory="$PROJECT_DIR$" vcs="Git" />
5
  </component>
6
- </project>
 
3
  <component name="VcsDirectoryMappings">
4
  <mapping directory="$PROJECT_DIR$" vcs="Git" />
5
  </component>
6
+ </project>
langchain-streamlit-demo/app.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  from datetime import datetime
3
  from tempfile import NamedTemporaryFile
4
- from typing import Union
5
 
6
  import anthropic
7
  import langsmith.utils
@@ -18,12 +18,15 @@ from langchain.document_loaders import PyPDFLoader
18
  from langchain.embeddings import OpenAIEmbeddings
19
  from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
20
  from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
 
21
  from langchain.schema.retriever import BaseRetriever
22
  from langchain.text_splitter import RecursiveCharacterTextSplitter
23
  from langchain.vectorstores import FAISS
24
  from langsmith.client import Client
25
  from streamlit_feedback import streamlit_feedback
26
 
 
 
27
  __version__ = "0.0.6"
28
 
29
  # --- Initialization ---
@@ -46,6 +49,7 @@ st_init_null(
46
  "document_chat_chain_type",
47
  "llm",
48
  "ls_tracer",
 
49
  "retriever",
50
  "run",
51
  "run_id",
@@ -120,11 +124,11 @@ DEFAULT_CHUNK_OVERLAP = 0
120
 
121
 
122
  @st.cache_data
123
- def get_retriever(
124
  uploaded_file_bytes: bytes,
125
  chunk_size: int = DEFAULT_CHUNK_SIZE,
126
  chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
127
- ) -> BaseRetriever:
128
  with NamedTemporaryFile() as temp_file:
129
  temp_file.write(uploaded_file_bytes)
130
  temp_file.seek(0)
@@ -138,7 +142,7 @@ def get_retriever(
138
  texts = text_splitter.split_documents(documents)
139
  embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
140
  db = FAISS.from_documents(texts, embeddings)
141
- return db.as_retriever()
142
 
143
 
144
  # --- Sidebar ---
@@ -152,10 +156,12 @@ with sidebar:
152
  index=SUPPORTED_MODELS.index(DEFAULT_MODEL),
153
  )
154
 
155
- provider = MODEL_DICT[model]
156
 
157
- provider_api_key = PROVIDER_KEY_DICT.get(provider) or st.text_input(
158
- f"{provider} API key",
 
 
159
  type="password",
160
  )
161
 
@@ -170,7 +176,7 @@ with sidebar:
170
 
171
  openai_api_key = (
172
  provider_api_key
173
- if provider == "OpenAI"
174
  else OPENAI_API_KEY
175
  or st.sidebar.text_input("OpenAI API Key: ", type="password")
176
  )
@@ -210,7 +216,7 @@ with sidebar:
210
  )
211
  document_chat_chain_type = st.selectbox(
212
  label="Document Chat Chain Type",
213
- options=["stuff", "refine", "map_reduce", "map_rerank"],
214
  index=0,
215
  help=chain_type_help,
216
  disabled=not document_chat,
@@ -218,7 +224,10 @@ with sidebar:
218
 
219
  if uploaded_file:
220
  if openai_api_key:
221
- st.session_state.retriever = get_retriever(
 
 
 
222
  uploaded_file_bytes=uploaded_file.getvalue(),
223
  chunk_size=chunk_size,
224
  chunk_overlap=chunk_overlap,
@@ -280,7 +289,7 @@ with sidebar:
280
 
281
  # --- LLM Instantiation ---
282
  if provider_api_key:
283
- if provider == "OpenAI":
284
  st.session_state.llm = ChatOpenAI(
285
  model=model,
286
  openai_api_key=provider_api_key,
@@ -288,7 +297,7 @@ if provider_api_key:
288
  streaming=True,
289
  max_tokens=max_tokens,
290
  )
291
- elif provider == "Anthropic":
292
  st.session_state.llm = ChatAnthropic(
293
  model_name=model,
294
  anthropic_api_key=provider_api_key,
@@ -296,7 +305,7 @@ if provider_api_key:
296
  streaming=True,
297
  max_tokens_to_sample=max_tokens,
298
  )
299
- elif provider == "Anyscale Endpoints":
300
  st.session_state.llm = ChatAnyscale(
301
  model=model,
302
  anyscale_api_key=provider_api_key,
@@ -321,18 +330,24 @@ for msg in STMEMORY.messages:
321
  if st.session_state.llm:
322
  # --- Document Chat ---
323
  if st.session_state.retriever:
324
- # st.session_state.doc_chain = ConversationalRetrievalChain.from_llm(
325
- # st.session_state.llm,
326
- # st.session_state.retriever,
327
- # memory=MEMORY,
328
- # )
329
-
330
- st.session_state.doc_chain = RetrievalQA.from_chain_type(
331
- llm=st.session_state.llm,
332
- chain_type=document_chat_chain_type,
333
- retriever=st.session_state.retriever,
334
- memory=MEMORY,
335
- )
 
 
 
 
 
 
336
 
337
  else:
338
  # --- Regular Chat ---
@@ -375,17 +390,45 @@ if st.session_state.llm:
375
  )
376
 
377
  try:
 
378
  if use_document_chat:
379
- st_handler = StreamlitCallbackHandler(st.container())
380
- callbacks.append(st_handler)
381
- full_response = st.session_state.doc_chain(
382
- {"query": prompt},
383
- callbacks=callbacks,
384
- tags=["Streamlit Chat"],
385
- return_only_outputs=True,
386
- )[st.session_state.doc_chain.output_key]
387
- st_handler._complete_current_thought()
388
- st.markdown(full_response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
  else:
390
  message_placeholder = st.empty()
391
  stream_handler = StreamHandler(message_placeholder)
@@ -399,7 +442,7 @@ if st.session_state.llm:
399
  message_placeholder.markdown(full_response)
400
  except (openai.error.AuthenticationError, anthropic.AuthenticationError):
401
  st.error(
402
- f"Please enter a valid {provider} API key.",
403
  icon="❌",
404
  )
405
  full_response = None
@@ -468,4 +511,4 @@ if st.session_state.llm:
468
  st.warning("Invalid feedback score.")
469
 
470
  else:
471
- st.error(f"Please enter a valid {provider} API key.", icon="❌")
 
1
  import os
2
  from datetime import datetime
3
  from tempfile import NamedTemporaryFile
4
+ from typing import Tuple, List, Dict, Any, Union
5
 
6
  import anthropic
7
  import langsmith.utils
 
18
  from langchain.embeddings import OpenAIEmbeddings
19
  from langchain.memory import ConversationBufferMemory, StreamlitChatMessageHistory
20
  from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
21
+ from langchain.schema.document import Document
22
  from langchain.schema.retriever import BaseRetriever
23
  from langchain.text_splitter import RecursiveCharacterTextSplitter
24
  from langchain.vectorstores import FAISS
25
  from langsmith.client import Client
26
  from streamlit_feedback import streamlit_feedback
27
 
28
+ from qagen import get_qa_gen_chain, combine_qa_pair_lists
29
+
30
  __version__ = "0.0.6"
31
 
32
  # --- Initialization ---
 
49
  "document_chat_chain_type",
50
  "llm",
51
  "ls_tracer",
52
+ "provider",
53
  "retriever",
54
  "run",
55
  "run_id",
 
124
 
125
 
126
  @st.cache_data
127
+ def get_texts_and_retriever(
128
  uploaded_file_bytes: bytes,
129
  chunk_size: int = DEFAULT_CHUNK_SIZE,
130
  chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
131
+ ) -> Tuple[List[Document], BaseRetriever]:
132
  with NamedTemporaryFile() as temp_file:
133
  temp_file.write(uploaded_file_bytes)
134
  temp_file.seek(0)
 
142
  texts = text_splitter.split_documents(documents)
143
  embeddings = OpenAIEmbeddings(openai_api_key=openai_api_key)
144
  db = FAISS.from_documents(texts, embeddings)
145
+ return texts, db.as_retriever()
146
 
147
 
148
  # --- Sidebar ---
 
156
  index=SUPPORTED_MODELS.index(DEFAULT_MODEL),
157
  )
158
 
159
+ st.session_state.provider = MODEL_DICT[model]
160
 
161
+ provider_api_key = PROVIDER_KEY_DICT.get(
162
+ st.session_state.provider,
163
+ ) or st.text_input(
164
+ f"{st.session_state.provider} API key",
165
  type="password",
166
  )
167
 
 
176
 
177
  openai_api_key = (
178
  provider_api_key
179
+ if st.session_state.provider == "OpenAI"
180
  else OPENAI_API_KEY
181
  or st.sidebar.text_input("OpenAI API Key: ", type="password")
182
  )
 
216
  )
217
  document_chat_chain_type = st.selectbox(
218
  label="Document Chat Chain Type",
219
+ options=["stuff", "refine", "map_reduce", "map_rerank", "Q&A Generation"],
220
  index=0,
221
  help=chain_type_help,
222
  disabled=not document_chat,
 
224
 
225
  if uploaded_file:
226
  if openai_api_key:
227
+ (
228
+ st.session_state.texts,
229
+ st.session_state.retriever,
230
+ ) = get_texts_and_retriever(
231
  uploaded_file_bytes=uploaded_file.getvalue(),
232
  chunk_size=chunk_size,
233
  chunk_overlap=chunk_overlap,
 
289
 
290
  # --- LLM Instantiation ---
291
  if provider_api_key:
292
+ if st.session_state.provider == "OpenAI":
293
  st.session_state.llm = ChatOpenAI(
294
  model=model,
295
  openai_api_key=provider_api_key,
 
297
  streaming=True,
298
  max_tokens=max_tokens,
299
  )
300
+ elif st.session_state.provider == "Anthropic":
301
  st.session_state.llm = ChatAnthropic(
302
  model_name=model,
303
  anthropic_api_key=provider_api_key,
 
305
  streaming=True,
306
  max_tokens_to_sample=max_tokens,
307
  )
308
+ elif st.session_state.provider == "Anyscale Endpoints":
309
  st.session_state.llm = ChatAnyscale(
310
  model=model,
311
  anyscale_api_key=provider_api_key,
 
330
  if st.session_state.llm:
331
  # --- Document Chat ---
332
  if st.session_state.retriever:
333
+ if document_chat_chain_type == "Summarization":
334
+ raise NotImplementedError
335
+ # st.session_state.doc_chain = RetrievalQA.from_chain_type(
336
+ # llm=st.session_state.llm,
337
+ # chain_type=chain_type,
338
+ # retriever=st.session_state.retriever,
339
+ # memory=MEMORY,
340
+ # )
341
+ elif document_chat_chain_type == "Q&A Generation":
342
+ st.session_state.doc_chain = get_qa_gen_chain(st.session_state.llm)
343
+
344
+ else:
345
+ st.session_state.doc_chain = RetrievalQA.from_chain_type(
346
+ llm=st.session_state.llm,
347
+ chain_type=document_chat_chain_type,
348
+ retriever=st.session_state.retriever,
349
+ memory=MEMORY,
350
+ )
351
 
352
  else:
353
  # --- Regular Chat ---
 
390
  )
391
 
392
  try:
393
+ full_response: Union[str, None]
394
  if use_document_chat:
395
+ if document_chat_chain_type == "Summarization":
396
+ raise NotImplementedError
397
+ elif document_chat_chain_type == "Q&A Generation":
398
+ config: Dict[str, Any] = dict(
399
+ callbacks=callbacks,
400
+ tags=["Streamlit Chat"],
401
+ )
402
+ if st.session_state.provider == "Anthropic":
403
+ config["max_concurrency"] = 5
404
+ raw_results = st.session_state.doc_chain.batch(
405
+ [
406
+ {"input": doc.page_content, "prompt": prompt}
407
+ for doc in st.session_state.texts
408
+ ],
409
+ config,
410
+ )
411
+ results = combine_qa_pair_lists(raw_results).QuestionAnswerPairs
412
+ full_response = "\n".join(
413
+ f"**Q:** {result.question}\n**A:** {result.answer}\n"
414
+ for result in results
415
+ )
416
+ for idx, result in enumerate(results, start=1):
417
+ st.markdown(f"{idx}. **Q:** {result.question}")
418
+ st.markdown(f"{idx}. **A:** {result.answer}")
419
+ st.markdown("\n")
420
+
421
+ else:
422
+ st_handler = StreamlitCallbackHandler(st.container())
423
+ callbacks.append(st_handler)
424
+ full_response = st.session_state.doc_chain(
425
+ {"query": prompt},
426
+ callbacks=callbacks,
427
+ tags=["Streamlit Chat"],
428
+ return_only_outputs=True,
429
+ )[st.session_state.doc_chain.output_key]
430
+ st_handler._complete_current_thought()
431
+ st.markdown(full_response)
432
  else:
433
  message_placeholder = st.empty()
434
  stream_handler = StreamHandler(message_placeholder)
 
442
  message_placeholder.markdown(full_response)
443
  except (openai.error.AuthenticationError, anthropic.AuthenticationError):
444
  st.error(
445
+ f"Please enter a valid {st.session_state.provider} API key.",
446
  icon="❌",
447
  )
448
  full_response = None
 
511
  st.warning("Invalid feedback score.")
512
 
513
  else:
514
+ st.error(f"Please enter a valid {st.session_state.provider} API key.", icon="❌")
langchain-streamlit-demo/qagen.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import reduce
2
+ from typing import List
3
+
4
+ from langchain.output_parsers import PydanticOutputParser, OutputFixingParser
5
+ from langchain.prompts.chat import (
6
+ ChatPromptTemplate,
7
+ )
8
+ from langchain.schema.language_model import BaseLanguageModel
9
+ from langchain.schema.runnable import RunnableSequence
10
+ from pydantic import BaseModel, field_validator, Field
11
+
12
+
13
+ class QuestionAnswerPair(BaseModel):
14
+ question: str = Field(..., description="The question that will be answered.")
15
+ answer: str = Field(..., description="The answer to the question that was asked.")
16
+
17
+ @field_validator("question")
18
+ def validate_question(cls, v: str) -> str:
19
+ if not v.endswith("?"):
20
+ raise ValueError("Question must end with a question mark.")
21
+ return v
22
+
23
+
24
+ class QuestionAnswerPairList(BaseModel):
25
+ QuestionAnswerPairs: List[QuestionAnswerPair]
26
+
27
+
28
+ PYDANTIC_PARSER: PydanticOutputParser = PydanticOutputParser(
29
+ pydantic_object=QuestionAnswerPairList,
30
+ )
31
+
32
+
33
+ templ1 = """You are a smart assistant designed to help college professors come up with reading comprehension questions.
34
+ Given a piece of text, you must come up with question and answer pairs that can be used to test a student's reading comprehension abilities.
35
+ Generate as many question/answer pairs as you can.
36
+ When coming up with the question/answer pairs, you must respond in the following format:
37
+ {format_instructions}
38
+
39
+ Do not provide additional commentary and do not wrap your response in Markdown formatting. Return RAW, VALID JSON.
40
+ """
41
+ templ2 = """{prompt}
42
+ Please create question/answer pairs, in the specified JSON format, for the following text:
43
+ ----------------
44
+ {input}"""
45
+ CHAT_PROMPT = ChatPromptTemplate.from_messages(
46
+ [
47
+ ("system", templ1),
48
+ ("human", templ2),
49
+ ],
50
+ ).partial(format_instructions=PYDANTIC_PARSER.get_format_instructions)
51
+
52
+
53
+ def combine_qa_pair_lists(
54
+ qa_pair_lists: List[QuestionAnswerPairList],
55
+ ) -> QuestionAnswerPairList:
56
+ def reducer(
57
+ accumulator: QuestionAnswerPairList,
58
+ current: QuestionAnswerPairList,
59
+ ) -> QuestionAnswerPairList:
60
+ return QuestionAnswerPairList(
61
+ QuestionAnswerPairs=accumulator.QuestionAnswerPairs
62
+ + current.QuestionAnswerPairs,
63
+ )
64
+
65
+ return reduce(
66
+ reducer,
67
+ qa_pair_lists,
68
+ QuestionAnswerPairList(QuestionAnswerPairs=[]),
69
+ )
70
+
71
+
72
+ def get_qa_gen_chain(llm: BaseLanguageModel) -> RunnableSequence:
73
+ return (
74
+ CHAT_PROMPT | llm | OutputFixingParser.from_llm(llm=llm, parser=PYDANTIC_PARSER)
75
+ )