陈君至 commited on
Commit
ec21955
·
1 Parent(s): 0a411b5

Add application file

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Smurfs/.DS_Store +0 -0
  2. Smurfs/__init__.py +0 -0
  3. Smurfs/__pycache__/__init__.cpython-39.pyc +0 -0
  4. Smurfs/agents/__init__.py +0 -0
  5. Smurfs/agents/__pycache__/__init__.cpython-39.pyc +0 -0
  6. Smurfs/agents/__pycache__/base.cpython-39.pyc +0 -0
  7. Smurfs/agents/answer_agent/__pycache__/answer.cpython-39.pyc +0 -0
  8. Smurfs/agents/answer_agent/__pycache__/prompt.cpython-39.pyc +0 -0
  9. Smurfs/agents/answer_agent/answer.py +303 -0
  10. Smurfs/agents/answer_agent/prompt.py +73 -0
  11. Smurfs/agents/base.py +51 -0
  12. Smurfs/agents/executor_agent/__pycache__/__init__.cpython-39.pyc +0 -0
  13. Smurfs/agents/executor_agent/__pycache__/executor.cpython-39.pyc +0 -0
  14. Smurfs/agents/executor_agent/__pycache__/prompt.cpython-39.pyc +0 -0
  15. Smurfs/agents/executor_agent/executor.py +246 -0
  16. Smurfs/agents/executor_agent/prompt.py +58 -0
  17. Smurfs/agents/memory_agent/memory_agent.py +0 -0
  18. Smurfs/agents/memory_agent/prompt.py +16 -0
  19. Smurfs/agents/planning_agent/__pycache__/planner.cpython-39.pyc +0 -0
  20. Smurfs/agents/planning_agent/__pycache__/prompt.cpython-39.pyc +0 -0
  21. Smurfs/agents/planning_agent/planner.py +137 -0
  22. Smurfs/agents/planning_agent/prompt.py +44 -0
  23. Smurfs/agents/verifier_agent/__pycache__/prompt.cpython-39.pyc +0 -0
  24. Smurfs/agents/verifier_agent/__pycache__/verifier.cpython-39.pyc +0 -0
  25. Smurfs/agents/verifier_agent/prompt.py +25 -0
  26. Smurfs/agents/verifier_agent/verifier.py +90 -0
  27. Smurfs/data/.DS_Store +0 -0
  28. Smurfs/data/__init__.py +0 -0
  29. Smurfs/data/post_process.py +65 -0
  30. Smurfs/data/utils.py +53 -0
  31. Smurfs/deploy/__init__.py +3 -0
  32. Smurfs/deploy/__pycache__/__init__.cpython-39.pyc +0 -0
  33. Smurfs/deploy/cli_inference.py +58 -0
  34. Smurfs/deploy/gradio_inference.py +223 -0
  35. Smurfs/eval/hotpot_qa/__pycache__/utils.cpython-39.pyc +0 -0
  36. Smurfs/eval/hotpot_qa/post_process.py +109 -0
  37. Smurfs/eval/hotpot_qa/run_eval.py +395 -0
  38. Smurfs/eval/hotpot_qa/utils.py +117 -0
  39. Smurfs/inference/__init__.py +0 -0
  40. Smurfs/inference/__pycache__/__init__.cpython-39.pyc +0 -0
  41. Smurfs/inference/__pycache__/inference.cpython-39.pyc +0 -0
  42. Smurfs/inference/__pycache__/server.cpython-39.pyc +0 -0
  43. Smurfs/inference/__pycache__/smurfs_worker.cpython-39.pyc +0 -0
  44. Smurfs/inference/__pycache__/utils.cpython-39.pyc +0 -0
  45. Smurfs/inference/functioncall_inference.py +533 -0
  46. Smurfs/inference/inference.py +527 -0
  47. Smurfs/inference/server.py +179 -0
  48. Smurfs/inference/smurfs_worker.py +1040 -0
  49. Smurfs/inference/utils.py +356 -0
  50. Smurfs/model/__init__.py +0 -0
Smurfs/.DS_Store ADDED
Binary file (6.15 kB). View file
 
Smurfs/__init__.py ADDED
File without changes
Smurfs/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (160 Bytes). View file
 
Smurfs/agents/__init__.py ADDED
File without changes
Smurfs/agents/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (167 Bytes). View file
 
Smurfs/agents/__pycache__/base.cpython-39.pyc ADDED
Binary file (2.07 kB). View file
 
Smurfs/agents/answer_agent/__pycache__/answer.cpython-39.pyc ADDED
Binary file (6.1 kB). View file
 
Smurfs/agents/answer_agent/__pycache__/prompt.cpython-39.pyc ADDED
Binary file (4.72 kB). View file
 
Smurfs/agents/answer_agent/answer.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from Smurfs.agents.base import BaseAgent
2
+ from Smurfs.agents.answer_agent.prompt import answer_generation_direct_prompt, answer_generation_prompt, final_answer_generation_prompt, tool_check_prompt, hotpot_answer_parser_prompt
3
+ from typing import Any
4
+
5
+ class answer_agent(BaseAgent):
6
+ direct_prompt: Any
7
+ answer_prompt: Any
8
+ final_prompt: Any
9
+ tool_check_prompt: Any
10
+ HP_parser_prompt: Any
11
+
12
+ def __init__(self, *args, **kwargs):
13
+ direct_prompt = answer_generation_direct_prompt
14
+ answer_prompt = answer_generation_prompt
15
+ final_prompt = final_answer_generation_prompt
16
+ check_prompt = tool_check_prompt
17
+ name = "Answer Agent"
18
+ kwargs.update({"direct_prompt": direct_prompt})
19
+ kwargs.update({"answer_prompt": answer_prompt})
20
+ kwargs.update({"final_prompt": final_prompt})
21
+ kwargs.update({"tool_check_prompt": check_prompt})
22
+ kwargs.update({"name": name})
23
+ kwargs.update({"HP_parser_prompt": hotpot_answer_parser_prompt})
24
+ super().__init__(
25
+ *args,
26
+ **kwargs,
27
+ )
28
+
29
+ def run(self, query_id, task, **kwargs):
30
+ """agent run one step"""
31
+ if task == "direct":
32
+ self.get_memory(**kwargs)
33
+ agent_prompt = self.get_prompt(task)
34
+ message = [{'role': 'user',
35
+ 'content': agent_prompt}]
36
+ result = self.llm.prediction(message)
37
+ self.log(query_id, result)
38
+ self.colorful_print(result, "Answer Directly")
39
+ return result
40
+
41
+ elif task == "answer":
42
+ self.get_memory(**kwargs)
43
+ agent_prompt = self.get_prompt(task)
44
+ message = [{'role': 'user',
45
+ 'content': agent_prompt}]
46
+ ind = 0
47
+ while True:
48
+ try:
49
+ result = self.llm.prediction(message)
50
+ self.log(query_id, result)
51
+ self.colorful_print(result, "Answer Generation")
52
+ break
53
+ except Exception as e:
54
+ print(f"answer generation fails: {e}")
55
+ self.log(query_id, f"answer generation fails: {e}")
56
+ if ind > 2:
57
+ return -1
58
+ ind += 1
59
+ continue
60
+ return result
61
+
62
+ elif task == "final":
63
+ self.get_memory(**kwargs)
64
+ agent_prompt = self.get_prompt(task)
65
+ message = [{'role': 'user',
66
+ 'content': agent_prompt}]
67
+ ind = 0
68
+ while True:
69
+ try:
70
+ result = self.llm.prediction(message)
71
+ self.log(query_id, result)
72
+ self.colorful_print(result, "Final Answer Generation")
73
+ break
74
+ except Exception as e:
75
+ print(f"answer generation fails: {e}")
76
+ self.log(query_id, f"answer generation fails: {e}")
77
+ if ind > 2:
78
+ return -1
79
+ ind += 1
80
+ continue
81
+ return result
82
+
83
+ elif task == "tool_check":
84
+ self.get_memory(**kwargs)
85
+ agent_prompt = self.get_prompt(task)
86
+ message = [{'role': 'user',
87
+ 'content': agent_prompt}]
88
+ ind = 0
89
+ while True:
90
+ try:
91
+ result = self.llm.prediction(message)
92
+ result = eval(result)
93
+ a = result["Reason"]
94
+ b = result["Choice"]
95
+ self.log(query_id, result)
96
+ self.colorful_print(result, "Tool Check")
97
+ if 'yes' in b.lower():
98
+ return -1, a
99
+ else:
100
+ return 1, a
101
+ except Exception as e:
102
+ print(f"tool check fails: {e}")
103
+ self.log(query_id, f"tool check fails: {e}")
104
+ if ind > self.max_retry:
105
+ return -1, 'fail'
106
+ ind += 1
107
+ continue
108
+
109
+ elif task == "parse":
110
+ self.get_memory(**kwargs)
111
+ agent_prompt = self.get_prompt(task)
112
+ message = [{'role': 'user',
113
+ 'content': agent_prompt}]
114
+ ind = 0
115
+ while True:
116
+ try:
117
+ result = self.llm.prediction(message)
118
+ # result = eval(result)
119
+ # a = result["Reason"]
120
+ # b = result["Choice"]
121
+ self.colorful_print(result, "Parse Answer Hotpot QA")
122
+ self.log(query_id, result)
123
+ # if 'yes' in b.lower():
124
+ # return result, -1
125
+ # else:
126
+ # return result, 1
127
+ return result
128
+ except Exception as e:
129
+ print(f"answer parse fails: {e}")
130
+ self.log(query_id, f"answer parse fails: {e}")
131
+ if ind > self.max_retry:
132
+ return "answer parse fails"
133
+ ind += 1
134
+ continue
135
+
136
+
137
+ def get_memory(self, **kwargs):
138
+ """get relevant memory and add it to agent's memory"""
139
+ self.memory = kwargs
140
+
141
+ def get_prompt(self, task):
142
+ """get the prompt for the agent"""
143
+ if task == "direct":
144
+ agent_prompt = self.direct_prompt.format(**self.memory)
145
+ elif task == "answer":
146
+ agent_prompt = self.answer_prompt.format(**self.memory)
147
+ elif task == "final":
148
+ agent_prompt = self.final_prompt.format(**self.memory)
149
+ elif task == "tool_check":
150
+ agent_prompt = self.tool_check_prompt.format(**self.memory)
151
+ elif task == "parse":
152
+ agent_prompt = self.HP_parser_prompt.format(**self.memory)
153
+ return agent_prompt
154
+
155
+ class stream_answer_agent(BaseAgent):
156
+ direct_prompt: Any
157
+ answer_prompt: Any
158
+ final_prompt: Any
159
+ tool_check_prompt: Any
160
+ HP_parser_prompt: Any
161
+
162
+ def __init__(self, *args, **kwargs):
163
+ direct_prompt = answer_generation_direct_prompt
164
+ answer_prompt = answer_generation_prompt
165
+ final_prompt = final_answer_generation_prompt
166
+ check_prompt = tool_check_prompt
167
+ name = "Answer Agent"
168
+ kwargs.update({"direct_prompt": direct_prompt})
169
+ kwargs.update({"answer_prompt": answer_prompt})
170
+ kwargs.update({"final_prompt": final_prompt})
171
+ kwargs.update({"tool_check_prompt": check_prompt})
172
+ kwargs.update({"name": name})
173
+ kwargs.update({"HP_parser_prompt": hotpot_answer_parser_prompt})
174
+ super().__init__(
175
+ *args,
176
+ **kwargs,
177
+ )
178
+
179
+ def run(self, query_id, task, **kwargs):
180
+ """agent run one step"""
181
+ if task == "direct":
182
+ self.get_memory(**kwargs)
183
+ agent_prompt = self.get_prompt(task)
184
+ message = [{'role': 'user',
185
+ 'content': agent_prompt}]
186
+ result = self.llm.prediction(message)
187
+ self.log(query_id, result)
188
+ self.colorful_print(result, "Answer Directly")
189
+ return result, "Answer Directly", self.name, result
190
+
191
+ elif task == "answer":
192
+ self.get_memory(**kwargs)
193
+ agent_prompt = self.get_prompt(task)
194
+ message = [{'role': 'user',
195
+ 'content': agent_prompt}]
196
+ ind = 0
197
+ while True:
198
+ try:
199
+ result = self.llm.prediction(message)
200
+ self.log(query_id, result)
201
+ self.colorful_print(result, "Answer Generation")
202
+ break
203
+ except Exception as e:
204
+ print(f"answer generation fails: {e}")
205
+ self.log(query_id, f"answer generation fails: {e}")
206
+ if ind > 2:
207
+ return -1, "Answer Generation", self.name, str(e)
208
+ ind += 1
209
+ continue
210
+ return result, "Answer Generation", self.name, result
211
+
212
+ elif task == "final":
213
+ self.get_memory(**kwargs)
214
+ agent_prompt = self.get_prompt(task)
215
+ message = [{'role': 'user',
216
+ 'content': agent_prompt}]
217
+ ind = 0
218
+ while True:
219
+ try:
220
+ result = self.llm.prediction(message)
221
+ self.log(query_id, result)
222
+ self.colorful_print(result, "Final Answer Generation")
223
+ break
224
+ except Exception as e:
225
+ print(f"answer generation fails: {e}")
226
+ self.log(query_id, f"answer generation fails: {e}")
227
+ if ind > 2:
228
+ return -1, "Final Answer Generation", self.name, str(e)
229
+ ind += 1
230
+ continue
231
+ return result, "Final Answer Generation", self.name, result
232
+
233
+ elif task == "tool_check":
234
+ self.get_memory(**kwargs)
235
+ agent_prompt = self.get_prompt(task)
236
+ message = [{'role': 'user',
237
+ 'content': agent_prompt}]
238
+ ind = 0
239
+ while True:
240
+ try:
241
+ result = self.llm.prediction(message)
242
+ result = eval(result)
243
+ a = result["Reason"]
244
+ b = result["Choice"]
245
+ self.log(query_id, result)
246
+ self.colorful_print(result, "Tool Check")
247
+ if 'yes' in b.lower():
248
+ return -1, a
249
+ else:
250
+ return 1, a
251
+ except Exception as e:
252
+ print(f"tool check fails: {e}")
253
+ self.log(query_id, f"tool check fails: {e}")
254
+ if ind > self.max_retry:
255
+ return -1, 'fail'
256
+ ind += 1
257
+ continue
258
+
259
+ elif task == "parse":
260
+ self.get_memory(**kwargs)
261
+ agent_prompt = self.get_prompt(task)
262
+ message = [{'role': 'user',
263
+ 'content': agent_prompt}]
264
+ ind = 0
265
+ while True:
266
+ try:
267
+ result = self.llm.prediction(message)
268
+ # result = eval(result)
269
+ # a = result["Reason"]
270
+ # b = result["Choice"]
271
+ self.colorful_print(result, "Parse Answer Hotpot QA")
272
+ self.log(query_id, result)
273
+ # if 'yes' in b.lower():
274
+ # return result, -1
275
+ # else:
276
+ # return result, 1
277
+ return result
278
+ except Exception as e:
279
+ print(f"answer parse fails: {e}")
280
+ self.log(query_id, f"answer parse fails: {e}")
281
+ if ind > self.max_retry:
282
+ return "answer parse fails"
283
+ ind += 1
284
+ continue
285
+
286
+
287
+ def get_memory(self, **kwargs):
288
+ """get relevant memory and add it to agent's memory"""
289
+ self.memory = kwargs
290
+
291
+ def get_prompt(self, task):
292
+ """get the prompt for the agent"""
293
+ if task == "direct":
294
+ agent_prompt = self.direct_prompt.format(**self.memory)
295
+ elif task == "answer":
296
+ agent_prompt = self.answer_prompt.format(**self.memory)
297
+ elif task == "final":
298
+ agent_prompt = self.final_prompt.format(**self.memory)
299
+ elif task == "tool_check":
300
+ agent_prompt = self.tool_check_prompt.format(**self.memory)
301
+ elif task == "parse":
302
+ agent_prompt = self.HP_parser_prompt.format(**self.memory)
303
+ return agent_prompt
Smurfs/agents/answer_agent/prompt.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts import PromptTemplate
2
+
3
+ tool_check_prompt = """As a powerful language model, you're equipped to answer user's question with accumulated knowledge.
4
+ However, in some cases, you need to use external APIs to answer accurately.
5
+ Thus, you need to check whether the user's question requires you to call an external API to solve it.
6
+ Here are some tips to help you check:
7
+ 1. If the user's question requires real-time information, since your knowledge base isn't updated in real-time, any such question will demand an API call.
8
+ 2. If you need to obtain information (e.g., ID, name, phone number, geographical location, rank, etc.), you need to call the database APIs if you are not sure.
9
+ 3. If the question demand a database search or internet research to generate an answer, this is another situation where an API call is necessary.
10
+ 4. If the question demand coding and math calculation to generate an answer (e.g., algebraic operation, coding problem), you must call external APIs no matter how simple you think it is.
11
+ If need, please output 'YES'; If not, please output 'NO'
12
+ You need to give reasons first and then decide whether to keep it or not. You must only output in a parsible JSON format. Two example outputs look like:
13
+ Example 1: {{\"Reason\": \"The reason why you think you do not need to call an external API to solve the user's question\", \"Choice\": \"No\"}}
14
+ Example 2: {{\"Reason\": \"The reason why you think you need to call an external API to solve the user's question\", \"Choice\": \"Yes\"}}
15
+ This is the user's question: {question}
16
+ Output: """
17
+ tool_check_prompt = PromptTemplate.from_template(tool_check_prompt)
18
+
19
+ answer_generation_prompt = """
20
+ You should answer the question based on the response output by the API tool.
21
+ Please note that:
22
+ 1. Answer the question in natural language based on the API response reasonably and effectively.
23
+ 2. The user cannot directly get API response, so you need to make full use of the response and give the information in the response that can satisfy the user's question in as much detail as possible.
24
+ 3. Do not output answer that is too long. Output in 3-6 sentences is OK.
25
+
26
+ This is the user's question:
27
+ {question}
28
+ This is the API response:
29
+ {call_result}
30
+ Output:"""
31
+ answer_generation_prompt = PromptTemplate.from_template(answer_generation_prompt)
32
+
33
+ final_answer_generation_prompt = """
34
+ You will be given a complex question and you need to solve it step by step by decomposing it to a series of subtasks that can be solved using a single tool(functions).
35
+ At this step, you need to analyse the previous subtasks and their execution result to generate the answer to the original question reasonably and accurately.
36
+ Please note that:
37
+ 1. Answer the question in natural language based on the subtask results reasonably and effectively.
38
+ 2. The user cannot directly get the subtask results, so you need to make full use of the subtask results and give the information in the response that can satisfy the user's question in as much detail as possible.
39
+ This is the user's question:
40
+ {question}
41
+ There are logs of previous subtasks and execution results:
42
+ {previous_log}
43
+ Output:"""
44
+ final_answer_generation_prompt = PromptTemplate.from_template(final_answer_generation_prompt)
45
+
46
+ answer_generation_direct_prompt = """"You need to answer the user's question.
47
+ This is the user's question: {question}
48
+ Output:"""
49
+ answer_generation_direct_prompt = PromptTemplate.from_template(answer_generation_direct_prompt)
50
+
51
+ hotpot_answer_parser_prompt = """
52
+ You will need to extract a concise answer from the detailed answer to answer the question in a consice language style.
53
+ Only output your concise answer to the answer.
54
+ For example:
55
+ Question:
56
+ VIVA Media AG changed it's name in 2004. What does their new acronym stand for?
57
+ Detailed Answer:
58
+ The new name of VIVA Media AG since its change in 2004 is "VIVA Media GmbH". In this acronym, "GmbH" is a German term which means "Gesellschaft mit beschränkter Haftung", translating to "company with limited liability" in English. So, the acronym denotes that it is a type of business organization similar to a limited liability company (LLC). VIVA Media GmbH is a company that specializes in publishing, producing, and developing high-quality games for different platforms.
59
+ Output:
60
+ Gesellschaft mit beschränkter Haftung
61
+
62
+ Question:
63
+ Jaclyn Stapp is married to the former frontman of a band that disbanded in what year?
64
+ Detailed Answer:
65
+ The band Creed effectively ended on December 29, 2002. However, they had a reunion tour that started on August 6, 2009, and ended on October 20, 2009. They also released an album called "Full Circle" on October 27, 2009. Despite these reunions, the band's meteoric rise came to a halt when it split up again in 2004.
66
+ Output:
67
+ 2004
68
+
69
+ Question:
70
+ {question}
71
+ Detailed Answer:
72
+ {detailed_answer}
73
+ Output:"""
Smurfs/agents/base.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from pydantic import BaseModel, Field
3
+ from Smurfs.model.base import BaseLM
4
+ # import json
5
+ import os
6
+ from langchain.prompts import PromptTemplate
7
+ from typing import Any
8
+ from termcolor import colored
9
+
10
+ class BaseAgent(BaseModel):
11
+ name: str
12
+ llm: BaseLM
13
+ prompt: Any
14
+ logger_dir: str
15
+ memory: dict = Field(default={})
16
+ max_retry: int = Field(default=10)
17
+
18
+ @abstractmethod
19
+ def run(self, **kwargs):
20
+ """agent run one step"""
21
+ pass
22
+
23
+ @abstractmethod
24
+ def get_memory(self, **kwargs):
25
+ """get relevant memory and add it to agent's memory"""
26
+ pass
27
+
28
+ @abstractmethod
29
+ def get_prompt(self, **kwargs):
30
+ """get the prompt for the agent"""
31
+ pass
32
+
33
+ def log(self, query_id, content):
34
+ """write log to the logger file"""
35
+ logger_file = os.path.join(self.logger_dir, f"{query_id}.txt")
36
+ with open(logger_file, "a+") as file:
37
+ file.write("\n##########\n")
38
+ file.write(f"{self.name}: \n\n")
39
+ file.write(str(content))
40
+ file.write("\n##########\n")
41
+
42
+ def colorful_print(self, content, task):
43
+ """print out message in different color"""
44
+ role_to_color = {
45
+ "Answer Agent": "red",
46
+ "Executor Agent": "green",
47
+ "Planning Agent": "blue",
48
+ "Verifier Agent": "yellow",
49
+ }
50
+
51
+ print(colored(f"##########{task}##########\n{content}\n", role_to_color[self.name]))
Smurfs/agents/executor_agent/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (176 Bytes). View file
 
Smurfs/agents/executor_agent/__pycache__/executor.cpython-39.pyc ADDED
Binary file (5.79 kB). View file
 
Smurfs/agents/executor_agent/__pycache__/prompt.cpython-39.pyc ADDED
Binary file (2.44 kB). View file
 
