Joshua Sundance Bailey commited on
Commit
679726e
1 Parent(s): 64e3f44

parameterize research assistant llms

Browse files
langchain-streamlit-demo/app.py CHANGED
@@ -26,7 +26,7 @@ from llm_resources import (
26
  get_runnable,
27
  get_texts_and_multiretriever,
28
  )
29
- from research_assistant.chain import chain as research_assistant_chain
30
 
31
  __version__ = "2.0.1"
32
 
@@ -367,7 +367,7 @@ with sidebar:
367
 
368
 
369
  # --- LLM Instantiation ---
370
- st.session_state.llm = get_llm(
371
  provider=st.session_state.provider,
372
  model=model,
373
  provider_api_key=provider_api_key,
@@ -382,6 +382,8 @@ st.session_state.llm = get_llm(
382
  "AZURE_OPENAI_MODEL_VERSION": st.session_state.AZURE_OPENAI_MODEL_VERSION,
383
  },
384
  )
 
 
385
 
386
  # --- Chat History ---
387
  for msg in STMEMORY.messages:
@@ -448,12 +450,16 @@ if st.session_state.llm:
448
  WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()),
449
  ]
450
  if st.session_state.provider in ("Azure OpenAI", "OpenAI"):
 
 
 
 
451
  st_callback = StreamlitCallbackHandler(st.container())
452
  callbacks.append(st_callback)
453
  research_assistant_tool = Tool.from_function(
454
  func=lambda s: research_assistant_chain.invoke(
455
  {"question": s},
456
- config=get_config(callbacks),
457
  ),
458
  name="web-research-assistant",
459
  description="this assistant returns a comprehensive report based on web research. for quick facts, use duckduckgo instead.",
@@ -473,7 +479,7 @@ if st.session_state.llm:
473
  doc_chain_tool = Tool.from_function(
474
  func=lambda s: st.session_state.doc_chain.invoke(
475
  s,
476
- config=get_config(callbacks),
477
  ),
478
  name="user-document-chat",
479
  description="this assistant returns a response based on the user's custom context. if the user's meaning is unclear, perhaps the answer is here. generally speaking, try this tool before conducting web research.",
 
26
  get_runnable,
27
  get_texts_and_multiretriever,
28
  )
29
+ from research_assistant.chain import get_chain as get_research_assistant_chain
30
 
31
  __version__ = "2.0.1"
32
 
 
367
 
368
 
369
  # --- LLM Instantiation ---
370
+ get_llm_args = dict(
371
  provider=st.session_state.provider,
372
  model=model,
373
  provider_api_key=provider_api_key,
 
382
  "AZURE_OPENAI_MODEL_VERSION": st.session_state.AZURE_OPENAI_MODEL_VERSION,
383
  },
384
  )
385
+ get_llm_args_temp_zero = get_llm_args | {"temperature": 0.0}
386
+ st.session_state.llm = get_llm(**get_llm_args)
387
 
388
  # --- Chat History ---
389
  for msg in STMEMORY.messages:
 
450
  WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper()),
451
  ]
452
  if st.session_state.provider in ("Azure OpenAI", "OpenAI"):
453
+ research_assistant_chain = get_research_assistant_chain(
454
+ search_llm=get_llm(**get_llm_args_temp_zero), # type: ignore
455
+ writer_llm=get_llm(**get_llm_args_temp_zero), # type: ignore
456
+ )
457
  st_callback = StreamlitCallbackHandler(st.container())
458
  callbacks.append(st_callback)
