陈君至
commited on
Commit
·
ec21955
1
Parent(s):
0a411b5
Add application file
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Smurfs/.DS_Store +0 -0
- Smurfs/__init__.py +0 -0
- Smurfs/__pycache__/__init__.cpython-39.pyc +0 -0
- Smurfs/agents/__init__.py +0 -0
- Smurfs/agents/__pycache__/__init__.cpython-39.pyc +0 -0
- Smurfs/agents/__pycache__/base.cpython-39.pyc +0 -0
- Smurfs/agents/answer_agent/__pycache__/answer.cpython-39.pyc +0 -0
- Smurfs/agents/answer_agent/__pycache__/prompt.cpython-39.pyc +0 -0
- Smurfs/agents/answer_agent/answer.py +303 -0
- Smurfs/agents/answer_agent/prompt.py +73 -0
- Smurfs/agents/base.py +51 -0
- Smurfs/agents/executor_agent/__pycache__/__init__.cpython-39.pyc +0 -0
- Smurfs/agents/executor_agent/__pycache__/executor.cpython-39.pyc +0 -0
- Smurfs/agents/executor_agent/__pycache__/prompt.cpython-39.pyc +0 -0
- Smurfs/agents/executor_agent/executor.py +246 -0
- Smurfs/agents/executor_agent/prompt.py +58 -0
- Smurfs/agents/memory_agent/memory_agent.py +0 -0
- Smurfs/agents/memory_agent/prompt.py +16 -0
- Smurfs/agents/planning_agent/__pycache__/planner.cpython-39.pyc +0 -0
- Smurfs/agents/planning_agent/__pycache__/prompt.cpython-39.pyc +0 -0
- Smurfs/agents/planning_agent/planner.py +137 -0
- Smurfs/agents/planning_agent/prompt.py +44 -0
- Smurfs/agents/verifier_agent/__pycache__/prompt.cpython-39.pyc +0 -0
- Smurfs/agents/verifier_agent/__pycache__/verifier.cpython-39.pyc +0 -0
- Smurfs/agents/verifier_agent/prompt.py +25 -0
- Smurfs/agents/verifier_agent/verifier.py +90 -0
- Smurfs/data/.DS_Store +0 -0
- Smurfs/data/__init__.py +0 -0
- Smurfs/data/post_process.py +65 -0
- Smurfs/data/utils.py +53 -0
- Smurfs/deploy/__init__.py +3 -0
- Smurfs/deploy/__pycache__/__init__.cpython-39.pyc +0 -0
- Smurfs/deploy/cli_inference.py +58 -0
- Smurfs/deploy/gradio_inference.py +223 -0
- Smurfs/eval/hotpot_qa/__pycache__/utils.cpython-39.pyc +0 -0
- Smurfs/eval/hotpot_qa/post_process.py +109 -0
- Smurfs/eval/hotpot_qa/run_eval.py +395 -0
- Smurfs/eval/hotpot_qa/utils.py +117 -0
- Smurfs/inference/__init__.py +0 -0
- Smurfs/inference/__pycache__/__init__.cpython-39.pyc +0 -0
- Smurfs/inference/__pycache__/inference.cpython-39.pyc +0 -0
- Smurfs/inference/__pycache__/server.cpython-39.pyc +0 -0
- Smurfs/inference/__pycache__/smurfs_worker.cpython-39.pyc +0 -0
- Smurfs/inference/__pycache__/utils.cpython-39.pyc +0 -0
- Smurfs/inference/functioncall_inference.py +533 -0
- Smurfs/inference/inference.py +527 -0
- Smurfs/inference/server.py +179 -0
- Smurfs/inference/smurfs_worker.py +1040 -0
- Smurfs/inference/utils.py +356 -0
- Smurfs/model/__init__.py +0 -0
Smurfs/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
Smurfs/__init__.py
ADDED
File without changes
|
Smurfs/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (160 Bytes). View file
|
|
Smurfs/agents/__init__.py
ADDED
File without changes
|
Smurfs/agents/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (167 Bytes). View file
|
|
Smurfs/agents/__pycache__/base.cpython-39.pyc
ADDED
Binary file (2.07 kB). View file
|
|
Smurfs/agents/answer_agent/__pycache__/answer.cpython-39.pyc
ADDED
Binary file (6.1 kB). View file
|
|
Smurfs/agents/answer_agent/__pycache__/prompt.cpython-39.pyc
ADDED
Binary file (4.72 kB). View file
|
|
Smurfs/agents/answer_agent/answer.py
ADDED
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from Smurfs.agents.base import BaseAgent
|
2 |
+
from Smurfs.agents.answer_agent.prompt import answer_generation_direct_prompt, answer_generation_prompt, final_answer_generation_prompt, tool_check_prompt, hotpot_answer_parser_prompt
|
3 |
+
from typing import Any
|
4 |
+
|
5 |
+
class answer_agent(BaseAgent):
|
6 |
+
direct_prompt: Any
|
7 |
+
answer_prompt: Any
|
8 |
+
final_prompt: Any
|
9 |
+
tool_check_prompt: Any
|
10 |
+
HP_parser_prompt: Any
|
11 |
+
|
12 |
+
def __init__(self, *args, **kwargs):
|
13 |
+
direct_prompt = answer_generation_direct_prompt
|
14 |
+
answer_prompt = answer_generation_prompt
|
15 |
+
final_prompt = final_answer_generation_prompt
|
16 |
+
check_prompt = tool_check_prompt
|
17 |
+
name = "Answer Agent"
|
18 |
+
kwargs.update({"direct_prompt": direct_prompt})
|
19 |
+
kwargs.update({"answer_prompt": answer_prompt})
|
20 |
+
kwargs.update({"final_prompt": final_prompt})
|
21 |
+
kwargs.update({"tool_check_prompt": check_prompt})
|
22 |
+
kwargs.update({"name": name})
|
23 |
+
kwargs.update({"HP_parser_prompt": hotpot_answer_parser_prompt})
|
24 |
+
super().__init__(
|
25 |
+
*args,
|
26 |
+
**kwargs,
|
27 |
+
)
|
28 |
+
|
29 |
+
def run(self, query_id, task, **kwargs):
|
30 |
+
"""agent run one step"""
|
31 |
+
if task == "direct":
|
32 |
+
self.get_memory(**kwargs)
|
33 |
+
agent_prompt = self.get_prompt(task)
|
34 |
+
message = [{'role': 'user',
|
35 |
+
'content': agent_prompt}]
|
36 |
+
result = self.llm.prediction(message)
|
37 |
+
self.log(query_id, result)
|
38 |
+
self.colorful_print(result, "Answer Directly")
|
39 |
+
return result
|
40 |
+
|
41 |
+
elif task == "answer":
|
42 |
+
self.get_memory(**kwargs)
|
43 |
+
agent_prompt = self.get_prompt(task)
|
44 |
+
message = [{'role': 'user',
|
45 |
+
'content': agent_prompt}]
|
46 |
+
ind = 0
|
47 |
+
while True:
|
48 |
+
try:
|
49 |
+
result = self.llm.prediction(message)
|
50 |
+
self.log(query_id, result)
|
51 |
+
self.colorful_print(result, "Answer Generation")
|
52 |
+
break
|
53 |
+
except Exception as e:
|
54 |
+
print(f"answer generation fails: {e}")
|
55 |
+
self.log(query_id, f"answer generation fails: {e}")
|
56 |
+
if ind > 2:
|
57 |
+
return -1
|
58 |
+
ind += 1
|
59 |
+
continue
|
60 |
+
return result
|
61 |
+
|
62 |
+
elif task == "final":
|
63 |
+
self.get_memory(**kwargs)
|
64 |
+
agent_prompt = self.get_prompt(task)
|
65 |
+
message = [{'role': 'user',
|
66 |
+
'content': agent_prompt}]
|
67 |
+
ind = 0
|
68 |
+
while True:
|
69 |
+
try:
|
70 |
+
result = self.llm.prediction(message)
|
71 |
+
self.log(query_id, result)
|
72 |
+
self.colorful_print(result, "Final Answer Generation")
|
73 |
+
break
|
74 |
+
except Exception as e:
|
75 |
+
print(f"answer generation fails: {e}")
|
76 |
+
self.log(query_id, f"answer generation fails: {e}")
|
77 |
+
if ind > 2:
|
78 |
+
return -1
|
79 |
+
ind += 1
|
80 |
+
continue
|
81 |
+
return result
|
82 |
+
|
83 |
+
elif task == "tool_check":
|
84 |
+
self.get_memory(**kwargs)
|
85 |
+
agent_prompt = self.get_prompt(task)
|
86 |
+
message = [{'role': 'user',
|
87 |
+
'content': agent_prompt}]
|
88 |
+
ind = 0
|
89 |
+
while True:
|
90 |
+
try:
|
91 |
+
result = self.llm.prediction(message)
|
92 |
+
result = eval(result)
|
93 |
+
a = result["Reason"]
|
94 |
+
b = result["Choice"]
|
95 |
+
self.log(query_id, result)
|
96 |
+
self.colorful_print(result, "Tool Check")
|
97 |
+
if 'yes' in b.lower():
|
98 |
+
return -1, a
|
99 |
+
else:
|
100 |
+
return 1, a
|
101 |
+
except Exception as e:
|
102 |
+
print(f"tool check fails: {e}")
|
103 |
+
self.log(query_id, f"tool check fails: {e}")
|
104 |
+
if ind > self.max_retry:
|
105 |
+
return -1, 'fail'
|
106 |
+
ind += 1
|
107 |
+
continue
|
108 |
+
|
109 |
+
elif task == "parse":
|
110 |
+
self.get_memory(**kwargs)
|
111 |
+
agent_prompt = self.get_prompt(task)
|
112 |
+
message = [{'role': 'user',
|
113 |
+
'content': agent_prompt}]
|
114 |
+
ind = 0
|
115 |
+
while True:
|
116 |
+
try:
|
117 |
+
result = self.llm.prediction(message)
|
118 |
+
# result = eval(result)
|
119 |
+
# a = result["Reason"]
|
120 |
+
# b = result["Choice"]
|
121 |
+
self.colorful_print(result, "Parse Answer Hotpot QA")
|
122 |
+
self.log(query_id, result)
|
123 |
+
# if 'yes' in b.lower():
|
124 |
+
# return result, -1
|
125 |
+
# else:
|
126 |
+
# return result, 1
|
127 |
+
return result
|
128 |
+
except Exception as e:
|
129 |
+
print(f"answer parse fails: {e}")
|
130 |
+
self.log(query_id, f"answer parse fails: {e}")
|
131 |
+
if ind > self.max_retry:
|
132 |
+
return "answer parse fails"
|
133 |
+
ind += 1
|
134 |
+
continue
|
135 |
+
|
136 |
+
|
137 |
+
def get_memory(self, **kwargs):
|
138 |
+
"""get relevant memory and add it to agent's memory"""
|
139 |
+
self.memory = kwargs
|
140 |
+
|
141 |
+
def get_prompt(self, task):
|
142 |
+
"""get the prompt for the agent"""
|
143 |
+
if task == "direct":
|
144 |
+
agent_prompt = self.direct_prompt.format(**self.memory)
|
145 |
+
elif task == "answer":
|
146 |
+
agent_prompt = self.answer_prompt.format(**self.memory)
|
147 |
+
elif task == "final":
|
148 |
+
agent_prompt = self.final_prompt.format(**self.memory)
|
149 |
+
elif task == "tool_check":
|
150 |
+
agent_prompt = self.tool_check_prompt.format(**self.memory)
|
151 |
+
elif task == "parse":
|
152 |
+
agent_prompt = self.HP_parser_prompt.format(**self.memory)
|
153 |
+
return agent_prompt
|
154 |
+
|
155 |
+
class stream_answer_agent(BaseAgent):
|
156 |
+
direct_prompt: Any
|
157 |
+
answer_prompt: Any
|
158 |
+
final_prompt: Any
|
159 |
+
tool_check_prompt: Any
|
160 |
+
HP_parser_prompt: Any
|
161 |
+
|
162 |
+
def __init__(self, *args, **kwargs):
|
163 |
+
direct_prompt = answer_generation_direct_prompt
|
164 |
+
answer_prompt = answer_generation_prompt
|
165 |
+
final_prompt = final_answer_generation_prompt
|
166 |
+
check_prompt = tool_check_prompt
|
167 |
+
name = "Answer Agent"
|
168 |
+
kwargs.update({"direct_prompt": direct_prompt})
|
169 |
+
kwargs.update({"answer_prompt": answer_prompt})
|
170 |
+
kwargs.update({"final_prompt": final_prompt})
|
171 |
+
kwargs.update({"tool_check_prompt": check_prompt})
|
172 |
+
kwargs.update({"name": name})
|
173 |
+
kwargs.update({"HP_parser_prompt": hotpot_answer_parser_prompt})
|
174 |
+
super().__init__(
|
175 |
+
*args,
|
176 |
+
**kwargs,
|
177 |
+
)
|
178 |
+
|
179 |
+
def run(self, query_id, task, **kwargs):
|
180 |
+
"""agent run one step"""
|
181 |
+
if task == "direct":
|
182 |
+
self.get_memory(**kwargs)
|
183 |
+
agent_prompt = self.get_prompt(task)
|
184 |
+
message = [{'role': 'user',
|
185 |
+
'content': agent_prompt}]
|
186 |
+
result = self.llm.prediction(message)
|
187 |
+
self.log(query_id, result)
|
188 |
+
self.colorful_print(result, "Answer Directly")
|
189 |
+
return result, "Answer Directly", self.name, result
|
190 |
+
|
191 |
+
elif task == "answer":
|
192 |
+
self.get_memory(**kwargs)
|
193 |
+
agent_prompt = self.get_prompt(task)
|
194 |
+
message = [{'role': 'user',
|
195 |
+
'content': agent_prompt}]
|
196 |
+
ind = 0
|
197 |
+
while True:
|
198 |
+
try:
|
199 |
+
result = self.llm.prediction(message)
|
200 |
+
self.log(query_id, result)
|
201 |
+
self.colorful_print(result, "Answer Generation")
|
202 |
+
break
|
203 |
+
except Exception as e:
|
204 |
+
print(f"answer generation fails: {e}")
|
205 |
+
self.log(query_id, f"answer generation fails: {e}")
|
206 |
+
if ind > 2:
|
207 |
+
return -1, "Answer Generation", self.name, str(e)
|
208 |
+
ind += 1
|
209 |
+
continue
|
210 |
+
return result, "Answer Generation", self.name, result
|
211 |
+
|
212 |
+
elif task == "final":
|
213 |
+
self.get_memory(**kwargs)
|
214 |
+
agent_prompt = self.get_prompt(task)
|
215 |
+
message = [{'role': 'user',
|
216 |
+
'content': agent_prompt}]
|
217 |
+
ind = 0
|
218 |
+
while True:
|
219 |
+
try:
|
220 |
+
result = self.llm.prediction(message)
|
221 |
+
self.log(query_id, result)
|
222 |
+
self.colorful_print(result, "Final Answer Generation")
|
223 |
+
break
|
224 |
+
except Exception as e:
|
225 |
+
print(f"answer generation fails: {e}")
|
226 |
+
self.log(query_id, f"answer generation fails: {e}")
|
227 |
+
if ind > 2:
|
228 |
+
return -1, "Final Answer Generation", self.name, str(e)
|
229 |
+
ind += 1
|
230 |
+
continue
|
231 |
+
return result, "Final Answer Generation", self.name, result
|
232 |
+
|
233 |
+
elif task == "tool_check":
|
234 |
+
self.get_memory(**kwargs)
|
235 |
+
agent_prompt = self.get_prompt(task)
|
236 |
+
message = [{'role': 'user',
|
237 |
+
'content': agent_prompt}]
|
238 |
+
ind = 0
|
239 |
+
while True:
|
240 |
+
try:
|
241 |
+
result = self.llm.prediction(message)
|
242 |
+
result = eval(result)
|
243 |
+
a = result["Reason"]
|
244 |
+
b = result["Choice"]
|
245 |
+
self.log(query_id, result)
|
246 |
+
self.colorful_print(result, "Tool Check")
|
247 |
+
if 'yes' in b.lower():
|
248 |
+
return -1, a
|
249 |
+
else:
|
250 |
+
return 1, a
|
251 |
+
except Exception as e:
|
252 |
+
print(f"tool check fails: {e}")
|
253 |
+
self.log(query_id, f"tool check fails: {e}")
|
254 |
+
if ind > self.max_retry:
|
255 |
+
return -1, 'fail'
|
256 |
+
ind += 1
|
257 |
+
continue
|
258 |
+
|
259 |
+
elif task == "parse":
|
260 |
+
self.get_memory(**kwargs)
|
261 |
+
agent_prompt = self.get_prompt(task)
|
262 |
+
message = [{'role': 'user',
|
263 |
+
'content': agent_prompt}]
|
264 |
+
ind = 0
|
265 |
+
while True:
|
266 |
+
try:
|
267 |
+
result = self.llm.prediction(message)
|
268 |
+
# result = eval(result)
|
269 |
+
# a = result["Reason"]
|
270 |
+
# b = result["Choice"]
|
271 |
+
self.colorful_print(result, "Parse Answer Hotpot QA")
|
272 |
+
self.log(query_id, result)
|
273 |
+
# if 'yes' in b.lower():
|
274 |
+
# return result, -1
|
275 |
+
# else:
|
276 |
+
# return result, 1
|
277 |
+
return result
|
278 |
+
except Exception as e:
|
279 |
+
print(f"answer parse fails: {e}")
|
280 |
+
self.log(query_id, f"answer parse fails: {e}")
|
281 |
+
if ind > self.max_retry:
|
282 |
+
return "answer parse fails"
|
283 |
+
ind += 1
|
284 |
+
continue
|
285 |
+
|
286 |
+
|
287 |
+
def get_memory(self, **kwargs):
|
288 |
+
"""get relevant memory and add it to agent's memory"""
|
289 |
+
self.memory = kwargs
|
290 |
+
|
291 |
+
def get_prompt(self, task):
|
292 |
+
"""get the prompt for the agent"""
|
293 |
+
if task == "direct":
|
294 |
+
agent_prompt = self.direct_prompt.format(**self.memory)
|
295 |
+
elif task == "answer":
|
296 |
+
agent_prompt = self.answer_prompt.format(**self.memory)
|
297 |
+
elif task == "final":
|
298 |
+
agent_prompt = self.final_prompt.format(**self.memory)
|
299 |
+
elif task == "tool_check":
|
300 |
+
agent_prompt = self.tool_check_prompt.format(**self.memory)
|
301 |
+
elif task == "parse":
|
302 |
+
agent_prompt = self.HP_parser_prompt.format(**self.memory)
|
303 |
+
return agent_prompt
|
Smurfs/agents/answer_agent/prompt.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.prompts import PromptTemplate
|
2 |
+
|
3 |
+
tool_check_prompt = """As a powerful language model, you're equipped to answer user's question with accumulated knowledge.
|
4 |
+
However, in some cases, you need to use external APIs to answer accurately.
|
5 |
+
Thus, you need to check whether the user's question requires you to call an external API to solve it.
|
6 |
+
Here are some tips to help you check:
|
7 |
+
1. If the user's question requires real-time information, since your knowledge base isn't updated in real-time, any such question will demand an API call.
|
8 |
+
2. If you need to obtain information (e.g., ID, name, phone number, geographical location, rank, etc.), you need to call the database APIs if you are not sure.
|
9 |
+
3. If the question demand a database search or internet research to generate an answer, this is another situation where an API call is necessary.
|
10 |
+
4. If the question demand coding and math calculation to generate an answer (e.g., algebraic operation, coding problem), you must call external APIs no matter how simple you think it is.
|
11 |
+
If need, please output 'YES'; If not, please output 'NO'
|
12 |
+
You need to give reasons first and then decide whether to keep it or not. You must only output in a parsible JSON format. Two example outputs look like:
|
13 |
+
Example 1: {{\"Reason\": \"The reason why you think you do not need to call an external API to solve the user's question\", \"Choice\": \"No\"}}
|
14 |
+
Example 2: {{\"Reason\": \"The reason why you think you need to call an external API to solve the user's question\", \"Choice\": \"Yes\"}}
|
15 |
+
This is the user's question: {question}
|
16 |
+
Output: """
|
17 |
+
tool_check_prompt = PromptTemplate.from_template(tool_check_prompt)
|
18 |
+
|
19 |
+
answer_generation_prompt = """
|
20 |
+
You should answer the question based on the response output by the API tool.
|
21 |
+
Please note that:
|
22 |
+
1. Answer the question in natural language based on the API response reasonably and effectively.
|
23 |
+
2. The user cannot directly get API response, so you need to make full use of the response and give the information in the response that can satisfy the user's question in as much detail as possible.
|
24 |
+
3. Do not output answer that is too long. Output in 3-6 sentences is OK.
|
25 |
+
|
26 |
+
This is the user's question:
|
27 |
+
{question}
|
28 |
+
This is the API response:
|
29 |
+
{call_result}
|
30 |
+
Output:"""
|
31 |
+
answer_generation_prompt = PromptTemplate.from_template(answer_generation_prompt)
|
32 |
+
|
33 |
+
final_answer_generation_prompt = """
|
34 |
+
You will be given a complex question and you need to solve it step by step by decomposing it to a series of subtasks that can be solved using a single tool(functions).
|
35 |
+
At this step, you need to analyse the previous subtasks and their execution result to generate the answer to the original question reasonably and accurately.
|
36 |
+
Please note that:
|
37 |
+
1. Answer the question in natural language based on the subtask results reasonably and effectively.
|
38 |
+
2. The user cannot directly get the subtask results, so you need to make full use of the subtask results and give the information in the response that can satisfy the user's question in as much detail as possible.
|
39 |
+
This is the user's question:
|
40 |
+
{question}
|
41 |
+
There are logs of previous subtasks and execution results:
|
42 |
+
{previous_log}
|
43 |
+
Output:"""
|
44 |
+
final_answer_generation_prompt = PromptTemplate.from_template(final_answer_generation_prompt)
|
45 |
+
|
46 |
+
answer_generation_direct_prompt = """"You need to answer the user's question.
|
47 |
+
This is the user's question: {question}
|
48 |
+
Output:"""
|
49 |
+
answer_generation_direct_prompt = PromptTemplate.from_template(answer_generation_direct_prompt)
|
50 |
+
|
51 |
+
hotpot_answer_parser_prompt = """
|
52 |
+
You will need to extract a concise answer from the detailed answer to answer the question in a consice language style.
|
53 |
+
Only output your concise answer to the answer.
|
54 |
+
For example:
|
55 |
+
Question:
|
56 |
+
VIVA Media AG changed it's name in 2004. What does their new acronym stand for?
|
57 |
+
Detailed Answer:
|
58 |
+
The new name of VIVA Media AG since its change in 2004 is "VIVA Media GmbH". In this acronym, "GmbH" is a German term which means "Gesellschaft mit beschränkter Haftung", translating to "company with limited liability" in English. So, the acronym denotes that it is a type of business organization similar to a limited liability company (LLC). VIVA Media GmbH is a company that specializes in publishing, producing, and developing high-quality games for different platforms.
|
59 |
+
Output:
|
60 |
+
Gesellschaft mit beschränkter Haftung
|
61 |
+
|
62 |
+
Question:
|
63 |
+
Jaclyn Stapp is married to the former frontman of a band that disbanded in what year?
|
64 |
+
Detailed Answer:
|
65 |
+
The band Creed effectively ended on December 29, 2002. However, they had a reunion tour that started on August 6, 2009, and ended on October 20, 2009. They also released an album called "Full Circle" on October 27, 2009. Despite these reunions, the band's meteoric rise came to a halt when it split up again in 2004.
|
66 |
+
Output:
|
67 |
+
2004
|
68 |
+
|
69 |
+
Question:
|
70 |
+
{question}
|
71 |
+
Detailed Answer:
|
72 |
+
{detailed_answer}
|
73 |
+
Output:"""
|
Smurfs/agents/base.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
from pydantic import BaseModel, Field
|
3 |
+
from Smurfs.model.base import BaseLM
|
4 |
+
# import json
|
5 |
+
import os
|
6 |
+
from langchain.prompts import PromptTemplate
|
7 |
+
from typing import Any
|
8 |
+
from termcolor import colored
|
9 |
+
|
10 |
+
class BaseAgent(BaseModel):
|
11 |
+
name: str
|
12 |
+
llm: BaseLM
|
13 |
+
prompt: Any
|
14 |
+
logger_dir: str
|
15 |
+
memory: dict = Field(default={})
|
16 |
+
max_retry: int = Field(default=10)
|
17 |
+
|
18 |
+
@abstractmethod
|
19 |
+
def run(self, **kwargs):
|
20 |
+
"""agent run one step"""
|
21 |
+
pass
|
22 |
+
|
23 |
+
@abstractmethod
|
24 |
+
def get_memory(self, **kwargs):
|
25 |
+
"""get relevant memory and add it to agent's memory"""
|
26 |
+
pass
|
27 |
+
|
28 |
+
@abstractmethod
|
29 |
+
def get_prompt(self, **kwargs):
|
30 |
+
"""get the prompt for the agent"""
|
31 |
+
pass
|
32 |
+
|
33 |
+
def log(self, query_id, content):
|
34 |
+
"""write log to the logger file"""
|
35 |
+
logger_file = os.path.join(self.logger_dir, f"{query_id}.txt")
|
36 |
+
with open(logger_file, "a+") as file:
|
37 |
+
file.write("\n##########\n")
|
38 |
+
file.write(f"{self.name}: \n\n")
|
39 |
+
file.write(str(content))
|
40 |
+
file.write("\n##########\n")
|
41 |
+
|
42 |
+
def colorful_print(self, content, task):
|
43 |
+
"""print out message in different color"""
|
44 |
+
role_to_color = {
|
45 |
+
"Answer Agent": "red",
|
46 |
+
"Executor Agent": "green",
|
47 |
+
"Planning Agent": "blue",
|
48 |
+
"Verifier Agent": "yellow",
|
49 |
+
}
|
50 |
+
|
51 |
+
print(colored(f"##########{task}##########\n{content}\n", role_to_color[self.name]))
|
Smurfs/agents/executor_agent/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (176 Bytes). View file
|
|
Smurfs/agents/executor_agent/__pycache__/executor.cpython-39.pyc
ADDED
Binary file (5.79 kB). View file
|
|
Smurfs/agents/executor_agent/__pycache__/prompt.cpython-39.pyc
ADDED
Binary file (2.44 kB). View file
|
|
Smurfs/agents/executor_agent/executor.py
ADDED
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from Smurfs.agents.base import BaseAgent
|
2 |
+
from Smurfs.agents.executor_agent.prompt import generate_thought_prompt, choose_tool_prompt, choose_parameter_prompt
|
3 |
+
from Smurfs.inference.utils import change_name, standardize, contain
|
4 |
+
from Smurfs.inference.server import get_rapidapi_response
|
5 |
+
from typing import Any
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
import time
|
9 |
+
import requests
|
10 |
+
|
11 |
+
class executor_agent(BaseAgent):
|
12 |
+
thought_prompt: Any
|
13 |
+
tool_prompt: Any
|
14 |
+
parameter_prompt: Any
|
15 |
+
|
16 |
+
def __init__(self, *args, **kwargs):
|
17 |
+
thought_prompt = generate_thought_prompt
|
18 |
+
tool_prompt = choose_tool_prompt
|
19 |
+
parameter_prompt = choose_parameter_prompt
|
20 |
+
name = "Executor Agent"
|
21 |
+
kwargs.update({"thought_prompt": thought_prompt})
|
22 |
+
kwargs.update({"tool_prompt": tool_prompt})
|
23 |
+
kwargs.update({"parameter_prompt": parameter_prompt})
|
24 |
+
kwargs.update({"name": name})
|
25 |
+
super().__init__(
|
26 |
+
*args,
|
27 |
+
**kwargs,
|
28 |
+
)
|
29 |
+
|
30 |
+
def run(self, query_id, task, **kwargs):
|
31 |
+
"""agent run one step"""
|
32 |
+
if task == "thought":
|
33 |
+
self.get_memory(**kwargs)
|
34 |
+
agent_prompt = self.get_prompt(task)
|
35 |
+
message = [{'role': 'user',
|
36 |
+
'content': agent_prompt}]
|
37 |
+
ind = 0
|
38 |
+
|
39 |
+
while True:
|
40 |
+
try:
|
41 |
+
result = self.llm.prediction(message)
|
42 |
+
self.log(query_id, result)
|
43 |
+
self.colorful_print(result, "Thought Generation")
|
44 |
+
return result
|
45 |
+
except Exception as e:
|
46 |
+
print(f"generating thought fails: {e}")
|
47 |
+
self.log(query_id, f"generating thought fails: {e}")
|
48 |
+
if ind > self.max_retry:
|
49 |
+
return -1
|
50 |
+
ind += 1
|
51 |
+
continue
|
52 |
+
|
53 |
+
elif task == "tool":
|
54 |
+
thought = kwargs["thought"]
|
55 |
+
kwargs["question"] = kwargs["question"]+f"thought: {thought}\n"
|
56 |
+
del kwargs["thought"]
|
57 |
+
self.get_memory(**kwargs)
|
58 |
+
agent_prompt = self.get_prompt(task)
|
59 |
+
message = [{'role': 'user',
|
60 |
+
'content': agent_prompt}]
|
61 |
+
ind = 0
|
62 |
+
while True:
|
63 |
+
try:
|
64 |
+
result = self.llm.prediction(message)
|
65 |
+
start = result.find("{")
|
66 |
+
end = result.rfind("}")
|
67 |
+
result = eval(result[start:end+1])
|
68 |
+
self.colorful_print(result, "Choose Tool")
|
69 |
+
tool = result['ID']
|
70 |
+
self.log(query_id, result)
|
71 |
+
return tool
|
72 |
+
except Exception as e:
|
73 |
+
print(f"choosing tool fails: {e}")
|
74 |
+
self.log(query_id, f"choosing tool fails: {e}")
|
75 |
+
if ind > self.max_retry:
|
76 |
+
return -1
|
77 |
+
ind += 1
|
78 |
+
continue
|
79 |
+
|
80 |
+
elif task == "parameter":
|
81 |
+
thought = kwargs["thought"]
|
82 |
+
del kwargs["thought"]
|
83 |
+
kwargs["question"] = kwargs["question"]+f"thought: {thought}\n"
|
84 |
+
api_dic = kwargs["api_dic"]
|
85 |
+
if len(api_dic["required_parameters"]) == 0 and len(api_dic["optional_parameters"]) == 0:
|
86 |
+
return {}
|
87 |
+
self.get_memory(**kwargs)
|
88 |
+
agent_prompt = self.get_prompt(task)
|
89 |
+
message = [{'role': 'user',
|
90 |
+
'content': agent_prompt}]
|
91 |
+
|
92 |
+
ind = 0
|
93 |
+
while True:
|
94 |
+
try:
|
95 |
+
result = self.llm.prediction(message)
|
96 |
+
start = result.find("{")
|
97 |
+
end = result.rfind("}")
|
98 |
+
self.colorful_print(result[start:end+1], "Generate Parameters")
|
99 |
+
result = result[start:end+1]
|
100 |
+
clean_answer = eval(
|
101 |
+
result.replace(": true", ": True").replace(": false", ": False").replace("```", "").strip())
|
102 |
+
# a = clean_answer["Parameters"]
|
103 |
+
# clean_answer = clean_answer["Parameters"]
|
104 |
+
self.log(query_id, clean_answer)
|
105 |
+
return clean_answer
|
106 |
+
except Exception as e:
|
107 |
+
print(f"choose parameter fails: {e}")
|
108 |
+
self.log(query_id, f"choose parameter fails: {e}")
|
109 |
+
if ind > self.max_retry:
|
110 |
+
return -1
|
111 |
+
ind += 1
|
112 |
+
continue
|
113 |
+
|
114 |
+
def get_memory(self, **kwargs):
|
115 |
+
"""get relevant memory and add it to agent's memory"""
|
116 |
+
self.memory = kwargs
|
117 |
+
|
118 |
+
def get_prompt(self, task):
|
119 |
+
"""get the prompt for the agent"""
|
120 |
+
if task == "thought":
|
121 |
+
agent_prompt = self.thought_prompt.format(**self.memory)
|
122 |
+
elif task == "tool":
|
123 |
+
agent_prompt = self.tool_prompt.format(**self.memory)
|
124 |
+
elif task == "parameter":
|
125 |
+
agent_prompt = self.parameter_prompt.format(**self.memory)
|
126 |
+
|
127 |
+
return agent_prompt
|
128 |
+
|
129 |
+
class stream_executor_agent(BaseAgent):
|
130 |
+
thought_prompt: Any
|
131 |
+
tool_prompt: Any
|
132 |
+
parameter_prompt: Any
|
133 |
+
|
134 |
+
def __init__(self, *args, **kwargs):
|
135 |
+
thought_prompt = generate_thought_prompt
|
136 |
+
tool_prompt = choose_tool_prompt
|
137 |
+
parameter_prompt = choose_parameter_prompt
|
138 |
+
name = "Executor Agent"
|
139 |
+
kwargs.update({"thought_prompt": thought_prompt})
|
140 |
+
kwargs.update({"tool_prompt": tool_prompt})
|
141 |
+
kwargs.update({"parameter_prompt": parameter_prompt})
|
142 |
+
kwargs.update({"name": name})
|
143 |
+
super().__init__(
|
144 |
+
*args,
|
145 |
+
**kwargs,
|
146 |
+
)
|
147 |
+
|
148 |
+
def run(self, query_id, task, **kwargs):
|
149 |
+
"""agent run one step"""
|
150 |
+
if task == "thought":
|
151 |
+
self.get_memory(**kwargs)
|
152 |
+
agent_prompt = self.get_prompt(task)
|
153 |
+
message = [{'role': 'user',
|
154 |
+
'content': agent_prompt}]
|
155 |
+
ind = 0
|
156 |
+
|
157 |
+
while True:
|
158 |
+
try:
|
159 |
+
result = self.llm.prediction(message)
|
160 |
+
self.log(query_id, result)
|
161 |
+
self.colorful_print(result, "Thought Generation")
|
162 |
+
return result, "Thought Generation", self.name, result
|
163 |
+
except Exception as e:
|
164 |
+
print(f"generating thought fails: {e}")
|
165 |
+
self.log(query_id, f"generating thought fails: {e}")
|
166 |
+
if ind > self.max_retry:
|
167 |
+
return -1, "Thought Generation", self.name, str(e)
|
168 |
+
ind += 1
|
169 |
+
continue
|
170 |
+
|
171 |
+
elif task == "tool":
|
172 |
+
thought = kwargs["thought"]
|
173 |
+
kwargs["question"] = kwargs["question"]+f"thought: {thought}\n"
|
174 |
+
del kwargs["thought"]
|
175 |
+
self.get_memory(**kwargs)
|
176 |
+
agent_prompt = self.get_prompt(task)
|
177 |
+
message = [{'role': 'user',
|
178 |
+
'content': agent_prompt}]
|
179 |
+
ind = 0
|
180 |
+
while True:
|
181 |
+
try:
|
182 |
+
result = self.llm.prediction(message)
|
183 |
+
start = result.find("{")
|
184 |
+
end = result.rfind("}")
|
185 |
+
result = eval(result[start:end+1])
|
186 |
+
self.colorful_print(result, "Choose Tool")
|
187 |
+
tool = result['ID']
|
188 |
+
self.log(query_id, result)
|
189 |
+
return tool, "Choose Tool", self.name, result
|
190 |
+
except Exception as e:
|
191 |
+
print(f"choosing tool fails: {e}")
|
192 |
+
self.log(query_id, f"choosing tool fails: {e}")
|
193 |
+
if ind > self.max_retry:
|
194 |
+
return -1, "Choose Tool", self.name, str(e)
|
195 |
+
ind += 1
|
196 |
+
continue
|
197 |
+
|
198 |
+
elif task == "parameter":
|
199 |
+
thought = kwargs["thought"]
|
200 |
+
del kwargs["thought"]
|
201 |
+
kwargs["question"] = kwargs["question"]+f"thought: {thought}\n"
|
202 |
+
api_dic = kwargs["api_dic"]
|
203 |
+
if len(api_dic["required_parameters"]) == 0 and len(api_dic["optional_parameters"]) == 0:
|
204 |
+
return {}
|
205 |
+
self.get_memory(**kwargs)
|
206 |
+
agent_prompt = self.get_prompt(task)
|
207 |
+
message = [{'role': 'user',
|
208 |
+
'content': agent_prompt}]
|
209 |
+
|
210 |
+
ind = 0
|
211 |
+
while True:
|
212 |
+
try:
|
213 |
+
result = self.llm.prediction(message)
|
214 |
+
start = result.find("{")
|
215 |
+
end = result.rfind("}")
|
216 |
+
self.colorful_print(result[start:end+1], "Generate Parameters")
|
217 |
+
result = result[start:end+1]
|
218 |
+
clean_answer = eval(
|
219 |
+
result.replace(": true", ": True").replace(": false", ": False").replace("```", "").strip())
|
220 |
+
# a = clean_answer["Parameters"]
|
221 |
+
# clean_answer = clean_answer["Parameters"]
|
222 |
+
self.log(query_id, clean_answer)
|
223 |
+
return clean_answer, "Generate Parameters", self.name, result[start:end+1]
|
224 |
+
except Exception as e:
|
225 |
+
print(f"choose parameter fails: {e}")
|
226 |
+
self.log(query_id, f"choose parameter fails: {e}")
|
227 |
+
if ind > self.max_retry:
|
228 |
+
return -1, "Generate Parameters", self.name, str(e)
|
229 |
+
ind += 1
|
230 |
+
continue
|
231 |
+
|
232 |
+
def get_memory(self, **kwargs):
|
233 |
+
"""get relevant memory and add it to agent's memory"""
|
234 |
+
self.memory = kwargs
|
235 |
+
|
236 |
+
def get_prompt(self, task):
|
237 |
+
"""get the prompt for the agent"""
|
238 |
+
if task == "thought":
|
239 |
+
agent_prompt = self.thought_prompt.format(**self.memory)
|
240 |
+
elif task == "tool":
|
241 |
+
agent_prompt = self.tool_prompt.format(**self.memory)
|
242 |
+
elif task == "parameter":
|
243 |
+
agent_prompt = self.parameter_prompt.format(**self.memory)
|
244 |
+
|
245 |
+
return agent_prompt
|
246 |
+
|
Smurfs/agents/executor_agent/prompt.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.prompts import PromptTemplate
|
2 |
+
|
3 |
+
generate_thought_prompt = """\
|
4 |
+
You need to analyse the previous execution history and generate your internal reasoning and thoughts on the task, and how you plan to solve it based on the current attempts.
|
5 |
+
|
6 |
+
Do not output thought that is too long. Output in 2-3 sentences is OK.
|
7 |
+
|
8 |
+
This is the user's task:
|
9 |
+
{question}
|
10 |
+
|
11 |
+
This is the Tool List:
|
12 |
+
{tool_list}
|
13 |
+
|
14 |
+
This is the previous execution history:
|
15 |
+
{previous_log}
|
16 |
+
|
17 |
+
This is the hint comes from the evaluator:
|
18 |
+
{hint}
|
19 |
+
|
20 |
+
Output:"""
|
21 |
+
generate_thought_prompt = PromptTemplate.from_template(generate_thought_prompt)
|
22 |
+
|
23 |
+
choose_tool_prompt = """\
|
24 |
+
This is the user's question:
|
25 |
+
{question}
|
26 |
+
These are the tools you can select to solve the question:
|
27 |
+
Tool List:
|
28 |
+
{tool_list}
|
29 |
+
|
30 |
+
Please note that:
|
31 |
+
1. You should only chooce one tool from the Tool List to solve this question.
|
32 |
+
2. You must ONLY output the ID of the tool and your reason for choosing it in a parsible JSON format. An example output looks like:
|
33 |
+
'''
|
34 |
+
Example: {{\"ID\": ID of the tool, \"Reason\": The reason for choosing the tool}}
|
35 |
+
'''
|
36 |
+
Output: """
|
37 |
+
choose_tool_prompt = PromptTemplate.from_template(choose_tool_prompt)
|
38 |
+
|
39 |
+
choose_parameter_prompt="""\
|
40 |
+
Given a user's question and a API tool documentation, you need to output parameters according to the API tool documentation to successfully call the API to solve the user's question.
|
41 |
+
Please note that:
|
42 |
+
1. The Example in the API tool documentation can help you better understand the use of the API.
|
43 |
+
2. Ensure the parameters you output are correct. The output must contain the required parameters, and can contain the optional parameters based on the question. If no paremters in the required parameters and optional parameters, just leave it as {{}}
|
44 |
+
3. If the user's question mentions other APIs, you should ONLY consider the API tool documentation I give and do not consider other APIs.
|
45 |
+
4. The question may have dependencies on answers of other questions, so we will provide logs of previous questions and answers for your reference.
|
46 |
+
5. You must ONLY output in a parsible JSON Format. The example output looks like:
|
47 |
+
'''
|
48 |
+
Example: {{\"keyword\": \"Artificial Intelligence\", \"language\": \"English\"}}
|
49 |
+
'''
|
50 |
+
|
51 |
+
There are logs of previous questions and answers:
|
52 |
+
{previous_log}
|
53 |
+
|
54 |
+
This is the current user's question: {question}
|
55 |
+
|
56 |
+
This is API tool documentation: {api_dic}
|
57 |
+
Output:"""
|
58 |
+
choose_parameter_prompt = PromptTemplate.from_template(choose_parameter_prompt)
|
Smurfs/agents/memory_agent/memory_agent.py
ADDED
File without changes
|
Smurfs/agents/memory_agent/prompt.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.prompts import PromptTemplate
|
2 |
+
|
3 |
+
mem_choose_prompt = """You are a memory agent that controls the memory of the agent system.
|
4 |
+
The agent system is trying to solve a complex question step by step by solving its subtasks one by one.
|
5 |
+
Among those subtasks, some subtask may need execution history of other subtasks to be solved.
|
6 |
+
Your task is to decide which subtasks' execution history is needed by the agent system to solve the current subtask.
|
7 |
+
Please note that:
|
8 |
+
1. If the current subtask is independent of the other subtasks, just output {{\"task\":}}
|
9 |
+
2.
|
10 |
+
You must only output in a parsible JSON format. Two example outputs look like:
|
11 |
+
Example 1: {{\"Reason\": \"The reason why you think you do not need to call an external API to solve the user's question\", \"Choice\": \"No\"}}
|
12 |
+
Example 2: {{\"Reason\": \"The reason why you think you need to call an external API to solve the user's question\", \"Choice\": \"Yes\"}}
|
13 |
+
This is the current subtask: {question}
|
14 |
+
This is the previous execution history: {history}
|
15 |
+
Output: """
|
16 |
+
tool_check_prompt = PromptTemplate.from_template(mem_choose_prompt)
|
Smurfs/agents/planning_agent/__pycache__/planner.cpython-39.pyc
ADDED
Binary file (3.99 kB). View file
|
|
Smurfs/agents/planning_agent/__pycache__/prompt.cpython-39.pyc
ADDED
Binary file (3.1 kB). View file
|
|
Smurfs/agents/planning_agent/planner.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from Smurfs.agents.base import BaseAgent
|
2 |
+
from Smurfs.agents.planning_agent.prompt import task_decompose_prompt, hotpot_task_decompose_prompt
|
3 |
+
|
4 |
+
class planning_agent(BaseAgent):
|
5 |
+
def __init__(self, llm, logger_dir):
|
6 |
+
super().__init__(
|
7 |
+
prompt = task_decompose_prompt,
|
8 |
+
llm = llm,
|
9 |
+
name = "Planning Agent",
|
10 |
+
logger_dir = logger_dir
|
11 |
+
)
|
12 |
+
|
13 |
+
def run(self, question, query_id):
|
14 |
+
"""agent run one step"""
|
15 |
+
self.get_memory(question)
|
16 |
+
agent_prompt = self.get_prompt()
|
17 |
+
message = [{'role': 'user',
|
18 |
+
'content': agent_prompt}]
|
19 |
+
ind = 0
|
20 |
+
while True:
|
21 |
+
try:
|
22 |
+
result = self.llm.prediction(message)
|
23 |
+
# print(result)
|
24 |
+
start = result.find("{")
|
25 |
+
end = result.find("}")
|
26 |
+
result = eval(result[start:end+1])
|
27 |
+
self.colorful_print(result, "Task Decompose")
|
28 |
+
subtasks = result['Tasks']
|
29 |
+
self.log(query_id, result)
|
30 |
+
# print(a)
|
31 |
+
return subtasks
|
32 |
+
except Exception as e:
|
33 |
+
print(f"task deompose fails: {e}")
|
34 |
+
self.log(query_id, f"task deompose fails: {e}")
|
35 |
+
if ind > self.max_retry:
|
36 |
+
return -1
|
37 |
+
ind += 1
|
38 |
+
continue
|
39 |
+
|
40 |
+
def get_memory(self, question):
|
41 |
+
"""get relevant memory and add it to agent's memory"""
|
42 |
+
self.memory = {"question": question}
|
43 |
+
|
44 |
+
def get_prompt(self):
|
45 |
+
"""get the prompt for the agent"""
|
46 |
+
agent_prompt = self.prompt.format(**self.memory)
|
47 |
+
return agent_prompt
|
48 |
+
|
49 |
+
class hotpot_planning_agent(BaseAgent):
|
50 |
+
def __init__(self, llm, logger_dir):
|
51 |
+
super().__init__(
|
52 |
+
prompt = hotpot_task_decompose_prompt,
|
53 |
+
llm = llm,
|
54 |
+
name = "Planning Agent",
|
55 |
+
logger_dir = logger_dir
|
56 |
+
)
|
57 |
+
|
58 |
+
def run(self, question, query_id):
|
59 |
+
"""agent run one step"""
|
60 |
+
self.get_memory(question)
|
61 |
+
agent_prompt = self.get_prompt()
|
62 |
+
message = [{'role': 'user',
|
63 |
+
'content': agent_prompt}]
|
64 |
+
ind = 0
|
65 |
+
while True:
|
66 |
+
try:
|
67 |
+
result = self.llm.prediction(message)
|
68 |
+
# print(result)
|
69 |
+
start = result.find("{")
|
70 |
+
end = result.find("}")
|
71 |
+
result = eval(result[start:end+1])
|
72 |
+
self.colorful_print(result, "Task Decompose")
|
73 |
+
subtasks = result['Tasks']
|
74 |
+
self.log(query_id, result)
|
75 |
+
# print(a)
|
76 |
+
return subtasks
|
77 |
+
except Exception as e:
|
78 |
+
print(f"task deompose fails: {e}")
|
79 |
+
self.log(query_id, f"task deompose fails: {e}")
|
80 |
+
if ind > self.max_retry:
|
81 |
+
return -1
|
82 |
+
ind += 1
|
83 |
+
continue
|
84 |
+
|
85 |
+
def get_memory(self, question):
|
86 |
+
"""get relevant memory and add it to agent's memory"""
|
87 |
+
self.memory = {"question": question}
|
88 |
+
|
89 |
+
def get_prompt(self):
|
90 |
+
"""get the prompt for the agent"""
|
91 |
+
agent_prompt = self.prompt.format(**self.memory)
|
92 |
+
return agent_prompt
|
93 |
+
|
94 |
+
class stream_hotpot_planning_agent(BaseAgent):
|
95 |
+
def __init__(self, llm, logger_dir):
|
96 |
+
super().__init__(
|
97 |
+
prompt = hotpot_task_decompose_prompt,
|
98 |
+
llm = llm,
|
99 |
+
name = "Planning Agent",
|
100 |
+
logger_dir = logger_dir
|
101 |
+
)
|
102 |
+
|
103 |
+
def run(self, question, query_id):
|
104 |
+
"""agent run one step"""
|
105 |
+
self.get_memory(question)
|
106 |
+
agent_prompt = self.get_prompt()
|
107 |
+
message = [{'role': 'user',
|
108 |
+
'content': agent_prompt}]
|
109 |
+
ind = 0
|
110 |
+
while True:
|
111 |
+
try:
|
112 |
+
result = self.llm.prediction(message)
|
113 |
+
# print(result)
|
114 |
+
start = result.find("{")
|
115 |
+
end = result.find("}")
|
116 |
+
result = eval(result[start:end+1])
|
117 |
+
self.colorful_print(result, "Task Decompose")
|
118 |
+
subtasks = result['Tasks']
|
119 |
+
self.log(query_id, result)
|
120 |
+
# print(a)
|
121 |
+
return subtasks, "Task Decompose", self.name, result
|
122 |
+
except Exception as e:
|
123 |
+
print(f"task deompose fails: {e}")
|
124 |
+
self.log(query_id, f"task deompose fails: {e}")
|
125 |
+
if ind > 5:
|
126 |
+
return -1, "Task Decompose", self.name, str(e)
|
127 |
+
ind += 1
|
128 |
+
continue
|
129 |
+
|
130 |
+
def get_memory(self, question):
|
131 |
+
"""get relevant memory and add it to agent's memory"""
|
132 |
+
self.memory = {"question": question}
|
133 |
+
|
134 |
+
def get_prompt(self):
|
135 |
+
"""get the prompt for the agent"""
|
136 |
+
agent_prompt = self.prompt.format(**self.memory)
|
137 |
+
return agent_prompt
|
Smurfs/agents/planning_agent/prompt.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.prompts import PromptTemplate
|
2 |
+
|
3 |
+
task_decompose_prompt = """
|
4 |
+
You need to decompose a complex user's question into some simple subtasks and let the model execute it step by step.
|
5 |
+
Please note that:
|
6 |
+
1. You should only decompose this complex user's question into some simple subtasks which can be executed easily by using a single tool.
|
7 |
+
2. Each simple subtask should be expressed into natural language.
|
8 |
+
3. Each subtask should contain the necessary information from the original question and should be complete, explicit and self-consistent.
|
9 |
+
4. You must ONLY output in a parsible JSON format. An example output looks like:
|
10 |
+
'''
|
11 |
+
{{\"Tasks\": [\"Task 1\", \"Task 2\", ...]}}
|
12 |
+
'''
|
13 |
+
|
14 |
+
This is the user's question: I'm planning a trip to Turkey and need information about postal codes in Istanbul. Can you provide me with the postal code and district for Istanbul province with plate number 34? Additionally, I would like to know if there are any transit agencies available in Istanbul. Please fetch their names and contact numbers.
|
15 |
+
Output: {{\"Tasks\": [\"Find the postal codes and districts for plate number 34 in Istanbul.\", \"Search for transit agencies and their contact numbers in Istanbul.\"]}}
|
16 |
+
|
17 |
+
This is the user's question: I recently moved to a new address and I need to update my information. Can you retrieve my address details using the postal code 75094080? Additionally, I would like to know the companies that offer shipping services.
|
18 |
+
Output: {{\"Tasks\": [\"retrieve the address details using the postal code 75094080\", \"search for companies that offer shipping services to my address\"]}}
|
19 |
+
|
20 |
+
This is the user's question: {question}
|
21 |
+
Output:
|
22 |
+
"""
|
23 |
+
task_decompose_prompt = PromptTemplate.from_template(task_decompose_prompt)
|
24 |
+
|
25 |
+
hotpot_task_decompose_prompt = """
|
26 |
+
You need to decompose a complex user's question into some simple subtasks and let the model execute it step by step.
|
27 |
+
Please note that:
|
28 |
+
1. You should only decompose this complex user's question into some simple subtasks which can be executed easily by using a single tool.
|
29 |
+
2. Each simple subtask should be expressed into natural language.
|
30 |
+
3. Each subtask should contain the necessary information from the original question and should be complete, explicit and self-consistent.
|
31 |
+
4. You must ONLY output in a parsible JSON format. An example output looks like:
|
32 |
+
'''
|
33 |
+
{{\"Tasks\": [\"Task 1\", \"Task 2\", ...]}}
|
34 |
+
'''
|
35 |
+
|
36 |
+
This is the user's question: What government position was held by the woman who portrayed Corliss Archer in the film Kiss and Tell?
|
37 |
+
Output: {{\"Tasks\": [\"In the film Kiss and Tell, who is the woman who portrayed Corliss Archer?\", \"What government position was held by this woman?\"]}}
|
38 |
+
|
39 |
+
This is the user's question: Were Scott Derrickson and Ed Wood of the same nationality?
|
40 |
+
Output: {{\"Tasks\": [\"search for the nationality of Scott Derrickson\", \"search for the nationality for Ed Wood\", \"Compare whether they have the same nationality\"]}}
|
41 |
+
|
42 |
+
This is the user's question: {question}
|
43 |
+
Output:
|
44 |
+
"""
|
Smurfs/agents/verifier_agent/__pycache__/prompt.cpython-39.pyc
ADDED
Binary file (1.51 kB). View file
|
|
Smurfs/agents/verifier_agent/__pycache__/verifier.cpython-39.pyc
ADDED
Binary file (3.03 kB). View file
|
|
Smurfs/agents/verifier_agent/prompt.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.prompts import PromptTemplate
|
2 |
+
|
3 |
+
final_answer_check_prompt = """
|
4 |
+
An agent is trying to solve the query proposed by the user. \
|
5 |
+
You need to evaluate whether the given query has been completed reasonably and accurately. If so, summarize the solution to the user. If not, summarize the current progress, and propose what is missing.
|
6 |
+
|
7 |
+
You response contains following elements:
|
8 |
+
Speak: (your words to the agent if the task isn't completed, or a complete answer based on the full execution log to the user if the task is finished)
|
9 |
+
Status: (0 or 1. 0 for unfinished and 1 for finished)
|
10 |
+
|
11 |
+
Please note that:
|
12 |
+
1. If the answer says the query can't be solved or it can't answer the query given the current information, please output Status as 0.
|
13 |
+
2. Only output Status as 1 if the query has been answered correctly and accurately.
|
14 |
+
3. If the answer only give a plan instead of a detailed answer, output Status as 0.
|
15 |
+
|
16 |
+
You must only output in a parsible JSON format. Two example outputs look like:
|
17 |
+
Example 1: {{\"Speak\": \"answer based on the full execution log to the user\", \"Status\": \"1\"}}
|
18 |
+
Example 2: {{\"Speak\": \"your words to the group if the task isn't solved\", \"Status\": \"0\"}}
|
19 |
+
|
20 |
+
This is the answer from the previous execution result:
|
21 |
+
{answer}
|
22 |
+
|
23 |
+
This is the original question: {question}
|
24 |
+
Output: """
|
25 |
+
final_answer_check_prompt = PromptTemplate.from_template(final_answer_check_prompt)
|
Smurfs/agents/verifier_agent/verifier.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from Smurfs.agents.base import BaseAgent
|
2 |
+
from Smurfs.agents.verifier_agent.prompt import final_answer_check_prompt
|
3 |
+
|
4 |
+
class verifier_agent(BaseAgent):
|
5 |
+
def __init__(self, llm, logger_dir):
|
6 |
+
super().__init__(
|
7 |
+
prompt = final_answer_check_prompt,
|
8 |
+
llm = llm,
|
9 |
+
name = "Verifier Agent",
|
10 |
+
logger_dir = logger_dir
|
11 |
+
)
|
12 |
+
|
13 |
+
def run(self, question, answer, query_id):
|
14 |
+
"""agent run one step"""
|
15 |
+
self.get_memory(question, answer)
|
16 |
+
agent_prompt = self.get_prompt()
|
17 |
+
message = [{'role': 'user',
|
18 |
+
'content': agent_prompt}]
|
19 |
+
ind = 0
|
20 |
+
while True:
|
21 |
+
try:
|
22 |
+
result = self.llm.prediction(message)
|
23 |
+
start = result.find("{")
|
24 |
+
end = result.find("}")
|
25 |
+
self.colorful_print(result, "Answer Verify")
|
26 |
+
self.log(query_id, result)
|
27 |
+
clean_result = eval(result[start:end+1])
|
28 |
+
speak = clean_result["Speak"]
|
29 |
+
status = clean_result["Status"]
|
30 |
+
return speak, status
|
31 |
+
except Exception as e:
|
32 |
+
print(f"final answer check fails: {e}")
|
33 |
+
self.log(query_id, f"final answer check fails: {e}")
|
34 |
+
if ind > self.max_retry:
|
35 |
+
return -1, -1
|
36 |
+
ind += 1
|
37 |
+
continue
|
38 |
+
|
39 |
+
def get_memory(self, question, answer):
|
40 |
+
"""get relevant memory and add it to agent's memory"""
|
41 |
+
self.memory = {"question": question, "answer": answer}
|
42 |
+
|
43 |
+
def get_prompt(self):
|
44 |
+
"""get the prompt for the agent"""
|
45 |
+
agent_prompt = self.prompt.format(**self.memory)
|
46 |
+
return agent_prompt
|
47 |
+
|
48 |
+
class stream_verifier_agent(BaseAgent):
|
49 |
+
def __init__(self, llm, logger_dir):
|
50 |
+
super().__init__(
|
51 |
+
prompt = final_answer_check_prompt,
|
52 |
+
llm = llm,
|
53 |
+
name = "Verifier Agent",
|
54 |
+
logger_dir = logger_dir
|
55 |
+
)
|
56 |
+
|
57 |
+
def run(self, question, answer, query_id):
|
58 |
+
"""agent run one step"""
|
59 |
+
self.get_memory(question, answer)
|
60 |
+
agent_prompt = self.get_prompt()
|
61 |
+
message = [{'role': 'user',
|
62 |
+
'content': agent_prompt}]
|
63 |
+
ind = 0
|
64 |
+
while True:
|
65 |
+
try:
|
66 |
+
result = self.llm.prediction(message)
|
67 |
+
start = result.find("{")
|
68 |
+
end = result.find("}")
|
69 |
+
self.colorful_print(result, "Answer Verify")
|
70 |
+
self.log(query_id, result)
|
71 |
+
clean_result = eval(result[start:end+1])
|
72 |
+
speak = clean_result["Speak"]
|
73 |
+
status = clean_result["Status"]
|
74 |
+
return speak, status, "Answer Verify", self.name, result
|
75 |
+
except Exception as e:
|
76 |
+
print(f"final answer check fails: {e}")
|
77 |
+
self.log(query_id, f"final answer check fails: {e}")
|
78 |
+
if ind > self.max_retry:
|
79 |
+
return -1, -1, "Answer Verify", self.name, str(e)
|
80 |
+
ind += 1
|
81 |
+
continue
|
82 |
+
|
83 |
+
def get_memory(self, question, answer):
|
84 |
+
"""get relevant memory and add it to agent's memory"""
|
85 |
+
self.memory = {"question": question, "answer": answer}
|
86 |
+
|
87 |
+
def get_prompt(self):
|
88 |
+
"""get the prompt for the agent"""
|
89 |
+
agent_prompt = self.prompt.format(**self.memory)
|
90 |
+
return agent_prompt
|
Smurfs/data/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
Smurfs/data/__init__.py
ADDED
File without changes
|
Smurfs/data/post_process.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
4 |
+
# print(sys.path)
|
5 |
+
import json
|
6 |
+
from Smurfs.data.utils import tree_steps_counter, total_path_transform
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
def parse_arg():
|
10 |
+
parser = argparse.ArgumentParser()
|
11 |
+
parser.add_argument('--input_dir', type=str, default="dir_to_your_data", required=False, help='the directory of the data that needs post-processing')
|
12 |
+
parser.add_argument('--example_dir', type=str, default="dir_to_example_data", required=False, help='the directory of the example data')
|
13 |
+
parser.add_argument('--test_sets', nargs='+', type=str, required=False, help='the test sets that need processing. It should be G2_instruction, G2_category or G3_instruction')
|
14 |
+
args = parser.parse_args()
|
15 |
+
return args
|
16 |
+
|
17 |
+
def main():
|
18 |
+
args = parse_arg()
|
19 |
+
input_dir = args.input_dir
|
20 |
+
test_sets = args.test_sets
|
21 |
+
example_dir = args.example_dir
|
22 |
+
for test_set in test_sets:
|
23 |
+
data_path = os.path.join(input_dir, f"{test_set}_raw.json")
|
24 |
+
example_data_path = os.path.join(example_dir, f"{test_set}.json")
|
25 |
+
|
26 |
+
with open(data_path, 'r') as file:
|
27 |
+
g_data = json.load(file)
|
28 |
+
with open(example_data_path, 'r') as file:
|
29 |
+
g_example_data = json.load(file)
|
30 |
+
g_new_data = {}
|
31 |
+
for g_d in g_data:
|
32 |
+
m = False
|
33 |
+
for d in g_example_data:
|
34 |
+
if g_data[g_d]["query"] == g_example_data[d]["query"]:
|
35 |
+
g_new_data_ele = {"query":"", "available_tools": [], "answer":{}}
|
36 |
+
g_new_answer_ele = {
|
37 |
+
"method": "smurfs",
|
38 |
+
"total_steps": 0,
|
39 |
+
"final_answer": "",
|
40 |
+
"answer_details": []
|
41 |
+
}
|
42 |
+
g_new_data_ele["query"] = g_data[g_d]["query"]
|
43 |
+
g_new_data_ele["available_tools"] = g_example_data[d]["available_tools"]
|
44 |
+
g_new_answer_ele["answer_details"] = [g_data[g_d]["answer"]["answer_details"]]
|
45 |
+
counter = tree_steps_counter(0)
|
46 |
+
counter.count_total_steps(g_data[g_d]["answer"]["answer_details"])
|
47 |
+
g_new_answer_ele["total_steps"] = counter.get_steps()
|
48 |
+
g_new_answer_ele["final_answer"] = g_data[g_d]["answer"]["final_answer"]
|
49 |
+
g_new_data_ele["answer"] = g_new_answer_ele
|
50 |
+
g_new_data[d] = g_new_data_ele
|
51 |
+
m = True
|
52 |
+
break
|
53 |
+
if not m:
|
54 |
+
print(f"{test_set} mismatch! The key is: {g_d}")
|
55 |
+
if test_set == "G2_category":
|
56 |
+
duplicate = g_new_data["43201"]
|
57 |
+
g_new_data["43200"] = duplicate
|
58 |
+
|
59 |
+
output_path = os.path.join(input_dir, f"{test_set}.json")
|
60 |
+
print(output_path)
|
61 |
+
with open(output_path, 'w') as file:
|
62 |
+
json.dump(g_new_data, file, indent=4, ensure_ascii=False)
|
63 |
+
|
64 |
+
if __name__ == '__main__':
|
65 |
+
main()
|
Smurfs/data/utils.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class tree_steps_counter:
|
2 |
+
def __init__(self, steps):
|
3 |
+
self.steps = steps
|
4 |
+
|
5 |
+
def count_total_steps(self, root):
|
6 |
+
self.steps += 1
|
7 |
+
if root["next"] == []:
|
8 |
+
return
|
9 |
+
for i in range(len(root["next"])):
|
10 |
+
self.count_total_steps(root["next"][i])
|
11 |
+
|
12 |
+
def get_steps(self):
|
13 |
+
return self.steps
|
14 |
+
|
15 |
+
def total_path_transform(data, index):
|
16 |
+
finish_template = [{
|
17 |
+
"role": "tool",
|
18 |
+
"message": {
|
19 |
+
"name": "Finish",
|
20 |
+
"arguments": {
|
21 |
+
"return_type": "give_answer",
|
22 |
+
"final_answer": data[index]["final_answer"]
|
23 |
+
},
|
24 |
+
"response": ""
|
25 |
+
},
|
26 |
+
"next": []
|
27 |
+
}]
|
28 |
+
answer_path = data[index]["total_path"]["next"][0]["next"]
|
29 |
+
for i in range(len(answer_path)-1, -1, -1):
|
30 |
+
if answer_path[i]["role"] == "plan_global":
|
31 |
+
if answer_path[i]["next"] != []:
|
32 |
+
current_log = answer_path[i]["next"][-1]
|
33 |
+
answer_path[i] = answer_path[i]["next"]
|
34 |
+
else:
|
35 |
+
if i == len(answer_path)-1:
|
36 |
+
answer_path[i] = finish_template[0]
|
37 |
+
continue
|
38 |
+
else:
|
39 |
+
current_log = answer_path[i]
|
40 |
+
while current_log["next"] != []:
|
41 |
+
current_log = current_log["next"][-1]
|
42 |
+
if i == len(answer_path)-1:
|
43 |
+
current_log["next"] = finish_template
|
44 |
+
else:
|
45 |
+
if not isinstance(answer_path[i+1], list):
|
46 |
+
current_log["next"] = [answer_path[i+1]]
|
47 |
+
else:
|
48 |
+
current_log["next"] = answer_path[i+1]
|
49 |
+
if not isinstance(answer_path[0], list):
|
50 |
+
data[index]["total_path"]["next"][0]["next"] = [answer_path[0]]
|
51 |
+
else:
|
52 |
+
data[index]["total_path"]["next"][0]["next"] = answer_path[0]
|
53 |
+
|
Smurfs/deploy/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
global_dict = {
|
2 |
+
"knowledge_base" : None
|
3 |
+
}
|
Smurfs/deploy/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (203 Bytes). View file
|
|
Smurfs/deploy/cli_inference.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
|
3 |
+
# 抑制所有警告
|
4 |
+
warnings.filterwarnings('ignore')
|
5 |
+
|
6 |
+
import os
|
7 |
+
import sys
|
8 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
9 |
+
from Smurfs.inference.smurfs_worker import smurfs_hotpot_worker, smurfs_worker
|
10 |
+
# from Smurfs.tools.tool_env import HotpotToolEnv
|
11 |
+
from Smurfs.tools.tool_env import tool_env
|
12 |
+
from Smurfs.model.openai_model.openai_model import OpenAI_Model, OpenRouter_Model
|
13 |
+
from Smurfs.agents.answer_agent.answer import answer_agent
|
14 |
+
from Smurfs.agents.executor_agent.executor import executor_agent
|
15 |
+
from Smurfs.agents.planning_agent.planner import hotpot_planning_agent
|
16 |
+
from Smurfs.agents.verifier_agent.verifier import verifier_agent
|
17 |
+
import json
|
18 |
+
import threading
|
19 |
+
import joblib
|
20 |
+
from tqdm import tqdm
|
21 |
+
import time
|
22 |
+
|
23 |
+
def run(worker, query, query_id):
|
24 |
+
# global lock
|
25 |
+
final_answer, output_file_ele, solution_file_ele = worker.run(query, query_id)
|
26 |
+
# lock.acquire()
|
27 |
+
worker.save_solution(output_file_ele, solution_file_ele, query_id)
|
28 |
+
# lock.release()
|
29 |
+
return final_answer
|
30 |
+
|
31 |
+
def cli_run(query, worker):
|
32 |
+
pre = run(worker, query, 0)
|
33 |
+
return pre
|
34 |
+
|
35 |
+
if __name__ == '__main__':
|
36 |
+
# model_name = "mistralai/mistral-7b-instruct-v0.2"
|
37 |
+
model_name = "mistralai/mistral-7b-instruct-v0.2"
|
38 |
+
method_name = "cli_inference"
|
39 |
+
tool_doc_path = "Smurfs/tools/math_search.json"
|
40 |
+
# llm = OpenAI_Model(model_name=model_name)
|
41 |
+
llm = OpenRouter_Model(model_name=model_name)
|
42 |
+
# parser_llm = OpenAI_Model(model_name="gpt-4")
|
43 |
+
with open(tool_doc_path, "r") as f:
|
44 |
+
available_tools = json.load(f)
|
45 |
+
|
46 |
+
test_set = "cli"
|
47 |
+
|
48 |
+
output_dir = f"data/{method_name}/{test_set}/answer"
|
49 |
+
results_dir = f"data/{method_name}/{test_set}/results.json"
|
50 |
+
if not os.path.exists(f"data/{method_name}/{test_set}/parser_log"):
|
51 |
+
os.makedirs(f"data/{method_name}/{test_set}/parser_log")
|
52 |
+
if not os.path.exists(output_dir):
|
53 |
+
os.makedirs(output_dir)
|
54 |
+
# HP_answer_agent = answer_agent(llm=parser_llm, logger_dir=f"data/{method_name}/{test_set}/parser_log")
|
55 |
+
# worker = smurfs_hotpot_worker(available_tools, HotpotToolEnv, llm, method_name, test_set, answer_agent, executor_agent,hotpot_planning_agent, verifier_agent)
|
56 |
+
worker = smurfs_worker(available_tools, tool_env, llm, method_name, test_set, answer_agent, executor_agent,hotpot_planning_agent, verifier_agent)
|
57 |
+
query = input("Please Enter Your Task: ")
|
58 |
+
cli_run(query, worker)
|
Smurfs/deploy/gradio_inference.py
ADDED
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
# 抑制所有警告
|
6 |
+
warnings.filterwarnings('ignore')
|
7 |
+
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
11 |
+
from Smurfs.inference.smurfs_worker import smurfs_hotpot_worker, smurfs_worker, stream_smurfs_worker
|
12 |
+
# from Smurfs.tools.tool_env import HotpotToolEnv
|
13 |
+
from Smurfs.deploy import global_dict
|
14 |
+
from Smurfs.tools.tool_env import tool_env
|
15 |
+
from Smurfs.model.openai_model.openai_model import OpenAI_Model
|
16 |
+
from Smurfs.agents.answer_agent.answer import stream_answer_agent
|
17 |
+
from Smurfs.agents.executor_agent.executor import stream_executor_agent
|
18 |
+
from Smurfs.agents.planning_agent.planner import stream_hotpot_planning_agent
|
19 |
+
from Smurfs.agents.verifier_agent.verifier import stream_verifier_agent
|
20 |
+
|
21 |
+
from Smurfs.tools.docqa.api import tool_env as docqa_tool_env
|
22 |
+
from Smurfs.tools.hotpotQA.api import tool_env as hotpot_tool_env
|
23 |
+
from Smurfs.tools.math.api import tool_env as math_tool_env
|
24 |
+
from Smurfs.tools.shell.api import tool_env as shell_tool_env
|
25 |
+
from Smurfs.tools.websearch.api import tool_env as websearch_tool_env
|
26 |
+
|
27 |
+
|
28 |
+
import json
|
29 |
+
import threading
|
30 |
+
import joblib
|
31 |
+
from tqdm import tqdm
|
32 |
+
import time
|
33 |
+
|
34 |
+
from PyPDF2 import PdfReader
|
35 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
36 |
+
from langchain.vectorstores import FAISS
|
37 |
+
from langchain_openai import OpenAIEmbeddings
|
38 |
+
from datetime import datetime
|
39 |
+
current_datetime = datetime.now()
|
40 |
+
# user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(container=False)
|
41 |
+
# inp = gr.Textbox(placeholder="Enter your task")
|
42 |
+
# css = """
|
43 |
+
# .btn {background-color: blue; color: white;}
|
44 |
+
# #bot {background-color: blue; color: white;}
|
45 |
+
# #e {display: inline-block; vertical-align: middle;}
|
46 |
+
# """
|
47 |
+
|
48 |
+
# def update(name):
|
49 |
+
# return f"<span style='color: red'>Welcome to Gradio, {name}!</span>"
|
50 |
+
|
51 |
+
tool_env_map = {
|
52 |
+
"shell": shell_tool_env,
|
53 |
+
"math": math_tool_env,
|
54 |
+
"docqa": docqa_tool_env,
|
55 |
+
"hotpotQA": hotpot_tool_env,
|
56 |
+
"websearch": websearch_tool_env
|
57 |
+
}
|
58 |
+
|
59 |
+
total_env, env_name_list = {}, []
|
60 |
+
|
61 |
+
def loading():
|
62 |
+
return "Loading..."
|
63 |
+
|
64 |
+
def load_text_from_pdf(up, key=None):
|
65 |
+
global global_dict
|
66 |
+
if key == None:
|
67 |
+
key = os.environ.get("OPENAI_API_KEY")
|
68 |
+
pdf_path = up.name
|
69 |
+
pdf_reader = PdfReader(pdf_path)
|
70 |
+
text = ""
|
71 |
+
for page in pdf_reader.pages:
|
72 |
+
text += page.extract_text()
|
73 |
+
|
74 |
+
# split into chunks
|
75 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
76 |
+
chunk_size=1000,
|
77 |
+
chunk_overlap=200,
|
78 |
+
add_start_index=True
|
79 |
+
)
|
80 |
+
chunks = text_splitter.split_text(text)
|
81 |
+
|
82 |
+
# create embeddings
|
83 |
+
# embeddings = OpenAIEmbeddings()
|
84 |
+
embeddings = OpenAIEmbeddings(openai_api_key=key)
|
85 |
+
global_dict["knowledge_base"] = FAISS.from_texts(chunks, embeddings)
|
86 |
+
return "upload success!"
|
87 |
+
#return knowledge_base
|
88 |
+
|
89 |
+
def update(query, OPENAI_API_KEY, BING_SUBSCRIPT_KEY, WOLFRAMALPH_APP_ID, WEATHER_API_KEYS):
|
90 |
+
global total_env, env_name_list
|
91 |
+
# print(total_env)
|
92 |
+
# print(BING_SUBSCRIPT_KEY)
|
93 |
+
# print(WOLFRAMALPH_APP_ID)
|
94 |
+
# print(WEATHER_API_KEYS)
|
95 |
+
# model_name = "mistralai/mistral-7b-instruct-v0.2"
|
96 |
+
model_name = "gpt-4"
|
97 |
+
method_name = "cli_inference"
|
98 |
+
tool_doc_path = "Smurfs/tools/tool_doc.json"
|
99 |
+
if OPENAI_API_KEY == None or OPENAI_API_KEY == '':
|
100 |
+
yield [(query, "No OPENAI KEY provided!")]
|
101 |
+
raise KeyError
|
102 |
+
if (BING_SUBSCRIPT_KEY == None or BING_SUBSCRIPT_KEY == ''):
|
103 |
+
yield [(query, "No BING_SUBSCRIPT_KEY provided! Please register one from https://www.microsoft.com/en-us/bing/apis/bing-web-search-api and add it to your keys")]
|
104 |
+
raise KeyError
|
105 |
+
if WOLFRAMALPH_APP_ID == None or WOLFRAMALPH_APP_ID == '':
|
106 |
+
yield [(query, "No WOLFRAMALPH_APP_ID provided! please register one from https://products.wolframalpha.com/api/ and add it to your keys")]
|
107 |
+
raise KeyError
|
108 |
+
if WEATHER_API_KEYS == None or WEATHER_API_KEYS == '':
|
109 |
+
yield [(query, "No WEATHER_API_KEYS provided! Please register one from https://www.weatherapi.com/ and add it to")]
|
110 |
+
raise KeyError
|
111 |
+
llm = OpenAI_Model(model_name=model_name, api_key=OPENAI_API_KEY)
|
112 |
+
#llm = OpenRouter_Model(model_name=model_name)
|
113 |
+
if "docqa" in total_env:
|
114 |
+
sys_prompt = llm.sys_prompt + "You already have access to the file uploaded by the user. So just answer the question from the user, you don't need to find the file first."
|
115 |
+
llm.change_sys_prompt(sys_prompt)
|
116 |
+
else:
|
117 |
+
llm.set_default_sys_prompt()
|
118 |
+
# parser_llm = OpenAI_Model(model_name="gpt-4")
|
119 |
+
with open(tool_doc_path, "r") as f:
|
120 |
+
tool_doc = json.load(f)
|
121 |
+
tool_doc["bing_search"]["api_description"] += f"Today is {current_datetime.year}.{current_datetime.month}.{current_datetime.day}"
|
122 |
+
available_tools = []
|
123 |
+
for env_name in env_name_list:
|
124 |
+
available_tools.append(tool_doc[env_name])
|
125 |
+
|
126 |
+
test_set = "cli"
|
127 |
+
|
128 |
+
output_dir = f"data/{method_name}/{test_set}/answer"
|
129 |
+
results_dir = f"data/{method_name}/{test_set}/results.json"
|
130 |
+
if not os.path.exists(f"data/{method_name}/{test_set}/parser_log"):
|
131 |
+
os.makedirs(f"data/{method_name}/{test_set}/parser_log")
|
132 |
+
if not os.path.exists(output_dir):
|
133 |
+
os.makedirs(output_dir)
|
134 |
+
# HP_answer_agent = answer_agent(llm=parser_llm, logger_dir=f"data/{method_name}/{test_set}/parser_log")
|
135 |
+
# worker = smurfs_hotpot_worker(available_tools, HotpotToolEnv, llm, method_name, test_set, answer_agent, executor_agent,hotpot_planning_agent, verifier_agent)
|
136 |
+
#print(total_env)
|
137 |
+
worker = stream_smurfs_worker(available_tools, total_env, llm, method_name, test_set, stream_answer_agent, stream_executor_agent, stream_hotpot_planning_agent, stream_verifier_agent, OPENAI_API_KEY, BING_SUBSCRIPT_KEY, WOLFRAMALPH_APP_ID, WEATHER_API_KEYS)
|
138 |
+
stream_generator = worker.run(query, 0)
|
139 |
+
# messages = []
|
140 |
+
while True:
|
141 |
+
try:
|
142 |
+
response = next(stream_generator)
|
143 |
+
messages = [(query, response)]
|
144 |
+
yield messages
|
145 |
+
except StopIteration:
|
146 |
+
break
|
147 |
+
# query = input("Please Enter Your Task: ")
|
148 |
+
# cli_run(query, worker)
|
149 |
+
|
150 |
+
def update_tools(rs):
|
151 |
+
global total_env, env_name_list
|
152 |
+
total_env = {}
|
153 |
+
env_name_list = []
|
154 |
+
for tool_system in rs:
|
155 |
+
tool = tool_system.split(": ")[0]
|
156 |
+
env = tool_env_map[tool]
|
157 |
+
print(f"env: {env}")
|
158 |
+
for e in env:
|
159 |
+
if e not in env_name_list:
|
160 |
+
total_env[e] = env[e]
|
161 |
+
env_name_list.append(e)
|
162 |
+
print(total_env)
|
163 |
+
#return total_env, env_name_list
|
164 |
+
|
165 |
+
def user(user_msg):
|
166 |
+
return user_msg
|
167 |
+
|
168 |
+
tools = ["math: Tool that can handle mathematical problems",
|
169 |
+
"docqa: Tool that can answer questions about your uploaded file",
|
170 |
+
"hotpotQA: Tool that can do multi-hop commonsense reasoning",
|
171 |
+
"websearch: Tool that can do web search to answer your question"]
|
172 |
+
websearch_example = ["请根据深圳明天的天气推荐给我推荐一套穿搭方案,结果用中文输出。", "今年的中秋节是哪天?用中文输出"]
|
173 |
+
math_example = ["Calc integral of sin(x)+2x^2+3x+1 from 0 to 1", "When both sides of a right triangle are 6 and 8, what is the length of the other side?"]
|
174 |
+
inp = gr.Textbox(placeholder="Please input your task", label="Task")
|
175 |
+
with gr.Blocks() as demo:
|
176 |
+
gr.HTML("""<h1 align="center">Smurfs</h1>""")
|
177 |
+
#gr.Markdown("""<figure><a href=https://yoursmiles.org/h-smurf.php><img src=https://yoursmiles.org/hsmile/smurf/h3602.gif></a><a href=https://yoursmiles.org/h-smurf.php><img src=https://yoursmiles.org/hsmile/smurf/h3607.gif></a><a href=https://yoursmiles.org/h-smurf.php><img src=https://yoursmiles.org/hsmile/smurf/h3623.gif></a><a href=https://yoursmiles.org/h-smurf.php><img src=https://yoursmiles.org/hsmile/smurf/h3625.gif></a></figure>""")
|
178 |
+
#gr.HTML("""<a href=https://yoursmiles.org/h-smurf.php><img src=https://yoursmiles.org/hsmile/smurf/h3602.gif>""")
|
179 |
+
with gr.Row():
|
180 |
+
with gr.Column(scale=1):
|
181 |
+
inp.render()
|
182 |
+
rs = gr.Dropdown(choices=tools, label="Tool Systems", multiselect=True)
|
183 |
+
file_output = gr.File(file_types=[".pdf"])
|
184 |
+
with gr.Accordion("Keys", open=False):
|
185 |
+
# model_name = gr.Dropdown(label="Moel Name", choices=["gpt-3.5", "gpt-4o", "gpt-4"])
|
186 |
+
openai_key = gr.Textbox(label="OpenAI API Key", placeholder="Please Enter Your OpenAI API Key")
|
187 |
+
bing_search_key = gr.Textbox(label="BingSearch Key", placeholder="Please Enter Your BingSearch Key from https://www.microsoft.com/en-us/bing/apis/bing-web-search-api")
|
188 |
+
wolframalpha_key = gr.Textbox(label="Wolframalpha API Key", placeholder="Please Enter Your WOLFRAMALPH_APP_ID from https://products.wolframalpha.com/api/")
|
189 |
+
weather_key = gr.Textbox(label="Weather API Key", placeholder="Please Enter Your Weather API Key from https://www.weatherapi.com/")
|
190 |
+
|
191 |
+
|
192 |
+
gr.Examples(["Who is the brother of the 2022 NBA FMVP?", "How much older is Lebron James than his oldest son?", "Calc integral of sin(x)+2x^2+3x+1 from 0 to 1", "Calculate the length of the hypotenuse of a right triangle when the other two sides are 6 and 8", "请根据深圳明天的天气推荐给我推荐一套穿搭方案,结果用中文输出。", "今年的中秋节是哪天?用中文输出"], inp)
|
193 |
+
_submit = gr.Button("Submit")
|
194 |
+
stop = gr.Button("Stop")
|
195 |
+
clear = gr.Button("Clear")
|
196 |
+
#upload = gr.UploadButton("Click to upload your pdf file")
|
197 |
+
# btn = gr.Button("Run", elem_id="bot", elem_classes="btn")
|
198 |
+
#with gr.Column(scale=1, elem_id="e"):
|
199 |
+
# chatbox = gr.HTML()
|
200 |
+
chatbox = gr.Chatbot(height=300)
|
201 |
+
# btn.click(fn=update, inputs=inp, outputs=chatbox)
|
202 |
+
file_output.upload(load_text_from_pdf, [file_output, openai_key], None)
|
203 |
+
#upload.upload(loading, None, inp).then(load_text_from_pdf, upload, inp)
|
204 |
+
rs.change(update_tools, rs, None)
|
205 |
+
click_event = _submit.click(user, inp, inp).then(update, [inp, openai_key, bing_search_key, wolframalpha_key, weather_key], chatbox)
|
206 |
+
stop.click(None, None, None, cancels=[click_event])
|
207 |
+
#inp.submit(user, inp, inp).then(update, inp, chatbox)
|
208 |
+
clear.click(lambda: (None, None), None, [inp, chatbox], queue=False)
|
209 |
+
# theme=gr.themes.Default().set(button_primary_border_color_dark=, hover)
|
210 |
+
demo.load(
|
211 |
+
None,
|
212 |
+
None,
|
213 |
+
_js="""
|
214 |
+
() => {
|
215 |
+
const params = new URLSearchParams(window.location.search);
|
216 |
+
if (!params.has('__theme')){
|
217 |
+
params.set('__theme', 'dark');
|
218 |
+
window.location.search = params.toString();
|
219 |
+
}
|
220 |
+
}
|
221 |
+
""",
|
222 |
+
)
|
223 |
+
demo.queue().launch(server_name='0.0.0.0', share=True, inbrowser=False, server_port=7001)
|
Smurfs/eval/hotpot_qa/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (3.46 kB). View file
|
|
Smurfs/eval/hotpot_qa/post_process.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
def remove_tags(text):
|
3 |
+
# 使用正则表达式去除形如<xxx>的字符
|
4 |
+
cleaned_text = re.sub(r'<[^>]*>', '', text)
|
5 |
+
return cleaned_text
|
6 |
+
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
|
10 |
+
from Smurfs.inference.smurfs_worker import smurfs_hotpot_worker
|
11 |
+
from Smurfs.tools.tool_env import HotpotToolEnv
|
12 |
+
from Smurfs.model.openai_model.openai_model import OpenAI_Model, OpenRouter_Model
|
13 |
+
from Smurfs.agents.answer_agent.answer import answer_agent
|
14 |
+
from Smurfs.agents.executor_agent.executor import executor_agent
|
15 |
+
from Smurfs.agents.planning_agent.planner import hotpot_planning_agent
|
16 |
+
from Smurfs.agents.verifier_agent.verifier import verifier_agent
|
17 |
+
from Smurfs.eval.hotpot_qa.utils import eval_result
|
18 |
+
import json
|
19 |
+
import threading
|
20 |
+
import joblib
|
21 |
+
from tqdm import tqdm
|
22 |
+
import time
|
23 |
+
|
24 |
+
def post_run(results, HP_answer_agent):
|
25 |
+
global new_results
|
26 |
+
for res in tqdm(results):
|
27 |
+
pre = res["pre_ans"]
|
28 |
+
query_id = res["id"]
|
29 |
+
ques = res["question"]
|
30 |
+
pre = remove_tags(pre)
|
31 |
+
if len(pre) == 0:
|
32 |
+
res["parsed_pre"] = ""
|
33 |
+
new_results.append(res)
|
34 |
+
continue
|
35 |
+
parsed_pre = HP_answer_agent.run(query_id=query_id, task="parse", question=ques, detailed_answer=pre)
|
36 |
+
res["parsed_pre"] = parsed_pre
|
37 |
+
new_results.append(res)
|
38 |
+
|
39 |
+
if __name__ == "__main__":
|
40 |
+
levels = ['easy', 'medium', 'hard']
|
41 |
+
# model_name = "meta-llama/llama-2-70b-chat"
|
42 |
+
method_name = "llama-2-13b-Smurfs"
|
43 |
+
parser_llm = OpenAI_Model(model_name="gpt-4")
|
44 |
+
for level in levels:
|
45 |
+
# level = 'easy'
|
46 |
+
# model_name = "gpt-4-0613"
|
47 |
+
new_results = []
|
48 |
+
test_set = f"hotpot_qa_{level}"
|
49 |
+
HP_answer_agent = answer_agent(llm=parser_llm, logger_dir=f"data/{method_name}/{test_set}/post_process/parser_log")
|
50 |
+
output_dir = f"data/{method_name}/{test_set}/answer"
|
51 |
+
new_results_path = f"data/{method_name}/{test_set}/post_process/results.json"
|
52 |
+
new_results_dir = f"data/{method_name}/{test_set}/post_process/"
|
53 |
+
results_dir = f"data/{method_name}/{test_set}/results.json"
|
54 |
+
if not os.path.exists(new_results_path):
|
55 |
+
os.makedirs(f"data/{method_name}/{test_set}/post_process/parser_log")
|
56 |
+
# os.makedirs(new_results_dir)
|
57 |
+
with open(results_dir, "r") as file:
|
58 |
+
results = json.load(file)
|
59 |
+
|
60 |
+
total_len = len(results)
|
61 |
+
print(total_len)
|
62 |
+
|
63 |
+
threads = []
|
64 |
+
if total_len < 20:
|
65 |
+
for i in range(total_len):
|
66 |
+
if total_len == 0:
|
67 |
+
break
|
68 |
+
|
69 |
+
start = i
|
70 |
+
end = i+1
|
71 |
+
if i == total_len-1:
|
72 |
+
query_cur = results[start:]
|
73 |
+
else:
|
74 |
+
query_cur = results[start: end]
|
75 |
+
t = threading.Thread(target=post_run, args=(query_cur, HP_answer_agent))
|
76 |
+
t.start()
|
77 |
+
threads.append(t)
|
78 |
+
|
79 |
+
else:
|
80 |
+
for i in range(20):
|
81 |
+
|
82 |
+
if total_len == 0:
|
83 |
+
break
|
84 |
+
|
85 |
+
start = round(total_len/20)*i
|
86 |
+
end = round(total_len/20)*(i+1)
|
87 |
+
if i == 19:
|
88 |
+
query_cur = results[start:]
|
89 |
+
else:
|
90 |
+
query_cur = results[start: end]
|
91 |
+
t = threading.Thread(target=post_run, args=(query_cur, HP_answer_agent))
|
92 |
+
t.start()
|
93 |
+
threads.append(t)
|
94 |
+
|
95 |
+
for thread in threads:
|
96 |
+
thread.join()
|
97 |
+
|
98 |
+
with open(new_results_path, "w") as file:
|
99 |
+
json.dump(new_results, file, indent=4, ensure_ascii=False)
|
100 |
+
correct, reward, parsed_correct, parsed_reward, pre_dict, parsed_dict = eval_result(new_results)
|
101 |
+
print(f"correct rate for {test_set} is: {correct}, reward rate for {test_set} is: {reward}")
|
102 |
+
print(f"parsed correct rate for {test_set} is: {parsed_correct}, parsed reward rate for {test_set} is: {parsed_reward}")
|
103 |
+
else:
|
104 |
+
with open(new_results_path, "r") as file:
|
105 |
+
new_results = json.load(file)
|
106 |
+
correct, reward, parsed_correct, parsed_reward, pre_dict, parsed_dict = eval_result(new_results)
|
107 |
+
print(f"correct rate for {test_set} is: {correct}, reward rate for {test_set} is: {reward}")
|
108 |
+
print(f"parsed correct rate for {test_set} is: {parsed_correct}, parsed reward rate for {test_set} is: {parsed_reward}")
|
109 |
+
|
Smurfs/eval/hotpot_qa/run_eval.py
ADDED
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
|
4 |
+
from Smurfs.inference.smurfs_worker import smurfs_hotpot_worker
|
5 |
+
from Smurfs.tools.tool_env import HotpotToolEnv
|
6 |
+
from Smurfs.model.openai_model.openai_model import OpenAI_Model, OpenRouter_Model
|
7 |
+
from Smurfs.agents.answer_agent.answer import answer_agent
|
8 |
+
from Smurfs.agents.executor_agent.executor import executor_agent
|
9 |
+
from Smurfs.agents.planning_agent.planner import hotpot_planning_agent
|
10 |
+
from Smurfs.agents.verifier_agent.verifier import verifier_agent
|
11 |
+
from Smurfs.eval.hotpot_qa.utils import eval_result
|
12 |
+
import json
|
13 |
+
import threading
|
14 |
+
import joblib
|
15 |
+
from tqdm import tqdm
|
16 |
+
import time
|
17 |
+
import warnings
|
18 |
+
|
19 |
+
# 抑制所有警告
|
20 |
+
warnings.filterwarnings('ignore')
|
21 |
+
|
22 |
+
def run(worker, query, query_id):
|
23 |
+
# global lock
|
24 |
+
final_answer, output_file_ele, solution_file_ele = worker.run(query, query_id)
|
25 |
+
# lock.acquire()
|
26 |
+
worker.save_solution(output_file_ele, solution_file_ele, query_id)
|
27 |
+
# lock.release()
|
28 |
+
return final_answer
|
29 |
+
|
30 |
+
def run_one_hotpot(ques, ans, HP_answer_agent, worker, query_id):
|
31 |
+
global results
|
32 |
+
# ques, ans = task_instructions[0]
|
33 |
+
# print(ques)
|
34 |
+
# print(ans)
|
35 |
+
pre = run(worker, ques, query_id)
|
36 |
+
print(pre)
|
37 |
+
|
38 |
+
# question = "Where was the first governor after the The Missouri Compromise from?"
|
39 |
+
# detailed_answer = "The first governor of Missouri after the Missouri Compromise was Alexander McNair. He was originally from central Pennsylvania, specifically, he was born in Cumberland County on May 5, 1775, and later lived in Derry township, Lancaster (now Dauphin) County. He also pursued his education in Derry and attended the University of Pennsylvania in Philadelphia for a term."
|
40 |
+
parsed_pre = HP_answer_agent.run(query_id=query_id, task="parse", question=ques, detailed_answer=pre)
|
41 |
+
print(parsed_pre)
|
42 |
+
|
43 |
+
result_ele = {"question": ques, "gt_answer": ans, "pre_ans": pre, "parsed_pre": parsed_pre, "id": query_id}
|
44 |
+
lock.acquire()
|
45 |
+
results.append(result_ele)
|
46 |
+
lock.release()
|
47 |
+
|
48 |
+
def run_hotpot(query_list, HP_answer_agent, worker):
|
49 |
+
with tqdm(total=len(query_list), desc="Processing files", initial=0) as pbar:
|
50 |
+
for i, test_task_ins in enumerate(query_list, start=0):
|
51 |
+
idx = test_task_ins[0]
|
52 |
+
ques, ans = test_task_ins[1]
|
53 |
+
while True:
|
54 |
+
try:
|
55 |
+
run_one_hotpot(ques, ans, HP_answer_agent, worker, idx)
|
56 |
+
break
|
57 |
+
except Exception as e:
|
58 |
+
print(e)
|
59 |
+
print("some error occurs, continue...")
|
60 |
+
time.sleep(60)
|
61 |
+
continue
|
62 |
+
|
63 |
+
pbar.update(1)
|
64 |
+
return
|
65 |
+
|
66 |
+
#测试三个测试集
|
67 |
+
# if __name__ == '__main__':
|
68 |
+
# #store true pre and parse pre in a same json, calculate together
|
69 |
+
# #dump them together
|
70 |
+
# lock = threading.Lock()
|
71 |
+
# levels = ['easy', 'medium', 'hard']
|
72 |
+
# model_name = "gpt-3.5-turbo"
|
73 |
+
# method_name = "GPT3-turbo-Smurfs"
|
74 |
+
# llm = OpenAI_Model(model_name=model_name)
|
75 |
+
# parser_llm = OpenAI_Model(model_name="gpt-4")
|
76 |
+
# task_path = "/Users/chenjunzhi/Desktop/smurfs_more/AutoAct/Self_Plan/Group_Planning/benchmark_run/data/hotpotqa"
|
77 |
+
# with open("/Users/chenjunzhi/Desktop/smurfs_more/Smurfs/Smurfs/tools/hotpot.json", "r") as f:
|
78 |
+
# available_tools = json.load(f)
|
79 |
+
# for level in levels:
|
80 |
+
# # level = 'hard'
|
81 |
+
# # model_name = "gpt-4-0613"
|
82 |
+
# results = []
|
83 |
+
# test_set = f"hotpot_qa_{level}"
|
84 |
+
|
85 |
+
# output_dir = f"data/{method_name}/{test_set}/answer"
|
86 |
+
# results_dir = f"data/{method_name}/{test_set}/results.json"
|
87 |
+
# if not os.path.exists(f"data/{method_name}/{test_set}/parser_log"):
|
88 |
+
# os.makedirs(f"data/{method_name}/{test_set}/parser_log")
|
89 |
+
# if not os.path.exists(output_dir):
|
90 |
+
# os.makedirs(output_dir)
|
91 |
+
# if os.path.exists(results_dir):
|
92 |
+
# with open(results_dir, "r") as file:
|
93 |
+
# results = json.load(file)
|
94 |
+
|
95 |
+
# items = os.listdir(output_dir)
|
96 |
+
# for i in range(len(items)):
|
97 |
+
# items[i] = items[i].split(".")[0]
|
98 |
+
|
99 |
+
# HP_answer_agent = answer_agent(llm=parser_llm, logger_dir=f"data/{method_name}/{test_set}/parser_log")
|
100 |
+
# worker = smurfs_worker(available_tools, tool_env, llm, method_name, test_set, answer_agent, executor_agent,hotpot_planning_agent, verifier_agent)
|
101 |
+
# hotpot = joblib.load(f'{task_path}/{level}.joblib').reset_index(drop = True)
|
102 |
+
# task_instructions = [(row['question'], row['answer']) for _, row in hotpot.iterrows()]
|
103 |
+
|
104 |
+
|
105 |
+
# query_to_do = []
|
106 |
+
# # if len(items) != 0:
|
107 |
+
# for idx, q in enumerate(task_instructions):
|
108 |
+
# # print(idx)
|
109 |
+
# if str(idx) in items:
|
110 |
+
# continue
|
111 |
+
# # query_id = q["query_id"]
|
112 |
+
# # if str(query_id) not in test_ids:
|
113 |
+
# # continue
|
114 |
+
# query_to_do_ele = (idx, q)
|
115 |
+
# query_to_do.append(query_to_do_ele)
|
116 |
+
|
117 |
+
# total_len = len(query_to_do)
|
118 |
+
# query_len = len(task_instructions)
|
119 |
+
# print(total_len)
|
120 |
+
|
121 |
+
# threads = []
|
122 |
+
# if total_len < 20:
|
123 |
+
# for i in range(total_len):
|
124 |
+
# if total_len == 0:
|
125 |
+
# break
|
126 |
+
|
127 |
+
# start = i
|
128 |
+
# end = i+1
|
129 |
+
# if i == total_len-1:
|
130 |
+
# query_cur = query_to_do[start:]
|
131 |
+
# else:
|
132 |
+
# query_cur = query_to_do[start: end]
|
133 |
+
# t = threading.Thread(target=run_hotpot, args=(query_cur, HP_answer_agent, worker))
|
134 |
+
# t.start()
|
135 |
+
# threads.append(t)
|
136 |
+
|
137 |
+
# else:
|
138 |
+
# for i in range(20):
|
139 |
+
|
140 |
+
# if total_len == 0:
|
141 |
+
# break
|
142 |
+
|
143 |
+
# start = round(total_len/20)*i
|
144 |
+
# end = round(total_len/20)*(i+1)
|
145 |
+
# if i == 19:
|
146 |
+
# query_cur = query_to_do[start:]
|
147 |
+
# else:
|
148 |
+
# query_cur = query_to_do[start: end]
|
149 |
+
# t = threading.Thread(target=run_hotpot, args=(query_cur, HP_answer_agent, worker))
|
150 |
+
# t.start()
|
151 |
+
# threads.append(t)
|
152 |
+
|
153 |
+
# for thread in threads:
|
154 |
+
# thread.join()
|
155 |
+
|
156 |
+
# with open(results_dir, "w") as file:
|
157 |
+
# json.dump(results, file, indent=4, ensure_ascii=False)
|
158 |
+
|
159 |
+
# with open(results_dir, "r") as file:
|
160 |
+
# eval_data = json.load(file)
|
161 |
+
|
162 |
+
# correct, reward, parsed_correct, parsed_reward, pre_dict, parsed_dict = eval_result(eval_data)
|
163 |
+
# print(f"correct rate for {test_set} is: {correct}, reward rate for {test_set} is: {reward}")
|
164 |
+
# print(f"parsed correct rate for {test_set} is: {parsed_correct}, parsed reward rate for {test_set} is: {parsed_reward}")
|
165 |
+
|
166 |
+
# with open(f"data/{method_name}/{test_set}/parsed_result.json", "w") as file:
|
167 |
+
# json.dump(parsed_dict, file, indent=4, ensure_ascii=False)
|
168 |
+
|
169 |
+
# with open(f"data/{method_name}/{test_set}/original_result.json", "w") as file:
|
170 |
+
# json.dump(pre_dict, file, indent=4, ensure_ascii=False)
|
171 |
+
|
172 |
+
|
173 |
+
|
174 |
+
|
175 |
+
|
176 |
+
|
177 |
+
|
178 |
+
# 测试一个query
|
179 |
+
# if __name__ == '__main__':
|
180 |
+
# #store true pre and parse pre in a same json, calculate together
|
181 |
+
# #dump them together
|
182 |
+
# lock = threading.Lock()
|
183 |
+
# levels = ['easy', 'medium', 'hard']
|
184 |
+
# model_name = "gpt-3.5-turbo"
|
185 |
+
# method_name = "GPT3-test-Smurfs"
|
186 |
+
# llm = OpenAI_Model(model_name=model_name)
|
187 |
+
# parser_llm = OpenAI_Model(model_name="gpt-4")
|
188 |
+
# task_path = "/Users/chenjunzhi/Desktop/smurfs_more/AutoAct/Self_Plan/Group_Planning/benchmark_run/data/hotpotqa"
|
189 |
+
# with open("/Users/chenjunzhi/Desktop/smurfs_more/Smurfs/Smurfs/tools/hotpot.json", "r") as f:
|
190 |
+
# available_tools = json.load(f)
|
191 |
+
# # for level in levels:
|
192 |
+
# level = 'hard'
|
193 |
+
# # model_name = "gpt-4-0613"
|
194 |
+
# results = []
|
195 |
+
# test_set = f"hotpot_qa_{level}"
|
196 |
+
|
197 |
+
# output_dir = f"data/{method_name}/{test_set}/answer"
|
198 |
+
# results_dir = f"data/{method_name}/{test_set}/results.json"
|
199 |
+
# if not os.path.exists(f"data/{method_name}/{test_set}/parser_log"):
|
200 |
+
# os.makedirs(f"data/{method_name}/{test_set}/parser_log")
|
201 |
+
# if not os.path.exists(output_dir):
|
202 |
+
# os.makedirs(output_dir)
|
203 |
+
# if os.path.exists(results_dir):
|
204 |
+
# with open(results_dir, "r") as file:
|
205 |
+
# results = json.load(file)
|
206 |
+
|
207 |
+
# items = os.listdir(output_dir)
|
208 |
+
# for i in range(len(items)):
|
209 |
+
# items[i] = items[i].split(".")[0]
|
210 |
+
|
211 |
+
# HP_answer_agent = answer_agent(llm=parser_llm, logger_dir=f"data/{method_name}/{test_set}/parser_log")
|
212 |
+
# worker = smurfs_worker(available_tools, tool_env, llm, method_name, test_set, answer_agent, executor_agent,hotpot_planning_agent, verifier_agent)
|
213 |
+
# hotpot = joblib.load(f'{task_path}/{level}.joblib').reset_index(drop = True)
|
214 |
+
# task_instructions = [(row['question'], row['answer']) for _, row in hotpot.iterrows()]
|
215 |
+
# ques, ans = task_instructions[15][0], task_instructions[15][1]
|
216 |
+
# run_one_hotpot(ques, ans, HP_answer_agent, worker, 0)
|
217 |
+
|
218 |
+
|
219 |
+
# # query_to_do = []
|
220 |
+
# # if len(items) != 0:
|
221 |
+
# # for idx, q in enumerate(task_instructions):
|
222 |
+
# # # print(idx)
|
223 |
+
# # if str(idx) in items:
|
224 |
+
# # continue
|
225 |
+
# # # query_id = q["query_id"]
|
226 |
+
# # # if str(query_id) not in test_ids:
|
227 |
+
# # # continue
|
228 |
+
# # query_to_do_ele = (idx, q)
|
229 |
+
# # query_to_do.append(query_to_do_ele)
|
230 |
+
|
231 |
+
# # total_len = len(query_to_do)
|
232 |
+
# # query_len = len(task_instructions)
|
233 |
+
# # print(total_len)
|
234 |
+
|
235 |
+
# # threads = []
|
236 |
+
# # if total_len < 20:
|
237 |
+
# # for i in range(total_len):
|
238 |
+
# # if total_len == 0:
|
239 |
+
# # break
|
240 |
+
|
241 |
+
# # start = i
|
242 |
+
# # end = i+1
|
243 |
+
# # if i == total_len-1:
|
244 |
+
# # query_cur = query_to_do[start:]
|
245 |
+
# # else:
|
246 |
+
# # query_cur = query_to_do[start: end]
|
247 |
+
# # t = threading.Thread(target=run_hotpot, args=(query_cur, HP_answer_agent, worker))
|
248 |
+
# # t.start()
|
249 |
+
# # threads.append(t)
|
250 |
+
|
251 |
+
# # else:
|
252 |
+
# # for i in range(20):
|
253 |
+
|
254 |
+
# # if total_len == 0:
|
255 |
+
# # break
|
256 |
+
|
257 |
+
# # start = round(total_len/20)*i
|
258 |
+
# # end = round(total_len/20)*(i+1)
|
259 |
+
# # if i == 19:
|
260 |
+
# # query_cur = query_to_do[start:]
|
261 |
+
# # else:
|
262 |
+
# # query_cur = query_to_do[start: end]
|
263 |
+
# # t = threading.Thread(target=run_hotpot, args=(query_cur, HP_answer_agent, worker))
|
264 |
+
# # t.start()
|
265 |
+
# # threads.append(t)
|
266 |
+
|
267 |
+
# # for thread in threads:
|
268 |
+
# # thread.join()
|
269 |
+
|
270 |
+
# with open(results_dir, "w") as file:
|
271 |
+
# json.dump(results, file, indent=4, ensure_ascii=False)
|
272 |
+
|
273 |
+
# # with open(results_dir, "r") as file:
|
274 |
+
# # eval_data = json.load(file)
|
275 |
+
|
276 |
+
# # correct, reward, parsed_correct, parsed_reward, pre_dict, parsed_dict = eval_result(eval_data)
|
277 |
+
# # print(f"correct rate for {test_set} is: {correct}, reward rate for {test_set} is: {reward}")
|
278 |
+
# # print(f"parsed correct rate for {test_set} is: {parsed_correct}, parsed reward rate for {test_set} is: {parsed_reward}")
|
279 |
+
|
280 |
+
# # with open(f"data/{method_name}/{test_set}/parsed_result.json", "w") as file:
|
281 |
+
# # json.dump(parsed_dict, file, indent=4, ensure_ascii=False)
|
282 |
+
|
283 |
+
# # with open(f"data/{method_name}/{test_set}/original_result.json", "w") as file:
|
284 |
+
# # json.dump(pre_dict, file, indent=4, ensure_ascii=False)
|
285 |
+
|
286 |
+
|
287 |
+
|
288 |
+
|
289 |
+
|
290 |
+
#测试一个个测试集
|
291 |
+
if __name__ == '__main__':
|
292 |
+
#store true pre and parse pre in a same json, calculate together
|
293 |
+
#dump them together
|
294 |
+
lock = threading.Lock()
|
295 |
+
levels = ['easy', 'medium', 'hard']
|
296 |
+
model_name = "meta-llama/llama-2-70b-chat"
|
297 |
+
method_name = "llama-2-13b-Smurfs"
|
298 |
+
# llm = OpenAI_Model(model_name=model_name)
|
299 |
+
llm = OpenRouter_Model(model_name=model_name)
|
300 |
+
parser_llm = OpenAI_Model(model_name="gpt-4")
|
301 |
+
task_path = "/Users/chenjunzhi/Desktop/smurfs_more/AutoAct/Self_Plan/Group_Planning/benchmark_run/data/hotpotqa"
|
302 |
+
with open("/Users/chenjunzhi/Desktop/smurfs_more/Smurfs/Smurfs/tools/hotpot.json", "r") as f:
|
303 |
+
available_tools = json.load(f)
|
304 |
+
for level in levels:
|
305 |
+
# level = 'easy'
|
306 |
+
# model_name = "gpt-4-0613"
|
307 |
+
results = []
|
308 |
+
test_set = f"hotpot_qa_{level}"
|
309 |
+
|
310 |
+
output_dir = f"data/{method_name}/{test_set}/answer"
|
311 |
+
results_dir = f"data/{method_name}/{test_set}/results.json"
|
312 |
+
if not os.path.exists(f"data/{method_name}/{test_set}/parser_log"):
|
313 |
+
os.makedirs(f"data/{method_name}/{test_set}/parser_log")
|
314 |
+
if not os.path.exists(output_dir):
|
315 |
+
os.makedirs(output_dir)
|
316 |
+
if os.path.exists(results_dir):
|
317 |
+
with open(results_dir, "r") as file:
|
318 |
+
results = json.load(file)
|
319 |
+
|
320 |
+
items = os.listdir(output_dir)
|
321 |
+
for i in range(len(items)):
|
322 |
+
items[i] = items[i].split(".")[0]
|
323 |
+
|
324 |
+
HP_answer_agent = answer_agent(llm=parser_llm, logger_dir=f"data/{method_name}/{test_set}/parser_log")
|
325 |
+
worker = smurfs_hotpot_worker(available_tools, HotpotToolEnv, llm, method_name, test_set, answer_agent, executor_agent,hotpot_planning_agent, verifier_agent)
|
326 |
+
hotpot = joblib.load(f'{task_path}/{level}.joblib').reset_index(drop = True)
|
327 |
+
task_instructions = [(row['question'], row['answer']) for _, row in hotpot.iterrows()]
|
328 |
+
|
329 |
+
|
330 |
+
query_to_do = []
|
331 |
+
# if len(items) != 0:
|
332 |
+
for idx, q in enumerate(task_instructions):
|
333 |
+
# print(idx)
|
334 |
+
if str(idx) in items:
|
335 |
+
continue
|
336 |
+
# query_id = q["query_id"]
|
337 |
+
# if str(query_id) not in test_ids:
|
338 |
+
# continue
|
339 |
+
query_to_do_ele = (idx, q)
|
340 |
+
query_to_do.append(query_to_do_ele)
|
341 |
+
|
342 |
+
total_len = len(query_to_do)
|
343 |
+
query_len = len(task_instructions)
|
344 |
+
print(total_len)
|
345 |
+
|
346 |
+
threads = []
|
347 |
+
if total_len < 20:
|
348 |
+
for i in range(total_len):
|
349 |
+
if total_len == 0:
|
350 |
+
break
|
351 |
+
|
352 |
+
start = i
|
353 |
+
end = i+1
|
354 |
+
if i == total_len-1:
|
355 |
+
query_cur = query_to_do[start:]
|
356 |
+
else:
|
357 |
+
query_cur = query_to_do[start: end]
|
358 |
+
t = threading.Thread(target=run_hotpot, args=(query_cur, HP_answer_agent, worker))
|
359 |
+
t.start()
|
360 |
+
threads.append(t)
|
361 |
+
|
362 |
+
else:
|
363 |
+
for i in range(20):
|
364 |
+
|
365 |
+
if total_len == 0:
|
366 |
+
break
|
367 |
+
|
368 |
+
start = round(total_len/20)*i
|
369 |
+
end = round(total_len/20)*(i+1)
|
370 |
+
if i == 19:
|
371 |
+
query_cur = query_to_do[start:]
|
372 |
+
else:
|
373 |
+
query_cur = query_to_do[start: end]
|
374 |
+
t = threading.Thread(target=run_hotpot, args=(query_cur, HP_answer_agent, worker))
|
375 |
+
t.start()
|
376 |
+
threads.append(t)
|
377 |
+
|
378 |
+
for thread in threads:
|
379 |
+
thread.join()
|
380 |
+
|
381 |
+
with open(results_dir, "w") as file:
|
382 |
+
json.dump(results, file, indent=4, ensure_ascii=False)
|
383 |
+
|
384 |
+
with open(results_dir, "r") as file:
|
385 |
+
eval_data = json.load(file)
|
386 |
+
|
387 |
+
correct, reward, parsed_correct, parsed_reward, pre_dict, parsed_dict = eval_result(eval_data)
|
388 |
+
print(f"correct rate for {test_set} is: {correct}, reward rate for {test_set} is: {reward}")
|
389 |
+
print(f"parsed correct rate for {test_set} is: {parsed_correct}, parsed reward rate for {test_set} is: {parsed_reward}")
|
390 |
+
|
391 |
+
with open(f"data/{method_name}/{test_set}/parsed_result.json", "w") as file:
|
392 |
+
json.dump(parsed_dict, file, indent=4, ensure_ascii=False)
|
393 |
+
|
394 |
+
with open(f"data/{method_name}/{test_set}/original_result.json", "w") as file:
|
395 |
+
json.dump(pre_dict, file, indent=4, ensure_ascii=False)
|
Smurfs/eval/hotpot_qa/utils.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from collections import Counter
|
3 |
+
import string
|
4 |
+
import json
|
5 |
+
|
6 |
+
def format_step(step: str) -> str:
|
7 |
+
step = step.strip('\n').strip().replace('\n', '')
|
8 |
+
if step.startswith("Thought") or step.startswith("Action"):
|
9 |
+
step = step.split()[2:]
|
10 |
+
step = " ".join(step)
|
11 |
+
if "Thought" in step:
|
12 |
+
step = step.split("Thought")[0].strip()
|
13 |
+
if "Action" in step:
|
14 |
+
step = step.split("Action")[0].strip()
|
15 |
+
if "Observation" in step:
|
16 |
+
step = step.split("Observation")[0].strip()
|
17 |
+
return step
|
18 |
+
|
19 |
+
def normalize_answer(s):
|
20 |
+
def remove_articles(text):
|
21 |
+
return re.sub(r"\b(a|an|the)\b", " ", text)
|
22 |
+
|
23 |
+
def white_space_fix(text):
|
24 |
+
return " ".join(text.split())
|
25 |
+
|
26 |
+
def remove_punc(text):
|
27 |
+
exclude = set(string.punctuation)
|
28 |
+
return "".join(ch for ch in text if ch not in exclude)
|
29 |
+
|
30 |
+
def lower(text):
|
31 |
+
return text.lower()
|
32 |
+
|
33 |
+
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
34 |
+
|
35 |
+
def f1_score(prediction, ground_truth):
|
36 |
+
normalized_prediction = normalize_answer(prediction)
|
37 |
+
normalized_ground_truth = normalize_answer(ground_truth)
|
38 |
+
|
39 |
+
ZERO_METRIC = (0, 0, 0)
|
40 |
+
|
41 |
+
if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
|
42 |
+
return ZERO_METRIC
|
43 |
+
if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
|
44 |
+
return ZERO_METRIC
|
45 |
+
|
46 |
+
prediction_tokens = normalized_prediction.split()
|
47 |
+
ground_truth_tokens = normalized_ground_truth.split()
|
48 |
+
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
|
49 |
+
num_same = sum(common.values())
|
50 |
+
if num_same == 0:
|
51 |
+
return ZERO_METRIC
|
52 |
+
precision = 1.0 * num_same / len(prediction_tokens)
|
53 |
+
recall = 1.0 * num_same / len(ground_truth_tokens)
|
54 |
+
f1 = (2 * precision * recall) / (precision + recall)
|
55 |
+
return f1, precision, recall
|
56 |
+
|
57 |
+
def EM(answer, key) -> bool:
|
58 |
+
return normalize_answer(answer) == normalize_answer(key)
|
59 |
+
|
60 |
+
def score_string_similarity(str1, str2):
|
61 |
+
if str1 == str2:
|
62 |
+
return 2.0
|
63 |
+
elif " " in str1 or " " in str2:
|
64 |
+
str1_split = str1.split(" ")
|
65 |
+
str2_split = str2.split(" ")
|
66 |
+
overlap = list(set(str1_split) & set(str2_split))
|
67 |
+
return len(overlap) / max(len(str1_split), len(str2_split))
|
68 |
+
else:
|
69 |
+
return 0.0
|
70 |
+
|
71 |
+
def eval_result_once(question, pre, gt):
|
72 |
+
correct = EM(pre, gt)
|
73 |
+
reward = f1_score(pre, gt)[0]
|
74 |
+
# halted = agent.is_halted()
|
75 |
+
# error = agent.run_error
|
76 |
+
# prompt = agent._build_agent_prompt()
|
77 |
+
save_dict = {"question":question, "answer":gt, "prediction": pre, "EM":correct, "reward":reward}
|
78 |
+
# with open(file_path, 'a') as f:
|
79 |
+
# json.dump(save_dict, f)
|
80 |
+
# f.write("\n")
|
81 |
+
return save_dict
|
82 |
+
|
83 |
+
def eval_result(eval_data):
|
84 |
+
result = []
|
85 |
+
parsed_result = []
|
86 |
+
correct = 0
|
87 |
+
reward = 0
|
88 |
+
parsed_correct = 0
|
89 |
+
parsed_reward = 0
|
90 |
+
total_len = len(eval_data)
|
91 |
+
for d in eval_data:
|
92 |
+
pre = d["pre_ans"]
|
93 |
+
parsed_pre = d["parsed_pre"]
|
94 |
+
gt = d["gt_answer"]
|
95 |
+
question = d["question"]
|
96 |
+
pre_dict = eval_result_once(question, pre, gt)
|
97 |
+
parsed_dict = eval_result_once(question, parsed_pre, gt)
|
98 |
+
result.append(pre_dict)
|
99 |
+
parsed_result.append(parsed_dict)
|
100 |
+
|
101 |
+
correct += pre_dict["EM"]
|
102 |
+
reward += pre_dict["reward"]
|
103 |
+
|
104 |
+
parsed_correct += parsed_dict["EM"]
|
105 |
+
parsed_reward += parsed_dict["reward"]
|
106 |
+
|
107 |
+
correct /= total_len
|
108 |
+
reward /= total_len
|
109 |
+
|
110 |
+
parsed_correct /= total_len
|
111 |
+
parsed_reward /= total_len
|
112 |
+
|
113 |
+
return correct, reward, parsed_correct, parsed_reward, result, parsed_result
|
114 |
+
|
115 |
+
|
116 |
+
|
117 |
+
|
Smurfs/inference/__init__.py
ADDED
File without changes
|
Smurfs/inference/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (170 Bytes). View file
|
|
Smurfs/inference/__pycache__/inference.cpython-39.pyc
ADDED
Binary file (12.1 kB). View file
|
|
Smurfs/inference/__pycache__/server.cpython-39.pyc
ADDED
Binary file (5.27 kB). View file
|
|
Smurfs/inference/__pycache__/smurfs_worker.cpython-39.pyc
ADDED
Binary file (17.1 kB). View file
|
|
Smurfs/inference/__pycache__/utils.cpython-39.pyc
ADDED
Binary file (4.52 kB). View file
|
|
Smurfs/inference/functioncall_inference.py
ADDED
@@ -0,0 +1,533 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# — coding: utf-8 –
|
2 |
+
import json
|
3 |
+
import sys
|
4 |
+
import argparse
|
5 |
+
import time
|
6 |
+
import requests
|
7 |
+
import os
|
8 |
+
from utils import change_name, standardize, get_white_list, get_answer_log, get_observation_log, build_tree, get_answer_details, test_sets
|
9 |
+
from tqdm import tqdm
|
10 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
11 |
+
from Smurfs.model.vllm_model.vllm_model import vllm_Model
|
12 |
+
from Smurfs.inference.server import get_rapidapi_response
|
13 |
+
import threading
|
14 |
+
from Smurfs.agents.answer_agent.answer import answer_agent
|
15 |
+
from Smurfs.agents.executor_agent.executor import function_call_executor_agent
|
16 |
+
from Smurfs.agents.planning_agent.planner import planning_agent
|
17 |
+
from Smurfs.agents.verifier_agent.verifier import verifier_agent
|
18 |
+
|
19 |
+
def _Call_function(category, tool_name, api_name, tool_input, strip, white_list, toolbench_key, args):
|
20 |
+
use_rapidapi_key = args.use_rapidapi_key
|
21 |
+
rapidapi_key = os.environ.get("rapidapi_key")
|
22 |
+
api_customization = args.api_customization
|
23 |
+
|
24 |
+
api_name = change_name(standardize(api_name))
|
25 |
+
tool_name = standardize(tool_name)
|
26 |
+
if tool_name not in white_list.keys():
|
27 |
+
print(f"tool name doesn't exist: {tool_name}")
|
28 |
+
return {}, 1
|
29 |
+
standard_tool_name = white_list[tool_name]["standard_tool_name"]
|
30 |
+
payload = {
|
31 |
+
"category": category,
|
32 |
+
"tool_name": standard_tool_name,
|
33 |
+
"api_name": api_name,
|
34 |
+
"tool_input": tool_input,
|
35 |
+
"strip": strip,
|
36 |
+
"toolbench_key": toolbench_key
|
37 |
+
}
|
38 |
+
if use_rapidapi_key or api_customization:
|
39 |
+
payload["rapidapi_key"] = rapidapi_key
|
40 |
+
response = get_rapidapi_response(payload, api_customization=api_customization)
|
41 |
+
else:
|
42 |
+
time.sleep(2) # rate limit: 30 per minute
|
43 |
+
headers = {"toolbench_key": toolbench_key}
|
44 |
+
print(payload)
|
45 |
+
# if tool_input == {}:
|
46 |
+
# response = requests.post("http://8.218.239.54:8080/rapidapi", headers=headers, timeout=15)
|
47 |
+
# else:
|
48 |
+
response = requests.post("http://8.218.239.54:8080/rapidapi", json=payload, headers=headers, timeout=15)
|
49 |
+
if response.status_code != 200:
|
50 |
+
return json.dumps({"error": f"request invalid, data error. status_code={response.status_code}", "response": ""}), 12
|
51 |
+
try:
|
52 |
+
response = response.json()
|
53 |
+
except:
|
54 |
+
print(response)
|
55 |
+
return json.dumps({"error": f"request invalid, data error", "response": ""}), 12
|
56 |
+
# 1 Hallucinating function names
|
57 |
+
# 4 means that the model decides to pruning by itself
|
58 |
+
# 5 represents api call timeout
|
59 |
+
# 6 for 404
|
60 |
+
# 7 means not subscribed
|
61 |
+
# 8 represents unauthorized
|
62 |
+
# 9 represents too many requests
|
63 |
+
# 10 stands for rate limit
|
64 |
+
# 11 message contains "error" field
|
65 |
+
# 12 error sending request
|
66 |
+
if response["error"] == "API not working error...":
|
67 |
+
status_code = 6
|
68 |
+
elif response["error"] == "Unauthorized error...":
|
69 |
+
status_code = 7
|
70 |
+
elif response["error"] == "Unsubscribed error...":
|
71 |
+
status_code = 8
|
72 |
+
elif response["error"] == "Too many requests error...":
|
73 |
+
status_code = 9
|
74 |
+
elif response["error"] == "Rate limit per minute error...":
|
75 |
+
print("Reach api calling limit per minute, sleeping...")
|
76 |
+
time.sleep(10)
|
77 |
+
status_code = 10
|
78 |
+
elif response["error"] == "Message error...":
|
79 |
+
status_code = 11
|
80 |
+
elif response["error"] != "":
|
81 |
+
status_code = "unexpected error, try again!"
|
82 |
+
else:
|
83 |
+
status_code = 0
|
84 |
+
return json.dumps(response), status_code
|
85 |
+
|
86 |
+
def Call_function(category, tool_name, api_name, tool_input, strip, white_list, args):
|
87 |
+
toolbench_key = os.environ.get("toolbench_key")
|
88 |
+
response, status_code = _Call_function(category, tool_name, api_name, tool_input, strip, white_list, toolbench_key, args)
|
89 |
+
if status_code == "unexpected error, try again!":
|
90 |
+
arg = {change_name(k.lower()): v for k, v in tool_input.items()}
|
91 |
+
response, status_code = _Call_function(category, tool_name, api_name, arg, strip, white_list, toolbench_key, args)
|
92 |
+
if status_code == "unexpected error, try again!":
|
93 |
+
arg = {change_name(k.replace("-", "_")): v for k, v in tool_input.items()}
|
94 |
+
response, status_code = _Call_function(category, tool_name, api_name, arg, strip, white_list, toolbench_key, args)
|
95 |
+
if status_code == "unexpected error, try again!":
|
96 |
+
arg = {change_name(k.replace("\\", "")): v for k, v in tool_input.items()}
|
97 |
+
response, status_code = _Call_function(category, tool_name, api_name, arg, strip, white_list, toolbench_key, args)
|
98 |
+
if status_code == "unexpected error, try again!":
|
99 |
+
print(f"Call function fails")
|
100 |
+
with open('wrong_log.json', 'a+', encoding='utf-8') as f:
|
101 |
+
line = json.dumps({
|
102 |
+
"id": 0,
|
103 |
+
"parameters": arg,
|
104 |
+
"wrong": response
|
105 |
+
}, ensure_ascii=False)
|
106 |
+
f.write(line + '\n')
|
107 |
+
return -1
|
108 |
+
return response
|
109 |
+
|
110 |
+
def inference(query, relevant_APIs, white_list, subtask, Answer_Agent, Executor_Agent, Verifier_Agent, query_id, args, max_step=3):
|
111 |
+
tool_check_num = Answer_Agent.run(question=query, task="tool_check", query_id=query_id)
|
112 |
+
#direct answer
|
113 |
+
if tool_check_num == 1:
|
114 |
+
input_dic = {"task": query}
|
115 |
+
answer = Answer_Agent.run(input_dic)
|
116 |
+
return answer, answer, None, None
|
117 |
+
|
118 |
+
previous_log = []
|
119 |
+
history_log = []
|
120 |
+
tool_used_dic = {}
|
121 |
+
relevant_APIs_ids = []
|
122 |
+
for idx in relevant_APIs:
|
123 |
+
ele = relevant_APIs[idx]
|
124 |
+
relevant_APIs_ids.append(str(ele["ID"]))
|
125 |
+
restart_time = 0
|
126 |
+
step_num = 0
|
127 |
+
hint = "Beginnig of the agent. No hint yet"
|
128 |
+
retry_tool_id = 0
|
129 |
+
retry_parameter = 0
|
130 |
+
re_time = 0
|
131 |
+
subtask_id = 0
|
132 |
+
restart = 0
|
133 |
+
while True:
|
134 |
+
if step_num >= max_step:
|
135 |
+
print("\n\nReach steps limits, return answers!\n\n")
|
136 |
+
answer_log = get_answer_log(history_log)
|
137 |
+
answer = Answer_Agent.run(question=query, previous_log=answer_log, task="final", query_id=query_id)
|
138 |
+
return answer, previous_log, re_time, history_log
|
139 |
+
|
140 |
+
if step_num not in tool_used_dic.keys():
|
141 |
+
tool_used_dic[step_num] = []
|
142 |
+
|
143 |
+
tool_used = tool_used_dic[step_num]
|
144 |
+
|
145 |
+
tool_list = []
|
146 |
+
for idx in relevant_APIs:
|
147 |
+
ele = idx
|
148 |
+
ID = str(ele['api_name'])
|
149 |
+
if ID in tool_used:
|
150 |
+
continue
|
151 |
+
# des = ele['description']
|
152 |
+
# name = ele["tool_name"]
|
153 |
+
# tool_list.append({"ID": ID, "tool_name": name, "description": des})
|
154 |
+
tool_list.append(ele)
|
155 |
+
|
156 |
+
if len(tool_list) == 0:
|
157 |
+
if len(previous_log) == 0:
|
158 |
+
answer_log = get_answer_log(history_log)
|
159 |
+
partial_answer = Answer_Agent.run(question=query, previous_log=answer_log, task="final", query_id=query_id)
|
160 |
+
answer = f"Sorry, I can't answer this question accurately using the existing tools. A partial answer is: {partial_answer}"
|
161 |
+
return answer, previous_log, re_time, history_log
|
162 |
+
else:
|
163 |
+
delete_log = previous_log.pop()
|
164 |
+
tool_used_dic[step_num] = []
|
165 |
+
step_num -= 1
|
166 |
+
tool_used_dic[step_num].append(delete_log["tool"])
|
167 |
+
restart_time += 1
|
168 |
+
re_time += 1
|
169 |
+
continue
|
170 |
+
|
171 |
+
current_log = {"thought": "", "action": "", "action_input": {}, "observation": "", "answer": "", "tool": "","id": subtask_id}
|
172 |
+
|
173 |
+
answer_log = get_answer_log(previous_log)
|
174 |
+
|
175 |
+
if retry_tool_id == 4:
|
176 |
+
# tool_id = tool_list[0]["ID"]
|
177 |
+
tool_list = tool_list[0]
|
178 |
+
thought = Executor_Agent.run(question=query, tool_list=tool_list, previous_log=answer_log, hint=hint, query_id=query_id, task="thought")
|
179 |
+
|
180 |
+
else:
|
181 |
+
thought = Executor_Agent.run(question=query, tool_list=tool_list, previous_log=answer_log, hint=hint, query_id=query_id, task="thought")
|
182 |
+
# tool_id = Executor_Agent.run(question=subtask, tool_list=tool_list, thought=thought, query_id=query_id, task="tool")
|
183 |
+
|
184 |
+
try:
|
185 |
+
tool_id = int(tool_id)
|
186 |
+
tool_id = str(tool_id)
|
187 |
+
if tool_id not in relevant_APIs_ids:
|
188 |
+
re_time += 1
|
189 |
+
retry_tool_id += 1
|
190 |
+
print("Tool ID wrong! Generate tool_id that do not exist!")
|
191 |
+
continue
|
192 |
+
tool_des_json = relevant_APIs[str(tool_id)]
|
193 |
+
retry_tool_id = 0
|
194 |
+
except:
|
195 |
+
retry_tool_id += 1
|
196 |
+
print("Tool ID wrong! Generate tool_id that do not exist!")
|
197 |
+
continue
|
198 |
+
|
199 |
+
tool_name_list = tool_des_json["tool_name"].split(":")
|
200 |
+
category_name = tool_name_list[0]
|
201 |
+
tool_name = tool_name_list[1]
|
202 |
+
api_name = tool_name_list[2]
|
203 |
+
API_doc = tool_des_json
|
204 |
+
|
205 |
+
while True:
|
206 |
+
try:
|
207 |
+
parameters = {}
|
208 |
+
|
209 |
+
if retry_parameter == 4:
|
210 |
+
restart = 1
|
211 |
+
retry_parameter = 0
|
212 |
+
print("No Para! Restart!")
|
213 |
+
break
|
214 |
+
|
215 |
+
parameter = Executor_Agent.run(api_dic=API_doc, question=query, previous_log=answer_log, thought=thought, query_id=query_id, task="parameter")
|
216 |
+
if parameter == -1:
|
217 |
+
retry_parameter += 1
|
218 |
+
re_time += 1
|
219 |
+
continue
|
220 |
+
|
221 |
+
if parameter == {}:
|
222 |
+
retry_parameter = 0
|
223 |
+
parameters = {}
|
224 |
+
break
|
225 |
+
|
226 |
+
for key in parameter:
|
227 |
+
value = parameter[key]
|
228 |
+
key = change_name(key)
|
229 |
+
parameters[key] = value
|
230 |
+
|
231 |
+
retry_parameter = 0
|
232 |
+
break
|
233 |
+
|
234 |
+
except:
|
235 |
+
if retry_parameter == 4:
|
236 |
+
parameters = {}
|
237 |
+
retry_parameter = 0
|
238 |
+
restart = 1
|
239 |
+
break
|
240 |
+
retry_parameter += 1
|
241 |
+
print("parameter generation fails, try again!")
|
242 |
+
re_time += 1
|
243 |
+
continue
|
244 |
+
|
245 |
+
|
246 |
+
api_name = change_name(standardize(api_name))
|
247 |
+
|
248 |
+
if restart != 1:
|
249 |
+
try:
|
250 |
+
observation = Call_function(category_name, tool_name, api_name, parameters, "truncate", white_list, args)
|
251 |
+
except:
|
252 |
+
observation = -1
|
253 |
+
|
254 |
+
if observation == -1:
|
255 |
+
restart = 1
|
256 |
+
observation = str({"error": "", "response": "call API fails"})
|
257 |
+
|
258 |
+
if restart == 1:
|
259 |
+
tool_used_dic[step_num].append(str(tool_id))
|
260 |
+
print('****Try Again For This Step****')
|
261 |
+
re_time += 1
|
262 |
+
restart = 0
|
263 |
+
continue
|
264 |
+
|
265 |
+
|
266 |
+
if len(previous_log) != 0:
|
267 |
+
previous_id = previous_log[-1]["id"]
|
268 |
+
else:
|
269 |
+
previous_id = -1
|
270 |
+
|
271 |
+
current_log["tool"] = str(tool_id)
|
272 |
+
current_log["thought"] = thought
|
273 |
+
current_log["action"] = api_name
|
274 |
+
current_log["action_input"] = parameters
|
275 |
+
current_log["observation"] = observation
|
276 |
+
previous_log.append(current_log)
|
277 |
+
|
278 |
+
observation_log = get_observation_log(previous_log)
|
279 |
+
|
280 |
+
answer = Answer_Agent.run(question=subtask, call_result=observation_log, query_id=query_id, task="answer")
|
281 |
+
|
282 |
+
previous_log[-1]["answer"] = answer
|
283 |
+
|
284 |
+
history_log_ele = {"thought": thought, "action": tool_name, "action_input": parameters, "observation": observation, "answer": answer, "previous_id": previous_id, "id": subtask_id}
|
285 |
+
history_log.append(history_log_ele)
|
286 |
+
subtask_id += 1
|
287 |
+
|
288 |
+
speak, status = Verifier_Agent.run(question=subtask, answer=answer, query_id=query_id)
|
289 |
+
if speak == -1 and status == -1:
|
290 |
+
step_num += 1
|
291 |
+
continue
|
292 |
+
|
293 |
+
try:
|
294 |
+
if int(status) == 0:
|
295 |
+
hint = speak
|
296 |
+
step_num += 1
|
297 |
+
continue
|
298 |
+
except:
|
299 |
+
step_num += 1
|
300 |
+
continue
|
301 |
+
|
302 |
+
else:
|
303 |
+
return answer, previous_log, re_time, history_log
|
304 |
+
|
305 |
+
def decompose_inference(query, relevant_APIs, api_list, white_list, Answer_Agent, Executor_Agent, Planning_Agent, Verifier_Agent, query_id, args):
|
306 |
+
while True:
|
307 |
+
subtasks = Planning_Agent.run(question=query, query_id=query_id)
|
308 |
+
if subtasks == -1:
|
309 |
+
continue
|
310 |
+
break
|
311 |
+
task_log = ""
|
312 |
+
history_log = []
|
313 |
+
previous_log_totals = []
|
314 |
+
re_time_total = 0
|
315 |
+
print(subtasks)
|
316 |
+
relevant_API_list = []
|
317 |
+
# tool_id = 0
|
318 |
+
for api in api_list:
|
319 |
+
for relevant_API in relevant_APIs:
|
320 |
+
if relevant_API[0] == api["tool_name"] and relevant_API[1] == api["api_name"]:
|
321 |
+
# new_tool_name = api["category_name"]+":"+api["tool_name"]+":"+api["api_name"]
|
322 |
+
ele = api
|
323 |
+
# ele = {"ID": tool_id, "tool_name": new_tool_name, "description": api["api_description"], "required_parameters": api["required_parameters"], "optional_parameters": api["optional_parameters"]}
|
324 |
+
# for para in api["required_parameters"]:
|
325 |
+
# para_ele = {
|
326 |
+
# para["name"]: {
|
327 |
+
# "type": para["type"],
|
328 |
+
# "description": para["description": ]
|
329 |
+
# }
|
330 |
+
# }
|
331 |
+
relevant_API_list.append(ele)
|
332 |
+
# tool_id += 1
|
333 |
+
|
334 |
+
for subtask in subtasks:
|
335 |
+
task_log += f"question: {subtask}\n"
|
336 |
+
answer, previous_log, re_time, previous_log_total = inference(task_log, relevant_API_list, white_list, subtask, Answer_Agent, Executor_Agent, Verifier_Agent, query_id, args)
|
337 |
+
previous_log_totals.append(previous_log_total)
|
338 |
+
print(answer)
|
339 |
+
history_log += previous_log
|
340 |
+
re_time_total += re_time
|
341 |
+
task_log += f"answer: {answer}\n"
|
342 |
+
final_answer = Answer_Agent.run(question=query, previous_log=task_log, task="final", query_id=query_id)
|
343 |
+
return final_answer, history_log, task_log, re_time_total, previous_log_totals
|
344 |
+
|
345 |
+
def test(query_json, white_list, output_dir, whole_solution_dir, Answer_Agent, Executor_Agent, Planning_Agent, Verifier_Agent, args):
|
346 |
+
while True:
|
347 |
+
try:
|
348 |
+
global lock
|
349 |
+
total_query = len(query_json)
|
350 |
+
with tqdm(total=total_query, desc="Processing files", initial=0) as pbar:
|
351 |
+
for i, test_query in enumerate(query_json, start=0):
|
352 |
+
idx = test_query[0]
|
353 |
+
test_query = test_query[1]
|
354 |
+
query = test_query["query"]
|
355 |
+
relevant_APIs = test_query["relevant APIs"]
|
356 |
+
api_list = test_query["api_list"]
|
357 |
+
final_answer, previous_log, task_log,re_time, previous_log_totals = decompose_inference(query, relevant_APIs, api_list, white_list, Answer_Agent, Executor_Agent, Planning_Agent, Verifier_Agent, idx, args)
|
358 |
+
answer_details, total_steps = get_answer_details(final_answer, previous_log)
|
359 |
+
solution_tree, solution_total_steps = build_tree(previous_log_totals, task_log)
|
360 |
+
output_file_ele = {
|
361 |
+
"query": query,
|
362 |
+
"restart_time": re_time,
|
363 |
+
"answer": {
|
364 |
+
"method": "decompose_dfs",
|
365 |
+
"total_steps": total_steps,
|
366 |
+
"final_answer": final_answer,
|
367 |
+
"answer_details": answer_details
|
368 |
+
}
|
369 |
+
}
|
370 |
+
|
371 |
+
solution_file_ele = {
|
372 |
+
"query": query,
|
373 |
+
"total_steps": solution_total_steps,
|
374 |
+
"task_log": task_log,
|
375 |
+
"final_answer": final_answer,
|
376 |
+
"answer_path": answer_details,
|
377 |
+
"total_path": solution_tree
|
378 |
+
}
|
379 |
+
file_name = f"{idx}.json"
|
380 |
+
output_file = os.path.join(output_dir, file_name)
|
381 |
+
whole_solution_file = os.path.join(whole_solution_dir, file_name)
|
382 |
+
lock.acquire()
|
383 |
+
with open(output_file, "w") as file:
|
384 |
+
json.dump(output_file_ele, file, ensure_ascii=False, indent=4)
|
385 |
+
with open(whole_solution_file, "w") as file:
|
386 |
+
json.dump(solution_file_ele, file, ensure_ascii=False, indent=4)
|
387 |
+
lock.release()
|
388 |
+
pbar.update(1)
|
389 |
+
return
|
390 |
+
|
391 |
+
except Exception as e:
|
392 |
+
print(e)
|
393 |
+
print("some error occurs, continue...")
|
394 |
+
time.sleep(60)
|
395 |
+
continue
|
396 |
+
|
397 |
+
def parse_arg():
|
398 |
+
parser = argparse.ArgumentParser()
|
399 |
+
parser.add_argument('--test_query_id_path', type=str, default="toolbench_data/data/test_query_ids", required=False, help='test query ids for different test sets')
|
400 |
+
parser.add_argument('--method_name', type=str, default="smurfs-all", required=False, help='the inference method')
|
401 |
+
parser.add_argument('--model_name', type=str, default="your_model_name", required=False, help='the model name for the vllm model')
|
402 |
+
parser.add_argument('--query_file_dir', type=str, default="toolbench_data/data/test_instruction", required=False, help='the directory that contains test sets')
|
403 |
+
parser.add_argument('--tool_env_dir', type=str, default="toolbench_data/data/toolenv/tools", required=False, help='tool environment for the toolbench')
|
404 |
+
parser.add_argument('--toolbench_key', type=str, default="",required=False, help='your toolbench key to request rapidapi service')
|
405 |
+
parser.add_argument('--rapidapi_key', type=str, default="",required=False, help='your rapidapi key to request rapidapi service')
|
406 |
+
parser.add_argument('--use_rapidapi_key', action="store_true", help="To use customized rapidapi service or not.")
|
407 |
+
parser.add_argument('--api_customization', action="store_true", help="To use customized api or not.")
|
408 |
+
args = parser.parse_args()
|
409 |
+
|
410 |
+
return args
|
411 |
+
|
412 |
+
if __name__ == '__main__':
|
413 |
+
threads = []
|
414 |
+
lock = threading.Lock()
|
415 |
+
|
416 |
+
args = parse_arg()
|
417 |
+
|
418 |
+
test_query_id_path = args.test_query_id_path
|
419 |
+
method_name = args.method_name
|
420 |
+
model_name = args.model_name
|
421 |
+
query_file_dir = args.query_file_dir
|
422 |
+
tool_env_dir = args.tool_env_dir
|
423 |
+
|
424 |
+
toolbench_key = args.toolbench_key
|
425 |
+
rapidapi_key = args.rapidapi_key
|
426 |
+
use_rapidapi_key = args.use_rapidapi_key
|
427 |
+
api_customization = args.api_customization
|
428 |
+
|
429 |
+
chat = vllm_Model(model_name=model_name)
|
430 |
+
|
431 |
+
|
432 |
+
for test_set in test_sets:
|
433 |
+
total_output_file = f"data/{method_name}/{test_set}_raw.json"
|
434 |
+
|
435 |
+
test_ids = list(json.load(open(os.path.join(test_query_id_path, test_set+".json"), "r")).keys())
|
436 |
+
|
437 |
+
query_file = f'{query_file_dir}/{test_set}.json'
|
438 |
+
|
439 |
+
output_dir = f"data/{method_name}/{test_set}/answer"
|
440 |
+
|
441 |
+
whole_solution_dir = f"data/{method_name}/{test_set}/whole"
|
442 |
+
|
443 |
+
logger_dir = f"data/{method_name}/{test_set}/agent_log"
|
444 |
+
|
445 |
+
if not os.path.exists(output_dir):
|
446 |
+
os.makedirs(output_dir)
|
447 |
+
|
448 |
+
if not os.path.exists(whole_solution_dir):
|
449 |
+
os.makedirs(whole_solution_dir)
|
450 |
+
|
451 |
+
if not os.path.exists(logger_dir):
|
452 |
+
os.makedirs(logger_dir)
|
453 |
+
|
454 |
+
Answer_Agent = answer_agent(llm=chat, logger_dir=logger_dir)
|
455 |
+
Executor_Agent = function_call_executor_agent(llm=chat, logger_dir=logger_dir, max_observation_length=4096, use_rapidapi_key=use_rapidapi_key, api_customization=api_customization, toolbench_key="HcnlyY4DUKOr3mMas51dewgfzAhBVST7EWv0FPtNRjpoi6buJk", service_url="http://8.218.239.54:8080/rapidapi")
|
456 |
+
Planning_Agent = planning_agent(llm=chat, logger_dir=logger_dir)
|
457 |
+
Verifier_Agent = verifier_agent(llm=chat, logger_dir=logger_dir)
|
458 |
+
|
459 |
+
items = os.listdir(output_dir)
|
460 |
+
for i in range(len(items)):
|
461 |
+
items[i] = items[i].split(".")[0]
|
462 |
+
|
463 |
+
white_list = get_white_list(tool_env_dir)
|
464 |
+
|
465 |
+
with open(query_file) as file:
|
466 |
+
query_json = json.load(file)
|
467 |
+
# with open(tool_doc_dir) as file:
|
468 |
+
# tool_doc = json.load(file)
|
469 |
+
|
470 |
+
print(len(items))
|
471 |
+
query_json_to_do = []
|
472 |
+
# if len(items) != 0:
|
473 |
+
for idx, q in enumerate(query_json):
|
474 |
+
# print(idx)
|
475 |
+
if str(idx) in items:
|
476 |
+
continue
|
477 |
+
query_id = q["query_id"]
|
478 |
+
if str(query_id) not in test_ids:
|
479 |
+
continue
|
480 |
+
query_json_to_do_ele = (idx, q)
|
481 |
+
query_json_to_do.append(query_json_to_do_ele)
|
482 |
+
# else:
|
483 |
+
# query_json_to_do = query_json
|
484 |
+
|
485 |
+
total_len = len(query_json_to_do)
|
486 |
+
query_len = len(query_json)
|
487 |
+
print(total_len)
|
488 |
+
|
489 |
+
if total_len < 20:
|
490 |
+
for i in range(total_len):
|
491 |
+
if total_len == 0:
|
492 |
+
break
|
493 |
+
|
494 |
+
start = i
|
495 |
+
end = i+1
|
496 |
+
if i == total_len-1:
|
497 |
+
query_json_cur = query_json_to_do[start:]
|
498 |
+
else:
|
499 |
+
query_json_cur = query_json_to_do[start: end]
|
500 |
+
t = threading.Thread(target=test, args=(query_json_cur, white_list, output_dir, whole_solution_dir, Answer_Agent, Executor_Agent, Planning_Agent, Verifier_Agent, args))
|
501 |
+
t.start()
|
502 |
+
threads.append(t)
|
503 |
+
|
504 |
+
|
505 |
+
else:
|
506 |
+
for i in range(20):
|
507 |
+
|
508 |
+
if total_len == 0:
|
509 |
+
break
|
510 |
+
|
511 |
+
start = round(total_len/20)*i
|
512 |
+
end = round(total_len/20)*(i+1)
|
513 |
+
if i == 19:
|
514 |
+
query_json_cur = query_json_to_do[start:]
|
515 |
+
else:
|
516 |
+
query_json_cur = query_json_to_do[start: end]
|
517 |
+
t = threading.Thread(target=test, args=(query_json_cur, white_list, output_dir, whole_solution_dir, Answer_Agent, Executor_Agent, Planning_Agent, Verifier_Agent, args))
|
518 |
+
t.start()
|
519 |
+
threads.append(t)
|
520 |
+
|
521 |
+
for thread in threads:
|
522 |
+
thread.join()
|
523 |
+
|
524 |
+
total_json = {}
|
525 |
+
items = os.listdir(output_dir)
|
526 |
+
for item in items:
|
527 |
+
item_path = os.path.join(output_dir, item)
|
528 |
+
idx = item.split(".")[0]
|
529 |
+
total_json[str(idx)] = json.load(open(item_path, 'r'))
|
530 |
+
|
531 |
+
with open(total_output_file, 'w') as file:
|
532 |
+
json.dump(total_json, file, indent=4, ensure_ascii=False)
|
533 |
+
|
Smurfs/inference/inference.py
ADDED
@@ -0,0 +1,527 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# — coding: utf-8 –
|
2 |
+
import json
|
3 |
+
import sys
|
4 |
+
import argparse
|
5 |
+
import time
|
6 |
+
import requests
|
7 |
+
import os
|
8 |
+
from tqdm import tqdm
|
9 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
10 |
+
from Smurfs.inference.utils import change_name, standardize, get_white_list, get_answer_log, get_observation_log, build_tree, get_answer_details, test_sets
|
11 |
+
from Smurfs.model.vllm_model.vllm_model import vllm_Model
|
12 |
+
from Smurfs.inference.server import get_rapidapi_response
|
13 |
+
import threading
|
14 |
+
from Smurfs.agents.answer_agent.answer import answer_agent
|
15 |
+
from Smurfs.agents.executor_agent.executor import executor_agent
|
16 |
+
from Smurfs.agents.planning_agent.planner import planning_agent
|
17 |
+
from Smurfs.agents.verifier_agent.verifier import verifier_agent
|
18 |
+
import warnings
|
19 |
+
|
20 |
+
warnings.filterwarnings('ignore')
|
21 |
+
|
22 |
+
def _Call_function(category, tool_name, api_name, tool_input, strip, white_list, toolbench_key, args):
|
23 |
+
use_rapidapi_key = args.use_rapidapi_key
|
24 |
+
rapidapi_key = os.environ.get("rapidapi_key")
|
25 |
+
api_customization = args.api_customization
|
26 |
+
|
27 |
+
api_name = change_name(standardize(api_name))
|
28 |
+
tool_name = standardize(tool_name)
|
29 |
+
if tool_name not in white_list.keys():
|
30 |
+
print(f"tool name doesn't exist: {tool_name}")
|
31 |
+
return {}, 1
|
32 |
+
standard_tool_name = white_list[tool_name]["standard_tool_name"]
|
33 |
+
payload = {
|
34 |
+
"category": category,
|
35 |
+
"tool_name": standard_tool_name,
|
36 |
+
"api_name": api_name,
|
37 |
+
"tool_input": tool_input,
|
38 |
+
"strip": strip,
|
39 |
+
"toolbench_key": toolbench_key
|
40 |
+
}
|
41 |
+
if use_rapidapi_key or api_customization:
|
42 |
+
payload["rapidapi_key"] = rapidapi_key
|
43 |
+
response = get_rapidapi_response(payload, api_customization=api_customization)
|
44 |
+
else:
|
45 |
+
time.sleep(2) # rate limit: 30 per minute
|
46 |
+
headers = {"toolbench_key": toolbench_key}
|
47 |
+
print(payload)
|
48 |
+
# if tool_input == {}:
|
49 |
+
# response = requests.post("http://8.218.239.54:8080/rapidapi", headers=headers, timeout=15)
|
50 |
+
# else:
|
51 |
+
response = requests.post("http://8.218.239.54:8080/rapidapi", json=payload, headers=headers, timeout=15)
|
52 |
+
if response.status_code != 200:
|
53 |
+
return json.dumps({"error": f"request invalid, data error. status_code={response.status_code}", "response": ""}), 12
|
54 |
+
try:
|
55 |
+
response = response.json()
|
56 |
+
except:
|
57 |
+
print(response)
|
58 |
+
return json.dumps({"error": f"request invalid, data error", "response": ""}), 12
|
59 |
+
# 1 Hallucinating function names
|
60 |
+
# 4 means that the model decides to pruning by itself
|
61 |
+
# 5 represents api call timeout
|
62 |
+
# 6 for 404
|
63 |
+
# 7 means not subscribed
|
64 |
+
# 8 represents unauthorized
|
65 |
+
# 9 represents too many requests
|
66 |
+
# 10 stands for rate limit
|
67 |
+
# 11 message contains "error" field
|
68 |
+
# 12 error sending request
|
69 |
+
if response["error"] == "API not working error...":
|
70 |
+
status_code = 6
|
71 |
+
elif response["error"] == "Unauthorized error...":
|
72 |
+
status_code = 7
|
73 |
+
elif response["error"] == "Unsubscribed error...":
|
74 |
+
status_code = 8
|
75 |
+
elif response["error"] == "Too many requests error...":
|
76 |
+
status_code = 9
|
77 |
+
elif response["error"] == "Rate limit per minute error...":
|
78 |
+
print("Reach api calling limit per minute, sleeping...")
|
79 |
+
time.sleep(10)
|
80 |
+
status_code = 10
|
81 |
+
elif response["error"] == "Message error...":
|
82 |
+
status_code = 11
|
83 |
+
elif response["error"] != "":
|
84 |
+
status_code = "unexpected error, try again!"
|
85 |
+
else:
|
86 |
+
status_code = 0
|
87 |
+
return json.dumps(response), status_code
|
88 |
+
|
89 |
+
def Call_function(category, tool_name, api_name, tool_input, strip, white_list, args):
|
90 |
+
toolbench_key = os.environ.get("toolbench_key")
|
91 |
+
response, status_code = _Call_function(category, tool_name, api_name, tool_input, strip, white_list, toolbench_key, args)
|
92 |
+
if status_code == "unexpected error, try again!":
|
93 |
+
arg = {change_name(k.lower()): v for k, v in tool_input.items()}
|
94 |
+
response, status_code = _Call_function(category, tool_name, api_name, arg, strip, white_list, toolbench_key, args)
|
95 |
+
if status_code == "unexpected error, try again!":
|
96 |
+
arg = {change_name(k.replace("-", "_")): v for k, v in tool_input.items()}
|
97 |
+
response, status_code = _Call_function(category, tool_name, api_name, arg, strip, white_list, toolbench_key, args)
|
98 |
+
if status_code == "unexpected error, try again!":
|
99 |
+
arg = {change_name(k.replace("\\", "")): v for k, v in tool_input.items()}
|
100 |
+
response, status_code = _Call_function(category, tool_name, api_name, arg, strip, white_list, toolbench_key, args)
|
101 |
+
if status_code == "unexpected error, try again!":
|
102 |
+
print(f"Call function fails")
|
103 |
+
with open('wrong_log.json', 'a+', encoding='utf-8') as f:
|
104 |
+
line = json.dumps({
|
105 |
+
"id": 0,
|
106 |
+
"parameters": arg,
|
107 |
+
"wrong": response
|
108 |
+
}, ensure_ascii=False)
|
109 |
+
f.write(line + '\n')
|
110 |
+
return -1
|
111 |
+
return response
|
112 |
+
|
113 |
+
def inference(query, relevant_APIs, white_list, subtask, Answer_Agent, Executor_Agent, Verifier_Agent, query_id, args, max_step=3):
|
114 |
+
tool_check_num = Answer_Agent.run(question=query, task="tool_check", query_id=query_id)
|
115 |
+
#direct answer
|
116 |
+
if tool_check_num == 1:
|
117 |
+
input_dic = {"task": query}
|
118 |
+
answer = Answer_Agent.run(input_dic)
|
119 |
+
return answer, answer, None, None
|
120 |
+
|
121 |
+
previous_log = []
|
122 |
+
history_log = []
|
123 |
+
tool_used_dic = {}
|
124 |
+
relevant_APIs_ids = []
|
125 |
+
for idx in relevant_APIs:
|
126 |
+
ele = relevant_APIs[idx]
|
127 |
+
relevant_APIs_ids.append(str(ele["ID"]))
|
128 |
+
restart_time = 0
|
129 |
+
step_num = 0
|
130 |
+
hint = "Beginnig of the agent. No hint yet"
|
131 |
+
retry_tool_id = 0
|
132 |
+
retry_parameter = 0
|
133 |
+
re_time = 0
|
134 |
+
subtask_id = 0
|
135 |
+
restart = 0
|
136 |
+
while True:
|
137 |
+
if step_num >= max_step:
|
138 |
+
print("\n\nReach steps limits, return answers!\n\n")
|
139 |
+
answer_log = get_answer_log(history_log)
|
140 |
+
answer = Answer_Agent.run(question=query, previous_log=answer_log, task="final", query_id=query_id)
|
141 |
+
return answer, previous_log, re_time, history_log
|
142 |
+
|
143 |
+
if step_num not in tool_used_dic.keys():
|
144 |
+
tool_used_dic[step_num] = []
|
145 |
+
|
146 |
+
tool_used = tool_used_dic[step_num]
|
147 |
+
|
148 |
+
tool_list = []
|
149 |
+
for idx in relevant_APIs:
|
150 |
+
ele = relevant_APIs[idx]
|
151 |
+
ID = str(ele['ID'])
|
152 |
+
if ID in tool_used:
|
153 |
+
continue
|
154 |
+
des = ele['description']
|
155 |
+
name = ele["tool_name"]
|
156 |
+
tool_list.append({"ID": ID, "tool_name": name, "description": des})
|
157 |
+
|
158 |
+
if len(tool_list) == 0:
|
159 |
+
if len(previous_log) == 0:
|
160 |
+
answer_log = get_answer_log(history_log)
|
161 |
+
partial_answer = Answer_Agent.run(question=query, previous_log=answer_log, task="final", query_id=query_id)
|
162 |
+
answer = f"Sorry, I can't answer this question accurately using the existing tools. A partial answer is: {partial_answer}"
|
163 |
+
return answer, previous_log, re_time, history_log
|
164 |
+
else:
|
165 |
+
delete_log = previous_log.pop()
|
166 |
+
tool_used_dic[step_num] = []
|
167 |
+
step_num -= 1
|
168 |
+
tool_used_dic[step_num].append(delete_log["tool"])
|
169 |
+
restart_time += 1
|
170 |
+
re_time += 1
|
171 |
+
continue
|
172 |
+
|
173 |
+
current_log = {"thought": "", "action": "", "action_input": {}, "observation": "", "answer": "", "tool": "","id": subtask_id}
|
174 |
+
|
175 |
+
answer_log = get_answer_log(previous_log)
|
176 |
+
|
177 |
+
if retry_tool_id == 4:
|
178 |
+
tool_id = tool_list[0]["ID"]
|
179 |
+
tool_list = tool_list[0]
|
180 |
+
thought = Executor_Agent.run(question=query, tool_list=tool_list, previous_log=answer_log, hint=hint, query_id=query_id, task="thought")
|
181 |
+
|
182 |
+
else:
|
183 |
+
thought = Executor_Agent.run(question=query, tool_list=tool_list, previous_log=answer_log, hint=hint, query_id=query_id, task="thought")
|
184 |
+
tool_id = Executor_Agent.run(question=subtask, tool_list=tool_list, thought=thought, query_id=query_id, task="tool")
|
185 |
+
|
186 |
+
try:
|
187 |
+
tool_id = int(tool_id)
|
188 |
+
tool_id = str(tool_id)
|
189 |
+
if tool_id not in relevant_APIs_ids:
|
190 |
+
re_time += 1
|
191 |
+
retry_tool_id += 1
|
192 |
+
print("Tool ID wrong! Generate tool_id that do not exist!")
|
193 |
+
continue
|
194 |
+
tool_des_json = relevant_APIs[str(tool_id)]
|
195 |
+
retry_tool_id = 0
|
196 |
+
except:
|
197 |
+
retry_tool_id += 1
|
198 |
+
print("Tool ID wrong! Generate tool_id that do not exist!")
|
199 |
+
continue
|
200 |
+
|
201 |
+
tool_name_list = tool_des_json["tool_name"].split(":")
|
202 |
+
category_name = tool_name_list[0]
|
203 |
+
tool_name = tool_name_list[1]
|
204 |
+
api_name = tool_name_list[2]
|
205 |
+
API_doc = tool_des_json
|
206 |
+
|
207 |
+
while True:
|
208 |
+
try:
|
209 |
+
parameters = {}
|
210 |
+
|
211 |
+
if retry_parameter == 4:
|
212 |
+
restart = 1
|
213 |
+
retry_parameter = 0
|
214 |
+
print("No Para! Restart!")
|
215 |
+
break
|
216 |
+
|
217 |
+
parameter = Executor_Agent.run(api_dic=API_doc, question=query, previous_log=answer_log, thought=thought, query_id=query_id, task="parameter")
|
218 |
+
if parameter == -1:
|
219 |
+
retry_parameter += 1
|
220 |
+
re_time += 1
|
221 |
+
continue
|
222 |
+
|
223 |
+
if parameter == {}:
|
224 |
+
retry_parameter = 0
|
225 |
+
parameters = {}
|
226 |
+
break
|
227 |
+
|
228 |
+
for key in parameter:
|
229 |
+
value = parameter[key]
|
230 |
+
key = change_name(key)
|
231 |
+
parameters[key] = value
|
232 |
+
|
233 |
+
retry_parameter = 0
|
234 |
+
break
|
235 |
+
|
236 |
+
except:
|
237 |
+
if retry_parameter == 4:
|
238 |
+
parameters = {}
|
239 |
+
retry_parameter = 0
|
240 |
+
restart = 1
|
241 |
+
break
|
242 |
+
retry_parameter += 1
|
243 |
+
print("parameter generation fails, try again!")
|
244 |
+
re_time += 1
|
245 |
+
continue
|
246 |
+
|
247 |
+
|
248 |
+
api_name = change_name(standardize(api_name))
|
249 |
+
|
250 |
+
if restart != 1:
|
251 |
+
try:
|
252 |
+
observation = Call_function(category_name, tool_name, api_name, parameters, "truncate", white_list, args)
|
253 |
+
except:
|
254 |
+
observation = -1
|
255 |
+
|
256 |
+
if observation == -1:
|
257 |
+
restart = 1
|
258 |
+
observation = str({"error": "", "response": "call API fails"})
|
259 |
+
|
260 |
+
if restart == 1:
|
261 |
+
tool_used_dic[step_num].append(str(tool_id))
|
262 |
+
print('****Try Again For This Step****')
|
263 |
+
re_time += 1
|
264 |
+
restart = 0
|
265 |
+
continue
|
266 |
+
|
267 |
+
|
268 |
+
if len(previous_log) != 0:
|
269 |
+
previous_id = previous_log[-1]["id"]
|
270 |
+
else:
|
271 |
+
previous_id = -1
|
272 |
+
|
273 |
+
current_log["tool"] = str(tool_id)
|
274 |
+
current_log["thought"] = thought
|
275 |
+
current_log["action"] = api_name
|
276 |
+
current_log["action_input"] = parameters
|
277 |
+
current_log["observation"] = observation
|
278 |
+
previous_log.append(current_log)
|
279 |
+
|
280 |
+
observation_log = get_observation_log(previous_log)
|
281 |
+
|
282 |
+
answer = Answer_Agent.run(question=subtask, call_result=observation_log, query_id=query_id, task="answer")
|
283 |
+
|
284 |
+
previous_log[-1]["answer"] = answer
|
285 |
+
|
286 |
+
history_log_ele = {"thought": thought, "action": tool_name, "action_input": parameters, "observation": observation, "answer": answer, "previous_id": previous_id, "id": subtask_id}
|
287 |
+
history_log.append(history_log_ele)
|
288 |
+
subtask_id += 1
|
289 |
+
|
290 |
+
speak, status = Verifier_Agent.run(question=subtask, answer=answer, query_id=query_id)
|
291 |
+
if speak == -1 and status == -1:
|
292 |
+
step_num += 1
|
293 |
+
continue
|
294 |
+
|
295 |
+
try:
|
296 |
+
if int(status) == 0:
|
297 |
+
hint = speak
|
298 |
+
step_num += 1
|
299 |
+
continue
|
300 |
+
except:
|
301 |
+
step_num += 1
|
302 |
+
continue
|
303 |
+
|
304 |
+
else:
|
305 |
+
return answer, previous_log, re_time, history_log
|
306 |
+
|
307 |
+
def decompose_inference(query, relevant_APIs, api_list, white_list, Answer_Agent, Executor_Agent, Planning_Agent, Verifier_Agent, query_id, args):
|
308 |
+
while True:
|
309 |
+
subtasks = Planning_Agent.run(question=query, query_id=query_id)
|
310 |
+
if subtasks == -1:
|
311 |
+
continue
|
312 |
+
break
|
313 |
+
task_log = ""
|
314 |
+
history_log = []
|
315 |
+
previous_log_totals = []
|
316 |
+
re_time_total = 0
|
317 |
+
print(subtasks)
|
318 |
+
relevant_API_list = {}
|
319 |
+
tool_id = 0
|
320 |
+
for api in api_list:
|
321 |
+
for relevant_API in relevant_APIs:
|
322 |
+
if relevant_API[0] == api["tool_name"] and relevant_API[1] == api["api_name"]:
|
323 |
+
new_tool_name = api["category_name"]+":"+api["tool_name"]+":"+api["api_name"]
|
324 |
+
ele = {"ID": tool_id, "tool_name": new_tool_name, "description": api["api_description"], "required_parameters": api["required_parameters"], "optional_parameters": api["optional_parameters"]}
|
325 |
+
relevant_API_list[str(tool_id)] = ele
|
326 |
+
tool_id += 1
|
327 |
+
|
328 |
+
for subtask in subtasks:
|
329 |
+
task_log += f"question: {subtask}\n"
|
330 |
+
answer, previous_log, re_time, previous_log_total = inference(task_log, relevant_API_list, white_list, subtask, Answer_Agent, Executor_Agent, Verifier_Agent, query_id, args)
|
331 |
+
previous_log_totals.append(previous_log_total)
|
332 |
+
print(answer)
|
333 |
+
history_log += previous_log
|
334 |
+
re_time_total += re_time
|
335 |
+
task_log += f"answer: {answer}\n"
|
336 |
+
final_answer = Answer_Agent.run(question=query, previous_log=task_log, task="final", query_id=query_id)
|
337 |
+
return final_answer, history_log, task_log, re_time_total, previous_log_totals
|
338 |
+
|
339 |
+
def test(query_json, white_list, output_dir, whole_solution_dir, Answer_Agent, Executor_Agent, Planning_Agent, Verifier_Agent, args):
|
340 |
+
while True:
|
341 |
+
try:
|
342 |
+
global lock
|
343 |
+
total_query = len(query_json)
|
344 |
+
with tqdm(total=total_query, desc="Processing files", initial=0) as pbar:
|
345 |
+
for i, test_query in enumerate(query_json, start=0):
|
346 |
+
idx = test_query[0]
|
347 |
+
test_query = test_query[1]
|
348 |
+
query = test_query["query"]
|
349 |
+
relevant_APIs = test_query["relevant APIs"]
|
350 |
+
api_list = test_query["api_list"]
|
351 |
+
final_answer, previous_log, task_log,re_time, previous_log_totals = decompose_inference(query, relevant_APIs, api_list, white_list, Answer_Agent, Executor_Agent, Planning_Agent, Verifier_Agent, idx, args)
|
352 |
+
answer_details, total_steps = get_answer_details(final_answer, previous_log)
|
353 |
+
solution_tree, solution_total_steps = build_tree(previous_log_totals, task_log)
|
354 |
+
output_file_ele = {
|
355 |
+
"query": query,
|
356 |
+
"restart_time": re_time,
|
357 |
+
"answer": {
|
358 |
+
"method": "decompose_dfs",
|
359 |
+
"total_steps": total_steps,
|
360 |
+
"final_answer": final_answer,
|
361 |
+
"answer_details": answer_details
|
362 |
+
}
|
363 |
+
}
|
364 |
+
|
365 |
+
solution_file_ele = {
|
366 |
+
"query": query,
|
367 |
+
"total_steps": solution_total_steps,
|
368 |
+
"task_log": task_log,
|
369 |
+
"final_answer": final_answer,
|
370 |
+
"answer_path": answer_details,
|
371 |
+
"total_path": solution_tree
|
372 |
+
}
|
373 |
+
file_name = f"{idx}.json"
|
374 |
+
output_file = os.path.join(output_dir, file_name)
|
375 |
+
whole_solution_file = os.path.join(whole_solution_dir, file_name)
|
376 |
+
lock.acquire()
|
377 |
+
with open(output_file, "w") as file:
|
378 |
+
json.dump(output_file_ele, file, ensure_ascii=False, indent=4)
|
379 |
+
with open(whole_solution_file, "w") as file:
|
380 |
+
json.dump(solution_file_ele, file, ensure_ascii=False, indent=4)
|
381 |
+
lock.release()
|
382 |
+
pbar.update(1)
|
383 |
+
return
|
384 |
+
|
385 |
+
except Exception as e:
|
386 |
+
print(e)
|
387 |
+
print("some error occurs, continue...")
|
388 |
+
time.sleep(60)
|
389 |
+
continue
|
390 |
+
|
391 |
+
def parse_arg():
|
392 |
+
parser = argparse.ArgumentParser()
|
393 |
+
parser.add_argument('--test_query_id_path', type=str, default="toolbench_data/data/test_query_ids", required=False, help='test query ids for different test sets')
|
394 |
+
parser.add_argument('--method_name', type=str, default="smurfs-all", required=False, help='the inference method')
|
395 |
+
parser.add_argument('--model_name', type=str, default="your_model_name", required=False, help='the model name for the vllm model')
|
396 |
+
parser.add_argument('--query_file_dir', type=str, default="toolbench_data/data/test_instruction", required=False, help='the directory that contains test sets')
|
397 |
+
parser.add_argument('--tool_env_dir', type=str, default="toolbench_data/data/toolenv/tools", required=False, help='tool environment for the toolbench')
|
398 |
+
parser.add_argument('--toolbench_key', type=str, default="",required=False, help='your toolbench key to request rapidapi service')
|
399 |
+
parser.add_argument('--rapidapi_key', type=str, default="",required=False, help='your rapidapi key to request rapidapi service')
|
400 |
+
parser.add_argument('--use_rapidapi_key', action="store_true", help="To use customized rapidapi service or not.")
|
401 |
+
parser.add_argument('--api_customization', action="store_true", help="To use customized api or not.")
|
402 |
+
args = parser.parse_args()
|
403 |
+
|
404 |
+
return args
|
405 |
+
|
406 |
+
if __name__ == '__main__':
|
407 |
+
threads = []
|
408 |
+
lock = threading.Lock()
|
409 |
+
|
410 |
+
args = parse_arg()
|
411 |
+
|
412 |
+
test_query_id_path = args.test_query_id_path
|
413 |
+
method_name = args.method_name
|
414 |
+
model_name = args.model_name
|
415 |
+
query_file_dir = args.query_file_dir
|
416 |
+
tool_env_dir = args.tool_env_dir
|
417 |
+
|
418 |
+
toolbench_key = args.toolbench_key
|
419 |
+
rapidapi_key = args.rapidapi_key
|
420 |
+
use_rapidapi_key = args.use_rapidapi_key
|
421 |
+
api_customization = args.api_customization
|
422 |
+
|
423 |
+
chat = vllm_Model(model_name=model_name)
|
424 |
+
|
425 |
+
|
426 |
+
for test_set in test_sets:
|
427 |
+
total_output_file = f"data/{method_name}/{test_set}_raw.json"
|
428 |
+
|
429 |
+
test_ids = list(json.load(open(os.path.join(test_query_id_path, test_set+".json"), "r")).keys())
|
430 |
+
|
431 |
+
query_file = f'{query_file_dir}/{test_set}.json'
|
432 |
+
|
433 |
+
output_dir = f"data/{method_name}/{test_set}/answer"
|
434 |
+
|
435 |
+
whole_solution_dir = f"data/{method_name}/{test_set}/whole"
|
436 |
+
|
437 |
+
logger_dir = f"data/{method_name}/{test_set}/agent_log"
|
438 |
+
|
439 |
+
if not os.path.exists(output_dir):
|
440 |
+
os.makedirs(output_dir)
|
441 |
+
|
442 |
+
if not os.path.exists(whole_solution_dir):
|
443 |
+
os.makedirs(whole_solution_dir)
|
444 |
+
|
445 |
+
if not os.path.exists(logger_dir):
|
446 |
+
os.makedirs(logger_dir)
|
447 |
+
|
448 |
+
Answer_Agent = answer_agent(llm=chat, logger_dir=logger_dir)
|
449 |
+
Executor_Agent = executor_agent(llm=chat, logger_dir=logger_dir)
|
450 |
+
Planning_Agent = planning_agent(llm=chat, logger_dir=logger_dir)
|
451 |
+
Verifier_Agent = verifier_agent(llm=chat, logger_dir=logger_dir)
|
452 |
+
|
453 |
+
items = os.listdir(output_dir)
|
454 |
+
for i in range(len(items)):
|
455 |
+
items[i] = items[i].split(".")[0]
|
456 |
+
|
457 |
+
white_list = get_white_list(tool_env_dir)
|
458 |
+
|
459 |
+
with open(query_file) as file:
|
460 |
+
query_json = json.load(file)
|
461 |
+
# with open(tool_doc_dir) as file:
|
462 |
+
# tool_doc = json.load(file)
|
463 |
+
|
464 |
+
print(len(items))
|
465 |
+
query_json_to_do = []
|
466 |
+
# if len(items) != 0:
|
467 |
+
for idx, q in enumerate(query_json):
|
468 |
+
# print(idx)
|
469 |
+
if str(idx) in items:
|
470 |
+
continue
|
471 |
+
query_id = q["query_id"]
|
472 |
+
if str(query_id) not in test_ids:
|
473 |
+
continue
|
474 |
+
query_json_to_do_ele = (idx, q)
|
475 |
+
query_json_to_do.append(query_json_to_do_ele)
|
476 |
+
# else:
|
477 |
+
# query_json_to_do = query_json
|
478 |
+
|
479 |
+
total_len = len(query_json_to_do)
|
480 |
+
query_len = len(query_json)
|
481 |
+
print(total_len)
|
482 |
+
|
483 |
+
if total_len < 20:
|
484 |
+
for i in range(total_len):
|
485 |
+
if total_len == 0:
|
486 |
+
break
|
487 |
+
|
488 |
+
start = i
|
489 |
+
end = i+1
|
490 |
+
if i == total_len-1:
|
491 |
+
query_json_cur = query_json_to_do[start:]
|
492 |
+
else:
|
493 |
+
query_json_cur = query_json_to_do[start: end]
|
494 |
+
t = threading.Thread(target=test, args=(query_json_cur, white_list, output_dir, whole_solution_dir, Answer_Agent, Executor_Agent, Planning_Agent, Verifier_Agent, args))
|
495 |
+
t.start()
|
496 |
+
threads.append(t)
|
497 |
+
|
498 |
+
|
499 |
+
else:
|
500 |
+
for i in range(20):
|
501 |
+
|
502 |
+
if total_len == 0:
|
503 |
+
break
|
504 |
+
|
505 |
+
start = round(total_len/20)*i
|
506 |
+
end = round(total_len/20)*(i+1)
|
507 |
+
if i == 19:
|
508 |
+
query_json_cur = query_json_to_do[start:]
|
509 |
+
else:
|
510 |
+
query_json_cur = query_json_to_do[start: end]
|
511 |
+
t = threading.Thread(target=test, args=(query_json_cur, white_list, output_dir, whole_solution_dir, Answer_Agent, Executor_Agent, Planning_Agent, Verifier_Agent, args))
|
512 |
+
t.start()
|
513 |
+
threads.append(t)
|
514 |
+
|
515 |
+
for thread in threads:
|
516 |
+
thread.join()
|
517 |
+
|
518 |
+
total_json = {}
|
519 |
+
items = os.listdir(output_dir)
|
520 |
+
for item in items:
|
521 |
+
item_path = os.path.join(output_dir, item)
|
522 |
+
idx = item.split(".")[0]
|
523 |
+
total_json[str(idx)] = json.load(open(item_path, 'r'))
|
524 |
+
|
525 |
+
with open(total_output_file, 'w') as file:
|
526 |
+
json.dump(total_json, file, indent=4, ensure_ascii=False)
|
527 |
+
|
Smurfs/inference/server.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
from typing import Union
|
5 |
+
import sys
|
6 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
7 |
+
from Smurfs.inference.utils import standardize, change_name
|
8 |
+
import random
|
9 |
+
|
10 |
+
|
11 |
+
class Info(BaseModel):
|
12 |
+
category: str
|
13 |
+
tool_name: str
|
14 |
+
api_name: str
|
15 |
+
tool_input: Union[str, dict]
|
16 |
+
strip: str
|
17 |
+
|
18 |
+
def prepare_tool_name_and_url(tools_root, info):
|
19 |
+
category = info.category
|
20 |
+
standard_category = category.replace(" ", "_").replace(",", "_").replace("/", "_")
|
21 |
+
while " " in standard_category or "," in standard_category:
|
22 |
+
standard_category = standard_category.replace(" ", "_").replace(",", "_")
|
23 |
+
standard_category = standard_category.replace("__", "_")
|
24 |
+
|
25 |
+
tool_name = info.tool_name
|
26 |
+
api_name = change_name(standardize(info.api_name))
|
27 |
+
if not tool_name.endswith(f"_for_{standard_category}"):
|
28 |
+
tool_name = standardize(info.tool_name)
|
29 |
+
code_string = f"""from {tools_root}.{standard_category}.{tool_name}.api import {api_name}"""
|
30 |
+
tool_name += f"_for_{standard_category}"
|
31 |
+
else:
|
32 |
+
tmp_tool_name = standardize(tool_name.replace(f"_for_{standard_category}", ""))
|
33 |
+
code_string = f"""from {tools_root}.{standard_category}.{tmp_tool_name}.api import {api_name}"""
|
34 |
+
return tool_name, standard_category, api_name, code_string
|
35 |
+
|
36 |
+
def process_error(response):
|
37 |
+
save_cache_flag = False
|
38 |
+
switch_flag = False
|
39 |
+
if "The request to the API has timed out. Please try again later, or if the issue persists" in str(response):
|
40 |
+
return_dict = {"error": "API temporarily not working error...", "response": response}
|
41 |
+
|
42 |
+
if "Your Client (working) ---> Gateway (working) ---> API (not working)" in str(response):
|
43 |
+
return_dict = {"error": "API not working error...", "response": response}
|
44 |
+
|
45 |
+
elif "Unauthorized" in str(response) or "unauthorized" in str(response):
|
46 |
+
save_cache_flag = True
|
47 |
+
return_dict = {"error": "Unauthorized error...", "response": response}
|
48 |
+
|
49 |
+
elif "You are not subscribed to this API." in str(response):
|
50 |
+
switch_flag = True
|
51 |
+
return_dict = {"error": "Unsubscribed error...", "response": response}
|
52 |
+
|
53 |
+
elif "Too many requests" in str(response):
|
54 |
+
switch_flag = True
|
55 |
+
return_dict = {"error": "Too many requests error...", "response": response}
|
56 |
+
|
57 |
+
elif "You have exceeded" in str(response) or "you are being rate limited" in str(response):
|
58 |
+
switch_flag = True
|
59 |
+
return_dict = {"error": "Rate limit error...", "response": response}
|
60 |
+
|
61 |
+
elif "Access restricted. Check credits balance or enter the correct API key." in str(response):
|
62 |
+
switch_flag = True
|
63 |
+
return_dict = {"error": "Rate limit error...", "response": response}
|
64 |
+
|
65 |
+
elif "Oops, an error in the gateway has occurred." in str(response):
|
66 |
+
switch_flag = True
|
67 |
+
return_dict = {"error": "Gateway error...", "response": response}
|
68 |
+
|
69 |
+
elif "Blocked User. Please contact your API provider." in str(response):
|
70 |
+
switch_flag = True
|
71 |
+
return_dict = {"error": "Blocked error...", "response": response}
|
72 |
+
|
73 |
+
elif "error" in str(response):
|
74 |
+
return_dict = {"error": "Message error...", "response": response}
|
75 |
+
|
76 |
+
else:
|
77 |
+
save_cache_flag = True
|
78 |
+
return_dict = {"error": "", "response": response}
|
79 |
+
return return_dict, save_cache_flag, switch_flag
|
80 |
+
|
81 |
+
def run(toolbench_code_string, toolbench_api_name, toolbench_input_params_str):
|
82 |
+
# get observation
|
83 |
+
success_flag = False
|
84 |
+
switch_flag = False
|
85 |
+
save_cache = False
|
86 |
+
exec(toolbench_code_string)
|
87 |
+
try:
|
88 |
+
eval_func_str = f"{toolbench_api_name}({toolbench_input_params_str})"
|
89 |
+
new_func = eval(eval_func_str)
|
90 |
+
response, save_cache, switch_flag = process_error(new_func)
|
91 |
+
success_flag = True
|
92 |
+
except Exception as e:
|
93 |
+
response = {"error": f"Function executing {toolbench_code_string} error...\n{e}", "response": ""}
|
94 |
+
save_cache = False
|
95 |
+
return success_flag, switch_flag, response, save_cache
|
96 |
+
|
97 |
+
|
98 |
+
def dict_shorten(origin: dict, schema: dict):
|
99 |
+
for key, value in list(origin.items()):
|
100 |
+
if key not in schema:
|
101 |
+
del origin[key]
|
102 |
+
else:
|
103 |
+
if isinstance(value, dict):
|
104 |
+
dict_shorten(value, schema[key]) # schema[key] should be a dict
|
105 |
+
elif isinstance(value, list):
|
106 |
+
if value:
|
107 |
+
if isinstance(value[0], dict):
|
108 |
+
for item in value:
|
109 |
+
dict_shorten(item, schema[key][0]) # schema[key] should be a list with only one dict element
|
110 |
+
return origin
|
111 |
+
|
112 |
+
def observation_shorten(schema_root, response_dict, category, tool_name, api_name, strip_method):
|
113 |
+
print(random.random())
|
114 |
+
if strip_method == "filter" or (strip_method == "random" and random.random() > 0.5):
|
115 |
+
if isinstance(response_dict["response"], dict):
|
116 |
+
if os.path.exists(os.path.join(schema_root, category)):
|
117 |
+
if os.path.exists(os.path.join(schema_root, category, tool_name+".json")):
|
118 |
+
schema_dicts = json.load(open(os.path.join(schema_root, category, tool_name+".json"), "r"))
|
119 |
+
api_list = schema_dicts["api_list"]
|
120 |
+
schema = None
|
121 |
+
for schema_dict in api_list:
|
122 |
+
schema_api_name = change_name(standardize(schema_dict["name"]))
|
123 |
+
if schema_api_name == api_name and len(schema_dict["schema"]) > 0:
|
124 |
+
schema = schema_dict["schema"]
|
125 |
+
break
|
126 |
+
if schema is not None:
|
127 |
+
response_dict["response"] = dict_shorten(response_dict["response"], schema)
|
128 |
+
return str(response_dict["response"])
|
129 |
+
|
130 |
+
|
131 |
+
def get_rapidapi_response(input_dict: dict, api_customization: bool=False, tools_root: str="data.toolenv.tools", schema_root: str="data/toolenv/response_examples"):
|
132 |
+
info = Info
|
133 |
+
info.category = input_dict['category']
|
134 |
+
info.tool_name = input_dict['tool_name']
|
135 |
+
info.api_name = input_dict['api_name']
|
136 |
+
info.tool_input = input_dict['tool_input']
|
137 |
+
info.strip = input_dict['strip']
|
138 |
+
rapidapi_key = input_dict['rapidapi_key']
|
139 |
+
|
140 |
+
tool_name, standard_category, api_name, code_string = prepare_tool_name_and_url(tools_root, info)
|
141 |
+
tool_input = info.tool_input
|
142 |
+
|
143 |
+
strip_method = info.strip
|
144 |
+
|
145 |
+
try:
|
146 |
+
tool_input = json.loads(tool_input)
|
147 |
+
except Exception as e:
|
148 |
+
if tool_input == "":
|
149 |
+
tool_input = {}
|
150 |
+
else:
|
151 |
+
print(f"Can not parse tool input into json: {tool_input}")
|
152 |
+
response_dict = {"error": f"Tool input parse error...\n", "response": ""}
|
153 |
+
return response_dict
|
154 |
+
|
155 |
+
input_params_str = ""
|
156 |
+
if len(tool_input) > 0:
|
157 |
+
for key, value in tool_input.items():
|
158 |
+
if isinstance(value, str):
|
159 |
+
input_params_str += f'{key}="{value}", '
|
160 |
+
else:
|
161 |
+
input_params_str += f'{key}={value}, '
|
162 |
+
if not api_customization:
|
163 |
+
input_params_str += f"toolbench_rapidapi_key='{rapidapi_key}'"
|
164 |
+
success_flag, switch_flag, response_dict, save_cache = run(code_string, api_name, input_params_str)
|
165 |
+
observation = observation_shorten(schema_root, response_dict, standard_category, tool_name.replace(f"_for_{standard_category}", ""), api_name, strip_method)
|
166 |
+
result = str(observation)[:2048]
|
167 |
+
return {"error": response_dict['error'], "response": result}
|
168 |
+
|
169 |
+
|
170 |
+
if __name__ == "__main__":
|
171 |
+
result = get_rapidapi_response({
|
172 |
+
"category": "Social",
|
173 |
+
"tool_name": "olato_quotes",
|
174 |
+
"api_name": "love_quote",
|
175 |
+
"tool_input": '{}',
|
176 |
+
"strip": "filter",
|
177 |
+
"rapidapi_key": ""
|
178 |
+
})
|
179 |
+
print(result)
|
Smurfs/inference/smurfs_worker.py
ADDED
@@ -0,0 +1,1040 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# — coding: utf-8 –
|
2 |
+
import json
|
3 |
+
import sys
|
4 |
+
import argparse
|
5 |
+
import time
|
6 |
+
import requests
|
7 |
+
import os
|
8 |
+
from Smurfs.inference.utils import change_name, standardize, get_white_list, get_answer_log, get_observation_log, build_tree, get_answer_details, test_sets
|
9 |
+
from tqdm import tqdm
|
10 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
11 |
+
from Smurfs.model.vllm_model.vllm_model import vllm_Model
|
12 |
+
from Smurfs.inference.server import get_rapidapi_response
|
13 |
+
import threading
|
14 |
+
from Smurfs.agents.answer_agent.answer import answer_agent
|
15 |
+
from Smurfs.agents.executor_agent.executor import executor_agent
|
16 |
+
from Smurfs.agents.planning_agent.planner import planning_agent
|
17 |
+
from Smurfs.agents.verifier_agent.verifier import verifier_agent
|
18 |
+
from termcolor import colored
|
19 |
+
class smurfs_worker:
|
20 |
+
def __init__(self, available_tools, tool_env, llm, method_name, test_set, Answer_Agent, Executor_Agent, Planning_Agent, Verifier_Agent):
|
21 |
+
#available_tools的格式形如toolbench里面的api_list里的格式,只需要api_name
|
22 |
+
#tool_env是一个工具函数里用来存储工具代码的py文件中的所有函数的字典,key为函数名,value是函数对象
|
23 |
+
self.available_tools = available_tools
|
24 |
+
|
25 |
+
self.output_dir = f"data/{method_name}/{test_set}/answer"
|
26 |
+
|
27 |
+
self.whole_solution_dir = f"data/{method_name}/{test_set}/whole"
|
28 |
+
|
29 |
+
self.logger_dir = f"data/{method_name}/{test_set}/agent_log"
|
30 |
+
|
31 |
+
if not os.path.exists(self.output_dir):
|
32 |
+
os.makedirs(self.output_dir)
|
33 |
+
|
34 |
+
if not os.path.exists(self.whole_solution_dir):
|
35 |
+
os.makedirs(self.whole_solution_dir)
|
36 |
+
|
37 |
+
if not os.path.exists(self.logger_dir):
|
38 |
+
os.makedirs(self.logger_dir)
|
39 |
+
|
40 |
+
self.Answer_Agent = Answer_Agent(llm=llm, logger_dir=self.logger_dir)
|
41 |
+
self.Executor_Agent = Executor_Agent(llm=llm, logger_dir=self.logger_dir)
|
42 |
+
self.Planning_Agent = Planning_Agent(llm=llm, logger_dir=self.logger_dir)
|
43 |
+
self.Verifier_Agent = Verifier_Agent(llm=llm, logger_dir=self.logger_dir)
|
44 |
+
self.tool_env = tool_env
|
45 |
+
|
46 |
+
def inference(self, query, relevant_APIs, subtask, query_id, max_step=3):
|
47 |
+
# tool_check_num, reason = self.Answer_Agent.run(question=query, task="tool_check", query_id=query_id)
|
48 |
+
# #direct answer
|
49 |
+
# if tool_check_num == 1:
|
50 |
+
# # input_dic = {"task": query}
|
51 |
+
# answer = self.Answer_Agent.run(question=query, task="direct", query_id=query_id)
|
52 |
+
# previous_log = [{"thought": reason, "action": "", "action_input": "", "observation": "", "answer": answer, "tool": "","id": 0}]
|
53 |
+
# history_log = [{"thought": reason, "action": "", "action_input": "", "observation": "", "answer": answer, "previous_id": -1, "id": 0}]
|
54 |
+
# return answer, previous_log, 0, history_log
|
55 |
+
|
56 |
+
previous_log = []
|
57 |
+
history_log = []
|
58 |
+
tool_used_dic = {}
|
59 |
+
relevant_APIs_ids = []
|
60 |
+
for idx in relevant_APIs:
|
61 |
+
ele = relevant_APIs[idx]
|
62 |
+
relevant_APIs_ids.append(str(ele["ID"]))
|
63 |
+
restart_time = 0
|
64 |
+
step_num = 0
|
65 |
+
hint = "Beginnig of the agent. No hint yet"
|
66 |
+
retry_tool_id = 0
|
67 |
+
retry_parameter = 0
|
68 |
+
re_time = 0
|
69 |
+
subtask_id = 0
|
70 |
+
restart = 0
|
71 |
+
while True:
|
72 |
+
if step_num >= max_step:
|
73 |
+
print("\n\nReach steps limits, return answers!\n\n")
|
74 |
+
answer_log = get_answer_log(history_log)
|
75 |
+
answer = self.Answer_Agent.run(question=query, previous_log=answer_log, task="final", query_id=query_id)
|
76 |
+
return answer, previous_log, re_time, history_log
|
77 |
+
|
78 |
+
if step_num not in tool_used_dic.keys():
|
79 |
+
tool_used_dic[step_num] = []
|
80 |
+
|
81 |
+
tool_used = tool_used_dic[step_num]
|
82 |
+
|
83 |
+
tool_list = []
|
84 |
+
for idx in relevant_APIs:
|
85 |
+
ele = relevant_APIs[idx]
|
86 |
+
ID = str(ele['ID'])
|
87 |
+
if ID in tool_used:
|
88 |
+
continue
|
89 |
+
des = ele['description']
|
90 |
+
name = ele["tool_name"]
|
91 |
+
tool_list.append({"ID": ID, "tool_name": name, "description": des})
|
92 |
+
|
93 |
+
if len(tool_list) == 0:
|
94 |
+
if len(previous_log) == 0:
|
95 |
+
answer_log = get_answer_log(history_log)
|
96 |
+
partial_answer = self.Answer_Agent.run(question=query, previous_log=answer_log, task="final", query_id=query_id)
|
97 |
+
answer = f"Sorry, I can't answer this question accurately using the existing tools. A partial answer is: {partial_answer}"
|
98 |
+
return answer, previous_log, re_time, history_log
|
99 |
+
else:
|
100 |
+
delete_log = previous_log.pop()
|
101 |
+
tool_used_dic[step_num] = []
|
102 |
+
step_num -= 1
|
103 |
+
tool_used_dic[step_num].append(delete_log["tool"])
|
104 |
+
restart_time += 1
|
105 |
+
re_time += 1
|
106 |
+
continue
|
107 |
+
|
108 |
+
current_log = {"thought": "", "action": "", "action_input": {}, "observation": "", "answer": "", "tool": "","id": subtask_id}
|
109 |
+
|
110 |
+
answer_log = get_answer_log(previous_log)
|
111 |
+
|
112 |
+
if retry_tool_id == 4:
|
113 |
+
tool_id = tool_list[0]["ID"]
|
114 |
+
tool_list = tool_list[0]
|
115 |
+
thought = self.Executor_Agent.run(question=query, tool_list=tool_list, previous_log=answer_log, hint=hint, query_id=query_id, task="thought")
|
116 |
+
|
117 |
+
else:
|
118 |
+
thought = self.Executor_Agent.run(question=query, tool_list=tool_list, previous_log=answer_log, hint=hint, query_id=query_id, task="thought")
|
119 |
+
tool_id = self.Executor_Agent.run(question=subtask, tool_list=tool_list, thought=thought, query_id=query_id, task="tool")
|
120 |
+
|
121 |
+
try:
|
122 |
+
tool_id = int(tool_id)
|
123 |
+
tool_id = str(tool_id)
|
124 |
+
if tool_id not in relevant_APIs_ids:
|
125 |
+
re_time += 1
|
126 |
+
retry_tool_id += 1
|
127 |
+
print("Tool ID wrong! Generate tool_id that do not exist!")
|
128 |
+
continue
|
129 |
+
tool_des_json = relevant_APIs[str(tool_id)]
|
130 |
+
retry_tool_id = 0
|
131 |
+
except:
|
132 |
+
retry_tool_id += 1
|
133 |
+
print("Tool ID wrong! Generate tool_id that do not exist!")
|
134 |
+
continue
|
135 |
+
|
136 |
+
# tool_name_list = tool_des_json["tool_name"].split(":")
|
137 |
+
# category_name = tool_name_list[0]
|
138 |
+
# tool_name = tool_name_list[1]
|
139 |
+
api_name = tool_des_json["tool_name"]
|
140 |
+
API_doc = tool_des_json
|
141 |
+
|
142 |
+
while True:
|
143 |
+
try:
|
144 |
+
parameters = {}
|
145 |
+
|
146 |
+
if retry_parameter == 4:
|
147 |
+
restart = 1
|
148 |
+
retry_parameter = 0
|
149 |
+
print("No Para! Restart!")
|
150 |
+
break
|
151 |
+
|
152 |
+
parameter = self.Executor_Agent.run(api_dic=API_doc, question=query, previous_log=answer_log, thought=thought, query_id=query_id, task="parameter")
|
153 |
+
if parameter == -1:
|
154 |
+
retry_parameter += 1
|
155 |
+
re_time += 1
|
156 |
+
continue
|
157 |
+
|
158 |
+
if parameter == {}:
|
159 |
+
retry_parameter = 0
|
160 |
+
parameters = {}
|
161 |
+
break
|
162 |
+
|
163 |
+
for key in parameter:
|
164 |
+
value = parameter[key]
|
165 |
+
key = change_name(key)
|
166 |
+
parameters[key] = value
|
167 |
+
|
168 |
+
retry_parameter = 0
|
169 |
+
break
|
170 |
+
|
171 |
+
except:
|
172 |
+
if retry_parameter == 4:
|
173 |
+
parameters = {}
|
174 |
+
retry_parameter = 0
|
175 |
+
restart = 1
|
176 |
+
break
|
177 |
+
retry_parameter += 1
|
178 |
+
print("parameter generation fails, try again!")
|
179 |
+
re_time += 1
|
180 |
+
continue
|
181 |
+
|
182 |
+
|
183 |
+
# api_name = change_name(standardize(api_name))
|
184 |
+
|
185 |
+
if restart != 1:
|
186 |
+
try:
|
187 |
+
observation = self.Call_function(api_name, parameters)
|
188 |
+
except:
|
189 |
+
observation = -1
|
190 |
+
|
191 |
+
if observation == -1:
|
192 |
+
restart = 1
|
193 |
+
observation = str({"error": "", "response": "call API fails"})
|
194 |
+
|
195 |
+
if restart == 1:
|
196 |
+
tool_used_dic[step_num].append(str(tool_id))
|
197 |
+
print('****Try Again For This Step****')
|
198 |
+
re_time += 1
|
199 |
+
restart = 0
|
200 |
+
continue
|
201 |
+
|
202 |
+
|
203 |
+
if len(previous_log) != 0:
|
204 |
+
previous_id = previous_log[-1]["id"]
|
205 |
+
else:
|
206 |
+
previous_id = -1
|
207 |
+
|
208 |
+
current_log["tool"] = str(tool_id)
|
209 |
+
current_log["thought"] = thought
|
210 |
+
current_log["action"] = api_name
|
211 |
+
current_log["action_input"] = parameters
|
212 |
+
current_log["observation"] = observation
|
213 |
+
previous_log.append(current_log)
|
214 |
+
print("##########Tool Response##########")
|
215 |
+
print(f"{observation}\n")
|
216 |
+
observation_log = get_observation_log(previous_log)
|
217 |
+
|
218 |
+
answer = self.Answer_Agent.run(question=subtask, call_result=observation_log, query_id=query_id, task="answer")
|
219 |
+
|
220 |
+
previous_log[-1]["answer"] = answer
|
221 |
+
|
222 |
+
history_log_ele = {"thought": thought, "action": api_name, "action_input": parameters, "observation": observation, "answer": answer, "previous_id": previous_id, "id": subtask_id}
|
223 |
+
history_log.append(history_log_ele)
|
224 |
+
subtask_id += 1
|
225 |
+
|
226 |
+
speak, status = self.Verifier_Agent.run(question=subtask, answer=answer, query_id=query_id)
|
227 |
+
if speak == -1 and status == -1:
|
228 |
+
step_num += 1
|
229 |
+
continue
|
230 |
+
|
231 |
+
try:
|
232 |
+
if int(status) == 0:
|
233 |
+
hint = speak
|
234 |
+
step_num += 1
|
235 |
+
continue
|
236 |
+
except:
|
237 |
+
step_num += 1
|
238 |
+
continue
|
239 |
+
|
240 |
+
else:
|
241 |
+
return answer, previous_log, re_time, history_log
|
242 |
+
|
243 |
+
def decompose_inference(self, query, query_id):
|
244 |
+
while True:
|
245 |
+
subtasks = self.Planning_Agent.run(question=query, query_id=query_id)
|
246 |
+
if subtasks == -1:
|
247 |
+
continue
|
248 |
+
break
|
249 |
+
task_log = ""
|
250 |
+
history_log = []
|
251 |
+
previous_log_totals = []
|
252 |
+
re_time_total = 0
|
253 |
+
# print(subtasks)
|
254 |
+
relevant_API_list = {}
|
255 |
+
tool_id = 0
|
256 |
+
for api in self.available_tools:
|
257 |
+
tool_name = api["api_name"]
|
258 |
+
ele = {"ID": tool_id, "tool_name": tool_name, "description": api["api_description"], "required_parameters": api["required_parameters"], "optional_parameters": api["optional_parameters"]}
|
259 |
+
relevant_API_list[str(tool_id)] = ele
|
260 |
+
tool_id += 1
|
261 |
+
|
262 |
+
for subtask in subtasks:
|
263 |
+
task_log += f"question: {subtask}\n"
|
264 |
+
answer, previous_log, re_time, previous_log_total = self.inference(task_log, relevant_API_list, subtask, query_id)
|
265 |
+
previous_log_totals.append(previous_log_total)
|
266 |
+
# print(answer)
|
267 |
+
history_log += previous_log
|
268 |
+
re_time_total += re_time
|
269 |
+
task_log += f"answer: {answer}\n"
|
270 |
+
final_answer = self.Answer_Agent.run(question=query, previous_log=task_log, task="final", query_id=query_id)
|
271 |
+
return final_answer, history_log, task_log, re_time_total, previous_log_totals
|
272 |
+
|
273 |
+
def run(self, input, query_id):
|
274 |
+
# result = {}
|
275 |
+
# st = time.time()
|
276 |
+
final_answer, previous_log, task_log,re_time, previous_log_totals = self.decompose_inference(input, query_id)
|
277 |
+
answer_details, total_steps = get_answer_details(final_answer, previous_log)
|
278 |
+
solution_tree, solution_total_steps = build_tree(previous_log_totals, task_log)
|
279 |
+
output_file_ele = {
|
280 |
+
"query": input,
|
281 |
+
"restart_time": re_time,
|
282 |
+
"answer": {
|
283 |
+
"method": "decompose_dfs",
|
284 |
+
"total_steps": total_steps,
|
285 |
+
"final_answer": final_answer,
|
286 |
+
"answer_details": answer_details
|
287 |
+
}
|
288 |
+
}
|
289 |
+
|
290 |
+
solution_file_ele = {
|
291 |
+
"query": input,
|
292 |
+
"total_steps": solution_total_steps,
|
293 |
+
"task_log": task_log,
|
294 |
+
"final_answer": final_answer,
|
295 |
+
"answer_path": answer_details,
|
296 |
+
"total_path": solution_tree
|
297 |
+
}
|
298 |
+
return final_answer, output_file_ele, solution_file_ele
|
299 |
+
|
300 |
+
def save_solution(self, output_file_ele, solution_file_ele, idx):
|
301 |
+
file_name = f"{idx}.json"
|
302 |
+
output_file = os.path.join(self.output_dir, file_name)
|
303 |
+
whole_solution_file = os.path.join(self.whole_solution_dir, file_name)
|
304 |
+
with open(output_file, "w") as file:
|
305 |
+
json.dump(output_file_ele, file, ensure_ascii=False, indent=4)
|
306 |
+
with open(whole_solution_file, "w") as file:
|
307 |
+
json.dump(solution_file_ele, file, ensure_ascii=False, indent=4)
|
308 |
+
|
309 |
+
def Call_function(self, tool_name, args):
|
310 |
+
try:
|
311 |
+
print(tool_name)
|
312 |
+
func = self.tool_env[tool_name]
|
313 |
+
observation = func(**args)
|
314 |
+
return observation
|
315 |
+
except Exception as e:
|
316 |
+
print(e)
|
317 |
+
print(f"Call function fails")
|
318 |
+
with open('wrong_log.json', 'a+', encoding='utf-8') as f:
|
319 |
+
line = json.dumps({
|
320 |
+
"id": 0,
|
321 |
+
"parameters": args,
|
322 |
+
"tool": tool_name,
|
323 |
+
"wrong": str(e)
|
324 |
+
}, ensure_ascii=False)
|
325 |
+
f.write(line + '\n')
|
326 |
+
return -1
|
327 |
+
|
328 |
+
class smurfs_hotpot_worker:
|
329 |
+
def __init__(self, available_tools, tool_env, llm, method_name, test_set, Answer_Agent, Executor_Agent, Planning_Agent, Verifier_Agent):
|
330 |
+
#available_tools的格式形如toolbench里面的api_list里的格式,只需要api_name
|
331 |
+
#tool_env是一个工具函数里用来存储工具代码的py文件中的所有函数的字典,key为函数名,value是函数对象
|
332 |
+
self.available_tools = available_tools
|
333 |
+
|
334 |
+
self.output_dir = f"data/{method_name}/{test_set}/answer"
|
335 |
+
|
336 |
+
self.whole_solution_dir = f"data/{method_name}/{test_set}/whole"
|
337 |
+
|
338 |
+
self.logger_dir = f"data/{method_name}/{test_set}/agent_log"
|
339 |
+
|
340 |
+
if not os.path.exists(self.output_dir):
|
341 |
+
os.makedirs(self.output_dir)
|
342 |
+
|
343 |
+
if not os.path.exists(self.whole_solution_dir):
|
344 |
+
os.makedirs(self.whole_solution_dir)
|
345 |
+
|
346 |
+
if not os.path.exists(self.logger_dir):
|
347 |
+
os.makedirs(self.logger_dir)
|
348 |
+
|
349 |
+
self.Answer_Agent = Answer_Agent(llm=llm, logger_dir=self.logger_dir)
|
350 |
+
self.Executor_Agent = Executor_Agent(llm=llm, logger_dir=self.logger_dir)
|
351 |
+
self.Planning_Agent = Planning_Agent(llm=llm, logger_dir=self.logger_dir)
|
352 |
+
self.Verifier_Agent = Verifier_Agent(llm=llm, logger_dir=self.logger_dir)
|
353 |
+
self.tool_class = tool_env
|
354 |
+
self.tool_env = {}
|
355 |
+
|
356 |
+
def inference(self, query, relevant_APIs, subtask, query_id, max_step=3):
|
357 |
+
# tool_check_num = self.Answer_Agent.run(question=query, task="tool_check", query_id=query_id)
|
358 |
+
# #direct answer
|
359 |
+
# if tool_check_num == 1:
|
360 |
+
# input_dic = {"task": query}
|
361 |
+
# answer = self.Answer_Agent.run(input_dic)
|
362 |
+
# return answer, answer, None, None
|
363 |
+
|
364 |
+
previous_log = []
|
365 |
+
history_log = []
|
366 |
+
tool_used_dic = {}
|
367 |
+
relevant_APIs_ids = []
|
368 |
+
for idx in relevant_APIs:
|
369 |
+
ele = relevant_APIs[idx]
|
370 |
+
relevant_APIs_ids.append(str(ele["ID"]))
|
371 |
+
restart_time = 0
|
372 |
+
step_num = 0
|
373 |
+
hint = "Beginnig of the agent. No hint yet"
|
374 |
+
retry_tool_id = 0
|
375 |
+
retry_parameter = 0
|
376 |
+
re_time = 0
|
377 |
+
subtask_id = 0
|
378 |
+
restart = 0
|
379 |
+
while True:
|
380 |
+
if step_num >= max_step:
|
381 |
+
print("\n\nReach steps limits, return answers!\n\n")
|
382 |
+
answer_log = get_answer_log(history_log)
|
383 |
+
answer = self.Answer_Agent.run(question=query, previous_log=answer_log, task="final", query_id=query_id)
|
384 |
+
return answer, previous_log, re_time, history_log
|
385 |
+
|
386 |
+
if step_num not in tool_used_dic.keys():
|
387 |
+
tool_used_dic[step_num] = []
|
388 |
+
|
389 |
+
tool_used = tool_used_dic[step_num]
|
390 |
+
|
391 |
+
tool_list = []
|
392 |
+
for idx in relevant_APIs:
|
393 |
+
ele = relevant_APIs[idx]
|
394 |
+
ID = str(ele['ID'])
|
395 |
+
if ID in tool_used:
|
396 |
+
continue
|
397 |
+
des = ele['description']
|
398 |
+
name = ele["tool_name"]
|
399 |
+
tool_list.append({"ID": ID, "tool_name": name, "description": des})
|
400 |
+
|
401 |
+
if len(tool_list) == 0:
|
402 |
+
if len(previous_log) == 0:
|
403 |
+
answer_log = get_answer_log(history_log)
|
404 |
+
partial_answer = self.Answer_Agent.run(question=query, previous_log=answer_log, task="final", query_id=query_id)
|
405 |
+
answer = f"Sorry, I can't answer this question accurately using the existing tools. A partial answer is: {partial_answer}"
|
406 |
+
return answer, previous_log, re_time, history_log
|
407 |
+
else:
|
408 |
+
delete_log = previous_log.pop()
|
409 |
+
tool_used_dic[step_num] = []
|
410 |
+
step_num -= 1
|
411 |
+
tool_used_dic[step_num].append(delete_log["tool"])
|
412 |
+
restart_time += 1
|
413 |
+
re_time += 1
|
414 |
+
continue
|
415 |
+
|
416 |
+
current_log = {"thought": "", "action": "", "action_input": {}, "observation": "", "answer": "", "tool": "","id": subtask_id}
|
417 |
+
|
418 |
+
answer_log = get_answer_log(previous_log)
|
419 |
+
|
420 |
+
if retry_tool_id == 4:
|
421 |
+
tool_id = tool_list[0]["ID"]
|
422 |
+
tool_list = tool_list[0]
|
423 |
+
thought = self.Executor_Agent.run(question=query, tool_list=tool_list, previous_log=answer_log, hint=hint, query_id=query_id, task="thought")
|
424 |
+
|
425 |
+
else:
|
426 |
+
thought = self.Executor_Agent.run(question=query, tool_list=tool_list, previous_log=answer_log, hint=hint, query_id=query_id, task="thought")
|
427 |
+
tool_id = self.Executor_Agent.run(question=subtask, tool_list=tool_list, thought=thought, query_id=query_id, task="tool")
|
428 |
+
|
429 |
+
try:
|
430 |
+
tool_id = int(tool_id)
|
431 |
+
tool_id = str(tool_id)
|
432 |
+
if tool_id not in relevant_APIs_ids:
|
433 |
+
re_time += 1
|
434 |
+
retry_tool_id += 1
|
435 |
+
print("Tool ID wrong! Generate tool_id that do not exist!")
|
436 |
+
continue
|
437 |
+
tool_des_json = relevant_APIs[str(tool_id)]
|
438 |
+
retry_tool_id = 0
|
439 |
+
except:
|
440 |
+
retry_tool_id += 1
|
441 |
+
print("Tool ID wrong! Generate tool_id that do not exist!")
|
442 |
+
continue
|
443 |
+
|
444 |
+
# tool_name_list = tool_des_json["tool_name"].split(":")
|
445 |
+
# category_name = tool_name_list[0]
|
446 |
+
# tool_name = tool_name_list[1]
|
447 |
+
api_name = tool_des_json["tool_name"]
|
448 |
+
API_doc = tool_des_json
|
449 |
+
|
450 |
+
while True:
|
451 |
+
try:
|
452 |
+
parameters = {}
|
453 |
+
|
454 |
+
if retry_parameter == 4:
|
455 |
+
restart = 1
|
456 |
+
retry_parameter = 0
|
457 |
+
print("No Para! Restart!")
|
458 |
+
break
|
459 |
+
|
460 |
+
parameter = self.Executor_Agent.run(api_dic=API_doc, question=query, previous_log=answer_log, thought=thought, query_id=query_id, task="parameter")
|
461 |
+
if parameter == -1:
|
462 |
+
retry_parameter += 1
|
463 |
+
re_time += 1
|
464 |
+
continue
|
465 |
+
|
466 |
+
if parameter == {}:
|
467 |
+
retry_parameter = 0
|
468 |
+
parameters = {}
|
469 |
+
break
|
470 |
+
|
471 |
+
for key in parameter:
|
472 |
+
value = parameter[key]
|
473 |
+
key = change_name(key)
|
474 |
+
parameters[key] = value
|
475 |
+
|
476 |
+
retry_parameter = 0
|
477 |
+
break
|
478 |
+
|
479 |
+
except:
|
480 |
+
if retry_parameter == 4:
|
481 |
+
parameters = {}
|
482 |
+
retry_parameter = 0
|
483 |
+
restart = 1
|
484 |
+
break
|
485 |
+
retry_parameter += 1
|
486 |
+
print("parameter generation fails, try again!")
|
487 |
+
re_time += 1
|
488 |
+
continue
|
489 |
+
|
490 |
+
|
491 |
+
# api_name = change_name(standardize(api_name))
|
492 |
+
|
493 |
+
if restart != 1:
|
494 |
+
try:
|
495 |
+
observation = self.Call_function(api_name, parameters)
|
496 |
+
except:
|
497 |
+
observation = -1
|
498 |
+
|
499 |
+
if observation == -1:
|
500 |
+
restart = 1
|
501 |
+
observation = str({"error": "", "response": "call API fails"})
|
502 |
+
|
503 |
+
if restart == 1:
|
504 |
+
tool_used_dic[step_num].append(str(tool_id))
|
505 |
+
print('****Try Again For This Step****')
|
506 |
+
re_time += 1
|
507 |
+
restart = 0
|
508 |
+
continue
|
509 |
+
|
510 |
+
|
511 |
+
if len(previous_log) != 0:
|
512 |
+
previous_id = previous_log[-1]["id"]
|
513 |
+
else:
|
514 |
+
previous_id = -1
|
515 |
+
|
516 |
+
current_log["tool"] = str(tool_id)
|
517 |
+
current_log["thought"] = thought
|
518 |
+
current_log["action"] = api_name
|
519 |
+
current_log["action_input"] = parameters
|
520 |
+
current_log["observation"] = observation
|
521 |
+
print("##########Tool Response##########")
|
522 |
+
print(f"{observation}\n")
|
523 |
+
previous_log.append(current_log)
|
524 |
+
|
525 |
+
observation_log = get_observation_log(previous_log)
|
526 |
+
|
527 |
+
answer = self.Answer_Agent.run(question=subtask, call_result=observation_log, query_id=query_id, task="answer")
|
528 |
+
|
529 |
+
previous_log[-1]["answer"] = answer
|
530 |
+
|
531 |
+
history_log_ele = {"thought": thought, "action": api_name, "action_input": parameters, "observation": observation, "answer": answer, "previous_id": previous_id, "id": subtask_id}
|
532 |
+
history_log.append(history_log_ele)
|
533 |
+
subtask_id += 1
|
534 |
+
|
535 |
+
speak, status = self.Verifier_Agent.run(question=subtask, answer=answer, query_id=query_id)
|
536 |
+
if speak == -1 and status == -1:
|
537 |
+
step_num += 1
|
538 |
+
continue
|
539 |
+
|
540 |
+
try:
|
541 |
+
if int(status) == 0:
|
542 |
+
hint = speak
|
543 |
+
step_num += 1
|
544 |
+
continue
|
545 |
+
except:
|
546 |
+
step_num += 1
|
547 |
+
continue
|
548 |
+
|
549 |
+
else:
|
550 |
+
return answer, previous_log, re_time, history_log
|
551 |
+
|
552 |
+
def decompose_inference(self, query, query_id):
|
553 |
+
while True:
|
554 |
+
subtasks = self.Planning_Agent.run(question=query, query_id=query_id)
|
555 |
+
if subtasks == -1:
|
556 |
+
continue
|
557 |
+
break
|
558 |
+
task_log = ""
|
559 |
+
history_log = []
|
560 |
+
previous_log_totals = []
|
561 |
+
re_time_total = 0
|
562 |
+
# print(subtasks)
|
563 |
+
relevant_API_list = {}
|
564 |
+
tool_id = 0
|
565 |
+
for api in self.available_tools:
|
566 |
+
tool_name = api["api_name"]
|
567 |
+
ele = {"ID": tool_id, "tool_name": tool_name, "description": api["api_description"], "required_parameters": api["required_parameters"], "optional_parameters": api["optional_parameters"]}
|
568 |
+
relevant_API_list[str(tool_id)] = ele
|
569 |
+
tool_id += 1
|
570 |
+
|
571 |
+
for subtask in subtasks:
|
572 |
+
task_log += f"question: {subtask}\n"
|
573 |
+
answer, previous_log, re_time, previous_log_total = self.inference(task_log, relevant_API_list, subtask, query_id)
|
574 |
+
previous_log_totals.append(previous_log_total)
|
575 |
+
# print(answer)
|
576 |
+
history_log += previous_log
|
577 |
+
re_time_total += re_time
|
578 |
+
task_log += f"answer: {answer}\n"
|
579 |
+
final_answer = self.Answer_Agent.run(question=query, previous_log=task_log, task="final", query_id=query_id)
|
580 |
+
return final_answer, history_log, task_log, re_time_total, previous_log_totals
|
581 |
+
|
582 |
+
def run(self, input, query_id):
|
583 |
+
# result = {}
|
584 |
+
# st = time.time()
|
585 |
+
HPEnv = self.tool_class()
|
586 |
+
self.tool_env = {
|
587 |
+
"BingSearch": HPEnv.BingSearch,
|
588 |
+
"Retrieve": HPEnv.Retrieve,
|
589 |
+
"Lookup": HPEnv.Lookup
|
590 |
+
}
|
591 |
+
final_answer, previous_log, task_log,re_time, previous_log_totals = self.decompose_inference(input, query_id)
|
592 |
+
answer_details, total_steps = get_answer_details(final_answer, previous_log)
|
593 |
+
solution_tree, solution_total_steps = build_tree(previous_log_totals, task_log)
|
594 |
+
output_file_ele = {
|
595 |
+
"query": input,
|
596 |
+
"restart_time": re_time,
|
597 |
+
"answer": {
|
598 |
+
"method": "decompose_dfs",
|
599 |
+
"total_steps": total_steps,
|
600 |
+
"final_answer": final_answer,
|
601 |
+
"answer_details": answer_details
|
602 |
+
}
|
603 |
+
}
|
604 |
+
|
605 |
+
solution_file_ele = {
|
606 |
+
"query": input,
|
607 |
+
"total_steps": solution_total_steps,
|
608 |
+
"task_log": task_log,
|
609 |
+
"final_answer": final_answer,
|
610 |
+
"answer_path": answer_details,
|
611 |
+
"total_path": solution_tree
|
612 |
+
}
|
613 |
+
return final_answer, output_file_ele, solution_file_ele
|
614 |
+
|
615 |
+
def save_solution(self, output_file_ele, solution_file_ele, idx):
|
616 |
+
file_name = f"{idx}.json"
|
617 |
+
output_file = os.path.join(self.output_dir, file_name)
|
618 |
+
whole_solution_file = os.path.join(self.whole_solution_dir, file_name)
|
619 |
+
with open(output_file, "w") as file:
|
620 |
+
json.dump(output_file_ele, file, ensure_ascii=False, indent=4)
|
621 |
+
with open(whole_solution_file, "w") as file:
|
622 |
+
json.dump(solution_file_ele, file, ensure_ascii=False, indent=4)
|
623 |
+
|
624 |
+
def Call_function(self, tool_name, args):
|
625 |
+
try:
|
626 |
+
print(tool_name)
|
627 |
+
func = self.tool_env[tool_name]
|
628 |
+
observation = func(**args)
|
629 |
+
return observation
|
630 |
+
except Exception as e:
|
631 |
+
print(e)
|
632 |
+
print(f"Call function fails")
|
633 |
+
with open('wrong_log.json', 'a+', encoding='utf-8') as f:
|
634 |
+
line = json.dumps({
|
635 |
+
"id": 0,
|
636 |
+
"parameters": args,
|
637 |
+
"tool": tool_name,
|
638 |
+
"wrong": str(e)
|
639 |
+
}, ensure_ascii=False)
|
640 |
+
f.write(line + '\n')
|
641 |
+
return -1
|
642 |
+
|
643 |
+
|
644 |
+
class stream_smurfs_worker:
|
645 |
+
def __init__(self, available_tools, tool_env, llm, method_name, test_set, Answer_Agent, Executor_Agent, Planning_Agent, Verifier_Agent, OPENAI_API_KEY, BING_SUBSCRIPT_KEY, WOLFRAMALPH_APP_ID, WEATHER_API_KEYS):
|
646 |
+
#available_tools的格式形如toolbench里面的api_list里的格式,只需要api_name
|
647 |
+
#tool_env是一个工具函数里用来存储工具代码的py文件中的所有函数的字典,key为函数名,value是函数对象
|
648 |
+
self.OPENAI_API_KEY = OPENAI_API_KEY
|
649 |
+
self.BING_SUBSCRIPT_KEY = BING_SUBSCRIPT_KEY
|
650 |
+
self.WOLFRAMALPH_APP_ID = WOLFRAMALPH_APP_ID
|
651 |
+
self.WEATHER_API_KEYS = WEATHER_API_KEYS
|
652 |
+
#print(self.BING_SUBSCRIPT_KEY)
|
653 |
+
self.available_tools = available_tools
|
654 |
+
|
655 |
+
self.output_dir = f"data/{method_name}/{test_set}/answer"
|
656 |
+
|
657 |
+
self.whole_solution_dir = f"data/{method_name}/{test_set}/whole"
|
658 |
+
|
659 |
+
self.logger_dir = f"data/{method_name}/{test_set}/agent_log"
|
660 |
+
|
661 |
+
if not os.path.exists(self.output_dir):
|
662 |
+
os.makedirs(self.output_dir)
|
663 |
+
|
664 |
+
if not os.path.exists(self.whole_solution_dir):
|
665 |
+
os.makedirs(self.whole_solution_dir)
|
666 |
+
|
667 |
+
if not os.path.exists(self.logger_dir):
|
668 |
+
os.makedirs(self.logger_dir)
|
669 |
+
|
670 |
+
self.Answer_Agent = Answer_Agent(llm=llm, logger_dir=self.logger_dir)
|
671 |
+
self.Executor_Agent = Executor_Agent(llm=llm, logger_dir=self.logger_dir)
|
672 |
+
self.Planning_Agent = Planning_Agent(llm=llm, logger_dir=self.logger_dir)
|
673 |
+
self.Verifier_Agent = Verifier_Agent(llm=llm, logger_dir=self.logger_dir)
|
674 |
+
self.tool_env = tool_env
|
675 |
+
|
676 |
+
# def colorful_html(self, task, content, name):
|
677 |
+
# """print out message in different color"""
|
678 |
+
# role_to_color = {
|
679 |
+
# "Answer Agent": "red",
|
680 |
+
# "Executor Agent": "green",
|
681 |
+
# "Planning Agent": "blue",
|
682 |
+
# "Verifier Agent": "yellow",
|
683 |
+
# }
|
684 |
+
# color = role_to_color[name]
|
685 |
+
# html_text = f"<span style='color: {color}'>##########{task}##########<br>{content}<br></span>"
|
686 |
+
# # print(colored(f"##########{task}##########\n{content}\n", role_to_color[name]))
|
687 |
+
# return html_text
|
688 |
+
|
689 |
+
def colorful_html(self, task, content, name):
|
690 |
+
"""print out message in different color"""
|
691 |
+
role_to_color = {
|
692 |
+
"Answer Agent": "red",
|
693 |
+
"Executor Agent": "green",
|
694 |
+
"Planning Agent": "blue",
|
695 |
+
"Verifier Agent": "yellow",
|
696 |
+
}
|
697 |
+
color = role_to_color[name]
|
698 |
+
if task != "Final Answer Generation":
|
699 |
+
html_text = f"""<details><summary>{task}<br></summary>{content}<br></details>"""
|
700 |
+
else:
|
701 |
+
html_text = content
|
702 |
+
# html_text = f"<span style='color: {color}'>##########{task}##########<br>{content}<br></span>"
|
703 |
+
# print(colored(f"##########{task}##########\n{content}\n", role_to_color[name]))
|
704 |
+
return html_text
|
705 |
+
|
706 |
+
def inference(self, query, relevant_APIs, subtask, query_id, max_step=3):
|
707 |
+
# tool_check_num, reason = self.Answer_Agent.run(question=query, task="tool_check", query_id=query_id)
|
708 |
+
# #direct answer
|
709 |
+
# if tool_check_num == 1:
|
710 |
+
# # input_dic = {"task": query}
|
711 |
+
# answer = self.Answer_Agent.run(question=query, task="direct", query_id=query_id)
|
712 |
+
# previous_log = [{"thought": reason, "action": "", "action_input": "", "observation": "", "answer": answer, "tool": "","id": 0}]
|
713 |
+
# history_log = [{"thought": reason, "action": "", "action_input": "", "observation": "", "answer": answer, "previous_id": -1, "id": 0}]
|
714 |
+
# return answer, previous_log, 0, history_log
|
715 |
+
|
716 |
+
previous_log = []
|
717 |
+
history_log = []
|
718 |
+
tool_used_dic = {}
|
719 |
+
relevant_APIs_ids = []
|
720 |
+
for idx in relevant_APIs:
|
721 |
+
ele = relevant_APIs[idx]
|
722 |
+
relevant_APIs_ids.append(str(ele["ID"]))
|
723 |
+
restart_time = 0
|
724 |
+
step_num = 0
|
725 |
+
hint = "Beginnig of the agent. No hint yet"
|
726 |
+
retry_tool_id = 0
|
727 |
+
retry_parameter = 0
|
728 |
+
re_time = 0
|
729 |
+
subtask_id = 0
|
730 |
+
restart = 0
|
731 |
+
while True:
|
732 |
+
if step_num >= max_step:
|
733 |
+
yield("<br><br>Reach steps limits, return answers!<br><br>")
|
734 |
+
answer_log = get_answer_log(history_log)
|
735 |
+
answer, task, agent_name, result = self.Answer_Agent.run(question=query, previous_log=answer_log, task="final", query_id=query_id)
|
736 |
+
yield self.colorful_html(task, result, agent_name)
|
737 |
+
yield answer, previous_log, re_time, history_log
|
738 |
+
yield "stop"
|
739 |
+
|
740 |
+
if step_num not in tool_used_dic.keys():
|
741 |
+
tool_used_dic[step_num] = []
|
742 |
+
|
743 |
+
tool_used = tool_used_dic[step_num]
|
744 |
+
|
745 |
+
tool_list = []
|
746 |
+
for idx in relevant_APIs:
|
747 |
+
ele = relevant_APIs[idx]
|
748 |
+
ID = str(ele['ID'])
|
749 |
+
if ID in tool_used:
|
750 |
+
continue
|
751 |
+
des = ele['description']
|
752 |
+
name = ele["tool_name"]
|
753 |
+
tool_list.append({"ID": ID, "tool_name": name, "description": des})
|
754 |
+
|
755 |
+
if len(tool_list) == 0:
|
756 |
+
if len(previous_log) == 0:
|
757 |
+
answer_log = get_answer_log(history_log)
|
758 |
+
partial_answer, task, agent_name, result = self.Answer_Agent.run(question=query, previous_log=answer_log, task="final", query_id=query_id)
|
759 |
+
answer = f"Sorry, I can't answer this question accurately using the existing tools. A partial answer is: {partial_answer}"
|
760 |
+
yield self.colorful_html(task, result, agent_name)
|
761 |
+
yield answer, previous_log, re_time, history_log
|
762 |
+
yield "stop"
|
763 |
+
else:
|
764 |
+
delete_log = previous_log.pop()
|
765 |
+
tool_used_dic[step_num] = []
|
766 |
+
step_num -= 1
|
767 |
+
tool_used_dic[step_num].append(delete_log["tool"])
|
768 |
+
restart_time += 1
|
769 |
+
re_time += 1
|
770 |
+
continue
|
771 |
+
|
772 |
+
current_log = {"thought": "", "action": "", "action_input": {}, "observation": "", "answer": "", "tool": "","id": subtask_id}
|
773 |
+
|
774 |
+
answer_log = get_answer_log(previous_log)
|
775 |
+
|
776 |
+
if retry_tool_id == 4:
|
777 |
+
tool_id = tool_list[0]["ID"]
|
778 |
+
tool_list = tool_list[0]
|
779 |
+
thought, task, agent_name, result = self.Executor_Agent.run(question=query, tool_list=tool_list, previous_log=answer_log, hint=hint, query_id=query_id, task="thought")
|
780 |
+
yield self.colorful_html(task, result, agent_name)
|
781 |
+
else:
|
782 |
+
thought, task, agent_name, result = self.Executor_Agent.run(question=query, tool_list=tool_list, previous_log=answer_log, hint=hint, query_id=query_id, task="thought")
|
783 |
+
yield self.colorful_html(task, result, agent_name)
|
784 |
+
tool_id, task, agent_name, result = self.Executor_Agent.run(question=subtask, tool_list=tool_list, thought=thought, query_id=query_id, task="tool")
|
785 |
+
yield self.colorful_html(task, result, agent_name)
|
786 |
+
|
787 |
+
try:
|
788 |
+
tool_id = int(tool_id)
|
789 |
+
tool_id = str(tool_id)
|
790 |
+
if tool_id not in relevant_APIs_ids:
|
791 |
+
re_time += 1
|
792 |
+
retry_tool_id += 1
|
793 |
+
yield("Tool ID wrong! Generate tool_id that do not exist!<br>")
|
794 |
+
continue
|
795 |
+
tool_des_json = relevant_APIs[str(tool_id)]
|
796 |
+
retry_tool_id = 0
|
797 |
+
except:
|
798 |
+
retry_tool_id += 1
|
799 |
+
yield("Tool ID wrong! Generate tool_id that do not exist!<br>")
|
800 |
+
continue
|
801 |
+
|
802 |
+
# tool_name_list = tool_des_json["tool_name"].split(":")
|
803 |
+
# category_name = tool_name_list[0]
|
804 |
+
# tool_name = tool_name_list[1]
|
805 |
+
api_name = tool_des_json["tool_name"]
|
806 |
+
API_doc = tool_des_json
|
807 |
+
|
808 |
+
while True:
|
809 |
+
try:
|
810 |
+
parameters = {}
|
811 |
+
|
812 |
+
if retry_parameter == 4:
|
813 |
+
restart = 1
|
814 |
+
retry_parameter = 0
|
815 |
+
yield("No Para! Restart!<br>")
|
816 |
+
break
|
817 |
+
|
818 |
+
parameter, task, agent_name, result = self.Executor_Agent.run(api_dic=API_doc, question=query, previous_log=answer_log, thought=thought, query_id=query_id, task="parameter")
|
819 |
+
yield self.colorful_html(task, result, agent_name)
|
820 |
+
|
821 |
+
if parameter == -1:
|
822 |
+
retry_parameter += 1
|
823 |
+
re_time += 1
|
824 |
+
continue
|
825 |
+
|
826 |
+
if parameter == {}:
|
827 |
+
retry_parameter = 0
|
828 |
+
parameters = {}
|
829 |
+
break
|
830 |
+
|
831 |
+
for key in parameter:
|
832 |
+
value = parameter[key]
|
833 |
+
key = change_name(key)
|
834 |
+
parameters[key] = value
|
835 |
+
|
836 |
+
retry_parameter = 0
|
837 |
+
break
|
838 |
+
|
839 |
+
except:
|
840 |
+
if retry_parameter == 4:
|
841 |
+
parameters = {}
|
842 |
+
retry_parameter = 0
|
843 |
+
restart = 1
|
844 |
+
break
|
845 |
+
retry_parameter += 1
|
846 |
+
yield("parameter generation fails, try again!<br>")
|
847 |
+
re_time += 1
|
848 |
+
continue
|
849 |
+
|
850 |
+
|
851 |
+
# api_name = change_name(standardize(api_name))
|
852 |
+
|
853 |
+
if restart != 1:
|
854 |
+
try:
|
855 |
+
observation = self.Call_function(api_name, parameters)
|
856 |
+
except:
|
857 |
+
observation = -1
|
858 |
+
|
859 |
+
if observation == -1:
|
860 |
+
restart = 1
|
861 |
+
observation = str({"error": "", "response": "call API fails"})
|
862 |
+
|
863 |
+
if restart == 1:
|
864 |
+
tool_used_dic[step_num].append(str(tool_id))
|
865 |
+
yield('****Try Again For This Step****<br>')
|
866 |
+
re_time += 1
|
867 |
+
restart = 0
|
868 |
+
continue
|
869 |
+
|
870 |
+
|
871 |
+
if len(previous_log) != 0:
|
872 |
+
previous_id = previous_log[-1]["id"]
|
873 |
+
else:
|
874 |
+
previous_id = -1
|
875 |
+
|
876 |
+
current_log["tool"] = str(tool_id)
|
877 |
+
current_log["thought"] = thought
|
878 |
+
current_log["action"] = api_name
|
879 |
+
current_log["action_input"] = parameters
|
880 |
+
current_log["observation"] = observation
|
881 |
+
previous_log.append(current_log)
|
882 |
+
yield(f"<details><summary>Tool Response</summary>{observation}<br></details>")
|
883 |
+
# print(f"{observation}\n")
|
884 |
+
observation_log = get_observation_log(previous_log)
|
885 |
+
|
886 |
+
answer, task, agent_name, result = self.Answer_Agent.run(question=subtask, call_result=observation_log, query_id=query_id, task="answer")
|
887 |
+
yield self.colorful_html(task, result, agent_name)
|
888 |
+
previous_log[-1]["answer"] = answer
|
889 |
+
|
890 |
+
history_log_ele = {"thought": thought, "action": api_name, "action_input": parameters, "observation": observation, "answer": answer, "previous_id": previous_id, "id": subtask_id}
|
891 |
+
history_log.append(history_log_ele)
|
892 |
+
subtask_id += 1
|
893 |
+
|
894 |
+
speak, status, task, agent_name, result = self.Verifier_Agent.run(question=subtask, answer=answer, query_id=query_id)
|
895 |
+
yield self.colorful_html(task, result, agent_name)
|
896 |
+
|
897 |
+
if speak == -1 and status == -1:
|
898 |
+
step_num += 1
|
899 |
+
continue
|
900 |
+
|
901 |
+
try:
|
902 |
+
if int(status) == 0:
|
903 |
+
hint = speak
|
904 |
+
step_num += 1
|
905 |
+
continue
|
906 |
+
except:
|
907 |
+
step_num += 1
|
908 |
+
continue
|
909 |
+
|
910 |
+
else:
|
911 |
+
yield answer, previous_log, re_time, history_log
|
912 |
+
yield "stop"
|
913 |
+
|
914 |
+
def run(self, query, query_id):
|
915 |
+
output = ""
|
916 |
+
count = 0
|
917 |
+
while True:
|
918 |
+
subtasks, task, agent_name, result = self.Planning_Agent.run(question=query, query_id=query_id)
|
919 |
+
if subtasks == -1:
|
920 |
+
count += 1
|
921 |
+
if count >= 1:
|
922 |
+
yield "Task Decompose Fails! Your OpenAI Key can not function correctly."
|
923 |
+
raise RuntimeError
|
924 |
+
continue
|
925 |
+
|
926 |
+
break
|
927 |
+
output += self.colorful_html(task, result, agent_name)
|
928 |
+
yield output
|
929 |
+
task_log = ""
|
930 |
+
history_log = []
|
931 |
+
previous_log_totals = []
|
932 |
+
re_time_total = 0
|
933 |
+
# print(subtasks)
|
934 |
+
relevant_API_list = {}
|
935 |
+
tool_id = 0
|
936 |
+
for api in self.available_tools:
|
937 |
+
tool_name = api["api_name"]
|
938 |
+
ele = {"ID": tool_id, "tool_name": tool_name, "description": api["api_description"], "required_parameters": api["required_parameters"], "optional_parameters": api["optional_parameters"]}
|
939 |
+
relevant_API_list[str(tool_id)] = ele
|
940 |
+
tool_id += 1
|
941 |
+
|
942 |
+
for subtask in subtasks:
|
943 |
+
sub_output = ""
|
944 |
+
output_ele = "<details><summary>subtask: {subtask}</summary>{sub_output}</details>"
|
945 |
+
task_log += f"question: {subtask}\n"
|
946 |
+
inference_generator = self.inference(task_log, relevant_API_list, subtask, query_id)
|
947 |
+
# old = None
|
948 |
+
while True:
|
949 |
+
try:
|
950 |
+
result = next(inference_generator)
|
951 |
+
# if result == "stop":
|
952 |
+
# break
|
953 |
+
if isinstance(result, str):
|
954 |
+
sub_output += result
|
955 |
+
sub_out = output_ele.format(subtask=subtask, sub_output=sub_output)
|
956 |
+
output_ins = output + sub_out
|
957 |
+
yield output_ins
|
958 |
+
else:
|
959 |
+
break
|
960 |
+
except StopIteration:
|
961 |
+
break
|
962 |
+
output += sub_out
|
963 |
+
answer, previous_log, re_time, previous_log_total = result
|
964 |
+
#answer, previous_log, re_time, previous_log_total = self.inference(task_log, relevant_API_list, subtask, query_id)
|
965 |
+
previous_log_totals.append(previous_log_total)
|
966 |
+
# print(answer)
|
967 |
+
history_log += previous_log
|
968 |
+
re_time_total += re_time
|
969 |
+
task_log += f"answer: {answer}\n"
|
970 |
+
final_answer, task, agent_name, result = self.Answer_Agent.run(question=query, previous_log=task_log, task="final", query_id=query_id)
|
971 |
+
output += self.colorful_html(task, result, agent_name)
|
972 |
+
yield output
|
973 |
+
# return final_answer, history_log, task_log, re_time_total, previous_log_totals
|
974 |
+
|
975 |
+
# def run(self, input, query_id):
|
976 |
+
# # result = {}
|
977 |
+
# # st = time.time()
|
978 |
+
# final_answer, previous_log, task_log,re_time, previous_log_totals = self.decompose_inference(input, query_id)
|
979 |
+
# answer_details, total_steps = get_answer_details(final_answer, previous_log)
|
980 |
+
# solution_tree, solution_total_steps = build_tree(previous_log_totals, task_log)
|
981 |
+
# output_file_ele = {
|
982 |
+
# "query": input,
|
983 |
+
# "restart_time": re_time,
|
984 |
+
# "answer": {
|
985 |
+
# "method": "decompose_dfs",
|
986 |
+
# "total_steps": total_steps,
|
987 |
+
# "final_answer": final_answer,
|
988 |
+
# "answer_details": answer_details
|
989 |
+
# }
|
990 |
+
# }
|
991 |
+
|
992 |
+
# solution_file_ele = {
|
993 |
+
# "query": input,
|
994 |
+
# "total_steps": solution_total_steps,
|
995 |
+
# "task_log": task_log,
|
996 |
+
# "final_answer": final_answer,
|
997 |
+
# "answer_path": answer_details,
|
998 |
+
# "total_path": solution_tree
|
999 |
+
# }
|
1000 |
+
# return final_answer, output_file_ele, solution_file_ele
|
1001 |
+
|
1002 |
+
def save_solution(self, output_file_ele, solution_file_ele, idx):
|
1003 |
+
file_name = f"{idx}.json"
|
1004 |
+
output_file = os.path.join(self.output_dir, file_name)
|
1005 |
+
whole_solution_file = os.path.join(self.whole_solution_dir, file_name)
|
1006 |
+
with open(output_file, "w") as file:
|
1007 |
+
json.dump(output_file_ele, file, ensure_ascii=False, indent=4)
|
1008 |
+
with open(whole_solution_file, "w") as file:
|
1009 |
+
json.dump(solution_file_ele, file, ensure_ascii=False, indent=4)
|
1010 |
+
|
1011 |
+
def Call_function(self, tool_name, args):
|
1012 |
+
try:
|
1013 |
+
print(tool_name)
|
1014 |
+
# print(self.BING_SUBSCRIPT_KEY)
|
1015 |
+
if tool_name == "bing_search" or tool_name == "BingSearch":
|
1016 |
+
args["key"] = self.BING_SUBSCRIPT_KEY
|
1017 |
+
if tool_name == "forecast_weather" or tool_name == "get_weather_today":
|
1018 |
+
args["KEY"] = self.WEATHER_API_KEYS
|
1019 |
+
if tool_name == "getWolframAlphaResults":
|
1020 |
+
args["APPID"] = self.WOLFRAMALPH_APP_ID
|
1021 |
+
|
1022 |
+
print(args)
|
1023 |
+
func = self.tool_env[tool_name]
|
1024 |
+
observation = func(**args)
|
1025 |
+
return observation
|
1026 |
+
except Exception as e:
|
1027 |
+
print(e)
|
1028 |
+
print(f"Call function fails")
|
1029 |
+
with open('wrong_log.json', 'a+', encoding='utf-8') as f:
|
1030 |
+
line = json.dumps({
|
1031 |
+
"id": 0,
|
1032 |
+
"parameters": args,
|
1033 |
+
"tool": tool_name,
|
1034 |
+
"wrong": str(e)
|
1035 |
+
}, ensure_ascii=False)
|
1036 |
+
f.write(line + '\n')
|
1037 |
+
return -1
|
1038 |
+
|
1039 |
+
|
1040 |
+
|
Smurfs/inference/utils.py
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# — coding: utf-8 –
|
2 |
+
import json
|
3 |
+
import re
|
4 |
+
import os
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
def get_white_list(tool_root_dir):
|
8 |
+
# print(tool_root_dir)
|
9 |
+
white_list_dir = os.path.join(tool_root_dir)
|
10 |
+
white_list = {}
|
11 |
+
for cate in tqdm(os.listdir(white_list_dir)):
|
12 |
+
if not os.path.isdir(os.path.join(white_list_dir,cate)):
|
13 |
+
continue
|
14 |
+
for file in os.listdir(os.path.join(white_list_dir,cate)):
|
15 |
+
if not file.endswith(".json"):
|
16 |
+
continue
|
17 |
+
standard_tool_name = file.split(".")[0]
|
18 |
+
# print(standard_tool_name)
|
19 |
+
with open(os.path.join(white_list_dir,cate,file)) as reader:
|
20 |
+
js_data = json.load(reader)
|
21 |
+
origin_tool_name = js_data["tool_name"]
|
22 |
+
white_list[standardize(origin_tool_name)] = {"description": js_data["tool_description"], "standard_tool_name": standard_tool_name}
|
23 |
+
return white_list
|
24 |
+
|
25 |
+
def build_index(base_path):
|
26 |
+
index = {}
|
27 |
+
for root, dirs, files in os.walk(base_path):
|
28 |
+
for dir_name in dirs:
|
29 |
+
if dir_name not in index:
|
30 |
+
index[dir_name] = []
|
31 |
+
index[dir_name].append(root)
|
32 |
+
return index
|
33 |
+
|
34 |
+
|
35 |
+
def change_name(name):
|
36 |
+
change_list = ["from", "class", "return", "false", "true", "id", "and", "", "ID"]
|
37 |
+
if name in change_list:
|
38 |
+
name = "is_" + name.lower()
|
39 |
+
return name
|
40 |
+
|
41 |
+
|
42 |
+
def standardize(string):
|
43 |
+
res = re.compile("[^\\u4e00-\\u9fa5^a-z^A-Z^0-9^_]")
|
44 |
+
string = res.sub("_", string)
|
45 |
+
string = re.sub(r"(_)\1+", "_", string).lower()
|
46 |
+
while True:
|
47 |
+
if len(string) == 0:
|
48 |
+
return string
|
49 |
+
if string[0] == "_":
|
50 |
+
string = string[1:]
|
51 |
+
else:
|
52 |
+
break
|
53 |
+
while True:
|
54 |
+
if len(string) == 0:
|
55 |
+
return string
|
56 |
+
if string[-1] == "_":
|
57 |
+
string = string[:-1]
|
58 |
+
else:
|
59 |
+
break
|
60 |
+
if string[0].isdigit():
|
61 |
+
string = "get_" + string
|
62 |
+
return string
|
63 |
+
|
64 |
+
def get_answer_log(log):
|
65 |
+
if log == []:
|
66 |
+
return "Beginnig of the agent. No log yet"
|
67 |
+
answer_logs = []
|
68 |
+
for ele in log:
|
69 |
+
answer_log = {"thought": "", "answer": ""}
|
70 |
+
answer_log["thought"] = ele["thought"]
|
71 |
+
answer_log["answer"] = ele["answer"]
|
72 |
+
answer_logs.append(answer_log)
|
73 |
+
return answer_logs
|
74 |
+
|
75 |
+
def get_observation_log(log):
|
76 |
+
if log == []:
|
77 |
+
return ""
|
78 |
+
answer_logs = []
|
79 |
+
for i, ele in enumerate(log):
|
80 |
+
if i == len(log)-1:
|
81 |
+
answer_log = {"thought": "", "observation": ""}
|
82 |
+
answer_log["thought"] = ele["thought"]
|
83 |
+
answer_log["observation"] = ele["observation"]
|
84 |
+
answer_logs.append(answer_log)
|
85 |
+
else:
|
86 |
+
answer_log = {"thought": "", "answer": ""}
|
87 |
+
answer_log["thought"] = ele["thought"]
|
88 |
+
answer_log["answer"] = ele["answer"]
|
89 |
+
answer_logs.append(answer_log)
|
90 |
+
return answer_logs
|
91 |
+
|
92 |
+
def build_tree(previous_log_totals, task_log):
|
93 |
+
|
94 |
+
total_root_list = []
|
95 |
+
total_total_steps = 0
|
96 |
+
task_log_list = task_log.split("question: ")[1:]
|
97 |
+
for i in range(len(task_log_list)):
|
98 |
+
task_log_list[i] = task_log_list[i].split("answer: ")
|
99 |
+
|
100 |
+
for j, previous_log_total in enumerate(previous_log_totals):
|
101 |
+
|
102 |
+
if previous_log_total == None:
|
103 |
+
|
104 |
+
answer_detail = {
|
105 |
+
"role": "plan_global",
|
106 |
+
"message": {
|
107 |
+
"subtask": task_log_list[j][0],
|
108 |
+
"subtask_answer": task_log_list[j][1]
|
109 |
+
},
|
110 |
+
"total_steps": 0,
|
111 |
+
"next": []
|
112 |
+
}
|
113 |
+
total_root_list.append(answer_detail)
|
114 |
+
continue
|
115 |
+
|
116 |
+
next_list = []
|
117 |
+
root_list = []
|
118 |
+
total_steps = 0
|
119 |
+
for i in range(len(previous_log_total)):
|
120 |
+
|
121 |
+
current_log = previous_log_total[i]
|
122 |
+
|
123 |
+
tool_call_list = []
|
124 |
+
api_name = current_log["action"]
|
125 |
+
parameter = current_log["action_input"]
|
126 |
+
response = current_log["observation"]
|
127 |
+
next_ele = {
|
128 |
+
"role": "tool",
|
129 |
+
"message": {
|
130 |
+
"name": api_name,
|
131 |
+
"arguments": parameter,
|
132 |
+
"response": response
|
133 |
+
},
|
134 |
+
"next": []
|
135 |
+
}
|
136 |
+
tool_call_list.append(next_ele)
|
137 |
+
total_steps += 1
|
138 |
+
if len(tool_call_list) > 1:
|
139 |
+
for k in range(len(tool_call_list)-2, -1, -1):
|
140 |
+
tool_call_list[k]["next"].append(tool_call_list[k+1])
|
141 |
+
next_list.append(tool_call_list[0])
|
142 |
+
total_total_steps += total_steps
|
143 |
+
|
144 |
+
for i in range(len(next_list)-1, -1, -1):
|
145 |
+
current_log = next_list[i]
|
146 |
+
current_log_pre_id = previous_log_total[i]["previous_id"]
|
147 |
+
if current_log_pre_id == -1:
|
148 |
+
# print(current_log)
|
149 |
+
root_list.append(current_log)
|
150 |
+
else:
|
151 |
+
next_list[current_log_pre_id]["next"].append(current_log)
|
152 |
+
answer_detail = {
|
153 |
+
"role": "plan_global",
|
154 |
+
"message": {
|
155 |
+
"subtask": task_log_list[j][0],
|
156 |
+
"subtask_answer": task_log_list[j][1]
|
157 |
+
},
|
158 |
+
"total_steps": total_steps,
|
159 |
+
"next": root_list
|
160 |
+
}
|
161 |
+
total_root_list.append(answer_detail)
|
162 |
+
|
163 |
+
answer_details = {
|
164 |
+
"role": "system",
|
165 |
+
"message": "",
|
166 |
+
"next": [
|
167 |
+
{
|
168 |
+
"role": "user",
|
169 |
+
"message": "",
|
170 |
+
"next": total_root_list
|
171 |
+
}
|
172 |
+
]
|
173 |
+
}
|
174 |
+
return answer_details, total_total_steps
|
175 |
+
|
176 |
+
def get_answer_details(final_answer, previous_log):
|
177 |
+
|
178 |
+
next_list = []
|
179 |
+
total_steps = 0
|
180 |
+
for i in range(len(previous_log)):
|
181 |
+
current_log = previous_log[i]
|
182 |
+
if not isinstance(current_log, dict):
|
183 |
+
next_ele = {
|
184 |
+
"role": "assistant",
|
185 |
+
"message": current_log,
|
186 |
+
"next": []
|
187 |
+
}
|
188 |
+
next_list.append(next_ele)
|
189 |
+
total_steps += 1
|
190 |
+
continue
|
191 |
+
|
192 |
+
api_name = current_log["action"]
|
193 |
+
parameter = current_log["action_input"]
|
194 |
+
response = current_log["observation"]
|
195 |
+
next_ele = {
|
196 |
+
"role": "tool",
|
197 |
+
"message": {
|
198 |
+
"name": api_name,
|
199 |
+
"arguments": parameter,
|
200 |
+
"response": response
|
201 |
+
},
|
202 |
+
"next": []
|
203 |
+
}
|
204 |
+
next_list.append(next_ele)
|
205 |
+
total_steps += 1
|
206 |
+
answer_ele = {
|
207 |
+
"role": "tool",
|
208 |
+
"message": {
|
209 |
+
"name": "Finish",
|
210 |
+
"arguments": {
|
211 |
+
"return_type": "give_answer",
|
212 |
+
"final_answer": final_answer
|
213 |
+
},
|
214 |
+
"response": ""
|
215 |
+
},
|
216 |
+
"next": []
|
217 |
+
}
|
218 |
+
next_list.append(answer_ele)
|
219 |
+
for i in range(len(next_list)-2, -1, -1):
|
220 |
+
next_list[i]["next"].append(next_list[i+1])
|
221 |
+
next_result = next_list[0]
|
222 |
+
answer_details = {
|
223 |
+
"role": "system",
|
224 |
+
"message": "",
|
225 |
+
"next": [
|
226 |
+
{
|
227 |
+
"role": "user",
|
228 |
+
"message": "",
|
229 |
+
"next": [next_result]
|
230 |
+
}
|
231 |
+
]
|
232 |
+
}
|
233 |
+
return answer_details, total_steps
|
234 |
+
|
235 |
+
def contain(candidate_list, white_list):
|
236 |
+
output = []
|
237 |
+
for cand in candidate_list:
|
238 |
+
if cand not in white_list.keys():
|
239 |
+
return False
|
240 |
+
output.append(white_list[cand])
|
241 |
+
return output
|
242 |
+
|
243 |
+
# def fetch_api_json(api_list, tool_root_dir):
|
244 |
+
# data_dict = {"api_list":[]}
|
245 |
+
# for item in api_list:
|
246 |
+
# cate_name = item["category_name"]
|
247 |
+
# tool_name = standardize(item["tool_name"])
|
248 |
+
# api_name = change_name(standardize(item["api_name"]))
|
249 |
+
# tool_json = json.load(open(os.path.join(tool_root_dir, cate_name, tool_name + ".json"), "r"))
|
250 |
+
# append_flag = False
|
251 |
+
# api_dict_names = []
|
252 |
+
# for api_dict in tool_json["api_list"]:
|
253 |
+
# api_dict_names.append(api_dict["name"])
|
254 |
+
# pure_api_name = change_name(standardize(api_dict["name"]))
|
255 |
+
# if pure_api_name != api_name:
|
256 |
+
# continue
|
257 |
+
# api_json = {}
|
258 |
+
# api_json["category_name"] = cate_name
|
259 |
+
# api_json["api_name"] = api_dict["name"]
|
260 |
+
# api_json["api_description"] = api_dict["description"]
|
261 |
+
# api_json["required_parameters"] = api_dict["required_parameters"]
|
262 |
+
# api_json["optional_parameters"] = api_dict["optional_parameters"]
|
263 |
+
# api_json["tool_name"] = tool_json["tool_name"]
|
264 |
+
# data_dict["api_list"].append(api_json)
|
265 |
+
# append_flag = True
|
266 |
+
# break
|
267 |
+
# if not append_flag:
|
268 |
+
# print(api_name, api_dict_names)
|
269 |
+
# return data_dict
|
270 |
+
|
271 |
+
# def api_json_to_openai_json(api_json,standard_tool_name):
|
272 |
+
# description_max_length=256
|
273 |
+
# templete = {
|
274 |
+
# "name": "",
|
275 |
+
# "description": "",
|
276 |
+
# "parameters": {
|
277 |
+
# "type": "object",
|
278 |
+
# "properties": {
|
279 |
+
# },
|
280 |
+
# "required": [],
|
281 |
+
# "optional": [],
|
282 |
+
# }
|
283 |
+
# }
|
284 |
+
|
285 |
+
# map_type = {
|
286 |
+
# "NUMBER": "integer",
|
287 |
+
# "STRING": "string",
|
288 |
+
# "BOOLEAN": "boolean"
|
289 |
+
# }
|
290 |
+
|
291 |
+
# pure_api_name = change_name(standardize(api_json["api_name"]))
|
292 |
+
# templete["name"] = pure_api_name+ f"_for_{standard_tool_name}"
|
293 |
+
# templete["name"] = templete["name"][-64:]
|
294 |
+
|
295 |
+
# templete["description"] = f"This is the subfunction for tool \"{standard_tool_name}\", you can use this tool."
|
296 |
+
|
297 |
+
# if api_json["api_description"].strip() != "":
|
298 |
+
# tuncated_description = api_json['api_description'].strip().replace(api_json['api_name'],templete['name'])[:description_max_length]
|
299 |
+
# templete["description"] = templete["description"] + f"The description of this function is: \"{tuncated_description}\""
|
300 |
+
# if "required_parameters" in api_json.keys() and len(api_json["required_parameters"]) > 0:
|
301 |
+
# for para in api_json["required_parameters"]:
|
302 |
+
# name = standardize(para["name"])
|
303 |
+
# name = change_name(name)
|
304 |
+
# if para["type"] in map_type:
|
305 |
+
# param_type = map_type[para["type"]]
|
306 |
+
# else:
|
307 |
+
# param_type = "string"
|
308 |
+
# prompt = {
|
309 |
+
# "type":param_type,
|
310 |
+
# "description":para["description"][:description_max_length],
|
311 |
+
# }
|
312 |
+
|
313 |
+
# default_value = para['default']
|
314 |
+
# if len(str(default_value)) != 0:
|
315 |
+
# prompt = {
|
316 |
+
# "type":param_type,
|
317 |
+
# "description":para["description"][:description_max_length],
|
318 |
+
# "example_value": default_value
|
319 |
+
# }
|
320 |
+
# else:
|
321 |
+
# prompt = {
|
322 |
+
# "type":param_type,
|
323 |
+
# "description":para["description"][:description_max_length]
|
324 |
+
# }
|
325 |
+
|
326 |
+
# templete["parameters"]["properties"][name] = prompt
|
327 |
+
# templete["parameters"]["required"].append(name)
|
328 |
+
# for para in api_json["optional_parameters"]:
|
329 |
+
# name = standardize(para["name"])
|
330 |
+
# name = change_name(name)
|
331 |
+
# if para["type"] in map_type:
|
332 |
+
# param_type = map_type[para["type"]]
|
333 |
+
# else:
|
334 |
+
# param_type = "string"
|
335 |
+
|
336 |
+
# default_value = para['default']
|
337 |
+
# if len(str(default_value)) != 0:
|
338 |
+
# prompt = {
|
339 |
+
# "type":param_type,
|
340 |
+
# "description":para["description"][:description_max_length],
|
341 |
+
# "example_value": default_value
|
342 |
+
# }
|
343 |
+
# else:
|
344 |
+
# prompt = {
|
345 |
+
# "type":param_type,
|
346 |
+
# "description":para["description"][:description_max_length]
|
347 |
+
# }
|
348 |
+
|
349 |
+
# templete["parameters"]["properties"][name] = prompt
|
350 |
+
# templete["parameters"]["optional"].append(name)
|
351 |
+
|
352 |
+
# return templete, api_json["category_name"], pure_api_name
|
353 |
+
|
354 |
+
|
355 |
+
|
356 |
+
test_sets = ["G2_category"]
|
Smurfs/model/__init__.py
ADDED
File without changes
|