Smurfs/agents/executor_agent/executor.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from Smurfs.agents.base import BaseAgent
2
+ from Smurfs.agents.executor_agent.prompt import generate_thought_prompt, choose_tool_prompt, choose_parameter_prompt
3
+ from Smurfs.inference.utils import change_name, standardize, contain
4
+ from Smurfs.inference.server import get_rapidapi_response
5
+ from typing import Any
6
+ import json
7
+ import os
8
+ import time
9
+ import requests
10
+
11
+ class executor_agent(BaseAgent):
12
+ thought_prompt: Any
13
+ tool_prompt: Any
14
+ parameter_prompt: Any
15
+
16
+ def __init__(self, *args, **kwargs):
17
+ thought_prompt = generate_thought_prompt
18
+ tool_prompt = choose_tool_prompt
19
+ parameter_prompt = choose_parameter_prompt
20
+ name = "Executor Agent"
21
+ kwargs.update({"thought_prompt": thought_prompt})
22
+ kwargs.update({"tool_prompt": tool_prompt})
23
+ kwargs.update({"parameter_prompt": parameter_prompt})
24
+ kwargs.update({"name": name})
25
+ super().__init__(
26
+ *args,
27
+ **kwargs,
28
+ )
29
+
30
+ def run(self, query_id, task, **kwargs):
31
+ """agent run one step"""
32
+ if task == "thought":
33
+ self.get_memory(**kwargs)
34
+ agent_prompt = self.get_prompt(task)
35
+ message = [{'role': 'user',
36
+ 'content': agent_prompt}]
37
+ ind = 0
38
+
39
+ while True:
40
+ try:
41
+ result = self.llm.prediction(message)
42
+ self.log(query_id, result)
43
+ self.colorful_print(result, "Thought Generation")
44
+ return result
45
+ except Exception as e:
46
+ print(f"generating thought fails: {e}")
47
+ self.log(query_id, f"generating thought fails: {e}")
48
+ if ind > self.max_retry:
49
+ return -1
50
+ ind += 1
51
+ continue
52
+
53
+ elif task == "tool":
54
+ thought = kwargs["thought"]
55
+ kwargs["question"] = kwargs["question"]+f"thought: {thought}\n"
56
+ del kwargs["thought"]
57
+ self.get_memory(**kwargs)
58
+ agent_prompt = self.get_prompt(task)
59
+ message = [{'role': 'user',
60
+ 'content': agent_prompt}]
61
+ ind = 0
62
+ while True:
63
+ try:
64
+ result = self.llm.prediction(message)
65
+ start = result.find("{")
66
+ end = result.rfind("}")
67
+ result = eval(result[start:end+1])
68
+ self.colorful_print(result, "Choose Tool")
69
+ tool = result['ID']
70
+ self.log(query_id, result)
71
+ return tool
72
+ except Exception as e:
73
+ print(f"choosing tool fails: {e}")
74
+ self.log(query_id, f"choosing tool fails: {e}")
75
+ if ind > self.max_retry:
76
+ return -1
77
+ ind += 1
78
+ continue
79
+
80
+ elif task == "parameter":
81
+ thought = kwargs["thought"]
82
+ del kwargs["thought"]
83
+ kwargs["question"] = kwargs["question"]+f"thought: {thought}\n"
84
+ api_dic = kwargs["api_dic"]
85
+ if len(api_dic["required_parameters"]) == 0 and len(api_dic["optional_parameters"]) == 0:
86
+ return {}
87
+ self.get_memory(**kwargs)
88
+ agent_prompt = self.get_prompt(task)
89
+ message = [{'role': 'user',
90
+ 'content': agent_prompt}]
91
+
92
+ ind = 0
93
+ while True:
94
+ try:
95
+ result = self.llm.prediction(message)
96
+ start = result.find("{")
97
+ end = result.rfind("}")
98
+ self.colorful_print(result[start:end+1], "Generate Parameters")
99
+ result = result[start:end+1]
100
+ clean_answer = eval(
101
+ result.replace(": true", ": True").replace(": false", ": False").replace("```", "").strip())
102
+ # a = clean_answer["Parameters"]
103
+ # clean_answer = clean_answer["Parameters"]
104
+ self.log(query_id, clean_answer)
105
+ return clean_answer
106
+ except Exception as e:
107
+ print(f"choose parameter fails: {e}")
108
+ self.log(query_id, f"choose parameter fails: {e}")
109
+ if ind > self.max_retry:
110
+ return -1
111
+ ind += 1
112
+ continue
113
+
114
+ def get_memory(self, **kwargs):
115
+ """get relevant memory and add it to agent's memory"""
116
+ self.memory = kwargs
117
+
118
+ def get_prompt(self, task):
119
+ """get the prompt for the agent"""
120
+ if task == "thought":
121
+ agent_prompt = self.thought_prompt.format(**self.memory)
122
+ elif task == "tool":
123
+ agent_prompt = self.tool_prompt.format(**self.memory)
124
+ elif task == "parameter":
125
+ agent_prompt = self.parameter_prompt.format(**self.memory)
126
+
127
+ return agent_prompt
128
+
129
+ class stream_executor_agent(BaseAgent):
130
+ thought_prompt: Any
131
+ tool_prompt: Any
132
+ parameter_prompt: Any
133
+
134
+ def __init__(self, *args, **kwargs):
135
+ thought_prompt = generate_thought_prompt
136
+ tool_prompt = choose_tool_prompt
137
+ parameter_prompt = choose_parameter_prompt
138
+ name = "Executor Agent"
139
+ kwargs.update({"thought_prompt": thought_prompt})
140
+ kwargs.update({"tool_prompt": tool_prompt})
141
+ kwargs.update({"parameter_prompt": parameter_prompt})
142
+ kwargs.update({"name": name})
143
+ super().__init__(
144
+ *args,
145
+ **kwargs,
146
+ )
147
+
148
+ def run(self, query_id, task, **kwargs):
149
+ """agent run one step"""
150
+ if task == "thought":
151
+ self.get_memory(**kwargs)
152
+ agent_prompt = self.get_prompt(task)
153
+ message = [{'role': 'user',
154
+ 'content': agent_prompt}]
155
+ ind = 0
156
+
157
+ while True:
158
+ try:
159
+ result = self.llm.prediction(message)
160
+ self.log(query_id, result)
161
+ self.colorful_print(result, "Thought Generation")
162
+ return result, "Thought Generation", self.name, result
163
+ except Exception as e:
164
+ print(f"generating thought fails: {e}")
165
+ self.log(query_id, f"generating thought fails: {e}")
166
+ if ind > self.max_retry:
167
+ return -1, "Thought Generation", self.name, str(e)
168
+ ind += 1
169
+ continue
170
+
171
+ elif task == "tool":
172
+ thought = kwargs["thought"]
173
+ kwargs["question"] = kwargs["question"]+f"thought: {thought}\n"
174
+ del kwargs["thought"]
175
+ self.get_memory(**kwargs)
176
+ agent_prompt = self.get_prompt(task)
177
+ message = [{'role': 'user',
178
+ 'content': agent_prompt}]
179
+ ind = 0
180
+ while True:
181
+ try:
182
+ result = self.llm.prediction(message)
183
+ start = result.find("{")
184
+ end = result.rfind("}")
185
+ result = eval(result[start:end+1])
186
+ self.colorful_print(result, "Choose Tool")
187
+ tool = result['ID']
188
+ self.log(query_id, result)
189
+ return tool, "Choose Tool", self.name, result
190
+ except Exception as e:
191
+ print(f"choosing tool fails: {e}")
192
+ self.log(query_id, f"choosing tool fails: {e}")
193
+ if ind > self.max_retry:
194
+ return -1, "Choose Tool", self.name, str(e)
195
+ ind += 1
196
+ continue
197
+
198
+ elif task == "parameter":
199
+ thought = kwargs["thought"]
200
+ del kwargs["thought"]
201
+ kwargs["question"] = kwargs["question"]+f"thought: {thought}\n"
202
+ api_dic = kwargs["api_dic"]
203
+ if len(api_dic["required_parameters"]) == 0 and len(api_dic["optional_parameters"]) == 0:
204
+ return {}
205
+ self.get_memory(**kwargs)
206
+ agent_prompt = self.get_prompt(task)
207
+ message = [{'role': 'user',
208
+ 'content': agent_prompt}]
209
+
210
+ ind = 0
211
+ while True:
212
+ try:
213
+ result = self.llm.prediction(message)
214
+ start = result.find("{")
215
+ end = result.rfind("}")
216
+ self.colorful_print(result[start:end+1], "Generate Parameters")
217
+ result = result[start:end+1]
218
+ clean_answer = eval(
219
+ result.replace(": true", ": True").replace(": false", ": False").replace("```", "").strip())
220
+ # a = clean_answer["Parameters"]
221
+ # clean_answer = clean_answer["Parameters"]
222
+ self.log(query_id, clean_answer)
223
+ return clean_answer, "Generate Parameters", self.name, result[start:end+1]
224
+ except Exception as e:
225
+ print(f"choose parameter fails: {e}")
226
+ self.log(query_id, f"choose parameter fails: {e}")
227
+ if ind > self.max_retry:
228
+ return -1, "Generate Parameters", self.name, str(e)
229
+ ind += 1
230
+ continue
231
+
232
+ def get_memory(self, **kwargs):
233
+ """get relevant memory and add it to agent's memory"""
234
+ self.memory = kwargs
235
+
236
+ def get_prompt(self, task):
237
+ """get the prompt for the agent"""
238
+ if task == "thought":
239
+ agent_prompt = self.thought_prompt.format(**self.memory)
240
+ elif task == "tool":
241
+ agent_prompt = self.tool_prompt.format(**self.memory)
242
+ elif task == "parameter":
243
+ agent_prompt = self.parameter_prompt.format(**self.memory)
244
+
245
+ return agent_prompt
246
+
Smurfs/agents/executor_agent/prompt.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts import PromptTemplate
2
+
3
+ generate_thought_prompt = """\
4
+ You need to analyse the previous execution history and generate your internal reasoning and thoughts on the task, and how you plan to solve it based on the current attempts.
5
+
6
+ Do not output thought that is too long. Output in 2-3 sentences is OK.
7
+
8
+ This is the user's task:
9
+ {question}
10
+
11
+ This is the Tool List:
12
+ {tool_list}
13
+
14
+ This is the previous execution history:
15
+ {previous_log}
16
+
17
+ This is the hint comes from the evaluator:
18
+ {hint}
19
+
20
+ Output:"""
21
+ generate_thought_prompt = PromptTemplate.from_template(generate_thought_prompt)
22
+
23
+ choose_tool_prompt = """\
24
+ This is the user's question:
25
+ {question}
26
+ These are the tools you can select to solve the question:
27
+ Tool List:
28
+ {tool_list}
29
+
30
+ Please note that:
31
+ 1. You should only chooce one tool from the Tool List to solve this question.
32
+ 2. You must ONLY output the ID of the tool and your reason for choosing it in a parsible JSON format. An example output looks like:
33
+ '''
34
+ Example: {{\"ID\": ID of the tool, \"Reason\": The reason for choosing the tool}}
35
+ '''
36
+ Output: """
37
+ choose_tool_prompt = PromptTemplate.from_template(choose_tool_prompt)
38
+
39
+ choose_parameter_prompt="""\
40
+ Given a user's question and a API tool documentation, you need to output parameters according to the API tool documentation to successfully call the API to solve the user's question.
41
+ Please note that:
42
+ 1. The Example in the API tool documentation can help you better understand the use of the API.
43
+ 2. Ensure the parameters you output are correct. The output must contain the required parameters, and can contain the optional parameters based on the question. If no paremters in the required parameters and optional parameters, just leave it as {{}}
44
+ 3. If the user's question mentions other APIs, you should ONLY consider the API tool documentation I give and do not consider other APIs.
45
+ 4. The question may have dependencies on answers of other questions, so we will provide logs of previous questions and answers for your reference.
46
+ 5. You must ONLY output in a parsible JSON Format. The example output looks like:
47
+ '''
48
+ Example: {{\"keyword\": \"Artificial Intelligence\", \"language\": \"English\"}}
49
+ '''
50
+
51
+ There are logs of previous questions and answers:
52
+ {previous_log}
53
+
54
+ This is the current user's question: {question}
55
+
56
+ This is API tool documentation: {api_dic}
57
+ Output:"""
58
+ choose_parameter_prompt = PromptTemplate.from_template(choose_parameter_prompt)
Smurfs/agents/memory_agent/memory_agent.py ADDED
File without changes
Smurfs/agents/memory_agent/prompt.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts import PromptTemplate
2
+
3
+ mem_choose_prompt = """You are a memory agent that controls the memory of the agent system.
4
+ The agent system is trying to solve a complex question step by step by solving its subtasks one by one.
5
+ Among those subtasks, some subtask may need execution history of other subtasks to be solved.
6
+ Your task is to decide which subtasks' execution history is needed by the agent system to solve the current subtask.
7
+ Please note that:
8
+ 1. If the current subtask is independent of the other subtasks, just output {{\"task\":}}
9
+ 2.
10
+ You must only output in a parsible JSON format. Two example outputs look like:
11
+ Example 1: {{\"Reason\": \"The reason why you think you do not need to call an external API to solve the user's question\", \"Choice\": \"No\"}}
12
+ Example 2: {{\"Reason\": \"The reason why you think you need to call an external API to solve the user's question\", \"Choice\": \"Yes\"}}
13
+ This is the current subtask: {question}
14
+ This is the previous execution history: {history}
15
+ Output: """
16
+ tool_check_prompt = PromptTemplate.from_template(mem_choose_prompt)
Smurfs/agents/planning_agent/__pycache__/planner.cpython-39.pyc ADDED
Binary file (3.99 kB). View file
 
Smurfs/agents/planning_agent/__pycache__/prompt.cpython-39.pyc ADDED
Binary file (3.1 kB). View file
 
Smurfs/agents/planning_agent/planner.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from Smurfs.agents.base import BaseAgent
2
+ from Smurfs.agents.planning_agent.prompt import task_decompose_prompt, hotpot_task_decompose_prompt
3
+
4
+ class planning_agent(BaseAgent):
5
+ def __init__(self, llm, logger_dir):
6
+ super().__init__(
7
+ prompt = task_decompose_prompt,
8
+ llm = llm,
9
+ name = "Planning Agent",
10
+ logger_dir = logger_dir
11
+ )
12
+
13
+ def run(self, question, query_id):
14
+ """agent run one step"""
15
+ self.get_memory(question)
16
+ agent_prompt = self.get_prompt()
17
+ message = [{'role': 'user',
18
+ 'content': agent_prompt}]
19
+ ind = 0
20
+ while True:
21
+ try:
22
+ result = self.llm.prediction(message)
23
+ # print(result)
24
+ start = result.find("{")
25
+ end = result.find("}")
26
+ result = eval(result[start:end+1])
27
+ self.colorful_print(result, "Task Decompose")
28
+ subtasks = result['Tasks']
29
+ self.log(query_id, result)
30
+ # print(a)
31
+ return subtasks
32
+ except Exception as e:
33
+ print(f"task deompose fails: {e}")
34
+ self.log(query_id, f"task deompose fails: {e}")
35
+ if ind > self.max_retry:
36
+ return -1
37
+ ind += 1
38
+ continue
39
+
40
+ def get_memory(self, question):
41
+ """get relevant memory and add it to agent's memory"""
42
+ self.memory = {"question": question}
43
+
44
+ def get_prompt(self):
45
+ """get the prompt for the agent"""
46
+ agent_prompt = self.prompt.format(**self.memory)
47
+ return agent_prompt
48
+
49
+ class hotpot_planning_agent(BaseAgent):
50
+ def __init__(self, llm, logger_dir):
51
+ super().__init__(
52
+ prompt = hotpot_task_decompose_prompt,
53
+ llm = llm,
54
+ name = "Planning Agent",
55
+ logger_dir = logger_dir
56
+ )
57
+
58
+ def run(self, question, query_id):
59
+ """agent run one step"""
60
+ self.get_memory(question)
61
+ agent_prompt = self.get_prompt()
62
+ message = [{'role': 'user',
63
+ 'content': agent_prompt}]
64
+ ind = 0
65
+ while True:
66
+ try:
67
+ result = self.llm.prediction(message)
68
+ # print(result)
69
+ start = result.find("{")
70
+ end = result.find("}")
71
+ result = eval(result[start:end+1])
72
+ self.colorful_print(result, "Task Decompose")
73
+ subtasks = result['Tasks']
74
+ self.log(query_id, result)
75
+ # print(a)
76
+ return subtasks
77
+ except Exception as e:
78
+ print(f"task deompose fails: {e}")
79
+ self.log(query_id, f"task deompose fails: {e}")
80
+ if ind > self.max_retry:
81
+ return -1
82
+ ind += 1
83
+ continue
84
+
85
+ def get_memory(self, question):
86
+ """get relevant memory and add it to agent's memory"""
87
+ self.memory = {"question": question}
88
+
89
+ def get_prompt(self):
90
+ """get the prompt for the agent"""
91
+ agent_prompt = self.prompt.format(**self.memory)
92
+ return agent_prompt
93
+
94
+ class stream_hotpot_planning_agent(BaseAgent):
95
+ def __init__(self, llm, logger_dir):
96
+ super().__init__(
97
+ prompt = hotpot_task_decompose_prompt,
98
+ llm = llm,
99
+ name = "Planning Agent",
100
+ logger_dir = logger_dir
101
+ )
102
+
103
+ def run(self, question, query_id):
104
+ """agent run one step"""
105
+ self.get_memory(question)
106
+ agent_prompt = self.get_prompt()
107
+ message = [{'role': 'user',
108
+ 'content': agent_prompt}]
109
+ ind = 0
110
+ while True:
111
+ try:
112
+ result = self.llm.prediction(message)
113
+ # print(result)
114
+ start = result.find("{")
115
+ end = result.find("}")
116
+ result = eval(result[start:end+1])
117
+ self.colorful_print(result, "Task Decompose")
118
+ subtasks = result['Tasks']
119
+ self.log(query_id, result)
120
+ # print(a)
121
+ return subtasks, "Task Decompose", self.name, result
122
+ except Exception as e:
123
+ print(f"task deompose fails: {e}")
124
+ self.log(query_id, f"task deompose fails: {e}")
125
+ if ind > 5:
126
+ return -1, "Task Decompose", self.name, str(e)
127
+ ind += 1
128
+ continue
129
+
130
+ def get_memory(self, question):
131
+ """get relevant memory and add it to agent's memory"""
132
+ self.memory = {"question": question}
133
+
134
+ def get_prompt(self):
135
+ """get the prompt for the agent"""
136
+ agent_prompt = self.prompt.format(**self.memory)
137
+ return agent_prompt
Smurfs/agents/planning_agent/prompt.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts import PromptTemplate
2
+
3
+ task_decompose_prompt = """
4
+ You need to decompose a complex user's question into some simple subtasks and let the model execute it step by step.
5
+ Please note that:
6
+ 1. You should only decompose this complex user's question into some simple subtasks which can be executed easily by using a single tool.
7
+ 2. Each simple subtask should be expressed into natural language.
8
+ 3. Each subtask should contain the necessary information from the original question and should be complete, explicit and self-consistent.
9
+ 4. You must ONLY output in a parsible JSON format. An example output looks like:
10
+ '''
11
+ {{\"Tasks\": [\"Task 1\", \"Task 2\", ...]}}
12
+ '''
13
+
14
+ This is the user's question: I'm planning a trip to Turkey and need information about postal codes in Istanbul. Can you provide me with the postal code and district for Istanbul province with plate number 34? Additionally, I would like to know if there are any transit agencies available in Istanbul. Please fetch their names and contact numbers.
15
+ Output: {{\"Tasks\": [\"Find the postal codes and districts for plate number 34 in Istanbul.\", \"Search for transit agencies and their contact numbers in Istanbul.\"]}}
16
+
17
+ This is the user's question: I recently moved to a new address and I need to update my information. Can you retrieve my address details using the postal code 75094080? Additionally, I would like to know the companies that offer shipping services.
18
+ Output: {{\"Tasks\": [\"retrieve the address details using the postal code 75094080\", \"search for companies that offer shipping services to my address\"]}}
19
+
20
+ This is the user's question: {question}
21
+ Output:
22
+ """
23
+ task_decompose_prompt = PromptTemplate.from_template(task_decompose_prompt)
24
+
25
+ hotpot_task_decompose_prompt = """
26
+ You need to decompose a complex user's question into some simple subtasks and let the model execute it step by step.
27
+ Please note that:
28
+ 1. You should only decompose this complex user's question into some simple subtasks which can be executed easily by using a single tool.
29
+ 2. Each simple subtask should be expressed into natural language.
30
+ 3. Each subtask should contain the necessary information from the original question and should be complete, explicit and self-consistent.
31
+ 4. You must ONLY output in a parsible JSON format. An example output looks like:
32
+ '''
33
+ {{\"Tasks\": [\"Task 1\", \"Task 2\", ...]}}
34
+ '''
35
+
36
+ This is the user's question: What government position was held by the woman who portrayed Corliss Archer in the film Kiss and Tell?
37
+ Output: {{\"Tasks\": [\"In the film Kiss and Tell, who is the woman who portrayed Corliss Archer?\", \"What government position was held by this woman?\"]}}
38
+
39
+ This is the user's question: Were Scott Derrickson and Ed Wood of the same nationality?
40
+ Output: {{\"Tasks\": [\"search for the nationality of Scott Derrickson\", \"search for the nationality for Ed Wood\", \"Compare whether they have the same nationality\"]}}
41
+
42
+ This is the user's question: {question}
43
+ Output:
44
+ """
Smurfs/agents/verifier_agent/__pycache__/prompt.cpython-39.pyc ADDED
Binary file (1.51 kB). View file
 
Smurfs/agents/verifier_agent/__pycache__/verifier.cpython-39.pyc ADDED
Binary file (3.03 kB). View file
 
Smurfs/agents/verifier_agent/prompt.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.prompts import PromptTemplate
2
+
3
+ final_answer_check_prompt = """
4
+ An agent is trying to solve the query proposed by the user. \
5
+ You need to evaluate whether the given query has been completed reasonably and accurately. If so, summarize the solution to the user. If not, summarize the current progress, and propose what is missing.
6
+
7
+ You response contains following elements:
8
+ Speak: (your words to the agent if the task isn't completed, or a complete answer based on the full execution log to the user if the task is finished)
9
+ Status: (0 or 1. 0 for unfinished and 1 for finished)
10
+
11
+ Please note that:
12
+ 1. If the answer says the query can't be solved or it can't answer the query given the current information, please output Status as 0.
13
+ 2. Only output Status as 1 if the query has been answered correctly and accurately.
14
+ 3. If the answer only give a plan instead of a detailed answer, output Status as 0.
15
+
16
+ You must only output in a parsible JSON format. Two example outputs look like:
17
+ Example 1: {{\"Speak\": \"answer based on the full execution log to the user\", \"Status\": \"1\"}}
18
+ Example 2: {{\"Speak\": \"your words to the group if the task isn't solved\", \"Status\": \"0\"}}
19
+
20
+ This is the answer from the previous execution result:
21
+ {answer}
22
+
23
+ This is the original question: {question}
24
+ Output: """
25
+ final_answer_check_prompt = PromptTemplate.from_template(final_answer_check_prompt)
Smurfs/agents/verifier_agent/verifier.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from Smurfs.agents.base import BaseAgent
2
+ from Smurfs.agents.verifier_agent.prompt import final_answer_check_prompt
3
+
4
+ class verifier_agent(BaseAgent):
5
+ def __init__(self, llm, logger_dir):
6
+ super().__init__(
7
+ prompt = final_answer_check_prompt,
8
+ llm = llm,
9
+ name = "Verifier Agent",
10
+ logger_dir = logger_dir
11
+ )
12
+
13
+ def run(self, question, answer, query_id):
14
+ """agent run one step"""
15
+ self.get_memory(question, answer)
16
+ agent_prompt = self.get_prompt()
17
+ message = [{'role': 'user',
18
+ 'content': agent_prompt}]
19
+ ind = 0
20
+ while True:
21
+ try:
22
+ result = self.llm.prediction(message)
23
+ start = result.find("{")
24
+ end = result.find("}")
25
+ self.colorful_print(result, "Answer Verify")
26
+ self.log(query_id, result)
27
+ clean_result = eval(result[start:end+1])
28
+ speak = clean_result["Speak"]
29
+ status = clean_result["Status"]
30
+ return speak, status
31
+ except Exception as e:
32
+ print(f"final answer check fails: {e}")
33
+ self.log(query_id, f"final answer check fails: {e}")
34
+ if ind > self.max_retry:
35
+ return -1, -1
36
+ ind += 1
37
+ continue
38
+
39
+ def get_memory(self, question, answer):
40
+ """get relevant memory and add it to agent's memory"""
41
+ self.memory = {"question": question, "answer": answer}
42
+
43
+ def get_prompt(self):
44
+ """get the prompt for the agent"""
45
+ agent_prompt = self.prompt.format(**self.memory)
46
+ return agent_prompt
47
+
48
+ class stream_verifier_agent(BaseAgent):
49
+ def __init__(self, llm, logger_dir):
50
+ super().__init__(
51
+ prompt = final_answer_check_prompt,
52
+ llm = llm,
53
+ name = "Verifier Agent",
54
+ logger_dir = logger_dir
55
+ )
56
+
57
+ def run(self, question, answer, query_id):
58
+ """agent run one step"""
59
+ self.get_memory(question, answer)
60
+ agent_prompt = self.get_prompt()
61
+ message = [{'role': 'user',
62
+ 'content': agent_prompt}]
63
+ ind = 0
64
+ while True:
65
+ try:
66
+ result = self.llm.prediction(message)
67
+ start = result.find("{")
68
+ end = result.find("}")
69
+ self.colorful_print(result, "Answer Verify")
70
+ self.log(query_id, result)
71
+ clean_result = eval(result[start:end+1])
72
+ speak = clean_result["Speak"]
73
+ status = clean_result["Status"]
74
+ return speak, status, "Answer Verify", self.name, result
75
+ except Exception as e:
76
+ print(f"final answer check fails: {e}")
77
+ self.log(query_id, f"final answer check fails: {e}")
78
+ if ind > self.max_retry:
79
+ return -1, -1, "Answer Verify", self.name, str(e)
80
+ ind += 1
81
+ continue
82
+
83
+ def get_memory(self, question, answer):
84
+ """get relevant memory and add it to agent's memory"""
85
+ self.memory = {"question": question, "answer": answer}
86
+
87
+ def get_prompt(self):
88
+ """get the prompt for the agent"""
89
+ agent_prompt = self.prompt.format(**self.memory)
90
+ return agent_prompt
Smurfs/data/.DS_Store ADDED
Binary file (6.15 kB). View file
 
Smurfs/data/__init__.py ADDED
File without changes
Smurfs/data/post_process.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
4
+ # print(sys.path)
5
+ import json
6
+ from Smurfs.data.utils import tree_steps_counter, total_path_transform
7
+ import argparse
8
+
9
+ def parse_arg():
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument('--input_dir', type=str, default="dir_to_your_data", required=False, help='the directory of the data that needs post-processing')
12
+ parser.add_argument('--example_dir', type=str, default="dir_to_example_data", required=False, help='the directory of the example data')
13
+ parser.add_argument('--test_sets', nargs='+', type=str, required=False, help='the test sets that need processing. It should be G2_instruction, G2_category or G3_instruction')
14
+ args = parser.parse_args()
15
+ return args
16
+
17
+ def main():
18
+ args = parse_arg()
19
+ input_dir = args.input_dir
20
+ test_sets = args.test_sets
21
+ example_dir = args.example_dir
22
+ for test_set in test_sets:
23
+ data_path = os.path.join(input_dir, f"{test_set}_raw.json")
24
+ example_data_path = os.path.join(example_dir, f"{test_set}.json")
25
+
26
+ with open(data_path, 'r') as file:
27
+ g_data = json.load(file)
28
+ with open(example_data_path, 'r') as file:
29
+ g_example_data = json.load(file)
30
+ g_new_data = {}
31
+ for g_d in g_data:
32
+ m = False
33
+ for d in g_example_data:
34
+ if g_data[g_d]["query"] == g_example_data[d]["query"]:
35
+ g_new_data_ele = {"query":"", "available_tools": [], "answer":{}}
36
+ g_new_answer_ele = {
37
+ "method": "smurfs",
38
+ "total_steps": 0,
39
+ "final_answer": "",
40
+ "answer_details": []
41
+ }
42
+ g_new_data_ele["query"] = g_data[g_d]["query"]
43
+ g_new_data_ele["available_tools"] = g_example_data[d]["available_tools"]
44
+ g_new_answer_ele["answer_details"] = [g_data[g_d]["answer"]["answer_details"]]
45
+ counter = tree_steps_counter(0)
46
+ counter.count_total_steps(g_data[g_d]["answer"]["answer_details"])
47
+ g_new_answer_ele["total_steps"] = counter.get_steps()
48
+ g_new_answer_ele["final_answer"] = g_data[g_d]["answer"]["final_answer"]
49
+ g_new_data_ele["answer"] = g_new_answer_ele
50
+ g_new_data[d] = g_new_data_ele
51
+ m = True
52
+ break
53
+ if not m:
54
+ print(f"{test_set} mismatch! The key is: {g_d}")
55
+ if test_set == "G2_category":
56
+ duplicate = g_new_data["43201"]
57
+ g_new_data["43200"] = duplicate
58
+
59
+ output_path = os.path.join(input_dir, f"{test_set}.json")
60
+ print(output_path)
61
+ with open(output_path, 'w') as file:
62
+ json.dump(g_new_data, file, indent=4, ensure_ascii=False)
63
+
64
+ if __name__ == '__main__':
65
+ main()
Smurfs/data/utils.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class tree_steps_counter:
2
+ def __init__(self, steps):
3
+ self.steps = steps
4
+
5
+ def count_total_steps(self, root):
6
+ self.steps += 1
7
+ if root["next"] == []:
8
+ return
9
+ for i in range(len(root["next"])):
10
+ self.count_total_steps(root["next"][i])
11
+
12
+ def get_steps(self):
13
+ return self.steps
14
+
15
+ def total_path_transform(data, index):
16
+ finish_template = [{
17
+ "role": "tool",
18
+ "message": {
19
+ "name": "Finish",
20
+ "arguments": {
21
+ "return_type": "give_answer",
22
+ "final_answer": data[index]["final_answer"]
23
+ },
24
+ "response": ""
25
+ },
26
+ "next": []
27
+ }]
28
+ answer_path = data[index]["total_path"]["next"][0]["next"]
29
+ for i in range(len(answer_path)-1, -1, -1):
30
+ if answer_path[i]["role"] == "plan_global":
31
+ if answer_path[i]["next"] != []:
32
+ current_log = answer_path[i]["next"][-1]
33
+ answer_path[i] = answer_path[i]["next"]
34
+ else:
35
+ if i == len(answer_path)-1:
36
+ answer_path[i] = finish_template[0]
37
+ continue
38
+ else:
39
+ current_log = answer_path[i]
40
+ while current_log["next"] != []:
41
+ current_log = current_log["next"][-1]
42
+ if i == len(answer_path)-1:
43
+ current_log["next"] = finish_template
44
+ else:
45
+ if not isinstance(answer_path[i+1], list):
46
+ current_log["next"] = [answer_path[i+1]]
47
+ else:
48
+ current_log["next"] = answer_path[i+1]
49
+ if not isinstance(answer_path[0], list):
50
+ data[index]["total_path"]["next"][0]["next"] = [answer_path[0]]
51
+ else:
52
+ data[index]["total_path"]["next"][0]["next"] = answer_path[0]
53
+
Smurfs/deploy/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ global_dict = {
2
+ "knowledge_base" : None
3
+ }
Smurfs/deploy/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (203 Bytes). View file
 
