Joshua Sundance Bailey commited on
Commit
048798b
1 Parent(s): f310350

agentic rag

Browse files
langchain-streamlit-demo/app.py CHANGED
@@ -442,6 +442,7 @@ if st.session_state.llm:
442
  MEMORY,
443
  chat_prompt,
444
  prompt,
 
445
  )
446
 
447
  # --- LLM call ---
 
442
  MEMORY,
443
  chat_prompt,
444
  prompt,
445
+ STMEMORY,
446
  )
447
 
448
  # --- LLM call ---
langchain-streamlit-demo/llm_resources.py CHANGED
@@ -1,8 +1,15 @@
 
1
  from tempfile import NamedTemporaryFile
2
  from typing import Tuple, List, Optional, Dict
3
 
 
 
 
 
 
 
4
  from langchain.callbacks.base import BaseCallbackHandler
5
- from langchain.chains import RetrievalQA, LLMChain
6
  from langchain.chat_models import (
7
  AzureChatOpenAI,
8
  ChatOpenAI,
@@ -11,15 +18,16 @@ from langchain.chat_models import (
11
  )
12
  from langchain.document_loaders import PyPDFLoader
13
  from langchain.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
 
14
  from langchain.retrievers import EnsembleRetriever
15
- from langchain.schema import Document, BaseRetriever
16
- from langchain.text_splitter import RecursiveCharacterTextSplitter
17
- from langchain.vectorstores import FAISS
18
-
19
  from langchain.retrievers.multi_query import MultiQueryRetriever
20
  from langchain.retrievers.multi_vector import MultiVectorRetriever
 
 
21
  from langchain.storage import InMemoryStore
22
- import uuid
 
 
23
 
24
  from defaults import DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP, DEFAULT_RETRIEVER_K
25
  from qagen import get_rag_qa_gen_chain
@@ -34,6 +42,7 @@ def get_runnable(
34
  memory,
35
  chat_prompt,
36
  summarization_prompt,
 
37
  ):
38
  if not use_document_chat:
39
  return LLMChain(
@@ -54,13 +63,44 @@ def get_runnable(
54
  llm,
55
  )
56
  else:
57
- return RetrievalQA.from_chain_type(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  llm=llm,
59
- chain_type=document_chat_chain_type,
60
- retriever=retriever,
61
- memory=memory,
62
- output_key="output_text",
63
- ) | (lambda output: output["output_text"])
 
 
 
 
 
 
 
 
 
64
 
65
 
66
  def get_llm(
 
1
+ import uuid
2
  from tempfile import NamedTemporaryFile
3
  from typing import Tuple, List, Optional, Dict
4
 
5
+ from langchain.agents import AgentExecutor
6
+ from langchain.agents.agent_toolkits import create_retriever_tool
7
+ from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
8
+ AgentTokenBufferMemory,
9
+ )
10
+ from langchain.agents.openai_functions_agent.base import OpenAIFunctionsAgent
11
  from langchain.callbacks.base import BaseCallbackHandler
12
+ from langchain.chains import LLMChain
13
  from langchain.chat_models import (
14
  AzureChatOpenAI,
15
  ChatOpenAI,
 
18
  )
19
  from langchain.document_loaders import PyPDFLoader
20
  from langchain.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
21
+ from langchain.prompts import MessagesPlaceholder
22
  from langchain.retrievers import EnsembleRetriever
 
 
 
 
23
  from langchain.retrievers.multi_query import MultiQueryRetriever
24
  from langchain.retrievers.multi_vector import MultiVectorRetriever
25
+ from langchain.schema import Document, BaseRetriever
26
+ from langchain.schema.runnable import RunnablePassthrough
27
  from langchain.storage import InMemoryStore
28
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
29
+ from langchain.vectorstores import FAISS
30
+ from langchain_core.messages import SystemMessage
31
 
32
  from defaults import DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP, DEFAULT_RETRIEVER_K
33
  from qagen import get_rag_qa_gen_chain
 
42
  memory,
43
  chat_prompt,
44
  summarization_prompt,
45
+ chat_history,
46
  ):
47
  if not use_document_chat:
48
  return LLMChain(
 
63
  llm,
64
  )
65
  else:
66
+ tool = create_retriever_tool(
67
+ retriever,
68
+ "search_user_document",
69
+ "Retrieves custom context provided by the user for this conversation. Use this if you cannot answer immediately and confidently.",
70
+ )
71
+ tools = [tool]
72
+ memory_key = "agent_history"
73
+ system_message = SystemMessage(
74
+ content=(
75
+ "Do your best to answer the questions. "
76
+ "Feel free to use any tools available to look up "
77
+ "relevant information, only if necessary"
78
+ ),
79
+ )
80
+ prompt = OpenAIFunctionsAgent.create_prompt(
81
+ system_message=system_message,
82
+ extra_prompt_messages=[MessagesPlaceholder(variable_name=memory_key)],
83
+ )
84
+ agent = OpenAIFunctionsAgent(llm=llm, tools=tools, prompt=prompt)
85
+
86
+ agent_memory = AgentTokenBufferMemory(
87
+ chat_memory=chat_history,
88
+ memory_key=memory_key,
89
  llm=llm,
90
+ )
91
+
92
+ agent_executor = AgentExecutor(
93
+ agent=agent,
94
+ tools=tools,
95
+ memory=agent_memory,
96
+ verbose=True,
97
+ return_intermediate_steps=True,
98
+ )
99
+ return (
100
+ {"input": RunnablePassthrough()}
101
+ | agent_executor
102
+ | (lambda output: output["output"])
103
+ )
104
 
105
 
106
  def get_llm(