459
  research_assistant_tool = Tool.from_function(
460
  func=lambda s: research_assistant_chain.invoke(
461
  {"question": s},
462
+ # config=get_config(callbacks),
463
  ),
464
  name="web-research-assistant",
465
  description="this assistant returns a comprehensive report based on web research. for quick facts, use duckduckgo instead.",
 
479
  doc_chain_tool = Tool.from_function(
480
  func=lambda s: st.session_state.doc_chain.invoke(
481
  s,
482
+ # config=get_config(callbacks),
483
  ),
484
  name="user-document-chat",
485
  description="this assistant returns a response based on the user's custom context. if the user's meaning is unclear, perhaps the answer is here. generally speaking, try this tool before conducting web research.",
langchain-streamlit-demo/research_assistant/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
- from research_assistant.chain import chain
2
 
3
- __all__ = ["chain"]
 
1
+ from research_assistant.chain import get_chain
2
 
3
+ __all__ = ["get_chain"]
langchain-streamlit-demo/research_assistant/chain.py CHANGED
@@ -1,16 +1,18 @@
1
  from langchain_core.pydantic_v1 import BaseModel
2
  from langchain_core.runnables import RunnablePassthrough
3
 
4
- from research_assistant.search.web import chain as search_chain
5
- from research_assistant.writer import chain as writer_chain
 
 
6
 
7
- chain_notypes = (
8
- RunnablePassthrough().assign(research_summary=search_chain) | writer_chain
9
- )
10
 
 
 
 
 
11
 
12
- class InputType(BaseModel):
13
- question: str
14
 
15
-
16
- chain = chain_notypes.with_types(input_type=InputType)
 
1
  from langchain_core.pydantic_v1 import BaseModel
2
  from langchain_core.runnables import RunnablePassthrough
3
 
4
+ from research_assistant.search.web import get_search_chain
5
+ from research_assistant.writer import get_writer_chain
6
+ from langchain.llms.base import BaseLLM
7
+ from langchain.schema.runnable import Runnable
8
 
 
 
 
9
 
10
+ def get_chain(search_llm: BaseLLM, writer_llm: BaseLLM) -> Runnable:
11
+ chain_notypes = RunnablePassthrough().assign(
12
+ research_summary=get_search_chain(search_llm),
13
+ ) | get_writer_chain(writer_llm)
14
 
15
+ class InputType(BaseModel):
16
+ question: str
17
 
18
+ return chain_notypes.with_types(input_type=InputType)
 
langchain-streamlit-demo/research_assistant/search/web.py CHANGED
@@ -3,7 +3,7 @@ from typing import Any
3
 
4
  import requests
5
  from bs4 import BeautifulSoup
6
- from langchain.chat_models import ChatOpenAI
7
  from langchain.prompts import ChatPromptTemplate
8
  from langchain.retrievers.tavily_search_api import TavilySearchAPIRetriever
9
  from langchain.utilities import DuckDuckGoSearchAPIWrapper
@@ -130,25 +130,6 @@ Using the above text, answer in short the following question:
130
  if the question cannot be answered using the text, imply summarize the text. Include all factual information, numbers, stats etc if available.""" # noqa: E501
131
  SUMMARY_PROMPT = ChatPromptTemplate.from_template(SUMMARY_TEMPLATE)
132
 
133
- scrape_and_summarize: Runnable[Any, Any] = (
134
- RunnableParallel(
135
- {
136
- "question": lambda x: x["question"],
137
- "text": lambda x: scrape_text(x["url"])[:10000],
138
- "url": lambda x: x["url"],
139
- },
140
- )
141
- | RunnableParallel(
142
- {
143
- "summary": SUMMARY_PROMPT | ChatOpenAI(temperature=0) | StrOutputParser(),
144
- "url": lambda x: x["url"],
145
- },
146
- )
147
- | RunnableLambda(lambda x: f"Source Url: {x['url']}\nSummary: {x['summary']}")
148
- )
149
-
150
- multi_search = get_links | scrape_and_summarize.map() | (lambda x: "\n".join(x))
151
-
152
 
153
  def load_json(s):
154
  try:
@@ -157,24 +138,41 @@ def load_json(s):
157
  return {}
158
 
159
 
160
- search_query = SEARCH_PROMPT | ChatOpenAI(temperature=0) | StrOutputParser() | load_json
161
- choose_agent = (
162
- CHOOSE_AGENT_PROMPT | ChatOpenAI(temperature=0) | StrOutputParser() | load_json
163
- )
164
-
165
- get_search_queries = (
166
- RunnablePassthrough().assign(
167
- agent_prompt=RunnableParallel({"task": lambda x: x})
168
- | choose_agent
169
- | (lambda x: x.get("agent_role_prompt")),
 
 
 
 
 
 
170
  )
171
- | search_query
172
- )
173
 
 
174
 
175
- chain = (
176
- get_search_queries
177
- | (lambda x: [{"question": q} for q in x])
178
- | multi_search.map()
179
- | (lambda x: "\n\n".join(x))
180
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  import requests
5
  from bs4 import BeautifulSoup
6
+ from langchain.llms.base import BaseLLM
7
  from langchain.prompts import ChatPromptTemplate
8
  from langchain.retrievers.tavily_search_api import TavilySearchAPIRetriever
9
  from langchain.utilities import DuckDuckGoSearchAPIWrapper
 
130
  if the question cannot be answered using the text, imply summarize the text. Include all factual information, numbers, stats etc if available.""" # noqa: E501
131
  SUMMARY_PROMPT = ChatPromptTemplate.from_template(SUMMARY_TEMPLATE)
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  def load_json(s):
135
  try:
 
138
  return {}
139
 
140
 
141
+ def get_search_chain(model: BaseLLM) -> Runnable:
142
+ scrape_and_summarize: Runnable[Any, Any] = (
143
+ RunnableParallel(
144
+ {
145
+ "question": lambda x: x["question"],
146
+ "text": lambda x: scrape_text(x["url"])[:10000],
147
+ "url": lambda x: x["url"],
148
+ },
149
+ )
150
+ | RunnableParallel(
151
+ {
152
+ "summary": SUMMARY_PROMPT | model | StrOutputParser(),
153
+ "url": lambda x: x["url"],
154
+ },
155
+ )
156
+ | RunnableLambda(lambda x: f"Source Url: {x['url']}\nSummary: {x['summary']}")
157
  )
 
 
158
 
159
+ multi_search = get_links | scrape_and_summarize.map() | (lambda x: "\n".join(x))
160
 
161
+ search_query = SEARCH_PROMPT | model | StrOutputParser() | load_json
162
+ choose_agent = CHOOSE_AGENT_PROMPT | model | StrOutputParser() | load_json
163
+
164
+ get_search_queries = (
165
+ RunnablePassthrough().assign(
166
+ agent_prompt=RunnableParallel({"task": lambda x: x})
167
+ | choose_agent
168
+ | (lambda x: x.get("agent_role_prompt")),
169
+ )
170
+ | search_query
171
+ )
172
+
173
+ return (
174
+ get_search_queries
175
+ | (lambda x: [{"question": q} for q in x])
176
+ | multi_search.map()
177
+ | (lambda x: "\n\n".join(x))
178
+ )
langchain-streamlit-demo/research_assistant/writer.py CHANGED
@@ -1,7 +1,8 @@
1
- from langchain.chat_models import ChatOpenAI
2
  from langchain.prompts import ChatPromptTemplate
3
  from langchain_core.output_parsers import StrOutputParser
4
  from langchain_core.runnables import ConfigurableField
 
 
5
 
6
  WRITER_SYSTEM_PROMPT = "You are an AI critical thinker research assistant. Your sole purpose is to write well written, critically acclaimed, objective and structured reports on given text." # noqa: E501
7
 
@@ -50,7 +51,6 @@ Use appropriate Markdown syntax to format the outline and ensure readability.
50
 
51
  Please do your best, this is very important to my career.""" # noqa: E501
52
 
53
- model = ChatOpenAI(temperature=0)
54
  prompt = ChatPromptTemplate.from_messages(
55
  [
56
  ("system", WRITER_SYSTEM_PROMPT),
@@ -72,4 +72,7 @@ prompt = ChatPromptTemplate.from_messages(
72
  ],
73
  ),
74
  )
75
- chain = prompt | model | StrOutputParser()
 
 
 
 
 
1
  from langchain.prompts import ChatPromptTemplate
2
  from langchain_core.output_parsers import StrOutputParser
3
  from langchain_core.runnables import ConfigurableField
4
+ from langchain.llms.base import BaseLLM
5
+ from langchain.schema.runnable import Runnable
6
 
7
  WRITER_SYSTEM_PROMPT = "You are an AI critical thinker research assistant. Your sole purpose is to write well written, critically acclaimed, objective and structured reports on given text." # noqa: E501
8
 
 
51
 
52
  Please do your best, this is very important to my career.""" # noqa: E501
53
 
 
54
  prompt = ChatPromptTemplate.from_messages(
55
  [
56
  ("system", WRITER_SYSTEM_PROMPT),
 
72
  ],
73
  ),
74
  )
75
+
76
+
77
+ def get_writer_chain(model: BaseLLM) -> Runnable:
78
+ return prompt | model | StrOutputParser()