Smurfs/deploy/cli_inference.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ # 抑制所有警告
4
+ warnings.filterwarnings('ignore')
5
+
6
+ import os
7
+ import sys
8
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
9
+ from Smurfs.inference.smurfs_worker import smurfs_hotpot_worker, smurfs_worker
10
+ # from Smurfs.tools.tool_env import HotpotToolEnv
11
+ from Smurfs.tools.tool_env import tool_env
12
+ from Smurfs.model.openai_model.openai_model import OpenAI_Model, OpenRouter_Model
13
+ from Smurfs.agents.answer_agent.answer import answer_agent
14
+ from Smurfs.agents.executor_agent.executor import executor_agent
15
+ from Smurfs.agents.planning_agent.planner import hotpot_planning_agent
16
+ from Smurfs.agents.verifier_agent.verifier import verifier_agent
17
+ import json
18
+ import threading
19
+ import joblib
20
+ from tqdm import tqdm
21
+ import time
22
+
23
+ def run(worker, query, query_id):
24
+ # global lock
25
+ final_answer, output_file_ele, solution_file_ele = worker.run(query, query_id)
26
+ # lock.acquire()
27
+ worker.save_solution(output_file_ele, solution_file_ele, query_id)
28
+ # lock.release()
29
+ return final_answer
30
+
31
+ def cli_run(query, worker):
32
+ pre = run(worker, query, 0)
33
+ return pre
34
+
35
+ if __name__ == '__main__':
36
+ # model_name = "mistralai/mistral-7b-instruct-v0.2"
37
+ model_name = "mistralai/mistral-7b-instruct-v0.2"
38
+ method_name = "cli_inference"
39
+ tool_doc_path = "Smurfs/tools/math_search.json"
40
+ # llm = OpenAI_Model(model_name=model_name)
41
+ llm = OpenRouter_Model(model_name=model_name)
42
+ # parser_llm = OpenAI_Model(model_name="gpt-4")
43
+ with open(tool_doc_path, "r") as f:
44
+ available_tools = json.load(f)
45
+
46
+ test_set = "cli"
47
+
48
+ output_dir = f"data/{method_name}/{test_set}/answer"
49
+ results_dir = f"data/{method_name}/{test_set}/results.json"
50
+ if not os.path.exists(f"data/{method_name}/{test_set}/parser_log"):
51
+ os.makedirs(f"data/{method_name}/{test_set}/parser_log")
52
+ if not os.path.exists(output_dir):
53
+ os.makedirs(output_dir)
54
+ # HP_answer_agent = answer_agent(llm=parser_llm, logger_dir=f"data/{method_name}/{test_set}/parser_log")
55
+ # worker = smurfs_hotpot_worker(available_tools, HotpotToolEnv, llm, method_name, test_set, answer_agent, executor_agent,hotpot_planning_agent, verifier_agent)
56
+ worker = smurfs_worker(available_tools, tool_env, llm, method_name, test_set, answer_agent, executor_agent,hotpot_planning_agent, verifier_agent)
57
+ query = input("Please Enter Your Task: ")
58
+ cli_run(query, worker)
Smurfs/deploy/gradio_inference.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import warnings
4
+
5
+ # 抑制所有警告
6
+ warnings.filterwarnings('ignore')
7
+
8
+ import os
9
+ import sys
10
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
11
+ from Smurfs.inference.smurfs_worker import smurfs_hotpot_worker, smurfs_worker, stream_smurfs_worker
12
+ # from Smurfs.tools.tool_env import HotpotToolEnv
13
+ from Smurfs.deploy import global_dict
14
+ from Smurfs.tools.tool_env import tool_env
15
+ from Smurfs.model.openai_model.openai_model import OpenAI_Model
16
+ from Smurfs.agents.answer_agent.answer import stream_answer_agent
17
+ from Smurfs.agents.executor_agent.executor import stream_executor_agent
18
+ from Smurfs.agents.planning_agent.planner import stream_hotpot_planning_agent
19
+ from Smurfs.agents.verifier_agent.verifier import stream_verifier_agent
20
+
21
+ from Smurfs.tools.docqa.api import tool_env as docqa_tool_env
22
+ from Smurfs.tools.hotpotQA.api import tool_env as hotpot_tool_env
23
+ from Smurfs.tools.math.api import tool_env as math_tool_env
24
+ from Smurfs.tools.shell.api import tool_env as shell_tool_env
25
+ from Smurfs.tools.websearch.api import tool_env as websearch_tool_env
26
+
27
+
28
+ import json
29
+ import threading
30
+ import joblib
31
+ from tqdm import tqdm
32
+ import time
33
+
34
+ from PyPDF2 import PdfReader
35
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
36
+ from langchain.vectorstores import FAISS
37
+ from langchain_openai import OpenAIEmbeddings
38
+ from datetime import datetime
39
+ current_datetime = datetime.now()
40
+ # user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(container=False)
41
+ # inp = gr.Textbox(placeholder="Enter your task")
42
+ # css = """
43
+ # .btn {background-color: blue; color: white;}
44
+ # #bot {background-color: blue; color: white;}
45
+ # #e {display: inline-block; vertical-align: middle;}
46
+ # """
47
+
48
+ # def update(name):
49
+ # return f"<span style='color: red'>Welcome to Gradio, {name}!</span>"
50
+
51
+ tool_env_map = {
52
+ "shell": shell_tool_env,
53
+ "math": math_tool_env,
54
+ "docqa": docqa_tool_env,
55
+ "hotpotQA": hotpot_tool_env,
56
+ "websearch": websearch_tool_env
57
+ }
58
+
59
+ total_env, env_name_list = {}, []
60
+
61
+ def loading():
62
+ return "Loading..."
63
+
64
+ def load_text_from_pdf(up, key=None):
65
+ global global_dict
66
+ if key == None:
67
+ key = os.environ.get("OPENAI_API_KEY")
68
+ pdf_path = up.name
69
+ pdf_reader = PdfReader(pdf_path)
70
+ text = ""
71
+ for page in pdf_reader.pages:
72
+ text += page.extract_text()
73
+
74
+ # split into chunks
75
+ text_splitter = RecursiveCharacterTextSplitter(
76
+ chunk_size=1000,
77
+ chunk_overlap=200,
78
+ add_start_index=True
79
+ )
80
+ chunks = text_splitter.split_text(text)
81
+
82
+ # create embeddings
83
+ # embeddings = OpenAIEmbeddings()
84
+ embeddings = OpenAIEmbeddings(openai_api_key=key)
85
+ global_dict["knowledge_base"] = FAISS.from_texts(chunks, embeddings)
86
+ return "upload success!"
87
+ #return knowledge_base
88
+
89
+ def update(query, OPENAI_API_KEY, BING_SUBSCRIPT_KEY, WOLFRAMALPH_APP_ID, WEATHER_API_KEYS):
90
+ global total_env, env_name_list
91
+ # print(total_env)
92
+ # print(BING_SUBSCRIPT_KEY)
93
+ # print(WOLFRAMALPH_APP_ID)
94
+ # print(WEATHER_API_KEYS)
95
+ # model_name = "mistralai/mistral-7b-instruct-v0.2"
96
+ model_name = "gpt-4"
97
+ method_name = "cli_inference"
98
+ tool_doc_path = "Smurfs/tools/tool_doc.json"
99
+ if OPENAI_API_KEY == None or OPENAI_API_KEY == '':
100
+ yield [(query, "No OPENAI KEY provided!")]
101
+ raise KeyError
102
+ if (BING_SUBSCRIPT_KEY == None or BING_SUBSCRIPT_KEY == ''):
103
+ yield [(query, "No BING_SUBSCRIPT_KEY provided! Please register one from https://www.microsoft.com/en-us/bing/apis/bing-web-search-api and add it to your keys")]
104
+ raise KeyError
105
+ if WOLFRAMALPH_APP_ID == None or WOLFRAMALPH_APP_ID == '':
106
+ yield [(query, "No WOLFRAMALPH_APP_ID provided! please register one from https://products.wolframalpha.com/api/ and add it to your keys")]
107
+ raise KeyError
108
+ if WEATHER_API_KEYS == None or WEATHER_API_KEYS == '':
109
+ yield [(query, "No WEATHER_API_KEYS provided! Please register one from https://www.weatherapi.com/ and add it to")]
110
+ raise KeyError
111
+ llm = OpenAI_Model(model_name=model_name, api_key=OPENAI_API_KEY)
112
+ #llm = OpenRouter_Model(model_name=model_name)
113
+ if "docqa" in total_env:
114
+ sys_prompt = llm.sys_prompt + "You already have access to the file uploaded by the user. So just answer the question from the user, you don't need to find the file first."
115
+ llm.change_sys_prompt(sys_prompt)
116
+ else:
117
+ llm.set_default_sys_prompt()
118
+ # parser_llm = OpenAI_Model(model_name="gpt-4")
119
+ with open(tool_doc_path, "r") as f:
120
+ tool_doc = json.load(f)
121
+ tool_doc["bing_search"]["api_description"] += f"Today is {current_datetime.year}.{current_datetime.month}.{current_datetime.day}"
122
+ available_tools = []
123
+ for env_name in env_name_list:
124
+ available_tools.append(tool_doc[env_name])
125
+
126
+ test_set = "cli"
127
+
128
+ output_dir = f"data/{method_name}/{test_set}/answer"
129
+ results_dir = f"data/{method_name}/{test_set}/results.json"
130
+ if not os.path.exists(f"data/{method_name}/{test_set}/parser_log"):
131
+ os.makedirs(f"data/{method_name}/{test_set}/parser_log")
132
+ if not os.path.exists(output_dir):
133
+ os.makedirs(output_dir)
134
+ # HP_answer_agent = answer_agent(llm=parser_llm, logger_dir=f"data/{method_name}/{test_set}/parser_log")
135
+ # worker = smurfs_hotpot_worker(available_tools, HotpotToolEnv, llm, method_name, test_set, answer_agent, executor_agent,hotpot_planning_agent, verifier_agent)
136
+ #print(total_env)
137
+ worker = stream_smurfs_worker(available_tools, total_env, llm, method_name, test_set, stream_answer_agent, stream_executor_agent, stream_hotpot_planning_agent, stream_verifier_agent, OPENAI_API_KEY, BING_SUBSCRIPT_KEY, WOLFRAMALPH_APP_ID, WEATHER_API_KEYS)
138
+ stream_generator = worker.run(query, 0)
139
+ # messages = []
140
+ while True:
141
+ try:
142
+ response = next(stream_generator)
143
+ messages = [(query, response)]
144
+ yield messages
145
+ except StopIteration:
146
+ break
147
+ # query = input("Please Enter Your Task: ")
148
+ # cli_run(query, worker)
149
+
150
+ def update_tools(rs):
151
+ global total_env, env_name_list
152
+ total_env = {}
153
+ env_name_list = []
154
+ for tool_system in rs:
155
+ tool = tool_system.split(": ")[0]
156
+ env = tool_env_map[tool]
157
+ print(f"env: {env}")
158
+ for e in env:
159
+ if e not in env_name_list:
160
+ total_env[e] = env[e]
161
+ env_name_list.append(e)
162
+ print(total_env)
163
+ #return total_env, env_name_list
164
+
165
+ def user(user_msg):
166
+ return user_msg
167
+
168
+ tools = ["math: Tool that can handle mathematical problems",
169
+ "docqa: Tool that can answer questions about your uploaded file",
170
+ "hotpotQA: Tool that can do multi-hop commonsense reasoning",
171
+ "websearch: Tool that can do web search to answer your question"]
172
+ websearch_example = ["请根据深圳明天的天气推荐给我推荐一套穿搭方案,结果用中文输出。", "今年的中秋节是哪天?用中文输出"]
173
+ math_example = ["Calc integral of sin(x)+2x^2+3x+1 from 0 to 1", "When both sides of a right triangle are 6 and 8, what is the length of the other side?"]
174
+ inp = gr.Textbox(placeholder="Please input your task", label="Task")
175
+ with gr.Blocks() as demo:
176
+ gr.HTML("""<h1 align="center">Smurfs</h1>""")
177
+ #gr.Markdown("""<figure><a href=https://yoursmiles.org/h-smurf.php><img src=https://yoursmiles.org/hsmile/smurf/h3602.gif></a><a href=https://yoursmiles.org/h-smurf.php><img src=https://yoursmiles.org/hsmile/smurf/h3607.gif></a><a href=https://yoursmiles.org/h-smurf.php><img src=https://yoursmiles.org/hsmile/smurf/h3623.gif></a><a href=https://yoursmiles.org/h-smurf.php><img src=https://yoursmiles.org/hsmile/smurf/h3625.gif></a></figure>""")
178
+ #gr.HTML("""<a href=https://yoursmiles.org/h-smurf.php><img src=https://yoursmiles.org/hsmile/smurf/h3602.gif>""")
179
+ with gr.Row():
180
+ with gr.Column(scale=1):
181
+ inp.render()
182
+ rs = gr.Dropdown(choices=tools, label="Tool Systems", multiselect=True)
183
+ file_output = gr.File(file_types=[".pdf"])
184
+ with gr.Accordion("Keys", open=False):
185
+ # model_name = gr.Dropdown(label="Moel Name", choices=["gpt-3.5", "gpt-4o", "gpt-4"])
186
+ openai_key = gr.Textbox(label="OpenAI API Key", placeholder="Please Enter Your OpenAI API Key")
187
+ bing_search_key = gr.Textbox(label="BingSearch Key", placeholder="Please Enter Your BingSearch Key from https://www.microsoft.com/en-us/bing/apis/bing-web-search-api")
188
+ wolframalpha_key = gr.Textbox(label="Wolframalpha API Key", placeholder="Please Enter Your WOLFRAMALPH_APP_ID from https://products.wolframalpha.com/api/")
189
+ weather_key = gr.Textbox(label="Weather API Key", placeholder="Please Enter Your Weather API Key from https://www.weatherapi.com/")
190
+
191
+
192
+ gr.Examples(["Who is the brother of the 2022 NBA FMVP?", "How much older is Lebron James than his oldest son?", "Calc integral of sin(x)+2x^2+3x+1 from 0 to 1", "Calculate the length of the hypotenuse of a right triangle when the other two sides are 6 and 8", "请根据深圳明天的天气推荐给我推荐一套穿搭方案,结果用中文输出。", "今年的中秋节是哪天?用中文输出"], inp)
193
+ _submit = gr.Button("Submit")
194
+ stop = gr.Button("Stop")
195
+ clear = gr.Button("Clear")
196
+ #upload = gr.UploadButton("Click to upload your pdf file")
197
+ # btn = gr.Button("Run", elem_id="bot", elem_classes="btn")
198
+ #with gr.Column(scale=1, elem_id="e"):
199
+ # chatbox = gr.HTML()
200
+ chatbox = gr.Chatbot(height=300)
201
+ # btn.click(fn=update, inputs=inp, outputs=chatbox)
202
+ file_output.upload(load_text_from_pdf, [file_output, openai_key], None)
203
+ #upload.upload(loading, None, inp).then(load_text_from_pdf, upload, inp)
204
+ rs.change(update_tools, rs, None)
205
+ click_event = _submit.click(user, inp, inp).then(update, [inp, openai_key, bing_search_key, wolframalpha_key, weather_key], chatbox)
206
+ stop.click(None, None, None, cancels=[click_event])
207
+ #inp.submit(user, inp, inp).then(update, inp, chatbox)
208
+ clear.click(lambda: (None, None), None, [inp, chatbox], queue=False)
209
+ # theme=gr.themes.Default().set(button_primary_border_color_dark=, hover)
210
+ demo.load(
211
+ None,
212
+ None,
213
+ _js="""
214
+ () => {
215
+ const params = new URLSearchParams(window.location.search);
216
+ if (!params.has('__theme')){
217
+ params.set('__theme', 'dark');
218
+ window.location.search = params.toString();
219
+ }
220
+ }
221
+ """,
222
+ )
223
+ demo.queue().launch(server_name='0.0.0.0', share=True, inbrowser=False, server_port=7001)
Smurfs/eval/hotpot_qa/__pycache__/utils.cpython-39.pyc ADDED
Binary file (3.46 kB). View file
 
