ValentinGuigon commited on
Commit
f72ff7c
·
1 Parent(s): 8891c1e

Testing a simpler agent

Browse files
Files changed (1) hide show
  1. agents/agent.py +149 -68
agents/agent.py CHANGED
@@ -1,80 +1,162 @@
1
- """LangGraph Agent for GAIA Benchmark"""
2
- from tools.SearchToolkit import wiki_search, web_search, arxiv_search, vector_store
3
- from tools.MathsToolkit import (
4
- multiply, add, subtract, divide, modulus, power, square_root
5
- )
6
- from tools.ImagesToolkit import (
7
- analyze_image,
8
- transform_image,
9
- draw_on_image,
10
- generate_simple_image,
11
- combine_images
12
- )
13
- from tools.DocumentsToolkit import (
14
- save_and_read_file,
15
- download_file_from_url,
16
- extract_text_from_image,
17
- analyze_csv_file,
18
- analyze_excel_file,
19
- analyze_word_file,
20
- analyze_pdf_file
21
- )
22
- from tools.CodeToolkit import execute_code_multilang
23
- from langchain_groq import ChatGroq
24
- from langchain_core.messages import SystemMessage, HumanMessage
25
- from langgraph.prebuilt import tools_condition, ToolNode
26
- from langgraph.graph import START, StateGraph, MessagesState
27
  import os
28
  from dotenv import load_dotenv
29
- # Load environment variables
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  load_dotenv()
31
 
32
- prompt_path = os.path.join(os.path.dirname(__file__), "../prompts")
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- # Load system prompt
36
- with open(os.path.join(prompt_path, "system_prompt.txt"), "r", encoding="utf-8") as f:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  system_prompt = f.read()
 
 
38
  sys_msg = SystemMessage(content=system_prompt)
39
 
40
- # Toolset
41
- tools = [
42
- # SearchToolkit
43
- web_search,
44
- wiki_search,
45
- arxiv_search,
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- # MathsToolkit
 
48
  multiply,
49
  add,
50
  subtract,
51
  divide,
52
  modulus,
53
- power,
54
- square_root,
55
-
56
- # DocumentsToolkit
57
- save_and_read_file,
58
- download_file_from_url,
59
- extract_text_from_image,
60
- analyze_csv_file,
61
- analyze_excel_file,
62
- analyze_word_file,
63
- analyze_pdf_file,
64
-
65
- # CodeToolkit
66
- execute_code_multilang,
67
-
68
- # ImagesToolkit
69
- analyze_image,
70
- transform_image,
71
- draw_on_image,
72
- generate_simple_image,
73
- combine_images,
74
  ]
75
 
76
- # Build LangGraph workflow
77
-
78
 
79
  def build_graph():
80
  llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
@@ -84,14 +166,13 @@ def build_graph():
84
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
85
 
86
  def retriever(state: MessagesState):
87
- similar = vector_store.similarity_search(state["messages"][0].content)
88
- if similar:
89
- reference = HumanMessage(
90
- content=f"Here is a similar Q&A that might help: \n\n{similar[0].page_content}"
91
- )
92
- return {"messages": [sys_msg] + state["messages"] + [reference]}
93
- else:
94
- return {"messages": [sys_msg] + state["messages"]}
95
 
96
  builder = StateGraph(MessagesState)
97
  builder.add_node("retriever", retriever)
 