Smurfs/eval/hotpot_qa/post_process.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ def remove_tags(text):
3
+ # 使用正则表达式去除形如<xxx>的字符
4
+ cleaned_text = re.sub(r'<[^>]*>', '', text)
5
+ return cleaned_text
6
+
7
+ import os
8
+ import sys
9
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
10
+ from Smurfs.inference.smurfs_worker import smurfs_hotpot_worker
11
+ from Smurfs.tools.tool_env import HotpotToolEnv
12
+ from Smurfs.model.openai_model.openai_model import OpenAI_Model, OpenRouter_Model
13
+ from Smurfs.agents.answer_agent.answer import answer_agent
14
+ from Smurfs.agents.executor_agent.executor import executor_agent
15
+ from Smurfs.agents.planning_agent.planner import hotpot_planning_agent
16
+ from Smurfs.agents.verifier_agent.verifier import verifier_agent
17
+ from Smurfs.eval.hotpot_qa.utils import eval_result
18
+ import json
19
+ import threading
20
+ import joblib
21
+ from tqdm import tqdm
22
+ import time
23
+
24
+ def post_run(results, HP_answer_agent):
25
+ global new_results
26
+ for res in tqdm(results):
27
+ pre = res["pre_ans"]
28
+ query_id = res["id"]
29
+ ques = res["question"]
30
+ pre = remove_tags(pre)
31
+ if len(pre) == 0:
32
+ res["parsed_pre"] = ""
33
+ new_results.append(res)
34
+ continue
35
+ parsed_pre = HP_answer_agent.run(query_id=query_id, task="parse", question=ques, detailed_answer=pre)
36
+ res["parsed_pre"] = parsed_pre
37
+ new_results.append(res)
38
+
39
+ if __name__ == "__main__":
40
+ levels = ['easy', 'medium', 'hard']
41
+ # model_name = "meta-llama/llama-2-70b-chat"
42
+ method_name = "llama-2-13b-Smurfs"
43
+ parser_llm = OpenAI_Model(model_name="gpt-4")
44
+ for level in levels:
45
+ # level = 'easy'
46
+ # model_name = "gpt-4-0613"
47
+ new_results = []
48
+ test_set = f"hotpot_qa_{level}"
49
+ HP_answer_agent = answer_agent(llm=parser_llm, logger_dir=f"data/{method_name}/{test_set}/post_process/parser_log")
50
+ output_dir = f"data/{method_name}/{test_set}/answer"
51
+ new_results_path = f"data/{method_name}/{test_set}/post_process/results.json"
52
+ new_results_dir = f"data/{method_name}/{test_set}/post_process/"
53
+ results_dir = f"data/{method_name}/{test_set}/results.json"
54
+ if not os.path.exists(new_results_path):
55
+ os.makedirs(f"data/{method_name}/{test_set}/post_process/parser_log")
56
+ # os.makedirs(new_results_dir)
57
+ with open(results_dir, "r") as file:
58
+ results = json.load(file)
59
+
60
+ total_len = len(results)
61
+ print(total_len)
62
+
63
+ threads = []
64
+ if total_len < 20:
65
+ for i in range(total_len):
66
+ if total_len == 0:
67
+ break
68
+
69
+ start = i
70
+ end = i+1
71
+ if i == total_len-1:
72
+ query_cur = results[start:]
73
+ else:
74
+ query_cur = results[start: end]
75
+ t = threading.Thread(target=post_run, args=(query_cur, HP_answer_agent))
76
+ t.start()
77
+ threads.append(t)
78
+
79
+ else:
80
+ for i in range(20):
81
+
82
+ if total_len == 0:
83
+ break
84
+
85
+ start = round(total_len/20)*i
86
+ end = round(total_len/20)*(i+1)
87
+ if i == 19:
88
+ query_cur = results[start:]
89
+ else:
90
+ query_cur = results[start: end]
91
+ t = threading.Thread(target=post_run, args=(query_cur, HP_answer_agent))
92
+ t.start()
93
+ threads.append(t)
94
+
95
+ for thread in threads:
96
+ thread.join()
97
+
98
+ with open(new_results_path, "w") as file:
99
+ json.dump(new_results, file, indent=4, ensure_ascii=False)
100
+ correct, reward, parsed_correct, parsed_reward, pre_dict, parsed_dict = eval_result(new_results)
101
+ print(f"correct rate for {test_set} is: {correct}, reward rate for {test_set} is: {reward}")
102
+ print(f"parsed correct rate for {test_set} is: {parsed_correct}, parsed reward rate for {test_set} is: {parsed_reward}")
103
+ else:
104
+ with open(new_results_path, "r") as file:
105
+ new_results = json.load(file)
106
+ correct, reward, parsed_correct, parsed_reward, pre_dict, parsed_dict = eval_result(new_results)
107
+ print(f"correct rate for {test_set} is: {correct}, reward rate for {test_set} is: {reward}")
108
+ print(f"parsed correct rate for {test_set} is: {parsed_correct}, parsed reward rate for {test_set} is: {parsed_reward}")
109
+
Smurfs/eval/hotpot_qa/run_eval.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
4
+ from Smurfs.inference.smurfs_worker import smurfs_hotpot_worker
5
+ from Smurfs.tools.tool_env import HotpotToolEnv
6
+ from Smurfs.model.openai_model.openai_model import OpenAI_Model, OpenRouter_Model
7
+ from Smurfs.agents.answer_agent.answer import answer_agent
8
+ from Smurfs.agents.executor_agent.executor import executor_agent
9
+ from Smurfs.agents.planning_agent.planner import hotpot_planning_agent
10
+ from Smurfs.agents.verifier_agent.verifier import verifier_agent
11
+ from Smurfs.eval.hotpot_qa.utils import eval_result
12
+ import json
13
+ import threading
14
+ import joblib
15
+ from tqdm import tqdm
16
+ import time
17
+ import warnings
18
+
19
+ # 抑制所有警告
20
+ warnings.filterwarnings('ignore')
21
+
22
+ def run(worker, query, query_id):
23
+ # global lock
24
+ final_answer, output_file_ele, solution_file_ele = worker.run(query, query_id)
25
+ # lock.acquire()
26
+ worker.save_solution(output_file_ele, solution_file_ele, query_id)
27
+ # lock.release()
28
+ return final_answer
29
+
30
+ def run_one_hotpot(ques, ans, HP_answer_agent, worker, query_id):
31
+ global results
32
+ # ques, ans = task_instructions[0]
33
+ # print(ques)
34
+ # print(ans)
35
+ pre = run(worker, ques, query_id)
36
+ print(pre)
37
+
38
+ # question = "Where was the first governor after the The Missouri Compromise from?"
39
+ # detailed_answer = "The first governor of Missouri after the Missouri Compromise was Alexander McNair. He was originally from central Pennsylvania, specifically, he was born in Cumberland County on May 5, 1775, and later lived in Derry township, Lancaster (now Dauphin) County. He also pursued his education in Derry and attended the University of Pennsylvania in Philadelphia for a term."
40
+ parsed_pre = HP_answer_agent.run(query_id=query_id, task="parse", question=ques, detailed_answer=pre)
41
+ print(parsed_pre)
42
+
43
+ result_ele = {"question": ques, "gt_answer": ans, "pre_ans": pre, "parsed_pre": parsed_pre, "id": query_id}
44
+ lock.acquire()
45
+ results.append(result_ele)
46
+ lock.release()
47
+
48
+ def run_hotpot(query_list, HP_answer_agent, worker):
49
+ with tqdm(total=len(query_list), desc="Processing files", initial=0) as pbar:
50
+ for i, test_task_ins in enumerate(query_list, start=0):
51
+ idx = test_task_ins[0]
52
+ ques, ans = test_task_ins[1]
53
+ while True:
54
+ try:
55
+ run_one_hotpot(ques, ans, HP_answer_agent, worker, idx)
56
+ break
57
+ except Exception as e:
58
+ print(e)
59
+ print("some error occurs, continue...")
60
+ time.sleep(60)
61
+ continue
62
+
63
+ pbar.update(1)
64
+ return
65
+
66
+ #测试三个测试集
67
+ # if __name__ == '__main__':
68
+ # #store true pre and parse pre in a same json, calculate together
69
+ # #dump them together
70
+ # lock = threading.Lock()
71
+ # levels = ['easy', 'medium', 'hard']
72
+ # model_name = "gpt-3.5-turbo"
73
+ # method_name = "GPT3-turbo-Smurfs"
74
+ # llm = OpenAI_Model(model_name=model_name)
75
+ # parser_llm = OpenAI_Model(model_name="gpt-4")
76
+ # task_path = "/Users/chenjunzhi/Desktop/smurfs_more/AutoAct/Self_Plan/Group_Planning/benchmark_run/data/hotpotqa"
77
+ # with open("/Users/chenjunzhi/Desktop/smurfs_more/Smurfs/Smurfs/tools/hotpot.json", "r") as f:
78
+ # available_tools = json.load(f)
79
+ # for level in levels:
80
+ # # level = 'hard'
81
+ # # model_name = "gpt-4-0613"
82
+ # results = []
83
+ # test_set = f"hotpot_qa_{level}"
84
+
85
+ # output_dir = f"data/{method_name}/{test_set}/answer"
86
+ # results_dir = f"data/{method_name}/{test_set}/results.json"
87
+ # if not os.path.exists(f"data/{method_name}/{test_set}/parser_log"):
88
+ # os.makedirs(f"data/{method_name}/{test_set}/parser_log")
89
+ # if not os.path.exists(output_dir):
90
+ # os.makedirs(output_dir)
91
+ # if os.path.exists(results_dir):
92
+ # with open(results_dir, "r") as file:
93
+ # results = json.load(file)
94
+
95
+ # items = os.listdir(output_dir)
96
+ # for i in range(len(items)):
97
+ # items[i] = items[i].split(".")[0]
98
+
99
+ # HP_answer_agent = answer_agent(llm=parser_llm, logger_dir=f"data/{method_name}/{test_set}/parser_log")
100
+ # worker = smurfs_worker(available_tools, tool_env, llm, method_name, test_set, answer_agent, executor_agent,hotpot_planning_agent, verifier_agent)
101
+ # hotpot = joblib.load(f'{task_path}/{level}.joblib').reset_index(drop = True)
102
+ # task_instructions = [(row['question'], row['answer']) for _, row in hotpot.iterrows()]
103
+
104
+
105
+ # query_to_do = []
106
+ # # if len(items) != 0:
107
+ # for idx, q in enumerate(task_instructions):
108
+ # # print(idx)
109
+ # if str(idx) in items:
110
+ # continue
111
+ # # query_id = q["query_id"]
112
+ # # if str(query_id) not in test_ids:
113
+ # # continue
114
+ # query_to_do_ele = (idx, q)
115
+ # query_to_do.append(query_to_do_ele)
116
+
117
+ # total_len = len(query_to_do)
118
+ # query_len = len(task_instructions)
119
+ # print(total_len)
120
+
121
+ # threads = []
122
+ # if total_len < 20:
123
+ # for i in range(total_len):
124
+ # if total_len == 0:
125
+ # break
126
+
127
+ # start = i
128
+ # end = i+1
129
+ # if i == total_len-1:
130
+ # query_cur = query_to_do[start:]
131
+ # else:
132
+ # query_cur = query_to_do[start: end]
133
+ # t = threading.Thread(target=run_hotpot, args=(query_cur, HP_answer_agent, worker))
134
+ # t.start()
135
+ # threads.append(t)
136
+
137
+ # else:
138
+ # for i in range(20):
139
+
140
+ # if total_len == 0:
141
+ # break
142
+
143
+ # start = round(total_len/20)*i
144
+ # end = round(total_len/20)*(i+1)
145
+ # if i == 19:
146
+ # query_cur = query_to_do[start:]
147
+ # else:
148
+ # query_cur = query_to_do[start: end]
149
+ # t = threading.Thread(target=run_hotpot, args=(query_cur, HP_answer_agent, worker))
150
+ # t.start()
151
+ # threads.append(t)
152
+
153
+ # for thread in threads:
154
+ # thread.join()
155
+
156
+ # with open(results_dir, "w") as file:
157
+ # json.dump(results, file, indent=4, ensure_ascii=False)
158
+
159
+ # with open(results_dir, "r") as file:
160
+ # eval_data = json.load(file)
161
+
162
+ # correct, reward, parsed_correct, parsed_reward, pre_dict, parsed_dict = eval_result(eval_data)
163
+ # print(f"correct rate for {test_set} is: {correct}, reward rate for {test_set} is: {reward}")
164
+ # print(f"parsed correct rate for {test_set} is: {parsed_correct}, parsed reward rate for {test_set} is: {parsed_reward}")
165
+
166
+ # with open(f"data/{method_name}/{test_set}/parsed_result.json", "w") as file:
167
+ # json.dump(parsed_dict, file, indent=4, ensure_ascii=False)
168
+
169
+ # with open(f"data/{method_name}/{test_set}/original_result.json", "w") as file:
170
+ # json.dump(pre_dict, file, indent=4, ensure_ascii=False)
171
+
172
+
173
+
174
+
175
+
176
+
177
+
178
+ # 测试一个query
179
+ # if __name__ == '__main__':
180
+ # #store true pre and parse pre in a same json, calculate together
181
+ # #dump them together
182
+ # lock = threading.Lock()
183
+ # levels = ['easy', 'medium', 'hard']
184
+ # model_name = "gpt-3.5-turbo"
185
+ # method_name = "GPT3-test-Smurfs"
186
+ # llm = OpenAI_Model(model_name=model_name)
187
+ # parser_llm = OpenAI_Model(model_name="gpt-4")
188
+ # task_path = "/Users/chenjunzhi/Desktop/smurfs_more/AutoAct/Self_Plan/Group_Planning/benchmark_run/data/hotpotqa"
189
+ # with open("/Users/chenjunzhi/Desktop/smurfs_more/Smurfs/Smurfs/tools/hotpot.json", "r") as f:
190
+ # available_tools = json.load(f)
191
+ # # for level in levels:
192
+ # level = 'hard'
193
+ # # model_name = "gpt-4-0613"
194
+ # results = []
195
+ # test_set = f"hotpot_qa_{level}"
196
+
197
+ # output_dir = f"data/{method_name}/{test_set}/answer"
198
+ # results_dir = f"data/{method_name}/{test_set}/results.json"
199
+ # if not os.path.exists(f"data/{method_name}/{test_set}/parser_log"):
200
+ # os.makedirs(f"data/{method_name}/{test_set}/parser_log")
201
+ # if not os.path.exists(output_dir):
202
+ # os.makedirs(output_dir)
203
+ # if os.path.exists(results_dir):
204
+ # with open(results_dir, "r") as file:
205
+ # results = json.load(file)
206
+
207
+ # items = os.listdir(output_dir)
208
+ # for i in range(len(items)):
209
+ # items[i] = items[i].split(".")[0]
210
+
211
+ # HP_answer_agent = answer_agent(llm=parser_llm, logger_dir=f"data/{method_name}/{test_set}/parser_log")
212
+ # worker = smurfs_worker(available_tools, tool_env, llm, method_name, test_set, answer_agent, executor_agent,hotpot_planning_agent, verifier_agent)
213
+ # hotpot = joblib.load(f'{task_path}/{level}.joblib').reset_index(drop = True)
214
+ # task_instructions = [(row['question'], row['answer']) for _, row in hotpot.iterrows()]
215
+ # ques, ans = task_instructions[15][0], task_instructions[15][1]
216
+ # run_one_hotpot(ques, ans, HP_answer_agent, worker, 0)
217
+
218
+
219
+ # # query_to_do = []
220
+ # # if len(items) != 0:
221
+ # # for idx, q in enumerate(task_instructions):
222
+ # # # print(idx)
223
+ # # if str(idx) in items:
224
+ # # continue
225
+ # # # query_id = q["query_id"]
226
+ # # # if str(query_id) not in test_ids:
227
+ # # # continue
228
+ # # query_to_do_ele = (idx, q)
229
+ # # query_to_do.append(query_to_do_ele)
230
+
231
+ # # total_len = len(query_to_do)
232
+ # # query_len = len(task_instructions)
233
+ # # print(total_len)
234
+
235
+ # # threads = []
236
+ # # if total_len < 20:
237
+ # # for i in range(total_len):
238
+ # # if total_len == 0:
239
+ # # break
240
+
241
+ # # start = i
242
+ # # end = i+1
243
+ # # if i == total_len-1:
244
+ # # query_cur = query_to_do[start:]
245
+ # # else:
246
+ # # query_cur = query_to_do[start: end]
247
+ # # t = threading.Thread(target=run_hotpot, args=(query_cur, HP_answer_agent, worker))
248
+ # # t.start()
249
+ # # threads.append(t)
250
+
251
+ # # else:
252
+ # # for i in range(20):
253
+
254
+ # # if total_len == 0:
255
+ # # break
256
+
257
+ # # start = round(total_len/20)*i
258
+ # # end = round(total_len/20)*(i+1)
259
+ # # if i == 19:
260
+ # # query_cur = query_to_do[start:]
261
+ # # else:
262
+ # # query_cur = query_to_do[start: end]
263
+ # # t = threading.Thread(target=run_hotpot, args=(query_cur, HP_answer_agent, worker))
264
+ # # t.start()
265
+ # # threads.append(t)
266
+
267
+ # # for thread in threads:
268
+ # # thread.join()
269
+
270
+ # with open(results_dir, "w") as file:
271
+ # json.dump(results, file, indent=4, ensure_ascii=False)
272
+
273
+ # # with open(results_dir, "r") as file:
274
+ # # eval_data = json.load(file)
275
+
276
+ # # correct, reward, parsed_correct, parsed_reward, pre_dict, parsed_dict = eval_result(eval_data)
277
+ # # print(f"correct rate for {test_set} is: {correct}, reward rate for {test_set} is: {reward}")
278
+ # # print(f"parsed correct rate for {test_set} is: {parsed_correct}, parsed reward rate for {test_set} is: {parsed_reward}")
279
+
280
+ # # with open(f"data/{method_name}/{test_set}/parsed_result.json", "w") as file:
281
+ # # json.dump(parsed_dict, file, indent=4, ensure_ascii=False)
282
+
283
+ # # with open(f"data/{method_name}/{test_set}/original_result.json", "w") as file:
284
+ # # json.dump(pre_dict, file, indent=4, ensure_ascii=False)
285
+
286
+
287
+
288
+
289
+
290
+ #测试一个个测试集
291
+ if __name__ == '__main__':
292
+ #store true pre and parse pre in a same json, calculate together
293
+ #dump them together
294
+ lock = threading.Lock()
295
+ levels = ['easy', 'medium', 'hard']
296
+ model_name = "meta-llama/llama-2-70b-chat"
297
+ method_name = "llama-2-13b-Smurfs"
298
+ # llm = OpenAI_Model(model_name=model_name)
299
+ llm = OpenRouter_Model(model_name=model_name)
300
+ parser_llm = OpenAI_Model(model_name="gpt-4")
301
+ task_path = "/Users/chenjunzhi/Desktop/smurfs_more/AutoAct/Self_Plan/Group_Planning/benchmark_run/data/hotpotqa"
302
+ with open("/Users/chenjunzhi/Desktop/smurfs_more/Smurfs/Smurfs/tools/hotpot.json", "r") as f:
303
+ available_tools = json.load(f)
304
+ for level in levels:
305
+ # level = 'easy'
306
+ # model_name = "gpt-4-0613"
307
+ results = []
308
+ test_set = f"hotpot_qa_{level}"
309
+
310
+ output_dir = f"data/{method_name}/{test_set}/answer"
311
+ results_dir = f"data/{method_name}/{test_set}/results.json"
312
+ if not os.path.exists(f"data/{method_name}/{test_set}/parser_log"):
313
+ os.makedirs(f"data/{method_name}/{test_set}/parser_log")
314
+ if not os.path.exists(output_dir):
315
+ os.makedirs(output_dir)
316
+ if os.path.exists(results_dir):
317
+ with open(results_dir, "r") as file:
318
+ results = json.load(file)
319
+
320
+ items = os.listdir(output_dir)
321
+ for i in range(len(items)):
322
+ items[i] = items[i].split(".")[0]
323
+
324
+ HP_answer_agent = answer_agent(llm=parser_llm, logger_dir=f"data/{method_name}/{test_set}/parser_log")
325
+ worker = smurfs_hotpot_worker(available_tools, HotpotToolEnv, llm, method_name, test_set, answer_agent, executor_agent,hotpot_planning_agent, verifier_agent)
326
+ hotpot = joblib.load(f'{task_path}/{level}.joblib').reset_index(drop = True)
327
+ task_instructions = [(row['question'], row['answer']) for _, row in hotpot.iterrows()]
328
+
329
+
330
+ query_to_do = []
331
+ # if len(items) != 0:
332
+ for idx, q in enumerate(task_instructions):
333
+ # print(idx)
334
+ if str(idx) in items:
335
+ continue
336
+ # query_id = q["query_id"]
337
+ # if str(query_id) not in test_ids:
338
+ # continue
339
+ query_to_do_ele = (idx, q)
340
+ query_to_do.append(query_to_do_ele)
341
+
342
+ total_len = len(query_to_do)
343
+ query_len = len(task_instructions)
344
+ print(total_len)
345
+
346
+ threads = []
347
+ if total_len < 20:
348
+ for i in range(total_len):
349
+ if total_len == 0:
350
+ break
351
+
352
+ start = i
353
+ end = i+1
354
+ if i == total_len-1:
355
+ query_cur = query_to_do[start:]
356
+ else:
357
+ query_cur = query_to_do[start: end]
358
+ t = threading.Thread(target=run_hotpot, args=(query_cur, HP_answer_agent, worker))
359
+ t.start()
360
+ threads.append(t)
361
+
362
+ else:
363
+ for i in range(20):
364
+
365
+ if total_len == 0:
366
+ break
367
+
368
+ start = round(total_len/20)*i
369
+ end = round(total_len/20)*(i+1)
370
+ if i == 19:
371
+ query_cur = query_to_do[start:]
372
+ else:
373
+ query_cur = query_to_do[start: end]
374
+ t = threading.Thread(target=run_hotpot, args=(query_cur, HP_answer_agent, worker))
375
+ t.start()
376
+ threads.append(t)
377
+
378
+ for thread in threads:
379
+ thread.join()
380
+
381
+ with open(results_dir, "w") as file:
382
+ json.dump(results, file, indent=4, ensure_ascii=False)
383
+
384
+ with open(results_dir, "r") as file:
385
+ eval_data = json.load(file)
386
+
387
+ correct, reward, parsed_correct, parsed_reward, pre_dict, parsed_dict = eval_result(eval_data)
388
+ print(f"correct rate for {test_set} is: {correct}, reward rate for {test_set} is: {reward}")
389
+ print(f"parsed correct rate for {test_set} is: {parsed_correct}, parsed reward rate for {test_set} is: {parsed_reward}")
390
+
391
+ with open(f"data/{method_name}/{test_set}/parsed_result.json", "w") as file:
392
+ json.dump(parsed_dict, file, indent=4, ensure_ascii=False)
393
+
394
+ with open(f"data/{method_name}/{test_set}/original_result.json", "w") as file:
395
+ json.dump(pre_dict, file, indent=4, ensure_ascii=False)
Smurfs/eval/hotpot_qa/utils.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from collections import Counter
3
+ import string
4
+ import json
5
+
6
+ def format_step(step: str) -> str:
7
+ step = step.strip('\n').strip().replace('\n', '')
8
+ if step.startswith("Thought") or step.startswith("Action"):
9
+ step = step.split()[2:]
10
+ step = " ".join(step)
11
+ if "Thought" in step:
12
+ step = step.split("Thought")[0].strip()
13
+ if "Action" in step:
14
+ step = step.split("Action")[0].strip()
15
+ if "Observation" in step:
16
+ step = step.split("Observation")[0].strip()
17
+ return step
18
+
19
+ def normalize_answer(s):
20
+ def remove_articles(text):
21
+ return re.sub(r"\b(a|an|the)\b", " ", text)
22
+
23
+ def white_space_fix(text):
24
+ return " ".join(text.split())
25
+
26
+ def remove_punc(text):
27
+ exclude = set(string.punctuation)
28
+ return "".join(ch for ch in text if ch not in exclude)
29
+
30
+ def lower(text):
31
+ return text.lower()
32
+
33
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
34
+
35
+ def f1_score(prediction, ground_truth):
36
+ normalized_prediction = normalize_answer(prediction)
37
+ normalized_ground_truth = normalize_answer(ground_truth)
38
+
39
+ ZERO_METRIC = (0, 0, 0)
40
+
41
+ if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
42
+ return ZERO_METRIC
43
+ if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
44
+ return ZERO_METRIC
45
+
46
+ prediction_tokens = normalized_prediction.split()
47
+ ground_truth_tokens = normalized_ground_truth.split()
48
+ common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
49
+ num_same = sum(common.values())
50
+ if num_same == 0:
51
+ return ZERO_METRIC
52
+ precision = 1.0 * num_same / len(prediction_tokens)
53
+ recall = 1.0 * num_same / len(ground_truth_tokens)
54
+ f1 = (2 * precision * recall) / (precision + recall)
55
+ return f1, precision, recall
56
+
57
+ def EM(answer, key) -> bool:
58
+ return normalize_answer(answer) == normalize_answer(key)
59
+
60
+ def score_string_similarity(str1, str2):
61
+ if str1 == str2:
62
+ return 2.0
63
+ elif " " in str1 or " " in str2:
64
+ str1_split = str1.split(" ")
65
+ str2_split = str2.split(" ")
66
+ overlap = list(set(str1_split) & set(str2_split))
67
+ return len(overlap) / max(len(str1_split), len(str2_split))
68
+ else:
69
+ return 0.0
70
+
71
+ def eval_result_once(question, pre, gt):
72
+ correct = EM(pre, gt)
73
+ reward = f1_score(pre, gt)[0]
74
+ # halted = agent.is_halted()
75
+ # error = agent.run_error
76
+ # prompt = agent._build_agent_prompt()
77
+ save_dict = {"question":question, "answer":gt, "prediction": pre, "EM":correct, "reward":reward}
78
+ # with open(file_path, 'a') as f:
79
+ # json.dump(save_dict, f)
80
+ # f.write("\n")
81
+ return save_dict
82
+
83
+ def eval_result(eval_data):
84
+ result = []
85
+ parsed_result = []
86
+ correct = 0
87
+ reward = 0
88
+ parsed_correct = 0
89
+ parsed_reward = 0
90
+ total_len = len(eval_data)
91
+ for d in eval_data:
92
+ pre = d["pre_ans"]
93
+ parsed_pre = d["parsed_pre"]
94
+ gt = d["gt_answer"]
95
+ question = d["question"]
96
+ pre_dict = eval_result_once(question, pre, gt)
97
+ parsed_dict = eval_result_once(question, parsed_pre, gt)
98
+ result.append(pre_dict)
99
+ parsed_result.append(parsed_dict)
100
+
101
+ correct += pre_dict["EM"]
102
+ reward += pre_dict["reward"]
103
+
104
+ parsed_correct += parsed_dict["EM"]
105
+ parsed_reward += parsed_dict["reward"]
106
+
107
+ correct /= total_len
108
+ reward /= total_len
109
+
110
+ parsed_correct /= total_len
111
+ parsed_reward /= total_len
112
+
113
+ return correct, reward, parsed_correct, parsed_reward, result, parsed_result
114
+
115
+
116
+
117
+
Smurfs/inference/__init__.py ADDED
File without changes
Smurfs/inference/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (170 Bytes). View file
 
Smurfs/inference/__pycache__/inference.cpython-39.pyc ADDED
Binary file (12.1 kB). View file
 
Smurfs/inference/__pycache__/server.cpython-39.pyc ADDED
Binary file (5.27 kB). View file
 
Smurfs/inference/__pycache__/smurfs_worker.cpython-39.pyc ADDED
Binary file (17.1 kB). View file
 
Smurfs/inference/__pycache__/utils.cpython-39.pyc ADDED
Binary file (4.52 kB). View file
 
Smurfs/inference/functioncall_inference.py ADDED
@@ -0,0 +1,533 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # — coding: utf-8 –
2
+ import json
3
+ import sys
4
+ import argparse
5
+ import time
6
+ import requests
7
+ import os
8
+ from utils import change_name, standardize, get_white_list, get_answer_log, get_observation_log, build_tree, get_answer_details, test_sets
9
+ from tqdm import tqdm
10
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
11
+ from Smurfs.model.vllm_model.vllm_model import vllm_Model
12
+ from Smurfs.inference.server import get_rapidapi_response
13
+ import threading
14
+ from Smurfs.agents.answer_agent.answer import answer_agent
15
+ from Smurfs.agents.executor_agent.executor import function_call_executor_agent
16
+ from Smurfs.agents.planning_agent.planner import planning_agent
17
+ from Smurfs.agents.verifier_agent.verifier import verifier_agent
18
+
19
+ def _Call_function(category, tool_name, api_name, tool_input, strip, white_list, toolbench_key, args):
20
+ use_rapidapi_key = args.use_rapidapi_key
21
+ rapidapi_key = os.environ.get("rapidapi_key")
22
+ api_customization = args.api_customization
23
+
24
+ api_name = change_name(standardize(api_name))
25
+ tool_name = standardize(tool_name)
26
+ if tool_name not in white_list.keys():
27
+ print(f"tool name doesn't exist: {tool_name}")
28
+ return {}, 1
29
+ standard_tool_name = white_list[tool_name]["standard_tool_name"]
30
+ payload = {
31
+ "category": category,
32
+ "tool_name": standard_tool_name,
33
+ "api_name": api_name,
34
+ "tool_input": tool_input,
35
+ "strip": strip,
36
+ "toolbench_key": toolbench_key
37
+ }
38
+ if use_rapidapi_key or api_customization:
39
+ payload["rapidapi_key"] = rapidapi_key
40
+ response = get_rapidapi_response(payload, api_customization=api_customization)
41
+ else:
42
+ time.sleep(2) # rate limit: 30 per minute
43
+ headers = {"toolbench_key": toolbench_key}
44
+ print(payload)
45
+ # if tool_input == {}:
46
+ # response = requests.post("http://8.218.239.54:8080/rapidapi", headers=headers, timeout=15)
47
+ # else:
48
+ response = requests.post("http://8.218.239.54:8080/rapidapi", json=payload, headers=headers, timeout=15)
49
+ if response.status_code != 200:
50
+ return json.dumps({"error": f"request invalid, data error. status_code={response.status_code}", "response": ""}), 12
51
+ try:
52
+ response = response.json()
53
+ except:
54
+ print(response)
55
+ return json.dumps({"error": f"request invalid, data error", "response": ""}), 12
56
+ # 1 Hallucinating function names
57
+ # 4 means that the model decides to pruning by itself
58
+ # 5 represents api call timeout
59
+ # 6 for 404
60
+ # 7 means not subscribed
61
+ # 8 represents unauthorized
62
+ # 9 represents too many requests
63
+ # 10 stands for rate limit
64
+ # 11 message contains "error" field
65
+ # 12 error sending request
66
+ if response["error"] == "API not working error...":
67
+ status_code = 6
68
+ elif response["error"] == "Unauthorized error...":
69
+ status_code = 7
70
+ elif response["error"] == "Unsubscribed error...":
71
+ status_code = 8
72
+ elif response["error"] == "Too many requests error...":
73
+ status_code = 9
74
+ elif response["error"] == "Rate limit per minute error...":
75
+ print("Reach api calling limit per minute, sleeping...")
76
+ time.sleep(10)
77
+ status_code = 10
78
+ elif response["error"] == "Message error...":
79
+ status_code = 11
80
+ elif response["error"] != "":
81
+ status_code = "unexpected error, try again!"
82
+ else:
83
+ status_code = 0
84
+ return json.dumps(response), status_code
85
+
86
+ def Call_function(category, tool_name, api_name, tool_input, strip, white_list, args):
87
+ toolbench_key = os.environ.get("toolbench_key")
88
+ response, status_code = _Call_function(category, tool_name, api_name, tool_input, strip, white_list, toolbench_key, args)
89
+ if status_code == "unexpected error, try again!":
90
+ arg = {change_name(k.lower()): v for k, v in tool_input.items()}
91
+ response, status_code = _Call_function(category, tool_name, api_name, arg, strip, white_list, toolbench_key, args)
92
+ if status_code == "unexpected error, try again!":
93
+ arg = {change_name(k.replace("-", "_")): v for k, v in tool_input.items()}
94
+ response, status_code = _Call_function(category, tool_name, api_name, arg, strip, white_list, toolbench_key, args)
95
+ if status_code == "unexpected error, try again!":
96
+ arg = {change_name(k.replace("\\", "")): v for k, v in tool_input.items()}
97
+ response, status_code = _Call_function(category, tool_name, api_name, arg, strip, white_list, toolbench_key, args)
98
+ if status_code == "unexpected error, try again!":
99
+ print(f"Call function fails")
100
+ with open('wrong_log.json', 'a+', encoding='utf-8') as f:
101
+ line = json.dumps({
102
+ "id": 0,
103
+ "parameters": arg,
104
+ "wrong": response
105
+ }, ensure_ascii=False)
106
+ f.write(line + '\n')
107
+ return -1
108
+ return response
109
+
110
+ def inference(query, relevant_APIs, white_list, subtask, Answer_Agent, Executor_Agent, Verifier_Agent, query_id, args, max_step=3):
111
+ tool_check_num = Answer_Agent.run(question=query, task="tool_check", query_id=query_id)
112
+ #direct answer
113
+ if tool_check_num == 1:
114
+ input_dic = {"task": query}
115
+ answer = Answer_Agent.run(input_dic)
116
+ return answer, answer, None, None
117
+
118
+ previous_log = []
119
+ history_log = []
120
+ tool_used_dic = {}
121
+ relevant_APIs_ids = []
122
+ for idx in relevant_APIs:
123
+ ele = relevant_APIs[idx]
124
+ relevant_APIs_ids.append(str(ele["ID"]))
125
+ restart_time = 0
126
+ step_num = 0
127
+ hint = "Beginnig of the agent. No hint yet"
128
+ retry_tool_id = 0
129
+ retry_parameter = 0
130
+ re_time = 0
131
+ subtask_id = 0
132
+ restart = 0
133
+ while True:
134
+ if step_num >= max_step:
135
+ print("\n\nReach steps limits, return answers!\n\n")
136
+ answer_log = get_answer_log(history_log)
137
+ answer = Answer_Agent.run(question=query, previous_log=answer_log, task="final", query_id=query_id)
138
+ return answer, previous_log, re_time, history_log
139
+
140
+ if step_num not in tool_used_dic.keys():
141
+ tool_used_dic[step_num] = []
142
+
143
+ tool_used = tool_used_dic[step_num]
144
+
145
+ tool_list = []
146
+ for idx in relevant_APIs:
147
+ ele = idx
148
+ ID = str(ele['api_name'])
149
+ if ID in tool_used:
150
+ continue
151
+ # des = ele['description']
152
+ # name = ele["tool_name"]
153
+ # tool_list.append({"ID": ID, "tool_name": name, "description": des})
154
+ tool_list.append(ele)
155
+
156
+ if len(tool_list) == 0:
157
+ if len(previous_log) == 0:
158
+ answer_log = get_answer_log(history_log)
159
+ partial_answer = Answer_Agent.run(question=query, previous_log=answer_log, task="final", query_id=query_id)
160
+ answer = f"Sorry, I can't answer this question accurately using the existing tools. A partial answer is: {partial_answer}"
161
+ return answer, previous_log, re_time, history_log
162
+ else:
163
+ delete_log = previous_log.pop()
164
+ tool_used_dic[step_num] = []
165
+ step_num -= 1
166
+ tool_used_dic[step_num].append(delete_log["tool"])
167
+ restart_time += 1
168
+ re_time += 1
169
+ continue
170
+
171
+ current_log = {"thought": "", "action": "", "action_input": {}, "observation": "", "answer": "", "tool": "","id": subtask_id}
172
+
173
+ answer_log = get_answer_log(previous_log)
174
+
175
+ if retry_tool_id == 4:
176
+ # tool_id = tool_list[0]["ID"]
177
+ tool_list = tool_list[0]
178
+ thought = Executor_Agent.run(question=query, tool_list=tool_list, previous_log=answer_log, hint=hint, query_id=query_id, task="thought")
179
+
180
+ else:
181
+ thought = Executor_Agent.run(question=query, tool_list=tool_list, previous_log=answer_log, hint=hint, query_id=query_id, task="thought")
182
+ # tool_id = Executor_Agent.run(question=subtask, tool_list=tool_list, thought=thought, query_id=query_id, task="tool")
183
+
184
+ try:
185
+ tool_id = int(tool_id)
186
+ tool_id = str(tool_id)
187
+ if tool_id not in relevant_APIs_ids:
188
+ re_time += 1
189
+ retry_tool_id += 1
190
+ print("Tool ID wrong! Generate tool_id that do not exist!")
191
+ continue
192
+ tool_des_json = relevant_APIs[str(tool_id)]
193
+ retry_tool_id = 0
194
+ except:
195
+ retry_tool_id += 1
196
+ print("Tool ID wrong! Generate tool_id that do not exist!")
197
+ continue
198
+
199
+ tool_name_list = tool_des_json["tool_name"].split(":")
200
+ category_name = tool_name_list[0]
201
+ tool_name = tool_name_list[1]
202
+ api_name = tool_name_list[2]
203
+ API_doc = tool_des_json
204
+
205
+ while True:
206
+ try:
207
+ parameters = {}
208
+
209
+ if retry_parameter == 4:
210
+ restart = 1
211
+ retry_parameter = 0
212
+ print("No Para! Restart!")
213
+ break
214
+
215
+ parameter = Executor_Agent.run(api_dic=API_doc, question=query, previous_log=answer_log, thought=thought, query_id=query_id, task="parameter")
216
+ if parameter == -1:
217
+ retry_parameter += 1
218
+ re_time += 1
219
+ continue
220
+
221
+ if parameter == {}:
222
+ retry_parameter = 0
223
+ parameters = {}
224
+ break
225
+
226
+ for key in parameter:
227
+ value = parameter[key]
228
+ key = change_name(key)
229
+ parameters[key] = value
230
+
231
+ retry_parameter = 0
232
+ break
233
+
234
+ except:
235
+ if retry_parameter == 4:
236
+ parameters = {}
237
+ retry_parameter = 0
238
+ restart = 1
239
+ break
240
+ retry_parameter += 1
241
+ print("parameter generation fails, try again!")
242
+ re_time += 1
243
+ continue
244
+
245
+
246
+ api_name = change_name(standardize(api_name))
247
+
248
+ if restart != 1:
249
+ try:
250
+ observation = Call_function(category_name, tool_name, api_name, parameters, "truncate", white_list, args)
251
+ except:
252
+ observation = -1
253
+
254
+ if observation == -1:
255
+ restart = 1
256
+ observation = str({"error": "", "response": "call API fails"})
257
+
258
+ if restart == 1:
259
+ tool_used_dic[step_num].append(str(tool_id))
260
+ print('****Try Again For This Step****')
261
+ re_time += 1
262
+ restart = 0
263
+ continue
264
+
265
+
266
+ if len(previous_log) != 0:
267
+ previous_id = previous_log[-1]["id"]
268
+ else:
269
+ previous_id = -1
270
+
271
+ current_log["tool"] = str(tool_id)
272
+ current_log["thought"] = thought
273
+ current_log["action"] = api_name
274
+ current_log["action_input"] = parameters
275
+ current_log["observation"] = observation
276
+ previous_log.append(current_log)
277
+
278
+ observation_log = get_observation_log(previous_log)
279
+
280
+ answer = Answer_Agent.run(question=subtask, call_result=observation_log, query_id=query_id, task="answer")
281
+
282
+ previous_log[-1]["answer"] = answer
283
+
284
+ history_log_ele = {"thought": thought, "action": tool_name, "action_input": parameters, "observation": observation, "answer": answer, "previous_id": previous_id, "id": subtask_id}
285
+ history_log.append(history_log_ele)
286
+ subtask_id += 1
287
+
288
+ speak, status = Verifier_Agent.run(question=subtask, answer=answer, query_id=query_id)
289
+ if speak == -1 and status == -1:
290
+ step_num += 1
291
+ continue
292
+
293
+ try:
294
+ if int(status) == 0:
295
+ hint = speak
296
+ step_num += 1
297
+ continue
298
+ except:
299
+ step_num += 1
300
+ continue
301
+
302
+ else:
303
+ return answer, previous_log, re_time, history_log
304
+
305
+ def decompose_inference(query, relevant_APIs, api_list, white_list, Answer_Agent, Executor_Agent, Planning_Agent, Verifier_Agent, query_id, args):
306
+ while True:
307
+ subtasks = Planning_Agent.run(question=query, query_id=query_id)
308
+ if subtasks == -1:
309
+ continue
310
+ break
311
+ task_log = ""
312
+ history_log = []
313
+ previous_log_totals = []
314
+ re_time_total = 0
315
+ print(subtasks)
316
+ relevant_API_list = []
317
+ # tool_id = 0
318
+ for api in api_list:
319
+ for relevant_API in relevant_APIs:
320
+ if relevant_API[0] == api["tool_name"] and relevant_API[1] == api["api_name"]:
321
+ # new_tool_name = api["category_name"]+":"+api["tool_name"]+":"+api["api_name"]
322
+ ele = api
323
+ # ele = {"ID": tool_id, "tool_name": new_tool_name, "description": api["api_description"], "required_parameters": api["required_parameters"], "optional_parameters": api["optional_parameters"]}
324
+ # for para in api["required_parameters"]:
325
+ # para_ele = {
326
+ # para["name"]: {
327
+ # "type": para["type"],
328
+ # "description": para["description": ]
329
+ # }
330
+ # }
331
+ relevant_API_list.append(ele)
332
+ # tool_id += 1
333
+
334
+ for subtask in subtasks:
335
+ task_log += f"question: {subtask}\n"
336
+ answer, previous_log, re_time, previous_log_total = inference(task_log, relevant_API_list, white_list, subtask, Answer_Agent, Executor_Agent, Verifier_Agent, query_id, args)
337
+ previous_log_totals.append(previous_log_total)
338
+ print(answer)
339
+ history_log += previous_log
340
+ re_time_total += re_time
341
+ task_log += f"answer: {answer}\n"
342
+ final_answer = Answer_Agent.run(question=query, previous_log=task_log, task="final", query_id=query_id)
343
+ return final_answer, history_log, task_log, re_time_total, previous_log_totals
344
+
345
+ def test(query_json, white_list, output_dir, whole_solution_dir, Answer_Agent, Executor_Agent, Planning_Agent, Verifier_Agent, args):
346
+ while True:
347
+ try:
348
+ global lock
349
+ total_query = len(query_json)
350
+ with tqdm(total=total_query, desc="Processing files", initial=0) as pbar:
351
+ for i, test_query in enumerate(query_json, start=0):
352
+ idx = test_query[0]
353
+ test_query = test_query[1]
354
+ query = test_query["query"]
355
+ relevant_APIs = test_query["relevant APIs"]
356
+ api_list = test_query["api_list"]
357
+ final_answer, previous_log, task_log,re_time, previous_log_totals = decompose_inference(query, relevant_APIs, api_list, white_list, Answer_Agent, Executor_Agent, Planning_Agent, Verifier_Agent, idx, args)
358
+ answer_details, total_steps = get_answer_details(final_answer, previous_log)
359
+ solution_tree, solution_total_steps = build_tree(previous_log_totals, task_log)
360
+ output_file_ele = {
361
+ "query": query,
362
+ "restart_time": re_time,
363
+ "answer": {
364
+ "method": "decompose_dfs",
365
+ "total_steps": total_steps,
366
+ "final_answer": final_answer,
367
+ "answer_details": answer_details
368
+ }
369
+ }
370
+
371
+ solution_file_ele = {
372
+ "query": query,
373
+ "total_steps": solution_total_steps,
374
+ "task_log": task_log,
375
+ "final_answer": final_answer,
376
+ "answer_path": answer_details,
377
+ "total_path": solution_tree
378
+ }
379
+ file_name = f"{idx}.json"
380
+ output_file = os.path.join(output_dir, file_name)
381
+ whole_solution_file = os.path.join(whole_solution_dir, file_name)
382
+ lock.acquire()
383
+ with open(output_file, "w") as file:
384
+ json.dump(output_file_ele, file, ensure_ascii=False, indent=4)
385
+ with open(whole_solution_file, "w") as file:
386
+ json.dump(solution_file_ele, file, ensure_ascii=False, indent=4)
387
+ lock.release()
388
+ pbar.update(1)
389
+ return
390
+
391
+ except Exception as e:
392
+ print(e)
393
+ print("some error occurs, continue...")
394
+ time.sleep(60)
395
+ continue
396
+
397
+ def parse_arg():
398
+ parser = argparse.ArgumentParser()
399
+ parser.add_argument('--test_query_id_path', type=str, default="toolbench_data/data/test_query_ids", required=False, help='test query ids for different test sets')
400
+ parser.add_argument('--method_name', type=str, default="smurfs-all", required=False, help='the inference method')
401
+ parser.add_argument('--model_name', type=str, default="your_model_name", required=False, help='the model name for the vllm model')
402
+ parser.add_argument('--query_file_dir', type=str, default="toolbench_data/data/test_instruction", required=False, help='the directory that contains test sets')
403
+ parser.add_argument('--tool_env_dir', type=str, default="toolbench_data/data/toolenv/tools", required=False, help='tool environment for the toolbench')
404
+ parser.add_argument('--toolbench_key', type=str, default="",required=False, help='your toolbench key to request rapidapi service')
405
+ parser.add_argument('--rapidapi_key', type=str, default="",required=False, help='your rapidapi key to request rapidapi service')
406
+ parser.add_argument('--use_rapidapi_key', action="store_true", help="To use customized rapidapi service or not.")
407
+ parser.add_argument('--api_customization', action="store_true", help="To use customized api or not.")
408
+ args = parser.parse_args()
409
+
410
+ return args
411
+
412
+ if __name__ == '__main__':
413
+ threads = []
414
+ lock = threading.Lock()
415
+
416
+ args = parse_arg()
417
+
418
+ test_query_id_path = args.test_query_id_path
419
+ method_name = args.method_name
420
+ model_name = args.model_name
421
+ query_file_dir = args.query_file_dir
422
+ tool_env_dir = args.tool_env_dir
423
+
424
+ toolbench_key = args.toolbench_key
425
+ rapidapi_key = args.rapidapi_key
426
+ use_rapidapi_key = args.use_rapidapi_key
427
+ api_customization = args.api_customization
428
+
429
+ chat = vllm_Model(model_name=model_name)
430
+
431
+
432
+ for test_set in test_sets:
433
+ total_output_file = f"data/{method_name}/{test_set}_raw.json"
434
+
435
+ test_ids = list(json.load(open(os.path.join(test_query_id_path, test_set+".json"), "r")).keys())
436
+
437
+ query_file = f'{query_file_dir}/{test_set}.json'
438
+
439
+ output_dir = f"data/{method_name}/{test_set}/answer"
440
+
441
+ whole_solution_dir = f"data/{method_name}/{test_set}/whole"
442
+
443
+ logger_dir = f"data/{method_name}/{test_set}/agent_log"
444
+
445
+ if not os.path.exists(output_dir):
446
+ os.makedirs(output_dir)
447
+
448
+ if not os.path.exists(whole_solution_dir):
449
+ os.makedirs(whole_solution_dir)
450
+
451
+ if not os.path.exists(logger_dir):
452
+ os.makedirs(logger_dir)
453
+
454
+ Answer_Agent = answer_agent(llm=chat, logger_dir=logger_dir)
455
+ Executor_Agent = function_call_executor_agent(llm=chat, logger_dir=logger_dir, max_observation_length=4096, use_rapidapi_key=use_rapidapi_key, api_customization=api_customization, toolbench_key="HcnlyY4DUKOr3mMas51dewgfzAhBVST7EWv0FPtNRjpoi6buJk", service_url="http://8.218.239.54:8080/rapidapi")
456
+ Planning_Agent = planning_agent(llm=chat, logger_dir=logger_dir)
457
+ Verifier_Agent = verifier_agent(llm=chat, logger_dir=logger_dir)
458
+
459
+ items = os.listdir(output_dir)
460
+ for i in range(len(items)):
461
+ items[i] = items[i].split(".")[0]
462
+
463
+ white_list = get_white_list(tool_env_dir)
464
+
465
+ with open(query_file) as file:
466
+ query_json = json.load(file)
467
+ # with open(tool_doc_dir) as file:
468
+ # tool_doc = json.load(file)
469
+
470
+ print(len(items))
471
+ query_json_to_do = []
472
+ # if len(items) != 0:
473
+ for idx, q in enumerate(query_json):
474
+ # print(idx)
475
+ if str(idx) in items:
476
+ continue
477
+ query_id = q["query_id"]
478
+ if str(query_id) not in test_ids:
479
+ continue
480
+ query_json_to_do_ele = (idx, q)
481
+ query_json_to_do.append(query_json_to_do_ele)
482
+ # else:
483
+ # query_json_to_do = query_json
484
+
485
+ total_len = len(query_json_to_do)
486
+ query_len = len(query_json)
487
+ print(total_len)
488
+
489
+ if total_len < 20:
490
+ for i in range(total_len):
491
+ if total_len == 0:
492
+ break
493
+
494
+ start = i
495
+ end = i+1
496
+ if i == total_len-1:
497
+ query_json_cur = query_json_to_do[start:]
498
+ else:
499
+ query_json_cur = query_json_to_do[start: end]
500
+ t = threading.Thread(target=test, args=(query_json_cur, white_list, output_dir, whole_solution_dir, Answer_Agent, Executor_Agent, Planning_Agent, Verifier_Agent, args))
501
+ t.start()
502
+ threads.append(t)
503
+
504
+
505
+ else:
506
+ for i in range(20):
507
+
508
+ if total_len == 0:
509
+ break
510
+
511
+ start = round(total_len/20)*i
512
+ end = round(total_len/20)*(i+1)
513
+ if i == 19:
514
+ query_json_cur = query_json_to_do[start:]
515
+ else:
516
+ query_json_cur = query_json_to_do[start: end]
517
+ t = threading.Thread(target=test, args=(query_json_cur, white_list, output_dir, whole_solution_dir, Answer_Agent, Executor_Agent, Planning_Agent, Verifier_Agent, args))
518
+ t.start()
519
+ threads.append(t)
520
+
521
+ for thread in threads:
522
+ thread.join()
523
+
524
+ total_json = {}
525
+ items = os.listdir(output_dir)
526
+ for item in items:
527
+ item_path = os.path.join(output_dir, item)
528
+ idx = item.split(".")[0]
529
+ total_json[str(idx)] = json.load(open(item_path, 'r'))
530
+
531
+ with open(total_output_file, 'w') as file:
532
+ json.dump(total_json, file, indent=4, ensure_ascii=False)
533
+
Smurfs/inference/inference.py ADDED
@@ -0,0 +1,527 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # — coding: utf-8 –
2
+ import json
3
+ import sys
4
+ import argparse
5
+ import time
6
+ import requests
7
+ import os
8
+ from tqdm import tqdm
9
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
10
+ from Smurfs.inference.utils import change_name, standardize, get_white_list, get_answer_log, get_observation_log, build_tree, get_answer_details, test_sets
11
+ from Smurfs.model.vllm_model.vllm_model import vllm_Model
12
+ from Smurfs.inference.server import get_rapidapi_response
13
+ import threading
14
+ from Smurfs.agents.answer_agent.answer import answer_agent
15
+ from Smurfs.agents.executor_agent.executor import executor_agent
16
+ from Smurfs.agents.planning_agent.planner import planning_agent
17
+ from Smurfs.agents.verifier_agent.verifier import verifier_agent
18
+ import warnings
19
+
20
+ warnings.filterwarnings('ignore')
21
+
22
+ def _Call_function(category, tool_name, api_name, tool_input, strip, white_list, toolbench_key, args):
23
+ use_rapidapi_key = args.use_rapidapi_key
24
+ rapidapi_key = os.environ.get("rapidapi_key")
25
+ api_customization = args.api_customization
26
+
27
+ api_name = change_name(standardize(api_name))
28
+ tool_name = standardize(tool_name)
29
+ if tool_name not in white_list.keys():
30
+ print(f"tool name doesn't exist: {tool_name}")
31
+ return {}, 1
32
+ standard_tool_name = white_list[tool_name]["standard_tool_name"]
33
+ payload = {
34
+ "category": category,
35
+ "tool_name": standard_tool_name,
36
+ "api_name": api_name,
37
+ "tool_input": tool_input,
38
+ "strip": strip,
39
+ "toolbench_key": toolbench_key
40
+ }
41
+ if use_rapidapi_key or api_customization:
42
+ payload["rapidapi_key"] = rapidapi_key
43
+ response = get_rapidapi_response(payload, api_customization=api_customization)
44
+ else:
45
+ time.sleep(2) # rate limit: 30 per minute
46
+ headers = {"toolbench_key": toolbench_key}
47
+ print(payload)
48
+ # if tool_input == {}:
49
+ # response = requests.post("http://8.218.239.54:8080/rapidapi", headers=headers, timeout=15)
50
+ # else:
51
+ response = requests.post("http://8.218.239.54:8080/rapidapi", json=payload, headers=headers, timeout=15)
52
+ if response.status_code != 200:
53
+ return json.dumps({"error": f"request invalid, data error. status_code={response.status_code}", "response": ""}), 12
54
+ try:
55
+ response = response.json()
56
+ except:
57
+ print(response)
58
+ return json.dumps({"error": f"request invalid, data error", "response": ""}), 12
59
+ # 1 Hallucinating function names
60
+ # 4 means that the model decides to pruning by itself
61
+ # 5 represents api call timeout
62
+ # 6 for 404
63
+ # 7 means not subscribed
64
+ # 8 represents unauthorized
65
+ # 9 represents too many requests
66
+ # 10 stands for rate limit
67
+ # 11 message contains "error" field
68
+ # 12 error sending request
69
+ if response["error"] == "API not working error...":
70
+ status_code = 6
71
+ elif response["error"] == "Unauthorized error...":
72
+ status_code = 7
73
+ elif response["error"] == "Unsubscribed error...":
74
+ status_code = 8
75
+ elif response["error"] == "Too many requests error...":
76
+ status_code = 9
77
+ elif response["error"] == "Rate limit per minute error...":
78
+ print("Reach api calling limit per minute, sleeping...")
79
+ time.sleep(10)
80
+ status_code = 10
81
+ elif response["error"] == "Message error...":
82
+ status_code = 11
83
+ elif response["error"] != "":
84
+ status_code = "unexpected error, try again!"
85
+ else:
86
+ status_code = 0
87
+ return json.dumps(response), status_code
88
+
89
+ def Call_function(category, tool_name, api_name, tool_input, strip, white_list, args):
90
+ toolbench_key = os.environ.get("toolbench_key")
91
+ response, status_code = _Call_function(category, tool_name, api_name, tool_input, strip, white_list, toolbench_key, args)
92
+ if status_code == "unexpected error, try again!":
93
+ arg = {change_name(k.lower()): v for k, v in tool_input.items()}
94
+ response, status_code = _Call_function(category, tool_name, api_name, arg, strip, white_list, toolbench_key, args)
95
+ if status_code == "unexpected error, try again!":
96
+ arg = {change_name(k.replace("-", "_")): v for k, v in tool_input.items()}
97
+ response, status_code = _Call_function(category, tool_name, api_name, arg, strip, white_list, toolbench_key, args)
98
+ if status_code == "unexpected error, try again!":
99
+ arg = {change_name(k.replace("\\", "")): v for k, v in tool_input.items()}
100
+ response, status_code = _Call_function(category, tool_name, api_name, arg, strip, white_list, toolbench_key, args)
101
+ if status_code == "unexpected error, try again!":
102
+ print(f"Call function fails")
103
+ with open('wrong_log.json', 'a+', encoding='utf-8') as f:
104
+ line = json.dumps({
105
+ "id": 0,
106
+ "parameters": arg,
107
+ "wrong": response
108
+ }, ensure_ascii=False)
109
+ f.write(line + '\n')
110
+ return -1
111
+ return response
112
+
113
+ def inference(query, relevant_APIs, white_list, subtask, Answer_Agent, Executor_Agent, Verifier_Agent, query_id, args, max_step=3):
114
+ tool_check_num = Answer_Agent.run(question=query, task="tool_check", query_id=query_id)
115
+ #direct answer
116
+ if tool_check_num == 1:
117
+ input_dic = {"task": query}
118
+ answer = Answer_Agent.run(input_dic)
119
+ return answer, answer, None, None
120
+
121
+ previous_log = []
122
+ history_log = []
123
+ tool_used_dic = {}
124
+ relevant_APIs_ids = []
125
+ for idx in relevant_APIs:
126
+ ele = relevant_APIs[idx]
127
+ relevant_APIs_ids.append(str(ele["ID"]))
128
+ restart_time = 0
129
+ step_num = 0
130
+ hint = "Beginnig of the agent. No hint yet"
131
+ retry_tool_id = 0
132
+ retry_parameter = 0
133
+ re_time = 0
134
+ subtask_id = 0
135
+ restart = 0
136
+ while True:
137
+ if step_num >= max_step:
138
+ print("\n\nReach steps limits, return answers!\n\n")
139
+ answer_log = get_answer_log(history_log)
140
+ answer = Answer_Agent.run(question=query, previous_log=answer_log, task="final", query_id=query_id)
141
+ return answer, previous_log, re_time, history_log
142
+
143
+ if step_num not in tool_used_dic.keys():
144
+ tool_used_dic[step_num] = []
145
+
146
+ tool_used = tool_used_dic[step_num]
147
+
148
+ tool_list = []
149
+ for idx in relevant_APIs:
150
+ ele = relevant_APIs[idx]
151
+ ID = str(ele['ID'])
152
+ if ID in tool_used:
153
+ continue
154
+ des = ele['description']
155
+ name = ele["tool_name"]
156
+ tool_list.append({"ID": ID, "tool_name": name, "description": des})
157
+
158
+ if len(tool_list) == 0:
159
+ if len(previous_log) == 0:
160
+ answer_log = get_answer_log(history_log)
161
+ partial_answer = Answer_Agent.run(question=query, previous_log=answer_log, task="final", query_id=query_id)
162
+ answer = f"Sorry, I can't answer this question accurately using the existing tools. A partial answer is: {partial_answer}"
163
+ return answer, previous_log, re_time, history_log
164
+ else:
165
+ delete_log = previous_log.pop()
166
+ tool_used_dic[step_num] = []
167
+ step_num -= 1
168
+ tool_used_dic[step_num].append(delete_log["tool"])
169
+ restart_time += 1
170
+ re_time += 1
171
+ continue
172
+
173
+ current_log = {"thought": "", "action": "", "action_input": {}, "observation": "", "answer": "", "tool": "","id": subtask_id}
174
+
175
+ answer_log = get_answer_log(previous_log)
176
+
177
+ if retry_tool_id == 4:
178
+ tool_id = tool_list[0]["ID"]
179
+ tool_list = tool_list[0]
180
+ thought = Executor_Agent.run(question=query, tool_list=tool_list, previous_log=answer_log, hint=hint, query_id=query_id, task="thought")
181
+
182
+ else:
183
+ thought = Executor_Agent.run(question=query, tool_list=tool_list, previous_log=answer_log, hint=hint, query_id=query_id, task="thought")
184
+ tool_id = Executor_Agent.run(question=subtask, tool_list=tool_list, thought=thought, query_id=query_id, task="tool")
185
+
186
+ try:
187
+ tool_id = int(tool_id)
188
+ tool_id = str(tool_id)
189
+ if tool_id not in relevant_APIs_ids:
190
+ re_time += 1
191
+ retry_tool_id += 1
192
+ print("Tool ID wrong! Generate tool_id that do not exist!")
193
+ continue
194
+ tool_des_json = relevant_APIs[str(tool_id)]
195
+ retry_tool_id = 0
196
+ except:
197
+ retry_tool_id += 1
198
+ print("Tool ID wrong! Generate tool_id that do not exist!")
199
+ continue
200
+
201
+ tool_name_list = tool_des_json["tool_name"].split(":")
202
+ category_name = tool_name_list[0]
203
+ tool_name = tool_name_list[1]
204
+ api_name = tool_name_list[2]
205
+ API_doc = tool_des_json
206
+
207
+ while True:
208
+ try:
209
+ parameters = {}
210
+
211
+ if retry_parameter == 4:
212
+ restart = 1
213
+ retry_parameter = 0
214
+ print("No Para! Restart!")
215
+ break
216
+
217
+ parameter = Executor_Agent.run(api_dic=API_doc, question=query, previous_log=answer_log, thought=thought, query_id=query_id, task="parameter")
218
+ if parameter == -1:
219
+ retry_parameter += 1
220
+ re_time += 1
221
+ continue
222
+
223
+ if parameter == {}:
224
+ retry_parameter = 0
225
+ parameters = {}
226
+ break
227
+
228
+ for key in parameter:
229
+ value = parameter[key]
230
+ key = change_name(key)
231
+ parameters[key] = value
232
+
233
+ retry_parameter = 0
234
+ break
235
+
236
+ except:
237
+ if retry_parameter == 4:
238
+ parameters = {}
239
+ retry_parameter = 0
240
+ restart = 1
241
+ break
242
+ retry_parameter += 1
243
+ print("parameter generation fails, try again!")
244
+ re_time += 1
245
+ continue
246
+
247
+
248
+ api_name = change_name(standardize(api_name))
249
+
250
+ if restart != 1:
251
+ try:
252
+ observation = Call_function(category_name, tool_name, api_name, parameters, "truncate", white_list, args)
253
+ except:
254
+ observation = -1
255
+
256
+ if observation == -1:
257
+ restart = 1
258
+ observation = str({"error": "", "response": "call API fails"})
259
+
260
+ if restart == 1:
261
+ tool_used_dic[step_num].append(str(tool_id))
262
+ print('****Try Again For This Step****')
263
+ re_time += 1
264
+ restart = 0
265
+ continue
266
+
267
+
268
+ if len(previous_log) != 0:
269
+ previous_id = previous_log[-1]["id"]
270
+ else:
271
+ previous_id = -1
272
+
273
+ current_log["tool"] = str(tool_id)
274
+ current_log["thought"] = thought
275
+ current_log["action"] = api_name
276
+ current_log["action_input"] = parameters
277
+ current_log["observation"] = observation
278
+ previous_log.append(current_log)
279
+
280
+ observation_log = get_observation_log(previous_log)
281
+
282
+ answer = Answer_Agent.run(question=subtask, call_result=observation_log, query_id=query_id, task="answer")
283
+
284
+ previous_log[-1]["answer"] = answer
285
+
286
+ history_log_ele = {"thought": thought, "action": tool_name, "action_input": parameters, "observation": observation, "answer": answer, "previous_id": previous_id, "id": subtask_id}
287
+ history_log.append(history_log_ele)
288
+ subtask_id += 1
289
+
290
+ speak, status = Verifier_Agent.run(question=subtask, answer=answer, query_id=query_id)
291
+ if speak == -1 and status == -1:
292
+ step_num += 1
293
+ continue
294
+
295
+ try:
296
+ if int(status) == 0:
297
+ hint = speak
298
+ step_num += 1
299
+ continue
300
+ except:
301
+ step_num += 1
302
+ continue
303
+
304
+ else:
305
+ return answer, previous_log, re_time, history_log
306
+
307
+ def decompose_inference(query, relevant_APIs, api_list, white_list, Answer_Agent, Executor_Agent, Planning_Agent, Verifier_Agent, query_id, args):
308
+ while True:
309
+ subtasks = Planning_Agent.run(question=query, query_id=query_id)
310
+ if subtasks == -1:
311
+ continue
312
+ break
313
+ task_log = ""
314
+ history_log = []
315
+ previous_log_totals = []
316
+ re_time_total = 0
317
+ print(subtasks)
318
+ relevant_API_list = {}
319
+ tool_id = 0
320
+ for api in api_list:
321
+ for relevant_API in relevant_APIs:
322
+ if relevant_API[0] == api["tool_name"] and relevant_API[1] == api["api_name"]:
323
+ new_tool_name = api["category_name"]+":"+api["tool_name"]+":"+api["api_name"]
324
+ ele = {"ID": tool_id, "tool_name": new_tool_name, "description": api["api_description"], "required_parameters": api["required_parameters"], "optional_parameters": api["optional_parameters"]}
325
+ relevant_API_list[str(tool_id)] = ele
326
+ tool_id += 1
327
+
328
+ for subtask in subtasks:
329
+ task_log += f"question: {subtask}\n"
330
+ answer, previous_log, re_time, previous_log_total = inference(task_log, relevant_API_list, white_list, subtask, Answer_Agent, Executor_Agent, Verifier_Agent, query_id, args)
331
+ previous_log_totals.append(previous_log_total)
332
+ print(answer)
333
+ history_log += previous_log
334
+ re_time_total += re_time
335
+ task_log += f"answer: {answer}\n"
336
+ final_answer = Answer_Agent.run(question=query, previous_log=task_log, task="final", query_id=query_id)
337
+ return final_answer, history_log, task_log, re_time_total, previous_log_totals
338
+
339
+ def test(query_json, white_list, output_dir, whole_solution_dir, Answer_Agent, Executor_Agent, Planning_Agent, Verifier_Agent, args):
340
+ while True:
341
+ try:
342
+ global lock
343
+ total_query = len(query_json)
344
+ with tqdm(total=total_query, desc="Processing files", initial=0) as pbar:
345
+ for i, test_query in enumerate(query_json, start=0):
346
+ idx = test_query[0]
347
+ test_query = test_query[1]
348
+ query = test_query["query"]
349
+ relevant_APIs = test_query["relevant APIs"]
350
+ api_list = test_query["api_list"]
351
+ final_answer, previous_log, task_log,re_time, previous_log_totals = decompose_inference(query, relevant_APIs, api_list, white_list, Answer_Agent, Executor_Agent, Planning_Agent, Verifier_Agent, idx, args)
352
+ answer_details, total_steps = get_answer_details(final_answer, previous_log)
353
+ solution_tree, solution_total_steps = build_tree(previous_log_totals, task_log)
354
+ output_file_ele = {
355
+ "query": query,
356
+ "restart_time": re_time,
357
+ "answer": {
358
+ "method": "decompose_dfs",
359
+ "total_steps": total_steps,
360
+ "final_answer": final_answer,
361
+ "answer_details": answer_details
362
+ }
363
+ }
364
+
365
+ solution_file_ele = {
366
+ "query": query,
367
+ "total_steps": solution_total_steps,
368
+ "task_log": task_log,
369
+ "final_answer": final_answer,
370
+ "answer_path": answer_details,
371
+ "total_path": solution_tree
372
+ }
373
+ file_name = f"{idx}.json"
374
+ output_file = os.path.join(output_dir, file_name)
375
+ whole_solution_file = os.path.join(whole_solution_dir, file_name)
376
+ lock.acquire()
377
+ with open(output_file, "w") as file:
378
+ json.dump(output_file_ele, file, ensure_ascii=False, indent=4)
379
+ with open(whole_solution_file, "w") as file:
380
+ json.dump(solution_file_ele, file, ensure_ascii=False, indent=4)
381
+ lock.release()
382
+ pbar.update(1)
383
+ return
384
+
385
+ except Exception as e:
386
+ print(e)
387
+ print("some error occurs, continue...")
388
+ time.sleep(60)
389
+ continue
390
+
391
+ def parse_arg():
392
+ parser = argparse.ArgumentParser()
393
+ parser.add_argument('--test_query_id_path', type=str, default="toolbench_data/data/test_query_ids", required=False, help='test query ids for different test sets')
394
+ parser.add_argument('--method_name', type=str, default="smurfs-all", required=False, help='the inference method')
395
+ parser.add_argument('--model_name', type=str, default="your_model_name", required=False, help='the model name for the vllm model')
396
+ parser.add_argument('--query_file_dir', type=str, default="toolbench_data/data/test_instruction", required=False, help='the directory that contains test sets')
397
+ parser.add_argument('--tool_env_dir', type=str, default="toolbench_data/data/toolenv/tools", required=False, help='tool environment for the toolbench')
398
+ parser.add_argument('--toolbench_key', type=str, default="",required=False, help='your toolbench key to request rapidapi service')
399
+ parser.add_argument('--rapidapi_key', type=str, default="",required=False, help='your rapidapi key to request rapidapi service')
400
+ parser.add_argument('--use_rapidapi_key', action="store_true", help="To use customized rapidapi service or not.")
401
+ parser.add_argument('--api_customization', action="store_true", help="To use customized api or not.")
402
+ args = parser.parse_args()
403
+
404
+ return args
405
+
406
+ if __name__ == '__main__':
407
+ threads = []
408
+ lock = threading.Lock()
409
+
410
+ args = parse_arg()
411
+
412
+ test_query_id_path = args.test_query_id_path
413
+ method_name = args.method_name
414
+ model_name = args.model_name
415
+ query_file_dir = args.query_file_dir
416
+ tool_env_dir = args.tool_env_dir
417
+
418
+ toolbench_key = args.toolbench_key
419
+ rapidapi_key = args.rapidapi_key
420
+ use_rapidapi_key = args.use_rapidapi_key
421
+ api_customization = args.api_customization
422
+
423
+ chat = vllm_Model(model_name=model_name)
424
+
425
+
426
+ for test_set in test_sets:
427
+ total_output_file = f"data/{method_name}/{test_set}_raw.json"
428
+
429
+ test_ids = list(json.load(open(os.path.join(test_query_id_path, test_set+".json"), "r")).keys())
430
+
431
+ query_file = f'{query_file_dir}/{test_set}.json'
432
+
433
+ output_dir = f"data/{method_name}/{test_set}/answer"
434
+
435
+ whole_solution_dir = f"data/{method_name}/{test_set}/whole"
436
+
437
+ logger_dir = f"data/{method_name}/{test_set}/agent_log"
438
+
439
+ if not os.path.exists(output_dir):
440
+ os.makedirs(output_dir)
441
+
442
+ if not os.path.exists(whole_solution_dir):
443
+ os.makedirs(whole_solution_dir)
444
+
445
+ if not os.path.exists(logger_dir):
446
+ os.makedirs(logger_dir)
447
+
448
+ Answer_Agent = answer_agent(llm=chat, logger_dir=logger_dir)
449
+ Executor_Agent = executor_agent(llm=chat, logger_dir=logger_dir)
450
+ Planning_Agent = planning_agent(llm=chat, logger_dir=logger_dir)
451
+ Verifier_Agent = verifier_agent(llm=chat, logger_dir=logger_dir)
452
+
453
+ items = os.listdir(output_dir)
454
+ for i in range(len(items)):
455
+ items[i] = items[i].split(".")[0]
456
+
457
+ white_list = get_white_list(tool_env_dir)
458
+
459
+ with open(query_file) as file:
460
+ query_json = json.load(file)
461
+ # with open(tool_doc_dir) as file:
462
+ # tool_doc = json.load(file)
463
+
464
+ print(len(items))
465
+ query_json_to_do = []
466
+ # if len(items) != 0:
467
+ for idx, q in enumerate(query_json):
468
+ # print(idx)
469
+ if str(idx) in items:
470
+ continue
471
+ query_id = q["query_id"]
472
+ if str(query_id) not in test_ids:
473
+ continue
474
+ query_json_to_do_ele = (idx, q)
475
+ query_json_to_do.append(query_json_to_do_ele)
476
+ # else:
477
+ # query_json_to_do = query_json
478
+
479
+ total_len = len(query_json_to_do)
480
+ query_len = len(query_json)
481
+ print(total_len)
482
+
483
+ if total_len < 20:
484
+ for i in range(total_len):
485
+ if total_len == 0:
486
+ break
487
+
488
+ start = i
489
+ end = i+1
490
+ if i == total_len-1:
491
+ query_json_cur = query_json_to_do[start:]
492
+ else:
493
+ query_json_cur = query_json_to_do[start: end]
494
+ t = threading.Thread(target=test, args=(query_json_cur, white_list, output_dir, whole_solution_dir, Answer_Agent, Executor_Agent, Planning_Agent, Verifier_Agent, args))
495
+ t.start()
496
+ threads.append(t)
497
+
498
+
499
+ else:
500
+ for i in range(20):
501
+
502
+ if total_len == 0:
503
+ break
504
+
505
+ start = round(total_len/20)*i
506
+ end = round(total_len/20)*(i+1)
507
+ if i == 19:
508
+ query_json_cur = query_json_to_do[start:]
509
+ else:
510
+ query_json_cur = query_json_to_do[start: end]
511
+ t = threading.Thread(target=test, args=(query_json_cur, white_list, output_dir, whole_solution_dir, Answer_Agent, Executor_Agent, Planning_Agent, Verifier_Agent, args))
512
+ t.start()
513
+ threads.append(t)
514
+
515
+ for thread in threads:
516
+ thread.join()
517
+
518
+ total_json = {}
519
+ items = os.listdir(output_dir)
520
+ for item in items:
521
+ item_path = os.path.join(output_dir, item)
522
+ idx = item.split(".")[0]
523
+ total_json[str(idx)] = json.load(open(item_path, 'r'))
524
+
525
+ with open(total_output_file, 'w') as file:
526
+ json.dump(total_json, file, indent=4, ensure_ascii=False)
527
+
Smurfs/inference/server.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+ import json
3
+ import os
4
+ from typing import Union
5
+ import sys
6
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
7
+ from Smurfs.inference.utils import standardize, change_name
8
+ import random
9
+
10
+
11
+ class Info(BaseModel):
12
+ category: str
13
+ tool_name: str
14
+ api_name: str
15
+ tool_input: Union[str, dict]
16
+ strip: str
17
+
18
+ def prepare_tool_name_and_url(tools_root, info):
19
+ category = info.category
20
+ standard_category = category.replace(" ", "_").replace(",", "_").replace("/", "_")
21
+ while " " in standard_category or "," in standard_category:
22
+ standard_category = standard_category.replace(" ", "_").replace(",", "_")
23
+ standard_category = standard_category.replace("__", "_")
24
+
25
+ tool_name = info.tool_name
26
+ api_name = change_name(standardize(info.api_name))
27
+ if not tool_name.endswith(f"_for_{standard_category}"):
28
+ tool_name = standardize(info.tool_name)
29
+ code_string = f"""from {tools_root}.{standard_category}.{tool_name}.api import {api_name}"""
30
+ tool_name += f"_for_{standard_category}"
31
+ else:
32
+ tmp_tool_name = standardize(tool_name.replace(f"_for_{standard_category}", ""))
33
+ code_string = f"""from {tools_root}.{standard_category}.{tmp_tool_name}.api import {api_name}"""
34
+ return tool_name, standard_category, api_name, code_string
35
+
36
+ def process_error(response):
37
+ save_cache_flag = False
38
+ switch_flag = False
39
+ if "The request to the API has timed out. Please try again later, or if the issue persists" in str(response):
40
+ return_dict = {"error": "API temporarily not working error...", "response": response}
41
+
42
+ if "Your Client (working) ---> Gateway (working) ---> API (not working)" in str(response):
43
+ return_dict = {"error": "API not working error...", "response": response}
44
+
45
+ elif "Unauthorized" in str(response) or "unauthorized" in str(response):
46
+ save_cache_flag = True
47
+ return_dict = {"error": "Unauthorized error...", "response": response}
48
+
49
+ elif "You are not subscribed to this API." in str(response):
50
+ switch_flag = True
51
+ return_dict = {"error": "Unsubscribed error...", "response": response}
52
+
53
+ elif "Too many requests" in str(response):
54
+ switch_flag = True
55
+ return_dict = {"error": "Too many requests error...", "response": response}
56
+
57
+ elif "You have exceeded" in str(response) or "you are being rate limited" in str(response):
58
+ switch_flag = True
59
+ return_dict = {"error": "Rate limit error...", "response": response}
60
+
61
+ elif "Access restricted. Check credits balance or enter the correct API key." in str(response):
62
+ switch_flag = True
63
+ return_dict = {"error": "Rate limit error...", "response": response}
64
+
65
+ elif "Oops, an error in the gateway has occurred." in str(response):
66
+ switch_flag = True
67
+ return_dict = {"error": "Gateway error...", "response": response}
68
+
69
+ elif "Blocked User. Please contact your API provider." in str(response):
70
+ switch_flag = True
71
+ return_dict = {"error": "Blocked error...", "response": response}
72
+
73
+ elif "error" in str(response):
74
+ return_dict = {"error": "Message error...", "response": response}
75
+
76
+ else:
77
+ save_cache_flag = True
78
+ return_dict = {"error": "", "response": response}
79
+ return return_dict, save_cache_flag, switch_flag
80
+
81
+ def run(toolbench_code_string, toolbench_api_name, toolbench_input_params_str):
82
+ # get observation
83
+ success_flag = False
84
+ switch_flag = False
85
+ save_cache = False
86
+ exec(toolbench_code_string)
87
+ try:
88
+ eval_func_str = f"{toolbench_api_name}({toolbench_input_params_str})"
89
+ new_func = eval(eval_func_str)
90
+ response, save_cache, switch_flag = process_error(new_func)
91
+ success_flag = True
92
+ except Exception as e:
93
+ response = {"error": f"Function executing {toolbench_code_string} error...\n{e}", "response": ""}
94
+ save_cache = False
95
+ return success_flag, switch_flag, response, save_cache
96
+
97
+
98
+ def dict_shorten(origin: dict, schema: dict):
99
+ for key, value in list(origin.items()):
100
+ if key not in schema:
101
+ del origin[key]
102
+ else:
103
+ if isinstance(value, dict):
104
+ dict_shorten(value, schema[key]) # schema[key] should be a dict
105
+ elif isinstance(value, list):
106
+ if value:
107
+ if isinstance(value[0], dict):
108
+ for item in value:
109
+ dict_shorten(item, schema[key][0]) # schema[key] should be a list with only one dict element
110
+ return origin
111
+
112
+ def observation_shorten(schema_root, response_dict, category, tool_name, api_name, strip_method):
113
+ print(random.random())
114
+ if strip_method == "filter" or (strip_method == "random" and random.random() > 0.5):
115
+ if isinstance(response_dict["response"], dict):
116
+ if os.path.exists(os.path.join(schema_root, category)):
117
+ if os.path.exists(os.path.join(schema_root, category, tool_name+".json")):
118
+ schema_dicts = json.load(open(os.path.join(schema_root, category, tool_name+".json"), "r"))
119
+ api_list = schema_dicts["api_list"]
120
+ schema = None
121
+ for schema_dict in api_list:
122
+ schema_api_name = change_name(standardize(schema_dict["name"]))
123
+ if schema_api_name == api_name and len(schema_dict["schema"]) > 0:
124
+ schema = schema_dict["schema"]
125
+ break
126
+ if schema is not None:
127
+ response_dict["response"] = dict_shorten(response_dict["response"], schema)
128
+ return str(response_dict["response"])
129
+
130
+
131
+ def get_rapidapi_response(input_dict: dict, api_customization: bool=False, tools_root: str="data.toolenv.tools", schema_root: str="data/toolenv/response_examples"):
132
+ info = Info
133
+ info.category = input_dict['category']
134
+ info.tool_name = input_dict['tool_name']
135
+ info.api_name = input_dict['api_name']
136
+ info.tool_input = input_dict['tool_input']
137
+ info.strip = input_dict['strip']
138
+ rapidapi_key = input_dict['rapidapi_key']
139
+
140
+ tool_name, standard_category, api_name, code_string = prepare_tool_name_and_url(tools_root, info)
141
+ tool_input = info.tool_input
142
+
143
+ strip_method = info.strip
144
+
145
+ try:
146
+ tool_input = json.loads(tool_input)
147
+ except Exception as e:
148
+ if tool_input == "":
149
+ tool_input = {}
150
+ else:
151
+ print(f"Can not parse tool input into json: {tool_input}")
152
+ response_dict = {"error": f"Tool input parse error...\n", "response": ""}
153
+ return response_dict
154
+
155
+ input_params_str = ""
156
+ if len(tool_input) > 0:
157
+ for key, value in tool_input.items():
158
+ if isinstance(value, str):
159
+ input_params_str += f'{key}="{value}", '
160
+ else:
161
+ input_params_str += f'{key}={value}, '
162
+ if not api_customization:
163
+ input_params_str += f"toolbench_rapidapi_key='{rapidapi_key}'"
164
+ success_flag, switch_flag, response_dict, save_cache = run(code_string, api_name, input_params_str)
165
+ observation = observation_shorten(schema_root, response_dict, standard_category, tool_name.replace(f"_for_{standard_category}", ""), api_name, strip_method)
166
+ result = str(observation)[:2048]
167
+ return {"error": response_dict['error'], "response": result}
168
+
169
+
170
+ if __name__ == "__main__":
171
+ result = get_rapidapi_response({
172
+ "category": "Social",
173
+ "tool_name": "olato_quotes",
174
+ "api_name": "love_quote",
175
+ "tool_input": '{}',
176
+ "strip": "filter",
177
+ "rapidapi_key": ""
178
+ })
179
+ print(result)
Smurfs/inference/smurfs_worker.py ADDED
@@ -0,0 +1,1040 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # — coding: utf-8 –
2
+ import json
3
+ import sys
4
+ import argparse
5
+ import time
6
+ import requests
7
+ import os
8
+ from Smurfs.inference.utils import change_name, standardize, get_white_list, get_answer_log, get_observation_log, build_tree, get_answer_details, test_sets
9
+ from tqdm import tqdm
10
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
11
+ from Smurfs.model.vllm_model.vllm_model import vllm_Model
12
+ from Smurfs.inference.server import get_rapidapi_response
13
+ import threading
14
+ from Smurfs.agents.answer_agent.answer import answer_agent
15
+ from Smurfs.agents.executor_agent.executor import executor_agent
16
+ from Smurfs.agents.planning_agent.planner import planning_agent
17
+ from Smurfs.agents.verifier_agent.verifier import verifier_agent
18
+ from termcolor import colored
19
+ class smurfs_worker:
20
+ def __init__(self, available_tools, tool_env, llm, method_name, test_set, Answer_Agent, Executor_Agent, Planning_Agent, Verifier_Agent):
21
+ #available_tools的格式形如toolbench里面的api_list里的格式,只需要api_name
22
+ #tool_env是一个工具函数里用来存储工具代码的py文件中的所有函数的字典,key为函数名,value是函数对象
23
+ self.available_tools = available_tools
24
+
25
+ self.output_dir = f"data/{method_name}/{test_set}/answer"
26
+
27
+ self.whole_solution_dir = f"data/{method_name}/{test_set}/whole"
28
+
29
+ self.logger_dir = f"data/{method_name}/{test_set}/agent_log"
30
+
31
+ if not os.path.exists(self.output_dir):
32
+ os.makedirs(self.output_dir)
33
+
34
+ if not os.path.exists(self.whole_solution_dir):
35
+ os.makedirs(self.whole_solution_dir)
36
+
37
+ if not os.path.exists(self.logger_dir):
38
+ os.makedirs(self.logger_dir)
39
+
40
+ self.Answer_Agent = Answer_Agent(llm=llm, logger_dir=self.logger_dir)
41
+ self.Executor_Agent = Executor_Agent(llm=llm, logger_dir=self.logger_dir)
42
+ self.Planning_Agent = Planning_Agent(llm=llm, logger_dir=self.logger_dir)
43
+ self.Verifier_Agent = Verifier_Agent(llm=llm, logger_dir=self.logger_dir)
44
+ self.tool_env = tool_env
45
+
46
+ def inference(self, query, relevant_APIs, subtask, query_id, max_step=3):
47
+ # tool_check_num, reason = self.Answer_Agent.run(question=query, task="tool_check", query_id=query_id)
48
+ # #direct answer
49
+ # if tool_check_num == 1:
50
+ # # input_dic = {"task": query}
51
+ # answer = self.Answer_Agent.run(question=query, task="direct", query_id=query_id)
52
+ # previous_log = [{"thought": reason, "action": "", "action_input": "", "observation": "", "answer": answer, "tool": "","id": 0}]
53
+ # history_log = [{"thought": reason, "action": "", "action_input": "", "observation": "", "answer": answer, "previous_id": -1, "id": 0}]
54
+ # return answer, previous_log, 0, history_log
55
+
56
+ previous_log = []
57
+ history_log = []
58
+ tool_used_dic = {}
59
+ relevant_APIs_ids = []
60
+ for idx in relevant_APIs:
61
+ ele = relevant_APIs[idx]
62
+ relevant_APIs_ids.append(str(ele["ID"]))
63
+ restart_time = 0
64
+ step_num = 0
65
+ hint = "Beginnig of the agent. No hint yet"
66
+ retry_tool_id = 0
67
+ retry_parameter = 0
68
+ re_time = 0
69
+ subtask_id = 0
70
+ restart = 0
71
+ while True:
72
+ if step_num >= max_step:
73
+ print("\n\nReach steps limits, return answers!\n\n")
74
+ answer_log = get_answer_log(history_log)
75
+ answer = self.Answer_Agent.run(question=query, previous_log=answer_log, task="final", query_id=query_id)
76
+ return answer, previous_log, re_time, history_log
77
+
78
+ if step_num not in tool_used_dic.keys():
79
+ tool_used_dic[step_num] = []
80
+
81
+ tool_used = tool_used_dic[step_num]
82
+
83
+ tool_list = []
84
+ for idx in relevant_APIs:
85
+ ele = relevant_APIs[idx]
86
+ ID = str(ele['ID'])
87
+ if ID in tool_used:
88
+ continue
89
+ des = ele['description']
90
+ name = ele["tool_name"]
91
+ tool_list.append({"ID": ID, "tool_name": name, "description": des})
92
+
93
+ if len(tool_list) == 0:
94
+ if len(previous_log) == 0:
95
+ answer_log = get_answer_log(history_log)
96
+ partial_answer = self.Answer_Agent.run(question=query, previous_log=answer_log, task="final", query_id=query_id)
97
+ answer = f"Sorry, I can't answer this question accurately using the existing tools. A partial answer is: {partial_answer}"
98
+ return answer, previous_log, re_time, history_log
99
+ else:
100
+ delete_log = previous_log.pop()
101
+ tool_used_dic[step_num] = []
102
+ step_num -= 1
103
+ tool_used_dic[step_num].append(delete_log["tool"])
104
+ restart_time += 1
105
+ re_time += 1
106
+ continue
107
+
108
+ current_log = {"thought": "", "action": "", "action_input": {}, "observation": "", "answer": "", "tool": "","id": subtask_id}
109
+
110
+ answer_log = get_answer_log(previous_log)
111
+
112
+ if retry_tool_id == 4:
113
+ tool_id = tool_list[0]["ID"]
114
+ tool_list = tool_list[0]
115
+ thought = self.Executor_Agent.run(question=query, tool_list=tool_list, previous_log=answer_log, hint=hint, query_id=query_id, task="thought")
116
+
117
+ else:
118
+ thought = self.Executor_Agent.run(question=query, tool_list=tool_list, previous_log=answer_log, hint=hint, query_id=query_id, task="thought")
119
+ tool_id = self.Executor_Agent.run(question=subtask, tool_list=tool_list, thought=thought, query_id=query_id, task="tool")
120
+
121
+ try:
122
+ tool_id = int(tool_id)
123
+ tool_id = str(tool_id)
124
+ if tool_id not in relevant_APIs_ids:
125
+ re_time += 1
126
+ retry_tool_id += 1
127
+ print("Tool ID wrong! Generate tool_id that do not exist!")
128
+ continue
129
+ tool_des_json = relevant_APIs[str(tool_id)]
130
+ retry_tool_id = 0
131
+ except:
132
+ retry_tool_id += 1
133
+ print("Tool ID wrong! Generate tool_id that do not exist!")
134
+ continue
135
+
136
+ # tool_name_list = tool_des_json["tool_name"].split(":")
137
+ # category_name = tool_name_list[0]
138
+ # tool_name = tool_name_list[1]
139
+ api_name = tool_des_json["tool_name"]
140
+ API_doc = tool_des_json
141
+
142
+ while True:
143
+ try:
144
+ parameters = {}
145
+
146
+ if retry_parameter == 4:
147
+ restart = 1
148
+ retry_parameter = 0
149
+ print("No Para! Restart!")
150
+ break
151
+
152
+ parameter = self.Executor_Agent.run(api_dic=API_doc, question=query, previous_log=answer_log, thought=thought, query_id=query_id, task="parameter")
153
+ if parameter == -1:
154
+ retry_parameter += 1
155
+ re_time += 1
156
+ continue
157
+
158
+ if parameter == {}:
159
+ retry_parameter = 0
160
+ parameters = {}
161
+ break
162
+
163
+ for key in parameter:
164
+ value = parameter[key]
165
+ key = change_name(key)
166
+ parameters[key] = value
167
+
168
+ retry_parameter = 0
169
+ break
170
+
171
+ except:
172
+ if retry_parameter == 4:
173
+ parameters = {}
174
+ retry_parameter = 0
175
+ restart = 1
176
+ break
177
+ retry_parameter += 1
178
+ print("parameter generation fails, try again!")
179
+ re_time += 1
180
+ continue
181
+
182
+
183
+ # api_name = change_name(standardize(api_name))
184
+
185
+ if restart != 1:
186
+ try:
187
+ observation = self.Call_function(api_name, parameters)
188
+ except:
189
+ observation = -1
190
+
191
+ if observation == -1:
192
+ restart = 1
193
+ observation = str({"error": "", "response": "call API fails"})
194
+
195
+ if restart == 1:
196
+ tool_used_dic[step_num].append(str(tool_id))
197
+ print('****Try Again For This Step****')
198
+ re_time += 1
199
+ restart = 0
200
+ continue
201
+
202
+
203
+ if len(previous_log) != 0:
204
+ previous_id = previous_log[-1]["id"]
205
+ else:
206
+ previous_id = -1
207
+
208
+ current_log["tool"] = str(tool_id)
209
+ current_log["thought"] = thought
210
+ current_log["action"] = api_name
211
+ current_log["action_input"] = parameters
212
+ current_log["observation"] = observation
213
+ previous_log.append(current_log)
214
+ print("##########Tool Response##########")
215
+ print(f"{observation}\n")
216
+ observation_log = get_observation_log(previous_log)
217
+
218
+ answer = self.Answer_Agent.run(question=subtask, call_result=observation_log, query_id=query_id, task="answer")
219
+
220
+ previous_log[-1]["answer"] = answer
221
+
222
+ history_log_ele = {"thought": thought, "action": api_name, "action_input": parameters, "observation": observation, "answer": answer, "previous_id": previous_id, "id": subtask_id}
223
+ history_log.append(history_log_ele)
224
+ subtask_id += 1
225
+
226
+ speak, status = self.Verifier_Agent.run(question=subtask, answer=answer, query_id=query_id)
227
+ if speak == -1 and status == -1:
228
+ step_num += 1
229
+ continue
230
+
231
+ try:
232
+ if int(status) == 0:
233
+ hint = speak
234
+ step_num += 1
235
+ continue
236
+ except:
237
+ step_num += 1
238
+ continue
239
+
240
+ else:
241
+ return answer, previous_log, re_time, history_log
242
+
243
+ def decompose_inference(self, query, query_id):
244
+ while True:
245
+ subtasks = self.Planning_Agent.run(question=query, query_id=query_id)
246
+ if subtasks == -1:
247
+ continue
248
+ break
249
+ task_log = ""
250
+ history_log = []
251
+ previous_log_totals = []
252
+ re_time_total = 0
253
+ # print(subtasks)
254
+ relevant_API_list = {}
255
+ tool_id = 0
256
+ for api in self.available_tools:
257
+ tool_name = api["api_name"]
258
+ ele = {"ID": tool_id, "tool_name": tool_name, "description": api["api_description"], "required_parameters": api["required_parameters"], "optional_parameters": api["optional_parameters"]}
259
+ relevant_API_list[str(tool_id)] = ele
260
+ tool_id += 1
261
+
262
+ for subtask in subtasks:
263
+ task_log += f"question: {subtask}\n"
264
+ answer, previous_log, re_time, previous_log_total = self.inference(task_log, relevant_API_list, subtask, query_id)
265
+ previous_log_totals.append(previous_log_total)
266
+ # print(answer)
267
+ history_log += previous_log
268
+ re_time_total += re_time
269
+ task_log += f"answer: {answer}\n"
270
+ final_answer = self.Answer_Agent.run(question=query, previous_log=task_log, task="final", query_id=query_id)
271
+ return final_answer, history_log, task_log, re_time_total, previous_log_totals
272
+
273
+ def run(self, input, query_id):
274
+ # result = {}
275
+ # st = time.time()
276
+ final_answer, previous_log, task_log,re_time, previous_log_totals = self.decompose_inference(input, query_id)
277
+ answer_details, total_steps = get_answer_details(final_answer, previous_log)
278
+ solution_tree, solution_total_steps = build_tree(previous_log_totals, task_log)
279
+ output_file_ele = {
280
+ "query": input,
281
+ "restart_time": re_time,
282
+ "answer": {
283
+ "method": "decompose_dfs",
284
+ "total_steps": total_steps,
285
+ "final_answer": final_answer,
286
+ "answer_details": answer_details
287
+ }
288
+ }
289
+
290
+ solution_file_ele = {
291
+ "query": input,
292
+ "total_steps": solution_total_steps,
293
+ "task_log": task_log,
294
+ "final_answer": final_answer,
295
+ "answer_path": answer_details,
296
+ "total_path": solution_tree
297
+ }
298
+ return final_answer, output_file_ele, solution_file_ele
299
+
300
+ def save_solution(self, output_file_ele, solution_file_ele, idx):
301
+ file_name = f"{idx}.json"
302
+ output_file = os.path.join(self.output_dir, file_name)
303
+ whole_solution_file = os.path.join(self.whole_solution_dir, file_name)
304
+ with open(output_file, "w") as file:
305
+ json.dump(output_file_ele, file, ensure_ascii=False, indent=4)
306
+ with open(whole_solution_file, "w") as file:
307
+ json.dump(solution_file_ele, file, ensure_ascii=False, indent=4)
308
+
309
+ def Call_function(self, tool_name, args):
310
+ try:
311
+ print(tool_name)
312
+ func = self.tool_env[tool_name]
313
+ observation = func(**args)
314
+ return observation
315
+ except Exception as e:
316
+ print(e)
317
+ print(f"Call function fails")
318
+ with open('wrong_log.json', 'a+', encoding='utf-8') as f:
319
+ line = json.dumps({
320
+ "id": 0,
321
+ "parameters": args,
322
+ "tool": tool_name,
323
+ "wrong": str(e)
324
+ }, ensure_ascii=False)
325
+ f.write(line + '\n')
326
+ return -1
327
+
328
+ class smurfs_hotpot_worker:
329
+ def __init__(self, available_tools, tool_env, llm, method_name, test_set, Answer_Agent, Executor_Agent, Planning_Agent, Verifier_Agent):
330
+ #available_tools的格式形如toolbench里面的api_list里的格式,只需要api_name
331
+ #tool_env是一个工具函数里用来存储工具代码的py文件中的所有函数的字典,key为函数名,value是函数对象
332
+ self.available_tools = available_tools
333
+
334
+ self.output_dir = f"data/{method_name}/{test_set}/answer"
335
+
336
+ self.whole_solution_dir = f"data/{method_name}/{test_set}/whole"
337
+
338
+ self.logger_dir = f"data/{method_name}/{test_set}/agent_log"
339
+
340
+ if not os.path.exists(self.output_dir):
341
+ os.makedirs(self.output_dir)
342
+
343
+ if not os.path.exists(self.whole_solution_dir):
344
+ os.makedirs(self.whole_solution_dir)
345
+
346
+ if not os.path.exists(self.logger_dir):
347
+ os.makedirs(self.logger_dir)
348
+
349
+ self.Answer_Agent = Answer_Agent(llm=llm, logger_dir=self.logger_dir)
350
+ self.Executor_Agent = Executor_Agent(llm=llm, logger_dir=self.logger_dir)
351
+ self.Planning_Agent = Planning_Agent(llm=llm, logger_dir=self.logger_dir)
352
+ self.Verifier_Agent = Verifier_Agent(llm=llm, logger_dir=self.logger_dir)
353
+ self.tool_class = tool_env
354
+ self.tool_env = {}
355
+
356
+ def inference(self, query, relevant_APIs, subtask, query_id, max_step=3):
357
+ # tool_check_num = self.Answer_Agent.run(question=query, task="tool_check", query_id=query_id)
358
+ # #direct answer
359
+ # if tool_check_num == 1:
360
+ # input_dic = {"task": query}
361
+ # answer = self.Answer_Agent.run(input_dic)
362
+ # return answer, answer, None, None
363
+
364
+ previous_log = []
365
+ history_log = []
366
+ tool_used_dic = {}
367
+ relevant_APIs_ids = []
368
+ for idx in relevant_APIs:
369
+ ele = relevant_APIs[idx]
370
+ relevant_APIs_ids.append(str(ele["ID"]))
371
+ restart_time = 0
372
+ step_num = 0
373
+ hint = "Beginnig of the agent. No hint yet"
374
+ retry_tool_id = 0
375
+ retry_parameter = 0
376
+ re_time = 0
377
+ subtask_id = 0
378
+ restart = 0
379
+ while True:
380
+ if step_num >= max_step:
381
+ print("\n\nReach steps limits, return answers!\n\n")
382
+ answer_log = get_answer_log(history_log)
383
+ answer = self.Answer_Agent.run(question=query, previous_log=answer_log, task="final", query_id=query_id)
384
+ return answer, previous_log, re_time, history_log
385
+
386
+ if step_num not in tool_used_dic.keys():
387
+ tool_used_dic[step_num] = []
388
+
389
+ tool_used = tool_used_dic[step_num]
390
+
391
+ tool_list = []
392
+ for idx in relevant_APIs:
393
+ ele = relevant_APIs[idx]
394
+ ID = str(ele['ID'])
395
+ if ID in tool_used:
396
+ continue
397
+ des = ele['description']
398
+ name = ele["tool_name"]
399
+ tool_list.append({"ID": ID, "tool_name": name, "description": des})
400
+
401
+ if len(tool_list) == 0:
402
+ if len(previous_log) == 0:
403
+ answer_log = get_answer_log(history_log)
404
+ partial_answer = self.Answer_Agent.run(question=query, previous_log=answer_log, task="final", query_id=query_id)
405
+ answer = f"Sorry, I can't answer this question accurately using the existing tools. A partial answer is: {partial_answer}"
406
+ return answer, previous_log, re_time, history_log
407
+ else:
408
+ delete_log = previous_log.pop()
409
+ tool_used_dic[step_num] = []
410
+ step_num -= 1
411
+ tool_used_dic[step_num].append(delete_log["tool"])
412
+ restart_time += 1
413
+ re_time += 1
414
+ continue
415
+
416
+ current_log = {"thought": "", "action": "", "action_input": {}, "observation": "", "answer": "", "tool": "","id": subtask_id}
417
+
418
+ answer_log = get_answer_log(previous_log)
419
+
420
+ if retry_tool_id == 4:
421
+ tool_id = tool_list[0]["ID"]
422
+ tool_list = tool_list[0]
423
+ thought = self.Executor_Agent.run(question=query, tool_list=tool_list, previous_log=answer_log, hint=hint, query_id=query_id, task="thought")
424
+
425
+ else:
426
+ thought = self.Executor_Agent.run(question=query, tool_list=tool_list, previous_log=answer_log, hint=hint, query_id=query_id, task="thought")
427
+ tool_id = self.Executor_Agent.run(question=subtask, tool_list=tool_list, thought=thought, query_id=query_id, task="tool")
428
+
429
+ try:
430
+ tool_id = int(tool_id)
431
+ tool_id = str(tool_id)
432
+ if tool_id not in relevant_APIs_ids:
433
+ re_time += 1
434
+ retry_tool_id += 1
435
+ print("Tool ID wrong! Generate tool_id that do not exist!")
436
+ continue
437
+ tool_des_json = relevant_APIs[str(tool_id)]
438
+ retry_tool_id = 0
439
+ except:
440
+ retry_tool_id += 1
441
+ print("Tool ID wrong! Generate tool_id that do not exist!")
442
+ continue
443
+
444
+ # tool_name_list = tool_des_json["tool_name"].split(":")
445
+ # category_name = tool_name_list[0]
446
+ # tool_name = tool_name_list[1]
447
+ api_name = tool_des_json["tool_name"]
448
+ API_doc = tool_des_json
449
+
450
+ while True:
451
+ try:
452
+ parameters = {}
453
+
454
+ if retry_parameter == 4:
455
+ restart = 1
456
+ retry_parameter = 0
457
+ print("No Para! Restart!")
458
+ break
459
+
460
+ parameter = self.Executor_Agent.run(api_dic=API_doc, question=query, previous_log=answer_log, thought=thought, query_id=query_id, task="parameter")
461
+ if parameter == -1:
462
+ retry_parameter += 1
463
+ re_time += 1
464
+ continue
465
+
466
+ if parameter == {}:
467
+ retry_parameter = 0
468
+ parameters = {}
469
+ break
470
+
471
+ for key in parameter:
472
+ value = parameter[key]
473
+ key = change_name(key)
474
+ parameters[key] = value
475
+
476
+ retry_parameter = 0
477
+ break
478
+
479
+ except:
480
+ if retry_parameter == 4:
481
+ parameters = {}
482
+ retry_parameter = 0
483
+ restart = 1
484
+ break
485
+ retry_parameter += 1
486
+ print("parameter generation fails, try again!")
487
+ re_time += 1
488
+ continue
489
+
490
+
491
+ # api_name = change_name(standardize(api_name))
492
+
493
+ if restart != 1:
494
+ try:
495
+ observation = self.Call_function(api_name, parameters)
496
+ except:
497
+ observation = -1
498
+
499
+ if observation == -1:
500
+ restart = 1
501
+ observation = str({"error": "", "response": "call API fails"})
502
+
503
+ if restart == 1:
504
+ tool_used_dic[step_num].append(str(tool_id))
505
+ print('****Try Again For This Step****')
506
+ re_time += 1
507
+ restart = 0
508
+ continue
509
+
510
+
511
+ if len(previous_log) != 0:
512
+ previous_id = previous_log[-1]["id"]
513
+ else:
514
+ previous_id = -1
515
+
516
+ current_log["tool"] = str(tool_id)
517
+ current_log["thought"] = thought
518
+ current_log["action"] = api_name
519
+ current_log["action_input"] = parameters
520
+ current_log["observation"] = observation
521
+ print("##########Tool Response##########")
522
+ print(f"{observation}\n")
523
+ previous_log.append(current_log)
524
+
525
+ observation_log = get_observation_log(previous_log)
526
+
527
+ answer = self.Answer_Agent.run(question=subtask, call_result=observation_log, query_id=query_id, task="answer")
528
+
529
+ previous_log[-1]["answer"] = answer
530
+
531
+ history_log_ele = {"thought": thought, "action": api_name, "action_input": parameters, "observation": observation, "answer": answer, "previous_id": previous_id, "id": subtask_id}
532
+ history_log.append(history_log_ele)
533
+ subtask_id += 1
534
+
535
+ speak, status = self.Verifier_Agent.run(question=subtask, answer=answer, query_id=query_id)
536
+ if speak == -1 and status == -1:
537
+ step_num += 1
538
+ continue
539
+
540
+ try:
541
+ if int(status) == 0:
542
+ hint = speak
543
+ step_num += 1
544
+ continue
545
+ except:
546
+ step_num += 1
547
+ continue
548
+
549
+ else:
550
+ return answer, previous_log, re_time, history_log
551
+
552
+ def decompose_inference(self, query, query_id):
553
+ while True:
554
+ subtasks = self.Planning_Agent.run(question=query, query_id=query_id)
555
+ if subtasks == -1:
556
+ continue
557
+ break
558
+ task_log = ""
559
+ history_log = []
560
+ previous_log_totals = []
561
+ re_time_total = 0
562
+ # print(subtasks)
563
+ relevant_API_list = {}
564
+ tool_id = 0
565
+ for api in self.available_tools:
566
+ tool_name = api["api_name"]
567
+ ele = {"ID": tool_id, "tool_name": tool_name, "description": api["api_description"], "required_parameters": api["required_parameters"], "optional_parameters": api["optional_parameters"]}
568
+ relevant_API_list[str(tool_id)] = ele
569
+ tool_id += 1
570
+
571
+ for subtask in subtasks:
572
+ task_log += f"question: {subtask}\n"
573
+ answer, previous_log, re_time, previous_log_total = self.inference(task_log, relevant_API_list, subtask, query_id)
574
+ previous_log_totals.append(previous_log_total)
575
+ # print(answer)
576
+ history_log += previous_log
577
+ re_time_total += re_time
578
+ task_log += f"answer: {answer}\n"
579
+ final_answer = self.Answer_Agent.run(question=query, previous_log=task_log, task="final", query_id=query_id)
580
+ return final_answer, history_log, task_log, re_time_total, previous_log_totals
581
+
582
+ def run(self, input, query_id):
583
+ # result = {}
584
+ # st = time.time()
585
+ HPEnv = self.tool_class()
586
+ self.tool_env = {
587
+ "BingSearch": HPEnv.BingSearch,
588
+ "Retrieve": HPEnv.Retrieve,
589
+ "Lookup": HPEnv.Lookup
590
+ }
591
+ final_answer, previous_log, task_log,re_time, previous_log_totals = self.decompose_inference(input, query_id)
592
+ answer_details, total_steps = get_answer_details(final_answer, previous_log)
593
+ solution_tree, solution_total_steps = build_tree(previous_log_totals, task_log)
594
+ output_file_ele = {
595
+ "query": input,
596
+ "restart_time": re_time,
597
+ "answer": {
598
+ "method": "decompose_dfs",
599
+ "total_steps": total_steps,
600
+ "final_answer": final_answer,
601
+ "answer_details": answer_details
602
+ }
603
+ }
604
+
605
+ solution_file_ele = {
606
+ "query": input,
607
+ "total_steps": solution_total_steps,
608
+ "task_log": task_log,
609
+ "final_answer": final_answer,
610
+ "answer_path": answer_details,
611
+ "total_path": solution_tree
612
+ }
613
+ return final_answer, output_file_ele, solution_file_ele
614
+
615
+ def save_solution(self, output_file_ele, solution_file_ele, idx):
616
+ file_name = f"{idx}.json"
617
+ output_file = os.path.join(self.output_dir, file_name)
618
+ whole_solution_file = os.path.join(self.whole_solution_dir, file_name)
619
+ with open(output_file, "w") as file:
620
+ json.dump(output_file_ele, file, ensure_ascii=False, indent=4)
621
+ with open(whole_solution_file, "w") as file:
622
+ json.dump(solution_file_ele, file, ensure_ascii=False, indent=4)
623
+
624
+ def Call_function(self, tool_name, args):
625
+ try:
626
+ print(tool_name)
627
+ func = self.tool_env[tool_name]
628
+ observation = func(**args)
629
+ return observation
630
+ except Exception as e:
631
+ print(e)
632
+ print(f"Call function fails")
633
+ with open('wrong_log.json', 'a+', encoding='utf-8') as f:
634
+ line = json.dumps({
635
+ "id": 0,
636
+ "parameters": args,
637
+ "tool": tool_name,
638
+ "wrong": str(e)
639
+ }, ensure_ascii=False)
640
+ f.write(line + '\n')
641
+ return -1
642
+
643
+
644
+ class stream_smurfs_worker:
645
+ def __init__(self, available_tools, tool_env, llm, method_name, test_set, Answer_Agent, Executor_Agent, Planning_Agent, Verifier_Agent, OPENAI_API_KEY, BING_SUBSCRIPT_KEY, WOLFRAMALPH_APP_ID, WEATHER_API_KEYS):
646
+ #available_tools的格式形如toolbench里面的api_list里的格式,只需要api_name
647
+ #tool_env是一个工具函数里用来存储工具代码的py文件中的所有函数的字典,key为函数名,value是函数对象
648
+ self.OPENAI_API_KEY = OPENAI_API_KEY
649
+ self.BING_SUBSCRIPT_KEY = BING_SUBSCRIPT_KEY
650
+ self.WOLFRAMALPH_APP_ID = WOLFRAMALPH_APP_ID
651
+ self.WEATHER_API_KEYS = WEATHER_API_KEYS
652
+ #print(self.BING_SUBSCRIPT_KEY)
653
+ self.available_tools = available_tools
654
+
655
+ self.output_dir = f"data/{method_name}/{test_set}/answer"
656
+
657
+ self.whole_solution_dir = f"data/{method_name}/{test_set}/whole"
658
+
659
+ self.logger_dir = f"data/{method_name}/{test_set}/agent_log"
660
+
661
+ if not os.path.exists(self.output_dir):
662
+ os.makedirs(self.output_dir)
663
+
664
+ if not os.path.exists(self.whole_solution_dir):
665
+ os.makedirs(self.whole_solution_dir)
666
+
667
+ if not os.path.exists(self.logger_dir):
668
+ os.makedirs(self.logger_dir)
669
+
670
+ self.Answer_Agent = Answer_Agent(llm=llm, logger_dir=self.logger_dir)
671
+ self.Executor_Agent = Executor_Agent(llm=llm, logger_dir=self.logger_dir)
672
+ self.Planning_Agent = Planning_Agent(llm=llm, logger_dir=self.logger_dir)
673
+ self.Verifier_Agent = Verifier_Agent(llm=llm, logger_dir=self.logger_dir)
674
+ self.tool_env = tool_env
675
+
676
+ # def colorful_html(self, task, content, name):
677
+ # """print out message in different color"""
678
+ # role_to_color = {
679
+ # "Answer Agent": "red",
680
+ # "Executor Agent": "green",
681
+ # "Planning Agent": "blue",
682
+ # "Verifier Agent": "yellow",
683
+ # }
684
+ # color = role_to_color[name]
685
+ # html_text = f"<span style='color: {color}'>##########{task}##########<br>{content}<br></span>"
686
+ # # print(colored(f"##########{task}##########\n{content}\n", role_to_color[name]))
687
+ # return html_text
688
+
689
+ def colorful_html(self, task, content, name):
690
+ """print out message in different color"""
691
+ role_to_color = {
692
+ "Answer Agent": "red",
693
+ "Executor Agent": "green",
694
+ "Planning Agent": "blue",
695
+ "Verifier Agent": "yellow",
696
+ }
697
+ color = role_to_color[name]
698
+ if task != "Final Answer Generation":
699
+ html_text = f"""<details><summary>{task}<br></summary>{content}<br></details>"""
700
+ else:
701
+ html_text = content
702
+ # html_text = f"<span style='color: {color}'>##########{task}##########<br>{content}<br></span>"
703
+ # print(colored(f"##########{task}##########\n{content}\n", role_to_color[name]))
704
+ return html_text
705
+
706
+ def inference(self, query, relevant_APIs, subtask, query_id, max_step=3):
707
+ # tool_check_num, reason = self.Answer_Agent.run(question=query, task="tool_check", query_id=query_id)
708
+ # #direct answer
709
+ # if tool_check_num == 1:
710
+ # # input_dic = {"task": query}
711
+ # answer = self.Answer_Agent.run(question=query, task="direct", query_id=query_id)
712
+ # previous_log = [{"thought": reason, "action": "", "action_input": "", "observation": "", "answer": answer, "tool": "","id": 0}]
713
+ # history_log = [{"thought": reason, "action": "", "action_input": "", "observation": "", "answer": answer, "previous_id": -1, "id": 0}]
714
+ # return answer, previous_log, 0, history_log
715
+
716
+ previous_log = []
717
+ history_log = []
718
+ tool_used_dic = {}
719
+ relevant_APIs_ids = []
720
+ for idx in relevant_APIs:
721
+ ele = relevant_APIs[idx]
722
+ relevant_APIs_ids.append(str(ele["ID"]))
723
+ restart_time = 0
724
+ step_num = 0
725
+ hint = "Beginnig of the agent. No hint yet"
726
+ retry_tool_id = 0
727
+ retry_parameter = 0
728
+ re_time = 0
729
+ subtask_id = 0
730
+ restart = 0
731
+ while True:
732
+ if step_num >= max_step:
733
+ yield("<br><br>Reach steps limits, return answers!<br><br>")
734
+ answer_log = get_answer_log(history_log)
735
+ answer, task, agent_name, result = self.Answer_Agent.run(question=query, previous_log=answer_log, task="final", query_id=query_id)
736
+ yield self.colorful_html(task, result, agent_name)
737
+ yield answer, previous_log, re_time, history_log
738
+ yield "stop"
739
+
740
+ if step_num not in tool_used_dic.keys():
741
+ tool_used_dic[step_num] = []
742
+
743
+ tool_used = tool_used_dic[step_num]
744
+
745
+ tool_list = []
746
+ for idx in relevant_APIs:
747
+ ele = relevant_APIs[idx]
748
+ ID = str(ele['ID'])
749
+ if ID in tool_used:
750
+ continue
751
+ des = ele['description']
752
+ name = ele["tool_name"]
753
+ tool_list.append({"ID": ID, "tool_name": name, "description": des})
754
+
755
+ if len(tool_list) == 0:
756
+ if len(previous_log) == 0:
757
+ answer_log = get_answer_log(history_log)
758
+ partial_answer, task, agent_name, result = self.Answer_Agent.run(question=query, previous_log=answer_log, task="final", query_id=query_id)
759
+ answer = f"Sorry, I can't answer this question accurately using the existing tools. A partial answer is: {partial_answer}"
760
+ yield self.colorful_html(task, result, agent_name)
761
+ yield answer, previous_log, re_time, history_log
762
+ yield "stop"
763
+ else:
764
+ delete_log = previous_log.pop()
765
+ tool_used_dic[step_num] = []
766
+ step_num -= 1
767
+ tool_used_dic[step_num].append(delete_log["tool"])
768
+ restart_time += 1
769
+ re_time += 1
770
+ continue
771
+
772
+ current_log = {"thought": "", "action": "", "action_input": {}, "observation": "", "answer": "", "tool": "","id": subtask_id}
773
+
774
+ answer_log = get_answer_log(previous_log)
775
+
776
+ if retry_tool_id == 4:
777
+ tool_id = tool_list[0]["ID"]
778
+ tool_list = tool_list[0]
779
+ thought, task, agent_name, result = self.Executor_Agent.run(question=query, tool_list=tool_list, previous_log=answer_log, hint=hint, query_id=query_id, task="thought")
780
+ yield self.colorful_html(task, result, agent_name)
781
+ else:
782
+ thought, task, agent_name, result = self.Executor_Agent.run(question=query, tool_list=tool_list, previous_log=answer_log, hint=hint, query_id=query_id, task="thought")
783
+ yield self.colorful_html(task, result, agent_name)
784
+ tool_id, task, agent_name, result = self.Executor_Agent.run(question=subtask, tool_list=tool_list, thought=thought, query_id=query_id, task="tool")
785
+ yield self.colorful_html(task, result, agent_name)
786
+
787
+ try:
788
+ tool_id = int(tool_id)
789
+ tool_id = str(tool_id)
790
+ if tool_id not in relevant_APIs_ids:
791
+ re_time += 1
792
+ retry_tool_id += 1
793
+ yield("Tool ID wrong! Generate tool_id that do not exist!<br>")
794
+ continue
795
+ tool_des_json = relevant_APIs[str(tool_id)]
796
+ retry_tool_id = 0
797
+ except:
798
+ retry_tool_id += 1
799
+ yield("Tool ID wrong! Generate tool_id that do not exist!<br>")
800
+ continue
801
+
802
+ # tool_name_list = tool_des_json["tool_name"].split(":")
803
+ # category_name = tool_name_list[0]
804
+ # tool_name = tool_name_list[1]
805
+ api_name = tool_des_json["tool_name"]
806
+ API_doc = tool_des_json
807
+
808
+ while True:
809
+ try:
810
+ parameters = {}
811
+
812
+ if retry_parameter == 4:
813
+ restart = 1
814
+ retry_parameter = 0
815
+ yield("No Para! Restart!<br>")
816
+ break
817
+
818
+ parameter, task, agent_name, result = self.Executor_Agent.run(api_dic=API_doc, question=query, previous_log=answer_log, thought=thought, query_id=query_id, task="parameter")
819
+ yield self.colorful_html(task, result, agent_name)
820
+
821
+ if parameter == -1:
822
+ retry_parameter += 1
823
+ re_time += 1
824
+ continue
825
+
826
+ if parameter == {}:
827
+ retry_parameter = 0
828
+ parameters = {}
829
+ break
830
+
831
+ for key in parameter:
832
+ value = parameter[key]
833
+ key = change_name(key)
834
+ parameters[key] = value
835
+
836
+ retry_parameter = 0
837
+ break
838
+
839
+ except:
840
+ if retry_parameter == 4:
841
+ parameters = {}
842
+ retry_parameter = 0
843
+ restart = 1
844
+ break
845
+ retry_parameter += 1
846
+ yield("parameter generation fails, try again!<br>")
847
+ re_time += 1
848
+ continue
849
+
850
+
851
+ # api_name = change_name(standardize(api_name))
852
+
853
+ if restart != 1:
854
+ try:
855
+ observation = self.Call_function(api_name, parameters)
856
+ except:
857
+ observation = -1
858
+
859
+ if observation == -1:
860
+ restart = 1
861
+ observation = str({"error": "", "response": "call API fails"})
862
+
863
+ if restart == 1:
864
+ tool_used_dic[step_num].append(str(tool_id))
865
+ yield('****Try Again For This Step****<br>')
866
+ re_time += 1
867
+ restart = 0
868
+ continue
869
+
870
+
871
+ if len(previous_log) != 0:
872
+ previous_id = previous_log[-1]["id"]
873
+ else:
874
+ previous_id = -1
875
+
876
+ current_log["tool"] = str(tool_id)
877
+ current_log["thought"] = thought
878
+ current_log["action"] = api_name
879
+ current_log["action_input"] = parameters
880
+ current_log["observation"] = observation
881
+ previous_log.append(current_log)
882
+ yield(f"<details><summary>Tool Response</summary>{observation}<br></details>")
883
+ # print(f"{observation}\n")
884
+ observation_log = get_observation_log(previous_log)
885
+
886
+ answer, task, agent_name, result = self.Answer_Agent.run(question=subtask, call_result=observation_log, query_id=query_id, task="answer")
887
+ yield self.colorful_html(task, result, agent_name)
888
+ previous_log[-1]["answer"] = answer
889
+
890
+ history_log_ele = {"thought": thought, "action": api_name, "action_input": parameters, "observation": observation, "answer": answer, "previous_id": previous_id, "id": subtask_id}
891
+ history_log.append(history_log_ele)
892
+ subtask_id += 1
893
+
894
+ speak, status, task, agent_name, result = self.Verifier_Agent.run(question=subtask, answer=answer, query_id=query_id)
895
+ yield self.colorful_html(task, result, agent_name)
896
+
897
+ if speak == -1 and status == -1:
898
+ step_num += 1
899
+ continue
900
+
901
+ try:
902
+ if int(status) == 0:
903
+ hint = speak
904
+ step_num += 1
905
+ continue
906
+ except:
907
+ step_num += 1
908
+ continue
909
+
910
+ else:
911
+ yield answer, previous_log, re_time, history_log
912
+ yield "stop"
913
+
914
+ def run(self, query, query_id):
915
+ output = ""
916
+ count = 0
917
+ while True:
918
+ subtasks, task, agent_name, result = self.Planning_Agent.run(question=query, query_id=query_id)
919
+ if subtasks == -1:
920
+ count += 1
921
+ if count >= 1:
922
+ yield "Task Decompose Fails! Your OpenAI Key can not function correctly."
923
+ raise RuntimeError
924
+ continue
925
+
926
+ break
927
+ output += self.colorful_html(task, result, agent_name)
928
+ yield output
929
+ task_log = ""
930
+ history_log = []
931
+ previous_log_totals = []
932
+ re_time_total = 0
933
+ # print(subtasks)
934
+ relevant_API_list = {}
935
+ tool_id = 0
936
+ for api in self.available_tools:
937
+ tool_name = api["api_name"]
938
+ ele = {"ID": tool_id, "tool_name": tool_name, "description": api["api_description"], "required_parameters": api["required_parameters"], "optional_parameters": api["optional_parameters"]}
939
+ relevant_API_list[str(tool_id)] = ele
940
+ tool_id += 1
941
+
942
+ for subtask in subtasks:
943
+ sub_output = ""
944
+ output_ele = "<details><summary>subtask: {subtask}</summary>{sub_output}</details>"
945
+ task_log += f"question: {subtask}\n"
946
+ inference_generator = self.inference(task_log, relevant_API_list, subtask, query_id)
947
+ # old = None
948
+ while True:
949
+ try:
950
+ result = next(inference_generator)
951
+ # if result == "stop":
952
+ # break
953
+ if isinstance(result, str):
954
+ sub_output += result
955
+ sub_out = output_ele.format(subtask=subtask, sub_output=sub_output)
956
+ output_ins = output + sub_out
957
+ yield output_ins
958
+ else:
959
+ break
960
+ except StopIteration:
961
+ break
962
+ output += sub_out
963
+ answer, previous_log, re_time, previous_log_total = result
964
+ #answer, previous_log, re_time, previous_log_total = self.inference(task_log, relevant_API_list, subtask, query_id)
965
+ previous_log_totals.append(previous_log_total)
966
+ # print(answer)
967
+ history_log += previous_log
968
+ re_time_total += re_time
969
+ task_log += f"answer: {answer}\n"
970
+ final_answer, task, agent_name, result = self.Answer_Agent.run(question=query, previous_log=task_log, task="final", query_id=query_id)
971
+ output += self.colorful_html(task, result, agent_name)
972
+ yield output
973
+ # return final_answer, history_log, task_log, re_time_total, previous_log_totals
974
+
975
+ # def run(self, input, query_id):
976
+ # # result = {}
977
+ # # st = time.time()
978
+ # final_answer, previous_log, task_log,re_time, previous_log_totals = self.decompose_inference(input, query_id)
979
+ # answer_details, total_steps = get_answer_details(final_answer, previous_log)
980
+ # solution_tree, solution_total_steps = build_tree(previous_log_totals, task_log)
981
+ # output_file_ele = {
982
+ # "query": input,
983
+ # "restart_time": re_time,
984
+ # "answer": {
985
+ # "method": "decompose_dfs",
986
+ # "total_steps": total_steps,
987
+ # "final_answer": final_answer,
988
+ # "answer_details": answer_details
989
+ # }
990
+ # }
991
+
992
+ # solution_file_ele = {
993
+ # "query": input,
994
+ # "total_steps": solution_total_steps,
995
+ # "task_log": task_log,
996
+ # "final_answer": final_answer,
997
+ # "answer_path": answer_details,
998
+ # "total_path": solution_tree
999
+ # }
1000
+ # return final_answer, output_file_ele, solution_file_ele
1001
+
1002
+ def save_solution(self, output_file_ele, solution_file_ele, idx):
1003
+ file_name = f"{idx}.json"
1004
+ output_file = os.path.join(self.output_dir, file_name)
1005
+ whole_solution_file = os.path.join(self.whole_solution_dir, file_name)
1006
+ with open(output_file, "w") as file:
1007
+ json.dump(output_file_ele, file, ensure_ascii=False, indent=4)
1008
+ with open(whole_solution_file, "w") as file:
1009
+ json.dump(solution_file_ele, file, ensure_ascii=False, indent=4)
1010
+
1011
+ def Call_function(self, tool_name, args):
1012
+ try:
1013
+ print(tool_name)
1014
+ # print(self.BING_SUBSCRIPT_KEY)
1015
+ if tool_name == "bing_search" or tool_name == "BingSearch":
1016
+ args["key"] = self.BING_SUBSCRIPT_KEY
1017
+ if tool_name == "forecast_weather" or tool_name == "get_weather_today":
1018
+ args["KEY"] = self.WEATHER_API_KEYS
1019
+ if tool_name == "getWolframAlphaResults":
1020
+ args["APPID"] = self.WOLFRAMALPH_APP_ID
1021
+
1022
+ print(args)
1023
+ func = self.tool_env[tool_name]
1024
+ observation = func(**args)
1025
+ return observation
1026
+ except Exception as e:
1027
+ print(e)
1028
+ print(f"Call function fails")
1029
+ with open('wrong_log.json', 'a+', encoding='utf-8') as f:
1030
+ line = json.dumps({
1031
+ "id": 0,
1032
+ "parameters": args,
1033
+ "tool": tool_name,
1034
+ "wrong": str(e)
1035
+ }, ensure_ascii=False)
1036
+ f.write(line + '\n')
1037
+ return -1
1038
+
1039
+
1040
+
Smurfs/inference/utils.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # — coding: utf-8 –
2
+ import json
3
+ import re
4
+ import os
5
+ from tqdm import tqdm
6
+
7
+ def get_white_list(tool_root_dir):
8
+ # print(tool_root_dir)
9
+ white_list_dir = os.path.join(tool_root_dir)
10
+ white_list = {}
11
+ for cate in tqdm(os.listdir(white_list_dir)):
12
+ if not os.path.isdir(os.path.join(white_list_dir,cate)):
13
+ continue
14
+ for file in os.listdir(os.path.join(white_list_dir,cate)):
15
+ if not file.endswith(".json"):
16
+ continue
17
+ standard_tool_name = file.split(".")[0]
18
+ # print(standard_tool_name)
19
+ with open(os.path.join(white_list_dir,cate,file)) as reader:
20
+ js_data = json.load(reader)
21
+ origin_tool_name = js_data["tool_name"]
22
+ white_list[standardize(origin_tool_name)] = {"description": js_data["tool_description"], "standard_tool_name": standard_tool_name}
23
+ return white_list
24
+
25
+ def build_index(base_path):
26
+ index = {}
27
+ for root, dirs, files in os.walk(base_path):
28
+ for dir_name in dirs:
29
+ if dir_name not in index:
30
+ index[dir_name] = []
31
+ index[dir_name].append(root)
32
+ return index
33
+
34
+
35
+ def change_name(name):
36
+ change_list = ["from", "class", "return", "false", "true", "id", "and", "", "ID"]
37
+ if name in change_list:
38
+ name = "is_" + name.lower()
39
+ return name
40
+
41
+
42
+ def standardize(string):
43
+ res = re.compile("[^\\u4e00-\\u9fa5^a-z^A-Z^0-9^_]")
44
+ string = res.sub("_", string)
45
+ string = re.sub(r"(_)\1+", "_", string).lower()
46
+ while True:
47
+ if len(string) == 0:
48
+ return string
49
+ if string[0] == "_":
50
+ string = string[1:]
51
+ else:
52
+ break
53
+ while True:
54
+ if len(string) == 0:
55
+ return string
56
+ if string[-1] == "_":
57
+ string = string[:-1]
58
+ else:
59
+ break
60
+ if string[0].isdigit():
61
+ string = "get_" + string
62
+ return string
63
+
64
+ def get_answer_log(log):
65
+ if log == []:
66
+ return "Beginnig of the agent. No log yet"
67
+ answer_logs = []
68
+ for ele in log:
69
+ answer_log = {"thought": "", "answer": ""}
70
+ answer_log["thought"] = ele["thought"]
71
+ answer_log["answer"] = ele["answer"]
72
+ answer_logs.append(answer_log)
73
+ return answer_logs
74
+
75
+ def get_observation_log(log):
76
+ if log == []:
77
+ return ""
78
+ answer_logs = []
79
+ for i, ele in enumerate(log):
80
+ if i == len(log)-1:
81
+ answer_log = {"thought": "", "observation": ""}
82
+ answer_log["thought"] = ele["thought"]
83
+ answer_log["observation"] = ele["observation"]
84
+ answer_logs.append(answer_log)
85
+ else:
86
+ answer_log = {"thought": "", "answer": ""}
87
+ answer_log["thought"] = ele["thought"]
88
+ answer_log["answer"] = ele["answer"]
89
+ answer_logs.append(answer_log)
90
+ return answer_logs
91
+
92
+ def build_tree(previous_log_totals, task_log):
93
+
94
+ total_root_list = []
95
+ total_total_steps = 0
96
+ task_log_list = task_log.split("question: ")[1:]
97
+ for i in range(len(task_log_list)):
98
+ task_log_list[i] = task_log_list[i].split("answer: ")
99
+
100
+ for j, previous_log_total in enumerate(previous_log_totals):
101
+
102
+ if previous_log_total == None:
103
+
104
+ answer_detail = {
105
+ "role": "plan_global",
106
+ "message": {
107
+ "subtask": task_log_list[j][0],
108
+ "subtask_answer": task_log_list[j][1]
109
+ },
110
+ "total_steps": 0,
111
+ "next": []
112
+ }
113
+ total_root_list.append(answer_detail)
114
+ continue
115
+
116
+ next_list = []
117
+ root_list = []
118
+ total_steps = 0
119
+ for i in range(len(previous_log_total)):
120
+
121
+ current_log = previous_log_total[i]
122
+
123
+ tool_call_list = []
124
+ api_name = current_log["action"]
125
+ parameter = current_log["action_input"]
126
+ response = current_log["observation"]
127
+ next_ele = {
128
+ "role": "tool",
129
+ "message": {
130
+ "name": api_name,
131
+ "arguments": parameter,
132
+ "response": response
133
+ },
134
+ "next": []
135
+ }
136
+ tool_call_list.append(next_ele)
137
+ total_steps += 1
138
+ if len(tool_call_list) > 1:
139
+ for k in range(len(tool_call_list)-2, -1, -1):
140
+ tool_call_list[k]["next"].append(tool_call_list[k+1])
141
+ next_list.append(tool_call_list[0])
142
+ total_total_steps += total_steps
143
+
144
+ for i in range(len(next_list)-1, -1, -1):
145
+ current_log = next_list[i]
146
+ current_log_pre_id = previous_log_total[i]["previous_id"]
147
+ if current_log_pre_id == -1:
148
+ # print(current_log)
149
+ root_list.append(current_log)
150
+ else:
151
+ next_list[current_log_pre_id]["next"].append(current_log)
152
+ answer_detail = {
153
+ "role": "plan_global",
154
+ "message": {
155
+ "subtask": task_log_list[j][0],
156
+ "subtask_answer": task_log_list[j][1]
157
+ },
158
+ "total_steps": total_steps,
159
+ "next": root_list
160
+ }
161
+ total_root_list.append(answer_detail)
162
+
163
+ answer_details = {
164
+ "role": "system",
165
+ "message": "",
166
+ "next": [
167
+ {
168
+ "role": "user",
169
+ "message": "",
170
+ "next": total_root_list
171
+ }
172
+ ]
173
+ }
174
+ return answer_details, total_total_steps
175
+
176
+ def get_answer_details(final_answer, previous_log):
177
+
178
+ next_list = []
179
+ total_steps = 0
180
+ for i in range(len(previous_log)):
181
+ current_log = previous_log[i]
182
+ if not isinstance(current_log, dict):
183
+ next_ele = {
184
+ "role": "assistant",
185
+ "message": current_log,
186
+ "next": []
187
+ }
188
+ next_list.append(next_ele)
189
+ total_steps += 1
190
+ continue
191
+
192
+ api_name = current_log["action"]
193
+ parameter = current_log["action_input"]
194
+ response = current_log["observation"]
195
+ next_ele = {
196
+ "role": "tool",
197
+ "message": {
198
+ "name": api_name,
199
+ "arguments": parameter,
200
+ "response": response
201
+ },
202
+ "next": []
203
+ }
204
+ next_list.append(next_ele)
205
+ total_steps += 1
206
+ answer_ele = {
207
+ "role": "tool",
208
+ "message": {
209
+ "name": "Finish",
210
+ "arguments": {
211
+ "return_type": "give_answer",
212
+ "final_answer": final_answer
213
+ },
214
+ "response": ""
215
+ },
216
+ "next": []
217
+ }
218
+ next_list.append(answer_ele)
219
+ for i in range(len(next_list)-2, -1, -1):
220
+ next_list[i]["next"].append(next_list[i+1])
221
+ next_result = next_list[0]
222
+ answer_details = {
223
+ "role": "system",
224
+ "message": "",
225
+ "next": [
226
+ {
227
+ "role": "user",
228
+ "message": "",
229
+ "next": [next_result]
230
+ }
231
+ ]
232
+ }
233
+ return answer_details, total_steps
234
+
235
+ def contain(candidate_list, white_list):
236
+ output = []
237
+ for cand in candidate_list:
238
+ if cand not in white_list.keys():
239
+ return False
240
+ output.append(white_list[cand])
241
+ return output
242
+
243
+ # def fetch_api_json(api_list, tool_root_dir):
244
+ # data_dict = {"api_list":[]}
245
+ # for item in api_list:
246
+ # cate_name = item["category_name"]
247
+ # tool_name = standardize(item["tool_name"])
248
+ # api_name = change_name(standardize(item["api_name"]))
249
+ # tool_json = json.load(open(os.path.join(tool_root_dir, cate_name, tool_name + ".json"), "r"))
250
+ # append_flag = False
251
+ # api_dict_names = []
252
+ # for api_dict in tool_json["api_list"]:
253
+ # api_dict_names.append(api_dict["name"])
254
+ # pure_api_name = change_name(standardize(api_dict["name"]))
255
+ # if pure_api_name != api_name:
256
+ # continue
257
+ # api_json = {}
258
+ # api_json["category_name"] = cate_name
259
+ # api_json["api_name"] = api_dict["name"]
260
+ # api_json["api_description"] = api_dict["description"]
261
+ # api_json["required_parameters"] = api_dict["required_parameters"]
262
+ # api_json["optional_parameters"] = api_dict["optional_parameters"]
263
+ # api_json["tool_name"] = tool_json["tool_name"]
264
+ # data_dict["api_list"].append(api_json)
265
+ # append_flag = True
266
+ # break
267
+ # if not append_flag:
268
+ # print(api_name, api_dict_names)
269
+ # return data_dict
270
+
271
+ # def api_json_to_openai_json(api_json,standard_tool_name):
272
+ # description_max_length=256
273
+ # templete = {
274
+ # "name": "",
275
+ # "description": "",
276
+ # "parameters": {
277
+ # "type": "object",
278
+ # "properties": {
279
+ # },
280
+ # "required": [],
281
+ # "optional": [],
282
+ # }
283
+ # }
284
+
285
+ # map_type = {
286
+ # "NUMBER": "integer",
287
+ # "STRING": "string",
288
+ # "BOOLEAN": "boolean"
289
+ # }
290
+
291
+ # pure_api_name = change_name(standardize(api_json["api_name"]))
292
+ # templete["name"] = pure_api_name+ f"_for_{standard_tool_name}"
293
+ # templete["name"] = templete["name"][-64:]
294
+
295
+ # templete["description"] = f"This is the subfunction for tool \"{standard_tool_name}\", you can use this tool."
296
+
297
+ # if api_json["api_description"].strip() != "":
298
+ # tuncated_description = api_json['api_description'].strip().replace(api_json['api_name'],templete['name'])[:description_max_length]
299
+ # templete["description"] = templete["description"] + f"The description of this function is: \"{tuncated_description}\""
300
+ # if "required_parameters" in api_json.keys() and len(api_json["required_parameters"]) > 0:
301
+ # for para in api_json["required_parameters"]:
302
+ # name = standardize(para["name"])
303
+ # name = change_name(name)
304
+ # if para["type"] in map_type:
305
+ # param_type = map_type[para["type"]]
306
+ # else:
307
+ # param_type = "string"
308
+ # prompt = {
309
+ # "type":param_type,
310
+ # "description":para["description"][:description_max_length],
311
+ # }
312
+
313
+ # default_value = para['default']
314
+ # if len(str(default_value)) != 0:
315
+ # prompt = {
316
+ # "type":param_type,
317
+ # "description":para["description"][:description_max_length],
318
+ # "example_value": default_value
319
+ # }
320
+ # else:
321
+ # prompt = {
322
+ # "type":param_type,
323
+ # "description":para["description"][:description_max_length]
324
+ # }
325
+
326
+ # templete["parameters"]["properties"][name] = prompt
327
+ # templete["parameters"]["required"].append(name)
328
+ # for para in api_json["optional_parameters"]:
329
+ # name = standardize(para["name"])
330
+ # name = change_name(name)
331
+ # if para["type"] in map_type:
332
+ # param_type = map_type[para["type"]]
333
+ # else:
334
+ # param_type = "string"
335
+
336
+ # default_value = para['default']
337
+ # if len(str(default_value)) != 0:
338
+ # prompt = {
339
+ # "type":param_type,
340
+ # "description":para["description"][:description_max_length],
341
+ # "example_value": default_value
342
+ # }
343
+ # else:
344
+ # prompt = {
345
+ # "type":param_type,
346
+ # "description":para["description"][:description_max_length]
347
+ # }
348
+
349
+ # templete["parameters"]["properties"][name] = prompt
350
+ # templete["parameters"]["optional"].append(name)
351
+
352
+ # return templete, api_json["category_name"], pure_api_name
353
+
354
+
355
+
356
+ test_sets = ["G2_category"]
Smurfs/model/__init__.py ADDED
File without changes