1
+ """LangGraph Agent"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import os
3
  from dotenv import load_dotenv
4
+ from langgraph.graph import START, StateGraph, MessagesState
5
+ from langgraph.prebuilt import tools_condition
6
+ from langgraph.prebuilt import ToolNode
7
+ from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain_groq import ChatGroq
9
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
+ from langchain_community.tools.tavily_search import TavilySearchResults
11
+ from langchain_community.document_loaders import WikipediaLoader
12
+ from langchain_community.document_loaders import ArxivLoader
13
+ from langchain_community.vectorstores import SupabaseVectorStore
14
+ from langchain_core.messages import SystemMessage, HumanMessage
15
+ from langchain_core.tools import tool
16
+ from langchain.tools.retriever import create_retriever_tool
17
+ from supabase.client import Client, create_client
18
+
19
  load_dotenv()
20
 
 
21
 
22
+ @tool
23
+ def multiply(a: int, b: int) -> int:
24
+ """Multiply two numbers.
25
+ Args:
26
+ a: first int
27
+ b: second int
28
+ """
29
+ return a * b
30
+
31
+
32
+ @tool
33
+ def add(a: int, b: int) -> int:
34
+ """Add two numbers.
35
+
36
+ Args:
37
+ a: first int
38
+ b: second int
39
+ """
40
+ return a + b
41
+
42
+
43
+ @tool
44
+ def subtract(a: int, b: int) -> int:
45
+ """Subtract two numbers.
46
+
47
+ Args:
48
+ a: first int
49
+ b: second int
50
+ """
51
+ return a - b
52
+
53
+
54
+ @tool
55
+ def divide(a: int, b: int) -> int:
56
+ """Divide two numbers.
57
+
58
+ Args:
59
+ a: first int
60
+ b: second int
61
+ """
62
+ if b == 0:
63
+ raise ValueError("Cannot divide by zero.")
64
+ return a / b
65
+
66
+
67
+ @tool
68
+ def modulus(a: int, b: int) -> int:
69
+ """Get the modulus of two numbers.
70
+
71
+ Args:
72
+ a: first int
73
+ b: second int
74
+ """
75
+ return a % b
76
 
77
+
78
+ @tool
79
+ def wiki_search(query: str) -> str:
80
+ """Search Wikipedia for a query and return maximum 2 results.
81
+
82
+ Args:
83
+ query: The search query."""
84
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
85
+ formatted_search_docs = "\n\n---\n\n".join(
86
+ [
87
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
88
+ for doc in search_docs
89
+ ])
90
+ return {"wiki_results": formatted_search_docs}
91
+
92
+
93
+ @tool
94
+ def web_search(query: str) -> str:
95
+ """Search Tavily for a query and return maximum 3 results.
96
+
97
+ Args:
98
+ query: The search query."""
99
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
100
+ formatted_search_docs = "\n\n---\n\n".join(
101
+ [
102
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
103
+ for doc in search_docs
104
+ ])
105
+ return {"web_results": formatted_search_docs}
106
+
107
+
108
+ @tool
109
+ def arvix_search(query: str) -> str:
110
+ """Search Arxiv for a query and return maximum 3 result.
111
+
112
+ Args:
113
+ query: The search query."""
114
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
115
+ formatted_search_docs = "\n\n---\n\n".join(
116
+ [
117
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
118
+ for doc in search_docs
119
+ ])
120
+ return {"arvix_results": formatted_search_docs}
121
+
122
+
123
+ # load the system prompt from the file
124
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
125
  system_prompt = f.read()
126
+
127
+ # System message
128
  sys_msg = SystemMessage(content=system_prompt)
129
 
130
+ # build a retriever
131
+ embeddings = HuggingFaceEmbeddings(
132
+ model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
133
+ supabase: Client = create_client(
134
+ os.environ.get("SUPABASE_URL"),
135
+ os.environ.get("SUPABASE_SERVICE_KEY"))
136
+ vector_store = SupabaseVectorStore(
137
+ client=supabase,
138
+ embedding=embeddings,
139
+ table_name="documents",
140
+ query_name="match_documents_langchain",
141
+ )
142
+ create_retriever_tool = create_retriever_tool(
143
+ retriever=vector_store.as_retriever(),
144
+ name="Question Search",
145
+ description="A tool to retrieve similar questions from a vector store.",
146
+ )
147
 
148
+
149
+ tools = [
150
  multiply,
151
  add,
152
  subtract,
153
  divide,
154
  modulus,
155
+ wiki_search,
156
+ web_search,
157
+ arvix_search,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  ]
159
 
 
 
160
 
161
  def build_graph():
162
  llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
 
166
  return {"messages": [llm_with_tools.invoke(state["messages"])]}
167
 
168
  def retriever(state: MessagesState):
169
+ """Retriever node"""
170
+ similar_question = vector_store.similarity_search(
171
+ state["messages"][0].content)
172
+ example_msg = HumanMessage(
173
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
174
+ )
175
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
 
176
 
177
  builder = StateGraph(MessagesState)
178
  builder.add_node("retriever", retriever)